| | import streamlit as st |
| | import pandas as pd |
| | import numpy as np |
| | from sentence_transformers import SentenceTransformer |
| | from sklearn.metrics.pairwise import cosine_similarity |
| | import torch |
| | import json |
| | import os |
| | import glob |
| | from pathlib import Path |
| | from datetime import datetime |
| | import edge_tts |
| | import asyncio |
| | import requests |
| | from collections import defaultdict |
| | from audio_recorder_streamlit import audio_recorder |
| | import streamlit.components.v1 as components |
| | from urllib.parse import quote |
| | from xml.etree import ElementTree as ET |
| | from datasets import load_dataset |
| |
|
| | |
| | SESSION_VARS = { |
| | 'search_history': [], |
| | 'last_voice_input': "", |
| | 'transcript_history': [], |
| | 'should_rerun': False, |
| | 'search_columns': [], |
| | 'initial_search_done': False, |
| | 'tts_voice': "en-US-AriaNeural", |
| | 'arxiv_last_query': "", |
| | 'dataset_loaded': False, |
| | 'current_page': 0, |
| | 'data_cache': None, |
| | 'dataset_info': None |
| | } |
| |
|
| | |
| | ROWS_PER_PAGE = 100 |
| |
|
| | |
| | for var, default in SESSION_VARS.items(): |
| | if var not in st.session_state: |
| | st.session_state[var] = default |
| |
|
| | @st.cache_resource |
| | def get_model(): |
| | return SentenceTransformer('all-MiniLM-L6-v2') |
| |
|
| | @st.cache_data |
| | def load_dataset_page(dataset_id, token, page, rows_per_page): |
| | try: |
| | start_idx = page * rows_per_page |
| | end_idx = start_idx + rows_per_page |
| | dataset = load_dataset( |
| | dataset_id, |
| | token=token, |
| | streaming=False, |
| | split=f'train[{start_idx}:{end_idx}]' |
| | ) |
| | return pd.DataFrame(dataset) |
| | except Exception as e: |
| | st.error(f"Error loading page {page}: {str(e)}") |
| | return pd.DataFrame() |
| |
|
| | @st.cache_data |
| | def get_dataset_info(dataset_id, token): |
| | try: |
| | dataset = load_dataset(dataset_id, token=token, streaming=True) |
| | return dataset['train'].info |
| | except Exception as e: |
| | st.error(f"Error loading dataset info: {str(e)}") |
| | return None |
| |
|
| | def fetch_dataset_info(dataset_id): |
| | info_url = f"https://huggingface.co/api/datasets/{dataset_id}" |
| | try: |
| | response = requests.get(info_url, timeout=30) |
| | if response.status_code == 200: |
| | return response.json() |
| | except Exception as e: |
| | st.warning(f"Error fetching dataset info: {e}") |
| | return None |
| |
|
| | def fetch_dataset_rows(dataset_id, config="default", split="train", max_rows=100): |
| | url = f"https://datasets-server.huggingface.co/first-rows?dataset={dataset_id}&config={config}&split={split}" |
| | try: |
| | response = requests.get(url, timeout=30) |
| | if response.status_code == 200: |
| | data = response.json() |
| | if 'rows' in data: |
| | processed_rows = [] |
| | for row_data in data['rows']: |
| | row = row_data.get('row', row_data) |
| | |
| | for key in row: |
| | if any(term in key.lower() for term in ['embed', 'vector', 'encoding']): |
| | if isinstance(row[key], str): |
| | try: |
| | row[key] = [float(x.strip()) for x in row[key].strip('[]').split(',') if x.strip()] |
| | except: |
| | continue |
| | row['_config'] = config |
| | row['_split'] = split |
| | processed_rows.append(row) |
| | return processed_rows |
| | except Exception as e: |
| | st.warning(f"Error fetching rows: {e}") |
| | return [] |
| |
|
| | class FastDatasetSearcher: |
| | def __init__(self, dataset_id="tomg-group-umd/cinepile"): |
| | self.dataset_id = dataset_id |
| | self.text_model = get_model() |
| | self.token = os.environ.get('DATASET_KEY') |
| | if not self.token: |
| | st.error("Please set the DATASET_KEY environment variable") |
| | st.stop() |
| | |
| | if st.session_state['dataset_info'] is None: |
| | st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token) |
| |
|
| | def load_page(self, page=0): |
| | return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE) |
| |
|
| | def quick_search(self, query, df): |
| | """Enhanced search with strict token matching and semantic relevance""" |
| | if df.empty or not query.strip(): |
| | return df |
| | |
| | try: |
| | |
| | MIN_SEMANTIC_SCORE = 0.5 |
| | EXACT_MATCH_BOOST = 2.0 |
| | |
| | |
| | searchable_cols = [] |
| | for col in df.columns: |
| | sample_val = df[col].iloc[0] |
| | if not isinstance(sample_val, (np.ndarray, bytes)): |
| | searchable_cols.append(col) |
| | |
| | query_lower = query.lower() |
| | query_terms = set(query_lower.split()) |
| | query_embedding = self.text_model.encode([query], show_progress_bar=False)[0] |
| | |
| | scores = [] |
| | matched_any = [] |
| | |
| | for _, row in df.iterrows(): |
| | text_parts = [] |
| | row_matched = False |
| | exact_match = False |
| | |
| | |
| | priority_fields = ['description', 'matched_text'] |
| | other_fields = [col for col in searchable_cols if col not in priority_fields] |
| | |
| | |
| | for col in priority_fields: |
| | if col in row: |
| | val = row[col] |
| | if val is not None: |
| | val_str = str(val).lower() |
| | |
| | if query_lower in val_str.split(): |
| | exact_match = True |
| | if any(term in val_str.split() for term in query_terms): |
| | row_matched = True |
| | text_parts.append(str(val)) |
| | |
| | |
| | for col in other_fields: |
| | val = row[col] |
| | if val is not None: |
| | val_str = str(val).lower() |
| | if query_lower in val_str.split(): |
| | exact_match = True |
| | if any(term in val_str.split() for term in query_terms): |
| | row_matched = True |
| | text_parts.append(str(val)) |
| | |
| | text = ' '.join(text_parts) |
| | |
| | if text.strip(): |
| | |
| | text_tokens = set(text.lower().split()) |
| | matching_terms = query_terms.intersection(text_tokens) |
| | keyword_score = len(matching_terms) / len(query_terms) |
| | |
| | |
| | text_embedding = self.text_model.encode([text], show_progress_bar=False)[0] |
| | semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0]) |
| | |
| | |
| | combined_score = 0.8 * keyword_score + 0.2 * semantic_score |
| | |
| | if exact_match: |
| | combined_score *= EXACT_MATCH_BOOST |
| | elif row_matched: |
| | combined_score *= 1.2 |
| | else: |
| | combined_score = 0.0 |
| | row_matched = False |
| | |
| | scores.append(combined_score) |
| | matched_any.append(row_matched) |
| | |
| | results_df = df.copy() |
| | results_df['score'] = scores |
| | results_df['matched'] = matched_any |
| | |
| | |
| | filtered_df = results_df[ |
| | (results_df['matched']) | |
| | (results_df['score'] > MIN_KEYWORD_MATCHES) |
| | ] |
| | |
| | return filtered_df.sort_values('score', ascending=False) |
| | |
| | except Exception as e: |
| | st.error(f"Search error: {str(e)}") |
| | return df |
| |
|
| | class VideoSearch: |
| | def __init__(self): |
| | self.text_model = SentenceTransformer('all-MiniLM-L6-v2') |
| | self.dataset_id = "omegalabsinc/omega-multimodal" |
| | self.load_dataset() |
| | |
| | def fetch_dataset_rows(self): |
| | try: |
| | df, configs, splits = search_dataset( |
| | self.dataset_id, |
| | "", |
| | include_configs=None, |
| | include_splits=None |
| | ) |
| | |
| | if not df.empty: |
| | st.session_state['search_columns'] = [col for col in df.columns |
| | if col not in ['video_embed', 'description_embed', 'audio_embed'] |
| | and not col.startswith('_')] |
| | return df |
| | |
| | return self.load_example_data() |
| | |
| | except Exception as e: |
| | st.warning(f"Error loading videos: {e}") |
| | return self.load_example_data() |
| |
|
| | def load_example_data(self): |
| | example_data = [{ |
| | "video_id": "sample-123", |
| | "youtube_id": "dQw4w9WgXcQ", |
| | "description": "An example video", |
| | "views": 12345, |
| | "start_time": 0, |
| | "end_time": 60 |
| | }] |
| | return pd.DataFrame(example_data) |
| |
|
| | def load_dataset(self): |
| | self.dataset = self.fetch_dataset_rows() |
| | self.prepare_features() |
| |
|
| | def prepare_features(self): |
| | try: |
| | embed_cols = [col for col in self.dataset.columns |
| | if any(term in col.lower() for term in ['embed', 'vector', 'encoding'])] |
| | |
| | embeddings = {} |
| | for col in embed_cols: |
| | try: |
| | data = [] |
| | for row in self.dataset[col]: |
| | if isinstance(row, str): |
| | values = [float(x.strip()) for x in row.strip('[]').split(',') if x.strip()] |
| | elif isinstance(row, list): |
| | values = row |
| | else: |
| | continue |
| | data.append(values) |
| | |
| | if data: |
| | embeddings[col] = np.array(data) |
| | except: |
| | continue |
| | |
| | self.video_embeds = embeddings.get('video_embed', next(iter(embeddings.values())) if embeddings else None) |
| | self.text_embeds = embeddings.get('description_embed', self.video_embeds) |
| | |
| | except: |
| | num_rows = len(self.dataset) |
| | self.video_embeds = np.random.randn(num_rows, 384) |
| | self.text_embeds = np.random.randn(num_rows, 384) |
| |
|
| | def search(self, query, column=None, top_k=20): |
| | """Enhanced search with better relevance scoring""" |
| | MIN_RELEVANCE = 0.3 |
| | |
| | query_embedding = self.text_model.encode([query])[0] |
| | video_sims = cosine_similarity([query_embedding], self.video_embeds)[0] |
| | text_sims = cosine_similarity([query_embedding], self.text_embeds)[0] |
| | combined_sims = 0.7 * text_sims + 0.3 * video_sims |
| | |
| | if column and column in self.dataset.columns and column != "All Fields": |
| | |
| | matches = self.dataset[column].astype(str).str.contains(query, case=False) |
| | combined_sims[matches] *= 1.5 |
| | |
| | |
| | relevant_indices = np.where(combined_sims >= MIN_RELEVANCE)[0] |
| | if len(relevant_indices) == 0: |
| | return [] |
| | |
| | top_k = min(top_k, len(relevant_indices)) |
| | top_indices = relevant_indices[np.argsort(combined_sims[relevant_indices])[-top_k:][::-1]] |
| | |
| | results = [] |
| | for idx in top_indices: |
| | result = {'relevance_score': float(combined_sims[idx])} |
| | for col in self.dataset.columns: |
| | if col not in ['video_embed', 'description_embed', 'audio_embed']: |
| | result[col] = self.dataset.iloc[idx][col] |
| | results.append(result) |
| | |
| | return results |
| |
|
| | def search_dataset(dataset_id, search_text, include_configs=None, include_splits=None): |
| | dataset_info = fetch_dataset_info(dataset_id) |
| | if not dataset_info: |
| | return pd.DataFrame(), [], [] |
| | |
| | configs = include_configs if include_configs else dataset_info.get('config_names', ['default']) |
| | all_rows = [] |
| | available_splits = set() |
| | |
| | for config in configs: |
| | try: |
| | splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}&config={config}" |
| | splits_response = requests.get(splits_url, timeout=30) |
| | if splits_response.status_code == 200: |
| | splits_data = splits_response.json() |
| | splits = [split['split'] for split in splits_data.get('splits', [])] |
| | if not splits: |
| | splits = ['train'] |
| | |
| | if include_splits: |
| | splits = [s for s in splits if s in include_splits] |
| | |
| | available_splits.update(splits) |
| | |
| | for split in splits: |
| | rows = fetch_dataset_rows(dataset_id, config, split) |
| | for row in rows: |
| | text_content = ' '.join(str(v) for v in row.values() |
| | if isinstance(v, (str, int, float))) |
| | if search_text.lower() in text_content.lower(): |
| | row['_matched_text'] = text_content |
| | row['_relevance_score'] = text_content.lower().count(search_text.lower()) |
| | all_rows.append(row) |
| | except Exception as e: |
| | st.warning(f"Error processing config {config}: {e}") |
| | continue |
| | |
| | if all_rows: |
| | df = pd.DataFrame(all_rows) |
| | df = df.sort_values('_relevance_score', ascending=False) |
| | return df, configs, list(available_splits) |
| | |
| | return pd.DataFrame(), configs, list(available_splits) |
| |
|
| | @st.cache_resource |
| | def get_speech_model(): |
| | return edge_tts.Communicate |
| |
|
| | async def generate_speech(text, voice=None): |
| | if not text.strip(): |
| | return None |
| | if not voice: |
| | voice = st.session_state['tts_voice'] |
| | try: |
| | communicate = get_speech_model()(text, voice) |
| | audio_file = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" |
| | await communicate.save(audio_file) |
| | return audio_file |
| | except Exception as e: |
| | st.error(f"Error generating speech: {e}") |
| | return None |
| |
|
| | def transcribe_audio(audio_path): |
| | """Placeholder for ASR implementation""" |
| | return "ASR not implemented. Add your preferred speech recognition here!" |
| |
|
| | def arxiv_search(query, max_results=5): |
| | base_url = "http://export.arxiv.org/api/query?" |
| | search_url = base_url + f"search_query={quote(query)}&start=0&max_results={max_results}" |
| | try: |
| | r = requests.get(search_url) |
| | if r.status_code == 200: |
| | root = ET.fromstring(r.text) |
| | ns = {'atom': 'http://www.w3.org/2005/Atom'} |
| | entries = root.findall('atom:entry', ns) |
| | results = [] |
| | for entry in entries: |
| | title = entry.find('atom:title', ns).text.strip() |
| | summary = entry.find('atom:summary', ns).text.strip() |
| | link = next((l.get('href') for l in entry.findall('atom:link', ns) |
| | if l.get('type') == 'text/html'), None) |
| | results.append((title, summary, link)) |
| | return results |
| | except Exception as e: |
| | st.error(f"ArXiv search error: {e}") |
| | return [] |
| |
|
| | def show_file_manager(): |
| | st.subheader("π File Manager") |
| | col1, col2 = st.columns(2) |
| | |
| | with col1: |
| | uploaded_file = st.file_uploader("Upload File", type=['txt', 'md', 'mp3']) |
| | if uploaded_file: |
| | with open(uploaded_file.name, "wb") as f: |
| | f.write(uploaded_file.getvalue()) |
| | st.success(f"Uploaded: {uploaded_file.name}") |
| | st.experimental_rerun() |
| | |
| | with col2: |
| | if st.button("π Clear Files"): |
| | for f in glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3"): |
| | os.remove(f) |
| | st.success("All files cleared!") |
| | st.experimental_rerun() |
| | |
| | files = glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3") |
| | if files: |
| | st.write("### Existing Files") |
| | for f in files: |
| | with st.expander(f"π {os.path.basename(f)}"): |
| | if f.endswith('.mp3'): |
| | st.audio(f) |
| | else: |
| | with open(f, 'r', encoding='utf-8') as file: |
| | st.text_area("Content", file.read(), height=100) |
| | if st.button(f"Delete {os.path.basename(f)}", key=f"del_{f}"): |
| | os.remove(f) |
| | st.experimental_rerun() |
| |
|
| | def perform_arxiv_lookup(query, vocal_summary=True, titles_summary=True, full_audio=False): |
| | results = arxiv_search(query, max_results=5) |
| | if not results: |
| | st.write("No results found.") |
| | return |
| | |
| | st.markdown(f"**ArXiv Results for '{query}':**") |
| | for i, (title, summary, link) in enumerate(results, start=1): |
| | st.markdown(f"**{i}. {title}**") |
| | st.write(summary) |
| | if link: |
| | st.markdown(f"[View Paper]({link})") |
| |
|
| | if vocal_summary: |
| | spoken_text = f"Here are ArXiv results for {query}. " |
| | if titles_summary: |
| | spoken_text += " Titles: " + ", ".join([res[0] for res in results]) |
| | else: |
| | spoken_text += " " + results[0][1][:200] |
| |
|
| | audio_file = asyncio.run(generate_speech(spoken_text)) |
| | if audio_file: |
| | st.audio(audio_file) |
| | |
| | if full_audio: |
| | full_text = "" |
| | for i, (title, summary, _) in enumerate(results, start=1): |
| | full_text += f"Result {i}: {title}. {summary} " |
| | audio_file_full = asyncio.run(generate_speech(full_text)) |
| | if audio_file_full: |
| | st.write("### Full Audio Summary") |
| | st.audio(audio_file_full) |
| |
|
| | def render_result(result): |
| | """Render a search result with voice selection and TTS options""" |
| | score = result.get('relevance_score', 0) |
| | result_filtered = {k: v for k, v in result.items() |
| | if k not in ['relevance_score', 'video_embed', 'description_embed', 'audio_embed']} |
| | |
| | if 'youtube_id' in result: |
| | st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}") |
| | |
| | cols = st.columns([2, 1]) |
| | with cols[0]: |
| | text_content = [] |
| | for key, value in result_filtered.items(): |
| | if isinstance(value, (str, int, float)): |
| | st.write(f"**{key}:** {value}") |
| | if isinstance(value, str) and len(value.strip()) > 0: |
| | text_content.append(f"{key}: {value}") |
| | |
| | with cols[1]: |
| | st.metric("Relevance Score", f"{score:.2%}") |
| | |
| | |
| | voices = { |
| | "Aria (US Female)": "en-US-AriaNeural", |
| | "Guy (US Male)": "en-US-GuyNeural", |
| | "Sonia (UK Female)": "en-GB-SoniaNeural", |
| | "Tony (UK Male)": "en-GB-TonyNeural", |
| | "Jenny (US Female)": "en-US-JennyNeural" |
| | } |
| | |
| | selected_voice = st.selectbox( |
| | "Select Voice", |
| | list(voices.keys()), |
| | key=f"voice_{result.get('video_id', '')}" |
| | ) |
| | |
| | if st.button("π Read Description", key=f"read_{result.get('video_id', '')}"): |
| | text_to_read = ". ".join(text_content) |
| | audio_file = asyncio.run(generate_speech(text_to_read, voices[selected_voice])) |
| | if audio_file: |
| | st.audio(audio_file) |
| |
|
| | def main(): |
| | st.title("π₯ Advanced Video & Dataset Search with Voice") |
| | |
| | |
| | search = VideoSearch() |
| | |
| | |
| | tab1, tab2, tab3, tab4 = st.tabs([ |
| | "π Search", "ποΈ Voice Input", "π ArXiv", "π Files" |
| | ]) |
| | |
| | |
| | with tab1: |
| | st.subheader("Search Videos") |
| | col1, col2 = st.columns([3, 1]) |
| | with col1: |
| | query = st.text_input("Enter search query:", |
| | value="" if st.session_state['initial_search_done'] else "aliens") |
| | with col2: |
| | search_column = st.selectbox("Search in:", |
| | ["All Fields"] + st.session_state['search_columns']) |
| | |
| | col3, col4 = st.columns(2) |
| | with col3: |
| | num_results = st.slider("Max results:", 1, 100, 20) |
| | with col4: |
| | search_button = st.button("π Search") |
| | |
| | if (search_button or not st.session_state['initial_search_done']) and query: |
| | st.session_state['initial_search_done'] = True |
| | selected_column = None if search_column == "All Fields" else search_column |
| | |
| | with st.spinner("Searching..."): |
| | results = search.search(query, selected_column, num_results) |
| | |
| | if results: |
| | st.session_state['search_history'].append({ |
| | 'query': query, |
| | 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| | 'results': results[:5] |
| | }) |
| | |
| | st.write(f"Found {len(results)} results:") |
| | for i, result in enumerate(results, 1): |
| | with st.expander(f"Result {i}", expanded=(i==1)): |
| | render_result(result) |
| | else: |
| | st.warning("No matching results found.") |
| | |
| | |
| | with tab2: |
| | st.subheader("Voice Search") |
| | st.write("ποΈ Record your query:") |
| | audio_bytes = audio_recorder() |
| | if audio_bytes: |
| | with st.spinner("Processing audio..."): |
| | audio_path = f"temp_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav" |
| | with open(audio_path, "wb") as f: |
| | f.write(audio_bytes) |
| | |
| | voice_query = transcribe_audio(audio_path) |
| | st.markdown("**Transcribed Text:**") |
| | st.write(voice_query) |
| | st.session_state['last_voice_input'] = voice_query |
| | |
| | if st.button("π Search from Voice"): |
| | results = search.search(voice_query, None, 20) |
| | for i, result in enumerate(results, 1): |
| | with st.expander(f"Result {i}", expanded=(i==1)): |
| | render_result(result) |
| | |
| | if os.path.exists(audio_path): |
| | os.remove(audio_path) |
| | |
| | |
| | with tab3: |
| | st.subheader("ArXiv Search") |
| | arxiv_query = st.text_input("Search ArXiv:", value=st.session_state['arxiv_last_query']) |
| | vocal_summary = st.checkbox("π Quick Audio Summary", value=True) |
| | titles_summary = st.checkbox("π Titles Only", value=True) |
| | full_audio = st.checkbox("π Full Audio Summary", value=False) |
| | |
| | if st.button("π Search ArXiv"): |
| | st.session_state['arxiv_last_query'] = arxiv_query |
| | perform_arxiv_lookup(arxiv_query, vocal_summary, titles_summary, full_audio) |
| | |
| | |
| | with tab4: |
| | show_file_manager() |
| | |
| | |
| | with st.sidebar: |
| | st.subheader("βοΈ Settings & History") |
| | if st.button("ποΈ Clear History"): |
| | st.session_state['search_history'] = [] |
| | st.experimental_rerun() |
| | |
| | st.markdown("### Recent Searches") |
| | for entry in reversed(st.session_state['search_history'][-5:]): |
| | with st.expander(f"{entry['timestamp']}: {entry['query']}"): |
| | for i, result in enumerate(entry['results'], 1): |
| | st.write(f"{i}. {result.get('description', '')[:100]}...") |
| |
|
| | st.markdown("### Voice Settings") |
| | st.selectbox("TTS Voice:", [ |
| | "en-US-AriaNeural", |
| | "en-US-GuyNeural", |
| | "en-GB-SoniaNeural" |
| | ], key="tts_voice") |
| |
|
| | if __name__ == "__main__": |
| | main() |