File size: 2,821 Bytes
faa44eb be05fd6 faa44eb be05fd6 faa44eb be05fd6 faa44eb be05fd6 faa44eb be05fd6 faa44eb be05fd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import os
from typing import Any, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from dotenv import load_dotenv
load_dotenv()
# --- CRITICAL FIX: Handle Import Error ---
try:
from mlx_lm import load, generate
HAS_MLX = True
except ImportError:
HAS_MLX = False
# ----------------------------------------
class MLXLLM(LLM):
"""Custom LangChain Wrapper for MLX Models (with Cloud Fallback)"""
model_id: str = os.getenv("MODEL_ID", "mlx-community/Llama-3.2-3B-Instruct-4bit")
model: Any = None
tokenizer: Any = None
max_tokens: int = int(os.getenv("MAX_TOKENS", 512))
pipeline: Any = None # For Cloud Fallback
def __init__(self, **kwargs):
super().__init__(**kwargs)
if HAS_MLX:
print(f"๐ Loading MLX Model: {self.model_id}")
self.model, self.tokenizer = load(self.model_id)
else:
print(f"โ ๏ธ MLX not found. Falling back to HuggingFace Transformers (CPU/Cloud).")
# Fallback: Use standard Transformers
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# Use the MODEL_ID env var (set to 'gpt2' or 'facebook/opt-125m' in HF Secrets)
# Do NOT use the MLX model ID here, as it requires MLX format.
cloud_model_id = os.getenv("MODEL_ID", "gpt2")
self.pipeline = pipeline(
"text-generation",
model=cloud_model_id,
max_new_tokens=self.max_tokens
)
@property
def _llm_type(self) -> str:
return "mlx_llama" if HAS_MLX else "transformers_fallback"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
if HAS_MLX:
# MLX Generation Logic
messages = [{"role": "user", "content": prompt}]
formatted_prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
response = generate(
self.model,
self.tokenizer,
prompt=formatted_prompt,
verbose=False,
max_tokens=self.max_tokens
)
return response
else:
# Cloud/CPU Fallback Logic
# Simple text generation for MVP
response = self.pipeline(prompt)[0]['generated_text']
# Remove the prompt from the response if needed
return response[len(prompt):]
|