| import gradio as gr |
| import os |
| import io |
| import png |
| import tensorflow as tf |
| import tensorflow_text as tf_text |
| import tensorflow_hub as tf_hub |
| import numpy as np |
| from PIL import Image |
| from huggingface_hub import snapshot_download |
| from sklearn.metrics.pairwise import cosine_similarity |
| import traceback |
| import time |
|
|
| |
| MODEL_REPO_ID = "google/cxr-foundation" |
| MODEL_DOWNLOAD_DIR = './hf_cxr_foundation_space' |
| SIMILARITY_DIFFERENCE_THRESHOLD = 0.1 |
| POSITIVE_SIMILARITY_THRESHOLD = 0.1 |
|
|
| print(f"Usando umbrales: Comp Δ={SIMILARITY_DIFFERENCE_THRESHOLD}, Simp τ={POSITIVE_SIMILARITY_THRESHOLD}") |
|
|
| |
| criteria_list_positive = [ |
| "optimal centering mediastinum", |
| "deep inspiration", |
| "adequate penetration", |
| "complete lung fields", |
| "scapulae retracted outside lungs", |
| "sharp contrast", |
| "artifact-free image" |
| ] |
| criteria_list_negative = [ |
| "poor centering", |
| "shallow inspiration", |
| "overexposed image", |
| "underexposed image", |
| "cropped lung fields", |
| "scapular overlay on lungs", |
| "blurred image with artifacts" |
| ] |
|
|
| |
| def bert_tokenize(text, preprocessor): |
| if preprocessor is None: |
| raise ValueError("BERT preprocessor no está cargado.") |
| text = str(text).lower() |
| out = preprocessor(tf.constant([text])) |
| ids = out['input_word_ids'].numpy().astype(np.int32) |
| masks = out['input_mask'].numpy().astype(np.float32) |
| paddings = 1.0 - masks |
| |
| end_token_idx = (ids == 102) |
| ids[end_token_idx] = 0 |
| paddings[end_token_idx] = 1.0 |
| |
| if ids.ndim == 2: ids = np.expand_dims(ids, 1) |
| if paddings.ndim == 2: paddings = np.expand_dims(paddings, 1) |
| return ids, paddings |
|
|
| def png_to_tfexample(image_array: np.ndarray) -> tf.train.Example: |
| |
| if image_array.ndim == 3 and image_array.shape[2] == 1: |
| image_array = np.squeeze(image_array, axis=2) |
| elif image_array.ndim != 2: |
| raise ValueError(f'Array debe ser 2-D. Dimensiones: {image_array.ndim}') |
| image = image_array.astype(np.float32) |
| min_val, max_val = image.min(), image.max() |
| if max_val <= min_val: |
| if image_array.dtype == np.uint8 or (min_val >= 0 and max_val <= 255): |
| pixel_array = image.astype(np.uint8); bitdepth = 8 |
| else: |
| pixel_array = np.zeros_like(image, dtype=np.uint16); bitdepth = 16 |
| else: |
| image -= min_val |
| current_max = max_val - min_val |
| if image_array.dtype != np.uint8: |
| image *= 65535 / current_max |
| pixel_array = image.astype(np.uint16); bitdepth = 16 |
| else: |
| image *= 255 / current_max |
| pixel_array = image.astype(np.uint8); bitdepth = 8 |
| output = io.BytesIO() |
| png.Writer(width=pixel_array.shape[1], height=pixel_array.shape[0], |
| greyscale=True, bitdepth=bitdepth).write(output, pixel_array.tolist()) |
| example = tf.train.Example() |
| features = example.features.feature |
| features['image/encoded'].bytes_list.value.append(output.getvalue()) |
| features['image/format'].bytes_list.value.append(b'png') |
| return example |
|
|
| def generate_image_embedding(img_np, elixrc_infer, qformer_infer): |
| if elixrc_infer is None or qformer_infer is None: |
| raise ValueError("Modelos ELIXR-C o QFormer no cargados.") |
| try: |
| serialized = png_to_tfexample(img_np).SerializeToString() |
| elixrc_out = elixrc_infer(input_example=tf.constant([serialized])) |
| elixr_emb = elixrc_out['feature_maps_0'].numpy() |
| q_in = { |
| 'image_feature': elixr_emb.tolist(), |
| 'ids': np.zeros((1,1,128),dtype=np.int32).tolist(), |
| 'paddings': np.ones((1,1,128),dtype=np.float32).tolist(), |
| } |
| q_out = qformer_infer(**q_in) |
| img_emb = q_out['all_contrastive_img_emb'].numpy() |
| if img_emb.ndim > 2: |
| img_emb = img_emb.mean(axis=tuple(range(1, img_emb.ndim-1))) |
| if img_emb.ndim == 1: |
| img_emb = img_emb[np.newaxis, :] |
| return img_emb |
| except Exception as e: |
| print(f"Error embedding imagen: {e}") |
| traceback.print_exc() |
| raise |
|
|
| def calculate_similarities_and_classify(image_embedding, bert_preprocessor, qformer_infer, |
| criteria_positive, criteria_negative): |
| results = {} |
| for pos, neg in zip(criteria_positive, criteria_negative): |
| sim_pos = sim_neg = diff = None |
| comp = simp = "ERROR" |
| try: |
| |
| ids_p, pad_p = bert_tokenize(pos, bert_preprocessor) |
| inp_p = {'image_feature': np.zeros([1,8,8,1376],dtype=np.float32).tolist(), |
| 'ids': ids_p.tolist(), 'paddings': pad_p.tolist()} |
| txt_p = qformer_infer(**inp_p)['contrastive_txt_emb'].numpy() |
| |
| ids_n, pad_n = bert_tokenize(neg, bert_preprocessor) |
| inp_n = {'image_feature': np.zeros([1,8,8,1376],dtype=np.float32).tolist(), |
| 'ids': ids_n.tolist(), 'paddings': pad_n.tolist()} |
| txt_n = qformer_infer(**inp_n)['contrastive_txt_emb'].numpy() |
|
|
| sim_pos = float(cosine_similarity(image_embedding, txt_p.reshape(1,-1))[0][0]) |
| sim_neg = float(cosine_similarity(image_embedding, txt_n.reshape(1,-1))[0][0]) |
| diff = sim_pos - sim_neg |
| comp = "PASS" if diff > SIMILARITY_DIFFERENCE_THRESHOLD else "FAIL" |
| simp = "PASS" if sim_pos > POSITIVE_SIMILARITY_THRESHOLD else "FAIL" |
| except Exception as e: |
| print(f"Error en criterio '{pos}': {e}") |
| results[pos] = { |
| 'positive_prompt': pos, |
| 'negative_prompt': neg, |
| 'sim_pos': sim_pos, |
| 'sim_neg': sim_neg, |
| 'difference': diff, |
| 'comp': comp, |
| 'simp': simp |
| } |
| return results |
|
|
| |
| print("--- Iniciando carga de modelos ---") |
| start_time = time.time() |
| models_loaded = False |
| bert_preproc = elixrc = qformer = None |
| try: |
| hf_token = os.environ.get("HF_TOKEN") |
| os.makedirs(MODEL_DOWNLOAD_DIR, exist_ok=True) |
| snapshot_download(repo_id=MODEL_REPO_ID, local_dir=MODEL_DOWNLOAD_DIR, |
| allow_patterns=['elixr-c-v2-pooled/*','pax-elixr-b-text/*'], |
| local_dir_use_symlinks=False, token=hf_token) |
| bert_preproc = tf_hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3") |
| elixr = tf.saved_model.load(os.path.join(MODEL_DOWNLOAD_DIR,'elixr-c-v2-pooled')).signatures['serving_default'] |
| qformer = tf.saved_model.load(os.path.join(MODEL_DOWNLOAD_DIR,'pax-elixr-b-text')).signatures['serving_default'] |
| models_loaded = True |
| print(f"Modelos cargados en {time.time()-start_time:.2f}s") |
| except Exception as e: |
| print("ERROR cargando modelos:", e) |
| traceback.print_exc() |
|
|
| |
| def assess_quality_and_update_ui(image_pil, pos_input, neg_input): |
| if not models_loaded: |
| raise gr.Error("No se pudieron cargar los modelos.") |
| if image_pil is None: |
| |
| return ( |
| gr.update(visible=True), |
| gr.update(visible=False), |
| None, |
| "N/A", |
| "", |
| {} |
| ) |
| |
| pos_list = [l.strip() for l in pos_input.splitlines() if l.strip()] |
| neg_list = [l.strip() for l in neg_input.splitlines() if l.strip()] |
| if len(pos_list) != len(neg_list): |
| raise gr.Error("El número de prompts positivos y negativos debe coincidir.") |
| |
| img_np = np.array(image_pil.convert('L')) |
| emb = generate_image_embedding(img_np, elixr, qformer) |
| |
| details = calculate_similarities_and_classify(emb, bert_preproc, qformer, pos_list, neg_list) |
| |
| passed = total = 0 |
| rows = "" |
| for crit, d in details.items(): |
| total += 1 |
| if d['comp']=="PASS": passed+=1 |
| c_style = "color:#22c55e;font-weight:bold;" if d['comp']=="PASS" else "color:#ef4444;font-weight:bold;" |
| s_style = "color:#22c55e;font-weight:bold;" if d['simp']=="PASS" else "color:#ef4444;font-weight:bold;" |
| rows += ( |
| f"<tr>" |
| f"<td>{crit}</td>" |
| f"<td>{d['sim_pos']:.4f}</td>" |
| f"<td>{d['sim_neg']:.4f}</td>" |
| f"<td>{d['difference']:.4f}</td>" |
| f"<td style='{c_style}'>{d['comp']}</td>" |
| f"<td style='{s_style}'>{d['simp']}</td>" |
| f"</tr>" |
| ) |
| html = f""" |
| <table style="width:100%;border-collapse:collapse;"> |
| <thead style="background:#f2f2f2;"> |
| <tr> |
| <th>Criterion</th><th>Sim (+)</th><th>Sim (-)</th><th>Diff</th> |
| <th>Assessment (Comp)</th><th>Assessment (Simp)</th> |
| </tr> |
| </thead> |
| <tbody>{rows}</tbody> |
| </table> |
| """ |
| |
| pass_rate = passed/total if total>0 else 0 |
| if pass_rate>=0.85: overall="Excellent" |
| elif pass_rate>=0.70: overall="Good" |
| elif pass_rate>=0.50: overall="Fair" |
| else: overall="Poor" |
| quality_label = f"{overall} ({passed}/{total} passed)" |
| |
| return ( |
| gr.update(visible=False), |
| gr.update(visible=True), |
| image_pil, |
| quality_label, |
| html, |
| details |
| ) |
|
|
| def reset_ui(): |
| return ( |
| gr.update(visible=True), |
| gr.update(visible=False), |
| None, |
| None, |
| "N/A", |
| "", |
| {} |
| ) |
|
|
| |
| dark_theme = gr.themes.Default( |
| primary_hue=gr.themes.colors.blue, |
| secondary_hue=gr.themes.colors.blue, |
| neutral_hue=gr.themes.colors.gray, |
| font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], |
| font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "Consolas", "monospace"], |
| ).set( |
| body_background_fill="#111827", |
| background_fill_primary="#1f2937", |
| background_fill_secondary="#374151", |
| block_background_fill="#1f2937", |
| body_text_color="#d1d5db", |
| block_label_text_color="#d1d5db", |
| block_title_text_color="#ffffff", |
| border_color_accent="#374151", |
| border_color_primary="#4b5563", |
| button_primary_background_fill="*primary_600", |
| button_primary_text_color="#ffffff", |
| button_secondary_background_fill="*neutral_700", |
| button_secondary_text_color="#ffffff", |
| input_background_fill="#374151", |
| input_border_color="#4b5563", |
| shadow_drop="rgba(0,0,0,0.2) 0px 2px 4px", |
| block_shadow="rgba(0,0,0,0.2) 0px 2px 5px", |
| ) |
|
|
| |
| with gr.Blocks(theme=dark_theme, title="CXR Quality Assessment") as demo: |
| |
| gr.Markdown(""" |
| # <span style="color: #e5e7eb;">CXR Quality Assessment</span> |
| <p style="color: #9ca3af;">Evalúa la calidad técnica de radiografías de tórax con AI</p> |
| """) |
| |
| with gr.Row(): |
| positive_prompts_input = gr.Textarea( |
| label="Prompts Positivos (uno por línea)", |
| value="\n".join(criteria_list_positive), |
| lines=7 |
| ) |
| negative_prompts_input = gr.Textarea( |
| label="Prompts Negativos (uno por línea)", |
| value="\n".join(criteria_list_negative), |
| lines=7 |
| ) |
| |
| with gr.Row(equal_height=False): |
| with gr.Column(scale=1, min_width=300): |
| gr.Markdown("### 1. Carga de Imagen") |
| input_image = gr.Image(type="pil", label="Sube tu CXR", height=300) |
| with gr.Row(): |
| analyze_btn = gr.Button("Analizar", variant="primary") |
| reset_btn = gr.Button("Reset", variant="secondary") |
| gr.Markdown("<p style='color:#9ca3af; font-size:0.9em;'>La carga de modelos tarda ~1 min; el análisis ~15–40 s.</p>") |
| with gr.Column(scale=2): |
| with gr.Column(visible=True) as welcome_block: |
| gr.Markdown("### ¡Bienvenido! Sube una radiografía y haz clic en «Analizar».") |
| with gr.Column(visible=False) as results_block: |
| gr.Markdown("### 2. Resultados") |
| with gr.Row(): |
| output_image = gr.Image(type="pil", label="Imagen Analizada", interactive=False) |
| with gr.Column(): |
| gr.Markdown("#### Calidad Global") |
| output_label = gr.Label(value="N/A") |
| gr.Markdown("#### Evaluación Detallada") |
| output_html = gr.HTML() |
| with gr.Accordion("Ver JSON (debug)", open=False): |
| output_json = gr.JSON() |
| |
| analyze_btn.click( |
| fn=assess_quality_and_update_ui, |
| inputs=[input_image, positive_prompts_input, negative_prompts_input], |
| outputs=[welcome_block, results_block, output_image, output_label, output_html, output_json] |
| ) |
| reset_btn.click( |
| fn=reset_ui, |
| inputs=None, |
| outputs=[welcome_block, results_block, input_image, output_image, output_label, output_html, output_json] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|