| | """Model manager for stance detection model""" |
| |
|
| | import os |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class StanceModelManager: |
| | """Manages stance detection model loading and predictions""" |
| | |
| | def __init__(self): |
| | self.model = None |
| | self.tokenizer = None |
| | self.device = None |
| | self.model_loaded = False |
| | |
| | def load_model(self, model_id: str, api_key: str = None): |
| | """Load model and tokenizer from Hugging Face""" |
| | if self.model_loaded: |
| | logger.info("Stance model already loaded") |
| | return |
| | |
| | try: |
| | logger.info(f"Loading stance model from Hugging Face: {model_id}") |
| | |
| | |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | logger.info(f"Using device: {self.device}") |
| | |
| | |
| | token = api_key if api_key else None |
| | |
| | |
| | logger.info("Loading tokenizer...") |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | model_id, |
| | token=token, |
| | trust_remote_code=True |
| | ) |
| | |
| | logger.info("Loading model...") |
| | self.model = AutoModelForSequenceClassification.from_pretrained( |
| | model_id, |
| | token=token, |
| | trust_remote_code=True |
| | ) |
| | self.model.to(self.device) |
| | self.model.eval() |
| | |
| | self.model_loaded = True |
| | logger.info("✓ Stance model loaded successfully from Hugging Face!") |
| | |
| | except Exception as e: |
| | logger.error(f"Error loading stance model: {str(e)}") |
| | raise RuntimeError(f"Failed to load stance model: {str(e)}") |
| | |
| | def predict(self, topic: str, argument: str) -> dict: |
| | """Make a single stance prediction""" |
| | if not self.model_loaded: |
| | raise RuntimeError("Stance model not loaded") |
| | |
| | |
| | text = f"Topic: {topic} [SEP] Argument: {argument}" |
| | |
| | |
| | inputs = self.tokenizer( |
| | text, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=512, |
| | padding=True |
| | ).to(self.device) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**inputs) |
| | probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| | predicted_class = torch.argmax(probabilities, dim=-1).item() |
| | |
| | |
| | prob_con = probabilities[0][0].item() |
| | prob_pro = probabilities[0][1].item() |
| | |
| | |
| | stance = "PRO" if predicted_class == 1 else "CON" |
| | confidence = probabilities[0][predicted_class].item() |
| | |
| | return { |
| | "predicted_stance": stance, |
| | "confidence": confidence, |
| | "probability_con": prob_con, |
| | "probability_pro": prob_pro |
| | } |
| |
|
| |
|
| | |
| | stance_model_manager = StanceModelManager() |
| |
|
| |
|