Commit
·
851aaca
1
Parent(s):
5418705
refactor: set task in lora class rather than xlm roberta
Browse files- modeling_lora.py +43 -7
- modeling_xlm_roberta.py +1 -24
modeling_lora.py
CHANGED
|
@@ -1,18 +1,17 @@
|
|
| 1 |
import math
|
| 2 |
import os
|
|
|
|
| 3 |
from functools import partial
|
| 4 |
-
from typing import Iterator, Optional, Tuple, Union
|
| 5 |
|
|
|
|
| 6 |
import torch
|
| 7 |
import torch.nn.utils.parametrize as parametrize
|
| 8 |
from torch import nn
|
| 9 |
from torch.nn import Parameter
|
| 10 |
from transformers import PretrainedConfig
|
| 11 |
|
| 12 |
-
from .modeling_xlm_roberta import
|
| 13 |
-
XLMRobertaFlashConfig,
|
| 14 |
-
XLMRobertaModel,
|
| 15 |
-
)
|
| 16 |
|
| 17 |
|
| 18 |
def initialized_weights(
|
|
@@ -231,7 +230,6 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
| 231 |
# By default, disable LoRA until it's specified which adapter/task to use
|
| 232 |
self.current_task = None
|
| 233 |
|
| 234 |
-
|
| 235 |
@property
|
| 236 |
def main_params_trainable(self):
|
| 237 |
return self._main_params_trainable
|
|
@@ -273,7 +271,8 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
| 273 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 274 |
)
|
| 275 |
else:
|
| 276 |
-
|
|
|
|
| 277 |
return cls(config)
|
| 278 |
|
| 279 |
def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
|
|
@@ -327,3 +326,40 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
| 327 |
):
|
| 328 |
if "lora" in name or self.main_params_trainable:
|
| 329 |
yield name, param
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
import os
|
| 3 |
+
import warnings
|
| 4 |
from functools import partial
|
| 5 |
+
from typing import Iterator, List, Optional, Tuple, Union
|
| 6 |
|
| 7 |
+
import numpy as np
|
| 8 |
import torch
|
| 9 |
import torch.nn.utils.parametrize as parametrize
|
| 10 |
from torch import nn
|
| 11 |
from torch.nn import Parameter
|
| 12 |
from transformers import PretrainedConfig
|
| 13 |
|
| 14 |
+
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
def initialized_weights(
|
|
|
|
| 230 |
# By default, disable LoRA until it's specified which adapter/task to use
|
| 231 |
self.current_task = None
|
| 232 |
|
|
|
|
| 233 |
@property
|
| 234 |
def main_params_trainable(self):
|
| 235 |
return self._main_params_trainable
|
|
|
|
| 271 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 272 |
)
|
| 273 |
else:
|
| 274 |
+
dtype = config.torch_dtype if config.torch_dtype else torch.bfloat16
|
| 275 |
+
torch.set_default_dtype(dtype)
|
| 276 |
return cls(config)
|
| 277 |
|
| 278 |
def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
|
|
|
|
| 326 |
):
|
| 327 |
if "lora" in name or self.main_params_trainable:
|
| 328 |
yield name, param
|
| 329 |
+
|
| 330 |
+
@torch.inference_mode()
|
| 331 |
+
def encode(
|
| 332 |
+
self,
|
| 333 |
+
*args,
|
| 334 |
+
task: Optional[str] = None,
|
| 335 |
+
**kwargs,
|
| 336 |
+
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 337 |
+
"""
|
| 338 |
+
Computes sentence embeddings
|
| 339 |
+
|
| 340 |
+
task(`str`, *optional*, defaults to None):
|
| 341 |
+
Specifies the task for which the encoding is intended. This
|
| 342 |
+
controls the use of specialized LoRA adapters that are tuned for specific tasks.
|
| 343 |
+
If provided, the corresponding LoRA adapter is enabled, enhancing the model's
|
| 344 |
+
performance for that task. If `None` or not provided, LoRA is disabled, and the
|
| 345 |
+
model uses its original, general-purpose weights.
|
| 346 |
+
"""
|
| 347 |
+
lora_adapter_num = None
|
| 348 |
+
if self.config.lora_adaptations:
|
| 349 |
+
if task:
|
| 350 |
+
if task in self.config.lora_adaptations:
|
| 351 |
+
lora_adapter_num = self.config.lora_adaptations.index(task)
|
| 352 |
+
else:
|
| 353 |
+
raise ValueError(
|
| 354 |
+
f"Unsupported task '{task}'. "
|
| 355 |
+
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
| 356 |
+
)
|
| 357 |
+
else:
|
| 358 |
+
warnings.warn(
|
| 359 |
+
f"Task-specific embeddings are disabled. To enable, specify the `task` "
|
| 360 |
+
f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
|
| 361 |
+
category=UserWarning,
|
| 362 |
+
)
|
| 363 |
+
self.current_task = lora_adapter_num
|
| 364 |
+
|
| 365 |
+
return super().encode(*args, **kwargs)
|
modeling_xlm_roberta.py
CHANGED
|
@@ -452,7 +452,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 452 |
convert_to_tensor: bool = False,
|
| 453 |
device: Optional[torch.device] = None,
|
| 454 |
normalize_embeddings: bool = False,
|
| 455 |
-
task: Optional[str] = None,
|
| 456 |
**tokenizer_kwargs,
|
| 457 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 458 |
"""
|
|
@@ -482,12 +481,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 482 |
If set to true, returned vectors will have length 1. In that case, the
|
| 483 |
faster dot-product (util.dot_score) instead of cosine similarity can
|
| 484 |
be used.
|
| 485 |
-
task(`str`, *optional*, defaults to None):
|
| 486 |
-
Specifies the task for which the encoding is intended. This
|
| 487 |
-
controls the use of specialized LoRA adapters that are tuned for specific tasks.
|
| 488 |
-
If provided, the corresponding LoRA adapter is enabled, enhancing the model's
|
| 489 |
-
performance for that task. If `None` or not provided, LoRA is disabled, and the
|
| 490 |
-
model uses its original, general-purpose weights.
|
| 491 |
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
| 492 |
Keyword arguments for the tokenizer
|
| 493 |
Returns:
|
|
@@ -525,22 +518,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 525 |
if device is not None:
|
| 526 |
self.to(device)
|
| 527 |
|
| 528 |
-
lora_adapter_num = None
|
| 529 |
-
if self.config.lora_adaptations:
|
| 530 |
-
if task:
|
| 531 |
-
if task in self.config.lora_adaptations:
|
| 532 |
-
lora_adapter_num = self.config.lora_adaptations.index(task)
|
| 533 |
-
else:
|
| 534 |
-
raise ValueError(
|
| 535 |
-
f"Unsupported task '{task}'. "
|
| 536 |
-
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}.")
|
| 537 |
-
else:
|
| 538 |
-
logger.warning(
|
| 539 |
-
f"Task-specific embeddings are disabled. To enable, specify the `task` "
|
| 540 |
-
f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}"
|
| 541 |
-
)
|
| 542 |
-
|
| 543 |
-
|
| 544 |
permutation = np.argsort([-len(i) for i in sentences])
|
| 545 |
inverse_permutation = np.argsort(permutation)
|
| 546 |
sentences = [sentences[idx] for idx in permutation]
|
|
@@ -570,7 +547,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 570 |
return_tensors='pt',
|
| 571 |
**tokenizer_kwargs,
|
| 572 |
).to(self.device)
|
| 573 |
-
token_embs = self.forward(**encoded_input
|
| 574 |
|
| 575 |
# Accumulate in fp32 to avoid overflow
|
| 576 |
token_embs = token_embs.float()
|
|
|
|
| 452 |
convert_to_tensor: bool = False,
|
| 453 |
device: Optional[torch.device] = None,
|
| 454 |
normalize_embeddings: bool = False,
|
|
|
|
| 455 |
**tokenizer_kwargs,
|
| 456 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 457 |
"""
|
|
|
|
| 481 |
If set to true, returned vectors will have length 1. In that case, the
|
| 482 |
faster dot-product (util.dot_score) instead of cosine similarity can
|
| 483 |
be used.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
| 485 |
Keyword arguments for the tokenizer
|
| 486 |
Returns:
|
|
|
|
| 518 |
if device is not None:
|
| 519 |
self.to(device)
|
| 520 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
permutation = np.argsort([-len(i) for i in sentences])
|
| 522 |
inverse_permutation = np.argsort(permutation)
|
| 523 |
sentences = [sentences[idx] for idx in permutation]
|
|
|
|
| 547 |
return_tensors='pt',
|
| 548 |
**tokenizer_kwargs,
|
| 549 |
).to(self.device)
|
| 550 |
+
token_embs = self.forward(**encoded_input)[0]
|
| 551 |
|
| 552 |
# Accumulate in fp32 to avoid overflow
|
| 553 |
token_embs = token_embs.float()
|