Yassine Mhirsi
refactor: Simplify topic extraction logic in TopicService by removing Pydantic schema, enhancing JSON response handling, and adding fuzzy matching for improved topic validation.
94c2a9a
| """Service for topic extraction from text using LangChain Groq""" | |
| import logging | |
| import json | |
| from typing import Optional, List | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from langchain_groq import ChatGroq | |
| from langsmith import traceable | |
| from config import GROQ_API_KEY | |
| logger = logging.getLogger(__name__) | |
| # Predefined topics list | |
| PREDEFINED_TOPICS = [ | |
| "Assisted suicide should be a criminal offence", | |
| "We should abolish intellectual property rights", | |
| "Homeschooling should be banned", | |
| "The vow of celibacy should be abandoned", | |
| "We should legalize prostitution", | |
| "We should ban private military companies", | |
| "We should abolish capital punishment", | |
| "Foster care brings more harm than good", | |
| "Routine child vaccinations should be mandatory", | |
| "We should abolish the three-strikes laws", | |
| "We should subsidize student loans", | |
| "We should end the use of economic sanctions", | |
| "We should end mandatory retirement", | |
| "We should close Guantanamo Bay detention camp", | |
| "We should subsidize space exploration", | |
| "We should abandon the use of school uniform", | |
| "The use of public defenders should be mandatory", | |
| "We should adopt an austerity regime", | |
| "Social media platforms should be regulated by the government", | |
| "We should ban human cloning", | |
| "We should adopt atheism", | |
| "We should introduce compulsory voting", | |
| "We should adopt libertarianism", | |
| "We should abolish the right to keep and bear arms", | |
| "We should legalize sex selection", | |
| "We should abandon marriage", | |
| "Entrapment should be legalized", | |
| "We should end affirmative action", | |
| "We should prohibit women in combat", | |
| "We should adopt a zero-tolerance policy in schools", | |
| "We should subsidize vocational education", | |
| "We should ban the use of child actors", | |
| "We should legalize cannabis", | |
| "We should ban cosmetic surgery", | |
| "We should end racial profiling", | |
| "We should prohibit flag burning", | |
| "The USA is a good country to live in", | |
| "We should ban algorithmic trading", | |
| "We should fight for the abolition of nuclear weapons", | |
| "We should fight urbanization", | |
| "We should subsidize journalism", | |
| ] | |
| class TopicService: | |
| """Service for extracting topics from text arguments by matching to predefined topics""" | |
| def __init__(self): | |
| self.llm = None | |
| self.model_name = "openai/gpt-oss-safeguard-20b" # Default model | |
| self.initialized = False | |
| self.predefined_topics = PREDEFINED_TOPICS | |
| def initialize(self, model_name: Optional[str] = None): | |
| """Initialize the Groq LLM""" | |
| if self.initialized: | |
| logger.info("Topic service already initialized") | |
| return | |
| if not GROQ_API_KEY: | |
| raise ValueError("GROQ_API_KEY not found in environment variables") | |
| if model_name: | |
| self.model_name = model_name | |
| try: | |
| logger.info(f"Initializing topic extraction service with model: {self.model_name}") | |
| self.llm = ChatGroq( | |
| model=self.model_name, | |
| api_key=GROQ_API_KEY, | |
| temperature=0.0, | |
| max_tokens=512, | |
| ) | |
| self.initialized = True | |
| logger.info("✓ Topic extraction service initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Error initializing topic service: {str(e)}") | |
| raise RuntimeError(f"Failed to initialize topic service: {str(e)}") | |
| def _get_system_message(self) -> str: | |
| """Generate system message with predefined topics list""" | |
| topics_list = "\n".join([f"{i+1}. {topic}" for i, topic in enumerate(self.predefined_topics)]) | |
| return f"""You are a topic classification model. Your task is to select the MOST SIMILAR topic from the predefined list below that best matches the user's input text. | |
| IMPORTANT: You MUST return EXACTLY one of the predefined topics below. Do not create new topics or modify the wording. | |
| Return your response as a JSON object with a single "topic" field containing the exact topic text from the list. | |
| Predefined Topics: | |
| {topics_list} | |
| Instructions: | |
| 1. Analyze the user's input text carefully | |
| 2. Identify the main theme, subject, or argument being discussed | |
| 3. Find the topic from the predefined list that is MOST SIMILAR to the input text | |
| 4. Return a JSON object with the EXACT topic text as it appears in the list above | |
| Examples: | |
| - Input: "I think we need to make assisted suicide illegal and punishable by law." | |
| Output: {{"topic": "Assisted suicide should be a criminal offence"}} | |
| - Input: "Student debt is crushing young people. The government should help pay for college." | |
| Output: {{"topic": "We should subsidize student loans"}} | |
| - Input: "Marijuana should be legal for adults to use recreationally." | |
| Output: {{"topic": "We should legalize cannabis"}} | |
| """ | |
| def extract_topic(self, text: str) -> str: | |
| """ | |
| Extract a topic from the given text/argument by matching to predefined topics | |
| Args: | |
| text: The input text/argument to extract topic from | |
| Returns: | |
| The extracted topic string (must be one of the predefined topics) | |
| """ | |
| if not self.initialized: | |
| self.initialize() | |
| if not text or not isinstance(text, str): | |
| raise ValueError("Text must be a non-empty string") | |
| text = text.strip() | |
| if len(text) == 0: | |
| raise ValueError("Text cannot be empty") | |
| system_message = self._get_system_message() | |
| try: | |
| result = self.llm.invoke( | |
| [ | |
| SystemMessage(content=system_message), | |
| HumanMessage(content=text), | |
| ] | |
| ) | |
| # Extract content from the response | |
| response_content = result.content.strip() | |
| # Try to parse as JSON first | |
| try: | |
| parsed_response = json.loads(response_content) | |
| selected_topic = parsed_response.get("topic", "").strip() | |
| except json.JSONDecodeError: | |
| # If not JSON, try to extract topic from plain text | |
| # Look for the topic in the response text | |
| selected_topic = response_content.strip() | |
| # Remove quotes if present | |
| if selected_topic.startswith('"') and selected_topic.endswith('"'): | |
| selected_topic = selected_topic[1:-1] | |
| elif selected_topic.startswith("'") and selected_topic.endswith("'"): | |
| selected_topic = selected_topic[1:-1] | |
| if not selected_topic: | |
| raise ValueError("No topic found in LLM response") | |
| # Validate that the returned topic is in the predefined list | |
| if selected_topic not in self.predefined_topics: | |
| logger.warning( | |
| f"LLM returned topic not in predefined list: '{selected_topic}'. " | |
| f"Attempting to find closest match..." | |
| ) | |
| # Try to find the closest match (case-insensitive) | |
| selected_topic_lower = selected_topic.lower() | |
| for predefined_topic in self.predefined_topics: | |
| if predefined_topic.lower() == selected_topic_lower: | |
| selected_topic = predefined_topic | |
| logger.info(f"Found case-insensitive match: '{selected_topic}'") | |
| break | |
| else: | |
| # If still no match, try fuzzy matching by checking if the topic contains key words | |
| # This is a fallback for when the LLM returns something close but not exact | |
| best_match = None | |
| best_match_score = 0 | |
| selected_words = set(selected_topic_lower.split()) | |
| for predefined_topic in self.predefined_topics: | |
| predefined_words = set(predefined_topic.lower().split()) | |
| # Calculate word overlap | |
| overlap = len(selected_words & predefined_words) | |
| if overlap > best_match_score and overlap >= 2: # At least 2 words must match | |
| best_match_score = overlap | |
| best_match = predefined_topic | |
| if best_match: | |
| logger.info(f"Found fuzzy match: '{selected_topic}' -> '{best_match}'") | |
| selected_topic = best_match | |
| else: | |
| # If still no match, log error and raise | |
| logger.error( | |
| f"Could not match returned topic '{selected_topic}' to any predefined topic. " | |
| f"Available topics: {self.predefined_topics[:3]}..." | |
| ) | |
| raise ValueError( | |
| f"Returned topic '{selected_topic}' is not in the predefined topics list" | |
| ) | |
| return selected_topic | |
| except Exception as e: | |
| logger.error(f"Error extracting topic: {str(e)}") | |
| raise RuntimeError(f"Topic extraction failed: {str(e)}") | |
| def batch_extract_topics(self, texts: List[str]) -> List[str]: | |
| """ | |
| Extract topics from multiple texts | |
| Args: | |
| texts: List of input texts/arguments | |
| Returns: | |
| List of extracted topics | |
| """ | |
| if not self.initialized: | |
| self.initialize() | |
| if not texts or not isinstance(texts, list): | |
| raise ValueError("Texts must be a non-empty list") | |
| results = [] | |
| for text in texts: | |
| try: | |
| topic = self.extract_topic(text) | |
| results.append(topic) | |
| except Exception as e: | |
| logger.error(f"Error extracting topic for text '{text[:50]}...': {str(e)}") | |
| results.append(None) # Or raise, depending on desired behavior | |
| return results | |
| # Initialize singleton instance | |
| topic_service = TopicService() | |