Cardiosense-AG commited on
Commit
bcde322
·
verified ·
1 Parent(s): 4e46ea8

Update src/explainability.py

Browse files
Files changed (1) hide show
  1. src/explainability.py +63 -103
src/explainability.py CHANGED
@@ -1,128 +1,88 @@
1
  # src/explainability.py
2
  from __future__ import annotations
 
 
 
 
 
3
 
4
  import math
5
- import os
6
  import re
7
- from functools import lru_cache
8
- from typing import Dict, List, Tuple
9
-
10
- # Streamlit is only used in optional helpers to avoid import-time overhead in non-UI contexts.
11
- try:
12
- import streamlit as st # noqa: F401
13
- except Exception:
14
- st = None # type: ignore
15
-
16
- # Prefer the same embedding model elsewhere in the Space
17
- E5_MODEL_ID = os.environ.get("E5_MODEL_ID", "intfloat/e5-base-v2")
18
 
19
- _WORD_RE = re.compile(r"[A-Za-z0-9\-\%]+")
20
- _STOPWORDS = {
21
- "the","a","an","and","or","but","if","in","on","at","by","for","to","of","with","without",
22
- "is","are","was","were","be","been","being","as","that","this","these","those","it","its",
23
- "patient","pt","hx","h/o","pmh","psh","ros","pe","labs","lab","imaging","plan","assessment",
24
- "subjective","objective","cc","chief","complaint"
25
- }
26
-
27
- def _tokens(s: str) -> List[str]:
28
- return [w.lower() for w in _WORD_RE.findall(s or "") if w and w.lower() not in _STOPWORDS]
29
 
30
  def segment_claims(text: str) -> List[str]:
31
- """Split free text into sentence-like claims."""
32
  if not text:
33
  return []
34
- t = text.replace("•", ". ").replace("\n", " ").replace(" - ", ". ")
35
- parts = re.split(r"(?<=[\.\?\!])\s+", t)
36
- claims = [p.strip() for p in parts if len(p.strip()) > 12]
37
- return claims[:16]
38
-
39
- @lru_cache(maxsize=1)
40
- def _get_e5():
41
- """Load sentence-transformers E5 model lazily."""
42
- from sentence_transformers import SentenceTransformer
43
- model = SentenceTransformer(E5_MODEL_ID)
44
- return model
45
-
46
- def _embed_query_passages(summary: str, claims: List[str]):
47
- """E5 uses 'query:' for queries and 'passage:' for documents."""
48
- model = _get_e5()
49
- q = f"query: {summary.strip()}"
50
- ps = [f"passage: {c.strip()}" for c in claims]
51
- import numpy as np # local import to avoid global dependency at parse-time
52
- qv = model.encode(q, normalize_embeddings=True)
53
- pvs = model.encode(ps, normalize_embeddings=True)
54
- return qv, pvs
55
-
56
- def _cos(a, b):
57
- import numpy as np
58
- return float(np.dot(a, b))
59
-
60
- def _idf(corpus_tokens: List[List[str]]) -> Dict[str, float]:
61
- N = max(1, len(corpus_tokens))
62
- df: Dict[str,int] = {}
63
- for toks in corpus_tokens:
64
- for t in set(toks):
65
- df[t] = df.get(t, 0) + 1
66
- return {t: math.log((N + 1) / (df[t] + 0.5)) + 1.0 for t in df}
67
 
68
  def _tf(tokens: List[str]) -> Dict[str, float]:
69
  tf: Dict[str, float] = {}
70
  for t in tokens:
71
  tf[t] = tf.get(t, 0.0) + 1.0
72
- L = float(len(tokens) or 1.0)
73
- return {t: tf[t] / L for t in tf}
74
-
75
- def l1_normalize(weights: Dict[str, float]) -> Dict[str, float]:
76
- s = sum(max(0.0, v) for v in weights.values())
77
- if s <= 0:
78
- return {k: 0.0 for k in weights}
79
- return {k: max(0.0, v) / s for k, v in weights.items()}
80
-
81
- def compute_referral_tokens_for_section(
82
- section_text: str,
83
- referral_summary: str,
84
- *,
85
- top_n: int = 4,
86
- min_weight: float = 0.03,
87
- ) -> List[Dict]:
88
- """Compute weighted referral tokens for a section based on post-hoc similarity.
89
 
90
- Steps:
91
- 1) Segment section into sentence-like claims.
92
- 2) Get E5 similarity between referral summary and each claim.
93
- 3) For each claim, compute TF-IDF over its tokens; weight each claim's tokens by sim.
94
- 4) Aggregate across claims; L1-normalize over the section.
95
- 5) Return top-N tokens as chips (token, weight).
96
- """
97
- claims = segment_claims(section_text)
98
  if not claims:
99
  return []
100
-
101
- # Tokenize claims for IDF
102
- claim_tokens = [_tokens(c) for c in claims]
103
- idf = _idf(claim_tokens)
104
-
105
- # Embed for similarity
106
- try:
107
- qv, pvs = _embed_query_passages(referral_summary, claims)
108
- sims = [_cos(qv, pv) for pv in pvs] # already normalized vectors
109
- except Exception:
110
- # If embedding fails (no internet or package missing), fall back to simple heuristics
111
- sims = [1.0 for _ in claims]
112
-
113
- # Weight tokens per-claim and aggregate
114
  agg: Dict[str, float] = {}
115
- for toks, sim in zip(claim_tokens, sims):
116
  tf = _tf(toks)
117
  for t, tv in tf.items():
118
- w = tv * idf.get(t, 1.0) * max(0.0, sim)
119
- agg[t] = agg.get(t, 0.0) + w
120
-
121
- agg = l1_normalize(agg)
122
- # Keep top_n tokens above cutoff
123
  ranked = sorted(agg.items(), key=lambda kv: kv[1], reverse=True)
124
- chips = [{"token": tok, "weight": round(w, 4)} for tok, w in ranked if w >= min_weight][:top_n]
125
- return chips
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
 
128
 
 
1
  # src/explainability.py
2
  from __future__ import annotations
3
+ """Explainability helpers (post-hoc only).
4
+
5
+ Provides deterministic "chips" extracted from assessment/plan text.
6
+ Caching by (case_id, section, text_hash) can be layered on top by the UI.
7
+ """
8
 
9
  import math
 
10
  import re
11
+ from typing import Dict, List
 
 
 
 
 
 
 
 
 
 
12
 
13
+ def _tokenize(s: str) -> List[str]:
14
+ s = s.lower()
15
+ # Keep simple alphanumerics
16
+ toks = re.findall(r"[a-z0-9]+", s)
17
+ return [t for t in toks if len(t) >= 3]
 
 
 
 
 
18
 
19
  def segment_claims(text: str) -> List[str]:
20
+ """Split text into claim-like sentences/lines."""
21
  if not text:
22
  return []
23
+ # Split by newline or period, keep moderately long segments
24
+ raw = re.split(r"[.\n]+", text)
25
+ claims = [c.strip() for c in raw if len(c.strip()) >= 12]
26
+ return claims[:10]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def _tf(tokens: List[str]) -> Dict[str, float]:
29
  tf: Dict[str, float] = {}
30
  for t in tokens:
31
  tf[t] = tf.get(t, 0.0) + 1.0
32
+ s = sum(tf.values()) or 1.0
33
+ for k in list(tf.keys()):
34
+ tf[k] = tf[k] / s
35
+ return tf
36
+
37
+ def _idf(docs: List[List[str]]) -> Dict[str, float]:
38
+ df: Dict[str, int] = {}
39
+ N = max(1, len(docs))
40
+ for doc in docs:
41
+ for t in set(doc):
42
+ df[t] = df.get(t, 0) + 1
43
+ return {t: math.log((N + 1) / (df_t + 1)) + 1.0 for t, df_t in df.items()}
 
 
 
 
 
44
 
45
+ def chips_from_text(text: str, top_n: int = 10, min_weight: float = 0.02) -> List[Dict[str, float]]:
46
+ """Generate top-n weighted tokens from text using simple TF-IDF."""
47
+ claims = segment_claims(text)
 
 
 
 
 
48
  if not claims:
49
  return []
50
+ docs = [_tokenize(c) for c in claims]
51
+ idf = _idf(docs)
52
+ # Weight tokens by TF * average claim-length proxy
 
 
 
 
 
 
 
 
 
 
 
53
  agg: Dict[str, float] = {}
54
+ for toks in docs:
55
  tf = _tf(toks)
56
  for t, tv in tf.items():
57
+ agg[t] = agg.get(t, 0.0) + tv * idf.get(t, 1.0)
58
+ # Normalize L1
59
+ s = sum(agg.values()) or 1.0
60
+ for k in list(agg.keys()):
61
+ agg[k] /= s
62
  ranked = sorted(agg.items(), key=lambda kv: kv[1], reverse=True)
63
+ return [{"token": tok, "weight": round(w, 4)} for tok, w in ranked if w >= min_weight][:top_n]
64
+
65
+ # --- V2 helpers (post-hoc only, deterministic) ---
66
+ def chip_cache_key(case_id: str, section: str, text: str) -> str:
67
+ """Deterministic cache key for explainability chips."""
68
+ import hashlib, json
69
+ blob = json.dumps({"case_id": case_id, "section": section, "text": text}, sort_keys=True).encode("utf-8")
70
+ return hashlib.sha256(blob).hexdigest()
71
+
72
+ def ensure_chip_schema(chips):
73
+ """Force a consistent chip schema: [{token, weight}] sorted by weight desc."""
74
+ if not isinstance(chips, (list, tuple)):
75
+ return []
76
+ norm = []
77
+ for c in chips:
78
+ if not isinstance(c, dict):
79
+ continue
80
+ tok = str(c.get("token", "")).strip()
81
+ w = float(c.get("weight", 0.0))
82
+ if tok:
83
+ norm.append({"token": tok, "weight": round(w, 4)})
84
+ norm.sort(key=lambda x: x["weight"], reverse=True)
85
+ return norm
86
 
87
 
88