Spaces:
Runtime error
Runtime error
| from collections import OrderedDict | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| from transformers import BridgeTowerPreTrainedModel, BridgeTowerModel | |
| from transformers.models.bridgetower.modeling_bridgetower import BridgeTowerTextModel | |
| class LayerNorm(nn.LayerNorm): | |
| """Subclass torch's LayerNorm to handle fp16.""" | |
| def forward(self, x: torch.Tensor): | |
| orig_type = x.dtype | |
| ret = super().forward(x.type(torch.float32)) | |
| return ret.type(orig_type) | |
| class BridgeTowerImageFeatureExtractor(nn.Module): | |
| def __init__( | |
| self, | |
| patch_size=14, | |
| width=1024, | |
| resolution_after=294, | |
| ckpt_path=None, | |
| ): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) | |
| scale = width ** -0.5 | |
| self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |
| self.positional_embedding = nn.Parameter(scale * torch.randn((resolution_after // patch_size) ** 2 + 1, width)) | |
| self.ln_pre = LayerNorm(width) | |
| if ckpt_path is not None: | |
| sd = torch.load(ckpt_path) | |
| if 'state_dict' in sd: | |
| sd = sd["state_dict"] | |
| print(f'Loading feature extractor checkpoint from {ckpt_path}') | |
| self.load_state_dict(sd) | |
| def forward(self, x: torch.Tensor): | |
| x = self.conv1(x) # shape = [*, width, grid, grid] | |
| x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
| t=self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) | |
| x = torch.cat([t, x], dim=1) # shape = [*, grid ** 2 + 1, width] | |
| x = x + self.positional_embedding.to(x.dtype) | |
| x = self.ln_pre(x) | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| return x | |
| class BridgeTowerITCHead(nn.Module): | |
| def __init__(self, hidden_size, embed_size): | |
| super().__init__() | |
| self.fc = nn.Linear(hidden_size, embed_size) | |
| def forward(self, x): | |
| x = self.fc(x) | |
| return x | |
| class _BridgeTowerTextModelWrapper(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.text_model = BridgeTowerTextModel(config) | |
| def forward(self, **kwargs): | |
| return self.text_model(**kwargs) | |
| class BridgeTowerTextFeatureExtractor(BridgeTowerPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.bridgetower = _BridgeTowerTextModelWrapper(config.text_config) | |
| self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| token_type_ids: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.FloatTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| ): | |
| outputs = self.bridgetower(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) | |
| final_hidden_cls = outputs.hidden_states[-1][:,0,:] | |
| final_hidden_cls = F.normalize(self.itc_text_head(final_hidden_cls), dim=-1, p=2) | |
| return final_hidden_cls | |
| class BridgeTowerForITC(BridgeTowerPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.bridgetower = BridgeTowerModel(config) | |
| self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) | |
| self.itc_image_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) | |
| self.itc_cross_modal_head = BridgeTowerITCHead(config.hidden_size * 2, config.contrastive_hidden_size) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| token_type_ids: Optional[torch.LongTensor] = None, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| pixel_mask: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.FloatTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| image_embeds: Optional[torch.FloatTensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| ) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]: | |
| assert output_hidden_states, 'output_hidden_states should be set to True for BridgeTowerForITC' | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.bridgetower( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| pixel_values=pixel_values, | |
| pixel_mask=pixel_mask, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| image_embeds=image_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| pooler_output = outputs.pooler_output if return_dict else outputs[2] | |
| hidden_states_txt, hidden_states_img, hidden_states_cross_modal = outputs.hidden_states | |
| final_hidden_txt = hidden_states_txt[-1] | |
| final_hidden_img = hidden_states_img[-1] | |
| image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(final_hidden_img) | |
| image_token_type_embeddings = self.bridgetower.token_type_embeddings( | |
| torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device) | |
| ).expand_as(image_embeds_with_ln) | |
| final_hidden_img = ( | |
| self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) | |
| + image_token_type_embeddings | |
| ) | |
| final_hidden_txt = F.normalize(self.itc_text_head(final_hidden_txt[:,0,:]), dim=-1, p=2) | |
| final_hidden_img = F.normalize(self.itc_image_head(final_hidden_img[:,0,:]), dim=-1, p=2) | |
| final_hidden_cross = F.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2) | |
| logits = torch.stack([final_hidden_txt, final_hidden_img, final_hidden_cross], dim=-2) | |
| if not return_dict: | |
| return tuple(logits) | |
| return SequenceClassifierOutput( | |
| loss=None, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |