Spaces:
Sleeping
SOTA 2-bit quants (llama/4773)
Browse files* iq2_xxs: basics
* iq2_xxs: scalar and AVX2 dot products
Needed to change Q8_K to have quants in the -127...127 range,
else the IQ2_XXS AVX implementation becomes very awkward.
The alternative would have been to use Q8_0 instead. Perhaps
I'll change later, for now this is what we have.
* iq2_xxs: ARM_NEON dot product
Somehow strangely slow (112 ms/token).
* iq2_xxs: WIP Metal
Dequantize works, something is still wrong with the
dot product.
* iq2_xxs: Metal dot product now works
We have
PP-512 = 475 t/s
TG-128 = 47.3 t/s
Not the greatest performance, but not complete garbage either.
* iq2_xxs: slighty faster dot product
TG-128 is now 48.4 t/s
* iq2_xxs: slighty faster dot product
TG-128 is now 50.9 t/s
* iq2_xxs: even faster Metal dot product
TG-128 is now 54.1 t/s.
Strangely enough, putting the signs lookup table
into shared memory has a bigger impact than the
grid values being in shared memory.
* iq2_xxs: dequantize CUDA kernel - fix conflict with master
* iq2_xxs: quantized CUDA dot product (MMVQ)
We get TG-128 = 153.1 t/s
* iq2_xxs: slightly faster CUDA dot product
TG-128 is now at 155.1 t/s.
* iq2_xxs: add to llama ftype enum
* iq2_xxs: fix MoE on Metal
* Fix missing MMQ ops when on hipBLAS
I had put the ggml_supports_mmq call at the wrong place.
* Fix bug in qequantize_row_iq2_xxs
The 0.25f factor was missing.
Great detective work by
@ggerganov
!
* Fixing tests
* PR suggestion
---------
Co-authored-by: Iwan Kawrakow <[email protected]>
- ggml-cuda.cu +205 -0
- ggml-metal.m +40 -0
- ggml-metal.metal +314 -0
- ggml-quants.c +293 -1
- ggml-quants.h +12 -0
- ggml.c +26 -0
- ggml.h +3 -0
|
@@ -477,6 +477,14 @@ typedef struct {
|
|
| 477 |
} block_q6_K;
|
| 478 |
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
|
| 479 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
#define WARP_SIZE 32
|
| 481 |
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
| 482 |
|
|
@@ -1292,6 +1300,128 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
|
|
| 1292 |
#endif
|
| 1293 |
}
|
| 1294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1295 |
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
| 1296 |
|
| 1297 |
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
|
@@ -3825,6 +3955,55 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
|
|
| 3825 |
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
|
| 3826 |
}
|
| 3827 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3828 |
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
| 3829 |
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
| 3830 |
static __device__ __forceinline__ void mul_mat_q(
|
|
@@ -5664,6 +5843,12 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
|
|
| 5664 |
#endif
|
| 5665 |
}
|
| 5666 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5667 |
template <typename src_t, typename dst_t>
|
| 5668 |
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
| 5669 |
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
|
@@ -5692,6 +5877,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|
| 5692 |
return dequantize_row_q5_K_cuda;
|
| 5693 |
case GGML_TYPE_Q6_K:
|
| 5694 |
return dequantize_row_q6_K_cuda;
|
|
|
|
|
|
|
| 5695 |
case GGML_TYPE_F32:
|
| 5696 |
return convert_unary_cuda<float>;
|
| 5697 |
default:
|
|
@@ -5721,6 +5908,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|
| 5721 |
return dequantize_row_q5_K_cuda;
|
| 5722 |
case GGML_TYPE_Q6_K:
|
| 5723 |
return dequantize_row_q6_K_cuda;
|
|
|
|
|
|
|
| 5724 |
case GGML_TYPE_F16:
|
| 5725 |
return convert_unary_cuda<half>;
|
| 5726 |
default:
|
|
@@ -5915,6 +6104,15 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
|
|
| 5915 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
| 5916 |
}
|
| 5917 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5918 |
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
| 5919 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 5920 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
|
@@ -7407,6 +7605,7 @@ static int64_t get_row_rounding(ggml_type type) {
|
|
| 7407 |
case GGML_TYPE_Q4_K:
|
| 7408 |
case GGML_TYPE_Q5_K:
|
| 7409 |
case GGML_TYPE_Q6_K:
|
|
|
|
| 7410 |
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
| 7411 |
default:
|
| 7412 |
GGML_ASSERT(false);
|
|
@@ -7427,6 +7626,7 @@ static int64_t get_row_rounding(ggml_type type) {
|
|
| 7427 |
case GGML_TYPE_Q3_K:
|
| 7428 |
case GGML_TYPE_Q4_K:
|
| 7429 |
case GGML_TYPE_Q5_K:
|
|
|
|
| 7430 |
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
| 7431 |
case GGML_TYPE_Q6_K:
|
| 7432 |
return 64;
|
|
@@ -7477,6 +7677,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
|
| 7477 |
case GGML_TYPE_Q6_K:
|
| 7478 |
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
| 7479 |
break;
|
|
|
|
|
|
|
|
|
|
| 7480 |
default:
|
| 7481 |
GGML_ASSERT(false);
|
| 7482 |
break;
|
|
@@ -8693,6 +8896,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|
| 8693 |
|
| 8694 |
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 8695 |
|
|
|
|
|
|
|
| 8696 |
// debug helpers
|
| 8697 |
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
| 8698 |
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
|
|
|
|
| 477 |
} block_q6_K;
|
| 478 |
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
|
| 479 |
|
| 480 |
+
#define QR2_XXS 8
|
| 481 |
+
#define QI2_XXS (QK_K / (4*QR2_XXS))
|
| 482 |
+
typedef struct {
|
| 483 |
+
half d;
|
| 484 |
+
uint16_t qs[QK_K/8];
|
| 485 |
+
} block_iq2_xxs;
|
| 486 |
+
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
|
| 487 |
+
|
| 488 |
#define WARP_SIZE 32
|
| 489 |
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
| 490 |
|
|
|
|
| 1300 |
#endif
|
| 1301 |
}
|
| 1302 |
|
| 1303 |
+
static const __device__ uint64_t kgrid_iq2xxs[256] = {
|
| 1304 |
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
| 1305 |
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
| 1306 |
+
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
| 1307 |
+
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
| 1308 |
+
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
| 1309 |
+
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
| 1310 |
+
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
| 1311 |
+
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
| 1312 |
+
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
| 1313 |
+
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
| 1314 |
+
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
| 1315 |
+
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
| 1316 |
+
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
| 1317 |
+
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
| 1318 |
+
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
| 1319 |
+
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
| 1320 |
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
| 1321 |
+
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
| 1322 |
+
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
| 1323 |
+
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
| 1324 |
+
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
| 1325 |
+
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
| 1326 |
+
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
| 1327 |
+
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
| 1328 |
+
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
| 1329 |
+
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
| 1330 |
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
| 1331 |
+
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
| 1332 |
+
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
| 1333 |
+
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
| 1334 |
+
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
| 1335 |
+
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
| 1336 |
+
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
| 1337 |
+
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
| 1338 |
+
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
| 1339 |
+
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
| 1340 |
+
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
| 1341 |
+
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
| 1342 |
+
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
| 1343 |
+
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
| 1344 |
+
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
| 1345 |
+
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
| 1346 |
+
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
| 1347 |
+
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
| 1348 |
+
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
| 1349 |
+
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
| 1350 |
+
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
| 1351 |
+
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
| 1352 |
+
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
| 1353 |
+
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
| 1354 |
+
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
| 1355 |
+
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
| 1356 |
+
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
| 1357 |
+
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
| 1358 |
+
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
| 1359 |
+
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
| 1360 |
+
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
| 1361 |
+
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
| 1362 |
+
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
| 1363 |
+
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
| 1364 |
+
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
| 1365 |
+
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
| 1366 |
+
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
| 1367 |
+
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
| 1368 |
+
};
|
| 1369 |
+
|
| 1370 |
+
static const __device__ uint8_t ksigns_iq2xs[128] = {
|
| 1371 |
+
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 1372 |
+
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
| 1373 |
+
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
| 1374 |
+
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
| 1375 |
+
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
| 1376 |
+
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
| 1377 |
+
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
| 1378 |
+
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
| 1379 |
+
};
|
| 1380 |
+
|
| 1381 |
+
static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
| 1382 |
+
|
| 1383 |
+
inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
|
| 1384 |
+
switch (type) {
|
| 1385 |
+
case GGML_TYPE_Q4_0:
|
| 1386 |
+
case GGML_TYPE_Q4_1:
|
| 1387 |
+
case GGML_TYPE_Q5_0:
|
| 1388 |
+
case GGML_TYPE_Q5_1:
|
| 1389 |
+
case GGML_TYPE_Q8_0:
|
| 1390 |
+
case GGML_TYPE_Q2_K:
|
| 1391 |
+
case GGML_TYPE_Q3_K:
|
| 1392 |
+
case GGML_TYPE_Q4_K:
|
| 1393 |
+
case GGML_TYPE_Q5_K:
|
| 1394 |
+
case GGML_TYPE_Q6_K:
|
| 1395 |
+
return true;
|
| 1396 |
+
default:
|
| 1397 |
+
return false;
|
| 1398 |
+
}
|
| 1399 |
+
}
|
| 1400 |
+
|
| 1401 |
+
template<typename dst_t>
|
| 1402 |
+
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
| 1403 |
+
|
| 1404 |
+
const int i = blockIdx.x;
|
| 1405 |
+
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
| 1406 |
+
|
| 1407 |
+
const int tid = threadIdx.x;
|
| 1408 |
+
#if QK_K == 256
|
| 1409 |
+
const int il = tid/8; // 0...3
|
| 1410 |
+
const int ib = tid%8; // 0...7
|
| 1411 |
+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
| 1412 |
+
const uint16_t * q2 = x[i].qs + 4*ib;
|
| 1413 |
+
const uint8_t * aux8 = (const uint8_t *)q2;
|
| 1414 |
+
const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[il]);
|
| 1415 |
+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
| 1416 |
+
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
|
| 1417 |
+
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
| 1418 |
+
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
| 1419 |
+
#else
|
| 1420 |
+
assert(false);
|
| 1421 |
+
#endif
|
| 1422 |
+
|
| 1423 |
+
}
|
| 1424 |
+
|
| 1425 |
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
| 1426 |
|
| 1427 |
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
|
|
|
| 3955 |
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
|
| 3956 |
}
|
| 3957 |
|
| 3958 |
+
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
| 3959 |
+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
| 3960 |
+
#if QK_K == 256
|
| 3961 |
+
const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
|
| 3962 |
+
|
| 3963 |
+
#if QR2_XXS == 8
|
| 3964 |
+
const int ib32 = iqs;
|
| 3965 |
+
const uint16_t * q2 = bq2->qs + 4*ib32;
|
| 3966 |
+
const uint8_t * aux8 = (const uint8_t *)q2;
|
| 3967 |
+
const int8_t * q8 = bq8_1[ib32].qs;
|
| 3968 |
+
uint32_t aux32 = q2[2] | (q2[3] << 16);
|
| 3969 |
+
int sumi = 0;
|
| 3970 |
+
for (int l = 0; l < 4; ++l) {
|
| 3971 |
+
const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[l]);
|
| 3972 |
+
const uint8_t signs = ksigns_iq2xs[aux32 & 127];
|
| 3973 |
+
for (int j = 0; j < 8; ++j) {
|
| 3974 |
+
sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
| 3975 |
+
}
|
| 3976 |
+
q8 += 8;
|
| 3977 |
+
aux32 >>= 7;
|
| 3978 |
+
}
|
| 3979 |
+
const float d = (float)bq2->d * (0.5f + aux32) * (float)bq8_1[ib32].ds.x * 0.25f;
|
| 3980 |
+
return d * sumi;
|
| 3981 |
+
#else
|
| 3982 |
+
// iqs is 0...15
|
| 3983 |
+
const int ib32 = iqs/2;
|
| 3984 |
+
const int il = iqs%2;
|
| 3985 |
+
const uint16_t * q2 = bq2->qs + 4*ib32;
|
| 3986 |
+
const uint8_t * aux8 = (const uint8_t *)q2;
|
| 3987 |
+
const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
|
| 3988 |
+
const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
|
| 3989 |
+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
| 3990 |
+
const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
|
| 3991 |
+
const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
| 3992 |
+
const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
|
| 3993 |
+
const int8_t * q8 = bq8_1[ib32].qs + 16*il;
|
| 3994 |
+
int sumi1 = 0, sumi2 = 0;
|
| 3995 |
+
for (int j = 0; j < 8; ++j) {
|
| 3996 |
+
sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
|
| 3997 |
+
sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
|
| 3998 |
+
}
|
| 3999 |
+
return d * (sumi1 + sumi2);
|
| 4000 |
+
#endif
|
| 4001 |
+
#else
|
| 4002 |
+
assert(false);
|
| 4003 |
+
return 0.f;
|
| 4004 |
+
#endif
|
| 4005 |
+
}
|
| 4006 |
+
|
| 4007 |
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
| 4008 |
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
| 4009 |
static __device__ __forceinline__ void mul_mat_q(
|
|
|
|
| 5843 |
#endif
|
| 5844 |
}
|
| 5845 |
|
| 5846 |
+
template<typename dst_t>
|
| 5847 |
+
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
| 5848 |
+
const int nb = k / QK_K;
|
| 5849 |
+
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
| 5850 |
+
}
|
| 5851 |
+
|
| 5852 |
template <typename src_t, typename dst_t>
|
| 5853 |
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
| 5854 |
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
|
|
|
| 5877 |
return dequantize_row_q5_K_cuda;
|
| 5878 |
case GGML_TYPE_Q6_K:
|
| 5879 |
return dequantize_row_q6_K_cuda;
|
| 5880 |
+
case GGML_TYPE_IQ2_XXS:
|
| 5881 |
+
return dequantize_row_iq2_xxs_cuda;
|
| 5882 |
case GGML_TYPE_F32:
|
| 5883 |
return convert_unary_cuda<float>;
|
| 5884 |
default:
|
|
|
|
| 5908 |
return dequantize_row_q5_K_cuda;
|
| 5909 |
case GGML_TYPE_Q6_K:
|
| 5910 |
return dequantize_row_q6_K_cuda;
|
| 5911 |
+
case GGML_TYPE_IQ2_XXS:
|
| 5912 |
+
return dequantize_row_iq2_xxs_cuda;
|
| 5913 |
case GGML_TYPE_F16:
|
| 5914 |
return convert_unary_cuda<half>;
|
| 5915 |
default:
|
|
|
|
| 6104 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
| 6105 |
}
|
| 6106 |
|
| 6107 |
+
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 6108 |
+
GGML_ASSERT(ncols % QK_K == 0);
|
| 6109 |
+
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 6110 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 6111 |
+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 6112 |
+
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
| 6113 |
+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
| 6114 |
+
}
|
| 6115 |
+
|
| 6116 |
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
| 6117 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 6118 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
|
|
|
| 7605 |
case GGML_TYPE_Q4_K:
|
| 7606 |
case GGML_TYPE_Q5_K:
|
| 7607 |
case GGML_TYPE_Q6_K:
|
| 7608 |
+
case GGML_TYPE_IQ2_XXS:
|
| 7609 |
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
| 7610 |
default:
|
| 7611 |
GGML_ASSERT(false);
|
|
|
|
| 7626 |
case GGML_TYPE_Q3_K:
|
| 7627 |
case GGML_TYPE_Q4_K:
|
| 7628 |
case GGML_TYPE_Q5_K:
|
| 7629 |
+
case GGML_TYPE_IQ2_XXS:
|
| 7630 |
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
| 7631 |
case GGML_TYPE_Q6_K:
|
| 7632 |
return 64;
|
|
|
|
| 7677 |
case GGML_TYPE_Q6_K:
|
| 7678 |
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
| 7679 |
break;
|
| 7680 |
+
case GGML_TYPE_IQ2_XXS:
|
| 7681 |
+
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
| 7682 |
+
break;
|
| 7683 |
default:
|
| 7684 |
GGML_ASSERT(false);
|
| 7685 |
break;
|
|
|
|
| 8896 |
|
| 8897 |
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 8898 |
|
| 8899 |
+
use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
|
| 8900 |
+
|
| 8901 |
// debug helpers
|
| 8902 |
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
| 8903 |
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
|
|
@@ -88,6 +88,7 @@ struct ggml_metal_context {
|
|
| 88 |
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
| 89 |
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
| 90 |
GGML_METAL_DECL_KERNEL(get_rows_i32);
|
|
|
|
| 91 |
GGML_METAL_DECL_KERNEL(rms_norm);
|
| 92 |
GGML_METAL_DECL_KERNEL(group_norm);
|
| 93 |
GGML_METAL_DECL_KERNEL(norm);
|
|
@@ -106,6 +107,7 @@ struct ggml_metal_context {
|
|
| 106 |
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
| 107 |
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
| 108 |
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
|
|
|
| 109 |
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
| 110 |
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
| 111 |
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
|
@@ -121,6 +123,7 @@ struct ggml_metal_context {
|
|
| 121 |
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
| 122 |
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
| 123 |
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
|
|
|
| 124 |
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
| 125 |
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
| 126 |
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
|
@@ -133,6 +136,7 @@ struct ggml_metal_context {
|
|
| 133 |
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
| 134 |
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
| 135 |
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
|
|
|
| 136 |
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
| 137 |
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
| 138 |
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
|
@@ -145,6 +149,7 @@ struct ggml_metal_context {
|
|
| 145 |
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
|
| 146 |
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
| 147 |
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
|
|
|
| 148 |
GGML_METAL_DECL_KERNEL(rope_f32);
|
| 149 |
GGML_METAL_DECL_KERNEL(rope_f16);
|
| 150 |
GGML_METAL_DECL_KERNEL(alibi_f32);
|
|
@@ -379,6 +384,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 379 |
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
| 380 |
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
| 381 |
GGML_METAL_ADD_KERNEL(get_rows_i32);
|
|
|
|
| 382 |
GGML_METAL_ADD_KERNEL(rms_norm);
|
| 383 |
GGML_METAL_ADD_KERNEL(group_norm);
|
| 384 |
GGML_METAL_ADD_KERNEL(norm);
|
|
@@ -397,6 +403,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 397 |
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
| 398 |
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
| 399 |
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
|
|
|
| 400 |
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
| 401 |
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
| 402 |
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
|
@@ -412,6 +419,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 412 |
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
|
| 413 |
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
| 414 |
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
|
|
|
| 415 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
| 416 |
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
| 417 |
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
|
@@ -425,6 +433,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 425 |
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
| 426 |
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
| 427 |
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
|
|
|
| 428 |
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
|
| 429 |
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
|
| 430 |
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
|
|
@@ -437,6 +446,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 437 |
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
|
| 438 |
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
|
| 439 |
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
|
|
|
|
| 440 |
}
|
| 441 |
GGML_METAL_ADD_KERNEL(rope_f32);
|
| 442 |
GGML_METAL_ADD_KERNEL(rope_f16);
|
|
@@ -502,6 +512,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 502 |
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
| 503 |
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
| 504 |
GGML_METAL_DEL_KERNEL(get_rows_i32);
|
|
|
|
| 505 |
GGML_METAL_DEL_KERNEL(rms_norm);
|
| 506 |
GGML_METAL_DEL_KERNEL(group_norm);
|
| 507 |
GGML_METAL_DEL_KERNEL(norm);
|
|
@@ -520,6 +531,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 520 |
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
| 521 |
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
| 522 |
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
|
|
|
| 523 |
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
| 524 |
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
| 525 |
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
|
@@ -535,6 +547,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 535 |
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
| 536 |
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
| 537 |
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
|
|
|
| 538 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
| 539 |
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
| 540 |
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
|
@@ -548,6 +561,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 548 |
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
| 549 |
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
| 550 |
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
|
|
|
| 551 |
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
| 552 |
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
| 553 |
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
|
@@ -560,6 +574,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 560 |
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
|
| 561 |
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
| 562 |
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
|
|
|
| 563 |
}
|
| 564 |
GGML_METAL_DEL_KERNEL(rope_f32);
|
| 565 |
GGML_METAL_DEL_KERNEL(rope_f16);
|
|
@@ -1541,6 +1556,7 @@ bool ggml_metal_graph_compute(
|
|
| 1541 |
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
| 1542 |
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
| 1543 |
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
|
|
|
| 1544 |
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
| 1545 |
}
|
| 1546 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -1653,6 +1669,12 @@ bool ggml_metal_graph_compute(
|
|
| 1653 |
nth1 = 32;
|
| 1654 |
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
| 1655 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1656 |
default:
|
| 1657 |
{
|
| 1658 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
@@ -1686,9 +1708,14 @@ bool ggml_metal_graph_compute(
|
|
| 1686 |
|
| 1687 |
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
| 1688 |
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
|
|
|
| 1689 |
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
| 1690 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1691 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1692 |
else if (src0t == GGML_TYPE_Q4_K) {
|
| 1693 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1694 |
}
|
|
@@ -1778,6 +1805,7 @@ bool ggml_metal_graph_compute(
|
|
| 1778 |
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
|
| 1779 |
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
| 1780 |
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
|
|
|
| 1781 |
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
| 1782 |
}
|
| 1783 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -1893,6 +1921,12 @@ bool ggml_metal_graph_compute(
|
|
| 1893 |
nth1 = 32;
|
| 1894 |
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
| 1895 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1896 |
default:
|
| 1897 |
{
|
| 1898 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
|
@@ -1942,9 +1976,14 @@ bool ggml_metal_graph_compute(
|
|
| 1942 |
|
| 1943 |
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
| 1944 |
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
|
|
|
| 1945 |
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
| 1946 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1947 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1948 |
else if (src2t == GGML_TYPE_Q4_K) {
|
| 1949 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1950 |
}
|
|
@@ -1982,6 +2021,7 @@ bool ggml_metal_graph_compute(
|
|
| 1982 |
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
|
| 1983 |
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
| 1984 |
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
|
|
|
|
| 1985 |
default: GGML_ASSERT(false && "not implemented");
|
| 1986 |
}
|
| 1987 |
|
|
|
|
| 88 |
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
| 89 |
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
| 90 |
GGML_METAL_DECL_KERNEL(get_rows_i32);
|
| 91 |
+
GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
|
| 92 |
GGML_METAL_DECL_KERNEL(rms_norm);
|
| 93 |
GGML_METAL_DECL_KERNEL(group_norm);
|
| 94 |
GGML_METAL_DECL_KERNEL(norm);
|
|
|
|
| 107 |
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
| 108 |
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
| 109 |
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
| 110 |
+
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
|
| 111 |
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
| 112 |
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
| 113 |
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
|
|
|
| 123 |
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
| 124 |
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
| 125 |
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
| 126 |
+
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
|
| 127 |
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
| 128 |
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
| 129 |
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
|
|
|
| 136 |
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
| 137 |
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
| 138 |
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
| 139 |
+
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
|
| 140 |
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
| 141 |
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
| 142 |
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
|
|
|
| 149 |
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
|
| 150 |
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
| 151 |
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
| 152 |
+
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
|
| 153 |
GGML_METAL_DECL_KERNEL(rope_f32);
|
| 154 |
GGML_METAL_DECL_KERNEL(rope_f16);
|
| 155 |
GGML_METAL_DECL_KERNEL(alibi_f32);
|
|
|
|
| 384 |
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
| 385 |
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
| 386 |
GGML_METAL_ADD_KERNEL(get_rows_i32);
|
| 387 |
+
GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
|
| 388 |
GGML_METAL_ADD_KERNEL(rms_norm);
|
| 389 |
GGML_METAL_ADD_KERNEL(group_norm);
|
| 390 |
GGML_METAL_ADD_KERNEL(norm);
|
|
|
|
| 403 |
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
| 404 |
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
| 405 |
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
| 406 |
+
GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
|
| 407 |
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
| 408 |
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
| 409 |
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
|
|
|
| 419 |
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
|
| 420 |
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
| 421 |
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
| 422 |
+
GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
|
| 423 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
| 424 |
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
| 425 |
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
|
|
|
| 433 |
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
| 434 |
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
| 435 |
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
| 436 |
+
GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
|
| 437 |
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
|
| 438 |
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
|
| 439 |
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
|
|
|
|
| 446 |
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
|
| 447 |
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
|
| 448 |
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
|
| 449 |
+
GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
|
| 450 |
}
|
| 451 |
GGML_METAL_ADD_KERNEL(rope_f32);
|
| 452 |
GGML_METAL_ADD_KERNEL(rope_f16);
|
|
|
|
| 512 |
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
| 513 |
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
| 514 |
GGML_METAL_DEL_KERNEL(get_rows_i32);
|
| 515 |
+
GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
|
| 516 |
GGML_METAL_DEL_KERNEL(rms_norm);
|
| 517 |
GGML_METAL_DEL_KERNEL(group_norm);
|
| 518 |
GGML_METAL_DEL_KERNEL(norm);
|
|
|
|
| 531 |
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
| 532 |
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
| 533 |
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
| 534 |
+
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
|
| 535 |
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
| 536 |
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
| 537 |
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
|
|
|
| 547 |
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
| 548 |
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
| 549 |
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
| 550 |
+
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
|
| 551 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
| 552 |
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
| 553 |
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
|
|
|
| 561 |
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
| 562 |
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
| 563 |
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
| 564 |
+
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
|
| 565 |
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
| 566 |
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
| 567 |
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
|
|
|
| 574 |
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
|
| 575 |
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
| 576 |
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
| 577 |
+
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
|
| 578 |
}
|
| 579 |
GGML_METAL_DEL_KERNEL(rope_f32);
|
| 580 |
GGML_METAL_DEL_KERNEL(rope_f16);
|
|
|
|
| 1556 |
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
| 1557 |
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
| 1558 |
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
| 1559 |
+
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
|
| 1560 |
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
| 1561 |
}
|
| 1562 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
| 1669 |
nth1 = 32;
|
| 1670 |
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
| 1671 |
} break;
|
| 1672 |
+
case GGML_TYPE_IQ2_XXS:
|
| 1673 |
+
{
|
| 1674 |
+
nth0 = 4;
|
| 1675 |
+
nth1 = 16;
|
| 1676 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
|
| 1677 |
+
} break;
|
| 1678 |
default:
|
| 1679 |
{
|
| 1680 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
|
|
| 1708 |
|
| 1709 |
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
| 1710 |
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
| 1711 |
+
//src0t == GGML_TYPE_IQ2_XXS ||
|
| 1712 |
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
| 1713 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1714 |
}
|
| 1715 |
+
else if (src0t == GGML_TYPE_IQ2_XXS) {
|
| 1716 |
+
[encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
|
| 1717 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1718 |
+
}
|
| 1719 |
else if (src0t == GGML_TYPE_Q4_K) {
|
| 1720 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1721 |
}
|
|
|
|
| 1805 |
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
|
| 1806 |
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
| 1807 |
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
| 1808 |
+
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
|
| 1809 |
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
| 1810 |
}
|
| 1811 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
| 1921 |
nth1 = 32;
|
| 1922 |
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
| 1923 |
} break;
|
| 1924 |
+
case GGML_TYPE_IQ2_XXS:
|
| 1925 |
+
{
|
| 1926 |
+
nth0 = 4;
|
| 1927 |
+
nth1 = 16;
|
| 1928 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
|
| 1929 |
+
} break;
|
| 1930 |
default:
|
| 1931 |
{
|
| 1932 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
|
|
|
| 1976 |
|
| 1977 |
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
| 1978 |
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
| 1979 |
+
//src2t == GGML_TYPE_IQ2_XXS ||
|
| 1980 |
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
| 1981 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1982 |
}
|
| 1983 |
+
else if (src2t == GGML_TYPE_IQ2_XXS) {
|
| 1984 |
+
[encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
|
| 1985 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1986 |
+
}
|
| 1987 |
else if (src2t == GGML_TYPE_Q4_K) {
|
| 1988 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1989 |
}
|
|
|
|
| 2021 |
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
|
| 2022 |
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
| 2023 |
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
|
| 2024 |
+
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
|
| 2025 |
default: GGML_ASSERT(false && "not implemented");
|
| 2026 |
}
|
| 2027 |
|
|
@@ -2446,6 +2446,12 @@ typedef struct {
|
|
| 2446 |
} block_q6_K;
|
| 2447 |
// 210 bytes / block
|
| 2448 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2449 |
//====================================== dot products =========================
|
| 2450 |
|
| 2451 |
void kernel_mul_mv_q2_K_f32_impl(
|
|
@@ -3468,6 +3474,221 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
| 3468 |
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 3469 |
}
|
| 3470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3471 |
//============================= templates and their specializations =============================
|
| 3472 |
|
| 3473 |
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
@@ -3739,6 +3960,31 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
| 3739 |
}
|
| 3740 |
}
|
| 3741 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3742 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 3743 |
kernel void kernel_get_rows(
|
| 3744 |
device const void * src0,
|
|
@@ -4278,6 +4524,7 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows
|
|
| 4278 |
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
| 4279 |
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4280 |
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
|
| 4281 |
|
| 4282 |
//
|
| 4283 |
// matrix-matrix multiplication
|
|
@@ -4314,6 +4561,7 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
| 4314 |
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
| 4315 |
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4316 |
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
|
| 4317 |
|
| 4318 |
//
|
| 4319 |
// indirect matrix-matrix multiplication
|
|
@@ -4362,6 +4610,7 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
| 4362 |
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
| 4363 |
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4364 |
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
|
| 4365 |
|
| 4366 |
//
|
| 4367 |
// matrix-vector multiplication
|
|
@@ -5134,3 +5383,68 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
| 5134 |
tiisg,
|
| 5135 |
sgitg);
|
| 5136 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2446 |
} block_q6_K;
|
| 2447 |
// 210 bytes / block
|
| 2448 |
|
| 2449 |
+
typedef struct {
|
| 2450 |
+
half d;
|
| 2451 |
+
uint16_t qs[QK_K/8];
|
| 2452 |
+
} block_iq2_xxs;
|
| 2453 |
+
// 66 bytes / block for QK_K = 256, so 2.0625 bpw
|
| 2454 |
+
|
| 2455 |
//====================================== dot products =========================
|
| 2456 |
|
| 2457 |
void kernel_mul_mv_q2_K_f32_impl(
|
|
|
|
| 3474 |
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 3475 |
}
|
| 3476 |
|
| 3477 |
+
// ======================= "True" 2-bit
|
| 3478 |
+
|
| 3479 |
+
constexpr constant static uint64_t kgrid_iq2xxs[256] = {
|
| 3480 |
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
| 3481 |
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
| 3482 |
+
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
| 3483 |
+
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
| 3484 |
+
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
| 3485 |
+
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
| 3486 |
+
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
| 3487 |
+
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
| 3488 |
+
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
| 3489 |
+
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
| 3490 |
+
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
| 3491 |
+
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
| 3492 |
+
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
| 3493 |
+
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
| 3494 |
+
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
| 3495 |
+
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
| 3496 |
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
| 3497 |
+
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
| 3498 |
+
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
| 3499 |
+
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
| 3500 |
+
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
| 3501 |
+
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
| 3502 |
+
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
| 3503 |
+
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
| 3504 |
+
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
| 3505 |
+
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
| 3506 |
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
| 3507 |
+
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
| 3508 |
+
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
| 3509 |
+
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
| 3510 |
+
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
| 3511 |
+
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
| 3512 |
+
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
| 3513 |
+
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
| 3514 |
+
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
| 3515 |
+
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
| 3516 |
+
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
| 3517 |
+
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
| 3518 |
+
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
| 3519 |
+
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
| 3520 |
+
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
| 3521 |
+
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
| 3522 |
+
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
| 3523 |
+
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
| 3524 |
+
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
| 3525 |
+
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
| 3526 |
+
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
| 3527 |
+
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
| 3528 |
+
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
| 3529 |
+
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
| 3530 |
+
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
| 3531 |
+
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
| 3532 |
+
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
| 3533 |
+
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
| 3534 |
+
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
| 3535 |
+
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
| 3536 |
+
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
| 3537 |
+
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
| 3538 |
+
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
| 3539 |
+
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
| 3540 |
+
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
| 3541 |
+
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
| 3542 |
+
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
| 3543 |
+
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
| 3544 |
+
};
|
| 3545 |
+
|
| 3546 |
+
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
| 3547 |
+
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 3548 |
+
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
| 3549 |
+
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
| 3550 |
+
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
| 3551 |
+
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
| 3552 |
+
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
| 3553 |
+
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
| 3554 |
+
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
| 3555 |
+
};
|
| 3556 |
+
|
| 3557 |
+
constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
| 3558 |
+
|
| 3559 |
+
void kernel_mul_mv_iq2_xxs_f32_impl(
|
| 3560 |
+
device const void * src0,
|
| 3561 |
+
device const float * src1,
|
| 3562 |
+
device float * dst,
|
| 3563 |
+
constant int64_t & ne00,
|
| 3564 |
+
constant int64_t & ne01,
|
| 3565 |
+
constant int64_t & ne02,
|
| 3566 |
+
constant int64_t & ne10,
|
| 3567 |
+
constant int64_t & ne12,
|
| 3568 |
+
constant int64_t & ne0,
|
| 3569 |
+
constant int64_t & ne1,
|
| 3570 |
+
constant uint & r2,
|
| 3571 |
+
constant uint & r3,
|
| 3572 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 3573 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3574 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 3575 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3576 |
+
|
| 3577 |
+
const int nb = ne00/QK_K;
|
| 3578 |
+
const int r0 = tgpig.x;
|
| 3579 |
+
const int r1 = tgpig.y;
|
| 3580 |
+
const int im = tgpig.z;
|
| 3581 |
+
|
| 3582 |
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 3583 |
+
const int ib_row = first_row * nb;
|
| 3584 |
+
|
| 3585 |
+
const uint i12 = im%ne12;
|
| 3586 |
+
const uint i13 = im/ne12;
|
| 3587 |
+
|
| 3588 |
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
| 3589 |
+
|
| 3590 |
+
device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
|
| 3591 |
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
| 3592 |
+
|
| 3593 |
+
float yl[32];
|
| 3594 |
+
float sumf[N_DST]={0.f}, all_sum;
|
| 3595 |
+
|
| 3596 |
+
const int nb32 = nb * (QK_K / 32);
|
| 3597 |
+
|
| 3598 |
+
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
| 3599 |
+
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
|
| 3600 |
+
{
|
| 3601 |
+
int nval = 4;
|
| 3602 |
+
int pos = (32*sgitg + tiisg)*nval;
|
| 3603 |
+
for (int i = 0; i < nval; ++i) values[pos + i] = kgrid_iq2xxs[pos + i];
|
| 3604 |
+
nval = 2;
|
| 3605 |
+
pos = (32*sgitg + tiisg)*nval;
|
| 3606 |
+
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
| 3607 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3608 |
+
}
|
| 3609 |
+
|
| 3610 |
+
#if QK_K == 256
|
| 3611 |
+
const int ix = tiisg;
|
| 3612 |
+
|
| 3613 |
+
device const float * y4 = y + 32 * ix;
|
| 3614 |
+
|
| 3615 |
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 3616 |
+
|
| 3617 |
+
for (int i = 0; i < 32; ++i) {
|
| 3618 |
+
yl[i] = y4[i];
|
| 3619 |
+
}
|
| 3620 |
+
|
| 3621 |
+
const int ibl = ib32 / (QK_K / 32);
|
| 3622 |
+
const int ib = ib32 % (QK_K / 32);
|
| 3623 |
+
|
| 3624 |
+
device const block_iq2_xxs * xr = x + ibl;
|
| 3625 |
+
device const uint16_t * q2 = xr->qs + 4 * ib;
|
| 3626 |
+
device const half * dh = &xr->d;
|
| 3627 |
+
|
| 3628 |
+
for (int row = 0; row < N_DST; row++) {
|
| 3629 |
+
|
| 3630 |
+
const float db = dh[0];
|
| 3631 |
+
device const uint8_t * aux8 = (device const uint8_t *)q2;
|
| 3632 |
+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
| 3633 |
+
const float d = db * (0.5f + (aux32 >> 28));
|
| 3634 |
+
|
| 3635 |
+
float sum = 0;
|
| 3636 |
+
for (int l = 0; l < 4; ++l) {
|
| 3637 |
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
|
| 3638 |
+
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
|
| 3639 |
+
for (int j = 0; j < 8; ++j) {
|
| 3640 |
+
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
| 3641 |
+
}
|
| 3642 |
+
}
|
| 3643 |
+
sumf[row] += d * sum;
|
| 3644 |
+
|
| 3645 |
+
dh += nb*sizeof(block_iq2_xxs)/2;
|
| 3646 |
+
q2 += nb*sizeof(block_iq2_xxs)/2;
|
| 3647 |
+
}
|
| 3648 |
+
|
| 3649 |
+
y4 += 32 * 32;
|
| 3650 |
+
}
|
| 3651 |
+
#else
|
| 3652 |
+
// TODO
|
| 3653 |
+
#endif
|
| 3654 |
+
|
| 3655 |
+
for (int row = 0; row < N_DST; ++row) {
|
| 3656 |
+
all_sum = simd_sum(sumf[row]);
|
| 3657 |
+
if (tiisg == 0) {
|
| 3658 |
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
| 3659 |
+
}
|
| 3660 |
+
}
|
| 3661 |
+
}
|
| 3662 |
+
|
| 3663 |
+
[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
|
| 3664 |
+
kernel void kernel_mul_mv_iq2_xxs_f32(
|
| 3665 |
+
device const void * src0,
|
| 3666 |
+
device const float * src1,
|
| 3667 |
+
device float * dst,
|
| 3668 |
+
constant int64_t & ne00,
|
| 3669 |
+
constant int64_t & ne01,
|
| 3670 |
+
constant int64_t & ne02,
|
| 3671 |
+
constant uint64_t & nb00,
|
| 3672 |
+
constant uint64_t & nb01,
|
| 3673 |
+
constant uint64_t & nb02,
|
| 3674 |
+
constant int64_t & ne10,
|
| 3675 |
+
constant int64_t & ne11,
|
| 3676 |
+
constant int64_t & ne12,
|
| 3677 |
+
constant uint64_t & nb10,
|
| 3678 |
+
constant uint64_t & nb11,
|
| 3679 |
+
constant uint64_t & nb12,
|
| 3680 |
+
constant int64_t & ne0,
|
| 3681 |
+
constant int64_t & ne1,
|
| 3682 |
+
constant uint & r2,
|
| 3683 |
+
constant uint & r3,
|
| 3684 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 3685 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3686 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 3687 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3688 |
+
|
| 3689 |
+
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 3690 |
+
}
|
| 3691 |
+
|
| 3692 |
//============================= templates and their specializations =============================
|
| 3693 |
|
| 3694 |
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
|
|
| 3960 |
}
|
| 3961 |
}
|
| 3962 |
|
| 3963 |
+
template <typename type4x4>
|
| 3964 |
+
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
|
| 3965 |
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
| 3966 |
+
const float d = xb->d;
|
| 3967 |
+
const int ib32 = il/2;
|
| 3968 |
+
il = il%2;
|
| 3969 |
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
| 3970 |
+
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
|
| 3971 |
+
device const uint16_t * q2 = xb->qs + 4*ib32;
|
| 3972 |
+
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
|
| 3973 |
+
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
|
| 3974 |
+
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
|
| 3975 |
+
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
|
| 3976 |
+
constant uint8_t * grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
|
| 3977 |
+
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
|
| 3978 |
+
for (int i = 0; i < 8; ++i) {
|
| 3979 |
+
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
| 3980 |
+
}
|
| 3981 |
+
grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
|
| 3982 |
+
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
|
| 3983 |
+
for (int i = 0; i < 8; ++i) {
|
| 3984 |
+
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
| 3985 |
+
}
|
| 3986 |
+
}
|
| 3987 |
+
|
| 3988 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 3989 |
kernel void kernel_get_rows(
|
| 3990 |
device const void * src0,
|
|
|
|
| 4524 |
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
| 4525 |
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4526 |
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4527 |
+
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 4528 |
|
| 4529 |
//
|
| 4530 |
// matrix-matrix multiplication
|
|
|
|
| 4561 |
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
| 4562 |
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4563 |
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4564 |
+
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 4565 |
|
| 4566 |
//
|
| 4567 |
// indirect matrix-matrix multiplication
|
|
|
|
| 4610 |
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
| 4611 |
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4612 |
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4613 |
+
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 4614 |
|
| 4615 |
//
|
| 4616 |
// matrix-vector multiplication
|
|
|
|
| 5383 |
tiisg,
|
| 5384 |
sgitg);
|
| 5385 |
}
|
| 5386 |
+
|
| 5387 |
+
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
| 5388 |
+
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
| 5389 |
+
device const char * ids,
|
| 5390 |
+
device const char * src1,
|
| 5391 |
+
device float * dst,
|
| 5392 |
+
constant uint64_t & nbi1,
|
| 5393 |
+
constant int64_t & ne00,
|
| 5394 |
+
constant int64_t & ne01,
|
| 5395 |
+
constant int64_t & ne02,
|
| 5396 |
+
constant uint64_t & nb00,
|
| 5397 |
+
constant uint64_t & nb01,
|
| 5398 |
+
constant uint64_t & nb02,
|
| 5399 |
+
constant int64_t & ne10,
|
| 5400 |
+
constant int64_t & ne11,
|
| 5401 |
+
constant int64_t & ne12,
|
| 5402 |
+
constant int64_t & ne13,
|
| 5403 |
+
constant uint64_t & nb10,
|
| 5404 |
+
constant uint64_t & nb11,
|
| 5405 |
+
constant uint64_t & nb12,
|
| 5406 |
+
constant int64_t & ne0,
|
| 5407 |
+
constant int64_t & ne1,
|
| 5408 |
+
constant uint64_t & nb1,
|
| 5409 |
+
constant uint & r2,
|
| 5410 |
+
constant uint & r3,
|
| 5411 |
+
constant int & idx,
|
| 5412 |
+
device const char * src00,
|
| 5413 |
+
device const char * src01,
|
| 5414 |
+
device const char * src02,
|
| 5415 |
+
device const char * src03,
|
| 5416 |
+
device const char * src04,
|
| 5417 |
+
device const char * src05,
|
| 5418 |
+
device const char * src06,
|
| 5419 |
+
device const char * src07,
|
| 5420 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 5421 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5422 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 5423 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 5424 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5425 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 5426 |
+
|
| 5427 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 5428 |
+
|
| 5429 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 5430 |
+
|
| 5431 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 5432 |
+
|
| 5433 |
+
kernel_mul_mv_iq2_xxs_f32_impl(
|
| 5434 |
+
src0[id],
|
| 5435 |
+
(device const float *) (src1 + bid*nb11),
|
| 5436 |
+
dst + bid*ne0,
|
| 5437 |
+
ne00,
|
| 5438 |
+
ne01,
|
| 5439 |
+
ne02,
|
| 5440 |
+
ne10,
|
| 5441 |
+
ne12,
|
| 5442 |
+
ne0,
|
| 5443 |
+
ne1,
|
| 5444 |
+
r2,
|
| 5445 |
+
r3,
|
| 5446 |
+
shared_values,
|
| 5447 |
+
tgpig,
|
| 5448 |
+
tiisg,
|
| 5449 |
+
sgitg);
|
| 5450 |
+
}
|
|
@@ -2340,6 +2340,138 @@ size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t *
|
|
| 2340 |
return (n/QK_K*sizeof(block_q6_K));
|
| 2341 |
}
|
| 2342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2343 |
//===================================== Q8_K ==============================================
|
| 2344 |
|
| 2345 |
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
|
@@ -2362,7 +2494,9 @@ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict
|
|
| 2362 |
x += QK_K;
|
| 2363 |
continue;
|
| 2364 |
}
|
| 2365 |
-
const float iscale = -128.f/max;
|
|
|
|
|
|
|
| 2366 |
for (int j = 0; j < QK_K; ++j) {
|
| 2367 |
int v = nearest_int(iscale*x[j]);
|
| 2368 |
y[i].qs[j] = MIN(127, v);
|
|
@@ -7065,3 +7199,161 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
| 7065 |
}
|
| 7066 |
|
| 7067 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2340 |
return (n/QK_K*sizeof(block_q6_K));
|
| 2341 |
}
|
| 2342 |
|
| 2343 |
+
// ====================== "True" 2-bit (de)-quantization
|
| 2344 |
+
|
| 2345 |
+
void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) {
|
| 2346 |
+
(void)x;
|
| 2347 |
+
(void)y;
|
| 2348 |
+
(void)k;
|
| 2349 |
+
assert(k % QK_K == 0);
|
| 2350 |
+
//fprintf(stderr, "=========================== %s: not implemented\n", __func__);
|
| 2351 |
+
}
|
| 2352 |
+
|
| 2353 |
+
static const uint64_t iq2xxs_grid[256] = {
|
| 2354 |
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
| 2355 |
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
| 2356 |
+
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
| 2357 |
+
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
| 2358 |
+
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
| 2359 |
+
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
| 2360 |
+
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
| 2361 |
+
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
| 2362 |
+
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
| 2363 |
+
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
| 2364 |
+
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
| 2365 |
+
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
| 2366 |
+
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
| 2367 |
+
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
| 2368 |
+
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
| 2369 |
+
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
| 2370 |
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
| 2371 |
+
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
| 2372 |
+
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
| 2373 |
+
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
| 2374 |
+
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
| 2375 |
+
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
| 2376 |
+
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
| 2377 |
+
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
| 2378 |
+
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
| 2379 |
+
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
| 2380 |
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
| 2381 |
+
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
| 2382 |
+
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
| 2383 |
+
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
| 2384 |
+
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
| 2385 |
+
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
| 2386 |
+
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
| 2387 |
+
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
| 2388 |
+
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
| 2389 |
+
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
| 2390 |
+
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
| 2391 |
+
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
| 2392 |
+
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
| 2393 |
+
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
| 2394 |
+
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
| 2395 |
+
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
| 2396 |
+
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
| 2397 |
+
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
| 2398 |
+
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
| 2399 |
+
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
| 2400 |
+
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
| 2401 |
+
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
| 2402 |
+
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
| 2403 |
+
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
| 2404 |
+
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
| 2405 |
+
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
| 2406 |
+
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
| 2407 |
+
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
| 2408 |
+
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
| 2409 |
+
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
| 2410 |
+
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
| 2411 |
+
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
| 2412 |
+
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
| 2413 |
+
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
| 2414 |
+
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
| 2415 |
+
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
| 2416 |
+
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
| 2417 |
+
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
| 2418 |
+
};
|
| 2419 |
+
|
| 2420 |
+
static const uint8_t ksigns_iq2xs[128] = {
|
| 2421 |
+
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 2422 |
+
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
| 2423 |
+
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
| 2424 |
+
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
| 2425 |
+
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
| 2426 |
+
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
| 2427 |
+
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
| 2428 |
+
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
| 2429 |
+
};
|
| 2430 |
+
static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
| 2431 |
+
|
| 2432 |
+
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
|
| 2433 |
+
assert(k % QK_K == 0);
|
| 2434 |
+
const int nb = k / QK_K;
|
| 2435 |
+
|
| 2436 |
+
uint32_t aux32[2];
|
| 2437 |
+
const uint8_t * aux8 = (const uint8_t *)aux32;
|
| 2438 |
+
|
| 2439 |
+
for (int i = 0; i < nb; i++) {
|
| 2440 |
+
|
| 2441 |
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
| 2442 |
+
|
| 2443 |
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
| 2444 |
+
memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t));
|
| 2445 |
+
const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;
|
| 2446 |
+
for (int l = 0; l < 4; ++l) {
|
| 2447 |
+
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
|
| 2448 |
+
const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
|
| 2449 |
+
for (int j = 0; j < 8; ++j) {
|
| 2450 |
+
y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
| 2451 |
+
}
|
| 2452 |
+
y += 8;
|
| 2453 |
+
}
|
| 2454 |
+
}
|
| 2455 |
+
}
|
| 2456 |
+
}
|
| 2457 |
+
|
| 2458 |
+
void quantize_row_iq2_xxs(const float * restrict x, void * restrict vy, int k) {
|
| 2459 |
+
assert(k % QK_K == 0);
|
| 2460 |
+
block_iq2_xxs * restrict y = vy;
|
| 2461 |
+
quantize_row_iq2_xxs_reference(x, y, k);
|
| 2462 |
+
}
|
| 2463 |
+
|
| 2464 |
+
size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist) {
|
| 2465 |
+
assert(k % QK_K == 0);
|
| 2466 |
+
(void)hist; // TODO: collect histograms
|
| 2467 |
+
|
| 2468 |
+
for (int j = 0; j < n; j += k) {
|
| 2469 |
+
block_iq2_xxs * restrict y = (block_iq2_xxs *)dst + j/QK_K;
|
| 2470 |
+
quantize_row_iq2_xxs_reference(src + j, y, k);
|
| 2471 |
+
}
|
| 2472 |
+
return (n/QK_K*sizeof(block_iq2_xxs));
|
| 2473 |
+
}
|
| 2474 |
+
|
| 2475 |
//===================================== Q8_K ==============================================
|
| 2476 |
|
| 2477 |
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
|
|
|
| 2494 |
x += QK_K;
|
| 2495 |
continue;
|
| 2496 |
}
|
| 2497 |
+
//const float iscale = -128.f/max;
|
| 2498 |
+
// We need this change for IQ2_XXS, else the AVX implementation becomes very awkward
|
| 2499 |
+
const float iscale = -127.f/max;
|
| 2500 |
for (int j = 0; j < QK_K; ++j) {
|
| 2501 |
int v = nearest_int(iscale*x[j]);
|
| 2502 |
y[i].qs[j] = MIN(127, v);
|
|
|
|
| 7199 |
}
|
| 7200 |
|
| 7201 |
#endif
|
| 7202 |
+
|
| 7203 |
+
static const int8_t keven_signs_q2xs[1024] = {
|
| 7204 |
+
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
|
| 7205 |
+
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
|
| 7206 |
+
1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
|
| 7207 |
+
1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
|
| 7208 |
+
1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
|
| 7209 |
+
1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
|
| 7210 |
+
1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
|
| 7211 |
+
1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
|
| 7212 |
+
1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
|
| 7213 |
+
1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
|
| 7214 |
+
1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
|
| 7215 |
+
1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
|
| 7216 |
+
1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
|
| 7217 |
+
1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
|
| 7218 |
+
1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
|
| 7219 |
+
1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
|
| 7220 |
+
1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
|
| 7221 |
+
1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
|
| 7222 |
+
1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
|
| 7223 |
+
1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
|
| 7224 |
+
1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
|
| 7225 |
+
1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
|
| 7226 |
+
1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
|
| 7227 |
+
1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
|
| 7228 |
+
1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
|
| 7229 |
+
1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
|
| 7230 |
+
1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
|
| 7231 |
+
1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
|
| 7232 |
+
1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
|
| 7233 |
+
1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
|
| 7234 |
+
1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
|
| 7235 |
+
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
|
| 7236 |
+
};
|
| 7237 |
+
|
| 7238 |
+
void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
| 7239 |
+
assert(n % QK_K == 0);
|
| 7240 |
+
|
| 7241 |
+
const block_iq2_xxs * restrict x = vx;
|
| 7242 |
+
const block_q8_K * restrict y = vy;
|
| 7243 |
+
|
| 7244 |
+
const int nb = n / QK_K;
|
| 7245 |
+
|
| 7246 |
+
#if defined(__ARM_NEON)
|
| 7247 |
+
|
| 7248 |
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
| 7249 |
+
|
| 7250 |
+
uint32_t aux32[4];
|
| 7251 |
+
const uint8_t * aux8 = (const uint8_t *)aux32;
|
| 7252 |
+
|
| 7253 |
+
int8x16x4_t q2u;
|
| 7254 |
+
int8x16x4_t q2s;
|
| 7255 |
+
int8x16x4_t q8b;
|
| 7256 |
+
|
| 7257 |
+
float sumf = 0;
|
| 7258 |
+
for (int i = 0; i < nb; ++i) {
|
| 7259 |
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
| 7260 |
+
const uint16_t * restrict q2 = x[i].qs;
|
| 7261 |
+
const int8_t * restrict q8 = y[i].qs;
|
| 7262 |
+
float sumf1 = 0, sumf2 = 0;
|
| 7263 |
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 7264 |
+
q8b = vld1q_s8_x4(q8); q8 += 64;
|
| 7265 |
+
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
|
| 7266 |
+
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
|
| 7267 |
+
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
|
| 7268 |
+
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
|
| 7269 |
+
q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
|
| 7270 |
+
q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
|
| 7271 |
+
q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
|
| 7272 |
+
q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127))));
|
| 7273 |
+
q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
|
| 7274 |
+
q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
|
| 7275 |
+
q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
|
| 7276 |
+
q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
|
| 7277 |
+
q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
|
| 7278 |
+
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
|
| 7279 |
+
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
|
| 7280 |
+
sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
|
| 7281 |
+
sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
|
| 7282 |
+
}
|
| 7283 |
+
sumf += d*(sumf1 + sumf2);
|
| 7284 |
+
}
|
| 7285 |
+
*s = 0.25f * sumf;
|
| 7286 |
+
|
| 7287 |
+
#elif defined(__AVX2__)
|
| 7288 |
+
|
| 7289 |
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
| 7290 |
+
|
| 7291 |
+
uint32_t aux32[4];
|
| 7292 |
+
const uint8_t * aux8 = (const uint8_t *)aux32;
|
| 7293 |
+
|
| 7294 |
+
__m256 accumf = _mm256_setzero_ps();
|
| 7295 |
+
for (int i = 0; i < nb; ++i) {
|
| 7296 |
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
| 7297 |
+
const uint16_t * restrict q2 = x[i].qs;
|
| 7298 |
+
const int8_t * restrict q8 = y[i].qs;
|
| 7299 |
+
__m256i sumi1 = _mm256_setzero_si256();
|
| 7300 |
+
__m256i sumi2 = _mm256_setzero_si256();
|
| 7301 |
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 7302 |
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 7303 |
+
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 7304 |
+
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
|
| 7305 |
+
const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
|
| 7306 |
+
const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
|
| 7307 |
+
const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
|
| 7308 |
+
signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
|
| 7309 |
+
const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
|
| 7310 |
+
signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
|
| 7311 |
+
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
|
| 7312 |
+
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
|
| 7313 |
+
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
| 7314 |
+
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
| 7315 |
+
const uint16_t ls1 = aux32[1] >> 28;
|
| 7316 |
+
const uint16_t ls2 = aux32[3] >> 28;
|
| 7317 |
+
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
|
| 7318 |
+
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
|
| 7319 |
+
sumi1 = _mm256_add_epi32(sumi1, p1);
|
| 7320 |
+
sumi2 = _mm256_add_epi32(sumi2, p2);
|
| 7321 |
+
}
|
| 7322 |
+
|
| 7323 |
+
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
|
| 7324 |
+
|
| 7325 |
+
}
|
| 7326 |
+
|
| 7327 |
+
*s = 0.125f * hsum_float_8(accumf);
|
| 7328 |
+
|
| 7329 |
+
#else
|
| 7330 |
+
|
| 7331 |
+
uint32_t aux32[2];
|
| 7332 |
+
const uint8_t * aux8 = (const uint8_t *)aux32;
|
| 7333 |
+
|
| 7334 |
+
float sumf = 0.f;
|
| 7335 |
+
for (int i = 0; i < nb; ++i) {
|
| 7336 |
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
| 7337 |
+
const uint16_t * restrict q2 = x[i].qs;
|
| 7338 |
+
const int8_t * restrict q8 = y[i].qs;
|
| 7339 |
+
int32_t bsum = 0;
|
| 7340 |
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
| 7341 |
+
memcpy(aux32, q2, 2*sizeof(uint32_t));
|
| 7342 |
+
q2 += 4;
|
| 7343 |
+
const uint32_t ls = 2*(aux32[1] >> 28) + 1;
|
| 7344 |
+
int32_t sumi = 0;
|
| 7345 |
+
for (int l = 0; l < 4; ++l) {
|
| 7346 |
+
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
|
| 7347 |
+
const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
|
| 7348 |
+
for (int j = 0; j < 8; ++j) {
|
| 7349 |
+
sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
| 7350 |
+
}
|
| 7351 |
+
q8 += 8;
|
| 7352 |
+
}
|
| 7353 |
+
bsum += sumi * ls;
|
| 7354 |
+
}
|
| 7355 |
+
sumf += d * bsum;
|
| 7356 |
+
}
|
| 7357 |
+
*s = 0.125f * sumf;
|
| 7358 |
+
#endif
|
| 7359 |
+
}
|
|
@@ -165,6 +165,14 @@ typedef struct {
|
|
| 165 |
} block_q8_K;
|
| 166 |
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
// Quantization
|
| 170 |
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
|
|
@@ -180,6 +188,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
|
| 180 |
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
|
| 181 |
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
|
| 182 |
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
|
|
|
|
| 183 |
|
| 184 |
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
|
| 185 |
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
|
|
@@ -194,6 +203,7 @@ void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
|
|
| 194 |
void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
|
| 195 |
void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
|
| 196 |
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
|
|
|
|
| 197 |
|
| 198 |
// Dequantization
|
| 199 |
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
|
|
@@ -209,6 +219,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
|
|
| 209 |
void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
|
| 210 |
void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
|
| 211 |
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
|
|
|
|
| 212 |
|
| 213 |
// Dot product
|
| 214 |
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
|
@@ -222,3 +233,4 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx,
|
|
| 222 |
void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 223 |
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 224 |
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
|
|
|
|
|
| 165 |
} block_q8_K;
|
| 166 |
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
|
| 167 |
|
| 168 |
+
// (Almost) "true" 2-bit quantization.
|
| 169 |
+
// Due to the need to use blocks as per ggml dsign, it ends up using
|
| 170 |
+
// 2.0625 bpw because of the 16-bit scale for each block of 256.
|
| 171 |
+
typedef struct {
|
| 172 |
+
ggml_fp16_t d;
|
| 173 |
+
uint16_t qs[QK_K/8];
|
| 174 |
+
} block_iq2_xxs;
|
| 175 |
+
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
|
| 176 |
|
| 177 |
// Quantization
|
| 178 |
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
|
|
|
|
| 188 |
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
|
| 189 |
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
|
| 190 |
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
|
| 191 |
+
void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k);
|
| 192 |
|
| 193 |
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
|
| 194 |
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
|
|
|
|
| 203 |
void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
|
| 204 |
void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
|
| 205 |
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
|
| 206 |
+
void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k);
|
| 207 |
|
| 208 |
// Dequantization
|
| 209 |
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
|
|
|
|
| 219 |
void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
|
| 220 |
void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
|
| 221 |
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
|
| 222 |
+
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
|
| 223 |
|
| 224 |
// Dot product
|
| 225 |
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
|
|
|
| 233 |
void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 234 |
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 235 |
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 236 |
+
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
|
@@ -573,6 +573,17 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|
| 573 |
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
| 574 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 575 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 576 |
[GGML_TYPE_Q8_K] = {
|
| 577 |
.type_name = "q8_K",
|
| 578 |
.blck_size = QK_K,
|
|
@@ -2111,6 +2122,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
|
| 2111 |
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
|
| 2112 |
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
|
| 2113 |
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
|
|
|
|
| 2114 |
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
| 2115 |
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
| 2116 |
}
|
|
@@ -7436,6 +7448,7 @@ static void ggml_compute_forward_add(
|
|
| 7436 |
case GGML_TYPE_Q4_K:
|
| 7437 |
case GGML_TYPE_Q5_K:
|
| 7438 |
case GGML_TYPE_Q6_K:
|
|
|
|
| 7439 |
{
|
| 7440 |
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
| 7441 |
} break;
|
|
@@ -7700,6 +7713,7 @@ static void ggml_compute_forward_add1(
|
|
| 7700 |
case GGML_TYPE_Q4_K:
|
| 7701 |
case GGML_TYPE_Q5_K:
|
| 7702 |
case GGML_TYPE_Q6_K:
|
|
|
|
| 7703 |
{
|
| 7704 |
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
|
| 7705 |
} break;
|
|
@@ -7814,6 +7828,7 @@ static void ggml_compute_forward_acc(
|
|
| 7814 |
case GGML_TYPE_Q4_K:
|
| 7815 |
case GGML_TYPE_Q5_K:
|
| 7816 |
case GGML_TYPE_Q6_K:
|
|
|
|
| 7817 |
default:
|
| 7818 |
{
|
| 7819 |
GGML_ASSERT(false);
|
|
@@ -10455,6 +10470,7 @@ static void ggml_compute_forward_out_prod(
|
|
| 10455 |
case GGML_TYPE_Q4_K:
|
| 10456 |
case GGML_TYPE_Q5_K:
|
| 10457 |
case GGML_TYPE_Q6_K:
|
|
|
|
| 10458 |
{
|
| 10459 |
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
| 10460 |
} break;
|
|
@@ -10629,6 +10645,7 @@ static void ggml_compute_forward_set(
|
|
| 10629 |
case GGML_TYPE_Q4_K:
|
| 10630 |
case GGML_TYPE_Q5_K:
|
| 10631 |
case GGML_TYPE_Q6_K:
|
|
|
|
| 10632 |
default:
|
| 10633 |
{
|
| 10634 |
GGML_ASSERT(false);
|
|
@@ -10823,6 +10840,7 @@ static void ggml_compute_forward_get_rows(
|
|
| 10823 |
case GGML_TYPE_Q4_K:
|
| 10824 |
case GGML_TYPE_Q5_K:
|
| 10825 |
case GGML_TYPE_Q6_K:
|
|
|
|
| 10826 |
{
|
| 10827 |
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
| 10828 |
} break;
|
|
@@ -11459,6 +11477,7 @@ static void ggml_compute_forward_alibi(
|
|
| 11459 |
case GGML_TYPE_Q4_K:
|
| 11460 |
case GGML_TYPE_Q5_K:
|
| 11461 |
case GGML_TYPE_Q6_K:
|
|
|
|
| 11462 |
case GGML_TYPE_Q8_K:
|
| 11463 |
case GGML_TYPE_I8:
|
| 11464 |
case GGML_TYPE_I16:
|
|
@@ -11533,6 +11552,7 @@ static void ggml_compute_forward_clamp(
|
|
| 11533 |
case GGML_TYPE_Q4_K:
|
| 11534 |
case GGML_TYPE_Q5_K:
|
| 11535 |
case GGML_TYPE_Q6_K:
|
|
|
|
| 11536 |
case GGML_TYPE_Q8_K:
|
| 11537 |
case GGML_TYPE_I8:
|
| 11538 |
case GGML_TYPE_I16:
|
|
@@ -18648,6 +18668,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
| 18648 |
block_q6_K * block = (block_q6_K*)dst + start / QK_K;
|
| 18649 |
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
|
| 18650 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18651 |
case GGML_TYPE_F16:
|
| 18652 |
{
|
| 18653 |
int elemsize = sizeof(ggml_fp16_t);
|
|
|
|
| 573 |
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
| 574 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 575 |
},
|
| 576 |
+
[GGML_TYPE_IQ2_XXS] = {
|
| 577 |
+
.type_name = "iq2_xxs",
|
| 578 |
+
.blck_size = QK_K,
|
| 579 |
+
.type_size = sizeof(block_iq2_xxs),
|
| 580 |
+
.is_quantized = true,
|
| 581 |
+
.to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
|
| 582 |
+
.from_float = quantize_row_iq2_xxs,
|
| 583 |
+
.from_float_reference = (ggml_from_float_t) quantize_row_iq2_xxs_reference,
|
| 584 |
+
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
|
| 585 |
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 586 |
+
},
|
| 587 |
[GGML_TYPE_Q8_K] = {
|
| 588 |
.type_name = "q8_K",
|
| 589 |
.blck_size = QK_K,
|
|
|
|
| 2122 |
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
|
| 2123 |
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
|
| 2124 |
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
|
| 2125 |
+
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
|
| 2126 |
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
| 2127 |
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
| 2128 |
}
|
|
|
|
| 7448 |
case GGML_TYPE_Q4_K:
|
| 7449 |
case GGML_TYPE_Q5_K:
|
| 7450 |
case GGML_TYPE_Q6_K:
|
| 7451 |
+
case GGML_TYPE_IQ2_XXS:
|
| 7452 |
{
|
| 7453 |
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
| 7454 |
} break;
|
|
|
|
| 7713 |
case GGML_TYPE_Q4_K:
|
| 7714 |
case GGML_TYPE_Q5_K:
|
| 7715 |
case GGML_TYPE_Q6_K:
|
| 7716 |
+
case GGML_TYPE_IQ2_XXS:
|
| 7717 |
{
|
| 7718 |
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
|
| 7719 |
} break;
|
|
|
|
| 7828 |
case GGML_TYPE_Q4_K:
|
| 7829 |
case GGML_TYPE_Q5_K:
|
| 7830 |
case GGML_TYPE_Q6_K:
|
| 7831 |
+
case GGML_TYPE_IQ2_XXS:
|
| 7832 |
default:
|
| 7833 |
{
|
| 7834 |
GGML_ASSERT(false);
|
|
|
|
| 10470 |
case GGML_TYPE_Q4_K:
|
| 10471 |
case GGML_TYPE_Q5_K:
|
| 10472 |
case GGML_TYPE_Q6_K:
|
| 10473 |
+
case GGML_TYPE_IQ2_XXS:
|
| 10474 |
{
|
| 10475 |
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
| 10476 |
} break;
|
|
|
|
| 10645 |
case GGML_TYPE_Q4_K:
|
| 10646 |
case GGML_TYPE_Q5_K:
|
| 10647 |
case GGML_TYPE_Q6_K:
|
| 10648 |
+
case GGML_TYPE_IQ2_XXS:
|
| 10649 |
default:
|
| 10650 |
{
|
| 10651 |
GGML_ASSERT(false);
|
|
|
|
| 10840 |
case GGML_TYPE_Q4_K:
|
| 10841 |
case GGML_TYPE_Q5_K:
|
| 10842 |
case GGML_TYPE_Q6_K:
|
| 10843 |
+
case GGML_TYPE_IQ2_XXS:
|
| 10844 |
{
|
| 10845 |
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
| 10846 |
} break;
|
|
|
|
| 11477 |
case GGML_TYPE_Q4_K:
|
| 11478 |
case GGML_TYPE_Q5_K:
|
| 11479 |
case GGML_TYPE_Q6_K:
|
| 11480 |
+
case GGML_TYPE_IQ2_XXS:
|
| 11481 |
case GGML_TYPE_Q8_K:
|
| 11482 |
case GGML_TYPE_I8:
|
| 11483 |
case GGML_TYPE_I16:
|
|
|
|
| 11552 |
case GGML_TYPE_Q4_K:
|
| 11553 |
case GGML_TYPE_Q5_K:
|
| 11554 |
case GGML_TYPE_Q6_K:
|
| 11555 |
+
case GGML_TYPE_IQ2_XXS:
|
| 11556 |
case GGML_TYPE_Q8_K:
|
| 11557 |
case GGML_TYPE_I8:
|
| 11558 |
case GGML_TYPE_I16:
|
|
|
|
| 18668 |
block_q6_K * block = (block_q6_K*)dst + start / QK_K;
|
| 18669 |
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
|
| 18670 |
} break;
|
| 18671 |
+
case GGML_TYPE_IQ2_XXS:
|
| 18672 |
+
{
|
| 18673 |
+
GGML_ASSERT(start % QK_K == 0);
|
| 18674 |
+
block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K;
|
| 18675 |
+
result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist);
|
| 18676 |
+
} break;
|
| 18677 |
case GGML_TYPE_F16:
|
| 18678 |
{
|
| 18679 |
int elemsize = sizeof(ggml_fp16_t);
|
|
@@ -339,6 +339,7 @@ extern "C" {
|
|
| 339 |
GGML_TYPE_Q5_K = 13,
|
| 340 |
GGML_TYPE_Q6_K = 14,
|
| 341 |
GGML_TYPE_Q8_K = 15,
|
|
|
|
| 342 |
GGML_TYPE_I8,
|
| 343 |
GGML_TYPE_I16,
|
| 344 |
GGML_TYPE_I32,
|
|
@@ -373,6 +374,7 @@ extern "C" {
|
|
| 373 |
GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
|
| 374 |
GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
|
| 375 |
GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
|
|
|
|
| 376 |
};
|
| 377 |
|
| 378 |
// available tensor operations:
|
|
@@ -2067,6 +2069,7 @@ extern "C" {
|
|
| 2067 |
GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
| 2068 |
GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
| 2069 |
GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
|
|
|
| 2070 |
|
| 2071 |
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
|
| 2072 |
|
|
|
|
| 339 |
GGML_TYPE_Q5_K = 13,
|
| 340 |
GGML_TYPE_Q6_K = 14,
|
| 341 |
GGML_TYPE_Q8_K = 15,
|
| 342 |
+
GGML_TYPE_IQ2_XXS = 16,
|
| 343 |
GGML_TYPE_I8,
|
| 344 |
GGML_TYPE_I16,
|
| 345 |
GGML_TYPE_I32,
|
|
|
|
| 374 |
GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
|
| 375 |
GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
|
| 376 |
GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
|
| 377 |
+
GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
|
| 378 |
};
|
| 379 |
|
| 380 |
// available tensor operations:
|
|
|
|
| 2069 |
GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
| 2070 |
GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
| 2071 |
GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
| 2072 |
+
GGML_API size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist);
|
| 2073 |
|
| 2074 |
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
|
| 2075 |
|