File size: 7,994 Bytes
532e8ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
# tests/routes/test_session_routes.py

from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from fastapi import FastAPI, status
from fastapi.testclient import TestClient

from src.api.routes.session import router as session_router
from src.core.state import get_state
from src.models.session import Message, Session

# --- Test Setup: Mocking and Dependency Injection ---

mock_memory_manager = MagicMock()

class MockAppState:
	def __init__(self):
		self.memory_manager = mock_memory_manager
		# Add mock rotators for the session endpoint
		self.gemini_rotator = MagicMock()
		self.nvidia_rotator = MagicMock()

def override_get_state() -> MockAppState:
	return MockAppState()

app = FastAPI()
app.include_router(session_router)
app.dependency_overrides[get_state] = override_get_state


# --- Fixtures ---

@pytest.fixture
def client():
	"""Provides a TestClient for making requests to the app."""
	with TestClient(app) as c:
		yield c

@pytest.fixture(autouse=True)
def reset_mocks():
	"""Resets mocks before each test to ensure test isolation."""
	mock_memory_manager.reset_mock()


# --- Test Data ---

fake_session_dict = {
	"_id": "session456",
	"account_id": "doctor789",
	"patient_id": "patient123",
	"title": "Checkup",
	"created_at": datetime.now(timezone.utc).isoformat(),
	"updated_at": datetime.now(timezone.utc).isoformat(),
	"messages": [],
}
fake_session = Session.model_validate(fake_session_dict)

fake_message_dict = {
	"_id": 1,
	"sent_by_user": True,
	"content": "Hello",
	"timestamp": datetime.now(timezone.utc).isoformat(),
}
fake_message = Message.model_validate(fake_message_dict)


# --- Tests for POST /session ---

def test_create_session_success(client: TestClient):
	"""Tests successful creation of a new chat session."""
	mock_memory_manager.create_session.return_value = fake_session

	session_data = {"account_id": "doctor789", "patient_id": "patient123", "title": "Checkup"}
	response = client.post("/session", json=session_data)

	assert response.status_code == status.HTTP_201_CREATED
	assert response.json()["title"] == "Checkup"
	mock_memory_manager.create_session.assert_called_once_with(
		user_id="doctor789", patient_id="patient123", title="Checkup"
	)

def test_create_session_failure(client: TestClient):
	"""Tests that a 500 error is returned if session creation fails."""
	mock_memory_manager.create_session.return_value = None

	session_data = {"account_id": "doctor789", "patient_id": "patient123"}
	response = client.post("/session", json=session_data)

	assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
	assert response.json()["detail"] == "Failed to create session."


# --- Tests for GET /session/{session_id} ---

def test_get_session_success(client: TestClient):
	"""Tests successfully retrieving a session by its ID."""
	mock_memory_manager.get_session.return_value = fake_session

	response = client.get(f"/session/{fake_session.id}")

	assert response.status_code == status.HTTP_200_OK
	assert response.json()["title"] == fake_session.title
	mock_memory_manager.get_session.assert_called_once_with(str(fake_session.id))

def test_get_session_not_found(client: TestClient):
	"""Tests that a 404 error is returned for a non-existent session."""
	mock_memory_manager.get_session.return_value = None

	response = client.get("/session/non_existent_id")

	assert response.status_code == status.HTTP_404_NOT_FOUND
	assert response.json()["detail"] == "Session not found"


# --- Tests for DELETE /session/{session_id} ---

def test_delete_session_success(client: TestClient):
	"""Tests successful deletion of a session."""
	mock_memory_manager.delete_session.return_value = True

	response = client.delete(f"/session/{fake_session.id}")

	assert response.status_code == status.HTTP_204_NO_CONTENT
	mock_memory_manager.delete_session.assert_called_once_with(str(fake_session.id))

def test_delete_session_not_found(client: TestClient):
	"""Tests that a 404 is returned when trying to delete a non-existent session."""
	mock_memory_manager.delete_session.return_value = False

	response = client.delete("/session/non_existent_id")

	assert response.status_code == status.HTTP_404_NOT_FOUND
	assert response.json()["detail"] == "Session not found or already deleted"


# --- Tests for GET /session/{session_id}/messages ---

def test_list_messages_success(client: TestClient):
	"""Tests successfully listing messages for a session."""
	mock_memory_manager.get_session.return_value = fake_session
	mock_memory_manager.get_session_messages.return_value = [fake_message]

	response = client.get(f"/session/{fake_session.id}/messages?limit=10")

	assert response.status_code == status.HTTP_200_OK
	assert len(response.json()) == 1
	assert response.json()[0]["content"] == "Hello"
	mock_memory_manager.get_session.assert_called_once_with(str(fake_session.id))
	mock_memory_manager.get_session_messages.assert_called_once_with(str(fake_session.id), 10)

def test_list_messages_session_not_found(client: TestClient):
	"""Tests listing messages for a non-existent session."""
	mock_memory_manager.get_session.return_value = None

	response = client.get("/session/non_existent_id/messages")

	assert response.status_code == status.HTTP_404_NOT_FOUND
	assert response.json()["detail"] == "Session not found"


# --- Tests for POST /session/{session_id}/messages ---

@patch('src.api.routes.session.generate_medical_response', new_callable=AsyncMock)
def test_post_chat_message_success(mock_generate_response: AsyncMock, client: TestClient):
	"""Tests the full, successful flow of posting a message and getting a response."""
	# Arrange: Mock all async dependencies
	mock_memory_manager.get_enhanced_context = AsyncMock(return_value="Enhanced context.")
	mock_memory_manager.process_medical_exchange = AsyncMock(return_value="Generated summary.")
	mock_generate_response.return_value = "This is the AI response."

	chat_data = {"account_id": "doc1", "patient_id": "pat1", "message": "Patient has a fever."}
	response = client.post(f"/session/{fake_session.id}/messages", json=chat_data)

	assert response.status_code == status.HTTP_200_OK
	json_response = response.json()
	assert json_response["response"] == "This is the AI response."
	assert json_response["medical_context"] == "Enhanced context."

	# Assert that all async functions were called correctly
	mock_memory_manager.get_enhanced_context.assert_awaited_once()
	mock_generate_response.assert_awaited_once()
	mock_memory_manager.process_medical_exchange.assert_awaited_once()

@patch('src.api.routes.session.generate_medical_response', new_callable=AsyncMock)
def test_post_chat_message_context_error(mock_generate_response: AsyncMock, client: TestClient):
	"""Tests failure during the context generation step."""
	mock_memory_manager.get_enhanced_context = AsyncMock(side_effect=Exception("Context DB failed"))

	chat_data = {"account_id": "doc1", "patient_id": "pat1", "message": "Patient has a fever."}
	response = client.post(f"/session/{fake_session.id}/messages", json=chat_data)

	assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
	assert response.json()["detail"] == "Failed to build medical context."
	mock_generate_response.assert_not_awaited() # Should fail before this is called

@patch('src.api.routes.session.generate_medical_response', new_callable=AsyncMock)
def test_post_chat_message_generation_error(mock_generate_response: AsyncMock, client: TestClient):
	"""Tests failure during the AI response generation step."""
	mock_memory_manager.get_enhanced_context = AsyncMock(return_value="Context")
	mock_generate_response.side_effect = Exception("AI model API is down")

	chat_data = {"account_id": "doc1", "patient_id": "pat1", "message": "Patient has a fever."}
	response = client.post(f"/session/{fake_session.id}/messages", json=chat_data)

	assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
	assert response.json()["detail"] == "Failed to generate AI response."