import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() resnet = models.resnet50(pretrained=True) for param in resnet.parameters(): param.requires_grad_(False) modules = list(resnet.children())[:-2] # extracting the last conv layer from the model self.resnet = nn.Sequential(*modules) def forward(self, imgs): features = self.resnet(imgs) features = features.permute(0, 2, 3, 1) # batch x 7 x 7 x 2048 features = features.view(features.size(0), -1, features.size(-1)) # batch x 49 x 2048 return features class Attention(nn.Module): def __init__(self, encoder_dims, decoder_dims, attention_dims): super(Attention, self).__init__() self.attention_dims = attention_dims # size of attention network self.U = nn.Linear(encoder_dims, attention_dims) # a^(t) self.W = nn.Linear(decoder_dims, attention_dims) # s^(t` - 1) self.A = nn.Linear(attention_dims, 1) # cvt the attention dims back to 1 def forward(self, features, hidden): u_as = self.U(features) w_as = self.W(hidden) combined_state = torch.tanh(u_as + w_as.unsqueeze(1)) attention_score = self.A(combined_state) attention_score = attention_score.squeeze(2) alpha = F.softmax(attention_score, dim=1) attention_weights = features * alpha.unsqueeze(2) # batch x num_timesteps (49) x features attention_weights = attention_weights.sum(dim=1) return alpha, attention_weights class Decoder(nn.Module): def __init__(self, embed_size, vocab_size, attention_dim, encoder_dim, decoder_dim, fc_dims, p=0.3, embeddings=None): super().__init__() self.vocab_size = vocab_size self.attention_dim = attention_dim self.decoder_dim = decoder_dim self.embedding = nn.Embedding(vocab_size, embedding_dim=embed_size) self.attention = Attention(encoder_dim, decoder_dim, attention_dim) self.init_h = nn.Linear(encoder_dim, decoder_dim) self.init_c = nn.Linear(encoder_dim, decoder_dim) self.lstm = nn.LSTMCell(encoder_dim + embed_size, decoder_dim, bias=True) self.fcn1 = nn.Linear(decoder_dim, vocab_size) self.fcn2 = nn.Linear(fc_dims, vocab_size) self.drop = nn.Dropout(p) if embeddings is not None: self.load_pretrained_embed(embeddings) def forward(self, features, captions): seq_length = len(captions[0]) - 1 # Exclude the last one batch_size = captions.size(0) num_timesteps = features.size(1) embed = self.embedding(captions) h, c = self.init_hidden_state(features) # initialize h and c for LSTM preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device) alphas = torch.zeros(batch_size, seq_length, num_timesteps).to(device) for s in range(seq_length): alpha, context = self.attention(features, h) lstm_inp = torch.cat((embed[:, s], context), dim=1) h, c = self.lstm(lstm_inp, (h, c)) out = self.drop(self.fcn1(h)) preds[:, s] = out alphas[:, s] = alpha return preds, alphas def gen_captions(self, features, max_len=20, vocab=None): h, c = self.init_hidden_state(features) alphas = [] captions = [] word = torch.tensor(vocab.stoi[""]).view(1, -1).to(device) embed = self.embedding(word) for i in range(max_len): alpha, context = self.attention(features, h) alphas.append(alpha.cpu().detach().numpy()) lstm_inp = torch.cat((embed[:, 0], context), dim=1) h, c = self.lstm(lstm_inp, (h, c)) out = self.drop(self.fcn1(h)) word_out_idx = torch.argmax(out, dim=1) captions.append(word_out_idx.item()) if vocab.itos[word_out_idx.item()] == "": break embed = self.embedding(word_out_idx.unsqueeze(0)) return [vocab.itos[word] for word in captions], alphas def load_pretrained_embed(self, embeddings): self.embedding.weight = nn.Parameter(embeddings) for p in self.embedding.parameters(): p.requires_grad = True def init_hidden_state(self, encoder_output): mean_encoder_out = encoder_output.mean(dim=1) h = self.init_h(mean_encoder_out) c = self.init_c(mean_encoder_out) return h, c class EncoderDecoder(nn.Module): def __init__(self, embed_size, vocab_size, attention_dim, encoder_dim, decoder_dim, fc_dims, p=0.3, embeddings=None): super().__init__() self.EncoderCNN = Encoder() self.DecoderLSTM = Decoder(embed_size, vocab_size, attention_dim, encoder_dim, decoder_dim, fc_dims, p, embeddings) def forward(self, imgs, caps): features = self.EncoderCNN(imgs) out = self.DecoderLSTM(features, caps) return out