Kawrakow ikawrakow commited on
Commit
0ee1bfb
·
unverified ·
1 Parent(s): 753b30d

IQ4_XS: a 4.25 bpw quantization (llama/5747)

Browse files

* Try IQ4_NL with blocks of 64 - does not look good

* iq4_xs: go to super-blocks of 256 and 6-bit scales for blocks of 32

* iq4_xs: CUDA works - 133.2 t/s

* iq4_xs: AVX2 dot product

* iq4_xs: ARM_NEON dot product

* iq4_nl: Metal implementation

As usual, Metal / Apple Silicon don't like my quants.

* iq3_xs: minor fix

* iq4_xs: shrink by using IQ3_S for attn_k and attn_q

* iq4_xs: revert using IQ3_S for attn_k and attn_v

PPL vs size is good, but CPU performance suffers: on M2 Max
TG-128 drops to 21.7 t/s from 28.8, and on a Ryzen-7950X
to 14.5 t/s from 15.8 t/s. On CUDA we have 135 t/s when
using IQ3_S vs 133 t/s with pure IQ4_XS.

* Fix CI

* iq4_xs: Added forgotten check for 256 divisibility

---------

Co-authored-by: Iwan Kawrakow <[email protected]>

Files changed (7) hide show
  1. ggml-cuda.cu +118 -1
  2. ggml-metal.m +27 -2
  3. ggml-metal.metal +221 -3
  4. ggml-quants.c +240 -21
  5. ggml-quants.h +13 -0
  6. ggml.c +30 -0
  7. ggml.h +2 -0
ggml-cuda.cu CHANGED
@@ -571,6 +571,18 @@ typedef struct {
571
  } block_iq4_nl;
572
  static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
573
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  #define WARP_SIZE 32
575
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
576
 
@@ -2427,6 +2439,25 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
2427
 
2428
  }
2429
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2430
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
2431
 
2432
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -5286,6 +5317,76 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
5286
  return d * (sumi1 + sumi2);
5287
  }
5288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5289
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
5290
  allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
5291
  static __device__ __forceinline__ void mul_mat_q(
@@ -7340,6 +7441,12 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k,
7340
  dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
7341
  }
7342
 
 
 
 
 
 
 
7343
  template <typename src_t, typename dst_t>
7344
  static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
7345
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@@ -7385,6 +7492,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
7385
  return dequantize_row_iq1_s_cuda;
7386
  case GGML_TYPE_IQ4_NL:
7387
  return dequantize_row_iq4_nl_cuda;
 
 
7388
  case GGML_TYPE_IQ3_S:
7389
  return dequantize_row_iq3_s_cuda;
7390
  case GGML_TYPE_F32:
@@ -7428,6 +7537,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
7428
  return dequantize_row_iq1_s_cuda;
7429
  case GGML_TYPE_IQ4_NL:
7430
  return dequantize_row_iq4_nl_cuda;
 
 
7431
  case GGML_TYPE_IQ3_S:
7432
  return dequantize_row_iq3_s_cuda;
7433
  case GGML_TYPE_F16:
@@ -9176,6 +9287,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
9176
  case GGML_TYPE_IQ3_XXS:
9177
  case GGML_TYPE_IQ1_S:
9178
  case GGML_TYPE_IQ4_NL:
 
9179
  case GGML_TYPE_IQ3_S:
9180
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
9181
  default:
@@ -9203,6 +9315,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
9203
  case GGML_TYPE_IQ3_XXS:
9204
  case GGML_TYPE_IQ1_S:
9205
  case GGML_TYPE_IQ4_NL:
 
9206
  case GGML_TYPE_IQ3_S:
9207
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
9208
  case GGML_TYPE_Q6_K:
@@ -9313,6 +9426,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
9313
  mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
9314
  (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
9315
  break;
 
 
 
 
9316
  case GGML_TYPE_IQ3_S:
9317
  mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
9318
  (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
@@ -12041,7 +12158,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
12041
  ggml_type a_type = a->type;
12042
  if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
12043
  a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
12044
- a_type == GGML_TYPE_IQ2_S) {
12045
  if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
12046
  return false;
12047
  }
 
571
  } block_iq4_nl;
572
  static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
573
 
574
+ // QR4_XS = 8 is very slightly faster than QR4_XS = 4
575
+ #define QR4_XS 8
576
+ #define QI4_XS (QK_K / (4*QR4_XS))
577
+ typedef struct {
578
+ half d;
579
+ uint16_t scales_h;
580
+ uint8_t scales_l[QK_K/64];
581
+ uint8_t qs[QK_K/2];
582
+ } block_iq4_xs;
583
+ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
584
+
585
+
586
  #define WARP_SIZE 32
587
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
588
 
 
2439
 
2440
  }
2441
 
2442
+ template<typename dst_t>
2443
+ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
2444
+
2445
+ const int i = blockIdx.x;
2446
+ const block_iq4_xs * x = (const block_iq4_xs *)vx;
2447
+
2448
+ const int tid = threadIdx.x;
2449
+ const int il = tid/8; // 0...3
2450
+ const int ib = tid%8; // 0...7
2451
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
2452
+ const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
2453
+ const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
2454
+ for (int j = 0; j < 4; ++j) {
2455
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
2456
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
2457
+ }
2458
+
2459
+ }
2460
+
2461
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
2462
 
2463
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
5317
  return d * (sumi1 + sumi2);
5318
  }
5319
 
5320
+ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
5321
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
5322
+
5323
+ #if QK_K == 256
5324
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
5325
+
5326
+ const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
5327
+ const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
5328
+
5329
+ //// iqs is 0...7
5330
+ //const int ib64 = iqs/2;
5331
+ //const int il = iqs%2;
5332
+ //const int32_t * q8_1 = (const int *)bq8_1[2*ib64+0].qs + 2*il;
5333
+ //const int32_t * q8_2 = (const int *)bq8_1[2*ib64+1].qs + 2*il;
5334
+ //const uint32_t * q4_1 = (const uint32_t *)bq4->qs + 8*ib64 + 2*il;
5335
+ //const uint32_t * q4_2 = q4_1 + 4;
5336
+ //const int8_t ls1 = (bq4->scales_l[ib64] & 0xf) | (((bq4->scales_h >> (4*ib64+0)) & 3) << 4);
5337
+ //const int8_t ls2 = (bq4->scales_l[ib64] >> 4) | (((bq4->scales_h >> (4*ib64+2)) & 3) << 4);
5338
+ //const float d1 = (float)bq4->d * (ls1 - 32) * __low2float(bq8_1[2*ib64+0].ds);
5339
+ //const float d2 = (float)bq4->d * (ls2 - 32) * __low2float(bq8_1[2*ib64+1].ds);
5340
+ //int v1, v2;
5341
+ //int sumi1 = 0, sumi2 = 0;
5342
+ //for (int j = 0; j < 2; ++j) {
5343
+ // get_int_from_table_16(q4_1[j], values, v1, v2);
5344
+ // sumi1 = __dp4a(v2, q8_1[j+4], __dp4a(v1, q8_1[j+0], sumi1));
5345
+ // get_int_from_table_16(q4_2[j], values, v1, v2);
5346
+ // sumi2 = __dp4a(v2, q8_2[j+4], __dp4a(v1, q8_2[j+0], sumi2));
5347
+ //}
5348
+ //return d1 * sumi1 + d2 * sumi2;
5349
+
5350
+ // iqs is 0...7
5351
+ const int ib32 = iqs;
5352
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
5353
+ const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
5354
+ const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
5355
+ const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds);
5356
+ int v1, v2;
5357
+ int sumi1 = 0, sumi2 = 0;
5358
+ for (int j = 0; j < 4; ++j) {
5359
+ get_int_from_table_16(q4[j], values, v1, v2);
5360
+ sumi1 = __dp4a(v1, q8[j+0], sumi1);
5361
+ sumi2 = __dp4a(v2, q8[j+4], sumi2);
5362
+ }
5363
+ return d * (sumi1 + sumi2);
5364
+
5365
+ //// iqs is 0...15
5366
+ //const int ib32 = iqs/2;
5367
+ //const int il = iqs%2;
5368
+ //const int32_t * q8 = (const int *)bq8_1[ib32].qs + 2*il;
5369
+ //const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32 + 2*il;
5370
+ //const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
5371
+ //const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds);
5372
+ //int v1, v2;
5373
+ //int sumi1 = 0, sumi2 = 0;
5374
+ //for (int j = 0; j < 2; ++j) {
5375
+ // get_int_from_table_16(q4[j], values, v1, v2);
5376
+ // sumi1 = __dp4a(v1, q8[j+0], sumi1);
5377
+ // sumi2 = __dp4a(v2, q8[j+4], sumi2);
5378
+ //}
5379
+ //return d * (sumi1 + sumi2);
5380
+ #else
5381
+ assert(false);
5382
+ return 0.f;
5383
+ #endif
5384
+ #else
5385
+ assert(false);
5386
+ return 0.f;
5387
+ #endif
5388
+ }
5389
+
5390
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
5391
  allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
5392
  static __device__ __forceinline__ void mul_mat_q(
 
7441
  dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
7442
  }
7443
 
7444
+ template<typename dst_t>
7445
+ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
7446
+ const int nb = (k + QK_K - 1) / QK_K;
7447
+ dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
7448
+ }
7449
+
7450
  template <typename src_t, typename dst_t>
7451
  static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
7452
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
 
7492
  return dequantize_row_iq1_s_cuda;
7493
  case GGML_TYPE_IQ4_NL:
7494
  return dequantize_row_iq4_nl_cuda;
7495
+ case GGML_TYPE_IQ4_XS:
7496
+ return dequantize_row_iq4_xs_cuda;
7497
  case GGML_TYPE_IQ3_S:
7498
  return dequantize_row_iq3_s_cuda;
7499
  case GGML_TYPE_F32:
 
7537
  return dequantize_row_iq1_s_cuda;
7538
  case GGML_TYPE_IQ4_NL:
7539
  return dequantize_row_iq4_nl_cuda;
7540
+ case GGML_TYPE_IQ4_XS:
7541
+ return dequantize_row_iq4_xs_cuda;
7542
  case GGML_TYPE_IQ3_S:
7543
  return dequantize_row_iq3_s_cuda;
7544
  case GGML_TYPE_F16:
 
9287
  case GGML_TYPE_IQ3_XXS:
9288
  case GGML_TYPE_IQ1_S:
9289
  case GGML_TYPE_IQ4_NL:
9290
+ case GGML_TYPE_IQ4_XS:
9291
  case GGML_TYPE_IQ3_S:
9292
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
9293
  default:
 
9315
  case GGML_TYPE_IQ3_XXS:
9316
  case GGML_TYPE_IQ1_S:
9317
  case GGML_TYPE_IQ4_NL:
9318
+ case GGML_TYPE_IQ4_XS:
9319
  case GGML_TYPE_IQ3_S:
9320
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
9321
  case GGML_TYPE_Q6_K:
 
9426
  mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
9427
  (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
9428
  break;
9429
+ case GGML_TYPE_IQ4_XS:
9430
+ mul_mat_vec_q_cuda<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
9431
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
9432
+ break;
9433
  case GGML_TYPE_IQ3_S:
9434
  mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
9435
  (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
 
12158
  ggml_type a_type = a->type;
12159
  if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
12160
  a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
12161
+ a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
12162
  if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
12163
  return false;
12164
  }
ggml-metal.m CHANGED
@@ -65,6 +65,7 @@ enum ggml_metal_kernel_type {
65
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
67
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
 
68
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
69
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
70
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -91,6 +92,7 @@ enum ggml_metal_kernel_type {
91
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
92
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
93
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
 
94
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
95
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
96
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -113,6 +115,7 @@ enum ggml_metal_kernel_type {
113
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
114
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
115
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
 
116
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
117
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
118
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -132,6 +135,7 @@ enum ggml_metal_kernel_type {
132
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
133
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
134
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
 
135
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
136
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
137
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -151,6 +155,7 @@ enum ggml_metal_kernel_type {
151
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
152
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
153
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
 
154
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
155
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
156
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -466,6 +471,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
466
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
467
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
468
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
 
469
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
470
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
471
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -492,6 +498,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
492
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
493
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
494
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
 
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
496
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -514,6 +521,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
514
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
515
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
516
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
 
517
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
518
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
519
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -533,6 +541,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
533
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
534
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
535
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
 
536
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
537
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
538
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -552,6 +561,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
552
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
553
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
554
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
 
555
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
556
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
557
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -1371,6 +1381,7 @@ static bool ggml_metal_graph_compute(
1371
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1372
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1373
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
 
1374
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1375
  }
1376
 
@@ -1529,6 +1540,12 @@ static bool ggml_metal_graph_compute(
1529
  nth1 = 16;
1530
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
1531
  } break;
 
 
 
 
 
 
1532
  default:
1533
  {
1534
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1576,7 +1593,7 @@ static bool ggml_metal_graph_compute(
1576
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1577
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1578
  }
1579
- else if (src0t == GGML_TYPE_IQ4_NL) {
1580
  const int mem_size = 32*sizeof(float);
1581
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1582
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1678,6 +1695,7 @@ static bool ggml_metal_graph_compute(
1678
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
1679
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1680
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
 
1681
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1682
  }
1683
 
@@ -1839,6 +1857,12 @@ static bool ggml_metal_graph_compute(
1839
  nth1 = 16;
1840
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
1841
  } break;
 
 
 
 
 
 
1842
  default:
1843
  {
1844
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1902,7 +1926,7 @@ static bool ggml_metal_graph_compute(
1902
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1903
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1904
  }
1905
- else if (src2t == GGML_TYPE_IQ4_NL) {
1906
  const int mem_size = 32*sizeof(float);
1907
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1908
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1952,6 +1976,7 @@ static bool ggml_metal_graph_compute(
1952
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
1953
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
1954
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
 
1955
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1956
  default: GGML_ASSERT(false && "not implemented");
1957
  }
 
65
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
67
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
68
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
69
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
70
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
71
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
 
92
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
93
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
94
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
95
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
96
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
97
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
98
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
 
115
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
116
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
117
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
118
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
119
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
120
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
121
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
 
135
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
136
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
137
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
138
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
139
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
140
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
141
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
 
155
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
156
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
157
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
158
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
159
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
160
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
161
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
 
471
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
472
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
473
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
474
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
475
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
 
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
500
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
501
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
503
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
504
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
 
521
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
522
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
523
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
524
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
525
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
526
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
527
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
 
541
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
542
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
543
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
544
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
545
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
546
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
547
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
 
561
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
562
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
563
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
564
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
565
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
566
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
567
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
 
1381
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1382
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1383
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1384
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1385
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1386
  }
1387
 
 
1540
  nth1 = 16;
1541
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
1542
  } break;
1543
+ case GGML_TYPE_IQ4_XS:
1544
+ {
1545
+ nth0 = 4;
1546
+ nth1 = 16;
1547
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
1548
+ } break;
1549
  default:
1550
  {
1551
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
 
1593
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1594
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1595
  }
1596
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1597
  const int mem_size = 32*sizeof(float);
1598
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1599
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 
1695
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
1696
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1697
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1698
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
1699
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1700
  }
1701
 
 
1857
  nth1 = 16;
1858
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
1859
  } break;
1860
+ case GGML_TYPE_IQ4_XS:
1861
+ {
1862
+ nth0 = 4;
1863
+ nth1 = 16;
1864
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
1865
+ } break;
1866
  default:
1867
  {
1868
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
 
1926
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1927
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1928
  }
1929
+ else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
1930
  const int mem_size = 32*sizeof(float);
1931
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1932
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 
1976
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
1977
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
1978
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
1979
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
1980
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1981
  default: GGML_ASSERT(false && "not implemented");
1982
  }
ggml-metal.metal CHANGED
@@ -2560,6 +2560,13 @@ typedef struct {
2560
  uint8_t qs[QK4_NL/2];
2561
  } block_iq4_nl;
2562
 
 
 
 
 
 
 
 
2563
  //====================================== dot products =========================
2564
 
2565
  void kernel_mul_mv_q2_K_f32_impl(
@@ -5160,6 +5167,100 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5160
  }
5161
  }
5162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5163
  [[host_name("kernel_mul_mv_iq1_s_f32")]]
5164
  kernel void kernel_mul_mv_iq1_s_f32(
5165
  device const void * src0,
@@ -5217,6 +5318,35 @@ kernel void kernel_mul_mv_iq4_nl_f32(
5217
  kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5218
  }
5219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5220
  //============================= templates and their specializations =============================
5221
 
5222
  // NOTE: this is not dequantizing - we are simply fitting the template
@@ -5638,6 +5768,26 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
5638
  }
5639
  }
5640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5641
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
5642
  kernel void kernel_get_rows(
5643
  device const void * src0,
@@ -6183,7 +6333,8 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r
6183
  template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
6184
  template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
6185
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6186
- template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
 
6187
 
6188
  //
6189
  // matrix-matrix multiplication
@@ -6226,7 +6377,8 @@ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_m
6226
  template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
6227
  template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
6228
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
6229
- template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
 
6230
 
6231
  //
6232
  // indirect matrix-matrix multiplication
@@ -6281,7 +6433,8 @@ template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel
6281
  template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
6282
  template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
6283
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6284
- template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
 
6285
 
6286
  //
6287
  // matrix-vector multiplication
@@ -7507,3 +7660,68 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
7507
  tiisg,
7508
  sgitg);
7509
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2560
  uint8_t qs[QK4_NL/2];
2561
  } block_iq4_nl;
2562
 
2563
+ typedef struct {
2564
+ half d;
2565
+ uint16_t scales_h;
2566
+ uint8_t scales_l[QK_K/64];
2567
+ uint8_t qs[QK_K/2];
2568
+ } block_iq4_xs;
2569
+
2570
  //====================================== dot products =========================
2571
 
2572
  void kernel_mul_mv_q2_K_f32_impl(
 
5167
  }
5168
  }
5169
 
5170
+ void kernel_mul_mv_iq4_xs_f32_impl(
5171
+ device const void * src0,
5172
+ device const float * src1,
5173
+ device float * dst,
5174
+ constant int64_t & ne00,
5175
+ constant int64_t & ne01,
5176
+ constant int64_t & ne02,
5177
+ constant int64_t & ne10,
5178
+ constant int64_t & ne12,
5179
+ constant int64_t & ne0,
5180
+ constant int64_t & ne1,
5181
+ constant uint & r2,
5182
+ constant uint & r3,
5183
+ threadgroup float * shared_values [[threadgroup(0)]],
5184
+ uint3 tgpig[[threadgroup_position_in_grid]],
5185
+ uint tiisg[[thread_index_in_simdgroup]],
5186
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5187
+
5188
+ const int nb = ne00/QK_K;
5189
+ const int r0 = tgpig.x;
5190
+ const int r1 = tgpig.y;
5191
+ const int im = tgpig.z;
5192
+ const int first_row = (r0 * 2 + sgitg) * 2;
5193
+ const int ib_row = first_row * nb;
5194
+
5195
+ const uint i12 = im%ne12;
5196
+ const uint i13 = im/ne12;
5197
+
5198
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
5199
+ device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
5200
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
5201
+
5202
+ const int ix = tiisg/16; // 0 or 1
5203
+ const int it = tiisg%16; // 0...15
5204
+ const int ib = it/2;
5205
+ const int il = it%2;
5206
+
5207
+ shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
5208
+ threadgroup_barrier(mem_flags::mem_threadgroup);
5209
+
5210
+ float4 yl[4];
5211
+ float sumf[2]={0.f}, all_sum;
5212
+
5213
+ device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
5214
+
5215
+ uint32_t aux32[2];
5216
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
5217
+
5218
+ float4 qf1, qf2;
5219
+
5220
+ for (int ibl = ix; ibl < nb; ibl += 2) {
5221
+
5222
+ device const float4 * y4 = (device const float4 *)yb;
5223
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
5224
+
5225
+ for (int row = 0; row < 2; ++row) {
5226
+
5227
+ device const block_iq4_xs & xb = x[row*nb + ibl];
5228
+ device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
5229
+
5230
+ float4 acc1 = {0.f}, acc2 = {0.f};
5231
+
5232
+ aux32[0] = q4[0] & 0x0f0f0f0f;
5233
+ aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
5234
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
5235
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
5236
+ acc1 += yl[0] * qf1;
5237
+ acc2 += yl[1] * qf2;
5238
+
5239
+ aux32[0] = q4[1] & 0x0f0f0f0f;
5240
+ aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
5241
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
5242
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
5243
+ acc1 += yl[2] * qf1;
5244
+ acc2 += yl[3] * qf2;
5245
+
5246
+ acc1 += acc2;
5247
+
5248
+ const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
5249
+ sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
5250
+
5251
+ }
5252
+
5253
+ yb += 2 * QK_K;
5254
+ }
5255
+
5256
+ for (int row = 0; row < 2; ++row) {
5257
+ all_sum = simd_sum(sumf[row]);
5258
+ if (tiisg == 0) {
5259
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
5260
+ }
5261
+ }
5262
+ }
5263
+
5264
  [[host_name("kernel_mul_mv_iq1_s_f32")]]
5265
  kernel void kernel_mul_mv_iq1_s_f32(
5266
  device const void * src0,
 
5318
  kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5319
  }
5320
 
5321
+ [[host_name("kernel_mul_mv_iq4_xs_f32")]]
5322
+ kernel void kernel_mul_mv_iq4_xs_f32(
5323
+ device const void * src0,
5324
+ device const float * src1,
5325
+ device float * dst,
5326
+ constant int64_t & ne00,
5327
+ constant int64_t & ne01,
5328
+ constant int64_t & ne02,
5329
+ constant uint64_t & nb00,
5330
+ constant uint64_t & nb01,
5331
+ constant uint64_t & nb02,
5332
+ constant int64_t & ne10,
5333
+ constant int64_t & ne11,
5334
+ constant int64_t & ne12,
5335
+ constant uint64_t & nb10,
5336
+ constant uint64_t & nb11,
5337
+ constant uint64_t & nb12,
5338
+ constant int64_t & ne0,
5339
+ constant int64_t & ne1,
5340
+ constant uint & r2,
5341
+ constant uint & r3,
5342
+ threadgroup float * shared_values [[threadgroup(0)]],
5343
+ uint3 tgpig[[threadgroup_position_in_grid]],
5344
+ uint tiisg[[thread_index_in_simdgroup]],
5345
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5346
+
5347
+ kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5348
+ }
5349
+
5350
  //============================= templates and their specializations =============================
5351
 
5352
  // NOTE: this is not dequantizing - we are simply fitting the template
 
5768
  }
5769
  }
5770
 
5771
+ template <typename type4x4>
5772
+ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
5773
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5774
+ const int ib32 = il/2;
5775
+ il = il%2;
5776
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5777
+ device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
5778
+ const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
5779
+ const float d = (float)xb->d * (ls - 32);
5780
+ uint32_t aux32;
5781
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
5782
+ for (int i = 0; i < 4; ++i) {
5783
+ aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
5784
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
5785
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
5786
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
5787
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
5788
+ }
5789
+ }
5790
+
5791
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
5792
  kernel void kernel_get_rows(
5793
  device const void * src0,
 
6333
  template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
6334
  template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
6335
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6336
+ template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6337
+ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6338
 
6339
  //
6340
  // matrix-matrix multiplication
 
6377
  template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
6378
  template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
6379
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
6380
+ template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6381
+ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6382
 
6383
  //
6384
  // indirect matrix-matrix multiplication
 
6433
  template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
6434
  template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
6435
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6436
+ template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6437
+ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6438
 
6439
  //
6440
  // matrix-vector multiplication
 
7660
  tiisg,
7661
  sgitg);
7662
  }
7663
+
7664
+ [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
7665
+ kernel void kernel_mul_mv_id_iq4_xs_f32(
7666
+ device const char * ids,
7667
+ device const char * src1,
7668
+ device float * dst,
7669
+ constant uint64_t & nbi1,
7670
+ constant int64_t & ne00,
7671
+ constant int64_t & ne01,
7672
+ constant int64_t & ne02,
7673
+ constant uint64_t & nb00,
7674
+ constant uint64_t & nb01,
7675
+ constant uint64_t & nb02,
7676
+ constant int64_t & ne10,
7677
+ constant int64_t & ne11,
7678
+ constant int64_t & ne12,
7679
+ constant int64_t & ne13,
7680
+ constant uint64_t & nb10,
7681
+ constant uint64_t & nb11,
7682
+ constant uint64_t & nb12,
7683
+ constant int64_t & ne0,
7684
+ constant int64_t & ne1,
7685
+ constant uint64_t & nb1,
7686
+ constant uint & r2,
7687
+ constant uint & r3,
7688
+ constant int & idx,
7689
+ device const char * src00,
7690
+ device const char * src01,
7691
+ device const char * src02,
7692
+ device const char * src03,
7693
+ device const char * src04,
7694
+ device const char * src05,
7695
+ device const char * src06,
7696
+ device const char * src07,
7697
+ threadgroup float * shared_values [[threadgroup(0)]],
7698
+ uint3 tgpig[[threadgroup_position_in_grid]],
7699
+ uint tiitg[[thread_index_in_threadgroup]],
7700
+ uint tiisg[[thread_index_in_simdgroup]],
7701
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
7702
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7703
+
7704
+ const int64_t bid = tgpig.z/(ne12*ne13);
7705
+
7706
+ tgpig.z = tgpig.z%(ne12*ne13);
7707
+
7708
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7709
+
7710
+ kernel_mul_mv_iq4_xs_f32_impl(
7711
+ src0[id],
7712
+ (device const float *) (src1 + bid*nb11),
7713
+ dst + bid*ne0,
7714
+ ne00,
7715
+ ne01,
7716
+ ne02,
7717
+ ne10,
7718
+ ne12,
7719
+ ne0,
7720
+ ne1,
7721
+ r2,
7722
+ r3,
7723
+ shared_values,
7724
+ tgpig,
7725
+ tiisg,
7726
+ sgitg);
7727
+ }
ggml-quants.c CHANGED
@@ -4225,6 +4225,29 @@ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y,
4225
  }
4226
  }
4227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4228
  //===================================== Q8_K ==============================================
4229
 
4230
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -9675,8 +9698,8 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
9675
  qs += 8;
9676
 
9677
  vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
9678
- vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
9679
- vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
9680
  vs.val[0] = vceqq_u8(vs.val[0], mask2);
9681
  vs.val[1] = vceqq_u8(vs.val[1], mask2);
9682
 
@@ -9684,8 +9707,8 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
9684
  q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
9685
 
9686
  vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
9687
- vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
9688
- vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
9689
  vs.val[0] = vceqq_u8(vs.val[0], mask2);
9690
  vs.val[1] = vceqq_u8(vs.val[1], mask2);
9691
 
@@ -10425,6 +10448,134 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
10425
  #endif
10426
  }
10427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10428
  // ================================ IQ2 quantization =============================================
10429
 
10430
  typedef struct {
@@ -12021,23 +12172,23 @@ static inline int best_index_int8(int n, const int8_t * val, float x) {
12021
  return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
12022
  }
12023
 
12024
- static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RESTRICT x,
12025
- ggml_fp16_t * dh, uint8_t * q4,
12026
- float * weight, uint8_t * L,
12027
  const int8_t * values,
12028
  const float * quant_weights) {
12029
 
12030
  const int ntry = 7;
12031
 
12032
  float sigma2 = 0;
12033
- for (int j = 0; j < QK4_NL; ++j) sigma2 += x[j]*x[j];
12034
- sigma2 *= 2.f/QK4_NL;
12035
 
12036
- const int nb = QK4_NL/block_size;
 
12037
 
12038
- memset(q4, 0, QK4_NL/2);
12039
- for (int ib = 0; ib < nb; ++ib) {
12040
- dh[ib] = GGML_FP32_TO_FP16(0.f);
12041
  const float * xb = x + ib*block_size;
12042
  if (quant_weights) {
12043
  const float * qw = quant_weights + ib*block_size;
@@ -12053,6 +12204,7 @@ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RE
12053
  }
12054
  }
12055
  if (!amax) {
 
12056
  continue;
12057
  }
12058
  float d = -max/values[0];
@@ -12066,7 +12218,6 @@ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RE
12066
  sumqx += w*q*xb[j];
12067
  sumq2 += w*q*q;
12068
  }
12069
- float best_id = id;
12070
  d = sumqx/sumq2;
12071
  float best = d*sumqx;
12072
  for (int itry = -ntry; itry <= ntry; ++itry) {
@@ -12082,15 +12233,47 @@ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RE
12082
  }
12083
  if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
12084
  d = sumqx/sumq2; best = d * sumqx;
12085
- best_id = id;
12086
  }
12087
  }
12088
- dh[ib] = GGML_FP32_TO_FP16(d);
12089
- for (int j = 0; j < block_size; ++j) {
12090
- L[ib*block_size + j] = best_index_int8(16, values, best_id*xb[j]);
 
12091
  }
12092
  }
12093
- for (int i = 0; i < QK4_NL/32; ++i) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12094
  for (int j = 0; j < 16; ++j) {
12095
  q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
12096
  }
@@ -12103,12 +12286,16 @@ size_t quantize_iq4_nl(const float * src, void * dst, int nrow, int n_per_row, i
12103
  int nblock = n_per_row/QK4_NL;
12104
  char * qrow = (char *)dst;
12105
  uint8_t L[QK4_NL];
12106
- float weight[32];
 
 
 
12107
  for (int row = 0; row < nrow; ++row) {
12108
  block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
12109
  for (int ibl = 0; ibl < nblock; ++ibl) {
12110
  const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
12111
- quantize_row_iq4_nl_impl(32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, weight, L, kvalues_iq4nl, qw);
 
12112
  }
12113
  src += n_per_row;
12114
  qrow += nblock*sizeof(block_iq4_nl);
@@ -12127,6 +12314,38 @@ void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * rest
12127
  quantize_iq4_nl(x, y, 1, k, NULL, NULL);
12128
  }
12129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12130
  // =============================== 2.5625 bpw
12131
 
12132
  static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
 
4225
  }
4226
  }
4227
 
4228
+ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int k) {
4229
+ assert(k % QK_K == 0);
4230
+ const int nb = k / QK_K;
4231
+
4232
+ for (int i = 0; i < nb; i++) {
4233
+
4234
+ const uint8_t * qs = x[i].qs;
4235
+
4236
+ const float d = GGML_FP16_TO_FP32(x[i].d);
4237
+
4238
+ for (int ib = 0; ib < QK_K/32; ++ib) {
4239
+ const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4);
4240
+ const float dl = d * (ls - 32);
4241
+ for (int j = 0; j < 16; ++j) {
4242
+ y[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf];
4243
+ y[j+16] = dl * kvalues_iq4nl[qs[j] >> 4];
4244
+ }
4245
+ y += 32;
4246
+ qs += 16;
4247
+ }
4248
+ }
4249
+ }
4250
+
4251
  //===================================== Q8_K ==============================================
4252
 
4253
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
 
9698
  qs += 8;
9699
 
9700
  vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
9701
+ vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
9702
+ vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
9703
  vs.val[0] = vceqq_u8(vs.val[0], mask2);
9704
  vs.val[1] = vceqq_u8(vs.val[1], mask2);
9705
 
 
9707
  q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
9708
 
9709
  vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
9710
+ vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
9711
+ vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
9712
  vs.val[0] = vceqq_u8(vs.val[0], mask2);
9713
  vs.val[1] = vceqq_u8(vs.val[1], mask2);
9714
 
 
10448
  #endif
10449
  }
10450
 
10451
+ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
10452
+ assert(nrc == 1);
10453
+ UNUSED(nrc);
10454
+ UNUSED(bx);
10455
+ UNUSED(by);
10456
+ UNUSED(bs);
10457
+ assert(n % QK_K == 0);
10458
+
10459
+ const block_iq4_xs * restrict x = vx;
10460
+ const block_q8_K * restrict y = vy;
10461
+
10462
+ const int nb = n / QK_K;
10463
+
10464
+ #if defined __ARM_NEON
10465
+ const int8x16_t values = vld1q_s8(kvalues_iq4nl);
10466
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
10467
+ uint8x16x2_t q4bits;
10468
+ int8x16x4_t q4b;
10469
+ int8x16x4_t q8b;
10470
+ int32x4_t prod_1, prod_2;
10471
+
10472
+ float sumf = 0;
10473
+
10474
+ for (int ibl = 0; ibl < nb; ++ibl) {
10475
+
10476
+ const int8_t * q8 = y[ibl].qs;
10477
+ const uint8_t * q4 = x[ibl].qs;
10478
+ uint16_t h = x[ibl].scales_h;
10479
+
10480
+ int sumi1 = 0, sumi2 = 0;
10481
+ for (int ib = 0; ib < QK_K/64; ++ib) {
10482
+
10483
+ q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
10484
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
10485
+
10486
+ q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
10487
+ q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
10488
+ q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
10489
+ q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
10490
+
10491
+ prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
10492
+ prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
10493
+
10494
+ int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
10495
+ int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
10496
+ h >>= 4;
10497
+ sumi1 += vaddvq_s32(prod_1) * ls1;
10498
+ sumi2 += vaddvq_s32(prod_2) * ls2;
10499
+
10500
+ }
10501
+
10502
+ sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
10503
+ }
10504
+
10505
+ *s = sumf;
10506
+
10507
+ #elif defined __AVX2__
10508
+
10509
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
10510
+ const __m128i m4b = _mm_set1_epi8(0x0f);
10511
+
10512
+ __m256 accum = _mm256_setzero_ps();
10513
+ for (int ibl = 0; ibl < nb; ++ibl) {
10514
+ const uint8_t * qs = x[ibl].qs;
10515
+ const int8_t * q8 = y[ibl].qs;
10516
+ uint16_t sh = x[ibl].scales_h;
10517
+ __m256i sumi1 = _mm256_setzero_si256();
10518
+ __m256i sumi2 = _mm256_setzero_si256();
10519
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
10520
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
10521
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
10522
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
10523
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
10524
+ const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
10525
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
10526
+ const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
10527
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
10528
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
10529
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
10530
+ const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
10531
+ const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
10532
+ sh >>= 4;
10533
+ const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1));
10534
+ const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2));
10535
+ sumi1 = _mm256_add_epi32(p_1, sumi1);
10536
+ sumi2 = _mm256_add_epi32(p_2, sumi2);
10537
+ }
10538
+ accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
10539
+ _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum);
10540
+ }
10541
+
10542
+ *s = hsum_float_8(accum);
10543
+
10544
+ #else
10545
+ float sumf = 0;
10546
+ for (int ibl = 0; ibl < nb; ++ibl) {
10547
+ const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
10548
+ uint16_t h = x[ibl].scales_h;
10549
+ const uint8_t * qs = x[ibl].qs;
10550
+ const int8_t * q8 = y[ibl].qs;
10551
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
10552
+ const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
10553
+ const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30);
10554
+ h >>= 4;
10555
+ const float d1 = d4d8*(ls1 - 32);
10556
+ const float d2 = d4d8*(ls2 - 32);
10557
+ int sumi1 = 0, sumi2 = 0;
10558
+ for (int j = 0; j < 16; ++j) {
10559
+ sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
10560
+ sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
10561
+ }
10562
+ sumf += d1 * (sumi1 + sumi2);
10563
+ qs += 16;
10564
+ q8 += 32;
10565
+ sumi1 = sumi2 = 0;
10566
+ for (int j = 0; j < 16; ++j) {
10567
+ sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
10568
+ sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
10569
+ }
10570
+ sumf += d2 * (sumi1 + sumi2);
10571
+ qs += 16;
10572
+ q8 += 32;
10573
+ }
10574
+ }
10575
+ *s = sumf;
10576
+ #endif
10577
+ }
10578
+
10579
  // ================================ IQ2 quantization =============================================
10580
 
10581
  typedef struct {
 
12172
  return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
12173
  }
12174
 
12175
+ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
12176
+ ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
12177
+ float * scales, float * weight, uint8_t * L,
12178
  const int8_t * values,
12179
  const float * quant_weights) {
12180
 
12181
  const int ntry = 7;
12182
 
12183
  float sigma2 = 0;
12184
+ for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];
12185
+ sigma2 *= 2.f/super_block_size;
12186
 
12187
+ memset(q4, 0, super_block_size/2);
12188
+ dh[0] = GGML_FP32_TO_FP16(0.f);
12189
 
12190
+ float max_scale = 0, amax_scale = 0;
12191
+ for (int ib = 0; ib < super_block_size/block_size; ++ib) {
 
12192
  const float * xb = x + ib*block_size;
12193
  if (quant_weights) {
12194
  const float * qw = quant_weights + ib*block_size;
 
12204
  }
12205
  }
12206
  if (!amax) {
12207
+ scales[ib] = 0;
12208
  continue;
12209
  }
12210
  float d = -max/values[0];
 
12218
  sumqx += w*q*xb[j];
12219
  sumq2 += w*q*q;
12220
  }
 
12221
  d = sumqx/sumq2;
12222
  float best = d*sumqx;
12223
  for (int itry = -ntry; itry <= ntry; ++itry) {
 
12233
  }
12234
  if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
12235
  d = sumqx/sumq2; best = d * sumqx;
 
12236
  }
12237
  }
12238
+ scales[ib] = d;
12239
+ float abs_d = fabsf(d);
12240
+ if (abs_d > amax_scale) {
12241
+ amax_scale = abs_d; max_scale = d;
12242
  }
12243
  }
12244
+
12245
+ if (super_block_size/block_size > 1) {
12246
+ int nb = super_block_size/block_size;
12247
+ memset(scales_h, 0, ((nb+7)/8)*sizeof(uint16_t));
12248
+ float d = -max_scale/32;
12249
+ dh[0] = GGML_FP32_TO_FP16(d);
12250
+ float id = d ? 1/d : 0.f;
12251
+ for (int ib = 0; ib < super_block_size/block_size; ++ib) {
12252
+ int l = nearest_int(id*scales[ib]);
12253
+ l = MAX(-32, MIN(31, l));
12254
+ float dl = d * l;
12255
+ float idl = dl ? 1/dl : 0.f;
12256
+ uint8_t * Lb = L + ib*block_size;
12257
+ const float * xb = x + ib*block_size;
12258
+ for (int j = 0; j < block_size; ++j) {
12259
+ Lb[j] = best_index_int8(16, values, idl*xb[j]);
12260
+ }
12261
+ l += 32;
12262
+ uint8_t l_l = l & 0xf;
12263
+ uint8_t l_h = l >> 4;
12264
+ if (ib%2 == 0) scales_l[ib/2] = l_l;
12265
+ else scales_l[ib/2] |= (l_l << 4);
12266
+ scales_h[ib/8] |= (l_h << 2*(ib%8));
12267
+ }
12268
+ } else {
12269
+ dh[0] = GGML_FP32_TO_FP16(scales[0]);
12270
+ float id = scales[0] ? 1/scales[0] : 0;
12271
+ for (int j = 0; j < super_block_size; ++j) {
12272
+ L[j] = best_index_int8(16, values, id*x[j]);
12273
+ }
12274
+ }
12275
+
12276
+ for (int i = 0; i < super_block_size/32; ++i) {
12277
  for (int j = 0; j < 16; ++j) {
12278
  q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
12279
  }
 
12286
  int nblock = n_per_row/QK4_NL;
12287
  char * qrow = (char *)dst;
12288
  uint8_t L[QK4_NL];
12289
+ float weight[QK4_NL];
12290
+ uint16_t unused_h;
12291
+ uint8_t * unused_l = NULL;
12292
+ float scale;
12293
  for (int row = 0; row < nrow; ++row) {
12294
  block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
12295
  for (int ibl = 0; ibl < nblock; ++ibl) {
12296
  const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
12297
+ quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
12298
+ &scale, weight, L, kvalues_iq4nl, qw);
12299
  }
12300
  src += n_per_row;
12301
  qrow += nblock*sizeof(block_iq4_nl);
 
12314
  quantize_iq4_nl(x, y, 1, k, NULL, NULL);
12315
  }
12316
 
12317
+ size_t quantize_iq4_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
12318
+ (void)hist;
12319
+ GGML_ASSERT(n_per_row%QK_K == 0);
12320
+ int nblock = n_per_row/QK_K;
12321
+ char * qrow = (char *)dst;
12322
+ uint8_t L[QK_K];
12323
+ float weight[32];
12324
+ float scales[QK_K/32];
12325
+ for (int row = 0; row < nrow; ++row) {
12326
+ block_iq4_xs * iq4 = (block_iq4_xs *)qrow;
12327
+ for (int ibl = 0; ibl < nblock; ++ibl) {
12328
+ const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;
12329
+ quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l,
12330
+ scales, weight, L, kvalues_iq4nl, qw);
12331
+ }
12332
+ src += n_per_row;
12333
+ qrow += nblock*sizeof(block_iq4_xs);
12334
+ }
12335
+ return nrow * nblock * sizeof(block_iq4_xs);
12336
+ }
12337
+
12338
+ void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int k) {
12339
+ assert(k % QK_K == 0);
12340
+ block_iq4_xs * restrict y = vy;
12341
+ quantize_row_iq4_xs_reference(x, y, k);
12342
+ }
12343
+
12344
+ void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int k) {
12345
+ assert(k % QK_K == 0);
12346
+ quantize_iq4_xs(x, y, 1, k, NULL, NULL);
12347
+ }
12348
+
12349
  // =============================== 2.5625 bpw
12350
 
12351
  static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
ggml-quants.h CHANGED
@@ -230,6 +230,14 @@ typedef struct {
230
  } block_iq4_nl;
231
  static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
232
 
 
 
 
 
 
 
 
 
233
  #ifdef __cplusplus
234
  extern "C" {
235
  #endif
@@ -250,6 +258,7 @@ void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGM
250
  void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
251
  void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
252
  void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k);
 
253
  void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int k);
254
  void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int k);
255
 
@@ -268,6 +277,7 @@ void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
268
  void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
269
  void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
270
  void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 
271
  void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
272
  void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
273
 
@@ -291,6 +301,7 @@ void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_
291
  void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
292
  void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
293
  void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 
294
  void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
295
 
296
  // Dot product
@@ -311,6 +322,7 @@ void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
311
  void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
312
  void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
313
  void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
314
  void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
315
 
316
  //
@@ -322,6 +334,7 @@ size_t quantize_iq2_s (const float * src, void * dst, int nrows, int n_per_row,
322
  size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
323
  size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
324
  size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 
325
  size_t quantize_iq3_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
326
  size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
327
  size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 
230
  } block_iq4_nl;
231
  static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
232
 
233
+ typedef struct {
234
+ ggml_fp16_t d;
235
+ uint16_t scales_h;
236
+ uint8_t scales_l[QK_K/64];
237
+ uint8_t qs[QK_K/2];
238
+ } block_iq4_xs;
239
+ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
240
+
241
  #ifdef __cplusplus
242
  extern "C" {
243
  #endif
 
258
  void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
259
  void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
260
  void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k);
261
+ void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int k);
262
  void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int k);
263
  void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int k);
264
 
 
277
  void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
278
  void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
279
  void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
280
+ void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
281
  void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
282
  void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
283
 
 
301
  void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
302
  void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
303
  void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
304
+ void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
305
  void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
306
 
307
  // Dot product
 
322
  void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
323
  void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
324
  void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
325
+ void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
326
  void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
327
 
328
  //
 
334
  size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
335
  size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
336
  size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
337
+ size_t quantize_iq4_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
338
  size_t quantize_iq3_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
339
  size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
340
  size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
ggml.c CHANGED
@@ -730,6 +730,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
730
  .vec_dot_type = GGML_TYPE_Q8_0,
731
  .nrows = 1,
732
  },
 
 
 
 
 
 
 
 
 
 
 
 
733
  [GGML_TYPE_Q8_K] = {
734
  .type_name = "q8_K",
735
  .blck_size = QK_K,
@@ -2338,6 +2350,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2338
  case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
2339
  case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
2340
  case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
 
2341
  case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
2342
  case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
2343
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
@@ -7776,6 +7789,7 @@ static void ggml_compute_forward_add(
7776
  case GGML_TYPE_IQ3_XXS:
7777
  case GGML_TYPE_IQ1_S:
7778
  case GGML_TYPE_IQ4_NL:
 
7779
  case GGML_TYPE_IQ3_S:
7780
  case GGML_TYPE_IQ2_S:
7781
  {
@@ -8057,6 +8071,7 @@ static void ggml_compute_forward_add1(
8057
  case GGML_TYPE_IQ3_XXS:
8058
  case GGML_TYPE_IQ1_S:
8059
  case GGML_TYPE_IQ4_NL:
 
8060
  case GGML_TYPE_IQ3_S:
8061
  case GGML_TYPE_IQ2_S:
8062
  {
@@ -8183,6 +8198,7 @@ static void ggml_compute_forward_acc(
8183
  case GGML_TYPE_IQ3_XXS:
8184
  case GGML_TYPE_IQ1_S:
8185
  case GGML_TYPE_IQ4_NL:
 
8186
  case GGML_TYPE_IQ3_S:
8187
  case GGML_TYPE_IQ2_S:
8188
  default:
@@ -11083,6 +11099,7 @@ static void ggml_compute_forward_out_prod(
11083
  case GGML_TYPE_IQ3_XXS:
11084
  case GGML_TYPE_IQ1_S:
11085
  case GGML_TYPE_IQ4_NL:
 
11086
  case GGML_TYPE_IQ3_S:
11087
  case GGML_TYPE_IQ2_S:
11088
  {
@@ -11273,6 +11290,7 @@ static void ggml_compute_forward_set(
11273
  case GGML_TYPE_IQ3_XXS:
11274
  case GGML_TYPE_IQ1_S:
11275
  case GGML_TYPE_IQ4_NL:
 
11276
  case GGML_TYPE_IQ3_S:
11277
  case GGML_TYPE_IQ2_S:
11278
  default:
@@ -11477,6 +11495,7 @@ static void ggml_compute_forward_get_rows(
11477
  case GGML_TYPE_IQ3_XXS:
11478
  case GGML_TYPE_IQ1_S:
11479
  case GGML_TYPE_IQ4_NL:
 
11480
  case GGML_TYPE_IQ3_S:
11481
  case GGML_TYPE_IQ2_S:
11482
  {
@@ -12179,6 +12198,7 @@ static void ggml_compute_forward_alibi(
12179
  case GGML_TYPE_IQ3_XXS:
12180
  case GGML_TYPE_IQ1_S:
12181
  case GGML_TYPE_IQ4_NL:
 
12182
  case GGML_TYPE_IQ3_S:
12183
  case GGML_TYPE_IQ2_S:
12184
  case GGML_TYPE_Q8_K:
@@ -12264,6 +12284,7 @@ static void ggml_compute_forward_clamp(
12264
  case GGML_TYPE_IQ3_XXS:
12265
  case GGML_TYPE_IQ1_S:
12266
  case GGML_TYPE_IQ4_NL:
 
12267
  case GGML_TYPE_IQ3_S:
12268
  case GGML_TYPE_IQ2_S:
12269
  case GGML_TYPE_Q8_K:
@@ -19835,6 +19856,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
19835
  result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19836
  GGML_ASSERT(result == row_size * nrows);
19837
  } break;
 
 
 
 
 
 
 
 
 
19838
  case GGML_TYPE_F16:
19839
  {
19840
  size_t elemsize = sizeof(ggml_fp16_t);
 
730
  .vec_dot_type = GGML_TYPE_Q8_0,
731
  .nrows = 1,
732
  },
733
+ [GGML_TYPE_IQ4_XS] = {
734
+ .type_name = "iq4_xs",
735
+ .blck_size = QK_K,
736
+ .type_size = sizeof(block_iq4_xs),
737
+ .is_quantized = true,
738
+ .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
739
+ .from_float = quantize_row_iq4_xs,
740
+ .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
741
+ .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
742
+ .vec_dot_type = GGML_TYPE_Q8_K,
743
+ .nrows = 1,
744
+ },
745
  [GGML_TYPE_Q8_K] = {
746
  .type_name = "q8_K",
747
  .blck_size = QK_K,
 
2350
  case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
2351
  case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
2352
  case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
2353
+ case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
2354
  case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
2355
  case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
2356
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
 
7789
  case GGML_TYPE_IQ3_XXS:
7790
  case GGML_TYPE_IQ1_S:
7791
  case GGML_TYPE_IQ4_NL:
7792
+ case GGML_TYPE_IQ4_XS:
7793
  case GGML_TYPE_IQ3_S:
7794
  case GGML_TYPE_IQ2_S:
7795
  {
 
8071
  case GGML_TYPE_IQ3_XXS:
8072
  case GGML_TYPE_IQ1_S:
8073
  case GGML_TYPE_IQ4_NL:
8074
+ case GGML_TYPE_IQ4_XS:
8075
  case GGML_TYPE_IQ3_S:
8076
  case GGML_TYPE_IQ2_S:
8077
  {
 
8198
  case GGML_TYPE_IQ3_XXS:
8199
  case GGML_TYPE_IQ1_S:
8200
  case GGML_TYPE_IQ4_NL:
8201
+ case GGML_TYPE_IQ4_XS:
8202
  case GGML_TYPE_IQ3_S:
8203
  case GGML_TYPE_IQ2_S:
8204
  default:
 
11099
  case GGML_TYPE_IQ3_XXS:
11100
  case GGML_TYPE_IQ1_S:
11101
  case GGML_TYPE_IQ4_NL:
11102
+ case GGML_TYPE_IQ4_XS:
11103
  case GGML_TYPE_IQ3_S:
11104
  case GGML_TYPE_IQ2_S:
11105
  {
 
11290
  case GGML_TYPE_IQ3_XXS:
11291
  case GGML_TYPE_IQ1_S:
11292
  case GGML_TYPE_IQ4_NL:
11293
+ case GGML_TYPE_IQ4_XS:
11294
  case GGML_TYPE_IQ3_S:
11295
  case GGML_TYPE_IQ2_S:
11296
  default:
 
11495
  case GGML_TYPE_IQ3_XXS:
11496
  case GGML_TYPE_IQ1_S:
11497
  case GGML_TYPE_IQ4_NL:
11498
+ case GGML_TYPE_IQ4_XS:
11499
  case GGML_TYPE_IQ3_S:
11500
  case GGML_TYPE_IQ2_S:
11501
  {
 
12198
  case GGML_TYPE_IQ3_XXS:
12199
  case GGML_TYPE_IQ1_S:
12200
  case GGML_TYPE_IQ4_NL:
12201
+ case GGML_TYPE_IQ4_XS:
12202
  case GGML_TYPE_IQ3_S:
12203
  case GGML_TYPE_IQ2_S:
12204
  case GGML_TYPE_Q8_K:
 
12284
  case GGML_TYPE_IQ3_XXS:
12285
  case GGML_TYPE_IQ1_S:
12286
  case GGML_TYPE_IQ4_NL:
12287
+ case GGML_TYPE_IQ4_XS:
12288
  case GGML_TYPE_IQ3_S:
12289
  case GGML_TYPE_IQ2_S:
12290
  case GGML_TYPE_Q8_K:
 
19856
  result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19857
  GGML_ASSERT(result == row_size * nrows);
19858
  } break;
19859
+ case GGML_TYPE_IQ4_XS:
19860
+ {
19861
+ GGML_ASSERT(start % QK4_NL == 0);
19862
+ GGML_ASSERT(start % n_per_row == 0);
19863
+ size_t start_row = start / n_per_row;
19864
+ size_t row_size = ggml_row_size(type, n_per_row);
19865
+ result = quantize_iq4_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19866
+ GGML_ASSERT(result == row_size * nrows);
19867
+ } break;
19868
  case GGML_TYPE_F16:
19869
  {
19870
  size_t elemsize = sizeof(ggml_fp16_t);
ggml.h CHANGED
@@ -352,6 +352,7 @@ extern "C" {
352
  GGML_TYPE_IQ4_NL = 20,
353
  GGML_TYPE_IQ3_S = 21,
354
  GGML_TYPE_IQ2_S = 22,
 
355
  GGML_TYPE_I8,
356
  GGML_TYPE_I16,
357
  GGML_TYPE_I32,
@@ -393,6 +394,7 @@ extern "C" {
393
  GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
394
  GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
395
  GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
 
396
  };
397
 
398
  // available tensor operations:
 
352
  GGML_TYPE_IQ4_NL = 20,
353
  GGML_TYPE_IQ3_S = 21,
354
  GGML_TYPE_IQ2_S = 22,
355
+ GGML_TYPE_IQ4_XS = 23,
356
  GGML_TYPE_I8,
357
  GGML_TYPE_I16,
358
  GGML_TYPE_I32,
 
394
  GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
395
  GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
396
  GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
397
+ GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
398
  };
399
 
400
  // available tensor operations: