OccamRazor commited on
Commit
06ec111
·
1 Parent(s): a596e84

Vulkan: Add DP4A MMQ and Q8_1 quantization shader (llama/12135)

Browse files

* Vulkan: Add DP4A MMQ and Q8_1 quantization shader

* Add q4_0 x q8_1 matrix matrix multiplication support

* Vulkan: Add int8 coopmat MMQ support

* Vulkan: Add q4_1, q5_0 and q5_1 quants, improve integer dot code

* Add GL_EXT_integer_dot_product check

* Remove ggml changes, fix mmq pipeline picker

* Remove ggml changes, restore Intel coopmat behaviour

* Fix glsl compile attempt when integer vec dot is not supported

* Remove redundant code, use non-saturating integer dot, enable all matmul sizes for mmq

* Remove redundant comment

* Fix integer dot check

* Fix compile issue with unsupported int dot glslc

* Update Windows build Vulkan SDK version

ggml/src/ggml-vulkan/CMakeLists.txt CHANGED
@@ -69,6 +69,20 @@ if (Vulkan_FOUND)
69
  add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
70
  endif()
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
73
  target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
74
 
 
69
  add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
70
  endif()
71
 
72
+ # Compile a test shader to determine whether GL_EXT_integer_dot_product is supported.
73
+ # If it's not, there will be an error to stderr.
74
+ # If it's supported, set a define to indicate that we should compile those shaders
75
+ execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
76
+ OUTPUT_VARIABLE glslc_output
77
+ ERROR_VARIABLE glslc_error)
78
+
79
+ if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*")
80
+ message(STATUS "GL_EXT_integer_dot_product not supported by glslc")
81
+ else()
82
+ message(STATUS "GL_EXT_integer_dot_product supported by glslc")
83
+ add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
84
+ endif()
85
+
86
  target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
87
  target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
88
 
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -234,6 +234,8 @@ struct vk_device_struct {
234
  bool float_controls_rte_fp16;
235
  bool subgroup_add;
236
 
 
 
237
  bool subgroup_size_control;
238
  uint32_t subgroup_min_size;
239
  uint32_t subgroup_max_size;
@@ -245,6 +247,12 @@ struct vk_device_struct {
245
  uint32_t coopmat_m;
246
  uint32_t coopmat_n;
247
  uint32_t coopmat_k;
 
 
 
 
 
 
248
  bool coopmat2;
249
 
250
  size_t idx;
@@ -263,10 +271,10 @@ struct vk_device_struct {
263
  vk_matmul_pipeline pipeline_matmul_f32_f16 {};
264
  vk_matmul_pipeline2 pipeline_matmul_f16;
265
  vk_matmul_pipeline2 pipeline_matmul_f16_f32;
266
- vk_pipeline pipeline_matmul_split_k_reduce;
267
 
268
- vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
269
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
 
 
270
 
271
  vk_matmul_pipeline pipeline_matmul_id_f32 {};
272
  vk_matmul_pipeline2 pipeline_matmul_id_f16;
@@ -274,6 +282,9 @@ struct vk_device_struct {
274
 
275
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
276
 
 
 
 
277
  vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
278
  vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
279
  vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
@@ -640,6 +651,13 @@ struct vk_op_rwkv_wkv7_push_constants {
640
  uint32_t H;
641
  };
642
 
 
 
 
 
 
 
 
643
  // Allow pre-recording command buffers
644
  struct vk_staging_memcpy {
645
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -649,13 +667,6 @@ struct vk_staging_memcpy {
649
  size_t n;
650
  };
651
 
652
- struct vk_op_upscale_push_constants {
653
- uint32_t ne; uint32_t a_offset; uint32_t d_offset;
654
- uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
655
- uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
656
- float sf0; float sf1; float sf2; float sf3;
657
- };
658
-
659
  struct vk_context_struct {
660
  vk_submission * s;
661
  std::vector<vk_sequence> seqs;
@@ -1598,6 +1609,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1598
  // mulmat
1599
  std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
1600
  l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
 
1601
  l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
1602
  l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
1603
  std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
@@ -1662,6 +1674,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
1662
  m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
1663
  s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
1664
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1665
  l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
1666
  m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
1667
  s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
@@ -2000,6 +2026,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
2000
  if (device->mul_mat ## ID ## _s[TYPE]) \
2001
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2002
 
 
 
 
 
 
 
 
 
2003
  // Create 2 variants, {f16,f32} accumulator
2004
  #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2005
  CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -2031,6 +2065,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
2031
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2032
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2033
 
 
 
 
 
 
 
 
 
 
 
2034
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2035
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2036
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2056,6 +2100,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2056
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2057
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2058
  #undef CREATE_MM2
 
2059
  #undef CREATE_MM
2060
  } else {
2061
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
@@ -2073,6 +2118,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
2073
  if (device->mul_mat ## ID ## _s[TYPE]) \
2074
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2075
 
 
 
 
 
 
 
 
 
2076
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2077
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2078
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
@@ -2099,6 +2152,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
2099
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2100
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2101
 
 
 
 
 
 
 
 
 
 
 
2102
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2103
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2104
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2132,7 +2195,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2132
  uint32_t rm_stdq = 1;
2133
  uint32_t rm_kq = 2;
2134
  if (device->vendor_id == VK_VENDOR_ID_AMD) {
2135
- if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN
2136
  rm_stdq = 2;
2137
  rm_kq = 4;
2138
  }
@@ -2266,6 +2329,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2266
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2267
 
2268
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
 
2269
 
2270
  for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
2271
  if (device->subgroup_add && device->subgroup_require_full_support) {
@@ -2452,6 +2516,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2452
  bool pipeline_robustness = false;
2453
  bool coopmat2_support = false;
2454
  device->coopmat_support = false;
 
2455
 
2456
  for (const auto& properties : ext_props) {
2457
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -2477,6 +2542,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
2477
  } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2478
  !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2479
  coopmat2_support = true;
 
 
 
 
 
2480
  }
2481
  }
2482
 
@@ -2490,6 +2560,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2490
  vk::PhysicalDeviceVulkan11Properties vk11_props;
2491
  vk::PhysicalDeviceVulkan12Properties vk12_props;
2492
  vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
 
2493
 
2494
  props2.pNext = &props3;
2495
  props3.pNext = &subgroup_props;
@@ -2524,6 +2595,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
2524
  }
2525
  #endif
2526
 
 
 
 
 
 
2527
  device->physical_device.getProperties2(&props2);
2528
  device->properties = props2.properties;
2529
  device->vendor_id = device->properties.vendorID;
@@ -2570,6 +2646,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
2570
  device->coopmat_support = false;
2571
  }
2572
 
 
 
2573
  std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
2574
 
2575
  // Try to find a non-graphics compute queue and transfer-focused queues
@@ -2662,6 +2740,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
2662
  device_extensions.push_back("VK_KHR_maintenance4");
2663
  }
2664
 
 
 
 
 
 
 
 
 
2665
  vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
2666
 
2667
  device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -2831,6 +2917,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
2831
  device->coopmat_acc_f16_support = true;
2832
  }
2833
  }
 
 
 
 
 
 
 
 
 
 
 
2834
  }
2835
  }
2836
 
@@ -2935,25 +3032,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2935
  vk::PhysicalDevice physical_device = devices[dev_num];
2936
  std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
2937
 
2938
- vk::PhysicalDeviceProperties2 props2;
2939
- vk::PhysicalDeviceMaintenance3Properties props3;
2940
- vk::PhysicalDeviceSubgroupProperties subgroup_props;
2941
- vk::PhysicalDeviceDriverProperties driver_props;
2942
- props2.pNext = &props3;
2943
- props3.pNext = &subgroup_props;
2944
- subgroup_props.pNext = &driver_props;
2945
- physical_device.getProperties2(&props2);
2946
-
2947
- vk_device_architecture arch = get_device_architecture(physical_device);
2948
- uint32_t default_subgroup_size = get_subgroup_size("", arch);
2949
- const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
2950
-
2951
- const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
2952
-
2953
  bool fp16_storage = false;
2954
  bool fp16_compute = false;
2955
  bool coopmat_support = false;
2956
  bool coopmat2_support = false;
 
2957
 
2958
  for (auto properties : ext_props) {
2959
  if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@@ -2969,27 +3052,44 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2969
  } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2970
  !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2971
  coopmat2_support = true;
 
 
 
 
 
2972
  #endif
2973
  }
2974
  }
2975
 
2976
  const vk_device_architecture device_architecture = get_device_architecture(physical_device);
2977
 
2978
- if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
2979
- coopmat_support = false;
2980
- }
2981
-
2982
  const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
2983
  bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
2984
 
2985
  bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2986
 
2987
- vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2988
 
2989
  VkPhysicalDeviceFeatures2 device_features2;
2990
  device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
2991
  device_features2.pNext = nullptr;
2992
- device_features2.features = (VkPhysicalDeviceFeatures)device_features;
2993
 
2994
  VkPhysicalDeviceVulkan11Features vk11_features;
2995
  vk11_features.pNext = nullptr;
@@ -3002,7 +3102,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
3002
  vk11_features.pNext = &vk12_features;
3003
 
3004
  // Pointer to the last chain element
3005
- VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
3006
 
3007
  #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
3008
  VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
@@ -3014,20 +3114,37 @@ static void ggml_vk_print_gpu_info(size_t idx) {
3014
  last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
3015
  last_struct = (VkBaseOutStructure *)&coopmat_features;
3016
  }
 
 
 
 
 
 
 
 
3017
 
3018
  vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
3019
 
3020
  fp16 = fp16 && vk12_features.shaderFloat16;
3021
 
3022
- coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
3023
- #endif
 
 
 
 
 
 
 
 
 
3024
 
3025
  std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
3026
 
3027
  std::string device_name = props2.properties.deviceName.data();
3028
- GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
3029
  idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
3030
- props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
3031
 
3032
  if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
3033
  GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
@@ -3293,6 +3410,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
3293
  }
3294
  }
3295
 
 
 
 
 
 
 
 
 
 
 
 
3296
  if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
3297
  return nullptr;
3298
  }
@@ -3585,8 +3713,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
3585
  return s;
3586
  }
3587
 
3588
-
3589
-
3590
  static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
3591
  const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
3592
  const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
@@ -4016,8 +4142,8 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
4016
  return split_k;
4017
  }
4018
 
4019
- static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
4020
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
4021
 
4022
  if (ctx->device->coopmat2) {
4023
  // Use large shader when the N dimension is greater than the medium shader's tile size
@@ -4042,9 +4168,9 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
4042
  return aligned ? mmp->a_l : mmp->l;
4043
  }
4044
 
4045
- static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
4046
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
4047
- return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
4048
  }
4049
 
4050
  static void ggml_vk_matmul(
@@ -4054,7 +4180,7 @@ static void ggml_vk_matmul(
4054
  uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
4055
  uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
4056
  uint32_t padded_n) {
4057
- VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
4058
  ggml_vk_sync_buffers(subctx);
4059
  if (split_k == 1) {
4060
  const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
@@ -4072,7 +4198,7 @@ static void ggml_vk_matmul(
4072
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
4073
  }
4074
 
4075
- static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
4076
  VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
4077
 
4078
  if (ctx->device->coopmat2) {
@@ -4214,6 +4340,25 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
4214
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
4215
  }
4216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4217
  static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
4218
  VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
4219
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -4265,10 +4410,19 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4265
 
4266
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
4267
 
4268
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
 
 
 
 
 
 
 
 
 
4269
 
4270
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
4271
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
4272
 
4273
  if (qx_needs_dequant) {
4274
  // Fall back to dequant + f16 mulmat
@@ -4278,13 +4432,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4278
  // Not implemented
4279
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4280
 
4281
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4282
- const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
4283
 
4284
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4285
 
4286
  // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4287
- uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
4288
  const int x_ne = ne01 * ne00;
4289
  const int y_ne = padded_n * ne10;
4290
  const int d_ne = ne11 * ne01;
@@ -4294,11 +4448,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4294
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
4295
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
4296
  const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
4297
- const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
4298
  const uint64_t d_sz = sizeof(float) * d_ne;
4299
 
4300
  vk_pipeline to_fp16_vk_0 = nullptr;
4301
  vk_pipeline to_fp16_vk_1 = nullptr;
 
4302
 
4303
  if (x_non_contig) {
4304
  to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
@@ -4313,6 +4468,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4313
  GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
4314
  GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
4315
 
 
 
 
 
4316
  if (dryrun) {
4317
  const uint64_t x_sz_upd = x_sz * ne02 * ne03;
4318
  const uint64_t y_sz_upd = y_sz * ne12 * ne13;
@@ -4326,7 +4485,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4326
  if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
4327
  ctx->prealloc_size_x = x_sz_upd;
4328
  }
4329
- if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
4330
  ctx->prealloc_size_y = y_sz_upd;
4331
  }
4332
  if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
@@ -4341,6 +4500,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4341
  if (qy_needs_dequant) {
4342
  ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
4343
  }
 
 
 
4344
  if (split_k > 1) {
4345
  ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
4346
  }
@@ -4376,6 +4538,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4376
  if (qy_needs_dequant) {
4377
  d_Y = ctx->prealloc_y;
4378
  GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
 
 
 
4379
  } else {
4380
  d_Y = d_Qy;
4381
  y_buf_offset = qy_buf_offset;
@@ -4392,6 +4557,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4392
  if (y_non_contig) {
4393
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
4394
  }
 
 
 
4395
 
4396
  uint32_t stride_batch_x = ne00*ne01;
4397
  uint32_t stride_batch_y = ne10*ne11;
@@ -4400,7 +4568,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4400
  stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
4401
  }
4402
 
4403
- if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
4404
  stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
4405
  }
4406
 
@@ -6929,6 +7097,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
6929
  }
6930
  }
6931
 
 
 
 
 
6932
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
6933
 
6934
  vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -7177,6 +7349,10 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
7177
 
7178
  ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
7179
 
 
 
 
 
7180
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
7181
 
7182
  ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
@@ -7236,66 +7412,198 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
7236
  free(x_chk);
7237
  }
7238
 
7239
- static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7240
  VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
7241
  const size_t x_ne = m * k * batch;
7242
  const size_t y_ne = k * n * batch;
7243
  const size_t d_ne = m * n * batch;
7244
 
 
 
 
 
 
 
 
 
 
 
7245
  vk_pipeline p;
7246
  std::string shname;
7247
  if (shader_size == 0) {
7248
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
7249
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
7250
  } else if (shader_size == 1) {
7251
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
7252
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
7253
  } else if (shader_size == 2) {
7254
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
7255
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
7256
  } else {
7257
  GGML_ASSERT(0);
7258
  }
7259
 
7260
- const size_t kpad = ggml_vk_align_size(k, p->align);
7261
 
7262
- if (k != kpad) {
7263
  if (shader_size == 0) {
7264
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
7265
  shname = std::string(ggml_type_name(quant)) + "_S";
7266
  } else if (shader_size == 1) {
7267
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
7268
  shname = std::string(ggml_type_name(quant)) + "_M";
7269
  } else if (shader_size == 2) {
7270
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
7271
  shname = std::string(ggml_type_name(quant)) + "_L";
7272
  } else {
7273
  GGML_ASSERT(0);
7274
  }
7275
  }
7276
 
 
 
 
 
 
7277
  const size_t x_sz = sizeof(float) * x_ne;
7278
  const size_t y_sz = sizeof(float) * y_ne;
7279
  const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
 
7280
  const size_t d_sz = sizeof(float) * d_ne;
7281
  float * x = (float *) malloc(x_sz);
7282
  float * y = (float *) malloc(y_sz);
7283
  void * qx = malloc(qx_sz);
7284
  vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7285
  vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
 
7286
  vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7287
  float * d = (float *) malloc(d_sz);
7288
  float * d_chk = (float *) malloc(d_sz);
7289
 
7290
  for (size_t i = 0; i < x_ne; i++) {
7291
  x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
 
 
7292
  }
7293
 
7294
  ggml_vk_quantize_data(x, qx, x_ne, quant);
7295
 
7296
  for (size_t i = 0; i < y_ne; i++) {
7297
- // y[i] = rand() / (float)RAND_MAX;
7298
- y[i] = (i % k == i / k) ? 1.0f : 0.0f;
 
7299
  }
7300
 
7301
  ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
@@ -7310,6 +7618,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7310
  ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
7311
  }
7312
  }
 
 
 
 
 
 
 
7313
 
7314
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
7315
 
@@ -7318,13 +7633,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7318
 
7319
  vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
7320
  ggml_vk_ctx_begin(ctx->device, subctx);
7321
- for (size_t i = 0; i < num_it; i++) {
7322
- ggml_vk_matmul(
7323
- ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
7324
- m, n, k,
7325
- k, k, m, k*m, k*n, m*n,
7326
- split_k, batch, batch, batch, 1, 1, n
7327
- );
 
 
 
 
 
 
 
 
 
 
 
 
7328
  }
7329
  ggml_vk_ctx_end(subctx);
7330
 
@@ -7382,7 +7709,11 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7382
 
7383
  double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
7384
 
7385
- std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
 
 
 
 
7386
 
7387
  if (avg_err > 0.01 || std::isnan(avg_err)) {
7388
  std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
@@ -7392,6 +7723,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7392
  std::cerr << "Expected result: " << std::endl << std::endl;
7393
  ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
7394
 
 
 
 
 
 
 
7395
  if (split_k > 1) {
7396
  float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
7397
  ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
@@ -7414,6 +7751,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7414
 
7415
  ggml_vk_destroy_buffer(qx_buf);
7416
  ggml_vk_destroy_buffer(y_buf);
 
7417
  ggml_vk_destroy_buffer(d_buf);
7418
 
7419
  free(x);
@@ -7446,7 +7784,25 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
7446
  128, 49, 49,
7447
  4096, 49, 4096,
7448
  };
7449
- const size_t num_it = 100;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7450
 
7451
  for (size_t i = 0; i < vals.size(); i += 3) {
7452
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
@@ -9258,7 +9614,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9258
  }
9259
 
9260
  if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
9261
- const float *params = (const float *)tensor->op_params;
9262
  tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
9263
  } else if (tensor->op == GGML_OP_MUL_MAT) {
9264
  tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
@@ -9275,7 +9631,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9275
  } else if (tensor->op == GGML_OP_UPSCALE) {
9276
  tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
9277
  } else if (tensor->op == GGML_OP_SCALE) {
9278
- tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]);
 
9279
  } else if (tensor->op == GGML_OP_SQR) {
9280
  tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
9281
  } else if (tensor->op == GGML_OP_SIN) {
@@ -9283,7 +9640,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9283
  } else if (tensor->op == GGML_OP_COS) {
9284
  tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
9285
  } else if (tensor->op == GGML_OP_CLAMP) {
9286
- tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
 
9287
  } else if (tensor->op == GGML_OP_PAD) {
9288
  tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
9289
  } else if (tensor->op == GGML_OP_REPEAT) {
@@ -9297,7 +9655,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9297
  } else if (tensor->op == GGML_OP_NORM) {
9298
  tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9299
  } else if (tensor->op == GGML_OP_GROUP_NORM) {
9300
- tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
 
9301
  } else if (tensor->op == GGML_OP_RMS_NORM) {
9302
  tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9303
  } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
@@ -9310,14 +9669,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9310
  tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
9311
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
9312
  if (src1 != nullptr) {
9313
- tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
 
9314
  } else {
9315
  tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
9316
  }
9317
  } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
9318
  tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
9319
  } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
9320
- tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
9321
  } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
9322
  const int n_dims = ((int32_t *) tensor->op_params)[1];
9323
  const int mode = ((int32_t *) tensor->op_params)[2];
 
234
  bool float_controls_rte_fp16;
235
  bool subgroup_add;
236
 
237
+ bool integer_dot_product;
238
+
239
  bool subgroup_size_control;
240
  uint32_t subgroup_min_size;
241
  uint32_t subgroup_max_size;
 
247
  uint32_t coopmat_m;
248
  uint32_t coopmat_n;
249
  uint32_t coopmat_k;
250
+
251
+ bool coopmat_int_support;
252
+ uint32_t coopmat_int_m;
253
+ uint32_t coopmat_int_n;
254
+ uint32_t coopmat_int_k;
255
+
256
  bool coopmat2;
257
 
258
  size_t idx;
 
271
  vk_matmul_pipeline pipeline_matmul_f32_f16 {};
272
  vk_matmul_pipeline2 pipeline_matmul_f16;
273
  vk_matmul_pipeline2 pipeline_matmul_f16_f32;
 
274
 
 
275
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
276
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
277
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
278
 
279
  vk_matmul_pipeline pipeline_matmul_id_f32 {};
280
  vk_matmul_pipeline2 pipeline_matmul_id_f16;
 
282
 
283
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
284
 
285
+ vk_pipeline pipeline_matmul_split_k_reduce;
286
+ vk_pipeline pipeline_quantize_q8_1;
287
+
288
  vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
289
  vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
290
  vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
 
651
  uint32_t H;
652
  };
653
 
654
+ struct vk_op_upscale_push_constants {
655
+ uint32_t ne; uint32_t a_offset; uint32_t d_offset;
656
+ uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
657
+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
658
+ float sf0; float sf1; float sf2; float sf3;
659
+ };
660
+
661
  // Allow pre-recording command buffers
662
  struct vk_staging_memcpy {
663
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
 
667
  size_t n;
668
  };
669
 
 
 
 
 
 
 
 
670
  struct vk_context_struct {
671
  vk_submission * s;
672
  std::vector<vk_sequence> seqs;
 
1609
  // mulmat
1610
  std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
1611
  l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
1612
+ l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
1613
  l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
1614
  l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
1615
  std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
 
1674
  m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
1675
  s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
1676
 
1677
+ const uint32_t tm_int_l = device->coopmat_int_support ? device->coopmat_int_m : 4;
1678
+ const uint32_t tm_int_m = device->coopmat_int_support ? device->coopmat_int_m : 4;
1679
+ const uint32_t tm_int_s = device->coopmat_int_support ? device->coopmat_int_m : 2;
1680
+ const uint32_t tn_int_l = device->coopmat_int_support ? device->coopmat_int_n : 4;
1681
+ const uint32_t tn_int_m = device->coopmat_int_support ? device->coopmat_int_n : 2;
1682
+ const uint32_t tn_int_s = device->coopmat_int_support ? device->coopmat_int_n : 2;
1683
+ const uint32_t tk_int_l = device->coopmat_int_support ? device->coopmat_int_k : 1;
1684
+ const uint32_t tk_int_m = device->coopmat_int_support ? device->coopmat_int_k : 1;
1685
+ const uint32_t tk_int_s = device->coopmat_int_support ? device->coopmat_int_k : 1;
1686
+
1687
+ l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_int_l, tn_int_l, tk_int_l, subgroup_size_8 };
1688
+ m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_int_m, tn_int_m, tk_int_m, subgroup_size_8 };
1689
+ s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_int_s, tn_int_s, tk_int_s, subgroup_size_8 };
1690
+
1691
  l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
1692
  m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
1693
  s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
 
2026
  if (device->mul_mat ## ID ## _s[TYPE]) \
2027
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2028
 
2029
+ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2030
+ if (device->mul_mat ## ID ## _l[TYPE]) \
2031
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2032
+ if (device->mul_mat ## ID ## _m[TYPE]) \
2033
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2034
+ if (device->mul_mat ## ID ## _s[TYPE]) \
2035
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2036
+
2037
  // Create 2 variants, {f16,f32} accumulator
2038
  #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2039
  CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
 
2065
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2066
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2067
 
2068
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2069
+ if (device->integer_dot_product) {
2070
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2071
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2072
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2073
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2074
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2075
+ }
2076
+ #endif
2077
+
2078
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2079
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2080
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
 
2100
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2101
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2102
  #undef CREATE_MM2
2103
+ #undef CREATE_MMQ
2104
  #undef CREATE_MM
2105
  } else {
2106
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
 
2118
  if (device->mul_mat ## ID ## _s[TYPE]) \
2119
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2120
 
2121
+ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2122
+ if (device->mul_mat ## ID ## _l[TYPE]) \
2123
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2124
+ if (device->mul_mat ## ID ## _m[TYPE]) \
2125
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2126
+ if (device->mul_mat ## ID ## _s[TYPE]) \
2127
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2128
+
2129
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2130
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2131
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
 
2152
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2153
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2154
 
2155
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2156
+ if (device->integer_dot_product) {
2157
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2158
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2159
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2160
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2161
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2162
+ }
2163
+ #endif
2164
+
2165
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2166
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2167
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
 
2195
  uint32_t rm_stdq = 1;
2196
  uint32_t rm_kq = 2;
2197
  if (device->vendor_id == VK_VENDOR_ID_AMD) {
2198
+ if (device->architecture == AMD_GCN) {
2199
  rm_stdq = 2;
2200
  rm_kq = 4;
2201
  }
 
2329
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2330
 
2331
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2332
+ ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2333
 
2334
  for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
2335
  if (device->subgroup_add && device->subgroup_require_full_support) {
 
2516
  bool pipeline_robustness = false;
2517
  bool coopmat2_support = false;
2518
  device->coopmat_support = false;
2519
+ device->integer_dot_product = false;
2520
 
2521
  for (const auto& properties : ext_props) {
2522
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
 
2542
  } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2543
  !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2544
  coopmat2_support = true;
2545
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2546
+ } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
2547
+ !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
2548
+ device->integer_dot_product = true;
2549
+ #endif
2550
  }
2551
  }
2552
 
 
2560
  vk::PhysicalDeviceVulkan11Properties vk11_props;
2561
  vk::PhysicalDeviceVulkan12Properties vk12_props;
2562
  vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2563
+ vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
2564
 
2565
  props2.pNext = &props3;
2566
  props3.pNext = &subgroup_props;
 
2595
  }
2596
  #endif
2597
 
2598
+ if (device->integer_dot_product) {
2599
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
2600
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
2601
+ }
2602
+
2603
  device->physical_device.getProperties2(&props2);
2604
  device->properties = props2.properties;
2605
  device->vendor_id = device->properties.vendorID;
 
2646
  device->coopmat_support = false;
2647
  }
2648
 
2649
+ device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
2650
+
2651
  std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
2652
 
2653
  // Try to find a non-graphics compute queue and transfer-focused queues
 
2740
  device_extensions.push_back("VK_KHR_maintenance4");
2741
  }
2742
 
2743
+ VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
2744
+ shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
2745
+ if (device->integer_dot_product) {
2746
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
2747
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
2748
+ device_extensions.push_back("VK_KHR_shader_integer_dot_product");
2749
+ }
2750
+
2751
  vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
2752
 
2753
  device->fp16 = device->fp16 && vk12_features.shaderFloat16;
 
2917
  device->coopmat_acc_f16_support = true;
2918
  }
2919
  }
2920
+ } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
2921
+ (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
2922
+ (vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 &&
2923
+ (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
2924
+ (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
2925
+ device->coopmat_int_m == 0
2926
+ ) {
2927
+ device->coopmat_int_support = true;
2928
+ device->coopmat_int_m = prop.MSize;
2929
+ device->coopmat_int_n = prop.NSize;
2930
+ device->coopmat_int_k = prop.KSize;
2931
  }
2932
  }
2933
 
 
3032
  vk::PhysicalDevice physical_device = devices[dev_num];
3033
  std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
3034
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3035
  bool fp16_storage = false;
3036
  bool fp16_compute = false;
3037
  bool coopmat_support = false;
3038
  bool coopmat2_support = false;
3039
+ bool integer_dot_product = false;
3040
 
3041
  for (auto properties : ext_props) {
3042
  if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
 
3052
  } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
3053
  !getenv("GGML_VK_DISABLE_COOPMAT2")) {
3054
  coopmat2_support = true;
3055
+ #endif
3056
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3057
+ } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
3058
+ !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
3059
+ integer_dot_product = true;
3060
  #endif
3061
  }
3062
  }
3063
 
3064
  const vk_device_architecture device_architecture = get_device_architecture(physical_device);
3065
 
 
 
 
 
3066
  const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
3067
  bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
3068
 
3069
  bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
3070
 
3071
+ vk::PhysicalDeviceProperties2 props2;
3072
+ vk::PhysicalDeviceMaintenance3Properties props3;
3073
+ vk::PhysicalDeviceSubgroupProperties subgroup_props;
3074
+ vk::PhysicalDeviceDriverProperties driver_props;
3075
+ vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
3076
+ props2.pNext = &props3;
3077
+ props3.pNext = &subgroup_props;
3078
+ subgroup_props.pNext = &driver_props;
3079
+
3080
+ // Pointer to the last chain element
3081
+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
3082
+
3083
+ if (integer_dot_product) {
3084
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
3085
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
3086
+ }
3087
+
3088
+ physical_device.getProperties2(&props2);
3089
 
3090
  VkPhysicalDeviceFeatures2 device_features2;
3091
  device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
3092
  device_features2.pNext = nullptr;
 
3093
 
3094
  VkPhysicalDeviceVulkan11Features vk11_features;
3095
  vk11_features.pNext = nullptr;
 
3102
  vk11_features.pNext = &vk12_features;
3103
 
3104
  // Pointer to the last chain element
3105
+ last_struct = (VkBaseOutStructure *)&vk12_features;
3106
 
3107
  #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
3108
  VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
 
3114
  last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
3115
  last_struct = (VkBaseOutStructure *)&coopmat_features;
3116
  }
3117
+ #endif
3118
+
3119
+ VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
3120
+ shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
3121
+ if (integer_dot_product) {
3122
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
3123
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
3124
+ }
3125
 
3126
  vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
3127
 
3128
  fp16 = fp16 && vk12_features.shaderFloat16;
3129
 
3130
+ uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
3131
+ const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
3132
+ const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
3133
+
3134
+ integer_dot_product = integer_dot_product
3135
+ && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated
3136
+ && shader_integer_dot_product_features.shaderIntegerDotProduct;
3137
+
3138
+ coopmat_support = coopmat_support
3139
+ && coopmat_features.cooperativeMatrix
3140
+ && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
3141
 
3142
  std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
3143
 
3144
  std::string device_name = props2.properties.deviceName.data();
3145
+ GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
3146
  idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
3147
+ props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
3148
 
3149
  if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
3150
  GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
 
3410
  }
3411
  }
3412
 
3413
+ // MMQ
3414
+ if (src1_type == GGML_TYPE_Q8_1) {
3415
+ vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc;
3416
+
3417
+ if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
3418
+ return nullptr;
3419
+ }
3420
+
3421
+ return pipelines;
3422
+ }
3423
+
3424
  if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
3425
  return nullptr;
3426
  }
 
3713
  return s;
3714
  }
3715
 
 
 
3716
  static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
3717
  const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
3718
  const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
 
4142
  return split_k;
4143
  }
4144
 
4145
+ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
4146
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
4147
 
4148
  if (ctx->device->coopmat2) {
4149
  // Use large shader when the N dimension is greater than the medium shader's tile size
 
4168
  return aligned ? mmp->a_l : mmp->l;
4169
  }
4170
 
4171
+ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
4172
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
4173
+ return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
4174
  }
4175
 
4176
  static void ggml_vk_matmul(
 
4180
  uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
4181
  uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
4182
  uint32_t padded_n) {
4183
+ VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
4184
  ggml_vk_sync_buffers(subctx);
4185
  if (split_k == 1) {
4186
  const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
 
4198
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
4199
  }
4200
 
4201
+ static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
4202
  VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
4203
 
4204
  if (ctx->device->coopmat2) {
 
4340
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
4341
  }
4342
 
4343
+ static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
4344
+ switch(type) {
4345
+ case GGML_TYPE_Q8_1:
4346
+ return ctx->device->pipeline_quantize_q8_1;
4347
+ default:
4348
+ std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
4349
+ GGML_ABORT("fatal error");
4350
+ }
4351
+ }
4352
+
4353
+ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
4354
+ VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
4355
+
4356
+ vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
4357
+
4358
+ ggml_vk_sync_buffers(subctx);
4359
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 });
4360
+ }
4361
+
4362
  static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
4363
  VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
4364
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
 
4410
 
4411
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
4412
 
4413
+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
4414
+
4415
+ // Check for mmq first
4416
+ vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
4417
+
4418
+ if (mmp == nullptr) {
4419
+ // Fall back to f16 dequant mul mat
4420
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
4421
+ quantize_y = false;
4422
+ }
4423
 
4424
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
4425
+ const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
4426
 
4427
  if (qx_needs_dequant) {
4428
  // Fall back to dequant + f16 mulmat
 
4432
  // Not implemented
4433
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4434
 
4435
+ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
4436
+ const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
4437
 
4438
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
4439
 
4440
  // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4441
+ uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
4442
  const int x_ne = ne01 * ne00;
4443
  const int y_ne = padded_n * ne10;
4444
  const int d_ne = ne11 * ne01;
 
4448
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
4449
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
4450
  const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
4451
+ const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
4452
  const uint64_t d_sz = sizeof(float) * d_ne;
4453
 
4454
  vk_pipeline to_fp16_vk_0 = nullptr;
4455
  vk_pipeline to_fp16_vk_1 = nullptr;
4456
+ vk_pipeline to_q8_1 = nullptr;
4457
 
4458
  if (x_non_contig) {
4459
  to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
 
4468
  GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
4469
  GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
4470
 
4471
+ if (quantize_y) {
4472
+ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
4473
+ }
4474
+
4475
  if (dryrun) {
4476
  const uint64_t x_sz_upd = x_sz * ne02 * ne03;
4477
  const uint64_t y_sz_upd = y_sz * ne12 * ne13;
 
4485
  if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
4486
  ctx->prealloc_size_x = x_sz_upd;
4487
  }
4488
+ if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
4489
  ctx->prealloc_size_y = y_sz_upd;
4490
  }
4491
  if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
 
4500
  if (qy_needs_dequant) {
4501
  ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
4502
  }
4503
+ if (quantize_y) {
4504
+ ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1);
4505
+ }
4506
  if (split_k > 1) {
4507
  ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
4508
  }
 
4538
  if (qy_needs_dequant) {
4539
  d_Y = ctx->prealloc_y;
4540
  GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
4541
+ } else if (quantize_y) {
4542
+ d_Y = ctx->prealloc_y;
4543
+ GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
4544
  } else {
4545
  d_Y = d_Qy;
4546
  y_buf_offset = qy_buf_offset;
 
4557
  if (y_non_contig) {
4558
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
4559
  }
4560
+ if (quantize_y) {
4561
+ ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
4562
+ }
4563
 
4564
  uint32_t stride_batch_x = ne00*ne01;
4565
  uint32_t stride_batch_y = ne10*ne11;
 
4568
  stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
4569
  }
4570
 
4571
+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
4572
  stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
4573
  }
4574
 
 
7097
  }
7098
  }
7099
 
7100
+ if (ctx->device->need_compiles) {
7101
+ ggml_vk_load_shaders(ctx->device);
7102
+ }
7103
+
7104
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
7105
 
7106
  vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
 
7349
 
7350
  ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
7351
 
7352
+ if (ctx->device->need_compiles) {
7353
+ ggml_vk_load_shaders(ctx->device);
7354
+ }
7355
+
7356
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
7357
 
7358
  ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
 
7412
  free(x_chk);
7413
  }
7414
 
7415
+ // This does not work without ggml q8_1 quantization support
7416
+ //
7417
+ // typedef uint16_t ggml_half;
7418
+ // typedef uint32_t ggml_half2;
7419
+ //
7420
+ // #define QK8_1 32
7421
+ // typedef struct {
7422
+ // union {
7423
+ // struct {
7424
+ // ggml_half d; // delta
7425
+ // ggml_half s; // d * sum(qs[i])
7426
+ // } GGML_COMMON_AGGR_S;
7427
+ // ggml_half2 ds;
7428
+ // } GGML_COMMON_AGGR_U;
7429
+ // int8_t qs[QK8_1]; // quants
7430
+ // } block_q8_1;
7431
+ //
7432
+ // static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
7433
+ // VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")");
7434
+ // GGML_ASSERT(quant == GGML_TYPE_Q8_1);
7435
+ //
7436
+ // const size_t x_sz = sizeof(float) * ne;
7437
+ // const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
7438
+ // float * x = (float *) malloc(x_sz);
7439
+ // block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz);
7440
+ // block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
7441
+ // vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7442
+ // vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7443
+ //
7444
+ // for (size_t i = 0; i < ne; i++) {
7445
+ // x[i] = rand() / (float)RAND_MAX;
7446
+ // }
7447
+ //
7448
+ // vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
7449
+ //
7450
+ // ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
7451
+ //
7452
+ // if (ctx->device->need_compiles) {
7453
+ // ggml_vk_load_shaders(ctx->device);
7454
+ // }
7455
+ //
7456
+ // ggml_pipeline_allocate_descriptor_sets(ctx->device);
7457
+ //
7458
+ // ggml_vk_buffer_write(x_buf, 0, x, x_sz);
7459
+ //
7460
+ // vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
7461
+ // ggml_vk_ctx_begin(ctx->device, subctx);
7462
+ // ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne);
7463
+ // ggml_vk_ctx_end(subctx);
7464
+ //
7465
+ // auto begin = std::chrono::high_resolution_clock::now();
7466
+ //
7467
+ // ggml_vk_submit(subctx, ctx->fence);
7468
+ // VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
7469
+ // ctx->device->device.resetFences({ ctx->fence });
7470
+ //
7471
+ // auto end = std::chrono::high_resolution_clock::now();
7472
+ //
7473
+ // double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
7474
+ // ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);
7475
+ //
7476
+ // ggml_vk_quantize_data(x, qx_res, ne, quant);
7477
+ //
7478
+ // int first_err = -1;
7479
+ //
7480
+ // for (size_t i = 0; i < ne / 32; i++) {
7481
+ // double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));
7482
+ //
7483
+ // if (first_err < 0 && error > 0.1) {
7484
+ // first_err = i;
7485
+ // }
7486
+ //
7487
+ // error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));
7488
+ //
7489
+ // if (first_err < 0 && error > 0.1) {
7490
+ // first_err = i;
7491
+ // }
7492
+ //
7493
+ // for (size_t j = 0; j < 32; j++) {
7494
+ // uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);
7495
+ //
7496
+ // if (first_err < 0 && error > 1) {
7497
+ // first_err = i;
7498
+ // }
7499
+ // }
7500
+ // }
7501
+ //
7502
+ // std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl;
7503
+ //
7504
+ // if (first_err != -1) {
7505
+ // std::cerr << "first_error = " << first_err << std::endl;
7506
+ // std::cerr << "Actual result: " << std::endl << std::endl;
7507
+ // std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
7508
+ // for (size_t j = 0; j < 32; j++) {
7509
+ // std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " ";
7510
+ // }
7511
+ // std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl;
7512
+ // std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
7513
+ // for (size_t j = 0; j < 32; j++) {
7514
+ // std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " ";
7515
+ // }
7516
+ // std::cerr << std::endl;
7517
+ // }
7518
+ //
7519
+ // ggml_vk_destroy_buffer(x_buf);
7520
+ // ggml_vk_destroy_buffer(qx_buf);
7521
+ //
7522
+ // free(x);
7523
+ // free(qx);
7524
+ // free(qx_res);
7525
+ // }
7526
+
7527
+ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {
7528
  VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
7529
  const size_t x_ne = m * k * batch;
7530
  const size_t y_ne = k * n * batch;
7531
  const size_t d_ne = m * n * batch;
7532
 
7533
+ vk_matmul_pipeline2 * pipelines;
7534
+
7535
+ if (mmq) {
7536
+ pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;
7537
+ } else {
7538
+ pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
7539
+ }
7540
+
7541
+ const bool fp16acc = ctx->device->fp16;
7542
+
7543
  vk_pipeline p;
7544
  std::string shname;
7545
  if (shader_size == 0) {
7546
+ p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
7547
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
7548
  } else if (shader_size == 1) {
7549
+ p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
7550
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
7551
  } else if (shader_size == 2) {
7552
+ p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
7553
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
7554
  } else {
7555
  GGML_ASSERT(0);
7556
  }
7557
 
7558
+ const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);
7559
 
7560
+ if (mmq || k != kpad) {
7561
  if (shader_size == 0) {
7562
+ p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
7563
  shname = std::string(ggml_type_name(quant)) + "_S";
7564
  } else if (shader_size == 1) {
7565
+ p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
7566
  shname = std::string(ggml_type_name(quant)) + "_M";
7567
  } else if (shader_size == 2) {
7568
+ p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
7569
  shname = std::string(ggml_type_name(quant)) + "_L";
7570
  } else {
7571
  GGML_ASSERT(0);
7572
  }
7573
  }
7574
 
7575
+ if (p == nullptr) {
7576
+ std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl;
7577
+ return;
7578
+ }
7579
+
7580
  const size_t x_sz = sizeof(float) * x_ne;
7581
  const size_t y_sz = sizeof(float) * y_ne;
7582
  const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
7583
+ const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;
7584
  const size_t d_sz = sizeof(float) * d_ne;
7585
  float * x = (float *) malloc(x_sz);
7586
  float * y = (float *) malloc(y_sz);
7587
  void * qx = malloc(qx_sz);
7588
  vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7589
  vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7590
+ vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7591
  vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7592
  float * d = (float *) malloc(d_sz);
7593
  float * d_chk = (float *) malloc(d_sz);
7594
 
7595
  for (size_t i = 0; i < x_ne; i++) {
7596
  x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7597
+ // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
7598
+ // x[i] = i % k;
7599
  }
7600
 
7601
  ggml_vk_quantize_data(x, qx, x_ne, quant);
7602
 
7603
  for (size_t i = 0; i < y_ne; i++) {
7604
+ y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7605
+ // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
7606
+ // y[i] = i % k;
7607
  }
7608
 
7609
  ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
 
7618
  ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
7619
  }
7620
  }
7621
+ if (mmq) {
7622
+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it);
7623
+ }
7624
+
7625
+ if (ctx->device->need_compiles) {
7626
+ ggml_vk_load_shaders(ctx->device);
7627
+ }
7628
 
7629
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
7630
 
 
7633
 
7634
  vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
7635
  ggml_vk_ctx_begin(ctx->device, subctx);
7636
+ if (mmq) {
7637
+ for (size_t i = 0; i < num_it; i++) {
7638
+ ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
7639
+ ggml_vk_matmul(
7640
+ ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
7641
+ m, n, k,
7642
+ k, k, m, k*m, k*n, m*n,
7643
+ split_k, batch, batch, batch, 1, 1, n
7644
+ );
7645
+ }
7646
+ } else {
7647
+ for (size_t i = 0; i < num_it; i++) {
7648
+ ggml_vk_matmul(
7649
+ ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
7650
+ m, n, k,
7651
+ k, k, m, k*m, k*n, m*n,
7652
+ split_k, batch, batch, batch, 1, 1, n
7653
+ );
7654
+ }
7655
  }
7656
  ggml_vk_ctx_end(subctx);
7657
 
 
7709
 
7710
  double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
7711
 
7712
+ std::cerr << "TEST dequant matmul " << shname;
7713
+ if (mmq) {
7714
+ std::cerr << " mmq";
7715
+ }
7716
+ std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
7717
 
7718
  if (avg_err > 0.01 || std::isnan(avg_err)) {
7719
  std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
 
7723
  std::cerr << "Expected result: " << std::endl << std::endl;
7724
  ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
7725
 
7726
+ std::cerr << "src0: " << std::endl << std::endl;
7727
+ ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);
7728
+ std::cerr << std::endl;
7729
+ std::cerr << "src1: " << std::endl << std::endl;
7730
+ ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);
7731
+
7732
  if (split_k > 1) {
7733
  float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
7734
  ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
 
7751
 
7752
  ggml_vk_destroy_buffer(qx_buf);
7753
  ggml_vk_destroy_buffer(y_buf);
7754
+ ggml_vk_destroy_buffer(qy_buf);
7755
  ggml_vk_destroy_buffer(d_buf);
7756
 
7757
  free(x);
 
7784
  128, 49, 49,
7785
  4096, 49, 4096,
7786
  };
7787
+ const size_t num_it = 1;
7788
+
7789
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
7790
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
7791
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
7792
+
7793
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
7794
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
7795
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
7796
+
7797
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
7798
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
7799
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
7800
+
7801
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
7802
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
7803
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
7804
+
7805
+ abort();
7806
 
7807
  for (size_t i = 0; i < vals.size(); i += 3) {
7808
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
 
9614
  }
9615
 
9616
  if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
9617
+ const float * params = (const float *)tensor->op_params;
9618
  tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
9619
  } else if (tensor->op == GGML_OP_MUL_MAT) {
9620
  tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
 
9631
  } else if (tensor->op == GGML_OP_UPSCALE) {
9632
  tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
9633
  } else if (tensor->op == GGML_OP_SCALE) {
9634
+ const float * params = (const float *)tensor->op_params;
9635
+ tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
9636
  } else if (tensor->op == GGML_OP_SQR) {
9637
  tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
9638
  } else if (tensor->op == GGML_OP_SIN) {
 
9640
  } else if (tensor->op == GGML_OP_COS) {
9641
  tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
9642
  } else if (tensor->op == GGML_OP_CLAMP) {
9643
+ const float * params = (const float *)tensor->op_params;
9644
+ tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
9645
  } else if (tensor->op == GGML_OP_PAD) {
9646
  tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
9647
  } else if (tensor->op == GGML_OP_REPEAT) {
 
9655
  } else if (tensor->op == GGML_OP_NORM) {
9656
  tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9657
  } else if (tensor->op == GGML_OP_GROUP_NORM) {
9658
+ const float * float_params = (const float *)tensor->op_params;
9659
+ tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
9660
  } else if (tensor->op == GGML_OP_RMS_NORM) {
9661
  tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9662
  } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
 
9669
  tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
9670
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
9671
  if (src1 != nullptr) {
9672
+ const float * params = (const float *)tensor->op_params;
9673
+ tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
9674
  } else {
9675
  tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
9676
  }
9677
  } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
9678
  tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
9679
  } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
9680
+ tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);
9681
  } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
9682
  const int n_dims = ((int32_t *) tensor->op_params)[1];
9683
  const int mode = ((int32_t *) tensor->op_params)[2];
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp CHANGED
@@ -212,7 +212,7 @@ void main() {
212
  #else
213
  ACC_TYPE sums[WMITER * TM * WNITER * TN];
214
  FLOAT_TYPE cache_a[WMITER * TM];
215
- FLOAT_TYPE cache_b[WNITER * TN];
216
 
217
  [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
218
  sums[i] = ACC_TYPE(0.0f);
@@ -744,16 +744,14 @@ void main() {
744
  }
745
  [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
746
  [[unroll]] for (uint j = 0; j < TN; j++) {
747
- cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
748
  }
749
- }
750
 
751
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
752
  [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
753
  [[unroll]] for (uint cc = 0; cc < TN; cc++) {
754
  [[unroll]] for (uint cr = 0; cr < TM; cr++) {
755
  const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
756
- sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]);
757
  }
758
  }
759
  }
 
212
  #else
213
  ACC_TYPE sums[WMITER * TM * WNITER * TN];
214
  FLOAT_TYPE cache_a[WMITER * TM];
215
+ FLOAT_TYPE cache_b[TN];
216
 
217
  [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
218
  sums[i] = ACC_TYPE(0.0f);
 
744
  }
745
  [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
746
  [[unroll]] for (uint j = 0; j < TN; j++) {
747
+ cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
748
  }
 
749
 
 
750
  [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
751
  [[unroll]] for (uint cc = 0; cc < TN; cc++) {
752
  [[unroll]] for (uint cr = 0; cr < TM; cr++) {
753
  const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
754
+ sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
755
  }
756
  }
757
  }
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #extension GL_EXT_control_flow_attributes : enable
4
+ #extension GL_EXT_shader_16bit_storage : require
5
+ #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
6
+
7
+ #extension GL_EXT_integer_dot_product : require
8
+
9
+ #ifdef FLOAT16
10
+ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
11
+ #endif
12
+
13
+ #ifdef COOPMAT
14
+ #extension GL_KHR_cooperative_matrix : enable
15
+ #extension GL_KHR_memory_scope_semantics : enable
16
+ #extension GL_KHR_shader_subgroup_basic : enable
17
+ #endif
18
+
19
+ #ifdef MUL_MAT_ID
20
+ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
21
+ #endif
22
+
23
+ #include "types.comp"
24
+
25
+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
26
+
27
+ layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
28
+ #if defined(A_TYPE_PACKED32)
29
+ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
30
+ #endif
31
+ layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
32
+ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
33
+
34
+ #ifdef MUL_MAT_ID
35
+ layout (binding = 3) readonly buffer IDS {int data_ids[];};
36
+ #endif
37
+
38
+ layout (push_constant) uniform parameter
39
+ {
40
+ uint M;
41
+ uint N;
42
+ uint K;
43
+ uint stride_a;
44
+ uint stride_b;
45
+ uint stride_d;
46
+
47
+ uint batch_stride_a;
48
+ uint batch_stride_b;
49
+ uint batch_stride_d;
50
+
51
+ #ifdef MUL_MAT_ID
52
+ uint nei0;
53
+ uint nei1;
54
+ uint nbi1;
55
+ uint ne11;
56
+ #else
57
+ uint k_split;
58
+ uint ne02;
59
+ uint ne12;
60
+ uint broadcast2;
61
+ uint broadcast3;
62
+ #endif
63
+ } p;
64
+
65
+ layout (constant_id = 0) const uint BLOCK_SIZE = 64;
66
+ layout (constant_id = 1) const uint BM = 64;
67
+ layout (constant_id = 2) const uint BN = 64;
68
+ // layout (constant_id = 3) const uint BK = 32;
69
+ layout (constant_id = 4) const uint WM = 32;
70
+ layout (constant_id = 5) const uint WN = 32;
71
+ layout (constant_id = 6) const uint WMITER = 2;
72
+ layout (constant_id = 7) const uint TM = 4;
73
+ layout (constant_id = 8) const uint TN = 2;
74
+ layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
75
+ layout (constant_id = 10) const uint WARP = 32;
76
+
77
+ #define BK 32
78
+
79
+ #ifdef COOPMAT
80
+ #define SHMEM_STRIDE (BK / 4 + 4)
81
+ #else
82
+ #define SHMEM_STRIDE (BK / 4 + 1)
83
+ #endif
84
+
85
+ shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
86
+
87
+ #ifndef COOPMAT
88
+ #if QUANT_AUXF == 1
89
+ shared FLOAT_TYPE buf_a_dm[BM];
90
+ #else
91
+ shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
92
+ #endif
93
+ #endif
94
+
95
+ shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
96
+ #ifndef COOPMAT
97
+ shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
98
+ #endif
99
+
100
+ #define LOAD_VEC_A (4 * QUANT_R)
101
+ #define LOAD_VEC_B 4
102
+
103
+ #ifdef MUL_MAT_ID
104
+ shared u16vec2 row_ids[3072];
105
+ #endif // MUL_MAT_ID
106
+
107
+ #define NUM_WARPS (BLOCK_SIZE / WARP)
108
+
109
+ #ifdef COOPMAT
110
+ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
111
+ #endif
112
+
113
+ #include "mul_mmq_funcs.comp"
114
+
115
+ void main() {
116
+ #ifdef NEEDS_INIT_IQ_SHMEM
117
+ init_iq_shmem(gl_WorkGroupSize);
118
+ #endif
119
+
120
+ #ifdef MUL_MAT_ID
121
+ const uint expert_idx = gl_GlobalInvocationID.z;
122
+ #else
123
+ const uint batch_idx = gl_GlobalInvocationID.z;
124
+
125
+ const uint i13 = batch_idx / p.ne12;
126
+ const uint i12 = batch_idx % p.ne12;
127
+
128
+ const uint i03 = i13 / p.broadcast3;
129
+ const uint i02 = i12 / p.broadcast2;
130
+
131
+ const uint batch_idx_a = i03 * p.ne02 + i02;
132
+ #endif
133
+
134
+ const uint blocks_m = (p.M + BM - 1) / BM;
135
+ const uint ir = gl_WorkGroupID.x % blocks_m;
136
+ const uint ik = gl_WorkGroupID.x / blocks_m;
137
+ const uint ic = gl_WorkGroupID.y;
138
+
139
+ const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
140
+ const uint WSUBM = WM / WMITER;
141
+ const uint WSUBN = WN / WNITER;
142
+
143
+ #ifdef COOPMAT
144
+ const uint warp_i = gl_SubgroupID;
145
+
146
+ const uint tiw = gl_SubgroupInvocationID;
147
+
148
+ const uint cms_per_row = WM / TM;
149
+ const uint cms_per_col = WN / TN;
150
+
151
+ const uint storestride = WARP / TM;
152
+ const uint store_r = tiw % TM;
153
+ const uint store_c = tiw / TM;
154
+ #else
155
+ const uint warp_i = gl_LocalInvocationID.x / WARP;
156
+
157
+ const uint tiw = gl_LocalInvocationID.x % WARP;
158
+
159
+ const uint tiwr = tiw % (WSUBM / TM);
160
+ const uint tiwc = tiw / (WSUBM / TM);
161
+ #endif
162
+
163
+ const uint warp_r = warp_i % (BM / WM);
164
+ const uint warp_c = warp_i / (BM / WM);
165
+
166
+ const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
167
+ const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
168
+ const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
169
+ const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
170
+
171
+ const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK;
172
+ const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
173
+
174
+ #ifdef MUL_MAT_ID
175
+ uint _ne1 = 0;
176
+ for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
177
+ for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
178
+ if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
179
+ row_ids[_ne1] = u16vec2(ii0, ii1);
180
+ _ne1++;
181
+ }
182
+ }
183
+ }
184
+
185
+ barrier();
186
+
187
+ // Workgroup has no work
188
+ if (ic * BN >= _ne1) return;
189
+ #endif
190
+
191
+ #ifdef MUL_MAT_ID
192
+ const uint start_k = 0;
193
+ const uint end_k = p.K;
194
+ #else
195
+ const uint start_k = ik * p.k_split;
196
+ const uint end_k = min(p.K, (ik + 1) * p.k_split);
197
+ #endif
198
+
199
+ uint pos_a_ib = (
200
+ #ifdef MUL_MAT_ID
201
+ expert_idx * p.batch_stride_a +
202
+ #else
203
+ batch_idx_a * p.batch_stride_a +
204
+ #endif
205
+ ir * BM * p.stride_a + start_k) / BK;
206
+ #ifdef MUL_MAT_ID
207
+ uint pos_b_ib = 0;
208
+ #else
209
+ uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
210
+ #endif
211
+
212
+ #ifdef COOPMAT
213
+ coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
214
+ coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
215
+ coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
216
+
217
+ coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
218
+
219
+ coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
220
+
221
+ [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
222
+ sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
223
+ }
224
+ #else
225
+ int32_t cache_a_qs[WMITER * TM * BK / 4];
226
+
227
+ int32_t cache_b_qs[TN * BK / 4];
228
+
229
+ ACC_TYPE sums[WMITER * TM * WNITER * TN];
230
+
231
+ [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
232
+ sums[i] = ACC_TYPE(0.0f);
233
+ }
234
+ #endif
235
+
236
+ #if QUANT_AUXF == 1
237
+ FLOAT_TYPE cache_a_dm[TM];
238
+ #else
239
+ FLOAT_TYPE_VEC2 cache_a_dm[TM];
240
+ #endif
241
+
242
+ FLOAT_TYPE_VEC2 cache_b_ds[TN];
243
+
244
+ for (uint block = start_k; block < end_k; block += BK) {
245
+ [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
246
+ const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
247
+ const uint iqs = loadr_a;
248
+ const uint buf_ib = loadc_a + l;
249
+
250
+ // Should ds be gated to a single thread?
251
+ if (iqs == 0) {
252
+ #if QUANT_AUXF == 1
253
+ buf_a_dm[buf_ib] = get_d(ib);
254
+ #else
255
+ buf_a_dm[buf_ib] = get_dm(ib);
256
+ #endif
257
+ }
258
+ #if QUANT_R == 1
259
+ buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
260
+ #else
261
+ const i32vec2 vals = repack(ib, iqs);
262
+ buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
263
+ buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
264
+ #endif
265
+ }
266
+ [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
267
+ #ifdef MUL_MAT_ID
268
+ const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
269
+ const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
270
+ const uint ib = idx / 8;
271
+ const uint iqs = idx & 0x7;
272
+ #else
273
+ const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
274
+ const uint iqs = loadr_b;
275
+ #endif
276
+
277
+ const uint buf_ib = loadc_b + l;
278
+
279
+ // Should ds be gated to a single thread?
280
+ if (iqs == 0) {
281
+ buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
282
+ }
283
+ buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
284
+ }
285
+
286
+ barrier();
287
+
288
+ pos_a_ib += 1;
289
+ pos_b_ib += 1;
290
+
291
+ #ifdef COOPMAT
292
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
293
+ const uint ib_a = warp_r * WM + cm_row * TM;
294
+ // Load from shared into cache
295
+ coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
296
+
297
+ // TODO: only cache values that are actually needed
298
+ [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
299
+ cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
300
+ }
301
+
302
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
303
+ const uint ib_b = warp_c * WN + cm_col * TN;
304
+ coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
305
+
306
+ // TODO: only cache values that are actually needed
307
+ [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
308
+ cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
309
+ }
310
+
311
+ cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
312
+ cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
313
+
314
+ [[unroll]] for (uint col = 0; col < TN; col += storestride) {
315
+ coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
316
+ }
317
+
318
+ coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
319
+ sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
320
+ }
321
+ }
322
+ #else
323
+ // Load from shared into cache
324
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
325
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
326
+ const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
327
+ cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
328
+ [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
329
+ cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
330
+ }
331
+ }
332
+ }
333
+
334
+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
335
+ [[unroll]] for (uint cc = 0; cc < TN; cc++) {
336
+ const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
337
+ cache_b_ds[cc] = buf_b_ds[ib];
338
+ [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
339
+ cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
340
+ }
341
+ }
342
+
343
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
344
+ [[unroll]] for (uint cc = 0; cc < TN; cc++) {
345
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
346
+ const uint cache_a_idx = wsir * TM + cr;
347
+ const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
348
+ int32_t q_sum = 0;
349
+ [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
350
+ q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
351
+ cache_b_qs[cc * (BK / 4) + idx_k]);
352
+ }
353
+
354
+ sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
355
+ }
356
+ }
357
+ }
358
+ }
359
+ #endif
360
+
361
+ barrier();
362
+ }
363
+
364
+ const uint dr = ir * BM + warp_r * WM;
365
+ const uint dc = ic * BN + warp_c * WN;
366
+
367
+ #ifndef MUL_MAT_ID
368
+ const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
369
+ #endif
370
+
371
+ #ifdef COOPMAT
372
+ #ifdef MUL_MAT_ID
373
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
374
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
375
+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
376
+
377
+ [[unroll]] for (uint col = 0; col < BN; col += storestride) {
378
+ const uint row_i = dc + cm_col * TN + col + store_c;
379
+ if (row_i >= _ne1) break;
380
+
381
+ const u16vec2 row_idx = row_ids[row_i];
382
+
383
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
384
+ }
385
+ }
386
+ }
387
+ #else
388
+ const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
389
+
390
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
391
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
392
+ const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
393
+
394
+ if (is_aligned && is_in_bounds) {
395
+ // Full coopMat is within bounds and stride_d is aligned with 16B
396
+ coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
397
+ coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
398
+ } else if (is_in_bounds) {
399
+ // Full coopMat is within bounds, but stride_d is not aligned
400
+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
401
+
402
+ [[unroll]] for (uint col = 0; col < TN; col += storestride) {
403
+ data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
404
+ }
405
+ } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
406
+ // Partial coopMat is within bounds
407
+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
408
+
409
+ [[unroll]] for (uint col = 0; col < TN; col += storestride) {
410
+ if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
411
+ data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
412
+ }
413
+ }
414
+ }
415
+ }
416
+ }
417
+ #endif // MUL_MAT_ID
418
+ #else
419
+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
420
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
421
+
422
+ const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
423
+ const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
424
+ [[unroll]] for (uint cc = 0; cc < TN; cc++) {
425
+ #ifdef MUL_MAT_ID
426
+ const uint row_i = dc_warp + cc;
427
+ if (row_i >= _ne1) break;
428
+
429
+ const u16vec2 row_idx = row_ids[row_i];
430
+ #endif // MUL_MAT_ID
431
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
432
+ #ifdef MUL_MAT_ID
433
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
434
+ #else
435
+ if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
436
+ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
437
+ }
438
+ #endif // MUL_MAT_ID
439
+ }
440
+ }
441
+ }
442
+ }
443
+ #endif // COOPMAT
444
+ }
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
2
+ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
3
+ #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
4
+
5
+ #include "types.comp"
6
+
7
+ // Each iqs value maps to a 32-bit integer
8
+
9
+ #if defined(DATA_A_Q4_0)
10
+ i32vec2 repack(uint ib, uint iqs) {
11
+ // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
12
+ const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
13
+ data_a[ib].qs[iqs * 2 + 1]);
14
+ const uint32_t vui = pack32(quants);
15
+ return i32vec2( vui & 0x0F0F0F0F,
16
+ (vui >> 4) & 0x0F0F0F0F);
17
+ }
18
+
19
+ ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
20
+ return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0 * dsb.y));
21
+ }
22
+ #endif
23
+
24
+ #if defined(DATA_A_Q4_1)
25
+ i32vec2 repack(uint ib, uint iqs) {
26
+ // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
27
+ const uint32_t vui = data_a_packed32[ib].qs[iqs];
28
+ return i32vec2( vui & 0x0F0F0F0F,
29
+ (vui >> 4) & 0x0F0F0F0F);
30
+ }
31
+
32
+ ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
33
+ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
34
+ }
35
+ #endif
36
+
37
+ #if defined(DATA_A_Q5_0)
38
+ i32vec2 repack(uint ib, uint iqs) {
39
+ // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
40
+ const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
41
+ data_a[ib].qs[iqs * 2 + 1]);
42
+ const uint32_t vui = pack32(quants);
43
+ const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
44
+ const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
45
+ | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
46
+
47
+ const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
48
+ | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
49
+
50
+ return i32vec2(v0, v1);
51
+ }
52
+
53
+ ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
54
+ return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0 * dsb.y));
55
+ }
56
+ #endif
57
+
58
+ #if defined(DATA_A_Q5_1)
59
+ i32vec2 repack(uint ib, uint iqs) {
60
+ // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
61
+ const uint32_t vui = data_a_packed32[ib].qs[iqs];
62
+ const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
63
+ const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
64
+ | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
65
+
66
+ const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
67
+ | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
68
+
69
+ return i32vec2(v0, v1);
70
+ }
71
+
72
+ ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
73
+ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
74
+ }
75
+ #endif
76
+
77
+ #if defined(DATA_A_Q8_0)
78
+ int32_t repack(uint ib, uint iqs) {
79
+ // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
80
+ return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
81
+ data_a[ib].qs[iqs * 2 + 1]));
82
+ }
83
+
84
+ ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
85
+ return ACC_TYPE(float(q_sum) * da * dsb.x);
86
+ }
87
+ #endif
88
+
89
+ #if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
90
+ FLOAT_TYPE get_d(uint ib) {
91
+ return FLOAT_TYPE(data_a[ib].d);
92
+ }
93
+ #endif
94
+
95
+ #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
96
+ FLOAT_TYPE_VEC2 get_dm(uint ib) {
97
+ return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
98
+ }
99
+ #endif
ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #extension GL_EXT_control_flow_attributes : require
4
+ #extension GL_EXT_shader_16bit_storage : require
5
+
6
+ layout (push_constant) uniform parameter
7
+ {
8
+ uint ne;
9
+ } p;
10
+
11
+ #include "types.comp"
12
+
13
+ layout(constant_id = 0) const uint GROUP_SIZE = 32;
14
+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
15
+
16
+ layout (binding = 0) readonly buffer A {vec4 data_a[];};
17
+ layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
18
+
19
+ shared float shmem[GROUP_SIZE];
20
+
21
+ void quantize() {
22
+ const uint wgid = gl_WorkGroupID.x;
23
+ const uint tid = gl_LocalInvocationID.x;
24
+
25
+ // Each thread handles a vec4, so 8 threads handle a block
26
+ const uint blocks_per_group = GROUP_SIZE / 8;
27
+
28
+ const uint block_in_wg = tid / 8;
29
+
30
+ const uint ib = wgid * blocks_per_group + block_in_wg;
31
+ const uint iqs = tid % 8;
32
+
33
+ if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
34
+ return;
35
+ }
36
+
37
+ const uint a_idx = ib * 8 + iqs;
38
+
39
+ vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f);
40
+ const vec4 abs_vals = abs(vals);
41
+
42
+ // Find absolute max for each block
43
+ shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
44
+ barrier();
45
+ [[unroll]] for (uint s = 4; s > 0; s >>= 1) {
46
+ if (iqs < s) {
47
+ shmem[tid] = max(shmem[tid], shmem[tid + s]);
48
+ }
49
+ barrier();
50
+ }
51
+
52
+ const float amax = shmem[block_in_wg * 8];
53
+ const float d = amax / 127.0;
54
+ const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
55
+ vals = round(vals * d_inv);
56
+ data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
57
+ barrier();
58
+
59
+ // Calculate the sum for each block
60
+ shmem[tid] = vals.x + vals.y + vals.z + vals.w;
61
+ barrier();
62
+ [[unroll]] for (uint s = 4; s > 0; s >>= 1) {
63
+ if (iqs < s) {
64
+ shmem[tid] += shmem[tid + s];
65
+ }
66
+ barrier();
67
+ }
68
+ if (iqs == 0) {
69
+ const float sum = shmem[tid];
70
+
71
+ data_b[ib].ds = f16vec2(vec2(d, sum * d));
72
+ }
73
+ }
74
+
75
+ void main() {
76
+ quantize();
77
+ }
ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #version 460
2
+
3
+ #extension GL_EXT_integer_dot_product : require
4
+
5
+ void main()
6
+ {
7
+ }
ggml/src/ggml-vulkan/vulkan-shaders/types.comp CHANGED
@@ -1,4 +1,3 @@
1
-
2
  #if !defined(GGML_TYPES_COMP)
3
  #define GGML_TYPES_COMP
4
 
@@ -51,6 +50,7 @@ struct block_q4_0_packed16
51
  #if defined(DATA_A_Q4_0)
52
  #define QUANT_K QUANT_K_Q4_0
53
  #define QUANT_R QUANT_R_Q4_0
 
54
  #define A_TYPE block_q4_0
55
  #define A_TYPE_PACKED16 block_q4_0_packed16
56
  #endif
@@ -72,11 +72,19 @@ struct block_q4_1_packed16
72
  uint16_t qs[16/2];
73
  };
74
 
 
 
 
 
 
 
75
  #if defined(DATA_A_Q4_1)
76
  #define QUANT_K QUANT_K_Q4_1
77
  #define QUANT_R QUANT_R_Q4_1
 
78
  #define A_TYPE block_q4_1
79
  #define A_TYPE_PACKED16 block_q4_1_packed16
 
80
  #endif
81
 
82
  #define QUANT_K_Q5_0 32
@@ -99,6 +107,7 @@ struct block_q5_0_packed16
99
  #if defined(DATA_A_Q5_0)
100
  #define QUANT_K QUANT_K_Q5_0
101
  #define QUANT_R QUANT_R_Q5_0
 
102
  #define A_TYPE block_q5_0
103
  #define A_TYPE_PACKED16 block_q5_0_packed16
104
  #endif
@@ -122,11 +131,20 @@ struct block_q5_1_packed16
122
  uint16_t qs[16/2];
123
  };
124
 
 
 
 
 
 
 
 
125
  #if defined(DATA_A_Q5_1)
126
  #define QUANT_K QUANT_K_Q5_1
127
  #define QUANT_R QUANT_R_Q5_1
 
128
  #define A_TYPE block_q5_1
129
  #define A_TYPE_PACKED16 block_q5_1_packed16
 
130
  #endif
131
 
132
  #define QUANT_K_Q8_0 32
@@ -142,14 +160,40 @@ struct block_q8_0_packed16
142
  float16_t d;
143
  int16_t qs[32/2];
144
  };
 
 
 
 
 
145
 
146
  #if defined(DATA_A_Q8_0)
147
  #define QUANT_K QUANT_K_Q8_0
148
  #define QUANT_R QUANT_R_Q8_0
 
149
  #define A_TYPE block_q8_0
150
  #define A_TYPE_PACKED16 block_q8_0_packed16
 
151
  #endif
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  // K-quants
154
  #define QUANT_K_Q2_K 256
155
 
 
 
1
  #if !defined(GGML_TYPES_COMP)
2
  #define GGML_TYPES_COMP
3
 
 
50
  #if defined(DATA_A_Q4_0)
51
  #define QUANT_K QUANT_K_Q4_0
52
  #define QUANT_R QUANT_R_Q4_0
53
+ #define QUANT_AUXF 1
54
  #define A_TYPE block_q4_0
55
  #define A_TYPE_PACKED16 block_q4_0_packed16
56
  #endif
 
72
  uint16_t qs[16/2];
73
  };
74
 
75
+ struct block_q4_1_packed32
76
+ {
77
+ f16vec2 dm;
78
+ uint32_t qs[16/4];
79
+ };
80
+
81
  #if defined(DATA_A_Q4_1)
82
  #define QUANT_K QUANT_K_Q4_1
83
  #define QUANT_R QUANT_R_Q4_1
84
+ #define QUANT_AUXF 2
85
  #define A_TYPE block_q4_1
86
  #define A_TYPE_PACKED16 block_q4_1_packed16
87
+ #define A_TYPE_PACKED32 block_q4_1_packed32
88
  #endif
89
 
90
  #define QUANT_K_Q5_0 32
 
107
  #if defined(DATA_A_Q5_0)
108
  #define QUANT_K QUANT_K_Q5_0
109
  #define QUANT_R QUANT_R_Q5_0
110
+ #define QUANT_AUXF 1
111
  #define A_TYPE block_q5_0
112
  #define A_TYPE_PACKED16 block_q5_0_packed16
113
  #endif
 
131
  uint16_t qs[16/2];
132
  };
133
 
134
+ struct block_q5_1_packed32
135
+ {
136
+ f16vec2 dm;
137
+ uint qh;
138
+ uint32_t qs[16/4];
139
+ };
140
+
141
  #if defined(DATA_A_Q5_1)
142
  #define QUANT_K QUANT_K_Q5_1
143
  #define QUANT_R QUANT_R_Q5_1
144
+ #define QUANT_AUXF 2
145
  #define A_TYPE block_q5_1
146
  #define A_TYPE_PACKED16 block_q5_1_packed16
147
+ #define A_TYPE_PACKED32 block_q5_1_packed32
148
  #endif
149
 
150
  #define QUANT_K_Q8_0 32
 
160
  float16_t d;
161
  int16_t qs[32/2];
162
  };
163
+ struct block_q8_0_packed32
164
+ {
165
+ float16_t d;
166
+ int32_t qs[32/4];
167
+ };
168
 
169
  #if defined(DATA_A_Q8_0)
170
  #define QUANT_K QUANT_K_Q8_0
171
  #define QUANT_R QUANT_R_Q8_0
172
+ #define QUANT_AUXF 1
173
  #define A_TYPE block_q8_0
174
  #define A_TYPE_PACKED16 block_q8_0_packed16
175
+ #define A_TYPE_PACKED32 block_q8_0_packed32
176
  #endif
177
 
178
+ #define QUANT_K_Q8_1 32
179
+ #define QUANT_R_Q8_1 1
180
+
181
+ struct block_q8_1
182
+ {
183
+ f16vec2 ds;
184
+ int8_t qs[32];
185
+ };
186
+ struct block_q8_1_packed16
187
+ {
188
+ f16vec2 ds;
189
+ int16_t qs[16];
190
+ };
191
+ struct block_q8_1_packed32
192
+ {
193
+ f16vec2 ds;
194
+ int32_t qs[8];
195
+ };
196
+
197
  // K-quants
198
  #define QUANT_K_Q2_K 256
199
 
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -295,7 +295,10 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
295
  std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
296
  std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
297
 
298
- std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}};
 
 
 
299
  std::string shader_name = "matmul";
300
 
301
  if (matmul_id) {
@@ -313,9 +316,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
313
  base_dict["COOPMAT"] = "1";
314
  }
315
 
316
- base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
317
-
318
- std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
319
 
320
  // Shaders with f16 B_TYPE
321
  string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
@@ -339,14 +340,20 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
339
 
340
  // don't generate f32 variants for coopmat2
341
  if (!coopmat2) {
342
- string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
343
- string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
344
  }
345
 
346
  if (tname != "f16" && tname != "f32") {
347
- string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
348
- string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
349
  }
 
 
 
 
 
 
350
  }
351
  }
352
 
@@ -458,6 +465,7 @@ void process_shaders() {
458
  string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
459
 
460
  string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
 
461
 
462
  string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
463
 
 
295
  std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
296
  std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
297
 
298
+ std::map<std::string, std::string> base_dict = {
299
+ {"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"},
300
+ {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
301
+ };
302
  std::string shader_name = "matmul";
303
 
304
  if (matmul_id) {
 
316
  base_dict["COOPMAT"] = "1";
317
  }
318
 
319
+ const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
 
 
320
 
321
  // Shaders with f16 B_TYPE
322
  string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
 
340
 
341
  // don't generate f32 variants for coopmat2
342
  if (!coopmat2) {
343
+ string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
344
+ string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
345
  }
346
 
347
  if (tname != "f16" && tname != "f32") {
348
+ string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
349
+ string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
350
  }
351
+
352
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
353
+ if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
354
+ string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
355
+ }
356
+ #endif
357
  }
358
  }
359
 
 
465
  string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
466
 
467
  string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
468
+ string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
469
 
470
  string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
471