dylanglenister commited on
Commit
2fba483
·
1 Parent(s): 65d3a5d

Added session tests

Browse files
Files changed (1) hide show
  1. tests/test_session.py +187 -0
tests/test_session.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import unittest
3
+ from datetime import datetime, timedelta, timezone
4
+
5
+ from bson import ObjectId
6
+
7
+ from src.data.connection import ActionFailed, Collections, get_collection
8
+ from src.data.repositories import session as session_repo
9
+ from src.utils.logger import logger
10
+ from tests.base_test import BaseMongoTest
11
+
12
+
13
+ class TestSessionRepository(BaseMongoTest):
14
+
15
+ def setUp(self):
16
+ """Set up a clean test environment before each test."""
17
+ super().setUp()
18
+ self.test_collection = self._collections[Collections.SESSION]
19
+ session_repo.init(collection_name=self.test_collection, drop=True)
20
+
21
+ self.account_id = str(ObjectId())
22
+ self.patient_id = str(ObjectId())
23
+
24
+ def test_init_functionality(self):
25
+ """Test that init sets up the collection and indexes correctly."""
26
+ self.assertIn(self.test_collection, self.db.list_collection_names())
27
+ index_info = get_collection(self.test_collection).index_information()
28
+ self.assertIn("messages._id_1", index_info)
29
+
30
+ def test_create_and_get_session(self):
31
+ """Test chat session creation and retrieval by ID."""
32
+ # Test creation
33
+ session = session_repo.create_session(
34
+ self.account_id,
35
+ self.patient_id,
36
+ "Test Chat",
37
+ collection_name=self.test_collection
38
+ )
39
+ self.assertIn("_id", session)
40
+ self.assertIsInstance(session["_id"], str)
41
+ self.assertEqual(session["title"], "Test Chat")
42
+ self.assertEqual(len(session["messages"]), 0)
43
+
44
+ # Test retrieval
45
+ retrieved = session_repo.get_session(session["_id"], collection_name=self.test_collection)
46
+ self.assertIsNotNone(retrieved)
47
+ self.assertEqual(retrieved["_id"], session["_id"]) # type: ignore
48
+ self.assertEqual(retrieved["account_id"], self.account_id) # type: ignore
49
+ self.assertEqual(retrieved["patient_id"], self.patient_id) # type: ignore
50
+
51
+ # Test getting a non-existent session
52
+ non_existent = session_repo.get_session(str(ObjectId()), collection_name=self.test_collection)
53
+ self.assertIsNone(non_existent)
54
+
55
+ def test_add_and_get_messages(self):
56
+ """Test adding messages and retrieving them in the correct order."""
57
+ session = session_repo.create_session(
58
+ self.account_id, self.patient_id, "Message Test", collection_name=self.test_collection
59
+ )
60
+ session_id = session["_id"]
61
+
62
+ # Add messages and verify session's updated_at timestamp changes
63
+ original_doc = self.get_doc_by_id(Collections.SESSION, session_id)
64
+ time.sleep(0.01) # Ensure timestamp will be different
65
+ session_repo.add_message(session_id, "User message 1", True, collection_name=self.test_collection)
66
+ updated_doc = self.get_doc_by_id(Collections.SESSION, session_id)
67
+ self.assertLess(original_doc["updated_at"], updated_doc["updated_at"]) # type: ignore
68
+
69
+ session_repo.add_message(session_id, "AI response 1", False, collection_name=self.test_collection)
70
+ session_repo.add_message(session_id, "User message 2", True, collection_name=self.test_collection)
71
+
72
+ # Test message retrieval (should be in descending order of creation)
73
+ messages = session_repo.get_session_messages(session_id, collection_name=self.test_collection)
74
+ self.assertEqual(len(messages), 3)
75
+ self.assertEqual(messages[0]["_id"], 2)
76
+ self.assertEqual(messages[0]["content"], "User message 2")
77
+ self.assertEqual(messages[1]["_id"], 1)
78
+ self.assertEqual(messages[2]["_id"], 0)
79
+
80
+ # Test limit
81
+ limited_messages = session_repo.get_session_messages(session_id, limit=2, collection_name=self.test_collection)
82
+ self.assertEqual(len(limited_messages), 2)
83
+ self.assertEqual(limited_messages[0]["_id"], 2)
84
+
85
+ # Test adding message to non-existent session
86
+ with self.assertRaises(ActionFailed):
87
+ session_repo.add_message(str(ObjectId()), "ghost", True, collection_name=self.test_collection)
88
+
89
+ def test_list_patient_sessions(self):
90
+ """Test listing sessions for a specific patient, sorted by update time."""
91
+ p_id_1 = str(ObjectId())
92
+ p_id_2 = str(ObjectId())
93
+
94
+ # Create sessions, sleeping briefly to ensure distinct updated_at times
95
+ session_repo.create_session(self.account_id, p_id_1, "P1 Chat 1", collection_name=self.test_collection)
96
+ time.sleep(0.01)
97
+ session_repo.create_session(self.account_id, p_id_2, "P2 Chat 1", collection_name=self.test_collection) # Belongs to other patient
98
+ time.sleep(0.01)
99
+ s2 = session_repo.create_session(self.account_id, p_id_1, "P1 Chat 2", collection_name=self.test_collection)
100
+
101
+ # Test listing for patient 1
102
+ sessions = session_repo.list_patient_sessions(p_id_1, collection_name=self.test_collection)
103
+ self.assertEqual(len(sessions), 2)
104
+ self.assertEqual(sessions[0]["_id"], s2["_id"]) # Most recently created should be first
105
+
106
+ def test_get_user_sessions(self):
107
+ """Test listing sessions for a specific user, sorted by update time."""
108
+ user1 = str(ObjectId())
109
+ user2 = str(ObjectId())
110
+
111
+ s1 = session_repo.create_session(user1, self.patient_id, "U1 Chat 1", collection_name=self.test_collection)
112
+ time.sleep(0.01)
113
+ session_repo.create_session(user2, self.patient_id, "U2 Chat 1", collection_name=self.test_collection)
114
+ time.sleep(0.01)
115
+ s3 = session_repo.create_session(user1, self.patient_id, "U1 Chat 2", collection_name=self.test_collection)
116
+
117
+ sessions = session_repo.get_user_sessions(user1, collection_name=self.test_collection)
118
+ self.assertEqual(len(sessions), 2)
119
+ self.assertEqual(sessions[0]["_id"], s3["_id"]) # s3 was updated most recently
120
+ self.assertEqual(sessions[1]["_id"], s1["_id"])
121
+
122
+ # Test limit
123
+ sessions_limited = session_repo.get_user_sessions(user1, limit=1, collection_name=self.test_collection)
124
+ self.assertEqual(len(sessions_limited), 1)
125
+
126
+ def test_update_session_title(self):
127
+ """Test updating a session's title."""
128
+ session = session_repo.create_session(self.account_id, self.patient_id, "Old Title", collection_name=self.test_collection)
129
+ session_id = session["_id"]
130
+
131
+ success = session_repo.update_session_title(session_id, "New Title", collection_name=self.test_collection)
132
+ self.assertTrue(success)
133
+
134
+ updated_session = session_repo.get_session(session_id, collection_name=self.test_collection)
135
+ self.assertEqual(updated_session["title"], "New Title") # type: ignore
136
+
137
+ # Test updating non-existent session
138
+ success_fail = session_repo.update_session_title(str(ObjectId()), "Ghost", collection_name=self.test_collection)
139
+ self.assertFalse(success_fail)
140
+
141
+ def test_delete_session(self):
142
+ """Test deleting a session."""
143
+ session = session_repo.create_session(self.account_id, self.patient_id, "To Delete", collection_name=self.test_collection)
144
+ session_id = session["_id"]
145
+
146
+ success = session_repo.delete_session(session_id, collection_name=self.test_collection)
147
+ self.assertTrue(success)
148
+
149
+ deleted_session = session_repo.get_session(session_id, collection_name=self.test_collection)
150
+ self.assertIsNone(deleted_session)
151
+
152
+ # Test deleting non-existent session
153
+ success_fail = session_repo.delete_session(str(ObjectId()), collection_name=self.test_collection)
154
+ self.assertFalse(success_fail)
155
+
156
+ def test_prune_old_sessions(self):
157
+ """Test deleting sessions older than a specified number of days."""
158
+ coll = get_collection(self.test_collection)
159
+ now = datetime.now(timezone.utc)
160
+ old_date = now - timedelta(days=31)
161
+
162
+ # Manually insert one old and one new session
163
+ coll.insert_one({
164
+ "account_id": ObjectId(self.account_id), "patient_id": ObjectId(self.patient_id),
165
+ "title": "Old Session", "created_at": old_date, "updated_at": old_date, "messages": []
166
+ })
167
+ coll.insert_one({
168
+ "account_id": ObjectId(self.account_id), "patient_id": ObjectId(self.patient_id),
169
+ "title": "New Session", "created_at": now, "updated_at": now, "messages": []
170
+ })
171
+
172
+ self.assertEqual(coll.count_documents({}), 2)
173
+
174
+ deleted_count = session_repo.prune_old_sessions(days=30, collection_name=self.test_collection)
175
+ self.assertEqual(deleted_count, 1)
176
+ self.assertEqual(coll.count_documents({}), 1)
177
+
178
+ remaining = coll.find_one()
179
+ self.assertEqual(remaining["title"], "New Session") # type: ignore
180
+
181
+
182
+ if __name__ == "__main__":
183
+ try:
184
+ logger().info("Starting MongoDB repository integration tests...")
185
+ unittest.main(verbosity=2)
186
+ finally:
187
+ logger().info("Tests completed and database connection closed.")