Qsevent77 commited on
Commit
3c1448c
·
verified ·
1 Parent(s): 8838f8a

Upload folder using huggingface_hub

Browse files
Files changed (12) hide show
  1. README.md +114 -0
  2. block.py +413 -0
  3. configuration_xlm_roberta.py +130 -0
  4. convert_roberta_weights_to_flash.py +170 -0
  5. embedding.py +95 -0
  6. mha.py +830 -0
  7. mlp.py +237 -0
  8. modeling_lora.py +426 -0
  9. modeling_xlm_roberta.py +1254 -0
  10. rotary.py +659 -0
  11. stochastic_depth.py +97 -0
  12. xlm_padding.py +236 -0
README.md ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - transformers
4
+ - xlm-roberta
5
+ library_name: transformers
6
+ license: cc-by-nc-4.0
7
+ language:
8
+ - multilingual
9
+ - af
10
+ - am
11
+ - ar
12
+ - as
13
+ - az
14
+ - be
15
+ - bg
16
+ - bn
17
+ - br
18
+ - bs
19
+ - ca
20
+ - cs
21
+ - cy
22
+ - da
23
+ - de
24
+ - el
25
+ - en
26
+ - eo
27
+ - es
28
+ - et
29
+ - eu
30
+ - fa
31
+ - fi
32
+ - fr
33
+ - fy
34
+ - ga
35
+ - gd
36
+ - gl
37
+ - gu
38
+ - ha
39
+ - he
40
+ - hi
41
+ - hr
42
+ - hu
43
+ - hy
44
+ - id
45
+ - is
46
+ - it
47
+ - ja
48
+ - jv
49
+ - ka
50
+ - kk
51
+ - km
52
+ - kn
53
+ - ko
54
+ - ku
55
+ - ky
56
+ - la
57
+ - lo
58
+ - lt
59
+ - lv
60
+ - mg
61
+ - mk
62
+ - ml
63
+ - mn
64
+ - mr
65
+ - ms
66
+ - my
67
+ - ne
68
+ - nl
69
+ - 'no'
70
+ - om
71
+ - or
72
+ - pa
73
+ - pl
74
+ - ps
75
+ - pt
76
+ - ro
77
+ - ru
78
+ - sa
79
+ - sd
80
+ - si
81
+ - sk
82
+ - sl
83
+ - so
84
+ - sq
85
+ - sr
86
+ - su
87
+ - sv
88
+ - sw
89
+ - ta
90
+ - te
91
+ - th
92
+ - tl
93
+ - tr
94
+ - ug
95
+ - uk
96
+ - ur
97
+ - uz
98
+ - vi
99
+ - xh
100
+ - yi
101
+ - zh
102
+ ---
103
+ Core implementation of Jina XLM-RoBERTa
104
+
105
+ This implementation is adapted from [XLM-Roberta](https://huggingface.co/docs/transformers/en/model_doc/xlm-roberta). In contrast to the original implementation, this model uses Rotary positional encodings and supports flash-attention 2.
106
+
107
+ ### Models that use this implementation
108
+
109
+ - [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3)
110
+ - [jinaai/jina-colbert-v2](https://huggingface.co/jinaai/jina-colbert-v2)
111
+
112
+ ### Converting weights
113
+
114
+ Weights from an [original XLMRoberta model](https://huggingface.co/FacebookAI/xlm-roberta-large) can be converted using the `convert_roberta_weights_to_flash.py` script in the model repository.
block.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+
4
+ # Copyright (c) 2024, Tri Dao.
5
+
6
+ from functools import partial
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+
13
+ from .mha import MHA
14
+ from .mlp import Mlp
15
+ from .stochastic_depth import StochasticDepth
16
+
17
+ try:
18
+ from flash_attn.ops.triton.layer_norm import RMSNorm, layer_norm_fn
19
+ except ImportError:
20
+ layer_norm_fn, RMSNorm = None, None
21
+
22
+
23
+ class Block(nn.Module):
24
+ def __init__(
25
+ self,
26
+ dim,
27
+ mixer_cls=None,
28
+ mlp_cls=None,
29
+ norm_cls=nn.LayerNorm,
30
+ dropout_cls=nn.Dropout,
31
+ prenorm=True,
32
+ resid_dropout1=0.0,
33
+ resid_dropout2=0.0,
34
+ drop_path1=0.0,
35
+ drop_path2=0.0,
36
+ fused_dropout_add_ln=False,
37
+ return_residual=False,
38
+ residual_in_fp32=False,
39
+ sequence_parallel=False,
40
+ mark_shared_params=False,
41
+ ):
42
+ """
43
+ For prenorm=True, this Block has a slightly different structure compared to a regular
44
+ prenorm Transformer block.
45
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
46
+ [Ref: https://arxiv.org/abs/2002.04745]
47
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
48
+ the hidden_states (output of the MLP) and the residual.
49
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
50
+ The residual needs to be provided (except for the very first block).
51
+
52
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
53
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
54
+
55
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
56
+ This is for performance reason: for post-norm architecture, returning the input allows us
57
+ to fuse the backward of nn.Linear with the residual connection.
58
+ """
59
+ super().__init__()
60
+ self.prenorm = prenorm
61
+ self.fused_dropout_add_ln = fused_dropout_add_ln
62
+ self.return_residual = return_residual
63
+ self.residual_in_fp32 = residual_in_fp32
64
+ if self.residual_in_fp32:
65
+ assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
66
+ if mixer_cls is None:
67
+ mixer_cls = partial(MHA, num_heads=dim // 64)
68
+ if mlp_cls is None:
69
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
70
+ self.mixer = mixer_cls(dim)
71
+ self.dropout1 = dropout_cls(resid_dropout1)
72
+ self.drop_path1 = StochasticDepth(drop_path1, mode="row")
73
+ self.norm1 = norm_cls(dim)
74
+ self.mlp = mlp_cls(dim)
75
+ if not isinstance(self.mlp, nn.Identity):
76
+ self.dropout2 = dropout_cls(resid_dropout2)
77
+ self.drop_path2 = StochasticDepth(drop_path2, mode="row")
78
+ self.norm2 = norm_cls(dim)
79
+
80
+ if self.fused_dropout_add_ln:
81
+ assert layer_norm_fn is not None, "Triton is not installed"
82
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
83
+ self.dropout1, nn.Dropout
84
+ )
85
+
86
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
87
+ # then the input to each worker in the tensor parallel group will be different.
88
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
89
+ # For now this is not an issue because we always use sequence_parallel=True during training
90
+ # and only use sequence_parallel=False during inference.
91
+
92
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
93
+ if sequence_parallel:
94
+ for p in self.norm1.parameters():
95
+ p._sequence_parallel = True
96
+ if hasattr(self, "norm2"):
97
+ for p in self.norm2.parameters():
98
+ p._sequence_parallel = True
99
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
100
+ if mark_shared_params:
101
+ for p in self.norm1.parameters():
102
+ p._shared_params = True
103
+ if hasattr(self, "norm2"):
104
+ for p in self.norm2.parameters():
105
+ p._shared_params = True
106
+
107
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
108
+ return self.mixer.allocate_inference_cache(
109
+ batch_size, max_seqlen, dtype=dtype, **kwargs
110
+ )
111
+
112
+ def forward(
113
+ self,
114
+ hidden_states: Tensor,
115
+ residual: Optional[Tensor] = None,
116
+ mixer_subset=None,
117
+ mixer_kwargs=None,
118
+ ):
119
+ r"""Pass the input through the encoder layer.
120
+
121
+ Args:
122
+ hidden_states: the sequence to the encoder layer (required).
123
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
124
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
125
+ before applying the query projection. Useful for e.g., ViT where we only care
126
+ about the CLS token in the last layer.
127
+ """
128
+ if self.prenorm:
129
+ if not self.fused_dropout_add_ln:
130
+ dropped = self.drop_path1(self.dropout1(hidden_states))
131
+ residual = (dropped + residual) if residual is not None else dropped
132
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
133
+ if self.residual_in_fp32:
134
+ residual = residual.to(torch.float32)
135
+ else:
136
+ if self.drop_path1.p == 0 or not self.training:
137
+ rowscale1 = None
138
+ else:
139
+ rowscale1 = self.drop_path1(
140
+ torch.ones(
141
+ hidden_states.shape[:-1],
142
+ device=hidden_states.device,
143
+ dtype=hidden_states.dtype,
144
+ )
145
+ )
146
+ hidden_states, residual = layer_norm_fn(
147
+ hidden_states,
148
+ self.norm1.weight,
149
+ self.norm1.bias,
150
+ residual=residual,
151
+ eps=self.norm1.eps,
152
+ dropout_p=self.dropout1.p if self.training else 0.0,
153
+ rowscale=rowscale1,
154
+ prenorm=True,
155
+ residual_in_fp32=self.residual_in_fp32,
156
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
157
+ )
158
+ if mixer_kwargs is None:
159
+ mixer_kwargs = {}
160
+ if mixer_subset is not None:
161
+ mixer_kwargs["mixer_subset"] = mixer_subset
162
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
163
+ if mixer_subset is not None:
164
+ residual = residual[:, mixer_subset]
165
+ if not isinstance(self.mlp, nn.Identity):
166
+ if not self.fused_dropout_add_ln:
167
+ dropped = self.drop_path2(self.dropout2(hidden_states))
168
+ residual = (dropped + residual) if residual is not None else dropped
169
+ hidden_states = self.norm2(
170
+ residual.to(dtype=self.norm2.weight.dtype)
171
+ )
172
+ if self.residual_in_fp32:
173
+ residual = residual.to(torch.float32)
174
+ else:
175
+ if self.drop_path2.p == 0 or not self.training:
176
+ rowscale2 = None
177
+ else:
178
+ rowscale2 = self.drop_path2(
179
+ torch.ones(
180
+ hidden_states.shape[:-1],
181
+ device=hidden_states.device,
182
+ dtype=hidden_states.dtype,
183
+ )
184
+ )
185
+ hidden_states, residual = layer_norm_fn(
186
+ hidden_states,
187
+ self.norm2.weight,
188
+ self.norm2.bias,
189
+ residual=residual,
190
+ eps=self.norm2.eps,
191
+ dropout_p=self.dropout2.p if self.training else 0.0,
192
+ rowscale=rowscale2,
193
+ prenorm=True,
194
+ residual_in_fp32=self.residual_in_fp32,
195
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
196
+ )
197
+ hidden_states = self.mlp(hidden_states)
198
+ return hidden_states, residual
199
+ else:
200
+ assert residual is None
201
+ mixer_out = self.mixer(
202
+ hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
203
+ )
204
+ if self.return_residual: # mixer out is actually a pair here
205
+ mixer_out, hidden_states = mixer_out
206
+ if not self.fused_dropout_add_ln:
207
+ hidden_states = self.norm1(
208
+ (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
209
+ dtype=self.norm1.weight.dtype
210
+ )
211
+ )
212
+ else:
213
+ if self.drop_path1.p == 0 or not self.training:
214
+ rowscale1 = None
215
+ else:
216
+ rowscale1 = self.drop_path1(
217
+ torch.ones(
218
+ mixer_out.shape[:-1],
219
+ device=mixer_out.device,
220
+ dtype=mixer_out.dtype,
221
+ )
222
+ )
223
+ hidden_states = layer_norm_fn(
224
+ mixer_out,
225
+ self.norm1.weight,
226
+ self.norm1.bias,
227
+ residual=hidden_states,
228
+ eps=self.norm1.eps,
229
+ dropout_p=self.dropout1.p if self.training else 0.0,
230
+ rowscale=rowscale1,
231
+ prenorm=False,
232
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
233
+ )
234
+ if not isinstance(self.mlp, nn.Identity):
235
+ mlp_out = self.mlp(
236
+ hidden_states, adapter_mask=mixer_kwargs.get("adapter_mask")
237
+ )
238
+ if self.return_residual: # mlp out is actually a pair here
239
+ mlp_out, hidden_states = mlp_out
240
+ if not self.fused_dropout_add_ln:
241
+ hidden_states = self.norm2(
242
+ (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
243
+ dtype=self.norm2.weight.dtype
244
+ )
245
+ )
246
+ else:
247
+ if self.drop_path2.p == 0 or not self.training:
248
+ rowscale2 = None
249
+ else:
250
+ rowscale2 = self.drop_path2(
251
+ torch.ones(
252
+ mlp_out.shape[:-1],
253
+ device=mlp_out.device,
254
+ dtype=mlp_out.dtype,
255
+ )
256
+ )
257
+ hidden_states = layer_norm_fn(
258
+ mlp_out,
259
+ self.norm2.weight,
260
+ self.norm2.bias,
261
+ residual=hidden_states,
262
+ eps=self.norm2.eps,
263
+ dropout_p=self.dropout2.p if self.training else 0.0,
264
+ rowscale=rowscale2,
265
+ prenorm=False,
266
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
267
+ )
268
+ return hidden_states
269
+
270
+
271
+ class ParallelBlock(nn.Module):
272
+ """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
273
+ and PaLM.
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ dim,
279
+ mixer_cls=None,
280
+ mlp_cls=None,
281
+ norm_cls=nn.LayerNorm,
282
+ dropout_cls=nn.Dropout,
283
+ resid_dropout1=0.0,
284
+ resid_dropout2=0.0,
285
+ tied_norm=False,
286
+ fused_dropout_add_ln=False,
287
+ residual_in_fp32=False,
288
+ sequence_parallel=False,
289
+ mark_shared_params=False,
290
+ ):
291
+ """
292
+ This Block has a slightly different structure compared to a regular
293
+ prenorm Transformer block.
294
+ The standard block is: LN -> MHA / MLP -> Dropout -> Add.
295
+ [Ref: https://arxiv.org/abs/2002.04745]
296
+ Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
297
+ the hidden_states (output1 of the MHA / MLP) and the residual.
298
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
299
+ The residual needs to be provided (except for the very first block).
300
+ """
301
+ super().__init__()
302
+ self.tied_norm = tied_norm
303
+ self.fused_dropout_add_ln = fused_dropout_add_ln
304
+ self.residual_in_fp32 = residual_in_fp32
305
+ if mixer_cls is None:
306
+ mixer_cls = partial(MHA, num_heads=dim // 64)
307
+ if mlp_cls is None:
308
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
309
+ self.mixer = mixer_cls(dim)
310
+ self.dropout1 = dropout_cls(resid_dropout1)
311
+ self.norm1 = norm_cls(dim)
312
+ self.mlp = mlp_cls(dim)
313
+ self.dropout2 = dropout_cls(resid_dropout2)
314
+ if not self.tied_norm:
315
+ self.norm2 = norm_cls(dim)
316
+
317
+ if self.fused_dropout_add_ln:
318
+ assert layer_norm_fn is not None, "Triton is not installed"
319
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
320
+ self.dropout1, nn.Dropout
321
+ )
322
+
323
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
324
+ # then the input to each worker in the tensor parallel group will be different.
325
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
326
+ # For now this is not an issue because we always use sequence_parallel=True during training
327
+ # and only use sequence_parallel=False during inference.
328
+
329
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
330
+ if sequence_parallel:
331
+ for p in self.norm1.parameters():
332
+ p._sequence_parallel = True
333
+ if hasattr(self, "norm2"):
334
+ for p in self.norm2.parameters():
335
+ p._sequence_parallel = True
336
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
337
+ if mark_shared_params:
338
+ for p in self.norm1.parameters():
339
+ p._shared_params = True
340
+ if hasattr(self, "norm2"):
341
+ for p in self.norm2.parameters():
342
+ p._shared_params = True
343
+
344
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
345
+ return self.mixer.allocate_inference_cache(
346
+ batch_size, max_seqlen, dtype=dtype, **kwargs
347
+ )
348
+
349
+ def forward(
350
+ self,
351
+ hidden_states1: Tensor,
352
+ hidden_states2: Optional[Tensor] = None,
353
+ residual: Optional[Tensor] = None,
354
+ mixer_kwargs=None,
355
+ ):
356
+ r"""Pass the input through the encoder layer.
357
+
358
+ Args:
359
+ hidden_states1: the output of the previous attention (mixer) or embedding layer.
360
+ hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
361
+ residual.
362
+ """
363
+ # TODO: Ideally we should only do the allgather / allreduce once for
364
+ # the Linear to MLP & Attention
365
+ if not self.fused_dropout_add_ln:
366
+ dropped1 = self.dropout1(hidden_states1)
367
+ # For the very 1st block, we only want 1 dropout, not two different dropouts
368
+ if hidden_states2 is not None:
369
+ dropped2 = self.dropout2(hidden_states2)
370
+ residual = (
371
+ (residual + dropped1 + dropped2)
372
+ if residual is not None
373
+ else dropped1 + dropped2
374
+ )
375
+ else:
376
+ residual = (residual + dropped1) if residual is not None else dropped1
377
+ hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
378
+ hidden_states2 = (
379
+ self.norm2(residual.to(dtype=self.norm2.weight.dtype))
380
+ if not self.tied_norm
381
+ else hidden_states1
382
+ )
383
+ if self.residual_in_fp32:
384
+ residual = residual.to(torch.float32)
385
+ else:
386
+ weight2, bias2 = (
387
+ (self.norm2.weight, self.norm2.bias)
388
+ if not self.tied_norm
389
+ else (None, None)
390
+ )
391
+ hidden_states1, *rest, residual = layer_norm_fn(
392
+ hidden_states1,
393
+ self.norm1.weight,
394
+ self.norm1.bias,
395
+ residual=residual,
396
+ x1=hidden_states2,
397
+ weight1=weight2,
398
+ bias1=bias2,
399
+ eps=self.norm1.eps,
400
+ dropout_p=self.dropout1.p if self.training else 0.0,
401
+ prenorm=True,
402
+ residual_in_fp32=self.residual_in_fp32,
403
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
404
+ )
405
+ if self.tied_norm:
406
+ hidden_states2 = hidden_states1
407
+ else:
408
+ (hidden_states2,) = rest
409
+ if mixer_kwargs is None:
410
+ mixer_kwargs = {}
411
+ hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
412
+ hidden_states2 = self.mlp(hidden_states2)
413
+ return hidden_states1, hidden_states2, residual
configuration_xlm_roberta.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ import torch
4
+ from transformers import PretrainedConfig
5
+
6
+
7
+ class XLMRobertaFlashConfig(PretrainedConfig):
8
+
9
+ model_type = "xlm-roberta"
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size: int = 250002,
14
+ hidden_size: int = 1024,
15
+ num_hidden_layers: int = 24,
16
+ num_attention_heads: int = 16,
17
+ intermediate_size: int = 4096,
18
+ hidden_act: str = "gelu",
19
+ hidden_dropout_prob: float = 0.1,
20
+ attention_probs_dropout_prob: float = 0.1,
21
+ max_position_embeddings: int = 8194,
22
+ type_vocab_size: int = 1,
23
+ initializer_range: float = 0.02,
24
+ layer_norm_eps: float = 1e-05,
25
+ pad_token_id: int = 1,
26
+ bos_token_id: int = 0,
27
+ eos_token_id: int = 2,
28
+ position_embedding_type: str = "rotary",
29
+ rotary_emb_base: float = 10000.0,
30
+ use_cache: bool = True,
31
+ use_reentrant: bool = False,
32
+ classifier_dropout: Optional[float] = None,
33
+ lora_adaptations: Optional[List[str]] = None,
34
+ task_instructions: Optional[Dict[str, str]] = None,
35
+ lora_rank: int = 4,
36
+ lora_dropout_p: float = 0.0,
37
+ lora_alpha: int = 1,
38
+ lora_main_params_trainable: bool = False,
39
+ load_trained_adapters: bool = False,
40
+ use_flash_attn: bool = True,
41
+ torch_dtype: Optional[Union[str, torch.dtype]] = None,
42
+ emb_pooler: Optional[str] = None,
43
+ matryoshka_dimensions: Optional[List[int]] = None,
44
+ truncate_dim: Optional[int] = None,
45
+ **kwargs: Dict[str, Any],
46
+ ):
47
+ """
48
+ Initialize the XLMRobertaFlashConfig configuration.
49
+
50
+ Args:
51
+ vocab_size (int): Size of the vocabulary.
52
+ hidden_size (int): Dimensionality of the encoder layers and the pooler layer.
53
+ num_hidden_layers (int): Number of hidden layers in the Transformer encoder.
54
+ num_attention_heads (int): Number of attention heads for each attention layer in the Transformer encoder.
55
+ intermediate_size (int): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer.
56
+ hidden_act (str): The activation function to use.
57
+ hidden_dropout_prob (float): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
58
+ attention_probs_dropout_prob (float): The dropout ratio for the attention probabilities.
59
+ max_position_embeddings (int): The maximum length of the position embeddings.
60
+ type_vocab_size (int): The vocabulary size of the token type ids.
61
+ initializer_range (float): The standard deviation for initializing all weight matrices.
62
+ layer_norm_eps (float): The epsilon used by the layer normalization layers.
63
+ pad_token_id (int): The ID of the padding token.
64
+ bos_token_id (int): The ID of the beginning-of-sequence token.
65
+ eos_token_id (int): The ID of the end-of-sequence token.
66
+ position_embedding_type (str): Type of position embeddings. Options are 'absolute', 'alibi', or 'rotary'.
67
+ rotary_emb_base (float): Base for rotary embeddings.
68
+ use_cache (bool): Whether or not the model should return the last key/values attentions (not used by all models).
69
+ use_reentrant (bool): Whether or not the model should enable the 'use_reentrant' flag in gradient checkpointing.
70
+ classifier_dropout (Optional[float]): The dropout ratio for the classification head.
71
+ lora_adaptations (Optional[List[str]]): LoRA adaptations configuration.
72
+ lora_prompts (Optional[Dict[str, str]]): LoRA prompts configuration.
73
+ lora_rank (int): Rank for LoRA adaptations.
74
+ lora_dropout_p (float): Dropout probability for LoRA adaptations.
75
+ lora_alpha (int): Alpha parameter for LoRA.
76
+ lora_main_params_trainable (bool): Whether to make the main model parameters trainable when using LoRA.
77
+ load_trained_adapters (bool): Whether to load trained adapters.
78
+ use_flash_attn (bool): Whether to use FlashAttention.
79
+ torch_dtype (Optional[Union[str, torch.dtype]]): Data type for the tensors.
80
+ emb_pooler (Optional[str]): Pooling layer configuration.
81
+ matryoshka_dimensions (Optional[List[int]]): Configuration for matryoshka dimension reduction.
82
+ truncate_dim (Optional[int]): Dimension to truncate embeddings to, if any.
83
+ **kwargs (Dict[str, Any]): Additional keyword arguments passed to the configuration.
84
+ """
85
+
86
+ super().__init__(
87
+ pad_token_id=pad_token_id,
88
+ bos_token_id=bos_token_id,
89
+ eos_token_id=eos_token_id,
90
+ **kwargs,
91
+ )
92
+
93
+ self.vocab_size = vocab_size
94
+ self.hidden_size = hidden_size
95
+ self.num_hidden_layers = num_hidden_layers
96
+ self.num_attention_heads = num_attention_heads
97
+ self.hidden_act = hidden_act
98
+ self.intermediate_size = intermediate_size
99
+ self.hidden_dropout_prob = hidden_dropout_prob
100
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
101
+ self.max_position_embeddings = max_position_embeddings
102
+ self.type_vocab_size = type_vocab_size
103
+ self.initializer_range = initializer_range
104
+ self.layer_norm_eps = layer_norm_eps
105
+ self.position_embedding_type = position_embedding_type
106
+ self.rotary_emb_base = rotary_emb_base
107
+ self.use_cache = use_cache
108
+ self.use_reentrant = use_reentrant
109
+ self.classifier_dropout = classifier_dropout
110
+ self.load_trained_adapters = load_trained_adapters
111
+ self.lora_adaptations = lora_adaptations
112
+ self.task_instructions = task_instructions
113
+ self.lora_rank = lora_rank
114
+ self.lora_dropout_p = lora_dropout_p
115
+ self.lora_alpha = lora_alpha
116
+ self.lora_main_params_trainable = lora_main_params_trainable
117
+ self.use_flash_attn = use_flash_attn
118
+ self.emb_pooler = emb_pooler
119
+ self.matryoshka_dimensions = matryoshka_dimensions
120
+ self.truncate_dim = truncate_dim
121
+ if (
122
+ torch_dtype
123
+ and hasattr(torch, torch_dtype)
124
+ and type(getattr(torch, torch_dtype)) is torch.dtype
125
+ ):
126
+ self.torch_dtype = getattr(torch, torch_dtype)
127
+ else:
128
+ self.torch_dtype = torch_dtype
129
+ if not self.use_flash_attn or not torch.cuda.is_available():
130
+ self.torch_dtype = torch.float32
convert_roberta_weights_to_flash.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import OrderedDict
3
+ from transformers import PretrainedConfig
4
+ from transformers import XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification
5
+
6
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig as BertConfig
7
+ from .modeling_xlm_roberta import XLMRobertaForMaskedLM as FlashXLMRobertaForMaskedLM
8
+ from .modeling_xlm_roberta import XLMRobertaForSequenceClassification as FlashXLMRobertaForSequenceClassification
9
+ import torch
10
+
11
+ import click
12
+
13
+ ## inspired by https://github.com/Dao-AILab/flash-attention/blob/85881f547fd1053a7b4a2c3faad6690cca969279/flash_attn/models/bert.py
14
+
15
+
16
+ def remap_state_dict(state_dict, config: PretrainedConfig):
17
+ """
18
+ Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
19
+ """
20
+
21
+ # LayerNorm
22
+ def key_mapping_ln_gamma_beta(key):
23
+ key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
24
+ key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
25
+ return key
26
+
27
+ state_dict = OrderedDict(
28
+ (key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()
29
+ )
30
+
31
+ # Layers
32
+ def key_mapping_layers(key):
33
+ return re.sub(r"^roberta.encoder.layer.", "roberta.encoder.layers.", key)
34
+
35
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
36
+
37
+ # LayerNorm
38
+ def key_mapping_ln(key):
39
+ key = re.sub(r"^roberta.embeddings.LayerNorm.", "roberta.emb_ln.", key)
40
+ key = re.sub(
41
+ r"^roberta.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
42
+ r"roberta.encoder.layers.\1.norm1.\2",
43
+ key,
44
+ )
45
+ key = re.sub(
46
+ r"^roberta.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
47
+ r"roberta.encoder.layers.\1.norm2.\2",
48
+ key,
49
+ )
50
+ key = re.sub(
51
+ r"^cls.predictions.transform.LayerNorm.(weight|bias)",
52
+ r"cls.predictions.transform.layer_norm.\1",
53
+ key,
54
+ )
55
+ return key
56
+
57
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
58
+
59
+ # MLP
60
+ def key_mapping_mlp(key):
61
+ key = re.sub(
62
+ r"^roberta.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
63
+ r"roberta.encoder.layers.\1.mlp.fc1.\2",
64
+ key,
65
+ )
66
+ key = re.sub(
67
+ r"^roberta.encoder.layers.(\d+).output.dense.(weight|bias)",
68
+ r"roberta.encoder.layers.\1.mlp.fc2.\2",
69
+ key,
70
+ )
71
+ return key
72
+
73
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
74
+
75
+ # Attention
76
+ last_layer_subset = getattr(config, "last_layer_subset", False)
77
+ for d in range(config.num_hidden_layers):
78
+ Wq = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.query.weight")
79
+ Wk = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.key.weight")
80
+ Wv = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.value.weight")
81
+ bq = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.query.bias")
82
+ bk = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.key.bias")
83
+ bv = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.value.bias")
84
+ if not (last_layer_subset and d == config.num_hidden_layers - 1):
85
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
86
+ [Wq, Wk, Wv], dim=0
87
+ )
88
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat(
89
+ [bq, bk, bv], dim=0
90
+ )
91
+ else:
92
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wq.weight"] = Wq
93
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat(
94
+ [Wk, Wv], dim=0
95
+ )
96
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wq.bias"] = bq
97
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat(
98
+ [bk, bv], dim=0
99
+ )
100
+
101
+ def key_mapping_attn(key):
102
+ return re.sub(
103
+ r"^roberta.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
104
+ r"roberta.encoder.layers.\1.mixer.out_proj.\2",
105
+ key,
106
+ )
107
+
108
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
109
+
110
+ def key_mapping_decoder_bias(key):
111
+ return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
112
+
113
+ state_dict = OrderedDict(
114
+ (key_mapping_decoder_bias(k), v) for k, v in state_dict.items()
115
+ )
116
+
117
+ # Word embedding
118
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
119
+ if pad_vocab_size_multiple > 1:
120
+ word_embeddings = state_dict["roberta.embeddings.word_embeddings.weight"]
121
+ state_dict["roberta.embeddings.word_embeddings.weight"] = F.pad(
122
+ word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
123
+ )
124
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
125
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
126
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
127
+ )
128
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
129
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
130
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
131
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
132
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
133
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
134
+ )
135
+
136
+ return state_dict
137
+
138
+
139
+ @click.command()
140
+ @click.option('--model_name', default='FacebookAI/xlm-roberta-base', help='model name')
141
+ @click.option('--revision', default='main', help='revision')
142
+ @click.option('--task', default='masked_lm', help='task')
143
+ @click.option('--output', default='converted_roberta_weights.bin', help='model name')
144
+ def main(model_name, revision, task, output):
145
+
146
+ if task == 'masked_lm':
147
+ roberta_model = XLMRobertaForMaskedLM.from_pretrained(model_name, revision=revision)
148
+ elif task == 'sequence_classification':
149
+ roberta_model = XLMRobertaForSequenceClassification.from_pretrained(model_name, revision=revision,num_labels=1)
150
+ config = BertConfig.from_dict(roberta_model.config.to_dict())
151
+ state_dict = roberta_model.state_dict()
152
+ new_state_dict = remap_state_dict(state_dict, config)
153
+
154
+ if task == 'masked_lm':
155
+ flash_model = FlashXLMRobertaForMaskedLM(config)
156
+ elif task == 'sequence_classification':
157
+ flash_model = FlashXLMRobertaForSequenceClassification(config)
158
+
159
+ for k, v in flash_model.state_dict().items():
160
+ if k not in new_state_dict:
161
+ print(f'Use old weights from {k}')
162
+ new_state_dict[k] = v
163
+
164
+ flash_model.load_state_dict(new_state_dict)
165
+
166
+ torch.save(new_state_dict, output)
167
+
168
+
169
+ if __name__ == '__main__':
170
+ main()
embedding.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py
2
+ # Commit id: f1a73d074002226c42ce65a1df170ecff9f022c0
3
+
4
+ # Copyright (c) 2022, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import \
9
+ create_position_ids_from_input_ids
10
+
11
+
12
+ class XLMRobertaEmbeddings(nn.Module):
13
+ def __init__(
14
+ self,
15
+ embed_dim,
16
+ vocab_size,
17
+ max_position_embeddings,
18
+ type_vocab_size,
19
+ padding_idx=None,
20
+ device=None,
21
+ dtype=None,
22
+ ):
23
+ """
24
+ If max_position_embeddings <= 0, there's no position embeddings
25
+ If type_vocab_size <= 0, there's no token type embeddings
26
+ """
27
+ factory_kwargs = {"device": device, "dtype": dtype}
28
+ super().__init__()
29
+ self.word_embeddings = nn.Embedding(
30
+ vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
31
+ )
32
+ self.max_position_embeddings = max_position_embeddings
33
+ self.type_vocab_size = type_vocab_size
34
+ if self.max_position_embeddings > 0:
35
+ self.position_embeddings = nn.Embedding(
36
+ max_position_embeddings, embed_dim, **factory_kwargs
37
+ )
38
+ if self.type_vocab_size > 0:
39
+ self.token_type_embeddings = nn.Embedding(
40
+ type_vocab_size, embed_dim, **factory_kwargs
41
+ )
42
+
43
+ def forward(
44
+ self, input_ids, position_ids=None, token_type_ids=None, adapter_mask=None
45
+ ):
46
+ """
47
+ input_ids: (batch, seqlen)
48
+ position_ids: (batch, seqlen)
49
+ token_type_ids: (batch, seqlen)
50
+ adapter_mask: (batch, 1)
51
+ """
52
+ batch_size, seqlen = input_ids.shape
53
+ if adapter_mask is not None:
54
+ unique_tasks = torch.unique(adapter_mask)
55
+ embedding_dtype = next(self.word_embeddings.parameters()).dtype
56
+ embeddings = torch.empty(
57
+ *input_ids.shape,
58
+ self.word_embeddings.embedding_dim,
59
+ dtype=embedding_dtype,
60
+ device=input_ids.device
61
+ )
62
+ for task_id in unique_tasks:
63
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
64
+ task_input_ids = input_ids[task_indices]
65
+ task_embeddings = self.word_embeddings(task_input_ids, task_id=task_id)
66
+ embeddings[task_indices] = task_embeddings
67
+ else:
68
+ embeddings = self.word_embeddings(input_ids)
69
+ if self.max_position_embeddings > 0:
70
+ if position_ids is None:
71
+ position_ids = create_position_ids_from_input_ids(
72
+ input_ids, padding_idx=self.word_embeddings.padding_idx
73
+ ).to(input_ids.device)
74
+ position_embeddings = self.position_embeddings(position_ids)
75
+ embeddings = embeddings + position_embeddings
76
+ if self.type_vocab_size > 0:
77
+ if token_type_ids is None:
78
+ token_type_ids = torch.zeros(
79
+ seqlen, dtype=torch.long, device=input_ids.device
80
+ )
81
+
82
+ if adapter_mask is not None:
83
+ unique_tasks = torch.unique(adapter_mask)
84
+ for task_id in unique_tasks:
85
+ task_token_type_embeddings = self.token_type_embeddings(
86
+ token_type_ids, task_id=task_id
87
+ )
88
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
89
+ embeddings[task_indices] = (
90
+ embeddings[task_indices] + task_token_type_embeddings
91
+ )
92
+ else:
93
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
94
+ embeddings = embeddings + token_type_embeddings
95
+ return embeddings
mha.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py
2
+ # Commit id: 6bbc532388e61185a92e2a563126739967b4c8c5
3
+ # Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
4
+
5
+ # Copyright (c) 2023, Tri Dao.
6
+
7
+ import math
8
+ from functools import partial
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange, repeat
13
+
14
+ try:
15
+ from flash_attn import (flash_attn_kvpacked_func,
16
+ flash_attn_qkvpacked_func,
17
+ flash_attn_varlen_kvpacked_func,
18
+ flash_attn_varlen_qkvpacked_func,
19
+ flash_attn_with_kvcache)
20
+ except ImportError:
21
+ flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
22
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
23
+ flash_attn_with_kvcache = None
24
+
25
+ try:
26
+ from flash_attn.ops.fused_dense import (ColumnParallelLinear, FusedDense,
27
+ RowParallelLinear)
28
+ except ImportError:
29
+ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
30
+
31
+ from .rotary import RotaryEmbedding
32
+
33
+
34
+ # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
35
+ def get_alibi_slopes(nheads):
36
+ def get_slopes_power_of_2(nheads):
37
+ start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
38
+ ratio = start
39
+ return [start * ratio**i for i in range(nheads)]
40
+
41
+ if math.log2(nheads).is_integer():
42
+ return get_slopes_power_of_2(nheads)
43
+ else:
44
+ closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
45
+ return (
46
+ get_slopes_power_of_2(closest_power_of_2)
47
+ + get_alibi_slopes(2 * closest_power_of_2)[0::2][
48
+ : nheads - closest_power_of_2
49
+ ]
50
+ )
51
+
52
+
53
+ class FlashSelfAttention(nn.Module):
54
+ """Implement the scaled dot product attention with softmax.
55
+ Arguments
56
+ ---------
57
+ softmax_scale: The temperature to use for the softmax attention.
58
+ (default: 1/sqrt(d_keys) where d_keys is computed at
59
+ runtime)
60
+ attention_dropout: The dropout rate to apply to the attention
61
+ (default: 0.0)
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ causal=False,
67
+ softmax_scale=None,
68
+ attention_dropout=0.0,
69
+ window_size=(-1, -1),
70
+ alibi_slopes=None,
71
+ deterministic=False,
72
+ ):
73
+ super().__init__()
74
+ assert (
75
+ flash_attn_varlen_qkvpacked_func is not None
76
+ ), "FlashAttention is not installed"
77
+ assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
78
+ self.causal = causal
79
+ self.softmax_scale = softmax_scale
80
+ self.drop = nn.Dropout(attention_dropout)
81
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
82
+ self.window_size = window_size
83
+ self.deterministic = deterministic
84
+
85
+ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
86
+ """Implements the multihead softmax attention.
87
+ Arguments
88
+ ---------
89
+ qkv: The tensor containing the query, key, and value.
90
+ If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
91
+ If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
92
+ (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
93
+ causal: if passed, will override self.causal
94
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
95
+ of the sequences in the batch, used to index into qkv.
96
+ max_seqlen: int. Maximum sequence length in the batch.
97
+ Returns:
98
+ --------
99
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
100
+ else (B, S, H, D).
101
+ """
102
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
103
+ assert qkv.is_cuda
104
+ causal = self.causal if causal is None else causal
105
+ unpadded = cu_seqlens is not None
106
+ if self.alibi_slopes is not None:
107
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
108
+ if unpadded:
109
+ assert cu_seqlens.dtype == torch.int32
110
+ assert max_seqlen is not None
111
+ assert isinstance(max_seqlen, int)
112
+ return flash_attn_varlen_qkvpacked_func(
113
+ qkv,
114
+ cu_seqlens,
115
+ max_seqlen,
116
+ self.drop.p if self.training else 0.0,
117
+ softmax_scale=self.softmax_scale,
118
+ causal=causal,
119
+ alibi_slopes=self.alibi_slopes,
120
+ window_size=self.window_size,
121
+ deterministic=self.deterministic,
122
+ )
123
+ else:
124
+ return flash_attn_qkvpacked_func(
125
+ qkv,
126
+ self.drop.p if self.training else 0.0,
127
+ softmax_scale=self.softmax_scale,
128
+ causal=causal,
129
+ alibi_slopes=self.alibi_slopes,
130
+ window_size=self.window_size,
131
+ deterministic=self.deterministic,
132
+ )
133
+
134
+
135
+ class FlashCrossAttention(nn.Module):
136
+ """Implement the scaled dot product attention with softmax.
137
+ Arguments
138
+ ---------
139
+ softmax_scale: The temperature to use for the softmax attention.
140
+ (default: 1/sqrt(d_keys) where d_keys is computed at
141
+ runtime)
142
+ attention_dropout: The dropout rate to apply to the attention
143
+ (default: 0.0)
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ causal=False,
149
+ softmax_scale=None,
150
+ attention_dropout=0.0,
151
+ alibi_slopes=None,
152
+ window_size=(-1, -1),
153
+ deterministic=False,
154
+ ):
155
+ super().__init__()
156
+ assert (
157
+ flash_attn_varlen_kvpacked_func is not None
158
+ ), "FlashAttention is not installed"
159
+ assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
160
+ self.causal = causal
161
+ self.softmax_scale = softmax_scale
162
+ self.drop = nn.Dropout(attention_dropout)
163
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
164
+ self.window_size = window_size
165
+ self.deterministic = deterministic
166
+
167
+ def forward(
168
+ self,
169
+ q,
170
+ kv,
171
+ causal=None,
172
+ cu_seqlens=None,
173
+ max_seqlen=None,
174
+ cu_seqlens_k=None,
175
+ max_seqlen_k=None,
176
+ ):
177
+ """Implements the multihead softmax attention.
178
+ Arguments
179
+ ---------
180
+ q: The tensor containing the query. (B, Sq, H, D)
181
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
182
+ causal: if passed, will override self.causal
183
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
184
+ of the sequences in the batch, used to index into q.
185
+ max_seqlen: int. Maximum sequence length in the batch of q.
186
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
187
+ of the sequences in the batch, used to index into kv.
188
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
189
+ """
190
+ assert q.dtype in [torch.float16, torch.bfloat16]
191
+ assert q.is_cuda and kv.is_cuda
192
+ causal = self.causal if causal is None else causal
193
+ unpadded = cu_seqlens is not None
194
+ if self.alibi_slopes is not None:
195
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
196
+ if unpadded:
197
+ assert cu_seqlens.dtype == torch.int32
198
+ assert max_seqlen is not None
199
+ assert isinstance(max_seqlen, int)
200
+ assert cu_seqlens_k is not None
201
+ assert cu_seqlens_k.dtype == torch.int32
202
+ assert max_seqlen_k is not None
203
+ assert isinstance(max_seqlen, int)
204
+ return flash_attn_varlen_kvpacked_func(
205
+ q,
206
+ kv,
207
+ cu_seqlens,
208
+ cu_seqlens_k,
209
+ max_seqlen,
210
+ max_seqlen_k,
211
+ self.drop.p if self.training else 0.0,
212
+ softmax_scale=self.softmax_scale,
213
+ causal=causal,
214
+ alibi_slopes=self.alibi_slopes,
215
+ window_size=self.window_size,
216
+ deterministic=self.deterministic,
217
+ )
218
+ else:
219
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
220
+ seqlen_k = kv.shape[1]
221
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
222
+ return flash_attn_kvpacked_func(
223
+ q,
224
+ kv,
225
+ self.drop.p if self.training else 0.0,
226
+ causal=causal,
227
+ softmax_scale=self.softmax_scale,
228
+ alibi_slopes=self.alibi_slopes,
229
+ window_size=self.window_size,
230
+ deterministic=self.deterministic,
231
+ )
232
+
233
+
234
+ class SelfAttention(nn.Module):
235
+ """Implement the scaled dot product attention with softmax.
236
+ Arguments
237
+ ---------
238
+ softmax_scale: The temperature to use for the softmax attention.
239
+ (default: 1/sqrt(d_keys) where d_keys is computed at
240
+ runtime)
241
+ attention_dropout: The dropout rate to apply to the attention
242
+ (default: 0.0)
243
+ """
244
+
245
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
246
+ super().__init__()
247
+ self.causal = causal
248
+ self.softmax_scale = softmax_scale
249
+ self.drop = nn.Dropout(attention_dropout)
250
+
251
+ def forward(self, qkv, causal=None, key_padding_mask=None):
252
+ """Implements the multihead softmax attention.
253
+ Arguments
254
+ ---------
255
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
256
+ causal: if passed, will override self.causal
257
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
258
+ False means to mask out. (B, S)
259
+ """
260
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
261
+ causal = self.causal if causal is None else causal
262
+ q, k, v = qkv.unbind(dim=2)
263
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
264
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
265
+ if key_padding_mask is not None:
266
+ padding_mask = torch.full(
267
+ (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
268
+ )
269
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
270
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
271
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
272
+ if causal:
273
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
274
+ # So we have to construct the mask in float
275
+ causal_mask = torch.triu(
276
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
277
+ )
278
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
279
+ scores = scores + causal_mask.to(dtype=scores.dtype)
280
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
281
+ attention_drop = self.drop(attention)
282
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
283
+ return output
284
+
285
+
286
+ class CrossAttention(nn.Module):
287
+ """Implement the scaled dot product attention with softmax.
288
+ Arguments
289
+ ---------
290
+ softmax_scale: The temperature to use for the softmax attention.
291
+ (default: 1/sqrt(d_keys) where d_keys is computed at
292
+ runtime)
293
+ attention_dropout: The dropout rate to apply to the attention
294
+ (default: 0.0)
295
+ """
296
+
297
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
298
+ super().__init__()
299
+ self.causal = causal
300
+ self.softmax_scale = softmax_scale
301
+ self.drop = nn.Dropout(attention_dropout)
302
+
303
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
304
+ """Implements the multihead softmax attention.
305
+ Arguments
306
+ ---------
307
+ q: The tensor containing the query. (B, Sq, H, D)
308
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
309
+ causal: if passed, will override self.causal
310
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
311
+ False means to mask out. (B, Sk)
312
+ """
313
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
314
+ causal = self.causal if causal is None else causal
315
+ seqlen_k = kv.shape[1]
316
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
317
+ if kv.shape[3] != q.shape[2]: # MQA/GQA
318
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
319
+ k, v = kv.unbind(dim=2)
320
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
321
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
322
+ if key_padding_mask is not None:
323
+ padding_mask = torch.full(
324
+ (batch_size, seqlen_k),
325
+ -10000.0,
326
+ dtype=scores.dtype,
327
+ device=scores.device,
328
+ )
329
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
330
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
331
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
332
+ if causal:
333
+ # causal mask needs to take into account the difference between seqlen_q and seqlen_k
334
+ row_idx = rearrange(
335
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
336
+ )
337
+ col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
338
+ sk = (
339
+ seqlen_k
340
+ if key_padding_mask is None
341
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
342
+ )
343
+ causal_mask = col_idx > row_idx + sk - seqlen_q
344
+ scores = scores.masked_fill(causal_mask, -10000.0)
345
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
346
+ attention_drop = self.drop(attention)
347
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
348
+ return output
349
+
350
+
351
+ class LinearResidual(nn.Linear):
352
+ """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
353
+
354
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
355
+ return super().forward(input), input
356
+
357
+
358
+ def _update_kv_cache(kv, inference_params, layer_idx):
359
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
360
+ # Pre-allocate memory for key-values for inference.
361
+ num_heads, head_dim = kv.shape[-2:]
362
+ if layer_idx not in inference_params.key_value_memory_dict:
363
+ kv_cache = torch.empty(
364
+ inference_params.max_batch_size,
365
+ inference_params.max_seqlen,
366
+ 2,
367
+ num_heads,
368
+ head_dim,
369
+ dtype=kv.dtype,
370
+ device=kv.device,
371
+ )
372
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
373
+ else:
374
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
375
+ # Adjust key and value for inference
376
+ batch_start = inference_params.batch_size_offset
377
+ batch_end = batch_start + kv.shape[0]
378
+ sequence_start = inference_params.seqlen_offset
379
+ sequence_end = sequence_start + kv.shape[1]
380
+ assert batch_end <= kv_cache.shape[0]
381
+ assert sequence_end <= kv_cache.shape[1]
382
+ assert kv_cache is not None
383
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
384
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
385
+
386
+
387
+ class MHA(nn.Module):
388
+ """Multi-head self-attention and cross-attention"""
389
+
390
+ def __init__(
391
+ self,
392
+ embed_dim,
393
+ num_heads,
394
+ num_heads_kv=None,
395
+ cross_attn=False,
396
+ qkv_proj_bias=True,
397
+ out_proj_bias=True,
398
+ dropout=0.0,
399
+ softmax_scale=None,
400
+ causal=False,
401
+ layer_idx=None,
402
+ dwconv=False,
403
+ rotary_emb_dim=0,
404
+ rotary_emb_base=10000.0,
405
+ rotary_emb_scale_base=None,
406
+ rotary_emb_interleaved=False,
407
+ use_alibi=False,
408
+ window_size=(-1, -1),
409
+ fused_bias_fc=False,
410
+ use_flash_attn=False,
411
+ return_residual=False,
412
+ checkpointing=False,
413
+ device=None,
414
+ dtype=None,
415
+ ) -> None:
416
+ """
417
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
418
+ return_residual: whether to return the input x along with the output. This is for
419
+ performance reason: for post-norm architecture, returning the input allows us
420
+ to fuse the backward of nn.Linear with the residual connection.
421
+ """
422
+ factory_kwargs = {"device": device, "dtype": dtype}
423
+ super().__init__()
424
+ self.embed_dim = embed_dim
425
+ self.cross_attn = cross_attn
426
+ self.causal = causal
427
+ self.layer_idx = layer_idx
428
+ self.dwconv = dwconv
429
+ self.rotary_emb_dim = rotary_emb_dim
430
+ self.use_flash_attn = use_flash_attn
431
+ self.return_residual = return_residual
432
+ self.checkpointing = checkpointing
433
+ if use_alibi:
434
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
435
+ alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
436
+ else:
437
+ alibi_slopes = None
438
+ if window_size != (-1, -1):
439
+ assert (
440
+ use_flash_attn
441
+ ), "Local (sliding window) attention code path requires flash_attn"
442
+
443
+ self.num_heads = num_heads
444
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
445
+ assert (
446
+ self.num_heads % self.num_heads_kv == 0
447
+ ), "num_heads must be divisible by num_heads_kv"
448
+ assert (
449
+ self.embed_dim % num_heads == 0
450
+ ), "embed_dim must be divisible by num_heads"
451
+ self.head_dim = self.embed_dim // num_heads
452
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
453
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
454
+
455
+ if self.rotary_emb_dim > 0:
456
+ assert (
457
+ not cross_attn
458
+ ), "MHA with rotary embedding does not support cross-attention yet"
459
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
460
+ self.rotary_emb = RotaryEmbedding(
461
+ self.rotary_emb_dim,
462
+ base=rotary_emb_base,
463
+ scale_base=rotary_emb_scale_base,
464
+ interleaved=rotary_emb_interleaved,
465
+ device=device,
466
+ use_flash_attn=use_flash_attn,
467
+ )
468
+
469
+ if fused_bias_fc and FusedDense is None:
470
+ raise ImportError("fused_dense is not installed")
471
+
472
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
473
+ linear_resid_cls = (
474
+ LinearResidual
475
+ if not fused_bias_fc
476
+ else partial(FusedDense, return_residual=True)
477
+ )
478
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
479
+ inner_attn_cls = (
480
+ partial(
481
+ FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size
482
+ )
483
+ if use_flash_attn
484
+ else SelfAttention
485
+ )
486
+ inner_cross_attn_cls = (
487
+ partial(
488
+ FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size
489
+ )
490
+ if use_flash_attn
491
+ else CrossAttention
492
+ )
493
+ if not self.cross_attn:
494
+ self.Wqkv = wqkv_cls(
495
+ embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs
496
+ )
497
+ else:
498
+ self.Wq = linear_cls(
499
+ embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs
500
+ )
501
+ self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
502
+ if self.dwconv:
503
+ if self.num_heads_kv == self.num_heads:
504
+ self.dwconv_qkv = nn.Conv1d(
505
+ qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
506
+ )
507
+ else:
508
+ self.dwconv_q = nn.Conv1d(
509
+ embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
510
+ )
511
+ self.dwconv_kv = nn.Conv1d(
512
+ kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim
513
+ )
514
+ self.inner_attn = inner_attn_cls(
515
+ causal=causal,
516
+ softmax_scale=softmax_scale,
517
+ attention_dropout=dropout,
518
+ )
519
+ self.inner_cross_attn = inner_cross_attn_cls(
520
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
521
+ )
522
+ self.out_proj = linear_cls(
523
+ embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs
524
+ )
525
+
526
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
527
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
528
+ device = self.out_proj.weight.device
529
+ return torch.empty(
530
+ batch_size,
531
+ max_seqlen,
532
+ 2,
533
+ self.num_heads_kv,
534
+ self.head_dim,
535
+ dtype=dtype,
536
+ device=device,
537
+ )
538
+
539
+ def _update_kv_cache(self, kv, inference_params):
540
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
541
+ assert not self.dwconv, "Generation does not support dwconv yet"
542
+ assert (
543
+ self.layer_idx is not None
544
+ ), "Generation requires layer_idx in the constructor"
545
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
546
+
547
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
548
+ """
549
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
550
+ q: (batch_size, seqlen_q, nheads, head_dim)
551
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
552
+ """
553
+ assert inference_params is not None and inference_params.seqlen_offset > 0
554
+ assert self.use_flash_attn
555
+ if self.rotary_emb_dim > 0:
556
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
557
+ self.rotary_emb._update_cos_sin_cache(
558
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
559
+ )
560
+ rotary_cos, rotary_sin = (
561
+ self.rotary_emb._cos_cached,
562
+ self.rotary_emb._sin_cached,
563
+ )
564
+ else:
565
+ rotary_cos, rotary_sin = None, None
566
+ batch = q.shape[0]
567
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
568
+ cache_seqlens = (
569
+ inference_params.lengths_per_sample[:batch]
570
+ if inference_params.lengths_per_sample is not None
571
+ else inference_params.seqlen_offset
572
+ )
573
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
574
+ context = flash_attn_with_kvcache(
575
+ q,
576
+ kv_cache[:, :, 0],
577
+ kv_cache[:, :, 1],
578
+ kv[:, :, 0],
579
+ kv[:, :, 1],
580
+ rotary_cos=rotary_cos,
581
+ rotary_sin=rotary_sin,
582
+ cache_seqlens=cache_seqlens,
583
+ softmax_scale=self.inner_cross_attn.softmax_scale,
584
+ causal=self.inner_cross_attn.causal,
585
+ rotary_interleaved=(
586
+ self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False
587
+ ),
588
+ alibi_slopes=alibi_slopes,
589
+ )
590
+ return context
591
+
592
+ def _update_kvcache_attention(self, q, kv, inference_params):
593
+ """Write kv to inference_params, then do attention"""
594
+ if (
595
+ inference_params.seqlen_offset == 0
596
+ or flash_attn_with_kvcache is None
597
+ or not self.use_flash_attn
598
+ ):
599
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
600
+ kv = self._update_kv_cache(kv, inference_params)
601
+ return self.inner_cross_attn(q, kv)
602
+ else:
603
+ batch = q.shape[0]
604
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
605
+ cache_seqlens = (
606
+ inference_params.lengths_per_sample[:batch]
607
+ if inference_params.lengths_per_sample is not None
608
+ else inference_params.seqlen_offset
609
+ )
610
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
611
+ return flash_attn_with_kvcache(
612
+ q,
613
+ kv_cache[:, :, 0],
614
+ kv_cache[:, :, 1],
615
+ kv[:, :, 0],
616
+ kv[:, :, 1],
617
+ cache_seqlens=cache_seqlens,
618
+ softmax_scale=self.inner_cross_attn.softmax_scale,
619
+ causal=self.inner_cross_attn.causal,
620
+ alibi_slopes=alibi_slopes,
621
+ )
622
+
623
+ def forward(
624
+ self,
625
+ x,
626
+ x_kv=None,
627
+ key_padding_mask=None,
628
+ cu_seqlens=None,
629
+ max_seqlen=None,
630
+ mixer_subset=None,
631
+ inference_params=None,
632
+ adapter_mask=None,
633
+ **kwargs,
634
+ ):
635
+ """
636
+ Arguments:
637
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
638
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
639
+ is the is the sum of the sequence lengths in the batch.
640
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
641
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
642
+ of the sequences in the batch, used to index into x. Only applicable when using
643
+ FlashAttention.
644
+ max_seqlen: int. Maximum sequence length in the batch.
645
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
646
+ (batch, seqlen). Only applicable when not using FlashAttention.
647
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
648
+ before applying the query projection. Useful for e.g., ViT where we only care
649
+ about the CLS token in the last layer.
650
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
651
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
652
+ """
653
+ if cu_seqlens is not None:
654
+ assert max_seqlen is not None
655
+ assert key_padding_mask is None
656
+ assert self.use_flash_attn
657
+ assert not self.dwconv
658
+ if key_padding_mask is not None:
659
+ assert cu_seqlens is None
660
+ assert max_seqlen is None
661
+ assert not self.use_flash_attn
662
+ if inference_params is not None:
663
+ assert key_padding_mask is None
664
+ assert cu_seqlens is None and max_seqlen is None
665
+ assert not self.dwconv
666
+
667
+ kwargs = (
668
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
669
+ if self.use_flash_attn
670
+ else {"key_padding_mask": key_padding_mask, **kwargs}
671
+ )
672
+ seqlen_offset = (
673
+ 0
674
+ if inference_params is None
675
+ else (
676
+ inference_params.lengths_per_sample
677
+ if inference_params.lengths_per_sample is not None
678
+ else inference_params.seqlen_offset
679
+ )
680
+ )
681
+ rotary_max_seqlen = (
682
+ inference_params.max_sequence_len
683
+ if inference_params is not None
684
+ else max_seqlen
685
+ )
686
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
687
+ assert x_kv is None and mixer_subset is None
688
+
689
+ if adapter_mask is not None:
690
+ unique_tasks = torch.unique(adapter_mask)
691
+ qkv_dtype = next(self.Wqkv.parameters()).dtype
692
+ qkv = torch.empty(
693
+ *x.shape[:-1],
694
+ self.Wqkv.out_features,
695
+ dtype=qkv_dtype,
696
+ device=x.device,
697
+ )
698
+ for task_id in unique_tasks:
699
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
700
+ task_tensor = x[task_indices]
701
+ if not self.return_residual:
702
+ task_qkv = self.Wqkv(task_tensor, task_id=task_id)
703
+ else:
704
+ task_qkv, _ = self.Wqkv(
705
+ task_tensor, task_id=task_id, residual=True
706
+ )
707
+ qkv[task_indices] = task_qkv
708
+ else:
709
+ if not self.return_residual:
710
+ qkv = self.Wqkv(x)
711
+ else:
712
+ if hasattr(self.Wqkv, "parametrizations"):
713
+ qkv, x = self.Wqkv(x, residual=True)
714
+ else:
715
+ qkv, x = self.Wqkv(x)
716
+
717
+ if self.dwconv:
718
+ qkv = rearrange(
719
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2],
720
+ "b d s -> b s d",
721
+ ).contiguous()
722
+ qkv = rearrange(
723
+ qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
724
+ )
725
+ if (
726
+ inference_params is None
727
+ or inference_params.seqlen_offset == 0
728
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
729
+ or not self.use_flash_attn
730
+ ):
731
+ if self.rotary_emb_dim > 0:
732
+ qkv = self.rotary_emb(
733
+ qkv,
734
+ seqlen_offset=seqlen_offset,
735
+ cu_seqlens=cu_seqlens,
736
+ max_seqlen=rotary_max_seqlen,
737
+ )
738
+ if inference_params is None:
739
+ if not self.checkpointing:
740
+ context = self.inner_attn(qkv, **kwargs)
741
+ else:
742
+ context = torch.utils.checkpoint.checkpoint(
743
+ self.inner_attn, qkv, **kwargs
744
+ )
745
+ else:
746
+ context = self._update_kvcache_attention(
747
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
748
+ )
749
+ else:
750
+ context = self._apply_rotary_update_kvcache_attention(
751
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
752
+ )
753
+ else:
754
+ if self.cross_attn:
755
+ if not self.return_residual:
756
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
757
+ kv = self.Wkv(x_kv if x_kv is not None else x)
758
+ else:
759
+ if x_kv is not None:
760
+ kv, x_kv = self.Wkv(x_kv)
761
+ else:
762
+ kv, x = self.Wkv(x)
763
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
764
+ else:
765
+ assert self.num_heads_kv != self.num_heads
766
+ if not self.return_residual:
767
+ qkv = self.Wqkv(x)
768
+ else:
769
+ qkv, x = self.Wqkv(x)
770
+ q = qkv[..., : self.num_heads * self.head_dim]
771
+ kv = qkv[..., self.num_heads * self.head_dim :]
772
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
773
+ kv = rearrange(
774
+ kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim
775
+ )
776
+ if self.dwconv:
777
+ q = rearrange(
778
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2],
779
+ "b d s -> b s d",
780
+ ).contiguous()
781
+ kv = rearrange(
782
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2],
783
+ "b d s -> b s d",
784
+ ).contiguous()
785
+ if (
786
+ inference_params is None
787
+ or inference_params.seqlen_offset == 0
788
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
789
+ or not self.use_flash_attn
790
+ ):
791
+ if self.rotary_emb_dim > 0:
792
+ q, kv = self.rotary_emb(
793
+ q,
794
+ kv,
795
+ seqlen_offset=seqlen_offset,
796
+ cu_seqlens=cu_seqlens,
797
+ max_seqlen=rotary_max_seqlen,
798
+ )
799
+ if inference_params is None:
800
+ if not self.checkpointing:
801
+ context = self.inner_cross_attn(q, kv, **kwargs)
802
+ else:
803
+ context = torch.utils.checkpoint.checkpoint(
804
+ self.inner_cross_attn, q, kv, **kwargs
805
+ )
806
+ else:
807
+ context = self._update_kvcache_attention(q, kv, inference_params)
808
+ else:
809
+ context = self._apply_rotary_update_kvcache_attention(
810
+ q, kv, inference_params
811
+ )
812
+
813
+ inp = rearrange(context, "... h d -> ... (h d)")
814
+ if adapter_mask is not None:
815
+ unique_tasks = torch.unique(adapter_mask)
816
+ out_dtype = next(self.out_proj.parameters()).dtype
817
+ out = torch.empty(
818
+ *inp.shape[:-1],
819
+ self.out_proj.out_features,
820
+ dtype=out_dtype,
821
+ device=inp.device,
822
+ )
823
+ for task_id in unique_tasks:
824
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
825
+ task_tensor = inp[task_indices]
826
+ task_out = self.out_proj(task_tensor, task_id=task_id)
827
+ out[task_indices] = task_out
828
+ else:
829
+ out = self.out_proj(inp)
830
+ return out if not self.return_residual else (out, x)
mlp.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mlp.py
2
+ # Commit id: c3b219665292c61a51153d0ded4473c494296382
3
+
4
+ # Copyright (c) 2023, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.distributed import ProcessGroup
10
+
11
+ try:
12
+ from flash_attn.ops.activations import swiglu
13
+ except ImportError:
14
+ swiglu = None
15
+
16
+ try:
17
+ from flash_attn.ops.fused_dense import (ColumnParallelLinear,
18
+ RowParallelLinear)
19
+ except ImportError:
20
+ ColumnParallelLinear, RowParallelLinear = None, None
21
+
22
+ try:
23
+ from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
24
+ except ImportError:
25
+ FusedMLP, ParallelFusedMLP = None, None
26
+
27
+
28
+ class Mlp(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_features,
32
+ hidden_features=None,
33
+ out_features=None,
34
+ activation=F.gelu,
35
+ bias1=True,
36
+ bias2=True,
37
+ return_residual=False,
38
+ device=None,
39
+ dtype=None,
40
+ ):
41
+ factory_kwargs = {"device": device, "dtype": dtype}
42
+ super().__init__()
43
+ out_features = out_features if out_features is not None else in_features
44
+ hidden_features = (
45
+ hidden_features if hidden_features is not None else in_features * 4
46
+ )
47
+ self.return_residual = return_residual
48
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
49
+ self.activation = activation
50
+ self.fc2 = nn.Linear(
51
+ hidden_features, out_features, bias=bias2, **factory_kwargs
52
+ )
53
+
54
+ def forward(self, x, adapter_mask=None):
55
+ if adapter_mask is not None:
56
+ unique_tasks = torch.unique(adapter_mask)
57
+ fc1_dtype = next(self.fc1.parameters()).dtype
58
+ y = torch.empty(
59
+ *x.shape[:-1], self.fc1.out_features, dtype=fc1_dtype, device=x.device
60
+ )
61
+ for task_id in unique_tasks:
62
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
63
+ task_tensor = x[task_indices]
64
+ task_y = self.fc1(task_tensor, task_id=task_id)
65
+ y[task_indices] = task_y
66
+ else:
67
+ y = self.fc1(x)
68
+
69
+ y = self.activation(y)
70
+
71
+ if adapter_mask is not None:
72
+ unique_tasks = torch.unique(adapter_mask)
73
+ fc2_dtype = next(self.fc2.parameters()).dtype
74
+ out = torch.empty(
75
+ *y.shape[:-1], self.fc2.out_features, dtype=fc2_dtype, device=y.device
76
+ )
77
+ for task_id in unique_tasks:
78
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
79
+ task_tensor = y[task_indices]
80
+ task_out = self.fc2(task_tensor, task_id=task_id)
81
+ out[task_indices] = task_out
82
+ else:
83
+ out = self.fc2(y)
84
+
85
+ return out if not self.return_residual else (out, x)
86
+
87
+
88
+ class ParallelMLP(nn.Module):
89
+ def __init__(
90
+ self,
91
+ in_features,
92
+ hidden_features=None,
93
+ out_features=None,
94
+ activation=F.gelu,
95
+ process_group: ProcessGroup = None,
96
+ sequence_parallel=True,
97
+ bias1=True,
98
+ bias2=True,
99
+ device=None,
100
+ dtype=None,
101
+ ):
102
+ factory_kwargs = {"device": device, "dtype": dtype}
103
+ super().__init__()
104
+ assert ColumnParallelLinear is not None, "Need to install fused_dense"
105
+ assert RowParallelLinear is not None, "Need to install fused_dense"
106
+ out_features = out_features if out_features is not None else in_features
107
+ hidden_features = (
108
+ hidden_features if hidden_features is not None else in_features * 4
109
+ )
110
+ self.fc1 = ColumnParallelLinear(
111
+ in_features,
112
+ hidden_features,
113
+ process_group,
114
+ bias=bias1,
115
+ sequence_parallel=sequence_parallel,
116
+ **factory_kwargs,
117
+ )
118
+ self.activation = activation
119
+ self.fc2 = RowParallelLinear(
120
+ hidden_features,
121
+ out_features,
122
+ process_group,
123
+ bias=bias2,
124
+ sequence_parallel=sequence_parallel,
125
+ **factory_kwargs,
126
+ )
127
+
128
+ def forward(self, x):
129
+ y = self.fc1(x)
130
+ y = self.activation(y)
131
+ y = self.fc2(y)
132
+ return y
133
+
134
+
135
+ class GatedMlp(nn.Module):
136
+ def __init__(
137
+ self,
138
+ in_features,
139
+ hidden_features=None,
140
+ out_features=None,
141
+ activation=F.sigmoid,
142
+ bias1=True,
143
+ bias2=True,
144
+ multiple_of=128,
145
+ return_residual=False,
146
+ device=None,
147
+ dtype=None,
148
+ ):
149
+ factory_kwargs = {"device": device, "dtype": dtype}
150
+ super().__init__()
151
+ out_features = out_features if out_features is not None else in_features
152
+ hidden_features = (
153
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
154
+ )
155
+ hidden_features = (
156
+ (hidden_features + multiple_of - 1) // multiple_of * multiple_of
157
+ )
158
+ self.return_residual = return_residual
159
+ self.fc1 = nn.Linear(
160
+ in_features, 2 * hidden_features, bias=bias1, **factory_kwargs
161
+ )
162
+ self.activation = activation
163
+ self.fc2 = nn.Linear(
164
+ hidden_features, out_features, bias=bias2, **factory_kwargs
165
+ )
166
+
167
+ def forward(self, x):
168
+ y = self.fc1(x)
169
+ if self.activation == F.sigmoid: # Special case for GLU
170
+ y = F.glu(y, dim=-1)
171
+ elif (
172
+ self.activation == F.silu and swiglu is not None
173
+ ): # Special case for SwiGLU
174
+ y, gate = y.chunk(2, dim=-1)
175
+ y = swiglu(gate, y)
176
+ else:
177
+ y, gate = y.chunk(2, dim=-1)
178
+ y = y * self.activation(gate)
179
+ y = self.fc2(y)
180
+ return y if not self.return_residual else (y, x)
181
+
182
+
183
+ class ParallelGatedMlp(nn.Module):
184
+ """Parallel GatedMlp"""
185
+
186
+ def __init__(
187
+ self,
188
+ in_features,
189
+ process_group,
190
+ hidden_features=None,
191
+ out_features=None,
192
+ activation=F.sigmoid,
193
+ bias1=True,
194
+ bias2=True,
195
+ multiple_of=128,
196
+ sequence_parallel=True,
197
+ device=None,
198
+ dtype=None,
199
+ ):
200
+ factory_kwargs = {"device": device, "dtype": dtype}
201
+ super().__init__()
202
+ out_features = out_features if out_features is not None else in_features
203
+ hidden_features = (
204
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
205
+ )
206
+ hidden_features = (
207
+ (hidden_features + multiple_of - 1) // multiple_of * multiple_of
208
+ )
209
+ if ColumnParallelLinear is None or RowParallelLinear is None:
210
+ raise ImportError("fused_dense is not installed")
211
+ self.fc1 = ColumnParallelLinear(
212
+ in_features,
213
+ 2 * hidden_features,
214
+ process_group,
215
+ bias=bias1,
216
+ sequence_parallel=sequence_parallel,
217
+ **factory_kwargs,
218
+ )
219
+ self.activation = activation
220
+ self.fc2 = RowParallelLinear(
221
+ hidden_features,
222
+ out_features,
223
+ process_group,
224
+ bias=bias2,
225
+ sequence_parallel=sequence_parallel,
226
+ **factory_kwargs,
227
+ )
228
+
229
+ def forward(self, x):
230
+ y = self.fc1(x)
231
+ if self.activation == F.sigmoid: # Special case for GLU
232
+ y = F.glu(y, dim=-1)
233
+ else:
234
+ y, gate = y.chunk(2, dim=-1)
235
+ y = y * self.activation(gate)
236
+ y = self.fc2(y)
237
+ return y
modeling_lora.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from functools import partial
4
+ from typing import Iterator, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.utils.parametrize as parametrize
9
+ from torch import nn
10
+ from torch.nn import Parameter
11
+ from torch.nn import functional as F
12
+ from transformers import PretrainedConfig
13
+
14
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
15
+ from .modeling_xlm_roberta import (
16
+ XLMRobertaFlashConfig,
17
+ XLMRobertaModel,
18
+ XLMRobertaPreTrainedModel,
19
+ )
20
+
21
+
22
+ def initialized_weights(
23
+ shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
24
+ ) -> torch.Tensor:
25
+ weight_data = []
26
+ for _ in range(num_adaptations):
27
+ new_adaption = torch.zeros(shape)
28
+ if init == "kaiming":
29
+ nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
30
+ elif init == "normal":
31
+ nn.init.normal_(new_adaption)
32
+ else:
33
+ raise NotImplementedError
34
+ weight_data.append(new_adaption)
35
+ return torch.stack(weight_data, dim=0)
36
+
37
+
38
+ class LoRAParametrization(nn.Module):
39
+ """
40
+ This LoRA implementation was inspired by https://github.com/cccntu/minLoRA
41
+ The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy
42
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software
43
+ and associated documentation files (the "Software"), to deal in the Software without restriction,
44
+ including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
45
+ and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
46
+ subject to the following conditions:
47
+ The above copyright notice and this permission notice shall be included in all copies or substantial
48
+ portions of the Software.
49
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
50
+ LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
51
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
52
+ WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
53
+ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ fan_in: int,
59
+ fan_out: int,
60
+ layer_type: str = "linear",
61
+ num_adaptations: int = 1,
62
+ rank: int = 4,
63
+ dropout_p: float = 0.0,
64
+ alpha: float = 1,
65
+ ):
66
+ super().__init__()
67
+ # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
68
+ # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
69
+ fan_in_fan_out = layer_type == "embedding"
70
+ self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
71
+
72
+ if layer_type == "linear":
73
+ self.lora_A = nn.Parameter(
74
+ initialized_weights((rank, fan_in), num_adaptations, init="kaiming")
75
+ )
76
+ self.lora_B = nn.Parameter(torch.zeros((num_adaptations, fan_out, rank)))
77
+ elif layer_type == "embedding":
78
+ self.lora_A = nn.Parameter(torch.zeros((num_adaptations, fan_in, rank)))
79
+ self.lora_B = nn.Parameter(
80
+ initialized_weights(
81
+ (rank, fan_out), num_adaptations=num_adaptations, init="normal"
82
+ )
83
+ )
84
+ else:
85
+ raise NotImplementedError
86
+
87
+ self.lora_alpha, self.rank = alpha, rank
88
+ self.scaling = alpha / rank
89
+ self.lora_dropout = nn.Dropout(p=dropout_p) if dropout_p > 0 else lambda x: x
90
+ self.dropout_fn = self._dropout if dropout_p > 0 else lambda x: x
91
+ self.register_buffer(
92
+ "lora_dropout_mask",
93
+ torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
94
+ persistent=False,
95
+ )
96
+
97
+ def _dropout(self, A):
98
+ # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
99
+ return A * self.lora_dropout(self.lora_dropout_mask)
100
+
101
+ def lora_forward(self, X, current_task):
102
+ return (
103
+ X
104
+ + torch.matmul(
105
+ *self.swap(
106
+ (
107
+ self.lora_B[current_task],
108
+ self.dropout_fn(self.lora_A[current_task]),
109
+ )
110
+ )
111
+ ).view(X.shape)
112
+ * self.scaling
113
+ )
114
+
115
+ def forward(self, X):
116
+ return X
117
+
118
+ @classmethod
119
+ def from_linear(
120
+ cls,
121
+ layer: nn.Module,
122
+ num_adaptations: int,
123
+ rank: int,
124
+ dropout_p: float,
125
+ alpha: float,
126
+ ):
127
+ assert isinstance(layer, nn.Linear)
128
+ fan_out, fan_in = layer.weight.shape
129
+ return cls(
130
+ fan_in,
131
+ fan_out,
132
+ num_adaptations=num_adaptations,
133
+ layer_type="linear",
134
+ rank=rank,
135
+ dropout_p=dropout_p,
136
+ alpha=alpha,
137
+ )
138
+
139
+ @classmethod
140
+ def from_embedding(
141
+ cls,
142
+ layer: nn.Module,
143
+ num_adaptations: int,
144
+ rank: int,
145
+ dropout_p: float,
146
+ alpha: float,
147
+ ):
148
+ assert isinstance(layer, nn.Embedding)
149
+ fan_in, fan_out = layer.weight.shape
150
+ return cls(
151
+ fan_in,
152
+ fan_out,
153
+ num_adaptations=num_adaptations,
154
+ layer_type="embedding",
155
+ rank=rank,
156
+ dropout_p=dropout_p,
157
+ alpha=alpha,
158
+ )
159
+
160
+ @classmethod
161
+ def add_to_layer(
162
+ cls,
163
+ layer: nn.Module,
164
+ num_adaptations: int,
165
+ rank: int,
166
+ dropout_p: float,
167
+ alpha: float,
168
+ ):
169
+ """
170
+ Registering LoRA adapters to all embedding and linear layers.
171
+ Additionally, we implement a custom forward function for LoRA parametrization.
172
+ This function modifies the layer's forward pass to optionally use task-specific
173
+ parameters. When a `task_id` is provided, it employs a LoRA parametrization
174
+ to modify the original weights according to the specific task. This allows
175
+ the layer to adapt dynamically to different tasks at runtime. If no `task_id`
176
+ is specified, the layer uses its original weights.
177
+ """
178
+ if isinstance(layer, nn.Linear):
179
+ parametrize.register_parametrization(
180
+ layer,
181
+ "weight",
182
+ cls.from_linear(
183
+ layer,
184
+ num_adaptations=num_adaptations,
185
+ rank=rank,
186
+ dropout_p=dropout_p,
187
+ alpha=alpha,
188
+ ),
189
+ )
190
+
191
+ def new_forward(self, input, task_id=None, residual=False):
192
+ if task_id is not None:
193
+ weights = self.parametrizations.weight[0].lora_forward(
194
+ self.weight, current_task=task_id
195
+ )
196
+ else:
197
+ weights = self.weight
198
+
199
+ out = F.linear(input, weights, self.bias)
200
+
201
+ if residual:
202
+ return out, input
203
+ return out
204
+
205
+ layer.forward = new_forward.__get__(layer, layer.__class__)
206
+
207
+ elif isinstance(layer, nn.Embedding):
208
+ parametrize.register_parametrization(
209
+ layer,
210
+ "weight",
211
+ cls.from_embedding(
212
+ layer,
213
+ num_adaptations=num_adaptations,
214
+ rank=rank,
215
+ dropout_p=dropout_p,
216
+ alpha=alpha,
217
+ ),
218
+ )
219
+
220
+ def new_forward(self, input, task_id=None):
221
+ if task_id is not None:
222
+ weights = self.parametrizations.weight[0].lora_forward(
223
+ self.weight, current_task=task_id
224
+ )
225
+ else:
226
+ weights = self.weight
227
+
228
+ out = F.embedding(
229
+ input,
230
+ weights,
231
+ self.padding_idx,
232
+ self.max_norm,
233
+ self.norm_type,
234
+ self.scale_grad_by_freq,
235
+ self.sparse,
236
+ )
237
+
238
+ return out
239
+
240
+ layer.forward = new_forward.__get__(layer, layer.__class__)
241
+
242
+
243
+ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
244
+ """
245
+ A wrapper class around the Jina XLM-RoBERTa model that integrates LoRA (Low-Rank Adaptation) adapters.
246
+ """
247
+
248
+ def __init__(
249
+ self,
250
+ config: XLMRobertaFlashConfig,
251
+ roberta: Optional[XLMRobertaModel] = None,
252
+ add_pooling_layer: bool = True,
253
+ ):
254
+ super().__init__(config)
255
+ if roberta is None:
256
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=add_pooling_layer)
257
+ else:
258
+ self.roberta = roberta
259
+
260
+ self._lora_adaptations = config.lora_adaptations
261
+ if (
262
+ not isinstance(self._lora_adaptations, list)
263
+ or len(self._lora_adaptations) < 1
264
+ ):
265
+ raise ValueError(
266
+ f"`lora_adaptations` must be a list and contain at least one element"
267
+ )
268
+ self._task_instructions = config.task_instructions
269
+ if (
270
+ not isinstance(self._task_instructions, dict)
271
+ or len(self._task_instructions) != len(self._lora_adaptations)
272
+ or not all(
273
+ [v in self._lora_adaptations for v in self._task_instructions.keys()]
274
+ )
275
+ ):
276
+ raise ValueError(
277
+ f"`task_instructions` must be a dict and contain the same number of elements "
278
+ f"as `lora_adaptations` with all keys in `task_instructions` present in `lora_adaptations`."
279
+ )
280
+ self._adaptation_map = {
281
+ name: idx for idx, name in enumerate(self._lora_adaptations)
282
+ }
283
+ self._rank = config.lora_rank
284
+ self._dropout_p = config.lora_dropout_p
285
+ self._alpha = config.lora_alpha
286
+ self._register_lora(
287
+ num_adaptations=len(self._lora_adaptations),
288
+ rank=self._rank,
289
+ dropout_p=self._dropout_p,
290
+ alpha=self._alpha,
291
+ )
292
+ self.main_params_trainable = config.lora_main_params_trainable
293
+
294
+ @property
295
+ def rotary_emb_base(self):
296
+ return self.roberta.rotary_emb_base
297
+
298
+ @rotary_emb_base.setter
299
+ def rotary_emb_base(self, base):
300
+ self.roberta.rotary_emb_base = base
301
+
302
+ @property
303
+ def main_params_trainable(self):
304
+ return self._main_params_trainable
305
+
306
+ @main_params_trainable.setter
307
+ def main_params_trainable(self, val: bool):
308
+ """Whether the main parameters (i.e. those that are not LoRA) should be trainable.
309
+ This method sets the `requires_grad_` attribute of the main weights
310
+ and controls which parameters are returned in `self.parameters()`.
311
+ :param val: Whether or not to make the parameters trainable.
312
+ :return: None
313
+ """
314
+ self._main_params_trainable = val
315
+ for name, param in super().named_parameters():
316
+ if "lora" not in name:
317
+ param.requires_grad_(val)
318
+
319
+ @classmethod
320
+ def from_pretrained(
321
+ cls,
322
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
323
+ *model_args,
324
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
325
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
326
+ ignore_mismatched_sizes: bool = False,
327
+ force_download: bool = False,
328
+ local_files_only: bool = False,
329
+ token: Optional[Union[str, bool]] = None,
330
+ revision: str = "main",
331
+ use_safetensors: bool = None,
332
+ **kwargs,
333
+ ):
334
+ for key in list(kwargs.keys()):
335
+ if key in config.to_dict():
336
+ config.update({key: kwargs.pop(key)})
337
+ if config.load_trained_adapters: # checkpoint already contains LoRA adapters
338
+ return super().from_pretrained(
339
+ pretrained_model_name_or_path,
340
+ *model_args,
341
+ config=config,
342
+ cache_dir=cache_dir,
343
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
344
+ force_download=force_download,
345
+ local_files_only=local_files_only,
346
+ token=token,
347
+ revision=revision,
348
+ use_safetensors=use_safetensors,
349
+ **kwargs,
350
+ )
351
+ else: # initializing new adapters
352
+ roberta = XLMRobertaModel.from_pretrained(
353
+ pretrained_model_name_or_path,
354
+ *model_args,
355
+ use_flash_attn=config.use_flash_attn,
356
+ **kwargs,
357
+ )
358
+ return cls(config, roberta=roberta)
359
+
360
+ def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
361
+ self.apply(
362
+ partial(
363
+ LoRAParametrization.add_to_layer,
364
+ num_adaptations=num_adaptations,
365
+ rank=rank,
366
+ dropout_p=dropout_p,
367
+ alpha=alpha,
368
+ )
369
+ )
370
+
371
+ def forward(self, *args, **kwargs):
372
+ return self.roberta(*args, **kwargs)
373
+
374
+ def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
375
+ for _, param in self.named_parameters(recurse=recurse):
376
+ yield param
377
+
378
+ def named_parameters(
379
+ self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
380
+ ) -> Iterator[Tuple[str, Parameter]]:
381
+ for name, param in super().named_parameters(
382
+ prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
383
+ ):
384
+ if "lora" in name or self.main_params_trainable:
385
+ yield name, param
386
+
387
+ @torch.inference_mode()
388
+ def encode(
389
+ self,
390
+ sentences: Union[str, List[str]],
391
+ *args,
392
+ task: Optional[str] = None,
393
+ **kwargs,
394
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
395
+ """
396
+ Computes sentence embeddings.
397
+ sentences(`str` or `List[str]`):
398
+ Sentence or sentences to be encoded
399
+ task(`str`, *optional*, defaults to `None`):
400
+ Specifies the task for which the encoding is intended. If `task` is not provided,
401
+ all LoRA adapters are disabled, and the model reverts to its original,
402
+ general-purpose weights.
403
+ """
404
+ if task and task not in self._lora_adaptations:
405
+ raise ValueError(
406
+ f"Unsupported task '{task}'. "
407
+ f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
408
+ f"Alternatively, don't pass the `task` argument to disable LoRA."
409
+ )
410
+ adapter_mask = None
411
+ if task:
412
+ task_id = self._adaptation_map[task]
413
+ num_examples = 1 if isinstance(sentences, str) else len(sentences)
414
+ adapter_mask = torch.full(
415
+ (num_examples,), task_id, dtype=torch.int32, device=self.device
416
+ )
417
+ if isinstance(sentences, str):
418
+ sentences = self._task_instructions[task] + sentences
419
+ else:
420
+ sentences = [
421
+ self._task_instructions[task] + sentence for sentence in sentences
422
+ ]
423
+ return self.roberta.encode(
424
+ sentences, *args, adapter_mask=adapter_mask, **kwargs
425
+ )
426
+
modeling_xlm_roberta.py ADDED
@@ -0,0 +1,1254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+ # Copyright (c) 2022, Tri Dao.
4
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
5
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
6
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
7
+
8
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
+
10
+ import importlib.util
11
+ import logging
12
+ import re
13
+ from collections import OrderedDict
14
+ from collections.abc import Sequence
15
+ from functools import partial
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import torch.utils.checkpoint
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+ from transformers import AutoTokenizer, PretrainedConfig
25
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.models.bert.modeling_bert import (
28
+ BaseModelOutputWithPoolingAndCrossAttentions,
29
+ BertForPreTrainingOutput,
30
+ )
31
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
32
+
33
+ from .rotary import RotaryEmbedding
34
+ from .block import Block
35
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
36
+ from .embedding import XLMRobertaEmbeddings
37
+ from .mha import MHA
38
+ from .mlp import FusedMLP, Mlp
39
+ from .xlm_padding import index_first_axis_residual, pad_input, unpad_input
40
+
41
+ try:
42
+ from flash_attn.ops.fused_dense import FusedDense
43
+ except ImportError:
44
+ FusedDense = None
45
+
46
+ try:
47
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn
48
+ except ImportError:
49
+ layer_norm_fn = None
50
+
51
+
52
+ try:
53
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
54
+ except ImportError:
55
+ CrossEntropyLoss = torch.nn.CrossEntropyLoss
56
+
57
+ try:
58
+ from tqdm.autonotebook import trange
59
+ except ImportError:
60
+ trange = None
61
+
62
+
63
+ logger = logging.getLogger(__name__)
64
+
65
+
66
+ def get_use_flash_attn(config: XLMRobertaFlashConfig):
67
+ if not getattr(config, "use_flash_attn", False) or not torch.cuda.is_available():
68
+ return False
69
+ if importlib.util.find_spec("flash_attn") is None:
70
+ logger.warning(
71
+ "flash_attn is not installed. Using PyTorch native attention implementation."
72
+ )
73
+ return False
74
+ return True
75
+
76
+
77
+ def create_mixer_cls(config, cross_attn=False, return_residual=False):
78
+ use_flash_attn = get_use_flash_attn(config)
79
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
80
+ rotary_kwargs = {}
81
+ if config.position_embedding_type == "rotary":
82
+ rotary_kwargs["rotary_emb_dim"] = getattr(
83
+ config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
84
+ )
85
+ rotary_kwargs["rotary_emb_base"] = config.rotary_emb_base
86
+ rotary_kwargs["rotary_emb_scale_base"] = getattr(
87
+ config, "rotary_emb_scale_base", None
88
+ )
89
+ rotary_kwargs["rotary_emb_interleaved"] = getattr(
90
+ config, "rotary_emb_interleaved", False
91
+ )
92
+ mixer_cls = partial(
93
+ MHA,
94
+ num_heads=config.num_attention_heads,
95
+ cross_attn=cross_attn,
96
+ dropout=config.attention_probs_dropout_prob,
97
+ causal=False,
98
+ fused_bias_fc=fused_bias_fc,
99
+ use_flash_attn=use_flash_attn,
100
+ return_residual=return_residual,
101
+ use_alibi=config.position_embedding_type == "alibi",
102
+ **rotary_kwargs,
103
+ )
104
+ return mixer_cls
105
+
106
+
107
+ def create_mlp_cls(config, layer_idx=None, return_residual=False):
108
+ inner_dim = config.intermediate_size
109
+ fused_mlp = getattr(config, "fused_mlp", False)
110
+ if fused_mlp:
111
+ assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
112
+ "fused_mlp only " "supports approximate gelu"
113
+ )
114
+ if not fused_mlp:
115
+ approximate = (
116
+ "tanh"
117
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
118
+ else "none"
119
+ )
120
+ mlp_cls = partial(
121
+ Mlp,
122
+ hidden_features=inner_dim,
123
+ activation=partial(F.gelu, approximate=approximate),
124
+ return_residual=return_residual,
125
+ )
126
+ else:
127
+ if FusedMLP is None:
128
+ raise ImportError("fused_dense is not installed")
129
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
130
+ # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
131
+ if isinstance(mlp_checkpoint_lvl, Sequence):
132
+ assert layer_idx is not None
133
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
134
+ mlp_cls = partial(
135
+ FusedMLP,
136
+ hidden_features=inner_dim,
137
+ checkpoint_lvl=mlp_checkpoint_lvl,
138
+ return_residual=return_residual,
139
+ )
140
+ return mlp_cls
141
+
142
+
143
+ def create_block(config, layer_idx=None):
144
+ last_layer_subset = getattr(config, "last_layer_subset", False)
145
+ cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
146
+ # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
147
+ # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
148
+ # one layer) so we just choose not to return residual in this case.
149
+ return_residual = not cross_attn
150
+ mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
151
+ mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
152
+ norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
153
+ block = Block(
154
+ config.hidden_size,
155
+ mixer_cls,
156
+ mlp_cls,
157
+ norm_cls=norm_cls,
158
+ prenorm=False,
159
+ resid_dropout1=config.hidden_dropout_prob,
160
+ resid_dropout2=config.hidden_dropout_prob,
161
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
162
+ return_residual=return_residual,
163
+ )
164
+ return block
165
+
166
+
167
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
168
+ def _init_weights(module, initializer_range=0.02):
169
+ if isinstance(module, nn.Linear):
170
+ nn.init.normal_(module.weight, std=initializer_range)
171
+ if module.bias is not None:
172
+ nn.init.zeros_(module.bias)
173
+ elif isinstance(module, nn.Embedding):
174
+ nn.init.normal_(module.weight, std=initializer_range)
175
+ if module.padding_idx is not None:
176
+ nn.init.zeros_(module.weight[module.padding_idx])
177
+
178
+
179
+ class XLMRobertaEncoder(nn.Module):
180
+ def __init__(self, config: XLMRobertaFlashConfig):
181
+ super().__init__()
182
+ self.use_flash_attn = get_use_flash_attn(config)
183
+ self.use_reentrant = config.use_reentrant
184
+ self.layers = nn.ModuleList(
185
+ [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
186
+ )
187
+ self._grad_checkpointing = False
188
+
189
+ @property
190
+ def gradient_checkpointing(self):
191
+ return self._grad_checkpointing
192
+
193
+ @gradient_checkpointing.setter
194
+ def gradient_checkpointing(self, value):
195
+ self._grad_checkpointing = value
196
+
197
+ def forward(
198
+ self,
199
+ hidden_states,
200
+ key_padding_mask=None,
201
+ subset_mask=None,
202
+ adapter_mask=None,
203
+ output_hidden_states: Optional[bool] = None,
204
+ ):
205
+ """If subset_mask is not None, we only want output for the subset of the sequence.
206
+ This means that we only compute the last layer output for these tokens.
207
+ subset_mask: (batch, seqlen), dtype=torch.bool
208
+ """
209
+
210
+ all_hidden_states = () if output_hidden_states else None
211
+
212
+ if output_hidden_states and subset_mask:
213
+ raise ValueError('output_hidden_states is not supported for subset_masks')
214
+
215
+ if key_padding_mask is None or not self.use_flash_attn:
216
+ mixer_kwargs = {"adapter_mask": adapter_mask}
217
+ if key_padding_mask is not None:
218
+ mixer_kwargs["key_padding_mask"] = key_padding_mask.bool()
219
+ for layer in self.layers:
220
+ if output_hidden_states:
221
+ all_hidden_states = all_hidden_states + (hidden_states,)
222
+ if self._grad_checkpointing:
223
+ hidden_states = torch.utils.checkpoint.checkpoint(
224
+ layer,
225
+ hidden_states,
226
+ use_reentrant=self.use_reentrant,
227
+ mixer_kwargs=mixer_kwargs,
228
+ )
229
+ else:
230
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
231
+ if output_hidden_states:
232
+ all_hidden_states = all_hidden_states + (hidden_states,)
233
+ if subset_mask is not None:
234
+ hidden_states = hidden_states[subset_mask]
235
+ else:
236
+ batch, seqlen = hidden_states.shape[:2]
237
+ if output_hidden_states:
238
+ all_hidden_states = all_hidden_states + (hidden_states,)
239
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = (
240
+ unpad_input(hidden_states, key_padding_mask, adapter_mask)
241
+ )
242
+ mixer_kwargs = {
243
+ "cu_seqlens": cu_seqlens,
244
+ "max_seqlen": max_seqlen_in_batch,
245
+ "adapter_mask": cu_adapter_mask,
246
+ }
247
+
248
+ if subset_mask is None:
249
+ for layer in self.layers:
250
+ if self._grad_checkpointing:
251
+ hidden_states = torch.utils.checkpoint.checkpoint(
252
+ layer,
253
+ hidden_states,
254
+ use_reentrant=self.use_reentrant,
255
+ mixer_kwargs=mixer_kwargs,
256
+ )
257
+ else:
258
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
259
+ if output_hidden_states:
260
+ all_hidden_states = all_hidden_states + (
261
+ pad_input(hidden_states, indices, batch, seqlen),
262
+ )
263
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
264
+ else:
265
+ for layer in self.layers[:-1]:
266
+ if self._grad_checkpointing:
267
+ hidden_states = torch.utils.checkpoint.checkpoint(
268
+ layer,
269
+ hidden_states,
270
+ use_reentrant=self.use_reentrant,
271
+ mixer_kwargs=mixer_kwargs,
272
+ )
273
+ else:
274
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
275
+ if key_padding_mask is not None:
276
+ subset_idx = torch.nonzero(
277
+ subset_mask[key_padding_mask], as_tuple=False
278
+ ).flatten()
279
+ subset_seqlens = (subset_mask & key_padding_mask).sum(
280
+ dim=-1, dtype=torch.int32
281
+ )
282
+ subset_cu_seqlens = F.pad(
283
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
284
+ (1, 0),
285
+ )
286
+ else:
287
+ subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
288
+ subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
289
+ subset_cu_seqlens = F.pad(
290
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
291
+ (1, 0),
292
+ )
293
+ hidden_states_subset, hidden_states = index_first_axis_residual(
294
+ hidden_states, subset_idx
295
+ )
296
+ # It's ok to set max_seqlen_q to be much larger
297
+ mixer_kwargs = {
298
+ "x_kv": hidden_states,
299
+ "cu_seqlens": subset_cu_seqlens,
300
+ "max_seqlen": max_seqlen_in_batch,
301
+ "cu_seqlens_k": cu_seqlens,
302
+ "max_seqlen_k": max_seqlen_in_batch,
303
+ }
304
+ if self._grad_checkpointing:
305
+ torch.utils.checkpoint.checkpoint(
306
+ self.layers[-1],
307
+ hidden_states_subset,
308
+ use_reentrant=self.use_reentrant,
309
+ mixer_kwargs=mixer_kwargs,
310
+ )
311
+ else:
312
+ hidden_states = self.layers[-1](
313
+ hidden_states_subset, mixer_kwargs=mixer_kwargs
314
+ )
315
+ return all_hidden_states if output_hidden_states else hidden_states
316
+
317
+
318
+ class XLMRobertaPooler(nn.Module):
319
+ def __init__(self, config):
320
+ super().__init__()
321
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
322
+ if fused_bias_fc and FusedDense is None:
323
+ raise ImportError("fused_dense is not installed")
324
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
325
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
326
+ self.activation = nn.Tanh()
327
+
328
+ def forward(self, hidden_states, pool=True, adapter_mask=None):
329
+ # We "pool" the model by simply taking the hidden state corresponding
330
+ # to the first token.
331
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
332
+ if adapter_mask is not None:
333
+ unique_tasks = torch.unique(adapter_mask)
334
+ pool_dtype = next(self.dense.parameters()).dtype
335
+ pooled_output = torch.empty(
336
+ first_token_tensor.shape[0],
337
+ self.dense.out_features,
338
+ dtype=pool_dtype,
339
+ device=first_token_tensor.device,
340
+ )
341
+ for task_id in unique_tasks:
342
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
343
+ task_first_token_tensor = first_token_tensor[task_indices]
344
+ task_pooled_output = self.dense(
345
+ task_first_token_tensor, task_id=task_id
346
+ )
347
+ pooled_output[task_indices] = task_pooled_output
348
+ else:
349
+ pooled_output = self.dense(first_token_tensor)
350
+ pooled_output = self.activation(pooled_output)
351
+ return pooled_output
352
+
353
+
354
+ class XLMRobertaPredictionHeadTransform(nn.Module):
355
+ def __init__(self, config):
356
+ super().__init__()
357
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
358
+ if fused_bias_fc and FusedDense is None:
359
+ raise ImportError("fused_dense is not installed")
360
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
361
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
362
+ raise ImportError("Triton is not installed")
363
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
364
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
365
+ approximate = (
366
+ "tanh"
367
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
368
+ else "none"
369
+ )
370
+ self.transform_act_fn = nn.GELU(approximate=approximate)
371
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
372
+
373
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
374
+ hidden_states = self.dense(hidden_states)
375
+ hidden_states = self.transform_act_fn(hidden_states)
376
+ if not self.fused_dropout_add_ln:
377
+ hidden_states = self.layer_norm(hidden_states)
378
+ else:
379
+ hidden_states = layer_norm_fn(
380
+ hidden_states,
381
+ self.layer_norm.weight,
382
+ self.layer_norm.bias,
383
+ eps=self.layer_norm.eps,
384
+ )
385
+ return hidden_states
386
+
387
+
388
+ class XLMRobertaLMPredictionHead(nn.Module):
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
392
+ if fused_bias_fc and FusedDense is None:
393
+ raise ImportError("fused_dense is not installed")
394
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
395
+
396
+ self.transform = XLMRobertaPredictionHeadTransform(config)
397
+
398
+ # The output weights are the same as the input embeddings, but there is
399
+ # an output-only bias for each token.
400
+ self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
401
+
402
+ def forward(self, hidden_states):
403
+ hidden_states = self.transform(hidden_states)
404
+ hidden_states = self.decoder(hidden_states)
405
+ return hidden_states
406
+
407
+
408
+ class XLMRobertaPreTrainingHeads(nn.Module):
409
+ def __init__(self, config):
410
+ super().__init__()
411
+ self.predictions = XLMRobertaLMPredictionHead(config)
412
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
413
+
414
+ def forward(self, sequence_output, pooled_output):
415
+ prediction_scores = self.predictions(sequence_output)
416
+ seq_relationship_score = self.seq_relationship(pooled_output)
417
+ return prediction_scores, seq_relationship_score
418
+
419
+
420
+ class XLMRobertaPreTrainedModel(PreTrainedModel):
421
+ """An abstract class to handle weights initialization and
422
+ a simple interface for dowloading and loading pretrained models.
423
+ """
424
+
425
+ config_class = XLMRobertaFlashConfig
426
+ base_model_prefix = "roberta"
427
+ supports_gradient_checkpointing = True
428
+ _supports_param_buffer_assignment = False
429
+
430
+ def _set_gradient_checkpointing(self, module, value=False):
431
+ if isinstance(module, XLMRobertaEncoder):
432
+ module.gradient_checkpointing = value
433
+
434
+ @classmethod
435
+ def from_pretrained(
436
+ cls,
437
+ *args,
438
+ **kwargs,
439
+ ):
440
+ if not "torch_dtype" in kwargs:
441
+ kwargs["torch_dtype"] = "auto"
442
+ return super().from_pretrained(*args, **kwargs)
443
+
444
+
445
+ class XLMRobertaModel(XLMRobertaPreTrainedModel):
446
+ def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
447
+ super().__init__(config)
448
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
449
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
450
+ config.vocab_size += self.pad_vocab_size_multiple - (
451
+ config.vocab_size % self.pad_vocab_size_multiple
452
+ )
453
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
454
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
455
+ raise ImportError("Triton is not installed")
456
+ assert config.hidden_act in [
457
+ "gelu",
458
+ "gelu_new",
459
+ "gelu_fast",
460
+ "gelu_pytorch_tanh",
461
+ ]
462
+ self.embeddings = XLMRobertaEmbeddings(
463
+ config.hidden_size,
464
+ config.vocab_size,
465
+ (
466
+ config.max_position_embeddings
467
+ if config.position_embedding_type == "absolute"
468
+ else -1
469
+ ),
470
+ config.type_vocab_size,
471
+ padding_idx=config.pad_token_id,
472
+ )
473
+ self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
474
+ self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
475
+ self.encoder = XLMRobertaEncoder(config)
476
+ self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
477
+
478
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
479
+ self.tokenizer = AutoTokenizer.from_pretrained(
480
+ self.name_or_path, trust_remote_code=True
481
+ )
482
+ self._rotary_emb_base = config.rotary_emb_base
483
+
484
+ @torch.inference_mode()
485
+ def encode(
486
+ self: "XLMRobertaModel",
487
+ sentences: Union[str, List[str]],
488
+ batch_size: int = 32,
489
+ show_progress_bar: Optional[bool] = None,
490
+ output_value: str = "sentence_embedding",
491
+ convert_to_numpy: bool = True,
492
+ convert_to_tensor: bool = False,
493
+ device: Optional[torch.device] = None,
494
+ normalize_embeddings: bool = True,
495
+ truncate_dim: Optional[int] = None,
496
+ adapter_mask: Optional[torch.Tensor] = None,
497
+ task: Optional[str] = None,
498
+ **tokenizer_kwargs,
499
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
500
+ """
501
+ Computes sentence embeddings
502
+ Args:
503
+ sentences(`str` or `List[str]`):
504
+ Sentence or sentences to be encoded
505
+ batch_size(`int`, *optional*, defaults to 32):
506
+ Batch size for the computation
507
+ show_progress_bar(`bool`, *optional*, defaults to None):
508
+ Show a progress bar when encoding sentences.
509
+ If set to None, progress bar is only shown when
510
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
511
+ output_value(`str`, *optional*, defaults to 'sentence_embedding'):
512
+ Default sentence_embedding, to get sentence embeddings.
513
+ Can be set to token_embeddings to get wordpiece token embeddings.
514
+ Set to None, to get all output values
515
+ convert_to_numpy(`bool`, *optional*, defaults to True):
516
+ If true, the output is a list of numpy vectors.
517
+ Else, it is a list of pytorch tensors.
518
+ convert_to_tensor(`bool`, *optional*, defaults to False):
519
+ If true, you get one large tensor as return.
520
+ Overwrites any setting from convert_to_numpy
521
+ device(`torch.device`, *optional*, defaults to None):
522
+ Which torch.device to use for the computation
523
+ normalize_embeddings(`bool`, *optional*, defaults to True):
524
+ If set to true, returned vectors will have length 1. In that case, the
525
+ faster dot-product (util.dot_score) instead of cosine similarity can
526
+ be used.
527
+ truncate_dim(`int`, *optional*, defaults to None):
528
+ The dimension to truncate sentence embeddings to. `None` does no truncation.
529
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
530
+ Keyword arguments for the tokenizer
531
+ Returns:
532
+ By default, a list of tensors is returned.
533
+ If convert_to_tensor, a stacked tensor is returned.
534
+ If convert_to_numpy, a numpy matrix is returned.
535
+ """
536
+ is_training = self.training
537
+ self.eval()
538
+
539
+ if show_progress_bar is None:
540
+ show_progress_bar = (
541
+ logger.getEffectiveLevel() == logging.INFO
542
+ or logger.getEffectiveLevel() == logging.DEBUG
543
+ )
544
+
545
+ if convert_to_tensor:
546
+ convert_to_numpy = False
547
+
548
+ if output_value != "sentence_embedding":
549
+ convert_to_tensor = False
550
+ convert_to_numpy = False
551
+
552
+ input_was_string = False
553
+ if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
554
+ sentences = [sentences]
555
+ input_was_string = True
556
+
557
+ if device is not None:
558
+ self.to(device)
559
+
560
+ permutation = np.argsort([-len(i) for i in sentences])
561
+ inverse_permutation = np.argsort(permutation)
562
+ sentences = [sentences[idx] for idx in permutation]
563
+
564
+ tokenizer_kwargs["padding"] = tokenizer_kwargs.get("padding", True)
565
+ tokenizer_kwargs["max_length"] = tokenizer_kwargs.get(
566
+ "max_length", self.tokenizer.init_kwargs.get("model_max_length", 8192)
567
+ )
568
+ tokenizer_kwargs["truncation"] = tokenizer_kwargs.get("truncation", True)
569
+
570
+ all_embeddings = []
571
+
572
+ if trange is not None:
573
+ range_iter = trange(
574
+ 0,
575
+ len(sentences),
576
+ batch_size,
577
+ desc="Encoding",
578
+ disable=not show_progress_bar,
579
+ )
580
+ else:
581
+ range_iter = range(0, len(sentences), batch_size)
582
+
583
+ for i in range_iter:
584
+ encoded_input = self.tokenizer(
585
+ sentences[i : i + batch_size],
586
+ return_tensors="pt",
587
+ **tokenizer_kwargs,
588
+ ).to(self.device)
589
+ lora_arguments = (
590
+ {"adapter_mask": adapter_mask[i : i + batch_size]}
591
+ if adapter_mask is not None
592
+ else {}
593
+ )
594
+ token_embs = self.forward(**encoded_input, **lora_arguments)[0]
595
+
596
+ # Accumulate in fp32 to avoid overflow
597
+ token_embs = token_embs.float()
598
+
599
+ if output_value == "token_embeddings":
600
+ raise NotImplementedError
601
+ elif output_value is None:
602
+ raise NotImplementedError
603
+ else:
604
+ if self.config.emb_pooler == "cls":
605
+ embeddings = self.cls_pooling(
606
+ token_embs, encoded_input["attention_mask"]
607
+ )
608
+ else:
609
+ embeddings = self.mean_pooling(
610
+ token_embs, encoded_input["attention_mask"]
611
+ )
612
+
613
+ all_embeddings.extend(embeddings)
614
+
615
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
616
+
617
+ truncate_dim = truncate_dim or self.config.truncate_dim
618
+ if truncate_dim:
619
+ all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
620
+
621
+ if normalize_embeddings:
622
+ all_embeddings = [
623
+ torch.nn.functional.normalize(embedding, p=2, dim=0)
624
+ for embedding in all_embeddings
625
+ ]
626
+
627
+ if convert_to_tensor:
628
+ all_embeddings = torch.stack(all_embeddings)
629
+ elif convert_to_numpy:
630
+ all_embeddings = np.asarray([emb.cpu().numpy() for emb in all_embeddings])
631
+
632
+ if input_was_string:
633
+ all_embeddings = all_embeddings[0]
634
+
635
+ self.train(is_training)
636
+ return all_embeddings
637
+
638
+ def truncate_embeddings(self, embeddings, truncate_dim):
639
+ if not self.config.matryoshka_dimensions:
640
+ logger.warning(
641
+ "Matryoshka embeddings are not supported, so dimension truncation will not be performed."
642
+ )
643
+ return embeddings
644
+ elif truncate_dim in self.config.matryoshka_dimensions:
645
+ return [tensor[:truncate_dim] for tensor in embeddings]
646
+ else:
647
+ raise ValueError(
648
+ f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
649
+ f"Supported dimensions are {self.config.matryoshka_dimensions}."
650
+ )
651
+
652
+ def mean_pooling(
653
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
654
+ ):
655
+ input_mask_expanded = (
656
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
657
+ )
658
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
659
+ input_mask_expanded.sum(1), min=1e-9
660
+ )
661
+
662
+ def cls_pooling(self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor):
663
+ return token_embeddings[:, 0]
664
+
665
+ @property
666
+ def rotary_emb_base(self):
667
+ return self._rotary_emb_base
668
+
669
+ @rotary_emb_base.setter
670
+ def rotary_emb_base(self, base):
671
+ if not isinstance(base, (int, float)):
672
+ raise TypeError("Base must be an integer or float")
673
+ logger.info(f"Changing RoPE base value to {base}")
674
+ for layer in self.encoder.layers:
675
+ layer.mixer.rotary_emb.base = base
676
+ self._rotary_emb_base = base
677
+
678
+ def forward(
679
+ self,
680
+ input_ids,
681
+ position_ids=None,
682
+ token_type_ids=None,
683
+ attention_mask=None,
684
+ masked_tokens_mask=None,
685
+ return_dict=None,
686
+ output_hidden_states=None,
687
+ **kwargs,
688
+ ):
689
+ """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
690
+ we only want the output for the masked tokens. This means that we only compute the last
691
+ layer output for these tokens.
692
+ masked_tokens_mask: (batch, seqlen), dtype=torch.bool
693
+ """
694
+ adapter_mask = kwargs.pop("adapter_mask", None)
695
+ if kwargs:
696
+ for key, value in kwargs.items():
697
+ if value is not None:
698
+ logger.warning(
699
+ "Flash attention implementation does not support kwargs: %s",
700
+ key,
701
+ )
702
+
703
+ return_dict = (
704
+ return_dict if return_dict is not None else self.config.use_return_dict
705
+ )
706
+
707
+ hidden_states = self.embeddings(
708
+ input_ids,
709
+ position_ids=position_ids,
710
+ token_type_ids=token_type_ids,
711
+ adapter_mask=adapter_mask,
712
+ )
713
+ # TD [2022-12:18]: Don't need to force residual in fp32
714
+ # BERT puts embedding LayerNorm before embedding dropout.
715
+ if not self.fused_dropout_add_ln:
716
+ hidden_states = self.emb_ln(hidden_states)
717
+ else:
718
+ hidden_states = layer_norm_fn(
719
+ hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
720
+ )
721
+ hidden_states = self.emb_drop(hidden_states)
722
+
723
+ if masked_tokens_mask is not None:
724
+ batch_size, seqlen = input_ids.shape[:2]
725
+ # We also need the first column for the CLS token
726
+ first_col_mask = torch.zeros(
727
+ batch_size, seqlen, dtype=torch.bool, device=input_ids.device
728
+ )
729
+ first_col_mask[:, 0] = True
730
+ subset_mask = masked_tokens_mask | first_col_mask
731
+ else:
732
+ subset_mask = None
733
+
734
+ sequence_output = self.encoder(
735
+ hidden_states,
736
+ key_padding_mask=attention_mask,
737
+ subset_mask=subset_mask,
738
+ adapter_mask=adapter_mask,
739
+ output_hidden_states=output_hidden_states,
740
+ )
741
+
742
+ if output_hidden_states:
743
+ all_hidden_states = sequence_output
744
+ sequence_output = sequence_output[-1]
745
+ else:
746
+ all_hidden_states = None
747
+
748
+ if masked_tokens_mask is None:
749
+ pooled_output = (
750
+ self.pooler(sequence_output, adapter_mask=adapter_mask)
751
+ if self.pooler is not None
752
+ else None
753
+ )
754
+ else:
755
+ # TD [2022-03-01]: the indexing here is very tricky.
756
+ if attention_mask is not None:
757
+ subset_idx = subset_mask[attention_mask]
758
+ pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
759
+ sequence_output = sequence_output[
760
+ masked_tokens_mask[attention_mask][subset_idx]
761
+ ]
762
+ else:
763
+ pool_input = sequence_output[first_col_mask[subset_mask]]
764
+ sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
765
+ pooled_output = (
766
+ self.pooler(pool_input, pool=False, adapter_mask=adapter_mask)
767
+ if self.pooler is not None
768
+ else None
769
+ )
770
+
771
+ if not return_dict:
772
+ return sequence_output, pooled_output
773
+
774
+ return BaseModelOutputWithPoolingAndCrossAttentions(
775
+ last_hidden_state=sequence_output,
776
+ pooler_output=pooled_output,
777
+ hidden_states=all_hidden_states,
778
+ )
779
+
780
+
781
+ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
782
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
783
+
784
+ def __init__(self, config):
785
+ super().__init__(config)
786
+
787
+ if config.is_decoder:
788
+ logger.warning(
789
+ "If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
790
+ "bi-directional self-attention."
791
+ )
792
+
793
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
794
+ self.lm_head = XLMRobertaLMHead(config)
795
+
796
+ # Initialize weights and apply final processing
797
+ self.post_init()
798
+
799
+ def get_input_embeddings(self):
800
+ return self.roberta.embeddings.word_embeddings
801
+
802
+ def get_output_embeddings(self):
803
+ return self.lm_head.decoder
804
+
805
+ def set_output_embeddings(self, new_embeddings):
806
+ self.lm_head.decoder = new_embeddings
807
+
808
+ def forward(
809
+ self,
810
+ input_ids: Optional[torch.LongTensor] = None,
811
+ attention_mask: Optional[torch.FloatTensor] = None,
812
+ token_type_ids: Optional[torch.LongTensor] = None,
813
+ position_ids: Optional[torch.LongTensor] = None,
814
+ head_mask: Optional[torch.FloatTensor] = None,
815
+ inputs_embeds: Optional[torch.FloatTensor] = None,
816
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
817
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
818
+ labels: Optional[torch.LongTensor] = None,
819
+ output_attentions: Optional[bool] = None,
820
+ output_hidden_states: Optional[bool] = None,
821
+ return_dict: Optional[bool] = None,
822
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
823
+ r"""
824
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
825
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
826
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
827
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
828
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
829
+ Used to hide legacy arguments that have been deprecated.
830
+ """
831
+ return_dict = (
832
+ return_dict if return_dict is not None else self.config.use_return_dict
833
+ )
834
+
835
+ outputs = self.roberta(
836
+ input_ids,
837
+ attention_mask=attention_mask,
838
+ token_type_ids=token_type_ids,
839
+ position_ids=position_ids,
840
+ head_mask=head_mask,
841
+ inputs_embeds=inputs_embeds,
842
+ encoder_hidden_states=encoder_hidden_states,
843
+ encoder_attention_mask=encoder_attention_mask,
844
+ output_attentions=output_attentions,
845
+ output_hidden_states=output_hidden_states,
846
+ return_dict=return_dict,
847
+ )
848
+ sequence_output = outputs[0]
849
+ prediction_scores = self.lm_head(sequence_output)
850
+
851
+ masked_lm_loss = None
852
+ if labels is not None:
853
+ # move labels to correct device to enable model parallelism
854
+ labels = labels.to(prediction_scores.device)
855
+ loss_fct = CrossEntropyLoss()
856
+ masked_lm_loss = loss_fct(
857
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
858
+ )
859
+
860
+ if not return_dict:
861
+ output = (prediction_scores,) + outputs[2:]
862
+ return (
863
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
864
+ )
865
+
866
+ return MaskedLMOutput(
867
+ loss=masked_lm_loss,
868
+ logits=prediction_scores,
869
+ hidden_states=outputs.hidden_states,
870
+ attentions=outputs.attentions,
871
+ )
872
+
873
+
874
+ def remap_state_dict(state_dict, config: PretrainedConfig):
875
+ """
876
+ Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
877
+ """
878
+
879
+ # LayerNorm
880
+ def key_mapping_ln_gamma_beta(key):
881
+ key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
882
+ key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
883
+ return key
884
+
885
+ state_dict = OrderedDict(
886
+ (key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()
887
+ )
888
+
889
+ # Layers
890
+ def key_mapping_layers(key):
891
+ return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
892
+
893
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
894
+
895
+ # LayerNorm
896
+ def key_mapping_ln(key):
897
+ key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
898
+ key = re.sub(
899
+ r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
900
+ r"bert.encoder.layers.\1.norm1.\2",
901
+ key,
902
+ )
903
+ key = re.sub(
904
+ r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
905
+ r"bert.encoder.layers.\1.norm2.\2",
906
+ key,
907
+ )
908
+ key = re.sub(
909
+ r"^cls.predictions.transform.LayerNorm.(weight|bias)",
910
+ r"cls.predictions.transform.layer_norm.\1",
911
+ key,
912
+ )
913
+ return key
914
+
915
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
916
+
917
+ # MLP
918
+ def key_mapping_mlp(key):
919
+ key = re.sub(
920
+ r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
921
+ r"bert.encoder.layers.\1.mlp.fc1.\2",
922
+ key,
923
+ )
924
+ key = re.sub(
925
+ r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
926
+ r"bert.encoder.layers.\1.mlp.fc2.\2",
927
+ key,
928
+ )
929
+ return key
930
+
931
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
932
+
933
+ # Attention
934
+ last_layer_subset = getattr(config, "last_layer_subset", False)
935
+ for d in range(config.num_hidden_layers):
936
+ Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
937
+ Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
938
+ Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
939
+ bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
940
+ bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
941
+ bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
942
+ if not (last_layer_subset and d == config.num_hidden_layers - 1):
943
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
944
+ [Wq, Wk, Wv], dim=0
945
+ )
946
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat(
947
+ [bq, bk, bv], dim=0
948
+ )
949
+ else:
950
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
951
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat(
952
+ [Wk, Wv], dim=0
953
+ )
954
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
955
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat(
956
+ [bk, bv], dim=0
957
+ )
958
+
959
+ def key_mapping_attn(key):
960
+ return re.sub(
961
+ r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
962
+ r"bert.encoder.layers.\1.mixer.out_proj.\2",
963
+ key,
964
+ )
965
+
966
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
967
+
968
+ def key_mapping_decoder_bias(key):
969
+ return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
970
+
971
+ state_dict = OrderedDict(
972
+ (key_mapping_decoder_bias(k), v) for k, v in state_dict.items()
973
+ )
974
+
975
+ # Word embedding
976
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
977
+ if pad_vocab_size_multiple > 1:
978
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
979
+ state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
980
+ word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
981
+ )
982
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
983
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
984
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
985
+ )
986
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
987
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
988
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
989
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
990
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
991
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
992
+ )
993
+
994
+ return state_dict
995
+
996
+
997
+ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
998
+ """
999
+ Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
1000
+
1001
+ This function is meant to be the inverse of remap_state_dict.
1002
+ """
1003
+ # Word embedding
1004
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
1005
+ if pad_vocab_size_multiple > 1:
1006
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
1007
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
1008
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
1009
+ # unpad embeddings
1010
+ state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
1011
+ : config.orig_vocab_size, :
1012
+ ]
1013
+ state_dict["cls.predictions.decoder.weight"] = decoder_weight[
1014
+ : config.orig_vocab_size, :
1015
+ ]
1016
+ state_dict["cls.predictions.decoder.bias"] = decoder_bias[
1017
+ : config.orig_vocab_size
1018
+ ]
1019
+
1020
+ for d in range(config.num_hidden_layers):
1021
+ last_layer_subset = getattr(config, "last_layer_subset", False)
1022
+ if not last_layer_subset or d != (config.num_hidden_layers - 1):
1023
+ Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
1024
+ Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
1025
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
1026
+ Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
1027
+ )
1028
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
1029
+ Wqkv_weights[
1030
+ Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
1031
+ ]
1032
+ )
1033
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
1034
+ Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
1035
+ )
1036
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = (
1037
+ Wqkv_biases[: Wqkv_biases.shape[0] // 3]
1038
+ )
1039
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = (
1040
+ Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
1041
+ )
1042
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
1043
+ Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
1044
+ )
1045
+ else:
1046
+ Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
1047
+ Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
1048
+ Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
1049
+ Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
1050
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
1051
+ Wq_weight
1052
+ )
1053
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
1054
+ Wkv_weights[: Wkv_weights.shape[0] // 2, :]
1055
+ )
1056
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
1057
+ Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
1058
+ )
1059
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
1060
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
1061
+ : Wkv_biases.shape[0] // 2
1062
+ ]
1063
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
1064
+ Wkv_biases[Wkv_biases.shape[0] // 2 :]
1065
+ )
1066
+
1067
+ def inv_key_mapping_ln(key):
1068
+ key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
1069
+ key = re.sub(
1070
+ r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
1071
+ r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
1072
+ key,
1073
+ )
1074
+ key = re.sub(
1075
+ r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
1076
+ r"bert.encoder.layers.\1.output.LayerNorm.\2",
1077
+ key,
1078
+ )
1079
+ key = re.sub(
1080
+ r"cls.predictions.transform.layer_norm.(weight|bias)",
1081
+ r"cls.predictions.transform.LayerNorm.\1",
1082
+ key,
1083
+ )
1084
+ return key
1085
+
1086
+ def inv_key_mapping_ln_gamma_beta(key):
1087
+ key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
1088
+ key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
1089
+ return key
1090
+
1091
+ def inv_key_mapping_layers(key):
1092
+ return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)
1093
+
1094
+ def inv_key_mapping_mlp(key):
1095
+ key = re.sub(
1096
+ r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
1097
+ r"bert.encoder.layer.\1.intermediate.dense.\2",
1098
+ key,
1099
+ )
1100
+ key = re.sub(
1101
+ r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
1102
+ r"bert.encoder.layer.\1.output.dense.\2",
1103
+ key,
1104
+ )
1105
+ return key
1106
+
1107
+ def inv_key_mapping_attn(key):
1108
+ return re.sub(
1109
+ r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
1110
+ r"bert.encoder.layer.\1.attention.output.dense.\2",
1111
+ key,
1112
+ )
1113
+
1114
+ def inv_key_mapping_decoder_bias(key):
1115
+ return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
1116
+
1117
+ state_dict = OrderedDict(
1118
+ (inv_key_mapping_ln(key), value) for key, value in state_dict.items()
1119
+ )
1120
+ state_dict = OrderedDict(
1121
+ (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
1122
+ )
1123
+ state_dict = OrderedDict(
1124
+ (inv_key_mapping_layers(key), value) for key, value in state_dict.items()
1125
+ )
1126
+ state_dict = OrderedDict(
1127
+ (inv_key_mapping_mlp(key), value) for key, value in state_dict.items()
1128
+ )
1129
+ state_dict = OrderedDict(
1130
+ (inv_key_mapping_attn(key), value) for key, value in state_dict.items()
1131
+ )
1132
+ state_dict = OrderedDict(
1133
+ (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
1134
+ )
1135
+
1136
+ return state_dict
1137
+
1138
+
1139
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
1140
+ class XLMRobertaClassificationHead(nn.Module):
1141
+ """Head for sentence-level classification tasks."""
1142
+
1143
+ def __init__(self, config):
1144
+ super().__init__()
1145
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
1146
+ if fused_bias_fc and FusedDense is None:
1147
+ raise ImportError("fused_dense is not installed")
1148
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
1149
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
1150
+ classifier_dropout = (
1151
+ config.classifier_dropout
1152
+ if config.classifier_dropout is not None
1153
+ else config.hidden_dropout_prob
1154
+ )
1155
+ self.dropout = nn.Dropout(classifier_dropout)
1156
+ self.out_proj = linear_cls(config.hidden_size, config.num_labels)
1157
+
1158
+ def forward(self, features, **kwargs):
1159
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1160
+ x = self.dropout(x)
1161
+ x = self.dense(x)
1162
+ x = torch.tanh(x)
1163
+ x = self.dropout(x)
1164
+ x = self.out_proj(x)
1165
+ return x
1166
+
1167
+
1168
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
1169
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
1170
+ def __init__(self, config):
1171
+ super().__init__(config)
1172
+ self.num_labels = config.num_labels
1173
+ self.config = config
1174
+
1175
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
1176
+ self.classifier = XLMRobertaClassificationHead(config)
1177
+
1178
+ # Initialize weights and apply final processing
1179
+ self.post_init()
1180
+
1181
+ def forward(
1182
+ self,
1183
+ input_ids: Optional[torch.LongTensor] = None,
1184
+ attention_mask: Optional[torch.FloatTensor] = None,
1185
+ token_type_ids: Optional[torch.LongTensor] = None,
1186
+ position_ids: Optional[torch.LongTensor] = None,
1187
+ head_mask: Optional[torch.FloatTensor] = None,
1188
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1189
+ labels: Optional[torch.LongTensor] = None,
1190
+ output_attentions: Optional[bool] = None,
1191
+ output_hidden_states: Optional[bool] = None,
1192
+ return_dict: Optional[bool] = None,
1193
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1194
+ r"""
1195
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1196
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1197
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1198
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1199
+ """
1200
+ return_dict = (
1201
+ return_dict if return_dict is not None else self.config.use_return_dict
1202
+ )
1203
+
1204
+ outputs = self.roberta(
1205
+ input_ids,
1206
+ attention_mask=attention_mask,
1207
+ token_type_ids=token_type_ids,
1208
+ position_ids=position_ids,
1209
+ head_mask=head_mask,
1210
+ inputs_embeds=inputs_embeds,
1211
+ output_attentions=output_attentions,
1212
+ output_hidden_states=output_hidden_states,
1213
+ return_dict=return_dict,
1214
+ )
1215
+ sequence_output = outputs[0]
1216
+ logits = self.classifier(sequence_output)
1217
+
1218
+ loss = None
1219
+ if labels is not None:
1220
+ # move labels to correct device to enable model parallelism
1221
+ labels = labels.to(logits.device)
1222
+ if self.config.problem_type is None:
1223
+ if self.num_labels == 1:
1224
+ self.config.problem_type = "regression"
1225
+ elif self.num_labels > 1 and (
1226
+ labels.dtype == torch.long or labels.dtype == torch.int
1227
+ ):
1228
+ self.config.problem_type = "single_label_classification"
1229
+ else:
1230
+ self.config.problem_type = "multi_label_classification"
1231
+
1232
+ if self.config.problem_type == "regression":
1233
+ loss_fct = MSELoss()
1234
+ if self.num_labels == 1:
1235
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1236
+ else:
1237
+ loss = loss_fct(logits, labels)
1238
+ elif self.config.problem_type == "single_label_classification":
1239
+ loss_fct = CrossEntropyLoss()
1240
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1241
+ elif self.config.problem_type == "multi_label_classification":
1242
+ loss_fct = BCEWithLogitsLoss()
1243
+ loss = loss_fct(logits, labels)
1244
+
1245
+ if not return_dict:
1246
+ output = (logits,) + outputs[2:]
1247
+ return ((loss,) + output) if loss is not None else output
1248
+
1249
+ return SequenceClassifierOutput(
1250
+ loss=loss,
1251
+ logits=logits,
1252
+ hidden_states=outputs.hidden_states,
1253
+ attentions=outputs.attentions,
1254
+ )
rotary.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py
2
+ # Commit id: 3566596ad867ee415dd3c12616dd50c610176f6c
3
+ # Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
4
+
5
+ # Copyright (c) 2023, Tri Dao.
6
+
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ from einops import rearrange, repeat
11
+
12
+ if torch.cuda.is_available():
13
+ try:
14
+ from flash_attn.ops.triton.rotary import apply_rotary
15
+ except ImportError:
16
+
17
+ def apply_rotary(*args, **kwargs):
18
+ raise RuntimeError(
19
+ "FlashAttention is not installed. To proceed with training, please install FlashAttention. "
20
+ "For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model."
21
+ )
22
+
23
+
24
+ def rotate_half(x, interleaved=False):
25
+ if not interleaved:
26
+ x1, x2 = x.chunk(2, dim=-1)
27
+ return torch.cat((-x2, x1), dim=-1)
28
+ else:
29
+ x1, x2 = x[..., ::2], x[..., 1::2]
30
+ return rearrange(
31
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
32
+ )
33
+
34
+
35
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
36
+ """
37
+ x: (batch_size, seqlen, nheads, headdim)
38
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
39
+ """
40
+ ro_dim = cos.shape[-1] * 2
41
+ assert ro_dim <= x.shape[-1]
42
+ cos, sin = (
43
+ cos[: x.shape[1]],
44
+ sin[: x.shape[1]],
45
+ )
46
+ cos = repeat(
47
+ cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
48
+ )
49
+ sin = repeat(
50
+ sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
51
+ )
52
+ return torch.cat(
53
+ [
54
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
55
+ x[..., ro_dim:],
56
+ ],
57
+ dim=-1,
58
+ )
59
+
60
+
61
+ class ApplyRotaryEmb(torch.autograd.Function):
62
+ @staticmethod
63
+ def forward(
64
+ ctx,
65
+ x,
66
+ cos,
67
+ sin,
68
+ interleaved=False,
69
+ inplace=False,
70
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
71
+ cu_seqlens: Optional[torch.Tensor] = None,
72
+ max_seqlen: Optional[int] = None,
73
+ ):
74
+ out = apply_rotary(
75
+ x,
76
+ cos,
77
+ sin,
78
+ seqlen_offsets=seqlen_offsets,
79
+ cu_seqlens=cu_seqlens,
80
+ max_seqlen=max_seqlen,
81
+ interleaved=interleaved,
82
+ inplace=inplace,
83
+ )
84
+
85
+ if isinstance(seqlen_offsets, int):
86
+ ctx.save_for_backward(
87
+ cos, sin, cu_seqlens
88
+ ) # Can't save int with save_for_backward
89
+ ctx.seqlen_offsets = seqlen_offsets
90
+ else:
91
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
92
+ ctx.seqlen_offsets = None
93
+ ctx.interleaved = interleaved
94
+ ctx.inplace = inplace
95
+ ctx.max_seqlen = max_seqlen
96
+ return out if not inplace else x
97
+
98
+ @staticmethod
99
+ def backward(ctx, do):
100
+ seqlen_offsets = ctx.seqlen_offsets
101
+ if seqlen_offsets is None:
102
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
103
+ else:
104
+ cos, sin, cu_seqlens = ctx.saved_tensors
105
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
106
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
107
+ if not ctx.interleaved and not ctx.inplace:
108
+ do = do.clone()
109
+
110
+ dx = apply_rotary(
111
+ do,
112
+ cos,
113
+ sin,
114
+ seqlen_offsets=seqlen_offsets,
115
+ cu_seqlens=cu_seqlens,
116
+ max_seqlen=ctx.max_seqlen,
117
+ interleaved=ctx.interleaved,
118
+ inplace=ctx.inplace,
119
+ conjugate=True,
120
+ )
121
+ return dx, None, None, None, None, None, None, None
122
+
123
+
124
+ def apply_rotary_emb(
125
+ x,
126
+ cos,
127
+ sin,
128
+ interleaved=False,
129
+ inplace=False,
130
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
131
+ cu_seqlens: Optional[torch.Tensor] = None,
132
+ max_seqlen: Optional[int] = None,
133
+ ):
134
+ """
135
+ Arguments:
136
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
137
+ else (total_seqlen, nheads, headdim)
138
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
139
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
140
+ of 1st half and 2nd half (GPT-NeoX style).
141
+ inplace: if True, apply rotary embedding in-place.
142
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
143
+ Most commonly used in inference when we have KV cache.
144
+ cu_seqlens: (batch + 1,) or None
145
+ max_seqlen: int
146
+ Return:
147
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
148
+ else (total_seqlen, nheads, headdim)
149
+ rotary_dim must be <= headdim
150
+ Apply rotary embedding to the first rotary_dim of x.
151
+ """
152
+ return ApplyRotaryEmb.apply(
153
+ x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
154
+ )
155
+
156
+
157
+ # For backward compatibility
158
+ apply_rotary_emb_func = apply_rotary_emb
159
+
160
+
161
+ class ApplyRotaryEmbQKV_(torch.autograd.Function):
162
+ @staticmethod
163
+ def forward(
164
+ ctx,
165
+ qkv,
166
+ cos,
167
+ sin,
168
+ cos_k=None,
169
+ sin_k=None,
170
+ interleaved=False,
171
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
172
+ cu_seqlens: Optional[torch.Tensor] = None,
173
+ max_seqlen: Optional[int] = None,
174
+ use_flash_attn: bool = True,
175
+ ):
176
+ # batch, seqlen, three, nheads, headdim = qkv.shape
177
+ assert qkv.shape[-3] == 3
178
+ if cos_k is None and sin_k is None and qkv.is_contiguous():
179
+
180
+ if use_flash_attn:
181
+ # Call 1 kernel instead of 2 kernels
182
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
183
+ # dimensions, we get the same tensor
184
+ qk = rearrange(qkv[..., :2, :, :], "... t h d -> ... (t h) d")
185
+ # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
186
+ apply_rotary(
187
+ qk,
188
+ cos,
189
+ sin,
190
+ seqlen_offsets=seqlen_offsets,
191
+ interleaved=interleaved,
192
+ inplace=True,
193
+ cu_seqlens=cu_seqlens,
194
+ max_seqlen=max_seqlen,
195
+ )
196
+ else:
197
+ q_rot = apply_rotary_emb_torch(
198
+ qkv[:, :, 0],
199
+ cos,
200
+ sin,
201
+ interleaved=interleaved,
202
+ )
203
+ k_rot = apply_rotary_emb_torch(
204
+ qkv[:, :, 1],
205
+ cos,
206
+ sin,
207
+ interleaved=interleaved,
208
+ )
209
+ qkv = torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
210
+ else:
211
+ cos_k = cos if cos_k is None else cos_k
212
+ sin_k = sin if sin_k is None else sin_k
213
+ q, k = qkv[..., 0, :, :], qkv[..., 1, :, :]
214
+ apply_rotary(
215
+ q,
216
+ cos,
217
+ sin,
218
+ seqlen_offsets,
219
+ interleaved=interleaved,
220
+ inplace=True,
221
+ cu_seqlens=cu_seqlens,
222
+ max_seqlen=max_seqlen,
223
+ )
224
+ apply_rotary(
225
+ k,
226
+ cos_k,
227
+ sin_k,
228
+ seqlen_offsets,
229
+ interleaved=interleaved,
230
+ inplace=True,
231
+ cu_seqlens=cu_seqlens,
232
+ max_seqlen=max_seqlen,
233
+ )
234
+ ctx.save_for_backward(cos, sin, cos_k, sin_k)
235
+ if isinstance(seqlen_offsets, int):
236
+ ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens)
237
+ ctx.seqlen_offsets = seqlen_offsets
238
+ else:
239
+ ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets)
240
+ ctx.seqlen_offsets = None
241
+ ctx.max_seqlen = max_seqlen
242
+ ctx.interleaved = interleaved
243
+ return qkv
244
+
245
+ @staticmethod
246
+ def backward(ctx, dqkv):
247
+ seqlen_offsets = ctx.seqlen_offsets
248
+ if seqlen_offsets is None:
249
+ cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors
250
+ else:
251
+ cos, sin, cos_k, sin_k, cu_seqlens = ctx.saved_tensors
252
+ if cos_k is None and sin_k is None and dqkv.is_contiguous():
253
+ # Call 1 kernel instead of 2 kernels
254
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
255
+ # dimensions, we get the same tensor
256
+ dqk = rearrange(dqkv[..., :2, :, :], "... t h d -> ... (t h) d")
257
+ apply_rotary(
258
+ dqk,
259
+ cos,
260
+ sin,
261
+ seqlen_offsets=seqlen_offsets,
262
+ interleaved=ctx.interleaved,
263
+ inplace=True,
264
+ conjugate=True,
265
+ cu_seqlens=cu_seqlens,
266
+ max_seqlen=ctx.max_seqlen,
267
+ )
268
+ else:
269
+ cos_k = cos if cos_k is None else cos_k
270
+ sin_k = sin if sin_k is None else sin_k
271
+ dq, dk = dqkv[..., 0, :, :], dqkv[..., 1, :, :]
272
+ apply_rotary(
273
+ dq,
274
+ cos,
275
+ sin,
276
+ seqlen_offsets,
277
+ interleaved=ctx.interleaved,
278
+ inplace=True,
279
+ conjugate=True,
280
+ cu_seqlens=cu_seqlens,
281
+ max_seqlen=ctx.max_seqlen,
282
+ )
283
+ apply_rotary(
284
+ dk,
285
+ cos_k,
286
+ sin_k,
287
+ seqlen_offsets,
288
+ interleaved=ctx.interleaved,
289
+ inplace=True,
290
+ conjugate=True,
291
+ cu_seqlens=cu_seqlens,
292
+ max_seqlen=ctx.max_seqlen,
293
+ )
294
+ return dqkv, None, None, None, None, None, None, None, None, None
295
+
296
+
297
+ def apply_rotary_emb_qkv_(
298
+ qkv,
299
+ cos,
300
+ sin,
301
+ cos_k=None,
302
+ sin_k=None,
303
+ interleaved=False,
304
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
305
+ cu_seqlens: Optional[torch.Tensor] = None,
306
+ max_seqlen: Optional[int] = None,
307
+ use_flash_attn=True,
308
+ ):
309
+ """
310
+ Arguments:
311
+ qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
312
+ else (total_seqlen, 3, nheads, headdim)
313
+ cos, sin: (seqlen, rotary_dim / 2)
314
+ cos_k, sin_k: (seqlen, rotary_dim / 2), optional
315
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
316
+ 1st half and 2nd half (GPT-NeoX style).
317
+ seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
318
+ Most commonly used in inference when we have KV cache.
319
+ cu_seqlens: (batch + 1,) or None
320
+ max_seqlen: int
321
+ Return:
322
+ qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
323
+ else (total_seqlen, 3, nheads, headdim)
324
+ rotary_dim must be <= headdim
325
+ Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
326
+ """
327
+ return ApplyRotaryEmbQKV_.apply(
328
+ qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn,
329
+ )
330
+
331
+
332
+ class ApplyRotaryEmbKV_(torch.autograd.Function):
333
+ @staticmethod
334
+ def forward(
335
+ ctx,
336
+ kv,
337
+ cos,
338
+ sin,
339
+ interleaved=False,
340
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
341
+ cu_seqlens: Optional[torch.Tensor] = None,
342
+ max_seqlen: Optional[int] = None,
343
+ ):
344
+ # batch, seqlen, two, nheads, headdim = kv.shape
345
+ assert kv.shape[-3] == 2
346
+ k = kv[..., 0, :, :]
347
+ apply_rotary(
348
+ k,
349
+ cos,
350
+ sin,
351
+ seqlen_offsets=seqlen_offsets,
352
+ interleaved=interleaved,
353
+ inplace=True,
354
+ cu_seqlens=cu_seqlens,
355
+ max_seqlen=max_seqlen,
356
+ )
357
+ if isinstance(seqlen_offsets, int):
358
+ ctx.save_for_backward(
359
+ cos, sin, cu_seqlens
360
+ ) # Can't save int with save_for_backward
361
+ ctx.seqlen_offsets = seqlen_offsets
362
+ else:
363
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
364
+ ctx.seqlen_offsets = None
365
+ ctx.max_seqlen = max_seqlen
366
+ ctx.interleaved = interleaved
367
+ return kv
368
+
369
+ @staticmethod
370
+ def backward(ctx, dkv):
371
+ seqlen_offsets = ctx.seqlen_offsets
372
+ if seqlen_offsets is None:
373
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
374
+ else:
375
+ cos, sin, cu_seqlens = ctx.saved_tensors
376
+ apply_rotary(
377
+ dkv[..., 0, :, :],
378
+ cos,
379
+ sin,
380
+ seqlen_offsets=seqlen_offsets,
381
+ interleaved=ctx.interleaved,
382
+ inplace=True,
383
+ conjugate=True,
384
+ cu_seqlens=cu_seqlens,
385
+ max_seqlen=ctx.max_seqlen,
386
+ )
387
+ return dkv, None, None, None, None, None, None
388
+
389
+
390
+ apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
391
+
392
+
393
+ def apply_rotary_emb_kv_(
394
+ kv,
395
+ cos,
396
+ sin,
397
+ interleaved=False,
398
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
399
+ cu_seqlens: Optional[torch.Tensor] = None,
400
+ max_seqlen: Optional[int] = None,
401
+ ):
402
+ """
403
+ Arguments:
404
+ kv: (batch_size, seqlen, 2, nheads, headdim) if cu_seqlens is None
405
+ else (total_seqlen, 2, nheads, headdim)
406
+ cos, sin: (seqlen, rotary_dim / 2)
407
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
408
+ 1st half and 2nd half (GPT-NeoX style).
409
+ seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
410
+ Most commonly used in inference when we have KV cache.
411
+ cu_seqlens: (batch + 1,) or None
412
+ max_seqlen: int
413
+ Return:
414
+ kv: (batch_size, seqlen, 2, nheads, headdim) if cu_seqlens is None
415
+ else (total_seqlen, 2, nheads, headdim)
416
+ rotary_dim must be <= headdim
417
+ Apply rotary embedding *inplace* to the first rotary_dim of K.
418
+ """
419
+ return ApplyRotaryEmbKV_.apply(
420
+ kv, cos, sin, interleaved, seqlen_offsets, cu_seqlens, max_seqlen
421
+ )
422
+
423
+
424
+ class RotaryEmbedding(torch.nn.Module):
425
+ """
426
+ The rotary position embeddings from RoFormer_ (Su et. al).
427
+ A crucial insight from the method is that the query and keys are
428
+ transformed by rotation matrices which depend on the relative positions.
429
+
430
+ Other implementations are available in the Rotary Transformer repo_ and in
431
+ GPT-NeoX_, GPT-NeoX was an inspiration
432
+
433
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
434
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
435
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
436
+
437
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
438
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
439
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
440
+ """
441
+
442
+ def __init__(
443
+ self,
444
+ dim: int,
445
+ base=10000.0,
446
+ interleaved=False,
447
+ scale_base=None,
448
+ pos_idx_in_fp32=True,
449
+ device=None,
450
+ use_flash_attn=True,
451
+ ):
452
+ """
453
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
454
+ of 1st half and 2nd half (GPT-NeoX style).
455
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
456
+ otherwise they might be in lower precision.
457
+ This option was added because previously (before 2023-07-02), when we construct
458
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
459
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
460
+ self.inv_freq would be bf16, and the position indices are also in bf16.
461
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
462
+ embeddings for some positions will coincide.
463
+ To maintain compatibility with models previously trained in pure bf16,
464
+ we add this option.
465
+ """
466
+ super().__init__()
467
+ self.dim = dim
468
+ self._base = float(base)
469
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
470
+ self.use_flash_attn = use_flash_attn
471
+ # Generate and save the inverse frequency buffer (non trainable)
472
+ inv_freq = self._compute_inv_freq(device)
473
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
474
+ self.interleaved = interleaved
475
+ self.scale_base = scale_base
476
+ scale = (
477
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
478
+ / (1.4 * dim)
479
+ if scale_base is not None
480
+ else None
481
+ )
482
+ self.register_buffer("scale", scale, persistent=False)
483
+
484
+ self._seq_len_cached = 0
485
+ self._cos_cached = None
486
+ self._sin_cached = None
487
+ self._cos_k_cached = None
488
+ self._sin_k_cached = None
489
+
490
+ @property
491
+ def base(self):
492
+ return self._base
493
+
494
+ @base.setter
495
+ def base(self, new_base):
496
+ new_base = float(new_base)
497
+ if new_base > 0:
498
+ if self._base != new_base: # only update if the base value has changed
499
+ self._base = new_base
500
+ self._update_cos_sin_cache(
501
+ self._seq_len_cached,
502
+ device=self.inv_freq.device,
503
+ dtype=self._cos_cached.dtype if self._cos_cached is not None else None,
504
+ rotary_base_changed=True,
505
+ )
506
+ else:
507
+ raise ValueError("Rotary base value must be positive")
508
+
509
+ def _compute_inv_freq(self, device=None):
510
+ return 1.0 / (
511
+ self.base
512
+ ** (
513
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
514
+ / self.dim
515
+ )
516
+ )
517
+
518
+ def _update_cos_sin_cache(
519
+ self, seqlen, device=None, dtype=None, rotary_base_changed=False
520
+ ):
521
+ # Reset the tables if the sequence length has changed,
522
+ # if we're on a new device (possibly due to tracing for instance),
523
+ # or if we're switching from inference mode to training
524
+ # or if the rotary base value was changed
525
+ if (
526
+ seqlen > self._seq_len_cached
527
+ or self._cos_cached is None
528
+ or self._cos_cached.device != device
529
+ or self._cos_cached.dtype != dtype
530
+ or (self.training and self._cos_cached.is_inference())
531
+ or rotary_base_changed
532
+ ):
533
+ self._seq_len_cached = seqlen
534
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
535
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
536
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
537
+ if rotary_base_changed:
538
+ self.inv_freq = self._compute_inv_freq(device=device)
539
+ if self.pos_idx_in_fp32:
540
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
541
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
542
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
543
+ # cos & sin output to change significantly.
544
+ # We want to recompute self.inv_freq if it was not loaded in fp32
545
+ if self.inv_freq.dtype != torch.float32:
546
+ inv_freq = self._compute_inv_freq(device=device)
547
+ else:
548
+ inv_freq = self.inv_freq
549
+ else:
550
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
551
+ inv_freq = self.inv_freq
552
+
553
+ # Don't do einsum, it converts fp32 to fp16 under AMP
554
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
555
+ freqs = torch.outer(t, inv_freq)
556
+ if self.scale is None:
557
+ self._cos_cached = torch.cos(freqs).to(dtype)
558
+ self._sin_cached = torch.sin(freqs).to(dtype)
559
+ else:
560
+ power = (
561
+ torch.arange(
562
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
563
+ )
564
+ - seqlen // 2
565
+ ) / self.scale_base
566
+ scale = self.scale.to(device=power.device) ** rearrange(
567
+ power, "s -> s 1"
568
+ )
569
+ # We want the multiplication by scale to happen in fp32
570
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
571
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
572
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
573
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
574
+
575
+ def forward(
576
+ self,
577
+ qkv: torch.Tensor,
578
+ kv: Optional[torch.Tensor] = None,
579
+ seqlen_offset: Union[int, torch.Tensor] = 0,
580
+ cu_seqlens: Optional[torch.Tensor] = None,
581
+ max_seqlen: Optional[int] = None,
582
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
583
+ """
584
+ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
585
+ else it's just q of shape (batch, seqlen, nheads, headdim)
586
+ kv: (batch, seqlen, 2, nheads, headdim)
587
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
588
+ Most commonly used in inference when we have KV cache.
589
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
590
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
591
+ Apply rotary embedding *inplace* to qkv and / or kv.
592
+ """
593
+ if cu_seqlens is not None:
594
+ assert max_seqlen is not None
595
+ seqlen = qkv.shape[1] if max_seqlen is None else max_seqlen
596
+ if max_seqlen is not None:
597
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
598
+ elif isinstance(seqlen_offset, int):
599
+ self._update_cos_sin_cache(
600
+ seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype
601
+ )
602
+ if kv is None:
603
+ if self.scale is None:
604
+ return apply_rotary_emb_qkv_(
605
+ qkv,
606
+ self._cos_cached,
607
+ self._sin_cached,
608
+ interleaved=self.interleaved,
609
+ seqlen_offsets=seqlen_offset,
610
+ cu_seqlens=cu_seqlens,
611
+ max_seqlen=max_seqlen,
612
+ use_flash_attn=self.use_flash_attn,
613
+ )
614
+ else:
615
+ return apply_rotary_emb_qkv_(
616
+ qkv,
617
+ self._cos_cached,
618
+ self._sin_cached,
619
+ self._cos_k_cached,
620
+ self._sin_k_cached,
621
+ interleaved=self.interleaved,
622
+ seqlen_offsets=seqlen_offset,
623
+ cu_seqlens=cu_seqlens,
624
+ max_seqlen=max_seqlen,
625
+ use_flash_attn=self.use_flash_attn,
626
+ )
627
+ else:
628
+ q = qkv
629
+ q = apply_rotary_emb_func(
630
+ q,
631
+ self._cos_cached,
632
+ self._sin_cached,
633
+ interleaved=self.interleaved,
634
+ inplace=True,
635
+ seqlen_offsets=seqlen_offset,
636
+ cu_seqlens=cu_seqlens,
637
+ max_seqlen=max_seqlen,
638
+ )
639
+ if self.scale is None:
640
+ kv = apply_rotary_emb_kv_(
641
+ kv,
642
+ self._cos_cached,
643
+ self._sin_cached,
644
+ interleaved=self.interleaved,
645
+ seqlen_offsets=seqlen_offset,
646
+ cu_seqlens=cu_seqlens,
647
+ max_seqlen=max_seqlen,
648
+ )
649
+ else:
650
+ kv = apply_rotary_emb_kv_(
651
+ kv,
652
+ self._cos_k_cached,
653
+ self._sin_k_cached,
654
+ interleaved=self.interleaved,
655
+ seqlen_offsets=seqlen_offset,
656
+ cu_seqlens=cu_seqlens,
657
+ max_seqlen=max_seqlen,
658
+ )
659
+ return q, kv
stochastic_depth.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation modified from torchvision:
2
+ # https://github.com/pytorch/vision/blob/main/torchvision/ops/stochastic_depth.py
3
+ #
4
+ # License:
5
+ # BSD 3-Clause License
6
+ #
7
+ # Copyright (c) Soumith Chintala 2016,
8
+ # All rights reserved.
9
+ #
10
+ # Redistribution and use in source and binary forms, with or without
11
+ # modification, are permitted provided that the following conditions are met:
12
+ #
13
+ # * Redistributions of source code must retain the above copyright notice, this
14
+ # list of conditions and the following disclaimer.
15
+ #
16
+ # * Redistributions in binary form must reproduce the above copyright notice,
17
+ # this list of conditions and the following disclaimer in the documentation
18
+ # and/or other materials provided with the distribution.
19
+ #
20
+ # * Neither the name of the copyright holder nor the names of its
21
+ # contributors may be used to endorse or promote products derived from
22
+ # this software without specific prior written permission.
23
+ #
24
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34
+
35
+ import torch
36
+ import torch.fx
37
+ from torch import Tensor, nn
38
+
39
+
40
+ def stochastic_depth(
41
+ input: Tensor, p: float, mode: str, training: bool = True
42
+ ) -> Tensor:
43
+ """
44
+ Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
45
+ <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
46
+ branches of residual architectures.
47
+
48
+ Args:
49
+ input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
50
+ being its batch i.e. a batch with ``N`` rows.
51
+ p (float): probability of the input to be zeroed.
52
+ mode (str): ``"batch"`` or ``"row"``.
53
+ ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
54
+ randomly selected rows from the batch.
55
+ training: apply stochastic depth if is ``True``. Default: ``True``
56
+
57
+ Returns:
58
+ Tensor[N, ...]: The randomly zeroed tensor.
59
+ """
60
+ if p < 0.0 or p > 1.0:
61
+ raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
62
+ if mode not in ["batch", "row"]:
63
+ raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
64
+ if not training or p == 0.0:
65
+ return input
66
+
67
+ survival_rate = 1.0 - p
68
+ if mode == "row":
69
+ size = [input.shape[0]] + [1] * (input.ndim - 1)
70
+ else:
71
+ size = [1] * input.ndim
72
+ noise = torch.empty(size, dtype=input.dtype, device=input.device)
73
+ noise = noise.bernoulli_(survival_rate)
74
+ if survival_rate > 0.0:
75
+ noise.div_(survival_rate)
76
+ return input * noise
77
+
78
+
79
+ torch.fx.wrap("stochastic_depth")
80
+
81
+
82
+ class StochasticDepth(nn.Module):
83
+ """
84
+ See :func:`stochastic_depth`.
85
+ """
86
+
87
+ def __init__(self, p: float, mode: str) -> None:
88
+ super().__init__()
89
+ self.p = p
90
+ self.mode = mode
91
+
92
+ def forward(self, input: Tensor) -> Tensor:
93
+ return stochastic_depth(input, self.p, self.mode, self.training)
94
+
95
+ def __repr__(self) -> str:
96
+ s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
97
+ return s
xlm_padding.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: c94cd09744d20f0ac587a351ff6ff2e8ad11ae1b
3
+
4
+ # Previously adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+
10
+
11
+ class IndexFirstAxis(torch.autograd.Function):
12
+ @staticmethod
13
+ def forward(ctx, input, indices):
14
+ ctx.save_for_backward(indices)
15
+ assert input.ndim >= 2
16
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
17
+ second_dim = other_shape.numel()
18
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19
+ # return input[indices]
20
+ return torch.gather(
21
+ rearrange(input, "b ... -> b (...)"),
22
+ 0,
23
+ repeat(indices, "z -> z d", d=second_dim),
24
+ ).reshape(-1, *other_shape)
25
+
26
+ @staticmethod
27
+ def backward(ctx, grad_output):
28
+ (indices,) = ctx.saved_tensors
29
+ assert grad_output.ndim >= 2
30
+ other_shape = grad_output.shape[1:]
31
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
32
+ grad_input = torch.zeros(
33
+ [ctx.first_axis_dim, grad_output.shape[1]],
34
+ device=grad_output.device,
35
+ dtype=grad_output.dtype,
36
+ )
37
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
38
+ # grad_input[indices] = grad_output
39
+ grad_input.scatter_(
40
+ 0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output
41
+ )
42
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
43
+
44
+
45
+ index_first_axis = IndexFirstAxis.apply
46
+
47
+
48
+ class IndexPutFirstAxis(torch.autograd.Function):
49
+ @staticmethod
50
+ def forward(ctx, values, indices, first_axis_dim):
51
+ ctx.save_for_backward(indices)
52
+ assert indices.ndim == 1
53
+ assert values.ndim >= 2
54
+ output = torch.zeros(
55
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
56
+ )
57
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
58
+ output[indices] = values
59
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
60
+ return output
61
+
62
+ @staticmethod
63
+ def backward(ctx, grad_output):
64
+ (indices,) = ctx.saved_tensors
65
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
66
+ grad_values = grad_output[indices]
67
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
68
+ return grad_values, None, None
69
+
70
+
71
+ index_put_first_axis = IndexPutFirstAxis.apply
72
+
73
+
74
+ class IndexFirstAxisResidual(torch.autograd.Function):
75
+ @staticmethod
76
+ def forward(ctx, input, indices):
77
+ ctx.save_for_backward(indices)
78
+ assert input.ndim >= 2
79
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
80
+ second_dim = other_shape.numel()
81
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
82
+ output = input[indices]
83
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
84
+ # memory format to channel_first. In other words, input might not be contiguous.
85
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
86
+ return output, input.detach()
87
+
88
+ @staticmethod
89
+ def backward(ctx, grad_output, grad_residual):
90
+ (indices,) = ctx.saved_tensors
91
+ assert grad_output.ndim >= 2
92
+ other_shape = grad_output.shape[1:]
93
+ assert grad_residual.shape[1:] == other_shape
94
+ grad_input = grad_residual
95
+ # grad_input[indices] += grad_output
96
+ indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
97
+ indices = indices.expand_as(grad_output)
98
+ grad_input.scatter_add_(0, indices, grad_output)
99
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
100
+
101
+
102
+ index_first_axis_residual = IndexFirstAxisResidual.apply
103
+
104
+
105
+ def unpad_input(hidden_states, attention_mask, adapter_mask=None):
106
+ """
107
+ Arguments:
108
+ hidden_states: (batch, seqlen, ...)
109
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
110
+ Return:
111
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
112
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
113
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
114
+ max_seqlen_in_batch: int
115
+ """
116
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
117
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
118
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
119
+ cu_seqlens = F.pad(
120
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
121
+ )
122
+
123
+ cu_adapter_mask = (
124
+ torch.repeat_interleave(adapter_mask, cu_seqlens[1:] - cu_seqlens[:-1])
125
+ if adapter_mask is not None
126
+ else None
127
+ )
128
+
129
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
130
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
131
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
132
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
133
+ # so we write custom forward and backward to make it a bit faster.
134
+ return (
135
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
136
+ indices,
137
+ cu_seqlens,
138
+ max_seqlen_in_batch,
139
+ cu_adapter_mask,
140
+ )
141
+
142
+
143
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
144
+ """
145
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
146
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
147
+
148
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
149
+ ```
150
+ [
151
+ [2, 3, 0, 0, 0, 0],
152
+ [3, 2, 0, 0, 0, 0],
153
+ [6, 0, 0, 0, 0, 0]
154
+ ]
155
+ ```
156
+ , which refers to the 3D-attention mask:
157
+ ```
158
+ [
159
+ [
160
+ [1, 0, 0, 0, 0, 0],
161
+ [1, 1, 0, 0, 0, 0],
162
+ [0, 0, 1, 0, 0, 0],
163
+ [0, 0, 1, 1, 0, 0],
164
+ [0, 0, 1, 1, 1, 0],
165
+ [0, 0, 0, 0, 0, 1]
166
+ ],
167
+ [
168
+ [1, 0, 0, 0, 0, 0],
169
+ [1, 1, 0, 0, 0, 0],
170
+ [1, 1, 1, 0, 0, 0],
171
+ [0, 0, 0, 1, 0, 0],
172
+ [0, 0, 0, 1, 1, 0],
173
+ [0, 0, 0, 0, 0, 1]
174
+ ],
175
+ [
176
+ [1, 0, 0, 0, 0, 0],
177
+ [1, 1, 0, 0, 0, 0],
178
+ [1, 1, 1, 0, 0, 0],
179
+ [1, 1, 1, 1, 0, 0],
180
+ [1, 1, 1, 1, 1, 0],
181
+ [1, 1, 1, 1, 1, 1]
182
+ ]
183
+ ]
184
+ ```.
185
+
186
+ Arguments:
187
+ hidden_states: (batch, seqlen, ...)
188
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
189
+ Return:
190
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
191
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
192
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
193
+ max_seqlen_in_batch: int
194
+ """
195
+ length = attention_mask_in_length.sum(dim=-1)
196
+ seqlen = attention_mask_in_length.size(-1)
197
+ attention_mask_2d = torch.arange(
198
+ seqlen, device=length.device, dtype=length.dtype
199
+ ).expand(len(length), seqlen) < length.unsqueeze(1)
200
+ real_indices_idx = torch.nonzero(
201
+ attention_mask_in_length.flatten(), as_tuple=False
202
+ ).flatten()
203
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
204
+ indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
205
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
206
+ cu_seqlens = F.pad(
207
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
208
+ )
209
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
210
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
211
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
212
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
213
+ # so we write custom forward and backward to make it a bit faster.
214
+ return (
215
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
216
+ indices,
217
+ cu_seqlens,
218
+ max_seqlen_in_batch,
219
+ )
220
+
221
+
222
+ def pad_input(hidden_states, indices, batch, seqlen):
223
+ """
224
+ Arguments:
225
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
226
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
227
+ batch: int, batch size for the padded sequence.
228
+ seqlen: int, maximum sequence length for the padded sequence.
229
+ Return:
230
+ hidden_states: (batch, seqlen, ...)
231
+ """
232
+ dim = hidden_states.shape[-1]
233
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
234
+ # output[indices] = hidden_states
235
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
236
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)