|
|
import gradio as gr |
|
|
import torch |
|
|
from sentence_transformers import CrossEncoder |
|
|
import numpy as np |
|
|
from typing import List, Tuple, Dict, Any |
|
|
|
|
|
|
|
|
MODEL_NAME = "juanludataanalyst/bge-spanish-salon-crossencoder" |
|
|
|
|
|
def load_model(): |
|
|
"""Load the cross-encoder model""" |
|
|
try: |
|
|
model = CrossEncoder(MODEL_NAME) |
|
|
return model |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
model = load_model() |
|
|
|
|
|
def predict_similarity(query: str, candidate: str) -> Dict[str, Any]: |
|
|
""" |
|
|
API function for similarity prediction - optimized for API calls |
|
|
|
|
|
Args: |
|
|
query: The input query/question |
|
|
candidate: The candidate response/answer |
|
|
|
|
|
Returns: |
|
|
Dictionary with score, interpretation and success status |
|
|
""" |
|
|
if not model: |
|
|
return { |
|
|
"error": "Modelo no cargado correctamente", |
|
|
"score": 0.0, |
|
|
"interpretation": "error", |
|
|
"success": False |
|
|
} |
|
|
|
|
|
if not query.strip() or not candidate.strip(): |
|
|
return { |
|
|
"error": "Query y candidate son requeridos", |
|
|
"score": 0.0, |
|
|
"interpretation": "invalid_input", |
|
|
"success": False |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
score = model.predict([query, candidate]) |
|
|
|
|
|
|
|
|
if isinstance(score, np.ndarray): |
|
|
score = float(score[0]) |
|
|
else: |
|
|
score = float(score) |
|
|
|
|
|
|
|
|
if score >= 0.8: |
|
|
interpretation = "muy_relevante" |
|
|
elif score >= 0.6: |
|
|
interpretation = "moderadamente_relevante" |
|
|
elif score >= 0.4: |
|
|
interpretation = "poco_relevante" |
|
|
else: |
|
|
interpretation = "no_relevante" |
|
|
|
|
|
return { |
|
|
"score": round(score, 4), |
|
|
"interpretation": interpretation, |
|
|
"success": True, |
|
|
"query": query, |
|
|
"candidate": candidate |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return { |
|
|
"error": str(e), |
|
|
"score": 0.0, |
|
|
"interpretation": "error", |
|
|
"success": False |
|
|
} |
|
|
|
|
|
def batch_predict(query: str, candidates: List[str]) -> Dict[str, Any]: |
|
|
""" |
|
|
API function for batch prediction - optimized for API calls |
|
|
|
|
|
Args: |
|
|
query: The input query/question |
|
|
candidates: List of candidate responses |
|
|
|
|
|
Returns: |
|
|
Dictionary with ranked results |
|
|
""" |
|
|
if not model: |
|
|
return {"error": "Modelo no cargado correctamente", "results": [], "success": False} |
|
|
|
|
|
if not query.strip() or not candidates: |
|
|
return {"error": "Query y candidates son requeridos", "results": [], "success": False} |
|
|
|
|
|
try: |
|
|
|
|
|
pairs = [[query, candidate] for candidate in candidates] |
|
|
scores = model.predict(pairs) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for i, (candidate, score) in enumerate(zip(candidates, scores)): |
|
|
score_val = float(score) if isinstance(score, np.ndarray) else score |
|
|
|
|
|
|
|
|
if score_val >= 0.8: |
|
|
interpretation = "muy_relevante" |
|
|
elif score_val >= 0.6: |
|
|
interpretation = "moderadamente_relevante" |
|
|
elif score_val >= 0.4: |
|
|
interpretation = "poco_relevante" |
|
|
else: |
|
|
interpretation = "no_relevante" |
|
|
|
|
|
results.append({ |
|
|
"candidate": candidate, |
|
|
"score": round(score_val, 4), |
|
|
"interpretation": interpretation, |
|
|
"original_index": i |
|
|
}) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x["score"], reverse=True) |
|
|
|
|
|
return { |
|
|
"query": query, |
|
|
"results": results, |
|
|
"success": True, |
|
|
"total_candidates": len(candidates) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e), "results": [], "success": False} |
|
|
|
|
|
|
|
|
def predict_similarity_ui(query: str, candidate: str) -> Tuple[float, str]: |
|
|
"""UI wrapper for the similarity prediction""" |
|
|
result = predict_similarity(query, candidate) |
|
|
|
|
|
if not result["success"]: |
|
|
return 0.0, f"❌ Error: {result['error']}" |
|
|
|
|
|
|
|
|
interpretation_map = { |
|
|
"muy_relevante": "🟢 Muy relevante", |
|
|
"moderadamente_relevante": "🟡 Moderadamente relevante", |
|
|
"poco_relevante": "🟠 Poco relevante", |
|
|
"no_relevante": "🔴 No relevante" |
|
|
} |
|
|
|
|
|
return result["score"], interpretation_map.get(result["interpretation"], result["interpretation"]) |
|
|
|
|
|
def batch_predict_ui(query: str, candidates_text: str) -> str: |
|
|
"""UI wrapper for batch prediction""" |
|
|
if not candidates_text.strip(): |
|
|
return "⚠️ No se encontraron respuestas candidatas" |
|
|
|
|
|
|
|
|
candidates = [c.strip() for c in candidates_text.split('\n') if c.strip()] |
|
|
|
|
|
if not candidates: |
|
|
return "⚠️ No se encontraron respuestas candidatas válidas" |
|
|
|
|
|
|
|
|
result = batch_predict(query, candidates) |
|
|
|
|
|
if not result["success"]: |
|
|
return f"❌ Error: {result['error']}" |
|
|
|
|
|
|
|
|
output = f"**Consulta:** {result['query']}\n\n**Ranking de respuestas ({result['total_candidates']} candidatos):**\n\n" |
|
|
|
|
|
for i, item in enumerate(result['results'], 1): |
|
|
|
|
|
emoji_map = { |
|
|
"muy_relevante": "🟢", |
|
|
"moderadamente_relevante": "🟡", |
|
|
"poco_relevante": "🟠", |
|
|
"no_relevante": "🔴" |
|
|
} |
|
|
|
|
|
emoji = emoji_map.get(item['interpretation'], "⚪") |
|
|
|
|
|
output += f"{i}. {emoji} **Score: {item['score']:.4f}**\n" |
|
|
output += f" {item['candidate']}\n\n" |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
example_queries = [ |
|
|
"¿Cuál es el horario de la peluquería?", |
|
|
"¿Hacen tratamientos capilares?", |
|
|
"¿Cuánto cuesta un corte de pelo?", |
|
|
"¿Necesito cita previa?", |
|
|
"¿Ofrecen descuentos para estudiantes?" |
|
|
] |
|
|
|
|
|
example_responses = [ |
|
|
"Abrimos de lunes a viernes de 9:00 a 18:00 y sábados de 9:00 a 14:00", |
|
|
"Sí, ofrecemos diversos tratamientos capilares especializados", |
|
|
"El precio de un corte básico es 25 euros", |
|
|
"Sí, es necesario reservar cita con antelación", |
|
|
"Ofrecemos descuento estudiantil del 10% con credencial válida" |
|
|
] |
|
|
|
|
|
|
|
|
with gr.Blocks(title="BGE Spanish Salon Cross-Encoder", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# 🎯 BGE Spanish Salon Cross-Encoder API |
|
|
|
|
|
**Modelo Fine-tuneado para evaluar la relevancia entre consultas y respuestas de salones de belleza en español** |
|
|
|
|
|
Este cross-encoder fue entrenado específicamente para emparejar preguntas de clientes con respuestas apropiadas |
|
|
en el contexto de salones de belleza y peluquerías. |
|
|
|
|
|
- **Base Model:** BAAI/bge-reranker-v2-m3 |
|
|
- **Entrenado en:** 75 pares de consulta-respuesta en español |
|
|
- **Uso:** Determinar la relevancia entre una consulta y una respuesta candidata |
|
|
|
|
|
## 🔗 **API Endpoints** |
|
|
|
|
|
Una vez desplegado, este Space generará automáticamente endpoints REST: |
|
|
- **Single Prediction:** `/predict_similarity` |
|
|
- **Batch Prediction:** `/batch_predict` |
|
|
|
|
|
Usa la pestaña "API" en Hugging Face para ver ejemplos de llamadas. |
|
|
""") |
|
|
|
|
|
with gr.Tab("🔍 Predicción Individual"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
query_input = gr.Textbox( |
|
|
label="Consulta del Cliente", |
|
|
placeholder="Ej: ¿Cuál es el horario de la peluquería?", |
|
|
lines=2 |
|
|
) |
|
|
candidate_input = gr.Textbox( |
|
|
label="Respuesta Candidata", |
|
|
placeholder="Ej: Abrimos de lunes a viernes de 9:00 a 18:00", |
|
|
lines=2 |
|
|
) |
|
|
predict_btn = gr.Button("🎯 Predecir Relevancia", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
score_output = gr.Number( |
|
|
label="Score de Similitud (0-1)", |
|
|
precision=4 |
|
|
) |
|
|
interpretation_output = gr.Textbox( |
|
|
label="Interpretación", |
|
|
lines=1 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
[example_queries[0], example_responses[0]], |
|
|
[example_queries[1], example_responses[1]], |
|
|
[example_queries[2], example_responses[2]] |
|
|
], |
|
|
inputs=[query_input, candidate_input] |
|
|
) |
|
|
|
|
|
with gr.Tab("📊 Ranking Multiple"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
batch_query_input = gr.Textbox( |
|
|
label="Consulta del Cliente", |
|
|
placeholder="Ej: ¿Hacen tratamientos capilares?", |
|
|
lines=2 |
|
|
) |
|
|
batch_candidates_input = gr.Textbox( |
|
|
label="Respuestas Candidatas (una por línea)", |
|
|
placeholder="Sí, ofrecemos tratamientos especializados\nNo, solo hacemos cortes básicos\nTenemos keratina y botox capilar disponible", |
|
|
lines=8 |
|
|
) |
|
|
batch_predict_btn = gr.Button("📊 Ranking de Respuestas", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
batch_output = gr.Markdown( |
|
|
label="Resultados Ordenados" |
|
|
) |
|
|
|
|
|
|
|
|
batch_example_candidates = "\n".join([ |
|
|
"Sí, ofrecemos diversos tratamientos capilares especializados", |
|
|
"No, solo realizamos cortes y peinados básicos", |
|
|
"Tenemos keratina, botox capilar y tratamientos de hidratación", |
|
|
"Consulta nuestro catálogo de servicios en recepción" |
|
|
]) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[["¿Hacen tratamientos capilares?", batch_example_candidates]], |
|
|
inputs=[batch_query_input, batch_candidates_input] |
|
|
) |
|
|
|
|
|
with gr.Tab("🔗 Uso de API"): |
|
|
gr.Markdown(""" |
|
|
### 📡 Llamadas a la API REST |
|
|
|
|
|
Una vez desplegado, podrás hacer llamadas HTTP a los endpoints: |
|
|
|
|
|
#### **Predicción Individual** |
|
|
```bash |
|
|
curl -X POST "https://tu-usuario-nombre-space.hf.space/api/predict" \\ |
|
|
-H "Content-Type: application/json" \\ |
|
|
-d '{ |
|
|
"data": ["¿Cuál es el horario?", "Abrimos de 9 a 18h"] |
|
|
}' |
|
|
``` |
|
|
|
|
|
#### **Predicción Batch** |
|
|
```bash |
|
|
curl -X POST "https://tu-usuario-nombre-space.hf.space/api/predict" \\ |
|
|
-H "Content-Type: application/json" \\ |
|
|
-d '{ |
|
|
"data": ["¿Hacen tratamientos?", ["Sí, ofrecemos keratina", "No, solo cortes", "Tratamientos disponibles"]] |
|
|
}' |
|
|
``` |
|
|
|
|
|
#### **Respuesta JSON** |
|
|
```json |
|
|
{ |
|
|
"data": { |
|
|
"score": 0.8532, |
|
|
"interpretation": "muy_relevante", |
|
|
"success": true |
|
|
} |
|
|
} |
|
|
``` |
|
|
|
|
|
### 🐍 Ejemplo en Python |
|
|
```python |
|
|
import requests |
|
|
|
|
|
url = "https://tu-usuario-nombre-space.hf.space/api/predict" |
|
|
|
|
|
# Predicción individual |
|
|
response = requests.post(url, json={ |
|
|
"data": ["¿Cuánto cuesta un corte?", "El precio es 25 euros"] |
|
|
}) |
|
|
|
|
|
result = response.json() |
|
|
print(f"Score: {result['data']['score']}") |
|
|
``` |
|
|
""") |
|
|
|
|
|
with gr.Tab("ℹ️ Información del Modelo"): |
|
|
gr.Markdown(""" |
|
|
### 📋 Detalles Técnicos |
|
|
|
|
|
- **Arquitectura:** XLMRobertaForSequenceClassification |
|
|
- **Parámetros:** ~560M |
|
|
- **Secuencia Máxima:** 512 tokens |
|
|
- **Framework:** sentence-transformers 5.1.0 |
|
|
- **Activación:** Sigmoid |
|
|
- **Pérdida:** BinaryCrossEntropyLoss |
|
|
|
|
|
### 🎯 Casos de Uso |
|
|
|
|
|
1. **Sistemas de FAQ:** Encontrar la respuesta más relevante para una pregunta |
|
|
2. **Chatbots:** Evaluar la calidad de respuestas generadas |
|
|
3. **Búsqueda Semántica:** Rankear documentos por relevancia |
|
|
4. **Control de Calidad:** Verificar coherencia entre preguntas y respuestas |
|
|
|
|
|
### 📊 Interpretación de Scores |
|
|
|
|
|
- **0.8 - 1.0:** 🟢 Muy relevante - Respuesta altamente apropiada |
|
|
- **0.6 - 0.8:** 🟡 Moderadamente relevante - Respuesta parcialmente apropiada |
|
|
- **0.4 - 0.6:** 🟠 Poco relevante - Respuesta tangencialmente relacionada |
|
|
- **0.0 - 0.4:** 🔴 No relevante - Respuesta no apropiada |
|
|
|
|
|
### 🔗 Enlaces |
|
|
|
|
|
- [Modelo en Hugging Face](https://huggingface.co/juanludataanalyst/bge-spanish-salon-crossencoder) |
|
|
- [Documentación BGE](https://huggingface.co/BAAI/bge-reranker-v2-m3) |
|
|
- [Sentence Transformers](https://www.sbert.net/docs/usage/cross-encoder.html) |
|
|
""") |
|
|
|
|
|
|
|
|
predict_btn.click( |
|
|
fn=predict_similarity_ui, |
|
|
inputs=[query_input, candidate_input], |
|
|
outputs=[score_output, interpretation_output] |
|
|
) |
|
|
|
|
|
batch_predict_btn.click( |
|
|
fn=batch_predict_ui, |
|
|
inputs=[batch_query_input, batch_candidates_input], |
|
|
outputs=[batch_output] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |