Commit
·
ee8863c
1
Parent(s):
4b000ec
feat: matryoshka embeddings
Browse filesSigned-off-by: jupyterjazz <[email protected]>
- configuration_xlm_roberta.py +2 -0
- modeling_xlm_roberta.py +14 -0
configuration_xlm_roberta.py
CHANGED
|
@@ -31,6 +31,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
| 31 |
use_flash_attn=True,
|
| 32 |
torch_dtype=None,
|
| 33 |
emb_pooler=None,
|
|
|
|
| 34 |
**kwargs,
|
| 35 |
):
|
| 36 |
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
@@ -59,6 +60,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
| 59 |
self.lora_main_params_trainable = lora_main_params_trainable
|
| 60 |
self.use_flash_attn = use_flash_attn
|
| 61 |
self.emb_pooler = emb_pooler
|
|
|
|
| 62 |
if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
|
| 63 |
self.torch_dtype = getattr(torch, torch_dtype)
|
| 64 |
else:
|
|
|
|
| 31 |
use_flash_attn=True,
|
| 32 |
torch_dtype=None,
|
| 33 |
emb_pooler=None,
|
| 34 |
+
matryoshka_dimensions=None,
|
| 35 |
**kwargs,
|
| 36 |
):
|
| 37 |
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
|
|
| 60 |
self.lora_main_params_trainable = lora_main_params_trainable
|
| 61 |
self.use_flash_attn = use_flash_attn
|
| 62 |
self.emb_pooler = emb_pooler
|
| 63 |
+
self.matryoshka_dimensions = matryoshka_dimensions
|
| 64 |
if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
|
| 65 |
self.torch_dtype = getattr(torch, torch_dtype)
|
| 66 |
else:
|
modeling_xlm_roberta.py
CHANGED
|
@@ -452,6 +452,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 452 |
convert_to_tensor: bool = False,
|
| 453 |
device: Optional[torch.device] = None,
|
| 454 |
normalize_embeddings: bool = False,
|
|
|
|
| 455 |
**tokenizer_kwargs,
|
| 456 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 457 |
"""
|
|
@@ -481,6 +482,8 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 481 |
If set to true, returned vectors will have length 1. In that case, the
|
| 482 |
faster dot-product (util.dot_score) instead of cosine similarity can
|
| 483 |
be used.
|
|
|
|
|
|
|
| 484 |
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
| 485 |
Keyword arguments for the tokenizer
|
| 486 |
Returns:
|
|
@@ -575,6 +578,17 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 575 |
|
| 576 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
| 577 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
if convert_to_tensor:
|
| 579 |
all_embeddings = torch.stack(all_embeddings)
|
| 580 |
elif convert_to_numpy:
|
|
|
|
| 452 |
convert_to_tensor: bool = False,
|
| 453 |
device: Optional[torch.device] = None,
|
| 454 |
normalize_embeddings: bool = False,
|
| 455 |
+
truncate_dim: int = None,
|
| 456 |
**tokenizer_kwargs,
|
| 457 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 458 |
"""
|
|
|
|
| 482 |
If set to true, returned vectors will have length 1. In that case, the
|
| 483 |
faster dot-product (util.dot_score) instead of cosine similarity can
|
| 484 |
be used.
|
| 485 |
+
truncate_dim(`int`, *optional*, defaults to None):
|
| 486 |
+
The dimension to truncate sentence embeddings to. `None` does no truncation.
|
| 487 |
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
| 488 |
Keyword arguments for the tokenizer
|
| 489 |
Returns:
|
|
|
|
| 578 |
|
| 579 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
| 580 |
|
| 581 |
+
if truncate_dim:
|
| 582 |
+
if not self.config.matryoshka_dimension:
|
| 583 |
+
logger.warning(
|
| 584 |
+
'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
|
| 585 |
+
)
|
| 586 |
+
elif truncate_dim in self.config.matryoshka_dimension:
|
| 587 |
+
all_embeddings = [tensor[:truncate_dim] for tensor in all_embeddings]
|
| 588 |
+
else:
|
| 589 |
+
raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
|
| 590 |
+
f'Supported dimensions are {self.config.matryoshka_dimension}.')
|
| 591 |
+
|
| 592 |
if convert_to_tensor:
|
| 593 |
all_embeddings = torch.stack(all_embeddings)
|
| 594 |
elif convert_to_numpy:
|