harismlnaslm commited on
Commit
baf2e68
·
1 Parent(s): b0f2f89

Switch to Llama-2-7b-chat-hf model with proper chat formatting

Browse files
Files changed (1) hide show
  1. app.py +38 -16
app.py CHANGED
@@ -85,8 +85,8 @@ class TextilindoAI:
85
 
86
  def __init__(self):
87
  self.api_key = os.getenv('HUGGINGFACE_API_KEY')
88
- # Use a more accessible model for free tier
89
- self.model = os.getenv('DEFAULT_MODEL', 'microsoft/DialoGPT-small')
90
  self.system_prompt = self.load_system_prompt()
91
 
92
  if not self.api_key:
@@ -144,8 +144,11 @@ Minimum purchase is 1 roll (67-70 yards)."""
144
  return self.get_mock_response(user_message)
145
 
146
  try:
147
- # For DialoGPT, use a simpler prompt format
148
- if "dialogpt" in self.model.lower():
 
 
 
149
  # DialoGPT works better with conversation format
150
  prompt = f"User: {user_message}\nAssistant:"
151
  else:
@@ -154,21 +157,40 @@ Minimum purchase is 1 roll (67-70 yards)."""
154
 
155
  logger.info(f"Generating response for prompt: {prompt[:100]}...")
156
 
157
- # Generate response
158
- response = self.client.text_generation(
159
- prompt,
160
- max_new_tokens=200, # Reduced for better performance
161
- temperature=0.7,
162
- top_p=0.9,
163
- top_k=40,
164
- repetition_penalty=1.1,
165
- stop_sequences=["<|end|>", "<|user|>", "User:", "Assistant:"]
166
- )
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  logger.info(f"Raw AI response: {response[:200]}...")
169
 
170
- # Clean up the response
171
- if "Assistant:" in response:
 
 
 
 
 
 
 
 
172
  assistant_response = response.split("Assistant:")[-1].strip()
173
  elif "<|assistant|>" in response:
174
  assistant_response = response.split("<|assistant|>")[-1].strip()
 
85
 
86
  def __init__(self):
87
  self.api_key = os.getenv('HUGGINGFACE_API_KEY')
88
+ # Use Llama model for better performance
89
+ self.model = os.getenv('DEFAULT_MODEL', 'meta-llama/Llama-2-7b-chat-hf')
90
  self.system_prompt = self.load_system_prompt()
91
 
92
  if not self.api_key:
 
144
  return self.get_mock_response(user_message)
145
 
146
  try:
147
+ # For Llama models, use the proper chat format
148
+ if "llama" in self.model.lower():
149
+ # Llama 2 chat format
150
+ prompt = f"<s>[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>>\n\n{user_message} [/INST]"
151
+ elif "dialogpt" in self.model.lower():
152
  # DialoGPT works better with conversation format
153
  prompt = f"User: {user_message}\nAssistant:"
154
  else:
 
157
 
158
  logger.info(f"Generating response for prompt: {prompt[:100]}...")
159
 
160
+ # Generate response with model-specific parameters
161
+ if "llama" in self.model.lower():
162
+ response = self.client.text_generation(
163
+ prompt,
164
+ max_new_tokens=200,
165
+ temperature=0.7,
166
+ top_p=0.9,
167
+ top_k=40,
168
+ repetition_penalty=1.1,
169
+ stop_sequences=["</s>", "[INST]", "User:", "Assistant:"]
170
+ )
171
+ else:
172
+ response = self.client.text_generation(
173
+ prompt,
174
+ max_new_tokens=200,
175
+ temperature=0.7,
176
+ top_p=0.9,
177
+ top_k=40,
178
+ repetition_penalty=1.1,
179
+ stop_sequences=["<|end|>", "<|user|>", "User:", "Assistant:"]
180
+ )
181
 
182
  logger.info(f"Raw AI response: {response[:200]}...")
183
 
184
+ # Clean up the response based on model type
185
+ if "llama" in self.model.lower():
186
+ # For Llama models, extract content after [/INST]
187
+ if "[/INST]" in response:
188
+ assistant_response = response.split("[/INST]")[-1].strip()
189
+ else:
190
+ assistant_response = response.strip()
191
+ # Remove Llama-specific tokens
192
+ assistant_response = assistant_response.replace("<s>", "").replace("</s>", "").replace("[INST]", "").replace("[/INST]", "").strip()
193
+ elif "Assistant:" in response:
194
  assistant_response = response.split("Assistant:")[-1].strip()
195
  elif "<|assistant|>" in response:
196
  assistant_response = response.split("<|assistant|>")[-1].strip()