Kawrakow ikawrakow commited on
Commit
9a07f42
·
unverified ·
1 Parent(s): e720b3b

ggml : make i-quants work with super-blocks of 64 (CPU,Metal) (llama/5760)

Browse files

* WIP: make i-quants work for QK_K = 64

* iq2_xs: attempt to fix AVX dot product for QK_K = 64

Tests pass, but I get gibberish.

* QK_K = 64 tests pass on ARM_NEON and Metal

Sadly, that does not mean it actually works.

* Make CUDA compile with QK_K = 64

Tests don't pass, plus we get misaligned access

* Q2_K: fixed bug in imatrix quantization for QK_K = 64

* iq1_s: turn off SIMD implementation for QK_K = 64 (it does not work)

---------

Co-authored-by: Iwan Kawrakow <[email protected]>

Files changed (5) hide show
  1. ggml-cuda.cu +20 -7
  2. ggml-metal.metal +30 -28
  3. ggml-quants.c +125 -23
  4. ggml-quants.h +5 -0
  5. ggml.c +14 -1
ggml-cuda.cu CHANGED
@@ -544,14 +544,19 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
544
 
545
  #define QR3_XS 8
546
  #define QI3_XS (QK_K / (4*QR3_XS))
 
 
 
 
 
547
  typedef struct {
548
  half d;
549
  uint8_t qs[QK_K/4];
550
  uint8_t qh[QK_K/32];
551
  uint8_t signs[QK_K/8];
552
- uint8_t scales[QK_K/64];
553
  } block_iq3_s;
554
- static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding");
555
 
556
  #define QR1_S 8
557
  #define QI1_S (QK_K / (4*QR1_S))
@@ -571,6 +576,11 @@ typedef struct {
571
  } block_iq4_nl;
572
  static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
573
 
 
 
 
 
 
574
  // QR4_XS = 8 is very slightly faster than QR4_XS = 4
575
  #define QR4_XS 8
576
  #define QI4_XS (QK_K / (4*QR4_XS))
@@ -581,7 +591,7 @@ typedef struct {
581
  uint8_t qs[QK_K/2];
582
  } block_iq4_xs;
583
  static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
584
-
585
 
586
  #define WARP_SIZE 32
587
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
@@ -2439,9 +2449,9 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
2439
 
2440
  }
2441
 
 
2442
  template<typename dst_t>
2443
  static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
2444
-
2445
  const int i = blockIdx.x;
2446
  const block_iq4_xs * x = (const block_iq4_xs *)vx;
2447
 
@@ -2455,8 +2465,8 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
2455
  y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
2456
  y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
2457
  }
2458
-
2459
  }
 
2460
 
2461
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
2462
 
@@ -5382,8 +5392,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
5382
  return 0.f;
5383
  #endif
5384
  #else
5385
- assert(false);
5386
- return 0.f;
5387
  #endif
5388
  }
5389
 
@@ -7444,7 +7453,11 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k,
7444
  template<typename dst_t>
7445
  static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
7446
  const int nb = (k + QK_K - 1) / QK_K;
 
 
 
7447
  dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
 
7448
  }
7449
 
7450
  template <typename src_t, typename dst_t>
 
544
 
545
  #define QR3_XS 8
546
  #define QI3_XS (QK_K / (4*QR3_XS))
547
+ #if QK_K == 64
548
+ #define IQ3S_N_SCALE 2
549
+ #else
550
+ #define IQ3S_N_SCALE QK_K/64
551
+ #endif
552
  typedef struct {
553
  half d;
554
  uint8_t qs[QK_K/4];
555
  uint8_t qh[QK_K/32];
556
  uint8_t signs[QK_K/8];
557
+ uint8_t scales[IQ3S_N_SCALE];
558
  } block_iq3_s;
559
+ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
560
 
561
  #define QR1_S 8
562
  #define QI1_S (QK_K / (4*QR1_S))
 
576
  } block_iq4_nl;
577
  static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
578
 
579
+ #if QK_K == 64
580
+ #define block_iq4_xs block_iq4_nl
581
+ #define QR4_XS QR4_NL
582
+ #define QI4_XS QI4_NL
583
+ #else
584
  // QR4_XS = 8 is very slightly faster than QR4_XS = 4
585
  #define QR4_XS 8
586
  #define QI4_XS (QK_K / (4*QR4_XS))
 
591
  uint8_t qs[QK_K/2];
592
  } block_iq4_xs;
593
  static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
594
+ #endif
595
 
596
  #define WARP_SIZE 32
597
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
2449
 
2450
  }
2451
 
2452
+ #if QK_K != 64
2453
  template<typename dst_t>
2454
  static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
2455
  const int i = blockIdx.x;
2456
  const block_iq4_xs * x = (const block_iq4_xs *)vx;
2457
 
 
2465
  y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
2466
  y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
2467
  }
 
2468
  }
2469
+ #endif
2470
 
2471
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
2472
 
 
5392
  return 0.f;
5393
  #endif
5394
  #else
5395
+ return vec_dot_iq4_xs_q8_1(vbq, bq8_1, iqs);
 
5396
  #endif
5397
  }
5398
 
 
7453
  template<typename dst_t>
7454
  static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
7455
  const int nb = (k + QK_K - 1) / QK_K;
7456
+ #if QK_K == 64
7457
+ dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
7458
+ #else
7459
  dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
7460
+ #endif
7461
  }
7462
 
7463
  template <typename src_t, typename dst_t>
ggml-metal.metal CHANGED
@@ -2560,12 +2560,16 @@ typedef struct {
2560
  uint8_t qs[QK4_NL/2];
2561
  } block_iq4_nl;
2562
 
 
 
 
2563
  typedef struct {
2564
  half d;
2565
  uint16_t scales_h;
2566
  uint8_t scales_l[QK_K/64];
2567
  uint8_t qs[QK_K/2];
2568
  } block_iq4_xs;
 
2569
 
2570
  //====================================== dot products =========================
2571
 
@@ -4346,7 +4350,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
4346
  threadgroup_barrier(mem_flags::mem_threadgroup);
4347
  }
4348
 
4349
- #if QK_K == 256
4350
  const int ix = tiisg;
4351
 
4352
  device const float * y4 = y + 32 * ix;
@@ -4387,12 +4390,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
4387
 
4388
  y4 += 32 * 32;
4389
  }
4390
- #else
4391
- (void) x;
4392
- (void) y;
4393
- (void) yl;
4394
- (void) nb32;
4395
- #endif
4396
 
4397
  for (int row = 0; row < N_DST; ++row) {
4398
  all_sum = simd_sum(sumf[row]);
@@ -4482,7 +4479,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
4482
  threadgroup_barrier(mem_flags::mem_threadgroup);
4483
  }
4484
 
4485
- #if QK_K == 256
4486
  const int ix = tiisg;
4487
 
4488
  device const float * y4 = y + 32 * ix;
@@ -4533,12 +4529,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
4533
 
4534
  y4 += 32 * 32;
4535
  }
4536
- #else
4537
- (void) x;
4538
- (void) y;
4539
- (void) yl;
4540
- (void) nb32;
4541
- #endif
4542
 
4543
  for (int row = 0; row < N_DST; ++row) {
4544
  all_sum = simd_sum(sumf[row]);
@@ -4628,7 +4618,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4628
  threadgroup_barrier(mem_flags::mem_threadgroup);
4629
  }
4630
 
4631
- #if QK_K == 256
4632
  const int ix = tiisg;
4633
 
4634
  device const float * y4 = y + 32 * ix;
@@ -4672,12 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4672
 
4673
  y4 += 32 * 32;
4674
  }
4675
- #else
4676
- (void) x;
4677
- (void) y;
4678
- (void) yl;
4679
- (void) nb32;
4680
- #endif
4681
 
4682
  for (int row = 0; row < N_DST; ++row) {
4683
  all_sum = simd_sum(sumf[row]);
@@ -5016,7 +4999,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
5016
 
5017
  const int nb32 = nb * (QK_K / 32);
5018
 
5019
- #if QK_K == 256
5020
  const int ix = tiisg/2;
5021
  const int il = tiisg%2;
5022
 
@@ -5055,12 +5037,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
5055
 
5056
  y4 += 16 * 32;
5057
  }
5058
- #else
5059
- (void) x;
5060
- (void) y;
5061
- (void) yl;
5062
- (void) nb32;
5063
- #endif
5064
 
5065
  for (int row = 0; row < N_DST; ++row) {
5066
  all_sum = simd_sum(sumf[row]);
@@ -5167,6 +5143,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5167
  }
5168
  }
5169
 
 
5170
  void kernel_mul_mv_iq4_xs_f32_impl(
5171
  device const void * src0,
5172
  device const float * src1,
@@ -5260,6 +5237,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5260
  }
5261
  }
5262
  }
 
5263
 
5264
  [[host_name("kernel_mul_mv_iq1_s_f32")]]
5265
  kernel void kernel_mul_mv_iq1_s_f32(
@@ -5344,7 +5322,11 @@ kernel void kernel_mul_mv_iq4_xs_f32(
5344
  uint tiisg[[thread_index_in_simdgroup]],
5345
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5346
 
 
 
 
5347
  kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 
5348
  }
5349
 
5350
  //============================= templates and their specializations =============================
@@ -5770,6 +5752,9 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
5770
 
5771
  template <typename type4x4>
5772
  void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
 
 
 
5773
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5774
  const int ib32 = il/2;
5775
  il = il%2;
@@ -5786,6 +5771,7 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
5786
  reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
5787
  reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
5788
  }
 
5789
  }
5790
 
5791
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -6334,7 +6320,11 @@ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_r
6334
  template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
6335
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6336
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
 
 
 
6337
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
 
6338
 
6339
  //
6340
  // matrix-matrix multiplication
@@ -6378,7 +6368,11 @@ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_m
6378
  template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
6379
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
6380
  template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
 
 
 
6381
  template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
 
6382
 
6383
  //
6384
  // indirect matrix-matrix multiplication
@@ -6434,7 +6428,11 @@ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel
6434
  template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
6435
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6436
  template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
 
 
 
6437
  template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
 
6438
 
6439
  //
6440
  // matrix-vector multiplication
@@ -7707,7 +7705,11 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
7707
 
7708
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7709
 
 
 
 
7710
  kernel_mul_mv_iq4_xs_f32_impl(
 
7711
  src0[id],
7712
  (device const float *) (src1 + bid*nb11),
7713
  dst + bid*ne0,
 
2560
  uint8_t qs[QK4_NL/2];
2561
  } block_iq4_nl;
2562
 
2563
+ #if QK_K == 64
2564
+ #define block_iq4_xs block_iq4_nl
2565
+ #else
2566
  typedef struct {
2567
  half d;
2568
  uint16_t scales_h;
2569
  uint8_t scales_l[QK_K/64];
2570
  uint8_t qs[QK_K/2];
2571
  } block_iq4_xs;
2572
+ #endif
2573
 
2574
  //====================================== dot products =========================
2575
 
 
4350
  threadgroup_barrier(mem_flags::mem_threadgroup);
4351
  }
4352
 
 
4353
  const int ix = tiisg;
4354
 
4355
  device const float * y4 = y + 32 * ix;
 
4390
 
4391
  y4 += 32 * 32;
4392
  }
 
 
 
 
 
 
4393
 
4394
  for (int row = 0; row < N_DST; ++row) {
4395
  all_sum = simd_sum(sumf[row]);
 
4479
  threadgroup_barrier(mem_flags::mem_threadgroup);
4480
  }
4481
 
 
4482
  const int ix = tiisg;
4483
 
4484
  device const float * y4 = y + 32 * ix;
 
4529
 
4530
  y4 += 32 * 32;
4531
  }
 
 
 
 
 
 
4532
 
4533
  for (int row = 0; row < N_DST; ++row) {
4534
  all_sum = simd_sum(sumf[row]);
 
4618
  threadgroup_barrier(mem_flags::mem_threadgroup);
4619
  }
4620
 
 
4621
  const int ix = tiisg;
4622
 
4623
  device const float * y4 = y + 32 * ix;
 
4661
 
4662
  y4 += 32 * 32;
4663
  }
 
 
 
 
 
 
4664
 
4665
  for (int row = 0; row < N_DST; ++row) {
4666
  all_sum = simd_sum(sumf[row]);
 
4999
 
5000
  const int nb32 = nb * (QK_K / 32);
5001
 
 
5002
  const int ix = tiisg/2;
5003
  const int il = tiisg%2;
5004
 
 
5037
 
5038
  y4 += 16 * 32;
5039
  }
 
 
 
 
 
 
5040
 
5041
  for (int row = 0; row < N_DST; ++row) {
5042
  all_sum = simd_sum(sumf[row]);
 
5143
  }
5144
  }
5145
 
5146
+ #if QK_K != 64
5147
  void kernel_mul_mv_iq4_xs_f32_impl(
5148
  device const void * src0,
5149
  device const float * src1,
 
5237
  }
5238
  }
5239
  }
5240
+ #endif
5241
 
5242
  [[host_name("kernel_mul_mv_iq1_s_f32")]]
5243
  kernel void kernel_mul_mv_iq1_s_f32(
 
5322
  uint tiisg[[thread_index_in_simdgroup]],
5323
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5324
 
5325
+ #if QK_K == 64
5326
+ kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5327
+ #else
5328
  kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5329
+ #endif
5330
  }
5331
 
5332
  //============================= templates and their specializations =============================
 
5752
 
5753
  template <typename type4x4>
5754
  void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
5755
+ #if QK_K == 64
5756
+ dequantize_iq4_nl(xb, il, reg);
5757
+ #else
5758
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5759
  const int ib32 = il/2;
5760
  il = il%2;
 
5771
  reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
5772
  reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
5773
  }
5774
+ #endif
5775
  }
5776
 
5777
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
 
6320
  template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
6321
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6322
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6323
+ #if QK_K == 64
6324
+ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
6325
+ #else
6326
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6327
+ #endif
6328
 
6329
  //
6330
  // matrix-matrix multiplication
 
6368
  template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
6369
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
6370
  template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6371
+ #if QK_K == 64
6372
+ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
6373
+ #else
6374
  template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6375
+ #endif
6376
 
6377
  //
6378
  // indirect matrix-matrix multiplication
 
6428
  template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
6429
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6430
  template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6431
+ #if QK_K == 64
6432
+ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
6433
+ #else
6434
  template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6435
+ #endif
6436
 
6437
  //
6438
  // matrix-vector multiplication
 
7705
 
7706
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7707
 
7708
+ #if QK_K == 64
7709
+ kernel_mul_mv_iq4_nl_f32_impl(
7710
+ #else
7711
  kernel_mul_mv_iq4_xs_f32_impl(
7712
+ #endif
7713
  src0[id],
7714
  (device const float *) (src1 + bid*nb11),
7715
  dst + bid*ne0,
ggml-quants.c CHANGED
@@ -1877,7 +1877,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
1877
  float mins[QK_K/16];
1878
  float scales[QK_K/16];
1879
  float sw[QK_K/16];
1880
- float weight[QK_K/16];
1881
  uint8_t Ls[QK_K/16], Lm[QK_K/16];
1882
 
1883
  for (int i = 0; i < nb; i++) {
@@ -1887,13 +1887,42 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
1887
  float sigma2 = sumx2/QK_K;
1888
  for (int j = 0; j < QK_K/16; ++j) {
1889
  const float * restrict qw = quant_weights + QK_K * i + 16*j;
1890
- for (int l = 0; l < QK_K/16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
1891
  for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
1892
- scales[j] = make_qkx3_quants(QK_K/16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
1893
  }
1894
 
1895
- float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
1896
- float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1897
  y[i].d = GGML_FP32_TO_FP16(dm);
1898
  y[i].dmin = GGML_FP32_TO_FP16(mm);
1899
  dm = GGML_FP16_TO_FP32(y[i].d);
@@ -4227,6 +4256,9 @@ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y,
4227
 
4228
  void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int k) {
4229
  assert(k % QK_K == 0);
 
 
 
4230
  const int nb = k / QK_K;
4231
 
4232
  for (int i = 0; i < nb; i++) {
@@ -4246,6 +4278,7 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y,
4246
  qs += 16;
4247
  }
4248
  }
 
4249
  }
4250
 
4251
  //===================================== Q8_K ==============================================
@@ -6306,7 +6339,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6306
 
6307
  float sumf = 0;
6308
 
6309
- int isum[4];
6310
 
6311
  for (int i = 0; i < nb; ++i) {
6312
 
@@ -6322,14 +6355,14 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6322
  const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6323
  const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
6324
 
6325
- isum[0] = isum[1] = isum[2] = isum[3] = 0;
6326
  for (int l = 0; l < 16; ++l) {
6327
  isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
6328
  isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
6329
  isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
6330
  isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
6331
  }
6332
- for (int l = 0; l < 4; ++l) {
6333
  isum[l] *= (sc[l] & 0xF);
6334
  }
6335
  sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
@@ -9488,15 +9521,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
9488
 
9489
  #elif defined(__AVX2__)
9490
 
9491
- const __m128i m4 = _mm_set1_epi8(0xf);
9492
- const __m128i m1 = _mm_set1_epi8(1);
9493
- const __m256i m511 = _mm256_set1_epi16(511);
9494
  const __m256i mone = _mm256_set1_epi8(1);
9495
-
9496
- static const uint8_t k_bit_helper[32] = {
9497
- 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
9498
- 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
9499
- };
9500
  static const char block_sign_shuffle_mask_1[32] = {
9501
  0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
9502
  0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
@@ -9510,11 +9535,77 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
9510
  0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
9511
  };
9512
 
9513
- const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
9514
  const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
9515
  const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
9516
  const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
9517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9518
  uint64_t aux64;
9519
 
9520
  // somewhat hacky, but gives a significant boost in performance
@@ -9603,6 +9694,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
9603
  }
9604
 
9605
  *s = 0.125f * hsum_float_8(accumf);
 
9606
 
9607
  #else
9608
 
@@ -10199,7 +10291,8 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
10199
 
10200
  const int nb = n / QK_K;
10201
 
10202
- #if defined __ARM_NEON
 
10203
 
10204
  const uint8x16_t m8 = vdupq_n_u8(0x08);
10205
  const uint8x16_t m7 = vdupq_n_u8(0x07);
@@ -10256,7 +10349,8 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
10256
 
10257
  *s = sumf;
10258
 
10259
- #elif defined __AVX2__
 
10260
 
10261
  const __m128i m8 = _mm_set1_epi8(0x08);
10262
  const __m128i m7 = _mm_set1_epi8(0x07);
@@ -10455,6 +10549,9 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
10455
  UNUSED(by);
10456
  UNUSED(bs);
10457
  assert(n % QK_K == 0);
 
 
 
10458
 
10459
  const block_iq4_xs * restrict x = vx;
10460
  const block_q8_K * restrict y = vy;
@@ -10574,6 +10671,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
10574
  }
10575
  *s = sumf;
10576
  #endif
 
10577
  }
10578
 
10579
  // ================================ IQ2 quantization =============================================
@@ -10921,7 +11019,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
10921
 
10922
  const int kMaxQ = 3;
10923
 
10924
- const int nbl = n/256;
10925
 
10926
  block_iq2_xxs * y = vy;
10927
 
@@ -11094,7 +11192,7 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v
11094
 
11095
  const int kMaxQ = 3;
11096
 
11097
- const int nbl = n/256;
11098
 
11099
  block_iq2_xs * y = vy;
11100
 
@@ -12037,7 +12135,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
12037
  GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
12038
  GGML_ASSERT(n%QK_K == 0);
12039
 
12040
- const int nbl = n/256;
12041
 
12042
  block_iq1_s * y = vy;
12043
 
@@ -12315,6 +12413,9 @@ void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * rest
12315
  }
12316
 
12317
  size_t quantize_iq4_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
 
 
 
12318
  (void)hist;
12319
  GGML_ASSERT(n_per_row%QK_K == 0);
12320
  int nblock = n_per_row/QK_K;
@@ -12333,6 +12434,7 @@ size_t quantize_iq4_xs(const float * src, void * dst, int nrow, int n_per_row, i
12333
  qrow += nblock*sizeof(block_iq4_xs);
12334
  }
12335
  return nrow * nblock * sizeof(block_iq4_xs);
 
12336
  }
12337
 
12338
  void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int k) {
@@ -12363,7 +12465,7 @@ static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy
12363
 
12364
  const int kMaxQ = 3;
12365
 
12366
- const int nbl = n/256;
12367
 
12368
  block_iq2_s * y = vy;
12369
 
 
1877
  float mins[QK_K/16];
1878
  float scales[QK_K/16];
1879
  float sw[QK_K/16];
1880
+ float weight[16];
1881
  uint8_t Ls[QK_K/16], Lm[QK_K/16];
1882
 
1883
  for (int i = 0; i < nb; i++) {
 
1887
  float sigma2 = sumx2/QK_K;
1888
  for (int j = 0; j < QK_K/16; ++j) {
1889
  const float * restrict qw = quant_weights + QK_K * i + 16*j;
1890
+ for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
1891
  for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
1892
+ scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
1893
  }
1894
 
1895
+ float dm, mm;
1896
+ #if QK_K == 64
1897
+ float max_scale = 0, max_min = 0;
1898
+ for (int j = 0; j < QK_K/16; ++j) {
1899
+ max_scale = MAX(max_scale, scales[j]);
1900
+ max_min = MAX(max_min, mins[j]);
1901
+ }
1902
+ dm = max_scale/15;
1903
+ mm = max_min/15;
1904
+ if (max_scale) {
1905
+ float id = 1/dm;
1906
+ for (int j = 0; j < QK_K/16; ++j) {
1907
+ int l = nearest_int(id*scales[j]);
1908
+ Ls[j] = MAX(0, MIN(15, l));
1909
+ }
1910
+ } else {
1911
+ memset(Ls, 0, QK_K/16);
1912
+ }
1913
+ if (max_min) {
1914
+ float id = 1/mm;
1915
+ for (int j = 0; j < QK_K/16; ++j) {
1916
+ int l = nearest_int(id*mins[j]);
1917
+ Lm[j] = MAX(0, MIN(15, l));
1918
+ }
1919
+ } else {
1920
+ memset(Lm, 0, QK_K/16);
1921
+ }
1922
+ #else
1923
+ dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
1924
+ mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
1925
+ #endif
1926
  y[i].d = GGML_FP32_TO_FP16(dm);
1927
  y[i].dmin = GGML_FP32_TO_FP16(mm);
1928
  dm = GGML_FP16_TO_FP32(y[i].d);
 
4256
 
4257
  void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int k) {
4258
  assert(k % QK_K == 0);
4259
+ #if QK_K == 64
4260
+ dequantize_row_iq4_nl((const block_iq4_nl *)x, y, k);
4261
+ #else
4262
  const int nb = k / QK_K;
4263
 
4264
  for (int i = 0; i < nb; i++) {
 
4278
  qs += 16;
4279
  }
4280
  }
4281
+ #endif
4282
  }
4283
 
4284
  //===================================== Q8_K ==============================================
 
6339
 
6340
  float sumf = 0;
6341
 
6342
+ int isum[QK_K/16];
6343
 
6344
  for (int i = 0; i < nb; ++i) {
6345
 
 
6355
  const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6356
  const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
6357
 
6358
+ memset(isum, 0, (QK_K/16)*sizeof(int));
6359
  for (int l = 0; l < 16; ++l) {
6360
  isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
6361
  isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
6362
  isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
6363
  isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
6364
  }
6365
+ for (int l = 0; l < QK_K/16; ++l) {
6366
  isum[l] *= (sc[l] & 0xF);
6367
  }
6368
  sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
 
9521
 
9522
  #elif defined(__AVX2__)
9523
 
 
 
 
9524
  const __m256i mone = _mm256_set1_epi8(1);
 
 
 
 
 
9525
  static const char block_sign_shuffle_mask_1[32] = {
9526
  0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
9527
  0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
 
9535
  0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
9536
  };
9537
 
 
9538
  const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
9539
  const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
9540
  const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
9541
 
9542
+ #if QK_K == 64
9543
+ static const uint8_t k_bit_helper[16] = {
9544
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
9545
+ };
9546
+ const __m128i bit_helper = _mm_loadu_si128((const __m128i*)k_bit_helper);
9547
+ const __m128i m511 = _mm_set1_epi16(511);
9548
+ typedef union {
9549
+ __m128i vec_index;
9550
+ uint16_t index[8];
9551
+ } index_t;
9552
+
9553
+ index_t idx;
9554
+ __m256 accumf = _mm256_setzero_ps();
9555
+ for (int i = 0; i < nb; ++i) {
9556
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
9557
+ const __m128i q2_data = _mm_loadu_si128((const __m128i*)x[i].qs);
9558
+ idx.vec_index = _mm_and_si128(q2_data, m511);
9559
+
9560
+ const __m128i partial_sign_bits = _mm_srli_epi16(q2_data, 9);
9561
+ const __m128i partial_sign_bits_upper = _mm_srli_epi16(q2_data, 13);
9562
+ const __m128i partial_sign_bits_for_counting = _mm_xor_si128(partial_sign_bits, partial_sign_bits_upper);
9563
+
9564
+ const __m128i odd_bits = _mm_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
9565
+ const __m128i full_sign_bits = _mm_or_si128(partial_sign_bits, odd_bits);
9566
+ const __m256i full_signs = _mm256_set_m128i(full_sign_bits, full_sign_bits);
9567
+
9568
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)y[i].qs);
9569
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)(y[i].qs+32));
9570
+
9571
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[idx.index[3]], iq2xs_grid[idx.index[2]],
9572
+ iq2xs_grid[idx.index[1]], iq2xs_grid[idx.index[0]]);
9573
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[idx.index[7]], iq2xs_grid[idx.index[6]],
9574
+ iq2xs_grid[idx.index[5]], iq2xs_grid[idx.index[4]]);
9575
+
9576
+ __m256i signs;
9577
+ signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_1);
9578
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
9579
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
9580
+
9581
+ signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_2);
9582
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
9583
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
9584
+
9585
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
9586
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
9587
+
9588
+ const __m256i sc1 = _mm256_set_m128i(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1));
9589
+ const __m256i sc2 = _mm256_set_m128i(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1));
9590
+
9591
+ const __m256i sum = _mm256_add_epi32(_mm256_madd_epi16(sc1, dot1), _mm256_madd_epi16(sc2, dot2));
9592
+
9593
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sum), accumf);
9594
+
9595
+ }
9596
+
9597
+ *s = 0.125f * hsum_float_8(accumf);
9598
+ #else
9599
+
9600
+ static const uint8_t k_bit_helper[32] = {
9601
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
9602
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
9603
+ };
9604
+ const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
9605
+ const __m256i m511 = _mm256_set1_epi16(511);
9606
+ const __m128i m4 = _mm_set1_epi8(0xf);
9607
+ const __m128i m1 = _mm_set1_epi8(1);
9608
+
9609
  uint64_t aux64;
9610
 
9611
  // somewhat hacky, but gives a significant boost in performance
 
9694
  }
9695
 
9696
  *s = 0.125f * hsum_float_8(accumf);
9697
+ #endif
9698
 
9699
  #else
9700
 
 
10291
 
10292
  const int nb = n / QK_K;
10293
 
10294
+ // TODO: implement for QK_K = 64
10295
+ #if defined __ARM_NEON && QK_K == 256
10296
 
10297
  const uint8x16_t m8 = vdupq_n_u8(0x08);
10298
  const uint8x16_t m7 = vdupq_n_u8(0x07);
 
10349
 
10350
  *s = sumf;
10351
 
10352
+ // TODO: implement for QK_K = 64
10353
+ #elif defined __AVX2__ && QK_K == 256
10354
 
10355
  const __m128i m8 = _mm_set1_epi8(0x08);
10356
  const __m128i m7 = _mm_set1_epi8(0x07);
 
10549
  UNUSED(by);
10550
  UNUSED(bs);
10551
  assert(n % QK_K == 0);
10552
+ #if QK_K == 64
10553
+ ggml_vec_dot_iq4_nl_q8_0(n, s, bs, vx, bx, vy, by, nrc);
10554
+ #else
10555
 
10556
  const block_iq4_xs * restrict x = vx;
10557
  const block_q8_K * restrict y = vy;
 
10671
  }
10672
  *s = sumf;
10673
  #endif
10674
+ #endif
10675
  }
10676
 
10677
  // ================================ IQ2 quantization =============================================
 
11019
 
11020
  const int kMaxQ = 3;
11021
 
11022
+ const int nbl = n/QK_K;
11023
 
11024
  block_iq2_xxs * y = vy;
11025
 
 
11192
 
11193
  const int kMaxQ = 3;
11194
 
11195
+ const int nbl = n/QK_K;
11196
 
11197
  block_iq2_xs * y = vy;
11198
 
 
12135
  GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
12136
  GGML_ASSERT(n%QK_K == 0);
12137
 
12138
+ const int nbl = n/QK_K;
12139
 
12140
  block_iq1_s * y = vy;
12141
 
 
12413
  }
12414
 
12415
  size_t quantize_iq4_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
12416
+ #if QK_K == 64
12417
+ return quantize_iq4_nl(src, dst, nrow, n_per_row, hist, quant_weights);
12418
+ #else
12419
  (void)hist;
12420
  GGML_ASSERT(n_per_row%QK_K == 0);
12421
  int nblock = n_per_row/QK_K;
 
12434
  qrow += nblock*sizeof(block_iq4_xs);
12435
  }
12436
  return nrow * nblock * sizeof(block_iq4_xs);
12437
+ #endif
12438
  }
12439
 
12440
  void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int k) {
 
12465
 
12466
  const int kMaxQ = 3;
12467
 
12468
+ const int nbl = n/QK_K;
12469
 
12470
  block_iq2_s * y = vy;
12471
 
ggml-quants.h CHANGED
@@ -230,6 +230,10 @@ typedef struct {
230
  } block_iq4_nl;
231
  static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
232
 
 
 
 
 
233
  typedef struct {
234
  ggml_fp16_t d;
235
  uint16_t scales_h;
@@ -237,6 +241,7 @@ typedef struct {
237
  uint8_t qs[QK_K/2];
238
  } block_iq4_xs;
239
  static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
 
240
 
241
  #ifdef __cplusplus
242
  extern "C" {
 
230
  } block_iq4_nl;
231
  static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
232
 
233
+ #if QK_K == 64
234
+ #define block_iq4_xs block_iq4_nl
235
+ //typedef struct block_iq4_nl block_iq4_xs;
236
+ #else
237
  typedef struct {
238
  ggml_fp16_t d;
239
  uint16_t scales_h;
 
241
  uint8_t qs[QK_K/2];
242
  } block_iq4_xs;
243
  static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
244
+ #endif
245
 
246
  #ifdef __cplusplus
247
  extern "C" {
ggml.c CHANGED
@@ -732,14 +732,22 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
732
  },
733
  [GGML_TYPE_IQ4_XS] = {
734
  .type_name = "iq4_xs",
 
 
 
735
  .blck_size = QK_K,
 
736
  .type_size = sizeof(block_iq4_xs),
737
  .is_quantized = true,
738
  .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
739
  .from_float = quantize_row_iq4_xs,
740
  .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
741
  .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
 
 
 
742
  .vec_dot_type = GGML_TYPE_Q8_K,
 
743
  .nrows = 1,
744
  },
745
  [GGML_TYPE_Q8_K] = {
@@ -19848,6 +19856,9 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
19848
  GGML_ASSERT(result == row_size * nrows);
19849
  } break;
19850
  case GGML_TYPE_IQ4_NL:
 
 
 
19851
  {
19852
  GGML_ASSERT(start % QK4_NL == 0);
19853
  GGML_ASSERT(start % n_per_row == 0);
@@ -19856,15 +19867,17 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
19856
  result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19857
  GGML_ASSERT(result == row_size * nrows);
19858
  } break;
 
19859
  case GGML_TYPE_IQ4_XS:
19860
  {
19861
- GGML_ASSERT(start % QK4_NL == 0);
19862
  GGML_ASSERT(start % n_per_row == 0);
19863
  size_t start_row = start / n_per_row;
19864
  size_t row_size = ggml_row_size(type, n_per_row);
19865
  result = quantize_iq4_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19866
  GGML_ASSERT(result == row_size * nrows);
19867
  } break;
 
19868
  case GGML_TYPE_F16:
19869
  {
19870
  size_t elemsize = sizeof(ggml_fp16_t);
 
732
  },
733
  [GGML_TYPE_IQ4_XS] = {
734
  .type_name = "iq4_xs",
735
+ #if QK_K == 64
736
+ .blck_size = QK4_NL,
737
+ #else
738
  .blck_size = QK_K,
739
+ #endif
740
  .type_size = sizeof(block_iq4_xs),
741
  .is_quantized = true,
742
  .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
743
  .from_float = quantize_row_iq4_xs,
744
  .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
745
  .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
746
+ #if QK_K == 64
747
+ .vec_dot_type = GGML_TYPE_Q8_0,
748
+ #else
749
  .vec_dot_type = GGML_TYPE_Q8_K,
750
+ #endif
751
  .nrows = 1,
752
  },
753
  [GGML_TYPE_Q8_K] = {
 
19856
  GGML_ASSERT(result == row_size * nrows);
19857
  } break;
19858
  case GGML_TYPE_IQ4_NL:
19859
+ #if QK_K == 64
19860
+ case GGML_TYPE_IQ4_XS:
19861
+ #endif
19862
  {
19863
  GGML_ASSERT(start % QK4_NL == 0);
19864
  GGML_ASSERT(start % n_per_row == 0);
 
19867
  result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19868
  GGML_ASSERT(result == row_size * nrows);
19869
  } break;
19870
+ #if QK_K != 64
19871
  case GGML_TYPE_IQ4_XS:
19872
  {
19873
+ GGML_ASSERT(start % QK_K == 0);
19874
  GGML_ASSERT(start % n_per_row == 0);
19875
  size_t start_row = start / n_per_row;
19876
  size_t row_size = ggml_row_size(type, n_per_row);
19877
  result = quantize_iq4_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19878
  GGML_ASSERT(result == row_size * nrows);
19879
  } break;
19880
+ #endif
19881
  case GGML_TYPE_F16:
19882
  {
19883
  size_t elemsize = sizeof(ggml_fp16_t);