from dataclasses import dataclass from typing import List, Optional from graphgen.bases import Token @dataclass class TopkTokenModel: do_sample: bool = False temperature: float = 0 max_tokens: int = 4096 repetition_penalty: float = 1.05 num_beams: int = 1 topk: int = 50 topp: float = 0.95 topk_per_token: int = 5 # number of topk tokens to generate for each token async def generate_topk_per_token(self, text: str) -> List[Token]: """ Generate prob, text and candidates for each token of the model's output. This function is used to visualize the inference process. """ raise NotImplementedError async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None ) -> List[Token]: """ Generate prob and text for each token of the input text. This function is used to visualize the ppl. """ raise NotImplementedError async def generate_answer( self, text: str, history: Optional[List[str]] = None ) -> str: """ Generate answer from the model. """ raise NotImplementedError