Model update
Browse files- README.md +3 -1
- blocks_jvlm.py +50 -29
- config.json +2 -1
- configuration_jvlm.py +7 -1
- image_processing_jvlm.py +231 -52
- modeling_jvlm.py +23 -21
- processing_jvlm.py +47 -14
- test_jvlm.py +47 -29
README.md
CHANGED
|
@@ -286,7 +286,9 @@ processor = AutoProcessor.from_pretrained(
|
|
| 286 |
|
| 287 |
# Load the model on the available device(s)
|
| 288 |
model = AutoModelForCausalLM.from_pretrained(
|
| 289 |
-
'jinaai/jina-vlm-v1',
|
|
|
|
|
|
|
| 290 |
)
|
| 291 |
|
| 292 |
# You can specify a different model dtype and/or attention implementation
|
|
|
|
| 286 |
|
| 287 |
# Load the model on the available device(s)
|
| 288 |
model = AutoModelForCausalLM.from_pretrained(
|
| 289 |
+
'jinaai/jina-vlm-v1',
|
| 290 |
+
device_map='auto',
|
| 291 |
+
trust_remote_code=True
|
| 292 |
)
|
| 293 |
|
| 294 |
# You can specify a different model dtype and/or attention implementation
|
blocks_jvlm.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
# Copyright 2025 Jina AI. All rights reserved.
|
| 2 |
|
| 3 |
from abc import ABCMeta, abstractmethod
|
|
|
|
| 4 |
from copy import deepcopy
|
| 5 |
from functools import wraps
|
| 6 |
from math import prod, sqrt
|
|
@@ -11,6 +12,7 @@ import torch
|
|
| 11 |
import torch.backends.cuda
|
| 12 |
import torch.nn as nn
|
| 13 |
import torch.nn.functional as f
|
|
|
|
| 14 |
from transformers import PretrainedConfig
|
| 15 |
from transformers.activations import ACT2FN
|
| 16 |
from transformers.cache_utils import Cache
|
|
@@ -324,10 +326,11 @@ modeling_rope_utils.py
|
|
| 324 |
|
| 325 |
|
| 326 |
def inv_freq_to_device(rope_forward):
|
| 327 |
-
"""
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
device and in float32
|
|
|
|
| 331 |
"""
|
| 332 |
|
| 333 |
@wraps(rope_forward)
|
|
@@ -353,7 +356,6 @@ class RotaryEmbedding(nn.Module):
|
|
| 353 |
theta: float,
|
| 354 |
head_dim: int,
|
| 355 |
hidden_size: int,
|
| 356 |
-
n_heads: int,
|
| 357 |
partial_rotary_factor: float,
|
| 358 |
device: Optional[torch.device] = None,
|
| 359 |
scaling: Optional[Dict[str, Any]] = None,
|
|
@@ -366,7 +368,6 @@ class RotaryEmbedding(nn.Module):
|
|
| 366 |
setattr(self.config, 'rope_theta', theta)
|
| 367 |
setattr(self.config, 'partial_rotary_factor', partial_rotary_factor)
|
| 368 |
setattr(self.config, 'head_dim', head_dim)
|
| 369 |
-
setattr(self.config, 'num_attention_heads', n_heads)
|
| 370 |
setattr(self.config, 'hidden_size', hidden_size)
|
| 371 |
setattr(self.config, 'rope_scaling', scaling or {})
|
| 372 |
|
|
@@ -377,9 +378,7 @@ class RotaryEmbedding(nn.Module):
|
|
| 377 |
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 378 |
device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 379 |
seqlen = config.max_position_embeddings or config.max_sequence_length
|
| 380 |
-
invfreq, self.attention_scaling = self.rope_init_fn(
|
| 381 |
-
self.config, device, seqlen
|
| 382 |
-
)
|
| 383 |
self.rope_init_device = device
|
| 384 |
self.register_buffer('inv_freq', invfreq, persistent=False)
|
| 385 |
self.original_inv_freq = self.inv_freq
|
|
@@ -617,11 +616,9 @@ def _create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
|
|
| 617 |
def _ensure_finite(
|
| 618 |
x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False
|
| 619 |
):
|
| 620 |
-
"""
|
| 621 |
-
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the
|
| 622 |
dtype when ``check_neg_inf`` is ``True`` and replace ``float("inf")`` with the
|
| 623 |
-
maximum value of the dtype when ``check_pos_inf`` is ``True``
|
| 624 |
-
"""
|
| 625 |
if check_neg_inf:
|
| 626 |
x.masked_fill_(x == float('-inf'), torch.finfo(x.dtype).min)
|
| 627 |
if check_pos_inf:
|
|
@@ -641,14 +638,12 @@ def resolve_causal_mask(
|
|
| 641 |
# shape: (batch_size, 1, 1, seq_len)
|
| 642 |
if len(attention_mask.shape) == 2:
|
| 643 |
attention_mask = attention_mask[:, : past_length + seq_len]
|
| 644 |
-
attention_mask = attention_mask.to(dtype=torch.float).view(
|
| 645 |
-
|
| 646 |
-
|
| 647 |
else:
|
| 648 |
attention_mask = attention_mask.unsqueeze(1).to(dtype=torch.float)
|
| 649 |
-
attention_mask = (1.0 - attention_mask) * torch.finfo(
|
| 650 |
-
attention_mask.dtype
|
| 651 |
-
).min
|
| 652 |
|
| 653 |
# Merge attention mask with causal mask (attention bias)
|
| 654 |
# NOTE: We need to initialize the attn bias in order for attn to
|
|
@@ -660,9 +655,7 @@ def resolve_causal_mask(
|
|
| 660 |
or past_key_values is not None
|
| 661 |
):
|
| 662 |
if causal_mask is None:
|
| 663 |
-
causal_mask = _create_causal_mask(
|
| 664 |
-
past_length + seq_len, device
|
| 665 |
-
)
|
| 666 |
elif causal_mask.dtype in (torch.int8, torch.bool):
|
| 667 |
causal_mask = causal_mask.to(dtype=torch.float)
|
| 668 |
causal_mask.masked_fill_(
|
|
@@ -719,6 +712,7 @@ def eager_attention_forward(
|
|
| 719 |
dropout: float = 0.0,
|
| 720 |
**_,
|
| 721 |
):
|
|
|
|
| 722 |
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 723 |
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 724 |
|
|
@@ -745,7 +739,9 @@ def rotate_half(x: torch.Tensor):
|
|
| 745 |
|
| 746 |
|
| 747 |
def apply_rotary_positional_embeddings(
|
| 748 |
-
x: torch.Tensor,
|
|
|
|
|
|
|
| 749 |
) -> torch.Tensor:
|
| 750 |
return (x * cos + rotate_half(x) * sin).to(x.dtype)
|
| 751 |
|
|
@@ -890,7 +886,6 @@ class MHSDPA(nn.Module):
|
|
| 890 |
attn_mask: Optional[torch.Tensor] = None,
|
| 891 |
is_causal: Optional[bool] = None,
|
| 892 |
) -> Tuple[Callable, Optional[torch.Tensor], Optional[bool]]:
|
| 893 |
-
|
| 894 |
if 'flash' in attn_implementation and self.fp32_attn:
|
| 895 |
raise ValueError('Flash attention does not support fp32 attention')
|
| 896 |
if self.sliding_window != -1 and 'flash' not in attn_implementation:
|
|
@@ -1071,9 +1066,7 @@ class FFN(nn.Module):
|
|
| 1071 |
if self.gated_activation:
|
| 1072 |
intermediate_size = 2 * self.intermediate_size
|
| 1073 |
|
| 1074 |
-
self.up = nn.Linear(
|
| 1075 |
-
self.hidden_size, intermediate_size, bias=self.use_bias
|
| 1076 |
-
)
|
| 1077 |
self.down = nn.Linear(
|
| 1078 |
self.intermediate_size, self.output_size, bias=self.use_bias
|
| 1079 |
)
|
|
@@ -1245,6 +1238,8 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
|
|
| 1245 |
assert config.attn_pooling_config is not None
|
| 1246 |
if config.pooling_type == ImagePooling2DType.attention_2wide:
|
| 1247 |
pooling_input_size *= 2
|
|
|
|
|
|
|
| 1248 |
self.pooling = MHSDPA(
|
| 1249 |
config.attn_pooling_config,
|
| 1250 |
hidden_size=pooling_input_size,
|
|
@@ -1285,11 +1280,29 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
|
|
| 1285 |
self.projector_dropout = Dropout(config.projector_dropout)
|
| 1286 |
self.feature_dropout = Dropout(config.feature_dropout)
|
| 1287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1288 |
def forward(
|
| 1289 |
self,
|
| 1290 |
image_features: torch.Tensor,
|
| 1291 |
image_masks: Optional[torch.Tensor] = None,
|
| 1292 |
attn_implementation: Optional[str] = None,
|
|
|
|
| 1293 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 1294 |
# image_features:
|
| 1295 |
# (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
|
|
@@ -1345,11 +1358,19 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
|
|
| 1345 |
dh=self.pooling_h,
|
| 1346 |
dw=self.pooling_w,
|
| 1347 |
)
|
|
|
|
| 1348 |
if self.pooling_type == ImagePooling2DType.attention_meanq:
|
| 1349 |
query = image_features.mean(-2, keepdim=True)
|
| 1350 |
-
|
| 1351 |
-
|
| 1352 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1353 |
elif self.pooling_type not in {
|
| 1354 |
ImagePooling2DType.none,
|
| 1355 |
ImagePooling2DType.stack,
|
|
|
|
| 1 |
# Copyright 2025 Jina AI. All rights reserved.
|
| 2 |
|
| 3 |
from abc import ABCMeta, abstractmethod
|
| 4 |
+
from contextlib import nullcontext
|
| 5 |
from copy import deepcopy
|
| 6 |
from functools import wraps
|
| 7 |
from math import prod, sqrt
|
|
|
|
| 12 |
import torch.backends.cuda
|
| 13 |
import torch.nn as nn
|
| 14 |
import torch.nn.functional as f
|
| 15 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 16 |
from transformers import PretrainedConfig
|
| 17 |
from transformers.activations import ACT2FN
|
| 18 |
from transformers.cache_utils import Cache
|
|
|
|
| 326 |
|
| 327 |
|
| 328 |
def inv_freq_to_device(rope_forward):
|
| 329 |
+
"""Sometimes the inv_freq is calculated on the wrong device, or ends up in lower
|
| 330 |
+
precision than float32.
|
| 331 |
+
|
| 332 |
+
This wrapper ensures that inv_freq is always on the right device and in float32
|
| 333 |
+
precision.
|
| 334 |
"""
|
| 335 |
|
| 336 |
@wraps(rope_forward)
|
|
|
|
| 356 |
theta: float,
|
| 357 |
head_dim: int,
|
| 358 |
hidden_size: int,
|
|
|
|
| 359 |
partial_rotary_factor: float,
|
| 360 |
device: Optional[torch.device] = None,
|
| 361 |
scaling: Optional[Dict[str, Any]] = None,
|
|
|
|
| 368 |
setattr(self.config, 'rope_theta', theta)
|
| 369 |
setattr(self.config, 'partial_rotary_factor', partial_rotary_factor)
|
| 370 |
setattr(self.config, 'head_dim', head_dim)
|
|
|
|
| 371 |
setattr(self.config, 'hidden_size', hidden_size)
|
| 372 |
setattr(self.config, 'rope_scaling', scaling or {})
|
| 373 |
|
|
|
|
| 378 |
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 379 |
device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 380 |
seqlen = config.max_position_embeddings or config.max_sequence_length
|
| 381 |
+
invfreq, self.attention_scaling = self.rope_init_fn(self.config, device, seqlen)
|
|
|
|
|
|
|
| 382 |
self.rope_init_device = device
|
| 383 |
self.register_buffer('inv_freq', invfreq, persistent=False)
|
| 384 |
self.original_inv_freq = self.inv_freq
|
|
|
|
| 616 |
def _ensure_finite(
|
| 617 |
x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False
|
| 618 |
):
|
| 619 |
+
"""Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the
|
|
|
|
| 620 |
dtype when ``check_neg_inf`` is ``True`` and replace ``float("inf")`` with the
|
| 621 |
+
maximum value of the dtype when ``check_pos_inf`` is ``True``"""
|
|
|
|
| 622 |
if check_neg_inf:
|
| 623 |
x.masked_fill_(x == float('-inf'), torch.finfo(x.dtype).min)
|
| 624 |
if check_pos_inf:
|
|
|
|
| 638 |
# shape: (batch_size, 1, 1, seq_len)
|
| 639 |
if len(attention_mask.shape) == 2:
|
| 640 |
attention_mask = attention_mask[:, : past_length + seq_len]
|
| 641 |
+
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[
|
| 642 |
+
:, None, None, :
|
| 643 |
+
]
|
| 644 |
else:
|
| 645 |
attention_mask = attention_mask.unsqueeze(1).to(dtype=torch.float)
|
| 646 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
|
|
|
|
|
|
|
| 647 |
|
| 648 |
# Merge attention mask with causal mask (attention bias)
|
| 649 |
# NOTE: We need to initialize the attn bias in order for attn to
|
|
|
|
| 655 |
or past_key_values is not None
|
| 656 |
):
|
| 657 |
if causal_mask is None:
|
| 658 |
+
causal_mask = _create_causal_mask(past_length + seq_len, device)
|
|
|
|
|
|
|
| 659 |
elif causal_mask.dtype in (torch.int8, torch.bool):
|
| 660 |
causal_mask = causal_mask.to(dtype=torch.float)
|
| 661 |
causal_mask.masked_fill_(
|
|
|
|
| 712 |
dropout: float = 0.0,
|
| 713 |
**_,
|
| 714 |
):
|
| 715 |
+
assert isinstance(module.num_key_value_groups, int)
|
| 716 |
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 717 |
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 718 |
|
|
|
|
| 739 |
|
| 740 |
|
| 741 |
def apply_rotary_positional_embeddings(
|
| 742 |
+
x: torch.Tensor,
|
| 743 |
+
cos: torch.Tensor,
|
| 744 |
+
sin: torch.Tensor,
|
| 745 |
) -> torch.Tensor:
|
| 746 |
return (x * cos + rotate_half(x) * sin).to(x.dtype)
|
| 747 |
|
|
|
|
| 886 |
attn_mask: Optional[torch.Tensor] = None,
|
| 887 |
is_causal: Optional[bool] = None,
|
| 888 |
) -> Tuple[Callable, Optional[torch.Tensor], Optional[bool]]:
|
|
|
|
| 889 |
if 'flash' in attn_implementation and self.fp32_attn:
|
| 890 |
raise ValueError('Flash attention does not support fp32 attention')
|
| 891 |
if self.sliding_window != -1 and 'flash' not in attn_implementation:
|
|
|
|
| 1066 |
if self.gated_activation:
|
| 1067 |
intermediate_size = 2 * self.intermediate_size
|
| 1068 |
|
| 1069 |
+
self.up = nn.Linear(self.hidden_size, intermediate_size, bias=self.use_bias)
|
|
|
|
|
|
|
| 1070 |
self.down = nn.Linear(
|
| 1071 |
self.intermediate_size, self.output_size, bias=self.use_bias
|
| 1072 |
)
|
|
|
|
| 1238 |
assert config.attn_pooling_config is not None
|
| 1239 |
if config.pooling_type == ImagePooling2DType.attention_2wide:
|
| 1240 |
pooling_input_size *= 2
|
| 1241 |
+
|
| 1242 |
+
attn_implementation, _ = self._resolve_attn_pooling(attn_implementation)
|
| 1243 |
self.pooling = MHSDPA(
|
| 1244 |
config.attn_pooling_config,
|
| 1245 |
hidden_size=pooling_input_size,
|
|
|
|
| 1280 |
self.projector_dropout = Dropout(config.projector_dropout)
|
| 1281 |
self.feature_dropout = Dropout(config.feature_dropout)
|
| 1282 |
|
| 1283 |
+
@staticmethod
|
| 1284 |
+
def _resolve_attn_pooling(attn_implementation: Optional[str] = None):
|
| 1285 |
+
"""
|
| 1286 |
+
Flash Attention can cause Inf grads in the attention pooling layer because of
|
| 1287 |
+
very large batch sizes. Setting this to sdpa does not cost us much since
|
| 1288 |
+
sequence lengths in the case of attention pooling are tiny
|
| 1289 |
+
"""
|
| 1290 |
+
attn_runtime_ctx = nullcontext()
|
| 1291 |
+
if (
|
| 1292 |
+
attn_implementation is not None
|
| 1293 |
+
and attn_implementation.startswith('flash')
|
| 1294 |
+
):
|
| 1295 |
+
attn_implementation = 'sdpa'
|
| 1296 |
+
attn_runtime_ctx = sdpa_kernel(backends=[SDPBackend.MATH])
|
| 1297 |
+
|
| 1298 |
+
return attn_implementation, attn_runtime_ctx
|
| 1299 |
+
|
| 1300 |
def forward(
|
| 1301 |
self,
|
| 1302 |
image_features: torch.Tensor,
|
| 1303 |
image_masks: Optional[torch.Tensor] = None,
|
| 1304 |
attn_implementation: Optional[str] = None,
|
| 1305 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 1306 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 1307 |
# image_features:
|
| 1308 |
# (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
|
|
|
|
| 1358 |
dh=self.pooling_h,
|
| 1359 |
dw=self.pooling_w,
|
| 1360 |
)
|
| 1361 |
+
image_features = image_features.contiguous()
|
| 1362 |
if self.pooling_type == ImagePooling2DType.attention_meanq:
|
| 1363 |
query = image_features.mean(-2, keepdim=True)
|
| 1364 |
+
attn_implementation, attn_runtime_ctx = self._resolve_attn_pooling(
|
| 1365 |
+
attn_implementation
|
| 1366 |
)
|
| 1367 |
+
with attn_runtime_ctx:
|
| 1368 |
+
image_features, _ = self.pooling(
|
| 1369 |
+
xq=query,
|
| 1370 |
+
xk=image_features,
|
| 1371 |
+
attn_implementation=attn_implementation,
|
| 1372 |
+
**kwargs,
|
| 1373 |
+
)
|
| 1374 |
elif self.pooling_type not in {
|
| 1375 |
ImagePooling2DType.none,
|
| 1376 |
ImagePooling2DType.stack,
|
config.json
CHANGED
|
@@ -4,6 +4,7 @@
|
|
| 4 |
],
|
| 5 |
"auto_map": {
|
| 6 |
"AutoConfig": "configuration_jvlm.JinaVLMConfig",
|
|
|
|
| 7 |
"AutoModelForCausalLM": "modeling_jvlm.JinaVLMForConditionalGeneration"
|
| 8 |
},
|
| 9 |
"bos_token_id": 151643,
|
|
@@ -214,4 +215,4 @@
|
|
| 214 |
"spatial_merge_size": 2
|
| 215 |
}
|
| 216 |
}
|
| 217 |
-
}
|
|
|
|
| 4 |
],
|
| 5 |
"auto_map": {
|
| 6 |
"AutoConfig": "configuration_jvlm.JinaVLMConfig",
|
| 7 |
+
"AutoModel": "modeling_jvlm.JinaVLM",
|
| 8 |
"AutoModelForCausalLM": "modeling_jvlm.JinaVLMForConditionalGeneration"
|
| 9 |
},
|
| 10 |
"bos_token_id": 151643,
|
|
|
|
| 215 |
"spatial_merge_size": 2
|
| 216 |
}
|
| 217 |
}
|
| 218 |
+
}
|
configuration_jvlm.py
CHANGED
|
@@ -530,6 +530,11 @@ class JinaVLMTextConfig(PretrainedConfigWithDataclasses):
|
|
| 530 |
self.rope_theta = rope_theta
|
| 531 |
self.rope_scaling = rope_scaling
|
| 532 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
|
| 534 |
class JinaVLMConfig(PretrainedConfig):
|
| 535 |
"""JinaVLM configuration.
|
|
@@ -545,7 +550,8 @@ class JinaVLMConfig(PretrainedConfig):
|
|
| 545 |
|
| 546 |
model_type = 'jvlm'
|
| 547 |
sub_configs = {
|
| 548 |
-
'vision_config': JinaVLMVisionConfig,
|
|
|
|
| 549 |
}
|
| 550 |
|
| 551 |
def __init__(
|
|
|
|
| 530 |
self.rope_theta = rope_theta
|
| 531 |
self.rope_scaling = rope_scaling
|
| 532 |
|
| 533 |
+
# Needed for vLLM
|
| 534 |
+
@property
|
| 535 |
+
def num_attention_heads(self) -> int:
|
| 536 |
+
return self.block_config.attn_config.n_heads
|
| 537 |
+
|
| 538 |
|
| 539 |
class JinaVLMConfig(PretrainedConfig):
|
| 540 |
"""JinaVLM configuration.
|
|
|
|
| 550 |
|
| 551 |
model_type = 'jvlm'
|
| 552 |
sub_configs = {
|
| 553 |
+
'vision_config': JinaVLMVisionConfig,
|
| 554 |
+
'text_config': JinaVLMTextConfig,
|
| 555 |
}
|
| 556 |
|
| 557 |
def __init__(
|
image_processing_jvlm.py
CHANGED
|
@@ -437,6 +437,17 @@ class JinaVLMImageProcessor(BaseImageProcessor):
|
|
| 437 |
|
| 438 |
""" Base cropping via resizing """
|
| 439 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
def base_resize_cropping(self, image: np.ndarray):
|
| 441 |
resized, mask = self.resize_image(image, list(self.base_input_size))
|
| 442 |
resized = self.normalize_image(resized)
|
|
@@ -497,6 +508,117 @@ class JinaVLMImageProcessor(BaseImageProcessor):
|
|
| 497 |
|
| 498 |
return candidate_tilings[ix]
|
| 499 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
def molmo_overlap_and_resize_cropping(self, image: np.ndarray):
|
| 501 |
# Discard this many patches from the (left/top, right/bottom) of crops
|
| 502 |
left_margin, right_margin = self.overlap_margins
|
|
@@ -625,37 +747,23 @@ class JinaVLMImageProcessor(BaseImageProcessor):
|
|
| 625 |
# new order into sparse structure of `patch_ordering` to fix it
|
| 626 |
patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
|
| 627 |
|
| 628 |
-
def get_num_patches(num_tiles, pooling_size) -> int:
|
| 629 |
-
if num_tiles > 1:
|
| 630 |
-
left_crop_window_patches = (
|
| 631 |
-
(crop_window_patches + left_margin + pooling_size - 1)
|
| 632 |
-
// pooling_size
|
| 633 |
-
* pooling_size
|
| 634 |
-
)
|
| 635 |
-
middle_crop_window_patches = (
|
| 636 |
-
(crop_window_patches + pooling_size - 1)
|
| 637 |
-
// pooling_size
|
| 638 |
-
* pooling_size
|
| 639 |
-
)
|
| 640 |
-
right_crop_window_patches = (
|
| 641 |
-
(crop_window_patches + right_margin + pooling_size - 1)
|
| 642 |
-
// pooling_size
|
| 643 |
-
* pooling_size
|
| 644 |
-
)
|
| 645 |
-
return (
|
| 646 |
-
left_crop_window_patches
|
| 647 |
-
+ (num_tiles - 2) * middle_crop_window_patches
|
| 648 |
-
+ right_crop_window_patches
|
| 649 |
-
)
|
| 650 |
-
else:
|
| 651 |
-
single_crop_window_patches = (
|
| 652 |
-
(crop_patches + pooling_size - 1) // pooling_size * pooling_size
|
| 653 |
-
)
|
| 654 |
-
return single_crop_window_patches
|
| 655 |
-
|
| 656 |
# Now build the output tokens
|
| 657 |
-
h =
|
| 658 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
# for each row of patches, add a patch token per patch
|
| 660 |
per_row = np.full((w // self.pooling_w,), self.patch_token_id, dtype=np.int32)
|
| 661 |
if self.use_column_tokens:
|
|
@@ -810,6 +918,14 @@ class JinaVLMImageProcessor(BaseImageProcessor):
|
|
| 810 |
|
| 811 |
return slices, image_masks, patch_ordering_arr, best_grid
|
| 812 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
def minicpm_adaptive_slicing(self, image: np.ndarray, with_thumbnail: bool = True):
|
| 814 |
scale_resolution = self.base_input_size[0]
|
| 815 |
refine_image, image_mask, best_grid = self._minicpm_refine_image_for_slicing(
|
|
@@ -946,23 +1062,12 @@ class JinaVLMImageProcessor(BaseImageProcessor):
|
|
| 946 |
self.start_token_id = start_token_id
|
| 947 |
self.end_token_id = end_token_id
|
| 948 |
|
| 949 |
-
def
|
| 950 |
-
self,
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
) -> Dict[str, List[np.ndarray]]:
|
| 954 |
-
"""Preprocess an image or batch of images."""
|
| 955 |
-
if images is None or len(images) == 0:
|
| 956 |
-
return {
|
| 957 |
-
'image_crops': [],
|
| 958 |
-
'image_tokens': [],
|
| 959 |
-
'image_input_idx': [],
|
| 960 |
-
'image_padding_mask': [],
|
| 961 |
-
}
|
| 962 |
-
|
| 963 |
if 'max_crops' in kwargs and kwargs['max_crops'] is not None:
|
| 964 |
max_crops = kwargs['max_crops']
|
| 965 |
-
self.max_crops = max_crops
|
| 966 |
|
| 967 |
min_pixels = self.min_pixels
|
| 968 |
if 'min_pixels' in kwargs and kwargs['min_pixels'] is not None:
|
|
@@ -984,14 +1089,93 @@ class JinaVLMImageProcessor(BaseImageProcessor):
|
|
| 984 |
size = {'shortest_edge': min_pixels, 'longest_edge': max_pixels}
|
| 985 |
else:
|
| 986 |
size = {**self.size}
|
| 987 |
-
|
|
|
|
| 988 |
do_resize = self.do_resize
|
| 989 |
if 'do_resize' in kwargs and kwargs['do_resize'] is not None:
|
| 990 |
do_resize = kwargs['do_resize']
|
| 991 |
-
|
| 992 |
do_convert_rgb = self.do_convert_rgb
|
| 993 |
if 'do_convert_rgb' in kwargs and kwargs['do_convert_rgb'] is not None:
|
| 994 |
do_convert_rgb = kwargs['do_convert_rgb']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 995 |
|
| 996 |
# noinspection PyTypeChecker
|
| 997 |
images = self.fetch_images(images)
|
|
@@ -1001,16 +1185,11 @@ class JinaVLMImageProcessor(BaseImageProcessor):
|
|
| 1001 |
'Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray '
|
| 1002 |
'or torch.Tensor'
|
| 1003 |
)
|
| 1004 |
-
|
| 1005 |
if do_convert_rgb:
|
| 1006 |
images = [convert_to_rgb(image) for image in images]
|
| 1007 |
|
| 1008 |
# All transformations expect numpy arrays
|
| 1009 |
images = [to_numpy_array(image) for image in images]
|
| 1010 |
-
|
| 1011 |
-
input_data_format = None
|
| 1012 |
-
if 'input_data_format' in kwargs:
|
| 1013 |
-
input_data_format = kwargs['input_data_format']
|
| 1014 |
if input_data_format is None:
|
| 1015 |
# We assume that all images have the same channel dimension format.
|
| 1016 |
input_data_format = infer_channel_dimension_format(images[0])
|
|
|
|
| 437 |
|
| 438 |
""" Base cropping via resizing """
|
| 439 |
|
| 440 |
+
def base_get_n_image_patches(
|
| 441 |
+
self,
|
| 442 |
+
height: int,
|
| 443 |
+
width: int,
|
| 444 |
+
max_crops: int,
|
| 445 |
+
) -> int:
|
| 446 |
+
raise NotImplementedError(
|
| 447 |
+
'Function `get_n_image_patches` is not implemented for cropping method '
|
| 448 |
+
f'{CroppingMethod.RESIZE}'
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
def base_resize_cropping(self, image: np.ndarray):
|
| 452 |
resized, mask = self.resize_image(image, list(self.base_input_size))
|
| 453 |
resized = self.normalize_image(resized)
|
|
|
|
| 508 |
|
| 509 |
return candidate_tilings[ix]
|
| 510 |
|
| 511 |
+
@staticmethod
|
| 512 |
+
def _molmo_get_patches_from_tiling(
|
| 513 |
+
num_tiles,
|
| 514 |
+
pooling_size,
|
| 515 |
+
crop_patches,
|
| 516 |
+
crop_window_patches,
|
| 517 |
+
left_margin,
|
| 518 |
+
right_margin,
|
| 519 |
+
) -> np.int32:
|
| 520 |
+
if num_tiles > 1:
|
| 521 |
+
left_crop_window_patches = (
|
| 522 |
+
(crop_window_patches + left_margin + pooling_size - 1)
|
| 523 |
+
// pooling_size
|
| 524 |
+
* pooling_size
|
| 525 |
+
)
|
| 526 |
+
middle_crop_window_patches = (
|
| 527 |
+
(crop_window_patches + pooling_size - 1) // pooling_size * pooling_size
|
| 528 |
+
)
|
| 529 |
+
right_crop_window_patches = (
|
| 530 |
+
(crop_window_patches + right_margin + pooling_size - 1)
|
| 531 |
+
// pooling_size
|
| 532 |
+
* pooling_size
|
| 533 |
+
)
|
| 534 |
+
return (
|
| 535 |
+
left_crop_window_patches
|
| 536 |
+
+ (num_tiles - 2) * middle_crop_window_patches
|
| 537 |
+
+ right_crop_window_patches
|
| 538 |
+
)
|
| 539 |
+
else:
|
| 540 |
+
single_crop_window_patches = (
|
| 541 |
+
(crop_patches + pooling_size - 1) // pooling_size * pooling_size
|
| 542 |
+
)
|
| 543 |
+
return single_crop_window_patches
|
| 544 |
+
|
| 545 |
+
def molmo_get_n_image_patches(
|
| 546 |
+
self,
|
| 547 |
+
height: int,
|
| 548 |
+
width: int,
|
| 549 |
+
max_crops: int,
|
| 550 |
+
) -> int:
|
| 551 |
+
# Discard this many patches from the (left/top, right/bottom) of crops
|
| 552 |
+
left_margin, right_margin = self.overlap_margins
|
| 553 |
+
# Required for compatibility with image pooling
|
| 554 |
+
assert left_margin % self.pooling_w == 0 and right_margin % self.pooling_w == 0
|
| 555 |
+
assert left_margin % self.pooling_h == 0 and right_margin % self.pooling_h == 0
|
| 556 |
+
# pixels removed per dim
|
| 557 |
+
total_margin_pixels = self.patch_size * (right_margin + left_margin)
|
| 558 |
+
# patches per crop dim
|
| 559 |
+
crop_patches = self.base_input_size[0] // self.patch_size
|
| 560 |
+
|
| 561 |
+
# usable patches
|
| 562 |
+
crop_window_patches = crop_patches - (right_margin + left_margin)
|
| 563 |
+
crop_window_size = crop_window_patches * self.patch_size
|
| 564 |
+
|
| 565 |
+
# We assume hxw pooling, but can allow padding the right/bottom with extra
|
| 566 |
+
# patches if the number of patches per side is not divisible by h/w
|
| 567 |
+
assert (
|
| 568 |
+
crop_patches + self.pooling_h - 1
|
| 569 |
+
) // self.pooling_h == self.token_length_h
|
| 570 |
+
assert (
|
| 571 |
+
crop_patches + self.pooling_w - 1
|
| 572 |
+
) // self.pooling_w == self.token_length_w
|
| 573 |
+
|
| 574 |
+
# Decide how to tile the image, to account for the overlap margins we
|
| 575 |
+
# compute the tiling as if we had an image without the margins and were
|
| 576 |
+
# using a crop size without the margins
|
| 577 |
+
tiling = self._molmo_select_tiling(
|
| 578 |
+
height - total_margin_pixels,
|
| 579 |
+
width - total_margin_pixels,
|
| 580 |
+
crop_window_size,
|
| 581 |
+
max_crops,
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# Now build the output tokens
|
| 585 |
+
h = self._molmo_get_patches_from_tiling(
|
| 586 |
+
tiling[0],
|
| 587 |
+
self.pooling_h,
|
| 588 |
+
crop_patches,
|
| 589 |
+
crop_window_patches,
|
| 590 |
+
left_margin,
|
| 591 |
+
right_margin,
|
| 592 |
+
)
|
| 593 |
+
w = self._molmo_get_patches_from_tiling(
|
| 594 |
+
tiling[1],
|
| 595 |
+
self.pooling_w,
|
| 596 |
+
crop_patches,
|
| 597 |
+
crop_window_patches,
|
| 598 |
+
left_margin,
|
| 599 |
+
right_margin,
|
| 600 |
+
)
|
| 601 |
+
# for each row of patches, add a patch token per patch
|
| 602 |
+
n_tokens = w.item() // self.pooling_w
|
| 603 |
+
if self.use_column_tokens:
|
| 604 |
+
# after each row, one column token is added
|
| 605 |
+
n_tokens += 1
|
| 606 |
+
# replicate each row of patch tokens by number of rows, i.e.
|
| 607 |
+
# proportional to image height
|
| 608 |
+
n_tokens *= h.item() // self.pooling_h
|
| 609 |
+
# add start and end image tokens
|
| 610 |
+
n_tokens += 2
|
| 611 |
+
|
| 612 |
+
# Global image goes first, so the order of patches in previous crops gets
|
| 613 |
+
# increased
|
| 614 |
+
n_thumbnail_tokens = self.token_length_w
|
| 615 |
+
if self.use_column_tokens:
|
| 616 |
+
n_thumbnail_tokens += 1
|
| 617 |
+
n_thumbnail_tokens *= self.token_length_h
|
| 618 |
+
n_thumbnail_tokens += 2
|
| 619 |
+
|
| 620 |
+
return n_tokens + n_thumbnail_tokens
|
| 621 |
+
|
| 622 |
def molmo_overlap_and_resize_cropping(self, image: np.ndarray):
|
| 623 |
# Discard this many patches from the (left/top, right/bottom) of crops
|
| 624 |
left_margin, right_margin = self.overlap_margins
|
|
|
|
| 747 |
# new order into sparse structure of `patch_ordering` to fix it
|
| 748 |
patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
|
| 749 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
# Now build the output tokens
|
| 751 |
+
h = self._molmo_get_patches_from_tiling(
|
| 752 |
+
tiling[0],
|
| 753 |
+
self.pooling_h,
|
| 754 |
+
crop_patches,
|
| 755 |
+
crop_window_patches,
|
| 756 |
+
left_margin,
|
| 757 |
+
right_margin,
|
| 758 |
+
)
|
| 759 |
+
w = self._molmo_get_patches_from_tiling(
|
| 760 |
+
tiling[1],
|
| 761 |
+
self.pooling_w,
|
| 762 |
+
crop_patches,
|
| 763 |
+
crop_window_patches,
|
| 764 |
+
left_margin,
|
| 765 |
+
right_margin,
|
| 766 |
+
)
|
| 767 |
# for each row of patches, add a patch token per patch
|
| 768 |
per_row = np.full((w // self.pooling_w,), self.patch_token_id, dtype=np.int32)
|
| 769 |
if self.use_column_tokens:
|
|
|
|
| 918 |
|
| 919 |
return slices, image_masks, patch_ordering_arr, best_grid
|
| 920 |
|
| 921 |
+
def minicpm_get_n_image_patches(
|
| 922 |
+
self, height: int, width: int, max_crops: int, with_thumbnail: bool = False
|
| 923 |
+
) -> int:
|
| 924 |
+
raise NotImplementedError(
|
| 925 |
+
'Function `get_n_image_patches` is not implemented for cropping method '
|
| 926 |
+
f'{CroppingMethod.ADAPTIVE_SLICING}'
|
| 927 |
+
)
|
| 928 |
+
|
| 929 |
def minicpm_adaptive_slicing(self, image: np.ndarray, with_thumbnail: bool = True):
|
| 930 |
scale_resolution = self.base_input_size[0]
|
| 931 |
refine_image, image_mask, best_grid = self._minicpm_refine_image_for_slicing(
|
|
|
|
| 1062 |
self.start_token_id = start_token_id
|
| 1063 |
self.end_token_id = end_token_id
|
| 1064 |
|
| 1065 |
+
def _resolve_images_kwargs(
|
| 1066 |
+
self, **kwargs: Unpack[JinaVLMImagesKwargs]
|
| 1067 |
+
) -> JinaVLMImagesKwargs:
|
| 1068 |
+
max_crops = self.max_crops
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1069 |
if 'max_crops' in kwargs and kwargs['max_crops'] is not None:
|
| 1070 |
max_crops = kwargs['max_crops']
|
|
|
|
| 1071 |
|
| 1072 |
min_pixels = self.min_pixels
|
| 1073 |
if 'min_pixels' in kwargs and kwargs['min_pixels'] is not None:
|
|
|
|
| 1089 |
size = {'shortest_edge': min_pixels, 'longest_edge': max_pixels}
|
| 1090 |
else:
|
| 1091 |
size = {**self.size}
|
| 1092 |
+
min_pixels = size['shortest_edge']
|
| 1093 |
+
max_pixels = size['longest_edge']
|
| 1094 |
do_resize = self.do_resize
|
| 1095 |
if 'do_resize' in kwargs and kwargs['do_resize'] is not None:
|
| 1096 |
do_resize = kwargs['do_resize']
|
|
|
|
| 1097 |
do_convert_rgb = self.do_convert_rgb
|
| 1098 |
if 'do_convert_rgb' in kwargs and kwargs['do_convert_rgb'] is not None:
|
| 1099 |
do_convert_rgb = kwargs['do_convert_rgb']
|
| 1100 |
+
input_data_format = None
|
| 1101 |
+
if 'input_data_format' in kwargs:
|
| 1102 |
+
input_data_format = kwargs['input_data_format']
|
| 1103 |
+
|
| 1104 |
+
return JinaVLMImagesKwargs(
|
| 1105 |
+
do_convert_rgb=do_convert_rgb,
|
| 1106 |
+
do_resize=do_resize,
|
| 1107 |
+
min_pixels=min_pixels,
|
| 1108 |
+
max_pixels=max_pixels,
|
| 1109 |
+
size=size,
|
| 1110 |
+
max_crops=max_crops,
|
| 1111 |
+
input_data_format=input_data_format,
|
| 1112 |
+
)
|
| 1113 |
+
|
| 1114 |
+
def get_n_image_patches(
|
| 1115 |
+
self,
|
| 1116 |
+
height: int,
|
| 1117 |
+
width: int,
|
| 1118 |
+
**kwargs: Unpack[JinaVLMImagesKwargs],
|
| 1119 |
+
) -> int:
|
| 1120 |
+
"""A utility that returns number of image patches for a given image size.
|
| 1121 |
+
|
| 1122 |
+
Args:
|
| 1123 |
+
height (`int`):
|
| 1124 |
+
Height of the input image.
|
| 1125 |
+
width (`int`):
|
| 1126 |
+
Width of the input image.
|
| 1127 |
+
**kwargs (`dict`, *optional*)
|
| 1128 |
+
Any kwargs to override defaults of the image processor.
|
| 1129 |
+
Returns:
|
| 1130 |
+
`int`: Number of image patches
|
| 1131 |
+
"""
|
| 1132 |
+
if self.cropping_method != CroppingMethod.OVERLAP_AND_RESIZE:
|
| 1133 |
+
raise NotImplementedError(
|
| 1134 |
+
'Function is only implemented for cropping method '
|
| 1135 |
+
f'{CroppingMethod.OVERLAP_AND_RESIZE}'
|
| 1136 |
+
)
|
| 1137 |
+
kwargs = self._resolve_images_kwargs(**kwargs)
|
| 1138 |
+
do_resize = kwargs['do_resize']
|
| 1139 |
+
size = kwargs['size']
|
| 1140 |
+
max_crops = kwargs['max_crops']
|
| 1141 |
+
if do_resize:
|
| 1142 |
+
height, width = smart_resize(
|
| 1143 |
+
height,
|
| 1144 |
+
width,
|
| 1145 |
+
factor=self.patch_size,
|
| 1146 |
+
min_pixels=size['shortest_edge'],
|
| 1147 |
+
max_pixels=size['longest_edge'],
|
| 1148 |
+
)
|
| 1149 |
+
|
| 1150 |
+
if self.cropping_method == CroppingMethod.RESIZE:
|
| 1151 |
+
return self.base_get_n_image_patches(height, width, max_crops)
|
| 1152 |
+
elif self.cropping_method == CroppingMethod.OVERLAP_AND_RESIZE:
|
| 1153 |
+
return self.molmo_get_n_image_patches(height, width, max_crops)
|
| 1154 |
+
elif self.cropping_method == CroppingMethod.ADAPTIVE_SLICING:
|
| 1155 |
+
return self.minicpm_get_n_image_patches(height, width, max_crops)
|
| 1156 |
+
return self.minicpm_get_n_image_patches(
|
| 1157 |
+
height, width, max_crops, with_thumbnail=True
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
def preprocess(
|
| 1161 |
+
self,
|
| 1162 |
+
images: ImageInput,
|
| 1163 |
+
**kwargs: Unpack[JinaVLMImagesKwargs],
|
| 1164 |
+
) -> Dict[str, List[np.ndarray]]:
|
| 1165 |
+
"""Preprocess an image or batch of images."""
|
| 1166 |
+
if images is None or len(images) == 0:
|
| 1167 |
+
return {
|
| 1168 |
+
'image_crops': [],
|
| 1169 |
+
'image_tokens': [],
|
| 1170 |
+
'image_input_idx': [],
|
| 1171 |
+
'image_padding_mask': [],
|
| 1172 |
+
}
|
| 1173 |
+
kwargs = self._resolve_images_kwargs(**kwargs)
|
| 1174 |
+
do_convert_rgb = kwargs['do_convert_rgb']
|
| 1175 |
+
do_resize = kwargs['do_resize']
|
| 1176 |
+
input_data_format = kwargs['input_data_format']
|
| 1177 |
+
size = kwargs['size']
|
| 1178 |
+
self.max_crops = kwargs['max_crops']
|
| 1179 |
|
| 1180 |
# noinspection PyTypeChecker
|
| 1181 |
images = self.fetch_images(images)
|
|
|
|
| 1185 |
'Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray '
|
| 1186 |
'or torch.Tensor'
|
| 1187 |
)
|
|
|
|
| 1188 |
if do_convert_rgb:
|
| 1189 |
images = [convert_to_rgb(image) for image in images]
|
| 1190 |
|
| 1191 |
# All transformations expect numpy arrays
|
| 1192 |
images = [to_numpy_array(image) for image in images]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1193 |
if input_data_format is None:
|
| 1194 |
# We assume that all images have the same channel dimension format.
|
| 1195 |
input_data_format = infer_channel_dimension_format(images[0])
|
modeling_jvlm.py
CHANGED
|
@@ -27,14 +27,13 @@ from .blocks_jvlm import (
|
|
| 27 |
TransformerBlock,
|
| 28 |
VisionLanguageConnector,
|
| 29 |
build_layer_norm,
|
| 30 |
-
resolve_causal_mask
|
| 31 |
)
|
| 32 |
from .configuration_jvlm import JinaVLMConfig, JinaVLMTextConfig, JinaVLMVisionConfig
|
| 33 |
|
| 34 |
|
| 35 |
class JinaPreTrainedModel(PreTrainedModel):
|
| 36 |
config: JinaVLMConfig
|
| 37 |
-
config_class = JinaVLMConfig
|
| 38 |
base_model_prefix = 'model'
|
| 39 |
supports_gradient_checkpointing = True
|
| 40 |
_supports_flash_attn = True
|
|
@@ -51,8 +50,6 @@ class JinaPreTrainedModel(PreTrainedModel):
|
|
| 51 |
|
| 52 |
class JinaVLMVisionModel(JinaPreTrainedModel):
|
| 53 |
config: JinaVLMVisionConfig
|
| 54 |
-
config_class = JinaVLMVisionConfig
|
| 55 |
-
base_model_prefix = ''
|
| 56 |
|
| 57 |
def __init__(self, config: JinaVLMVisionConfig, *args, **kwargs):
|
| 58 |
super().__init__(config, *args, **kwargs)
|
|
@@ -186,7 +183,11 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
|
|
| 186 |
pos = pos_emb[None, :, :].to(x.dtype)
|
| 187 |
return x + pos
|
| 188 |
|
| 189 |
-
def get_visual_features(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
x, shape = self.patch_embed(images)
|
| 191 |
if self.cls_embed is not None:
|
| 192 |
cls = self.cls_embed.view(1, 1, -1).expand(x.shape[0], -1, -1).to(x.dtype)
|
|
@@ -201,7 +202,11 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
|
|
| 201 |
hidden_states = []
|
| 202 |
attentions = []
|
| 203 |
for layer in self.layers:
|
| 204 |
-
x, attn = layer(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
hidden_states.append(x)
|
| 206 |
attentions.append(attn)
|
| 207 |
x = self.post_lnorm(x)
|
|
@@ -214,12 +219,15 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
|
|
| 214 |
)
|
| 215 |
|
| 216 |
def forward(
|
| 217 |
-
self,
|
|
|
|
|
|
|
|
|
|
| 218 |
) -> BaseModelOutput:
|
| 219 |
b, t, n, d = images.shape
|
| 220 |
mask = ~torch.all(images.view(b * t, n, d) == -1, dim=(1, 2), keepdim=True)
|
| 221 |
images = images.view(b * t, n, d)
|
| 222 |
-
out = self.get_visual_features(images)
|
| 223 |
image_features = out.hidden_states
|
| 224 |
|
| 225 |
features = []
|
|
@@ -230,14 +238,13 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
|
|
| 230 |
features.append(feats)
|
| 231 |
image_features = torch.cat(features, dim=-1)
|
| 232 |
image_features = image_features * mask
|
| 233 |
-
image_features = image_features.view(b, t, n, -1)
|
| 234 |
-
|
| 235 |
image_features = self.vl_connector(
|
| 236 |
image_features,
|
| 237 |
image_masks,
|
| 238 |
attn_implementation=self.config._attn_implementation,
|
|
|
|
| 239 |
)
|
| 240 |
-
|
| 241 |
return BaseModelOutput(
|
| 242 |
last_hidden_state=image_features,
|
| 243 |
hidden_states=out.hidden_states,
|
|
@@ -246,11 +253,7 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
|
|
| 246 |
|
| 247 |
|
| 248 |
class JinaVLMTextModel(JinaPreTrainedModel):
|
| 249 |
-
"""Decoder-only language model."""
|
| 250 |
-
|
| 251 |
config: JinaVLMTextConfig
|
| 252 |
-
config_class = JinaVLMTextConfig
|
| 253 |
-
base_model_prefix = ''
|
| 254 |
|
| 255 |
def __init__(self, config: JinaVLMTextConfig, *args, **kwargs):
|
| 256 |
super().__init__(config, *args, **kwargs)
|
|
@@ -297,7 +300,6 @@ class JinaVLMTextModel(JinaPreTrainedModel):
|
|
| 297 |
theta=self.config.rope_theta,
|
| 298 |
head_dim=self.config.block_config.attn_config.head_dim,
|
| 299 |
hidden_size=self.config.hidden_size,
|
| 300 |
-
n_heads=self.config.block_config.attn_config.n_heads,
|
| 301 |
partial_rotary_factor=self.config.partial_rotary_factor,
|
| 302 |
scaling=self.config.rope_scaling,
|
| 303 |
)
|
|
@@ -444,7 +446,7 @@ class JinaVLMTextModel(JinaPreTrainedModel):
|
|
| 444 |
|
| 445 |
|
| 446 |
class JinaVLM(JinaPreTrainedModel):
|
| 447 |
-
|
| 448 |
|
| 449 |
def __init__(self, config: JinaVLMConfig):
|
| 450 |
super().__init__(config)
|
|
@@ -493,7 +495,7 @@ class JinaVLM(JinaPreTrainedModel):
|
|
| 493 |
) -> BaseModelOutputWithPast:
|
| 494 |
image_features = None
|
| 495 |
if images is not None and images.shape[1] > 0:
|
| 496 |
-
image_out = self.vision_model(images, image_masks)
|
| 497 |
image_features = image_out.last_hidden_state
|
| 498 |
return self.language_model(
|
| 499 |
input_ids=input_ids,
|
|
@@ -512,10 +514,10 @@ class JinaVLM(JinaPreTrainedModel):
|
|
| 512 |
|
| 513 |
|
| 514 |
class JinaVLMForConditionalGeneration(JinaPreTrainedModel, GenerationMixin):
|
| 515 |
-
|
| 516 |
-
|
|
|
|
| 517 |
accepts_loss_kwargs = False
|
| 518 |
-
base_model_prefix = 'model'
|
| 519 |
config: JinaVLMConfig
|
| 520 |
|
| 521 |
def __init__(self, config: JinaVLMConfig):
|
|
|
|
| 27 |
TransformerBlock,
|
| 28 |
VisionLanguageConnector,
|
| 29 |
build_layer_norm,
|
| 30 |
+
resolve_causal_mask,
|
| 31 |
)
|
| 32 |
from .configuration_jvlm import JinaVLMConfig, JinaVLMTextConfig, JinaVLMVisionConfig
|
| 33 |
|
| 34 |
|
| 35 |
class JinaPreTrainedModel(PreTrainedModel):
|
| 36 |
config: JinaVLMConfig
|
|
|
|
| 37 |
base_model_prefix = 'model'
|
| 38 |
supports_gradient_checkpointing = True
|
| 39 |
_supports_flash_attn = True
|
|
|
|
| 50 |
|
| 51 |
class JinaVLMVisionModel(JinaPreTrainedModel):
|
| 52 |
config: JinaVLMVisionConfig
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def __init__(self, config: JinaVLMVisionConfig, *args, **kwargs):
|
| 55 |
super().__init__(config, *args, **kwargs)
|
|
|
|
| 183 |
pos = pos_emb[None, :, :].to(x.dtype)
|
| 184 |
return x + pos
|
| 185 |
|
| 186 |
+
def get_visual_features(
|
| 187 |
+
self,
|
| 188 |
+
images: torch.Tensor,
|
| 189 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 190 |
+
) -> BaseModelOutput:
|
| 191 |
x, shape = self.patch_embed(images)
|
| 192 |
if self.cls_embed is not None:
|
| 193 |
cls = self.cls_embed.view(1, 1, -1).expand(x.shape[0], -1, -1).to(x.dtype)
|
|
|
|
| 202 |
hidden_states = []
|
| 203 |
attentions = []
|
| 204 |
for layer in self.layers:
|
| 205 |
+
x, attn = layer(
|
| 206 |
+
x,
|
| 207 |
+
attn_implementation=self.config._attn_implementation,
|
| 208 |
+
**kwargs,
|
| 209 |
+
)
|
| 210 |
hidden_states.append(x)
|
| 211 |
attentions.append(attn)
|
| 212 |
x = self.post_lnorm(x)
|
|
|
|
| 219 |
)
|
| 220 |
|
| 221 |
def forward(
|
| 222 |
+
self,
|
| 223 |
+
images: torch.Tensor,
|
| 224 |
+
image_masks: torch.Tensor,
|
| 225 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 226 |
) -> BaseModelOutput:
|
| 227 |
b, t, n, d = images.shape
|
| 228 |
mask = ~torch.all(images.view(b * t, n, d) == -1, dim=(1, 2), keepdim=True)
|
| 229 |
images = images.view(b * t, n, d)
|
| 230 |
+
out = self.get_visual_features(images, **kwargs)
|
| 231 |
image_features = out.hidden_states
|
| 232 |
|
| 233 |
features = []
|
|
|
|
| 238 |
features.append(feats)
|
| 239 |
image_features = torch.cat(features, dim=-1)
|
| 240 |
image_features = image_features * mask
|
| 241 |
+
image_features = image_features.view(b, t, n, -1).contiguous()
|
|
|
|
| 242 |
image_features = self.vl_connector(
|
| 243 |
image_features,
|
| 244 |
image_masks,
|
| 245 |
attn_implementation=self.config._attn_implementation,
|
| 246 |
+
**kwargs,
|
| 247 |
)
|
|
|
|
| 248 |
return BaseModelOutput(
|
| 249 |
last_hidden_state=image_features,
|
| 250 |
hidden_states=out.hidden_states,
|
|
|
|
| 253 |
|
| 254 |
|
| 255 |
class JinaVLMTextModel(JinaPreTrainedModel):
|
|
|
|
|
|
|
| 256 |
config: JinaVLMTextConfig
|
|
|
|
|
|
|
| 257 |
|
| 258 |
def __init__(self, config: JinaVLMTextConfig, *args, **kwargs):
|
| 259 |
super().__init__(config, *args, **kwargs)
|
|
|
|
| 300 |
theta=self.config.rope_theta,
|
| 301 |
head_dim=self.config.block_config.attn_config.head_dim,
|
| 302 |
hidden_size=self.config.hidden_size,
|
|
|
|
| 303 |
partial_rotary_factor=self.config.partial_rotary_factor,
|
| 304 |
scaling=self.config.rope_scaling,
|
| 305 |
)
|
|
|
|
| 446 |
|
| 447 |
|
| 448 |
class JinaVLM(JinaPreTrainedModel):
|
| 449 |
+
config: JinaVLMConfig
|
| 450 |
|
| 451 |
def __init__(self, config: JinaVLMConfig):
|
| 452 |
super().__init__(config)
|
|
|
|
| 495 |
) -> BaseModelOutputWithPast:
|
| 496 |
image_features = None
|
| 497 |
if images is not None and images.shape[1] > 0:
|
| 498 |
+
image_out = self.vision_model(images, image_masks, **kwargs)
|
| 499 |
image_features = image_out.last_hidden_state
|
| 500 |
return self.language_model(
|
| 501 |
input_ids=input_ids,
|
|
|
|
| 514 |
|
| 515 |
|
| 516 |
class JinaVLMForConditionalGeneration(JinaPreTrainedModel, GenerationMixin):
|
| 517 |
+
_tied_weights_keys = {
|
| 518 |
+
'lm_head.weight': 'model.language_model.embedding.embedding.weight'
|
| 519 |
+
}
|
| 520 |
accepts_loss_kwargs = False
|
|
|
|
| 521 |
config: JinaVLMConfig
|
| 522 |
|
| 523 |
def __init__(self, config: JinaVLMConfig):
|
processing_jvlm.py
CHANGED
|
@@ -10,11 +10,14 @@ from transformers.image_utils import ImageInput
|
|
| 10 |
from transformers.processing_utils import (
|
| 11 |
AllKwargsForChatTemplate,
|
| 12 |
CommonKwargs,
|
|
|
|
| 13 |
ProcessorMixin,
|
| 14 |
Unpack,
|
| 15 |
)
|
| 16 |
from transformers.tokenization_utils_base import (
|
| 17 |
-
PaddingStrategy,
|
|
|
|
|
|
|
| 18 |
)
|
| 19 |
|
| 20 |
from .image_processing_jvlm import JinaVLMImageProcessor, JinaVLMImagesKwargs
|
|
@@ -38,7 +41,7 @@ class JinaVLMTextKwargs(TypedDict, total=False):
|
|
| 38 |
is_split_into_words: Optional[bool]
|
| 39 |
|
| 40 |
|
| 41 |
-
class
|
| 42 |
return_labels: Optional[bool]
|
| 43 |
|
| 44 |
|
|
@@ -171,8 +174,8 @@ class JinaVLMProcessor(ProcessorMixin):
|
|
| 171 |
def _collate(
|
| 172 |
self,
|
| 173 |
batch: Dict[str, List[Optional[np.ndarray]]],
|
| 174 |
-
|
| 175 |
-
|
| 176 |
padding: Union[
|
| 177 |
PaddingStrategy.MAX_LENGTH, PaddingStrategy.LONGEST
|
| 178 |
] = PaddingStrategy.MAX_LENGTH,
|
|
@@ -185,10 +188,10 @@ class JinaVLMProcessor(ProcessorMixin):
|
|
| 185 |
_padding_side = 'right'
|
| 186 |
if key in self.TEXT_KEYS:
|
| 187 |
_padding_side = padding_side
|
| 188 |
-
max_len =
|
| 189 |
dtype = np.int64
|
| 190 |
elif key in self.IMAGE_KEYS:
|
| 191 |
-
max_len =
|
| 192 |
dtype = np.int64
|
| 193 |
if key == 'images':
|
| 194 |
dtype = np.float32
|
|
@@ -214,22 +217,22 @@ class JinaVLMProcessor(ProcessorMixin):
|
|
| 214 |
shift = input_ids_padlens[:, np.newaxis, np.newaxis]
|
| 215 |
shift = np.repeat(shift, n_image_tokens, axis=2)
|
| 216 |
shift = np.repeat(shift, n_crops, axis=1)
|
| 217 |
-
image_input_idx[image_input_idx < 0] = -
|
| 218 |
image_input_idx = image_input_idx + shift
|
| 219 |
out['image_input_idx'] = image_input_idx
|
| 220 |
|
| 221 |
-
if
|
| 222 |
image_input_idx = out.get('image_input_idx', [])
|
| 223 |
n = len(image_input_idx)
|
| 224 |
for i in range(n):
|
| 225 |
arr = image_input_idx[i]
|
| 226 |
if arr.ndim > 0 and arr.size > 0:
|
| 227 |
n_image_tokens = arr.max()
|
| 228 |
-
if n_image_tokens >
|
| 229 |
raise RuntimeError(
|
| 230 |
'Image tokens truncation at sequence boundary. Max '
|
| 231 |
-
f'sequence length ({
|
| 232 |
-
'to fit the generated image tokens '
|
| 233 |
f'({n_image_tokens}). Consider increasing the max '
|
| 234 |
'sequence length or tweaking the image processing '
|
| 235 |
'parameters (`max_crops`, `max_pixels`) to reduce the '
|
|
@@ -386,7 +389,7 @@ class JinaVLMProcessor(ProcessorMixin):
|
|
| 386 |
text: Union[
|
| 387 |
None, TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
|
| 388 |
] = None,
|
| 389 |
-
**kwargs: Unpack[
|
| 390 |
) -> BatchFeature:
|
| 391 |
"""Main method to prepare for the model one or several sequences(s) and
|
| 392 |
image(s). This method forwards the `text` and `kwargs` arguments to the
|
|
@@ -489,9 +492,11 @@ class JinaVLMProcessor(ProcessorMixin):
|
|
| 489 |
)
|
| 490 |
|
| 491 |
outputs = defaultdict(list)
|
|
|
|
| 492 |
for idx in range(batch_size):
|
| 493 |
_token_ids = token_ids[idx]
|
| 494 |
_images = images[idx]
|
|
|
|
| 495 |
image_inputs = self.image_processor(_images, **images_kwargs)
|
| 496 |
image_crops = image_inputs['image_crops']
|
| 497 |
image_tokens = image_inputs['image_tokens']
|
|
@@ -510,14 +515,42 @@ class JinaVLMProcessor(ProcessorMixin):
|
|
| 510 |
outputs[k].append(v)
|
| 511 |
|
| 512 |
if padding != PaddingStrategy.DO_NOT_PAD:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
outputs = self._collate(
|
| 514 |
outputs,
|
| 515 |
-
|
| 516 |
-
|
| 517 |
padding=padding,
|
| 518 |
padding_side=padding_side,
|
| 519 |
)
|
| 520 |
return BatchFeature(data=outputs, tensor_type=return_tensors)
|
| 521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
|
| 523 |
JinaVLMProcessor.register_for_auto_class()
|
|
|
|
| 10 |
from transformers.processing_utils import (
|
| 11 |
AllKwargsForChatTemplate,
|
| 12 |
CommonKwargs,
|
| 13 |
+
MultiModalData,
|
| 14 |
ProcessorMixin,
|
| 15 |
Unpack,
|
| 16 |
)
|
| 17 |
from transformers.tokenization_utils_base import (
|
| 18 |
+
PaddingStrategy,
|
| 19 |
+
PreTokenizedInput,
|
| 20 |
+
TextInput,
|
| 21 |
)
|
| 22 |
|
| 23 |
from .image_processing_jvlm import JinaVLMImageProcessor, JinaVLMImagesKwargs
|
|
|
|
| 41 |
is_split_into_words: Optional[bool]
|
| 42 |
|
| 43 |
|
| 44 |
+
class JinaVLMProcessingKwargs(JinaVLMTextKwargs, JinaVLMImagesKwargs, CommonKwargs):
|
| 45 |
return_labels: Optional[bool]
|
| 46 |
|
| 47 |
|
|
|
|
| 174 |
def _collate(
|
| 175 |
self,
|
| 176 |
batch: Dict[str, List[Optional[np.ndarray]]],
|
| 177 |
+
text_max_sequence_length: Optional[int] = None,
|
| 178 |
+
image_max_sequence_length: Optional[int] = None,
|
| 179 |
padding: Union[
|
| 180 |
PaddingStrategy.MAX_LENGTH, PaddingStrategy.LONGEST
|
| 181 |
] = PaddingStrategy.MAX_LENGTH,
|
|
|
|
| 188 |
_padding_side = 'right'
|
| 189 |
if key in self.TEXT_KEYS:
|
| 190 |
_padding_side = padding_side
|
| 191 |
+
max_len = text_max_sequence_length
|
| 192 |
dtype = np.int64
|
| 193 |
elif key in self.IMAGE_KEYS:
|
| 194 |
+
max_len = image_max_sequence_length
|
| 195 |
dtype = np.int64
|
| 196 |
if key == 'images':
|
| 197 |
dtype = np.float32
|
|
|
|
| 217 |
shift = input_ids_padlens[:, np.newaxis, np.newaxis]
|
| 218 |
shift = np.repeat(shift, n_image_tokens, axis=2)
|
| 219 |
shift = np.repeat(shift, n_crops, axis=1)
|
| 220 |
+
image_input_idx[image_input_idx < 0] = -text_max_sequence_length
|
| 221 |
image_input_idx = image_input_idx + shift
|
| 222 |
out['image_input_idx'] = image_input_idx
|
| 223 |
|
| 224 |
+
if text_max_sequence_length is not None:
|
| 225 |
image_input_idx = out.get('image_input_idx', [])
|
| 226 |
n = len(image_input_idx)
|
| 227 |
for i in range(n):
|
| 228 |
arr = image_input_idx[i]
|
| 229 |
if arr.ndim > 0 and arr.size > 0:
|
| 230 |
n_image_tokens = arr.max()
|
| 231 |
+
if n_image_tokens > text_max_sequence_length - 3:
|
| 232 |
raise RuntimeError(
|
| 233 |
'Image tokens truncation at sequence boundary. Max '
|
| 234 |
+
f'sequence length ({text_max_sequence_length}) is too '
|
| 235 |
+
'small to fit the generated image tokens '
|
| 236 |
f'({n_image_tokens}). Consider increasing the max '
|
| 237 |
'sequence length or tweaking the image processing '
|
| 238 |
'parameters (`max_crops`, `max_pixels`) to reduce the '
|
|
|
|
| 389 |
text: Union[
|
| 390 |
None, TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
|
| 391 |
] = None,
|
| 392 |
+
**kwargs: Unpack[JinaVLMProcessingKwargs],
|
| 393 |
) -> BatchFeature:
|
| 394 |
"""Main method to prepare for the model one or several sequences(s) and
|
| 395 |
image(s). This method forwards the `text` and `kwargs` arguments to the
|
|
|
|
| 492 |
)
|
| 493 |
|
| 494 |
outputs = defaultdict(list)
|
| 495 |
+
n_images = []
|
| 496 |
for idx in range(batch_size):
|
| 497 |
_token_ids = token_ids[idx]
|
| 498 |
_images = images[idx]
|
| 499 |
+
n_images.append(len(_images))
|
| 500 |
image_inputs = self.image_processor(_images, **images_kwargs)
|
| 501 |
image_crops = image_inputs['image_crops']
|
| 502 |
image_tokens = image_inputs['image_tokens']
|
|
|
|
| 515 |
outputs[k].append(v)
|
| 516 |
|
| 517 |
if padding != PaddingStrategy.DO_NOT_PAD:
|
| 518 |
+
text_max_sequence_length = max_length or self.max_sequence_length
|
| 519 |
+
max_crops = max_crops or self.max_crops
|
| 520 |
+
max_n_images = max(n_images)
|
| 521 |
+
image_max_sequence_length = (max_crops + 1) * max_n_images
|
| 522 |
outputs = self._collate(
|
| 523 |
outputs,
|
| 524 |
+
text_max_sequence_length=text_max_sequence_length,
|
| 525 |
+
image_max_sequence_length=image_max_sequence_length,
|
| 526 |
padding=padding,
|
| 527 |
padding_side=padding_side,
|
| 528 |
)
|
| 529 |
return BatchFeature(data=outputs, tensor_type=return_tensors)
|
| 530 |
|
| 531 |
+
def _get_num_multimodal_tokens(
|
| 532 |
+
self,
|
| 533 |
+
image_sizes: Optional[List[List[int]]] = None,
|
| 534 |
+
**kwargs: Unpack[JinaVLMImagesKwargs],
|
| 535 |
+
) -> MultiModalData:
|
| 536 |
+
"""Computes the number of placeholder tokens needed for multimodal inputs with
|
| 537 |
+
the given sizes.
|
| 538 |
+
|
| 539 |
+
Args:
|
| 540 |
+
image_sizes (`list[list[int]]`, *optional*):
|
| 541 |
+
The input sizes formatted as (height, width) per each image.
|
| 542 |
+
Returns:
|
| 543 |
+
`MultiModalData`: A `MultiModalData` object holding number of tokens per
|
| 544 |
+
each of the provided input modalities, along with other useful data.
|
| 545 |
+
"""
|
| 546 |
+
data = {}
|
| 547 |
+
if image_sizes is not None:
|
| 548 |
+
n_patches = [
|
| 549 |
+
self.image_processor.get_n_image_patches(h, w, **kwargs)
|
| 550 |
+
for h, w in image_sizes
|
| 551 |
+
]
|
| 552 |
+
data.update({'num_image_tokens': n_patches, 'num_image_patches': n_patches})
|
| 553 |
+
return MultiModalData(**data)
|
| 554 |
+
|
| 555 |
|
| 556 |
JinaVLMProcessor.register_for_auto_class()
|
test_jvlm.py
CHANGED
|
@@ -11,7 +11,10 @@ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
| 11 |
|
| 12 |
import torch
|
| 13 |
from transformers import (
|
| 14 |
-
AutoModelForCausalLM,
|
|
|
|
|
|
|
|
|
|
| 15 |
)
|
| 16 |
from transformers.utils import is_flash_attn_2_available
|
| 17 |
|
|
@@ -60,7 +63,8 @@ def _build_conversations(
|
|
| 60 |
try:
|
| 61 |
result = urlparse(_path)
|
| 62 |
return result.scheme in ('http', 'https')
|
| 63 |
-
except:
|
|
|
|
| 64 |
return False
|
| 65 |
|
| 66 |
images = images or []
|
|
@@ -83,8 +87,9 @@ def _build_conversations(
|
|
| 83 |
images = [TEST_IMAGE]
|
| 84 |
n_images = len(images)
|
| 85 |
prompts = (
|
| 86 |
-
['Describe the image in 100 words']
|
| 87 |
-
|
|
|
|
| 88 |
)
|
| 89 |
n_prompts = len(prompts)
|
| 90 |
|
|
@@ -119,8 +124,16 @@ def _build_conversations(
|
|
| 119 |
allimages = []
|
| 120 |
allprompts = []
|
| 121 |
ordinals = [
|
| 122 |
-
'first',
|
| 123 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
]
|
| 125 |
for images, prompt in examples:
|
| 126 |
content = []
|
|
@@ -130,15 +143,17 @@ def _build_conversations(
|
|
| 130 |
content.append({'type': 'text', 'text': prompt})
|
| 131 |
if len(images) > 1 and image_labels:
|
| 132 |
for idx, img in enumerate(images):
|
| 133 |
-
ordinal = ordinals[idx] if idx < len(ordinals) else f'{idx+1}th'
|
| 134 |
image = images[idx]
|
| 135 |
descriptor = f'url: {image}'
|
| 136 |
if os.path.isfile(image):
|
| 137 |
descriptor = f'filename: {os.path.basename(image)}'
|
| 138 |
-
content.append(
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
| 142 |
content.append({'type': 'image', 'image': img})
|
| 143 |
else:
|
| 144 |
content.extend([{'type': 'image', 'image': image} for image in images])
|
|
@@ -189,9 +204,7 @@ def _token_usage_report(
|
|
| 189 |
tokens_per_image_list = []
|
| 190 |
|
| 191 |
# Find all img_start and img_end positions in input_ids
|
| 192 |
-
start_positions = (input_ids == image_start_id).nonzero(
|
| 193 |
-
as_tuple=True
|
| 194 |
-
)[0].tolist()
|
| 195 |
end_positions = (input_ids == image_end_id).nonzero(as_tuple=True)[0].tolist()
|
| 196 |
|
| 197 |
if len(start_positions) > 0 and len(end_positions) > 0:
|
|
@@ -211,9 +224,8 @@ def _token_usage_report(
|
|
| 211 |
# Get the start and end indices for this image
|
| 212 |
start_idx_begin = idx * n_starts_per_image
|
| 213 |
end_idx_end = (idx + 1) * n_starts_per_image
|
| 214 |
-
if (
|
| 215 |
-
|
| 216 |
-
end_idx_end <= len(end_positions)
|
| 217 |
):
|
| 218 |
# First start position and last end position define the image span
|
| 219 |
first_start = start_positions[start_idx_begin]
|
|
@@ -233,10 +245,10 @@ def _token_usage_report(
|
|
| 233 |
|
| 234 |
for idx in range(n_images):
|
| 235 |
n_tokens = tokens_per_image_list[idx] if idx < len(tokens_per_image_list) else 0
|
| 236 |
-
pct =
|
| 237 |
report.append(f'├── Image {idx + 1} → {n_tokens} tokens ({pct:.1f}%)')
|
| 238 |
|
| 239 |
-
text_pct =
|
| 240 |
report.append(f'└── Text: {text_token_count} tokens ({text_pct:.1f}%)')
|
| 241 |
|
| 242 |
return '\n'.join(report)
|
|
@@ -253,7 +265,7 @@ def test_jvlm():
|
|
| 253 |
help=(
|
| 254 |
'Model path (default: `"."`). Set this to `"jinaai/jina-vlm-v1"` if you '
|
| 255 |
'are running this script outside this repo.'
|
| 256 |
-
)
|
| 257 |
)
|
| 258 |
parser.add_argument(
|
| 259 |
'-i',
|
|
@@ -339,7 +351,9 @@ def test_jvlm():
|
|
| 339 |
print(f'Using dtype: {dtype}')
|
| 340 |
print('Model path: ', args.model)
|
| 341 |
processor = AutoProcessor.from_pretrained(
|
| 342 |
-
args.model,
|
|
|
|
|
|
|
| 343 |
)
|
| 344 |
model = AutoModelForCausalLM.from_pretrained(
|
| 345 |
args.model,
|
|
@@ -356,13 +370,13 @@ def test_jvlm():
|
|
| 356 |
print('Done ✅')
|
| 357 |
print()
|
| 358 |
|
| 359 |
-
print(
|
| 360 |
conversations, images, prompts = _build_conversations(
|
| 361 |
args.image,
|
| 362 |
args.prompt,
|
| 363 |
map_mode=args.map,
|
| 364 |
prompt_first=args.prompt_first,
|
| 365 |
-
image_labels=args.image_labels
|
| 366 |
)
|
| 367 |
n_conversations = len(conversations)
|
| 368 |
print(f'Built {n_conversations} conversations 🚀')
|
|
@@ -434,25 +448,28 @@ def test_jvlm():
|
|
| 434 |
print(f'├── 🖼️Images: {images[idx]}')
|
| 435 |
print(f'├── 📜Prompt: {prompts[idx]}')
|
| 436 |
print(f'├── 💬Chat:{texts[idx]}')
|
| 437 |
-
print(
|
| 438 |
ith_inputs = {k: v[idx].unsqueeze(0) for k, v in device_inputs.items()}
|
| 439 |
with (
|
| 440 |
timer,
|
| 441 |
torch.no_grad(),
|
| 442 |
-
torch.autocast(
|
|
|
|
|
|
|
| 443 |
):
|
| 444 |
output = model.generate(
|
| 445 |
**ith_inputs,
|
| 446 |
streamer=streamer,
|
| 447 |
generation_config=GenerationConfig(
|
| 448 |
-
max_new_tokens=args.max_tokens,
|
|
|
|
| 449 |
),
|
| 450 |
return_dict_in_generate=True,
|
| 451 |
use_model_defaults=True,
|
| 452 |
)
|
| 453 |
generation_time += timer.time
|
| 454 |
|
| 455 |
-
out = output.sequences[0][len(input_prompts[idx].tolist()):]
|
| 456 |
generated_tokens += len(out)
|
| 457 |
print('Token usage report:')
|
| 458 |
print(token_usage_reports[idx])
|
|
@@ -470,7 +487,8 @@ def test_jvlm():
|
|
| 470 |
output = model.generate(
|
| 471 |
**device_inputs,
|
| 472 |
generation_config=GenerationConfig(
|
| 473 |
-
max_new_tokens=args.max_tokens,
|
|
|
|
| 474 |
),
|
| 475 |
return_dict_in_generate=True,
|
| 476 |
use_model_defaults=True,
|
|
@@ -478,7 +496,7 @@ def test_jvlm():
|
|
| 478 |
generation_time = timer.time
|
| 479 |
|
| 480 |
for idx in range(n_conversations):
|
| 481 |
-
out = output.sequences[idx][len(input_prompts[idx].tolist()):]
|
| 482 |
generated_tokens += len(out)
|
| 483 |
response = processor.tokenizer.decode(out, skip_special_tokens=True)
|
| 484 |
print(f'* Conversation {idx + 1}/{n_conversations}')
|
|
|
|
| 11 |
|
| 12 |
import torch
|
| 13 |
from transformers import (
|
| 14 |
+
AutoModelForCausalLM,
|
| 15 |
+
AutoProcessor,
|
| 16 |
+
GenerationConfig,
|
| 17 |
+
TextStreamer,
|
| 18 |
)
|
| 19 |
from transformers.utils import is_flash_attn_2_available
|
| 20 |
|
|
|
|
| 63 |
try:
|
| 64 |
result = urlparse(_path)
|
| 65 |
return result.scheme in ('http', 'https')
|
| 66 |
+
except Exception as e:
|
| 67 |
+
_ = str(e)
|
| 68 |
return False
|
| 69 |
|
| 70 |
images = images or []
|
|
|
|
| 87 |
images = [TEST_IMAGE]
|
| 88 |
n_images = len(images)
|
| 89 |
prompts = (
|
| 90 |
+
['Describe the image in 100 words']
|
| 91 |
+
if n_images == 1 or map_mode
|
| 92 |
+
else ['Describe the images in 100 words']
|
| 93 |
)
|
| 94 |
n_prompts = len(prompts)
|
| 95 |
|
|
|
|
| 124 |
allimages = []
|
| 125 |
allprompts = []
|
| 126 |
ordinals = [
|
| 127 |
+
'first',
|
| 128 |
+
'second',
|
| 129 |
+
'third',
|
| 130 |
+
'fourth',
|
| 131 |
+
'fifth',
|
| 132 |
+
'sixth',
|
| 133 |
+
'seventh',
|
| 134 |
+
'eighth',
|
| 135 |
+
'ninth',
|
| 136 |
+
'tenth',
|
| 137 |
]
|
| 138 |
for images, prompt in examples:
|
| 139 |
content = []
|
|
|
|
| 143 |
content.append({'type': 'text', 'text': prompt})
|
| 144 |
if len(images) > 1 and image_labels:
|
| 145 |
for idx, img in enumerate(images):
|
| 146 |
+
ordinal = ordinals[idx] if idx < len(ordinals) else f'{idx + 1}th'
|
| 147 |
image = images[idx]
|
| 148 |
descriptor = f'url: {image}'
|
| 149 |
if os.path.isfile(image):
|
| 150 |
descriptor = f'filename: {os.path.basename(image)}'
|
| 151 |
+
content.append(
|
| 152 |
+
{
|
| 153 |
+
'type': 'text',
|
| 154 |
+
'text': f'(this is the {ordinal} image, {descriptor})',
|
| 155 |
+
}
|
| 156 |
+
)
|
| 157 |
content.append({'type': 'image', 'image': img})
|
| 158 |
else:
|
| 159 |
content.extend([{'type': 'image', 'image': image} for image in images])
|
|
|
|
| 204 |
tokens_per_image_list = []
|
| 205 |
|
| 206 |
# Find all img_start and img_end positions in input_ids
|
| 207 |
+
start_positions = (input_ids == image_start_id).nonzero(as_tuple=True)[0].tolist()
|
|
|
|
|
|
|
| 208 |
end_positions = (input_ids == image_end_id).nonzero(as_tuple=True)[0].tolist()
|
| 209 |
|
| 210 |
if len(start_positions) > 0 and len(end_positions) > 0:
|
|
|
|
| 224 |
# Get the start and end indices for this image
|
| 225 |
start_idx_begin = idx * n_starts_per_image
|
| 226 |
end_idx_end = (idx + 1) * n_starts_per_image
|
| 227 |
+
if start_idx_begin < len(start_positions) and end_idx_end <= len(
|
| 228 |
+
end_positions
|
|
|
|
| 229 |
):
|
| 230 |
# First start position and last end position define the image span
|
| 231 |
first_start = start_positions[start_idx_begin]
|
|
|
|
| 245 |
|
| 246 |
for idx in range(n_images):
|
| 247 |
n_tokens = tokens_per_image_list[idx] if idx < len(tokens_per_image_list) else 0
|
| 248 |
+
pct = n_tokens / max_sequence_length * 100
|
| 249 |
report.append(f'├── Image {idx + 1} → {n_tokens} tokens ({pct:.1f}%)')
|
| 250 |
|
| 251 |
+
text_pct = text_token_count / max_sequence_length * 100
|
| 252 |
report.append(f'└── Text: {text_token_count} tokens ({text_pct:.1f}%)')
|
| 253 |
|
| 254 |
return '\n'.join(report)
|
|
|
|
| 265 |
help=(
|
| 266 |
'Model path (default: `"."`). Set this to `"jinaai/jina-vlm-v1"` if you '
|
| 267 |
'are running this script outside this repo.'
|
| 268 |
+
),
|
| 269 |
)
|
| 270 |
parser.add_argument(
|
| 271 |
'-i',
|
|
|
|
| 351 |
print(f'Using dtype: {dtype}')
|
| 352 |
print('Model path: ', args.model)
|
| 353 |
processor = AutoProcessor.from_pretrained(
|
| 354 |
+
args.model,
|
| 355 |
+
trust_remote_code=True,
|
| 356 |
+
use_fast=False,
|
| 357 |
)
|
| 358 |
model = AutoModelForCausalLM.from_pretrained(
|
| 359 |
args.model,
|
|
|
|
| 370 |
print('Done ✅')
|
| 371 |
print()
|
| 372 |
|
| 373 |
+
print("--- Let's create some conversations ...")
|
| 374 |
conversations, images, prompts = _build_conversations(
|
| 375 |
args.image,
|
| 376 |
args.prompt,
|
| 377 |
map_mode=args.map,
|
| 378 |
prompt_first=args.prompt_first,
|
| 379 |
+
image_labels=args.image_labels,
|
| 380 |
)
|
| 381 |
n_conversations = len(conversations)
|
| 382 |
print(f'Built {n_conversations} conversations 🚀')
|
|
|
|
| 448 |
print(f'├── 🖼️Images: {images[idx]}')
|
| 449 |
print(f'├── 📜Prompt: {prompts[idx]}')
|
| 450 |
print(f'├── 💬Chat:{texts[idx]}')
|
| 451 |
+
print('└── 🧠Response:', end='')
|
| 452 |
ith_inputs = {k: v[idx].unsqueeze(0) for k, v in device_inputs.items()}
|
| 453 |
with (
|
| 454 |
timer,
|
| 455 |
torch.no_grad(),
|
| 456 |
+
torch.autocast(
|
| 457 |
+
device.type, enabled=(device.type != 'mps'), dtype=dtype
|
| 458 |
+
),
|
| 459 |
):
|
| 460 |
output = model.generate(
|
| 461 |
**ith_inputs,
|
| 462 |
streamer=streamer,
|
| 463 |
generation_config=GenerationConfig(
|
| 464 |
+
max_new_tokens=args.max_tokens,
|
| 465 |
+
do_sample=False,
|
| 466 |
),
|
| 467 |
return_dict_in_generate=True,
|
| 468 |
use_model_defaults=True,
|
| 469 |
)
|
| 470 |
generation_time += timer.time
|
| 471 |
|
| 472 |
+
out = output.sequences[0][len(input_prompts[idx].tolist()) :]
|
| 473 |
generated_tokens += len(out)
|
| 474 |
print('Token usage report:')
|
| 475 |
print(token_usage_reports[idx])
|
|
|
|
| 487 |
output = model.generate(
|
| 488 |
**device_inputs,
|
| 489 |
generation_config=GenerationConfig(
|
| 490 |
+
max_new_tokens=args.max_tokens,
|
| 491 |
+
do_sample=False,
|
| 492 |
),
|
| 493 |
return_dict_in_generate=True,
|
| 494 |
use_model_defaults=True,
|
|
|
|
| 496 |
generation_time = timer.time
|
| 497 |
|
| 498 |
for idx in range(n_conversations):
|
| 499 |
+
out = output.sequences[idx][len(input_prompts[idx].tolist()) :]
|
| 500 |
generated_tokens += len(out)
|
| 501 |
response = processor.tokenizer.decode(out, skip_special_tokens=True)
|
| 502 |
print(f'* Conversation {idx + 1}/{n_conversations}')
|