Kawrakow ikawrakow commited on
Commit
5e827d5
·
unverified ·
1 Parent(s): 6e822b8

ggml : SOTA 2-bit quants (add IQ2_XS) (llama/4856)

Browse files

* iq2_xs: basics

* iq2_xs: this should have been in the basics

* iq2_xs: CUDA and scalar CPU works

* iq2_xs: WIP Metal

* iq2_xs: Metal now works

* iq2_xs: working, but dog slow, ARM_NEON dot product

* iq2_xs: better ARM_NEON dot product

We are now at 19.5 t/s for TG-128 and 61 t/s for PP-512 when
running on the CPU.

* iq2_xs: AVX2 dot product - 19.5 t/s

* iq2_xs: faster AVX2 dit product

21.4 t/s for TG-128, 59.2 t/s for PP-512.
The latter is 2x compared to the previous version.

* iq2_xs: had forgotten to delete iq2-data.h

* Add llama enum for IQ2_XS

---------

Co-authored-by: Iwan Kawrakow <[email protected]>

Files changed (7) hide show
  1. ggml-cuda.cu +227 -5
  2. ggml-metal.m +36 -6
  3. ggml-metal.metal +374 -4
  4. ggml-quants.c +351 -9
  5. ggml-quants.h +12 -0
  6. ggml.c +28 -2
  7. ggml.h +3 -0
ggml-cuda.cu CHANGED
@@ -486,6 +486,15 @@ typedef struct {
486
  } block_iq2_xxs;
487
  static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
488
 
 
 
 
 
 
 
 
 
 
489
  #define WARP_SIZE 32
490
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
491
 
@@ -1328,7 +1337,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
1328
  #endif
1329
  }
1330
 
1331
- static const __device__ uint64_t kgrid_iq2xxs[256] = {
1332
  0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
1333
  0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
1334
  0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
@@ -1395,6 +1404,137 @@ static const __device__ uint64_t kgrid_iq2xxs[256] = {
1395
  0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
1396
  };
1397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1398
  static const __device__ uint8_t ksigns_iq2xs[128] = {
1399
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
1400
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -1439,7 +1579,7 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
1439
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
1440
  const uint16_t * q2 = x[i].qs + 4*ib;
1441
  const uint8_t * aux8 = (const uint8_t *)q2;
1442
- const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[il]);
1443
  const uint32_t aux32 = q2[2] | (q2[3] << 16);
1444
  const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
1445
  const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
@@ -1450,6 +1590,28 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
1450
 
1451
  }
1452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1453
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1454
 
1455
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -3996,7 +4158,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
3996
  uint32_t aux32 = q2[2] | (q2[3] << 16);
3997
  int sumi = 0;
3998
  for (int l = 0; l < 4; ++l) {
3999
- const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[l]);
4000
  const uint8_t signs = ksigns_iq2xs[aux32 & 127];
4001
  for (int j = 0; j < 8; ++j) {
4002
  sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
@@ -4012,8 +4174,8 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
4012
  const int il = iqs%2;
4013
  const uint16_t * q2 = bq2->qs + 4*ib32;
4014
  const uint8_t * aux8 = (const uint8_t *)q2;
4015
- const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
4016
- const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
4017
  const uint32_t aux32 = q2[2] | (q2[3] << 16);
4018
  const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
4019
  const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
@@ -4032,6 +4194,42 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
4032
  #endif
4033
  }
4034
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4035
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
4036
  allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
4037
  static __device__ __forceinline__ void mul_mat_q(
@@ -6035,6 +6233,12 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k,
6035
  dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
6036
  }
6037
 
 
 
 
 
 
 
6038
  template <typename src_t, typename dst_t>
6039
  static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
6040
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@@ -6065,6 +6269,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
6065
  return dequantize_row_q6_K_cuda;
6066
  case GGML_TYPE_IQ2_XXS:
6067
  return dequantize_row_iq2_xxs_cuda;
 
 
6068
  case GGML_TYPE_F32:
6069
  return convert_unary_cuda<float>;
6070
  default:
@@ -6096,6 +6302,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
6096
  return dequantize_row_q6_K_cuda;
6097
  case GGML_TYPE_IQ2_XXS:
6098
  return dequantize_row_iq2_xxs_cuda;
 
 
6099
  case GGML_TYPE_F16:
6100
  return convert_unary_cuda<half>;
6101
  default:
@@ -6299,6 +6507,15 @@ static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, floa
6299
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6300
  }
6301
 
 
 
 
 
 
 
 
 
 
6302
  static void ggml_mul_mat_q4_0_q8_1_cuda(
6303
  const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
6304
  const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -7871,6 +8088,7 @@ static int64_t get_row_rounding(ggml_type type) {
7871
  case GGML_TYPE_Q5_K:
7872
  case GGML_TYPE_Q6_K:
7873
  case GGML_TYPE_IQ2_XXS:
 
7874
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
7875
  default:
7876
  GGML_ASSERT(false);
@@ -7892,6 +8110,7 @@ static int64_t get_row_rounding(ggml_type type) {
7892
  case GGML_TYPE_Q4_K:
7893
  case GGML_TYPE_Q5_K:
7894
  case GGML_TYPE_IQ2_XXS:
 
7895
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
7896
  case GGML_TYPE_Q6_K:
7897
  return 64;
@@ -7945,6 +8164,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
7945
  case GGML_TYPE_IQ2_XXS:
7946
  mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
7947
  break;
 
 
 
7948
  default:
7949
  GGML_ASSERT(false);
7950
  break;
 
486
  } block_iq2_xxs;
487
  static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
488
 
489
+ #define QR2_XS 8
490
+ #define QI2_XS (QK_K / (4*QR2_XS))
491
+ typedef struct {
492
+ half d;
493
+ uint16_t qs[QK_K/8];
494
+ uint8_t scales[QK_K/32];
495
+ } block_iq2_xs;
496
+ static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
497
+
498
  #define WARP_SIZE 32
499
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
500
 
 
1337
  #endif
1338
  }
1339
 
1340
+ static const __device__ uint64_t iq2xxs_grid[256] = {
1341
  0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
1342
  0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
1343
  0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
 
1404
  0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
1405
  };
1406
 
1407
+ static const __device__ uint64_t iq2xs_grid[512] = {
1408
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
1409
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
1410
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
1411
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
1412
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
1413
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
1414
+ 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
1415
+ 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
1416
+ 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
1417
+ 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
1418
+ 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
1419
+ 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
1420
+ 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
1421
+ 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
1422
+ 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
1423
+ 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
1424
+ 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
1425
+ 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
1426
+ 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
1427
+ 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
1428
+ 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
1429
+ 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
1430
+ 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
1431
+ 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
1432
+ 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
1433
+ 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
1434
+ 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
1435
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
1436
+ 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
1437
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
1438
+ 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
1439
+ 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
1440
+ 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
1441
+ 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
1442
+ 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
1443
+ 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
1444
+ 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
1445
+ 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
1446
+ 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
1447
+ 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
1448
+ 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
1449
+ 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
1450
+ 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
1451
+ 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
1452
+ 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
1453
+ 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
1454
+ 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
1455
+ 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
1456
+ 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
1457
+ 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
1458
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
1459
+ 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
1460
+ 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
1461
+ 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
1462
+ 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
1463
+ 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
1464
+ 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
1465
+ 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
1466
+ 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
1467
+ 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
1468
+ 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
1469
+ 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
1470
+ 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
1471
+ 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
1472
+ 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
1473
+ 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
1474
+ 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
1475
+ 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
1476
+ 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
1477
+ 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
1478
+ 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
1479
+ 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
1480
+ 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
1481
+ 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
1482
+ 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
1483
+ 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
1484
+ 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
1485
+ 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
1486
+ 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
1487
+ 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
1488
+ 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
1489
+ 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
1490
+ 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
1491
+ 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
1492
+ 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
1493
+ 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
1494
+ 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
1495
+ 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
1496
+ 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
1497
+ 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
1498
+ 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
1499
+ 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
1500
+ 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
1501
+ 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
1502
+ 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
1503
+ 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
1504
+ 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
1505
+ 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
1506
+ 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
1507
+ 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
1508
+ 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
1509
+ 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
1510
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
1511
+ 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
1512
+ 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
1513
+ 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
1514
+ 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
1515
+ 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
1516
+ 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
1517
+ 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
1518
+ 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
1519
+ 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
1520
+ 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
1521
+ 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
1522
+ 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
1523
+ 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
1524
+ 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
1525
+ 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
1526
+ 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
1527
+ 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
1528
+ 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
1529
+ 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
1530
+ 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
1531
+ 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
1532
+ 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
1533
+ 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
1534
+ 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
1535
+ 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
1536
+ };
1537
+
1538
  static const __device__ uint8_t ksigns_iq2xs[128] = {
1539
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
1540
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
 
1579
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
1580
  const uint16_t * q2 = x[i].qs + 4*ib;
1581
  const uint8_t * aux8 = (const uint8_t *)q2;
1582
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
1583
  const uint32_t aux32 = q2[2] | (q2[3] << 16);
1584
  const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
1585
  const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
 
1590
 
1591
  }
1592
 
1593
+ template<typename dst_t>
1594
+ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
1595
+
1596
+ const int i = blockIdx.x;
1597
+ const block_iq2_xs * x = (const block_iq2_xs *) vx;
1598
+
1599
+ const int tid = threadIdx.x;
1600
+ #if QK_K == 256
1601
+ const int il = tid/8; // 0...3
1602
+ const int ib = tid%8; // 0...7
1603
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
1604
+ const uint16_t * q2 = x[i].qs + 4*ib;
1605
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
1606
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
1607
+ const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
1608
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
1609
+ #else
1610
+ assert(false);
1611
+ #endif
1612
+
1613
+ }
1614
+
1615
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1616
 
1617
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
4158
  uint32_t aux32 = q2[2] | (q2[3] << 16);
4159
  int sumi = 0;
4160
  for (int l = 0; l < 4; ++l) {
4161
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
4162
  const uint8_t signs = ksigns_iq2xs[aux32 & 127];
4163
  for (int j = 0; j < 8; ++j) {
4164
  sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
 
4174
  const int il = iqs%2;
4175
  const uint16_t * q2 = bq2->qs + 4*ib32;
4176
  const uint8_t * aux8 = (const uint8_t *)q2;
4177
+ const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
4178
+ const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
4179
  const uint32_t aux32 = q2[2] | (q2[3] << 16);
4180
  const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
4181
  const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
 
4194
  #endif
4195
  }
4196
 
4197
+ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
4198
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4199
+ #if QK_K == 256
4200
+ const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
4201
+
4202
+ const int ib32 = iqs;
4203
+ const uint16_t * q2 = bq2->qs + 4*ib32;
4204
+ const int8_t * q8 = bq8_1[ib32].qs;
4205
+ const uint8_t ls1 = bq2->scales[ib32] & 0xf;
4206
+ const uint8_t ls2 = bq2->scales[ib32] >> 4;
4207
+ int sumi1 = 0;
4208
+ for (int l = 0; l < 2; ++l) {
4209
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
4210
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
4211
+ for (int j = 0; j < 8; ++j) {
4212
+ sumi1 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
4213
+ }
4214
+ q8 += 8;
4215
+ }
4216
+ int sumi2 = 0;
4217
+ for (int l = 2; l < 4; ++l) {
4218
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
4219
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
4220
+ for (int j = 0; j < 8; ++j) {
4221
+ sumi2 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
4222
+ }
4223
+ q8 += 8;
4224
+ }
4225
+ const float d = (float)bq2->d * (float)bq8_1[ib32].ds.x * 0.25f;
4226
+ return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
4227
+ #else
4228
+ assert(false);
4229
+ return 0.f;
4230
+ #endif
4231
+ }
4232
+
4233
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
4234
  allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
4235
  static __device__ __forceinline__ void mul_mat_q(
 
6233
  dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
6234
  }
6235
 
6236
+ template<typename dst_t>
6237
+ static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6238
+ const int nb = k / QK_K;
6239
+ dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
6240
+ }
6241
+
6242
  template <typename src_t, typename dst_t>
6243
  static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
6244
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
 
6269
  return dequantize_row_q6_K_cuda;
6270
  case GGML_TYPE_IQ2_XXS:
6271
  return dequantize_row_iq2_xxs_cuda;
6272
+ case GGML_TYPE_IQ2_XS:
6273
+ return dequantize_row_iq2_xs_cuda;
6274
  case GGML_TYPE_F32:
6275
  return convert_unary_cuda<float>;
6276
  default:
 
6302
  return dequantize_row_q6_K_cuda;
6303
  case GGML_TYPE_IQ2_XXS:
6304
  return dequantize_row_iq2_xxs_cuda;
6305
+ case GGML_TYPE_IQ2_XS:
6306
+ return dequantize_row_iq2_xs_cuda;
6307
  case GGML_TYPE_F16:
6308
  return convert_unary_cuda<half>;
6309
  default:
 
6507
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6508
  }
6509
 
6510
+ static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6511
+ GGML_ASSERT(ncols % QK_K == 0);
6512
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6513
+ const dim3 block_nums(block_num_y, 1, 1);
6514
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6515
+ mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
6516
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6517
+ }
6518
+
6519
  static void ggml_mul_mat_q4_0_q8_1_cuda(
6520
  const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
6521
  const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
8088
  case GGML_TYPE_Q5_K:
8089
  case GGML_TYPE_Q6_K:
8090
  case GGML_TYPE_IQ2_XXS:
8091
+ case GGML_TYPE_IQ2_XS:
8092
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
8093
  default:
8094
  GGML_ASSERT(false);
 
8110
  case GGML_TYPE_Q4_K:
8111
  case GGML_TYPE_Q5_K:
8112
  case GGML_TYPE_IQ2_XXS:
8113
+ case GGML_TYPE_IQ2_XS:
8114
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
8115
  case GGML_TYPE_Q6_K:
8116
  return 64;
 
8164
  case GGML_TYPE_IQ2_XXS:
8165
  mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8166
  break;
8167
+ case GGML_TYPE_IQ2_XS:
8168
+ mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8169
+ break;
8170
  default:
8171
  GGML_ASSERT(false);
8172
  break;
ggml-metal.m CHANGED
@@ -89,6 +89,7 @@ struct ggml_metal_context {
89
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
90
  GGML_METAL_DECL_KERNEL(get_rows_i32);
91
  GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
 
92
  GGML_METAL_DECL_KERNEL(rms_norm);
93
  GGML_METAL_DECL_KERNEL(group_norm);
94
  GGML_METAL_DECL_KERNEL(norm);
@@ -108,6 +109,7 @@ struct ggml_metal_context {
108
  GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
109
  GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
110
  GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
 
111
  GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
112
  //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
113
  GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
@@ -124,6 +126,7 @@ struct ggml_metal_context {
124
  GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
125
  GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
126
  GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
 
127
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
128
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
129
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -137,6 +140,7 @@ struct ggml_metal_context {
137
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
138
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
139
  GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
 
140
  GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
141
  GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
142
  GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
@@ -150,6 +154,7 @@ struct ggml_metal_context {
150
  GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
151
  GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
152
  GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
 
153
  GGML_METAL_DECL_KERNEL(rope_f32);
154
  GGML_METAL_DECL_KERNEL(rope_f16);
155
  GGML_METAL_DECL_KERNEL(alibi_f32);
@@ -385,6 +390,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
385
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
386
  GGML_METAL_ADD_KERNEL(get_rows_i32);
387
  GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
 
388
  GGML_METAL_ADD_KERNEL(rms_norm);
389
  GGML_METAL_ADD_KERNEL(group_norm);
390
  GGML_METAL_ADD_KERNEL(norm);
@@ -404,6 +410,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
404
  GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
405
  GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
406
  GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
 
407
  GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
408
  //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
409
  GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
@@ -420,6 +427,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
420
  GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
421
  GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
422
  GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
 
423
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
424
  GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
425
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -434,6 +442,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
434
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
435
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
436
  GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
 
437
  GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
438
  GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
439
  GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
@@ -447,6 +456,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
447
  GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
448
  GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
449
  GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
 
450
  }
451
  GGML_METAL_ADD_KERNEL(rope_f32);
452
  GGML_METAL_ADD_KERNEL(rope_f16);
@@ -513,6 +523,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
513
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
514
  GGML_METAL_DEL_KERNEL(get_rows_i32);
515
  GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
 
516
  GGML_METAL_DEL_KERNEL(rms_norm);
517
  GGML_METAL_DEL_KERNEL(group_norm);
518
  GGML_METAL_DEL_KERNEL(norm);
@@ -532,6 +543,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
532
  GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
533
  GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
534
  GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
 
535
  GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
536
  //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
537
  GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
@@ -548,6 +560,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
548
  GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
549
  GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
550
  GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
 
551
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
552
  GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
553
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -562,6 +575,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
562
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
563
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
564
  GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
 
565
  GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
566
  GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
567
  GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
@@ -575,6 +589,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
575
  GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
576
  GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
577
  GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
 
578
  }
579
  GGML_METAL_DEL_KERNEL(rope_f32);
580
  GGML_METAL_DEL_KERNEL(rope_f16);
@@ -1561,6 +1576,7 @@ bool ggml_metal_graph_compute(
1561
  case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
1562
  case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
1563
  case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
 
1564
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1565
  }
1566
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1679,6 +1695,12 @@ bool ggml_metal_graph_compute(
1679
  nth1 = 16;
1680
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
1681
  } break;
 
 
 
 
 
 
1682
  default:
1683
  {
1684
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1712,12 +1734,12 @@ bool ggml_metal_graph_compute(
1712
 
1713
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1714
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1715
- //src0t == GGML_TYPE_IQ2_XXS ||
1716
  src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1717
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1718
  }
1719
- else if (src0t == GGML_TYPE_IQ2_XXS) {
1720
- [encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
 
1721
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1722
  }
1723
  else if (src0t == GGML_TYPE_Q4_K) {
@@ -1810,6 +1832,7 @@ bool ggml_metal_graph_compute(
1810
  case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1811
  case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1812
  case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
 
1813
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1814
  }
1815
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1931,6 +1954,12 @@ bool ggml_metal_graph_compute(
1931
  nth1 = 16;
1932
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
1933
  } break;
 
 
 
 
 
 
1934
  default:
1935
  {
1936
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1980,12 +2009,12 @@ bool ggml_metal_graph_compute(
1980
 
1981
  if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1982
  src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1983
- //src2t == GGML_TYPE_IQ2_XXS ||
1984
  src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1985
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1986
  }
1987
- else if (src2t == GGML_TYPE_IQ2_XXS) {
1988
- [encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
 
1989
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1990
  }
1991
  else if (src2t == GGML_TYPE_Q4_K) {
@@ -2026,6 +2055,7 @@ bool ggml_metal_graph_compute(
2026
  case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
2027
  case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
2028
  case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
 
2029
  default: GGML_ASSERT(false && "not implemented");
2030
  }
2031
 
 
89
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
90
  GGML_METAL_DECL_KERNEL(get_rows_i32);
91
  GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
92
+ GGML_METAL_DECL_KERNEL(get_rows_iq2_xs);
93
  GGML_METAL_DECL_KERNEL(rms_norm);
94
  GGML_METAL_DECL_KERNEL(group_norm);
95
  GGML_METAL_DECL_KERNEL(norm);
 
109
  GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
110
  GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
111
  GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
112
+ GGML_METAL_DECL_KERNEL(mul_mv_iq2_xs_f32);
113
  GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
114
  //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
115
  GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
 
126
  GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
127
  GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
128
  GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
129
+ GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xs_f32);
130
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
131
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
132
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
 
140
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
141
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
142
  GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
143
+ GGML_METAL_DECL_KERNEL(mul_mm_iq2_xs_f32);
144
  GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
145
  GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
146
  GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
 
154
  GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
155
  GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
156
  GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
157
+ GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xs_f32);
158
  GGML_METAL_DECL_KERNEL(rope_f32);
159
  GGML_METAL_DECL_KERNEL(rope_f16);
160
  GGML_METAL_DECL_KERNEL(alibi_f32);
 
390
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
391
  GGML_METAL_ADD_KERNEL(get_rows_i32);
392
  GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
393
+ GGML_METAL_ADD_KERNEL(get_rows_iq2_xs);
394
  GGML_METAL_ADD_KERNEL(rms_norm);
395
  GGML_METAL_ADD_KERNEL(group_norm);
396
  GGML_METAL_ADD_KERNEL(norm);
 
410
  GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
411
  GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
412
  GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
413
+ GGML_METAL_ADD_KERNEL(mul_mv_iq2_xs_f32);
414
  GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
415
  //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
416
  GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
 
427
  GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
428
  GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
429
  GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
430
+ GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xs_f32);
431
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
432
  GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
433
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
 
442
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
443
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
444
  GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
445
+ GGML_METAL_ADD_KERNEL(mul_mm_iq2_xs_f32);
446
  GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
447
  GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
448
  GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
 
456
  GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
457
  GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
458
  GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
459
+ GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xs_f32);
460
  }
461
  GGML_METAL_ADD_KERNEL(rope_f32);
462
  GGML_METAL_ADD_KERNEL(rope_f16);
 
523
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
524
  GGML_METAL_DEL_KERNEL(get_rows_i32);
525
  GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
526
+ GGML_METAL_DEL_KERNEL(get_rows_iq2_xs);
527
  GGML_METAL_DEL_KERNEL(rms_norm);
528
  GGML_METAL_DEL_KERNEL(group_norm);
529
  GGML_METAL_DEL_KERNEL(norm);
 
543
  GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
544
  GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
545
  GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
546
+ GGML_METAL_DEL_KERNEL(mul_mv_iq2_xs_f32);
547
  GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
548
  //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
549
  GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
 
560
  GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
561
  GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
562
  GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
563
+ GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xs_f32);
564
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
565
  GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
566
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
 
575
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
576
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
577
  GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
578
+ GGML_METAL_DEL_KERNEL(mul_mm_iq2_xs_f32);
579
  GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
580
  GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
581
  GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
 
589
  GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
590
  GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
591
  GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
592
+ GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xs_f32);
593
  }
594
  GGML_METAL_DEL_KERNEL(rope_f32);
595
  GGML_METAL_DEL_KERNEL(rope_f16);
 
1576
  case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
1577
  case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
1578
  case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
1579
+ case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xs_f32]; break;
1580
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1581
  }
1582
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
 
1695
  nth1 = 16;
1696
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
1697
  } break;
1698
+ case GGML_TYPE_IQ2_XS:
1699
+ {
1700
+ nth0 = 4;
1701
+ nth1 = 16;
1702
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xs_f32];
1703
+ } break;
1704
  default:
1705
  {
1706
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
 
1734
 
1735
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1736
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
 
1737
  src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1738
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1739
  }
1740
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1741
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1742
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1743
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1744
  }
1745
  else if (src0t == GGML_TYPE_Q4_K) {
 
1832
  case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1833
  case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1834
  case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
1835
+ case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xs_f32]; break;
1836
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1837
  }
1838
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
 
1954
  nth1 = 16;
1955
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
1956
  } break;
1957
+ case GGML_TYPE_IQ2_XS:
1958
+ {
1959
+ nth0 = 4;
1960
+ nth1 = 16;
1961
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xs_f32];
1962
+ } break;
1963
  default:
1964
  {
1965
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
 
2009
 
2010
  if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
2011
  src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
 
2012
  src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
2013
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2014
  }
2015
+ else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
2016
+ const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2017
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2018
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2019
  }
2020
  else if (src2t == GGML_TYPE_Q4_K) {
 
2055
  case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
2056
  case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
2057
  case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
2058
+ case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xs]; break;
2059
  default: GGML_ASSERT(false && "not implemented");
2060
  }
2061
 
ggml-metal.metal CHANGED
@@ -2452,6 +2452,13 @@ typedef struct {
2452
  } block_iq2_xxs;
2453
  // 66 bytes / block for QK_K = 256, so 2.0625 bpw
2454
 
 
 
 
 
 
 
 
2455
  //====================================== dot products =========================
2456
 
2457
  void kernel_mul_mv_q2_K_f32_impl(
@@ -3476,7 +3483,7 @@ kernel void kernel_mul_mv_q6_K_f32(
3476
 
3477
  // ======================= "True" 2-bit
3478
 
3479
- constexpr constant static uint64_t kgrid_iq2xxs[256] = {
3480
  0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3481
  0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
3482
  0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
@@ -3543,6 +3550,137 @@ constexpr constant static uint64_t kgrid_iq2xxs[256] = {
3543
  0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
3544
  };
3545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3546
  constexpr constant static uint8_t ksigns_iq2xs[128] = {
3547
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3548
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -3600,7 +3738,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
3600
  {
3601
  int nval = 4;
3602
  int pos = (32*sgitg + tiisg)*nval;
3603
- for (int i = 0; i < nval; ++i) values[pos + i] = kgrid_iq2xxs[pos + i];
3604
  nval = 2;
3605
  pos = (32*sgitg + tiisg)*nval;
3606
  for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
@@ -3689,6 +3827,149 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
3689
  kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3690
  }
3691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3692
  //============================= templates and their specializations =============================
3693
 
3694
  // NOTE: this is not dequantizing - we are simply fitting the template
@@ -3973,18 +4254,39 @@ void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x
3973
  const uint32_t aux32_s = q2[2] | (q2[3] << 16);
3974
  thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
3975
  const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
3976
- constant uint8_t * grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
3977
  uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
3978
  for (int i = 0; i < 8; ++i) {
3979
  reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
3980
  }
3981
- grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
3982
  signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
3983
  for (int i = 0; i < 8; ++i) {
3984
  reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
3985
  }
3986
  }
3987
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3988
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
3989
  kernel void kernel_get_rows(
3990
  device const void * src0,
@@ -4525,6 +4827,7 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
4525
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
4526
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
4527
  template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
 
4528
 
4529
  //
4530
  // matrix-matrix multiplication
@@ -4562,6 +4865,7 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
4562
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
4563
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4564
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
 
4565
 
4566
  //
4567
  // indirect matrix-matrix multiplication
@@ -4611,6 +4915,7 @@ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mu
4611
  template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
4612
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4613
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
 
4614
 
4615
  //
4616
  // matrix-vector multiplication
@@ -5448,3 +5753,68 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
5448
  tiisg,
5449
  sgitg);
5450
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2452
  } block_iq2_xxs;
2453
  // 66 bytes / block for QK_K = 256, so 2.0625 bpw
2454
 
2455
+ typedef struct {
2456
+ half d;
2457
+ uint16_t qs[QK_K/8];
2458
+ uint8_t scales[QK_K/32];
2459
+ } block_iq2_xs;
2460
+ // 74 bytes / block for QK_K = 256, so 2.3125 bpw
2461
+
2462
  //====================================== dot products =========================
2463
 
2464
  void kernel_mul_mv_q2_K_f32_impl(
 
3483
 
3484
  // ======================= "True" 2-bit
3485
 
3486
+ constexpr constant static uint64_t iq2xxs_grid[256] = {
3487
  0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3488
  0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
3489
  0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
 
3550
  0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
3551
  };
3552
 
3553
+ constexpr constant static uint64_t iq2xs_grid[512] = {
3554
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3555
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
3556
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
3557
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
3558
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
3559
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
3560
+ 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
3561
+ 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
3562
+ 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
3563
+ 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
3564
+ 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
3565
+ 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
3566
+ 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
3567
+ 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
3568
+ 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
3569
+ 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
3570
+ 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
3571
+ 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
3572
+ 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
3573
+ 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
3574
+ 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
3575
+ 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
3576
+ 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
3577
+ 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
3578
+ 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
3579
+ 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
3580
+ 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
3581
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
3582
+ 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
3583
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
3584
+ 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
3585
+ 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
3586
+ 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
3587
+ 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
3588
+ 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
3589
+ 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
3590
+ 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
3591
+ 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
3592
+ 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
3593
+ 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
3594
+ 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
3595
+ 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
3596
+ 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
3597
+ 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
3598
+ 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
3599
+ 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
3600
+ 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
3601
+ 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
3602
+ 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
3603
+ 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
3604
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
3605
+ 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
3606
+ 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
3607
+ 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
3608
+ 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
3609
+ 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
3610
+ 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
3611
+ 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
3612
+ 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
3613
+ 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
3614
+ 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
3615
+ 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
3616
+ 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
3617
+ 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
3618
+ 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
3619
+ 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
3620
+ 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
3621
+ 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
3622
+ 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
3623
+ 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
3624
+ 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
3625
+ 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
3626
+ 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
3627
+ 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
3628
+ 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
3629
+ 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
3630
+ 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
3631
+ 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
3632
+ 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
3633
+ 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
3634
+ 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
3635
+ 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
3636
+ 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
3637
+ 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
3638
+ 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
3639
+ 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
3640
+ 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
3641
+ 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
3642
+ 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
3643
+ 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
3644
+ 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
3645
+ 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
3646
+ 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
3647
+ 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
3648
+ 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
3649
+ 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
3650
+ 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
3651
+ 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
3652
+ 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
3653
+ 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
3654
+ 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
3655
+ 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
3656
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
3657
+ 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
3658
+ 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
3659
+ 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
3660
+ 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
3661
+ 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
3662
+ 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
3663
+ 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
3664
+ 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
3665
+ 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
3666
+ 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
3667
+ 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
3668
+ 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
3669
+ 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
3670
+ 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
3671
+ 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
3672
+ 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
3673
+ 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
3674
+ 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
3675
+ 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
3676
+ 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
3677
+ 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
3678
+ 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
3679
+ 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
3680
+ 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
3681
+ 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3682
+ };
3683
+
3684
  constexpr constant static uint8_t ksigns_iq2xs[128] = {
3685
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3686
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
 
3738
  {
3739
  int nval = 4;
3740
  int pos = (32*sgitg + tiisg)*nval;
3741
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
3742
  nval = 2;
3743
  pos = (32*sgitg + tiisg)*nval;
3744
  for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
 
3827
  kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3828
  }
3829
 
3830
+ void kernel_mul_mv_iq2_xs_f32_impl(
3831
+ device const void * src0,
3832
+ device const float * src1,
3833
+ device float * dst,
3834
+ constant int64_t & ne00,
3835
+ constant int64_t & ne01,
3836
+ constant int64_t & ne02,
3837
+ constant int64_t & ne10,
3838
+ constant int64_t & ne12,
3839
+ constant int64_t & ne0,
3840
+ constant int64_t & ne1,
3841
+ constant uint & r2,
3842
+ constant uint & r3,
3843
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3844
+ uint3 tgpig[[threadgroup_position_in_grid]],
3845
+ uint tiisg[[thread_index_in_simdgroup]],
3846
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3847
+
3848
+ const int nb = ne00/QK_K;
3849
+ const int r0 = tgpig.x;
3850
+ const int r1 = tgpig.y;
3851
+ const int im = tgpig.z;
3852
+
3853
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3854
+ const int ib_row = first_row * nb;
3855
+
3856
+ const uint i12 = im%ne12;
3857
+ const uint i13 = im/ne12;
3858
+
3859
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3860
+
3861
+ device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
3862
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3863
+
3864
+ float yl[32];
3865
+ float sumf[N_DST]={0.f}, all_sum;
3866
+
3867
+ const int nb32 = nb * (QK_K / 32);
3868
+
3869
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
3870
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
3871
+ {
3872
+ int nval = 8;
3873
+ int pos = (32*sgitg + tiisg)*nval;
3874
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
3875
+ nval = 2;
3876
+ pos = (32*sgitg + tiisg)*nval;
3877
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
3878
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3879
+ }
3880
+
3881
+ #if QK_K == 256
3882
+ const int ix = tiisg;
3883
+
3884
+ device const float * y4 = y + 32 * ix;
3885
+
3886
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
3887
+
3888
+ for (int i = 0; i < 32; ++i) {
3889
+ yl[i] = y4[i];
3890
+ }
3891
+
3892
+ const int ibl = ib32 / (QK_K / 32);
3893
+ const int ib = ib32 % (QK_K / 32);
3894
+
3895
+ device const block_iq2_xs * xr = x + ibl;
3896
+ device const uint16_t * q2 = xr->qs + 4 * ib;
3897
+ device const uint8_t * sc = xr->scales + ib;
3898
+ device const half * dh = &xr->d;
3899
+
3900
+ for (int row = 0; row < N_DST; row++) {
3901
+
3902
+ const float db = dh[0];
3903
+ const uint8_t ls1 = sc[0] & 0xf;
3904
+ const uint8_t ls2 = sc[0] >> 4;
3905
+ const float d1 = db * (0.5f + ls1);
3906
+ const float d2 = db * (0.5f + ls2);
3907
+
3908
+ float sum1 = 0, sum2 = 0;
3909
+ for (int l = 0; l < 2; ++l) {
3910
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
3911
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
3912
+ for (int j = 0; j < 8; ++j) {
3913
+ sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3914
+ }
3915
+ }
3916
+ for (int l = 2; l < 4; ++l) {
3917
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
3918
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
3919
+ for (int j = 0; j < 8; ++j) {
3920
+ sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3921
+ }
3922
+ }
3923
+ sumf[row] += d1 * sum1 + d2 * sum2;
3924
+
3925
+ dh += nb*sizeof(block_iq2_xs)/2;
3926
+ q2 += nb*sizeof(block_iq2_xs)/2;
3927
+ sc += nb*sizeof(block_iq2_xs);
3928
+ }
3929
+
3930
+ y4 += 32 * 32;
3931
+ }
3932
+ #else
3933
+ // TODO
3934
+ #endif
3935
+
3936
+ for (int row = 0; row < N_DST; ++row) {
3937
+ all_sum = simd_sum(sumf[row]);
3938
+ if (tiisg == 0) {
3939
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
3940
+ }
3941
+ }
3942
+ }
3943
+
3944
+ [[host_name("kernel_mul_mv_iq2_xs_f32")]]
3945
+ kernel void kernel_mul_mv_iq2_xs_f32(
3946
+ device const void * src0,
3947
+ device const float * src1,
3948
+ device float * dst,
3949
+ constant int64_t & ne00,
3950
+ constant int64_t & ne01,
3951
+ constant int64_t & ne02,
3952
+ constant uint64_t & nb00,
3953
+ constant uint64_t & nb01,
3954
+ constant uint64_t & nb02,
3955
+ constant int64_t & ne10,
3956
+ constant int64_t & ne11,
3957
+ constant int64_t & ne12,
3958
+ constant uint64_t & nb10,
3959
+ constant uint64_t & nb11,
3960
+ constant uint64_t & nb12,
3961
+ constant int64_t & ne0,
3962
+ constant int64_t & ne1,
3963
+ constant uint & r2,
3964
+ constant uint & r3,
3965
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3966
+ uint3 tgpig[[threadgroup_position_in_grid]],
3967
+ uint tiisg[[thread_index_in_simdgroup]],
3968
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3969
+
3970
+ kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3971
+ }
3972
+
3973
  //============================= templates and their specializations =============================
3974
 
3975
  // NOTE: this is not dequantizing - we are simply fitting the template
 
4254
  const uint32_t aux32_s = q2[2] | (q2[3] << 16);
4255
  thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
4256
  const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
4257
+ constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
4258
  uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
4259
  for (int i = 0; i < 8; ++i) {
4260
  reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4261
  }
4262
+ grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
4263
  signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
4264
  for (int i = 0; i < 8; ++i) {
4265
  reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4266
  }
4267
  }
4268
 
4269
+ template <typename type4x4>
4270
+ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
4271
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4272
+ const float d = xb->d;
4273
+ const int ib32 = il/2;
4274
+ il = il%2;
4275
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4276
+ device const uint16_t * q2 = xb->qs + 4*ib32;
4277
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
4278
+ constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
4279
+ uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
4280
+ for (int i = 0; i < 8; ++i) {
4281
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4282
+ }
4283
+ grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
4284
+ signs = ksigns_iq2xs[q2[2*il+1] >> 9];
4285
+ for (int i = 0; i < 8; ++i) {
4286
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4287
+ }
4288
+ }
4289
+
4290
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
4291
  kernel void kernel_get_rows(
4292
  device const void * src0,
 
4827
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
4828
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
4829
  template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4830
+ template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4831
 
4832
  //
4833
  // matrix-matrix multiplication
 
4865
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
4866
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4867
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4868
+ template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4869
 
4870
  //
4871
  // indirect matrix-matrix multiplication
 
4915
  template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
4916
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4917
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4918
+ template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4919
 
4920
  //
4921
  // matrix-vector multiplication
 
5753
  tiisg,
5754
  sgitg);
5755
  }
5756
+
5757
+ [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
5758
+ kernel void kernel_mul_mv_id_iq2_xs_f32(
5759
+ device const char * ids,
5760
+ device const char * src1,
5761
+ device float * dst,
5762
+ constant uint64_t & nbi1,
5763
+ constant int64_t & ne00,
5764
+ constant int64_t & ne01,
5765
+ constant int64_t & ne02,
5766
+ constant uint64_t & nb00,
5767
+ constant uint64_t & nb01,
5768
+ constant uint64_t & nb02,
5769
+ constant int64_t & ne10,
5770
+ constant int64_t & ne11,
5771
+ constant int64_t & ne12,
5772
+ constant int64_t & ne13,
5773
+ constant uint64_t & nb10,
5774
+ constant uint64_t & nb11,
5775
+ constant uint64_t & nb12,
5776
+ constant int64_t & ne0,
5777
+ constant int64_t & ne1,
5778
+ constant uint64_t & nb1,
5779
+ constant uint & r2,
5780
+ constant uint & r3,
5781
+ constant int & idx,
5782
+ device const char * src00,
5783
+ device const char * src01,
5784
+ device const char * src02,
5785
+ device const char * src03,
5786
+ device const char * src04,
5787
+ device const char * src05,
5788
+ device const char * src06,
5789
+ device const char * src07,
5790
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5791
+ uint3 tgpig[[threadgroup_position_in_grid]],
5792
+ uint tiitg[[thread_index_in_threadgroup]],
5793
+ uint tiisg[[thread_index_in_simdgroup]],
5794
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5795
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5796
+
5797
+ const int64_t bid = tgpig.z/(ne12*ne13);
5798
+
5799
+ tgpig.z = tgpig.z%(ne12*ne13);
5800
+
5801
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5802
+
5803
+ kernel_mul_mv_iq2_xs_f32_impl(
5804
+ src0[id],
5805
+ (device const float *) (src1 + bid*nb11),
5806
+ dst + bid*ne0,
5807
+ ne00,
5808
+ ne01,
5809
+ ne02,
5810
+ ne10,
5811
+ ne12,
5812
+ ne0,
5813
+ ne1,
5814
+ r2,
5815
+ r3,
5816
+ shared_values,
5817
+ tgpig,
5818
+ tiisg,
5819
+ sgitg);
5820
+ }
ggml-quants.c CHANGED
@@ -2342,15 +2342,7 @@ size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t *
2342
 
2343
  // ====================== "True" 2-bit (de)-quantization
2344
 
2345
- void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) {
2346
- (void)x;
2347
- (void)y;
2348
- (void)k;
2349
- assert(k % QK_K == 0);
2350
- //fprintf(stderr, "=========================== %s: not implemented\n", __func__);
2351
- }
2352
-
2353
- static const uint64_t iq2xxs_grid[256] = {
2354
  0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
2355
  0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
2356
  0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
@@ -2417,6 +2409,137 @@ static const uint64_t iq2xxs_grid[256] = {
2417
  0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
2418
  };
2419
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2420
  static const uint8_t ksigns_iq2xs[128] = {
2421
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
2422
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -2427,8 +2550,17 @@ static const uint8_t ksigns_iq2xs[128] = {
2427
  96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
2428
  240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
2429
  };
 
2430
  static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
2431
 
 
 
 
 
 
 
 
 
2432
  void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
2433
  assert(k % QK_K == 0);
2434
  const int nb = k / QK_K;
@@ -2472,6 +2604,58 @@ size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_
2472
  return (n/QK_K*sizeof(block_iq2_xxs));
2473
  }
2474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2475
  //===================================== Q8_K ==============================================
2476
 
2477
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -7357,3 +7541,161 @@ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * res
7357
  *s = 0.125f * sumf;
7358
  #endif
7359
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2342
 
2343
  // ====================== "True" 2-bit (de)-quantization
2344
 
2345
+ static const uint64_t iq2xxs_grid[256] = {
 
 
 
 
 
 
 
 
2346
  0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
2347
  0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
2348
  0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
 
2409
  0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
2410
  };
2411
 
2412
+ static const uint64_t iq2xs_grid[512] = {
2413
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
2414
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
2415
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
2416
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
2417
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
2418
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
2419
+ 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
2420
+ 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
2421
+ 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
2422
+ 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
2423
+ 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
2424
+ 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
2425
+ 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
2426
+ 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
2427
+ 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
2428
+ 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
2429
+ 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
2430
+ 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
2431
+ 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
2432
+ 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
2433
+ 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
2434
+ 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
2435
+ 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
2436
+ 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
2437
+ 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
2438
+ 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
2439
+ 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
2440
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
2441
+ 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
2442
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
2443
+ 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
2444
+ 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
2445
+ 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
2446
+ 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
2447
+ 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
2448
+ 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
2449
+ 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
2450
+ 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
2451
+ 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
2452
+ 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
2453
+ 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
2454
+ 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
2455
+ 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
2456
+ 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
2457
+ 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
2458
+ 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
2459
+ 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
2460
+ 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
2461
+ 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
2462
+ 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
2463
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
2464
+ 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
2465
+ 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
2466
+ 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
2467
+ 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
2468
+ 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
2469
+ 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
2470
+ 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
2471
+ 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
2472
+ 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
2473
+ 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
2474
+ 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
2475
+ 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
2476
+ 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
2477
+ 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
2478
+ 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
2479
+ 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
2480
+ 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
2481
+ 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
2482
+ 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
2483
+ 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
2484
+ 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
2485
+ 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
2486
+ 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
2487
+ 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
2488
+ 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
2489
+ 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
2490
+ 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
2491
+ 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
2492
+ 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
2493
+ 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
2494
+ 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
2495
+ 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
2496
+ 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
2497
+ 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
2498
+ 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
2499
+ 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
2500
+ 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
2501
+ 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
2502
+ 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
2503
+ 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
2504
+ 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
2505
+ 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
2506
+ 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
2507
+ 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
2508
+ 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
2509
+ 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
2510
+ 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
2511
+ 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
2512
+ 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
2513
+ 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
2514
+ 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
2515
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
2516
+ 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
2517
+ 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
2518
+ 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
2519
+ 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
2520
+ 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
2521
+ 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
2522
+ 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
2523
+ 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
2524
+ 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
2525
+ 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
2526
+ 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
2527
+ 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
2528
+ 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
2529
+ 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
2530
+ 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
2531
+ 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
2532
+ 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
2533
+ 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
2534
+ 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
2535
+ 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
2536
+ 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
2537
+ 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
2538
+ 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
2539
+ 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
2540
+ 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
2541
+ };
2542
+
2543
  static const uint8_t ksigns_iq2xs[128] = {
2544
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
2545
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
 
2550
  96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
2551
  240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
2552
  };
2553
+
2554
  static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
2555
 
2556
+ void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) {
2557
+ (void)x;
2558
+ (void)y;
2559
+ (void)k;
2560
+ assert(k % QK_K == 0);
2561
+ //fprintf(stderr, "=========================== %s: not implemented\n", __func__);
2562
+ }
2563
+
2564
  void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
2565
  assert(k % QK_K == 0);
2566
  const int nb = k / QK_K;
 
2604
  return (n/QK_K*sizeof(block_iq2_xxs));
2605
  }
2606
 
2607
+ // ====================== 2.3125 bpw (de)-quantization
2608
+
2609
+ void quantize_row_iq2_xs_reference(const float * restrict x, block_iq2_xs * restrict y, int k) {
2610
+ (void)x;
2611
+ (void)y;
2612
+ (void)k;
2613
+ assert(k % QK_K == 0);
2614
+ //fprintf(stderr, "=========================== %s: not implemented\n", __func__);
2615
+ }
2616
+
2617
+ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) {
2618
+ assert(k % QK_K == 0);
2619
+ const int nb = k / QK_K;
2620
+
2621
+ float db[2];
2622
+
2623
+ for (int i = 0; i < nb; i++) {
2624
+
2625
+ const float d = GGML_FP16_TO_FP32(x[i].d);
2626
+
2627
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
2628
+ db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;
2629
+ db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f;
2630
+ for (int l = 0; l < 4; ++l) {
2631
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511));
2632
+ const uint8_t signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9];
2633
+ for (int j = 0; j < 8; ++j) {
2634
+ y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
2635
+ }
2636
+ y += 8;
2637
+ }
2638
+ }
2639
+ }
2640
+ }
2641
+
2642
+ void quantize_row_iq2_xs(const float * restrict x, void * restrict vy, int k) {
2643
+ assert(k % QK_K == 0);
2644
+ block_iq2_xs * restrict y = vy;
2645
+ quantize_row_iq2_xs_reference(x, y, k);
2646
+ }
2647
+
2648
+ size_t ggml_quantize_iq2_xs(const float * src, void * dst, int n, int k, int64_t * hist) {
2649
+ assert(k % QK_K == 0);
2650
+ (void)hist; // TODO: collect histograms
2651
+
2652
+ for (int j = 0; j < n; j += k) {
2653
+ block_iq2_xs * restrict y = (block_iq2_xs *)dst + j/QK_K;
2654
+ quantize_row_iq2_xs_reference(src + j, y, k);
2655
+ }
2656
+ return (n/QK_K*sizeof(block_iq2_xs));
2657
+ }
2658
+
2659
  //===================================== Q8_K ==============================================
2660
 
2661
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
 
7541
  *s = 0.125f * sumf;
7542
  #endif
7543
  }
7544
+
7545
+ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
7546
+ assert(n % QK_K == 0);
7547
+
7548
+ const block_iq2_xs * restrict x = vx;
7549
+ const block_q8_K * restrict y = vy;
7550
+
7551
+ const int nb = n / QK_K;
7552
+
7553
+ #if defined(__ARM_NEON)
7554
+
7555
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
7556
+
7557
+ int8x16x4_t q2u;
7558
+ int8x16x4_t q2s;
7559
+ int8x16x4_t q8b;
7560
+
7561
+ int32x4x4_t scales32;
7562
+
7563
+ float sumf = 0;
7564
+ for (int i = 0; i < nb; ++i) {
7565
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
7566
+ const uint16_t * restrict q2 = x[i].qs;
7567
+ const int8_t * restrict q8 = y[i].qs;
7568
+ const uint8x8_t scales8 = vld1_u8(x[i].scales);
7569
+ const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
7570
+ const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
7571
+ uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
7572
+ scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
7573
+ const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
7574
+ const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
7575
+ scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
7576
+ scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
7577
+ scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
7578
+ scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
7579
+ int32x4_t sumi = vdupq_n_s32(0);
7580
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
7581
+ q8b = vld1q_s8_x4(q8); q8 += 64;
7582
+ q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
7583
+ q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
7584
+ q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
7585
+ q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));
7586
+ q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));
7587
+ q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));
7588
+ q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));
7589
+ q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));
7590
+ q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
7591
+ q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
7592
+ q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
7593
+ q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
7594
+ const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);
7595
+ const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
7596
+ const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
7597
+ const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
7598
+ const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
7599
+ sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
7600
+ q2 += 8;
7601
+ }
7602
+ sumf += d*vaddvq_s32(sumi);
7603
+ }
7604
+ *s = 0.125f * sumf;
7605
+
7606
+ #elif defined(__AVX2__)
7607
+
7608
+ const __m128i m4 = _mm_set1_epi8(0xf);
7609
+ const __m128i m1 = _mm_set1_epi8(1);
7610
+ const __m128i m511 = _mm_set1_epi16(511);
7611
+ const __m128i m127 = _mm_set1_epi16(127);
7612
+
7613
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
7614
+
7615
+ uint64_t aux64;
7616
+
7617
+ // somewhat hacky, but gives a significant boost in performance
7618
+ __m128i aux_gindex, aux_sindex;
7619
+ const uint16_t * gindex = (const uint16_t *)&aux_gindex;
7620
+ const uint16_t * sindex = (const uint16_t *)&aux_sindex;
7621
+
7622
+ __m256 accumf = _mm256_setzero_ps();
7623
+ for (int i = 0; i < nb; ++i) {
7624
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
7625
+ const uint16_t * restrict q2 = x[i].qs;
7626
+ const int8_t * restrict q8 = y[i].qs;
7627
+
7628
+ memcpy(&aux64, x[i].scales, 8);
7629
+ __m128i stmp = _mm_set1_epi64x(aux64);
7630
+ stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
7631
+ const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
7632
+
7633
+ __m256i sumi1 = _mm256_setzero_si256();
7634
+ __m256i sumi2 = _mm256_setzero_si256();
7635
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
7636
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
7637
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
7638
+ const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2); q2 += 8;
7639
+ aux_gindex = _mm_and_si128(q2_data, m511);
7640
+ aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127);
7641
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
7642
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
7643
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]);
7644
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]);
7645
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
7646
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
7647
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
7648
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
7649
+
7650
+ const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
7651
+ const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
7652
+
7653
+ sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
7654
+ sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
7655
+ }
7656
+
7657
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
7658
+
7659
+ }
7660
+
7661
+ *s = 0.125f * hsum_float_8(accumf);
7662
+
7663
+ #else
7664
+
7665
+ float sumf = 0.f;
7666
+ for (int i = 0; i < nb; ++i) {
7667
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
7668
+ const uint16_t * restrict q2 = x[i].qs;
7669
+ const uint8_t * restrict sc = x[i].scales;
7670
+ const int8_t * restrict q8 = y[i].qs;
7671
+ int32_t bsum = 0;
7672
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
7673
+ const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
7674
+ const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1;
7675
+ int32_t sumi = 0;
7676
+ for (int l = 0; l < 2; ++l) {
7677
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
7678
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
7679
+ for (int j = 0; j < 8; ++j) {
7680
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
7681
+ }
7682
+ q8 += 8;
7683
+ }
7684
+ bsum += sumi * ls1;
7685
+ sumi = 0;
7686
+ for (int l = 2; l < 4; ++l) {
7687
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
7688
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
7689
+ for (int j = 0; j < 8; ++j) {
7690
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
7691
+ }
7692
+ q8 += 8;
7693
+ }
7694
+ bsum += sumi * ls2;
7695
+ q2 += 4;
7696
+ }
7697
+ sumf += d * bsum;
7698
+ }
7699
+ *s = 0.125f * sumf;
7700
+ #endif
7701
+ }
ggml-quants.h CHANGED
@@ -174,6 +174,14 @@ typedef struct {
174
  } block_iq2_xxs;
175
  static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
176
 
 
 
 
 
 
 
 
 
177
  // Quantization
178
  void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
179
  void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
@@ -189,6 +197,7 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
189
  void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
190
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
191
  void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k);
 
192
 
193
  void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
194
  void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
@@ -204,6 +213,7 @@ void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
204
  void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
205
  void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
206
  void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k);
 
207
 
208
  // Dequantization
209
  void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
@@ -220,6 +230,7 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
220
  void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
221
  void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
222
  void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
 
223
 
224
  // Dot product
225
  void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
@@ -234,3 +245,4 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx,
234
  void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
235
  void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
236
  void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 
 
174
  } block_iq2_xxs;
175
  static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
176
 
177
+ // 2.3125 bpw quants
178
+ typedef struct {
179
+ ggml_fp16_t d;
180
+ uint16_t qs[QK_K/8];
181
+ uint8_t scales[QK_K/32];
182
+ } block_iq2_xs;
183
+ static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
184
+
185
  // Quantization
186
  void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
187
  void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
 
197
  void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
198
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
199
  void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k);
200
+ void quantize_row_iq2_xs_reference (const float * restrict x, block_iq2_xs * restrict y, int k);
201
 
202
  void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
203
  void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
 
213
  void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
214
  void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
215
  void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k);
216
+ void quantize_row_iq2_xs (const float * restrict x, void * restrict y, int k);
217
 
218
  // Dequantization
219
  void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
 
230
  void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
231
  void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
232
  void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
233
+ void dequantize_row_iq2_xs (const block_iq2_xs * restrict x, float * restrict y, int k);
234
 
235
  // Dot product
236
  void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 
245
  void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
246
  void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
247
  void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
248
+ void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);
ggml.c CHANGED
@@ -584,6 +584,17 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
584
  .vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
585
  .vec_dot_type = GGML_TYPE_Q8_K,
586
  },
 
 
 
 
 
 
 
 
 
 
 
587
  [GGML_TYPE_Q8_K] = {
588
  .type_name = "q8_K",
589
  .blck_size = QK_K,
@@ -2123,6 +2134,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2123
  case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
2124
  case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
2125
  case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
 
2126
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
2127
  case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
2128
  }
@@ -7435,6 +7447,7 @@ static void ggml_compute_forward_add(
7435
  case GGML_TYPE_Q5_K:
7436
  case GGML_TYPE_Q6_K:
7437
  case GGML_TYPE_IQ2_XXS:
 
7438
  {
7439
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
7440
  } break;
@@ -7700,6 +7713,7 @@ static void ggml_compute_forward_add1(
7700
  case GGML_TYPE_Q5_K:
7701
  case GGML_TYPE_Q6_K:
7702
  case GGML_TYPE_IQ2_XXS:
 
7703
  {
7704
  ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
7705
  } break;
@@ -7815,6 +7829,7 @@ static void ggml_compute_forward_acc(
7815
  case GGML_TYPE_Q5_K:
7816
  case GGML_TYPE_Q6_K:
7817
  case GGML_TYPE_IQ2_XXS:
 
7818
  default:
7819
  {
7820
  GGML_ASSERT(false);
@@ -10457,6 +10472,7 @@ static void ggml_compute_forward_out_prod(
10457
  case GGML_TYPE_Q5_K:
10458
  case GGML_TYPE_Q6_K:
10459
  case GGML_TYPE_IQ2_XXS:
 
10460
  {
10461
  ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
10462
  } break;
@@ -10632,6 +10648,7 @@ static void ggml_compute_forward_set(
10632
  case GGML_TYPE_Q5_K:
10633
  case GGML_TYPE_Q6_K:
10634
  case GGML_TYPE_IQ2_XXS:
 
10635
  default:
10636
  {
10637
  GGML_ASSERT(false);
@@ -10827,6 +10844,7 @@ static void ggml_compute_forward_get_rows(
10827
  case GGML_TYPE_Q5_K:
10828
  case GGML_TYPE_Q6_K:
10829
  case GGML_TYPE_IQ2_XXS:
 
10830
  {
10831
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
10832
  } break;
@@ -11464,6 +11482,7 @@ static void ggml_compute_forward_alibi(
11464
  case GGML_TYPE_Q5_K:
11465
  case GGML_TYPE_Q6_K:
11466
  case GGML_TYPE_IQ2_XXS:
 
11467
  case GGML_TYPE_Q8_K:
11468
  case GGML_TYPE_I8:
11469
  case GGML_TYPE_I16:
@@ -11539,6 +11558,7 @@ static void ggml_compute_forward_clamp(
11539
  case GGML_TYPE_Q5_K:
11540
  case GGML_TYPE_Q6_K:
11541
  case GGML_TYPE_IQ2_XXS:
 
11542
  case GGML_TYPE_Q8_K:
11543
  case GGML_TYPE_I8:
11544
  case GGML_TYPE_I16:
@@ -18660,6 +18680,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
18660
  block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K;
18661
  result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist);
18662
  } break;
 
 
 
 
 
 
18663
  case GGML_TYPE_F16:
18664
  {
18665
  int elemsize = sizeof(ggml_fp16_t);
@@ -19015,8 +19041,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
19015
  (int64_t) info->ne[3];
19016
 
19017
  if (ne % ggml_blck_size(info->type) != 0) {
19018
- fprintf(stderr, "%s: tensor '%s' number of elements (%" PRId64 ") is not a multiple of block size (%d)\n",
19019
- __func__, info->name.data, ne, ggml_blck_size(info->type));
19020
  fclose(file);
19021
  gguf_free(ctx);
19022
  return NULL;
 
584
  .vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
585
  .vec_dot_type = GGML_TYPE_Q8_K,
586
  },
587
+ [GGML_TYPE_IQ2_XS] = {
588
+ .type_name = "iq2_xs",
589
+ .blck_size = QK_K,
590
+ .type_size = sizeof(block_iq2_xs),
591
+ .is_quantized = true,
592
+ .to_float = (ggml_to_float_t) dequantize_row_iq2_xs,
593
+ .from_float = quantize_row_iq2_xs,
594
+ .from_float_reference = (ggml_from_float_t) quantize_row_iq2_xs_reference,
595
+ .vec_dot = ggml_vec_dot_iq2_xs_q8_K,
596
+ .vec_dot_type = GGML_TYPE_Q8_K,
597
+ },
598
  [GGML_TYPE_Q8_K] = {
599
  .type_name = "q8_K",
600
  .blck_size = QK_K,
 
2134
  case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
2135
  case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
2136
  case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
2137
+ case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
2138
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
2139
  case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
2140
  }
 
7447
  case GGML_TYPE_Q5_K:
7448
  case GGML_TYPE_Q6_K:
7449
  case GGML_TYPE_IQ2_XXS:
7450
+ case GGML_TYPE_IQ2_XS:
7451
  {
7452
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
7453
  } break;
 
7713
  case GGML_TYPE_Q5_K:
7714
  case GGML_TYPE_Q6_K:
7715
  case GGML_TYPE_IQ2_XXS:
7716
+ case GGML_TYPE_IQ2_XS:
7717
  {
7718
  ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
7719
  } break;
 
7829
  case GGML_TYPE_Q5_K:
7830
  case GGML_TYPE_Q6_K:
7831
  case GGML_TYPE_IQ2_XXS:
7832
+ case GGML_TYPE_IQ2_XS:
7833
  default:
7834
  {
7835
  GGML_ASSERT(false);
 
10472
  case GGML_TYPE_Q5_K:
10473
  case GGML_TYPE_Q6_K:
10474
  case GGML_TYPE_IQ2_XXS:
10475
+ case GGML_TYPE_IQ2_XS:
10476
  {
10477
  ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
10478
  } break;
 
10648
  case GGML_TYPE_Q5_K:
10649
  case GGML_TYPE_Q6_K:
10650
  case GGML_TYPE_IQ2_XXS:
10651
+ case GGML_TYPE_IQ2_XS:
10652
  default:
10653
  {
10654
  GGML_ASSERT(false);
 
10844
  case GGML_TYPE_Q5_K:
10845
  case GGML_TYPE_Q6_K:
10846
  case GGML_TYPE_IQ2_XXS:
10847
+ case GGML_TYPE_IQ2_XS:
10848
  {
10849
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
10850
  } break;
 
11482
  case GGML_TYPE_Q5_K:
11483
  case GGML_TYPE_Q6_K:
11484
  case GGML_TYPE_IQ2_XXS:
11485
+ case GGML_TYPE_IQ2_XS:
11486
  case GGML_TYPE_Q8_K:
11487
  case GGML_TYPE_I8:
11488
  case GGML_TYPE_I16:
 
11558
  case GGML_TYPE_Q5_K:
11559
  case GGML_TYPE_Q6_K:
11560
  case GGML_TYPE_IQ2_XXS:
11561
+ case GGML_TYPE_IQ2_XS:
11562
  case GGML_TYPE_Q8_K:
11563
  case GGML_TYPE_I8:
11564
  case GGML_TYPE_I16:
 
18680
  block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K;
18681
  result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist);
18682
  } break;
18683
+ case GGML_TYPE_IQ2_XS:
18684
+ {
18685
+ GGML_ASSERT(start % QK_K == 0);
18686
+ block_iq2_xs * block = (block_iq2_xs*)dst + start / QK_K;
18687
+ result = ggml_quantize_iq2_xs(src + start, block, n, n, hist);
18688
+ } break;
18689
  case GGML_TYPE_F16:
18690
  {
18691
  int elemsize = sizeof(ggml_fp16_t);
 
19041
  (int64_t) info->ne[3];
19042
 
19043
  if (ne % ggml_blck_size(info->type) != 0) {
19044
+ fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%d)\n",
19045
+ __func__, info->name.data, (int)info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
19046
  fclose(file);
19047
  gguf_free(ctx);
19048
  return NULL;
ggml.h CHANGED
@@ -342,6 +342,7 @@ extern "C" {
342
  GGML_TYPE_Q6_K = 14,
343
  GGML_TYPE_Q8_K = 15,
344
  GGML_TYPE_IQ2_XXS = 16,
 
345
  GGML_TYPE_I8,
346
  GGML_TYPE_I16,
347
  GGML_TYPE_I32,
@@ -377,6 +378,7 @@ extern "C" {
377
  GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
378
  GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
379
  GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
 
380
  };
381
 
382
  // available tensor operations:
@@ -2061,6 +2063,7 @@ extern "C" {
2061
  GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
2062
  GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
2063
  GGML_API size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist);
 
2064
 
2065
  GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
2066
 
 
342
  GGML_TYPE_Q6_K = 14,
343
  GGML_TYPE_Q8_K = 15,
344
  GGML_TYPE_IQ2_XXS = 16,
345
+ GGML_TYPE_IQ2_XS = 17,
346
  GGML_TYPE_I8,
347
  GGML_TYPE_I16,
348
  GGML_TYPE_I32,
 
378
  GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
379
  GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
380
  GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
381
+ GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
382
  };
383
 
384
  // available tensor operations:
 
2063
  GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
2064
  GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
2065
  GGML_API size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist);
2066
+ GGML_API size_t ggml_quantize_iq2_xs (const float * src, void * dst, int n, int k, int64_t * hist);
2067
 
2068
  GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
2069