dylanglenister commited on
Commit
57635b5
·
1 Parent(s): 464bdf6

Updated tests to match the new models.

Browse files
Files changed (3) hide show
  1. tests/test_account.py +23 -51
  2. tests/test_patient.py +14 -33
  3. tests/test_session.py +42 -39
tests/test_account.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import unittest
2
  from datetime import datetime
3
  from unittest.mock import patch
@@ -7,6 +8,7 @@ from pymongo.errors import ConnectionFailure
7
 
8
  from src.data.connection import ActionFailed, Collections, get_collection
9
  from src.data.repositories import account as account_repo
 
10
  from src.utils.logger import logger
11
  from tests.base_test import BaseMongoTest
12
 
@@ -23,24 +25,20 @@ class TestAccountRepository(BaseMongoTest):
23
  def test_init_functionality(self):
24
  """Test the init function's ability to create, drop, and preserve collections."""
25
  self.assertIn(self.test_collection, self.db.list_collection_names())
26
- # Test that data persists when drop=False
27
- account_id = account_repo.create_account("Persist Test", "Doctor", collection_name=self.test_collection)
28
  account_repo.init(collection_name=self.test_collection, drop=False)
29
  self.assertEqual(get_collection(self.test_collection).count_documents({}), 1)
30
- # Test that data is deleted when drop=True
31
  account_repo.init(collection_name=self.test_collection, drop=True)
32
  self.assertEqual(get_collection(self.test_collection).count_documents({}), 0)
33
 
34
  def test_create_account(self):
35
  """Test successful account creation, including optional fields."""
36
- # Test basic creation
37
  name, role = "Test Doctor", "Doctor"
38
  account_id = account_repo.create_account(name=name, role=role, collection_name=self.test_collection)
39
  self.assertIsInstance(account_id, str)
40
  doc = self.get_doc_by_id(Collections.ACCOUNT, account_id)
41
  self.assertIsNotNone(doc)
42
  self.assertEqual(doc["name"], name) # type: ignore
43
- # Test creation with specialty
44
  spec_id = account_repo.create_account("Spec", "Nurse", specialty="Cardiology", collection_name=self.test_collection)
45
  spec_doc = self.get_doc_by_id(Collections.ACCOUNT, spec_id)
46
  self.assertEqual(spec_doc["specialty"], "Cardiology") # type: ignore
@@ -50,8 +48,6 @@ class TestAccountRepository(BaseMongoTest):
50
  account_id = account_repo.create_account("Update Logic", "Doctor", collection_name=self.test_collection)
51
  original_doc = self.get_doc_by_id(Collections.ACCOUNT, account_id)
52
  self.assertIsNotNone(original_doc)
53
-
54
- # Test that 'created_at' is immutable
55
  updates = {"name": "Updated Name", "created_at": datetime(2000, 1, 1)}
56
  success = account_repo.update_account(account_id, updates, collection_name=self.test_collection)
57
  self.assertTrue(success)
@@ -59,80 +55,56 @@ class TestAccountRepository(BaseMongoTest):
59
  self.assertIsNotNone(updated_doc)
60
  self.assertEqual(updated_doc["created_at"], original_doc["created_at"]) # type: ignore
61
  self.assertLess(original_doc["updated_at"], updated_doc["updated_at"]) # type: ignore
62
-
63
- # Test updating a non-existent account returns False
64
  self.assertFalse(account_repo.update_account(str(ObjectId()), {"name": "No One"}, collection_name=self.test_collection))
65
 
66
  def test_get_account_logic(self):
67
- """Test that get_account updates 'last_seen' but not 'updated_at'."""
68
  account_id = account_repo.create_account("GetMe", "Doctor", collection_name=self.test_collection)
69
  original_doc = self.get_doc_by_id(Collections.ACCOUNT, account_id)
70
  self.assertIsNotNone(original_doc)
71
- self.assertNotIn("last_seen", original_doc) # type: ignore
72
 
73
- # Get the account, which should add 'last_seen'
74
  account = account_repo.get_account(account_id, collection_name=self.test_collection)
75
  self.assertIsNotNone(account)
76
- self.assertIn("last_seen", account) # type: ignore
77
-
78
- # Verify that 'updated_at' was NOT modified by the get operation
79
- final_doc = self.get_doc_by_id(Collections.ACCOUNT, account_id)
80
- self.assertIsNotNone(final_doc)
81
- self.assertEqual(final_doc["updated_at"], original_doc["updated_at"]) # type: ignore
82
 
83
  def test_get_account_by_name(self):
84
  """Test retrieving an account by name and check for deprecation warning."""
85
  name = "FindByName"
86
  account_repo.create_account(name, "Nurse", collection_name=self.test_collection)
87
-
88
- # Check that the function works and raises the expected warning
89
  account = account_repo.get_account_by_name(name, collection_name=self.test_collection)
90
  self.assertIsNotNone(account)
91
- self.assertEqual(account["name"], name) # type: ignore
92
-
93
- # Test retrieval of a non-existent name returns None
94
  self.assertIsNone(account_repo.get_account_by_name("NonExistent", collection_name=self.test_collection))
95
 
96
  def test_search_accounts(self):
97
- """Test search functionality with various edge cases."""
98
  account_repo.create_account("Alpha Doctor", "Doctor", collection_name=self.test_collection)
99
  account_repo.create_account("Beta Nurse", "Nurse", collection_name=self.test_collection)
100
- account_repo.create_account("Charlie Doctor", "Doctor", collection_name=self.test_collection)
101
-
102
- # Test case-insensitive partial match
103
- self.assertEqual(len(account_repo.search_accounts("alpha", collection_name=self.test_collection)), 1)
104
- # Test query matching multiple documents
105
- self.assertEqual(len(account_repo.search_accounts("Doctor", collection_name=self.test_collection)), 2)
106
- # Test limit parameter
107
- self.assertEqual(len(account_repo.search_accounts("Doctor", limit=1, collection_name=self.test_collection)), 1)
108
- # Test query with no matches
109
  self.assertEqual(len(account_repo.search_accounts("NonExistent", collection_name=self.test_collection)), 0)
110
- # Test empty query string returns an empty list
111
- self.assertEqual(len(account_repo.search_accounts("", collection_name=self.test_collection)), 0)
112
 
113
  def test_get_all_accounts(self):
114
- """Test retrieving all accounts, verifying sorting and limit."""
115
  account_repo.create_account("Charlie", "Doctor", collection_name=self.test_collection)
116
  account_repo.create_account("Alpha", "Nurse", collection_name=self.test_collection)
117
- account_repo.create_account("Beta", "Caregiver", collection_name=self.test_collection)
118
-
119
  all_accounts = account_repo.get_all_accounts(collection_name=self.test_collection)
120
- self.assertEqual(len(all_accounts), 3)
121
- # Verify results are sorted by name
122
- self.assertEqual(all_accounts[0]["name"], "Alpha")
123
- self.assertEqual(all_accounts[2]["name"], "Charlie")
124
-
125
- # Test with a limit
126
- limited_accounts = account_repo.get_all_accounts(limit=2, collection_name=self.test_collection)
127
- self.assertEqual(len(limited_accounts), 2)
128
- self.assertEqual(limited_accounts[1]["name"], "Beta")
129
 
130
  def test_get_account_frame(self):
131
- """Test retrieving accounts as a pandas DataFrame for empty and populated collections."""
132
- # Test with an empty collection
133
  df_empty = account_repo.get_account_frame(collection_name=self.test_collection)
134
  self.assertTrue(df_empty.empty)
135
- # Test with data
136
  account_repo.create_account("Frame Alpha", "Doctor", collection_name=self.test_collection)
137
  df_full = account_repo.get_account_frame(collection_name=self.test_collection)
138
  self.assertEqual(len(df_full), 1)
@@ -187,4 +159,4 @@ class TestAccountRepositoryExceptions(BaseMongoTest):
187
  if __name__ == "__main__":
188
  logger().info("Starting MongoDB repository integration tests...")
189
  unittest.main(verbosity=2)
190
- logger().info("Tests completed and database connection closed.")
 
1
+ import time
2
  import unittest
3
  from datetime import datetime
4
  from unittest.mock import patch
 
8
 
9
  from src.data.connection import ActionFailed, Collections, get_collection
10
  from src.data.repositories import account as account_repo
11
+ from src.models.account import Account
12
  from src.utils.logger import logger
13
  from tests.base_test import BaseMongoTest
14
 
 
25
  def test_init_functionality(self):
26
  """Test the init function's ability to create, drop, and preserve collections."""
27
  self.assertIn(self.test_collection, self.db.list_collection_names())
28
+ account_repo.create_account("Persist Test", "Doctor", collection_name=self.test_collection)
 
29
  account_repo.init(collection_name=self.test_collection, drop=False)
30
  self.assertEqual(get_collection(self.test_collection).count_documents({}), 1)
 
31
  account_repo.init(collection_name=self.test_collection, drop=True)
32
  self.assertEqual(get_collection(self.test_collection).count_documents({}), 0)
33
 
34
  def test_create_account(self):
35
  """Test successful account creation, including optional fields."""
 
36
  name, role = "Test Doctor", "Doctor"
37
  account_id = account_repo.create_account(name=name, role=role, collection_name=self.test_collection)
38
  self.assertIsInstance(account_id, str)
39
  doc = self.get_doc_by_id(Collections.ACCOUNT, account_id)
40
  self.assertIsNotNone(doc)
41
  self.assertEqual(doc["name"], name) # type: ignore
 
42
  spec_id = account_repo.create_account("Spec", "Nurse", specialty="Cardiology", collection_name=self.test_collection)
43
  spec_doc = self.get_doc_by_id(Collections.ACCOUNT, spec_id)
44
  self.assertEqual(spec_doc["specialty"], "Cardiology") # type: ignore
 
48
  account_id = account_repo.create_account("Update Logic", "Doctor", collection_name=self.test_collection)
49
  original_doc = self.get_doc_by_id(Collections.ACCOUNT, account_id)
50
  self.assertIsNotNone(original_doc)
 
 
51
  updates = {"name": "Updated Name", "created_at": datetime(2000, 1, 1)}
52
  success = account_repo.update_account(account_id, updates, collection_name=self.test_collection)
53
  self.assertTrue(success)
 
55
  self.assertIsNotNone(updated_doc)
56
  self.assertEqual(updated_doc["created_at"], original_doc["created_at"]) # type: ignore
57
  self.assertLess(original_doc["updated_at"], updated_doc["updated_at"]) # type: ignore
 
 
58
  self.assertFalse(account_repo.update_account(str(ObjectId()), {"name": "No One"}, collection_name=self.test_collection))
59
 
60
  def test_get_account_logic(self):
61
+ """Test that get_account updates 'last_seen' and returns a valid Account model."""
62
  account_id = account_repo.create_account("GetMe", "Doctor", collection_name=self.test_collection)
63
  original_doc = self.get_doc_by_id(Collections.ACCOUNT, account_id)
64
  self.assertIsNotNone(original_doc)
65
+ time.sleep(0.01) # Ensure timestamp will be different
66
 
 
67
  account = account_repo.get_account(account_id, collection_name=self.test_collection)
68
  self.assertIsNotNone(account)
69
+ self.assertIsInstance(account, Account)
70
+ self.assertLess(original_doc["last_seen"], account.last_seen) # type: ignore
71
+ self.assertEqual(original_doc["updated_at"], account.updated_at) # type: ignore
72
+ self.assertIsNone(account_repo.get_account(str(ObjectId()), collection_name=self.test_collection))
 
 
73
 
74
  def test_get_account_by_name(self):
75
  """Test retrieving an account by name and check for deprecation warning."""
76
  name = "FindByName"
77
  account_repo.create_account(name, "Nurse", collection_name=self.test_collection)
 
 
78
  account = account_repo.get_account_by_name(name, collection_name=self.test_collection)
79
  self.assertIsNotNone(account)
80
+ self.assertIsInstance(account, Account)
81
+ self.assertEqual(account.name, name) # type: ignore
 
82
  self.assertIsNone(account_repo.get_account_by_name("NonExistent", collection_name=self.test_collection))
83
 
84
  def test_search_accounts(self):
85
+ """Test search functionality returns a list of Account models."""
86
  account_repo.create_account("Alpha Doctor", "Doctor", collection_name=self.test_collection)
87
  account_repo.create_account("Beta Nurse", "Nurse", collection_name=self.test_collection)
88
+ results = account_repo.search_accounts("alpha", collection_name=self.test_collection)
89
+ self.assertEqual(len(results), 1)
90
+ self.assertIsInstance(results[0], Account)
91
+ self.assertEqual(results[0].name, "Alpha Doctor")
 
 
 
 
 
92
  self.assertEqual(len(account_repo.search_accounts("NonExistent", collection_name=self.test_collection)), 0)
 
 
93
 
94
  def test_get_all_accounts(self):
95
+ """Test retrieving all accounts, verifying sorting and model type."""
96
  account_repo.create_account("Charlie", "Doctor", collection_name=self.test_collection)
97
  account_repo.create_account("Alpha", "Nurse", collection_name=self.test_collection)
 
 
98
  all_accounts = account_repo.get_all_accounts(collection_name=self.test_collection)
99
+ self.assertEqual(len(all_accounts), 2)
100
+ self.assertIsInstance(all_accounts[0], Account)
101
+ self.assertEqual(all_accounts[0].name, "Alpha")
102
+ self.assertEqual(all_accounts[1].name, "Charlie")
 
 
 
 
 
103
 
104
  def test_get_account_frame(self):
105
+ """Test retrieving accounts as a pandas DataFrame."""
 
106
  df_empty = account_repo.get_account_frame(collection_name=self.test_collection)
107
  self.assertTrue(df_empty.empty)
 
108
  account_repo.create_account("Frame Alpha", "Doctor", collection_name=self.test_collection)
109
  df_full = account_repo.get_account_frame(collection_name=self.test_collection)
110
  self.assertEqual(len(df_full), 1)
 
159
  if __name__ == "__main__":
160
  logger().info("Starting MongoDB repository integration tests...")
161
  unittest.main(verbosity=2)
162
+ logger().info("Tests completed.")
tests/test_patient.py CHANGED
@@ -6,6 +6,7 @@ from pymongo.errors import ConnectionFailure
6
 
7
  from src.data.connection import ActionFailed, Collections, get_collection
8
  from src.data.repositories import patient as patient_repo
 
9
  from src.utils.logger import logger
10
  from tests.base_test import BaseMongoTest
11
 
@@ -27,7 +28,6 @@ class TestPatientRepository(BaseMongoTest):
27
 
28
  def test_create_patient(self):
29
  """Test patient creation with minimal and full data."""
30
- # Test minimal creation
31
  patient_id = patient_repo.create_patient(
32
  "John Doe", 45, "Male", "Caucasian", collection_name=self.test_collection
33
  )
@@ -36,41 +36,26 @@ class TestPatientRepository(BaseMongoTest):
36
  self.assertIsNotNone(doc)
37
  self.assertEqual(doc["name"], "John Doe") # type: ignore
38
 
39
- # Test full creation with all optional parameters
40
  doctor_id = str(ObjectId())
41
  full_id = patient_repo.create_patient(
42
- name="Jane Doe",
43
- age=30,
44
- sex="Female",
45
- ethnicity="Asian",
46
- address="123 Wellness Way",
47
- phone="555-123-4567",
48
- email="jane.doe@example.com",
49
- medications=["Lisinopril", "Metformin"],
50
- past_assessment_summary="Routine check-up, generally healthy.",
51
- assigned_doctor_id=doctor_id,
52
- collection_name=self.test_collection
53
  )
54
- self.assertIsInstance(full_id, str)
55
  full_doc = self.get_doc_by_id(Collections.PATIENT, full_id)
56
-
57
- # Verify all fields were saved correctly
58
  self.assertIsNotNone(full_doc)
59
- self.assertEqual(full_doc["address"], "123 Wellness Way") # type: ignore
60
- self.assertEqual(full_doc["phone"], "555-123-4567") # type: ignore
61
  self.assertEqual(full_doc["email"], "jane.doe@example.com") # type: ignore
62
- self.assertEqual(len(full_doc["medications"]), 2) # type: ignore
63
- self.assertIn("Lisinopril", full_doc["medications"]) # type: ignore
64
- self.assertEqual(full_doc["past_assessment_summary"], "Routine check-up, generally healthy.") # type: ignore
65
  self.assertEqual(str(full_doc["assigned_doctor_id"]), doctor_id) # type: ignore
66
 
67
  def test_get_patient_by_id(self):
68
- """Test retrieving an existing patient by ID."""
69
  patient_id = patient_repo.create_patient("GetMe", 33, "Female", "Other", collection_name=self.test_collection)
70
  patient = patient_repo.get_patient_by_id(patient_id, collection_name=self.test_collection)
71
  self.assertIsNotNone(patient)
72
- self.assertEqual(patient["_id"], patient_id) # type: ignore
73
- # Test retrieval of a non-existent patient returns None
 
74
  self.assertIsNone(patient_repo.get_patient_by_id(str(ObjectId()), collection_name=self.test_collection))
75
 
76
  def test_update_patient_profile(self):
@@ -80,18 +65,18 @@ class TestPatientRepository(BaseMongoTest):
80
  modified_count = patient_repo.update_patient_profile(patient_id, updates, collection_name=self.test_collection)
81
  self.assertEqual(modified_count, 1)
82
  doc = self.get_doc_by_id(Collections.PATIENT, patient_id)
 
83
  self.assertEqual(doc["age"], 26) # type: ignore
84
- # Test updating a non-existent patient returns a modified count of 0
85
  self.assertEqual(patient_repo.update_patient_profile(str(ObjectId()), {"name": "Ghost"}, collection_name=self.test_collection), 0)
86
 
87
  def test_search_patients(self):
88
- """Test patient search functionality with various queries and limits."""
89
  patient_repo.create_patient("Alice Smith", 30, "Female", "Asian", collection_name=self.test_collection)
90
  patient_repo.create_patient("Bob Smith", 45, "Male", "Caucasian", collection_name=self.test_collection)
91
  results = patient_repo.search_patients("smith", collection_name=self.test_collection)
92
  self.assertEqual(len(results), 2)
93
- self.assertEqual(results[0]["name"], "Alice Smith")
94
- # Test limit parameter
95
  self.assertEqual(len(patient_repo.search_patients("s", limit=1, collection_name=self.test_collection)), 1)
96
 
97
 
@@ -106,11 +91,8 @@ class TestPatientRepositoryExceptions(BaseMongoTest):
106
 
107
  def test_write_error_raises_action_failed(self):
108
  """Test that creating or updating with data violating schema raises ActionFailed."""
109
- # Test creating a patient with a 'sex' value not in the schema's enum
110
  with self.assertRaises(ActionFailed):
111
  patient_repo.create_patient("Schema Test", 25, "InvalidValue", "Other", collection_name=self.test_collection)
112
-
113
- # Test updating a patient with an invalid value
114
  patient_id = patient_repo.create_patient("UpdateSchema", 30, "Male", "Other", collection_name=self.test_collection)
115
  with self.assertRaises(ActionFailed):
116
  patient_repo.update_patient_profile(patient_id, {"ethnicity": 123}, collection_name=self.test_collection)
@@ -128,7 +110,6 @@ class TestPatientRepositoryExceptions(BaseMongoTest):
128
  def test_all_functions_raise_on_connection_error(self, mock_get_collection):
129
  """Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
130
  mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
131
-
132
  with self.assertRaises(ActionFailed):
133
  patient_repo.init(collection_name=self.test_collection, drop=True)
134
  with self.assertRaises(ActionFailed):
@@ -143,4 +124,4 @@ class TestPatientRepositoryExceptions(BaseMongoTest):
143
  if __name__ == "__main__":
144
  logger().info("Starting MongoDB repository integration tests...")
145
  unittest.main(verbosity=2)
146
- logger().info("Tests completed and database connection closed.")
 
6
 
7
  from src.data.connection import ActionFailed, Collections, get_collection
8
  from src.data.repositories import patient as patient_repo
9
+ from src.models.patient import Patient
10
  from src.utils.logger import logger
11
  from tests.base_test import BaseMongoTest
12
 
 
28
 
29
  def test_create_patient(self):
30
  """Test patient creation with minimal and full data."""
 
31
  patient_id = patient_repo.create_patient(
32
  "John Doe", 45, "Male", "Caucasian", collection_name=self.test_collection
33
  )
 
36
  self.assertIsNotNone(doc)
37
  self.assertEqual(doc["name"], "John Doe") # type: ignore
38
 
 
39
  doctor_id = str(ObjectId())
40
  full_id = patient_repo.create_patient(
41
+ name="Jane Doe", age=30, sex="Female", ethnicity="Asian",
42
+ address="123 Wellness Way", phone="555-123-4567", email="jane.doe@example.com",
43
+ medications=["Lisinopril"], past_assessment_summary="Routine check-up.",
44
+ assigned_doctor_id=doctor_id, collection_name=self.test_collection
 
 
 
 
 
 
 
45
  )
 
46
  full_doc = self.get_doc_by_id(Collections.PATIENT, full_id)
 
 
47
  self.assertIsNotNone(full_doc)
 
 
48
  self.assertEqual(full_doc["email"], "jane.doe@example.com") # type: ignore
 
 
 
49
  self.assertEqual(str(full_doc["assigned_doctor_id"]), doctor_id) # type: ignore
50
 
51
  def test_get_patient_by_id(self):
52
+ """Test retrieving an existing patient by ID returns a Patient model."""
53
  patient_id = patient_repo.create_patient("GetMe", 33, "Female", "Other", collection_name=self.test_collection)
54
  patient = patient_repo.get_patient_by_id(patient_id, collection_name=self.test_collection)
55
  self.assertIsNotNone(patient)
56
+ self.assertIsInstance(patient, Patient)
57
+ self.assertEqual(patient.id, patient_id) # type: ignore
58
+ self.assertEqual(patient.name, "GetMe") # type: ignore
59
  self.assertIsNone(patient_repo.get_patient_by_id(str(ObjectId()), collection_name=self.test_collection))
60
 
61
  def test_update_patient_profile(self):
 
65
  modified_count = patient_repo.update_patient_profile(patient_id, updates, collection_name=self.test_collection)
66
  self.assertEqual(modified_count, 1)
67
  doc = self.get_doc_by_id(Collections.PATIENT, patient_id)
68
+ self.assertIsNotNone(doc)
69
  self.assertEqual(doc["age"], 26) # type: ignore
 
70
  self.assertEqual(patient_repo.update_patient_profile(str(ObjectId()), {"name": "Ghost"}, collection_name=self.test_collection), 0)
71
 
72
  def test_search_patients(self):
73
+ """Test patient search functionality returns a list of Patient models."""
74
  patient_repo.create_patient("Alice Smith", 30, "Female", "Asian", collection_name=self.test_collection)
75
  patient_repo.create_patient("Bob Smith", 45, "Male", "Caucasian", collection_name=self.test_collection)
76
  results = patient_repo.search_patients("smith", collection_name=self.test_collection)
77
  self.assertEqual(len(results), 2)
78
+ self.assertIsInstance(results[0], Patient)
79
+ self.assertEqual(results[0].name, "Alice Smith")
80
  self.assertEqual(len(patient_repo.search_patients("s", limit=1, collection_name=self.test_collection)), 1)
81
 
82
 
 
91
 
92
  def test_write_error_raises_action_failed(self):
93
  """Test that creating or updating with data violating schema raises ActionFailed."""
 
94
  with self.assertRaises(ActionFailed):
95
  patient_repo.create_patient("Schema Test", 25, "InvalidValue", "Other", collection_name=self.test_collection)
 
 
96
  patient_id = patient_repo.create_patient("UpdateSchema", 30, "Male", "Other", collection_name=self.test_collection)
97
  with self.assertRaises(ActionFailed):
98
  patient_repo.update_patient_profile(patient_id, {"ethnicity": 123}, collection_name=self.test_collection)
 
110
  def test_all_functions_raise_on_connection_error(self, mock_get_collection):
111
  """Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
112
  mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
 
113
  with self.assertRaises(ActionFailed):
114
  patient_repo.init(collection_name=self.test_collection, drop=True)
115
  with self.assertRaises(ActionFailed):
 
124
  if __name__ == "__main__":
125
  logger().info("Starting MongoDB repository integration tests...")
126
  unittest.main(verbosity=2)
127
+ logger().info("Tests completed.")
tests/test_session.py CHANGED
@@ -8,6 +8,7 @@ from pymongo.errors import ConnectionFailure, WriteError
8
 
9
  from src.data.connection import ActionFailed, Collections, get_collection
10
  from src.data.repositories import session as session_repo
 
11
  from src.utils.logger import logger
12
  from tests.base_test import BaseMongoTest
13
 
@@ -34,46 +35,58 @@ class TestSessionRepository(BaseMongoTest):
34
  session = session_repo.create_session(
35
  self.account_id, self.patient_id, "Test Chat", collection_name=self.test_collection
36
  )
37
- self.assertIsInstance(session["_id"], str)
38
- retrieved = session_repo.get_session(session["_id"], collection_name=self.test_collection)
 
 
39
  self.assertIsNotNone(retrieved)
40
- self.assertEqual(retrieved["_id"], session["_id"]) # type: ignore
41
- # Test getting a non-existent session
 
42
  self.assertIsNone(session_repo.get_session(str(ObjectId()), collection_name=self.test_collection))
43
 
44
  def test_add_and_get_messages(self):
45
- """Test adding messages and retrieving them in the correct order."""
46
  session = session_repo.create_session(self.account_id, self.patient_id, "Msg Test", collection_name=self.test_collection)
47
- session_repo.add_message(session["_id"], "User message 1", True, collection_name=self.test_collection)
48
- session_repo.add_message(session["_id"], "AI response 1", False, collection_name=self.test_collection)
49
- messages = session_repo.get_session_messages(session["_id"], collection_name=self.test_collection)
 
 
 
 
50
  self.assertEqual(len(messages), 2)
51
- self.assertEqual(messages[0]["content"], "AI response 1") # Descending order
52
- self.assertEqual(messages[1]["_id"], 0)
53
- # Test limit
54
- self.assertEqual(len(session_repo.get_session_messages(session["_id"], limit=1, collection_name=self.test_collection)), 1)
55
 
56
  def test_list_sessions(self):
57
- """Test listing sessions for a specific patient and user, sorted by update time."""
58
  session_repo.create_session(self.account_id, self.patient_id, "First", collection_name=self.test_collection)
59
- time.sleep(0.01) # Ensure distinct timestamps
60
  s2 = session_repo.create_session(self.account_id, self.patient_id, "Second", collection_name=self.test_collection)
61
- # Test listing for patient
62
  patient_sessions = session_repo.list_patient_sessions(self.patient_id, collection_name=self.test_collection)
63
  self.assertEqual(len(patient_sessions), 2)
64
- self.assertEqual(patient_sessions[0]["_id"], s2["_id"]) # Most recent first
65
- # Test listing for user
 
66
  user_sessions = session_repo.get_user_sessions(self.account_id, collection_name=self.test_collection)
67
  self.assertEqual(len(user_sessions), 2)
68
- self.assertEqual(user_sessions[0]["_id"], s2["_id"])
69
 
70
  def test_update_session_title(self):
71
  """Test updating a session's title and its timestamp."""
72
  session = session_repo.create_session(self.account_id, self.patient_id, "Old", collection_name=self.test_collection)
73
- original_doc = self.get_doc_by_id(Collections.SESSION, session["_id"])
74
- success = session_repo.update_session_title(session["_id"], "New", collection_name=self.test_collection)
 
 
75
  self.assertTrue(success)
76
- updated_doc = self.get_doc_by_id(Collections.SESSION, session["_id"])
 
 
77
  self.assertEqual(updated_doc["title"], "New") # type: ignore
78
  self.assertLess(original_doc["updated_at"], updated_doc["updated_at"]) # type: ignore
79
  self.assertFalse(session_repo.update_session_title(str(ObjectId()), "Ghost", collection_name=self.test_collection))
@@ -81,27 +94,21 @@ class TestSessionRepository(BaseMongoTest):
81
  def test_delete_session(self):
82
  """Test deleting a session."""
83
  session = session_repo.create_session(self.account_id, self.patient_id, "To Delete", collection_name=self.test_collection)
84
- self.assertTrue(session_repo.delete_session(session["_id"], collection_name=self.test_collection))
85
- self.assertIsNone(session_repo.get_session(session["_id"], collection_name=self.test_collection))
86
  self.assertFalse(session_repo.delete_session(str(ObjectId()), collection_name=self.test_collection))
87
 
88
  def test_prune_old_sessions(self):
89
  """Test deleting sessions older than a specified number of days."""
90
- # Create two valid sessions
91
- old_session = session_repo.create_session(self.account_id, self.patient_id, "Old Session", collection_name=self.test_collection)
92
- session_repo.create_session(self.account_id, self.patient_id, "New Session", collection_name=self.test_collection)
93
-
94
- # Manually update one session to be old
95
  old_date = datetime.now(timezone.utc) - timedelta(days=31)
96
  get_collection(self.test_collection).update_one(
97
- {"_id": ObjectId(old_session["_id"])},
98
- {"$set": {"updated_at": old_date}}
99
  )
100
-
101
  self.assertEqual(get_collection(self.test_collection).count_documents({}), 2)
102
  deleted_count = session_repo.prune_old_sessions(days=30, collection_name=self.test_collection)
103
  self.assertEqual(deleted_count, 1)
104
- self.assertEqual(get_collection(self.test_collection).count_documents({}), 1)
105
 
106
 
107
  class TestSessionRepositoryExceptions(BaseMongoTest):
@@ -118,14 +125,11 @@ class TestSessionRepositoryExceptions(BaseMongoTest):
118
  @patch('src.data.repositories.session.get_collection')
119
  def test_write_error_raises_action_failed(self, mock_get_collection):
120
  """Test that a WriteError during an operation is raised as ActionFailed."""
121
- # Configure the mock to return a collection object whose methods raise errors
122
  mock_collection = mock_get_collection.return_value
123
  mock_collection.update_one.side_effect = WriteError("Simulated schema validation error")
124
- mock_collection.find_one.return_value = {"messages": []} # Needed for add_message to proceed
125
-
126
  with self.assertRaises(ActionFailed):
127
- # This will fail inside at the update_one call, which we've mocked
128
- session_repo.add_message("68e212480769b3f99015f43c", "content", True, collection_name=self.test_collection)
129
 
130
  def test_invalid_id_raises_action_failed(self):
131
  """Test that functions raise ActionFailed when given a malformed ObjectId string."""
@@ -150,7 +154,6 @@ class TestSessionRepositoryExceptions(BaseMongoTest):
150
  def test_all_functions_raise_on_connection_error(self, mock_get_collection):
151
  """Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
152
  mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
153
-
154
  with self.assertRaises(ActionFailed):
155
  session_repo.init(collection_name=self.test_collection, drop=True)
156
  with self.assertRaises(ActionFailed):
@@ -175,4 +178,4 @@ class TestSessionRepositoryExceptions(BaseMongoTest):
175
  if __name__ == "__main__":
176
  logger().info("Starting MongoDB repository integration tests...")
177
  unittest.main(verbosity=2)
178
- logger().info("Tests completed and database connection closed.")
 
8
 
9
  from src.data.connection import ActionFailed, Collections, get_collection
10
  from src.data.repositories import session as session_repo
11
+ from src.models.session import Message, Session
12
  from src.utils.logger import logger
13
  from tests.base_test import BaseMongoTest
14
 
 
35
  session = session_repo.create_session(
36
  self.account_id, self.patient_id, "Test Chat", collection_name=self.test_collection
37
  )
38
+ self.assertIsInstance(session, Session)
39
+ self.assertEqual(session.title, "Test Chat")
40
+
41
+ retrieved = session_repo.get_session(session.id, collection_name=self.test_collection)
42
  self.assertIsNotNone(retrieved)
43
+ self.assertIsInstance(retrieved, Session)
44
+ self.assertEqual(retrieved.id, session.id) # type: ignore
45
+ self.assertEqual(retrieved.account_id, self.account_id) # type: ignore
46
  self.assertIsNone(session_repo.get_session(str(ObjectId()), collection_name=self.test_collection))
47
 
48
  def test_add_and_get_messages(self):
49
+ """Test adding messages and retrieving them as Message models."""
50
  session = session_repo.create_session(self.account_id, self.patient_id, "Msg Test", collection_name=self.test_collection)
51
+
52
+ session_repo.add_message(session.id, "User message 1", True, collection_name=self.test_collection)
53
+ # Add a small delay to ensure a distinct timestamp for the next message
54
+ time.sleep(0.01)
55
+ session_repo.add_message(session.id, "AI response 1", False, collection_name=self.test_collection)
56
+
57
+ messages = session_repo.get_session_messages(session.id, collection_name=self.test_collection)
58
  self.assertEqual(len(messages), 2)
59
+ self.assertIsInstance(messages[0], Message)
60
+ self.assertEqual(messages[0].content, "AI response 1") # Descending order is now guaranteed
61
+ self.assertEqual(messages[1].id, 0)
62
+ self.assertEqual(len(session_repo.get_session_messages(session.id, limit=1, collection_name=self.test_collection)), 1)
63
 
64
  def test_list_sessions(self):
65
+ """Test listing sessions for a patient and user returns lists of Session models."""
66
  session_repo.create_session(self.account_id, self.patient_id, "First", collection_name=self.test_collection)
67
+ time.sleep(0.01)
68
  s2 = session_repo.create_session(self.account_id, self.patient_id, "Second", collection_name=self.test_collection)
69
+
70
  patient_sessions = session_repo.list_patient_sessions(self.patient_id, collection_name=self.test_collection)
71
  self.assertEqual(len(patient_sessions), 2)
72
+ self.assertIsInstance(patient_sessions[0], Session)
73
+ self.assertEqual(patient_sessions[0].id, s2.id)
74
+
75
  user_sessions = session_repo.get_user_sessions(self.account_id, collection_name=self.test_collection)
76
  self.assertEqual(len(user_sessions), 2)
77
+ self.assertEqual(user_sessions[0].id, s2.id)
78
 
79
  def test_update_session_title(self):
80
  """Test updating a session's title and its timestamp."""
81
  session = session_repo.create_session(self.account_id, self.patient_id, "Old", collection_name=self.test_collection)
82
+ original_doc = self.get_doc_by_id(Collections.SESSION, session.id)
83
+ self.assertIsNotNone(original_doc)
84
+
85
+ success = session_repo.update_session_title(session.id, "New", collection_name=self.test_collection)
86
  self.assertTrue(success)
87
+
88
+ updated_doc = self.get_doc_by_id(Collections.SESSION, session.id)
89
+ self.assertIsNotNone(updated_doc)
90
  self.assertEqual(updated_doc["title"], "New") # type: ignore
91
  self.assertLess(original_doc["updated_at"], updated_doc["updated_at"]) # type: ignore
92
  self.assertFalse(session_repo.update_session_title(str(ObjectId()), "Ghost", collection_name=self.test_collection))
 
94
  def test_delete_session(self):
95
  """Test deleting a session."""
96
  session = session_repo.create_session(self.account_id, self.patient_id, "To Delete", collection_name=self.test_collection)
97
+ self.assertTrue(session_repo.delete_session(session.id, collection_name=self.test_collection))
98
+ self.assertIsNone(session_repo.get_session(session.id, collection_name=self.test_collection))
99
  self.assertFalse(session_repo.delete_session(str(ObjectId()), collection_name=self.test_collection))
100
 
101
  def test_prune_old_sessions(self):
102
  """Test deleting sessions older than a specified number of days."""
103
+ old_session = session_repo.create_session(self.account_id, self.patient_id, "Old", collection_name=self.test_collection)
104
+ session_repo.create_session(self.account_id, self.patient_id, "New", collection_name=self.test_collection)
 
 
 
105
  old_date = datetime.now(timezone.utc) - timedelta(days=31)
106
  get_collection(self.test_collection).update_one(
107
+ {"_id": ObjectId(old_session.id)}, {"$set": {"updated_at": old_date}}
 
108
  )
 
109
  self.assertEqual(get_collection(self.test_collection).count_documents({}), 2)
110
  deleted_count = session_repo.prune_old_sessions(days=30, collection_name=self.test_collection)
111
  self.assertEqual(deleted_count, 1)
 
112
 
113
 
114
  class TestSessionRepositoryExceptions(BaseMongoTest):
 
125
  @patch('src.data.repositories.session.get_collection')
126
  def test_write_error_raises_action_failed(self, mock_get_collection):
127
  """Test that a WriteError during an operation is raised as ActionFailed."""
 
128
  mock_collection = mock_get_collection.return_value
129
  mock_collection.update_one.side_effect = WriteError("Simulated schema validation error")
130
+ mock_collection.find_one.return_value = {"messages": []}
 
131
  with self.assertRaises(ActionFailed):
132
+ session_repo.add_message(str(ObjectId()), "content", True, collection_name=self.test_collection)
 
133
 
134
  def test_invalid_id_raises_action_failed(self):
135
  """Test that functions raise ActionFailed when given a malformed ObjectId string."""
 
154
  def test_all_functions_raise_on_connection_error(self, mock_get_collection):
155
  """Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
156
  mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
 
157
  with self.assertRaises(ActionFailed):
158
  session_repo.init(collection_name=self.test_collection, drop=True)
159
  with self.assertRaises(ActionFailed):
 
178
  if __name__ == "__main__":
179
  logger().info("Starting MongoDB repository integration tests...")
180
  unittest.main(verbosity=2)
181
+ logger().info("Tests completed.")