Spaces:
Build error
Build error
| import torchvision.transforms as T | |
| from torch import optim | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data import random_split | |
| from tqdm import tqdm | |
| from dataset import * | |
| from model import * | |
| from utils import * | |
| spacy_eng = spacy.load('en') | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # init seed | |
| seed = torch.randint(100, (1,)) | |
| torch.manual_seed(seed) | |
| shuffle = True | |
| # src folders | |
| root_folder = "/content/flickr8k/Images" # change this | |
| csv_file = "/content/flickr8k/captions.txt" # change this | |
| # image transforms and augmentation | |
| transforms = T.Compose([ | |
| T.Resize(226), | |
| T.RandomCrop(224), | |
| T.ToTensor(), | |
| T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
| ]) | |
| # define dataset | |
| dataset = FlickrDataset(root_folder, csv_file, transforms) | |
| # split dataset | |
| val_size = 512 | |
| test_size = 256 | |
| train_size = len(dataset) - val_size - test_size | |
| train_ds, val_ds, test_ds = random_split(dataset, | |
| [train_size, val_size, test_size]) | |
| # Define data loader parameters | |
| num_workers = 4 | |
| pin_memory = True | |
| batch_size_train = 256 | |
| batch_size_val_test = 128 | |
| pad_idx = dataset.vocab.stoi["<PAD>"] | |
| # define loaders | |
| dataloader_train = DataLoader(train_ds, | |
| batch_size=batch_size_train, | |
| pin_memory=pin_memory, | |
| num_workers=num_workers, | |
| shuffle=shuffle, | |
| collate_fn=CapsCollate(pad_idx=pad_idx, batch_first=True)) | |
| dataloader_validation = DataLoader(val_ds, | |
| batch_size=batch_size_val_test, | |
| pin_memory=pin_memory, | |
| num_workers=num_workers, | |
| shuffle=shuffle, | |
| collate_fn=CapsCollate(pad_idx=pad_idx, batch_first=True)) | |
| dataloader_test = DataLoader(test_ds, | |
| batch_size=batch_size_val_test, | |
| pin_memory=pin_memory, | |
| num_workers=num_workers, | |
| shuffle=shuffle, | |
| collate_fn=CapsCollate(pad_idx=pad_idx, batch_first=True)) | |
| # model parameters | |
| embed_wts, embed_size = load_embeding("/content/glove.42B.300d.txt", dataset.vocab) # change path | |
| vocab_size = len(dataset.vocab) | |
| attention_dim = 256 | |
| encoder_dim = 2048 | |
| decoder_dim = 512 | |
| fc_dims = 256 | |
| learning_rate = 5e-4 | |
| model = EncoderDecoder(embed_size, | |
| vocab_size, | |
| attention_dim, | |
| encoder_dim, | |
| decoder_dim, | |
| fc_dims, | |
| p=0.3, | |
| embeddings=embed_wts).to(device) | |
| loss_fn = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"]) | |
| optimizer = optim.Adam(params=model.parameters(), lr=learning_rate) | |
| # training parmeters | |
| num_epochs = 35 | |
| train_loss_arr = [] | |
| val_loss_arr = [] | |
| def training(dataset, dataloader, loss_criteria, optimize, grad_clip=5.): | |
| total_loss = 0 | |
| for i, (img, cap) in enumerate(tqdm(dataloader, total=len(dataloader))): | |
| img, cap = img.to(device), cap.to(device) | |
| optimize.zero_grad() | |
| output, attention = model(img, cap) | |
| targets = cap[:, 1:] | |
| loss = loss_criteria(output.view(-1, vocab_size), targets.reshape(-1)) | |
| total_loss += (loss.item()) | |
| loss.backward() | |
| if grad_clip: | |
| nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| optimize.step() | |
| total_loss = total_loss / len(dataloader) | |
| return total_loss | |
| def validate(dataset, dataloader, loss_cr): | |
| total_loss = 0 | |
| for val_img, val_cap in tqdm(dataloader, total=len(dataloader)): | |
| val_img, val_cap = val_img.to(device), val_cap.to(device) | |
| output, attention = model(val_img, val_cap) | |
| targets = val_cap[:, 1:] | |
| loss = loss_cr(output.view(-1, vocab_size), targets.reshape(-1)) | |
| total_loss += (loss.item()) | |
| total_loss /= len(dataloader) | |
| return total_loss | |
| # for see results while training | |
| def test_on_img(data, dataloader): | |
| dataiter = iter(dataloader) | |
| img, cap = next(dataiter) | |
| features = model.EncoderCNN(img[0:1].to(device)) | |
| caps, alphas = model.DecoderLSTM.gen_captions(features, vocab=data.vocab) | |
| caption = ' '.join(caps) | |
| show_img(img[0], caption) | |
| def main(): | |
| best_val_loss = 6.0 | |
| for epoch in range(num_epochs): | |
| print(f"Epoch: {epoch + 1}/{num_epochs}") | |
| model.train() | |
| train_loss = training(dataset, dataloader_train, loss_fn, optimizer) | |
| train_loss_arr.append(train_loss) | |
| model.eval() | |
| val_loss = validate(dataset, dataloader_validation, loss_fn) | |
| val_loss_arr.append(val_loss) | |
| print(f"train_loss: {train_loss} validation_loss: {val_loss}") | |
| test_on_img(dataset, dataloader_validation) | |
| if len(val_loss_arr) == 1 or val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| save_model(model, epoch, optimizer, train_loss, val_loss, vocab=dataset.vocab) | |
| print("best model saved successfully") | |
| if __name__ == "__main__": | |
| print(torch.cuda.is_available()) | |
| main() | |