Spaces:
Sleeping
Sleeping
Commit
·
06ec111
1
Parent(s):
a596e84
Vulkan: Add DP4A MMQ and Q8_1 quantization shader (llama/12135)
Browse files* Vulkan: Add DP4A MMQ and Q8_1 quantization shader
* Add q4_0 x q8_1 matrix matrix multiplication support
* Vulkan: Add int8 coopmat MMQ support
* Vulkan: Add q4_1, q5_0 and q5_1 quants, improve integer dot code
* Add GL_EXT_integer_dot_product check
* Remove ggml changes, fix mmq pipeline picker
* Remove ggml changes, restore Intel coopmat behaviour
* Fix glsl compile attempt when integer vec dot is not supported
* Remove redundant code, use non-saturating integer dot, enable all matmul sizes for mmq
* Remove redundant comment
* Fix integer dot check
* Fix compile issue with unsupported int dot glslc
* Update Windows build Vulkan SDK version
- ggml/src/ggml-vulkan/CMakeLists.txt +14 -0
- ggml/src/ggml-vulkan/ggml-vulkan.cpp +440 -80
- ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +3 -5
- ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +444 -0
- ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
- ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
- ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
- ggml/src/ggml-vulkan/vulkan-shaders/types.comp +45 -1
- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +16 -8
ggml/src/ggml-vulkan/CMakeLists.txt
CHANGED
|
@@ -69,6 +69,20 @@ if (Vulkan_FOUND)
|
|
| 69 |
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
| 70 |
endif()
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
|
| 73 |
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
| 74 |
|
|
|
|
| 69 |
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
| 70 |
endif()
|
| 71 |
|
| 72 |
+
# Compile a test shader to determine whether GL_EXT_integer_dot_product is supported.
|
| 73 |
+
# If it's not, there will be an error to stderr.
|
| 74 |
+
# If it's supported, set a define to indicate that we should compile those shaders
|
| 75 |
+
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
|
| 76 |
+
OUTPUT_VARIABLE glslc_output
|
| 77 |
+
ERROR_VARIABLE glslc_error)
|
| 78 |
+
|
| 79 |
+
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*")
|
| 80 |
+
message(STATUS "GL_EXT_integer_dot_product not supported by glslc")
|
| 81 |
+
else()
|
| 82 |
+
message(STATUS "GL_EXT_integer_dot_product supported by glslc")
|
| 83 |
+
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 84 |
+
endif()
|
| 85 |
+
|
| 86 |
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
|
| 87 |
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
| 88 |
|
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -234,6 +234,8 @@ struct vk_device_struct {
|
|
| 234 |
bool float_controls_rte_fp16;
|
| 235 |
bool subgroup_add;
|
| 236 |
|
|
|
|
|
|
|
| 237 |
bool subgroup_size_control;
|
| 238 |
uint32_t subgroup_min_size;
|
| 239 |
uint32_t subgroup_max_size;
|
|
@@ -245,6 +247,12 @@ struct vk_device_struct {
|
|
| 245 |
uint32_t coopmat_m;
|
| 246 |
uint32_t coopmat_n;
|
| 247 |
uint32_t coopmat_k;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
bool coopmat2;
|
| 249 |
|
| 250 |
size_t idx;
|
|
@@ -263,10 +271,10 @@ struct vk_device_struct {
|
|
| 263 |
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
|
| 264 |
vk_matmul_pipeline2 pipeline_matmul_f16;
|
| 265 |
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
| 266 |
-
vk_pipeline pipeline_matmul_split_k_reduce;
|
| 267 |
|
| 268 |
-
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
|
| 269 |
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
|
|
|
|
|
|
|
| 270 |
|
| 271 |
vk_matmul_pipeline pipeline_matmul_id_f32 {};
|
| 272 |
vk_matmul_pipeline2 pipeline_matmul_id_f16;
|
|
@@ -274,6 +282,9 @@ struct vk_device_struct {
|
|
| 274 |
|
| 275 |
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
|
| 276 |
|
|
|
|
|
|
|
|
|
|
| 277 |
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
|
| 278 |
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
| 279 |
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
|
@@ -640,6 +651,13 @@ struct vk_op_rwkv_wkv7_push_constants {
|
|
| 640 |
uint32_t H;
|
| 641 |
};
|
| 642 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
// Allow pre-recording command buffers
|
| 644 |
struct vk_staging_memcpy {
|
| 645 |
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
|
@@ -649,13 +667,6 @@ struct vk_staging_memcpy {
|
|
| 649 |
size_t n;
|
| 650 |
};
|
| 651 |
|
| 652 |
-
struct vk_op_upscale_push_constants {
|
| 653 |
-
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
|
| 654 |
-
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
| 655 |
-
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
|
| 656 |
-
float sf0; float sf1; float sf2; float sf3;
|
| 657 |
-
};
|
| 658 |
-
|
| 659 |
struct vk_context_struct {
|
| 660 |
vk_submission * s;
|
| 661 |
std::vector<vk_sequence> seqs;
|
|
@@ -1598,6 +1609,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1598 |
// mulmat
|
| 1599 |
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
|
| 1600 |
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
|
|
|
|
| 1601 |
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
|
| 1602 |
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
|
| 1603 |
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
|
|
@@ -1662,6 +1674,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1662 |
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
| 1663 |
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
| 1664 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1665 |
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
| 1666 |
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
|
| 1667 |
s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
|
|
@@ -2000,6 +2026,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2000 |
if (device->mul_mat ## ID ## _s[TYPE]) \
|
| 2001 |
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
| 2002 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2003 |
// Create 2 variants, {f16,f32} accumulator
|
| 2004 |
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
| 2005 |
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
@@ -2031,6 +2065,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2031 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2032 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2033 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2034 |
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2035 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2036 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
@@ -2056,6 +2100,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2056 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2057 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2058 |
#undef CREATE_MM2
|
|
|
|
| 2059 |
#undef CREATE_MM
|
| 2060 |
} else {
|
| 2061 |
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
@@ -2073,6 +2118,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2073 |
if (device->mul_mat ## ID ## _s[TYPE]) \
|
| 2074 |
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
| 2075 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2076 |
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2077 |
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2078 |
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
@@ -2099,6 +2152,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2099 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2100 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2102 |
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2103 |
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2104 |
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
@@ -2132,7 +2195,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2132 |
uint32_t rm_stdq = 1;
|
| 2133 |
uint32_t rm_kq = 2;
|
| 2134 |
if (device->vendor_id == VK_VENDOR_ID_AMD) {
|
| 2135 |
-
if (device->
|
| 2136 |
rm_stdq = 2;
|
| 2137 |
rm_kq = 4;
|
| 2138 |
}
|
|
@@ -2266,6 +2329,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2266 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2267 |
|
| 2268 |
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
|
|
|
| 2269 |
|
| 2270 |
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
| 2271 |
if (device->subgroup_add && device->subgroup_require_full_support) {
|
|
@@ -2452,6 +2516,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 2452 |
bool pipeline_robustness = false;
|
| 2453 |
bool coopmat2_support = false;
|
| 2454 |
device->coopmat_support = false;
|
|
|
|
| 2455 |
|
| 2456 |
for (const auto& properties : ext_props) {
|
| 2457 |
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
|
@@ -2477,6 +2542,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 2477 |
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
| 2478 |
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
| 2479 |
coopmat2_support = true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2480 |
}
|
| 2481 |
}
|
| 2482 |
|
|
@@ -2490,6 +2560,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 2490 |
vk::PhysicalDeviceVulkan11Properties vk11_props;
|
| 2491 |
vk::PhysicalDeviceVulkan12Properties vk12_props;
|
| 2492 |
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
|
|
|
| 2493 |
|
| 2494 |
props2.pNext = &props3;
|
| 2495 |
props3.pNext = &subgroup_props;
|
|
@@ -2524,6 +2595,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 2524 |
}
|
| 2525 |
#endif
|
| 2526 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2527 |
device->physical_device.getProperties2(&props2);
|
| 2528 |
device->properties = props2.properties;
|
| 2529 |
device->vendor_id = device->properties.vendorID;
|
|
@@ -2570,6 +2646,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 2570 |
device->coopmat_support = false;
|
| 2571 |
}
|
| 2572 |
|
|
|
|
|
|
|
| 2573 |
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
| 2574 |
|
| 2575 |
// Try to find a non-graphics compute queue and transfer-focused queues
|
|
@@ -2662,6 +2740,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 2662 |
device_extensions.push_back("VK_KHR_maintenance4");
|
| 2663 |
}
|
| 2664 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2665 |
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
| 2666 |
|
| 2667 |
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
|
|
@@ -2831,6 +2917,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 2831 |
device->coopmat_acc_f16_support = true;
|
| 2832 |
}
|
| 2833 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2834 |
}
|
| 2835 |
}
|
| 2836 |
|
|
@@ -2935,25 +3032,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
| 2935 |
vk::PhysicalDevice physical_device = devices[dev_num];
|
| 2936 |
std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
|
| 2937 |
|
| 2938 |
-
vk::PhysicalDeviceProperties2 props2;
|
| 2939 |
-
vk::PhysicalDeviceMaintenance3Properties props3;
|
| 2940 |
-
vk::PhysicalDeviceSubgroupProperties subgroup_props;
|
| 2941 |
-
vk::PhysicalDeviceDriverProperties driver_props;
|
| 2942 |
-
props2.pNext = &props3;
|
| 2943 |
-
props3.pNext = &subgroup_props;
|
| 2944 |
-
subgroup_props.pNext = &driver_props;
|
| 2945 |
-
physical_device.getProperties2(&props2);
|
| 2946 |
-
|
| 2947 |
-
vk_device_architecture arch = get_device_architecture(physical_device);
|
| 2948 |
-
uint32_t default_subgroup_size = get_subgroup_size("", arch);
|
| 2949 |
-
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
|
| 2950 |
-
|
| 2951 |
-
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
| 2952 |
-
|
| 2953 |
bool fp16_storage = false;
|
| 2954 |
bool fp16_compute = false;
|
| 2955 |
bool coopmat_support = false;
|
| 2956 |
bool coopmat2_support = false;
|
|
|
|
| 2957 |
|
| 2958 |
for (auto properties : ext_props) {
|
| 2959 |
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
|
@@ -2969,27 +3052,44 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
| 2969 |
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
| 2970 |
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
| 2971 |
coopmat2_support = true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2972 |
#endif
|
| 2973 |
}
|
| 2974 |
}
|
| 2975 |
|
| 2976 |
const vk_device_architecture device_architecture = get_device_architecture(physical_device);
|
| 2977 |
|
| 2978 |
-
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
|
| 2979 |
-
coopmat_support = false;
|
| 2980 |
-
}
|
| 2981 |
-
|
| 2982 |
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
|
| 2983 |
bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
|
| 2984 |
|
| 2985 |
bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
| 2986 |
|
| 2987 |
-
vk::
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2988 |
|
| 2989 |
VkPhysicalDeviceFeatures2 device_features2;
|
| 2990 |
device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
|
| 2991 |
device_features2.pNext = nullptr;
|
| 2992 |
-
device_features2.features = (VkPhysicalDeviceFeatures)device_features;
|
| 2993 |
|
| 2994 |
VkPhysicalDeviceVulkan11Features vk11_features;
|
| 2995 |
vk11_features.pNext = nullptr;
|
|
@@ -3002,7 +3102,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
| 3002 |
vk11_features.pNext = &vk12_features;
|
| 3003 |
|
| 3004 |
// Pointer to the last chain element
|
| 3005 |
-
|
| 3006 |
|
| 3007 |
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
| 3008 |
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
|
|
@@ -3014,20 +3114,37 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
| 3014 |
last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
|
| 3015 |
last_struct = (VkBaseOutStructure *)&coopmat_features;
|
| 3016 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3017 |
|
| 3018 |
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
|
| 3019 |
|
| 3020 |
fp16 = fp16 && vk12_features.shaderFloat16;
|
| 3021 |
|
| 3022 |
-
|
| 3023 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3024 |
|
| 3025 |
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
|
| 3026 |
|
| 3027 |
std::string device_name = props2.properties.deviceName.data();
|
| 3028 |
-
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
|
| 3029 |
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
|
| 3030 |
-
props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
|
| 3031 |
|
| 3032 |
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
|
| 3033 |
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
|
|
@@ -3293,6 +3410,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
| 3293 |
}
|
| 3294 |
}
|
| 3295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3296 |
if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
|
| 3297 |
return nullptr;
|
| 3298 |
}
|
|
@@ -3585,8 +3713,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
|
|
| 3585 |
return s;
|
| 3586 |
}
|
| 3587 |
|
| 3588 |
-
|
| 3589 |
-
|
| 3590 |
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
|
| 3591 |
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
|
| 3592 |
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
|
|
@@ -4016,8 +4142,8 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
|
|
| 4016 |
return split_k;
|
| 4017 |
}
|
| 4018 |
|
| 4019 |
-
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp,
|
| 4020 |
-
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
| 4021 |
|
| 4022 |
if (ctx->device->coopmat2) {
|
| 4023 |
// Use large shader when the N dimension is greater than the medium shader's tile size
|
|
@@ -4042,9 +4168,9 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
|
|
| 4042 |
return aligned ? mmp->a_l : mmp->l;
|
| 4043 |
}
|
| 4044 |
|
| 4045 |
-
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
|
| 4046 |
-
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
|
| 4047 |
-
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
|
| 4048 |
}
|
| 4049 |
|
| 4050 |
static void ggml_vk_matmul(
|
|
@@ -4054,7 +4180,7 @@ static void ggml_vk_matmul(
|
|
| 4054 |
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
|
| 4055 |
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
|
| 4056 |
uint32_t padded_n) {
|
| 4057 |
-
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
|
| 4058 |
ggml_vk_sync_buffers(subctx);
|
| 4059 |
if (split_k == 1) {
|
| 4060 |
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
|
@@ -4072,7 +4198,7 @@ static void ggml_vk_matmul(
|
|
| 4072 |
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
| 4073 |
}
|
| 4074 |
|
| 4075 |
-
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp,
|
| 4076 |
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
| 4077 |
|
| 4078 |
if (ctx->device->coopmat2) {
|
|
@@ -4214,6 +4340,25 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
|
|
| 4214 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
|
| 4215 |
}
|
| 4216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4217 |
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 4218 |
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
| 4219 |
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
|
@@ -4265,10 +4410,19 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 4265 |
|
| 4266 |
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
| 4267 |
|
| 4268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4269 |
|
| 4270 |
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
| 4271 |
-
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
| 4272 |
|
| 4273 |
if (qx_needs_dequant) {
|
| 4274 |
// Fall back to dequant + f16 mulmat
|
|
@@ -4278,13 +4432,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 4278 |
// Not implemented
|
| 4279 |
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
| 4280 |
|
| 4281 |
-
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
|
| 4282 |
-
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
| 4283 |
|
| 4284 |
-
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
|
| 4285 |
|
| 4286 |
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
| 4287 |
-
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
| 4288 |
const int x_ne = ne01 * ne00;
|
| 4289 |
const int y_ne = padded_n * ne10;
|
| 4290 |
const int d_ne = ne11 * ne01;
|
|
@@ -4294,11 +4448,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 4294 |
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
| 4295 |
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
| 4296 |
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
| 4297 |
-
const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
|
| 4298 |
const uint64_t d_sz = sizeof(float) * d_ne;
|
| 4299 |
|
| 4300 |
vk_pipeline to_fp16_vk_0 = nullptr;
|
| 4301 |
vk_pipeline to_fp16_vk_1 = nullptr;
|
|
|
|
| 4302 |
|
| 4303 |
if (x_non_contig) {
|
| 4304 |
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
|
|
@@ -4313,6 +4468,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 4313 |
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
| 4314 |
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
| 4315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4316 |
if (dryrun) {
|
| 4317 |
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
|
| 4318 |
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
|
@@ -4326,7 +4485,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 4326 |
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
|
| 4327 |
ctx->prealloc_size_x = x_sz_upd;
|
| 4328 |
}
|
| 4329 |
-
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
|
| 4330 |
ctx->prealloc_size_y = y_sz_upd;
|
| 4331 |
}
|
| 4332 |
if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
|
|
@@ -4341,6 +4500,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 4341 |
if (qy_needs_dequant) {
|
| 4342 |
ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
|
| 4343 |
}
|
|
|
|
|
|
|
|
|
|
| 4344 |
if (split_k > 1) {
|
| 4345 |
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
|
| 4346 |
}
|
|
@@ -4376,6 +4538,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 4376 |
if (qy_needs_dequant) {
|
| 4377 |
d_Y = ctx->prealloc_y;
|
| 4378 |
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
|
|
|
|
|
|
|
|
|
|
| 4379 |
} else {
|
| 4380 |
d_Y = d_Qy;
|
| 4381 |
y_buf_offset = qy_buf_offset;
|
|
@@ -4392,6 +4557,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 4392 |
if (y_non_contig) {
|
| 4393 |
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
| 4394 |
}
|
|
|
|
|
|
|
|
|
|
| 4395 |
|
| 4396 |
uint32_t stride_batch_x = ne00*ne01;
|
| 4397 |
uint32_t stride_batch_y = ne10*ne11;
|
|
@@ -4400,7 +4568,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 4400 |
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
|
| 4401 |
}
|
| 4402 |
|
| 4403 |
-
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
|
| 4404 |
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
|
| 4405 |
}
|
| 4406 |
|
|
@@ -6929,6 +7097,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|
| 6929 |
}
|
| 6930 |
}
|
| 6931 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6932 |
ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
| 6933 |
|
| 6934 |
vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
@@ -7177,6 +7349,10 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
|
|
| 7177 |
|
| 7178 |
ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
|
| 7179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7180 |
ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
| 7181 |
|
| 7182 |
ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
|
|
@@ -7236,66 +7412,198 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
|
|
| 7236 |
free(x_chk);
|
| 7237 |
}
|
| 7238 |
|
| 7239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7240 |
VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
|
| 7241 |
const size_t x_ne = m * k * batch;
|
| 7242 |
const size_t y_ne = k * n * batch;
|
| 7243 |
const size_t d_ne = m * n * batch;
|
| 7244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7245 |
vk_pipeline p;
|
| 7246 |
std::string shname;
|
| 7247 |
if (shader_size == 0) {
|
| 7248 |
-
p =
|
| 7249 |
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
|
| 7250 |
} else if (shader_size == 1) {
|
| 7251 |
-
p =
|
| 7252 |
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
|
| 7253 |
} else if (shader_size == 2) {
|
| 7254 |
-
p =
|
| 7255 |
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
|
| 7256 |
} else {
|
| 7257 |
GGML_ASSERT(0);
|
| 7258 |
}
|
| 7259 |
|
| 7260 |
-
const size_t kpad = ggml_vk_align_size(k, p->align);
|
| 7261 |
|
| 7262 |
-
if (k != kpad) {
|
| 7263 |
if (shader_size == 0) {
|
| 7264 |
-
p =
|
| 7265 |
shname = std::string(ggml_type_name(quant)) + "_S";
|
| 7266 |
} else if (shader_size == 1) {
|
| 7267 |
-
p =
|
| 7268 |
shname = std::string(ggml_type_name(quant)) + "_M";
|
| 7269 |
} else if (shader_size == 2) {
|
| 7270 |
-
p =
|
| 7271 |
shname = std::string(ggml_type_name(quant)) + "_L";
|
| 7272 |
} else {
|
| 7273 |
GGML_ASSERT(0);
|
| 7274 |
}
|
| 7275 |
}
|
| 7276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7277 |
const size_t x_sz = sizeof(float) * x_ne;
|
| 7278 |
const size_t y_sz = sizeof(float) * y_ne;
|
| 7279 |
const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
|
|
|
|
| 7280 |
const size_t d_sz = sizeof(float) * d_ne;
|
| 7281 |
float * x = (float *) malloc(x_sz);
|
| 7282 |
float * y = (float *) malloc(y_sz);
|
| 7283 |
void * qx = malloc(qx_sz);
|
| 7284 |
vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
| 7285 |
vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
|
|
| 7286 |
vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
| 7287 |
float * d = (float *) malloc(d_sz);
|
| 7288 |
float * d_chk = (float *) malloc(d_sz);
|
| 7289 |
|
| 7290 |
for (size_t i = 0; i < x_ne; i++) {
|
| 7291 |
x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
|
|
|
|
|
|
|
| 7292 |
}
|
| 7293 |
|
| 7294 |
ggml_vk_quantize_data(x, qx, x_ne, quant);
|
| 7295 |
|
| 7296 |
for (size_t i = 0; i < y_ne; i++) {
|
| 7297 |
-
|
| 7298 |
-
y[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
|
|
|
| 7299 |
}
|
| 7300 |
|
| 7301 |
ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
|
|
@@ -7310,6 +7618,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
| 7310 |
ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
| 7311 |
}
|
| 7312 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7313 |
|
| 7314 |
ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
| 7315 |
|
|
@@ -7318,13 +7633,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
| 7318 |
|
| 7319 |
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
| 7320 |
ggml_vk_ctx_begin(ctx->device, subctx);
|
| 7321 |
-
|
| 7322 |
-
|
| 7323 |
-
ctx, subctx,
|
| 7324 |
-
|
| 7325 |
-
|
| 7326 |
-
|
| 7327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7328 |
}
|
| 7329 |
ggml_vk_ctx_end(subctx);
|
| 7330 |
|
|
@@ -7382,7 +7709,11 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
| 7382 |
|
| 7383 |
double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
|
| 7384 |
|
| 7385 |
-
std::cerr << "TEST
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7386 |
|
| 7387 |
if (avg_err > 0.01 || std::isnan(avg_err)) {
|
| 7388 |
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
|
|
@@ -7392,6 +7723,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
| 7392 |
std::cerr << "Expected result: " << std::endl << std::endl;
|
| 7393 |
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
| 7394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7395 |
if (split_k > 1) {
|
| 7396 |
float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
|
| 7397 |
ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
|
|
@@ -7414,6 +7751,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
| 7414 |
|
| 7415 |
ggml_vk_destroy_buffer(qx_buf);
|
| 7416 |
ggml_vk_destroy_buffer(y_buf);
|
|
|
|
| 7417 |
ggml_vk_destroy_buffer(d_buf);
|
| 7418 |
|
| 7419 |
free(x);
|
|
@@ -7446,7 +7784,25 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|
| 7446 |
128, 49, 49,
|
| 7447 |
4096, 49, 4096,
|
| 7448 |
};
|
| 7449 |
-
const size_t num_it =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7450 |
|
| 7451 |
for (size_t i = 0; i < vals.size(); i += 3) {
|
| 7452 |
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
|
|
@@ -9258,7 +9614,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
| 9258 |
}
|
| 9259 |
|
| 9260 |
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
| 9261 |
-
const float *params = (const float *)tensor->op_params;
|
| 9262 |
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
|
| 9263 |
} else if (tensor->op == GGML_OP_MUL_MAT) {
|
| 9264 |
tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
|
|
@@ -9275,7 +9631,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
| 9275 |
} else if (tensor->op == GGML_OP_UPSCALE) {
|
| 9276 |
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
| 9277 |
} else if (tensor->op == GGML_OP_SCALE) {
|
| 9278 |
-
|
|
|
|
| 9279 |
} else if (tensor->op == GGML_OP_SQR) {
|
| 9280 |
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
|
| 9281 |
} else if (tensor->op == GGML_OP_SIN) {
|
|
@@ -9283,7 +9640,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
| 9283 |
} else if (tensor->op == GGML_OP_COS) {
|
| 9284 |
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
|
| 9285 |
} else if (tensor->op == GGML_OP_CLAMP) {
|
| 9286 |
-
|
|
|
|
| 9287 |
} else if (tensor->op == GGML_OP_PAD) {
|
| 9288 |
tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
|
| 9289 |
} else if (tensor->op == GGML_OP_REPEAT) {
|
|
@@ -9297,7 +9655,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
| 9297 |
} else if (tensor->op == GGML_OP_NORM) {
|
| 9298 |
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
| 9299 |
} else if (tensor->op == GGML_OP_GROUP_NORM) {
|
| 9300 |
-
|
|
|
|
| 9301 |
} else if (tensor->op == GGML_OP_RMS_NORM) {
|
| 9302 |
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
| 9303 |
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
|
|
@@ -9310,14 +9669,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
| 9310 |
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
|
| 9311 |
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
| 9312 |
if (src1 != nullptr) {
|
| 9313 |
-
|
|
|
|
| 9314 |
} else {
|
| 9315 |
tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
|
| 9316 |
}
|
| 9317 |
} else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
|
| 9318 |
tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
| 9319 |
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
|
| 9320 |
-
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0],
|
| 9321 |
} else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
|
| 9322 |
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
| 9323 |
const int mode = ((int32_t *) tensor->op_params)[2];
|
|
|
|
| 234 |
bool float_controls_rte_fp16;
|
| 235 |
bool subgroup_add;
|
| 236 |
|
| 237 |
+
bool integer_dot_product;
|
| 238 |
+
|
| 239 |
bool subgroup_size_control;
|
| 240 |
uint32_t subgroup_min_size;
|
| 241 |
uint32_t subgroup_max_size;
|
|
|
|
| 247 |
uint32_t coopmat_m;
|
| 248 |
uint32_t coopmat_n;
|
| 249 |
uint32_t coopmat_k;
|
| 250 |
+
|
| 251 |
+
bool coopmat_int_support;
|
| 252 |
+
uint32_t coopmat_int_m;
|
| 253 |
+
uint32_t coopmat_int_n;
|
| 254 |
+
uint32_t coopmat_int_k;
|
| 255 |
+
|
| 256 |
bool coopmat2;
|
| 257 |
|
| 258 |
size_t idx;
|
|
|
|
| 271 |
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
|
| 272 |
vk_matmul_pipeline2 pipeline_matmul_f16;
|
| 273 |
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
|
|
|
| 274 |
|
|
|
|
| 275 |
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
|
| 276 |
+
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
|
| 277 |
+
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
|
| 278 |
|
| 279 |
vk_matmul_pipeline pipeline_matmul_id_f32 {};
|
| 280 |
vk_matmul_pipeline2 pipeline_matmul_id_f16;
|
|
|
|
| 282 |
|
| 283 |
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
|
| 284 |
|
| 285 |
+
vk_pipeline pipeline_matmul_split_k_reduce;
|
| 286 |
+
vk_pipeline pipeline_quantize_q8_1;
|
| 287 |
+
|
| 288 |
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
|
| 289 |
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
| 290 |
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
|
|
|
| 651 |
uint32_t H;
|
| 652 |
};
|
| 653 |
|
| 654 |
+
struct vk_op_upscale_push_constants {
|
| 655 |
+
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
|
| 656 |
+
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
| 657 |
+
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
|
| 658 |
+
float sf0; float sf1; float sf2; float sf3;
|
| 659 |
+
};
|
| 660 |
+
|
| 661 |
// Allow pre-recording command buffers
|
| 662 |
struct vk_staging_memcpy {
|
| 663 |
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
|
|
|
| 667 |
size_t n;
|
| 668 |
};
|
| 669 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 670 |
struct vk_context_struct {
|
| 671 |
vk_submission * s;
|
| 672 |
std::vector<vk_sequence> seqs;
|
|
|
|
| 1609 |
// mulmat
|
| 1610 |
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
|
| 1611 |
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
|
| 1612 |
+
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
|
| 1613 |
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
|
| 1614 |
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
|
| 1615 |
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
|
|
|
|
| 1674 |
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
| 1675 |
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
| 1676 |
|
| 1677 |
+
const uint32_t tm_int_l = device->coopmat_int_support ? device->coopmat_int_m : 4;
|
| 1678 |
+
const uint32_t tm_int_m = device->coopmat_int_support ? device->coopmat_int_m : 4;
|
| 1679 |
+
const uint32_t tm_int_s = device->coopmat_int_support ? device->coopmat_int_m : 2;
|
| 1680 |
+
const uint32_t tn_int_l = device->coopmat_int_support ? device->coopmat_int_n : 4;
|
| 1681 |
+
const uint32_t tn_int_m = device->coopmat_int_support ? device->coopmat_int_n : 2;
|
| 1682 |
+
const uint32_t tn_int_s = device->coopmat_int_support ? device->coopmat_int_n : 2;
|
| 1683 |
+
const uint32_t tk_int_l = device->coopmat_int_support ? device->coopmat_int_k : 1;
|
| 1684 |
+
const uint32_t tk_int_m = device->coopmat_int_support ? device->coopmat_int_k : 1;
|
| 1685 |
+
const uint32_t tk_int_s = device->coopmat_int_support ? device->coopmat_int_k : 1;
|
| 1686 |
+
|
| 1687 |
+
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_int_l, tn_int_l, tk_int_l, subgroup_size_8 };
|
| 1688 |
+
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_int_m, tn_int_m, tk_int_m, subgroup_size_8 };
|
| 1689 |
+
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_int_s, tn_int_s, tk_int_s, subgroup_size_8 };
|
| 1690 |
+
|
| 1691 |
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
| 1692 |
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
|
| 1693 |
s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
|
|
|
|
| 2026 |
if (device->mul_mat ## ID ## _s[TYPE]) \
|
| 2027 |
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
| 2028 |
|
| 2029 |
+
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
| 2030 |
+
if (device->mul_mat ## ID ## _l[TYPE]) \
|
| 2031 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
| 2032 |
+
if (device->mul_mat ## ID ## _m[TYPE]) \
|
| 2033 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
| 2034 |
+
if (device->mul_mat ## ID ## _s[TYPE]) \
|
| 2035 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
| 2036 |
+
|
| 2037 |
// Create 2 variants, {f16,f32} accumulator
|
| 2038 |
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
| 2039 |
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
|
|
| 2065 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2066 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2067 |
|
| 2068 |
+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 2069 |
+
if (device->integer_dot_product) {
|
| 2070 |
+
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
| 2071 |
+
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
| 2072 |
+
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
| 2073 |
+
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
| 2074 |
+
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
| 2075 |
+
}
|
| 2076 |
+
#endif
|
| 2077 |
+
|
| 2078 |
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2079 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2080 |
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
|
|
| 2100 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2101 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 2102 |
#undef CREATE_MM2
|
| 2103 |
+
#undef CREATE_MMQ
|
| 2104 |
#undef CREATE_MM
|
| 2105 |
} else {
|
| 2106 |
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
|
|
| 2118 |
if (device->mul_mat ## ID ## _s[TYPE]) \
|
| 2119 |
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
| 2120 |
|
| 2121 |
+
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
| 2122 |
+
if (device->mul_mat ## ID ## _l[TYPE]) \
|
| 2123 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
| 2124 |
+
if (device->mul_mat ## ID ## _m[TYPE]) \
|
| 2125 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
| 2126 |
+
if (device->mul_mat ## ID ## _s[TYPE]) \
|
| 2127 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
| 2128 |
+
|
| 2129 |
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2130 |
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 2131 |
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
|
|
| 2152 |
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2153 |
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 2154 |
|
| 2155 |
+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 2156 |
+
if (device->integer_dot_product) {
|
| 2157 |
+
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
| 2158 |
+
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
| 2159 |
+
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
| 2160 |
+
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
| 2161 |
+
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
| 2162 |
+
}
|
| 2163 |
+
#endif
|
| 2164 |
+
|
| 2165 |
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2166 |
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 2167 |
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
|
|
| 2195 |
uint32_t rm_stdq = 1;
|
| 2196 |
uint32_t rm_kq = 2;
|
| 2197 |
if (device->vendor_id == VK_VENDOR_ID_AMD) {
|
| 2198 |
+
if (device->architecture == AMD_GCN) {
|
| 2199 |
rm_stdq = 2;
|
| 2200 |
rm_kq = 4;
|
| 2201 |
}
|
|
|
|
| 2329 |
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
| 2330 |
|
| 2331 |
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
| 2332 |
+
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
| 2333 |
|
| 2334 |
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
| 2335 |
if (device->subgroup_add && device->subgroup_require_full_support) {
|
|
|
|
| 2516 |
bool pipeline_robustness = false;
|
| 2517 |
bool coopmat2_support = false;
|
| 2518 |
device->coopmat_support = false;
|
| 2519 |
+
device->integer_dot_product = false;
|
| 2520 |
|
| 2521 |
for (const auto& properties : ext_props) {
|
| 2522 |
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
|
|
|
| 2542 |
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
| 2543 |
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
| 2544 |
coopmat2_support = true;
|
| 2545 |
+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 2546 |
+
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
| 2547 |
+
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
| 2548 |
+
device->integer_dot_product = true;
|
| 2549 |
+
#endif
|
| 2550 |
}
|
| 2551 |
}
|
| 2552 |
|
|
|
|
| 2560 |
vk::PhysicalDeviceVulkan11Properties vk11_props;
|
| 2561 |
vk::PhysicalDeviceVulkan12Properties vk12_props;
|
| 2562 |
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
| 2563 |
+
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
|
| 2564 |
|
| 2565 |
props2.pNext = &props3;
|
| 2566 |
props3.pNext = &subgroup_props;
|
|
|
|
| 2595 |
}
|
| 2596 |
#endif
|
| 2597 |
|
| 2598 |
+
if (device->integer_dot_product) {
|
| 2599 |
+
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
| 2600 |
+
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
| 2601 |
+
}
|
| 2602 |
+
|
| 2603 |
device->physical_device.getProperties2(&props2);
|
| 2604 |
device->properties = props2.properties;
|
| 2605 |
device->vendor_id = device->properties.vendorID;
|
|
|
|
| 2646 |
device->coopmat_support = false;
|
| 2647 |
}
|
| 2648 |
|
| 2649 |
+
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
|
| 2650 |
+
|
| 2651 |
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
| 2652 |
|
| 2653 |
// Try to find a non-graphics compute queue and transfer-focused queues
|
|
|
|
| 2740 |
device_extensions.push_back("VK_KHR_maintenance4");
|
| 2741 |
}
|
| 2742 |
|
| 2743 |
+
VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
|
| 2744 |
+
shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
|
| 2745 |
+
if (device->integer_dot_product) {
|
| 2746 |
+
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
|
| 2747 |
+
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
|
| 2748 |
+
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
|
| 2749 |
+
}
|
| 2750 |
+
|
| 2751 |
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
| 2752 |
|
| 2753 |
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
|
|
|
|
| 2917 |
device->coopmat_acc_f16_support = true;
|
| 2918 |
}
|
| 2919 |
}
|
| 2920 |
+
} else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
|
| 2921 |
+
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
|
| 2922 |
+
(vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 &&
|
| 2923 |
+
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
|
| 2924 |
+
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
|
| 2925 |
+
device->coopmat_int_m == 0
|
| 2926 |
+
) {
|
| 2927 |
+
device->coopmat_int_support = true;
|
| 2928 |
+
device->coopmat_int_m = prop.MSize;
|
| 2929 |
+
device->coopmat_int_n = prop.NSize;
|
| 2930 |
+
device->coopmat_int_k = prop.KSize;
|
| 2931 |
}
|
| 2932 |
}
|
| 2933 |
|
|
|
|
| 3032 |
vk::PhysicalDevice physical_device = devices[dev_num];
|
| 3033 |
std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
|
| 3034 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3035 |
bool fp16_storage = false;
|
| 3036 |
bool fp16_compute = false;
|
| 3037 |
bool coopmat_support = false;
|
| 3038 |
bool coopmat2_support = false;
|
| 3039 |
+
bool integer_dot_product = false;
|
| 3040 |
|
| 3041 |
for (auto properties : ext_props) {
|
| 3042 |
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
|
|
|
| 3052 |
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
| 3053 |
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
| 3054 |
coopmat2_support = true;
|
| 3055 |
+
#endif
|
| 3056 |
+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 3057 |
+
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
| 3058 |
+
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
| 3059 |
+
integer_dot_product = true;
|
| 3060 |
#endif
|
| 3061 |
}
|
| 3062 |
}
|
| 3063 |
|
| 3064 |
const vk_device_architecture device_architecture = get_device_architecture(physical_device);
|
| 3065 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3066 |
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
|
| 3067 |
bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
|
| 3068 |
|
| 3069 |
bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
| 3070 |
|
| 3071 |
+
vk::PhysicalDeviceProperties2 props2;
|
| 3072 |
+
vk::PhysicalDeviceMaintenance3Properties props3;
|
| 3073 |
+
vk::PhysicalDeviceSubgroupProperties subgroup_props;
|
| 3074 |
+
vk::PhysicalDeviceDriverProperties driver_props;
|
| 3075 |
+
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
|
| 3076 |
+
props2.pNext = &props3;
|
| 3077 |
+
props3.pNext = &subgroup_props;
|
| 3078 |
+
subgroup_props.pNext = &driver_props;
|
| 3079 |
+
|
| 3080 |
+
// Pointer to the last chain element
|
| 3081 |
+
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
|
| 3082 |
+
|
| 3083 |
+
if (integer_dot_product) {
|
| 3084 |
+
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
| 3085 |
+
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
| 3086 |
+
}
|
| 3087 |
+
|
| 3088 |
+
physical_device.getProperties2(&props2);
|
| 3089 |
|
| 3090 |
VkPhysicalDeviceFeatures2 device_features2;
|
| 3091 |
device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
|
| 3092 |
device_features2.pNext = nullptr;
|
|
|
|
| 3093 |
|
| 3094 |
VkPhysicalDeviceVulkan11Features vk11_features;
|
| 3095 |
vk11_features.pNext = nullptr;
|
|
|
|
| 3102 |
vk11_features.pNext = &vk12_features;
|
| 3103 |
|
| 3104 |
// Pointer to the last chain element
|
| 3105 |
+
last_struct = (VkBaseOutStructure *)&vk12_features;
|
| 3106 |
|
| 3107 |
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
| 3108 |
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
|
|
|
|
| 3114 |
last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
|
| 3115 |
last_struct = (VkBaseOutStructure *)&coopmat_features;
|
| 3116 |
}
|
| 3117 |
+
#endif
|
| 3118 |
+
|
| 3119 |
+
VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
|
| 3120 |
+
shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
|
| 3121 |
+
if (integer_dot_product) {
|
| 3122 |
+
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
|
| 3123 |
+
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
|
| 3124 |
+
}
|
| 3125 |
|
| 3126 |
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
|
| 3127 |
|
| 3128 |
fp16 = fp16 && vk12_features.shaderFloat16;
|
| 3129 |
|
| 3130 |
+
uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
|
| 3131 |
+
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
|
| 3132 |
+
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
| 3133 |
+
|
| 3134 |
+
integer_dot_product = integer_dot_product
|
| 3135 |
+
&& shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated
|
| 3136 |
+
&& shader_integer_dot_product_features.shaderIntegerDotProduct;
|
| 3137 |
+
|
| 3138 |
+
coopmat_support = coopmat_support
|
| 3139 |
+
&& coopmat_features.cooperativeMatrix
|
| 3140 |
+
&& ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
|
| 3141 |
|
| 3142 |
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
|
| 3143 |
|
| 3144 |
std::string device_name = props2.properties.deviceName.data();
|
| 3145 |
+
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
|
| 3146 |
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
|
| 3147 |
+
props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
|
| 3148 |
|
| 3149 |
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
|
| 3150 |
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
|
|
|
|
| 3410 |
}
|
| 3411 |
}
|
| 3412 |
|
| 3413 |
+
// MMQ
|
| 3414 |
+
if (src1_type == GGML_TYPE_Q8_1) {
|
| 3415 |
+
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc;
|
| 3416 |
+
|
| 3417 |
+
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
|
| 3418 |
+
return nullptr;
|
| 3419 |
+
}
|
| 3420 |
+
|
| 3421 |
+
return pipelines;
|
| 3422 |
+
}
|
| 3423 |
+
|
| 3424 |
if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
|
| 3425 |
return nullptr;
|
| 3426 |
}
|
|
|
|
| 3713 |
return s;
|
| 3714 |
}
|
| 3715 |
|
|
|
|
|
|
|
| 3716 |
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
|
| 3717 |
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
|
| 3718 |
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
|
|
|
|
| 4142 |
return split_k;
|
| 4143 |
}
|
| 4144 |
|
| 4145 |
+
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
|
| 4146 |
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
| 4147 |
|
| 4148 |
if (ctx->device->coopmat2) {
|
| 4149 |
// Use large shader when the N dimension is greater than the medium shader's tile size
|
|
|
|
| 4168 |
return aligned ? mmp->a_l : mmp->l;
|
| 4169 |
}
|
| 4170 |
|
| 4171 |
+
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
|
| 4172 |
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
| 4173 |
+
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
|
| 4174 |
}
|
| 4175 |
|
| 4176 |
static void ggml_vk_matmul(
|
|
|
|
| 4180 |
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
|
| 4181 |
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
|
| 4182 |
uint32_t padded_n) {
|
| 4183 |
+
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
|
| 4184 |
ggml_vk_sync_buffers(subctx);
|
| 4185 |
if (split_k == 1) {
|
| 4186 |
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
|
|
|
| 4198 |
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
| 4199 |
}
|
| 4200 |
|
| 4201 |
+
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
|
| 4202 |
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
| 4203 |
|
| 4204 |
if (ctx->device->coopmat2) {
|
|
|
|
| 4340 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
|
| 4341 |
}
|
| 4342 |
|
| 4343 |
+
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
|
| 4344 |
+
switch(type) {
|
| 4345 |
+
case GGML_TYPE_Q8_1:
|
| 4346 |
+
return ctx->device->pipeline_quantize_q8_1;
|
| 4347 |
+
default:
|
| 4348 |
+
std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
|
| 4349 |
+
GGML_ABORT("fatal error");
|
| 4350 |
+
}
|
| 4351 |
+
}
|
| 4352 |
+
|
| 4353 |
+
static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
|
| 4354 |
+
VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
|
| 4355 |
+
|
| 4356 |
+
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
| 4357 |
+
|
| 4358 |
+
ggml_vk_sync_buffers(subctx);
|
| 4359 |
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 });
|
| 4360 |
+
}
|
| 4361 |
+
|
| 4362 |
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 4363 |
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
| 4364 |
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
|
|
|
| 4410 |
|
| 4411 |
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
| 4412 |
|
| 4413 |
+
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
|
| 4414 |
+
|
| 4415 |
+
// Check for mmq first
|
| 4416 |
+
vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
|
| 4417 |
+
|
| 4418 |
+
if (mmp == nullptr) {
|
| 4419 |
+
// Fall back to f16 dequant mul mat
|
| 4420 |
+
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
|
| 4421 |
+
quantize_y = false;
|
| 4422 |
+
}
|
| 4423 |
|
| 4424 |
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
| 4425 |
+
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
|
| 4426 |
|
| 4427 |
if (qx_needs_dequant) {
|
| 4428 |
// Fall back to dequant + f16 mulmat
|
|
|
|
| 4432 |
// Not implemented
|
| 4433 |
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
| 4434 |
|
| 4435 |
+
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
|
| 4436 |
+
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
|
| 4437 |
|
| 4438 |
+
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
|
| 4439 |
|
| 4440 |
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
| 4441 |
+
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
|
| 4442 |
const int x_ne = ne01 * ne00;
|
| 4443 |
const int y_ne = padded_n * ne10;
|
| 4444 |
const int d_ne = ne11 * ne01;
|
|
|
|
| 4448 |
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
| 4449 |
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
| 4450 |
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
| 4451 |
+
const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
|
| 4452 |
const uint64_t d_sz = sizeof(float) * d_ne;
|
| 4453 |
|
| 4454 |
vk_pipeline to_fp16_vk_0 = nullptr;
|
| 4455 |
vk_pipeline to_fp16_vk_1 = nullptr;
|
| 4456 |
+
vk_pipeline to_q8_1 = nullptr;
|
| 4457 |
|
| 4458 |
if (x_non_contig) {
|
| 4459 |
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
|
|
|
|
| 4468 |
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
| 4469 |
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
| 4470 |
|
| 4471 |
+
if (quantize_y) {
|
| 4472 |
+
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
| 4473 |
+
}
|
| 4474 |
+
|
| 4475 |
if (dryrun) {
|
| 4476 |
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
|
| 4477 |
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
|
|
|
| 4485 |
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
|
| 4486 |
ctx->prealloc_size_x = x_sz_upd;
|
| 4487 |
}
|
| 4488 |
+
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
|
| 4489 |
ctx->prealloc_size_y = y_sz_upd;
|
| 4490 |
}
|
| 4491 |
if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
|
|
|
|
| 4500 |
if (qy_needs_dequant) {
|
| 4501 |
ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
|
| 4502 |
}
|
| 4503 |
+
if (quantize_y) {
|
| 4504 |
+
ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1);
|
| 4505 |
+
}
|
| 4506 |
if (split_k > 1) {
|
| 4507 |
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
|
| 4508 |
}
|
|
|
|
| 4538 |
if (qy_needs_dequant) {
|
| 4539 |
d_Y = ctx->prealloc_y;
|
| 4540 |
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
|
| 4541 |
+
} else if (quantize_y) {
|
| 4542 |
+
d_Y = ctx->prealloc_y;
|
| 4543 |
+
GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
|
| 4544 |
} else {
|
| 4545 |
d_Y = d_Qy;
|
| 4546 |
y_buf_offset = qy_buf_offset;
|
|
|
|
| 4557 |
if (y_non_contig) {
|
| 4558 |
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
| 4559 |
}
|
| 4560 |
+
if (quantize_y) {
|
| 4561 |
+
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
|
| 4562 |
+
}
|
| 4563 |
|
| 4564 |
uint32_t stride_batch_x = ne00*ne01;
|
| 4565 |
uint32_t stride_batch_y = ne10*ne11;
|
|
|
|
| 4568 |
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
|
| 4569 |
}
|
| 4570 |
|
| 4571 |
+
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
|
| 4572 |
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
|
| 4573 |
}
|
| 4574 |
|
|
|
|
| 7097 |
}
|
| 7098 |
}
|
| 7099 |
|
| 7100 |
+
if (ctx->device->need_compiles) {
|
| 7101 |
+
ggml_vk_load_shaders(ctx->device);
|
| 7102 |
+
}
|
| 7103 |
+
|
| 7104 |
ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
| 7105 |
|
| 7106 |
vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
|
|
| 7349 |
|
| 7350 |
ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
|
| 7351 |
|
| 7352 |
+
if (ctx->device->need_compiles) {
|
| 7353 |
+
ggml_vk_load_shaders(ctx->device);
|
| 7354 |
+
}
|
| 7355 |
+
|
| 7356 |
ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
| 7357 |
|
| 7358 |
ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
|
|
|
|
| 7412 |
free(x_chk);
|
| 7413 |
}
|
| 7414 |
|
| 7415 |
+
// This does not work without ggml q8_1 quantization support
|
| 7416 |
+
//
|
| 7417 |
+
// typedef uint16_t ggml_half;
|
| 7418 |
+
// typedef uint32_t ggml_half2;
|
| 7419 |
+
//
|
| 7420 |
+
// #define QK8_1 32
|
| 7421 |
+
// typedef struct {
|
| 7422 |
+
// union {
|
| 7423 |
+
// struct {
|
| 7424 |
+
// ggml_half d; // delta
|
| 7425 |
+
// ggml_half s; // d * sum(qs[i])
|
| 7426 |
+
// } GGML_COMMON_AGGR_S;
|
| 7427 |
+
// ggml_half2 ds;
|
| 7428 |
+
// } GGML_COMMON_AGGR_U;
|
| 7429 |
+
// int8_t qs[QK8_1]; // quants
|
| 7430 |
+
// } block_q8_1;
|
| 7431 |
+
//
|
| 7432 |
+
// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
|
| 7433 |
+
// VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")");
|
| 7434 |
+
// GGML_ASSERT(quant == GGML_TYPE_Q8_1);
|
| 7435 |
+
//
|
| 7436 |
+
// const size_t x_sz = sizeof(float) * ne;
|
| 7437 |
+
// const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
|
| 7438 |
+
// float * x = (float *) malloc(x_sz);
|
| 7439 |
+
// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz);
|
| 7440 |
+
// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
|
| 7441 |
+
// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
| 7442 |
+
// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
| 7443 |
+
//
|
| 7444 |
+
// for (size_t i = 0; i < ne; i++) {
|
| 7445 |
+
// x[i] = rand() / (float)RAND_MAX;
|
| 7446 |
+
// }
|
| 7447 |
+
//
|
| 7448 |
+
// vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
|
| 7449 |
+
//
|
| 7450 |
+
// ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
|
| 7451 |
+
//
|
| 7452 |
+
// if (ctx->device->need_compiles) {
|
| 7453 |
+
// ggml_vk_load_shaders(ctx->device);
|
| 7454 |
+
// }
|
| 7455 |
+
//
|
| 7456 |
+
// ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
| 7457 |
+
//
|
| 7458 |
+
// ggml_vk_buffer_write(x_buf, 0, x, x_sz);
|
| 7459 |
+
//
|
| 7460 |
+
// vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
| 7461 |
+
// ggml_vk_ctx_begin(ctx->device, subctx);
|
| 7462 |
+
// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne);
|
| 7463 |
+
// ggml_vk_ctx_end(subctx);
|
| 7464 |
+
//
|
| 7465 |
+
// auto begin = std::chrono::high_resolution_clock::now();
|
| 7466 |
+
//
|
| 7467 |
+
// ggml_vk_submit(subctx, ctx->fence);
|
| 7468 |
+
// VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
|
| 7469 |
+
// ctx->device->device.resetFences({ ctx->fence });
|
| 7470 |
+
//
|
| 7471 |
+
// auto end = std::chrono::high_resolution_clock::now();
|
| 7472 |
+
//
|
| 7473 |
+
// double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
|
| 7474 |
+
// ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);
|
| 7475 |
+
//
|
| 7476 |
+
// ggml_vk_quantize_data(x, qx_res, ne, quant);
|
| 7477 |
+
//
|
| 7478 |
+
// int first_err = -1;
|
| 7479 |
+
//
|
| 7480 |
+
// for (size_t i = 0; i < ne / 32; i++) {
|
| 7481 |
+
// double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));
|
| 7482 |
+
//
|
| 7483 |
+
// if (first_err < 0 && error > 0.1) {
|
| 7484 |
+
// first_err = i;
|
| 7485 |
+
// }
|
| 7486 |
+
//
|
| 7487 |
+
// error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));
|
| 7488 |
+
//
|
| 7489 |
+
// if (first_err < 0 && error > 0.1) {
|
| 7490 |
+
// first_err = i;
|
| 7491 |
+
// }
|
| 7492 |
+
//
|
| 7493 |
+
// for (size_t j = 0; j < 32; j++) {
|
| 7494 |
+
// uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);
|
| 7495 |
+
//
|
| 7496 |
+
// if (first_err < 0 && error > 1) {
|
| 7497 |
+
// first_err = i;
|
| 7498 |
+
// }
|
| 7499 |
+
// }
|
| 7500 |
+
// }
|
| 7501 |
+
//
|
| 7502 |
+
// std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl;
|
| 7503 |
+
//
|
| 7504 |
+
// if (first_err != -1) {
|
| 7505 |
+
// std::cerr << "first_error = " << first_err << std::endl;
|
| 7506 |
+
// std::cerr << "Actual result: " << std::endl << std::endl;
|
| 7507 |
+
// std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
|
| 7508 |
+
// for (size_t j = 0; j < 32; j++) {
|
| 7509 |
+
// std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " ";
|
| 7510 |
+
// }
|
| 7511 |
+
// std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl;
|
| 7512 |
+
// std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
|
| 7513 |
+
// for (size_t j = 0; j < 32; j++) {
|
| 7514 |
+
// std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " ";
|
| 7515 |
+
// }
|
| 7516 |
+
// std::cerr << std::endl;
|
| 7517 |
+
// }
|
| 7518 |
+
//
|
| 7519 |
+
// ggml_vk_destroy_buffer(x_buf);
|
| 7520 |
+
// ggml_vk_destroy_buffer(qx_buf);
|
| 7521 |
+
//
|
| 7522 |
+
// free(x);
|
| 7523 |
+
// free(qx);
|
| 7524 |
+
// free(qx_res);
|
| 7525 |
+
// }
|
| 7526 |
+
|
| 7527 |
+
static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {
|
| 7528 |
VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
|
| 7529 |
const size_t x_ne = m * k * batch;
|
| 7530 |
const size_t y_ne = k * n * batch;
|
| 7531 |
const size_t d_ne = m * n * batch;
|
| 7532 |
|
| 7533 |
+
vk_matmul_pipeline2 * pipelines;
|
| 7534 |
+
|
| 7535 |
+
if (mmq) {
|
| 7536 |
+
pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;
|
| 7537 |
+
} else {
|
| 7538 |
+
pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
|
| 7539 |
+
}
|
| 7540 |
+
|
| 7541 |
+
const bool fp16acc = ctx->device->fp16;
|
| 7542 |
+
|
| 7543 |
vk_pipeline p;
|
| 7544 |
std::string shname;
|
| 7545 |
if (shader_size == 0) {
|
| 7546 |
+
p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
|
| 7547 |
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
|
| 7548 |
} else if (shader_size == 1) {
|
| 7549 |
+
p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
|
| 7550 |
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
|
| 7551 |
} else if (shader_size == 2) {
|
| 7552 |
+
p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
|
| 7553 |
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
|
| 7554 |
} else {
|
| 7555 |
GGML_ASSERT(0);
|
| 7556 |
}
|
| 7557 |
|
| 7558 |
+
const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);
|
| 7559 |
|
| 7560 |
+
if (mmq || k != kpad) {
|
| 7561 |
if (shader_size == 0) {
|
| 7562 |
+
p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
|
| 7563 |
shname = std::string(ggml_type_name(quant)) + "_S";
|
| 7564 |
} else if (shader_size == 1) {
|
| 7565 |
+
p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
|
| 7566 |
shname = std::string(ggml_type_name(quant)) + "_M";
|
| 7567 |
} else if (shader_size == 2) {
|
| 7568 |
+
p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
|
| 7569 |
shname = std::string(ggml_type_name(quant)) + "_L";
|
| 7570 |
} else {
|
| 7571 |
GGML_ASSERT(0);
|
| 7572 |
}
|
| 7573 |
}
|
| 7574 |
|
| 7575 |
+
if (p == nullptr) {
|
| 7576 |
+
std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl;
|
| 7577 |
+
return;
|
| 7578 |
+
}
|
| 7579 |
+
|
| 7580 |
const size_t x_sz = sizeof(float) * x_ne;
|
| 7581 |
const size_t y_sz = sizeof(float) * y_ne;
|
| 7582 |
const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
|
| 7583 |
+
const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;
|
| 7584 |
const size_t d_sz = sizeof(float) * d_ne;
|
| 7585 |
float * x = (float *) malloc(x_sz);
|
| 7586 |
float * y = (float *) malloc(y_sz);
|
| 7587 |
void * qx = malloc(qx_sz);
|
| 7588 |
vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
| 7589 |
vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
| 7590 |
+
vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
| 7591 |
vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
| 7592 |
float * d = (float *) malloc(d_sz);
|
| 7593 |
float * d_chk = (float *) malloc(d_sz);
|
| 7594 |
|
| 7595 |
for (size_t i = 0; i < x_ne; i++) {
|
| 7596 |
x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
|
| 7597 |
+
// x[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
| 7598 |
+
// x[i] = i % k;
|
| 7599 |
}
|
| 7600 |
|
| 7601 |
ggml_vk_quantize_data(x, qx, x_ne, quant);
|
| 7602 |
|
| 7603 |
for (size_t i = 0; i < y_ne; i++) {
|
| 7604 |
+
y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
|
| 7605 |
+
// y[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
| 7606 |
+
// y[i] = i % k;
|
| 7607 |
}
|
| 7608 |
|
| 7609 |
ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
|
|
|
|
| 7618 |
ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
| 7619 |
}
|
| 7620 |
}
|
| 7621 |
+
if (mmq) {
|
| 7622 |
+
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it);
|
| 7623 |
+
}
|
| 7624 |
+
|
| 7625 |
+
if (ctx->device->need_compiles) {
|
| 7626 |
+
ggml_vk_load_shaders(ctx->device);
|
| 7627 |
+
}
|
| 7628 |
|
| 7629 |
ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
| 7630 |
|
|
|
|
| 7633 |
|
| 7634 |
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
| 7635 |
ggml_vk_ctx_begin(ctx->device, subctx);
|
| 7636 |
+
if (mmq) {
|
| 7637 |
+
for (size_t i = 0; i < num_it; i++) {
|
| 7638 |
+
ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
|
| 7639 |
+
ggml_vk_matmul(
|
| 7640 |
+
ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
|
| 7641 |
+
m, n, k,
|
| 7642 |
+
k, k, m, k*m, k*n, m*n,
|
| 7643 |
+
split_k, batch, batch, batch, 1, 1, n
|
| 7644 |
+
);
|
| 7645 |
+
}
|
| 7646 |
+
} else {
|
| 7647 |
+
for (size_t i = 0; i < num_it; i++) {
|
| 7648 |
+
ggml_vk_matmul(
|
| 7649 |
+
ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
|
| 7650 |
+
m, n, k,
|
| 7651 |
+
k, k, m, k*m, k*n, m*n,
|
| 7652 |
+
split_k, batch, batch, batch, 1, 1, n
|
| 7653 |
+
);
|
| 7654 |
+
}
|
| 7655 |
}
|
| 7656 |
ggml_vk_ctx_end(subctx);
|
| 7657 |
|
|
|
|
| 7709 |
|
| 7710 |
double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
|
| 7711 |
|
| 7712 |
+
std::cerr << "TEST dequant matmul " << shname;
|
| 7713 |
+
if (mmq) {
|
| 7714 |
+
std::cerr << " mmq";
|
| 7715 |
+
}
|
| 7716 |
+
std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
|
| 7717 |
|
| 7718 |
if (avg_err > 0.01 || std::isnan(avg_err)) {
|
| 7719 |
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
|
|
|
|
| 7723 |
std::cerr << "Expected result: " << std::endl << std::endl;
|
| 7724 |
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
| 7725 |
|
| 7726 |
+
std::cerr << "src0: " << std::endl << std::endl;
|
| 7727 |
+
ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);
|
| 7728 |
+
std::cerr << std::endl;
|
| 7729 |
+
std::cerr << "src1: " << std::endl << std::endl;
|
| 7730 |
+
ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);
|
| 7731 |
+
|
| 7732 |
if (split_k > 1) {
|
| 7733 |
float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
|
| 7734 |
ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
|
|
|
|
| 7751 |
|
| 7752 |
ggml_vk_destroy_buffer(qx_buf);
|
| 7753 |
ggml_vk_destroy_buffer(y_buf);
|
| 7754 |
+
ggml_vk_destroy_buffer(qy_buf);
|
| 7755 |
ggml_vk_destroy_buffer(d_buf);
|
| 7756 |
|
| 7757 |
free(x);
|
|
|
|
| 7784 |
128, 49, 49,
|
| 7785 |
4096, 49, 4096,
|
| 7786 |
};
|
| 7787 |
+
const size_t num_it = 1;
|
| 7788 |
+
|
| 7789 |
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
|
| 7790 |
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
|
| 7791 |
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
|
| 7792 |
+
|
| 7793 |
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
|
| 7794 |
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
|
| 7795 |
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
|
| 7796 |
+
|
| 7797 |
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
|
| 7798 |
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
|
| 7799 |
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
|
| 7800 |
+
|
| 7801 |
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
|
| 7802 |
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
|
| 7803 |
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
|
| 7804 |
+
|
| 7805 |
+
abort();
|
| 7806 |
|
| 7807 |
for (size_t i = 0; i < vals.size(); i += 3) {
|
| 7808 |
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
|
|
|
|
| 9614 |
}
|
| 9615 |
|
| 9616 |
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
| 9617 |
+
const float * params = (const float *)tensor->op_params;
|
| 9618 |
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
|
| 9619 |
} else if (tensor->op == GGML_OP_MUL_MAT) {
|
| 9620 |
tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
|
|
|
|
| 9631 |
} else if (tensor->op == GGML_OP_UPSCALE) {
|
| 9632 |
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
| 9633 |
} else if (tensor->op == GGML_OP_SCALE) {
|
| 9634 |
+
const float * params = (const float *)tensor->op_params;
|
| 9635 |
+
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
|
| 9636 |
} else if (tensor->op == GGML_OP_SQR) {
|
| 9637 |
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
|
| 9638 |
} else if (tensor->op == GGML_OP_SIN) {
|
|
|
|
| 9640 |
} else if (tensor->op == GGML_OP_COS) {
|
| 9641 |
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
|
| 9642 |
} else if (tensor->op == GGML_OP_CLAMP) {
|
| 9643 |
+
const float * params = (const float *)tensor->op_params;
|
| 9644 |
+
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
|
| 9645 |
} else if (tensor->op == GGML_OP_PAD) {
|
| 9646 |
tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
|
| 9647 |
} else if (tensor->op == GGML_OP_REPEAT) {
|
|
|
|
| 9655 |
} else if (tensor->op == GGML_OP_NORM) {
|
| 9656 |
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
| 9657 |
} else if (tensor->op == GGML_OP_GROUP_NORM) {
|
| 9658 |
+
const float * float_params = (const float *)tensor->op_params;
|
| 9659 |
+
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
|
| 9660 |
} else if (tensor->op == GGML_OP_RMS_NORM) {
|
| 9661 |
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
| 9662 |
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
|
|
|
|
| 9669 |
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
|
| 9670 |
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
| 9671 |
if (src1 != nullptr) {
|
| 9672 |
+
const float * params = (const float *)tensor->op_params;
|
| 9673 |
+
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
|
| 9674 |
} else {
|
| 9675 |
tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
|
| 9676 |
}
|
| 9677 |
} else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
|
| 9678 |
tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
| 9679 |
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
|
| 9680 |
+
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);
|
| 9681 |
} else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
|
| 9682 |
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
| 9683 |
const int mode = ((int32_t *) tensor->op_params)[2];
|
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
CHANGED
|
@@ -212,7 +212,7 @@ void main() {
|
|
| 212 |
#else
|
| 213 |
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
| 214 |
FLOAT_TYPE cache_a[WMITER * TM];
|
| 215 |
-
FLOAT_TYPE cache_b[
|
| 216 |
|
| 217 |
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
| 218 |
sums[i] = ACC_TYPE(0.0f);
|
|
@@ -744,16 +744,14 @@ void main() {
|
|
| 744 |
}
|
| 745 |
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
| 746 |
[[unroll]] for (uint j = 0; j < TN; j++) {
|
| 747 |
-
cache_b[
|
| 748 |
}
|
| 749 |
-
}
|
| 750 |
|
| 751 |
-
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
| 752 |
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
| 753 |
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
| 754 |
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
| 755 |
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
| 756 |
-
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[
|
| 757 |
}
|
| 758 |
}
|
| 759 |
}
|
|
|
|
| 212 |
#else
|
| 213 |
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
| 214 |
FLOAT_TYPE cache_a[WMITER * TM];
|
| 215 |
+
FLOAT_TYPE cache_b[TN];
|
| 216 |
|
| 217 |
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
| 218 |
sums[i] = ACC_TYPE(0.0f);
|
|
|
|
| 744 |
}
|
| 745 |
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
| 746 |
[[unroll]] for (uint j = 0; j < TN; j++) {
|
| 747 |
+
cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
|
| 748 |
}
|
|
|
|
| 749 |
|
|
|
|
| 750 |
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
| 751 |
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
| 752 |
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
| 753 |
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
| 754 |
+
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
|
| 755 |
}
|
| 756 |
}
|
| 757 |
}
|
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#extension GL_EXT_control_flow_attributes : enable
|
| 4 |
+
#extension GL_EXT_shader_16bit_storage : require
|
| 5 |
+
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
| 6 |
+
|
| 7 |
+
#extension GL_EXT_integer_dot_product : require
|
| 8 |
+
|
| 9 |
+
#ifdef FLOAT16
|
| 10 |
+
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#ifdef COOPMAT
|
| 14 |
+
#extension GL_KHR_cooperative_matrix : enable
|
| 15 |
+
#extension GL_KHR_memory_scope_semantics : enable
|
| 16 |
+
#extension GL_KHR_shader_subgroup_basic : enable
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#ifdef MUL_MAT_ID
|
| 20 |
+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
| 21 |
+
#endif
|
| 22 |
+
|
| 23 |
+
#include "types.comp"
|
| 24 |
+
|
| 25 |
+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 26 |
+
|
| 27 |
+
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
| 28 |
+
#if defined(A_TYPE_PACKED32)
|
| 29 |
+
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
| 30 |
+
#endif
|
| 31 |
+
layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
|
| 32 |
+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
| 33 |
+
|
| 34 |
+
#ifdef MUL_MAT_ID
|
| 35 |
+
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
| 36 |
+
#endif
|
| 37 |
+
|
| 38 |
+
layout (push_constant) uniform parameter
|
| 39 |
+
{
|
| 40 |
+
uint M;
|
| 41 |
+
uint N;
|
| 42 |
+
uint K;
|
| 43 |
+
uint stride_a;
|
| 44 |
+
uint stride_b;
|
| 45 |
+
uint stride_d;
|
| 46 |
+
|
| 47 |
+
uint batch_stride_a;
|
| 48 |
+
uint batch_stride_b;
|
| 49 |
+
uint batch_stride_d;
|
| 50 |
+
|
| 51 |
+
#ifdef MUL_MAT_ID
|
| 52 |
+
uint nei0;
|
| 53 |
+
uint nei1;
|
| 54 |
+
uint nbi1;
|
| 55 |
+
uint ne11;
|
| 56 |
+
#else
|
| 57 |
+
uint k_split;
|
| 58 |
+
uint ne02;
|
| 59 |
+
uint ne12;
|
| 60 |
+
uint broadcast2;
|
| 61 |
+
uint broadcast3;
|
| 62 |
+
#endif
|
| 63 |
+
} p;
|
| 64 |
+
|
| 65 |
+
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
|
| 66 |
+
layout (constant_id = 1) const uint BM = 64;
|
| 67 |
+
layout (constant_id = 2) const uint BN = 64;
|
| 68 |
+
// layout (constant_id = 3) const uint BK = 32;
|
| 69 |
+
layout (constant_id = 4) const uint WM = 32;
|
| 70 |
+
layout (constant_id = 5) const uint WN = 32;
|
| 71 |
+
layout (constant_id = 6) const uint WMITER = 2;
|
| 72 |
+
layout (constant_id = 7) const uint TM = 4;
|
| 73 |
+
layout (constant_id = 8) const uint TN = 2;
|
| 74 |
+
layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
|
| 75 |
+
layout (constant_id = 10) const uint WARP = 32;
|
| 76 |
+
|
| 77 |
+
#define BK 32
|
| 78 |
+
|
| 79 |
+
#ifdef COOPMAT
|
| 80 |
+
#define SHMEM_STRIDE (BK / 4 + 4)
|
| 81 |
+
#else
|
| 82 |
+
#define SHMEM_STRIDE (BK / 4 + 1)
|
| 83 |
+
#endif
|
| 84 |
+
|
| 85 |
+
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
|
| 86 |
+
|
| 87 |
+
#ifndef COOPMAT
|
| 88 |
+
#if QUANT_AUXF == 1
|
| 89 |
+
shared FLOAT_TYPE buf_a_dm[BM];
|
| 90 |
+
#else
|
| 91 |
+
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
|
| 92 |
+
#endif
|
| 93 |
+
#endif
|
| 94 |
+
|
| 95 |
+
shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
|
| 96 |
+
#ifndef COOPMAT
|
| 97 |
+
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
|
| 98 |
+
#endif
|
| 99 |
+
|
| 100 |
+
#define LOAD_VEC_A (4 * QUANT_R)
|
| 101 |
+
#define LOAD_VEC_B 4
|
| 102 |
+
|
| 103 |
+
#ifdef MUL_MAT_ID
|
| 104 |
+
shared u16vec2 row_ids[3072];
|
| 105 |
+
#endif // MUL_MAT_ID
|
| 106 |
+
|
| 107 |
+
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
| 108 |
+
|
| 109 |
+
#ifdef COOPMAT
|
| 110 |
+
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
| 111 |
+
#endif
|
| 112 |
+
|
| 113 |
+
#include "mul_mmq_funcs.comp"
|
| 114 |
+
|
| 115 |
+
void main() {
|
| 116 |
+
#ifdef NEEDS_INIT_IQ_SHMEM
|
| 117 |
+
init_iq_shmem(gl_WorkGroupSize);
|
| 118 |
+
#endif
|
| 119 |
+
|
| 120 |
+
#ifdef MUL_MAT_ID
|
| 121 |
+
const uint expert_idx = gl_GlobalInvocationID.z;
|
| 122 |
+
#else
|
| 123 |
+
const uint batch_idx = gl_GlobalInvocationID.z;
|
| 124 |
+
|
| 125 |
+
const uint i13 = batch_idx / p.ne12;
|
| 126 |
+
const uint i12 = batch_idx % p.ne12;
|
| 127 |
+
|
| 128 |
+
const uint i03 = i13 / p.broadcast3;
|
| 129 |
+
const uint i02 = i12 / p.broadcast2;
|
| 130 |
+
|
| 131 |
+
const uint batch_idx_a = i03 * p.ne02 + i02;
|
| 132 |
+
#endif
|
| 133 |
+
|
| 134 |
+
const uint blocks_m = (p.M + BM - 1) / BM;
|
| 135 |
+
const uint ir = gl_WorkGroupID.x % blocks_m;
|
| 136 |
+
const uint ik = gl_WorkGroupID.x / blocks_m;
|
| 137 |
+
const uint ic = gl_WorkGroupID.y;
|
| 138 |
+
|
| 139 |
+
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
| 140 |
+
const uint WSUBM = WM / WMITER;
|
| 141 |
+
const uint WSUBN = WN / WNITER;
|
| 142 |
+
|
| 143 |
+
#ifdef COOPMAT
|
| 144 |
+
const uint warp_i = gl_SubgroupID;
|
| 145 |
+
|
| 146 |
+
const uint tiw = gl_SubgroupInvocationID;
|
| 147 |
+
|
| 148 |
+
const uint cms_per_row = WM / TM;
|
| 149 |
+
const uint cms_per_col = WN / TN;
|
| 150 |
+
|
| 151 |
+
const uint storestride = WARP / TM;
|
| 152 |
+
const uint store_r = tiw % TM;
|
| 153 |
+
const uint store_c = tiw / TM;
|
| 154 |
+
#else
|
| 155 |
+
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
| 156 |
+
|
| 157 |
+
const uint tiw = gl_LocalInvocationID.x % WARP;
|
| 158 |
+
|
| 159 |
+
const uint tiwr = tiw % (WSUBM / TM);
|
| 160 |
+
const uint tiwc = tiw / (WSUBM / TM);
|
| 161 |
+
#endif
|
| 162 |
+
|
| 163 |
+
const uint warp_r = warp_i % (BM / WM);
|
| 164 |
+
const uint warp_c = warp_i / (BM / WM);
|
| 165 |
+
|
| 166 |
+
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
|
| 167 |
+
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
|
| 168 |
+
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
|
| 169 |
+
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
|
| 170 |
+
|
| 171 |
+
const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK;
|
| 172 |
+
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
|
| 173 |
+
|
| 174 |
+
#ifdef MUL_MAT_ID
|
| 175 |
+
uint _ne1 = 0;
|
| 176 |
+
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
| 177 |
+
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
| 178 |
+
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
| 179 |
+
row_ids[_ne1] = u16vec2(ii0, ii1);
|
| 180 |
+
_ne1++;
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
barrier();
|
| 186 |
+
|
| 187 |
+
// Workgroup has no work
|
| 188 |
+
if (ic * BN >= _ne1) return;
|
| 189 |
+
#endif
|
| 190 |
+
|
| 191 |
+
#ifdef MUL_MAT_ID
|
| 192 |
+
const uint start_k = 0;
|
| 193 |
+
const uint end_k = p.K;
|
| 194 |
+
#else
|
| 195 |
+
const uint start_k = ik * p.k_split;
|
| 196 |
+
const uint end_k = min(p.K, (ik + 1) * p.k_split);
|
| 197 |
+
#endif
|
| 198 |
+
|
| 199 |
+
uint pos_a_ib = (
|
| 200 |
+
#ifdef MUL_MAT_ID
|
| 201 |
+
expert_idx * p.batch_stride_a +
|
| 202 |
+
#else
|
| 203 |
+
batch_idx_a * p.batch_stride_a +
|
| 204 |
+
#endif
|
| 205 |
+
ir * BM * p.stride_a + start_k) / BK;
|
| 206 |
+
#ifdef MUL_MAT_ID
|
| 207 |
+
uint pos_b_ib = 0;
|
| 208 |
+
#else
|
| 209 |
+
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
|
| 210 |
+
#endif
|
| 211 |
+
|
| 212 |
+
#ifdef COOPMAT
|
| 213 |
+
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
| 214 |
+
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
| 215 |
+
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
|
| 216 |
+
|
| 217 |
+
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
|
| 218 |
+
|
| 219 |
+
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
| 220 |
+
|
| 221 |
+
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
| 222 |
+
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
| 223 |
+
}
|
| 224 |
+
#else
|
| 225 |
+
int32_t cache_a_qs[WMITER * TM * BK / 4];
|
| 226 |
+
|
| 227 |
+
int32_t cache_b_qs[TN * BK / 4];
|
| 228 |
+
|
| 229 |
+
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
| 230 |
+
|
| 231 |
+
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
| 232 |
+
sums[i] = ACC_TYPE(0.0f);
|
| 233 |
+
}
|
| 234 |
+
#endif
|
| 235 |
+
|
| 236 |
+
#if QUANT_AUXF == 1
|
| 237 |
+
FLOAT_TYPE cache_a_dm[TM];
|
| 238 |
+
#else
|
| 239 |
+
FLOAT_TYPE_VEC2 cache_a_dm[TM];
|
| 240 |
+
#endif
|
| 241 |
+
|
| 242 |
+
FLOAT_TYPE_VEC2 cache_b_ds[TN];
|
| 243 |
+
|
| 244 |
+
for (uint block = start_k; block < end_k; block += BK) {
|
| 245 |
+
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
|
| 246 |
+
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
|
| 247 |
+
const uint iqs = loadr_a;
|
| 248 |
+
const uint buf_ib = loadc_a + l;
|
| 249 |
+
|
| 250 |
+
// Should ds be gated to a single thread?
|
| 251 |
+
if (iqs == 0) {
|
| 252 |
+
#if QUANT_AUXF == 1
|
| 253 |
+
buf_a_dm[buf_ib] = get_d(ib);
|
| 254 |
+
#else
|
| 255 |
+
buf_a_dm[buf_ib] = get_dm(ib);
|
| 256 |
+
#endif
|
| 257 |
+
}
|
| 258 |
+
#if QUANT_R == 1
|
| 259 |
+
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
|
| 260 |
+
#else
|
| 261 |
+
const i32vec2 vals = repack(ib, iqs);
|
| 262 |
+
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
|
| 263 |
+
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
|
| 264 |
+
#endif
|
| 265 |
+
}
|
| 266 |
+
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
|
| 267 |
+
#ifdef MUL_MAT_ID
|
| 268 |
+
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
| 269 |
+
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
| 270 |
+
const uint ib = idx / 8;
|
| 271 |
+
const uint iqs = idx & 0x7;
|
| 272 |
+
#else
|
| 273 |
+
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
|
| 274 |
+
const uint iqs = loadr_b;
|
| 275 |
+
#endif
|
| 276 |
+
|
| 277 |
+
const uint buf_ib = loadc_b + l;
|
| 278 |
+
|
| 279 |
+
// Should ds be gated to a single thread?
|
| 280 |
+
if (iqs == 0) {
|
| 281 |
+
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
|
| 282 |
+
}
|
| 283 |
+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
barrier();
|
| 287 |
+
|
| 288 |
+
pos_a_ib += 1;
|
| 289 |
+
pos_b_ib += 1;
|
| 290 |
+
|
| 291 |
+
#ifdef COOPMAT
|
| 292 |
+
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
| 293 |
+
const uint ib_a = warp_r * WM + cm_row * TM;
|
| 294 |
+
// Load from shared into cache
|
| 295 |
+
coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
| 296 |
+
|
| 297 |
+
// TODO: only cache values that are actually needed
|
| 298 |
+
[[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
|
| 299 |
+
cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
| 303 |
+
const uint ib_b = warp_c * WN + cm_col * TN;
|
| 304 |
+
coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
| 305 |
+
|
| 306 |
+
// TODO: only cache values that are actually needed
|
| 307 |
+
[[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
|
| 308 |
+
cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
|
| 312 |
+
cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
|
| 313 |
+
|
| 314 |
+
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
| 315 |
+
coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
| 319 |
+
sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
|
| 320 |
+
}
|
| 321 |
+
}
|
| 322 |
+
#else
|
| 323 |
+
// Load from shared into cache
|
| 324 |
+
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
| 325 |
+
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
| 326 |
+
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
|
| 327 |
+
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
|
| 328 |
+
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
| 329 |
+
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
|
| 330 |
+
}
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
| 335 |
+
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
| 336 |
+
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
|
| 337 |
+
cache_b_ds[cc] = buf_b_ds[ib];
|
| 338 |
+
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
| 339 |
+
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
|
| 340 |
+
}
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
| 344 |
+
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
| 345 |
+
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
| 346 |
+
const uint cache_a_idx = wsir * TM + cr;
|
| 347 |
+
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
| 348 |
+
int32_t q_sum = 0;
|
| 349 |
+
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
| 350 |
+
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
|
| 351 |
+
cache_b_qs[cc * (BK / 4) + idx_k]);
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
|
| 355 |
+
}
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
}
|
| 359 |
+
#endif
|
| 360 |
+
|
| 361 |
+
barrier();
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
const uint dr = ir * BM + warp_r * WM;
|
| 365 |
+
const uint dc = ic * BN + warp_c * WN;
|
| 366 |
+
|
| 367 |
+
#ifndef MUL_MAT_ID
|
| 368 |
+
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
| 369 |
+
#endif
|
| 370 |
+
|
| 371 |
+
#ifdef COOPMAT
|
| 372 |
+
#ifdef MUL_MAT_ID
|
| 373 |
+
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
| 374 |
+
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
| 375 |
+
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
| 376 |
+
|
| 377 |
+
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
|
| 378 |
+
const uint row_i = dc + cm_col * TN + col + store_c;
|
| 379 |
+
if (row_i >= _ne1) break;
|
| 380 |
+
|
| 381 |
+
const u16vec2 row_idx = row_ids[row_i];
|
| 382 |
+
|
| 383 |
+
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
}
|
| 387 |
+
#else
|
| 388 |
+
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
|
| 389 |
+
|
| 390 |
+
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
| 391 |
+
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
| 392 |
+
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
|
| 393 |
+
|
| 394 |
+
if (is_aligned && is_in_bounds) {
|
| 395 |
+
// Full coopMat is within bounds and stride_d is aligned with 16B
|
| 396 |
+
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
|
| 397 |
+
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
|
| 398 |
+
} else if (is_in_bounds) {
|
| 399 |
+
// Full coopMat is within bounds, but stride_d is not aligned
|
| 400 |
+
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
| 401 |
+
|
| 402 |
+
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
| 403 |
+
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
| 404 |
+
}
|
| 405 |
+
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
|
| 406 |
+
// Partial coopMat is within bounds
|
| 407 |
+
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
| 408 |
+
|
| 409 |
+
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
| 410 |
+
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
|
| 411 |
+
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
| 412 |
+
}
|
| 413 |
+
}
|
| 414 |
+
}
|
| 415 |
+
}
|
| 416 |
+
}
|
| 417 |
+
#endif // MUL_MAT_ID
|
| 418 |
+
#else
|
| 419 |
+
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
| 420 |
+
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
| 421 |
+
|
| 422 |
+
const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
| 423 |
+
const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
| 424 |
+
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
| 425 |
+
#ifdef MUL_MAT_ID
|
| 426 |
+
const uint row_i = dc_warp + cc;
|
| 427 |
+
if (row_i >= _ne1) break;
|
| 428 |
+
|
| 429 |
+
const u16vec2 row_idx = row_ids[row_i];
|
| 430 |
+
#endif // MUL_MAT_ID
|
| 431 |
+
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
| 432 |
+
#ifdef MUL_MAT_ID
|
| 433 |
+
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
| 434 |
+
#else
|
| 435 |
+
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
| 436 |
+
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
| 437 |
+
}
|
| 438 |
+
#endif // MUL_MAT_ID
|
| 439 |
+
}
|
| 440 |
+
}
|
| 441 |
+
}
|
| 442 |
+
}
|
| 443 |
+
#endif // COOPMAT
|
| 444 |
+
}
|
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
| 2 |
+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
| 3 |
+
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
| 4 |
+
|
| 5 |
+
#include "types.comp"
|
| 6 |
+
|
| 7 |
+
// Each iqs value maps to a 32-bit integer
|
| 8 |
+
|
| 9 |
+
#if defined(DATA_A_Q4_0)
|
| 10 |
+
i32vec2 repack(uint ib, uint iqs) {
|
| 11 |
+
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
|
| 12 |
+
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
|
| 13 |
+
data_a[ib].qs[iqs * 2 + 1]);
|
| 14 |
+
const uint32_t vui = pack32(quants);
|
| 15 |
+
return i32vec2( vui & 0x0F0F0F0F,
|
| 16 |
+
(vui >> 4) & 0x0F0F0F0F);
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
| 20 |
+
return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0 * dsb.y));
|
| 21 |
+
}
|
| 22 |
+
#endif
|
| 23 |
+
|
| 24 |
+
#if defined(DATA_A_Q4_1)
|
| 25 |
+
i32vec2 repack(uint ib, uint iqs) {
|
| 26 |
+
// Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
|
| 27 |
+
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
| 28 |
+
return i32vec2( vui & 0x0F0F0F0F,
|
| 29 |
+
(vui >> 4) & 0x0F0F0F0F);
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
|
| 33 |
+
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
|
| 34 |
+
}
|
| 35 |
+
#endif
|
| 36 |
+
|
| 37 |
+
#if defined(DATA_A_Q5_0)
|
| 38 |
+
i32vec2 repack(uint ib, uint iqs) {
|
| 39 |
+
// Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
|
| 40 |
+
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
|
| 41 |
+
data_a[ib].qs[iqs * 2 + 1]);
|
| 42 |
+
const uint32_t vui = pack32(quants);
|
| 43 |
+
const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
|
| 44 |
+
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
| 45 |
+
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
| 46 |
+
|
| 47 |
+
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
| 48 |
+
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
| 49 |
+
|
| 50 |
+
return i32vec2(v0, v1);
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
| 54 |
+
return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0 * dsb.y));
|
| 55 |
+
}
|
| 56 |
+
#endif
|
| 57 |
+
|
| 58 |
+
#if defined(DATA_A_Q5_1)
|
| 59 |
+
i32vec2 repack(uint ib, uint iqs) {
|
| 60 |
+
// Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
|
| 61 |
+
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
| 62 |
+
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
|
| 63 |
+
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
| 64 |
+
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
| 65 |
+
|
| 66 |
+
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
| 67 |
+
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
| 68 |
+
|
| 69 |
+
return i32vec2(v0, v1);
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
|
| 73 |
+
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
|
| 74 |
+
}
|
| 75 |
+
#endif
|
| 76 |
+
|
| 77 |
+
#if defined(DATA_A_Q8_0)
|
| 78 |
+
int32_t repack(uint ib, uint iqs) {
|
| 79 |
+
// Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
|
| 80 |
+
return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
|
| 81 |
+
data_a[ib].qs[iqs * 2 + 1]));
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
| 85 |
+
return ACC_TYPE(float(q_sum) * da * dsb.x);
|
| 86 |
+
}
|
| 87 |
+
#endif
|
| 88 |
+
|
| 89 |
+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
| 90 |
+
FLOAT_TYPE get_d(uint ib) {
|
| 91 |
+
return FLOAT_TYPE(data_a[ib].d);
|
| 92 |
+
}
|
| 93 |
+
#endif
|
| 94 |
+
|
| 95 |
+
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
| 96 |
+
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
| 97 |
+
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
| 98 |
+
}
|
| 99 |
+
#endif
|
ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#extension GL_EXT_control_flow_attributes : require
|
| 4 |
+
#extension GL_EXT_shader_16bit_storage : require
|
| 5 |
+
|
| 6 |
+
layout (push_constant) uniform parameter
|
| 7 |
+
{
|
| 8 |
+
uint ne;
|
| 9 |
+
} p;
|
| 10 |
+
|
| 11 |
+
#include "types.comp"
|
| 12 |
+
|
| 13 |
+
layout(constant_id = 0) const uint GROUP_SIZE = 32;
|
| 14 |
+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 15 |
+
|
| 16 |
+
layout (binding = 0) readonly buffer A {vec4 data_a[];};
|
| 17 |
+
layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
|
| 18 |
+
|
| 19 |
+
shared float shmem[GROUP_SIZE];
|
| 20 |
+
|
| 21 |
+
void quantize() {
|
| 22 |
+
const uint wgid = gl_WorkGroupID.x;
|
| 23 |
+
const uint tid = gl_LocalInvocationID.x;
|
| 24 |
+
|
| 25 |
+
// Each thread handles a vec4, so 8 threads handle a block
|
| 26 |
+
const uint blocks_per_group = GROUP_SIZE / 8;
|
| 27 |
+
|
| 28 |
+
const uint block_in_wg = tid / 8;
|
| 29 |
+
|
| 30 |
+
const uint ib = wgid * blocks_per_group + block_in_wg;
|
| 31 |
+
const uint iqs = tid % 8;
|
| 32 |
+
|
| 33 |
+
if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
|
| 34 |
+
return;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
const uint a_idx = ib * 8 + iqs;
|
| 38 |
+
|
| 39 |
+
vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f);
|
| 40 |
+
const vec4 abs_vals = abs(vals);
|
| 41 |
+
|
| 42 |
+
// Find absolute max for each block
|
| 43 |
+
shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
|
| 44 |
+
barrier();
|
| 45 |
+
[[unroll]] for (uint s = 4; s > 0; s >>= 1) {
|
| 46 |
+
if (iqs < s) {
|
| 47 |
+
shmem[tid] = max(shmem[tid], shmem[tid + s]);
|
| 48 |
+
}
|
| 49 |
+
barrier();
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
const float amax = shmem[block_in_wg * 8];
|
| 53 |
+
const float d = amax / 127.0;
|
| 54 |
+
const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
|
| 55 |
+
vals = round(vals * d_inv);
|
| 56 |
+
data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
|
| 57 |
+
barrier();
|
| 58 |
+
|
| 59 |
+
// Calculate the sum for each block
|
| 60 |
+
shmem[tid] = vals.x + vals.y + vals.z + vals.w;
|
| 61 |
+
barrier();
|
| 62 |
+
[[unroll]] for (uint s = 4; s > 0; s >>= 1) {
|
| 63 |
+
if (iqs < s) {
|
| 64 |
+
shmem[tid] += shmem[tid + s];
|
| 65 |
+
}
|
| 66 |
+
barrier();
|
| 67 |
+
}
|
| 68 |
+
if (iqs == 0) {
|
| 69 |
+
const float sum = shmem[tid];
|
| 70 |
+
|
| 71 |
+
data_b[ib].ds = f16vec2(vec2(d, sum * d));
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
void main() {
|
| 76 |
+
quantize();
|
| 77 |
+
}
|
ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 460
|
| 2 |
+
|
| 3 |
+
#extension GL_EXT_integer_dot_product : require
|
| 4 |
+
|
| 5 |
+
void main()
|
| 6 |
+
{
|
| 7 |
+
}
|
ggml/src/ggml-vulkan/vulkan-shaders/types.comp
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
#if !defined(GGML_TYPES_COMP)
|
| 3 |
#define GGML_TYPES_COMP
|
| 4 |
|
|
@@ -51,6 +50,7 @@ struct block_q4_0_packed16
|
|
| 51 |
#if defined(DATA_A_Q4_0)
|
| 52 |
#define QUANT_K QUANT_K_Q4_0
|
| 53 |
#define QUANT_R QUANT_R_Q4_0
|
|
|
|
| 54 |
#define A_TYPE block_q4_0
|
| 55 |
#define A_TYPE_PACKED16 block_q4_0_packed16
|
| 56 |
#endif
|
|
@@ -72,11 +72,19 @@ struct block_q4_1_packed16
|
|
| 72 |
uint16_t qs[16/2];
|
| 73 |
};
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
#if defined(DATA_A_Q4_1)
|
| 76 |
#define QUANT_K QUANT_K_Q4_1
|
| 77 |
#define QUANT_R QUANT_R_Q4_1
|
|
|
|
| 78 |
#define A_TYPE block_q4_1
|
| 79 |
#define A_TYPE_PACKED16 block_q4_1_packed16
|
|
|
|
| 80 |
#endif
|
| 81 |
|
| 82 |
#define QUANT_K_Q5_0 32
|
|
@@ -99,6 +107,7 @@ struct block_q5_0_packed16
|
|
| 99 |
#if defined(DATA_A_Q5_0)
|
| 100 |
#define QUANT_K QUANT_K_Q5_0
|
| 101 |
#define QUANT_R QUANT_R_Q5_0
|
|
|
|
| 102 |
#define A_TYPE block_q5_0
|
| 103 |
#define A_TYPE_PACKED16 block_q5_0_packed16
|
| 104 |
#endif
|
|
@@ -122,11 +131,20 @@ struct block_q5_1_packed16
|
|
| 122 |
uint16_t qs[16/2];
|
| 123 |
};
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
#if defined(DATA_A_Q5_1)
|
| 126 |
#define QUANT_K QUANT_K_Q5_1
|
| 127 |
#define QUANT_R QUANT_R_Q5_1
|
|
|
|
| 128 |
#define A_TYPE block_q5_1
|
| 129 |
#define A_TYPE_PACKED16 block_q5_1_packed16
|
|
|
|
| 130 |
#endif
|
| 131 |
|
| 132 |
#define QUANT_K_Q8_0 32
|
|
@@ -142,14 +160,40 @@ struct block_q8_0_packed16
|
|
| 142 |
float16_t d;
|
| 143 |
int16_t qs[32/2];
|
| 144 |
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
#if defined(DATA_A_Q8_0)
|
| 147 |
#define QUANT_K QUANT_K_Q8_0
|
| 148 |
#define QUANT_R QUANT_R_Q8_0
|
|
|
|
| 149 |
#define A_TYPE block_q8_0
|
| 150 |
#define A_TYPE_PACKED16 block_q8_0_packed16
|
|
|
|
| 151 |
#endif
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
// K-quants
|
| 154 |
#define QUANT_K_Q2_K 256
|
| 155 |
|
|
|
|
|
|
|
| 1 |
#if !defined(GGML_TYPES_COMP)
|
| 2 |
#define GGML_TYPES_COMP
|
| 3 |
|
|
|
|
| 50 |
#if defined(DATA_A_Q4_0)
|
| 51 |
#define QUANT_K QUANT_K_Q4_0
|
| 52 |
#define QUANT_R QUANT_R_Q4_0
|
| 53 |
+
#define QUANT_AUXF 1
|
| 54 |
#define A_TYPE block_q4_0
|
| 55 |
#define A_TYPE_PACKED16 block_q4_0_packed16
|
| 56 |
#endif
|
|
|
|
| 72 |
uint16_t qs[16/2];
|
| 73 |
};
|
| 74 |
|
| 75 |
+
struct block_q4_1_packed32
|
| 76 |
+
{
|
| 77 |
+
f16vec2 dm;
|
| 78 |
+
uint32_t qs[16/4];
|
| 79 |
+
};
|
| 80 |
+
|
| 81 |
#if defined(DATA_A_Q4_1)
|
| 82 |
#define QUANT_K QUANT_K_Q4_1
|
| 83 |
#define QUANT_R QUANT_R_Q4_1
|
| 84 |
+
#define QUANT_AUXF 2
|
| 85 |
#define A_TYPE block_q4_1
|
| 86 |
#define A_TYPE_PACKED16 block_q4_1_packed16
|
| 87 |
+
#define A_TYPE_PACKED32 block_q4_1_packed32
|
| 88 |
#endif
|
| 89 |
|
| 90 |
#define QUANT_K_Q5_0 32
|
|
|
|
| 107 |
#if defined(DATA_A_Q5_0)
|
| 108 |
#define QUANT_K QUANT_K_Q5_0
|
| 109 |
#define QUANT_R QUANT_R_Q5_0
|
| 110 |
+
#define QUANT_AUXF 1
|
| 111 |
#define A_TYPE block_q5_0
|
| 112 |
#define A_TYPE_PACKED16 block_q5_0_packed16
|
| 113 |
#endif
|
|
|
|
| 131 |
uint16_t qs[16/2];
|
| 132 |
};
|
| 133 |
|
| 134 |
+
struct block_q5_1_packed32
|
| 135 |
+
{
|
| 136 |
+
f16vec2 dm;
|
| 137 |
+
uint qh;
|
| 138 |
+
uint32_t qs[16/4];
|
| 139 |
+
};
|
| 140 |
+
|
| 141 |
#if defined(DATA_A_Q5_1)
|
| 142 |
#define QUANT_K QUANT_K_Q5_1
|
| 143 |
#define QUANT_R QUANT_R_Q5_1
|
| 144 |
+
#define QUANT_AUXF 2
|
| 145 |
#define A_TYPE block_q5_1
|
| 146 |
#define A_TYPE_PACKED16 block_q5_1_packed16
|
| 147 |
+
#define A_TYPE_PACKED32 block_q5_1_packed32
|
| 148 |
#endif
|
| 149 |
|
| 150 |
#define QUANT_K_Q8_0 32
|
|
|
|
| 160 |
float16_t d;
|
| 161 |
int16_t qs[32/2];
|
| 162 |
};
|
| 163 |
+
struct block_q8_0_packed32
|
| 164 |
+
{
|
| 165 |
+
float16_t d;
|
| 166 |
+
int32_t qs[32/4];
|
| 167 |
+
};
|
| 168 |
|
| 169 |
#if defined(DATA_A_Q8_0)
|
| 170 |
#define QUANT_K QUANT_K_Q8_0
|
| 171 |
#define QUANT_R QUANT_R_Q8_0
|
| 172 |
+
#define QUANT_AUXF 1
|
| 173 |
#define A_TYPE block_q8_0
|
| 174 |
#define A_TYPE_PACKED16 block_q8_0_packed16
|
| 175 |
+
#define A_TYPE_PACKED32 block_q8_0_packed32
|
| 176 |
#endif
|
| 177 |
|
| 178 |
+
#define QUANT_K_Q8_1 32
|
| 179 |
+
#define QUANT_R_Q8_1 1
|
| 180 |
+
|
| 181 |
+
struct block_q8_1
|
| 182 |
+
{
|
| 183 |
+
f16vec2 ds;
|
| 184 |
+
int8_t qs[32];
|
| 185 |
+
};
|
| 186 |
+
struct block_q8_1_packed16
|
| 187 |
+
{
|
| 188 |
+
f16vec2 ds;
|
| 189 |
+
int16_t qs[16];
|
| 190 |
+
};
|
| 191 |
+
struct block_q8_1_packed32
|
| 192 |
+
{
|
| 193 |
+
f16vec2 ds;
|
| 194 |
+
int32_t qs[8];
|
| 195 |
+
};
|
| 196 |
+
|
| 197 |
// K-quants
|
| 198 |
#define QUANT_K_Q2_K 256
|
| 199 |
|
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
CHANGED
|
@@ -295,7 +295,10 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|
| 295 |
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
| 296 |
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
| 297 |
|
| 298 |
-
std::map<std::string, std::string> base_dict = {
|
|
|
|
|
|
|
|
|
|
| 299 |
std::string shader_name = "matmul";
|
| 300 |
|
| 301 |
if (matmul_id) {
|
|
@@ -313,9 +316,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|
| 313 |
base_dict["COOPMAT"] = "1";
|
| 314 |
}
|
| 315 |
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
| 319 |
|
| 320 |
// Shaders with f16 B_TYPE
|
| 321 |
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
|
@@ -339,14 +340,20 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|
| 339 |
|
| 340 |
// don't generate f32 variants for coopmat2
|
| 341 |
if (!coopmat2) {
|
| 342 |
-
string_to_spv(shader_name + "_" + tname + "_f32",
|
| 343 |
-
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},
|
| 344 |
}
|
| 345 |
|
| 346 |
if (tname != "f16" && tname != "f32") {
|
| 347 |
-
string_to_spv(shader_name + "_" + tname + "_f16",
|
| 348 |
-
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"
|
| 349 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
}
|
| 351 |
}
|
| 352 |
|
|
@@ -458,6 +465,7 @@ void process_shaders() {
|
|
| 458 |
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
| 459 |
|
| 460 |
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
|
|
|
| 461 |
|
| 462 |
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
| 463 |
|
|
|
|
| 295 |
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
| 296 |
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
| 297 |
|
| 298 |
+
std::map<std::string, std::string> base_dict = {
|
| 299 |
+
{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"},
|
| 300 |
+
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
|
| 301 |
+
};
|
| 302 |
std::string shader_name = "matmul";
|
| 303 |
|
| 304 |
if (matmul_id) {
|
|
|
|
| 316 |
base_dict["COOPMAT"] = "1";
|
| 317 |
}
|
| 318 |
|
| 319 |
+
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
|
|
|
|
|
|
| 320 |
|
| 321 |
// Shaders with f16 B_TYPE
|
| 322 |
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
|
|
|
| 340 |
|
| 341 |
// don't generate f32 variants for coopmat2
|
| 342 |
if (!coopmat2) {
|
| 343 |
+
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
| 344 |
+
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 345 |
}
|
| 346 |
|
| 347 |
if (tname != "f16" && tname != "f32") {
|
| 348 |
+
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
| 349 |
+
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 350 |
}
|
| 351 |
+
|
| 352 |
+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
| 353 |
+
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
| 354 |
+
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
| 355 |
+
}
|
| 356 |
+
#endif
|
| 357 |
}
|
| 358 |
}
|
| 359 |
|
|
|
|
| 465 |
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
| 466 |
|
| 467 |
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
| 468 |
+
string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
|
| 469 |
|
| 470 |
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
| 471 |
|