Spaces:
Running
Running
ggml : SOTA 2-bit quants (add IQ2_XS) (llama/4856)
Browse files* iq2_xs: basics
* iq2_xs: this should have been in the basics
* iq2_xs: CUDA and scalar CPU works
* iq2_xs: WIP Metal
* iq2_xs: Metal now works
* iq2_xs: working, but dog slow, ARM_NEON dot product
* iq2_xs: better ARM_NEON dot product
We are now at 19.5 t/s for TG-128 and 61 t/s for PP-512 when
running on the CPU.
* iq2_xs: AVX2 dot product - 19.5 t/s
* iq2_xs: faster AVX2 dit product
21.4 t/s for TG-128, 59.2 t/s for PP-512.
The latter is 2x compared to the previous version.
* iq2_xs: had forgotten to delete iq2-data.h
* Add llama enum for IQ2_XS
---------
Co-authored-by: Iwan Kawrakow <[email protected]>
- ggml-cuda.cu +227 -5
- ggml-metal.m +36 -6
- ggml-metal.metal +374 -4
- ggml-quants.c +351 -9
- ggml-quants.h +12 -0
- ggml.c +28 -2
- ggml.h +3 -0
ggml-cuda.cu
CHANGED
|
@@ -486,6 +486,15 @@ typedef struct {
|
|
| 486 |
} block_iq2_xxs;
|
| 487 |
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
|
| 488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
#define WARP_SIZE 32
|
| 490 |
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
| 491 |
|
|
@@ -1328,7 +1337,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
|
|
| 1328 |
#endif
|
| 1329 |
}
|
| 1330 |
|
| 1331 |
-
static const __device__ uint64_t
|
| 1332 |
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
| 1333 |
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
| 1334 |
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
|
@@ -1395,6 +1404,137 @@ static const __device__ uint64_t kgrid_iq2xxs[256] = {
|
|
| 1395 |
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
| 1396 |
};
|
| 1397 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1398 |
static const __device__ uint8_t ksigns_iq2xs[128] = {
|
| 1399 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 1400 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
@@ -1439,7 +1579,7 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
|
|
| 1439 |
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
| 1440 |
const uint16_t * q2 = x[i].qs + 4*ib;
|
| 1441 |
const uint8_t * aux8 = (const uint8_t *)q2;
|
| 1442 |
-
const uint8_t * grid = (const uint8_t *)(
|
| 1443 |
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
| 1444 |
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
|
| 1445 |
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
|
@@ -1450,6 +1590,28 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
|
|
| 1450 |
|
| 1451 |
}
|
| 1452 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1453 |
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
| 1454 |
|
| 1455 |
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
|
@@ -3996,7 +4158,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
|
| 3996 |
uint32_t aux32 = q2[2] | (q2[3] << 16);
|
| 3997 |
int sumi = 0;
|
| 3998 |
for (int l = 0; l < 4; ++l) {
|
| 3999 |
-
const uint8_t * grid = (const uint8_t *)(
|
| 4000 |
const uint8_t signs = ksigns_iq2xs[aux32 & 127];
|
| 4001 |
for (int j = 0; j < 8; ++j) {
|
| 4002 |
sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
|
@@ -4012,8 +4174,8 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
|
| 4012 |
const int il = iqs%2;
|
| 4013 |
const uint16_t * q2 = bq2->qs + 4*ib32;
|
| 4014 |
const uint8_t * aux8 = (const uint8_t *)q2;
|
| 4015 |
-
const uint8_t * grid1 = (const uint8_t *)(
|
| 4016 |
-
const uint8_t * grid2 = (const uint8_t *)(
|
| 4017 |
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
| 4018 |
const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
|
| 4019 |
const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
|
@@ -4032,6 +4194,42 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
|
| 4032 |
#endif
|
| 4033 |
}
|
| 4034 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4035 |
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
| 4036 |
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
| 4037 |
static __device__ __forceinline__ void mul_mat_q(
|
|
@@ -6035,6 +6233,12 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k,
|
|
| 6035 |
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
| 6036 |
}
|
| 6037 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6038 |
template <typename src_t, typename dst_t>
|
| 6039 |
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
| 6040 |
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
|
@@ -6065,6 +6269,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|
| 6065 |
return dequantize_row_q6_K_cuda;
|
| 6066 |
case GGML_TYPE_IQ2_XXS:
|
| 6067 |
return dequantize_row_iq2_xxs_cuda;
|
|
|
|
|
|
|
| 6068 |
case GGML_TYPE_F32:
|
| 6069 |
return convert_unary_cuda<float>;
|
| 6070 |
default:
|
|
@@ -6096,6 +6302,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|
| 6096 |
return dequantize_row_q6_K_cuda;
|
| 6097 |
case GGML_TYPE_IQ2_XXS:
|
| 6098 |
return dequantize_row_iq2_xxs_cuda;
|
|
|
|
|
|
|
| 6099 |
case GGML_TYPE_F16:
|
| 6100 |
return convert_unary_cuda<half>;
|
| 6101 |
default:
|
|
@@ -6299,6 +6507,15 @@ static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, floa
|
|
| 6299 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
| 6300 |
}
|
| 6301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6302 |
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
| 6303 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 6304 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
|
@@ -7871,6 +8088,7 @@ static int64_t get_row_rounding(ggml_type type) {
|
|
| 7871 |
case GGML_TYPE_Q5_K:
|
| 7872 |
case GGML_TYPE_Q6_K:
|
| 7873 |
case GGML_TYPE_IQ2_XXS:
|
|
|
|
| 7874 |
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
| 7875 |
default:
|
| 7876 |
GGML_ASSERT(false);
|
|
@@ -7892,6 +8110,7 @@ static int64_t get_row_rounding(ggml_type type) {
|
|
| 7892 |
case GGML_TYPE_Q4_K:
|
| 7893 |
case GGML_TYPE_Q5_K:
|
| 7894 |
case GGML_TYPE_IQ2_XXS:
|
|
|
|
| 7895 |
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
| 7896 |
case GGML_TYPE_Q6_K:
|
| 7897 |
return 64;
|
|
@@ -7945,6 +8164,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
|
| 7945 |
case GGML_TYPE_IQ2_XXS:
|
| 7946 |
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
| 7947 |
break;
|
|
|
|
|
|
|
|
|
|
| 7948 |
default:
|
| 7949 |
GGML_ASSERT(false);
|
| 7950 |
break;
|
|
|
|
| 486 |
} block_iq2_xxs;
|
| 487 |
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
|
| 488 |
|
| 489 |
+
#define QR2_XS 8
|
| 490 |
+
#define QI2_XS (QK_K / (4*QR2_XS))
|
| 491 |
+
typedef struct {
|
| 492 |
+
half d;
|
| 493 |
+
uint16_t qs[QK_K/8];
|
| 494 |
+
uint8_t scales[QK_K/32];
|
| 495 |
+
} block_iq2_xs;
|
| 496 |
+
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
|
| 497 |
+
|
| 498 |
#define WARP_SIZE 32
|
| 499 |
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
| 500 |
|
|
|
|
| 1337 |
#endif
|
| 1338 |
}
|
| 1339 |
|
| 1340 |
+
static const __device__ uint64_t iq2xxs_grid[256] = {
|
| 1341 |
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
| 1342 |
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
| 1343 |
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
|
|
|
| 1404 |
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
| 1405 |
};
|
| 1406 |
|
| 1407 |
+
static const __device__ uint64_t iq2xs_grid[512] = {
|
| 1408 |
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
| 1409 |
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
| 1410 |
+
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
| 1411 |
+
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
| 1412 |
+
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
| 1413 |
+
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
|
| 1414 |
+
0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
|
| 1415 |
+
0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
|
| 1416 |
+
0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
|
| 1417 |
+
0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
|
| 1418 |
+
0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
|
| 1419 |
+
0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
|
| 1420 |
+
0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
|
| 1421 |
+
0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
|
| 1422 |
+
0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
|
| 1423 |
+
0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
|
| 1424 |
+
0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
|
| 1425 |
+
0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
|
| 1426 |
+
0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
|
| 1427 |
+
0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
|
| 1428 |
+
0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
|
| 1429 |
+
0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
|
| 1430 |
+
0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
|
| 1431 |
+
0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
|
| 1432 |
+
0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
|
| 1433 |
+
0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
|
| 1434 |
+
0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
|
| 1435 |
+
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
|
| 1436 |
+
0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
|
| 1437 |
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
|
| 1438 |
+
0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
|
| 1439 |
+
0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
|
| 1440 |
+
0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
|
| 1441 |
+
0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
|
| 1442 |
+
0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
|
| 1443 |
+
0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
|
| 1444 |
+
0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
|
| 1445 |
+
0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
|
| 1446 |
+
0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
|
| 1447 |
+
0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
|
| 1448 |
+
0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
|
| 1449 |
+
0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
|
| 1450 |
+
0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
|
| 1451 |
+
0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
|
| 1452 |
+
0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
|
| 1453 |
+
0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
|
| 1454 |
+
0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
|
| 1455 |
+
0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
|
| 1456 |
+
0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
|
| 1457 |
+
0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
|
| 1458 |
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
|
| 1459 |
+
0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
|
| 1460 |
+
0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
|
| 1461 |
+
0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
|
| 1462 |
+
0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
|
| 1463 |
+
0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
|
| 1464 |
+
0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
|
| 1465 |
+
0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
|
| 1466 |
+
0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
|
| 1467 |
+
0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
|
| 1468 |
+
0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
|
| 1469 |
+
0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
|
| 1470 |
+
0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
|
| 1471 |
+
0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
|
| 1472 |
+
0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
|
| 1473 |
+
0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
|
| 1474 |
+
0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
|
| 1475 |
+
0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
|
| 1476 |
+
0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
|
| 1477 |
+
0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
|
| 1478 |
+
0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
|
| 1479 |
+
0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
|
| 1480 |
+
0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
|
| 1481 |
+
0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
|
| 1482 |
+
0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
|
| 1483 |
+
0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
|
| 1484 |
+
0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
|
| 1485 |
+
0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
|
| 1486 |
+
0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
|
| 1487 |
+
0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
|
| 1488 |
+
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
|
| 1489 |
+
0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
|
| 1490 |
+
0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
|
| 1491 |
+
0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
|
| 1492 |
+
0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
|
| 1493 |
+
0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
|
| 1494 |
+
0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
|
| 1495 |
+
0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
|
| 1496 |
+
0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
|
| 1497 |
+
0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
|
| 1498 |
+
0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
|
| 1499 |
+
0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
|
| 1500 |
+
0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
|
| 1501 |
+
0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
|
| 1502 |
+
0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
|
| 1503 |
+
0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
|
| 1504 |
+
0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
|
| 1505 |
+
0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
|
| 1506 |
+
0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
|
| 1507 |
+
0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
|
| 1508 |
+
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
|
| 1509 |
+
0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
|
| 1510 |
+
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
|
| 1511 |
+
0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
|
| 1512 |
+
0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
|
| 1513 |
+
0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
|
| 1514 |
+
0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
|
| 1515 |
+
0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
|
| 1516 |
+
0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
|
| 1517 |
+
0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
|
| 1518 |
+
0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
|
| 1519 |
+
0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
|
| 1520 |
+
0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
|
| 1521 |
+
0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
|
| 1522 |
+
0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
|
| 1523 |
+
0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
|
| 1524 |
+
0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
|
| 1525 |
+
0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
|
| 1526 |
+
0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
|
| 1527 |
+
0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
|
| 1528 |
+
0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
|
| 1529 |
+
0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
|
| 1530 |
+
0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
|
| 1531 |
+
0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
|
| 1532 |
+
0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
|
| 1533 |
+
0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
|
| 1534 |
+
0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
|
| 1535 |
+
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
| 1536 |
+
};
|
| 1537 |
+
|
| 1538 |
static const __device__ uint8_t ksigns_iq2xs[128] = {
|
| 1539 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 1540 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
|
|
| 1579 |
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
| 1580 |
const uint16_t * q2 = x[i].qs + 4*ib;
|
| 1581 |
const uint8_t * aux8 = (const uint8_t *)q2;
|
| 1582 |
+
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
|
| 1583 |
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
| 1584 |
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
|
| 1585 |
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
|
|
|
| 1590 |
|
| 1591 |
}
|
| 1592 |
|
| 1593 |
+
template<typename dst_t>
|
| 1594 |
+
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
| 1595 |
+
|
| 1596 |
+
const int i = blockIdx.x;
|
| 1597 |
+
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
| 1598 |
+
|
| 1599 |
+
const int tid = threadIdx.x;
|
| 1600 |
+
#if QK_K == 256
|
| 1601 |
+
const int il = tid/8; // 0...3
|
| 1602 |
+
const int ib = tid%8; // 0...7
|
| 1603 |
+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
| 1604 |
+
const uint16_t * q2 = x[i].qs + 4*ib;
|
| 1605 |
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
| 1606 |
+
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
| 1607 |
+
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
|
| 1608 |
+
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
| 1609 |
+
#else
|
| 1610 |
+
assert(false);
|
| 1611 |
+
#endif
|
| 1612 |
+
|
| 1613 |
+
}
|
| 1614 |
+
|
| 1615 |
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
| 1616 |
|
| 1617 |
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
|
|
|
| 4158 |
uint32_t aux32 = q2[2] | (q2[3] << 16);
|
| 4159 |
int sumi = 0;
|
| 4160 |
for (int l = 0; l < 4; ++l) {
|
| 4161 |
+
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
|
| 4162 |
const uint8_t signs = ksigns_iq2xs[aux32 & 127];
|
| 4163 |
for (int j = 0; j < 8; ++j) {
|
| 4164 |
sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
|
|
|
| 4174 |
const int il = iqs%2;
|
| 4175 |
const uint16_t * q2 = bq2->qs + 4*ib32;
|
| 4176 |
const uint8_t * aux8 = (const uint8_t *)q2;
|
| 4177 |
+
const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
| 4178 |
+
const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
| 4179 |
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
| 4180 |
const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
|
| 4181 |
const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
|
|
|
| 4194 |
#endif
|
| 4195 |
}
|
| 4196 |
|
| 4197 |
+
static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
| 4198 |
+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
| 4199 |
+
#if QK_K == 256
|
| 4200 |
+
const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
|
| 4201 |
+
|
| 4202 |
+
const int ib32 = iqs;
|
| 4203 |
+
const uint16_t * q2 = bq2->qs + 4*ib32;
|
| 4204 |
+
const int8_t * q8 = bq8_1[ib32].qs;
|
| 4205 |
+
const uint8_t ls1 = bq2->scales[ib32] & 0xf;
|
| 4206 |
+
const uint8_t ls2 = bq2->scales[ib32] >> 4;
|
| 4207 |
+
int sumi1 = 0;
|
| 4208 |
+
for (int l = 0; l < 2; ++l) {
|
| 4209 |
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
|
| 4210 |
+
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
|
| 4211 |
+
for (int j = 0; j < 8; ++j) {
|
| 4212 |
+
sumi1 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
| 4213 |
+
}
|
| 4214 |
+
q8 += 8;
|
| 4215 |
+
}
|
| 4216 |
+
int sumi2 = 0;
|
| 4217 |
+
for (int l = 2; l < 4; ++l) {
|
| 4218 |
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
|
| 4219 |
+
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
|
| 4220 |
+
for (int j = 0; j < 8; ++j) {
|
| 4221 |
+
sumi2 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
| 4222 |
+
}
|
| 4223 |
+
q8 += 8;
|
| 4224 |
+
}
|
| 4225 |
+
const float d = (float)bq2->d * (float)bq8_1[ib32].ds.x * 0.25f;
|
| 4226 |
+
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
|
| 4227 |
+
#else
|
| 4228 |
+
assert(false);
|
| 4229 |
+
return 0.f;
|
| 4230 |
+
#endif
|
| 4231 |
+
}
|
| 4232 |
+
|
| 4233 |
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
| 4234 |
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
| 4235 |
static __device__ __forceinline__ void mul_mat_q(
|
|
|
|
| 6233 |
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
| 6234 |
}
|
| 6235 |
|
| 6236 |
+
template<typename dst_t>
|
| 6237 |
+
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
| 6238 |
+
const int nb = k / QK_K;
|
| 6239 |
+
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
|
| 6240 |
+
}
|
| 6241 |
+
|
| 6242 |
template <typename src_t, typename dst_t>
|
| 6243 |
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
| 6244 |
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
|
|
|
| 6269 |
return dequantize_row_q6_K_cuda;
|
| 6270 |
case GGML_TYPE_IQ2_XXS:
|
| 6271 |
return dequantize_row_iq2_xxs_cuda;
|
| 6272 |
+
case GGML_TYPE_IQ2_XS:
|
| 6273 |
+
return dequantize_row_iq2_xs_cuda;
|
| 6274 |
case GGML_TYPE_F32:
|
| 6275 |
return convert_unary_cuda<float>;
|
| 6276 |
default:
|
|
|
|
| 6302 |
return dequantize_row_q6_K_cuda;
|
| 6303 |
case GGML_TYPE_IQ2_XXS:
|
| 6304 |
return dequantize_row_iq2_xxs_cuda;
|
| 6305 |
+
case GGML_TYPE_IQ2_XS:
|
| 6306 |
+
return dequantize_row_iq2_xs_cuda;
|
| 6307 |
case GGML_TYPE_F16:
|
| 6308 |
return convert_unary_cuda<half>;
|
| 6309 |
default:
|
|
|
|
| 6507 |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
| 6508 |
}
|
| 6509 |
|
| 6510 |
+
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
| 6511 |
+
GGML_ASSERT(ncols % QK_K == 0);
|
| 6512 |
+
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
| 6513 |
+
const dim3 block_nums(block_num_y, 1, 1);
|
| 6514 |
+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
| 6515 |
+
mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
|
| 6516 |
+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
| 6517 |
+
}
|
| 6518 |
+
|
| 6519 |
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
| 6520 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 6521 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
|
|
|
| 8088 |
case GGML_TYPE_Q5_K:
|
| 8089 |
case GGML_TYPE_Q6_K:
|
| 8090 |
case GGML_TYPE_IQ2_XXS:
|
| 8091 |
+
case GGML_TYPE_IQ2_XS:
|
| 8092 |
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
| 8093 |
default:
|
| 8094 |
GGML_ASSERT(false);
|
|
|
|
| 8110 |
case GGML_TYPE_Q4_K:
|
| 8111 |
case GGML_TYPE_Q5_K:
|
| 8112 |
case GGML_TYPE_IQ2_XXS:
|
| 8113 |
+
case GGML_TYPE_IQ2_XS:
|
| 8114 |
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
| 8115 |
case GGML_TYPE_Q6_K:
|
| 8116 |
return 64;
|
|
|
|
| 8164 |
case GGML_TYPE_IQ2_XXS:
|
| 8165 |
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
| 8166 |
break;
|
| 8167 |
+
case GGML_TYPE_IQ2_XS:
|
| 8168 |
+
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
| 8169 |
+
break;
|
| 8170 |
default:
|
| 8171 |
GGML_ASSERT(false);
|
| 8172 |
break;
|
ggml-metal.m
CHANGED
|
@@ -89,6 +89,7 @@ struct ggml_metal_context {
|
|
| 89 |
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
| 90 |
GGML_METAL_DECL_KERNEL(get_rows_i32);
|
| 91 |
GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
|
|
|
|
| 92 |
GGML_METAL_DECL_KERNEL(rms_norm);
|
| 93 |
GGML_METAL_DECL_KERNEL(group_norm);
|
| 94 |
GGML_METAL_DECL_KERNEL(norm);
|
|
@@ -108,6 +109,7 @@ struct ggml_metal_context {
|
|
| 108 |
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
| 109 |
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
| 110 |
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
|
|
|
|
| 111 |
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
| 112 |
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
| 113 |
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
|
@@ -124,6 +126,7 @@ struct ggml_metal_context {
|
|
| 124 |
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
| 125 |
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
| 126 |
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
|
|
|
|
| 127 |
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
| 128 |
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
| 129 |
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
|
@@ -137,6 +140,7 @@ struct ggml_metal_context {
|
|
| 137 |
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
| 138 |
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
| 139 |
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
|
|
|
|
| 140 |
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
| 141 |
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
| 142 |
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
|
@@ -150,6 +154,7 @@ struct ggml_metal_context {
|
|
| 150 |
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
| 151 |
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
| 152 |
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
|
|
|
|
| 153 |
GGML_METAL_DECL_KERNEL(rope_f32);
|
| 154 |
GGML_METAL_DECL_KERNEL(rope_f16);
|
| 155 |
GGML_METAL_DECL_KERNEL(alibi_f32);
|
|
@@ -385,6 +390,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 385 |
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
| 386 |
GGML_METAL_ADD_KERNEL(get_rows_i32);
|
| 387 |
GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
|
|
|
|
| 388 |
GGML_METAL_ADD_KERNEL(rms_norm);
|
| 389 |
GGML_METAL_ADD_KERNEL(group_norm);
|
| 390 |
GGML_METAL_ADD_KERNEL(norm);
|
|
@@ -404,6 +410,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 404 |
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
| 405 |
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
| 406 |
GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
|
|
|
|
| 407 |
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
| 408 |
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
| 409 |
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
|
@@ -420,6 +427,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 420 |
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
| 421 |
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
| 422 |
GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
|
|
|
|
| 423 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
| 424 |
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
| 425 |
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
|
@@ -434,6 +442,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 434 |
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
| 435 |
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
| 436 |
GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
|
|
|
|
| 437 |
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
|
| 438 |
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
|
| 439 |
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
|
|
@@ -447,6 +456,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 447 |
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
|
| 448 |
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
|
| 449 |
GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
|
|
|
|
| 450 |
}
|
| 451 |
GGML_METAL_ADD_KERNEL(rope_f32);
|
| 452 |
GGML_METAL_ADD_KERNEL(rope_f16);
|
|
@@ -513,6 +523,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 513 |
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
| 514 |
GGML_METAL_DEL_KERNEL(get_rows_i32);
|
| 515 |
GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
|
|
|
|
| 516 |
GGML_METAL_DEL_KERNEL(rms_norm);
|
| 517 |
GGML_METAL_DEL_KERNEL(group_norm);
|
| 518 |
GGML_METAL_DEL_KERNEL(norm);
|
|
@@ -532,6 +543,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 532 |
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
| 533 |
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
| 534 |
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
|
|
|
|
| 535 |
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
| 536 |
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
| 537 |
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
|
@@ -548,6 +560,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 548 |
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
| 549 |
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
| 550 |
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
|
|
|
|
| 551 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
| 552 |
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
| 553 |
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
|
@@ -562,6 +575,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 562 |
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
| 563 |
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
| 564 |
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
|
|
|
|
| 565 |
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
| 566 |
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
| 567 |
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
|
@@ -575,6 +589,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 575 |
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
| 576 |
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
| 577 |
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
|
|
|
|
| 578 |
}
|
| 579 |
GGML_METAL_DEL_KERNEL(rope_f32);
|
| 580 |
GGML_METAL_DEL_KERNEL(rope_f16);
|
|
@@ -1561,6 +1576,7 @@ bool ggml_metal_graph_compute(
|
|
| 1561 |
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
| 1562 |
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
| 1563 |
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
|
|
|
|
| 1564 |
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
| 1565 |
}
|
| 1566 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -1679,6 +1695,12 @@ bool ggml_metal_graph_compute(
|
|
| 1679 |
nth1 = 16;
|
| 1680 |
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
|
| 1681 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1682 |
default:
|
| 1683 |
{
|
| 1684 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
@@ -1712,12 +1734,12 @@ bool ggml_metal_graph_compute(
|
|
| 1712 |
|
| 1713 |
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
| 1714 |
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
| 1715 |
-
//src0t == GGML_TYPE_IQ2_XXS ||
|
| 1716 |
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
| 1717 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1718 |
}
|
| 1719 |
-
else if (src0t == GGML_TYPE_IQ2_XXS) {
|
| 1720 |
-
|
|
|
|
| 1721 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1722 |
}
|
| 1723 |
else if (src0t == GGML_TYPE_Q4_K) {
|
|
@@ -1810,6 +1832,7 @@ bool ggml_metal_graph_compute(
|
|
| 1810 |
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
| 1811 |
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
| 1812 |
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
|
|
|
|
| 1813 |
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
| 1814 |
}
|
| 1815 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -1931,6 +1954,12 @@ bool ggml_metal_graph_compute(
|
|
| 1931 |
nth1 = 16;
|
| 1932 |
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
|
| 1933 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1934 |
default:
|
| 1935 |
{
|
| 1936 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
|
@@ -1980,12 +2009,12 @@ bool ggml_metal_graph_compute(
|
|
| 1980 |
|
| 1981 |
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
| 1982 |
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
| 1983 |
-
//src2t == GGML_TYPE_IQ2_XXS ||
|
| 1984 |
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
| 1985 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1986 |
}
|
| 1987 |
-
else if (src2t == GGML_TYPE_IQ2_XXS) {
|
| 1988 |
-
|
|
|
|
| 1989 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1990 |
}
|
| 1991 |
else if (src2t == GGML_TYPE_Q4_K) {
|
|
@@ -2026,6 +2055,7 @@ bool ggml_metal_graph_compute(
|
|
| 2026 |
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
| 2027 |
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
|
| 2028 |
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
|
|
|
|
| 2029 |
default: GGML_ASSERT(false && "not implemented");
|
| 2030 |
}
|
| 2031 |
|
|
|
|
| 89 |
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
| 90 |
GGML_METAL_DECL_KERNEL(get_rows_i32);
|
| 91 |
GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
|
| 92 |
+
GGML_METAL_DECL_KERNEL(get_rows_iq2_xs);
|
| 93 |
GGML_METAL_DECL_KERNEL(rms_norm);
|
| 94 |
GGML_METAL_DECL_KERNEL(group_norm);
|
| 95 |
GGML_METAL_DECL_KERNEL(norm);
|
|
|
|
| 109 |
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
| 110 |
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
| 111 |
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
|
| 112 |
+
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xs_f32);
|
| 113 |
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
| 114 |
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
| 115 |
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
|
|
|
| 126 |
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
| 127 |
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
| 128 |
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
|
| 129 |
+
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xs_f32);
|
| 130 |
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
| 131 |
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
| 132 |
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
|
|
|
| 140 |
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
| 141 |
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
| 142 |
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
|
| 143 |
+
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xs_f32);
|
| 144 |
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
| 145 |
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
| 146 |
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
|
|
|
| 154 |
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
| 155 |
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
| 156 |
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
|
| 157 |
+
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xs_f32);
|
| 158 |
GGML_METAL_DECL_KERNEL(rope_f32);
|
| 159 |
GGML_METAL_DECL_KERNEL(rope_f16);
|
| 160 |
GGML_METAL_DECL_KERNEL(alibi_f32);
|
|
|
|
| 390 |
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
| 391 |
GGML_METAL_ADD_KERNEL(get_rows_i32);
|
| 392 |
GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
|
| 393 |
+
GGML_METAL_ADD_KERNEL(get_rows_iq2_xs);
|
| 394 |
GGML_METAL_ADD_KERNEL(rms_norm);
|
| 395 |
GGML_METAL_ADD_KERNEL(group_norm);
|
| 396 |
GGML_METAL_ADD_KERNEL(norm);
|
|
|
|
| 410 |
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
| 411 |
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
| 412 |
GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
|
| 413 |
+
GGML_METAL_ADD_KERNEL(mul_mv_iq2_xs_f32);
|
| 414 |
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
| 415 |
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
| 416 |
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
|
|
|
| 427 |
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
| 428 |
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
| 429 |
GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
|
| 430 |
+
GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xs_f32);
|
| 431 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
| 432 |
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
| 433 |
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
|
|
|
| 442 |
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
| 443 |
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
| 444 |
GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
|
| 445 |
+
GGML_METAL_ADD_KERNEL(mul_mm_iq2_xs_f32);
|
| 446 |
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
|
| 447 |
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
|
| 448 |
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
|
|
|
|
| 456 |
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
|
| 457 |
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
|
| 458 |
GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
|
| 459 |
+
GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xs_f32);
|
| 460 |
}
|
| 461 |
GGML_METAL_ADD_KERNEL(rope_f32);
|
| 462 |
GGML_METAL_ADD_KERNEL(rope_f16);
|
|
|
|
| 523 |
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
| 524 |
GGML_METAL_DEL_KERNEL(get_rows_i32);
|
| 525 |
GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
|
| 526 |
+
GGML_METAL_DEL_KERNEL(get_rows_iq2_xs);
|
| 527 |
GGML_METAL_DEL_KERNEL(rms_norm);
|
| 528 |
GGML_METAL_DEL_KERNEL(group_norm);
|
| 529 |
GGML_METAL_DEL_KERNEL(norm);
|
|
|
|
| 543 |
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
| 544 |
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
| 545 |
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
|
| 546 |
+
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xs_f32);
|
| 547 |
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
| 548 |
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
| 549 |
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
|
|
|
| 560 |
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
| 561 |
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
| 562 |
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
|
| 563 |
+
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xs_f32);
|
| 564 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
| 565 |
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
| 566 |
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
|
|
|
| 575 |
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
| 576 |
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
| 577 |
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
|
| 578 |
+
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xs_f32);
|
| 579 |
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
| 580 |
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
| 581 |
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
|
|
|
| 589 |
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
| 590 |
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
| 591 |
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
|
| 592 |
+
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xs_f32);
|
| 593 |
}
|
| 594 |
GGML_METAL_DEL_KERNEL(rope_f32);
|
| 595 |
GGML_METAL_DEL_KERNEL(rope_f16);
|
|
|
|
| 1576 |
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
| 1577 |
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
| 1578 |
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
|
| 1579 |
+
case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xs_f32]; break;
|
| 1580 |
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
| 1581 |
}
|
| 1582 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
| 1695 |
nth1 = 16;
|
| 1696 |
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
|
| 1697 |
} break;
|
| 1698 |
+
case GGML_TYPE_IQ2_XS:
|
| 1699 |
+
{
|
| 1700 |
+
nth0 = 4;
|
| 1701 |
+
nth1 = 16;
|
| 1702 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xs_f32];
|
| 1703 |
+
} break;
|
| 1704 |
default:
|
| 1705 |
{
|
| 1706 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
|
|
| 1734 |
|
| 1735 |
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
| 1736 |
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
|
|
|
| 1737 |
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
| 1738 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1739 |
}
|
| 1740 |
+
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
| 1741 |
+
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
| 1742 |
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1743 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1744 |
}
|
| 1745 |
else if (src0t == GGML_TYPE_Q4_K) {
|
|
|
|
| 1832 |
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
| 1833 |
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
| 1834 |
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
|
| 1835 |
+
case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xs_f32]; break;
|
| 1836 |
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
| 1837 |
}
|
| 1838 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
| 1954 |
nth1 = 16;
|
| 1955 |
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
|
| 1956 |
} break;
|
| 1957 |
+
case GGML_TYPE_IQ2_XS:
|
| 1958 |
+
{
|
| 1959 |
+
nth0 = 4;
|
| 1960 |
+
nth1 = 16;
|
| 1961 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xs_f32];
|
| 1962 |
+
} break;
|
| 1963 |
default:
|
| 1964 |
{
|
| 1965 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
|
|
|
| 2009 |
|
| 2010 |
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
| 2011 |
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
|
|
|
| 2012 |
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
| 2013 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2014 |
}
|
| 2015 |
+
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
| 2016 |
+
const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
| 2017 |
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2018 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2019 |
}
|
| 2020 |
else if (src2t == GGML_TYPE_Q4_K) {
|
|
|
|
| 2055 |
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
| 2056 |
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
|
| 2057 |
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
|
| 2058 |
+
case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xs]; break;
|
| 2059 |
default: GGML_ASSERT(false && "not implemented");
|
| 2060 |
}
|
| 2061 |
|
ggml-metal.metal
CHANGED
|
@@ -2452,6 +2452,13 @@ typedef struct {
|
|
| 2452 |
} block_iq2_xxs;
|
| 2453 |
// 66 bytes / block for QK_K = 256, so 2.0625 bpw
|
| 2454 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2455 |
//====================================== dot products =========================
|
| 2456 |
|
| 2457 |
void kernel_mul_mv_q2_K_f32_impl(
|
|
@@ -3476,7 +3483,7 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
| 3476 |
|
| 3477 |
// ======================= "True" 2-bit
|
| 3478 |
|
| 3479 |
-
constexpr constant static uint64_t
|
| 3480 |
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
| 3481 |
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
| 3482 |
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
|
@@ -3543,6 +3550,137 @@ constexpr constant static uint64_t kgrid_iq2xxs[256] = {
|
|
| 3543 |
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
| 3544 |
};
|
| 3545 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3546 |
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
| 3547 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 3548 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
@@ -3600,7 +3738,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
| 3600 |
{
|
| 3601 |
int nval = 4;
|
| 3602 |
int pos = (32*sgitg + tiisg)*nval;
|
| 3603 |
-
for (int i = 0; i < nval; ++i) values[pos + i] =
|
| 3604 |
nval = 2;
|
| 3605 |
pos = (32*sgitg + tiisg)*nval;
|
| 3606 |
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
|
@@ -3689,6 +3827,149 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
|
|
| 3689 |
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 3690 |
}
|
| 3691 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3692 |
//============================= templates and their specializations =============================
|
| 3693 |
|
| 3694 |
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
@@ -3973,18 +4254,39 @@ void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x
|
|
| 3973 |
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
|
| 3974 |
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
|
| 3975 |
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
|
| 3976 |
-
constant uint8_t * grid = (constant uint8_t *)(
|
| 3977 |
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
|
| 3978 |
for (int i = 0; i < 8; ++i) {
|
| 3979 |
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
| 3980 |
}
|
| 3981 |
-
grid = (constant uint8_t *)(
|
| 3982 |
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
|
| 3983 |
for (int i = 0; i < 8; ++i) {
|
| 3984 |
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
| 3985 |
}
|
| 3986 |
}
|
| 3987 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3988 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 3989 |
kernel void kernel_get_rows(
|
| 3990 |
device const void * src0,
|
|
@@ -4525,6 +4827,7 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|
| 4525 |
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4526 |
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4527 |
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
|
|
| 4528 |
|
| 4529 |
//
|
| 4530 |
// matrix-matrix multiplication
|
|
@@ -4562,6 +4865,7 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
| 4562 |
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4563 |
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4564 |
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
|
|
| 4565 |
|
| 4566 |
//
|
| 4567 |
// indirect matrix-matrix multiplication
|
|
@@ -4611,6 +4915,7 @@ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
| 4611 |
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4612 |
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4613 |
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
|
|
| 4614 |
|
| 4615 |
//
|
| 4616 |
// matrix-vector multiplication
|
|
@@ -5448,3 +5753,68 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
| 5448 |
tiisg,
|
| 5449 |
sgitg);
|
| 5450 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2452 |
} block_iq2_xxs;
|
| 2453 |
// 66 bytes / block for QK_K = 256, so 2.0625 bpw
|
| 2454 |
|
| 2455 |
+
typedef struct {
|
| 2456 |
+
half d;
|
| 2457 |
+
uint16_t qs[QK_K/8];
|
| 2458 |
+
uint8_t scales[QK_K/32];
|
| 2459 |
+
} block_iq2_xs;
|
| 2460 |
+
// 74 bytes / block for QK_K = 256, so 2.3125 bpw
|
| 2461 |
+
|
| 2462 |
//====================================== dot products =========================
|
| 2463 |
|
| 2464 |
void kernel_mul_mv_q2_K_f32_impl(
|
|
|
|
| 3483 |
|
| 3484 |
// ======================= "True" 2-bit
|
| 3485 |
|
| 3486 |
+
constexpr constant static uint64_t iq2xxs_grid[256] = {
|
| 3487 |
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
| 3488 |
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
| 3489 |
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
|
|
|
| 3550 |
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
| 3551 |
};
|
| 3552 |
|
| 3553 |
+
constexpr constant static uint64_t iq2xs_grid[512] = {
|
| 3554 |
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
| 3555 |
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
| 3556 |
+
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
| 3557 |
+
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
| 3558 |
+
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
| 3559 |
+
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
|
| 3560 |
+
0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
|
| 3561 |
+
0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
|
| 3562 |
+
0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
|
| 3563 |
+
0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
|
| 3564 |
+
0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
|
| 3565 |
+
0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
|
| 3566 |
+
0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
|
| 3567 |
+
0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
|
| 3568 |
+
0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
|
| 3569 |
+
0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
|
| 3570 |
+
0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
|
| 3571 |
+
0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
|
| 3572 |
+
0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
|
| 3573 |
+
0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
|
| 3574 |
+
0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
|
| 3575 |
+
0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
|
| 3576 |
+
0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
|
| 3577 |
+
0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
|
| 3578 |
+
0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
|
| 3579 |
+
0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
|
| 3580 |
+
0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
|
| 3581 |
+
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
|
| 3582 |
+
0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
|
| 3583 |
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
|
| 3584 |
+
0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
|
| 3585 |
+
0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
|
| 3586 |
+
0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
|
| 3587 |
+
0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
|
| 3588 |
+
0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
|
| 3589 |
+
0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
|
| 3590 |
+
0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
|
| 3591 |
+
0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
|
| 3592 |
+
0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
|
| 3593 |
+
0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
|
| 3594 |
+
0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
|
| 3595 |
+
0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
|
| 3596 |
+
0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
|
| 3597 |
+
0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
|
| 3598 |
+
0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
|
| 3599 |
+
0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
|
| 3600 |
+
0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
|
| 3601 |
+
0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
|
| 3602 |
+
0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
|
| 3603 |
+
0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
|
| 3604 |
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
|
| 3605 |
+
0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
|
| 3606 |
+
0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
|
| 3607 |
+
0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
|
| 3608 |
+
0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
|
| 3609 |
+
0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
|
| 3610 |
+
0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
|
| 3611 |
+
0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
|
| 3612 |
+
0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
|
| 3613 |
+
0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
|
| 3614 |
+
0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
|
| 3615 |
+
0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
|
| 3616 |
+
0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
|
| 3617 |
+
0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
|
| 3618 |
+
0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
|
| 3619 |
+
0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
|
| 3620 |
+
0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
|
| 3621 |
+
0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
|
| 3622 |
+
0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
|
| 3623 |
+
0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
|
| 3624 |
+
0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
|
| 3625 |
+
0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
|
| 3626 |
+
0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
|
| 3627 |
+
0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
|
| 3628 |
+
0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
|
| 3629 |
+
0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
|
| 3630 |
+
0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
|
| 3631 |
+
0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
|
| 3632 |
+
0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
|
| 3633 |
+
0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
|
| 3634 |
+
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
|
| 3635 |
+
0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
|
| 3636 |
+
0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
|
| 3637 |
+
0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
|
| 3638 |
+
0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
|
| 3639 |
+
0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
|
| 3640 |
+
0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
|
| 3641 |
+
0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
|
| 3642 |
+
0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
|
| 3643 |
+
0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
|
| 3644 |
+
0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
|
| 3645 |
+
0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
|
| 3646 |
+
0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
|
| 3647 |
+
0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
|
| 3648 |
+
0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
|
| 3649 |
+
0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
|
| 3650 |
+
0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
|
| 3651 |
+
0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
|
| 3652 |
+
0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
|
| 3653 |
+
0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
|
| 3654 |
+
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
|
| 3655 |
+
0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
|
| 3656 |
+
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
|
| 3657 |
+
0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
|
| 3658 |
+
0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
|
| 3659 |
+
0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
|
| 3660 |
+
0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
|
| 3661 |
+
0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
|
| 3662 |
+
0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
|
| 3663 |
+
0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
|
| 3664 |
+
0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
|
| 3665 |
+
0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
|
| 3666 |
+
0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
|
| 3667 |
+
0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
|
| 3668 |
+
0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
|
| 3669 |
+
0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
|
| 3670 |
+
0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
|
| 3671 |
+
0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
|
| 3672 |
+
0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
|
| 3673 |
+
0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
|
| 3674 |
+
0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
|
| 3675 |
+
0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
|
| 3676 |
+
0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
|
| 3677 |
+
0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
|
| 3678 |
+
0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
|
| 3679 |
+
0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
|
| 3680 |
+
0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
|
| 3681 |
+
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
| 3682 |
+
};
|
| 3683 |
+
|
| 3684 |
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
| 3685 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 3686 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
|
|
| 3738 |
{
|
| 3739 |
int nval = 4;
|
| 3740 |
int pos = (32*sgitg + tiisg)*nval;
|
| 3741 |
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
|
| 3742 |
nval = 2;
|
| 3743 |
pos = (32*sgitg + tiisg)*nval;
|
| 3744 |
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
|
|
|
| 3827 |
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 3828 |
}
|
| 3829 |
|
| 3830 |
+
void kernel_mul_mv_iq2_xs_f32_impl(
|
| 3831 |
+
device const void * src0,
|
| 3832 |
+
device const float * src1,
|
| 3833 |
+
device float * dst,
|
| 3834 |
+
constant int64_t & ne00,
|
| 3835 |
+
constant int64_t & ne01,
|
| 3836 |
+
constant int64_t & ne02,
|
| 3837 |
+
constant int64_t & ne10,
|
| 3838 |
+
constant int64_t & ne12,
|
| 3839 |
+
constant int64_t & ne0,
|
| 3840 |
+
constant int64_t & ne1,
|
| 3841 |
+
constant uint & r2,
|
| 3842 |
+
constant uint & r3,
|
| 3843 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 3844 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3845 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 3846 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3847 |
+
|
| 3848 |
+
const int nb = ne00/QK_K;
|
| 3849 |
+
const int r0 = tgpig.x;
|
| 3850 |
+
const int r1 = tgpig.y;
|
| 3851 |
+
const int im = tgpig.z;
|
| 3852 |
+
|
| 3853 |
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 3854 |
+
const int ib_row = first_row * nb;
|
| 3855 |
+
|
| 3856 |
+
const uint i12 = im%ne12;
|
| 3857 |
+
const uint i13 = im/ne12;
|
| 3858 |
+
|
| 3859 |
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
| 3860 |
+
|
| 3861 |
+
device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
|
| 3862 |
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
| 3863 |
+
|
| 3864 |
+
float yl[32];
|
| 3865 |
+
float sumf[N_DST]={0.f}, all_sum;
|
| 3866 |
+
|
| 3867 |
+
const int nb32 = nb * (QK_K / 32);
|
| 3868 |
+
|
| 3869 |
+
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
| 3870 |
+
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
|
| 3871 |
+
{
|
| 3872 |
+
int nval = 8;
|
| 3873 |
+
int pos = (32*sgitg + tiisg)*nval;
|
| 3874 |
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
|
| 3875 |
+
nval = 2;
|
| 3876 |
+
pos = (32*sgitg + tiisg)*nval;
|
| 3877 |
+
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
| 3878 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3879 |
+
}
|
| 3880 |
+
|
| 3881 |
+
#if QK_K == 256
|
| 3882 |
+
const int ix = tiisg;
|
| 3883 |
+
|
| 3884 |
+
device const float * y4 = y + 32 * ix;
|
| 3885 |
+
|
| 3886 |
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 3887 |
+
|
| 3888 |
+
for (int i = 0; i < 32; ++i) {
|
| 3889 |
+
yl[i] = y4[i];
|
| 3890 |
+
}
|
| 3891 |
+
|
| 3892 |
+
const int ibl = ib32 / (QK_K / 32);
|
| 3893 |
+
const int ib = ib32 % (QK_K / 32);
|
| 3894 |
+
|
| 3895 |
+
device const block_iq2_xs * xr = x + ibl;
|
| 3896 |
+
device const uint16_t * q2 = xr->qs + 4 * ib;
|
| 3897 |
+
device const uint8_t * sc = xr->scales + ib;
|
| 3898 |
+
device const half * dh = &xr->d;
|
| 3899 |
+
|
| 3900 |
+
for (int row = 0; row < N_DST; row++) {
|
| 3901 |
+
|
| 3902 |
+
const float db = dh[0];
|
| 3903 |
+
const uint8_t ls1 = sc[0] & 0xf;
|
| 3904 |
+
const uint8_t ls2 = sc[0] >> 4;
|
| 3905 |
+
const float d1 = db * (0.5f + ls1);
|
| 3906 |
+
const float d2 = db * (0.5f + ls2);
|
| 3907 |
+
|
| 3908 |
+
float sum1 = 0, sum2 = 0;
|
| 3909 |
+
for (int l = 0; l < 2; ++l) {
|
| 3910 |
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
|
| 3911 |
+
const uint8_t signs = shared_signs[(q2[l] >> 9)];
|
| 3912 |
+
for (int j = 0; j < 8; ++j) {
|
| 3913 |
+
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
| 3914 |
+
}
|
| 3915 |
+
}
|
| 3916 |
+
for (int l = 2; l < 4; ++l) {
|
| 3917 |
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
|
| 3918 |
+
const uint8_t signs = shared_signs[(q2[l] >> 9)];
|
| 3919 |
+
for (int j = 0; j < 8; ++j) {
|
| 3920 |
+
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
| 3921 |
+
}
|
| 3922 |
+
}
|
| 3923 |
+
sumf[row] += d1 * sum1 + d2 * sum2;
|
| 3924 |
+
|
| 3925 |
+
dh += nb*sizeof(block_iq2_xs)/2;
|
| 3926 |
+
q2 += nb*sizeof(block_iq2_xs)/2;
|
| 3927 |
+
sc += nb*sizeof(block_iq2_xs);
|
| 3928 |
+
}
|
| 3929 |
+
|
| 3930 |
+
y4 += 32 * 32;
|
| 3931 |
+
}
|
| 3932 |
+
#else
|
| 3933 |
+
// TODO
|
| 3934 |
+
#endif
|
| 3935 |
+
|
| 3936 |
+
for (int row = 0; row < N_DST; ++row) {
|
| 3937 |
+
all_sum = simd_sum(sumf[row]);
|
| 3938 |
+
if (tiisg == 0) {
|
| 3939 |
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
| 3940 |
+
}
|
| 3941 |
+
}
|
| 3942 |
+
}
|
| 3943 |
+
|
| 3944 |
+
[[host_name("kernel_mul_mv_iq2_xs_f32")]]
|
| 3945 |
+
kernel void kernel_mul_mv_iq2_xs_f32(
|
| 3946 |
+
device const void * src0,
|
| 3947 |
+
device const float * src1,
|
| 3948 |
+
device float * dst,
|
| 3949 |
+
constant int64_t & ne00,
|
| 3950 |
+
constant int64_t & ne01,
|
| 3951 |
+
constant int64_t & ne02,
|
| 3952 |
+
constant uint64_t & nb00,
|
| 3953 |
+
constant uint64_t & nb01,
|
| 3954 |
+
constant uint64_t & nb02,
|
| 3955 |
+
constant int64_t & ne10,
|
| 3956 |
+
constant int64_t & ne11,
|
| 3957 |
+
constant int64_t & ne12,
|
| 3958 |
+
constant uint64_t & nb10,
|
| 3959 |
+
constant uint64_t & nb11,
|
| 3960 |
+
constant uint64_t & nb12,
|
| 3961 |
+
constant int64_t & ne0,
|
| 3962 |
+
constant int64_t & ne1,
|
| 3963 |
+
constant uint & r2,
|
| 3964 |
+
constant uint & r3,
|
| 3965 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 3966 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3967 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 3968 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3969 |
+
|
| 3970 |
+
kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 3971 |
+
}
|
| 3972 |
+
|
| 3973 |
//============================= templates and their specializations =============================
|
| 3974 |
|
| 3975 |
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
|
|
| 4254 |
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
|
| 4255 |
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
|
| 4256 |
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
|
| 4257 |
+
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
| 4258 |
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
|
| 4259 |
for (int i = 0; i < 8; ++i) {
|
| 4260 |
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
| 4261 |
}
|
| 4262 |
+
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
| 4263 |
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
|
| 4264 |
for (int i = 0; i < 8; ++i) {
|
| 4265 |
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
| 4266 |
}
|
| 4267 |
}
|
| 4268 |
|
| 4269 |
+
template <typename type4x4>
|
| 4270 |
+
void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
|
| 4271 |
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
| 4272 |
+
const float d = xb->d;
|
| 4273 |
+
const int ib32 = il/2;
|
| 4274 |
+
il = il%2;
|
| 4275 |
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
| 4276 |
+
device const uint16_t * q2 = xb->qs + 4*ib32;
|
| 4277 |
+
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
| 4278 |
+
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
|
| 4279 |
+
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
|
| 4280 |
+
for (int i = 0; i < 8; ++i) {
|
| 4281 |
+
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
| 4282 |
+
}
|
| 4283 |
+
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
|
| 4284 |
+
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
|
| 4285 |
+
for (int i = 0; i < 8; ++i) {
|
| 4286 |
+
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
| 4287 |
+
}
|
| 4288 |
+
}
|
| 4289 |
+
|
| 4290 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 4291 |
kernel void kernel_get_rows(
|
| 4292 |
device const void * src0,
|
|
|
|
| 4827 |
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4828 |
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4829 |
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 4830 |
+
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 4831 |
|
| 4832 |
//
|
| 4833 |
// matrix-matrix multiplication
|
|
|
|
| 4865 |
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4866 |
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4867 |
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 4868 |
+
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 4869 |
|
| 4870 |
//
|
| 4871 |
// indirect matrix-matrix multiplication
|
|
|
|
| 4915 |
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4916 |
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4917 |
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 4918 |
+
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 4919 |
|
| 4920 |
//
|
| 4921 |
// matrix-vector multiplication
|
|
|
|
| 5753 |
tiisg,
|
| 5754 |
sgitg);
|
| 5755 |
}
|
| 5756 |
+
|
| 5757 |
+
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
| 5758 |
+
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
| 5759 |
+
device const char * ids,
|
| 5760 |
+
device const char * src1,
|
| 5761 |
+
device float * dst,
|
| 5762 |
+
constant uint64_t & nbi1,
|
| 5763 |
+
constant int64_t & ne00,
|
| 5764 |
+
constant int64_t & ne01,
|
| 5765 |
+
constant int64_t & ne02,
|
| 5766 |
+
constant uint64_t & nb00,
|
| 5767 |
+
constant uint64_t & nb01,
|
| 5768 |
+
constant uint64_t & nb02,
|
| 5769 |
+
constant int64_t & ne10,
|
| 5770 |
+
constant int64_t & ne11,
|
| 5771 |
+
constant int64_t & ne12,
|
| 5772 |
+
constant int64_t & ne13,
|
| 5773 |
+
constant uint64_t & nb10,
|
| 5774 |
+
constant uint64_t & nb11,
|
| 5775 |
+
constant uint64_t & nb12,
|
| 5776 |
+
constant int64_t & ne0,
|
| 5777 |
+
constant int64_t & ne1,
|
| 5778 |
+
constant uint64_t & nb1,
|
| 5779 |
+
constant uint & r2,
|
| 5780 |
+
constant uint & r3,
|
| 5781 |
+
constant int & idx,
|
| 5782 |
+
device const char * src00,
|
| 5783 |
+
device const char * src01,
|
| 5784 |
+
device const char * src02,
|
| 5785 |
+
device const char * src03,
|
| 5786 |
+
device const char * src04,
|
| 5787 |
+
device const char * src05,
|
| 5788 |
+
device const char * src06,
|
| 5789 |
+
device const char * src07,
|
| 5790 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 5791 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5792 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 5793 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 5794 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5795 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 5796 |
+
|
| 5797 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 5798 |
+
|
| 5799 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 5800 |
+
|
| 5801 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 5802 |
+
|
| 5803 |
+
kernel_mul_mv_iq2_xs_f32_impl(
|
| 5804 |
+
src0[id],
|
| 5805 |
+
(device const float *) (src1 + bid*nb11),
|
| 5806 |
+
dst + bid*ne0,
|
| 5807 |
+
ne00,
|
| 5808 |
+
ne01,
|
| 5809 |
+
ne02,
|
| 5810 |
+
ne10,
|
| 5811 |
+
ne12,
|
| 5812 |
+
ne0,
|
| 5813 |
+
ne1,
|
| 5814 |
+
r2,
|
| 5815 |
+
r3,
|
| 5816 |
+
shared_values,
|
| 5817 |
+
tgpig,
|
| 5818 |
+
tiisg,
|
| 5819 |
+
sgitg);
|
| 5820 |
+
}
|
ggml-quants.c
CHANGED
|
@@ -2342,15 +2342,7 @@ size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t *
|
|
| 2342 |
|
| 2343 |
// ====================== "True" 2-bit (de)-quantization
|
| 2344 |
|
| 2345 |
-
|
| 2346 |
-
(void)x;
|
| 2347 |
-
(void)y;
|
| 2348 |
-
(void)k;
|
| 2349 |
-
assert(k % QK_K == 0);
|
| 2350 |
-
//fprintf(stderr, "=========================== %s: not implemented\n", __func__);
|
| 2351 |
-
}
|
| 2352 |
-
|
| 2353 |
-
static const uint64_t iq2xxs_grid[256] = {
|
| 2354 |
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
| 2355 |
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
| 2356 |
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
|
@@ -2417,6 +2409,137 @@ static const uint64_t iq2xxs_grid[256] = {
|
|
| 2417 |
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
| 2418 |
};
|
| 2419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2420 |
static const uint8_t ksigns_iq2xs[128] = {
|
| 2421 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 2422 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
@@ -2427,8 +2550,17 @@ static const uint8_t ksigns_iq2xs[128] = {
|
|
| 2427 |
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
| 2428 |
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
| 2429 |
};
|
|
|
|
| 2430 |
static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
| 2431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2432 |
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
|
| 2433 |
assert(k % QK_K == 0);
|
| 2434 |
const int nb = k / QK_K;
|
|
@@ -2472,6 +2604,58 @@ size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_
|
|
| 2472 |
return (n/QK_K*sizeof(block_iq2_xxs));
|
| 2473 |
}
|
| 2474 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2475 |
//===================================== Q8_K ==============================================
|
| 2476 |
|
| 2477 |
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
|
@@ -7357,3 +7541,161 @@ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * res
|
|
| 7357 |
*s = 0.125f * sumf;
|
| 7358 |
#endif
|
| 7359 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2342 |
|
| 2343 |
// ====================== "True" 2-bit (de)-quantization
|
| 2344 |
|
| 2345 |
+
static const uint64_t iq2xxs_grid[256] = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2346 |
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
| 2347 |
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
| 2348 |
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
|
|
|
| 2409 |
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
| 2410 |
};
|
| 2411 |
|
| 2412 |
+
static const uint64_t iq2xs_grid[512] = {
|
| 2413 |
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
| 2414 |
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
| 2415 |
+
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
| 2416 |
+
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
| 2417 |
+
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
| 2418 |
+
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
|
| 2419 |
+
0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
|
| 2420 |
+
0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
|
| 2421 |
+
0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
|
| 2422 |
+
0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
|
| 2423 |
+
0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
|
| 2424 |
+
0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
|
| 2425 |
+
0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
|
| 2426 |
+
0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
|
| 2427 |
+
0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
|
| 2428 |
+
0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
|
| 2429 |
+
0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
|
| 2430 |
+
0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
|
| 2431 |
+
0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
|
| 2432 |
+
0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
|
| 2433 |
+
0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
|
| 2434 |
+
0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
|
| 2435 |
+
0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
|
| 2436 |
+
0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
|
| 2437 |
+
0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
|
| 2438 |
+
0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
|
| 2439 |
+
0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
|
| 2440 |
+
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
|
| 2441 |
+
0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
|
| 2442 |
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
|
| 2443 |
+
0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
|
| 2444 |
+
0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
|
| 2445 |
+
0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
|
| 2446 |
+
0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
|
| 2447 |
+
0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
|
| 2448 |
+
0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
|
| 2449 |
+
0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
|
| 2450 |
+
0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
|
| 2451 |
+
0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
|
| 2452 |
+
0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
|
| 2453 |
+
0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
|
| 2454 |
+
0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
|
| 2455 |
+
0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
|
| 2456 |
+
0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
|
| 2457 |
+
0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
|
| 2458 |
+
0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
|
| 2459 |
+
0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
|
| 2460 |
+
0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
|
| 2461 |
+
0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
|
| 2462 |
+
0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
|
| 2463 |
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
|
| 2464 |
+
0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
|
| 2465 |
+
0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
|
| 2466 |
+
0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
|
| 2467 |
+
0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
|
| 2468 |
+
0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
|
| 2469 |
+
0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
|
| 2470 |
+
0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
|
| 2471 |
+
0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
|
| 2472 |
+
0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
|
| 2473 |
+
0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
|
| 2474 |
+
0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
|
| 2475 |
+
0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
|
| 2476 |
+
0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
|
| 2477 |
+
0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
|
| 2478 |
+
0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
|
| 2479 |
+
0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
|
| 2480 |
+
0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
|
| 2481 |
+
0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
|
| 2482 |
+
0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
|
| 2483 |
+
0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
|
| 2484 |
+
0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
|
| 2485 |
+
0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
|
| 2486 |
+
0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
|
| 2487 |
+
0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
|
| 2488 |
+
0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
|
| 2489 |
+
0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
|
| 2490 |
+
0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
|
| 2491 |
+
0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
|
| 2492 |
+
0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
|
| 2493 |
+
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
|
| 2494 |
+
0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
|
| 2495 |
+
0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
|
| 2496 |
+
0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
|
| 2497 |
+
0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
|
| 2498 |
+
0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
|
| 2499 |
+
0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
|
| 2500 |
+
0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
|
| 2501 |
+
0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
|
| 2502 |
+
0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
|
| 2503 |
+
0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
|
| 2504 |
+
0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
|
| 2505 |
+
0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
|
| 2506 |
+
0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
|
| 2507 |
+
0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
|
| 2508 |
+
0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
|
| 2509 |
+
0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
|
| 2510 |
+
0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
|
| 2511 |
+
0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
|
| 2512 |
+
0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
|
| 2513 |
+
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
|
| 2514 |
+
0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
|
| 2515 |
+
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
|
| 2516 |
+
0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
|
| 2517 |
+
0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
|
| 2518 |
+
0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
|
| 2519 |
+
0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
|
| 2520 |
+
0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
|
| 2521 |
+
0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
|
| 2522 |
+
0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
|
| 2523 |
+
0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
|
| 2524 |
+
0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
|
| 2525 |
+
0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
|
| 2526 |
+
0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
|
| 2527 |
+
0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
|
| 2528 |
+
0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
|
| 2529 |
+
0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
|
| 2530 |
+
0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
|
| 2531 |
+
0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
|
| 2532 |
+
0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
|
| 2533 |
+
0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
|
| 2534 |
+
0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
|
| 2535 |
+
0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
|
| 2536 |
+
0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
|
| 2537 |
+
0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
|
| 2538 |
+
0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
|
| 2539 |
+
0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
|
| 2540 |
+
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
| 2541 |
+
};
|
| 2542 |
+
|
| 2543 |
static const uint8_t ksigns_iq2xs[128] = {
|
| 2544 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 2545 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
|
|
| 2550 |
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
| 2551 |
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
| 2552 |
};
|
| 2553 |
+
|
| 2554 |
static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
| 2555 |
|
| 2556 |
+
void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) {
|
| 2557 |
+
(void)x;
|
| 2558 |
+
(void)y;
|
| 2559 |
+
(void)k;
|
| 2560 |
+
assert(k % QK_K == 0);
|
| 2561 |
+
//fprintf(stderr, "=========================== %s: not implemented\n", __func__);
|
| 2562 |
+
}
|
| 2563 |
+
|
| 2564 |
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
|
| 2565 |
assert(k % QK_K == 0);
|
| 2566 |
const int nb = k / QK_K;
|
|
|
|
| 2604 |
return (n/QK_K*sizeof(block_iq2_xxs));
|
| 2605 |
}
|
| 2606 |
|
| 2607 |
+
// ====================== 2.3125 bpw (de)-quantization
|
| 2608 |
+
|
| 2609 |
+
void quantize_row_iq2_xs_reference(const float * restrict x, block_iq2_xs * restrict y, int k) {
|
| 2610 |
+
(void)x;
|
| 2611 |
+
(void)y;
|
| 2612 |
+
(void)k;
|
| 2613 |
+
assert(k % QK_K == 0);
|
| 2614 |
+
//fprintf(stderr, "=========================== %s: not implemented\n", __func__);
|
| 2615 |
+
}
|
| 2616 |
+
|
| 2617 |
+
void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) {
|
| 2618 |
+
assert(k % QK_K == 0);
|
| 2619 |
+
const int nb = k / QK_K;
|
| 2620 |
+
|
| 2621 |
+
float db[2];
|
| 2622 |
+
|
| 2623 |
+
for (int i = 0; i < nb; i++) {
|
| 2624 |
+
|
| 2625 |
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
| 2626 |
+
|
| 2627 |
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
| 2628 |
+
db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;
|
| 2629 |
+
db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f;
|
| 2630 |
+
for (int l = 0; l < 4; ++l) {
|
| 2631 |
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511));
|
| 2632 |
+
const uint8_t signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9];
|
| 2633 |
+
for (int j = 0; j < 8; ++j) {
|
| 2634 |
+
y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
| 2635 |
+
}
|
| 2636 |
+
y += 8;
|
| 2637 |
+
}
|
| 2638 |
+
}
|
| 2639 |
+
}
|
| 2640 |
+
}
|
| 2641 |
+
|
| 2642 |
+
void quantize_row_iq2_xs(const float * restrict x, void * restrict vy, int k) {
|
| 2643 |
+
assert(k % QK_K == 0);
|
| 2644 |
+
block_iq2_xs * restrict y = vy;
|
| 2645 |
+
quantize_row_iq2_xs_reference(x, y, k);
|
| 2646 |
+
}
|
| 2647 |
+
|
| 2648 |
+
size_t ggml_quantize_iq2_xs(const float * src, void * dst, int n, int k, int64_t * hist) {
|
| 2649 |
+
assert(k % QK_K == 0);
|
| 2650 |
+
(void)hist; // TODO: collect histograms
|
| 2651 |
+
|
| 2652 |
+
for (int j = 0; j < n; j += k) {
|
| 2653 |
+
block_iq2_xs * restrict y = (block_iq2_xs *)dst + j/QK_K;
|
| 2654 |
+
quantize_row_iq2_xs_reference(src + j, y, k);
|
| 2655 |
+
}
|
| 2656 |
+
return (n/QK_K*sizeof(block_iq2_xs));
|
| 2657 |
+
}
|
| 2658 |
+
|
| 2659 |
//===================================== Q8_K ==============================================
|
| 2660 |
|
| 2661 |
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
|
|
|
| 7541 |
*s = 0.125f * sumf;
|
| 7542 |
#endif
|
| 7543 |
}
|
| 7544 |
+
|
| 7545 |
+
void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
| 7546 |
+
assert(n % QK_K == 0);
|
| 7547 |
+
|
| 7548 |
+
const block_iq2_xs * restrict x = vx;
|
| 7549 |
+
const block_q8_K * restrict y = vy;
|
| 7550 |
+
|
| 7551 |
+
const int nb = n / QK_K;
|
| 7552 |
+
|
| 7553 |
+
#if defined(__ARM_NEON)
|
| 7554 |
+
|
| 7555 |
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
| 7556 |
+
|
| 7557 |
+
int8x16x4_t q2u;
|
| 7558 |
+
int8x16x4_t q2s;
|
| 7559 |
+
int8x16x4_t q8b;
|
| 7560 |
+
|
| 7561 |
+
int32x4x4_t scales32;
|
| 7562 |
+
|
| 7563 |
+
float sumf = 0;
|
| 7564 |
+
for (int i = 0; i < nb; ++i) {
|
| 7565 |
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
| 7566 |
+
const uint16_t * restrict q2 = x[i].qs;
|
| 7567 |
+
const int8_t * restrict q8 = y[i].qs;
|
| 7568 |
+
const uint8x8_t scales8 = vld1_u8(x[i].scales);
|
| 7569 |
+
const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
|
| 7570 |
+
const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
|
| 7571 |
+
uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
|
| 7572 |
+
scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
|
| 7573 |
+
const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
|
| 7574 |
+
const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
|
| 7575 |
+
scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
|
| 7576 |
+
scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
|
| 7577 |
+
scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
|
| 7578 |
+
scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
|
| 7579 |
+
int32x4_t sumi = vdupq_n_s32(0);
|
| 7580 |
+
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
|
| 7581 |
+
q8b = vld1q_s8_x4(q8); q8 += 64;
|
| 7582 |
+
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
|
| 7583 |
+
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
|
| 7584 |
+
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
|
| 7585 |
+
q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));
|
| 7586 |
+
q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));
|
| 7587 |
+
q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));
|
| 7588 |
+
q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));
|
| 7589 |
+
q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));
|
| 7590 |
+
q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
|
| 7591 |
+
q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
|
| 7592 |
+
q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
|
| 7593 |
+
q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
|
| 7594 |
+
const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);
|
| 7595 |
+
const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
|
| 7596 |
+
const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
|
| 7597 |
+
const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
|
| 7598 |
+
const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
|
| 7599 |
+
sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
|
| 7600 |
+
q2 += 8;
|
| 7601 |
+
}
|
| 7602 |
+
sumf += d*vaddvq_s32(sumi);
|
| 7603 |
+
}
|
| 7604 |
+
*s = 0.125f * sumf;
|
| 7605 |
+
|
| 7606 |
+
#elif defined(__AVX2__)
|
| 7607 |
+
|
| 7608 |
+
const __m128i m4 = _mm_set1_epi8(0xf);
|
| 7609 |
+
const __m128i m1 = _mm_set1_epi8(1);
|
| 7610 |
+
const __m128i m511 = _mm_set1_epi16(511);
|
| 7611 |
+
const __m128i m127 = _mm_set1_epi16(127);
|
| 7612 |
+
|
| 7613 |
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
| 7614 |
+
|
| 7615 |
+
uint64_t aux64;
|
| 7616 |
+
|
| 7617 |
+
// somewhat hacky, but gives a significant boost in performance
|
| 7618 |
+
__m128i aux_gindex, aux_sindex;
|
| 7619 |
+
const uint16_t * gindex = (const uint16_t *)&aux_gindex;
|
| 7620 |
+
const uint16_t * sindex = (const uint16_t *)&aux_sindex;
|
| 7621 |
+
|
| 7622 |
+
__m256 accumf = _mm256_setzero_ps();
|
| 7623 |
+
for (int i = 0; i < nb; ++i) {
|
| 7624 |
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
| 7625 |
+
const uint16_t * restrict q2 = x[i].qs;
|
| 7626 |
+
const int8_t * restrict q8 = y[i].qs;
|
| 7627 |
+
|
| 7628 |
+
memcpy(&aux64, x[i].scales, 8);
|
| 7629 |
+
__m128i stmp = _mm_set1_epi64x(aux64);
|
| 7630 |
+
stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
|
| 7631 |
+
const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
|
| 7632 |
+
|
| 7633 |
+
__m256i sumi1 = _mm256_setzero_si256();
|
| 7634 |
+
__m256i sumi2 = _mm256_setzero_si256();
|
| 7635 |
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 7636 |
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 7637 |
+
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 7638 |
+
const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2); q2 += 8;
|
| 7639 |
+
aux_gindex = _mm_and_si128(q2_data, m511);
|
| 7640 |
+
aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127);
|
| 7641 |
+
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
|
| 7642 |
+
const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
|
| 7643 |
+
const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]);
|
| 7644 |
+
const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]);
|
| 7645 |
+
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
|
| 7646 |
+
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
|
| 7647 |
+
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
| 7648 |
+
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
| 7649 |
+
|
| 7650 |
+
const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
|
| 7651 |
+
const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
|
| 7652 |
+
|
| 7653 |
+
sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
|
| 7654 |
+
sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
|
| 7655 |
+
}
|
| 7656 |
+
|
| 7657 |
+
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
|
| 7658 |
+
|
| 7659 |
+
}
|
| 7660 |
+
|
| 7661 |
+
*s = 0.125f * hsum_float_8(accumf);
|
| 7662 |
+
|
| 7663 |
+
#else
|
| 7664 |
+
|
| 7665 |
+
float sumf = 0.f;
|
| 7666 |
+
for (int i = 0; i < nb; ++i) {
|
| 7667 |
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
| 7668 |
+
const uint16_t * restrict q2 = x[i].qs;
|
| 7669 |
+
const uint8_t * restrict sc = x[i].scales;
|
| 7670 |
+
const int8_t * restrict q8 = y[i].qs;
|
| 7671 |
+
int32_t bsum = 0;
|
| 7672 |
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
| 7673 |
+
const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
|
| 7674 |
+
const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1;
|
| 7675 |
+
int32_t sumi = 0;
|
| 7676 |
+
for (int l = 0; l < 2; ++l) {
|
| 7677 |
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
|
| 7678 |
+
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
|
| 7679 |
+
for (int j = 0; j < 8; ++j) {
|
| 7680 |
+
sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
| 7681 |
+
}
|
| 7682 |
+
q8 += 8;
|
| 7683 |
+
}
|
| 7684 |
+
bsum += sumi * ls1;
|
| 7685 |
+
sumi = 0;
|
| 7686 |
+
for (int l = 2; l < 4; ++l) {
|
| 7687 |
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
|
| 7688 |
+
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
|
| 7689 |
+
for (int j = 0; j < 8; ++j) {
|
| 7690 |
+
sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
| 7691 |
+
}
|
| 7692 |
+
q8 += 8;
|
| 7693 |
+
}
|
| 7694 |
+
bsum += sumi * ls2;
|
| 7695 |
+
q2 += 4;
|
| 7696 |
+
}
|
| 7697 |
+
sumf += d * bsum;
|
| 7698 |
+
}
|
| 7699 |
+
*s = 0.125f * sumf;
|
| 7700 |
+
#endif
|
| 7701 |
+
}
|
ggml-quants.h
CHANGED
|
@@ -174,6 +174,14 @@ typedef struct {
|
|
| 174 |
} block_iq2_xxs;
|
| 175 |
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
// Quantization
|
| 178 |
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
|
| 179 |
void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
|
|
@@ -189,6 +197,7 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
|
|
| 189 |
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
|
| 190 |
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
|
| 191 |
void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k);
|
|
|
|
| 192 |
|
| 193 |
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
|
| 194 |
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
|
|
@@ -204,6 +213,7 @@ void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
|
|
| 204 |
void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
|
| 205 |
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
|
| 206 |
void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k);
|
|
|
|
| 207 |
|
| 208 |
// Dequantization
|
| 209 |
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
|
|
@@ -220,6 +230,7 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
|
|
| 220 |
void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
|
| 221 |
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
|
| 222 |
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
|
|
|
|
| 223 |
|
| 224 |
// Dot product
|
| 225 |
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
|
@@ -234,3 +245,4 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx,
|
|
| 234 |
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 235 |
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 236 |
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
|
|
|
|
|
| 174 |
} block_iq2_xxs;
|
| 175 |
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
|
| 176 |
|
| 177 |
+
// 2.3125 bpw quants
|
| 178 |
+
typedef struct {
|
| 179 |
+
ggml_fp16_t d;
|
| 180 |
+
uint16_t qs[QK_K/8];
|
| 181 |
+
uint8_t scales[QK_K/32];
|
| 182 |
+
} block_iq2_xs;
|
| 183 |
+
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
|
| 184 |
+
|
| 185 |
// Quantization
|
| 186 |
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
|
| 187 |
void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
|
|
|
|
| 197 |
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
|
| 198 |
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
|
| 199 |
void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k);
|
| 200 |
+
void quantize_row_iq2_xs_reference (const float * restrict x, block_iq2_xs * restrict y, int k);
|
| 201 |
|
| 202 |
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
|
| 203 |
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
|
|
|
|
| 213 |
void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
|
| 214 |
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
|
| 215 |
void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k);
|
| 216 |
+
void quantize_row_iq2_xs (const float * restrict x, void * restrict y, int k);
|
| 217 |
|
| 218 |
// Dequantization
|
| 219 |
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
|
|
|
|
| 230 |
void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
|
| 231 |
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
|
| 232 |
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
|
| 233 |
+
void dequantize_row_iq2_xs (const block_iq2_xs * restrict x, float * restrict y, int k);
|
| 234 |
|
| 235 |
// Dot product
|
| 236 |
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
|
|
|
| 245 |
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 246 |
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 247 |
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
| 248 |
+
void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
ggml.c
CHANGED
|
@@ -584,6 +584,17 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|
| 584 |
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
|
| 585 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 586 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
[GGML_TYPE_Q8_K] = {
|
| 588 |
.type_name = "q8_K",
|
| 589 |
.blck_size = QK_K,
|
|
@@ -2123,6 +2134,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
|
| 2123 |
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
|
| 2124 |
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
|
| 2125 |
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
|
|
|
|
| 2126 |
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
| 2127 |
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
| 2128 |
}
|
|
@@ -7435,6 +7447,7 @@ static void ggml_compute_forward_add(
|
|
| 7435 |
case GGML_TYPE_Q5_K:
|
| 7436 |
case GGML_TYPE_Q6_K:
|
| 7437 |
case GGML_TYPE_IQ2_XXS:
|
|
|
|
| 7438 |
{
|
| 7439 |
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
| 7440 |
} break;
|
|
@@ -7700,6 +7713,7 @@ static void ggml_compute_forward_add1(
|
|
| 7700 |
case GGML_TYPE_Q5_K:
|
| 7701 |
case GGML_TYPE_Q6_K:
|
| 7702 |
case GGML_TYPE_IQ2_XXS:
|
|
|
|
| 7703 |
{
|
| 7704 |
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
|
| 7705 |
} break;
|
|
@@ -7815,6 +7829,7 @@ static void ggml_compute_forward_acc(
|
|
| 7815 |
case GGML_TYPE_Q5_K:
|
| 7816 |
case GGML_TYPE_Q6_K:
|
| 7817 |
case GGML_TYPE_IQ2_XXS:
|
|
|
|
| 7818 |
default:
|
| 7819 |
{
|
| 7820 |
GGML_ASSERT(false);
|
|
@@ -10457,6 +10472,7 @@ static void ggml_compute_forward_out_prod(
|
|
| 10457 |
case GGML_TYPE_Q5_K:
|
| 10458 |
case GGML_TYPE_Q6_K:
|
| 10459 |
case GGML_TYPE_IQ2_XXS:
|
|
|
|
| 10460 |
{
|
| 10461 |
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
| 10462 |
} break;
|
|
@@ -10632,6 +10648,7 @@ static void ggml_compute_forward_set(
|
|
| 10632 |
case GGML_TYPE_Q5_K:
|
| 10633 |
case GGML_TYPE_Q6_K:
|
| 10634 |
case GGML_TYPE_IQ2_XXS:
|
|
|
|
| 10635 |
default:
|
| 10636 |
{
|
| 10637 |
GGML_ASSERT(false);
|
|
@@ -10827,6 +10844,7 @@ static void ggml_compute_forward_get_rows(
|
|
| 10827 |
case GGML_TYPE_Q5_K:
|
| 10828 |
case GGML_TYPE_Q6_K:
|
| 10829 |
case GGML_TYPE_IQ2_XXS:
|
|
|
|
| 10830 |
{
|
| 10831 |
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
| 10832 |
} break;
|
|
@@ -11464,6 +11482,7 @@ static void ggml_compute_forward_alibi(
|
|
| 11464 |
case GGML_TYPE_Q5_K:
|
| 11465 |
case GGML_TYPE_Q6_K:
|
| 11466 |
case GGML_TYPE_IQ2_XXS:
|
|
|
|
| 11467 |
case GGML_TYPE_Q8_K:
|
| 11468 |
case GGML_TYPE_I8:
|
| 11469 |
case GGML_TYPE_I16:
|
|
@@ -11539,6 +11558,7 @@ static void ggml_compute_forward_clamp(
|
|
| 11539 |
case GGML_TYPE_Q5_K:
|
| 11540 |
case GGML_TYPE_Q6_K:
|
| 11541 |
case GGML_TYPE_IQ2_XXS:
|
|
|
|
| 11542 |
case GGML_TYPE_Q8_K:
|
| 11543 |
case GGML_TYPE_I8:
|
| 11544 |
case GGML_TYPE_I16:
|
|
@@ -18660,6 +18680,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
| 18660 |
block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K;
|
| 18661 |
result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist);
|
| 18662 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18663 |
case GGML_TYPE_F16:
|
| 18664 |
{
|
| 18665 |
int elemsize = sizeof(ggml_fp16_t);
|
|
@@ -19015,8 +19041,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|
| 19015 |
(int64_t) info->ne[3];
|
| 19016 |
|
| 19017 |
if (ne % ggml_blck_size(info->type) != 0) {
|
| 19018 |
-
fprintf(stderr, "%s: tensor '%s' number of elements (%" PRId64 ") is not a multiple of block size (%d)\n",
|
| 19019 |
-
__func__, info->name.data, ne, ggml_blck_size(info->type));
|
| 19020 |
fclose(file);
|
| 19021 |
gguf_free(ctx);
|
| 19022 |
return NULL;
|
|
|
|
| 584 |
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
|
| 585 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 586 |
},
|
| 587 |
+
[GGML_TYPE_IQ2_XS] = {
|
| 588 |
+
.type_name = "iq2_xs",
|
| 589 |
+
.blck_size = QK_K,
|
| 590 |
+
.type_size = sizeof(block_iq2_xs),
|
| 591 |
+
.is_quantized = true,
|
| 592 |
+
.to_float = (ggml_to_float_t) dequantize_row_iq2_xs,
|
| 593 |
+
.from_float = quantize_row_iq2_xs,
|
| 594 |
+
.from_float_reference = (ggml_from_float_t) quantize_row_iq2_xs_reference,
|
| 595 |
+
.vec_dot = ggml_vec_dot_iq2_xs_q8_K,
|
| 596 |
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 597 |
+
},
|
| 598 |
[GGML_TYPE_Q8_K] = {
|
| 599 |
.type_name = "q8_K",
|
| 600 |
.blck_size = QK_K,
|
|
|
|
| 2134 |
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
|
| 2135 |
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
|
| 2136 |
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
|
| 2137 |
+
case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
|
| 2138 |
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
| 2139 |
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
| 2140 |
}
|
|
|
|
| 7447 |
case GGML_TYPE_Q5_K:
|
| 7448 |
case GGML_TYPE_Q6_K:
|
| 7449 |
case GGML_TYPE_IQ2_XXS:
|
| 7450 |
+
case GGML_TYPE_IQ2_XS:
|
| 7451 |
{
|
| 7452 |
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
| 7453 |
} break;
|
|
|
|
| 7713 |
case GGML_TYPE_Q5_K:
|
| 7714 |
case GGML_TYPE_Q6_K:
|
| 7715 |
case GGML_TYPE_IQ2_XXS:
|
| 7716 |
+
case GGML_TYPE_IQ2_XS:
|
| 7717 |
{
|
| 7718 |
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
|
| 7719 |
} break;
|
|
|
|
| 7829 |
case GGML_TYPE_Q5_K:
|
| 7830 |
case GGML_TYPE_Q6_K:
|
| 7831 |
case GGML_TYPE_IQ2_XXS:
|
| 7832 |
+
case GGML_TYPE_IQ2_XS:
|
| 7833 |
default:
|
| 7834 |
{
|
| 7835 |
GGML_ASSERT(false);
|
|
|
|
| 10472 |
case GGML_TYPE_Q5_K:
|
| 10473 |
case GGML_TYPE_Q6_K:
|
| 10474 |
case GGML_TYPE_IQ2_XXS:
|
| 10475 |
+
case GGML_TYPE_IQ2_XS:
|
| 10476 |
{
|
| 10477 |
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
| 10478 |
} break;
|
|
|
|
| 10648 |
case GGML_TYPE_Q5_K:
|
| 10649 |
case GGML_TYPE_Q6_K:
|
| 10650 |
case GGML_TYPE_IQ2_XXS:
|
| 10651 |
+
case GGML_TYPE_IQ2_XS:
|
| 10652 |
default:
|
| 10653 |
{
|
| 10654 |
GGML_ASSERT(false);
|
|
|
|
| 10844 |
case GGML_TYPE_Q5_K:
|
| 10845 |
case GGML_TYPE_Q6_K:
|
| 10846 |
case GGML_TYPE_IQ2_XXS:
|
| 10847 |
+
case GGML_TYPE_IQ2_XS:
|
| 10848 |
{
|
| 10849 |
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
| 10850 |
} break;
|
|
|
|
| 11482 |
case GGML_TYPE_Q5_K:
|
| 11483 |
case GGML_TYPE_Q6_K:
|
| 11484 |
case GGML_TYPE_IQ2_XXS:
|
| 11485 |
+
case GGML_TYPE_IQ2_XS:
|
| 11486 |
case GGML_TYPE_Q8_K:
|
| 11487 |
case GGML_TYPE_I8:
|
| 11488 |
case GGML_TYPE_I16:
|
|
|
|
| 11558 |
case GGML_TYPE_Q5_K:
|
| 11559 |
case GGML_TYPE_Q6_K:
|
| 11560 |
case GGML_TYPE_IQ2_XXS:
|
| 11561 |
+
case GGML_TYPE_IQ2_XS:
|
| 11562 |
case GGML_TYPE_Q8_K:
|
| 11563 |
case GGML_TYPE_I8:
|
| 11564 |
case GGML_TYPE_I16:
|
|
|
|
| 18680 |
block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K;
|
| 18681 |
result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist);
|
| 18682 |
} break;
|
| 18683 |
+
case GGML_TYPE_IQ2_XS:
|
| 18684 |
+
{
|
| 18685 |
+
GGML_ASSERT(start % QK_K == 0);
|
| 18686 |
+
block_iq2_xs * block = (block_iq2_xs*)dst + start / QK_K;
|
| 18687 |
+
result = ggml_quantize_iq2_xs(src + start, block, n, n, hist);
|
| 18688 |
+
} break;
|
| 18689 |
case GGML_TYPE_F16:
|
| 18690 |
{
|
| 18691 |
int elemsize = sizeof(ggml_fp16_t);
|
|
|
|
| 19041 |
(int64_t) info->ne[3];
|
| 19042 |
|
| 19043 |
if (ne % ggml_blck_size(info->type) != 0) {
|
| 19044 |
+
fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%d)\n",
|
| 19045 |
+
__func__, info->name.data, (int)info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
|
| 19046 |
fclose(file);
|
| 19047 |
gguf_free(ctx);
|
| 19048 |
return NULL;
|
ggml.h
CHANGED
|
@@ -342,6 +342,7 @@ extern "C" {
|
|
| 342 |
GGML_TYPE_Q6_K = 14,
|
| 343 |
GGML_TYPE_Q8_K = 15,
|
| 344 |
GGML_TYPE_IQ2_XXS = 16,
|
|
|
|
| 345 |
GGML_TYPE_I8,
|
| 346 |
GGML_TYPE_I16,
|
| 347 |
GGML_TYPE_I32,
|
|
@@ -377,6 +378,7 @@ extern "C" {
|
|
| 377 |
GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
|
| 378 |
GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
|
| 379 |
GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
|
|
|
|
| 380 |
};
|
| 381 |
|
| 382 |
// available tensor operations:
|
|
@@ -2061,6 +2063,7 @@ extern "C" {
|
|
| 2061 |
GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
| 2062 |
GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
| 2063 |
GGML_API size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist);
|
|
|
|
| 2064 |
|
| 2065 |
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
|
| 2066 |
|
|
|
|
| 342 |
GGML_TYPE_Q6_K = 14,
|
| 343 |
GGML_TYPE_Q8_K = 15,
|
| 344 |
GGML_TYPE_IQ2_XXS = 16,
|
| 345 |
+
GGML_TYPE_IQ2_XS = 17,
|
| 346 |
GGML_TYPE_I8,
|
| 347 |
GGML_TYPE_I16,
|
| 348 |
GGML_TYPE_I32,
|
|
|
|
| 378 |
GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
|
| 379 |
GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
|
| 380 |
GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
|
| 381 |
+
GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
|
| 382 |
};
|
| 383 |
|
| 384 |
// available tensor operations:
|
|
|
|
| 2063 |
GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
| 2064 |
GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
| 2065 |
GGML_API size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist);
|
| 2066 |
+
GGML_API size_t ggml_quantize_iq2_xs (const float * src, void * dst, int n, int k, int64_t * hist);
|
| 2067 |
|
| 2068 |
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
|
| 2069 |
|