File size: 3,425 Bytes
d7cc754
2420a7b
 
 
 
 
 
 
 
d7cc754
 
2420a7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7cc754
2420a7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
context_simplifier.py
---------------------
Optional pre-processor to shorten retrieved RAG evidence context before it’s
passed to MedGemma for SOUP generation.

Default: uses facebook/bart-large-cnn for summarization (seq2seq).
Optionally, you can set USE_MISTRAL = True to use Mistral-7B-Instruct
for summarization via text generation instead.
"""

import os
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

# ---------------------------------------------------------------------
# ⚙️  Configuration toggle
# ---------------------------------------------------------------------
USE_MISTRAL = os.getenv("USE_MISTRAL", "false").lower() in ("true", "1", "yes")

# ---------------------------------------------------------------------
# ✅  Load summarizer pipeline
# ---------------------------------------------------------------------
if USE_MISTRAL:
    # Decoder-only model path
    print("🔧 Loading Mistral-7B-Instruct as a text-generation summarizer...")
    model_id = "mistralai/Mistral-7B-Instruct-v0.2"

    _simplifier = pipeline(
        "text-generation",
        model=model_id,
        device_map="auto",
        torch_dtype="auto",
    )

    def simplify_rag_context(context_text: str, max_words: int = 400) -> str:
        """Simplify RAG context using a decoder-only model (Mistral)."""
        if not context_text.strip():
            return "No evidence context available."

        prompt = (
            "Simplify and condense the following clinical evidence context. "
            "Keep only essential numeric values, treatment names, and short "
            "recommendations relevant to decision-making. "
            f"Limit the summary to about {max_words} words.\n\n"
            f"{context_text}\n\nSimplified summary:"
        )
        try:
            out = _simplifier(
                prompt,
                max_new_tokens=int(max_words * 1.3),
                do_sample=False,
                temperature=0.0,
            )
            # Extract only the text after the summarization cue
            summary = out[0]["generated_text"].split("Simplified summary:")[-1].strip()
        except Exception as e:
            summary = f"[Simplification failed: {e}]"
        return summary

else:
    # True summarization model (encoder-decoder)
    print("🔧 Loading BART Large CNN summarizer (default)...")
    _simplifier = pipeline(
        "summarization",
        model="facebook/bart-large-cnn",
        device_map="auto",
        torch_dtype="auto",
    )

    def simplify_rag_context(context_text: str, max_words: int = 400) -> str:
        """Simplify RAG context using a true seq2seq summarizer (BART)."""
        if not context_text.strip():
            return "No evidence context available."

        prompt = (
            "Simplify and condense the following clinical evidence context. "
            "Keep key numeric values, drug names, and brief recommendations. "
            f"Limit to about {max_words} words.\n\n{context_text}"
        )
        try:
            result = _simplifier(
                prompt,
                max_length=int(max_words * 1.3),
                min_length=60,
                do_sample=False,
                temperature=0.0,
            )
            summary = result[0]["summary_text"].strip()
        except Exception as e:
            summary = f"[Simplification failed: {e}]"
        return summary