gmastrapas commited on
Commit
331a585
·
verified ·
1 Parent(s): 8605b06

Model update

Browse files
Files changed (4) hide show
  1. README.md +12 -16
  2. blocks_jvlm.py +2 -0
  3. modeling_jvlm.py +1 -0
  4. test.py +70 -0
README.md CHANGED
@@ -209,7 +209,7 @@ python test_jvlm.py -i assets/the_persistence_of_memory.jpg -p "What's in this i
209
  python test_jvlm.py -i https://picsum.photos/id/1025/800/600.jpg -p "Describe this image"
210
 
211
  # Multiple images (local and remote)
212
- python test_jvlm.py -i https://picsum.photos/id/1015/800/600.jpg -i https://picsum.photos/id/1016/800/600.jpg -i https://picsum.photos/id/1021/800/600.jpg -p "What is the difference between these pictures?"
213
 
214
  # Text only input
215
  python test_jvlm.py -p "How many planets are in our solar system?"
@@ -302,7 +302,7 @@ model = AutoModelForCausalLM.from_pretrained(
302
  #
303
  # model = AutoModelForCausalLM.from_pretrained(
304
  # 'jinaai/jina-vlm-v1',
305
- # torch_dtype=torch.bfloat16,
306
  # attn_implementation='flash_attention_2',
307
  # device_map='auto',
308
  # trust_remote_code=True
@@ -317,7 +317,7 @@ conversation = [
317
  'type': 'image',
318
  'image': image,
319
  },
320
- {'type': 'text', 'text': 'Describe this image.'},
321
  ],
322
  }
323
  ]
@@ -347,14 +347,10 @@ inputs = processor(
347
 
348
  # Move the inputs to the appropriate device and/or dtype
349
  device = torch.device('cuda')
350
- dtype = torch.float16
351
  model_inputs = {}
352
  for k, v in inputs.items():
353
  if isinstance(v, torch.Tensor):
354
- if v.is_floating_point():
355
- model_inputs[k] = v.to(device, dtype=dtype, non_blocking=True)
356
- else:
357
- model_inputs[k] = v.to(device, non_blocking=True)
358
  else:
359
  model_inputs[k] = v
360
 
@@ -362,7 +358,7 @@ for k, v in inputs.items():
362
  output = model.generate(
363
  **model_inputs,
364
  generation_config=GenerationConfig(
365
- max_new_tokens=20, do_sample=False,
366
  ),
367
  return_dict_in_generate=True,
368
  use_model_defaults=True,
@@ -390,7 +386,7 @@ processor = AutoProcessor.from_pretrained(
390
  model = AutoModelForCausalLM.from_pretrained(
391
  'jinaai/jina-vlm-v1',
392
  device_map='auto',
393
- torch_dtype=torch.bfloat16,
394
  attn_implementation='flash_attention_2',
395
  trust_remote_code=True
396
  )
@@ -441,7 +437,7 @@ for k, v in inputs.items():
441
  output = model.generate(
442
  **model_inputs,
443
  generation_config=GenerationConfig(
444
- max_new_tokens=20, do_sample=False,
445
  ),
446
  return_dict_in_generate=True,
447
  use_model_defaults=True,
@@ -468,7 +464,7 @@ processor = AutoProcessor.from_pretrained(
468
  model = AutoModelForCausalLM.from_pretrained(
469
  'jinaai/jina-vlm-v1',
470
  device_map='auto',
471
- torch_dtype=torch.bfloat16,
472
  attn_implementation='flash_attention_2',
473
  trust_remote_code=True
474
  )
@@ -508,7 +504,7 @@ for k, v in inputs.items():
508
  output = model.generate(
509
  **model_inputs,
510
  generation_config=GenerationConfig(
511
- max_new_tokens=20, do_sample=False,
512
  ),
513
  return_dict_in_generate=True,
514
  use_model_defaults=True,
@@ -535,7 +531,7 @@ processor = AutoProcessor.from_pretrained(
535
  model = AutoModelForCausalLM.from_pretrained(
536
  'jinaai/jina-vlm-v1',
537
  device_map='auto',
538
- torch_dtype=torch.bfloat16,
539
  attn_implementation='flash_attention_2',
540
  trust_remote_code=True
541
  )
@@ -599,7 +595,7 @@ processor = AutoProcessor.from_pretrained(
599
  model = AutoModelForCausalLM.from_pretrained(
600
  'jinaai/jina-vlm-v1',
601
  device_map='auto',
602
- torch_dtype=torch.bfloat16,
603
  attn_implementation='flash_attention_2',
604
  trust_remote_code=True
605
  )
@@ -701,7 +697,7 @@ processor = AutoProcessor.from_pretrained(
701
  model = AutoModel.from_pretrained(
702
  'jinaai/jina-vlm-v1',
703
  device_map='auto',
704
- torch_dtype=torch.bfloat16,
705
  attn_implementation='flash_attention_2',
706
  trust_remote_code=True
707
  )
 
209
  python test_jvlm.py -i https://picsum.photos/id/1025/800/600.jpg -p "Describe this image"
210
 
211
  # Multiple images (local and remote)
212
+ python test_jvlm.py -i https://picsum.photos/id/1015/800/600.jpg -i https://picsum.photos/id/1016/800/600.jpg -i https://picsum.photos/id/1021/800/600.jpg -p "Describe these images"
213
 
214
  # Text only input
215
  python test_jvlm.py -p "How many planets are in our solar system?"
 
302
  #
303
  # model = AutoModelForCausalLM.from_pretrained(
304
  # 'jinaai/jina-vlm-v1',
305
+ # dtype=torch.bfloat16,
306
  # attn_implementation='flash_attention_2',
307
  # device_map='auto',
308
  # trust_remote_code=True
 
317
  'type': 'image',
318
  'image': image,
319
  },
320
+ {'type': 'text', 'text': 'Describe this image'},
321
  ],
322
  }
323
  ]
 
347
 
348
  # Move the inputs to the appropriate device and/or dtype
349
  device = torch.device('cuda')
 
350
  model_inputs = {}
351
  for k, v in inputs.items():
352
  if isinstance(v, torch.Tensor):
353
+ model_inputs[k] = v.to(device, non_blocking=True)
 
 
 
354
  else:
355
  model_inputs[k] = v
356
 
 
358
  output = model.generate(
359
  **model_inputs,
360
  generation_config=GenerationConfig(
361
+ max_new_tokens=1024, do_sample=False,
362
  ),
363
  return_dict_in_generate=True,
364
  use_model_defaults=True,
 
386
  model = AutoModelForCausalLM.from_pretrained(
387
  'jinaai/jina-vlm-v1',
388
  device_map='auto',
389
+ dtype=torch.bfloat16,
390
  attn_implementation='flash_attention_2',
391
  trust_remote_code=True
392
  )
 
437
  output = model.generate(
438
  **model_inputs,
439
  generation_config=GenerationConfig(
440
+ max_new_tokens=1024, do_sample=False,
441
  ),
442
  return_dict_in_generate=True,
443
  use_model_defaults=True,
 
464
  model = AutoModelForCausalLM.from_pretrained(
465
  'jinaai/jina-vlm-v1',
466
  device_map='auto',
467
+ dtype=torch.bfloat16,
468
  attn_implementation='flash_attention_2',
469
  trust_remote_code=True
470
  )
 
504
  output = model.generate(
505
  **model_inputs,
506
  generation_config=GenerationConfig(
507
+ max_new_tokens=1024, do_sample=False,
508
  ),
509
  return_dict_in_generate=True,
510
  use_model_defaults=True,
 
531
  model = AutoModelForCausalLM.from_pretrained(
532
  'jinaai/jina-vlm-v1',
533
  device_map='auto',
534
+ dtype=torch.bfloat16,
535
  attn_implementation='flash_attention_2',
536
  trust_remote_code=True
537
  )
 
595
  model = AutoModelForCausalLM.from_pretrained(
596
  'jinaai/jina-vlm-v1',
597
  device_map='auto',
598
+ dtype=torch.bfloat16,
599
  attn_implementation='flash_attention_2',
600
  trust_remote_code=True
601
  )
 
697
  model = AutoModel.from_pretrained(
698
  'jinaai/jina-vlm-v1',
699
  device_map='auto',
700
+ dtype=torch.bfloat16,
701
  attn_implementation='flash_attention_2',
702
  trust_remote_code=True
703
  )
blocks_jvlm.py CHANGED
@@ -1294,6 +1294,7 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
1294
  # image_features:
1295
  # (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
1296
  bs, ncrops = image_features.shape[:2]
 
1297
 
1298
  if self.padding_embed_type is not None:
1299
  assert image_masks is not None
@@ -1322,6 +1323,7 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
1322
  partial_pad, -1
1323
  )
1324
 
 
1325
  image_features = self.feature_dropout(image_features)
1326
  image_features = image_features.reshape((bs, ncrops) + self.n_patches + (-1,))
1327
  pad_h = self.n_patches[0] % self.pooling_h
 
1294
  # image_features:
1295
  # (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
1296
  bs, ncrops = image_features.shape[:2]
1297
+ ogtype = image_features.dtype
1298
 
1299
  if self.padding_embed_type is not None:
1300
  assert image_masks is not None
 
1323
  partial_pad, -1
1324
  )
1325
 
1326
+ image_features = image_features.to(dtype=ogtype)
1327
  image_features = self.feature_dropout(image_features)
1328
  image_features = image_features.reshape((bs, ncrops) + self.n_patches + (-1,))
1329
  pad_h = self.n_patches[0] % self.pooling_h
modeling_jvlm.py CHANGED
@@ -388,6 +388,7 @@ class JinaVLMTextModel(JinaPreTrainedModel):
388
  batch_idx = torch.arange(bs, device=x.device)
389
  batch_idx = torch.tile(batch_idx[:, None], [1, image_features.shape[1]])
390
  image_features = image_features.to(x.device)
 
391
  x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
392
 
393
  if not self.rope:
 
388
  batch_idx = torch.arange(bs, device=x.device)
389
  batch_idx = torch.tile(batch_idx[:, None], [1, image_features.shape[1]])
390
  image_features = image_features.to(x.device)
391
+ x = x.clone() # Clone x to avoid in-place operation on leaf tensor
392
  x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
393
 
394
  if not self.rope:
test.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
3
+
4
+ processor = AutoProcessor.from_pretrained(
5
+ 'jinaai/jina-vlm-v1', use_fast=False, trust_remote_code=True
6
+ )
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ 'jinaai/jina-vlm-v1',
9
+ device_map='auto',
10
+ torch_dtype=torch.bfloat16,
11
+ attn_implementation='flash_attention_2',
12
+ trust_remote_code=True
13
+ )
14
+ images = [
15
+ 'https://picsum.photos/id/22/4434/3729',
16
+ 'https://picsum.photos/id/49/1280/792'
17
+ ]
18
+ conversations = [
19
+ [
20
+ {
21
+ 'role': 'user',
22
+ 'content': [
23
+ {'type': 'image', 'image': images[0]},
24
+ {'type': 'text', 'text': 'What is the man doing in this image?'},
25
+ ],
26
+ }
27
+ ],
28
+ [
29
+ {
30
+ 'role': 'user',
31
+ 'content': [
32
+ {'type': 'image', 'image': images[1]},
33
+ {'type': 'text', 'text': 'What country\'s flag is in this image?'},
34
+ ],
35
+ }
36
+ ],
37
+
38
+ ]
39
+ texts = processor.apply_chat_template(conversations, add_generation_prompt=True)
40
+ inputs = processor(
41
+ text=texts,
42
+ images=images,
43
+ padding='longest',
44
+ return_tensors='pt',
45
+ )
46
+ device = torch.device('cuda')
47
+ dtype = torch.bfloat16
48
+ model_inputs = {}
49
+ for k, v in inputs.items():
50
+ if isinstance(v, torch.Tensor):
51
+ if v.is_floating_point():
52
+ model_inputs[k] = v.to(device, dtype=dtype, non_blocking=True)
53
+ else:
54
+ model_inputs[k] = v.to(device, non_blocking=True)
55
+ else:
56
+ model_inputs[k] = v
57
+
58
+ output = model.generate(
59
+ **model_inputs,
60
+ generation_config=GenerationConfig(
61
+ max_new_tokens=1024, do_sample=False,
62
+ ),
63
+ return_dict_in_generate=True,
64
+ use_model_defaults=True,
65
+ )
66
+ input_sequence_length = inputs.input_ids.shape[-1]
67
+ for idx in range(len(output.sequences)):
68
+ gen_ids = output.sequences[idx][input_sequence_length:]
69
+ response = processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
70
+ print(response)