harismlnaslm commited on
Commit
119d2a6
·
1 Parent(s): e035194

Add training API endpoints and production-ready files

Browse files
Files changed (3) hide show
  1. README.md +241 -8
  2. app.py +608 -28
  3. templates/chat.html +26 -26
README.md CHANGED
@@ -1,10 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Textilindo AI
3
- emoji: 📚
4
- colorFrom: purple
5
- colorTo: gray
6
- sdk: docker
7
- pinned: false
8
- ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
+ # Base LLM Setup - Llama 3.1 8B dengan LoRA
2
+
3
+ Setup lengkap untuk fine-tuning model Llama 3.1 8B menggunakan LoRA (Low-Rank Adaptation).
4
+
5
+ ## 🚀 Fitur
6
+
7
+ - **Base Model**: Llama 3.1 8B Instruct
8
+ - **Fine-tuning**: LoRA untuk efisiensi memory
9
+ - **Format Data**: JSONL (JSON Lines)
10
+ - **Environment**: Virtual environment dengan Python
11
+ - **Inference**: vLLM untuk serving model
12
+ - **Monitoring**: Logs dan metrics
13
+
14
+ ## 📁 Struktur Direktori
15
+
16
+ ```
17
+ base-llm-setup/
18
+ ├── models/ # Model weights
19
+ ├── data/ # Training datasets (JSONL)
20
+ ├── scripts/ # Python scripts
21
+ │ ├── download_model.py # Download base model
22
+ │ ├── finetune_lora.py # LoRA fine-tuning
23
+ │ ├── test_model.py # Test fine-tuned model
24
+ │ └── create_sample_dataset.py # Create sample data
25
+ ├── configs/ # Configuration files
26
+ ├── logs/ # Training logs
27
+ ├── venv/ # Virtual environment
28
+ ├── requirements.txt # Python dependencies
29
+ ├── setup.sh # Setup script
30
+ ├── docker-compose.yml # Docker services
31
+ └── README.md # This file
32
+ ```
33
+
34
+ ## 🛠️ Prerequisites
35
+
36
+ - Python 3.8+
37
+ - CUDA-compatible GPU (untuk training)
38
+ - Docker & Docker Compose
39
+ - HuggingFace account dan token
40
+
41
+ ## ⚡ Quick Start
42
+
43
+ ### 1. Setup Environment
44
+
45
+ ```bash
46
+ # Clone atau buat folder
47
+ cd base-llm-setup
48
+
49
+ # Jalankan setup script
50
+ chmod +x setup.sh
51
+ ./setup.sh
52
+ ```
53
+
54
+ ### 2. Aktifkan Virtual Environment
55
+
56
+ ```bash
57
+ source venv/bin/activate
58
+ ```
59
+
60
+ ### 3. Set HuggingFace Token
61
+
62
+ ```bash
63
+ export HUGGINGFACE_TOKEN="your_token_here"
64
+ ```
65
+
66
+ ### 4. Download Base Model
67
+
68
+ ```bash
69
+ python scripts/download_model.py
70
+ ```
71
+
72
+ ### 5. Buat Dataset (JSONL)
73
+
74
+ ```bash
75
+ python scripts/create_sample_dataset.py
76
+ ```
77
+
78
+ ### 6. Fine-tuning dengan LoRA
79
+
80
+ ```bash
81
+ python scripts/finetune_lora.py
82
+ ```
83
+
84
+ ### 7. Test Model
85
+
86
+ ```bash
87
+ python scripts/test_model.py
88
+ ```
89
+
90
+ ## 📊 Format Dataset JSONL
91
+
92
+ Dataset harus dalam format JSONL (JSON Lines) dengan struktur:
93
+
94
+ ```jsonl
95
+ {"text": "Apa itu machine learning?", "category": "education", "language": "id"}
96
+ {"text": "Jelaskan tentang deep learning", "category": "education", "language": "id"}
97
+ {"text": "Bagaimana cara kerja neural network?", "category": "education", "language": "id"}
98
+ ```
99
+
100
+ **Field yang diperlukan:**
101
+ - `text`: Teks untuk training (wajib)
102
+ - `category`: Kategori data (opsional)
103
+ - `language`: Bahasa (opsional, default: "id")
104
+
105
+ ## 🔧 Konfigurasi
106
+
107
+ ### Model Configuration (`configs/llama_config.yaml`)
108
+
109
+ ```yaml
110
+ model_name: "meta-llama/Llama-3.1-8B-Instruct"
111
+ model_path: "./models/llama-3.1-8b-instruct"
112
+ max_length: 8192
113
+ temperature: 0.7
114
+ top_p: 0.9
115
+ top_k: 40
116
+ repetition_penalty: 1.1
117
+
118
+ # LoRA Configuration
119
+ lora_config:
120
+ r: 16
121
+ lora_alpha: 32
122
+ lora_dropout: 0.1
123
+ target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
124
+
125
+ # Training Configuration
126
+ training_config:
127
+ learning_rate: 2e-4
128
+ batch_size: 4
129
+ gradient_accumulation_steps: 4
130
+ num_epochs: 3
131
+ warmup_steps: 100
132
+ save_steps: 500
133
+ eval_steps: 500
134
+ ```
135
+
136
+ ### Docker Configuration
137
+
138
+ ```bash
139
+ # Start vLLM service
140
+ docker-compose up -d vllm
141
+
142
+ # Check status
143
+ docker-compose ps
144
+
145
+ # View logs
146
+ docker-compose logs -f vllm
147
+ ```
148
+
149
+ ## 🧪 Testing
150
+
151
+ ### Interactive Mode
152
+ ```bash
153
+ python scripts/test_model.py
154
+ # Pilih opsi 1 untuk interactive chat
155
+ ```
156
+
157
+ ### Batch Testing
158
+ ```bash
159
+ python scripts/test_model.py
160
+ # Pilih opsi 2 untuk batch testing
161
+ ```
162
+
163
+ ### Custom Prompt
164
+ ```bash
165
+ python scripts/test_model.py
166
+ # Pilih opsi 3 untuk custom prompt
167
+ ```
168
+
169
+ ## 📈 Monitoring
170
+
171
+ ### Training Logs
172
+ - Logs tersimpan di folder `logs/`
173
+ - Monitor GPU usage dengan `nvidia-smi`
174
+ - Check training progress di console
175
+
176
+ ### Model Performance
177
+ - Loss metrics selama training
178
+ - Model checkpoints tersimpan setiap `save_steps`
179
+ - Evaluation metrics setiap `eval_steps`
180
+
181
+ ## 🔍 Troubleshooting
182
+
183
+ ### Common Issues
184
+
185
+ 1. **CUDA Out of Memory**
186
+ - Kurangi `batch_size`
187
+ - Kurangi `max_length`
188
+ - Gunakan gradient accumulation
189
+
190
+ 2. **Model Download Failed**
191
+ - Check HuggingFace token
192
+ - Verify internet connection
193
+ - Check disk space
194
+
195
+ 3. **Training Slow**
196
+ - Increase `batch_size` jika memory cukup
197
+ - Optimize data loading
198
+ - Use mixed precision training
199
+
200
+ ### Performance Tips
201
+
202
+ - Gunakan SSD untuk dataset besar
203
+ - Monitor GPU temperature
204
+ - Use appropriate learning rate scheduling
205
+ - Regular checkpointing untuk recovery
206
+
207
+ ## 📚 Dependencies
208
+
209
+ Lihat `requirements.txt` untuk daftar lengkap dependencies:
210
+
211
+ - **Core**: torch, transformers, peft, datasets
212
+ - **Inference**: vllm, openai
213
+ - **Utils**: numpy, pandas, pyyaml
214
+ - **Dev**: pytest, black, flake8
215
+
216
+ ## 🤝 Contributing
217
+
218
+ 1. Fork repository
219
+ 2. Create feature branch
220
+ 3. Commit changes
221
+ 4. Push to branch
222
+ 5. Create Pull Request
223
+
224
+ ## 📄 License
225
+
226
+ MIT License - lihat LICENSE file untuk detail.
227
+
228
+ ## �� Support
229
+
230
+ Jika ada masalah atau pertanyaan:
231
+
232
+ 1. Check troubleshooting section
233
+ 2. Review logs di folder `logs/`
234
+ 3. Open issue di repository
235
+ 4. Contact maintainer
236
+
237
  ---
 
 
 
 
 
 
 
238
 
239
+ **Happy Fine-tuning! 🚀**
240
+
241
+
242
+
243
+
app.py CHANGED
@@ -1,56 +1,636 @@
1
  #!/usr/bin/env python3
2
  """
3
- Minimal working version to fix 503 error
 
4
  """
5
 
6
  import os
7
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
8
  from pydantic import BaseModel
9
  import uvicorn
 
 
10
 
11
- app = FastAPI(title="Textilindo AI API")
 
 
 
 
 
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class ChatRequest(BaseModel):
14
  message: str
 
15
 
16
  class ChatResponse(BaseModel):
17
  response: str
 
18
  status: str = "success"
19
 
20
- @app.get("/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  async def root():
22
- return {"message": "Textilindo AI API is running", "status": "ok"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- @app.get("/health")
25
- async def health():
26
- return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- @app.get("/debug/env")
29
- async def debug_env():
30
- api_key = os.getenv("HUGGINGFACE_API_KEY")
31
  return {
32
- "api_key_present": bool(api_key),
33
- "api_key_length": len(api_key) if api_key else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  }
35
 
36
- @app.post("/chat")
37
- async def chat(request: ChatRequest):
38
- # Simple mock response for now
39
- mock_responses = {
40
- "jam berapa textilindo buka": "Jam operasional Senin-Jumat 08:00-17:00, Sabtu 08:00-12:00.",
41
- "dimana lokasi textilindo": "Textilindo berkantor pusat di Jl. Raya Prancis No.39, Kosambi Tim., Kec. Kosambi, Kabupaten Tangerang, Banten 15213",
42
- "apa ada gratis ongkir": "Gratis ongkir untuk order minimal 5 roll."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  }
 
 
 
 
 
 
44
 
45
- user_lower = request.message.lower()
46
- response = "Halo! Saya adalah asisten AI Textilindo. Bagaimana saya bisa membantu Anda hari ini? 😊"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- for key, mock_response in mock_responses.items():
49
- if any(word in user_lower for word in key.split()):
50
- response = mock_response
51
- break
52
 
53
- return ChatResponse(response=response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  if __name__ == "__main__":
56
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ Textilindo AI Assistant - Hugging Face Spaces FastAPI Application
4
+ Main application file for deployment on Hugging Face Spaces
5
  """
6
 
7
  import os
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+ from datetime import datetime
12
+ from typing import Optional, Dict, Any
13
+ from fastapi import FastAPI, HTTPException, Request, BackgroundTasks
14
+ from fastapi.responses import HTMLResponse, JSONResponse
15
+ from fastapi.staticfiles import StaticFiles
16
+ from fastapi.middleware.cors import CORSMiddleware
17
  from pydantic import BaseModel
18
  import uvicorn
19
+ from huggingface_hub import InferenceClient
20
+ import requests
21
 
22
+ # Import torch only when needed for training
23
+ try:
24
+ import torch
25
+ TORCH_AVAILABLE = True
26
+ except ImportError:
27
+ TORCH_AVAILABLE = False
28
+ torch = None
29
 
30
+ # Setup logging
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Initialize FastAPI app
35
+ app = FastAPI(
36
+ title="Textilindo AI Assistant",
37
+ description="AI Assistant for Textilindo textile company",
38
+ version="1.0.0"
39
+ )
40
+
41
+ # Add CORS middleware
42
+ app.add_middleware(
43
+ CORSMiddleware,
44
+ allow_origins=["*"],
45
+ allow_credentials=True,
46
+ allow_methods=["*"],
47
+ allow_headers=["*"],
48
+ )
49
+
50
+ # Request/Response models
51
  class ChatRequest(BaseModel):
52
  message: str
53
+ conversation_id: Optional[str] = None
54
 
55
  class ChatResponse(BaseModel):
56
  response: str
57
+ conversation_id: str
58
  status: str = "success"
59
 
60
+ class HealthResponse(BaseModel):
61
+ status: str
62
+ message: str
63
+ version: str = "1.0.0"
64
+
65
+ # Training models
66
+ class TrainingRequest(BaseModel):
67
+ model_name: str = "distilgpt2"
68
+ dataset_path: str = "data/lora_dataset_20250910_145055.jsonl"
69
+ config_path: str = "configs/training_config.yaml"
70
+ max_samples: int = 20
71
+ epochs: int = 1
72
+ batch_size: int = 1
73
+ learning_rate: float = 5e-5
74
+
75
+ class TrainingResponse(BaseModel):
76
+ success: bool
77
+ message: str
78
+ training_id: str
79
+ status: str
80
+
81
+ # Training status storage
82
+ training_status = {
83
+ "is_training": False,
84
+ "progress": 0,
85
+ "status": "idle",
86
+ "current_step": 0,
87
+ "total_steps": 0,
88
+ "loss": 0.0,
89
+ "start_time": None,
90
+ "end_time": None,
91
+ "error": None
92
+ }
93
+
94
+ class TextilindoAI:
95
+ """Textilindo AI Assistant using HuggingFace Inference API"""
96
+
97
+ def __init__(self):
98
+ self.api_key = os.getenv('HUGGINGFACE_API_KEY')
99
+ self.model = os.getenv('DEFAULT_MODEL', 'meta-llama/Llama-3.1-8B-Instruct')
100
+ self.system_prompt = self.load_system_prompt()
101
+
102
+ if not self.api_key:
103
+ logger.warning("HUGGINGFACE_API_KEY not found. Using mock responses.")
104
+ self.client = None
105
+ else:
106
+ try:
107
+ self.client = InferenceClient(
108
+ token=self.api_key,
109
+ model=self.model
110
+ )
111
+ logger.info(f"Initialized with model: {self.model}")
112
+ except Exception as e:
113
+ logger.error(f"Failed to initialize InferenceClient: {e}")
114
+ self.client = None
115
+
116
+ def load_system_prompt(self) -> str:
117
+ """Load system prompt from config file"""
118
+ try:
119
+ prompt_path = Path("configs/system_prompt.md")
120
+ if prompt_path.exists():
121
+ with open(prompt_path, 'r', encoding='utf-8') as f:
122
+ content = f.read()
123
+
124
+ # Extract system prompt from markdown
125
+ if 'SYSTEM_PROMPT = """' in content:
126
+ start = content.find('SYSTEM_PROMPT = """') + len('SYSTEM_PROMPT = """')
127
+ end = content.find('"""', start)
128
+ return content[start:end].strip()
129
+ else:
130
+ # Fallback: use entire content
131
+ return content.strip()
132
+ else:
133
+ return self.get_default_system_prompt()
134
+ except Exception as e:
135
+ logger.error(f"Error loading system prompt: {e}")
136
+ return self.get_default_system_prompt()
137
+
138
+ def get_default_system_prompt(self) -> str:
139
+ """Default system prompt if file not found"""
140
+ return """You are a friendly and helpful AI assistant for Textilindo, a textile company.
141
+
142
+ Always respond in Indonesian (Bahasa Indonesia).
143
+ Keep responses short and direct.
144
+ Be friendly and helpful.
145
+ Use exact information from the knowledge base.
146
+ The company uses yards for sales.
147
+ Minimum purchase is 1 roll (67-70 yards)."""
148
+
149
+ def generate_response(self, user_message: str) -> str:
150
+ """Generate response using HuggingFace Inference API"""
151
+ if not self.client:
152
+ return self.get_mock_response(user_message)
153
+
154
+ try:
155
+ # Create full prompt with system prompt
156
+ full_prompt = f"<|system|>\n{self.system_prompt}\n<|user|>\n{user_message}\n<|assistant|>\n"
157
+
158
+ # Generate response
159
+ response = self.client.text_generation(
160
+ full_prompt,
161
+ max_new_tokens=512,
162
+ temperature=0.7,
163
+ top_p=0.9,
164
+ top_k=40,
165
+ repetition_penalty=1.1,
166
+ stop_sequences=["<|end|>", "<|user|>"]
167
+ )
168
+
169
+ # Extract only the assistant's response
170
+ if "<|assistant|>" in response:
171
+ assistant_response = response.split("<|assistant|>")[-1].strip()
172
+ assistant_response = assistant_response.replace("<|end|>", "").strip()
173
+ return assistant_response
174
+ else:
175
+ return response
176
+
177
+ except Exception as e:
178
+ logger.error(f"Error generating response: {e}")
179
+ return self.get_mock_response(user_message)
180
+
181
+ def get_mock_response(self, user_message: str) -> str:
182
+ """Mock responses for testing without API key"""
183
+ mock_responses = {
184
+ "dimana lokasi textilindo": "Textilindo berkantor pusat di Jl. Raya Prancis No.39, Kosambi Tim., Kec. Kosambi, Kabupaten Tangerang, Banten 15213",
185
+ "jam berapa textilindo beroperasional": "Jam operasional Senin-Jumat 08:00-17:00, Sabtu 08:00-12:00.",
186
+ "berapa ketentuan pembelian": "Minimal order 1 roll per jenis kain",
187
+ "bagaimana dengan pembayarannya": "Pembayaran dapat dilakukan via transfer bank atau cash on delivery",
188
+ "apa ada gratis ongkir": "Gratis ongkir untuk order minimal 5 roll.",
189
+ "apa bisa dikirimkan sample": "hallo kak untuk sampel kita bisa kirimkan gratis ya kak 😊"
190
+ }
191
+
192
+ # Simple keyword matching
193
+ user_lower = user_message.lower()
194
+ for key, response in mock_responses.items():
195
+ if any(word in user_lower for word in key.split()):
196
+ return response
197
+
198
+ return "Halo! Saya adalah asisten AI Textilindo. Bagaimana saya bisa membantu Anda hari ini? 😊"
199
+
200
+ # Initialize AI assistant
201
+ ai_assistant = TextilindoAI()
202
+
203
+ # Training functions
204
+ def load_training_data(dataset_path: str, max_samples: int = 20) -> list:
205
+ """Load training data from JSONL file"""
206
+ data = []
207
+ try:
208
+ with open(dataset_path, 'r', encoding='utf-8') as f:
209
+ for i, line in enumerate(f):
210
+ if i >= max_samples:
211
+ break
212
+ if line.strip():
213
+ item = json.loads(line)
214
+ # Create training text
215
+ instruction = item.get('instruction', '')
216
+ output = item.get('output', '')
217
+ text = f"Question: {instruction} Answer: {output}"
218
+ data.append({"text": text})
219
+ logger.info(f"Loaded {len(data)} training samples")
220
+ return data
221
+ except Exception as e:
222
+ logger.error(f"Error loading training data: {e}")
223
+ return []
224
+
225
+ async def train_model_async(
226
+ model_name: str,
227
+ dataset_path: str,
228
+ config_path: str,
229
+ max_samples: int,
230
+ epochs: int,
231
+ batch_size: int,
232
+ learning_rate: float
233
+ ):
234
+ """Async training function"""
235
+ global training_status
236
+
237
+ try:
238
+ training_status.update({
239
+ "is_training": True,
240
+ "status": "starting",
241
+ "progress": 0,
242
+ "start_time": datetime.now().isoformat(),
243
+ "error": None
244
+ })
245
+
246
+ logger.info("🚀 Starting training...")
247
+
248
+ # Import training libraries
249
+ from transformers import (
250
+ AutoTokenizer,
251
+ AutoModelForCausalLM,
252
+ TrainingArguments,
253
+ Trainer,
254
+ DataCollatorForLanguageModeling
255
+ )
256
+ from datasets import Dataset
257
+
258
+ # Check GPU
259
+ if not TORCH_AVAILABLE:
260
+ raise Exception("PyTorch is required for training but not available")
261
+
262
+ gpu_available = torch.cuda.is_available()
263
+ logger.info(f"GPU available: {gpu_available}")
264
+
265
+ # Load model and tokenizer
266
+ logger.info(f"📥 Loading model: {model_name}")
267
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
268
+ if tokenizer.pad_token is None:
269
+ tokenizer.pad_token = tokenizer.eos_token
270
+
271
+ # Load model
272
+ if gpu_available:
273
+ model = AutoModelForCausalLM.from_pretrained(
274
+ model_name,
275
+ torch_dtype=torch.float16,
276
+ device_map="auto"
277
+ )
278
+ else:
279
+ model = AutoModelForCausalLM.from_pretrained(model_name)
280
+
281
+ logger.info("✅ Model loaded successfully")
282
+
283
+ # Load training data
284
+ training_data = load_training_data(dataset_path, max_samples)
285
+ if not training_data:
286
+ raise Exception("No training data loaded")
287
+
288
+ # Convert to dataset
289
+ dataset = Dataset.from_list(training_data)
290
+
291
+ def tokenize_function(examples):
292
+ return tokenizer(
293
+ examples["text"],
294
+ truncation=True,
295
+ padding=True,
296
+ max_length=256,
297
+ return_tensors="pt"
298
+ )
299
+
300
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
301
+
302
+ # Training arguments
303
+ training_args = TrainingArguments(
304
+ output_dir="./models/textilindo-trained",
305
+ num_train_epochs=epochs,
306
+ per_device_train_batch_size=batch_size,
307
+ gradient_accumulation_steps=2,
308
+ learning_rate=learning_rate,
309
+ warmup_steps=5,
310
+ save_steps=10,
311
+ logging_steps=1,
312
+ save_total_limit=1,
313
+ prediction_loss_only=True,
314
+ remove_unused_columns=False,
315
+ fp16=gpu_available,
316
+ dataloader_pin_memory=gpu_available,
317
+ report_to=None,
318
+ )
319
+
320
+ # Data collator
321
+ data_collator = DataCollatorForLanguageModeling(
322
+ tokenizer=tokenizer,
323
+ mlm=False,
324
+ )
325
+
326
+ # Create trainer
327
+ trainer = Trainer(
328
+ model=model,
329
+ args=training_args,
330
+ train_dataset=tokenized_dataset,
331
+ data_collator=data_collator,
332
+ tokenizer=tokenizer,
333
+ )
334
+
335
+ # Start training
336
+ training_status["status"] = "training"
337
+ trainer.train()
338
+
339
+ # Save model
340
+ model.save_pretrained("./models/textilindo-trained")
341
+ tokenizer.save_pretrained("./models/textilindo-trained")
342
+
343
+ # Update status
344
+ training_status.update({
345
+ "is_training": False,
346
+ "status": "completed",
347
+ "progress": 100,
348
+ "end_time": datetime.now().isoformat()
349
+ })
350
+
351
+ logger.info("✅ Training completed successfully!")
352
+
353
+ except Exception as e:
354
+ logger.error(f"Training failed: {e}")
355
+ training_status.update({
356
+ "is_training": False,
357
+ "status": "failed",
358
+ "error": str(e),
359
+ "end_time": datetime.now().isoformat()
360
+ })
361
+
362
+ # Routes
363
+ @app.get("/", response_class=HTMLResponse)
364
  async def root():
365
+ """Serve the main chat interface"""
366
+ try:
367
+ with open("templates/chat.html", "r", encoding="utf-8") as f:
368
+ return HTMLResponse(content=f.read())
369
+ except FileNotFoundError:
370
+ return HTMLResponse(content="""
371
+ <!DOCTYPE html>
372
+ <html>
373
+ <head>
374
+ <title>Textilindo AI Assistant</title>
375
+ <meta charset="utf-8">
376
+ <style>
377
+ body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
378
+ .chat-container { border: 1px solid #ddd; border-radius: 10px; padding: 20px; margin: 20px 0; }
379
+ .message { margin: 10px 0; padding: 10px; border-radius: 5px; }
380
+ .user { background-color: #e3f2fd; text-align: right; }
381
+ .assistant { background-color: #f5f5f5; }
382
+ input[type="text"] { width: 70%; padding: 10px; border: 1px solid #ddd; border-radius: 5px; }
383
+ button { padding: 10px 20px; background-color: #2196f3; color: white; border: none; border-radius: 5px; cursor: pointer; }
384
+ </style>
385
+ </head>
386
+ <body>
387
+ <h1>🤖 Textilindo AI Assistant</h1>
388
+ <div class="chat-container">
389
+ <div id="chat-messages"></div>
390
+ <div style="margin-top: 20px;">
391
+ <input type="text" id="message-input" placeholder="Tulis pesan Anda..." onkeypress="handleKeyPress(event)">
392
+ <button onclick="sendMessage()">Kirim</button>
393
+ </div>
394
+ </div>
395
+ <script>
396
+ async function sendMessage() {
397
+ const input = document.getElementById('message-input');
398
+ const message = input.value.trim();
399
+ if (!message) return;
400
+
401
+ // Add user message
402
+ addMessage(message, 'user');
403
+ input.value = '';
404
+
405
+ // Get AI response
406
+ try {
407
+ const response = await fetch('/chat', {
408
+ method: 'POST',
409
+ headers: { 'Content-Type': 'application/json' },
410
+ body: JSON.stringify({ message: message })
411
+ });
412
+ const data = await response.json();
413
+ addMessage(data.response, 'assistant');
414
+ } catch (error) {
415
+ addMessage('Maaf, terjadi kesalahan. Silakan coba lagi.', 'assistant');
416
+ }
417
+ }
418
+
419
+ function addMessage(text, sender) {
420
+ const messages = document.getElementById('chat-messages');
421
+ const div = document.createElement('div');
422
+ div.className = `message ${sender}`;
423
+ div.textContent = text;
424
+ messages.appendChild(div);
425
+ messages.scrollTop = messages.scrollHeight;
426
+ }
427
+
428
+ function handleKeyPress(event) {
429
+ if (event.key === 'Enter') {
430
+ sendMessage();
431
+ }
432
+ }
433
+ </script>
434
+ </body>
435
+ </html>
436
+ """)
437
 
438
+ @app.post("/chat", response_model=ChatResponse)
439
+ async def chat(request: ChatRequest):
440
+ """Chat endpoint"""
441
+ try:
442
+ response = ai_assistant.generate_response(request.message)
443
+ return ChatResponse(
444
+ response=response,
445
+ conversation_id=request.conversation_id or "default",
446
+ status="success"
447
+ )
448
+ except Exception as e:
449
+ logger.error(f"Error in chat endpoint: {e}")
450
+ raise HTTPException(status_code=500, detail="Internal server error")
451
+
452
+ @app.get("/health", response_model=HealthResponse)
453
+ async def health_check():
454
+ """Health check endpoint"""
455
+ return HealthResponse(
456
+ status="healthy",
457
+ message="Textilindo AI Assistant is running",
458
+ version="1.0.0"
459
+ )
460
 
461
+ @app.get("/info")
462
+ async def get_info():
463
+ """Get application information"""
464
  return {
465
+ "name": "Textilindo AI Assistant",
466
+ "version": "1.0.0",
467
+ "model": ai_assistant.model,
468
+ "has_api_key": bool(ai_assistant.api_key),
469
+ "client_initialized": bool(ai_assistant.client),
470
+ "endpoints": {
471
+ "training": {
472
+ "start": "POST /api/train/start",
473
+ "status": "GET /api/train/status",
474
+ "data": "GET /api/train/data",
475
+ "gpu": "GET /api/train/gpu",
476
+ "test": "POST /api/train/test"
477
+ },
478
+ "chat": {
479
+ "chat": "POST /chat",
480
+ "health": "GET /health"
481
+ }
482
+ }
483
  }
484
 
485
+ # Training API endpoints
486
+ @app.post("/api/train/start", response_model=TrainingResponse)
487
+ async def start_training(request: TrainingRequest, background_tasks: BackgroundTasks):
488
+ """Start training process"""
489
+ global training_status
490
+
491
+ if training_status["is_training"]:
492
+ raise HTTPException(status_code=400, detail="Training already in progress")
493
+
494
+ # Validate inputs
495
+ if not Path(request.dataset_path).exists():
496
+ raise HTTPException(status_code=404, detail=f"Dataset not found: {request.dataset_path}")
497
+
498
+ if not Path(request.config_path).exists():
499
+ raise HTTPException(status_code=404, detail=f"Config not found: {request.config_path}")
500
+
501
+ # Start training in background
502
+ training_id = f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
503
+
504
+ background_tasks.add_task(
505
+ train_model_async,
506
+ request.model_name,
507
+ request.dataset_path,
508
+ request.config_path,
509
+ request.max_samples,
510
+ request.epochs,
511
+ request.batch_size,
512
+ request.learning_rate
513
+ )
514
+
515
+ return TrainingResponse(
516
+ success=True,
517
+ message="Training started successfully",
518
+ training_id=training_id,
519
+ status="started"
520
+ )
521
+
522
+ @app.get("/api/train/status")
523
+ async def get_training_status():
524
+ """Get current training status"""
525
+ return training_status
526
+
527
+ @app.get("/api/train/data")
528
+ async def get_training_data_info():
529
+ """Get information about available training data"""
530
+ data_dir = Path("data")
531
+ if not data_dir.exists():
532
+ return {"files": [], "count": 0}
533
+
534
+ jsonl_files = list(data_dir.glob("*.jsonl"))
535
+ files_info = []
536
+
537
+ for file in jsonl_files:
538
+ try:
539
+ with open(file, 'r', encoding='utf-8') as f:
540
+ lines = f.readlines()
541
+ files_info.append({
542
+ "name": file.name,
543
+ "size": file.stat().st_size,
544
+ "lines": len(lines)
545
+ })
546
+ except Exception as e:
547
+ files_info.append({
548
+ "name": file.name,
549
+ "error": str(e)
550
+ })
551
+
552
+ return {
553
+ "files": files_info,
554
+ "count": len(jsonl_files)
555
  }
556
+
557
+ @app.get("/api/train/gpu")
558
+ async def get_gpu_info():
559
+ """Get GPU information"""
560
+ if not TORCH_AVAILABLE:
561
+ return {"available": False, "error": "PyTorch not available"}
562
 
563
+ try:
564
+ gpu_available = torch.cuda.is_available()
565
+ if gpu_available:
566
+ gpu_count = torch.cuda.device_count()
567
+ gpu_name = torch.cuda.get_device_name(0)
568
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
569
+ return {
570
+ "available": True,
571
+ "count": gpu_count,
572
+ "name": gpu_name,
573
+ "memory_gb": round(gpu_memory, 2)
574
+ }
575
+ else:
576
+ return {"available": False}
577
+ except Exception as e:
578
+ return {"error": str(e)}
579
+
580
+ @app.post("/api/train/test")
581
+ async def test_trained_model():
582
+ """Test the trained model"""
583
+ if not TORCH_AVAILABLE:
584
+ return {"error": "PyTorch is required for model testing but not available"}
585
 
586
+ model_path = "./models/textilindo-trained"
587
+ if not Path(model_path).exists():
588
+ return {"error": "No trained model found"}
 
589
 
590
+ try:
591
+ from transformers import AutoTokenizer, AutoModelForCausalLM
592
+
593
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
594
+ model = AutoModelForCausalLM.from_pretrained(model_path)
595
+
596
+ # Test prompt
597
+ test_prompt = "Question: dimana lokasi textilindo? Answer:"
598
+ inputs = tokenizer(test_prompt, return_tensors="pt")
599
+
600
+ with torch.no_grad():
601
+ outputs = model.generate(
602
+ **inputs,
603
+ max_length=inputs.input_ids.shape[1] + 30,
604
+ temperature=0.7,
605
+ do_sample=True,
606
+ pad_token_id=tokenizer.eos_token_id
607
+ )
608
+
609
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
610
+
611
+ return {
612
+ "success": True,
613
+ "test_prompt": test_prompt,
614
+ "response": response,
615
+ "model_path": model_path
616
+ }
617
+
618
+ except Exception as e:
619
+ return {"error": str(e)}
620
+
621
+ # Mount static files if they exist
622
+ if Path("static").exists():
623
+ app.mount("/static", StaticFiles(directory="static"), name="static")
624
 
625
  if __name__ == "__main__":
626
+ # Get port from environment variable (Hugging Face Spaces uses 7860)
627
+ port = int(os.getenv("PORT", 7860))
628
+
629
+ # Run the application
630
+ uvicorn.run(
631
+ "app:app",
632
+ host="0.0.0.0",
633
+ port=port,
634
+ log_level="info"
635
+ )
636
+ # Updated Mon, Oct 27, 2025 9:53:55 AM
templates/chat.html CHANGED
@@ -10,7 +10,7 @@
10
  padding: 0;
11
  box-sizing: border-box;
12
  }
13
-
14
  body {
15
  font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
16
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
@@ -20,7 +20,7 @@
20
  align-items: center;
21
  padding: 20px;
22
  }
23
-
24
  .chat-container {
25
  background: white;
26
  border-radius: 20px;
@@ -32,24 +32,24 @@
32
  flex-direction: column;
33
  overflow: hidden;
34
  }
35
-
36
  .chat-header {
37
  background: linear-gradient(135deg, #2196f3, #21cbf3);
38
  color: white;
39
  padding: 20px;
40
  text-align: center;
41
  }
42
-
43
  .chat-header h1 {
44
  font-size: 24px;
45
  margin-bottom: 5px;
46
  }
47
-
48
  .chat-header p {
49
  opacity: 0.9;
50
  font-size: 14px;
51
  }
52
-
53
  .chat-messages {
54
  flex: 1;
55
  padding: 20px;
@@ -58,7 +58,7 @@
58
  flex-direction: column;
59
  gap: 15px;
60
  }
61
-
62
  .message {
63
  max-width: 80%;
64
  padding: 12px 16px;
@@ -66,14 +66,14 @@
66
  word-wrap: break-word;
67
  animation: fadeIn 0.3s ease-in;
68
  }
69
-
70
  .user-message {
71
  background: #2196f3;
72
  color: white;
73
  align-self: flex-end;
74
  border-bottom-right-radius: 5px;
75
  }
76
-
77
  .assistant-message {
78
  background: #f5f5f5;
79
  color: #333;
@@ -88,7 +88,7 @@
88
  display: flex;
89
  gap: 10px;
90
  }
91
-
92
  .chat-input {
93
  flex: 1;
94
  padding: 12px 16px;
@@ -98,11 +98,11 @@
98
  font-size: 14px;
99
  transition: border-color 0.3s ease;
100
  }
101
-
102
  .chat-input:focus {
103
  border-color: #2196f3;
104
  }
105
-
106
  .send-button {
107
  background: #2196f3;
108
  color: white;
@@ -116,16 +116,16 @@
116
  justify-content: center;
117
  transition: background-color 0.3s ease;
118
  }
119
-
120
  .send-button:hover {
121
  background: #1976d2;
122
  }
123
-
124
  .send-button:disabled {
125
  background: #ccc;
126
  cursor: not-allowed;
127
  }
128
-
129
  .typing-indicator {
130
  display: none;
131
  align-self: flex-start;
@@ -136,7 +136,7 @@
136
  color: #666;
137
  font-style: italic;
138
  }
139
-
140
  @keyframes fadeIn {
141
  from { opacity: 0; transform: translateY(10px); }
142
  to { opacity: 1; transform: translateY(0); }
@@ -181,17 +181,17 @@
181
  <h1>🤖 Textilindo AI Assistant</h1>
182
  <p>Asisten AI untuk membantu pertanyaan tentang Textilindo</p>
183
  </div>
184
-
185
  <div class="chat-messages" id="chatMessages">
186
  <div class="welcome-message">
187
  👋 Halo! Saya adalah asisten AI Textilindo. Bagaimana saya bisa membantu Anda hari ini?
188
  </div>
189
  </div>
190
-
191
  <div class="typing-indicator" id="typingIndicator">
192
  <span class="typing-dots">AI sedang mengetik</span>
193
  </div>
194
-
195
  <div class="chat-input-container">
196
  <input
197
  type="text"
@@ -202,7 +202,7 @@
202
  >
203
  <button id="sendButton" class="send-button" onclick="sendMessage()">
204
 
205
- </button>
206
  </div>
207
  </div>
208
 
@@ -236,15 +236,15 @@
236
  // Disable input and button
237
  messageInput.disabled = true;
238
  sendButton.disabled = true;
239
-
240
  // Add user message
241
  addMessage(message, 'user');
242
  messageInput.value = '';
243
  messageInput.style.height = 'auto';
244
-
245
  // Show typing indicator
246
  showTypingIndicator();
247
-
248
  try {
249
  const response = await fetch('/chat', {
250
  method: 'POST',
@@ -257,11 +257,11 @@
257
  if (!response.ok) {
258
  throw new Error(`HTTP error! status: ${response.status}`);
259
  }
260
-
261
  const data = await response.json();
262
  hideTypingIndicator();
263
  addMessage(data.response, 'assistant');
264
-
265
  } catch (error) {
266
  console.error('Error:', error);
267
  hideTypingIndicator();
@@ -346,4 +346,4 @@
346
  setTimeout(addSampleQuestions, 1000);
347
  </script>
348
  </body>
349
- </html>
 
10
  padding: 0;
11
  box-sizing: border-box;
12
  }
13
+
14
  body {
15
  font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
16
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
 
20
  align-items: center;
21
  padding: 20px;
22
  }
23
+
24
  .chat-container {
25
  background: white;
26
  border-radius: 20px;
 
32
  flex-direction: column;
33
  overflow: hidden;
34
  }
35
+
36
  .chat-header {
37
  background: linear-gradient(135deg, #2196f3, #21cbf3);
38
  color: white;
39
  padding: 20px;
40
  text-align: center;
41
  }
42
+
43
  .chat-header h1 {
44
  font-size: 24px;
45
  margin-bottom: 5px;
46
  }
47
+
48
  .chat-header p {
49
  opacity: 0.9;
50
  font-size: 14px;
51
  }
52
+
53
  .chat-messages {
54
  flex: 1;
55
  padding: 20px;
 
58
  flex-direction: column;
59
  gap: 15px;
60
  }
61
+
62
  .message {
63
  max-width: 80%;
64
  padding: 12px 16px;
 
66
  word-wrap: break-word;
67
  animation: fadeIn 0.3s ease-in;
68
  }
69
+
70
  .user-message {
71
  background: #2196f3;
72
  color: white;
73
  align-self: flex-end;
74
  border-bottom-right-radius: 5px;
75
  }
76
+
77
  .assistant-message {
78
  background: #f5f5f5;
79
  color: #333;
 
88
  display: flex;
89
  gap: 10px;
90
  }
91
+
92
  .chat-input {
93
  flex: 1;
94
  padding: 12px 16px;
 
98
  font-size: 14px;
99
  transition: border-color 0.3s ease;
100
  }
101
+
102
  .chat-input:focus {
103
  border-color: #2196f3;
104
  }
105
+
106
  .send-button {
107
  background: #2196f3;
108
  color: white;
 
116
  justify-content: center;
117
  transition: background-color 0.3s ease;
118
  }
119
+
120
  .send-button:hover {
121
  background: #1976d2;
122
  }
123
+
124
  .send-button:disabled {
125
  background: #ccc;
126
  cursor: not-allowed;
127
  }
128
+
129
  .typing-indicator {
130
  display: none;
131
  align-self: flex-start;
 
136
  color: #666;
137
  font-style: italic;
138
  }
139
+
140
  @keyframes fadeIn {
141
  from { opacity: 0; transform: translateY(10px); }
142
  to { opacity: 1; transform: translateY(0); }
 
181
  <h1>🤖 Textilindo AI Assistant</h1>
182
  <p>Asisten AI untuk membantu pertanyaan tentang Textilindo</p>
183
  </div>
184
+
185
  <div class="chat-messages" id="chatMessages">
186
  <div class="welcome-message">
187
  👋 Halo! Saya adalah asisten AI Textilindo. Bagaimana saya bisa membantu Anda hari ini?
188
  </div>
189
  </div>
190
+
191
  <div class="typing-indicator" id="typingIndicator">
192
  <span class="typing-dots">AI sedang mengetik</span>
193
  </div>
194
+
195
  <div class="chat-input-container">
196
  <input
197
  type="text"
 
202
  >
203
  <button id="sendButton" class="send-button" onclick="sendMessage()">
204
 
205
+ </button>
206
  </div>
207
  </div>
208
 
 
236
  // Disable input and button
237
  messageInput.disabled = true;
238
  sendButton.disabled = true;
239
+
240
  // Add user message
241
  addMessage(message, 'user');
242
  messageInput.value = '';
243
  messageInput.style.height = 'auto';
244
+
245
  // Show typing indicator
246
  showTypingIndicator();
247
+
248
  try {
249
  const response = await fetch('/chat', {
250
  method: 'POST',
 
257
  if (!response.ok) {
258
  throw new Error(`HTTP error! status: ${response.status}`);
259
  }
260
+
261
  const data = await response.json();
262
  hideTypingIndicator();
263
  addMessage(data.response, 'assistant');
264
+
265
  } catch (error) {
266
  console.error('Error:', error);
267
  hideTypingIndicator();
 
346
  setTimeout(addSampleQuestions, 1000);
347
  </script>
348
  </body>
349
+ </html>