refactor-task-type-to-task (#43)
Browse files- rename task type (3afddee7275504c48afc63049db9124f9e2871ce)
- modeling_lora.py +10 -10
- modeling_xlm_roberta.py +1 -1
modeling_lora.py
CHANGED
|
@@ -367,35 +367,35 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 367 |
self,
|
| 368 |
sentences: Union[str, List[str]],
|
| 369 |
*args,
|
| 370 |
-
|
| 371 |
**kwargs,
|
| 372 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 373 |
"""
|
| 374 |
Computes sentence embeddings.
|
| 375 |
sentences(`str` or `List[str]`):
|
| 376 |
Sentence or sentences to be encoded
|
| 377 |
-
|
| 378 |
-
Specifies the task for which the encoding is intended. If `
|
| 379 |
all LoRA adapters are disabled, and the model reverts to its original,
|
| 380 |
general-purpose weights.
|
| 381 |
"""
|
| 382 |
-
if
|
| 383 |
raise ValueError(
|
| 384 |
-
f"Unsupported task '{
|
| 385 |
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
| 386 |
-
f"Alternatively, don't pass the `
|
| 387 |
)
|
| 388 |
adapter_mask = None
|
| 389 |
-
if
|
| 390 |
-
task_id = self._adaptation_map[
|
| 391 |
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
| 392 |
adapter_mask = torch.full(
|
| 393 |
(num_examples,), task_id, dtype=torch.int32, device=self.device
|
| 394 |
)
|
| 395 |
if isinstance(sentences, str):
|
| 396 |
-
sentences = self._task_instructions[
|
| 397 |
else:
|
| 398 |
-
sentences = [self._task_instructions[
|
| 399 |
return self.roberta.encode(
|
| 400 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
| 401 |
)
|
|
|
|
| 367 |
self,
|
| 368 |
sentences: Union[str, List[str]],
|
| 369 |
*args,
|
| 370 |
+
task: Optional[str] = None,
|
| 371 |
**kwargs,
|
| 372 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 373 |
"""
|
| 374 |
Computes sentence embeddings.
|
| 375 |
sentences(`str` or `List[str]`):
|
| 376 |
Sentence or sentences to be encoded
|
| 377 |
+
task(`str`, *optional*, defaults to `None`):
|
| 378 |
+
Specifies the task for which the encoding is intended. If `task` is not provided,
|
| 379 |
all LoRA adapters are disabled, and the model reverts to its original,
|
| 380 |
general-purpose weights.
|
| 381 |
"""
|
| 382 |
+
if task and task not in self._lora_adaptations:
|
| 383 |
raise ValueError(
|
| 384 |
+
f"Unsupported task '{task}'. "
|
| 385 |
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
| 386 |
+
f"Alternatively, don't pass the `task` argument to disable LoRA."
|
| 387 |
)
|
| 388 |
adapter_mask = None
|
| 389 |
+
if task:
|
| 390 |
+
task_id = self._adaptation_map[task]
|
| 391 |
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
| 392 |
adapter_mask = torch.full(
|
| 393 |
(num_examples,), task_id, dtype=torch.int32, device=self.device
|
| 394 |
)
|
| 395 |
if isinstance(sentences, str):
|
| 396 |
+
sentences = self._task_instructions[task] + sentences
|
| 397 |
else:
|
| 398 |
+
sentences = [self._task_instructions[task] + sentence for sentence in sentences]
|
| 399 |
return self.roberta.encode(
|
| 400 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
| 401 |
)
|
modeling_xlm_roberta.py
CHANGED
|
@@ -473,7 +473,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 473 |
normalize_embeddings: bool = True,
|
| 474 |
truncate_dim: Optional[int] = None,
|
| 475 |
adapter_mask: Optional[torch.Tensor] = None,
|
| 476 |
-
|
| 477 |
**tokenizer_kwargs,
|
| 478 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 479 |
"""
|
|
|
|
| 473 |
normalize_embeddings: bool = True,
|
| 474 |
truncate_dim: Optional[int] = None,
|
| 475 |
adapter_mask: Optional[torch.Tensor] = None,
|
| 476 |
+
task: Optional[str] = None,
|
| 477 |
**tokenizer_kwargs,
|
| 478 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 479 |
"""
|