dungeon29 commited on
Commit
269756e
Β·
verified Β·
1 Parent(s): 78920ba

Create llm_client.py

Browse files
Files changed (1) hide show
  1. llm_client.py +309 -0
llm_client.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import subprocess
4
+ import tarfile
5
+ import stat
6
+ import time
7
+ import atexit
8
+ from huggingface_hub import hf_hub_download
9
+ from langchain_core.language_models import LLM
10
+ from langchain.chains import RetrievalQA
11
+ from langchain_core.prompts import PromptTemplate
12
+ from typing import Any, List, Optional, Mapping
13
+
14
+ # --- Helper to Setup llama-server ---
15
+ def setup_llama_binaries():
16
+ """
17
+ Download and extract llama-server binary and libs from official releases
18
+ """
19
+ # Latest release URL for Linux x64 (b4991 equivalent or newer)
20
+ CLI_URL = "https://github.com/ggml-org/llama.cpp/releases/download/b7312/llama-b7312-bin-ubuntu-x64.tar.gz"
21
+ LOCAL_TAR = "llama-cli.tar.gz"
22
+ BIN_DIR = "./llama_bin"
23
+ SERVER_BIN = os.path.join(BIN_DIR, "bin/llama-server") # Look for server binary
24
+
25
+ if os.path.exists(SERVER_BIN):
26
+ return SERVER_BIN, BIN_DIR
27
+
28
+ try:
29
+ print("⬇️ Downloading llama.cpp binaries...")
30
+ response = requests.get(CLI_URL, stream=True)
31
+ if response.status_code == 200:
32
+ with open(LOCAL_TAR, 'wb') as f:
33
+ for chunk in response.iter_content(chunk_size=8192):
34
+ f.write(chunk)
35
+
36
+ print("πŸ“¦ Extracting binaries...")
37
+ os.makedirs(BIN_DIR, exist_ok=True)
38
+
39
+ with tarfile.open(LOCAL_TAR, "r:gz") as tar:
40
+ tar.extractall(path=BIN_DIR)
41
+
42
+ # Locate llama-server
43
+ found_bin = None
44
+ for root, dirs, files in os.walk(BIN_DIR):
45
+ if "llama-server" in files:
46
+ found_bin = os.path.join(root, "llama-server")
47
+ break
48
+
49
+ if not found_bin:
50
+ print("❌ Could not find llama-server in extracted files.")
51
+ return None, None
52
+
53
+ # Make executable
54
+ st = os.stat(found_bin)
55
+ os.chmod(found_bin, st.st_mode | stat.S_IEXEC)
56
+ print(f"βœ… llama-server binary ready at {found_bin}!")
57
+ return found_bin, BIN_DIR
58
+ else:
59
+ print(f"❌ Failed to download binaries: {response.status_code}")
60
+ return None, None
61
+ except Exception as e:
62
+ print(f"❌ Error setting up llama-server: {e}")
63
+ return None, None
64
+
65
+ # --- Local LLM Wrapper ---
66
+ class LocalLLM(LLM):
67
+ local_server_url: str = "http://localhost:8080"
68
+
69
+ @property
70
+ def _llm_type(self) -> str:
71
+ return "local_qwen"
72
+
73
+ def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
74
+ print("πŸ’» Using Local Qwen3-0.6B...")
75
+ try:
76
+ # OpenAI-compatible completion endpoint
77
+ payload = {
78
+ "prompt": prompt,
79
+ "n_predict": 1024,
80
+ "temperature": 0.3,
81
+ "stop": (stop or []) + ["<|im_end|>", "Input:", "Context:"]
82
+ }
83
+ response = requests.post(
84
+ f"{self.local_server_url}/completion",
85
+ json=payload,
86
+ timeout=300
87
+ )
88
+ if response.status_code == 200:
89
+ return response.json()["content"]
90
+ else:
91
+ return f"❌ Local Server Error: {response.text}"
92
+ except Exception as e:
93
+ return f"❌ Local Inference Failed: {e}"
94
+
95
+ @property
96
+ def _identifying_params(self) -> Mapping[str, Any]:
97
+ return {"local_server_url": self.local_server_url}
98
+
99
+ # --- Groq API LLM Wrapper ---
100
+ class GroqLLM(LLM):
101
+ groq_client: Any = None
102
+ groq_model: str = "qwen/qwen3-32b"
103
+
104
+ @property
105
+ def _llm_type(self) -> str:
106
+ return "groq_qwen"
107
+
108
+ def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
109
+ if not self.groq_client:
110
+ return "❌ Groq API Key not set or client initialization failed."
111
+
112
+ print(f"⚑ Using Groq API ({self.groq_model})...")
113
+ try:
114
+ stop_seq = (stop or []) + ["<|im_end|>", "Input:", "Context:"]
115
+
116
+ chat_completion = self.groq_client.chat.completions.create(
117
+ messages=[
118
+ {"role": "user", "content": prompt}
119
+ ],
120
+ model=self.groq_model,
121
+ temperature=0.3,
122
+ max_tokens=1024,
123
+ stop=stop_seq
124
+ )
125
+ return chat_completion.choices[0].message.content
126
+ except Exception as e:
127
+ return f"❌ Groq API Failed: {e}"
128
+
129
+ @property
130
+ def _identifying_params(self) -> Mapping[str, Any]:
131
+ return {"model": self.groq_model}
132
+
133
+ class LLMClient:
134
+ def __init__(self, vector_store=None):
135
+ """
136
+ Initialize LLM Client with support for both API and Local
137
+ """
138
+ self.vector_store = vector_store
139
+ self.server_process = None
140
+ self.server_port = 8080
141
+ self.groq_client = None
142
+ self.local_llm_instance = None
143
+ self.groq_llm_instance = None
144
+
145
+ # 1. Setup Groq Client
146
+ groq_api_key = os.environ.get("GROQ_API_KEY")
147
+ self.groq_model = "qwen/qwen3-32b"
148
+
149
+ if groq_api_key:
150
+ try:
151
+ from groq import Groq
152
+ print(f"⚑ Initializing Native Groq Client ({self.groq_model})...")
153
+ self.groq_client = Groq(api_key=groq_api_key)
154
+ self.groq_llm_instance = GroqLLM(
155
+ groq_client=self.groq_client,
156
+ groq_model=self.groq_model
157
+ )
158
+ print("βœ… Groq Client ready.")
159
+ except Exception as e:
160
+ print(f"⚠️ Groq Init Failed: {e}")
161
+
162
+ # 2. Setup Local Fallback (Always setup as requested)
163
+ try:
164
+ # Setup Binary
165
+ self.server_bin, self.lib_path = setup_llama_binaries()
166
+
167
+ # Download Model (Qwen3-0.6B)
168
+ print("οΏ½ Loading Local Qwen3-4B (GGUF)...")
169
+ model_repo = "Qwen/Qwen3-4B-GGUF"
170
+ filename = "Qwen3-4B-Q4_K_M.gguf"
171
+
172
+ self.model_path = hf_hub_download(
173
+ repo_id=model_repo,
174
+ filename=filename
175
+ )
176
+ print(f"βœ… Model downloaded to: {self.model_path}")
177
+
178
+ # Start Server
179
+ self.start_local_server()
180
+
181
+ self.local_llm_instance = LocalLLM(
182
+ local_server_url=f"http://localhost:{self.server_port}"
183
+ )
184
+
185
+ except Exception as e:
186
+ print(f"⚠️ Could not setup local fallback: {e}")
187
+
188
+ def start_local_server(self):
189
+ """Start llama-server in background"""
190
+ if not self.server_bin or not self.model_path:
191
+ return
192
+
193
+ print("πŸš€ Starting llama-server...")
194
+
195
+ # Setup Env
196
+ env = os.environ.copy()
197
+ lib_paths = [os.path.dirname(self.server_bin)]
198
+ lib_subdir = os.path.join(self.lib_path, "lib")
199
+ if os.path.exists(lib_subdir):
200
+ lib_paths.append(lib_subdir)
201
+ env["LD_LIBRARY_PATH"] = ":".join(lib_paths) + ":" + env.get("LD_LIBRARY_PATH", "")
202
+
203
+ cmd = [
204
+ self.server_bin,
205
+ "-m", self.model_path,
206
+ "--port", str(self.server_port),
207
+ "-c", "8192",
208
+ "--host", "0.0.0.0",
209
+ "--mlock" # Lock model in RAM to prevent swapping
210
+ ]
211
+
212
+ # Launch process
213
+ self.server_process = subprocess.Popen(
214
+ cmd,
215
+ stdout=subprocess.DEVNULL,
216
+ stderr=subprocess.DEVNULL,
217
+ env=env
218
+ )
219
+
220
+ # Register cleanup
221
+ atexit.register(self.stop_server)
222
+
223
+ # Wait for server to be ready
224
+ print("⏳ Waiting for server to be ready...")
225
+ for _ in range(20): # Wait up to 20s
226
+ try:
227
+ requests.get(f"http://localhost:{self.server_port}/health", timeout=1)
228
+ print("βœ… llama-server is ready!")
229
+ return
230
+ except:
231
+ time.sleep(1)
232
+
233
+ print("⚠️ Server start timed out (but might still be loading).")
234
+
235
+ def stop_server(self):
236
+ """Kill the server process"""
237
+ if self.server_process:
238
+ print("πŸ›‘ Stopping llama-server...")
239
+ self.server_process.terminate()
240
+ self.server_process = None
241
+
242
+ def analyze(self, text, model_selection="api"):
243
+ """
244
+ Analyze text using LangChain RetrievalQA with selected model
245
+ """
246
+ if not self.vector_store:
247
+ return "❌ Vector Store not initialized."
248
+
249
+ # Select LLM
250
+ selected_llm = None
251
+ if "api" in model_selection.lower():
252
+ if self.groq_llm_instance:
253
+ selected_llm = self.groq_llm_instance
254
+ else:
255
+ return "❌ Groq API not available. Please check API Key."
256
+ else:
257
+ if self.local_llm_instance:
258
+ selected_llm = self.local_llm_instance
259
+ else:
260
+ return "❌ Local Model not available. Please check server logs."
261
+
262
+ # Custom Prompt Template
263
+ template = """<|im_start|>system
264
+ You are CyberGuard - an AI specialized in Phishing Detection.
265
+ Task: Analyze the provided URL and HTML snippet to classify the website as 'PHISHING' or 'BENIGN'.
266
+ Check specifically for BRAND IMPERSONATION (e.g. Facebook, Google, Banks).
267
+ If the HTML content is missing, empty, or contains an error message (like "Could not fetch website content"), YOU MUST RETURN classification by ANALYZING the URL.
268
+ Classification Rules:
269
+ - PHISHING: Typosquatting URLs (e.g., paypa1.com), hidden login forms, obfuscated javascript, mismatched branding vs URL.
270
+ - BENIGN: Legitimate website, clean code, URL matches the content/brand.
271
+
272
+ RETURN THE RESULT IN THE EXACT FOLLOWING FORMAT (NO PREAMBLE):
273
+
274
+ CLASSIFICATION: [PHISHING or BENIGN]
275
+ CONFIDENCE SCORE: [0-100]%
276
+ EXPLANATION: [Write 3-4 concise sentences explaining the main reason]
277
+ <|im_end|>
278
+ <|im_start|>user
279
+ Context from knowledge base:
280
+ {context}
281
+
282
+ Input to analyze:
283
+ {question}
284
+ <|im_end|>
285
+ <|im_start|>assistant
286
+ """
287
+
288
+ PROMPT = PromptTemplate(
289
+ template=template,
290
+ input_variables=["context", "question"]
291
+ )
292
+
293
+ # Create QA Chain
294
+ qa_chain = RetrievalQA.from_chain_type(
295
+ llm=selected_llm,
296
+ chain_type="stuff",
297
+ retriever=self.vector_store.as_retriever(
298
+ search_type="mmr",
299
+ search_kwargs={"k": 3, "fetch_k": 10}
300
+ ),
301
+ chain_type_kwargs={"prompt": PROMPT}
302
+ )
303
+
304
+ try:
305
+ print(f"πŸ€– Generating response using {model_selection}...")
306
+ response = qa_chain.invoke(text)
307
+ return response['result']
308
+ except Exception as e:
309
+ return f"❌ Error: {str(e)}"