| import argparse |
| import logging |
| import os |
| import glob |
| import tqdm |
| import torch |
| import PIL |
| import cv2 |
| import numpy as np |
| import torch.nn.functional as F |
| from torchvision import transforms |
| from utils import Config, Logger, CharsetMapper |
|
|
| def get_model(config): |
| import importlib |
| names = config.model_name.split('.') |
| module_name, class_name = '.'.join(names[:-1]), names[-1] |
| cls = getattr(importlib.import_module(module_name), class_name) |
| model = cls(config) |
| logging.info(model) |
| model = model.eval() |
| return model |
|
|
| def preprocess(img, width, height): |
| img = cv2.resize(np.array(img), (width, height)) |
| img = transforms.ToTensor()(img).unsqueeze(0) |
| mean = torch.tensor([0.485, 0.456, 0.406]) |
| std = torch.tensor([0.229, 0.224, 0.225]) |
| return (img-mean[...,None,None]) / std[...,None,None] |
|
|
| def postprocess(output, charset, model_eval): |
| def _get_output(last_output, model_eval): |
| if isinstance(last_output, (tuple, list)): |
| for res in last_output: |
| if res['name'] == model_eval: output = res |
| else: output = last_output |
| return output |
|
|
| def _decode(logit): |
| """ Greed decode """ |
| out = F.softmax(logit, dim=2) |
| pt_text, pt_scores, pt_lengths = [], [], [] |
| for o in out: |
| text = charset.get_text(o.argmax(dim=1), padding=False, trim=False) |
| text = text.split(charset.null_char)[0] |
| pt_text.append(text) |
| pt_scores.append(o.max(dim=1)[0]) |
| pt_lengths.append(min(len(text) + 1, charset.max_length)) |
| return pt_text, pt_scores, pt_lengths |
|
|
| output = _get_output(output, model_eval) |
| logits, pt_lengths = output['logits'], output['pt_lengths'] |
| pt_text, pt_scores, pt_lengths_ = _decode(logits) |
| |
| return pt_text, pt_scores, pt_lengths_ |
|
|
| def load(model, file, device=None, strict=True): |
| if device is None: device = 'cpu' |
| elif isinstance(device, int): device = torch.device('cuda', device) |
| assert os.path.isfile(file) |
| state = torch.load(file, map_location=device) |
| if set(state.keys()) == {'model', 'opt'}: |
| state = state['model'] |
| model.load_state_dict(state, strict=strict) |
| return model |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config', type=str, default='configs/train_iternet.yaml', |
| help='path to config file') |
| parser.add_argument('--input', type=str, default='figures/demo') |
| parser.add_argument('--cuda', type=int, default=-1) |
| parser.add_argument('--checkpoint', type=str, default='workdir/train-iternet/best-train-iternet.pth') |
| parser.add_argument('--model_eval', type=str, default='alignment', |
| choices=['alignment', 'vision', 'language']) |
| args = parser.parse_args() |
| config = Config(args.config) |
| if args.checkpoint is not None: config.model_checkpoint = args.checkpoint |
| if args.model_eval is not None: config.model_eval = args.model_eval |
| config.global_phase = 'test' |
| config.model_vision_checkpoint, config.model_language_checkpoint = None, None |
| device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}' |
|
|
| Logger.init(config.global_workdir, config.global_name, config.global_phase) |
| Logger.enable_file() |
| logging.info(config) |
|
|
| logging.info('Construct model.') |
| model = get_model(config).to(device) |
| model = load(model, config.model_checkpoint, device=device) |
| charset = CharsetMapper(filename=config.dataset_charset_path, |
| max_length=config.dataset_max_length + 1) |
|
|
| if os.path.isdir(args.input): |
| paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)] |
| else: |
| paths = glob.glob(os.path.expanduser(args.input)) |
| assert paths, "The input path(s) was not found" |
| paths = sorted(paths) |
| for path in tqdm.tqdm(paths): |
| img = PIL.Image.open(path).convert('RGB') |
| img = preprocess(img, config.dataset_image_width, config.dataset_image_height) |
| img = img.to(device) |
| res = model(img) |
| pt_text, _, __ = postprocess(res, charset, config.model_eval) |
| logging.info(f'{path}: {pt_text[0]}') |
|
|
| if __name__ == '__main__': |
| main() |
|
|