|
|
from collections import deque |
|
|
from typing import Any, Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList |
|
|
from transformers.generation.logits_process import ( |
|
|
TemperatureLogitsWarper, |
|
|
TopKLogitsWarper, |
|
|
TopPLogitsWarper, |
|
|
) |
|
|
from transformers.generation.utils import ( |
|
|
GenerateDecoderOnlyOutput, |
|
|
GenerateEncoderDecoderOutput, |
|
|
) |
|
|
|
|
|
|
|
|
def generate( |
|
|
model: Any, |
|
|
input_ids: torch.LongTensor, |
|
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
|
generation_config: Optional[GenerationConfig] = None, |
|
|
synced_gpus: bool = False, |
|
|
streamer: Optional[Any] = None, |
|
|
**model_kwargs, |
|
|
) -> Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, torch.LongTensor]: |
|
|
"""Custom decoding with DeepCONF (confidence-based early stopping). |
|
|
|
|
|
Args: |
|
|
model: PreTrainedModel with a LM head. |
|
|
input_ids: Prompt ids of shape (batch, seq_len). |
|
|
logits_processor: Optional logits processors. |
|
|
stopping_criteria: Optional stopping criteria. |
|
|
generation_config: GenerationConfig controlling sampling/outputs. |
|
|
synced_gpus: Keep looping to max length for distributed setups. |
|
|
streamer: Optional streamer for incremental tokens. |
|
|
**model_kwargs: Forward pass kwargs (e.g., attention_mask). |
|
|
|
|
|
Returns: |
|
|
GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, or LongTensor |
|
|
depending on `return_dict_in_generate` and model type. |
|
|
""" |
|
|
|
|
|
|
|
|
if logits_processor is None: |
|
|
logits_processor = LogitsProcessorList() |
|
|
if stopping_criteria is None: |
|
|
stopping_criteria = StoppingCriteriaList() |
|
|
|
|
|
|
|
|
enable_conf = getattr(generation_config, "enable_conf", False) |
|
|
enable_early_stopping = getattr(generation_config, "enable_early_stopping", True) |
|
|
window_size = getattr(generation_config, "window_size", 2048) |
|
|
threshold = getattr( |
|
|
generation_config, "threshold", 17.0 |
|
|
) |
|
|
conf_topk = getattr( |
|
|
generation_config, "conf_topk", 20 |
|
|
) |
|
|
|
|
|
|
|
|
if not enable_conf: |
|
|
return model._sample( |
|
|
input_ids, |
|
|
logits_processor=logits_processor, |
|
|
stopping_criteria=stopping_criteria, |
|
|
generation_config=generation_config, |
|
|
synced_gpus=synced_gpus, |
|
|
streamer=streamer, |
|
|
**model_kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
pad_token_id = generation_config.pad_token_id |
|
|
if pad_token_id is None and hasattr(generation_config, "_pad_token_tensor"): |
|
|
pad_token_id = generation_config._pad_token_tensor |
|
|
if pad_token_id is None and hasattr(model.config, "pad_token_id"): |
|
|
pad_token_id = model.config.pad_token_id |
|
|
if pad_token_id is None and generation_config.eos_token_id is not None: |
|
|
|
|
|
pad_token_id = generation_config.eos_token_id |
|
|
|
|
|
output_attentions = generation_config.output_attentions |
|
|
output_hidden_states = generation_config.output_hidden_states |
|
|
output_scores = generation_config.output_scores |
|
|
output_logits = generation_config.output_logits |
|
|
return_dict_in_generate = generation_config.return_dict_in_generate |
|
|
output_confidences = getattr(generation_config, "output_confidences", False) |
|
|
|
|
|
deepconf_variant = getattr( |
|
|
generation_config, "deepconf_variant", None |
|
|
) |
|
|
deepconf_eta = getattr(generation_config, "deepconf_eta", None) |
|
|
deepconf_warmup_confidences = getattr( |
|
|
generation_config, "deepconf_warmup_confidences", None |
|
|
) |
|
|
has_eos_stopping_criteria = any( |
|
|
hasattr(criteria, "eos_token_id") for criteria in stopping_criteria |
|
|
) |
|
|
do_sample = generation_config.do_sample |
|
|
|
|
|
|
|
|
if enable_conf and threshold is not None: |
|
|
pass |
|
|
elif ( |
|
|
enable_conf |
|
|
and deepconf_variant is not None |
|
|
and deepconf_warmup_confidences is not None |
|
|
): |
|
|
confs = deepconf_warmup_confidences |
|
|
if hasattr(confs, "detach"): |
|
|
confs = confs.detach().cpu().numpy() |
|
|
elif isinstance(confs, torch.Tensor): |
|
|
confs = confs.cpu().numpy() |
|
|
confs = np.asarray(confs, dtype=np.float32).ravel() |
|
|
eta = deepconf_eta |
|
|
if eta is None: |
|
|
eta = ( |
|
|
0.1 |
|
|
if deepconf_variant == "low" |
|
|
else 0.9 |
|
|
if deepconf_variant == "high" |
|
|
else 0.5 |
|
|
) |
|
|
pct = max(0.0, min(100.0, 100.0 - (eta * 100.0))) |
|
|
threshold = float(np.percentile(confs, pct)) |
|
|
|
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
|
raw_logits = () if (return_dict_in_generate and output_logits) else None |
|
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
|
decoder_hidden_states = ( |
|
|
() if (return_dict_in_generate and output_hidden_states) else None |
|
|
) |
|
|
|
|
|
|
|
|
if return_dict_in_generate and model.config.is_encoder_decoder: |
|
|
encoder_attentions = ( |
|
|
model_kwargs["encoder_outputs"].get("attentions") |
|
|
if output_attentions |
|
|
else None |
|
|
) |
|
|
encoder_hidden_states = ( |
|
|
model_kwargs["encoder_outputs"].get("hidden_states") |
|
|
if output_hidden_states |
|
|
else None |
|
|
) |
|
|
|
|
|
|
|
|
batch_size, cur_len = input_ids.shape[:2] |
|
|
unfinished_sequences = torch.ones( |
|
|
batch_size, dtype=torch.long, device=input_ids.device |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf_group_lists = [deque(maxlen=window_size) for _ in range(batch_size)] |
|
|
conf_grouped_sums = [ |
|
|
0.0 for _ in range(batch_size) |
|
|
] |
|
|
|
|
|
|
|
|
step_confidences = [] if (return_dict_in_generate and output_confidences) else None |
|
|
|
|
|
|
|
|
steps = 0 |
|
|
max_new_tokens = getattr(generation_config, "max_new_tokens", None) or 512 |
|
|
|
|
|
|
|
|
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) |
|
|
while steps < max_new_tokens and unfinished_sequences.max() != 0: |
|
|
|
|
|
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
|
|
|
|
|
model_inputs.update( |
|
|
{"output_attentions": output_attentions} if output_attentions else {} |
|
|
) |
|
|
model_inputs.update( |
|
|
{"output_hidden_states": output_hidden_states} |
|
|
if output_hidden_states |
|
|
else {} |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**model_inputs, return_dict=True) |
|
|
next_token_logits = outputs.logits[:, -1, :].detach() |
|
|
|
|
|
|
|
|
if hasattr(outputs, "past_key_values") and outputs.past_key_values is not None: |
|
|
model_kwargs["past_key_values"] = outputs.past_key_values |
|
|
|
|
|
|
|
|
next_token_scores = logits_processor(input_ids, next_token_logits) |
|
|
|
|
|
|
|
|
warpers = LogitsProcessorList() |
|
|
|
|
|
temperature = getattr(generation_config, "temperature", 1.0) |
|
|
if temperature is not None and temperature != 1.0: |
|
|
warpers.append(TemperatureLogitsWarper(temperature)) |
|
|
|
|
|
top_k = getattr(generation_config, "top_k", None) |
|
|
if top_k is not None and isinstance(top_k, int) and top_k > 0: |
|
|
warpers.append(TopKLogitsWarper(top_k)) |
|
|
|
|
|
top_p = getattr(generation_config, "top_p", None) |
|
|
if top_p is not None and top_p < 1.0: |
|
|
warpers.append(TopPLogitsWarper(top_p)) |
|
|
if len(warpers) > 0: |
|
|
next_token_scores = warpers(input_ids, next_token_scores) |
|
|
|
|
|
|
|
|
if return_dict_in_generate: |
|
|
if output_scores: |
|
|
scores += (next_token_scores,) |
|
|
if output_logits: |
|
|
raw_logits += (next_token_logits,) |
|
|
if output_attentions: |
|
|
decoder_attentions += ( |
|
|
(outputs.decoder_attentions,) |
|
|
if model.config.is_encoder_decoder |
|
|
else (outputs.attentions,) |
|
|
) |
|
|
if model.config.is_encoder_decoder: |
|
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
|
|
if output_hidden_states: |
|
|
decoder_hidden_states += ( |
|
|
(outputs.decoder_hidden_states,) |
|
|
if model.config.is_encoder_decoder |
|
|
else (outputs.hidden_states,) |
|
|
) |
|
|
|
|
|
|
|
|
if do_sample: |
|
|
probs = F.softmax(next_token_scores, dim=-1) |
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
|
|
else: |
|
|
next_tokens = torch.argmax(next_token_scores, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
|
|
|
deepconf_stopping = torch.ones( |
|
|
batch_size, dtype=torch.bool, device=input_ids.device |
|
|
) |
|
|
step_conf_values = [ |
|
|
0.0 |
|
|
] * batch_size |
|
|
|
|
|
for i in range(batch_size): |
|
|
if not unfinished_sequences[i]: |
|
|
continue |
|
|
|
|
|
|
|
|
top_probs, _ = torch.topk(probs[i], k=conf_topk, dim=-1) |
|
|
|
|
|
|
|
|
eps = torch.finfo(top_probs.dtype).eps if top_probs.dtype == torch.float32 else 1e-7 |
|
|
top_probs = torch.clamp(top_probs, min=eps) |
|
|
log_probs = torch.log(top_probs) |
|
|
|
|
|
conf = -log_probs.mean().item() |
|
|
|
|
|
|
|
|
if len(conf_group_lists[i]) >= window_size: |
|
|
conf_grouped_sums[i] -= conf_group_lists[i][0] |
|
|
conf_group_lists[i].append(conf) |
|
|
conf_grouped_sums[i] += conf |
|
|
|
|
|
|
|
|
if enable_early_stopping and len(conf_group_lists[i]) >= window_size: |
|
|
avg_conf = conf_grouped_sums[i] / len(conf_group_lists[i]) |
|
|
if avg_conf < threshold: |
|
|
deepconf_stopping[i] = False |
|
|
|
|
|
if step_confidences is not None: |
|
|
step_conf_values[i] = conf |
|
|
|
|
|
if step_confidences is not None: |
|
|
|
|
|
step_confidences.append( |
|
|
torch.tensor(step_conf_values, device=input_ids.device) |
|
|
) |
|
|
|
|
|
|
|
|
if has_eos_stopping_criteria and pad_token_id is not None: |
|
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( |
|
|
1 - unfinished_sequences |
|
|
) |
|
|
|
|
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
|
|
|
|
|
if model_kwargs.get("attention_mask") is not None: |
|
|
attn = model_kwargs["attention_mask"] |
|
|
model_kwargs["attention_mask"] = torch.cat( |
|
|
[ |
|
|
attn, |
|
|
torch.ones((batch_size, 1), dtype=attn.dtype, device=attn.device), |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
|
|
|
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 |
|
|
if streamer is not None: |
|
|
streamer.put(next_tokens.cpu()) |
|
|
|
|
|
|
|
|
sc = stopping_criteria(input_ids, scores) |
|
|
if isinstance(sc, torch.Tensor): |
|
|
unfinished_sequences = unfinished_sequences & ~sc |
|
|
elif sc: |
|
|
|
|
|
unfinished_sequences = torch.zeros_like(unfinished_sequences) |
|
|
|
|
|
|
|
|
unfinished_sequences = unfinished_sequences & deepconf_stopping |
|
|
|
|
|
|
|
|
if unfinished_sequences.max() == 0 and not synced_gpus: |
|
|
break |
|
|
cur_len += 1 |
|
|
steps += 1 |
|
|
|
|
|
|
|
|
del outputs |
|
|
|
|
|
if streamer is not None: |
|
|
streamer.end() |
|
|
|
|
|
|
|
|
if return_dict_in_generate: |
|
|
|
|
|
confidences_tensor = None |
|
|
if step_confidences is not None and len(step_confidences) > 0: |
|
|
|
|
|
confidences_tensor = torch.stack(step_confidences, dim=0).transpose(0, 1) |
|
|
if model.config.is_encoder_decoder: |
|
|
output = GenerateEncoderDecoderOutput( |
|
|
sequences=input_ids, |
|
|
scores=scores, |
|
|
logits=raw_logits, |
|
|
encoder_attentions=encoder_attentions, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
decoder_attentions=decoder_attentions, |
|
|
cross_attentions=cross_attentions, |
|
|
decoder_hidden_states=decoder_hidden_states, |
|
|
past_key_values=model_kwargs.get("past_key_values"), |
|
|
) |
|
|
if confidences_tensor is not None: |
|
|
output["confidences"] = confidences_tensor |
|
|
try: |
|
|
setattr(output, "confidences", confidences_tensor) |
|
|
except Exception: |
|
|
pass |
|
|
return output |
|
|
else: |
|
|
output = GenerateDecoderOnlyOutput( |
|
|
sequences=input_ids, |
|
|
scores=scores, |
|
|
logits=raw_logits, |
|
|
attentions=decoder_attentions, |
|
|
hidden_states=decoder_hidden_states, |
|
|
past_key_values=model_kwargs.get("past_key_values"), |
|
|
) |
|
|
if confidences_tensor is not None: |
|
|
output["confidences"] = confidences_tensor |
|
|
try: |
|
|
setattr(output, "confidences", confidences_tensor) |
|
|
except Exception: |
|
|
pass |
|
|
return output |
|
|
else: |
|
|
return input_ids |
|
|
|