harismlnaslm commited on
Commit
6fc73ab
·
1 Parent(s): e513905

Fix chat API to use real AI instead of mock responses - integrate trained model and HuggingFace API

Browse files
Files changed (1) hide show
  1. app.py +147 -21
app.py CHANGED
@@ -232,25 +232,17 @@ async def test_trained_model():
232
  # Chat API endpoint
233
  @app.post("/chat", response_model=ChatResponse)
234
  async def chat(request: ChatRequest):
235
- """Chat with the AI assistant"""
236
  try:
237
- # Simple mock response for now
238
- mock_responses = {
239
- "dimana lokasi textilindo": "Textilindo berkantor pusat di Jl. Raya Prancis No.39, Kosambi Tim., Kec. Kosambi, Kabupaten Tangerang, Banten 15213",
240
- "jam berapa textilindo beroperasional": "Jam operasional Senin-Jumat 08:00-17:00, Sabtu 08:00-12:00.",
241
- "berapa ketentuan pembelian": "Minimal order 1 roll per jenis kain",
242
- "apa ada gratis ongkir": "Gratis ongkir untuk order minimal 5 roll.",
243
- "apa bisa dikirimkan sample": "Hallo kak untuk sampel kita bisa kirimkan gratis ya kak 😊"
244
- }
245
-
246
- # Simple keyword matching
247
- user_lower = request.message.lower()
248
- response = "Halo! Saya adalah asisten AI Textilindo. Bagaimana saya bisa membantu Anda hari ini? 😊"
249
-
250
- for key, mock_response in mock_responses.items():
251
- if any(word in user_lower for word in key.split()):
252
- response = mock_response
253
- break
254
 
255
  return ChatResponse(
256
  response=response,
@@ -260,13 +252,147 @@ async def chat(request: ChatRequest):
260
 
261
  except Exception as e:
262
  logger.error(f"Chat error: {e}")
 
 
263
  return ChatResponse(
264
- response="Maaf, terjadi kesalahan. Silakan coba lagi.",
265
  conversation_id=request.conversation_id or "default",
266
- status="error"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
- # Training function
 
 
270
  async def train_model_async(
271
  model_name: str,
272
  dataset_path: str,
 
232
  # Chat API endpoint
233
  @app.post("/chat", response_model=ChatResponse)
234
  async def chat(request: ChatRequest):
235
+ """Chat with the AI assistant using real AI model"""
236
  try:
237
+ # Try to use trained model first
238
+ model_path = "./models/textilindo-trained"
239
+ if Path(model_path).exists():
240
+ logger.info("Using trained model for chat")
241
+ response = await generate_ai_response(request.message, model_path)
242
+ else:
243
+ # Fallback to HuggingFace Inference API
244
+ logger.info("Using HuggingFace Inference API for chat")
245
+ response = await generate_hf_response(request.message)
 
 
 
 
 
 
 
 
246
 
247
  return ChatResponse(
248
  response=response,
 
252
 
253
  except Exception as e:
254
  logger.error(f"Chat error: {e}")
255
+ # Fallback to mock response
256
+ response = get_mock_response(request.message)
257
  return ChatResponse(
258
+ response=response,
259
  conversation_id=request.conversation_id or "default",
260
+ status="success"
261
+ )
262
+
263
+ async def generate_ai_response(message: str, model_path: str) -> str:
264
+ """Generate response using trained model"""
265
+ try:
266
+ from transformers import AutoTokenizer, AutoModelForCausalLM
267
+ import torch
268
+
269
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
270
+ model = AutoModelForCausalLM.from_pretrained(model_path)
271
+
272
+ # Create prompt
273
+ prompt = f"Question: {message} Answer:"
274
+ inputs = tokenizer(prompt, return_tensors="pt")
275
+
276
+ with torch.no_grad():
277
+ outputs = model.generate(
278
+ **inputs,
279
+ max_length=inputs.input_ids.shape[1] + 50,
280
+ temperature=0.7,
281
+ do_sample=True,
282
+ pad_token_id=tokenizer.eos_token_id,
283
+ eos_token_id=tokenizer.eos_token_id
284
+ )
285
+
286
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
287
+
288
+ # Extract only the answer part
289
+ if "Answer:" in full_response:
290
+ answer = full_response.split("Answer:")[-1].strip()
291
+ return answer
292
+ else:
293
+ return full_response
294
+
295
+ except Exception as e:
296
+ logger.error(f"AI model error: {e}")
297
+ return get_mock_response(message)
298
+
299
+ async def generate_hf_response(message: str) -> str:
300
+ """Generate response using HuggingFace Inference API"""
301
+ try:
302
+ from huggingface_hub import InferenceClient
303
+
304
+ # Get API key from environment
305
+ api_key = os.getenv("HUGGINGFACE_API_KEY")
306
+ if not api_key:
307
+ logger.warning("HUGGINGFACE_API_KEY not found, using mock response")
308
+ return get_mock_response(message)
309
+
310
+ # Initialize client
311
+ client = InferenceClient(token=api_key)
312
+
313
+ # Load system prompt from file or use default
314
+ system_prompt = load_system_prompt()
315
+
316
+ # Create full prompt
317
+ full_prompt = f"<|system|>\n{system_prompt}\n<|user|>\n{message}\n<|assistant|>\n"
318
+
319
+ # Generate response
320
+ response = client.text_generation(
321
+ full_prompt,
322
+ max_new_tokens=512,
323
+ temperature=0.7,
324
+ top_p=0.9,
325
+ top_k=40,
326
+ repetition_penalty=1.1,
327
+ stop_sequences=["<|end|>", "<|user|>"]
328
  )
329
+
330
+ # Extract only the assistant's response
331
+ if "<|assistant|>" in response:
332
+ assistant_response = response.split("<|assistant|>")[-1].strip()
333
+ assistant_response = assistant_response.replace("<|end|>", "").strip()
334
+ return assistant_response
335
+ else:
336
+ return response
337
+
338
+ except Exception as e:
339
+ logger.error(f"HuggingFace API error: {e}")
340
+ return get_mock_response(message)
341
+
342
+ def get_mock_response(message: str) -> str:
343
+ """Fallback mock responses"""
344
+ mock_responses = {
345
+ "dimana lokasi textilindo": "Textilindo berkantor pusat di Jl. Raya Prancis No.39, Kosambi Tim., Kec. Kosambi, Kabupaten Tangerang, Banten 15213",
346
+ "jam berapa textilindo beroperasional": "Jam operasional Senin-Jumat 08:00-17:00, Sabtu 08:00-12:00.",
347
+ "jam berapa textilindo buka": "Jam operasional Senin-Jumat 08:00-17:00, Sabtu 08:00-12:00.",
348
+ "berapa ketentuan pembelian": "Minimal order 1 roll per jenis kain",
349
+ "apa ada gratis ongkir": "Gratis ongkir untuk order minimal 5 roll.",
350
+ "apa bisa dikirimkan sample": "Hallo kak untuk sampel kita bisa kirimkan gratis ya kak 😊"
351
+ }
352
+
353
+ # Simple keyword matching
354
+ user_lower = message.lower()
355
+ for key, mock_response in mock_responses.items():
356
+ if any(word in user_lower for word in key.split()):
357
+ return mock_response
358
+
359
+ return "Halo! Saya adalah asisten AI Textilindo. Bagaimana saya bisa membantu Anda hari ini? 😊"
360
+
361
+ def load_system_prompt() -> str:
362
+ """Load system prompt from file or return default"""
363
+ try:
364
+ system_prompt_path = "configs/system_prompt.md"
365
+ if Path(system_prompt_path).exists():
366
+ with open(system_prompt_path, 'r', encoding='utf-8') as f:
367
+ content = f.read()
368
+
369
+ # Extract SYSTEM_PROMPT from markdown if it exists
370
+ if 'SYSTEM_PROMPT = """' in content:
371
+ start = content.find('SYSTEM_PROMPT = """') + len('SYSTEM_PROMPT = """')
372
+ end = content.find('"""', start)
373
+ system_prompt = content[start:end].strip()
374
+ else:
375
+ # Use entire content
376
+ system_prompt = content.strip()
377
+
378
+ return system_prompt
379
+ else:
380
+ # Default system prompt
381
+ return """You are a friendly and helpful AI assistant for Textilindo, a textile company.
382
+
383
+ Always respond in Indonesian (Bahasa Indonesia).
384
+ Keep responses short and direct.
385
+ Be friendly and helpful.
386
+ Use exact information from the knowledge base.
387
+ The company uses yards for sales.
388
+ Minimum purchase is 1 roll (67-70 yards)."""
389
+ except Exception as e:
390
+ logger.error(f"Error loading system prompt: {e}")
391
+ return """You are a friendly and helpful AI assistant for Textilindo, a textile company.
392
 
393
+ Always respond in Indonesian (Bahasa Indonesia).
394
+ Keep responses short and direct.
395
+ Be friendly and helpful."""
396
  async def train_model_async(
397
  model_name: str,
398
  dataset_path: str,