Phishing_Testing / llm_client.py
dungeon29's picture
Update llm_client.py
323cb1a verified
import os
import requests
import subprocess
import tarfile
import stat
import time
import atexit
from huggingface_hub import hf_hub_download
from langchain_core.language_models import LLM
from langchain.chains import RetrievalQA
from langchain_core.prompts import PromptTemplate
from typing import Any, List, Optional, Mapping
# --- Helper to Setup llama-server ---
def setup_llama_binaries():
"""
Download and extract llama-server binary and libs from official releases
"""
# Latest release URL for Linux x64 (b4991 equivalent or newer)
CLI_URL = "https://github.com/ggml-org/llama.cpp/releases/download/b7312/llama-b7312-bin-ubuntu-x64.tar.gz"
LOCAL_TAR = "llama-cli.tar.gz"
BIN_DIR = "./llama_bin"
SERVER_BIN = os.path.join(BIN_DIR, "bin/llama-server") # Look for server binary
if os.path.exists(SERVER_BIN):
return SERVER_BIN, BIN_DIR
try:
print("⬇️ Downloading llama.cpp binaries...")
response = requests.get(CLI_URL, stream=True)
if response.status_code == 200:
with open(LOCAL_TAR, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("πŸ“¦ Extracting binaries...")
os.makedirs(BIN_DIR, exist_ok=True)
with tarfile.open(LOCAL_TAR, "r:gz") as tar:
tar.extractall(path=BIN_DIR)
# Locate llama-server
found_bin = None
for root, dirs, files in os.walk(BIN_DIR):
if "llama-server" in files:
found_bin = os.path.join(root, "llama-server")
break
if not found_bin:
print("❌ Could not find llama-server in extracted files.")
return None, None
# Make executable
st = os.stat(found_bin)
os.chmod(found_bin, st.st_mode | stat.S_IEXEC)
print(f"βœ… llama-server binary ready at {found_bin}!")
return found_bin, BIN_DIR
else:
print(f"❌ Failed to download binaries: {response.status_code}")
return None, None
except Exception as e:
print(f"❌ Error setting up llama-server: {e}")
return None, None
# --- Local LLM Wrapper ---
class LocalLLM(LLM):
local_server_url: str = "http://localhost:8080"
@property
def _llm_type(self) -> str:
return "local_qwen"
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
print("πŸ’» Using Local Qwen3-0.6B...")
try:
# OpenAI-compatible completion endpoint
payload = {
"prompt": prompt,
"n_predict": 1024,
"temperature": 0.3,
"stop": (stop or []) + ["<|im_end|>", "Input:", "Context:"]
}
response = requests.post(
f"{self.local_server_url}/completion",
json=payload,
timeout=300
)
if response.status_code == 200:
return response.json()["content"]
else:
return f"❌ Local Server Error: {response.text}"
except Exception as e:
return f"❌ Local Inference Failed: {e}"
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"local_server_url": self.local_server_url}
# --- Groq API LLM Wrapper ---
class GroqLLM(LLM):
groq_client: Any = None
groq_model: str = "qwen/qwen3-32b"
@property
def _llm_type(self) -> str:
return "groq_qwen"
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
if not self.groq_client:
return "❌ Groq API Key not set or client initialization failed."
print(f"⚑ Using Groq API ({self.groq_model})...")
try:
stop_seq = (stop or []) + ["<|im_end|>", "Input:", "Context:"]
chat_completion = self.groq_client.chat.completions.create(
messages=[
{"role": "user", "content": prompt}
],
model=self.groq_model,
temperature=0.3,
max_tokens=1024,
stop=stop_seq,
reasoning_format="hidden"
)
return chat_completion.choices[0].message.content
except Exception as e:
return f"❌ Groq API Failed: {e}"
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"model": self.groq_model}
class LLMClient:
def __init__(self, vector_store=None):
"""
Initialize LLM Client with support for both API and Local
"""
self.vector_store = vector_store
self.server_process = None
self.server_port = 8080
self.groq_client = None
self.local_llm_instance = None
self.groq_llm_instance = None
# 1. Setup Groq Client
groq_api_key = os.environ.get("GROQ_API_KEY")
self.groq_model = "qwen/qwen3-32b"
if groq_api_key:
try:
from groq import Groq
print(f"⚑ Initializing Native Groq Client ({self.groq_model})...")
self.groq_client = Groq(api_key=groq_api_key)
self.groq_llm_instance = GroqLLM(
groq_client=self.groq_client,
groq_model=self.groq_model
)
print("βœ… Groq Client ready.")
except Exception as e:
print(f"⚠️ Groq Init Failed: {e}")
# 2. Setup Local Fallback (Always setup as requested)
try:
# Setup Binary
self.server_bin, self.lib_path = setup_llama_binaries()
# Download Model (Qwen3-0.6B)
print("οΏ½ Loading Local Qwen3-4B (GGUF)...")
model_repo = "Qwen/Qwen3-4B-GGUF"
filename = "Qwen3-4B-Q4_K_M.gguf"
self.model_path = hf_hub_download(
repo_id=model_repo,
filename=filename
)
print(f"βœ… Model downloaded to: {self.model_path}")
# Start Server
self.start_local_server()
self.local_llm_instance = LocalLLM(
local_server_url=f"http://localhost:{self.server_port}"
)
except Exception as e:
print(f"⚠️ Could not setup local fallback: {e}")
def start_local_server(self):
"""Start llama-server in background"""
if not self.server_bin or not self.model_path:
return
print("πŸš€ Starting llama-server...")
# Setup Env
env = os.environ.copy()
lib_paths = [os.path.dirname(self.server_bin)]
lib_subdir = os.path.join(self.lib_path, "lib")
if os.path.exists(lib_subdir):
lib_paths.append(lib_subdir)
env["LD_LIBRARY_PATH"] = ":".join(lib_paths) + ":" + env.get("LD_LIBRARY_PATH", "")
cmd = [
self.server_bin,
"-m", self.model_path,
"--port", str(self.server_port),
"-c", "8192", # Increased context window to handle web content
"--host", "0.0.0.0",
"--mlock" # Lock model in RAM to prevent swapping
]
# Launch process
self.server_process = subprocess.Popen(
cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
env=env
)
# Register cleanup
atexit.register(self.stop_server)
# Wait for server to be ready
print("⏳ Waiting for server to be ready...")
for _ in range(20): # Wait up to 20s
try:
requests.get(f"http://localhost:{self.server_port}/health", timeout=1)
print("βœ… llama-server is ready!")
return
except:
time.sleep(1)
print("⚠️ Server start timed out (but might still be loading).")
def stop_server(self):
"""Kill the server process"""
if self.server_process:
print("πŸ›‘ Stopping llama-server...")
self.server_process.terminate()
self.server_process = None
def analyze(self, text, model_selection="api"):
"""
Analyze text using LangChain RetrievalQA with selected model
"""
if not self.vector_store:
return "❌ Vector Store not initialized."
# Select LLM
selected_llm = None
if "api" in model_selection.lower():
if self.groq_llm_instance:
selected_llm = self.groq_llm_instance
else:
return "❌ Groq API not available. Please check API Key."
else:
if self.local_llm_instance:
selected_llm = self.local_llm_instance
else:
return "❌ Local Model not available. Please check server logs."
# Custom Prompt Template
template = """<|im_start|>system
You are CyberGuard - an AI specialized in Phishing Detection.
Task: Analyze the provided URL and HTML snippet to classify the website as 'PHISHING' or 'BENIGN'.
Check specifically for BRAND IMPERSONATION (e.g. Facebook, Google, Banks).
If the HTML content is missing, empty, or contains an error message (like "Could not fetch website content"), YOU MUST RETURN classification by ANALYZING the URL.
Classification Rules:
- PHISHING: Typosquatting URLs (e.g., paypa1.com), hidden login forms, obfuscated javascript, mismatched branding vs URL.
- BENIGN: Legitimate website, clean code, URL matches the content/brand.
Let's think step by step to verify logical inconsistencies between URL and Content before deciding.
RETURN THE RESULT IN THE EXACT FOLLOWING FORMAT (NO PREAMBLE):
CLASSIFICATION: [PHISHING or BENIGN]
CONFIDENCE SCORE: [0-100]%
EXPLANATION: [Write 3-4 concise sentences explaining the main reason]
<|im_end|>
<|im_start|>user
Context from knowledge base:
{context}
Input to analyze:
{question}
<|im_end|>
<|im_start|>assistant
"""
PROMPT = PromptTemplate(
template=template,
input_variables=["context", "question"]
)
# Create QA Chain
qa_chain = RetrievalQA.from_chain_type(
llm=selected_llm,
chain_type="stuff",
retriever=self.vector_store.as_retriever(
search_type="mmr",
search_kwargs={"k": 3, "fetch_k": 10}
),
chain_type_kwargs={"prompt": PROMPT}
)
try:
print(f"πŸ€– Generating response using {model_selection}...")
response = qa_chain.invoke(text)
return response['result']
except Exception as e:
return f"❌ Error: {str(e)}"