|
|
|
|
|
|
|
|
""" |
|
|
DiMa_new — Tiny Gradio demo |
|
|
- Input: English sentence |
|
|
- Translate -> Russian |
|
|
- Detect candidates from gazetteer |
|
|
- Classify with MariaOls/DiMa_new |
|
|
- Output: ONLY candidates considered DM (or 'no DMs found') |
|
|
""" |
|
|
|
|
|
import json |
|
|
import re |
|
|
from typing import List, Tuple, Dict, Optional |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
from transformers import ( |
|
|
AutoTokenizer, AutoModelForSequenceClassification, pipeline |
|
|
) |
|
|
import re |
|
|
from gradio.themes.utils import colors, sizes |
|
|
import random |
|
|
|
|
|
THEME = gr.themes.Soft( |
|
|
primary_hue=colors.red, |
|
|
secondary_hue=colors.orange, |
|
|
neutral_hue=colors.gray, |
|
|
radius_size=sizes.radius_xxl, |
|
|
) |
|
|
THEME.set( |
|
|
body_background_fill="#FFF7F2", |
|
|
block_background_fill="#FFFFFF", |
|
|
block_border_color="#FFD6C2", |
|
|
block_border_width="1px", |
|
|
block_shadow="0 10px 30px rgba(255, 107, 53, 0.10)", |
|
|
input_background_fill="#FFFDFC", |
|
|
input_border_color="#FFC7B3", |
|
|
button_primary_background_fill="*primary_500", |
|
|
button_primary_background_fill_hover="*primary_600", |
|
|
button_primary_text_color="#FFFFFF", |
|
|
) |
|
|
|
|
|
CYRILLIC_RE = re.compile(r"[А-Яа-яЁё]") |
|
|
|
|
|
def is_russian(text: str) -> bool: |
|
|
return bool(CYRILLIC_RE.search(text or "")) |
|
|
|
|
|
MODEL_ID = "MariaOls/DiMa_new" |
|
|
THRESHOLD = 0.5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
translator = pipeline( |
|
|
task="translation_en_to_ru", |
|
|
model="Helsinki-NLP/opus-mt-en-ru", |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
translator_ru_en = pipeline( |
|
|
task="translation_ru_to_en", |
|
|
model="Helsinki-NLP/opus-mt-ru-en", |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
|
|
|
|
|
|
clf_tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) |
|
|
clf_mdl = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) |
|
|
clf_mdl.eval() |
|
|
|
|
|
|
|
|
def load_gazetteer(repo_id: str) -> List[str]: |
|
|
p = hf_hub_download(repo_id=repo_id, filename="assets/gazetteer.json") |
|
|
obj = json.load(open(p, "r", encoding="utf-8")) |
|
|
items = obj.get("items", []) |
|
|
|
|
|
return sorted({s for s in items if isinstance(s, str) and s.strip()}, key=lambda s: (-len(s), s)) |
|
|
|
|
|
GAZ = load_gazetteer(MODEL_ID) |
|
|
|
|
|
|
|
|
def split_sentences(text: str) -> List[str]: |
|
|
try: |
|
|
from razdel import sentenize |
|
|
return [s.text.strip() for s in sentenize(text) if s.text.strip()] |
|
|
except Exception: |
|
|
|
|
|
parts = re.split(r'(?<=[\.!\?…])\s+', text.strip()) |
|
|
return [p.strip() for p in parts if p.strip()] |
|
|
|
|
|
|
|
|
_RUS_PUNCT = set(list(" \t\r\n.,;:!?…()[]{}«»\"'“”„—-")) |
|
|
|
|
|
def _is_boundary(ch: Optional[str]) -> bool: |
|
|
return ch is None or ch in _RUS_PUNCT |
|
|
|
|
|
def detect_candidates_ci(text: str, gazetteer: List[str]) -> List[Tuple[int,int,str]]: |
|
|
""" |
|
|
Longest-first, no overlap, case-insensitive. |
|
|
Returns [(start, end, original_span), ...] in original text indices. |
|
|
""" |
|
|
low = text.lower() |
|
|
used = [False] * len(text) |
|
|
spans: List[Tuple[int,int,str]] = [] |
|
|
|
|
|
for cand in gazetteer: |
|
|
clow = cand.lower() |
|
|
start = 0 |
|
|
while True: |
|
|
i = low.find(clow, start) |
|
|
if i == -1: |
|
|
break |
|
|
j = i + len(clow) |
|
|
left_ch = low[i-1] if i-1 >= 0 else None |
|
|
right_ch = low[j] if j < len(low) else None |
|
|
if _is_boundary(left_ch) and _is_boundary(right_ch) and not any(used[i:j]): |
|
|
spans.append((i, j, text[i:j])) |
|
|
for k in range(i, j): |
|
|
used[k] = True |
|
|
start = j |
|
|
else: |
|
|
start = i + 1 |
|
|
spans.sort(key=lambda x: x[0]) |
|
|
return spans |
|
|
|
|
|
|
|
|
def mark_span(sentence: str, start: int, end: int) -> str: |
|
|
return sentence[:start] + "<cand> " + sentence[start:end] + " </cand>" + sentence[end:] |
|
|
|
|
|
@torch.no_grad() |
|
|
def classify_marked_batch(marked_texts: List[str]) -> List[float]: |
|
|
""" |
|
|
Returns prob_dm list aligned with marked_texts. |
|
|
""" |
|
|
if not marked_texts: |
|
|
return [] |
|
|
enc = clf_tok(marked_texts, return_tensors="pt", truncation=True, padding=True) |
|
|
out = clf_mdl(**enc) |
|
|
probs = out.logits.softmax(-1)[:, 1].tolist() |
|
|
return [float(p) for p in probs] |
|
|
|
|
|
|
|
|
def run_pipeline(user_text: str) -> tuple[str, str, str, str]: |
|
|
""" |
|
|
Acepta inglés o ruso. |
|
|
- Si detecta cirílico, toma el texto tal cual (ruso) y además lo traduce a EN para mostrar. |
|
|
- Si no detecta cirílico, asume EN, traduce a RU y clasifica en RU. |
|
|
Returns: |
|
|
pretty (solo candidatos DM o 'no DMs found'), |
|
|
ru_text (texto ruso para clasificación / display), |
|
|
en_text (traducción o texto original en inglés), |
|
|
info (debug). |
|
|
""" |
|
|
if not user_text or not user_text.strip(): |
|
|
return "no input", "", "", "" |
|
|
|
|
|
if is_russian(user_text): |
|
|
|
|
|
ru_text = user_text.strip() |
|
|
en_text = translator_ru_en(ru_text)[0]["translation_text"].strip() |
|
|
else: |
|
|
|
|
|
en_text = user_text.strip() |
|
|
ru_text = translator(en_text)[0]["translation_text"].strip() |
|
|
|
|
|
|
|
|
sents = split_sentences(ru_text) |
|
|
marked, mapping = [], [] |
|
|
for si, sent in enumerate(sents): |
|
|
spans = detect_candidates_ci(sent, GAZ) |
|
|
for (st, en, span) in spans: |
|
|
marked.append(mark_span(sent, st, en)) |
|
|
mapping.append((si, span)) |
|
|
|
|
|
probs = classify_marked_batch(marked) |
|
|
dm_candidates: List[str] = [] |
|
|
for (si, span), p in zip(mapping, probs): |
|
|
if p >= THRESHOLD: |
|
|
dm_candidates.append(span) |
|
|
|
|
|
|
|
|
seen = set() |
|
|
dm_candidates = [x for x in dm_candidates if not (x in seen or seen.add(x))] |
|
|
|
|
|
pretty = "🧡 " + " · ".join(dm_candidates) if dm_candidates else "no DMs found" |
|
|
info = f"RU: {ru_text}\nEN: {en_text}\nDMs: {len(dm_candidates)}" |
|
|
return pretty, ru_text, en_text, info |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=THEME, css=""" |
|
|
/* fondo suave con degradado */ |
|
|
.gradio-container { |
|
|
background: radial-gradient(1200px 600px at 80% -10%, #FFE7DE 0%, rgba(255,231,222,0) 60%) , |
|
|
linear-gradient(180deg, #FFF7F2 0%, #FFFFFF 60%); |
|
|
} |
|
|
|
|
|
/* títulos */ |
|
|
#title { text-align:center; } |
|
|
#title h1 { |
|
|
font-weight: 800; |
|
|
letter-spacing: .2px; |
|
|
color: #E53935; /* rojo principal */ |
|
|
} |
|
|
#subtitle { |
|
|
text-align:center; |
|
|
color: #FF7043; /* naranja suave */ |
|
|
margin-top: -8px; |
|
|
} |
|
|
|
|
|
/* componentes redonditos + sombras suaves */ |
|
|
.gr-box, .gr-panel, .gr-group { border-radius: 20px !important; } |
|
|
button, .gr-button { border-radius: 999px !important; } |
|
|
textarea, input, .gr-textbox { border-radius: 16px !important; } |
|
|
|
|
|
/* botones primarios con leve glow */ |
|
|
button.primary, .gr-button-primary { |
|
|
box-shadow: 0 8px 20px rgba(229,57,53,0.18); |
|
|
} |
|
|
button.primary:hover, .gr-button-primary:hover { |
|
|
box-shadow: 0 10px 28px rgba(229,57,53,0.25); |
|
|
} |
|
|
|
|
|
/* cajitas informativas */ |
|
|
.accordion { border-radius: 16px !important; overflow: hidden; } |
|
|
|
|
|
/* pill para el resultado */ |
|
|
#result-pill { |
|
|
border-radius: 999px; |
|
|
padding: 12px 18px; |
|
|
background: #FFE6DE; |
|
|
color: #D84315; |
|
|
font-weight: 700; |
|
|
display: inline-block; |
|
|
} |
|
|
""") as demo: |
|
|
gr.Markdown("<h1 id='title'>DiMa — Automatic Russian Discourse Marker Detector</h1>") |
|
|
gr.Markdown("<div id='subtitle'>English <i>or</i> Russian → detect candidates → show only DMs</div>") |
|
|
|
|
|
with gr.Row(): |
|
|
inp = gr.Textbox(label="English or Russian input", placeholder="e.g., In fact, we should probably leave now.", lines=3) |
|
|
with gr.Row(): |
|
|
btn = gr.Button("Check 🧡", variant="primary") |
|
|
with gr.Row(): |
|
|
out = gr.Textbox(label="Result (only DM candidates)", lines=1) |
|
|
with gr.Accordion("Show Russian translation", open=True): |
|
|
ru = gr.Textbox(label="Russian", interactive=False) |
|
|
with gr.Accordion("Show English translation", open=True): |
|
|
en = gr.Textbox(label="English", interactive=False) |
|
|
with gr.Accordion("Details", open=False): |
|
|
dbg = gr.Textbox(label="Debug", interactive=False) |
|
|
|
|
|
FUNNY_EXAMPLES = [ |
|
|
"By the way, isn't ChatGPT supposed to solve this better?", |
|
|
"Honestly, I can't read Russian.", |
|
|
"For example, a free donut would drastically improve my focus.", |
|
|
"Honestly, my code only runs on Tuesdays.", |
|
|
"Actually, no one cares about Russian language.", |
|
|
"In fact, this has nothing to do with AI.", |
|
|
"Кстати, где тут бесплатная пицца?", |
|
|
"Честно, я не умею читать по-русски.", |
|
|
"По-моему, «кальсотс» переоценены.", |
|
|
"Вообще-то, я пришла только за стикерами.", |
|
|
"Кажется, Wi-Fi работает только когда не нужен.", |
|
|
"Итак, мы согласны, что это лучший стенд?" |
|
|
] |
|
|
example_radio = gr.Radio(label="Try an example", choices=[], interactive=True) |
|
|
shuffle_btn = gr.Button("Shuffle examples 🔥") |
|
|
|
|
|
def _pick_examples(): |
|
|
return random.sample(FUNNY_EXAMPLES, k=4) |
|
|
|
|
|
def shuffle_examples(): |
|
|
return gr.update(choices=_pick_examples(), value=None) |
|
|
|
|
|
|
|
|
demo.load(fn=shuffle_examples, inputs=None, outputs=example_radio) |
|
|
|
|
|
|
|
|
shuffle_btn.click(fn=shuffle_examples, inputs=None, outputs=example_radio) |
|
|
|
|
|
|
|
|
example_radio.change(lambda s: s, inputs=example_radio, outputs=inp) |
|
|
|
|
|
btn.click(run_pipeline, inputs=[inp], outputs=[out, ru, en, dbg]) |
|
|
example_radio.change(run_pipeline, inputs=example_radio, outputs=[out, ru, en, dbg]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|