Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import json | |
| import pickle | |
| from PIL import Image | |
| import torch | |
| from torch.nn.utils.rnn import pad_sequence | |
| from transformers import BridgeTowerProcessor | |
| from tqdm import tqdm | |
| from bridgetower_custom import BridgeTowerTextFeatureExtractor, BridgeTowerForITC | |
| import faiss | |
| import webvtt | |
| from pytube import YouTube | |
| from youtube_transcript_api import YouTubeTranscriptApi | |
| from youtube_transcript_api.formatters import WebVTTFormatter | |
| if torch.cuda.is_available(): | |
| device = 'cuda' | |
| else: | |
| device = 'cpu' | |
| model_name = 'BridgeTower/bridgetower-large-itm-mlm-itc' | |
| model = BridgeTowerForITC.from_pretrained(model_name).to(device) | |
| text_model = BridgeTowerTextFeatureExtractor.from_pretrained(model_name).to(device) | |
| processor = BridgeTowerProcessor.from_pretrained(model_name) | |
| def download_video(video_url, path='/tmp/'): | |
| yt = YouTube(video_url) | |
| yt = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first() | |
| if not os.path.exists(path): | |
| os.makedirs(path) | |
| filepath = os.path.join(path, yt.default_filename) | |
| if not os.path.exists(filepath): | |
| print('Downloading video from YouTube...') | |
| yt.download(path) | |
| return filepath | |
| # Get transcript in webvtt | |
| def get_transcript_vtt(video_id, path='/tmp'): | |
| filepath = os.path.join(path,'test_vm.vtt') | |
| if os.path.exists(filepath): | |
| return filepath | |
| transcript = YouTubeTranscriptApi.get_transcript(video_id) | |
| formatter = WebVTTFormatter() | |
| webvtt_formatted = formatter.format_transcript(transcript) | |
| with open(filepath, 'w', encoding='utf-8') as webvtt_file: | |
| webvtt_file.write(webvtt_formatted) | |
| webvtt_file.close() | |
| return filepath | |
| # https://stackoverflow.com/a/57781047 | |
| # Resizes a image and maintains aspect ratio | |
| def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA): | |
| # Grab the image size and initialize dimensions | |
| dim = None | |
| (h, w) = image.shape[:2] | |
| # Return original image if no need to resize | |
| if width is None and height is None: | |
| return image | |
| # We are resizing height if width is none | |
| if width is None: | |
| # Calculate the ratio of the height and construct the dimensions | |
| r = height / float(h) | |
| dim = (int(w * r), height) | |
| # We are resizing width if height is none | |
| else: | |
| # Calculate the ratio of the width and construct the dimensions | |
| r = width / float(w) | |
| dim = (width, int(h * r)) | |
| # Return the resized image | |
| return cv2.resize(image, dim, interpolation=inter) | |
| def time_to_frame(time, fps): | |
| ''' | |
| convert time in seconds into frame number | |
| ''' | |
| return int(time * fps - 1) | |
| def str2time(strtime): | |
| strtime = strtime.strip('"') | |
| hrs, mins, seconds = [float(c) for c in strtime.split(':')] | |
| total_seconds = hrs * 60**2 + mins * 60 + seconds | |
| return total_seconds | |
| def collate_fn(batch_list): | |
| batch = {} | |
| batch['input_ids'] = pad_sequence([encoding['input_ids'].squeeze(0) for encoding in batch_list], batch_first=True) | |
| batch['attention_mask'] = pad_sequence([encoding['attention_mask'].squeeze(0) for encoding in batch_list], batch_first=True) | |
| batch['pixel_values'] = torch.cat([encoding['pixel_values'] for encoding in batch_list], dim=0) | |
| batch['pixel_mask'] = torch.cat([encoding['pixel_mask'] for encoding in batch_list], dim=0) | |
| return batch | |
| def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=False, batch_size=2, progress=gr.Progress()): | |
| if os.path.exists(os.path.join(output, 'embeddings.pkl')): | |
| return | |
| os.makedirs(output, exist_ok=True) | |
| os.makedirs(os.path.join(output, 'frames'), exist_ok=True) | |
| os.makedirs(os.path.join(output, 'frames_thumb'), exist_ok=True) | |
| count = 0 | |
| vidcap = cv2.VideoCapture(video_path) | |
| # Get the frames per second | |
| fps = vidcap.get(cv2.CAP_PROP_FPS) | |
| # Get the total numer of frames in the video. | |
| frame_count = vidcap.get(cv2.CAP_PROP_FRAME_COUNT) | |
| # print(fps, frame_count) | |
| frame_number = 0 | |
| count = 0 | |
| anno = [] | |
| embeddings = [] | |
| batch_list = [] | |
| vtt = webvtt.read(subtitles) | |
| for idx, caption in enumerate(tqdm(vtt, total=vtt.total_length, desc="Generating embeddings")): | |
| st_time = str2time(caption.start) | |
| ed_time = str2time(caption.end) | |
| mid_time = (ed_time + st_time) / 2 | |
| text = caption.text.replace('\n', ' ') | |
| if expanded : | |
| raise NotImplementedError | |
| frame_no = time_to_frame(mid_time, fps) | |
| mid_time_ms = mid_time * 1000 | |
| # vidcap.set(1, frame_no) # added this line | |
| vidcap.set(cv2.CAP_PROP_POS_MSEC, mid_time_ms) | |
| print('Read a new frame: ', idx, mid_time, frame_no, text) | |
| success, frame = vidcap.read() | |
| if success: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame = Image.fromarray(frame) | |
| img_fname = f'{video_id}_{idx:06d}' | |
| img_fpath = os.path.join(output, 'frames', img_fname + '.jpg') | |
| # image = maintain_aspect_ratio_resize(image, height=350) # save frame as JPEG file | |
| # cv2.imwrite( img_fpath, image) # save frame as JPEG file | |
| count += 1 | |
| anno.append({ | |
| 'image_id': idx, | |
| 'img_fname': img_fname, | |
| 'caption': text, | |
| 'time': mid_time_ms, | |
| 'frame_no': frame_no | |
| }) | |
| encoding = processor(frame, text, return_tensors="pt").to(device) | |
| encoding['text'] = text | |
| encoding['image_filepath'] = img_fpath | |
| encoding['start_time'] = caption.start | |
| encoding['time'] = mid_time_ms | |
| batch_list.append(encoding) | |
| else: | |
| break | |
| if len(batch_list) == batch_size: | |
| batch = collate_fn(batch_list) | |
| with torch.no_grad(): | |
| outputs = model(**batch, output_hidden_states=True) | |
| for i in range(batch_size): | |
| embeddings.append({ | |
| 'embeddings':outputs.logits[i,2,:].detach().cpu().numpy(), | |
| 'text': batch_list[i]['text'], | |
| 'image_filepath': batch_list[i]['image_filepath'], | |
| 'start_time': batch_list[i]['start_time'], | |
| 'time': batch_list[i]['time'], | |
| }) | |
| batch_list = [] | |
| if batch_list: | |
| batch = collate_fn(batch_list) | |
| with torch.no_grad(): | |
| outputs = model(**batch, output_hidden_states=True) | |
| for i in range(len(batch_list)): | |
| embeddings.append({ | |
| 'embeddings':outputs.logits[i,2,:].detach().cpu().numpy(), | |
| 'text': batch_list[i]['text'], | |
| 'image_filepath': batch_list[i]['image_filepath'], | |
| 'start_time': batch_list[i]['start_time'], | |
| 'time': batch_list[i]['time'], | |
| }) | |
| batch_list = [] | |
| with open(os.path.join(output, 'annotations.json'), 'w') as fh: | |
| json.dump(anno, fh) | |
| with open(os.path.join(output, 'embeddings.pkl'), 'wb') as fh: | |
| pickle.dump(embeddings, fh) | |
| def run_query(video_path, text_query, path='/tmp'): | |
| vidcap = cv2.VideoCapture(video_path) | |
| embeddings_filepath = os.path.join(path, 'embeddings.pkl') | |
| faiss_filepath = os.path.join(path, 'faiss_index.pkl') | |
| embeddings = pickle.load(open(embeddings_filepath, 'rb')) | |
| if os.path.exists(faiss_filepath): | |
| faiss_index = pickle.load(open(faiss_filepath, 'rb')) | |
| else : | |
| embs = [emb['embeddings'] for emb in embeddings] | |
| vectors = np.stack(embs, axis=0) | |
| num_vectors, vector_dim = vectors.shape | |
| faiss_index = faiss.IndexFlatIP(vector_dim) | |
| faiss_index.add(vectors) | |
| pickle.dump(faiss_index, open(faiss_filepath, 'wb')) | |
| print('Processing query') | |
| encoding = processor.tokenizer(text_query, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = text_model(**encoding) | |
| emb_query = outputs.cpu().numpy() | |
| print('Running FAISS search') | |
| _, I = faiss_index.search(emb_query, 6) | |
| clip_images = [] | |
| transcripts = [] | |
| for idx in I[0]: | |
| # frame_no = embeddings[idx]['frame_no'] | |
| # vidcap.set(1, frame_no) # added this line | |
| frame_timestamp = embeddings[idx]['time'] | |
| vidcap.set(cv2.CAP_PROP_POS_MSEC, frame_timestamp) | |
| success, frame = vidcap.read() | |
| if success: | |
| frame = maintain_aspect_ratio_resize(frame, height=400) | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame = Image.fromarray(frame) | |
| clip_images.append(frame) | |
| transcripts.append(f"({embeddings[idx]['start_time']}) {embeddings[idx]['text']}") | |
| return clip_images, transcripts | |
| #https://stackoverflow.com/a/7936523 | |
| def get_video_id_from_url(video_url): | |
| """ | |
| Examples: | |
| - http://youtu.be/SA2iWivDJiE | |
| - http://www.youtube.com/watch?v=_oPAwA_Udwc&feature=feedu | |
| - http://www.youtube.com/embed/SA2iWivDJiE | |
| - http://www.youtube.com/v/SA2iWivDJiE?version=3&hl=en_US | |
| """ | |
| import urllib.parse | |
| url = urllib.parse.urlparse(video_url) | |
| if url.hostname == 'youtu.be': | |
| return url.path[1:] | |
| if url.hostname in ('www.youtube.com', 'youtube.com'): | |
| if url.path == '/watch': | |
| p = urllib.parse.parse_qs(url.query) | |
| return p['v'][0] | |
| if url.path[:7] == '/embed/': | |
| return url.path.split('/')[2] | |
| if url.path[:3] == '/v/': | |
| return url.path.split('/')[2] | |
| return None | |
| def process(video_url, text_query, progress=gr.Progress(track_tqdm=True)): | |
| tmp_dir = os.environ.get('TMPDIR', '/tmp') | |
| video_id = get_video_id_from_url(video_url) | |
| output_dir = os.path.join(tmp_dir, video_id) | |
| video_file = download_video(video_url, path=output_dir) | |
| subtitles = get_transcript_vtt(video_id, path=output_dir) | |
| extract_images_and_embeds(video_id=video_id, | |
| video_path=video_file, | |
| subtitles=subtitles, | |
| output=output_dir, | |
| expanded=False, | |
| batch_size=8, | |
| progress=progress, | |
| ) | |
| frame_paths, transcripts = run_query(video_file, text_query, path=output_dir) | |
| return video_file, [(image, caption) for image, caption in zip(frame_paths, transcripts)] | |
| description = "This Space lets you run semantic search on a video." | |
| with gr.Blocks() as demo: | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_url = gr.Text(label="Youtube url") | |
| text_query = gr.Text(label="Text query") | |
| btn = gr.Button("Run query") | |
| video_player = gr.Video(label="Video") | |
| with gr.Row(): | |
| gallery = gr.Gallery(label="Images").style(grid=6) | |
| gr.Examples( | |
| examples=[ | |
| ['https://www.youtube.com/watch?v=CvjoXdC-WkM','wedding'], | |
| ['https://www.youtube.com/watch?v=fWs2dWcNGu0', 'cheesecake'], | |
| ['https://www.youtube.com/watch?v=rmPpNsx4yAk', 'bunny'], | |
| ['https://www.youtube.com/watch?v=KCFYf4TJdN0' ,'sandwich'], | |
| ], | |
| inputs=[video_url, text_query], | |
| ) | |
| btn.click(fn=process, | |
| inputs=[video_url, text_query], | |
| outputs=[video_player, gallery], | |
| ) | |
| try: | |
| demo.queue(concurrency_count=3) | |
| demo.launch(share=True) | |
| except: | |
| demo.launch() | |