harismlnaslm commited on
Commit
94aafab
·
1 Parent(s): c5e93f3

Integrate training data: Use actual training data instead of mock responses for intelligent AI responses

Browse files
Files changed (1) hide show
  1. app.py +105 -23
app.py CHANGED
@@ -9,13 +9,15 @@ import json
9
  import logging
10
  from pathlib import Path
11
  from datetime import datetime
12
- from typing import Optional, Dict, Any
13
  from fastapi import FastAPI, HTTPException, Request, BackgroundTasks
14
  from fastapi.responses import HTMLResponse, JSONResponse
15
  from fastapi.middleware.cors import CORSMiddleware
16
  from pydantic import BaseModel
17
  import uvicorn
18
  import requests
 
 
19
 
20
  # Setup logging
21
  logging.basicConfig(level=logging.INFO)
@@ -80,14 +82,76 @@ training_status = {
80
  "error": None
81
  }
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  class TextilindoAI:
84
  """Textilindo AI Assistant using HuggingFace Inference API"""
85
 
86
  def __init__(self):
87
  self.api_key = os.getenv('HUGGINGFACE_API_KEY')
88
  # Use a model available on free HuggingFace Inference API
89
- self.model = os.getenv('DEFAULT_MODEL', 'microsoft/DialoGPT-small') # Try DialoGPT-small
90
  self.system_prompt = self.load_system_prompt()
 
91
 
92
  if not self.api_key:
93
  logger.warning("HUGGINGFACE_API_KEY not found. Using mock responses.")
@@ -138,13 +202,21 @@ The company uses yards for sales.
138
  Minimum purchase is 1 roll (67-70 yards)."""
139
 
140
  def generate_response(self, user_message: str) -> str:
141
- """Generate response using HuggingFace Inference API"""
 
 
 
 
 
 
 
 
142
  if not self.client:
143
- logger.warning("No HuggingFace client available, using mock response")
144
- return self.get_mock_response(user_message)
145
 
146
  try:
147
- # Use DialoGPT conversation format
148
  prompt = f"User: {user_message}\nAssistant:"
149
 
150
  logger.info(f"Using model: {self.model}")
@@ -152,45 +224,55 @@ Minimum purchase is 1 roll (67-70 yards)."""
152
 
153
  logger.info(f"Generating response for prompt: {prompt[:100]}...")
154
 
155
- # Generate response with DialoGPT parameters
156
  response = self.client.text_generation(
157
  prompt,
158
- max_new_tokens=200,
159
- temperature=0.7,
160
  top_p=0.9,
161
- top_k=40,
162
- repetition_penalty=1.1,
163
- stop_sequences=["User:", "Assistant:"]
 
164
  )
165
 
166
  logger.info(f"Raw AI response: {response[:200]}...")
167
 
168
- # Clean up the response for DialoGPT
169
  if "Assistant:" in response:
170
  assistant_response = response.split("Assistant:")[-1].strip()
171
  else:
172
  assistant_response = response.strip()
173
 
174
- # Remove any remaining special tokens
175
  assistant_response = assistant_response.replace("<|end|>", "").replace("<|user|>", "").strip()
176
 
 
 
 
 
 
 
 
 
 
 
177
  logger.info(f"Cleaned AI response: {assistant_response[:100]}...")
178
 
179
- # If response is too short or generic, use mock response
180
  if len(assistant_response) < 10 or "I don't know" in assistant_response.lower():
181
- logger.warning("AI response too short, using mock response")
182
- return self.get_mock_response(user_message)
183
-
184
- # For testing: if it's a non-Textilindo question, return the AI response directly
185
- if not any(keyword in user_message.lower() for keyword in ['textilindo', 'lokasi', 'jam', 'katalog', 'produk', 'sample', 'pembelian', 'pembayaran', 'ongkir']):
186
- logger.info("Non-Textilindo question detected, returning AI response directly")
187
- return assistant_response
188
 
189
  return assistant_response
190
 
191
  except Exception as e:
192
  logger.error(f"Error generating response: {e}")
193
- return self.get_mock_response(user_message)
 
 
 
 
194
 
195
  def get_mock_response(self, user_message: str) -> str:
196
  """Enhanced mock responses with better context awareness"""
 
9
  import logging
10
  from pathlib import Path
11
  from datetime import datetime
12
+ from typing import Optional, Dict, Any, List
13
  from fastapi import FastAPI, HTTPException, Request, BackgroundTasks
14
  from fastapi.responses import HTMLResponse, JSONResponse
15
  from fastapi.middleware.cors import CORSMiddleware
16
  from pydantic import BaseModel
17
  import uvicorn
18
  import requests
19
+ import re
20
+ from difflib import SequenceMatcher
21
 
22
  # Setup logging
23
  logging.basicConfig(level=logging.INFO)
 
82
  "error": None
83
  }
84
 
85
+ class TrainingDataLoader:
86
+ """Load and manage training data for intelligent responses"""
87
+
88
+ def __init__(self, data_path: str = "data/textilindo_training_data.jsonl"):
89
+ self.data_path = data_path
90
+ self.training_data = []
91
+ self.load_data()
92
+
93
+ def load_data(self):
94
+ """Load training data from JSONL file"""
95
+ try:
96
+ if os.path.exists(self.data_path):
97
+ with open(self.data_path, 'r', encoding='utf-8') as f:
98
+ for line in f:
99
+ line = line.strip()
100
+ if line:
101
+ try:
102
+ data = json.loads(line)
103
+ self.training_data.append(data)
104
+ except json.JSONDecodeError:
105
+ continue
106
+ logger.info(f"Loaded {len(self.training_data)} training samples")
107
+ else:
108
+ logger.warning(f"Training data file not found: {self.data_path}")
109
+ except Exception as e:
110
+ logger.error(f"Error loading training data: {e}")
111
+
112
+ def find_best_match(self, user_input: str, threshold: float = 0.3) -> Optional[Dict]:
113
+ """Find the best matching training sample for user input"""
114
+ if not self.training_data:
115
+ return None
116
+
117
+ user_input_lower = user_input.lower().strip()
118
+ best_match = None
119
+ best_score = 0
120
+
121
+ for data in self.training_data:
122
+ instruction = data.get('instruction', '').lower().strip()
123
+ if not instruction:
124
+ continue
125
+
126
+ # Calculate similarity score
127
+ score = SequenceMatcher(None, user_input_lower, instruction).ratio()
128
+
129
+ # Also check for keyword matches
130
+ user_words = set(user_input_lower.split())
131
+ instruction_words = set(instruction.split())
132
+ keyword_score = len(user_words.intersection(instruction_words)) / max(len(user_words), 1)
133
+
134
+ # Combine scores
135
+ combined_score = (score * 0.7) + (keyword_score * 0.3)
136
+
137
+ if combined_score > best_score and combined_score >= threshold:
138
+ best_score = combined_score
139
+ best_match = data
140
+
141
+ if best_match:
142
+ logger.info(f"Found match with score {best_score:.2f}: {best_match.get('instruction', '')[:50]}...")
143
+
144
+ return best_match
145
+
146
  class TextilindoAI:
147
  """Textilindo AI Assistant using HuggingFace Inference API"""
148
 
149
  def __init__(self):
150
  self.api_key = os.getenv('HUGGINGFACE_API_KEY')
151
  # Use a model available on free HuggingFace Inference API
152
+ self.model = os.getenv('DEFAULT_MODEL', 'gpt2') # Use GPT-2 which is available
153
  self.system_prompt = self.load_system_prompt()
154
+ self.data_loader = TrainingDataLoader()
155
 
156
  if not self.api_key:
157
  logger.warning("HUGGINGFACE_API_KEY not found. Using mock responses.")
 
202
  Minimum purchase is 1 roll (67-70 yards)."""
203
 
204
  def generate_response(self, user_message: str) -> str:
205
+ """Generate response using training data and HuggingFace Inference API"""
206
+
207
+ # First, try to find a match in training data
208
+ training_match = self.data_loader.find_best_match(user_message)
209
+ if training_match:
210
+ logger.info("Using training data response")
211
+ return training_match.get('output', '')
212
+
213
+ # If no training data match, try HuggingFace API if available
214
  if not self.client:
215
+ logger.warning("No HuggingFace client available, using fallback response")
216
+ return self.get_fallback_response(user_message)
217
 
218
  try:
219
+ # Use GPT-2 conversation format
220
  prompt = f"User: {user_message}\nAssistant:"
221
 
222
  logger.info(f"Using model: {self.model}")
 
224
 
225
  logger.info(f"Generating response for prompt: {prompt[:100]}...")
226
 
227
+ # Generate response with GPT-2 parameters
228
  response = self.client.text_generation(
229
  prompt,
230
+ max_new_tokens=150,
231
+ temperature=0.8,
232
  top_p=0.9,
233
+ top_k=50,
234
+ repetition_penalty=1.2,
235
+ do_sample=True,
236
+ stop_sequences=["User:", "Assistant:", "\n\n"]
237
  )
238
 
239
  logger.info(f"Raw AI response: {response[:200]}...")
240
 
241
+ # Clean up the response for GPT-2
242
  if "Assistant:" in response:
243
  assistant_response = response.split("Assistant:")[-1].strip()
244
  else:
245
  assistant_response = response.strip()
246
 
247
+ # Remove any remaining special tokens and clean up
248
  assistant_response = assistant_response.replace("<|end|>", "").replace("<|user|>", "").strip()
249
 
250
+ # Remove any incomplete sentences or cut-off text
251
+ if assistant_response.endswith(('.', '!', '?')):
252
+ pass # Complete sentence
253
+ elif '.' in assistant_response:
254
+ # Take only the first complete sentence
255
+ assistant_response = assistant_response.split('.')[0] + '.'
256
+ else:
257
+ # If no complete sentence, take first 100 characters
258
+ assistant_response = assistant_response[:100]
259
+
260
  logger.info(f"Cleaned AI response: {assistant_response[:100]}...")
261
 
262
+ # If response is too short or generic, use fallback
263
  if len(assistant_response) < 10 or "I don't know" in assistant_response.lower():
264
+ logger.warning("AI response too short, using fallback response")
265
+ return self.get_fallback_response(user_message)
 
 
 
 
 
266
 
267
  return assistant_response
268
 
269
  except Exception as e:
270
  logger.error(f"Error generating response: {e}")
271
+ return self.get_fallback_response(user_message)
272
+
273
+ def get_fallback_response(self, user_message: str) -> str:
274
+ """Fallback response when no training data match and no API available"""
275
+ return f"Halo! Saya adalah asisten AI Textilindo. Saya bisa membantu Anda dengan pertanyaan tentang produk dan layanan kami, atau sekadar mengobrol! Bagaimana saya bisa membantu Anda hari ini? 😊"
276
 
277
  def get_mock_response(self, user_message: str) -> str:
278
  """Enhanced mock responses with better context awareness"""