Update modeling_xlm_roberta.py
Browse files- modeling_xlm_roberta.py +4 -3
modeling_xlm_roberta.py
CHANGED
|
@@ -61,7 +61,7 @@ except ImportError:
|
|
| 61 |
try:
|
| 62 |
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
| 63 |
except ImportError:
|
| 64 |
-
CrossEntropyLoss =
|
| 65 |
|
| 66 |
try:
|
| 67 |
from tqdm.autonotebook import trange
|
|
@@ -1168,14 +1168,15 @@ class XLMRobertaClassificationHead(nn.Module):
|
|
| 1168 |
|
| 1169 |
def __init__(self, config):
|
| 1170 |
super().__init__()
|
| 1171 |
-
|
|
|
|
| 1172 |
classifier_dropout = (
|
| 1173 |
config.classifier_dropout
|
| 1174 |
if config.classifier_dropout is not None
|
| 1175 |
else config.hidden_dropout_prob
|
| 1176 |
)
|
| 1177 |
self.dropout = nn.Dropout(classifier_dropout)
|
| 1178 |
-
self.out_proj =
|
| 1179 |
|
| 1180 |
def forward(self, features, **kwargs):
|
| 1181 |
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
|
|
|
| 61 |
try:
|
| 62 |
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
| 63 |
except ImportError:
|
| 64 |
+
CrossEntropyLoss = torch.nn.CrossEntropyLoss
|
| 65 |
|
| 66 |
try:
|
| 67 |
from tqdm.autonotebook import trange
|
|
|
|
| 1168 |
|
| 1169 |
def __init__(self, config):
|
| 1170 |
super().__init__()
|
| 1171 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 1172 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
| 1173 |
classifier_dropout = (
|
| 1174 |
config.classifier_dropout
|
| 1175 |
if config.classifier_dropout is not None
|
| 1176 |
else config.hidden_dropout_prob
|
| 1177 |
)
|
| 1178 |
self.dropout = nn.Dropout(classifier_dropout)
|
| 1179 |
+
self.out_proj = linear_cls(config.hidden_size, config.num_labels)
|
| 1180 |
|
| 1181 |
def forward(self, features, **kwargs):
|
| 1182 |
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|