import torchvision.transforms as T from torch import optim from torch.utils.data import DataLoader from torch.utils.data import random_split from tqdm import tqdm from dataset import * from model import * from utils import * spacy_eng = spacy.load('en') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # init seed seed = torch.randint(100, (1,)) torch.manual_seed(seed) shuffle = True # src folders root_folder = "/content/flickr8k/Images" # change this csv_file = "/content/flickr8k/captions.txt" # change this # image transforms and augmentation transforms = T.Compose([ T.Resize(226), T.RandomCrop(224), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) # define dataset dataset = FlickrDataset(root_folder, csv_file, transforms) # split dataset val_size = 512 test_size = 256 train_size = len(dataset) - val_size - test_size train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size]) # Define data loader parameters num_workers = 4 pin_memory = True batch_size_train = 256 batch_size_val_test = 128 pad_idx = dataset.vocab.stoi[""] # define loaders dataloader_train = DataLoader(train_ds, batch_size=batch_size_train, pin_memory=pin_memory, num_workers=num_workers, shuffle=shuffle, collate_fn=CapsCollate(pad_idx=pad_idx, batch_first=True)) dataloader_validation = DataLoader(val_ds, batch_size=batch_size_val_test, pin_memory=pin_memory, num_workers=num_workers, shuffle=shuffle, collate_fn=CapsCollate(pad_idx=pad_idx, batch_first=True)) dataloader_test = DataLoader(test_ds, batch_size=batch_size_val_test, pin_memory=pin_memory, num_workers=num_workers, shuffle=shuffle, collate_fn=CapsCollate(pad_idx=pad_idx, batch_first=True)) # model parameters embed_wts, embed_size = load_embeding("/content/glove.42B.300d.txt", dataset.vocab) # change path vocab_size = len(dataset.vocab) attention_dim = 256 encoder_dim = 2048 decoder_dim = 512 fc_dims = 256 learning_rate = 5e-4 model = EncoderDecoder(embed_size, vocab_size, attention_dim, encoder_dim, decoder_dim, fc_dims, p=0.3, embeddings=embed_wts).to(device) loss_fn = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi[""]) optimizer = optim.Adam(params=model.parameters(), lr=learning_rate) # training parmeters num_epochs = 35 train_loss_arr = [] val_loss_arr = [] def training(dataset, dataloader, loss_criteria, optimize, grad_clip=5.): total_loss = 0 for i, (img, cap) in enumerate(tqdm(dataloader, total=len(dataloader))): img, cap = img.to(device), cap.to(device) optimize.zero_grad() output, attention = model(img, cap) targets = cap[:, 1:] loss = loss_criteria(output.view(-1, vocab_size), targets.reshape(-1)) total_loss += (loss.item()) loss.backward() if grad_clip: nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimize.step() total_loss = total_loss / len(dataloader) return total_loss @torch.no_grad() def validate(dataset, dataloader, loss_cr): total_loss = 0 for val_img, val_cap in tqdm(dataloader, total=len(dataloader)): val_img, val_cap = val_img.to(device), val_cap.to(device) output, attention = model(val_img, val_cap) targets = val_cap[:, 1:] loss = loss_cr(output.view(-1, vocab_size), targets.reshape(-1)) total_loss += (loss.item()) total_loss /= len(dataloader) return total_loss # for see results while training @torch.no_grad() def test_on_img(data, dataloader): dataiter = iter(dataloader) img, cap = next(dataiter) features = model.EncoderCNN(img[0:1].to(device)) caps, alphas = model.DecoderLSTM.gen_captions(features, vocab=data.vocab) caption = ' '.join(caps) show_img(img[0], caption) def main(): best_val_loss = 6.0 for epoch in range(num_epochs): print(f"Epoch: {epoch + 1}/{num_epochs}") model.train() train_loss = training(dataset, dataloader_train, loss_fn, optimizer) train_loss_arr.append(train_loss) model.eval() val_loss = validate(dataset, dataloader_validation, loss_fn) val_loss_arr.append(val_loss) print(f"train_loss: {train_loss} validation_loss: {val_loss}") test_on_img(dataset, dataloader_validation) if len(val_loss_arr) == 1 or val_loss < best_val_loss: best_val_loss = val_loss save_model(model, epoch, optimizer, train_loss, val_loss, vocab=dataset.vocab) print("best model saved successfully") if __name__ == "__main__": print(torch.cuda.is_available()) main()