Spaces:
Build error
Build error
File size: 4,516 Bytes
67f827d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
from nltk.translate.bleu_score import corpus_bleu
from tqdm import tqdm
from dataset import Vocabulary
from skimage import transform
from model import *
from utils import *
import torchvision.transforms as T
from PIL import Image
import argparse
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Will only work for batch size 1
def get_all_captions(img, model, vocab=None):
features = model.EncoderCNN(img[0:1].to(device))
caps, alphas = model.DecoderLSTM.gen_captions(features, vocab=vocab)
caps = caps[:-2]
return caps
def calculate_bleu_score(dataloader, model, vocab):
candidate_corpus = []
references_corpus = []
for batch in tqdm(dataloader, total=len(dataloader)):
img, cap, all_caps = batch
img, cap = img.to(device), cap.to(device)
caps = get_all_captions(img, model, vocab)
candidate_corpus.append(caps)
references_corpus.append(all_caps[0])
assert len(candidate_corpus) == len(references_corpus)
print(f"\nBLEU1 = {corpus_bleu(references_corpus, candidate_corpus, (1, 0, 0, 0))}")
print(f"BLEU2 = {corpus_bleu(references_corpus, candidate_corpus, (0.5, 0.5, 0, 0))}")
print(f"BLEU3 = {corpus_bleu(references_corpus, candidate_corpus, (0.33, 0.33, 0.33, 0))}")
print(f"BLEU4 = {corpus_bleu(references_corpus, candidate_corpus, (0.25, 0.25, 0.25, 0.25))}")
def get_caps_from(features_tensors, model, vocab=None):
model.eval()
with torch.no_grad():
features = model.EncoderCNN(features_tensors[0:1].to(device))
caps, alphas = model.DecoderLSTM.gen_captions(features, vocab=vocab)
caption = ' '.join(caps)
show_img(features_tensors[0], caption)
return caps, alphas
def plot_attention(img, target, attention_plot):
img = img.to('cpu').numpy().transpose((1, 2, 0))
temp_image = img
fig = plt.figure(figsize=(15, 15))
len_caps = len(target)
for i in range(len_caps):
temp_att = attention_plot[i].reshape(7, 7)
temp_att = transform.pyramid_expand(temp_att, upscale=24, sigma=8)
ax = fig.add_subplot(len_caps // 2, len_caps // 2, i + 1)
ax.set_title(target[i])
img = ax.imshow(temp_image)
ax.imshow(temp_att, cmap='gray', alpha=0.5, extent=img.get_extent())
plt.tight_layout()
plt.show()
def plot_caption_with_attention(img_pth, model, transforms_=None, vocab=None):
img = Image.open(img_pth)
img = transforms_(img)
img.unsqueeze_(0)
caps, attention = get_caps_from(img, model, vocab)
plot_attention(img[0], caps, attention)
def main(arguments):
state_checkpoint = torch.load(arguments.state_chechpoint, map_location=device) # change paths
# model params
vocab = state_checkpoint['vocab']
embed_size = arguments.embed_size
embed_wts = None
vocab_size = state_checkpoint['vocab_size']
attention_dim = arguments.attention_dim
encoder_dim = arguments.encoder_dim
decoder_dim = arguments.decoder_dim
fc_dims = arguments.fc_dims
model = EncoderDecoder(embed_size,
vocab_size,
attention_dim,
encoder_dim,
decoder_dim,
fc_dims,
p=0.3,
embeddings=embed_wts).to(device)
model.load_state_dict(state_checkpoint['state_dict'])
transforms = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
img_path = arguments.image
plot_caption_with_attention(img_path, model, transforms, vocab)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--image', type=str, required=True, help='input image for generating caption')
parser.add_argument('--state_checkpoint', type=str, required=True, help='path for state checkpoint')
parser.add_argument('--embed_size', type=int, default=300, help='dimension of word embedding vectors')
parser.add_argument('--attention_dim', type=int, default=256, help='dimension of attention layer')
parser.add_argument('--encoder_dim', type=int, default=2048, help='dimension of encoder layer')
parser.add_argument('--decoder_dim', type=int, default=512, help='dimension of decoder layer')
parser.add_argument('--fc_dims', type=int, default=256, help='dimension of fully connected layer')
args = parser.parse_args()
main(args)
|