Commit
·
4c504d3
1
Parent(s):
98c3cd2
fix: lora bug
Browse filesSigned-off-by: Meow <[email protected]>
- modeling_lora.py +13 -8
modeling_lora.py
CHANGED
|
@@ -11,7 +11,7 @@ from torch import nn
|
|
| 11 |
from torch.nn import Parameter
|
| 12 |
from transformers import PretrainedConfig
|
| 13 |
|
| 14 |
-
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel
|
| 15 |
|
| 16 |
|
| 17 |
LORA_NO_UPDATE = '__lora_no_update__'
|
|
@@ -210,13 +210,19 @@ class LoRAParametrization(nn.Module):
|
|
| 210 |
layer.current_task = task_idx
|
| 211 |
|
| 212 |
|
| 213 |
-
class XLMRobertaLoRA(
|
| 214 |
def __init__(
|
| 215 |
self,
|
| 216 |
config: XLMRobertaFlashConfig,
|
|
|
|
| 217 |
):
|
| 218 |
super().__init__(config)
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
self._lora_adaptations = config.lora_adaptations
|
| 221 |
if (
|
| 222 |
not isinstance(self._lora_adaptations, list)
|
|
@@ -231,7 +237,6 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
| 231 |
self._rank = config.lora_rank
|
| 232 |
self._dropout_p = config.lora_dropout_p
|
| 233 |
self._alpha = config.lora_alpha
|
| 234 |
-
|
| 235 |
self._register_lora(
|
| 236 |
num_adaptations=len(self._lora_adaptations),
|
| 237 |
rank=self._rank,
|
|
@@ -284,9 +289,8 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
| 284 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 285 |
)
|
| 286 |
else:
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
return cls(config)
|
| 290 |
|
| 291 |
def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
|
| 292 |
self.apply(
|
|
@@ -331,7 +335,8 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
| 331 |
def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
|
| 332 |
if task != LORA_NO_UPDATE:
|
| 333 |
self.current_task = task
|
| 334 |
-
|
|
|
|
| 335 |
|
| 336 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
| 337 |
for _, param in self.named_parameters(recurse=recurse):
|
|
@@ -373,4 +378,4 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
| 373 |
)
|
| 374 |
self.current_task = task
|
| 375 |
|
| 376 |
-
return
|
|
|
|
| 11 |
from torch.nn import Parameter
|
| 12 |
from transformers import PretrainedConfig
|
| 13 |
|
| 14 |
+
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
|
| 15 |
|
| 16 |
|
| 17 |
LORA_NO_UPDATE = '__lora_no_update__'
|
|
|
|
| 210 |
layer.current_task = task_idx
|
| 211 |
|
| 212 |
|
| 213 |
+
class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
| 214 |
def __init__(
|
| 215 |
self,
|
| 216 |
config: XLMRobertaFlashConfig,
|
| 217 |
+
roberta: Optional[XLMRobertaModel] = None
|
| 218 |
):
|
| 219 |
super().__init__(config)
|
| 220 |
|
| 221 |
+
if roberta is None:
|
| 222 |
+
self.roberta = XLMRobertaModel(config)
|
| 223 |
+
else:
|
| 224 |
+
self.roberta = roberta
|
| 225 |
+
|
| 226 |
self._lora_adaptations = config.lora_adaptations
|
| 227 |
if (
|
| 228 |
not isinstance(self._lora_adaptations, list)
|
|
|
|
| 237 |
self._rank = config.lora_rank
|
| 238 |
self._dropout_p = config.lora_dropout_p
|
| 239 |
self._alpha = config.lora_alpha
|
|
|
|
| 240 |
self._register_lora(
|
| 241 |
num_adaptations=len(self._lora_adaptations),
|
| 242 |
rank=self._rank,
|
|
|
|
| 289 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 290 |
)
|
| 291 |
else:
|
| 292 |
+
roberta = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
| 293 |
+
return cls(config, roberta=roberta)
|
|
|
|
| 294 |
|
| 295 |
def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
|
| 296 |
self.apply(
|
|
|
|
| 335 |
def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
|
| 336 |
if task != LORA_NO_UPDATE:
|
| 337 |
self.current_task = task
|
| 338 |
+
|
| 339 |
+
return self.roberta(*args, **kwargs)
|
| 340 |
|
| 341 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
| 342 |
for _, param in self.named_parameters(recurse=recurse):
|
|
|
|
| 378 |
)
|
| 379 |
self.current_task = task
|
| 380 |
|
| 381 |
+
return self.roberta.encode(*args, **kwargs)
|