| import torch |
| from typing import Any, Optional |
| from transformers import LayoutLMv2ForQuestionAnswering |
| from transformers import LayoutLMv2Processor |
| from transformers import LayoutLMv2FeatureExtractor |
| from transformers import LayoutLMv2ImageProcessor |
| from transformers import LayoutLMv2TokenizerFast |
| from transformers.tokenization_utils_base import BatchEncoding |
| from transformers.tokenization_utils_base import TruncationStrategy |
| from transformers.utils import TensorType |
| |
| |
| |
| import numpy as np |
| |
| |
| import pdf2image |
| |
| import logging |
| from os import environ |
| |
|
|
| |
| |
| |
|
|
| feature_extractor = LayoutLMv2FeatureExtractor() |
|
|
| |
| |
| |
|
|
| class NoOCRReaderFound(Exception): |
| def __init__(self, e): |
| self.e = e |
|
|
| def __str__(self): |
| return f"Could not load OCR Reader: {self.e}" |
|
|
| def pdf_to_image(b: bytes): |
| |
| |
| |
| |
| images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(b)] |
| encoded_inputs = feature_extractor(images) |
| print('feature_extractor: ', encoded_inputs.keys()) |
| data = {} |
| data['image'] = encoded_inputs.pixel_values |
| data['words'] = encoded_inputs.words |
| data['boxes'] = encoded_inputs.boxes |
| return data |
|
|
|
|
| def setup_logger(which_logger: Optional[str] = None): |
| lib_level = logging.DEBUG |
| root_level = logging.INFO |
| log_format = '%(asctime)s - %(process)d - %(levelname)s - %(funcName)s - %(message)s' |
| logging.basicConfig( |
| filename=environ.get('LOG_FILE_PATH_LAYOUTLM_V2'), |
| format=log_format, |
| datefmt='%d-%b-%y %H:%M:%S', |
| level=root_level, |
| force=True |
| ) |
| log = logging.getLogger(which_logger) |
| log.setLevel(lib_level) |
| return log |
|
|
| logger = setup_logger(__name__) |
|
|
|
|
| class Funcs: |
| |
| @staticmethod |
| def unnormalize_box(bbox, width, height): |
| return [ |
| width * (bbox[0] / 1000), |
| height * (bbox[1] / 1000), |
| width * (bbox[2] / 1000), |
| height * (bbox[3] / 1000), |
| ] |
|
|
| @staticmethod |
| def num_spans(encoding: BatchEncoding) -> int: |
| return len(encoding["input_ids"]) |
|
|
| @staticmethod |
| def p_mask(num_spans: int, encoding: BatchEncoding) -> list: |
| try: |
| return [ |
| [tok != 1 for tok in encoding.sequence_ids(span_id)] \ |
| for span_id in range(num_spans) |
| ] |
| except Exception as e: |
| raise |
|
|
| @staticmethod |
| def token_start_end(encoding, tokenizer): |
| sequence_ids = encoding.sequence_ids() |
|
|
| |
| token_start_index = 0 |
| while sequence_ids[token_start_index] != 1: |
| token_start_index += 1 |
|
|
| |
| token_end_index = len(encoding.input_ids) - 1 |
| while sequence_ids[token_end_index] != 1: |
| token_end_index -= 1 |
|
|
| print("Token start index:", token_start_index) |
| print("Token end index:", token_end_index) |
| print('token_start_end: ', tokenizer.decode(encoding.input_ids[token_start_index:token_end_index+1])) |
| return token_start_index, token_end_index |
|
|
| @staticmethod |
| def reconstruct_answer(word_idx_start, word_idx_end, encoding, tokenizer): |
| word_ids = encoding.word_ids()[token_start_index:token_end_index+1] |
| print("Word ids:", word_ids) |
| for id in word_ids: |
| if id == word_idx_start: |
| start_position = token_start_index |
| else: |
| token_start_index += 1 |
|
|
| for id in word_ids[::-1]: |
| if id == word_idx_end: |
| end_position = token_end_index |
| else: |
| token_end_index -= 1 |
|
|
| print("Reconstructed answer:", |
| tokenizer.decode(encoding.input_ids[start_position:end_position+1]) |
| ) |
| return start_position, end_position |
|
|
| @staticmethod |
| def sigmoid(_outputs): |
| return 1.0 / (1.0 + np.exp(-_outputs)) |
|
|
| @staticmethod |
| def softmax(_outputs): |
| maxes = np.max(_outputs, axis=-1, keepdims=True) |
| shifted_exp = np.exp(_outputs - maxes) |
| return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True) |
|
|
| class EndpointHandler: |
| def __init__(self, path="./"): |
| |
| self.model = LayoutLMv2ForQuestionAnswering.from_pretrained(path) |
| self.tokenizer = LayoutLMv2TokenizerFast.from_pretrained(path) |
| |
| self.processor = LayoutLMv2Processor.from_pretrained( |
| path, |
| |
| tokenizer=self.tokenizer) |
|
|
| def __call__(self, data: dict[str, bytes]): |
| """ |
| Args: |
| data (:obj:): |
| includes the deserialized image file as PIL.Image |
| """ |
| image = data.pop("inputs", data) |
| images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(image)] |
|
|
| question = "what is the bill date" |
| with torch.no_grad(): |
| for image in images: |
| |
| |
| encoding = self.processor( |
| image, |
| question, |
| |
| |
| truncation=True, |
| |
| |
| |
| |
| return_tensors=TensorType.PYTORCH |
| ) |
| print('encoding: ', encoding.keys()) |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| outputs = self.model(**encoding) |
| |
| start_logits = outputs.start_logits |
| end_logits = outputs.end_logits |
|
|
| predicted_start_idx = start_logits.argmax(-1).item() |
| predicted_end_idx = end_logits.argmax(-1).item() |
|
|
| predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1] |
| predicted_answer = self.processor.tokenizer.decode(predicted_answer_tokens) |
| |
| target_start_index = torch.tensor([7]) |
| target_end_index = torch.tensor([14]) |
| outputs = self.model(**encoding, start_positions=target_start_index, end_positions=target_end_index) |
| |
| |
|
|
| logger.info(f''' |
| START |
| predicted_start_idx: {predicted_start_idx} |
| predicted_end_idx: {predicted_end_idx} |
| --- |
| answer: {predicted_answer} |
| |
| END''') |
| return {'data': 'success'} |
|
|