from __future__ import annotations import dataclasses import glob from collections.abc import Callable from pathlib import Path from typing import Any, Dict, Optional, Tuple, Union, cast import numpy as np import torch import torch.nn as nn from einops import rearrange from safetensors.torch import load_file as safetensors_load from transformers import PretrainedConfig from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation.utils import GenerationMixin from transformers.integrations import use_kernel_forward_from_hub from transformers.masking_utils import create_causal_mask from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import ( BaseModelOutputWithPast, BaseModelOutputWithPooling, ) from transformers.modeling_rope_utils import ( ROPE_INIT_FUNCTIONS, dynamic_rope_update, ) from transformers.modeling_utils import ( ALL_ATTENTION_FUNCTIONS, PreTrainedModel, ) from transformers.processing_utils import Unpack from transformers.utils import ( ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, ) from transformers.utils.deprecation import deprecate_kwarg try: from transformers.utils.generic import check_model_inputs except ImportError: def check_model_inputs(*args, **kwargs): def _wrap(fn): return fn return _wrap from .configuration_yasa2 import ConvNextConfig, Yasa2Config, YasaConfig logger = logging.get_logger(__name__) # ---- Model outputs ---- @dataclasses.dataclass class Yasa2ModelOutputWithPast(BaseModelOutputWithPast): """ Base class for Yasa2 model outputs with past key values. Args: last_hidden_state (`torch.FloatTensor`, *optional*): Last hidden state of the model. past_key_values (`Cache`, *optional*): Cache of key/value tensors for each layer. hidden_states (`Tuple[torch.FloatTensor]`, *optional*): Tuple of hidden states from the model. attentions (`Tuple[torch.FloatTensor]`, *optional*): Tuple of attention maps from the model. """ last_hidden_state: Optional[torch.FloatTensor] = None past_key_values: Optional[Cache] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None vision_hidden_states: Optional[torch.FloatTensor] = None @dataclasses.dataclass class Yasa2ForConditionalGenerationModelOutput(ModelOutput): """ Outputs for Yasa2 conditional generation. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`Cache`, *optional*, returned when `use_cache=True`): Cache of key/value tensors for each layer. hidden_states (`Tuple[torch.FloatTensor]`, *optional*): Tuple of hidden states from the language model. attentions (`Tuple[torch.FloatTensor]`, *optional*): Tuple of attention maps from the language model. vision_hidden_states (`torch.FloatTensor`, *optional*): Vision embeddings after projection and pooling. language_model_outputs (`Yasa2ModelOutputWithPast`, *optional*): The full language model outputs. """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Cache] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None vision_hidden_states: Optional[torch.FloatTensor] = None language_model_outputs: Optional[Yasa2ModelOutputWithPast] = None # ---- Utilities ---- def get_2d_sincos_pos_embed( embed_dim: int, image_size: int | tuple[int, int] ) -> np.ndarray: """Generate 2D sincos positional embeddings for a vision grid. Args: embed_dim (int): Embedding dimension. image_size (int | tuple[int, int]): Image size as an int or (height, width) tuple. Returns: np.ndarray: Positional embedding array of shape (H*W, embed_dim). """ if isinstance(image_size, int): grid_h_size, grid_w_size = image_size, image_size else: grid_h_size, grid_w_size = image_size[0], image_size[1] grid_h = np.arange(grid_h_size, dtype=np.float32) grid_w = np.arange(grid_w_size, dtype=np.float32) # Build a meshgrid of spatial coordinates to compute positional embeddings. grid = np.meshgrid(grid_w, grid_h) grid = np.stack(grid, axis=0) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) return pos_embed def get_2d_sincos_pos_embed_from_grid( embed_dim: int, grid: np.ndarray ) -> np.ndarray: """Generate 2D sincos positional embeddings from a coordinate grid. Args: embed_dim (int): Embedding dimension. grid (np.ndarray): Grid array of shape (2, H, W). Returns: np.ndarray: Positional embedding array of shape (H, W, embed_dim). """ assert embed_dim % 2 == 0 emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) emb = np.concatenate([emb_h, emb_w], axis=-1) return emb def get_1d_sincos_pos_embed_from_grid( embed_dim: int, pos: np.ndarray ) -> np.ndarray: """Generate 1D sincos positional embeddings from a positional array. Args: embed_dim (int): Embedding dimension. pos (np.ndarray): Position grid array for one dimension. Returns: np.ndarray: Positional embedding array with sin/cos features. """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega out = np.einsum("hw,d->hwd", pos, omega) emb_sin = np.sin(out) emb_cos = np.cos(out) emb = np.concatenate([emb_sin, emb_cos], axis=-1) return emb # ---- ConvNeXt V2 backbone ---- def drop_path( input: torch.Tensor, drop_prob: float = 0.0, training: bool = False ) -> torch.Tensor: """Apply stochastic depth (drop path) to the input tensor. Args: input (torch.Tensor): Input tensor to apply drop path to. drop_prob (float): Probability of dropping a path. Defaults to 0.0. training (bool): Whether the model runs in training mode. Defaults to False. Returns: torch.Tensor: Tensor with drop path applied when enabled. """ if drop_prob == 0.0 or not training: return input keep_prob = 1 - drop_prob shape = (input.shape[0],) + (1,) * (input.ndim - 1) # Sample a random tensor that determines which paths to keep per sample. random_tensor = keep_prob + torch.rand( shape, dtype=input.dtype, device=input.device ) random_tensor.floor_() output = input.div(keep_prob) * random_tensor return output class ConvNextDropPath(nn.Module): """Drop paths (stochastic depth) per sample in residual blocks.""" def __init__(self, drop_prob: Optional[float] = None): """Initialize the drop-path module. Args: drop_prob (Optional[float]): Probability of dropping a path. """ super().__init__() self.drop_prob = drop_prob def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Apply drop path to the provided hidden states. Args: hidden_states (torch.Tensor): Tensor to apply stochastic depth to. Returns: torch.Tensor: Tensor after stochastic depth. """ return drop_path(hidden_states, self.drop_prob, self.training) def extra_repr(self) -> str: """Return a string representation for module printing. Returns: str: Description containing the configured drop probability. """ return "p={}".format(self.drop_prob) class ConvNextLayerNorm(nn.Module): r"""LayerNorm that supports channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). """ def __init__( self, normalized_shape: int, eps: float = 1e-6, data_format: str = "channels_last", ) -> None: """Initialize ConvNext LayerNorm. Args: normalized_shape (int): Expected shape of the input channels. eps (float): Small epsilon to avoid division by zero. data_format (str): Either 'channels_last' or 'channels_first'. Raises: NotImplementedError: If data_format is not supported. """ super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.data_format = data_format if self.data_format not in ["channels_last", "channels_first"]: raise NotImplementedError( f"Unsupported data format: {self.data_format}" ) self.normalized_shape = (normalized_shape,) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply layer normalization according to the configured data format. Args: x (torch.Tensor): Input tensor of shape (N, C, H, W) or (N, H, W, C). Returns: torch.Tensor: Normalized tensor with the same shape as input. """ if self.data_format == "channels_last": x = nn.functional.layer_norm( x, self.normalized_shape, self.weight, self.bias, self.eps ) elif self.data_format == "channels_first": input_dtype = x.dtype x = x.float() u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) # Compute normalized values in fp32 for stable statistics before restoring dtype. x = (x - u) / torch.sqrt(s + self.eps) x = x.to(dtype=input_dtype) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x class ConvNextV2GRN(nn.Module): """Global Response Normalization (GRN) layer for ConvNeXt V2.""" def __init__(self, dim: int): """Initialize the GRN layer parameters. Args: dim (int): Channel dimension of the input tensor. """ super().__init__() self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim)) self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim)) def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: """Apply Global Response Normalization to the hidden states. Args: hidden_states (torch.FloatTensor): Input tensor shaped (batch, height, width, channels). Returns: torch.FloatTensor: Normalized tensor with the same shape. """ # Compute and normalize global spatial feature maps global_features = torch.norm( hidden_states, p=2, dim=(1, 2), keepdim=True ) norm_features = global_features / ( global_features.mean(dim=-1, keepdim=True) + 1e-6 ) # Combine normalized features with learnable scale and bias. hidden_states = ( self.weight * (hidden_states * norm_features) + self.bias + hidden_states ) return hidden_states class ConvNextEmbeddings(nn.Module): """ConvNeXt patch embedding layer.""" def __init__( self, num_channels: int = 3, hidden_size: int = 96, patch_size: int = 4 ) -> None: """Initialize ConvNeXt patch embeddings. Args: num_channels (int): Number of image channels. Defaults to 3. hidden_size (int): Hidden dimension size. Defaults to 96. patch_size (int): Size of patches for initial convolution. Defaults to 4. """ super().__init__() self.patch_embeddings = nn.Conv2d( num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, ) self.layernorm = ConvNextLayerNorm( hidden_size, eps=1e-6, data_format="channels_first" ) self.num_channels = num_channels def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: """Create patch embeddings from pixel values. Args: pixel_values (torch.FloatTensor): Image tensor shaped (batch, channels, height, width). Returns: torch.Tensor: Embedded tensor after patch convolution. Raises: ValueError: If the channel dimension does not match the expected count. """ num_channels = pixel_values.shape[1] if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) embeddings = self.patch_embeddings(pixel_values) embeddings = self.layernorm(embeddings) return embeddings class ConvNextLayer(nn.Module): """ConvNeXt V2 layer with GRN.""" def __init__( self, dim: int, drop_path: float = 0, layer_scale_init_value: float = 1e-6, use_grn: bool = True, ) -> None: """Construct a ConvNeXt V2 layer with GRN and scaling. Args: dim (int): Input/output channel dimension. drop_path (float): Drop path probability for stochastic depth. layer_scale_init_value (float): Initial scaling factor for residual branches. use_grn (bool): Whether to enable Global Response Normalization. """ super().__init__() self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) self.layernorm = ConvNextLayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear(dim, 4 * dim) self.act = nn.GELU() if not use_grn: raise ValueError("ConvNeXt V2 requires use_grn=True.") self.grn = ConvNextV2GRN(4 * dim) self.pwconv2 = nn.Linear(4 * dim, dim) self.layer_scale_parameter = ( nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True ) if layer_scale_init_value > 0 else None ) self.drop_path = ( ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity() ) def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: """Run the ConvNeXt layer forward. Args: hidden_states (torch.FloatTensor): Input tensor shaped (batch, channels, height, width). Returns: torch.Tensor: Tensor after depthwise conv, GRN, and residual connection. """ input = hidden_states x = self.dwconv(hidden_states) x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) x = self.layernorm(x) x = self.pwconv1(x) x = self.act(x) x = self.grn(x) x = self.pwconv2(x) if self.layer_scale_parameter is not None: x = self.layer_scale_parameter * x x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) x = input + self.drop_path(x) return x class ConvNextStage(nn.Module): """ConvNeXt V2 stage with optional downsampling and residual blocks.""" def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 2, stride: int = 2, depth: int = 2, drop_path_rates: Optional[list[float]] = None, layer_scale_init_value: float = 1e-6, use_grn: bool = True, ) -> None: """Build a ConvNeXt stage that can downsample and stack layers. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. kernel_size (int): Kernel size for stripe downsampling. stride (int): Stride for downsampling. depth (int): Number of layers in the stage. drop_path_rates (Optional[list[float]]): Per-layer drop path rates. layer_scale_init_value (float): Residual scaling initial value. use_grn (bool): Whether to enable GRN. """ super().__init__() if in_channels != out_channels or stride > 1: self.downsampling_layer = nn.Sequential( ConvNextLayerNorm( in_channels, eps=1e-6, data_format="channels_first" ), nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, ), ) else: self.downsampling_layer = nn.Identity() drop_path_rates = drop_path_rates or [0.0] * depth self.layers = nn.Sequential( *[ ConvNextLayer( dim=out_channels, drop_path=drop_path_rates[j], layer_scale_init_value=layer_scale_init_value, use_grn=use_grn, ) for j in range(depth) ] ) def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: """Process a batch through downsampling and residual layers. Args: hidden_states (torch.FloatTensor): Input tensor of shape (batch, channels, height, width). Returns: torch.Tensor: Output tensor after the stage. """ hidden_states = self.downsampling_layer(hidden_states) hidden_states = self.layers(hidden_states) return hidden_states class ConvNextEncoder(nn.Module): """ConvNeXt V2 encoder.""" def __init__( self, hidden_sizes: list[int], depths: list[int], drop_path_rate: float = 0.0, layer_scale_init_value: float = 1e-6, use_grn: bool = True, ) -> None: """Construct the ConvNeXt encoder with multiple stages. Args: hidden_sizes (list[int]): Hidden dimensions per stage. depths (list[int]): Number of layers per stage. drop_path_rate (float): Maximum drop path rate (linear schedule). layer_scale_init_value (float): Initial residual scaling. use_grn (bool): Whether to use GRN within layers. """ super().__init__() self.stages = nn.ModuleList() self.gradient_checkpointing = False num_stages = len(hidden_sizes) total_depth = sum(depths) drop_path_schedule = np.linspace( 0.0, float(drop_path_rate), total_depth ).tolist() drop_path_rates = [] start = 0 for depth in depths: end = start + depth drop_path_rates.append(drop_path_schedule[start:end]) start = end # Keep track of the previous stage channel count for connecting stages. prev_chs = hidden_sizes[0] for i in range(num_stages): out_chs = hidden_sizes[i] stage = ConvNextStage( in_channels=prev_chs, out_channels=out_chs, stride=2 if i > 0 else 1, depth=depths[i], drop_path_rates=drop_path_rates[i], layer_scale_init_value=layer_scale_init_value, use_grn=use_grn, ) self.stages.append(stage) prev_chs = out_chs def forward( self, hidden_states: torch.FloatTensor, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Tuple: """Forward propagate through the ConvNeXt encoder stack. Args: hidden_states (torch.FloatTensor): Input tensor shaped (batch, channels, height, width). output_hidden_states (Optional[bool]): Whether to collect intermediate states. return_dict (Optional[bool]): Whether to return tuple or dict-like output. Returns: Tuple: Last hidden state followed by optional hidden states tuple. """ all_hidden_states = () if output_hidden_states else None for i, layer_module in enumerate(self.stages): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: hidden_states = torch.utils.checkpoint.checkpoint( layer_module, hidden_states, use_reentrant=False, ) else: hidden_states = layer_module(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, all_hidden_states] if v is not None ) return (hidden_states, all_hidden_states) class ConvNextModel(nn.Module): """ConvNeXt V2 model.""" def __init__( self, hidden_sizes: list[int], depths: list[int], num_channels: int = 3, patch_size: int = 4, drop_path_rate: float = 0.0, layer_scale_init_value: float = 1e-6, use_grn: bool = True, ) -> None: """Build the ConvNeXt V2 model with embedding, encoder, and pooling. Args: hidden_sizes (list[int]): Hidden channel sizes per stage. depths (list[int]): Layer counts per stage. num_channels (int): Number of image channels. patch_size (int): Patch size for initial embedding. drop_path_rate (float): Drop path rate range for residual blocks. layer_scale_init_value (float): Initial scale for residuals. use_grn (bool): Whether to enable GRN. """ super().__init__() if not use_grn: raise ValueError("ConvNeXt V2 requires use_grn=True.") self.embeddings = ConvNextEmbeddings( num_channels, hidden_sizes[0], patch_size ) self.encoder = ConvNextEncoder( hidden_sizes, depths, drop_path_rate, layer_scale_init_value, use_grn, ) self.layernorm = nn.LayerNorm(hidden_sizes[-1], eps=1e-6) # Initialize weights self.apply(self._init_weights) def _init_weights(self, module: nn.Module) -> None: """Initialize module weights following standard ConvNeXt heuristics. Args: module (nn.Module): Module to initialize. """ if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = True, return_pooled: bool = True, ) -> Union[Tuple, BaseModelOutputWithPooling]: """Encode images and optionally return pooled features. Args: pixel_values (Optional[torch.FloatTensor]): Input tensor shaped (batch, channels, height, width). output_hidden_states (Optional[bool]): Whether to return intermediate hidden states. return_dict (Optional[bool]): Whether to return output as BaseModelOutput. return_pooled (bool): Whether to include pooled output. Returns: Union[Tuple, BaseModelOutputWithPooling]: Model outputs containing last hidden states and optionally pooled output. Raises: ValueError: If `pixel_values` is None. """ if pixel_values is None: raise ValueError("You have to specify pixel_values") embedding_output = self.embeddings(pixel_values) encoder_outputs = self.encoder( embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] all_hidden_states = ( encoder_outputs[1] if output_hidden_states else None ) # Skip pooled output when callers only need token features. pooled_output = None if return_pooled: # Global average pooling, (N, C, H, W) -> (N, C). pooled_output = self.layernorm(last_hidden_state.mean([-2, -1])) if not return_dict: outputs = [last_hidden_state] if return_pooled: outputs.append(pooled_output) if output_hidden_states: outputs.append(all_hidden_states) return tuple(outputs) return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=all_hidden_states, ) @staticmethod def from_pretrained(model_path: Path | str) -> "ConvNextModel": """Load ConvNeXt model weights from a pretrained checkpoint directory. Args: model_path (Path | str): Directory path containing the checkpoint files. Returns: ConvNextModel: Initialized model with weights loaded from checkpoint. Raises: NotImplementedError: If config.json is missing in the directory. FileNotFoundError: If no weight file is found. """ model_path_str = str(model_path) model_path_obj = Path(model_path_str) # Check if this is a HuggingFace model path is_ckpt_dir = ( model_path_obj.is_dir() and (model_path_obj / "config.json").exists() ) if not is_ckpt_dir: raise NotImplementedError( "The checkpoint path should be a directory containing config.json " "and model.safetensors or pytorch_model.bin files." ) # Load configuration config = ConvNextConfig.from_pretrained(model_path_str) checkpoint_dir = model_path_obj # Create our model directly if not config.use_grn: raise ValueError( "ConvNeXt V2 requires use_grn=True in the checkpoint config." ) logger.info( "Loading ConvNeXt V2 model from checkpoint: %s", checkpoint_dir ) model = ConvNextModel( hidden_sizes=config.hidden_sizes, depths=config.depths, num_channels=config.num_channels, patch_size=config.patch_size, drop_path_rate=config.drop_path_rate, layer_scale_init_value=config.layer_scale_init_value, use_grn=config.use_grn, ) # Load state dict from checkpoint files state_dict = {} # Try to load from safetensors first (preferred) safetensors_file = checkpoint_dir / "model.safetensors" if safetensors_file.exists(): logger.info("Loading weights from %s", safetensors_file) state_dict = safetensors_load(str(safetensors_file)) else: # Try pytorch_model.bin pytorch_file = checkpoint_dir / "pytorch_model.bin" if pytorch_file.exists(): logger.info("Loading weights from %s", pytorch_file) state_dict = torch.load( str(pytorch_file), map_location="cpu", weights_only=False ) else: # Try sharded checkpoints shard_files = sorted( glob.glob(str(checkpoint_dir / "pytorch_model-*.bin")) ) if shard_files: logger.info( "Loading weights from %s sharded files", len(shard_files), ) for shard_file in shard_files: state_dict.update( torch.load( shard_file, map_location="cpu", weights_only=False, ) ) else: raise FileNotFoundError( f"Could not find model weights in {checkpoint_dir}. " "Expected model.safetensors, pytorch_model.bin, or pytorch_model-*.bin files." ) # Load the mapped state dict into our model missing_keys, unexpected_keys = model.load_state_dict( state_dict, strict=False ) if missing_keys: logger.warning( "Some weights of the model were not initialized from the checkpoint " "and are newly initialized: %s", missing_keys, ) if unexpected_keys: logger.warning( "Some weights of the checkpoint were not used when initializing the model: %s", unexpected_keys, ) return model class ConvNextVisionModel(nn.Module): """Vision model wrapper around ConvNeXt V2 backbone.""" def __init__(self, config: Optional[ConvNextConfig] = None): """Wrap ConvNeXt backbone for use within the multimodal stack. Args: config (Optional[ConvNextConfig]): Configuration for the ConvNeXt backbone. Raises: ValueError: If the config lacks required ConvNeXt attributes. """ super().__init__() if config is None: config = ConvNextConfig.convnext_large() self.config = config # Support both HuggingFace config and ensure we extract the right parameters if hasattr(config, "hidden_sizes"): # HuggingFace-style config hidden_sizes = config.hidden_sizes depths = config.depths num_channels = config.num_channels patch_size = config.patch_size drop_path_rate = config.drop_path_rate layer_scale_init_value = config.layer_scale_init_value use_grn = config.use_grn else: raise ValueError("Config must be a ConvNextConfig") if not use_grn: raise ValueError("ConvNeXt V2 requires use_grn=True.") self.backbone = ConvNextModel( hidden_sizes=hidden_sizes, depths=depths, num_channels=num_channels, patch_size=patch_size, drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value, use_grn=use_grn, ) @staticmethod def from_pretrained(model_path: Path | str) -> "ConvNextVisionModel": """Load a vision wrapper with pretrained ConvNeXt weights. Args: model_path (Path | str): Directory path containing the pretrained weights. Returns: ConvNextVisionModel: Wrapper instance with backbone weights loaded. """ # Load the backbone model backbone = ConvNextModel.from_pretrained(model_path) config = ConvNextConfig.from_pretrained(str(model_path)) wrapper = ConvNextVisionModel(config) wrapper.backbone = backbone return wrapper def forward( self, pixel_values: torch.FloatTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: bool = True, patch_attention_mask: Optional[torch.Tensor] = None, return_pooled: bool = True, ) -> Union[Tuple, BaseModelOutputWithPooling]: """Encode pixel values and reformat the ConvNeXt output. Args: pixel_values (torch.FloatTensor): Input tensor shaped (batch, channels, height, width). output_attentions (Optional[bool]): Ignored but present for compatibility. output_hidden_states (Optional[bool]): Whether to return staged hidden states. return_dict (bool): Whether to return `BaseModelOutputWithPooling`. patch_attention_mask (Optional[torch.Tensor]): Mask for patch tokens (unused here). return_pooled (bool): Whether to request pooled output. Returns: Union[Tuple, BaseModelOutputWithPooling]: Vision outputs in sequence format. """ # Avoid pooled output unless requested to reduce extra work. outputs = self.backbone( pixel_values, output_hidden_states=output_hidden_states, return_dict=True, return_pooled=return_pooled, ) outputs = cast(BaseModelOutputWithPooling, outputs) last_hidden_state = outputs.last_hidden_state # (b, c, h, w) pooled = outputs.pooler_output if return_pooled else None # Convert to sequence format: (b, c, h, w) -> (b, h*w, c) last_hidden_state = rearrange( last_hidden_state, "b c h w -> b (h w) c" ) if return_dict: return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled, hidden_states=( outputs.hidden_states if output_hidden_states else None ), ) if output_hidden_states: outputs_tuple = [last_hidden_state] if return_pooled: outputs_tuple.append(pooled) outputs_tuple.append(outputs.hidden_states) return tuple(outputs_tuple) if return_pooled: return (last_hidden_state, pooled) return (last_hidden_state,) # ---- Yasa language model utilities (inlined) ---- @use_kernel_forward_from_hub("RMSNorm") class YasaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ YasaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt( variance + self.variance_epsilon ) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class YasaRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: YasaConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and isinstance( config.rope_scaling, dict ): self.rope_type = config.rope_scaling.get( "rope_type", config.rope_scaling.get("type") ) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device ) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = ( self.inv_freq[None, :, None] .float() .expand(position_ids.shape[0], -1, 1) .to(x.device) ) position_ids_expanded = position_ids[:, None, :].float() device_type = ( x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" ) with torch.autocast( device_type=device_type, enabled=False ): # Force float32 freqs = ( inv_freq_expanded.float() @ position_ids_expanded.float() ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class YasaMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear( self.hidden_size, self.intermediate_size, bias=config.mlp_bias ) self.up_proj = nn.Linear( self.hidden_size, self.intermediate_size, bias=config.mlp_bias ) self.down_proj = nn.Linear( self.intermediate_size, self.hidden_size, bias=config.mlp_bias ) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj( self.act_fn(self.gate_proj(x)) * self.up_proj(x) ) return down_proj def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape( batch, num_key_value_heads * n_rep, slen, head_dim ) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query.dtype) attn_weights = nn.functional.dropout( attn_weights, p=dropout, training=module.training ) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class YasaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: YasaConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads, ) self.num_key_value_groups = ( config.num_attention_heads // config.num_key_value_heads ) self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias, ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias, ) @deprecate_kwarg( "past_key_value", new_name="past_key_values", version="4.58" ) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = ( self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) ) key_states = ( self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) ) value_states = ( self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) ) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin, "cos": cos, "cache_position": cache_position, } key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, cache_kwargs ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[ self.config._attn_implementation ] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class YasaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: YasaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = YasaAttention(config=config, layer_idx=layer_idx) self.mlp = YasaMLP(config) self.input_layernorm = YasaRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.post_attention_layernorm = YasaRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) @deprecate_kwarg( "past_key_value", new_name="past_key_values", version="4.58" ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[ tuple[torch.Tensor, torch.Tensor] ] = None, # necessary, but kept here for BC **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class YasaPreTrainedModel(PreTrainedModel): config = Yasa2Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["YasaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": YasaDecoderLayer, "attentions": YasaAttention, } @auto_docstring class YasaModel(YasaPreTrainedModel): def __init__(self, config: YasaConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = nn.ModuleList( [ YasaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self.norm = YasaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = YasaRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @check_model_inputs() @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You must specify exactly one of input_ids or inputs_embeds" ) if inputs_embeds is None: inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position: torch.Tensor = ( torch.arange( inputs_embeds.shape[1], device=inputs_embeds.device ) + past_seen_tokens ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) hidden_states = inputs_embeds position_embeddings = self.rotary_emb( hidden_states, position_ids=position_ids ) for decoder_layer in self.layers[: self.config.num_hidden_layers]: hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) class Yasa2Model(YasaPreTrainedModel): """Pretrained base class that holds the full Yasa2 multimodal stack.""" config_class: PretrainedConfig = Yasa2Config base_model_prefix: str = "" _checkpoint_conversion_mapping: Dict[str, str] = {} _no_split_modules = ["YasaDecoderLayer", "ConvNextVisionModel"] config: Yasa2Config def __init__( self, config: Yasa2Config, ): """Initialize the full Yasa2 multimodal stack. Args: config (Yasa2Config): Configuration for the multimodal model. """ super().__init__(config) self.vision_pooling = config.vision_pooling if self.vision_pooling != "adaptive_avg": raise ValueError( f"Yasa2 only supports adaptive_avg vision pooling, got {self.vision_pooling}" ) self.adaptive_pooling = nn.AdaptiveAvgPool2d( int(config.num_query_tokens**0.5) ) if not (config.num_query_tokens**0.5).is_integer(): raise ValueError( f"num_query_tokens {config.num_query_tokens} must be a " "square number for adaptive_avg pooling" ) # Set up vision backbone vision_config = config.vision_config if isinstance(vision_config, dict): vision_config = ConvNextConfig(**vision_config) self.vision_model = ConvNextVisionModel(vision_config) self.language_projection = nn.Sequential( nn.Linear( config.vision_config.hidden_size, config.text_config.hidden_size, ), nn.GELU(), nn.Linear( config.text_config.hidden_size, config.text_config.hidden_size, ), ) # Set up language model self.language_model = YasaModel(config.text_config) # Store only the raw non-learned vision positional embedding data. # Build device/dtype-specific tensors lazily in forward. self.add_vision_pos_embed = config.use_vision_pos_embed self._vision_pos_embed_np = get_2d_sincos_pos_embed( config.vision_config.hidden_size, image_size=50, ) self._vision_pos_embed_cache: Dict[str, torch.Tensor] = {} self.post_init() def get_input_embeddings(self) -> torch.nn.Module: """Return the multimodal head's input embeddings. Returns: torch.nn.Module: Embedding module used by the language model. """ return self.language_model.get_input_embeddings() def set_input_embeddings(self, value: torch.nn.Module) -> None: """Override the multimodal head's input embeddings. Args: value (torch.nn.Module): Embedding module to register. """ self.language_model.set_input_embeddings(value) def set_decoder(self, decoder: YasaModel) -> None: """Proxy to set the multimodal model decoder. Args: decoder: Decoder to register with the multimodal model. """ self.language_model = decoder def get_decoder(self) -> YasaModel: """Return the decoder component. Returns: YasaModel: Registered decoder module. """ return self.language_model def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, torch.Tensor]: """Return a filtered state dict that omits derived or non-persistent buffers. Args: *args: Positional arguments forwarded to the superclass. **kwargs: Keyword arguments forwarded to the superclass. Returns: Dict[str, torch.Tensor]: Filtered parameter mapping. """ state_dict = super().state_dict(*args, **kwargs) for key in list(state_dict.keys()): # masked_bias is a constant non-persistent attention buffer (-1e9). if "attention.masked_bias" in key: state_dict.pop(key, None) continue # rotary_emb.inv_freq is derived from rotary dims/base and rebuilt at init. if "rotary_emb.inv_freq" in key: state_dict.pop(key, None) return state_dict def _encode_vision_adaptive_2d_avg_pooling( self, pixel_values: torch.Tensor, patch_attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Encode vision inputs via the ConvNeXt backbone and adaptive avg pooling. Args: pixel_values (torch.Tensor): Vision input tensor. patch_attention_mask (Optional[torch.Tensor]): Optional patch mask. Returns: torch.Tensor: Vision embeddings projected into text hidden size. """ # Vision prefill only needs patch tokens; skip pooled output. image_embeds = self.vision_model( pixel_values=pixel_values, output_attentions=None, output_hidden_states=None, return_dict=False, patch_attention_mask=patch_attention_mask, return_pooled=False, )[0] img_num, seq_length, vision_hidden_size = image_embeds.size() height, width = int(seq_length**0.5), int(seq_length**0.5) if self.add_vision_pos_embed: vision_pos_embed = self._get_vision_pos_embed( device=image_embeds.device, dtype=image_embeds.dtype, seq_len=image_embeds.size(1), ) image_embeds = image_embeds + vision_pos_embed image_embeds = image_embeds.permute(0, 2, 1).contiguous() image_embeds = image_embeds.reshape( img_num, vision_hidden_size, height, width ) if ( self.config.apply_patch_attention_mask and patch_attention_mask is not None and patch_attention_mask.numel() > 0 ): patch_attention_mask = patch_attention_mask.reshape( img_num, height, width ) image_embeds = image_embeds * patch_attention_mask.unsqueeze(1).to( dtype=image_embeds.dtype ) # Force pooling in fp32 with autocast disabled; bf16 pooling can produce NaNs. pooled_dtype = image_embeds.dtype with torch.autocast(device_type="cuda", enabled=False): image_embeds = torch.nn.functional.adaptive_avg_pool2d( image_embeds.float(), self.adaptive_pooling.output_size ) image_embeds = image_embeds.to(dtype=pooled_dtype) image_embeds = image_embeds.flatten(2) image_embeds = image_embeds.permute(0, 2, 1).contiguous() vision_embeds = self.language_projection(image_embeds) return vision_embeds def _get_vision_pos_embed( self, device: torch.device, dtype: torch.dtype, seq_len: int, ) -> torch.Tensor: """Return cached/runtime-built vision positional embeddings.""" cache_key = f"{device}:{dtype}" cached = self._vision_pos_embed_cache.get(cache_key) if cached is None: cached = ( torch.from_numpy(self._vision_pos_embed_np) .view(-1, self.config.vision_config.hidden_size) .to(device=device, dtype=dtype) .unsqueeze(0) ) self._vision_pos_embed_cache[cache_key] = cached return cached[:, :seq_len, :] def get_image_features( self, pixel_values: torch.Tensor, **kwargs: Any ) -> torch.Tensor: """Return vision features for vLLM compatibility.""" patch_attention_mask = kwargs.get("patch_attention_mask") return self._encode_vision_adaptive_2d_avg_pooling( pixel_values, patch_attention_mask=patch_attention_mask ) @classmethod def scatter_embeddings_to_target_special_id( cls, target_tensor: torch.Tensor, target_input_ids: torch.Tensor, src_embeddings: torch.Tensor, special_token_id: int, ) -> torch.Tensor: """Scatter vision embeddings into the language embedding buffer at special tokens. Args: target_tensor (torch.Tensor): Target embedding buffer to update. target_input_ids (torch.Tensor): Input IDs aligned with the target tensor. src_embeddings (torch.Tensor): Source embeddings to scatter from vision outputs. special_token_id (int): Token ID used to locate insertion positions. Returns: torch.Tensor: Updated target tensor with vision embeddings placed at special IDs. """ b_source, n_source, d_embedding = src_embeddings.shape b_target, n_target, d_target = target_tensor.shape if b_target != target_input_ids.size(0): raise ValueError( "Batch size mismatch: target_input_ids " f"{target_input_ids.size(0)} vs target_tensor {b_target}" ) if n_target != target_input_ids.size(1): raise ValueError( "Sequence length mismatch: target_input_ids " f"{target_input_ids.size(1)} vs target_tensor {n_target}" ) if d_embedding != d_target: raise ValueError( "Embedding dimension mismatch: src_embeddings " f"{d_embedding} vs target_tensor {d_target}" ) special_token_mask = target_input_ids.view(-1) == special_token_id special_token_indices = torch.nonzero(special_token_mask).squeeze(-1) if len(special_token_indices) != b_source * n_source: raise ValueError( "Special token count mismatch: found " f"{len(special_token_indices)}, expected {b_source * n_source}" ) target_tensor = target_tensor.view(-1, d_embedding) src_embeddings = src_embeddings.view(-1, d_embedding) target_tensor[special_token_indices] = src_embeddings target_tensor = target_tensor.view(b_target, n_target, d_embedding) return target_tensor def _interleave_scatter( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, inputs_embeds: torch.Tensor, vision_embeds: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Scatter vision embeddings into language embeddings at the image token positions. Args: input_ids (torch.Tensor): Token IDs containing image placeholders. attention_mask (torch.Tensor): Attention mask for text tokens. inputs_embeds (torch.Tensor): Language model input embeddings. vision_embeds (torch.Tensor): Vision embeddings to be inserted. Returns: Tuple[torch.Tensor, torch.Tensor]: Updated inputs_embeds and attention_mask. """ inputs_embeds = Yasa2Model.scatter_embeddings_to_target_special_id( target_tensor=inputs_embeds, target_input_ids=input_ids, src_embeddings=vision_embeds, special_token_id=self.config.image_token_id, ) return inputs_embeds, attention_mask @can_return_tuple def forward( self, input_ids: Optional[torch.LongTensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values: Optional[ Union[Cache, Tuple[Tuple[torch.FloatTensor]]] ] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, patch_attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, mm_token_type_ids: Optional[torch.Tensor] = None, **kwargs: Any, ) -> Union[Tuple[torch.Tensor, ...], "Yasa2ModelOutputWithPast"]: """Forward pass combining language and vision inputs for Yasa2. Args: input_ids (Optional[torch.LongTensor]): Token IDs for the language model. attention_mask (Optional[torch.Tensor]): Attention mask aligned with `input_ids`. position_ids (Optional[torch.LongTensor]): Position indices feeding the language model. inputs_embeds (Optional[torch.FloatTensor]): Precomputed token embeddings. past_key_values (Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]]): Cached decoder key/value tensors. cache_position (Optional[torch.LongTensor]): Positions used for cache alignment. use_cache (Optional[bool]): Whether to request cached key/values. output_attentions (Optional[bool]): Whether to return attention weights. output_hidden_states (Optional[bool]): Whether to return hidden states for each layer. return_dict (Optional[bool]): Whether to return a `ModelOutput`. pixel_values (Optional[torch.Tensor]): Vision inputs providing image context. patch_attention_mask (Optional[torch.Tensor]): Optional patch mask for vision tokens. token_type_ids (Optional[torch.Tensor]): Unused token type ids for compatibility. mm_token_type_ids (Optional[torch.Tensor]): Unused multimodal token type ids. Returns: Union[Tuple[torch.Tensor, ...], Yasa2ModelOutputWithPast]: Combined multimodal outputs. """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) use_cache = ( use_cache if use_cache is not None else self.config.use_cache ) if input_ids is None and inputs_embeds is None: raise ValueError( "You must provide either input_ids or inputs_embeds." ) if inputs_embeds is not None and pixel_values is not None: raise ValueError( "pixel_values cannot be used when inputs_embeds is provided." ) if inputs_embeds is None: inputs_embeds = self.language_model.get_input_embeddings()( input_ids ) if attention_mask is None: pad_token_id = self.config.text_config.pad_token_id if input_ids is not None and pad_token_id is not None: if (input_ids == pad_token_id).any(): attention_mask = input_ids.ne(pad_token_id) if attention_mask is not None: if attention_mask.numel() == 0: attention_mask = None if cache_position is not None: expected_len = inputs_embeds.shape[1] if cache_position.shape[-1] != expected_len: raise ValueError( "cache_position length must match input sequence length: " f"{cache_position.shape[-1]} vs {expected_len}" ) vision_embeds = None if pixel_values is not None and len(pixel_values) > 0: if input_ids is None: raise ValueError( "input_ids is required when pixel_values is provided." ) vision_embeds = self._encode_vision_adaptive_2d_avg_pooling( pixel_values, patch_attention_mask=patch_attention_mask, ) inputs_embeds, attention_mask = self._interleave_scatter( input_ids, attention_mask, inputs_embeds, vision_embeds, ) outputs = self.language_model( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, head_mask=None, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, return_dict=True, **kwargs, ) return Yasa2ModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, vision_hidden_states=vision_embeds, ) class Yasa2ForConditionalGeneration(YasaPreTrainedModel, GenerationMixin): """Yasa2 multimodal conditional generation model (vision + text).""" config_class = Yasa2Config _checkpoint_conversion_mapping = {} _tied_weights_keys = [] # Weights are not tied config: Yasa2Config def __init__(self, config: Yasa2Config): """Initialize the Yasa2 conditional generation model. Args: config: Yasa2 configuration object. """ super().__init__(config) self.model = Yasa2Model(config) self.lm_head = nn.Linear( config.hidden_size, config.vocab_size, bias=False ) self.vocab_size = config.vocab_size # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> torch.nn.Module: """Return the multimodal head's input embeddings. Returns: torch.nn.Module: Embedding module used by the language model. """ return self.model.language_model.get_input_embeddings() def set_input_embeddings(self, value: torch.nn.Module) -> None: """Override the multimodal head's input embeddings. Args: value (torch.nn.Module): Embedding module to register. """ self.model.language_model.set_input_embeddings(value) def set_decoder(self, decoder): """Proxy to set the multimodal model decoder. Args: decoder: Decoder to register with the multimodal model. """ self.model.set_decoder(decoder) def get_decoder(self): """Proxy to return the multimodal decoder.""" return self.model.get_decoder() # Make modules available throught conditional class for BC @property def language_model(self) -> torch.nn.Module: """Expose the language model component. Returns: torch.nn.Module: Language model module. """ return self.model.language_model @property def vision_backbone(self) -> torch.nn.Module: """Expose the vision encoder backbone. Returns: torch.nn.Module: Vision backbone module. """ return self.model.vision_model @can_return_tuple def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[ Union[Cache, Tuple[Tuple[torch.FloatTensor]]] ] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.Tensor] = None, patch_attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, mm_token_type_ids: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, **kwargs: Any, ) -> Union[ Tuple[torch.Tensor, ...], "Yasa2ForConditionalGenerationModelOutput" ]: """Run the multimodal model, project outputs to logits, and compute loss if needed. Args: input_ids (Optional[torch.LongTensor]): Language token IDs. attention_mask (Optional[torch.Tensor]): Attention mask for language tokens. position_ids (Optional[torch.LongTensor]): Position indices. past_key_values (Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]]): Cached decoder states. inputs_embeds (Optional[torch.FloatTensor]): Input embeddings instead of token IDs. use_cache (Optional[bool]): Whether to cache key/value pairs. output_attentions (Optional[bool]): Whether to return attention weights. output_hidden_states (Optional[bool]): Whether to return hidden states. cache_position (Optional[torch.LongTensor]): Positions used for caching. pixel_values (Optional[torch.Tensor]): Vision inputs. patch_attention_mask (Optional[torch.Tensor]): Optional mask for vision patches. token_type_ids (Optional[torch.Tensor]): Unused token type ids for compatibility. mm_token_type_ids (Optional[torch.Tensor]): Unused multimodal token type ids. labels (Optional[torch.LongTensor]): Labels for computing cross-entropy loss. return_dict (Optional[bool]): Whether to return a dict-like output. Returns: Union[Tuple[torch.Tensor, ...], Yasa2ForConditionalGenerationModelOutput]: Model logits, caches, and optional loss. """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, cache_position=cache_position, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, return_dict=True, **kwargs, ) hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states) loss = None if labels is not None: labels = labels.to(logits.device) shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:] loss_fct = nn.CrossEntropyLoss( ignore_index=self.config.label_ignore_index ) loss = loss_fct( shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1), ) return Yasa2ForConditionalGenerationModelOutput( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, vision_hidden_states=outputs.vision_hidden_states, language_model_outputs=outputs, ) def generate( self, input_ids: Optional[torch.LongTensor], attention_mask: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, patch_attention_mask: Optional[torch.Tensor] = None, **generate_kwargs, ) -> torch.LongTensor: """Generate text tokens conditioned on vision and/or language inputs. Args: input_ids (Optional[torch.LongTensor]): Seed language tokens. attention_mask (Optional[torch.Tensor]): Language attention mask. pixel_values (Optional[torch.Tensor]): Vision inputs appended to prompts. patch_attention_mask (Optional[torch.Tensor]): Mask for vision patches. **generate_kwargs: Additional generation options forwarded to the `super().generate`. Returns: torch.LongTensor: Generated token IDs. """ return super().generate( input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, **generate_kwargs, ) def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[ Union[Cache, Tuple[Tuple[torch.FloatTensor]]] ] = None, inputs_embeds: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.Tensor] = None, patch_attention_mask: Optional[torch.Tensor] = None, **kwargs: Any, ) -> Dict[str, Any]: """Prepare multimodal inputs for generation bookkeeping. Args: input_ids (torch.LongTensor): Current token IDs for generation. past_key_values (Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]]): Cached past key/value tensors. inputs_embeds (Optional[torch.FloatTensor]): Optional token embeddings. attention_mask (Optional[torch.Tensor]): Language attention mask. cache_position (Optional[torch.LongTensor]): Cache alignment positions. pixel_values (Optional[torch.Tensor]): Vision inputs that should be reused. patch_attention_mask (Optional[torch.Tensor]): Vision patch mask for the prefill step. **kwargs: Additional arguments forwarded to the base implementation. Returns: Dict[str, Any]: Prepared inputs for the next generation step. """ model_inputs = super().prepare_inputs_for_generation( input_ids=input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, **kwargs, ) is_prefill = past_key_values is None or ( cache_position is not None and cache_position[0] == 0 ) if is_prefill: model_inputs["pixel_values"] = pixel_values model_inputs["patch_attention_mask"] = patch_attention_mask return model_inputs Yasa2ForConditionalGeneration.register_for_auto_class( "AutoModelForImageTextToText" )