Cardiosense-AG commited on
Commit
073dfbe
·
verified ·
1 Parent(s): 0884be0

Create tests/test_hybrid_mapping.py

Browse files
Files changed (1) hide show
  1. tests/test_hybrid_mapping.py +101 -0
tests/test_hybrid_mapping.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/test_hybrid_mapping.py
2
+ from __future__ import annotations
3
+ import os, json, time, csv
4
+ from pathlib import Path
5
+ from typing import Dict, List
6
+
7
+ from src.ai_core import generate_soap_draft
8
+
9
+ # writable folder (created automatically if missing)
10
+ BASE_DIR = Path("/data/econsult/tests")
11
+ BASE_DIR.mkdir(parents=True, exist_ok=True)
12
+ CSV_PATH = BASE_DIR / "results.csv"
13
+
14
+ # ---------------------------------------------------------------------
15
+ # Example intake cases for validation
16
+ # ---------------------------------------------------------------------
17
+ CASES: List[Dict[str, str]] = [
18
+ {
19
+ "id": "lipids",
20
+ "age": "58",
21
+ "sex": "Male",
22
+ "specialist": "Cardiology",
23
+ "chief_complaint": "Exertional chest tightness for ~2 months",
24
+ "history": "Type 2 diabetes, hyperlipidemia, no rest pain, no syncope.",
25
+ "medications": "Atorvastatin 20 mg nightly; Metformin 1000 mg BID.",
26
+ "vitals": "BP 132/78, HR 72, BMI 29",
27
+ "labs": "LDL 155 mg/dL, A1C 7.8%, eGFR 52",
28
+ "comorbidities": "DM2, CKD3a, hyperlipidemia",
29
+ "question": "Should we escalate to high-intensity statin and start low-dose aspirin?",
30
+ },
31
+ {
32
+ "id": "ckd_dose",
33
+ "age": "63",
34
+ "sex": "Male",
35
+ "specialist": "Cardiology",
36
+ "chief_complaint": "Medication dosing in CKD3a",
37
+ "history": "63 y/o M with DM2, CKD3a, HTN; needs metformin and statin dosing guidance.",
38
+ "medications": "Atorvastatin 20 mg nightly; Metformin 1000 mg BID.",
39
+ "vitals": "BP 128/80 mmHg, HR 70 bpm",
40
+ "labs": "A1C 7.5%, eGFR 50 mL/min/1.73 m2",
41
+ "comorbidities": "DM2, CKD3a, HTN",
42
+ "question": "What are recommended statin intensity and metformin dosing for eGFR ≈ 50?",
43
+ },
44
+ ]
45
+
46
+ # ---------------------------------------------------------------------
47
+ # Helpers
48
+ # ---------------------------------------------------------------------
49
+ def count_annotated(meta: Dict[str, object]) -> int:
50
+ """Count total annotated bullets (assessment+plan)"""
51
+ ann = meta.get("annotated", {}) or {}
52
+ return len(ann.get("assessment_html", [])) + len(ann.get("plan_html", []))
53
+
54
+ def run_case(intake: Dict[str, str]) -> Dict[str, object]:
55
+ """Generate SOAP + mapping metrics for one case"""
56
+ t0 = time.perf_counter()
57
+ result = generate_soap_draft(intake, mode="mapping", rag_top_k=5, max_new_tokens=700)
58
+ t1 = time.perf_counter()
59
+
60
+ meta = result.meta
61
+ timings = meta.get("timings", {})
62
+ rec = {
63
+ "case_id": intake["id"],
64
+ "generate_secs": timings.get("generate_secs", 0),
65
+ "map_secs": timings.get("map_secs", 0),
66
+ "total_runtime": round(t1 - t0, 2),
67
+ "assessment_items": len(result.soap.get("assessment", [])),
68
+ "plan_items": len(result.soap.get("plan", [])),
69
+ "annotated_items": count_annotated(meta),
70
+ "unique_evidence": len(result.citations),
71
+ "cache_stub": meta.get("stub", ""),
72
+ }
73
+ # save raw JSON too
74
+ (BASE_DIR / f"{intake['id']}_result.json").write_text(
75
+ json.dumps(result.soap, ensure_ascii=False, indent=2)
76
+ )
77
+ return rec
78
+
79
+ def write_csv(rows: List[Dict[str, object]]) -> None:
80
+ keys = list(rows[0].keys())
81
+ with CSV_PATH.open("w", newline="", encoding="utf-8") as f:
82
+ w = csv.DictWriter(f, fieldnames=keys)
83
+ w.writeheader()
84
+ w.writerows(rows)
85
+
86
+ # ---------------------------------------------------------------------
87
+ # Entry point
88
+ # ---------------------------------------------------------------------
89
+ def run_all() -> str:
90
+ rows: List[Dict[str, object]] = []
91
+ for case in CASES:
92
+ print(f"Running case: {case['id']} ...")
93
+ rec = run_case(case)
94
+ rows.append(rec)
95
+ print(rec)
96
+ write_csv(rows)
97
+ print(f"\nResults saved to: {CSV_PATH}")
98
+ return str(CSV_PATH)
99
+
100
+ if __name__ == "__main__":
101
+ run_all()