Spaces:
Paused
Paused
dylanglenister commited on
Commit ·
57635b5
1
Parent(s): 464bdf6
Updated tests to match the new models.
Browse files- tests/test_account.py +23 -51
- tests/test_patient.py +14 -33
- tests/test_session.py +42 -39
tests/test_account.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import unittest
|
| 2 |
from datetime import datetime
|
| 3 |
from unittest.mock import patch
|
|
@@ -7,6 +8,7 @@ from pymongo.errors import ConnectionFailure
|
|
| 7 |
|
| 8 |
from src.data.connection import ActionFailed, Collections, get_collection
|
| 9 |
from src.data.repositories import account as account_repo
|
|
|
|
| 10 |
from src.utils.logger import logger
|
| 11 |
from tests.base_test import BaseMongoTest
|
| 12 |
|
|
@@ -23,24 +25,20 @@ class TestAccountRepository(BaseMongoTest):
|
|
| 23 |
def test_init_functionality(self):
|
| 24 |
"""Test the init function's ability to create, drop, and preserve collections."""
|
| 25 |
self.assertIn(self.test_collection, self.db.list_collection_names())
|
| 26 |
-
|
| 27 |
-
account_id = account_repo.create_account("Persist Test", "Doctor", collection_name=self.test_collection)
|
| 28 |
account_repo.init(collection_name=self.test_collection, drop=False)
|
| 29 |
self.assertEqual(get_collection(self.test_collection).count_documents({}), 1)
|
| 30 |
-
# Test that data is deleted when drop=True
|
| 31 |
account_repo.init(collection_name=self.test_collection, drop=True)
|
| 32 |
self.assertEqual(get_collection(self.test_collection).count_documents({}), 0)
|
| 33 |
|
| 34 |
def test_create_account(self):
|
| 35 |
"""Test successful account creation, including optional fields."""
|
| 36 |
-
# Test basic creation
|
| 37 |
name, role = "Test Doctor", "Doctor"
|
| 38 |
account_id = account_repo.create_account(name=name, role=role, collection_name=self.test_collection)
|
| 39 |
self.assertIsInstance(account_id, str)
|
| 40 |
doc = self.get_doc_by_id(Collections.ACCOUNT, account_id)
|
| 41 |
self.assertIsNotNone(doc)
|
| 42 |
self.assertEqual(doc["name"], name) # type: ignore
|
| 43 |
-
# Test creation with specialty
|
| 44 |
spec_id = account_repo.create_account("Spec", "Nurse", specialty="Cardiology", collection_name=self.test_collection)
|
| 45 |
spec_doc = self.get_doc_by_id(Collections.ACCOUNT, spec_id)
|
| 46 |
self.assertEqual(spec_doc["specialty"], "Cardiology") # type: ignore
|
|
@@ -50,8 +48,6 @@ class TestAccountRepository(BaseMongoTest):
|
|
| 50 |
account_id = account_repo.create_account("Update Logic", "Doctor", collection_name=self.test_collection)
|
| 51 |
original_doc = self.get_doc_by_id(Collections.ACCOUNT, account_id)
|
| 52 |
self.assertIsNotNone(original_doc)
|
| 53 |
-
|
| 54 |
-
# Test that 'created_at' is immutable
|
| 55 |
updates = {"name": "Updated Name", "created_at": datetime(2000, 1, 1)}
|
| 56 |
success = account_repo.update_account(account_id, updates, collection_name=self.test_collection)
|
| 57 |
self.assertTrue(success)
|
|
@@ -59,80 +55,56 @@ class TestAccountRepository(BaseMongoTest):
|
|
| 59 |
self.assertIsNotNone(updated_doc)
|
| 60 |
self.assertEqual(updated_doc["created_at"], original_doc["created_at"]) # type: ignore
|
| 61 |
self.assertLess(original_doc["updated_at"], updated_doc["updated_at"]) # type: ignore
|
| 62 |
-
|
| 63 |
-
# Test updating a non-existent account returns False
|
| 64 |
self.assertFalse(account_repo.update_account(str(ObjectId()), {"name": "No One"}, collection_name=self.test_collection))
|
| 65 |
|
| 66 |
def test_get_account_logic(self):
|
| 67 |
-
"""Test that get_account updates 'last_seen'
|
| 68 |
account_id = account_repo.create_account("GetMe", "Doctor", collection_name=self.test_collection)
|
| 69 |
original_doc = self.get_doc_by_id(Collections.ACCOUNT, account_id)
|
| 70 |
self.assertIsNotNone(original_doc)
|
| 71 |
-
|
| 72 |
|
| 73 |
-
# Get the account, which should add 'last_seen'
|
| 74 |
account = account_repo.get_account(account_id, collection_name=self.test_collection)
|
| 75 |
self.assertIsNotNone(account)
|
| 76 |
-
self.
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
self.assertIsNotNone(final_doc)
|
| 81 |
-
self.assertEqual(final_doc["updated_at"], original_doc["updated_at"]) # type: ignore
|
| 82 |
|
| 83 |
def test_get_account_by_name(self):
|
| 84 |
"""Test retrieving an account by name and check for deprecation warning."""
|
| 85 |
name = "FindByName"
|
| 86 |
account_repo.create_account(name, "Nurse", collection_name=self.test_collection)
|
| 87 |
-
|
| 88 |
-
# Check that the function works and raises the expected warning
|
| 89 |
account = account_repo.get_account_by_name(name, collection_name=self.test_collection)
|
| 90 |
self.assertIsNotNone(account)
|
| 91 |
-
self.
|
| 92 |
-
|
| 93 |
-
# Test retrieval of a non-existent name returns None
|
| 94 |
self.assertIsNone(account_repo.get_account_by_name("NonExistent", collection_name=self.test_collection))
|
| 95 |
|
| 96 |
def test_search_accounts(self):
|
| 97 |
-
"""Test search functionality
|
| 98 |
account_repo.create_account("Alpha Doctor", "Doctor", collection_name=self.test_collection)
|
| 99 |
account_repo.create_account("Beta Nurse", "Nurse", collection_name=self.test_collection)
|
| 100 |
-
account_repo.
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
self.assertEqual(
|
| 104 |
-
# Test query matching multiple documents
|
| 105 |
-
self.assertEqual(len(account_repo.search_accounts("Doctor", collection_name=self.test_collection)), 2)
|
| 106 |
-
# Test limit parameter
|
| 107 |
-
self.assertEqual(len(account_repo.search_accounts("Doctor", limit=1, collection_name=self.test_collection)), 1)
|
| 108 |
-
# Test query with no matches
|
| 109 |
self.assertEqual(len(account_repo.search_accounts("NonExistent", collection_name=self.test_collection)), 0)
|
| 110 |
-
# Test empty query string returns an empty list
|
| 111 |
-
self.assertEqual(len(account_repo.search_accounts("", collection_name=self.test_collection)), 0)
|
| 112 |
|
| 113 |
def test_get_all_accounts(self):
|
| 114 |
-
"""Test retrieving all accounts, verifying sorting and
|
| 115 |
account_repo.create_account("Charlie", "Doctor", collection_name=self.test_collection)
|
| 116 |
account_repo.create_account("Alpha", "Nurse", collection_name=self.test_collection)
|
| 117 |
-
account_repo.create_account("Beta", "Caregiver", collection_name=self.test_collection)
|
| 118 |
-
|
| 119 |
all_accounts = account_repo.get_all_accounts(collection_name=self.test_collection)
|
| 120 |
-
self.assertEqual(len(all_accounts),
|
| 121 |
-
|
| 122 |
-
self.assertEqual(all_accounts[0]
|
| 123 |
-
self.assertEqual(all_accounts[
|
| 124 |
-
|
| 125 |
-
# Test with a limit
|
| 126 |
-
limited_accounts = account_repo.get_all_accounts(limit=2, collection_name=self.test_collection)
|
| 127 |
-
self.assertEqual(len(limited_accounts), 2)
|
| 128 |
-
self.assertEqual(limited_accounts[1]["name"], "Beta")
|
| 129 |
|
| 130 |
def test_get_account_frame(self):
|
| 131 |
-
"""Test retrieving accounts as a pandas DataFrame
|
| 132 |
-
# Test with an empty collection
|
| 133 |
df_empty = account_repo.get_account_frame(collection_name=self.test_collection)
|
| 134 |
self.assertTrue(df_empty.empty)
|
| 135 |
-
# Test with data
|
| 136 |
account_repo.create_account("Frame Alpha", "Doctor", collection_name=self.test_collection)
|
| 137 |
df_full = account_repo.get_account_frame(collection_name=self.test_collection)
|
| 138 |
self.assertEqual(len(df_full), 1)
|
|
@@ -187,4 +159,4 @@ class TestAccountRepositoryExceptions(BaseMongoTest):
|
|
| 187 |
if __name__ == "__main__":
|
| 188 |
logger().info("Starting MongoDB repository integration tests...")
|
| 189 |
unittest.main(verbosity=2)
|
| 190 |
-
logger().info("Tests completed
|
|
|
|
| 1 |
+
import time
|
| 2 |
import unittest
|
| 3 |
from datetime import datetime
|
| 4 |
from unittest.mock import patch
|
|
|
|
| 8 |
|
| 9 |
from src.data.connection import ActionFailed, Collections, get_collection
|
| 10 |
from src.data.repositories import account as account_repo
|
| 11 |
+
from src.models.account import Account
|
| 12 |
from src.utils.logger import logger
|
| 13 |
from tests.base_test import BaseMongoTest
|
| 14 |
|
|
|
|
| 25 |
def test_init_functionality(self):
|
| 26 |
"""Test the init function's ability to create, drop, and preserve collections."""
|
| 27 |
self.assertIn(self.test_collection, self.db.list_collection_names())
|
| 28 |
+
account_repo.create_account("Persist Test", "Doctor", collection_name=self.test_collection)
|
|
|
|
| 29 |
account_repo.init(collection_name=self.test_collection, drop=False)
|
| 30 |
self.assertEqual(get_collection(self.test_collection).count_documents({}), 1)
|
|
|
|
| 31 |
account_repo.init(collection_name=self.test_collection, drop=True)
|
| 32 |
self.assertEqual(get_collection(self.test_collection).count_documents({}), 0)
|
| 33 |
|
| 34 |
def test_create_account(self):
|
| 35 |
"""Test successful account creation, including optional fields."""
|
|
|
|
| 36 |
name, role = "Test Doctor", "Doctor"
|
| 37 |
account_id = account_repo.create_account(name=name, role=role, collection_name=self.test_collection)
|
| 38 |
self.assertIsInstance(account_id, str)
|
| 39 |
doc = self.get_doc_by_id(Collections.ACCOUNT, account_id)
|
| 40 |
self.assertIsNotNone(doc)
|
| 41 |
self.assertEqual(doc["name"], name) # type: ignore
|
|
|
|
| 42 |
spec_id = account_repo.create_account("Spec", "Nurse", specialty="Cardiology", collection_name=self.test_collection)
|
| 43 |
spec_doc = self.get_doc_by_id(Collections.ACCOUNT, spec_id)
|
| 44 |
self.assertEqual(spec_doc["specialty"], "Cardiology") # type: ignore
|
|
|
|
| 48 |
account_id = account_repo.create_account("Update Logic", "Doctor", collection_name=self.test_collection)
|
| 49 |
original_doc = self.get_doc_by_id(Collections.ACCOUNT, account_id)
|
| 50 |
self.assertIsNotNone(original_doc)
|
|
|
|
|
|
|
| 51 |
updates = {"name": "Updated Name", "created_at": datetime(2000, 1, 1)}
|
| 52 |
success = account_repo.update_account(account_id, updates, collection_name=self.test_collection)
|
| 53 |
self.assertTrue(success)
|
|
|
|
| 55 |
self.assertIsNotNone(updated_doc)
|
| 56 |
self.assertEqual(updated_doc["created_at"], original_doc["created_at"]) # type: ignore
|
| 57 |
self.assertLess(original_doc["updated_at"], updated_doc["updated_at"]) # type: ignore
|
|
|
|
|
|
|
| 58 |
self.assertFalse(account_repo.update_account(str(ObjectId()), {"name": "No One"}, collection_name=self.test_collection))
|
| 59 |
|
| 60 |
def test_get_account_logic(self):
|
| 61 |
+
"""Test that get_account updates 'last_seen' and returns a valid Account model."""
|
| 62 |
account_id = account_repo.create_account("GetMe", "Doctor", collection_name=self.test_collection)
|
| 63 |
original_doc = self.get_doc_by_id(Collections.ACCOUNT, account_id)
|
| 64 |
self.assertIsNotNone(original_doc)
|
| 65 |
+
time.sleep(0.01) # Ensure timestamp will be different
|
| 66 |
|
|
|
|
| 67 |
account = account_repo.get_account(account_id, collection_name=self.test_collection)
|
| 68 |
self.assertIsNotNone(account)
|
| 69 |
+
self.assertIsInstance(account, Account)
|
| 70 |
+
self.assertLess(original_doc["last_seen"], account.last_seen) # type: ignore
|
| 71 |
+
self.assertEqual(original_doc["updated_at"], account.updated_at) # type: ignore
|
| 72 |
+
self.assertIsNone(account_repo.get_account(str(ObjectId()), collection_name=self.test_collection))
|
|
|
|
|
|
|
| 73 |
|
| 74 |
def test_get_account_by_name(self):
|
| 75 |
"""Test retrieving an account by name and check for deprecation warning."""
|
| 76 |
name = "FindByName"
|
| 77 |
account_repo.create_account(name, "Nurse", collection_name=self.test_collection)
|
|
|
|
|
|
|
| 78 |
account = account_repo.get_account_by_name(name, collection_name=self.test_collection)
|
| 79 |
self.assertIsNotNone(account)
|
| 80 |
+
self.assertIsInstance(account, Account)
|
| 81 |
+
self.assertEqual(account.name, name) # type: ignore
|
|
|
|
| 82 |
self.assertIsNone(account_repo.get_account_by_name("NonExistent", collection_name=self.test_collection))
|
| 83 |
|
| 84 |
def test_search_accounts(self):
|
| 85 |
+
"""Test search functionality returns a list of Account models."""
|
| 86 |
account_repo.create_account("Alpha Doctor", "Doctor", collection_name=self.test_collection)
|
| 87 |
account_repo.create_account("Beta Nurse", "Nurse", collection_name=self.test_collection)
|
| 88 |
+
results = account_repo.search_accounts("alpha", collection_name=self.test_collection)
|
| 89 |
+
self.assertEqual(len(results), 1)
|
| 90 |
+
self.assertIsInstance(results[0], Account)
|
| 91 |
+
self.assertEqual(results[0].name, "Alpha Doctor")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
self.assertEqual(len(account_repo.search_accounts("NonExistent", collection_name=self.test_collection)), 0)
|
|
|
|
|
|
|
| 93 |
|
| 94 |
def test_get_all_accounts(self):
|
| 95 |
+
"""Test retrieving all accounts, verifying sorting and model type."""
|
| 96 |
account_repo.create_account("Charlie", "Doctor", collection_name=self.test_collection)
|
| 97 |
account_repo.create_account("Alpha", "Nurse", collection_name=self.test_collection)
|
|
|
|
|
|
|
| 98 |
all_accounts = account_repo.get_all_accounts(collection_name=self.test_collection)
|
| 99 |
+
self.assertEqual(len(all_accounts), 2)
|
| 100 |
+
self.assertIsInstance(all_accounts[0], Account)
|
| 101 |
+
self.assertEqual(all_accounts[0].name, "Alpha")
|
| 102 |
+
self.assertEqual(all_accounts[1].name, "Charlie")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
def test_get_account_frame(self):
|
| 105 |
+
"""Test retrieving accounts as a pandas DataFrame."""
|
|
|
|
| 106 |
df_empty = account_repo.get_account_frame(collection_name=self.test_collection)
|
| 107 |
self.assertTrue(df_empty.empty)
|
|
|
|
| 108 |
account_repo.create_account("Frame Alpha", "Doctor", collection_name=self.test_collection)
|
| 109 |
df_full = account_repo.get_account_frame(collection_name=self.test_collection)
|
| 110 |
self.assertEqual(len(df_full), 1)
|
|
|
|
| 159 |
if __name__ == "__main__":
|
| 160 |
logger().info("Starting MongoDB repository integration tests...")
|
| 161 |
unittest.main(verbosity=2)
|
| 162 |
+
logger().info("Tests completed.")
|
tests/test_patient.py
CHANGED
|
@@ -6,6 +6,7 @@ from pymongo.errors import ConnectionFailure
|
|
| 6 |
|
| 7 |
from src.data.connection import ActionFailed, Collections, get_collection
|
| 8 |
from src.data.repositories import patient as patient_repo
|
|
|
|
| 9 |
from src.utils.logger import logger
|
| 10 |
from tests.base_test import BaseMongoTest
|
| 11 |
|
|
@@ -27,7 +28,6 @@ class TestPatientRepository(BaseMongoTest):
|
|
| 27 |
|
| 28 |
def test_create_patient(self):
|
| 29 |
"""Test patient creation with minimal and full data."""
|
| 30 |
-
# Test minimal creation
|
| 31 |
patient_id = patient_repo.create_patient(
|
| 32 |
"John Doe", 45, "Male", "Caucasian", collection_name=self.test_collection
|
| 33 |
)
|
|
@@ -36,41 +36,26 @@ class TestPatientRepository(BaseMongoTest):
|
|
| 36 |
self.assertIsNotNone(doc)
|
| 37 |
self.assertEqual(doc["name"], "John Doe") # type: ignore
|
| 38 |
|
| 39 |
-
# Test full creation with all optional parameters
|
| 40 |
doctor_id = str(ObjectId())
|
| 41 |
full_id = patient_repo.create_patient(
|
| 42 |
-
name="Jane Doe",
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
address="123 Wellness Way",
|
| 47 |
-
phone="555-123-4567",
|
| 48 |
-
email="jane.doe@example.com",
|
| 49 |
-
medications=["Lisinopril", "Metformin"],
|
| 50 |
-
past_assessment_summary="Routine check-up, generally healthy.",
|
| 51 |
-
assigned_doctor_id=doctor_id,
|
| 52 |
-
collection_name=self.test_collection
|
| 53 |
)
|
| 54 |
-
self.assertIsInstance(full_id, str)
|
| 55 |
full_doc = self.get_doc_by_id(Collections.PATIENT, full_id)
|
| 56 |
-
|
| 57 |
-
# Verify all fields were saved correctly
|
| 58 |
self.assertIsNotNone(full_doc)
|
| 59 |
-
self.assertEqual(full_doc["address"], "123 Wellness Way") # type: ignore
|
| 60 |
-
self.assertEqual(full_doc["phone"], "555-123-4567") # type: ignore
|
| 61 |
self.assertEqual(full_doc["email"], "jane.doe@example.com") # type: ignore
|
| 62 |
-
self.assertEqual(len(full_doc["medications"]), 2) # type: ignore
|
| 63 |
-
self.assertIn("Lisinopril", full_doc["medications"]) # type: ignore
|
| 64 |
-
self.assertEqual(full_doc["past_assessment_summary"], "Routine check-up, generally healthy.") # type: ignore
|
| 65 |
self.assertEqual(str(full_doc["assigned_doctor_id"]), doctor_id) # type: ignore
|
| 66 |
|
| 67 |
def test_get_patient_by_id(self):
|
| 68 |
-
"""Test retrieving an existing patient by ID."""
|
| 69 |
patient_id = patient_repo.create_patient("GetMe", 33, "Female", "Other", collection_name=self.test_collection)
|
| 70 |
patient = patient_repo.get_patient_by_id(patient_id, collection_name=self.test_collection)
|
| 71 |
self.assertIsNotNone(patient)
|
| 72 |
-
self.
|
| 73 |
-
|
|
|
|
| 74 |
self.assertIsNone(patient_repo.get_patient_by_id(str(ObjectId()), collection_name=self.test_collection))
|
| 75 |
|
| 76 |
def test_update_patient_profile(self):
|
|
@@ -80,18 +65,18 @@ class TestPatientRepository(BaseMongoTest):
|
|
| 80 |
modified_count = patient_repo.update_patient_profile(patient_id, updates, collection_name=self.test_collection)
|
| 81 |
self.assertEqual(modified_count, 1)
|
| 82 |
doc = self.get_doc_by_id(Collections.PATIENT, patient_id)
|
|
|
|
| 83 |
self.assertEqual(doc["age"], 26) # type: ignore
|
| 84 |
-
# Test updating a non-existent patient returns a modified count of 0
|
| 85 |
self.assertEqual(patient_repo.update_patient_profile(str(ObjectId()), {"name": "Ghost"}, collection_name=self.test_collection), 0)
|
| 86 |
|
| 87 |
def test_search_patients(self):
|
| 88 |
-
"""Test patient search functionality
|
| 89 |
patient_repo.create_patient("Alice Smith", 30, "Female", "Asian", collection_name=self.test_collection)
|
| 90 |
patient_repo.create_patient("Bob Smith", 45, "Male", "Caucasian", collection_name=self.test_collection)
|
| 91 |
results = patient_repo.search_patients("smith", collection_name=self.test_collection)
|
| 92 |
self.assertEqual(len(results), 2)
|
| 93 |
-
self.
|
| 94 |
-
|
| 95 |
self.assertEqual(len(patient_repo.search_patients("s", limit=1, collection_name=self.test_collection)), 1)
|
| 96 |
|
| 97 |
|
|
@@ -106,11 +91,8 @@ class TestPatientRepositoryExceptions(BaseMongoTest):
|
|
| 106 |
|
| 107 |
def test_write_error_raises_action_failed(self):
|
| 108 |
"""Test that creating or updating with data violating schema raises ActionFailed."""
|
| 109 |
-
# Test creating a patient with a 'sex' value not in the schema's enum
|
| 110 |
with self.assertRaises(ActionFailed):
|
| 111 |
patient_repo.create_patient("Schema Test", 25, "InvalidValue", "Other", collection_name=self.test_collection)
|
| 112 |
-
|
| 113 |
-
# Test updating a patient with an invalid value
|
| 114 |
patient_id = patient_repo.create_patient("UpdateSchema", 30, "Male", "Other", collection_name=self.test_collection)
|
| 115 |
with self.assertRaises(ActionFailed):
|
| 116 |
patient_repo.update_patient_profile(patient_id, {"ethnicity": 123}, collection_name=self.test_collection)
|
|
@@ -128,7 +110,6 @@ class TestPatientRepositoryExceptions(BaseMongoTest):
|
|
| 128 |
def test_all_functions_raise_on_connection_error(self, mock_get_collection):
|
| 129 |
"""Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
|
| 130 |
mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
|
| 131 |
-
|
| 132 |
with self.assertRaises(ActionFailed):
|
| 133 |
patient_repo.init(collection_name=self.test_collection, drop=True)
|
| 134 |
with self.assertRaises(ActionFailed):
|
|
@@ -143,4 +124,4 @@ class TestPatientRepositoryExceptions(BaseMongoTest):
|
|
| 143 |
if __name__ == "__main__":
|
| 144 |
logger().info("Starting MongoDB repository integration tests...")
|
| 145 |
unittest.main(verbosity=2)
|
| 146 |
-
logger().info("Tests completed
|
|
|
|
| 6 |
|
| 7 |
from src.data.connection import ActionFailed, Collections, get_collection
|
| 8 |
from src.data.repositories import patient as patient_repo
|
| 9 |
+
from src.models.patient import Patient
|
| 10 |
from src.utils.logger import logger
|
| 11 |
from tests.base_test import BaseMongoTest
|
| 12 |
|
|
|
|
| 28 |
|
| 29 |
def test_create_patient(self):
|
| 30 |
"""Test patient creation with minimal and full data."""
|
|
|
|
| 31 |
patient_id = patient_repo.create_patient(
|
| 32 |
"John Doe", 45, "Male", "Caucasian", collection_name=self.test_collection
|
| 33 |
)
|
|
|
|
| 36 |
self.assertIsNotNone(doc)
|
| 37 |
self.assertEqual(doc["name"], "John Doe") # type: ignore
|
| 38 |
|
|
|
|
| 39 |
doctor_id = str(ObjectId())
|
| 40 |
full_id = patient_repo.create_patient(
|
| 41 |
+
name="Jane Doe", age=30, sex="Female", ethnicity="Asian",
|
| 42 |
+
address="123 Wellness Way", phone="555-123-4567", email="jane.doe@example.com",
|
| 43 |
+
medications=["Lisinopril"], past_assessment_summary="Routine check-up.",
|
| 44 |
+
assigned_doctor_id=doctor_id, collection_name=self.test_collection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
)
|
|
|
|
| 46 |
full_doc = self.get_doc_by_id(Collections.PATIENT, full_id)
|
|
|
|
|
|
|
| 47 |
self.assertIsNotNone(full_doc)
|
|
|
|
|
|
|
| 48 |
self.assertEqual(full_doc["email"], "jane.doe@example.com") # type: ignore
|
|
|
|
|
|
|
|
|
|
| 49 |
self.assertEqual(str(full_doc["assigned_doctor_id"]), doctor_id) # type: ignore
|
| 50 |
|
| 51 |
def test_get_patient_by_id(self):
|
| 52 |
+
"""Test retrieving an existing patient by ID returns a Patient model."""
|
| 53 |
patient_id = patient_repo.create_patient("GetMe", 33, "Female", "Other", collection_name=self.test_collection)
|
| 54 |
patient = patient_repo.get_patient_by_id(patient_id, collection_name=self.test_collection)
|
| 55 |
self.assertIsNotNone(patient)
|
| 56 |
+
self.assertIsInstance(patient, Patient)
|
| 57 |
+
self.assertEqual(patient.id, patient_id) # type: ignore
|
| 58 |
+
self.assertEqual(patient.name, "GetMe") # type: ignore
|
| 59 |
self.assertIsNone(patient_repo.get_patient_by_id(str(ObjectId()), collection_name=self.test_collection))
|
| 60 |
|
| 61 |
def test_update_patient_profile(self):
|
|
|
|
| 65 |
modified_count = patient_repo.update_patient_profile(patient_id, updates, collection_name=self.test_collection)
|
| 66 |
self.assertEqual(modified_count, 1)
|
| 67 |
doc = self.get_doc_by_id(Collections.PATIENT, patient_id)
|
| 68 |
+
self.assertIsNotNone(doc)
|
| 69 |
self.assertEqual(doc["age"], 26) # type: ignore
|
|
|
|
| 70 |
self.assertEqual(patient_repo.update_patient_profile(str(ObjectId()), {"name": "Ghost"}, collection_name=self.test_collection), 0)
|
| 71 |
|
| 72 |
def test_search_patients(self):
|
| 73 |
+
"""Test patient search functionality returns a list of Patient models."""
|
| 74 |
patient_repo.create_patient("Alice Smith", 30, "Female", "Asian", collection_name=self.test_collection)
|
| 75 |
patient_repo.create_patient("Bob Smith", 45, "Male", "Caucasian", collection_name=self.test_collection)
|
| 76 |
results = patient_repo.search_patients("smith", collection_name=self.test_collection)
|
| 77 |
self.assertEqual(len(results), 2)
|
| 78 |
+
self.assertIsInstance(results[0], Patient)
|
| 79 |
+
self.assertEqual(results[0].name, "Alice Smith")
|
| 80 |
self.assertEqual(len(patient_repo.search_patients("s", limit=1, collection_name=self.test_collection)), 1)
|
| 81 |
|
| 82 |
|
|
|
|
| 91 |
|
| 92 |
def test_write_error_raises_action_failed(self):
|
| 93 |
"""Test that creating or updating with data violating schema raises ActionFailed."""
|
|
|
|
| 94 |
with self.assertRaises(ActionFailed):
|
| 95 |
patient_repo.create_patient("Schema Test", 25, "InvalidValue", "Other", collection_name=self.test_collection)
|
|
|
|
|
|
|
| 96 |
patient_id = patient_repo.create_patient("UpdateSchema", 30, "Male", "Other", collection_name=self.test_collection)
|
| 97 |
with self.assertRaises(ActionFailed):
|
| 98 |
patient_repo.update_patient_profile(patient_id, {"ethnicity": 123}, collection_name=self.test_collection)
|
|
|
|
| 110 |
def test_all_functions_raise_on_connection_error(self, mock_get_collection):
|
| 111 |
"""Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
|
| 112 |
mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
|
|
|
|
| 113 |
with self.assertRaises(ActionFailed):
|
| 114 |
patient_repo.init(collection_name=self.test_collection, drop=True)
|
| 115 |
with self.assertRaises(ActionFailed):
|
|
|
|
| 124 |
if __name__ == "__main__":
|
| 125 |
logger().info("Starting MongoDB repository integration tests...")
|
| 126 |
unittest.main(verbosity=2)
|
| 127 |
+
logger().info("Tests completed.")
|
tests/test_session.py
CHANGED
|
@@ -8,6 +8,7 @@ from pymongo.errors import ConnectionFailure, WriteError
|
|
| 8 |
|
| 9 |
from src.data.connection import ActionFailed, Collections, get_collection
|
| 10 |
from src.data.repositories import session as session_repo
|
|
|
|
| 11 |
from src.utils.logger import logger
|
| 12 |
from tests.base_test import BaseMongoTest
|
| 13 |
|
|
@@ -34,46 +35,58 @@ class TestSessionRepository(BaseMongoTest):
|
|
| 34 |
session = session_repo.create_session(
|
| 35 |
self.account_id, self.patient_id, "Test Chat", collection_name=self.test_collection
|
| 36 |
)
|
| 37 |
-
self.assertIsInstance(session
|
| 38 |
-
|
|
|
|
|
|
|
| 39 |
self.assertIsNotNone(retrieved)
|
| 40 |
-
self.
|
| 41 |
-
|
|
|
|
| 42 |
self.assertIsNone(session_repo.get_session(str(ObjectId()), collection_name=self.test_collection))
|
| 43 |
|
| 44 |
def test_add_and_get_messages(self):
|
| 45 |
-
"""Test adding messages and retrieving them
|
| 46 |
session = session_repo.create_session(self.account_id, self.patient_id, "Msg Test", collection_name=self.test_collection)
|
| 47 |
-
|
| 48 |
-
session_repo.add_message(session
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
self.assertEqual(len(messages), 2)
|
| 51 |
-
self.
|
| 52 |
-
self.assertEqual(messages[
|
| 53 |
-
|
| 54 |
-
self.assertEqual(len(session_repo.get_session_messages(session
|
| 55 |
|
| 56 |
def test_list_sessions(self):
|
| 57 |
-
"""Test listing sessions for a
|
| 58 |
session_repo.create_session(self.account_id, self.patient_id, "First", collection_name=self.test_collection)
|
| 59 |
-
time.sleep(0.01)
|
| 60 |
s2 = session_repo.create_session(self.account_id, self.patient_id, "Second", collection_name=self.test_collection)
|
| 61 |
-
|
| 62 |
patient_sessions = session_repo.list_patient_sessions(self.patient_id, collection_name=self.test_collection)
|
| 63 |
self.assertEqual(len(patient_sessions), 2)
|
| 64 |
-
self.
|
| 65 |
-
|
|
|
|
| 66 |
user_sessions = session_repo.get_user_sessions(self.account_id, collection_name=self.test_collection)
|
| 67 |
self.assertEqual(len(user_sessions), 2)
|
| 68 |
-
self.assertEqual(user_sessions[0]
|
| 69 |
|
| 70 |
def test_update_session_title(self):
|
| 71 |
"""Test updating a session's title and its timestamp."""
|
| 72 |
session = session_repo.create_session(self.account_id, self.patient_id, "Old", collection_name=self.test_collection)
|
| 73 |
-
original_doc = self.get_doc_by_id(Collections.SESSION, session
|
| 74 |
-
|
|
|
|
|
|
|
| 75 |
self.assertTrue(success)
|
| 76 |
-
|
|
|
|
|
|
|
| 77 |
self.assertEqual(updated_doc["title"], "New") # type: ignore
|
| 78 |
self.assertLess(original_doc["updated_at"], updated_doc["updated_at"]) # type: ignore
|
| 79 |
self.assertFalse(session_repo.update_session_title(str(ObjectId()), "Ghost", collection_name=self.test_collection))
|
|
@@ -81,27 +94,21 @@ class TestSessionRepository(BaseMongoTest):
|
|
| 81 |
def test_delete_session(self):
|
| 82 |
"""Test deleting a session."""
|
| 83 |
session = session_repo.create_session(self.account_id, self.patient_id, "To Delete", collection_name=self.test_collection)
|
| 84 |
-
self.assertTrue(session_repo.delete_session(session
|
| 85 |
-
self.assertIsNone(session_repo.get_session(session
|
| 86 |
self.assertFalse(session_repo.delete_session(str(ObjectId()), collection_name=self.test_collection))
|
| 87 |
|
| 88 |
def test_prune_old_sessions(self):
|
| 89 |
"""Test deleting sessions older than a specified number of days."""
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
session_repo.create_session(self.account_id, self.patient_id, "New Session", collection_name=self.test_collection)
|
| 93 |
-
|
| 94 |
-
# Manually update one session to be old
|
| 95 |
old_date = datetime.now(timezone.utc) - timedelta(days=31)
|
| 96 |
get_collection(self.test_collection).update_one(
|
| 97 |
-
{"_id": ObjectId(old_session
|
| 98 |
-
{"$set": {"updated_at": old_date}}
|
| 99 |
)
|
| 100 |
-
|
| 101 |
self.assertEqual(get_collection(self.test_collection).count_documents({}), 2)
|
| 102 |
deleted_count = session_repo.prune_old_sessions(days=30, collection_name=self.test_collection)
|
| 103 |
self.assertEqual(deleted_count, 1)
|
| 104 |
-
self.assertEqual(get_collection(self.test_collection).count_documents({}), 1)
|
| 105 |
|
| 106 |
|
| 107 |
class TestSessionRepositoryExceptions(BaseMongoTest):
|
|
@@ -118,14 +125,11 @@ class TestSessionRepositoryExceptions(BaseMongoTest):
|
|
| 118 |
@patch('src.data.repositories.session.get_collection')
|
| 119 |
def test_write_error_raises_action_failed(self, mock_get_collection):
|
| 120 |
"""Test that a WriteError during an operation is raised as ActionFailed."""
|
| 121 |
-
# Configure the mock to return a collection object whose methods raise errors
|
| 122 |
mock_collection = mock_get_collection.return_value
|
| 123 |
mock_collection.update_one.side_effect = WriteError("Simulated schema validation error")
|
| 124 |
-
mock_collection.find_one.return_value = {"messages": []}
|
| 125 |
-
|
| 126 |
with self.assertRaises(ActionFailed):
|
| 127 |
-
|
| 128 |
-
session_repo.add_message("68e212480769b3f99015f43c", "content", True, collection_name=self.test_collection)
|
| 129 |
|
| 130 |
def test_invalid_id_raises_action_failed(self):
|
| 131 |
"""Test that functions raise ActionFailed when given a malformed ObjectId string."""
|
|
@@ -150,7 +154,6 @@ class TestSessionRepositoryExceptions(BaseMongoTest):
|
|
| 150 |
def test_all_functions_raise_on_connection_error(self, mock_get_collection):
|
| 151 |
"""Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
|
| 152 |
mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
|
| 153 |
-
|
| 154 |
with self.assertRaises(ActionFailed):
|
| 155 |
session_repo.init(collection_name=self.test_collection, drop=True)
|
| 156 |
with self.assertRaises(ActionFailed):
|
|
@@ -175,4 +178,4 @@ class TestSessionRepositoryExceptions(BaseMongoTest):
|
|
| 175 |
if __name__ == "__main__":
|
| 176 |
logger().info("Starting MongoDB repository integration tests...")
|
| 177 |
unittest.main(verbosity=2)
|
| 178 |
-
logger().info("Tests completed
|
|
|
|
| 8 |
|
| 9 |
from src.data.connection import ActionFailed, Collections, get_collection
|
| 10 |
from src.data.repositories import session as session_repo
|
| 11 |
+
from src.models.session import Message, Session
|
| 12 |
from src.utils.logger import logger
|
| 13 |
from tests.base_test import BaseMongoTest
|
| 14 |
|
|
|
|
| 35 |
session = session_repo.create_session(
|
| 36 |
self.account_id, self.patient_id, "Test Chat", collection_name=self.test_collection
|
| 37 |
)
|
| 38 |
+
self.assertIsInstance(session, Session)
|
| 39 |
+
self.assertEqual(session.title, "Test Chat")
|
| 40 |
+
|
| 41 |
+
retrieved = session_repo.get_session(session.id, collection_name=self.test_collection)
|
| 42 |
self.assertIsNotNone(retrieved)
|
| 43 |
+
self.assertIsInstance(retrieved, Session)
|
| 44 |
+
self.assertEqual(retrieved.id, session.id) # type: ignore
|
| 45 |
+
self.assertEqual(retrieved.account_id, self.account_id) # type: ignore
|
| 46 |
self.assertIsNone(session_repo.get_session(str(ObjectId()), collection_name=self.test_collection))
|
| 47 |
|
| 48 |
def test_add_and_get_messages(self):
|
| 49 |
+
"""Test adding messages and retrieving them as Message models."""
|
| 50 |
session = session_repo.create_session(self.account_id, self.patient_id, "Msg Test", collection_name=self.test_collection)
|
| 51 |
+
|
| 52 |
+
session_repo.add_message(session.id, "User message 1", True, collection_name=self.test_collection)
|
| 53 |
+
# Add a small delay to ensure a distinct timestamp for the next message
|
| 54 |
+
time.sleep(0.01)
|
| 55 |
+
session_repo.add_message(session.id, "AI response 1", False, collection_name=self.test_collection)
|
| 56 |
+
|
| 57 |
+
messages = session_repo.get_session_messages(session.id, collection_name=self.test_collection)
|
| 58 |
self.assertEqual(len(messages), 2)
|
| 59 |
+
self.assertIsInstance(messages[0], Message)
|
| 60 |
+
self.assertEqual(messages[0].content, "AI response 1") # Descending order is now guaranteed
|
| 61 |
+
self.assertEqual(messages[1].id, 0)
|
| 62 |
+
self.assertEqual(len(session_repo.get_session_messages(session.id, limit=1, collection_name=self.test_collection)), 1)
|
| 63 |
|
| 64 |
def test_list_sessions(self):
|
| 65 |
+
"""Test listing sessions for a patient and user returns lists of Session models."""
|
| 66 |
session_repo.create_session(self.account_id, self.patient_id, "First", collection_name=self.test_collection)
|
| 67 |
+
time.sleep(0.01)
|
| 68 |
s2 = session_repo.create_session(self.account_id, self.patient_id, "Second", collection_name=self.test_collection)
|
| 69 |
+
|
| 70 |
patient_sessions = session_repo.list_patient_sessions(self.patient_id, collection_name=self.test_collection)
|
| 71 |
self.assertEqual(len(patient_sessions), 2)
|
| 72 |
+
self.assertIsInstance(patient_sessions[0], Session)
|
| 73 |
+
self.assertEqual(patient_sessions[0].id, s2.id)
|
| 74 |
+
|
| 75 |
user_sessions = session_repo.get_user_sessions(self.account_id, collection_name=self.test_collection)
|
| 76 |
self.assertEqual(len(user_sessions), 2)
|
| 77 |
+
self.assertEqual(user_sessions[0].id, s2.id)
|
| 78 |
|
| 79 |
def test_update_session_title(self):
|
| 80 |
"""Test updating a session's title and its timestamp."""
|
| 81 |
session = session_repo.create_session(self.account_id, self.patient_id, "Old", collection_name=self.test_collection)
|
| 82 |
+
original_doc = self.get_doc_by_id(Collections.SESSION, session.id)
|
| 83 |
+
self.assertIsNotNone(original_doc)
|
| 84 |
+
|
| 85 |
+
success = session_repo.update_session_title(session.id, "New", collection_name=self.test_collection)
|
| 86 |
self.assertTrue(success)
|
| 87 |
+
|
| 88 |
+
updated_doc = self.get_doc_by_id(Collections.SESSION, session.id)
|
| 89 |
+
self.assertIsNotNone(updated_doc)
|
| 90 |
self.assertEqual(updated_doc["title"], "New") # type: ignore
|
| 91 |
self.assertLess(original_doc["updated_at"], updated_doc["updated_at"]) # type: ignore
|
| 92 |
self.assertFalse(session_repo.update_session_title(str(ObjectId()), "Ghost", collection_name=self.test_collection))
|
|
|
|
| 94 |
def test_delete_session(self):
|
| 95 |
"""Test deleting a session."""
|
| 96 |
session = session_repo.create_session(self.account_id, self.patient_id, "To Delete", collection_name=self.test_collection)
|
| 97 |
+
self.assertTrue(session_repo.delete_session(session.id, collection_name=self.test_collection))
|
| 98 |
+
self.assertIsNone(session_repo.get_session(session.id, collection_name=self.test_collection))
|
| 99 |
self.assertFalse(session_repo.delete_session(str(ObjectId()), collection_name=self.test_collection))
|
| 100 |
|
| 101 |
def test_prune_old_sessions(self):
|
| 102 |
"""Test deleting sessions older than a specified number of days."""
|
| 103 |
+
old_session = session_repo.create_session(self.account_id, self.patient_id, "Old", collection_name=self.test_collection)
|
| 104 |
+
session_repo.create_session(self.account_id, self.patient_id, "New", collection_name=self.test_collection)
|
|
|
|
|
|
|
|
|
|
| 105 |
old_date = datetime.now(timezone.utc) - timedelta(days=31)
|
| 106 |
get_collection(self.test_collection).update_one(
|
| 107 |
+
{"_id": ObjectId(old_session.id)}, {"$set": {"updated_at": old_date}}
|
|
|
|
| 108 |
)
|
|
|
|
| 109 |
self.assertEqual(get_collection(self.test_collection).count_documents({}), 2)
|
| 110 |
deleted_count = session_repo.prune_old_sessions(days=30, collection_name=self.test_collection)
|
| 111 |
self.assertEqual(deleted_count, 1)
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
class TestSessionRepositoryExceptions(BaseMongoTest):
|
|
|
|
| 125 |
@patch('src.data.repositories.session.get_collection')
|
| 126 |
def test_write_error_raises_action_failed(self, mock_get_collection):
|
| 127 |
"""Test that a WriteError during an operation is raised as ActionFailed."""
|
|
|
|
| 128 |
mock_collection = mock_get_collection.return_value
|
| 129 |
mock_collection.update_one.side_effect = WriteError("Simulated schema validation error")
|
| 130 |
+
mock_collection.find_one.return_value = {"messages": []}
|
|
|
|
| 131 |
with self.assertRaises(ActionFailed):
|
| 132 |
+
session_repo.add_message(str(ObjectId()), "content", True, collection_name=self.test_collection)
|
|
|
|
| 133 |
|
| 134 |
def test_invalid_id_raises_action_failed(self):
|
| 135 |
"""Test that functions raise ActionFailed when given a malformed ObjectId string."""
|
|
|
|
| 154 |
def test_all_functions_raise_on_connection_error(self, mock_get_collection):
|
| 155 |
"""Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
|
| 156 |
mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
|
|
|
|
| 157 |
with self.assertRaises(ActionFailed):
|
| 158 |
session_repo.init(collection_name=self.test_collection, drop=True)
|
| 159 |
with self.assertRaises(ActionFailed):
|
|
|
|
| 178 |
if __name__ == "__main__":
|
| 179 |
logger().info("Starting MongoDB repository integration tests...")
|
| 180 |
unittest.main(verbosity=2)
|
| 181 |
+
logger().info("Tests completed.")
|