| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from abc import ABC, abstractmethod |
| | import re |
| | import torch |
| | import torch.nn as nn |
| | import random |
| | from typing import List, Optional, Tuple, Union, Dict |
| |
|
| | from transformers import AutoConfig, AutoModelForCausalLM |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from transformers.generation.utils import GenerateOutput |
| | from transformers import Qwen2Config |
| |
|
| | from .vision_tower_builder import build_vision_tower |
| | from .mm_projector_builder import build_vision_projector |
| |
|
| | from .constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_TOKEN |
| | from .conversation import conv_templates, SeparatorStyle |
| | from .mm_utils import tokenizer_image_token, KeywordsStoppingCriteria, get_anyres_image_grid_shape, load_video |
| | from .modeling_qwen2_flash import Qwen2Model_Flash, Qwen2ForCausalLM_Flash |
| |
|
| |
|
| | class LlavaMetaModel: |
| |
|
| | def __init__(self, config): |
| | super(LlavaMetaModel, self).__init__(config) |
| |
|
| | if hasattr(config, "mm_vision_tower"): |
| | delay_load = getattr(config, "delay_load", False) |
| | self.vision_tower = build_vision_tower(config, delay_load=delay_load) |
| | self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config) |
| |
|
| | if "unpad" in getattr(config, "mm_patch_merge_type", ""): |
| | self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype)) |
| | if "nopad" in getattr(config, "mm_patch_merge_type", "") and getattr(self.config, "mm_newline_position", "nothing") != "nothing": |
| | self.frame_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype)) |
| |
|
| | def get_vision_tower(self): |
| | vision_tower = getattr(self, "vision_tower", None) |
| | if type(vision_tower) is list: |
| | vision_tower = vision_tower[0] |
| | return vision_tower |
| |
|
| | def initialize_vision_modules(self, model_args, fsdp=None): |
| | vision_tower = model_args.vision_tower |
| | mm_vision_select_layer = model_args.mm_vision_select_layer |
| | mm_vision_select_feature = model_args.mm_vision_select_feature |
| | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter |
| | mm_patch_merge_type = model_args.mm_patch_merge_type |
| |
|
| | self.config.mm_vision_tower = vision_tower |
| | self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "") |
| |
|
| | if self.get_vision_tower() is None: |
| | vision_tower = build_vision_tower(model_args) |
| |
|
| | if fsdp is not None and len(fsdp) > 0: |
| | self.vision_tower = [vision_tower] |
| | else: |
| | self.vision_tower = vision_tower |
| | else: |
| | if fsdp is not None and len(fsdp) > 0: |
| | vision_tower = self.vision_tower[0] |
| | else: |
| | vision_tower = self.vision_tower |
| | vision_tower.load_model() |
| |
|
| |
|
| |
|
| | self.config.use_mm_proj = True |
| | self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear") |
| | self.config.mm_vision_select_layer = mm_vision_select_layer |
| | self.config.mm_vision_select_feature = mm_vision_select_feature |
| | self.config.mm_patch_merge_type = mm_patch_merge_type |
| |
|
| | if getattr(self, "mm_projector", None) is None: |
| | self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config) |
| |
|
| | if "unpad" in mm_patch_merge_type: |
| | embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) |
| | self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std) |
| | if "nopad" in getattr(self.config, "mm_patch_merge_type", "") and getattr(self.config, "mm_newline_position", "nothing") != "nothing": |
| | embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) |
| | self.frame_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std) |
| | else: |
| | |
| | for p in self.mm_projector.parameters(): |
| | p.requires_grad = True |
| |
|
| | if pretrain_mm_mlp_adapter is not None: |
| | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu") |
| |
|
| | def get_w(weights, keyword): |
| | return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k} |
| |
|
| | if self.config.mm_projector_type =='lxh_qformer': |
| | incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"), strict=False) |
| | else: |
| | incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector")) |
| | print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}") |
| |
|
| |
|
| | class LlavaMetaForCausalLM(ABC): |
| |
|
| | @abstractmethod |
| | def get_model(self): |
| | pass |
| |
|
| | def get_vision_tower(self): |
| | return self.get_model().get_vision_tower() |
| |
|
| |
|
| | def encode_video_image(self, images_list, video_idx_in_batch): |
| | |
| | bs = len(images_list) |
| |
|
| | concat_images = [] |
| | concat_videos = [] |
| | for idx, image in enumerate(images_list): |
| | if idx in video_idx_in_batch: |
| | concat_videos.append(image) |
| | else: |
| | concat_images.append(image) |
| | |
| | has_image = len(concat_images) > 0 |
| | has_video = len(concat_videos) > 0 |
| |
|
| | mm_local_num_frames = getattr(self.config, "mm_local_num_frames", -1) |
| | assert mm_local_num_frames != -1 |
| | if has_image: |
| | image_split_sizes = [image.shape[0] for image in concat_images] |
| | concat_images = torch.cat([image.unsqueeze(1) for image in concat_images], dim=0) |
| | |
| | images_features = self.get_model().get_vision_tower()(concat_images) |
| | images_features = torch.split(images_features, image_split_sizes) |
| |
|
| | if has_video: |
| | video_split_sizes = [video.shape[0] // mm_local_num_frames for video in concat_videos] |
| | concat_videos = torch.cat([video.reshape(video.shape[0] // mm_local_num_frames, mm_local_num_frames, video.shape[1], video.shape[2], video.shape[3]) for video in concat_videos], dim=0) |
| | |
| | videos_features = self.get_model().get_vision_tower()(concat_videos) |
| | videos_features = [v.reshape(-1, v.shape[-2] // mm_local_num_frames, v.shape[-1]) for v in torch.split(videos_features, video_split_sizes)] |
| |
|
| |
|
| | all_videos_or_images_features = [] |
| | img_idx = 0 |
| | vid_idx = 0 |
| |
|
| | for idx in range(bs): |
| | |
| | if idx in video_idx_in_batch: |
| | feat = self.get_model().mm_projector(videos_features[vid_idx], compress=True, local_num_frames=getattr(self.config, "mm_local_num_frames", -1)) |
| | |
| | vid_idx += 1 |
| | else: |
| | feat = self.get_model().mm_projector(images_features[img_idx], compress=False) |
| | img_idx += 1 |
| | |
| | all_videos_or_images_features.append(feat) |
| |
|
| | if has_video: |
| | assert vid_idx == len(videos_features), f"vid: {vid_idx} != {len(videos_features)}" |
| | if has_image: |
| | assert img_idx == len(images_features), f"img: {img_idx} != {len(images_features)}" |
| |
|
| | return all_videos_or_images_features |
| |
|
| |
|
| | |
| | def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None): |
| | assert type(modalities) is list, modalities |
| | |
| | vision_tower = self.get_vision_tower() |
| | |
| | if vision_tower is None or images is None or input_ids.shape[1] == 1: |
| | return input_ids, position_ids, attention_mask, past_key_values, None, labels |
| |
|
| | if type(images) is list or images.ndim == 5: |
| | if type(images) is list: |
| | images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] |
| |
|
| | video_idx_in_batch = [] |
| | for _ in range(len(modalities)): |
| | if modalities[_] == "video": |
| | video_idx_in_batch.append(_) |
| |
|
| | images_list = [] |
| | for image in images: |
| | if image.ndim == 4: |
| | images_list.append(image) |
| | else: |
| | images_list.append(image.unsqueeze(0)) |
| |
|
| |
|
| | vision_encode_type = getattr(self.config, "vision_encode_type", "image") |
| | mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") |
| | image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") |
| | frame_aspect_ratio = getattr(self.config, "frame_aspect_ratio", "square") |
| | mm_newline_position = getattr(self.config, "mm_newline_position", "nothing") |
| |
|
| |
|
| | if vision_encode_type == "video_image": |
| | image_features = self.encode_video_image(images_list, video_idx_in_batch=video_idx_in_batch) |
| | else: |
| | raise NotImplementedError(vision_encode_type) |
| | |
| |
|
| | if mm_patch_merge_type == "flat": |
| | image_features = [x.flatten(0, 1) for x in image_features] |
| | elif mm_patch_merge_type.startswith("spatial"): |
| | new_image_features = [] |
| | for image_idx, image_feature in enumerate(image_features): |
| |
|
| | if image_idx in video_idx_in_batch: |
| |
|
| | if "anyres" in frame_aspect_ratio: |
| | raise NotImplementedError |
| | else: |
| | frame_feature = image_feature |
| |
|
| | if "pad" in mm_patch_merge_type: |
| | if mm_newline_position == 'one_token': |
| | frame_feature = frame_feature.flatten(0, 1) |
| | if "unpad" in mm_patch_merge_type: |
| | frame_feature = torch.cat((frame_feature, self.model.image_newline[None].to(frame_feature.device)), dim=0) |
| | else: |
| | frame_feature = torch.cat((frame_feature, self.model.frame_newline[None].to(frame_feature.device)), dim=0) |
| | elif mm_newline_position == 'nothing': |
| | frame_feature = frame_feature.flatten(0, 1) |
| | else: |
| | raise NotImplementedError("add pad please!!") |
| | else: |
| | frame_feature = frame_feature.flatten(0, 1) |
| |
|
| | |
| | image_feature = frame_feature |
| |
|
| | elif image_feature.shape[0] > 1: |
| | base_image_feature = image_feature[0] |
| | image_feature = image_feature[1:] |
| | origin_size = image_feature.shape |
| | |
| | height = width = self.get_model().mm_projector.num_image_patches_per_side |
| | assert height * width == base_image_feature.shape[0], f"height:{height}, width: {width}, base_image_feature: {base_image_feature.shape}" |
| |
|
| | if "anyres_max" in image_aspect_ratio: |
| | matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio) |
| | if matched_anyres_max_num_patches: |
| | max_num_patches = int(matched_anyres_max_num_patches.group(1)) |
| |
|
| | if "anyres" in image_aspect_ratio: |
| | if hasattr(self.get_vision_tower(), "image_size"): |
| | vision_tower_image_size = self.get_vision_tower().image_size |
| | else: |
| | raise ValueError("vision_tower_image_size is not found in the vision tower.") |
| | try: |
| | num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size, max_resolutions=None) |
| | except Exception as e: |
| | print(f"Error: {e}") |
| | raise e |
| | |
| |
|
| | image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) |
| | else: |
| | raise NotImplementedError(image_aspect_ratio) |
| | image_feature = image_feature.view(2, 2, height, width, -1) |
| |
|
| | if "maxpool2x2" in mm_patch_merge_type: |
| | raise NotImplementedError |
| | elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches: |
| | raise NotImplementedError |
| | elif "unpad" in mm_patch_merge_type: |
| | raise NotImplementedError |
| | else: |
| | image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() |
| | image_feature = image_feature.flatten(0, 3) |
| | if "nobase" in mm_patch_merge_type: |
| | pass |
| | else: |
| | try: |
| | image_feature = torch.cat((base_image_feature, image_feature), dim=0) |
| | except Exception as e: |
| | raise ValueError(f"{num_patch_width} {num_patch_height} now: base_image_feature: {base_image_feature.shape}, {image_feature.shape}, image_sizes[image_idx]: {image_sizes[image_idx]}, origin_size: {origin_size}, {image_sizes[image_idx]}, {self.config.image_grid_pinpoints}, {vision_tower_image_size}") |
| | else: |
| | image_feature = image_feature[0] |
| | if "unpad" in mm_patch_merge_type: |
| | image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0) |
| |
|
| | |
| | new_image_features.append(image_feature) |
| | image_features = new_image_features |
| | else: |
| | raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") |
| | else: |
| | |
| | image_features = self.encode_image(images) |
| |
|
| | |
| | if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False): |
| | raise NotImplementedError |
| | |
| |
|
| | |
| | |
| | |
| | |
| | _labels = labels |
| | _position_ids = position_ids |
| | _attention_mask = attention_mask |
| | if attention_mask is None: |
| | attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
| | else: |
| | attention_mask = attention_mask.bool() |
| | if position_ids is None: |
| | position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
| | if labels is None: |
| | labels = torch.full_like(input_ids, IGNORE_INDEX) |
| |
|
| |
|
| | input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] |
| | labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] |
| |
|
| | new_input_embeds = [] |
| | new_labels = [] |
| | cur_image_idx = 0 |
| |
|
| | mm_llm_compress = getattr(self.config, "mm_llm_compress", False) |
| | |
| | if mm_llm_compress: |
| | self.model.llm_compress_type = getattr(self.config, "llm_compress_type", "attention") |
| | self.model.llm_compress_layer_list = getattr(self.config, "llm_compress_layer_list", [8, 16, 24]) |
| | self.model.llm_image_token_ratio_list = getattr(self.config, "llm_image_token_ratio_list", [1.0, 0.5, 0.25, 0.125]) |
| | first_image_token_position = [] |
| | text_prompt_lens = [] |
| | else: |
| | self.model.llm_compress_type = "attention" |
| | self.model.llm_compress_layer_list = [] |
| | self.model.llm_image_token_ratio_list = [] |
| | first_image_token_position = [] |
| | text_prompt_lens = [] |
| |
|
| | |
| | for batch_idx, cur_input_ids in enumerate(input_ids): |
| | num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
| |
|
| | if mm_llm_compress: |
| | |
| | |
| | image_index = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() |
| | assert len(image_index) == 1, f"Only support singe/video: {image_index}" |
| | if image_index == []: |
| | first_image_token_position.append(-1) |
| | else: |
| | first_image_token_position.append(image_index[0]) |
| | |
| |
|
| | |
| | if not self.training: |
| | if image_index == []: |
| | assert num_images == 0, num_images |
| | else: |
| | assert num_images == 1, f"num_images={num_images}" |
| | text_prompt_lens.append(cur_input_ids.shape[0] - num_images) |
| |
|
| | |
| |
|
| | |
| | if num_images == 0: |
| | cur_image_features = image_features[cur_image_idx] |
| | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) |
| | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) |
| | new_input_embeds.append(cur_input_embeds) |
| | new_labels.append(labels[batch_idx]) |
| | cur_image_idx += 1 |
| | continue |
| |
|
| | image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] |
| | cur_input_ids_noim = [] |
| | cur_labels = labels[batch_idx] |
| | cur_labels_noim = [] |
| | for i in range(len(image_token_indices) - 1): |
| | cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) |
| | cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) |
| | split_sizes = [x.shape[0] for x in cur_labels_noim] |
| | cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) |
| | cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) |
| | cur_new_input_embeds = [] |
| | cur_new_labels = [] |
| |
|
| | for i in range(num_images + 1): |
| | cur_new_input_embeds.append(cur_input_embeds_no_im[i]) |
| | cur_new_labels.append(cur_labels_noim[i]) |
| | if i < num_images: |
| | try: |
| | cur_image_features = image_features[cur_image_idx] |
| | except IndexError: |
| | print(f"cur_image_idx={cur_image_idx} is not ok") |
| | cur_image_features = image_features[cur_image_idx - 1] |
| | cur_image_idx += 1 |
| | cur_new_input_embeds.append(cur_image_features) |
| | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) |
| |
|
| | cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] |
| |
|
| | |
| | cur_new_input_embeds = torch.cat(cur_new_input_embeds) |
| | cur_new_labels = torch.cat(cur_new_labels) |
| |
|
| | new_input_embeds.append(cur_new_input_embeds) |
| | new_labels.append(cur_new_labels) |
| |
|
| |
|
| | if mm_llm_compress: |
| | self.model.first_image_token_position = first_image_token_position |
| | self.model.text_prompt_lens = text_prompt_lens |
| | self.model.num_image_token_lens = [image_feature.shape[0] for image_feature in image_features] |
| |
|
| | |
| | tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) |
| | |
| |
|
| | new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] |
| | new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] |
| |
|
| | |
| | max_len = max(x.shape[0] for x in new_input_embeds) |
| | batch_size = len(new_input_embeds) |
| |
|
| | new_input_embeds_padded = [] |
| | new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) |
| | attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) |
| | position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) |
| | |
| |
|
| | for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): |
| | cur_len = cur_new_embed.shape[0] |
| | if getattr(self.config, "tokenizer_padding_side", "right") == "left": |
| | new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0)) |
| | if cur_len > 0: |
| | new_labels_padded[i, -cur_len:] = cur_new_labels |
| | attention_mask[i, -cur_len:] = True |
| | position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) |
| | else: |
| | new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) |
| | if cur_len > 0: |
| | new_labels_padded[i, :cur_len] = cur_new_labels |
| | attention_mask[i, :cur_len] = True |
| | position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) |
| |
|
| | new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) |
| | |
| |
|
| | if _labels is None: |
| | new_labels = None |
| | else: |
| | new_labels = new_labels_padded |
| |
|
| | if _attention_mask is None: |
| | attention_mask = None |
| | else: |
| | attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
| |
|
| | if _position_ids is None: |
| | position_ids = None |
| | if getattr(self.config, "use_pos_skipping", False) and self.training: |
| | position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device) |
| | split_position = random.randint(0, new_input_embeds.size(1)) |
| | left_add = random.randint(0, self.config.pos_skipping_range) |
| | right_add = random.randint(left_add, self.config.pos_skipping_range) |
| | position_ids[:, :split_position] += left_add |
| | position_ids[:, split_position:] += right_add |
| | |
| | |
| | return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels |
| |
|
| | def initialize_vision_tokenizer(self, model_args, tokenizer): |
| | if model_args.mm_use_im_patch_token: |
| | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) |
| | self.resize_token_embeddings(len(tokenizer)) |
| |
|
| | if model_args.mm_use_im_start_end: |
| | num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) |
| | self.resize_token_embeddings(len(tokenizer)) |
| |
|
| | if num_new_tokens > 0: |
| | input_embeddings = self.get_input_embeddings().weight.data |
| | output_embeddings = self.get_output_embeddings().weight.data |
| |
|
| | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
| | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
| |
|
| | input_embeddings[-num_new_tokens:] = input_embeddings_avg |
| | output_embeddings[-num_new_tokens:] = output_embeddings_avg |
| |
|
| | if model_args.tune_mm_mlp_adapter: |
| | for p in self.get_input_embeddings().parameters(): |
| | p.requires_grad = True |
| | for p in self.get_output_embeddings().parameters(): |
| | p.requires_grad = False |
| |
|
| | if model_args.pretrain_mm_mlp_adapter: |
| | mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu") |
| | embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] |
| | assert num_new_tokens == 2 |
| | if input_embeddings.shape == embed_tokens_weight.shape: |
| | input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] |
| | elif embed_tokens_weight.shape[0] == num_new_tokens: |
| | input_embeddings[-num_new_tokens:] = embed_tokens_weight |
| | else: |
| | raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") |
| | elif model_args.mm_use_im_patch_token: |
| | if model_args.tune_mm_mlp_adapter: |
| | for p in self.get_input_embeddings().parameters(): |
| | p.requires_grad = False |
| | for p in self.get_output_embeddings().parameters(): |
| | p.requires_grad = False |
| |
|
| |
|
| |
|
| | class VideoChatFlashQwenConfig(Qwen2Config): |
| | model_type = "videochat_flash_qwen" |
| |
|
| |
|
| | class VideoChatFlashQwenModel(LlavaMetaModel, Qwen2Model_Flash): |
| | config_class = VideoChatFlashQwenConfig |
| |
|
| | def __init__(self, config: VideoChatFlashQwenConfig): |
| | super(VideoChatFlashQwenModel, self).__init__(config) |
| |
|
| |
|
| | class VideoChatFlashQwenForCausalLM(LlavaMetaForCausalLM, Qwen2ForCausalLM_Flash): |
| | config_class = VideoChatFlashQwenConfig |
| |
|
| | def __init__(self, config): |
| | |
| | Qwen2ForCausalLM_Flash.__init__(self, config) |
| | config.model_type = "videochat_flash_qwen" |
| | |
| |
|
| | self.model = VideoChatFlashQwenModel(config) |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| | |
| | self.post_init() |
| |
|
| | def get_model(self): |
| | return self.model |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | images: Optional[torch.FloatTensor] = None, |
| | image_sizes: Optional[List[List[int]]] = None, |
| | return_dict: Optional[bool] = None, |
| | modalities: Optional[List[str]] = ["image"], |
| | dpo_forward: Optional[bool] = False, |
| | cache_position=None, |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| |
|
| | if inputs_embeds is None: |
| | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) |
| |
|
| | |
| | if dpo_forward: |
| | raise NotImplementedError |
| | else: |
| | return super().forward( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | labels=labels, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | inputs: Optional[torch.Tensor] = None, |
| | images: Optional[torch.Tensor] = None, |
| | image_sizes: Optional[torch.Tensor] = None, |
| | modalities: Optional[List[str]] = ["image"], |
| | **kwargs, |
| | ) -> Union[GenerateOutput, torch.LongTensor]: |
| | position_ids = kwargs.pop("position_ids", None) |
| | attention_mask = kwargs.pop("attention_mask", None) |
| | if "inputs_embeds" in kwargs: |
| | raise NotImplementedError("`inputs_embeds` is not supported") |
| |
|
| | if images is not None: |
| | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) |
| | else: |
| | self.model.image_token_posi = [-1] |
| | self.model.prompt_len = None |
| | self.model.image_tokens = [0] |
| | inputs_embeds = self.get_model().embed_tokens(inputs) |
| |
|
| | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) |
| |
|
| | @torch.no_grad() |
| | def chat(self, |
| | video_path, |
| | tokenizer, |
| | user_prompt, |
| | chat_history=None, |
| | return_history=True, |
| | max_num_frames=512, |
| | media_dict=None, |
| | generation_config={}): |
| |
|
| | frames, time_msg = load_video(video_path, max_num_frames=max_num_frames, media_dict=media_dict) |
| |
|
| | image_sizes = [frames[0].shape[:2]] |
| |
|
| | frames = [self.get_vision_tower().image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].to(self.model.dtype).cuda()] |
| |
|
| | conv = conv_templates["qwen_2"].copy() |
| |
|
| | if chat_history is None or len(chat_history) == 0: |
| | user_prompt = f'{DEFAULT_IMAGE_TOKEN}\n{time_msg.strip()} {user_prompt}' |
| | else: |
| | assert DEFAULT_IMAGE_TOKEN in chat_history[0]['content'], chat_history |
| | for msg in chat_history: |
| | conv.append_message(msg['role'], msg['content']) |
| | |
| | conv.append_message(conv.roles[0], user_prompt) |
| | conv.append_message(conv.roles[1], None) |
| |
|
| | prompt = conv.get_prompt() |
| |
|
| | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() |
| |
|
| | if tokenizer.pad_token_id is None: |
| | if "qwen" in tokenizer.name_or_path.lower(): |
| | print("Setting pad token to bos token for qwen model.") |
| | tokenizer.pad_token_id = 151643 |
| |
|
| | attention_masks = input_ids.ne(tokenizer.pad_token_id).long().cuda() |
| |
|
| | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
| | keywords = [stop_str] |
| | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) |
| | |
| | with torch.inference_mode(): |
| | output_ids = self.generate( |
| | inputs=input_ids, |
| | images=frames, |
| | attention_mask=attention_masks, |
| | modalities=["video"], |
| | image_sizes=image_sizes, |
| | use_cache=True, |
| | stopping_criteria=[stopping_criteria], |
| | **generation_config |
| | ) |
| |
|
| | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() |
| | if outputs.endswith(stop_str): |
| | outputs = outputs[: -len(stop_str)] |
| |
|
| | outputs = outputs.strip() |
| |
|
| | |
| | |
| | |
| | if chat_history is None: |
| | chat_history = [] |
| |
|
| | chat_history.append({"role":conv.roles[0], "content":user_prompt}) |
| | chat_history.append({"role":conv.roles[1], "content":outputs}) |
| | if return_history: |
| | return outputs, chat_history |
| | else: |
| | return outputs |
| | |
| |
|
| |
|
| | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): |
| | images = kwargs.pop("images", None) |
| | image_sizes = kwargs.pop("image_sizes", None) |
| | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) |
| | if images is not None: |
| | inputs["images"] = images |
| | if image_sizes is not None: |
| | inputs["image_sizes"] = image_sizes |
| | return inputs |
| |
|
| |
|
| | AutoConfig.register("videochat_flash_qwen", VideoChatFlashQwenConfig) |
| | AutoModelForCausalLM.register(VideoChatFlashQwenConfig, VideoChatFlashQwenForCausalLM) |