# tests/test_hybrid_mapping.py from __future__ import annotations import os, json, time, csv from pathlib import Path from typing import Dict, List from src.ai_core import generate_soap_draft BASE_DIR = Path("/data/econsult/tests") BASE_DIR.mkdir(parents=True, exist_ok=True) CSV_PATH = BASE_DIR / "results.csv" LOG_PATH = BASE_DIR / "run_logs.txt" CASES: List[Dict[str, str]] = [ { "id": "lipids", "age": "58", "sex": "Male", "specialist": "Cardiology", "chief_complaint": "Exertional chest tightness for ~2 months", "history": "Type 2 diabetes, hyperlipidemia, no rest pain, no syncope.", "medications": "Atorvastatin 20 mg nightly; Metformin 1000 mg BID.", "vitals": "BP 132/78, HR 72, BMI 29", "labs": "LDL 155 mg/dL, A1C 7.8%, eGFR 52", "comorbidities": "DM2, CKD3a, hyperlipidemia", "question": "Should we escalate to high-intensity statin and start low-dose aspirin?", }, { "id": "ckd_dose", "age": "63", "sex": "Male", "specialist": "Cardiology", "chief_complaint": "Medication dosing in CKD3a", "history": "63 y/o M with DM2, CKD3a, HTN; needs metformin and statin dosing guidance.", "medications": "Atorvastatin 20 mg nightly; Metformin 1000 mg BID.", "vitals": "BP 128/80 mmHg, HR 70 bpm", "labs": "A1C 7.5%, eGFR 50 mL/min/1.73 m2", "comorbidities": "DM2, CKD3a, HTN", "question": "What are recommended statin intensity and metformin dosing for eGFR ≈ 50?", }, ] # --------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------- def count_annotated(meta: Dict[str, object]) -> int: ann = meta.get("annotated", {}) or {} return len(ann.get("assessment_html", [])) + len(ann.get("plan_html", [])) def run_case(intake: Dict[str, str]) -> Dict[str, object]: t0 = time.perf_counter() result = generate_soap_draft(intake, mode="mapping", rag_top_k=5, max_new_tokens=700) t1 = time.perf_counter() meta = result.meta timings = meta.get("timings", {}) rec = { "case_id": intake["id"], "generate_secs": timings.get("generate_secs", 0), "map_secs": timings.get("map_secs", 0), "total_runtime": round(t1 - t0, 2), "assessment_items": len(result.soap.get("assessment", [])), "plan_items": len(result.soap.get("plan", [])), "annotated_items": count_annotated(meta), "unique_evidence": len(result.citations), "cache_stub": meta.get("stub", ""), } (BASE_DIR / f"{intake['id']}_result.json").write_text( json.dumps(result.soap, ensure_ascii=False, indent=2) ) return rec def write_csv(rows: List[Dict[str, object]]) -> None: if not rows: return keys = list(rows[0].keys()) with CSV_PATH.open("w", newline="", encoding="utf-8") as f: w = csv.DictWriter(f, fieldnames=keys) w.writeheader() w.writerows(rows) def save_logs(log_text: str) -> None: """Append captured console logs to persistent file.""" LOG_PATH.parent.mkdir(parents=True, exist_ok=True) with LOG_PATH.open("a", encoding="utf-8") as f: f.write("\n" + log_text.strip() + "\n") # --------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------- def run_all() -> str: rows: List[Dict[str, object]] = [] print("=== Hybrid Mapping Validation Run ===") for case in CASES: print(f"\n--- Running case: {case['id']} ---") rec = run_case(case) rows.append(rec) print(f"Result: {rec}") write_csv(rows) print(f"\nResults saved to: {CSV_PATH}") return str(CSV_PATH) if __name__ == "__main__": run_all()