kashif HF Staff commited on
Commit
e0297b7
·
1 Parent(s): 30add1f

formatting

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +81 -22
custom_generate/generate.py CHANGED
@@ -11,7 +11,10 @@ from transformers.generation.logits_process import (
11
  TopKLogitsWarper,
12
  TopPLogitsWarper,
13
  )
14
- from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput
 
 
 
15
 
16
 
17
  def generate(
@@ -50,8 +53,12 @@ def generate(
50
  # Get DeepCONF parameters from generation_config or set defaults
51
  enable_conf = getattr(generation_config, "enable_conf", False)
52
  window_size = getattr(generation_config, "window_size", 2048)
53
- threshold = getattr(generation_config, "threshold", 17.0) # Default threshold for confidence (positive value)
54
- conf_topk = getattr(generation_config, "conf_topk", 20) # Number of top tokens for confidence calculation
 
 
 
 
55
 
56
  # If DeepCONF is not enabled, fall back to standard sampling
57
  if not enable_conf:
@@ -83,16 +90,26 @@ def generate(
83
  return_dict_in_generate = generation_config.return_dict_in_generate
84
  output_confidences = getattr(generation_config, "output_confidences", False)
85
  # Optional DeepConf variant helpers (compute threshold from warmup confidences)
86
- deepconf_variant = getattr(generation_config, "deepconf_variant", None) # "low" or "high"
 
 
87
  deepconf_eta = getattr(generation_config, "deepconf_eta", None) # float in (0,1)
88
- deepconf_warmup_confidences = getattr(generation_config, "deepconf_warmup_confidences", None) # list/1D tensor
89
- has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
 
 
 
 
90
  do_sample = generation_config.do_sample
91
 
92
  # If a variant is requested and a warmup set of confidences is provided, derive the threshold
93
  if enable_conf and threshold is not None:
94
  pass
95
- elif enable_conf and deepconf_variant is not None and deepconf_warmup_confidences is not None:
 
 
 
 
96
  confs = deepconf_warmup_confidences
97
  if hasattr(confs, "detach"):
98
  confs = confs.detach().cpu().numpy()
@@ -101,7 +118,13 @@ def generate(
101
  confs = np.asarray(confs, dtype=np.float32).ravel()
102
  eta = deepconf_eta
103
  if eta is None:
104
- eta = 0.1 if deepconf_variant == "low" else 0.9 if deepconf_variant == "high" else 0.5
 
 
 
 
 
 
105
  pct = max(0.0, min(100.0, 100.0 - (eta * 100.0)))
106
  threshold = float(np.percentile(confs, pct))
107
 
@@ -110,22 +133,36 @@ def generate(
110
  raw_logits = () if (return_dict_in_generate and output_logits) else None
111
  decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
112
  cross_attentions = () if (return_dict_in_generate and output_attentions) else None
113
- decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
 
 
114
 
115
  # If model is an encoder-decoder, retrieve encoder attention weights and hidden states
116
  if return_dict_in_generate and model.config.is_encoder_decoder:
117
- encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
118
- encoder_hidden_states = model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
 
 
 
 
 
 
 
 
119
 
120
  # Keep track of which sequences are already finished
121
  batch_size, cur_len = input_ids.shape[:2]
122
- unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
 
 
123
  # Use public kv-cache via past_key_values
124
 
125
  # Initialize confidence tracking
126
  # Use deque for sliding window with fixed size
127
  conf_group_lists = [deque(maxlen=window_size) for _ in range(batch_size)]
128
- conf_grouped_sums = [0.0 for _ in range(batch_size)] # Running sums for efficient mean calculation
 
 
129
 
130
  # Optional per-step confidences for debugging/visualization
131
  step_confidences = [] if (return_dict_in_generate and output_confidences) else None
@@ -141,8 +178,14 @@ def generate(
141
  model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
142
 
143
  # Prepare variable output controls
144
- model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
145
- model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
 
 
 
 
 
 
146
 
147
  # Forward pass with proper KV cache handling
148
  with torch.no_grad():
@@ -181,14 +224,18 @@ def generate(
181
  raw_logits += (next_token_logits,)
182
  if output_attentions:
183
  decoder_attentions += (
184
- (outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,)
 
 
185
  )
186
  if model.config.is_encoder_decoder:
187
  cross_attentions += (outputs.cross_attentions,)
188
 
189
  if output_hidden_states:
190
  decoder_hidden_states += (
191
- (outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,)
 
 
192
  )
193
 
194
  # Token selection
@@ -203,8 +250,12 @@ def generate(
203
  # This uses the raw logits (next_token_logits) before warpers are applied.
204
  probs = F.softmax(next_token_logits, dim=-1)
205
 
206
- deepconf_stopping = torch.ones(batch_size, dtype=torch.bool, device=input_ids.device)
207
- step_conf_values = [0.0] * batch_size # collect per-sequence confidences for this step (full batch)
 
 
 
 
208
 
209
  for i in range(batch_size):
210
  if not unfinished_sequences[i]:
@@ -233,11 +284,15 @@ def generate(
233
 
234
  if step_confidences is not None:
235
  # Store this step's confidences as a tensor of shape (batch,)
236
- step_confidences.append(torch.tensor(step_conf_values, device=input_ids.device))
 
 
237
 
238
  # Finished sentences should have their next token be a padding token
239
  if has_eos_stopping_criteria and pad_token_id is not None:
240
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
 
 
241
 
242
  # Update generated ids, model inputs, and length for next step
243
  input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
@@ -245,7 +300,11 @@ def generate(
245
  if model_kwargs.get("attention_mask") is not None:
246
  attn = model_kwargs["attention_mask"]
247
  model_kwargs["attention_mask"] = torch.cat(
248
- [attn, torch.ones((batch_size, 1), dtype=attn.dtype, device=attn.device)], dim=-1
 
 
 
 
249
  )
250
  # Update cache_position for next step (single next token)
251
  model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
 
11
  TopKLogitsWarper,
12
  TopPLogitsWarper,
13
  )
14
+ from transformers.generation.utils import (
15
+ GenerateDecoderOnlyOutput,
16
+ GenerateEncoderDecoderOutput,
17
+ )
18
 
19
 
20
  def generate(
 
53
  # Get DeepCONF parameters from generation_config or set defaults
54
  enable_conf = getattr(generation_config, "enable_conf", False)
55
  window_size = getattr(generation_config, "window_size", 2048)
56
+ threshold = getattr(
57
+ generation_config, "threshold", 17.0
58
+ ) # Default threshold for confidence (positive value)
59
+ conf_topk = getattr(
60
+ generation_config, "conf_topk", 20
61
+ ) # Number of top tokens for confidence calculation
62
 
63
  # If DeepCONF is not enabled, fall back to standard sampling
64
  if not enable_conf:
 
90
  return_dict_in_generate = generation_config.return_dict_in_generate
91
  output_confidences = getattr(generation_config, "output_confidences", False)
92
  # Optional DeepConf variant helpers (compute threshold from warmup confidences)
93
+ deepconf_variant = getattr(
94
+ generation_config, "deepconf_variant", None
95
+ ) # "low" or "high"
96
  deepconf_eta = getattr(generation_config, "deepconf_eta", None) # float in (0,1)
97
+ deepconf_warmup_confidences = getattr(
98
+ generation_config, "deepconf_warmup_confidences", None
99
+ ) # list/1D tensor
100
+ has_eos_stopping_criteria = any(
101
+ hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
102
+ )
103
  do_sample = generation_config.do_sample
104
 
105
  # If a variant is requested and a warmup set of confidences is provided, derive the threshold
106
  if enable_conf and threshold is not None:
107
  pass
108
+ elif (
109
+ enable_conf
110
+ and deepconf_variant is not None
111
+ and deepconf_warmup_confidences is not None
112
+ ):
113
  confs = deepconf_warmup_confidences
114
  if hasattr(confs, "detach"):
115
  confs = confs.detach().cpu().numpy()
 
118
  confs = np.asarray(confs, dtype=np.float32).ravel()
119
  eta = deepconf_eta
120
  if eta is None:
121
+ eta = (
122
+ 0.1
123
+ if deepconf_variant == "low"
124
+ else 0.9
125
+ if deepconf_variant == "high"
126
+ else 0.5
127
+ )
128
  pct = max(0.0, min(100.0, 100.0 - (eta * 100.0)))
129
  threshold = float(np.percentile(confs, pct))
130
 
 
133
  raw_logits = () if (return_dict_in_generate and output_logits) else None
134
  decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
135
  cross_attentions = () if (return_dict_in_generate and output_attentions) else None
136
+ decoder_hidden_states = (
137
+ () if (return_dict_in_generate and output_hidden_states) else None
138
+ )
139
 
140
  # If model is an encoder-decoder, retrieve encoder attention weights and hidden states
141
  if return_dict_in_generate and model.config.is_encoder_decoder:
142
+ encoder_attentions = (
143
+ model_kwargs["encoder_outputs"].get("attentions")
144
+ if output_attentions
145
+ else None
146
+ )
147
+ encoder_hidden_states = (
148
+ model_kwargs["encoder_outputs"].get("hidden_states")
149
+ if output_hidden_states
150
+ else None
151
+ )
152
 
153
  # Keep track of which sequences are already finished
154
  batch_size, cur_len = input_ids.shape[:2]
155
+ unfinished_sequences = torch.ones(
156
+ batch_size, dtype=torch.long, device=input_ids.device
157
+ )
158
  # Use public kv-cache via past_key_values
159
 
160
  # Initialize confidence tracking
161
  # Use deque for sliding window with fixed size
162
  conf_group_lists = [deque(maxlen=window_size) for _ in range(batch_size)]
163
+ conf_grouped_sums = [
164
+ 0.0 for _ in range(batch_size)
165
+ ] # Running sums for efficient mean calculation
166
 
167
  # Optional per-step confidences for debugging/visualization
168
  step_confidences = [] if (return_dict_in_generate and output_confidences) else None
 
178
  model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
179
 
180
  # Prepare variable output controls
181
+ model_inputs.update(
182
+ {"output_attentions": output_attentions} if output_attentions else {}
183
+ )
184
+ model_inputs.update(
185
+ {"output_hidden_states": output_hidden_states}
186
+ if output_hidden_states
187
+ else {}
188
+ )
189
 
190
  # Forward pass with proper KV cache handling
191
  with torch.no_grad():
 
224
  raw_logits += (next_token_logits,)
225
  if output_attentions:
226
  decoder_attentions += (
227
+ (outputs.decoder_attentions,)
228
+ if model.config.is_encoder_decoder
229
+ else (outputs.attentions,)
230
  )
231
  if model.config.is_encoder_decoder:
232
  cross_attentions += (outputs.cross_attentions,)
233
 
234
  if output_hidden_states:
235
  decoder_hidden_states += (
236
+ (outputs.decoder_hidden_states,)
237
+ if model.config.is_encoder_decoder
238
+ else (outputs.hidden_states,)
239
  )
240
 
241
  # Token selection
 
250
  # This uses the raw logits (next_token_logits) before warpers are applied.
251
  probs = F.softmax(next_token_logits, dim=-1)
252
 
253
+ deepconf_stopping = torch.ones(
254
+ batch_size, dtype=torch.bool, device=input_ids.device
255
+ )
256
+ step_conf_values = [
257
+ 0.0
258
+ ] * batch_size # collect per-sequence confidences for this step (full batch)
259
 
260
  for i in range(batch_size):
261
  if not unfinished_sequences[i]:
 
284
 
285
  if step_confidences is not None:
286
  # Store this step's confidences as a tensor of shape (batch,)
287
+ step_confidences.append(
288
+ torch.tensor(step_conf_values, device=input_ids.device)
289
+ )
290
 
291
  # Finished sentences should have their next token be a padding token
292
  if has_eos_stopping_criteria and pad_token_id is not None:
293
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
294
+ 1 - unfinished_sequences
295
+ )
296
 
297
  # Update generated ids, model inputs, and length for next step
298
  input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
 
300
  if model_kwargs.get("attention_mask") is not None:
301
  attn = model_kwargs["attention_mask"]
302
  model_kwargs["attention_mask"] = torch.cat(
303
+ [
304
+ attn,
305
+ torch.ones((batch_size, 1), dtype=attn.dtype, device=attn.device),
306
+ ],
307
+ dim=-1,
308
  )
309
  # Update cache_position for next step (single next token)
310
  model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1