juanludataanalyst's picture
Update app.py
3b78ebf verified
import gradio as gr
import torch
from sentence_transformers import CrossEncoder
import numpy as np
from typing import List, Tuple, Dict, Any
# Load the model
MODEL_NAME = "juanludataanalyst/bge-spanish-salon-crossencoder"
def load_model():
"""Load the cross-encoder model"""
try:
model = CrossEncoder(MODEL_NAME)
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Load model once at startup
model = load_model()
def predict_similarity(query: str, candidate: str) -> Dict[str, Any]:
"""
API function for similarity prediction - optimized for API calls
Args:
query: The input query/question
candidate: The candidate response/answer
Returns:
Dictionary with score, interpretation and success status
"""
if not model:
return {
"error": "Modelo no cargado correctamente",
"score": 0.0,
"interpretation": "error",
"success": False
}
if not query.strip() or not candidate.strip():
return {
"error": "Query y candidate son requeridos",
"score": 0.0,
"interpretation": "invalid_input",
"success": False
}
try:
# Predict similarity score
score = model.predict([query, candidate])
# Convert to float if it's a numpy array
if isinstance(score, np.ndarray):
score = float(score[0])
else:
score = float(score)
# Interpret the score
if score >= 0.8:
interpretation = "muy_relevante"
elif score >= 0.6:
interpretation = "moderadamente_relevante"
elif score >= 0.4:
interpretation = "poco_relevante"
else:
interpretation = "no_relevante"
return {
"score": round(score, 4),
"interpretation": interpretation,
"success": True,
"query": query,
"candidate": candidate
}
except Exception as e:
return {
"error": str(e),
"score": 0.0,
"interpretation": "error",
"success": False
}
def batch_predict(query: str, candidates: List[str]) -> Dict[str, Any]:
"""
API function for batch prediction - optimized for API calls
Args:
query: The input query/question
candidates: List of candidate responses
Returns:
Dictionary with ranked results
"""
if not model:
return {"error": "Modelo no cargado correctamente", "results": [], "success": False}
if not query.strip() or not candidates:
return {"error": "Query y candidates son requeridos", "results": [], "success": False}
try:
# Create pairs for batch prediction
pairs = [[query, candidate] for candidate in candidates]
scores = model.predict(pairs)
# Create results with scores
results = []
for i, (candidate, score) in enumerate(zip(candidates, scores)):
score_val = float(score) if isinstance(score, np.ndarray) else score
# Determine interpretation
if score_val >= 0.8:
interpretation = "muy_relevante"
elif score_val >= 0.6:
interpretation = "moderadamente_relevante"
elif score_val >= 0.4:
interpretation = "poco_relevante"
else:
interpretation = "no_relevante"
results.append({
"candidate": candidate,
"score": round(score_val, 4),
"interpretation": interpretation,
"original_index": i
})
# Sort by score (descending)
results.sort(key=lambda x: x["score"], reverse=True)
return {
"query": query,
"results": results,
"success": True,
"total_candidates": len(candidates)
}
except Exception as e:
return {"error": str(e), "results": [], "success": False}
# UI wrapper functions
def predict_similarity_ui(query: str, candidate: str) -> Tuple[float, str]:
"""UI wrapper for the similarity prediction"""
result = predict_similarity(query, candidate)
if not result["success"]:
return 0.0, f"❌ Error: {result['error']}"
# Add emoji for UI
interpretation_map = {
"muy_relevante": "🟢 Muy relevante",
"moderadamente_relevante": "🟡 Moderadamente relevante",
"poco_relevante": "🟠 Poco relevante",
"no_relevante": "🔴 No relevante"
}
return result["score"], interpretation_map.get(result["interpretation"], result["interpretation"])
def batch_predict_ui(query: str, candidates_text: str) -> str:
"""UI wrapper for batch prediction"""
if not candidates_text.strip():
return "⚠️ No se encontraron respuestas candidatas"
# Split candidates by newlines and filter empty ones
candidates = [c.strip() for c in candidates_text.split('\n') if c.strip()]
if not candidates:
return "⚠️ No se encontraron respuestas candidatas válidas"
# Call the API function
result = batch_predict(query, candidates)
if not result["success"]:
return f"❌ Error: {result['error']}"
# Format results for UI
output = f"**Consulta:** {result['query']}\n\n**Ranking de respuestas ({result['total_candidates']} candidatos):**\n\n"
for i, item in enumerate(result['results'], 1):
# Add emoji based on interpretation
emoji_map = {
"muy_relevante": "🟢",
"moderadamente_relevante": "🟡",
"poco_relevante": "🟠",
"no_relevante": "🔴"
}
emoji = emoji_map.get(item['interpretation'], "⚪")
output += f"{i}. {emoji} **Score: {item['score']:.4f}**\n"
output += f" {item['candidate']}\n\n"
return output
# Example data for the interface
example_queries = [
"¿Cuál es el horario de la peluquería?",
"¿Hacen tratamientos capilares?",
"¿Cuánto cuesta un corte de pelo?",
"¿Necesito cita previa?",
"¿Ofrecen descuentos para estudiantes?"
]
example_responses = [
"Abrimos de lunes a viernes de 9:00 a 18:00 y sábados de 9:00 a 14:00",
"Sí, ofrecemos diversos tratamientos capilares especializados",
"El precio de un corte básico es 25 euros",
"Sí, es necesario reservar cita con antelación",
"Ofrecemos descuento estudiantil del 10% con credencial válida"
]
# Create the Gradio interface
with gr.Blocks(title="BGE Spanish Salon Cross-Encoder", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎯 BGE Spanish Salon Cross-Encoder API
**Modelo Fine-tuneado para evaluar la relevancia entre consultas y respuestas de salones de belleza en español**
Este cross-encoder fue entrenado específicamente para emparejar preguntas de clientes con respuestas apropiadas
en el contexto de salones de belleza y peluquerías.
- **Base Model:** BAAI/bge-reranker-v2-m3
- **Entrenado en:** 75 pares de consulta-respuesta en español
- **Uso:** Determinar la relevancia entre una consulta y una respuesta candidata
## 🔗 **API Endpoints**
Una vez desplegado, este Space generará automáticamente endpoints REST:
- **Single Prediction:** `/predict_similarity`
- **Batch Prediction:** `/batch_predict`
Usa la pestaña "API" en Hugging Face para ver ejemplos de llamadas.
""")
with gr.Tab("🔍 Predicción Individual"):
with gr.Row():
with gr.Column():
query_input = gr.Textbox(
label="Consulta del Cliente",
placeholder="Ej: ¿Cuál es el horario de la peluquería?",
lines=2
)
candidate_input = gr.Textbox(
label="Respuesta Candidata",
placeholder="Ej: Abrimos de lunes a viernes de 9:00 a 18:00",
lines=2
)
predict_btn = gr.Button("🎯 Predecir Relevancia", variant="primary")
with gr.Column():
score_output = gr.Number(
label="Score de Similitud (0-1)",
precision=4
)
interpretation_output = gr.Textbox(
label="Interpretación",
lines=1
)
# Examples for individual prediction
gr.Examples(
examples=[
[example_queries[0], example_responses[0]],
[example_queries[1], example_responses[1]],
[example_queries[2], example_responses[2]]
],
inputs=[query_input, candidate_input]
)
with gr.Tab("📊 Ranking Multiple"):
with gr.Row():
with gr.Column():
batch_query_input = gr.Textbox(
label="Consulta del Cliente",
placeholder="Ej: ¿Hacen tratamientos capilares?",
lines=2
)
batch_candidates_input = gr.Textbox(
label="Respuestas Candidatas (una por línea)",
placeholder="Sí, ofrecemos tratamientos especializados\nNo, solo hacemos cortes básicos\nTenemos keratina y botox capilar disponible",
lines=8
)
batch_predict_btn = gr.Button("📊 Ranking de Respuestas", variant="primary")
with gr.Column():
batch_output = gr.Markdown(
label="Resultados Ordenados"
)
# Example for batch prediction
batch_example_candidates = "\n".join([
"Sí, ofrecemos diversos tratamientos capilares especializados",
"No, solo realizamos cortes y peinados básicos",
"Tenemos keratina, botox capilar y tratamientos de hidratación",
"Consulta nuestro catálogo de servicios en recepción"
])
gr.Examples(
examples=[["¿Hacen tratamientos capilares?", batch_example_candidates]],
inputs=[batch_query_input, batch_candidates_input]
)
with gr.Tab("🔗 Uso de API"):
gr.Markdown("""
### 📡 Llamadas a la API REST
Una vez desplegado, podrás hacer llamadas HTTP a los endpoints:
#### **Predicción Individual**
```bash
curl -X POST "https://tu-usuario-nombre-space.hf.space/api/predict" \\
-H "Content-Type: application/json" \\
-d '{
"data": ["¿Cuál es el horario?", "Abrimos de 9 a 18h"]
}'
```
#### **Predicción Batch**
```bash
curl -X POST "https://tu-usuario-nombre-space.hf.space/api/predict" \\
-H "Content-Type: application/json" \\
-d '{
"data": ["¿Hacen tratamientos?", ["Sí, ofrecemos keratina", "No, solo cortes", "Tratamientos disponibles"]]
}'
```
#### **Respuesta JSON**
```json
{
"data": {
"score": 0.8532,
"interpretation": "muy_relevante",
"success": true
}
}
```
### 🐍 Ejemplo en Python
```python
import requests
url = "https://tu-usuario-nombre-space.hf.space/api/predict"
# Predicción individual
response = requests.post(url, json={
"data": ["¿Cuánto cuesta un corte?", "El precio es 25 euros"]
})
result = response.json()
print(f"Score: {result['data']['score']}")
```
""")
with gr.Tab("ℹ️ Información del Modelo"):
gr.Markdown("""
### 📋 Detalles Técnicos
- **Arquitectura:** XLMRobertaForSequenceClassification
- **Parámetros:** ~560M
- **Secuencia Máxima:** 512 tokens
- **Framework:** sentence-transformers 5.1.0
- **Activación:** Sigmoid
- **Pérdida:** BinaryCrossEntropyLoss
### 🎯 Casos de Uso
1. **Sistemas de FAQ:** Encontrar la respuesta más relevante para una pregunta
2. **Chatbots:** Evaluar la calidad de respuestas generadas
3. **Búsqueda Semántica:** Rankear documentos por relevancia
4. **Control de Calidad:** Verificar coherencia entre preguntas y respuestas
### 📊 Interpretación de Scores
- **0.8 - 1.0:** 🟢 Muy relevante - Respuesta altamente apropiada
- **0.6 - 0.8:** 🟡 Moderadamente relevante - Respuesta parcialmente apropiada
- **0.4 - 0.6:** 🟠 Poco relevante - Respuesta tangencialmente relacionada
- **0.0 - 0.4:** 🔴 No relevante - Respuesta no apropiada
### 🔗 Enlaces
- [Modelo en Hugging Face](https://huggingface.co/juanludataanalyst/bge-spanish-salon-crossencoder)
- [Documentación BGE](https://huggingface.co/BAAI/bge-reranker-v2-m3)
- [Sentence Transformers](https://www.sbert.net/docs/usage/cross-encoder.html)
""")
# Connect the buttons to functions
predict_btn.click(
fn=predict_similarity_ui,
inputs=[query_input, candidate_input],
outputs=[score_output, interpretation_output]
)
batch_predict_btn.click(
fn=batch_predict_ui,
inputs=[batch_query_input, batch_candidates_input],
outputs=[batch_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch()