import gradio as gr import torch from sentence_transformers import CrossEncoder import numpy as np from typing import List, Tuple, Dict, Any # Load the model MODEL_NAME = "juanludataanalyst/bge-spanish-salon-crossencoder" def load_model(): """Load the cross-encoder model""" try: model = CrossEncoder(MODEL_NAME) return model except Exception as e: print(f"Error loading model: {e}") return None # Load model once at startup model = load_model() def predict_similarity(query: str, candidate: str) -> Dict[str, Any]: """ API function for similarity prediction - optimized for API calls Args: query: The input query/question candidate: The candidate response/answer Returns: Dictionary with score, interpretation and success status """ if not model: return { "error": "Modelo no cargado correctamente", "score": 0.0, "interpretation": "error", "success": False } if not query.strip() or not candidate.strip(): return { "error": "Query y candidate son requeridos", "score": 0.0, "interpretation": "invalid_input", "success": False } try: # Predict similarity score score = model.predict([query, candidate]) # Convert to float if it's a numpy array if isinstance(score, np.ndarray): score = float(score[0]) else: score = float(score) # Interpret the score if score >= 0.8: interpretation = "muy_relevante" elif score >= 0.6: interpretation = "moderadamente_relevante" elif score >= 0.4: interpretation = "poco_relevante" else: interpretation = "no_relevante" return { "score": round(score, 4), "interpretation": interpretation, "success": True, "query": query, "candidate": candidate } except Exception as e: return { "error": str(e), "score": 0.0, "interpretation": "error", "success": False } def batch_predict(query: str, candidates: List[str]) -> Dict[str, Any]: """ API function for batch prediction - optimized for API calls Args: query: The input query/question candidates: List of candidate responses Returns: Dictionary with ranked results """ if not model: return {"error": "Modelo no cargado correctamente", "results": [], "success": False} if not query.strip() or not candidates: return {"error": "Query y candidates son requeridos", "results": [], "success": False} try: # Create pairs for batch prediction pairs = [[query, candidate] for candidate in candidates] scores = model.predict(pairs) # Create results with scores results = [] for i, (candidate, score) in enumerate(zip(candidates, scores)): score_val = float(score) if isinstance(score, np.ndarray) else score # Determine interpretation if score_val >= 0.8: interpretation = "muy_relevante" elif score_val >= 0.6: interpretation = "moderadamente_relevante" elif score_val >= 0.4: interpretation = "poco_relevante" else: interpretation = "no_relevante" results.append({ "candidate": candidate, "score": round(score_val, 4), "interpretation": interpretation, "original_index": i }) # Sort by score (descending) results.sort(key=lambda x: x["score"], reverse=True) return { "query": query, "results": results, "success": True, "total_candidates": len(candidates) } except Exception as e: return {"error": str(e), "results": [], "success": False} # UI wrapper functions def predict_similarity_ui(query: str, candidate: str) -> Tuple[float, str]: """UI wrapper for the similarity prediction""" result = predict_similarity(query, candidate) if not result["success"]: return 0.0, f"❌ Error: {result['error']}" # Add emoji for UI interpretation_map = { "muy_relevante": "🟢 Muy relevante", "moderadamente_relevante": "🟡 Moderadamente relevante", "poco_relevante": "🟠 Poco relevante", "no_relevante": "🔴 No relevante" } return result["score"], interpretation_map.get(result["interpretation"], result["interpretation"]) def batch_predict_ui(query: str, candidates_text: str) -> str: """UI wrapper for batch prediction""" if not candidates_text.strip(): return "⚠️ No se encontraron respuestas candidatas" # Split candidates by newlines and filter empty ones candidates = [c.strip() for c in candidates_text.split('\n') if c.strip()] if not candidates: return "⚠️ No se encontraron respuestas candidatas válidas" # Call the API function result = batch_predict(query, candidates) if not result["success"]: return f"❌ Error: {result['error']}" # Format results for UI output = f"**Consulta:** {result['query']}\n\n**Ranking de respuestas ({result['total_candidates']} candidatos):**\n\n" for i, item in enumerate(result['results'], 1): # Add emoji based on interpretation emoji_map = { "muy_relevante": "🟢", "moderadamente_relevante": "🟡", "poco_relevante": "🟠", "no_relevante": "🔴" } emoji = emoji_map.get(item['interpretation'], "⚪") output += f"{i}. {emoji} **Score: {item['score']:.4f}**\n" output += f" {item['candidate']}\n\n" return output # Example data for the interface example_queries = [ "¿Cuál es el horario de la peluquería?", "¿Hacen tratamientos capilares?", "¿Cuánto cuesta un corte de pelo?", "¿Necesito cita previa?", "¿Ofrecen descuentos para estudiantes?" ] example_responses = [ "Abrimos de lunes a viernes de 9:00 a 18:00 y sábados de 9:00 a 14:00", "Sí, ofrecemos diversos tratamientos capilares especializados", "El precio de un corte básico es 25 euros", "Sí, es necesario reservar cita con antelación", "Ofrecemos descuento estudiantil del 10% con credencial válida" ] # Create the Gradio interface with gr.Blocks(title="BGE Spanish Salon Cross-Encoder", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎯 BGE Spanish Salon Cross-Encoder API **Modelo Fine-tuneado para evaluar la relevancia entre consultas y respuestas de salones de belleza en español** Este cross-encoder fue entrenado específicamente para emparejar preguntas de clientes con respuestas apropiadas en el contexto de salones de belleza y peluquerías. - **Base Model:** BAAI/bge-reranker-v2-m3 - **Entrenado en:** 75 pares de consulta-respuesta en español - **Uso:** Determinar la relevancia entre una consulta y una respuesta candidata ## 🔗 **API Endpoints** Una vez desplegado, este Space generará automáticamente endpoints REST: - **Single Prediction:** `/predict_similarity` - **Batch Prediction:** `/batch_predict` Usa la pestaña "API" en Hugging Face para ver ejemplos de llamadas. """) with gr.Tab("🔍 Predicción Individual"): with gr.Row(): with gr.Column(): query_input = gr.Textbox( label="Consulta del Cliente", placeholder="Ej: ¿Cuál es el horario de la peluquería?", lines=2 ) candidate_input = gr.Textbox( label="Respuesta Candidata", placeholder="Ej: Abrimos de lunes a viernes de 9:00 a 18:00", lines=2 ) predict_btn = gr.Button("🎯 Predecir Relevancia", variant="primary") with gr.Column(): score_output = gr.Number( label="Score de Similitud (0-1)", precision=4 ) interpretation_output = gr.Textbox( label="Interpretación", lines=1 ) # Examples for individual prediction gr.Examples( examples=[ [example_queries[0], example_responses[0]], [example_queries[1], example_responses[1]], [example_queries[2], example_responses[2]] ], inputs=[query_input, candidate_input] ) with gr.Tab("📊 Ranking Multiple"): with gr.Row(): with gr.Column(): batch_query_input = gr.Textbox( label="Consulta del Cliente", placeholder="Ej: ¿Hacen tratamientos capilares?", lines=2 ) batch_candidates_input = gr.Textbox( label="Respuestas Candidatas (una por línea)", placeholder="Sí, ofrecemos tratamientos especializados\nNo, solo hacemos cortes básicos\nTenemos keratina y botox capilar disponible", lines=8 ) batch_predict_btn = gr.Button("📊 Ranking de Respuestas", variant="primary") with gr.Column(): batch_output = gr.Markdown( label="Resultados Ordenados" ) # Example for batch prediction batch_example_candidates = "\n".join([ "Sí, ofrecemos diversos tratamientos capilares especializados", "No, solo realizamos cortes y peinados básicos", "Tenemos keratina, botox capilar y tratamientos de hidratación", "Consulta nuestro catálogo de servicios en recepción" ]) gr.Examples( examples=[["¿Hacen tratamientos capilares?", batch_example_candidates]], inputs=[batch_query_input, batch_candidates_input] ) with gr.Tab("🔗 Uso de API"): gr.Markdown(""" ### 📡 Llamadas a la API REST Una vez desplegado, podrás hacer llamadas HTTP a los endpoints: #### **Predicción Individual** ```bash curl -X POST "https://tu-usuario-nombre-space.hf.space/api/predict" \\ -H "Content-Type: application/json" \\ -d '{ "data": ["¿Cuál es el horario?", "Abrimos de 9 a 18h"] }' ``` #### **Predicción Batch** ```bash curl -X POST "https://tu-usuario-nombre-space.hf.space/api/predict" \\ -H "Content-Type: application/json" \\ -d '{ "data": ["¿Hacen tratamientos?", ["Sí, ofrecemos keratina", "No, solo cortes", "Tratamientos disponibles"]] }' ``` #### **Respuesta JSON** ```json { "data": { "score": 0.8532, "interpretation": "muy_relevante", "success": true } } ``` ### 🐍 Ejemplo en Python ```python import requests url = "https://tu-usuario-nombre-space.hf.space/api/predict" # Predicción individual response = requests.post(url, json={ "data": ["¿Cuánto cuesta un corte?", "El precio es 25 euros"] }) result = response.json() print(f"Score: {result['data']['score']}") ``` """) with gr.Tab("ℹ️ Información del Modelo"): gr.Markdown(""" ### 📋 Detalles Técnicos - **Arquitectura:** XLMRobertaForSequenceClassification - **Parámetros:** ~560M - **Secuencia Máxima:** 512 tokens - **Framework:** sentence-transformers 5.1.0 - **Activación:** Sigmoid - **Pérdida:** BinaryCrossEntropyLoss ### 🎯 Casos de Uso 1. **Sistemas de FAQ:** Encontrar la respuesta más relevante para una pregunta 2. **Chatbots:** Evaluar la calidad de respuestas generadas 3. **Búsqueda Semántica:** Rankear documentos por relevancia 4. **Control de Calidad:** Verificar coherencia entre preguntas y respuestas ### 📊 Interpretación de Scores - **0.8 - 1.0:** 🟢 Muy relevante - Respuesta altamente apropiada - **0.6 - 0.8:** 🟡 Moderadamente relevante - Respuesta parcialmente apropiada - **0.4 - 0.6:** 🟠 Poco relevante - Respuesta tangencialmente relacionada - **0.0 - 0.4:** 🔴 No relevante - Respuesta no apropiada ### 🔗 Enlaces - [Modelo en Hugging Face](https://huggingface.co/juanludataanalyst/bge-spanish-salon-crossencoder) - [Documentación BGE](https://huggingface.co/BAAI/bge-reranker-v2-m3) - [Sentence Transformers](https://www.sbert.net/docs/usage/cross-encoder.html) """) # Connect the buttons to functions predict_btn.click( fn=predict_similarity_ui, inputs=[query_input, candidate_input], outputs=[score_output, interpretation_output] ) batch_predict_btn.click( fn=batch_predict_ui, inputs=[batch_query_input, batch_candidates_input], outputs=[batch_output] ) # Launch the app if __name__ == "__main__": demo.launch()