gmastrapas commited on
Commit
3d813dc
·
verified ·
1 Parent(s): 1ba27de

Model update

Browse files
README.md CHANGED
@@ -286,7 +286,9 @@ processor = AutoProcessor.from_pretrained(
286
 
287
  # Load the model on the available device(s)
288
  model = AutoModelForCausalLM.from_pretrained(
289
- 'jinaai/jina-vlm-v1', device_map='auto', trust_remote_code=True
 
 
290
  )
291
 
292
  # You can specify a different model dtype and/or attention implementation
 
286
 
287
  # Load the model on the available device(s)
288
  model = AutoModelForCausalLM.from_pretrained(
289
+ 'jinaai/jina-vlm-v1',
290
+ device_map='auto',
291
+ trust_remote_code=True
292
  )
293
 
294
  # You can specify a different model dtype and/or attention implementation
blocks_jvlm.py CHANGED
@@ -1,6 +1,7 @@
1
  # Copyright 2025 Jina AI. All rights reserved.
2
 
3
  from abc import ABCMeta, abstractmethod
 
4
  from copy import deepcopy
5
  from functools import wraps
6
  from math import prod, sqrt
@@ -11,6 +12,7 @@ import torch
11
  import torch.backends.cuda
12
  import torch.nn as nn
13
  import torch.nn.functional as f
 
14
  from transformers import PretrainedConfig
15
  from transformers.activations import ACT2FN
16
  from transformers.cache_utils import Cache
@@ -324,10 +326,11 @@ modeling_rope_utils.py
324
 
325
 
326
  def inv_freq_to_device(rope_forward):
327
- """
328
- Sometimes the inv_freq is calculated on the wrong device, or ends up in lower
329
- precision than float32. This wrapper ensures that inv_freq is always on the right
330
- device and in float32 precision.
 
331
  """
332
 
333
  @wraps(rope_forward)
@@ -353,7 +356,6 @@ class RotaryEmbedding(nn.Module):
353
  theta: float,
354
  head_dim: int,
355
  hidden_size: int,
356
- n_heads: int,
357
  partial_rotary_factor: float,
358
  device: Optional[torch.device] = None,
359
  scaling: Optional[Dict[str, Any]] = None,
@@ -366,7 +368,6 @@ class RotaryEmbedding(nn.Module):
366
  setattr(self.config, 'rope_theta', theta)
367
  setattr(self.config, 'partial_rotary_factor', partial_rotary_factor)
368
  setattr(self.config, 'head_dim', head_dim)
369
- setattr(self.config, 'num_attention_heads', n_heads)
370
  setattr(self.config, 'hidden_size', hidden_size)
371
  setattr(self.config, 'rope_scaling', scaling or {})
372
 
@@ -377,9 +378,7 @@ class RotaryEmbedding(nn.Module):
377
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
378
  device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
379
  seqlen = config.max_position_embeddings or config.max_sequence_length
380
- invfreq, self.attention_scaling = self.rope_init_fn(
381
- self.config, device, seqlen
382
- )
383
  self.rope_init_device = device
384
  self.register_buffer('inv_freq', invfreq, persistent=False)
385
  self.original_inv_freq = self.inv_freq
@@ -617,11 +616,9 @@ def _create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
617
  def _ensure_finite(
618
  x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False
619
  ):
620
- """
621
- Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the
622
  dtype when ``check_neg_inf`` is ``True`` and replace ``float("inf")`` with the
623
- maximum value of the dtype when ``check_pos_inf`` is ``True``
624
- """
625
  if check_neg_inf:
626
  x.masked_fill_(x == float('-inf'), torch.finfo(x.dtype).min)
627
  if check_pos_inf:
@@ -641,14 +638,12 @@ def resolve_causal_mask(
641
  # shape: (batch_size, 1, 1, seq_len)
642
  if len(attention_mask.shape) == 2:
643
  attention_mask = attention_mask[:, : past_length + seq_len]
644
- attention_mask = attention_mask.to(dtype=torch.float).view(
645
- batch_size, -1
646
- )[:, None, None, :]
647
  else:
648
  attention_mask = attention_mask.unsqueeze(1).to(dtype=torch.float)
649
- attention_mask = (1.0 - attention_mask) * torch.finfo(
650
- attention_mask.dtype
651
- ).min
652
 
653
  # Merge attention mask with causal mask (attention bias)
654
  # NOTE: We need to initialize the attn bias in order for attn to
@@ -660,9 +655,7 @@ def resolve_causal_mask(
660
  or past_key_values is not None
661
  ):
662
  if causal_mask is None:
663
- causal_mask = _create_causal_mask(
664
- past_length + seq_len, device
665
- )
666
  elif causal_mask.dtype in (torch.int8, torch.bool):
667
  causal_mask = causal_mask.to(dtype=torch.float)
668
  causal_mask.masked_fill_(
@@ -719,6 +712,7 @@ def eager_attention_forward(
719
  dropout: float = 0.0,
720
  **_,
721
  ):
 
722
  key_states = repeat_kv(key, module.num_key_value_groups)
723
  value_states = repeat_kv(value, module.num_key_value_groups)
724
 
@@ -745,7 +739,9 @@ def rotate_half(x: torch.Tensor):
745
 
746
 
747
  def apply_rotary_positional_embeddings(
748
- x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
 
 
749
  ) -> torch.Tensor:
750
  return (x * cos + rotate_half(x) * sin).to(x.dtype)
751
 
@@ -890,7 +886,6 @@ class MHSDPA(nn.Module):
890
  attn_mask: Optional[torch.Tensor] = None,
891
  is_causal: Optional[bool] = None,
892
  ) -> Tuple[Callable, Optional[torch.Tensor], Optional[bool]]:
893
-
894
  if 'flash' in attn_implementation and self.fp32_attn:
895
  raise ValueError('Flash attention does not support fp32 attention')
896
  if self.sliding_window != -1 and 'flash' not in attn_implementation:
@@ -1071,9 +1066,7 @@ class FFN(nn.Module):
1071
  if self.gated_activation:
1072
  intermediate_size = 2 * self.intermediate_size
1073
 
1074
- self.up = nn.Linear(
1075
- self.hidden_size, intermediate_size, bias=self.use_bias
1076
- )
1077
  self.down = nn.Linear(
1078
  self.intermediate_size, self.output_size, bias=self.use_bias
1079
  )
@@ -1245,6 +1238,8 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
1245
  assert config.attn_pooling_config is not None
1246
  if config.pooling_type == ImagePooling2DType.attention_2wide:
1247
  pooling_input_size *= 2
 
 
1248
  self.pooling = MHSDPA(
1249
  config.attn_pooling_config,
1250
  hidden_size=pooling_input_size,
@@ -1285,11 +1280,29 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
1285
  self.projector_dropout = Dropout(config.projector_dropout)
1286
  self.feature_dropout = Dropout(config.feature_dropout)
1287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1288
  def forward(
1289
  self,
1290
  image_features: torch.Tensor,
1291
  image_masks: Optional[torch.Tensor] = None,
1292
  attn_implementation: Optional[str] = None,
 
1293
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1294
  # image_features:
1295
  # (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
@@ -1345,11 +1358,19 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
1345
  dh=self.pooling_h,
1346
  dw=self.pooling_w,
1347
  )
 
1348
  if self.pooling_type == ImagePooling2DType.attention_meanq:
1349
  query = image_features.mean(-2, keepdim=True)
1350
- image_features, _ = self.pooling(
1351
- xq=query, xk=image_features, attn_implementation=attn_implementation
1352
  )
 
 
 
 
 
 
 
1353
  elif self.pooling_type not in {
1354
  ImagePooling2DType.none,
1355
  ImagePooling2DType.stack,
 
1
  # Copyright 2025 Jina AI. All rights reserved.
2
 
3
  from abc import ABCMeta, abstractmethod
4
+ from contextlib import nullcontext
5
  from copy import deepcopy
6
  from functools import wraps
7
  from math import prod, sqrt
 
12
  import torch.backends.cuda
13
  import torch.nn as nn
14
  import torch.nn.functional as f
15
+ from torch.nn.attention import SDPBackend, sdpa_kernel
16
  from transformers import PretrainedConfig
17
  from transformers.activations import ACT2FN
18
  from transformers.cache_utils import Cache
 
326
 
327
 
328
  def inv_freq_to_device(rope_forward):
329
+ """Sometimes the inv_freq is calculated on the wrong device, or ends up in lower
330
+ precision than float32.
331
+
332
+ This wrapper ensures that inv_freq is always on the right device and in float32
333
+ precision.
334
  """
335
 
336
  @wraps(rope_forward)
 
356
  theta: float,
357
  head_dim: int,
358
  hidden_size: int,
 
359
  partial_rotary_factor: float,
360
  device: Optional[torch.device] = None,
361
  scaling: Optional[Dict[str, Any]] = None,
 
368
  setattr(self.config, 'rope_theta', theta)
369
  setattr(self.config, 'partial_rotary_factor', partial_rotary_factor)
370
  setattr(self.config, 'head_dim', head_dim)
 
371
  setattr(self.config, 'hidden_size', hidden_size)
372
  setattr(self.config, 'rope_scaling', scaling or {})
373
 
 
378
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
379
  device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
380
  seqlen = config.max_position_embeddings or config.max_sequence_length
381
+ invfreq, self.attention_scaling = self.rope_init_fn(self.config, device, seqlen)
 
 
382
  self.rope_init_device = device
383
  self.register_buffer('inv_freq', invfreq, persistent=False)
384
  self.original_inv_freq = self.inv_freq
 
616
  def _ensure_finite(
617
  x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False
618
  ):
619
+ """Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the
 
620
  dtype when ``check_neg_inf`` is ``True`` and replace ``float("inf")`` with the
621
+ maximum value of the dtype when ``check_pos_inf`` is ``True``"""
 
622
  if check_neg_inf:
623
  x.masked_fill_(x == float('-inf'), torch.finfo(x.dtype).min)
624
  if check_pos_inf:
 
638
  # shape: (batch_size, 1, 1, seq_len)
639
  if len(attention_mask.shape) == 2:
640
  attention_mask = attention_mask[:, : past_length + seq_len]
641
+ attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[
642
+ :, None, None, :
643
+ ]
644
  else:
645
  attention_mask = attention_mask.unsqueeze(1).to(dtype=torch.float)
646
+ attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
 
 
647
 
648
  # Merge attention mask with causal mask (attention bias)
649
  # NOTE: We need to initialize the attn bias in order for attn to
 
655
  or past_key_values is not None
656
  ):
657
  if causal_mask is None:
658
+ causal_mask = _create_causal_mask(past_length + seq_len, device)
 
 
659
  elif causal_mask.dtype in (torch.int8, torch.bool):
660
  causal_mask = causal_mask.to(dtype=torch.float)
661
  causal_mask.masked_fill_(
 
712
  dropout: float = 0.0,
713
  **_,
714
  ):
715
+ assert isinstance(module.num_key_value_groups, int)
716
  key_states = repeat_kv(key, module.num_key_value_groups)
717
  value_states = repeat_kv(value, module.num_key_value_groups)
718
 
 
739
 
740
 
741
  def apply_rotary_positional_embeddings(
742
+ x: torch.Tensor,
743
+ cos: torch.Tensor,
744
+ sin: torch.Tensor,
745
  ) -> torch.Tensor:
746
  return (x * cos + rotate_half(x) * sin).to(x.dtype)
747
 
 
886
  attn_mask: Optional[torch.Tensor] = None,
887
  is_causal: Optional[bool] = None,
888
  ) -> Tuple[Callable, Optional[torch.Tensor], Optional[bool]]:
 
889
  if 'flash' in attn_implementation and self.fp32_attn:
890
  raise ValueError('Flash attention does not support fp32 attention')
891
  if self.sliding_window != -1 and 'flash' not in attn_implementation:
 
1066
  if self.gated_activation:
1067
  intermediate_size = 2 * self.intermediate_size
1068
 
1069
+ self.up = nn.Linear(self.hidden_size, intermediate_size, bias=self.use_bias)
 
 
1070
  self.down = nn.Linear(
1071
  self.intermediate_size, self.output_size, bias=self.use_bias
1072
  )
 
1238
  assert config.attn_pooling_config is not None
1239
  if config.pooling_type == ImagePooling2DType.attention_2wide:
1240
  pooling_input_size *= 2
1241
+
1242
+ attn_implementation, _ = self._resolve_attn_pooling(attn_implementation)
1243
  self.pooling = MHSDPA(
1244
  config.attn_pooling_config,
1245
  hidden_size=pooling_input_size,
 
1280
  self.projector_dropout = Dropout(config.projector_dropout)
1281
  self.feature_dropout = Dropout(config.feature_dropout)
1282
 
1283
+ @staticmethod
1284
+ def _resolve_attn_pooling(attn_implementation: Optional[str] = None):
1285
+ """
1286
+ Flash Attention can cause Inf grads in the attention pooling layer because of
1287
+ very large batch sizes. Setting this to sdpa does not cost us much since
1288
+ sequence lengths in the case of attention pooling are tiny
1289
+ """
1290
+ attn_runtime_ctx = nullcontext()
1291
+ if (
1292
+ attn_implementation is not None
1293
+ and attn_implementation.startswith('flash')
1294
+ ):
1295
+ attn_implementation = 'sdpa'
1296
+ attn_runtime_ctx = sdpa_kernel(backends=[SDPBackend.MATH])
1297
+
1298
+ return attn_implementation, attn_runtime_ctx
1299
+
1300
  def forward(
1301
  self,
1302
  image_features: torch.Tensor,
1303
  image_masks: Optional[torch.Tensor] = None,
1304
  attn_implementation: Optional[str] = None,
1305
+ **kwargs: Unpack[FlashAttentionKwargs],
1306
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1307
  # image_features:
1308
  # (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
 
1358
  dh=self.pooling_h,
1359
  dw=self.pooling_w,
1360
  )
1361
+ image_features = image_features.contiguous()
1362
  if self.pooling_type == ImagePooling2DType.attention_meanq:
1363
  query = image_features.mean(-2, keepdim=True)
1364
+ attn_implementation, attn_runtime_ctx = self._resolve_attn_pooling(
1365
+ attn_implementation
1366
  )
1367
+ with attn_runtime_ctx:
1368
+ image_features, _ = self.pooling(
1369
+ xq=query,
1370
+ xk=image_features,
1371
+ attn_implementation=attn_implementation,
1372
+ **kwargs,
1373
+ )
1374
  elif self.pooling_type not in {
1375
  ImagePooling2DType.none,
1376
  ImagePooling2DType.stack,
config.json CHANGED
@@ -4,6 +4,7 @@
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_jvlm.JinaVLMConfig",
 
7
  "AutoModelForCausalLM": "modeling_jvlm.JinaVLMForConditionalGeneration"
8
  },
9
  "bos_token_id": 151643,
@@ -214,4 +215,4 @@
214
  "spatial_merge_size": 2
215
  }
216
  }
217
- }
 
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_jvlm.JinaVLMConfig",
7
+ "AutoModel": "modeling_jvlm.JinaVLM",
8
  "AutoModelForCausalLM": "modeling_jvlm.JinaVLMForConditionalGeneration"
9
  },
10
  "bos_token_id": 151643,
 
215
  "spatial_merge_size": 2
216
  }
217
  }
218
+ }
configuration_jvlm.py CHANGED
@@ -530,6 +530,11 @@ class JinaVLMTextConfig(PretrainedConfigWithDataclasses):
530
  self.rope_theta = rope_theta
531
  self.rope_scaling = rope_scaling
532
 
 
 
 
 
 
533
 
534
  class JinaVLMConfig(PretrainedConfig):
535
  """JinaVLM configuration.
@@ -545,7 +550,8 @@ class JinaVLMConfig(PretrainedConfig):
545
 
546
  model_type = 'jvlm'
547
  sub_configs = {
548
- 'vision_config': JinaVLMVisionConfig, 'text_config': JinaVLMTextConfig
 
549
  }
550
 
551
  def __init__(
 
530
  self.rope_theta = rope_theta
531
  self.rope_scaling = rope_scaling
532
 
533
+ # Needed for vLLM
534
+ @property
535
+ def num_attention_heads(self) -> int:
536
+ return self.block_config.attn_config.n_heads
537
+
538
 
539
  class JinaVLMConfig(PretrainedConfig):
540
  """JinaVLM configuration.
 
550
 
551
  model_type = 'jvlm'
552
  sub_configs = {
553
+ 'vision_config': JinaVLMVisionConfig,
554
+ 'text_config': JinaVLMTextConfig,
555
  }
556
 
557
  def __init__(
image_processing_jvlm.py CHANGED
@@ -437,6 +437,17 @@ class JinaVLMImageProcessor(BaseImageProcessor):
437
 
438
  """ Base cropping via resizing """
439
 
 
 
 
 
 
 
 
 
 
 
 
440
  def base_resize_cropping(self, image: np.ndarray):
441
  resized, mask = self.resize_image(image, list(self.base_input_size))
442
  resized = self.normalize_image(resized)
@@ -497,6 +508,117 @@ class JinaVLMImageProcessor(BaseImageProcessor):
497
 
498
  return candidate_tilings[ix]
499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  def molmo_overlap_and_resize_cropping(self, image: np.ndarray):
501
  # Discard this many patches from the (left/top, right/bottom) of crops
502
  left_margin, right_margin = self.overlap_margins
@@ -625,37 +747,23 @@ class JinaVLMImageProcessor(BaseImageProcessor):
625
  # new order into sparse structure of `patch_ordering` to fix it
626
  patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
627
 
628
- def get_num_patches(num_tiles, pooling_size) -> int:
629
- if num_tiles > 1:
630
- left_crop_window_patches = (
631
- (crop_window_patches + left_margin + pooling_size - 1)
632
- // pooling_size
633
- * pooling_size
634
- )
635
- middle_crop_window_patches = (
636
- (crop_window_patches + pooling_size - 1)
637
- // pooling_size
638
- * pooling_size
639
- )
640
- right_crop_window_patches = (
641
- (crop_window_patches + right_margin + pooling_size - 1)
642
- // pooling_size
643
- * pooling_size
644
- )
645
- return (
646
- left_crop_window_patches
647
- + (num_tiles - 2) * middle_crop_window_patches
648
- + right_crop_window_patches
649
- )
650
- else:
651
- single_crop_window_patches = (
652
- (crop_patches + pooling_size - 1) // pooling_size * pooling_size
653
- )
654
- return single_crop_window_patches
655
-
656
  # Now build the output tokens
657
- h = get_num_patches(tiling[0], self.pooling_h)
658
- w = get_num_patches(tiling[1], self.pooling_w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
  # for each row of patches, add a patch token per patch
660
  per_row = np.full((w // self.pooling_w,), self.patch_token_id, dtype=np.int32)
661
  if self.use_column_tokens:
@@ -810,6 +918,14 @@ class JinaVLMImageProcessor(BaseImageProcessor):
810
 
811
  return slices, image_masks, patch_ordering_arr, best_grid
812
 
 
 
 
 
 
 
 
 
813
  def minicpm_adaptive_slicing(self, image: np.ndarray, with_thumbnail: bool = True):
814
  scale_resolution = self.base_input_size[0]
815
  refine_image, image_mask, best_grid = self._minicpm_refine_image_for_slicing(
@@ -946,23 +1062,12 @@ class JinaVLMImageProcessor(BaseImageProcessor):
946
  self.start_token_id = start_token_id
947
  self.end_token_id = end_token_id
948
 
949
- def preprocess(
950
- self,
951
- images: ImageInput,
952
- **kwargs: Unpack[JinaVLMImagesKwargs],
953
- ) -> Dict[str, List[np.ndarray]]:
954
- """Preprocess an image or batch of images."""
955
- if images is None or len(images) == 0:
956
- return {
957
- 'image_crops': [],
958
- 'image_tokens': [],
959
- 'image_input_idx': [],
960
- 'image_padding_mask': [],
961
- }
962
-
963
  if 'max_crops' in kwargs and kwargs['max_crops'] is not None:
964
  max_crops = kwargs['max_crops']
965
- self.max_crops = max_crops
966
 
967
  min_pixels = self.min_pixels
968
  if 'min_pixels' in kwargs and kwargs['min_pixels'] is not None:
@@ -984,14 +1089,93 @@ class JinaVLMImageProcessor(BaseImageProcessor):
984
  size = {'shortest_edge': min_pixels, 'longest_edge': max_pixels}
985
  else:
986
  size = {**self.size}
987
-
 
988
  do_resize = self.do_resize
989
  if 'do_resize' in kwargs and kwargs['do_resize'] is not None:
990
  do_resize = kwargs['do_resize']
991
-
992
  do_convert_rgb = self.do_convert_rgb
993
  if 'do_convert_rgb' in kwargs and kwargs['do_convert_rgb'] is not None:
994
  do_convert_rgb = kwargs['do_convert_rgb']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
995
 
996
  # noinspection PyTypeChecker
997
  images = self.fetch_images(images)
@@ -1001,16 +1185,11 @@ class JinaVLMImageProcessor(BaseImageProcessor):
1001
  'Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray '
1002
  'or torch.Tensor'
1003
  )
1004
-
1005
  if do_convert_rgb:
1006
  images = [convert_to_rgb(image) for image in images]
1007
 
1008
  # All transformations expect numpy arrays
1009
  images = [to_numpy_array(image) for image in images]
1010
-
1011
- input_data_format = None
1012
- if 'input_data_format' in kwargs:
1013
- input_data_format = kwargs['input_data_format']
1014
  if input_data_format is None:
1015
  # We assume that all images have the same channel dimension format.
1016
  input_data_format = infer_channel_dimension_format(images[0])
 
437
 
438
  """ Base cropping via resizing """
439
 
440
+ def base_get_n_image_patches(
441
+ self,
442
+ height: int,
443
+ width: int,
444
+ max_crops: int,
445
+ ) -> int:
446
+ raise NotImplementedError(
447
+ 'Function `get_n_image_patches` is not implemented for cropping method '
448
+ f'{CroppingMethod.RESIZE}'
449
+ )
450
+
451
  def base_resize_cropping(self, image: np.ndarray):
452
  resized, mask = self.resize_image(image, list(self.base_input_size))
453
  resized = self.normalize_image(resized)
 
508
 
509
  return candidate_tilings[ix]
510
 
511
+ @staticmethod
512
+ def _molmo_get_patches_from_tiling(
513
+ num_tiles,
514
+ pooling_size,
515
+ crop_patches,
516
+ crop_window_patches,
517
+ left_margin,
518
+ right_margin,
519
+ ) -> np.int32:
520
+ if num_tiles > 1:
521
+ left_crop_window_patches = (
522
+ (crop_window_patches + left_margin + pooling_size - 1)
523
+ // pooling_size
524
+ * pooling_size
525
+ )
526
+ middle_crop_window_patches = (
527
+ (crop_window_patches + pooling_size - 1) // pooling_size * pooling_size
528
+ )
529
+ right_crop_window_patches = (
530
+ (crop_window_patches + right_margin + pooling_size - 1)
531
+ // pooling_size
532
+ * pooling_size
533
+ )
534
+ return (
535
+ left_crop_window_patches
536
+ + (num_tiles - 2) * middle_crop_window_patches
537
+ + right_crop_window_patches
538
+ )
539
+ else:
540
+ single_crop_window_patches = (
541
+ (crop_patches + pooling_size - 1) // pooling_size * pooling_size
542
+ )
543
+ return single_crop_window_patches
544
+
545
+ def molmo_get_n_image_patches(
546
+ self,
547
+ height: int,
548
+ width: int,
549
+ max_crops: int,
550
+ ) -> int:
551
+ # Discard this many patches from the (left/top, right/bottom) of crops
552
+ left_margin, right_margin = self.overlap_margins
553
+ # Required for compatibility with image pooling
554
+ assert left_margin % self.pooling_w == 0 and right_margin % self.pooling_w == 0
555
+ assert left_margin % self.pooling_h == 0 and right_margin % self.pooling_h == 0
556
+ # pixels removed per dim
557
+ total_margin_pixels = self.patch_size * (right_margin + left_margin)
558
+ # patches per crop dim
559
+ crop_patches = self.base_input_size[0] // self.patch_size
560
+
561
+ # usable patches
562
+ crop_window_patches = crop_patches - (right_margin + left_margin)
563
+ crop_window_size = crop_window_patches * self.patch_size
564
+
565
+ # We assume hxw pooling, but can allow padding the right/bottom with extra
566
+ # patches if the number of patches per side is not divisible by h/w
567
+ assert (
568
+ crop_patches + self.pooling_h - 1
569
+ ) // self.pooling_h == self.token_length_h
570
+ assert (
571
+ crop_patches + self.pooling_w - 1
572
+ ) // self.pooling_w == self.token_length_w
573
+
574
+ # Decide how to tile the image, to account for the overlap margins we
575
+ # compute the tiling as if we had an image without the margins and were
576
+ # using a crop size without the margins
577
+ tiling = self._molmo_select_tiling(
578
+ height - total_margin_pixels,
579
+ width - total_margin_pixels,
580
+ crop_window_size,
581
+ max_crops,
582
+ )
583
+
584
+ # Now build the output tokens
585
+ h = self._molmo_get_patches_from_tiling(
586
+ tiling[0],
587
+ self.pooling_h,
588
+ crop_patches,
589
+ crop_window_patches,
590
+ left_margin,
591
+ right_margin,
592
+ )
593
+ w = self._molmo_get_patches_from_tiling(
594
+ tiling[1],
595
+ self.pooling_w,
596
+ crop_patches,
597
+ crop_window_patches,
598
+ left_margin,
599
+ right_margin,
600
+ )
601
+ # for each row of patches, add a patch token per patch
602
+ n_tokens = w.item() // self.pooling_w
603
+ if self.use_column_tokens:
604
+ # after each row, one column token is added
605
+ n_tokens += 1
606
+ # replicate each row of patch tokens by number of rows, i.e.
607
+ # proportional to image height
608
+ n_tokens *= h.item() // self.pooling_h
609
+ # add start and end image tokens
610
+ n_tokens += 2
611
+
612
+ # Global image goes first, so the order of patches in previous crops gets
613
+ # increased
614
+ n_thumbnail_tokens = self.token_length_w
615
+ if self.use_column_tokens:
616
+ n_thumbnail_tokens += 1
617
+ n_thumbnail_tokens *= self.token_length_h
618
+ n_thumbnail_tokens += 2
619
+
620
+ return n_tokens + n_thumbnail_tokens
621
+
622
  def molmo_overlap_and_resize_cropping(self, image: np.ndarray):
623
  # Discard this many patches from the (left/top, right/bottom) of crops
624
  left_margin, right_margin = self.overlap_margins
 
747
  # new order into sparse structure of `patch_ordering` to fix it
748
  patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
  # Now build the output tokens
751
+ h = self._molmo_get_patches_from_tiling(
752
+ tiling[0],
753
+ self.pooling_h,
754
+ crop_patches,
755
+ crop_window_patches,
756
+ left_margin,
757
+ right_margin,
758
+ )
759
+ w = self._molmo_get_patches_from_tiling(
760
+ tiling[1],
761
+ self.pooling_w,
762
+ crop_patches,
763
+ crop_window_patches,
764
+ left_margin,
765
+ right_margin,
766
+ )
767
  # for each row of patches, add a patch token per patch
768
  per_row = np.full((w // self.pooling_w,), self.patch_token_id, dtype=np.int32)
769
  if self.use_column_tokens:
 
918
 
919
  return slices, image_masks, patch_ordering_arr, best_grid
920
 
921
+ def minicpm_get_n_image_patches(
922
+ self, height: int, width: int, max_crops: int, with_thumbnail: bool = False
923
+ ) -> int:
924
+ raise NotImplementedError(
925
+ 'Function `get_n_image_patches` is not implemented for cropping method '
926
+ f'{CroppingMethod.ADAPTIVE_SLICING}'
927
+ )
928
+
929
  def minicpm_adaptive_slicing(self, image: np.ndarray, with_thumbnail: bool = True):
930
  scale_resolution = self.base_input_size[0]
931
  refine_image, image_mask, best_grid = self._minicpm_refine_image_for_slicing(
 
1062
  self.start_token_id = start_token_id
1063
  self.end_token_id = end_token_id
1064
 
1065
+ def _resolve_images_kwargs(
1066
+ self, **kwargs: Unpack[JinaVLMImagesKwargs]
1067
+ ) -> JinaVLMImagesKwargs:
1068
+ max_crops = self.max_crops
 
 
 
 
 
 
 
 
 
 
1069
  if 'max_crops' in kwargs and kwargs['max_crops'] is not None:
1070
  max_crops = kwargs['max_crops']
 
1071
 
1072
  min_pixels = self.min_pixels
1073
  if 'min_pixels' in kwargs and kwargs['min_pixels'] is not None:
 
1089
  size = {'shortest_edge': min_pixels, 'longest_edge': max_pixels}
1090
  else:
1091
  size = {**self.size}
1092
+ min_pixels = size['shortest_edge']
1093
+ max_pixels = size['longest_edge']
1094
  do_resize = self.do_resize
1095
  if 'do_resize' in kwargs and kwargs['do_resize'] is not None:
1096
  do_resize = kwargs['do_resize']
 
1097
  do_convert_rgb = self.do_convert_rgb
1098
  if 'do_convert_rgb' in kwargs and kwargs['do_convert_rgb'] is not None:
1099
  do_convert_rgb = kwargs['do_convert_rgb']
1100
+ input_data_format = None
1101
+ if 'input_data_format' in kwargs:
1102
+ input_data_format = kwargs['input_data_format']
1103
+
1104
+ return JinaVLMImagesKwargs(
1105
+ do_convert_rgb=do_convert_rgb,
1106
+ do_resize=do_resize,
1107
+ min_pixels=min_pixels,
1108
+ max_pixels=max_pixels,
1109
+ size=size,
1110
+ max_crops=max_crops,
1111
+ input_data_format=input_data_format,
1112
+ )
1113
+
1114
+ def get_n_image_patches(
1115
+ self,
1116
+ height: int,
1117
+ width: int,
1118
+ **kwargs: Unpack[JinaVLMImagesKwargs],
1119
+ ) -> int:
1120
+ """A utility that returns number of image patches for a given image size.
1121
+
1122
+ Args:
1123
+ height (`int`):
1124
+ Height of the input image.
1125
+ width (`int`):
1126
+ Width of the input image.
1127
+ **kwargs (`dict`, *optional*)
1128
+ Any kwargs to override defaults of the image processor.
1129
+ Returns:
1130
+ `int`: Number of image patches
1131
+ """
1132
+ if self.cropping_method != CroppingMethod.OVERLAP_AND_RESIZE:
1133
+ raise NotImplementedError(
1134
+ 'Function is only implemented for cropping method '
1135
+ f'{CroppingMethod.OVERLAP_AND_RESIZE}'
1136
+ )
1137
+ kwargs = self._resolve_images_kwargs(**kwargs)
1138
+ do_resize = kwargs['do_resize']
1139
+ size = kwargs['size']
1140
+ max_crops = kwargs['max_crops']
1141
+ if do_resize:
1142
+ height, width = smart_resize(
1143
+ height,
1144
+ width,
1145
+ factor=self.patch_size,
1146
+ min_pixels=size['shortest_edge'],
1147
+ max_pixels=size['longest_edge'],
1148
+ )
1149
+
1150
+ if self.cropping_method == CroppingMethod.RESIZE:
1151
+ return self.base_get_n_image_patches(height, width, max_crops)
1152
+ elif self.cropping_method == CroppingMethod.OVERLAP_AND_RESIZE:
1153
+ return self.molmo_get_n_image_patches(height, width, max_crops)
1154
+ elif self.cropping_method == CroppingMethod.ADAPTIVE_SLICING:
1155
+ return self.minicpm_get_n_image_patches(height, width, max_crops)
1156
+ return self.minicpm_get_n_image_patches(
1157
+ height, width, max_crops, with_thumbnail=True
1158
+ )
1159
+
1160
+ def preprocess(
1161
+ self,
1162
+ images: ImageInput,
1163
+ **kwargs: Unpack[JinaVLMImagesKwargs],
1164
+ ) -> Dict[str, List[np.ndarray]]:
1165
+ """Preprocess an image or batch of images."""
1166
+ if images is None or len(images) == 0:
1167
+ return {
1168
+ 'image_crops': [],
1169
+ 'image_tokens': [],
1170
+ 'image_input_idx': [],
1171
+ 'image_padding_mask': [],
1172
+ }
1173
+ kwargs = self._resolve_images_kwargs(**kwargs)
1174
+ do_convert_rgb = kwargs['do_convert_rgb']
1175
+ do_resize = kwargs['do_resize']
1176
+ input_data_format = kwargs['input_data_format']
1177
+ size = kwargs['size']
1178
+ self.max_crops = kwargs['max_crops']
1179
 
1180
  # noinspection PyTypeChecker
1181
  images = self.fetch_images(images)
 
1185
  'Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray '
1186
  'or torch.Tensor'
1187
  )
 
1188
  if do_convert_rgb:
1189
  images = [convert_to_rgb(image) for image in images]
1190
 
1191
  # All transformations expect numpy arrays
1192
  images = [to_numpy_array(image) for image in images]
 
 
 
 
1193
  if input_data_format is None:
1194
  # We assume that all images have the same channel dimension format.
1195
  input_data_format = infer_channel_dimension_format(images[0])
modeling_jvlm.py CHANGED
@@ -27,14 +27,13 @@ from .blocks_jvlm import (
27
  TransformerBlock,
28
  VisionLanguageConnector,
29
  build_layer_norm,
30
- resolve_causal_mask
31
  )
32
  from .configuration_jvlm import JinaVLMConfig, JinaVLMTextConfig, JinaVLMVisionConfig
33
 
34
 
35
  class JinaPreTrainedModel(PreTrainedModel):
36
  config: JinaVLMConfig
37
- config_class = JinaVLMConfig
38
  base_model_prefix = 'model'
39
  supports_gradient_checkpointing = True
40
  _supports_flash_attn = True
@@ -51,8 +50,6 @@ class JinaPreTrainedModel(PreTrainedModel):
51
 
52
  class JinaVLMVisionModel(JinaPreTrainedModel):
53
  config: JinaVLMVisionConfig
54
- config_class = JinaVLMVisionConfig
55
- base_model_prefix = ''
56
 
57
  def __init__(self, config: JinaVLMVisionConfig, *args, **kwargs):
58
  super().__init__(config, *args, **kwargs)
@@ -186,7 +183,11 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
186
  pos = pos_emb[None, :, :].to(x.dtype)
187
  return x + pos
188
 
189
- def get_visual_features(self, images: torch.Tensor) -> BaseModelOutput:
 
 
 
 
190
  x, shape = self.patch_embed(images)
191
  if self.cls_embed is not None:
192
  cls = self.cls_embed.view(1, 1, -1).expand(x.shape[0], -1, -1).to(x.dtype)
@@ -201,7 +202,11 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
201
  hidden_states = []
202
  attentions = []
203
  for layer in self.layers:
204
- x, attn = layer(x, attn_implementation=self.config._attn_implementation)
 
 
 
 
205
  hidden_states.append(x)
206
  attentions.append(attn)
207
  x = self.post_lnorm(x)
@@ -214,12 +219,15 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
214
  )
215
 
216
  def forward(
217
- self, images: torch.Tensor, image_masks: torch.Tensor
 
 
 
218
  ) -> BaseModelOutput:
219
  b, t, n, d = images.shape
220
  mask = ~torch.all(images.view(b * t, n, d) == -1, dim=(1, 2), keepdim=True)
221
  images = images.view(b * t, n, d)
222
- out = self.get_visual_features(images)
223
  image_features = out.hidden_states
224
 
225
  features = []
@@ -230,14 +238,13 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
230
  features.append(feats)
231
  image_features = torch.cat(features, dim=-1)
232
  image_features = image_features * mask
233
- image_features = image_features.view(b, t, n, -1)
234
-
235
  image_features = self.vl_connector(
236
  image_features,
237
  image_masks,
238
  attn_implementation=self.config._attn_implementation,
 
239
  )
240
-
241
  return BaseModelOutput(
242
  last_hidden_state=image_features,
243
  hidden_states=out.hidden_states,
@@ -246,11 +253,7 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
246
 
247
 
248
  class JinaVLMTextModel(JinaPreTrainedModel):
249
- """Decoder-only language model."""
250
-
251
  config: JinaVLMTextConfig
252
- config_class = JinaVLMTextConfig
253
- base_model_prefix = ''
254
 
255
  def __init__(self, config: JinaVLMTextConfig, *args, **kwargs):
256
  super().__init__(config, *args, **kwargs)
@@ -297,7 +300,6 @@ class JinaVLMTextModel(JinaPreTrainedModel):
297
  theta=self.config.rope_theta,
298
  head_dim=self.config.block_config.attn_config.head_dim,
299
  hidden_size=self.config.hidden_size,
300
- n_heads=self.config.block_config.attn_config.n_heads,
301
  partial_rotary_factor=self.config.partial_rotary_factor,
302
  scaling=self.config.rope_scaling,
303
  )
@@ -444,7 +446,7 @@ class JinaVLMTextModel(JinaPreTrainedModel):
444
 
445
 
446
  class JinaVLM(JinaPreTrainedModel):
447
- base_model_prefix = ''
448
 
449
  def __init__(self, config: JinaVLMConfig):
450
  super().__init__(config)
@@ -493,7 +495,7 @@ class JinaVLM(JinaPreTrainedModel):
493
  ) -> BaseModelOutputWithPast:
494
  image_features = None
495
  if images is not None and images.shape[1] > 0:
496
- image_out = self.vision_model(images, image_masks)
497
  image_features = image_out.last_hidden_state
498
  return self.language_model(
499
  input_ids=input_ids,
@@ -512,10 +514,10 @@ class JinaVLM(JinaPreTrainedModel):
512
 
513
 
514
  class JinaVLMForConditionalGeneration(JinaPreTrainedModel, GenerationMixin):
515
- _checkpoint_conversion_mapping = {}
516
- _tied_weights_keys = ['lm_head.weight']
 
517
  accepts_loss_kwargs = False
518
- base_model_prefix = 'model'
519
  config: JinaVLMConfig
520
 
521
  def __init__(self, config: JinaVLMConfig):
 
27
  TransformerBlock,
28
  VisionLanguageConnector,
29
  build_layer_norm,
30
+ resolve_causal_mask,
31
  )
32
  from .configuration_jvlm import JinaVLMConfig, JinaVLMTextConfig, JinaVLMVisionConfig
33
 
34
 
35
  class JinaPreTrainedModel(PreTrainedModel):
36
  config: JinaVLMConfig
 
37
  base_model_prefix = 'model'
38
  supports_gradient_checkpointing = True
39
  _supports_flash_attn = True
 
50
 
51
  class JinaVLMVisionModel(JinaPreTrainedModel):
52
  config: JinaVLMVisionConfig
 
 
53
 
54
  def __init__(self, config: JinaVLMVisionConfig, *args, **kwargs):
55
  super().__init__(config, *args, **kwargs)
 
183
  pos = pos_emb[None, :, :].to(x.dtype)
184
  return x + pos
185
 
186
+ def get_visual_features(
187
+ self,
188
+ images: torch.Tensor,
189
+ **kwargs: Unpack[FlashAttentionKwargs],
190
+ ) -> BaseModelOutput:
191
  x, shape = self.patch_embed(images)
192
  if self.cls_embed is not None:
193
  cls = self.cls_embed.view(1, 1, -1).expand(x.shape[0], -1, -1).to(x.dtype)
 
202
  hidden_states = []
203
  attentions = []
204
  for layer in self.layers:
205
+ x, attn = layer(
206
+ x,
207
+ attn_implementation=self.config._attn_implementation,
208
+ **kwargs,
209
+ )
210
  hidden_states.append(x)
211
  attentions.append(attn)
212
  x = self.post_lnorm(x)
 
219
  )
220
 
221
  def forward(
222
+ self,
223
+ images: torch.Tensor,
224
+ image_masks: torch.Tensor,
225
+ **kwargs: Unpack[FlashAttentionKwargs],
226
  ) -> BaseModelOutput:
227
  b, t, n, d = images.shape
228
  mask = ~torch.all(images.view(b * t, n, d) == -1, dim=(1, 2), keepdim=True)
229
  images = images.view(b * t, n, d)
230
+ out = self.get_visual_features(images, **kwargs)
231
  image_features = out.hidden_states
232
 
233
  features = []
 
238
  features.append(feats)
239
  image_features = torch.cat(features, dim=-1)
240
  image_features = image_features * mask
241
+ image_features = image_features.view(b, t, n, -1).contiguous()
 
242
  image_features = self.vl_connector(
243
  image_features,
244
  image_masks,
245
  attn_implementation=self.config._attn_implementation,
246
+ **kwargs,
247
  )
 
248
  return BaseModelOutput(
249
  last_hidden_state=image_features,
250
  hidden_states=out.hidden_states,
 
253
 
254
 
255
  class JinaVLMTextModel(JinaPreTrainedModel):
 
 
256
  config: JinaVLMTextConfig
 
 
257
 
258
  def __init__(self, config: JinaVLMTextConfig, *args, **kwargs):
259
  super().__init__(config, *args, **kwargs)
 
300
  theta=self.config.rope_theta,
301
  head_dim=self.config.block_config.attn_config.head_dim,
302
  hidden_size=self.config.hidden_size,
 
303
  partial_rotary_factor=self.config.partial_rotary_factor,
304
  scaling=self.config.rope_scaling,
305
  )
 
446
 
447
 
448
  class JinaVLM(JinaPreTrainedModel):
449
+ config: JinaVLMConfig
450
 
451
  def __init__(self, config: JinaVLMConfig):
452
  super().__init__(config)
 
495
  ) -> BaseModelOutputWithPast:
496
  image_features = None
497
  if images is not None and images.shape[1] > 0:
498
+ image_out = self.vision_model(images, image_masks, **kwargs)
499
  image_features = image_out.last_hidden_state
500
  return self.language_model(
501
  input_ids=input_ids,
 
514
 
515
 
516
  class JinaVLMForConditionalGeneration(JinaPreTrainedModel, GenerationMixin):
517
+ _tied_weights_keys = {
518
+ 'lm_head.weight': 'model.language_model.embedding.embedding.weight'
519
+ }
520
  accepts_loss_kwargs = False
 
521
  config: JinaVLMConfig
522
 
523
  def __init__(self, config: JinaVLMConfig):
processing_jvlm.py CHANGED
@@ -10,11 +10,14 @@ from transformers.image_utils import ImageInput
10
  from transformers.processing_utils import (
11
  AllKwargsForChatTemplate,
12
  CommonKwargs,
 
13
  ProcessorMixin,
14
  Unpack,
15
  )
16
  from transformers.tokenization_utils_base import (
17
- PaddingStrategy, PreTokenizedInput, TextInput,
 
 
18
  )
19
 
20
  from .image_processing_jvlm import JinaVLMImageProcessor, JinaVLMImagesKwargs
@@ -38,7 +41,7 @@ class JinaVLMTextKwargs(TypedDict, total=False):
38
  is_split_into_words: Optional[bool]
39
 
40
 
41
- class JinaVLProcessingKwargs(JinaVLMTextKwargs, JinaVLMImagesKwargs, CommonKwargs):
42
  return_labels: Optional[bool]
43
 
44
 
@@ -171,8 +174,8 @@ class JinaVLMProcessor(ProcessorMixin):
171
  def _collate(
172
  self,
173
  batch: Dict[str, List[Optional[np.ndarray]]],
174
- max_sequence_length: Optional[int] = None,
175
- max_crops: Optional[int] = None,
176
  padding: Union[
177
  PaddingStrategy.MAX_LENGTH, PaddingStrategy.LONGEST
178
  ] = PaddingStrategy.MAX_LENGTH,
@@ -185,10 +188,10 @@ class JinaVLMProcessor(ProcessorMixin):
185
  _padding_side = 'right'
186
  if key in self.TEXT_KEYS:
187
  _padding_side = padding_side
188
- max_len = max_sequence_length
189
  dtype = np.int64
190
  elif key in self.IMAGE_KEYS:
191
- max_len = max_crops
192
  dtype = np.int64
193
  if key == 'images':
194
  dtype = np.float32
@@ -214,22 +217,22 @@ class JinaVLMProcessor(ProcessorMixin):
214
  shift = input_ids_padlens[:, np.newaxis, np.newaxis]
215
  shift = np.repeat(shift, n_image_tokens, axis=2)
216
  shift = np.repeat(shift, n_crops, axis=1)
217
- image_input_idx[image_input_idx < 0] = -max_sequence_length
218
  image_input_idx = image_input_idx + shift
219
  out['image_input_idx'] = image_input_idx
220
 
221
- if max_sequence_length is not None:
222
  image_input_idx = out.get('image_input_idx', [])
223
  n = len(image_input_idx)
224
  for i in range(n):
225
  arr = image_input_idx[i]
226
  if arr.ndim > 0 and arr.size > 0:
227
  n_image_tokens = arr.max()
228
- if n_image_tokens > max_sequence_length - 3:
229
  raise RuntimeError(
230
  'Image tokens truncation at sequence boundary. Max '
231
- f'sequence length ({max_sequence_length}) is too small '
232
- 'to fit the generated image tokens '
233
  f'({n_image_tokens}). Consider increasing the max '
234
  'sequence length or tweaking the image processing '
235
  'parameters (`max_crops`, `max_pixels`) to reduce the '
@@ -386,7 +389,7 @@ class JinaVLMProcessor(ProcessorMixin):
386
  text: Union[
387
  None, TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
388
  ] = None,
389
- **kwargs: Unpack[JinaVLProcessingKwargs],
390
  ) -> BatchFeature:
391
  """Main method to prepare for the model one or several sequences(s) and
392
  image(s). This method forwards the `text` and `kwargs` arguments to the
@@ -489,9 +492,11 @@ class JinaVLMProcessor(ProcessorMixin):
489
  )
490
 
491
  outputs = defaultdict(list)
 
492
  for idx in range(batch_size):
493
  _token_ids = token_ids[idx]
494
  _images = images[idx]
 
495
  image_inputs = self.image_processor(_images, **images_kwargs)
496
  image_crops = image_inputs['image_crops']
497
  image_tokens = image_inputs['image_tokens']
@@ -510,14 +515,42 @@ class JinaVLMProcessor(ProcessorMixin):
510
  outputs[k].append(v)
511
 
512
  if padding != PaddingStrategy.DO_NOT_PAD:
 
 
 
 
513
  outputs = self._collate(
514
  outputs,
515
- max_sequence_length=max_length or self.max_sequence_length,
516
- max_crops=max_crops or self.max_crops,
517
  padding=padding,
518
  padding_side=padding_side,
519
  )
520
  return BatchFeature(data=outputs, tensor_type=return_tensors)
521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
  JinaVLMProcessor.register_for_auto_class()
 
10
  from transformers.processing_utils import (
11
  AllKwargsForChatTemplate,
12
  CommonKwargs,
13
+ MultiModalData,
14
  ProcessorMixin,
15
  Unpack,
16
  )
17
  from transformers.tokenization_utils_base import (
18
+ PaddingStrategy,
19
+ PreTokenizedInput,
20
+ TextInput,
21
  )
22
 
23
  from .image_processing_jvlm import JinaVLMImageProcessor, JinaVLMImagesKwargs
 
41
  is_split_into_words: Optional[bool]
42
 
43
 
44
+ class JinaVLMProcessingKwargs(JinaVLMTextKwargs, JinaVLMImagesKwargs, CommonKwargs):
45
  return_labels: Optional[bool]
46
 
47
 
 
174
  def _collate(
175
  self,
176
  batch: Dict[str, List[Optional[np.ndarray]]],
177
+ text_max_sequence_length: Optional[int] = None,
178
+ image_max_sequence_length: Optional[int] = None,
179
  padding: Union[
180
  PaddingStrategy.MAX_LENGTH, PaddingStrategy.LONGEST
181
  ] = PaddingStrategy.MAX_LENGTH,
 
188
  _padding_side = 'right'
189
  if key in self.TEXT_KEYS:
190
  _padding_side = padding_side
191
+ max_len = text_max_sequence_length
192
  dtype = np.int64
193
  elif key in self.IMAGE_KEYS:
194
+ max_len = image_max_sequence_length
195
  dtype = np.int64
196
  if key == 'images':
197
  dtype = np.float32
 
217
  shift = input_ids_padlens[:, np.newaxis, np.newaxis]
218
  shift = np.repeat(shift, n_image_tokens, axis=2)
219
  shift = np.repeat(shift, n_crops, axis=1)
220
+ image_input_idx[image_input_idx < 0] = -text_max_sequence_length
221
  image_input_idx = image_input_idx + shift
222
  out['image_input_idx'] = image_input_idx
223
 
224
+ if text_max_sequence_length is not None:
225
  image_input_idx = out.get('image_input_idx', [])
226
  n = len(image_input_idx)
227
  for i in range(n):
228
  arr = image_input_idx[i]
229
  if arr.ndim > 0 and arr.size > 0:
230
  n_image_tokens = arr.max()
231
+ if n_image_tokens > text_max_sequence_length - 3:
232
  raise RuntimeError(
233
  'Image tokens truncation at sequence boundary. Max '
234
+ f'sequence length ({text_max_sequence_length}) is too '
235
+ 'small to fit the generated image tokens '
236
  f'({n_image_tokens}). Consider increasing the max '
237
  'sequence length or tweaking the image processing '
238
  'parameters (`max_crops`, `max_pixels`) to reduce the '
 
389
  text: Union[
390
  None, TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
391
  ] = None,
392
+ **kwargs: Unpack[JinaVLMProcessingKwargs],
393
  ) -> BatchFeature:
394
  """Main method to prepare for the model one or several sequences(s) and
395
  image(s). This method forwards the `text` and `kwargs` arguments to the
 
492
  )
493
 
494
  outputs = defaultdict(list)
495
+ n_images = []
496
  for idx in range(batch_size):
497
  _token_ids = token_ids[idx]
498
  _images = images[idx]
499
+ n_images.append(len(_images))
500
  image_inputs = self.image_processor(_images, **images_kwargs)
501
  image_crops = image_inputs['image_crops']
502
  image_tokens = image_inputs['image_tokens']
 
515
  outputs[k].append(v)
516
 
517
  if padding != PaddingStrategy.DO_NOT_PAD:
518
+ text_max_sequence_length = max_length or self.max_sequence_length
519
+ max_crops = max_crops or self.max_crops
520
+ max_n_images = max(n_images)
521
+ image_max_sequence_length = (max_crops + 1) * max_n_images
522
  outputs = self._collate(
523
  outputs,
524
+ text_max_sequence_length=text_max_sequence_length,
525
+ image_max_sequence_length=image_max_sequence_length,
526
  padding=padding,
527
  padding_side=padding_side,
528
  )
529
  return BatchFeature(data=outputs, tensor_type=return_tensors)
530
 
531
+ def _get_num_multimodal_tokens(
532
+ self,
533
+ image_sizes: Optional[List[List[int]]] = None,
534
+ **kwargs: Unpack[JinaVLMImagesKwargs],
535
+ ) -> MultiModalData:
536
+ """Computes the number of placeholder tokens needed for multimodal inputs with
537
+ the given sizes.
538
+
539
+ Args:
540
+ image_sizes (`list[list[int]]`, *optional*):
541
+ The input sizes formatted as (height, width) per each image.
542
+ Returns:
543
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per
544
+ each of the provided input modalities, along with other useful data.
545
+ """
546
+ data = {}
547
+ if image_sizes is not None:
548
+ n_patches = [
549
+ self.image_processor.get_n_image_patches(h, w, **kwargs)
550
+ for h, w in image_sizes
551
+ ]
552
+ data.update({'num_image_tokens': n_patches, 'num_image_patches': n_patches})
553
+ return MultiModalData(**data)
554
+
555
 
556
  JinaVLMProcessor.register_for_auto_class()
test_jvlm.py CHANGED
@@ -11,7 +11,10 @@ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
11
 
12
  import torch
13
  from transformers import (
14
- AutoModelForCausalLM, AutoProcessor, GenerationConfig, TextStreamer
 
 
 
15
  )
16
  from transformers.utils import is_flash_attn_2_available
17
 
@@ -60,7 +63,8 @@ def _build_conversations(
60
  try:
61
  result = urlparse(_path)
62
  return result.scheme in ('http', 'https')
63
- except:
 
64
  return False
65
 
66
  images = images or []
@@ -83,8 +87,9 @@ def _build_conversations(
83
  images = [TEST_IMAGE]
84
  n_images = len(images)
85
  prompts = (
86
- ['Describe the image in 100 words'] if n_images == 1 or map_mode else
87
- ['Describe the images in 100 words']
 
88
  )
89
  n_prompts = len(prompts)
90
 
@@ -119,8 +124,16 @@ def _build_conversations(
119
  allimages = []
120
  allprompts = []
121
  ordinals = [
122
- 'first', 'second', 'third', 'fourth', 'fifth',
123
- 'sixth', 'seventh', 'eighth', 'ninth', 'tenth',
 
 
 
 
 
 
 
 
124
  ]
125
  for images, prompt in examples:
126
  content = []
@@ -130,15 +143,17 @@ def _build_conversations(
130
  content.append({'type': 'text', 'text': prompt})
131
  if len(images) > 1 and image_labels:
132
  for idx, img in enumerate(images):
133
- ordinal = ordinals[idx] if idx < len(ordinals) else f'{idx+1}th'
134
  image = images[idx]
135
  descriptor = f'url: {image}'
136
  if os.path.isfile(image):
137
  descriptor = f'filename: {os.path.basename(image)}'
138
- content.append({
139
- 'type': 'text',
140
- 'text': f'(this is the {ordinal} image, {descriptor})',
141
- })
 
 
142
  content.append({'type': 'image', 'image': img})
143
  else:
144
  content.extend([{'type': 'image', 'image': image} for image in images])
@@ -189,9 +204,7 @@ def _token_usage_report(
189
  tokens_per_image_list = []
190
 
191
  # Find all img_start and img_end positions in input_ids
192
- start_positions = (input_ids == image_start_id).nonzero(
193
- as_tuple=True
194
- )[0].tolist()
195
  end_positions = (input_ids == image_end_id).nonzero(as_tuple=True)[0].tolist()
196
 
197
  if len(start_positions) > 0 and len(end_positions) > 0:
@@ -211,9 +224,8 @@ def _token_usage_report(
211
  # Get the start and end indices for this image
212
  start_idx_begin = idx * n_starts_per_image
213
  end_idx_end = (idx + 1) * n_starts_per_image
214
- if (
215
- start_idx_begin < len(start_positions) and
216
- end_idx_end <= len(end_positions)
217
  ):
218
  # First start position and last end position define the image span
219
  first_start = start_positions[start_idx_begin]
@@ -233,10 +245,10 @@ def _token_usage_report(
233
 
234
  for idx in range(n_images):
235
  n_tokens = tokens_per_image_list[idx] if idx < len(tokens_per_image_list) else 0
236
- pct = (n_tokens / max_sequence_length * 100)
237
  report.append(f'├── Image {idx + 1} → {n_tokens} tokens ({pct:.1f}%)')
238
 
239
- text_pct = (text_token_count / max_sequence_length * 100)
240
  report.append(f'└── Text: {text_token_count} tokens ({text_pct:.1f}%)')
241
 
242
  return '\n'.join(report)
@@ -253,7 +265,7 @@ def test_jvlm():
253
  help=(
254
  'Model path (default: `"."`). Set this to `"jinaai/jina-vlm-v1"` if you '
255
  'are running this script outside this repo.'
256
- )
257
  )
258
  parser.add_argument(
259
  '-i',
@@ -339,7 +351,9 @@ def test_jvlm():
339
  print(f'Using dtype: {dtype}')
340
  print('Model path: ', args.model)
341
  processor = AutoProcessor.from_pretrained(
342
- args.model, trust_remote_code=True, use_fast=False,
 
 
343
  )
344
  model = AutoModelForCausalLM.from_pretrained(
345
  args.model,
@@ -356,13 +370,13 @@ def test_jvlm():
356
  print('Done ✅')
357
  print()
358
 
359
- print('--- Let\'s create some conversations ...')
360
  conversations, images, prompts = _build_conversations(
361
  args.image,
362
  args.prompt,
363
  map_mode=args.map,
364
  prompt_first=args.prompt_first,
365
- image_labels=args.image_labels
366
  )
367
  n_conversations = len(conversations)
368
  print(f'Built {n_conversations} conversations 🚀')
@@ -434,25 +448,28 @@ def test_jvlm():
434
  print(f'├── 🖼️Images: {images[idx]}')
435
  print(f'├── 📜Prompt: {prompts[idx]}')
436
  print(f'├── 💬Chat:{texts[idx]}')
437
- print(f'└── 🧠Response:', end='')
438
  ith_inputs = {k: v[idx].unsqueeze(0) for k, v in device_inputs.items()}
439
  with (
440
  timer,
441
  torch.no_grad(),
442
- torch.autocast(device.type, enabled=(device.type != 'mps'), dtype=dtype)
 
 
443
  ):
444
  output = model.generate(
445
  **ith_inputs,
446
  streamer=streamer,
447
  generation_config=GenerationConfig(
448
- max_new_tokens=args.max_tokens, do_sample=False,
 
449
  ),
450
  return_dict_in_generate=True,
451
  use_model_defaults=True,
452
  )
453
  generation_time += timer.time
454
 
455
- out = output.sequences[0][len(input_prompts[idx].tolist()):]
456
  generated_tokens += len(out)
457
  print('Token usage report:')
458
  print(token_usage_reports[idx])
@@ -470,7 +487,8 @@ def test_jvlm():
470
  output = model.generate(
471
  **device_inputs,
472
  generation_config=GenerationConfig(
473
- max_new_tokens=args.max_tokens, do_sample=False,
 
474
  ),
475
  return_dict_in_generate=True,
476
  use_model_defaults=True,
@@ -478,7 +496,7 @@ def test_jvlm():
478
  generation_time = timer.time
479
 
480
  for idx in range(n_conversations):
481
- out = output.sequences[idx][len(input_prompts[idx].tolist()):]
482
  generated_tokens += len(out)
483
  response = processor.tokenizer.decode(out, skip_special_tokens=True)
484
  print(f'* Conversation {idx + 1}/{n_conversations}')
 
11
 
12
  import torch
13
  from transformers import (
14
+ AutoModelForCausalLM,
15
+ AutoProcessor,
16
+ GenerationConfig,
17
+ TextStreamer,
18
  )
19
  from transformers.utils import is_flash_attn_2_available
20
 
 
63
  try:
64
  result = urlparse(_path)
65
  return result.scheme in ('http', 'https')
66
+ except Exception as e:
67
+ _ = str(e)
68
  return False
69
 
70
  images = images or []
 
87
  images = [TEST_IMAGE]
88
  n_images = len(images)
89
  prompts = (
90
+ ['Describe the image in 100 words']
91
+ if n_images == 1 or map_mode
92
+ else ['Describe the images in 100 words']
93
  )
94
  n_prompts = len(prompts)
95
 
 
124
  allimages = []
125
  allprompts = []
126
  ordinals = [
127
+ 'first',
128
+ 'second',
129
+ 'third',
130
+ 'fourth',
131
+ 'fifth',
132
+ 'sixth',
133
+ 'seventh',
134
+ 'eighth',
135
+ 'ninth',
136
+ 'tenth',
137
  ]
138
  for images, prompt in examples:
139
  content = []
 
143
  content.append({'type': 'text', 'text': prompt})
144
  if len(images) > 1 and image_labels:
145
  for idx, img in enumerate(images):
146
+ ordinal = ordinals[idx] if idx < len(ordinals) else f'{idx + 1}th'
147
  image = images[idx]
148
  descriptor = f'url: {image}'
149
  if os.path.isfile(image):
150
  descriptor = f'filename: {os.path.basename(image)}'
151
+ content.append(
152
+ {
153
+ 'type': 'text',
154
+ 'text': f'(this is the {ordinal} image, {descriptor})',
155
+ }
156
+ )
157
  content.append({'type': 'image', 'image': img})
158
  else:
159
  content.extend([{'type': 'image', 'image': image} for image in images])
 
204
  tokens_per_image_list = []
205
 
206
  # Find all img_start and img_end positions in input_ids
207
+ start_positions = (input_ids == image_start_id).nonzero(as_tuple=True)[0].tolist()
 
 
208
  end_positions = (input_ids == image_end_id).nonzero(as_tuple=True)[0].tolist()
209
 
210
  if len(start_positions) > 0 and len(end_positions) > 0:
 
224
  # Get the start and end indices for this image
225
  start_idx_begin = idx * n_starts_per_image
226
  end_idx_end = (idx + 1) * n_starts_per_image
227
+ if start_idx_begin < len(start_positions) and end_idx_end <= len(
228
+ end_positions
 
229
  ):
230
  # First start position and last end position define the image span
231
  first_start = start_positions[start_idx_begin]
 
245
 
246
  for idx in range(n_images):
247
  n_tokens = tokens_per_image_list[idx] if idx < len(tokens_per_image_list) else 0
248
+ pct = n_tokens / max_sequence_length * 100
249
  report.append(f'├── Image {idx + 1} → {n_tokens} tokens ({pct:.1f}%)')
250
 
251
+ text_pct = text_token_count / max_sequence_length * 100
252
  report.append(f'└── Text: {text_token_count} tokens ({text_pct:.1f}%)')
253
 
254
  return '\n'.join(report)
 
265
  help=(
266
  'Model path (default: `"."`). Set this to `"jinaai/jina-vlm-v1"` if you '
267
  'are running this script outside this repo.'
268
+ ),
269
  )
270
  parser.add_argument(
271
  '-i',
 
351
  print(f'Using dtype: {dtype}')
352
  print('Model path: ', args.model)
353
  processor = AutoProcessor.from_pretrained(
354
+ args.model,
355
+ trust_remote_code=True,
356
+ use_fast=False,
357
  )
358
  model = AutoModelForCausalLM.from_pretrained(
359
  args.model,
 
370
  print('Done ✅')
371
  print()
372
 
373
+ print("--- Let's create some conversations ...")
374
  conversations, images, prompts = _build_conversations(
375
  args.image,
376
  args.prompt,
377
  map_mode=args.map,
378
  prompt_first=args.prompt_first,
379
+ image_labels=args.image_labels,
380
  )
381
  n_conversations = len(conversations)
382
  print(f'Built {n_conversations} conversations 🚀')
 
448
  print(f'├── 🖼️Images: {images[idx]}')
449
  print(f'├── 📜Prompt: {prompts[idx]}')
450
  print(f'├── 💬Chat:{texts[idx]}')
451
+ print('└── 🧠Response:', end='')
452
  ith_inputs = {k: v[idx].unsqueeze(0) for k, v in device_inputs.items()}
453
  with (
454
  timer,
455
  torch.no_grad(),
456
+ torch.autocast(
457
+ device.type, enabled=(device.type != 'mps'), dtype=dtype
458
+ ),
459
  ):
460
  output = model.generate(
461
  **ith_inputs,
462
  streamer=streamer,
463
  generation_config=GenerationConfig(
464
+ max_new_tokens=args.max_tokens,
465
+ do_sample=False,
466
  ),
467
  return_dict_in_generate=True,
468
  use_model_defaults=True,
469
  )
470
  generation_time += timer.time
471
 
472
+ out = output.sequences[0][len(input_prompts[idx].tolist()) :]
473
  generated_tokens += len(out)
474
  print('Token usage report:')
475
  print(token_usage_reports[idx])
 
487
  output = model.generate(
488
  **device_inputs,
489
  generation_config=GenerationConfig(
490
+ max_new_tokens=args.max_tokens,
491
+ do_sample=False,
492
  ),
493
  return_dict_in_generate=True,
494
  use_model_defaults=True,
 
496
  generation_time = timer.time
497
 
498
  for idx in range(n_conversations):
499
+ out = output.sequences[idx][len(input_prompts[idx].tolist()) :]
500
  generated_tokens += len(out)
501
  response = processor.tokenizer.decode(out, skip_special_tokens=True)
502
  print(f'* Conversation {idx + 1}/{n_conversations}')