Kawrakow ikawrakow commited on
Commit
f3a62cc
·
unverified ·
1 Parent(s): 2957823

Better 1.5 bit quantization (llama/5971)

Browse files

* Trying blocvks of 16 for IQ1_S - seems slightly better

* iq1s_blocks16: Adjust scale fudge factor to 1.125

* iq1s_blocks16: going to blocks of 32

with 2048 lattice points, so same bpw.
This is even better than blocks of 16.
Should I try blocks of 64? But to keep the same
bpw, when I go to 4096 lattice points, I need to
remove blocks alltogether and just have superblocks of
256 weights.

* iq1s_blocks16: Use 2*<x^2> as sigma2 in weight adjustment

* iq1s_blocks16: scalar and AVX2 dot products

* iq1s_blocks16: CUDA dot product

* iq1s_blocks16: Metal works, Neon does not

Metal works but TG is dog slow (35 t/s). PP is OKish (493 t/s).
Not seeing the bug in the Neon implementation for now.

* iq1s_blocks16: fixed Neon

* iq1s_blocks16: very slightly faster TG on Metal

Still pathetic at 37 t/s

* iq1s_blocks16: speedup Metal by packing codebook into uint32_t's

* Formatting

* iq1s_blocks16: uint32_t codebook is also better in CUDA

TG-128 is now 204 t/s up from 194 t/s.
PP-512 is 5890 t/s, so significantly better than other quants

* iq1s_blocks16: slightly faster Neon dot product

* iq1s_blocks16: faster AVX2 dot product

* iq1s_blocks16: adjust to ggml-common.h

---------

Co-authored-by: Iwan Kawrakow <[email protected]>

Files changed (4) hide show
  1. ggml-cuda.cu +31 -31
  2. ggml-metal.metal +37 -29
  3. ggml-quants.c +309 -203
  4. ggml-quants.h +2 -2
ggml-cuda.cu CHANGED
@@ -565,8 +565,8 @@ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N
565
  #define QI1_S (QK_K / (4*QR1_S))
566
  typedef struct {
567
  half d;
568
- uint8_t qs[QK_K/8];
569
- uint8_t scales[QK_K/16];
570
  } block_iq1_s;
571
  static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
572
 
@@ -1722,11 +1722,22 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
1722
  const int il = tid/8; // 0...3
1723
  const int ib = tid%8; // 0...7
1724
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
1725
- const int i8 = 4*ib+il;
1726
- uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
1727
- const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
1728
- const float d = (float)x[i].d * (2*(h & 7) + 1);
1729
- for (int j = 0; j < 8; ++j) y[j] = d * grid[j];
 
 
 
 
 
 
 
 
 
 
 
1730
  #else
1731
  assert(false);
1732
  #endif
@@ -4538,44 +4549,33 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
4538
  #endif
4539
  }
4540
 
4541
-
4542
  static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
4543
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4544
  #if QK_K == 256
4545
  const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
4546
 
4547
  const int ib32 = iqs;
4548
- int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
4549
- const uint8_t h1 = bq1->scales[2*ib32+0];
4550
- const uint8_t h2 = bq1->scales[2*ib32+1];
4551
  #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
4552
  const int * q8 = (const int *)bq8_1[ib32].qs;
4553
- const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
4554
- const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
4555
- const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
4556
- const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
4557
- for (int j = 0; j < 2; ++j) {
4558
- sumi1 = __dp4a(q8[j+0], grid1[j], sumi1);
4559
- sumi2 = __dp4a(q8[j+2], grid2[j], sumi2);
4560
- sumi3 = __dp4a(q8[j+4], grid3[j], sumi3);
4561
- sumi4 = __dp4a(q8[j+6], grid4[j], sumi4);
4562
  }
4563
  #else
4564
  const int8_t * q8 = bq8_1[ib32].qs;
4565
- const int8_t * grid1 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
4566
- const int8_t * grid2 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
4567
- const int8_t * grid3 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
4568
- const int8_t * grid4 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
4569
- for (int j = 0; j < 8; ++j) {
4570
- sumi1 += q8[j+ 0] * grid1[j];
4571
- sumi2 += q8[j+ 8] * grid2[j];
4572
- sumi3 += q8[j+16] * grid3[j];
4573
- sumi4 += q8[j+24] * grid4[j];
4574
  }
4575
  #endif
4576
  const float d = (float)bq1->d * __low2float(bq8_1[ib32].ds);
4577
- return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) +
4578
- sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1));
4579
  #else
4580
  assert(false);
4581
  return 0.f;
 
565
  #define QI1_S (QK_K / (4*QR1_S))
566
  typedef struct {
567
  half d;
568
+ uint8_t qs[QK_K/8];
569
+ uint16_t qh[QK_K/32];
570
  } block_iq1_s;
571
  static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
572
 
 
1722
  const int il = tid/8; // 0...3
1723
  const int ib = tid%8; // 0...7
1724
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
1725
+ const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 0xf) + 1);
1726
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
1727
+ int grid32[2]; const int8_t * q = (const int8_t *)grid32;
1728
+ grid32[0] = *((const int *)(iq1s_grid_gpu + (x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8))));
1729
+ grid32[1] = __vsub4((grid32[0] >> 4) & 0x0f0f0f0f, 0x01010101);
1730
+ grid32[0] = __vsub4(grid32[0] & 0x0f0f0f0f, 0x01010101);
1731
+ for (int j = 0; j < 8; ++j) {
1732
+ y[j] = d * q[j];
1733
+ }
1734
+ #else
1735
+ const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)));
1736
+ for (int j = 0; j < 4; ++j) {
1737
+ y[j+0] = d * ((grid[j] & 0xf) - 1);
1738
+ y[j+4] = d * ((grid[j] >> 4) - 1);
1739
+ }
1740
+ #endif
1741
  #else
1742
  assert(false);
1743
  #endif
 
4549
  #endif
4550
  }
4551
 
 
4552
  static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
4553
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4554
  #if QK_K == 256
4555
  const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
4556
 
4557
  const int ib32 = iqs;
4558
+ int sumi = 0;
 
 
4559
  #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
4560
  const int * q8 = (const int *)bq8_1[ib32].qs;
4561
+ for (int l = 0; l < 4; ++l) {
4562
+ const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
4563
+ int grid0 = __vsub4(grid[0] & 0x0f0f0f0f, 0x01010101);
4564
+ int grid1 = __vsub4((grid[0] >> 4) & 0x0f0f0f0f, 0x01010101);
4565
+ sumi = __dp4a(q8[2*l+1], grid1, __dp4a(q8[2*l+0], grid0, sumi));
 
 
 
 
4566
  }
4567
  #else
4568
  const int8_t * q8 = bq8_1[ib32].qs;
4569
+ for (int l = 0; l < 4; ++l) {
4570
+ const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
4571
+ for (int j = 0; j < 4; ++j) {
4572
+ sumi += q8[j] * ((grid[j] & 0xf) - 1) + q8[j+4] * ((grid[j] >> 4) - 1);
4573
+ }
4574
+ q8 += 8;
 
 
 
4575
  }
4576
  #endif
4577
  const float d = (float)bq1->d * __low2float(bq8_1[ib32].ds);
4578
+ return d * sumi * (2*(bq1->qh[ib32] >> 12) + 1);
 
4579
  #else
4580
  assert(false);
4581
  return 0.f;
ggml-metal.metal CHANGED
@@ -2595,8 +2595,8 @@ typedef struct {
2595
 
2596
  typedef struct {
2597
  half d;
2598
- uint8_t qs[QK_K/8];
2599
- uint8_t scales[QK_K/16];
2600
  } block_iq1_s;
2601
 
2602
  // Non-linear quants
@@ -4338,48 +4338,53 @@ void kernel_mul_mv_iq1_s_f32_impl(
4338
  device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
4339
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4340
 
4341
- float yl[16];
4342
  float sumf[N_DST]={0.f}, all_sum;
4343
 
4344
  const int nb32 = nb * (QK_K / 32);
4345
 
4346
- const int ix = tiisg/2;
4347
- const int il = tiisg%2;
4348
 
4349
- device const float * y4 = y + 32 * ix + 16 * il;
4350
 
4351
- for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
4352
 
4353
- for (int i = 0; i < 16; ++i) {
 
4354
  yl[i] = y4[i];
 
4355
  }
4356
 
4357
  const int ibl = ib32 / (QK_K / 32);
4358
  const int ib = ib32 % (QK_K / 32);
4359
 
4360
  device const block_iq1_s * xr = x + ibl;
4361
- device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
4362
- device const uint8_t * sc = xr->scales + 2 * ib + il;
4363
- device const half * dh = &xr->d;
4364
 
4365
  for (int row = 0; row < N_DST; row++) {
4366
 
4367
- constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
4368
- constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
 
 
4369
 
4370
- float2 sum = {0};
4371
- for (int j = 0; j < 8; ++j) {
4372
- sum[0] += yl[j+ 0] * grid1[j];
4373
- sum[1] += yl[j+ 8] * grid2[j];
 
 
4374
  }
4375
- sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
4376
 
4377
  dh += nb*sizeof(block_iq1_s)/2;
4378
  qs += nb*sizeof(block_iq1_s);
4379
- sc += nb*sizeof(block_iq1_s);
4380
  }
4381
 
4382
- y4 += 16 * 32;
4383
  }
4384
 
4385
  for (int row = 0; row < N_DST; ++row) {
@@ -5066,16 +5071,19 @@ void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 &
5066
  template <typename type4x4>
5067
  void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
5068
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
 
 
5069
  const float d = xb->d;
5070
- device const uint8_t * qs = xb->qs + 2*il;
5071
- device const uint8_t * sc = xb->scales + il;
5072
- const float dl1 = d * (2*(sc[0] & 7) + 1);
5073
- const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1);
5074
- constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
5075
- constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
5076
- for (int i = 0; i < 8; ++i) {
5077
- reg[i/4+0][i%4] = dl1 * grid1[i];
5078
- reg[i/4+2][i%4] = dl2 * grid2[i];
 
5079
  }
5080
  }
5081
 
 
2595
 
2596
  typedef struct {
2597
  half d;
2598
+ uint8_t qs[QK_K/8];
2599
+ uint16_t qh[QK_K/32];
2600
  } block_iq1_s;
2601
 
2602
  // Non-linear quants
 
4338
  device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
4339
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4340
 
4341
+ float yl[32];
4342
  float sumf[N_DST]={0.f}, all_sum;
4343
 
4344
  const int nb32 = nb * (QK_K / 32);
4345
 
4346
+ const int ix = tiisg;
 
4347
 
4348
+ device const float * y4 = y + 32 * ix;
4349
 
4350
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
4351
 
4352
+ float sumy = 0;
4353
+ for (int i = 0; i < 32; ++i) {
4354
  yl[i] = y4[i];
4355
+ sumy += yl[i];
4356
  }
4357
 
4358
  const int ibl = ib32 / (QK_K / 32);
4359
  const int ib = ib32 % (QK_K / 32);
4360
 
4361
  device const block_iq1_s * xr = x + ibl;
4362
+ device const uint8_t * qs = xr->qs + 4 * ib;
4363
+ device const uint16_t * qh = xr->qh + ib;
4364
+ device const half * dh = &xr->d;
4365
 
4366
  for (int row = 0; row < N_DST; row++) {
4367
 
4368
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
4369
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
4370
+ constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
4371
+ constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
4372
 
4373
+ float sum = 0;
4374
+ for (int j = 0; j < 4; ++j) {
4375
+ sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
4376
+ + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
4377
+ + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
4378
+ + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
4379
  }
4380
+ sumf[row] += (float)dh[0] * (sum - sumy) * (2*(qh[0] >> 12) + 1);
4381
 
4382
  dh += nb*sizeof(block_iq1_s)/2;
4383
  qs += nb*sizeof(block_iq1_s);
4384
+ qh += nb*sizeof(block_iq1_s)/2;
4385
  }
4386
 
4387
+ y4 += 32 * 32;
4388
  }
4389
 
4390
  for (int row = 0; row < N_DST; ++row) {
 
5071
  template <typename type4x4>
5072
  void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
5073
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5074
+ const int ib32 = il/2;
5075
+ il = il%2;
5076
  const float d = xb->d;
5077
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5078
+ device const uint16_t * qh = xb->qh;
5079
+ const float dl = d * (2*(qh[ib32] >> 12) + 1);
5080
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | (((qh[ib32] >> (6*il+0)) & 7) << 8)));
5081
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | (((qh[ib32] >> (6*il+3)) & 7) << 8)));
5082
+ for (int i = 0; i < 4; ++i) {
5083
+ reg[0][i] = dl * (grid1[i] & 0xf) - dl;
5084
+ reg[1][i] = dl * (grid1[i] >> 4) - dl;
5085
+ reg[2][i] = dl * (grid2[i] & 0xf) - dl;
5086
+ reg[3][i] = dl * (grid2[i] >> 4) - dl;
5087
  }
5088
  }
5089
 
ggml-quants.c CHANGED
@@ -3449,39 +3449,22 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in
3449
  assert(k % QK_K == 0);
3450
  const int nb = k / QK_K;
3451
 
3452
- float db[4];
3453
- uint16_t idx[4];
3454
- //const int8_t * grid[4];
3455
-
3456
  for (int i = 0; i < nb; i++) {
3457
 
3458
  const float d = GGML_FP16_TO_FP32(x[i].d);
3459
- const uint8_t * sc = x[i].scales;
3460
- const uint8_t * qs = x[i].qs;
3461
 
3462
- for (int i8 = 0; i8 < QK_K/8; i8 += 4) {
3463
- idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
3464
- idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
3465
- idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
3466
- idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
3467
- //grid[0] = (const int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
3468
- //grid[1] = (const int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
3469
- //grid[2] = (const int8_t *)(iq1s_grid + (qs[2] | ((sc[1] & 0x08) << 5)));
3470
- //grid[3] = (const int8_t *)(iq1s_grid + (qs[3] | ((sc[1] & 0x80) << 1)));
3471
- db[0] = d * (2*(sc[0] & 7) + 1);
3472
- db[1] = d * (2*((sc[0] >> 4) & 7) + 1);
3473
- db[2] = d * (2*(sc[1] & 7) + 1);
3474
- db[3] = d * (2*((sc[1] >> 4) & 7) + 1);
3475
  for (int l = 0; l < 4; ++l) {
3476
- const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
3477
  for (int j = 0; j < 8; ++j) {
3478
- //y[j] = db[l] * grid[l][j];
3479
- y[j] = db[l] * grid[j];
3480
  }
3481
  y += 8;
3482
  }
3483
  qs += 4;
3484
- sc += 2;
3485
  }
3486
  }
3487
  }
@@ -9587,113 +9570,72 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
9587
 
9588
  const int nb = n / QK_K;
9589
 
9590
- // TODO: implement for QK_K = 64
9591
- #if defined __ARM_NEON && QK_K == 256
9592
-
9593
- const uint8x16_t m8 = vdupq_n_u8(0x08);
9594
- const uint8x16_t m7 = vdupq_n_u8(0x07);
9595
- const uint8x16_t m1 = vdupq_n_u8(0x01);
9596
- const int32x4_t vzero = vdupq_n_s32(0);
9597
 
9598
- uint16_t gindex[8];
9599
- uint16x8x2_t vindex;
9600
- int8x16x4_t q1b;
9601
  ggml_int8x16x4_t q8b;
9602
- uint16x8x4_t scales;
9603
- int32x4x2_t sumi;
9604
- int32x4x2_t dotq;
9605
 
9606
  float sumf = 0;
9607
  for (int i = 0; i < nb; ++i) {
9608
 
9609
- const int8_t * q8 = y[i].qs;
9610
- const uint8_t * qs = x[i].qs;
9611
- const uint8_t * sc = x[i].scales;
9612
 
9613
- sumi.val[0] = sumi.val[1] = vzero;
9614
 
9615
- for (int i128 = 0; i128 < QK_K/128; ++i128) {
9616
- const uint8x16_t ql = vld1q_u8(qs); qs += 16;
9617
- const uint8x8_t tm1 = vld1_u8 (sc); sc += 8;
9618
- const uint8x8_t tm2 = vshr_n_u8(tm1, 4);
9619
- const uint8x16_t qh = vcombine_u8(vzip1_u8(tm1, tm2), vzip2_u8(tm1, tm2));
9620
- const uint8x16_t hbit = vandq_u8(qh, m8);
9621
- vindex.val[0] = vorrq_u16(vmovl_u8(vget_low_u8 (ql)), vshlq_n_u16(vmovl_u8(vget_low_u8 (hbit)), 5));
9622
- vindex.val[1] = vorrq_u16(vmovl_u8(vget_high_u8(ql)), vshlq_n_u16(vmovl_u8(vget_high_u8(hbit)), 5));
9623
- const uint8x16_t scales8 = vorrq_u8(vshlq_n_u8(vandq_u8(qh, m7), 1), m1);
9624
- scales.val[0] = vmovl_u8(vget_low_u8 (scales8));
9625
- scales.val[1] = vmovl_u8(vget_high_u8 (scales8));
9626
 
9627
- for (int l = 0; l < 2; ++l) {
9628
- vst1q_u16(gindex+0, vindex.val[l]);
9629
- q1b.val[0] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[0])), vld1_s8((const void *)(iq1s_grid+gindex[1])));
9630
- q1b.val[1] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[2])), vld1_s8((const void *)(iq1s_grid+gindex[3])));
9631
- q1b.val[2] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[4])), vld1_s8((const void *)(iq1s_grid+gindex[5])));
9632
- q1b.val[3] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[6])), vld1_s8((const void *)(iq1s_grid+gindex[7])));
9633
- q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
 
 
 
 
9634
 
9635
- dotq.val[0] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(vzero, q1b.val[1], q8b.val[1]));
9636
- dotq.val[1] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(vzero, q1b.val[3], q8b.val[3]));
 
 
 
9637
 
9638
- sumi.val[0] = vmlaq_s32(sumi.val[0], dotq.val[0], vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales.val[l]))));
9639
- sumi.val[1] = vmlaq_s32(sumi.val[1], dotq.val[1], vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales.val[l]))));
9640
- }
9641
  }
9642
 
9643
- sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * vaddvq_s32(vaddq_s32(sumi.val[0], sumi.val[1]));
9644
  }
9645
 
9646
  *s = sumf;
9647
 
9648
- // TODO: implement for QK_K = 64
9649
- #elif defined __AVX2__ && QK_K == 256
9650
-
9651
- const __m128i m8 = _mm_set1_epi8(0x08);
9652
- const __m128i m7 = _mm_set1_epi8(0x07);
9653
- const __m128i m1 = _mm_set1_epi8(0x01);
9654
- const __m128i shuffle_h = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
9655
- const __m128i shuffle_s[4] = {
9656
- _mm_set_epi32(0x03030303, 0x02020202, 0x01010101, 0x00000000),
9657
- _mm_set_epi32(0x07070707, 0x06060606, 0x05050505, 0x04040404),
9658
- _mm_set_epi32(0x0b0b0b0b, 0x0a0a0a0a, 0x09090909, 0x08080808),
9659
- _mm_set_epi32(0x0f0f0f0f, 0x0e0e0e0e, 0x0d0d0d0d, 0x0c0c0c0c)
9660
- };
9661
-
9662
- uint64_t aux64;
9663
-
9664
- typedef union m256i_uint16 {
9665
- __m256i reg;
9666
- uint16_t s[16];
9667
- } m256i_uint16_t;
9668
-
9669
- m256i_uint16_t v_gindex;
9670
 
9671
  __m256 accum = _mm256_setzero_ps();
9672
  for (int i = 0; i < nb; ++i) {
9673
 
9674
- const int8_t * q8 = y[i].qs;
9675
- const uint8_t * qs = x[i].qs;
9676
- const uint8_t * sc = x[i].scales;
9677
 
9678
  __m256i sumi = _mm256_setzero_si256();
9679
- for (int i128 = 0; i128 < QK_K/128; ++i128) {
9680
- const __m128i ql = _mm_loadu_si128((const __m128i*)qs); qs += 16;
9681
- memcpy(&aux64, sc, 8); sc += 8;
9682
- const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h);
9683
- const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8));
9684
- v_gindex.reg = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
9685
- const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1);
 
9686
 
9687
- for (int i32 = 0; i32 < 4; ++i32) {
9688
- const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
9689
- const __m256i q1b = _mm256_set_epi64x(iq1s_grid[v_gindex.s[4*i32+3]], iq1s_grid[v_gindex.s[4*i32+2]],
9690
- iq1s_grid[v_gindex.s[4*i32+1]], iq1s_grid[v_gindex.s[4*i32+0]]);
9691
- const __m256i dot = mul_add_epi8(q1b, q8b);
9692
- const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32]));
9693
- const __m256i p = _mm256_madd_epi16(s16, dot);
9694
- sumi = _mm256_add_epi32(sumi, p);
9695
- }
9696
 
 
9697
  }
9698
 
9699
  accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum);
@@ -9704,35 +9646,26 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
9704
 
9705
  #else
9706
 
9707
- int db[4];
9708
- uint16_t idx[4];
9709
-
9710
  float sumf = 0;
9711
- for (int i = 0; i < nb; ++i) {
9712
 
9713
- const int8_t * q8 = y[i].qs;
9714
- const uint8_t * qs = x[i].qs;
9715
- const uint8_t * sc = x[i].scales;
9716
 
9717
  int sumi = 0;
9718
- for (int i32 = 0; i32 < QK_K/32; ++i32) {
9719
- idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
9720
- idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
9721
- idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
9722
- idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
9723
- db[0] = (2*(sc[0] & 7) + 1);
9724
- db[1] = (2*((sc[0] >> 4) & 7) + 1);
9725
- db[2] = (2*(sc[1] & 7) + 1);
9726
- db[3] = (2*((sc[1] >> 4) & 7) + 1);
9727
  for (int l = 0; l < 4; ++l) {
9728
- const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
9729
- int suml = 0;
9730
- for (int j = 0; j < 8; ++j) suml += q8[j] * grid[j];
9731
- sumi += db[l] * suml;
9732
  q8 += 8;
9733
  }
 
9734
  qs += 4;
9735
- sc += 2;
9736
  }
9737
 
9738
  sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * sumi;
@@ -9996,7 +9929,7 @@ static inline int iq2_grid_size(enum ggml_type type) {
9996
  GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S);
9997
  return type == GGML_TYPE_IQ2_XXS ? 256 :
9998
  type == GGML_TYPE_IQ2_XS ? 512 :
9999
- type == GGML_TYPE_IQ1_S ? 512 : 1024;
10000
  }
10001
 
10002
  static int iq2_compare_func(const void * left, const void * right) {
@@ -10063,39 +9996,135 @@ void iq2xs_init_impl(enum ggml_type type) {
10063
  40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
10064
  42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
10065
  };
10066
- static const uint16_t kgrid_1bit_512[512] = {
10067
- 10, 33, 41, 85, 132, 134, 160, 162, 277, 337, 340, 345, 357, 405, 516, 545,
10068
- 553, 598, 641, 650, 681, 1042, 1044, 1097, 1169, 1176, 1320, 1345, 1365, 1378, 1434, 1444,
10069
- 1545, 1617, 1642, 1685, 2053, 2080, 2089, 2133, 2176, 2182, 2208, 2214, 2306, 2384, 2393, 2440,
10070
- 2453, 2581, 2664, 2690, 2721, 4117, 4161, 4182, 4184, 4261, 4357, 4369, 4372, 4377, 4390, 4422,
10071
- 4432, 4437, 4449, 4457, 4485, 4497, 4505, 4629, 4677, 4696, 4774, 5205, 5217, 5225, 5386, 5397,
10072
- 5409, 5445, 5457, 5460, 5461, 5462, 5465, 5472, 5477, 5525, 5545, 5650, 5668, 5717, 5729, 5769,
10073
- 5777, 6212, 6234, 6244, 6293, 6424, 6482, 6485, 6502, 6505, 6529, 6538, 6565, 6656, 6682, 6788,
10074
- 6806, 6820, 8218, 8224, 8226, 8232, 8277, 8326, 8354, 8469, 8521, 8530, 8549, 8596, 8737, 8794,
10075
- 9221, 9253, 9348, 9369, 9380, 9474, 9557, 9633, 9732, 9753, 9793, 9830, 9862, 9880, 10240, 10272,
10076
- 10282, 10321, 10406, 10517, 10530, 10566, 10585, 10645, 10896, 16466, 16468, 16473, 16485, 16646, 16660, 16665,
10077
- 16725, 16793, 16806, 16914, 16969, 16977, 16996, 17028, 17057, 17408, 17416, 17434, 17493, 17512, 17578, 17685,
10078
- 17696, 17733, 17745, 17748, 17749, 17750, 17753, 17765, 17794, 17813, 17946, 17984, 18005, 18072, 18453, 18529,
10079
- 18569, 18722, 18756, 18762, 18773, 18794, 18833, 18853, 18945, 19026, 19033, 19077, 20489, 20497, 20500, 20517,
10080
- 20565, 20586, 20610, 20633, 20757, 20769, 20776, 20805, 20817, 20820, 20821, 20822, 20825, 20837, 20864, 20872,
10081
- 20885, 20896, 21002, 21029, 21077, 21146, 21510, 21525, 21573, 21585, 21588, 21589, 21590, 21593, 21605, 21653,
10082
- 21665, 21765, 21777, 21780, 21781, 21782, 21785, 21797, 21825, 21828, 21829, 21830, 21833, 21840, 21841, 21842,
10083
- 21844, 21846, 21848, 21849, 21850, 21857, 21860, 21861, 21862, 21865, 21893, 21905, 21908, 21909, 21910, 21913,
10084
- 21925, 22024, 22037, 22085, 22097, 22100, 22101, 22102, 22105, 22117, 22165, 22545, 22566, 22568, 22594, 22608,
10085
- 22613, 22676, 22697, 22793, 22805, 22853, 22865, 22868, 22869, 22870, 22873, 22885, 22933, 22946, 23046, 23072,
10086
- 23125, 23209, 24597, 24640, 24665, 24673, 24725, 24833, 24840, 24869, 24917, 24934, 24965, 25001, 25108, 25110,
10087
- 25152, 25184, 25192, 25234, 25616, 25618, 25625, 25685, 25704, 25738, 25744, 25770, 25877, 25897, 25925, 25937,
10088
- 25940, 25941, 25942, 25945, 25957, 25986, 26005, 26186, 26197, 26276, 26632, 26634, 26725, 26757, 26770, 26885,
10089
- 26965, 26976, 26986, 27032, 27153, 27174, 27200, 27208, 27240, 27269, 27282, 27290, 32778, 32800, 32802, 32808,
10090
- 32810, 32853, 32904, 32922, 32930, 32932, 33105, 33110, 33112, 33125, 33157, 33280, 33288, 33301, 33312, 33320,
10091
- 33424, 33797, 33829, 33858, 34068, 34133, 34146, 34176, 34217, 34306, 34342, 34441, 34454, 34468, 34832, 34918,
10092
- 34965, 34984, 35094, 35137, 35161, 35208, 35232, 35332, 35338, 35368, 35429, 36932, 36934, 36953, 37009, 37125,
10093
- 37136, 37138, 37145, 37157, 37205, 37220, 37258, 37290, 37444, 37446, 37465, 37478, 37525, 37905, 37968, 37973,
10094
- 38040, 38054, 38145, 38154, 38165, 38180, 38186, 38213, 38225, 38228, 38229, 38230, 38233, 38245, 38293, 38485,
10095
- 38504, 38530, 38938, 38985, 38993, 39012, 39040, 39173, 39192, 39253, 39265, 39301, 39316, 39322, 39442, 39497,
10096
- 39504, 39590, 40970, 40984, 40992, 41002, 41045, 41120, 41128, 41237, 41289, 41297, 41317, 41364, 41366, 41514,
10097
- 41557, 41633, 41989, 42021, 42056, 42068, 42074, 42113, 42242, 42265, 42274, 42325, 42340, 42402, 42501, 42512,
10098
- 42533, 42624, 42632, 42666, 43040, 43093, 43106, 43168, 43176, 43264, 43286, 43345, 43429, 43590, 43618, 43680,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10099
  };
10100
  static const uint16_t kgrid_2bit_1024[1024] = {
10101
  0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
@@ -10169,7 +10198,7 @@ void iq2xs_init_impl(enum ggml_type type) {
10169
  const int nwant = type == GGML_TYPE_IQ1_S ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2;
10170
  const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
10171
  type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 :
10172
- type == GGML_TYPE_IQ1_S ? kgrid_1bit_512 : kgrid_2bit_1024;
10173
  uint64_t * kgrid_q2xs;
10174
  int * kmap_q2xs;
10175
  uint16_t * kneighbors_q2xs;
@@ -11408,12 +11437,70 @@ static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const u
11408
  return grid_index;
11409
  }
11410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11411
  static int iq1_sort_helper(const void * left, const void * right) {
11412
  const float * l = left;
11413
  const float * r = right;
11414
  return *l < *r ? -1 : *l > *r ? 1 : 0;
11415
  }
11416
 
 
11417
  static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
11418
 
11419
  const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
@@ -11432,37 +11519,37 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
11432
 
11433
  block_iq1_s * y = vy;
11434
 
11435
- float scales[QK_K/8];
11436
- float weight[8];
11437
- int8_t L[8];
11438
- float sumx[9];
11439
- float sumw[9];
11440
- float pairs[16];
11441
  int * idx = (int *)(pairs + 1);
11442
- uint8_t hbit[QK_K/8];
11443
 
11444
  for (int ibl = 0; ibl < nbl; ++ibl) {
11445
 
11446
  y[ibl].d = GGML_FP32_TO_FP16(0.f);
11447
  memset(y[ibl].qs, 0, QK_K/8);
11448
- memset(y[ibl].scales, 0, QK_K/16);
11449
 
11450
  float max_scale = 0;
11451
 
11452
  const float * xbl = x + QK_K*ibl;
11453
  float sumx2 = 0;
11454
  for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
11455
- float sigma2 = sumx2/QK_K;
11456
 
11457
- for (int ib = 0; ib < QK_K/8; ++ib) {
11458
- const float * xb = xbl + 8*ib;
11459
- const float * qw = quant_weights + QK_K*ibl + 8*ib;
11460
- for (int i = 0; i < 8; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
11461
  float max = fabsf(xb[0]);
11462
- for (int i = 1; i < 8; ++i) max = MAX(max, fabsf(xb[i]));
11463
  if (!max) {
11464
  scales[ib] = 0;
11465
- memset(L, 1, 8);
11466
  continue;
11467
  }
11468
  // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
@@ -11471,14 +11558,14 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
11471
  // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
11472
  // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
11473
  // for each possible and score for each split.
11474
- for (int j = 0; j < 8; ++j) {
11475
  pairs[2*j] = xb[j];
11476
  idx[2*j] = j;
11477
  }
11478
- qsort(pairs, 8, 2*sizeof(float), iq1_sort_helper);
11479
  {
11480
  sumx[0] = sumw[0] = 0;
11481
- for (int j = 0; j < 8; ++j) {
11482
  int i = idx[2*j];
11483
  sumx[j+1] = sumx[j] + weight[i]*xb[i];
11484
  sumw[j+1] = sumw[j] + weight[i];
@@ -11486,10 +11573,10 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
11486
  }
11487
  float best_score = 0, scale = max;
11488
  int besti1 = 0, besti2 = 0;
11489
- for (int i1 = 0; i1 <= 8; ++i1) {
11490
- for (int i2 = i1; i2 <= 8; ++i2) {
11491
- float sumqx = -(sumx[i1] - sumx[0]) + (sumx[8] - sumx[i2]);
11492
- float sumq2 = (sumw[i1] - sumw[0]) + (sumw[8] - sumw[i2]);
11493
  if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
11494
  scale = sumqx/sumq2; best_score = scale*sumqx;
11495
  besti1 = i1; besti2 = i2;
@@ -11498,23 +11585,43 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
11498
  }
11499
  for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
11500
  for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
11501
- for (int j = besti2; j < 8; ++j) L[idx[2*j]] = 2;
11502
  if (scale < 0) {
11503
- for (int j = 0; j < 8; ++j) L[j] = 2 - L[j];
11504
  scale = -scale;
11505
  }
11506
- // Now we check if the solution found above corresponds to a grid point and, if not, use a neighbouring
11507
- // grid point that minimizes SSD.
11508
- uint16_t u = 0;
11509
- for (int j = 0; j < 8; ++j) u |= (L[j] << 2*j);
11510
- int grid_index = kmap_q2xs[u];
11511
- if (grid_index < 0) {
11512
- const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
11513
- grid_index = iq1_find_best_neighbour(neighbours, kgrid_q2xs, xb, weight, &scale, L, NGRID_IQ2XXS);
11514
- GGML_ASSERT(grid_index >= 0);
11515
- }
11516
- y[ibl].qs[ib] = grid_index & 255;
11517
- hbit[ib] = grid_index >> 8;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11518
  GGML_ASSERT(scale >= 0);
11519
  scales[ib] = scale;
11520
  max_scale = MAX(max_scale, scale);
@@ -11525,14 +11632,13 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
11525
  continue;
11526
  }
11527
 
11528
- float d = max_scale/15;
11529
- y[ibl].d = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed.
11530
  float id = 1/d;
11531
- for (int ib = 0; ib < QK_K/8; ++ib) {
11532
  int l = nearest_int(0.5f*(id*scales[ib]-1));
11533
- l = MAX(0, MIN(7, l));
11534
- if (hbit[ib]) l |= 8;
11535
- y[ibl].scales[ib/2] |= (l << 4*(ib%2));
11536
  }
11537
  }
11538
  }
 
3449
  assert(k % QK_K == 0);
3450
  const int nb = k / QK_K;
3451
 
 
 
 
 
3452
  for (int i = 0; i < nb; i++) {
3453
 
3454
  const float d = GGML_FP16_TO_FP32(x[i].d);
3455
+ const uint8_t * qs = x[i].qs;
3456
+ const uint16_t * qh = x[i].qh;
3457
 
3458
+ for (int ib = 0; ib < QK_K/32; ++ib) {
3459
+ const float dl = d * (2*(qh[ib] >> 12) + 1);
 
 
 
 
 
 
 
 
 
 
 
3460
  for (int l = 0; l < 4; ++l) {
3461
+ const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
3462
  for (int j = 0; j < 8; ++j) {
3463
+ y[j] = dl * grid[j];
 
3464
  }
3465
  y += 8;
3466
  }
3467
  qs += 4;
 
3468
  }
3469
  }
3470
  }
 
9570
 
9571
  const int nb = n / QK_K;
9572
 
9573
+ #if defined __ARM_NEON
 
 
 
 
 
 
9574
 
9575
+ ggml_int8x16x4_t q1b;
 
 
9576
  ggml_int8x16x4_t q8b;
 
 
 
9577
 
9578
  float sumf = 0;
9579
  for (int i = 0; i < nb; ++i) {
9580
 
9581
+ const int8_t * q8 = y[i].qs;
9582
+ const uint8_t * qs = x[i].qs;
9583
+ const uint16_t * qh = x[i].qh;
9584
 
9585
+ int sumi1 = 0, sumi2 = 0;
9586
 
9587
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
 
 
 
 
 
 
 
 
 
 
9588
 
9589
+ q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))),
9590
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700)))));
9591
+ q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))),
9592
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700)))));
9593
+ q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))),
9594
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700)))));
9595
+ q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))),
9596
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700)))));
9597
+ qs += 8;
9598
+
9599
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
9600
 
9601
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]);
9602
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]);
9603
+
9604
+ sumi1 += vaddvq_s32(p1) * (2*(qh[ib+0] >> 12) + 1);
9605
+ sumi2 += vaddvq_s32(p2) * (2*(qh[ib+1] >> 12) + 1);
9606
 
 
 
 
9607
  }
9608
 
9609
+ sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2);
9610
  }
9611
 
9612
  *s = sumf;
9613
 
9614
+ #elif defined __AVX2__
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9615
 
9616
  __m256 accum = _mm256_setzero_ps();
9617
  for (int i = 0; i < nb; ++i) {
9618
 
9619
+ const int8_t * q8 = y[i].qs;
9620
+ const uint8_t * qs = x[i].qs;
9621
+ const uint16_t * qh = x[i].qh;
9622
 
9623
  __m256i sumi = _mm256_setzero_si256();
9624
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
9625
+ const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
9626
+ iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
9627
+ const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)],
9628
+ iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
9629
+ qs += 8;
9630
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
9631
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
9632
 
9633
+ const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
9634
+ const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
9635
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*(qh[ib+0] >> 12) + 1));
9636
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*(qh[ib+1] >> 12) + 1));
 
 
 
 
 
9637
 
9638
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2));
9639
  }
9640
 
9641
  accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum);
 
9646
 
9647
  #else
9648
 
 
 
 
9649
  float sumf = 0;
9650
+ for (int i = 0; i < nb; i++) {
9651
 
9652
+ const int8_t * q8 = y[i].qs;
9653
+ const uint8_t * qs = x[i].qs;
9654
+ const uint16_t * qh = x[i].qh;
9655
 
9656
  int sumi = 0;
9657
+ for (int ib = 0; ib < QK_K/32; ++ib) {
9658
+ const int ls = 2*(qh[ib] >> 12) + 1;
9659
+ int lsum = 0;
 
 
 
 
 
 
9660
  for (int l = 0; l < 4; ++l) {
9661
+ const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
9662
+ for (int j = 0; j < 8; ++j) {
9663
+ lsum += q8[j] * grid[j];
9664
+ }
9665
  q8 += 8;
9666
  }
9667
+ sumi += ls * lsum;
9668
  qs += 4;
 
9669
  }
9670
 
9671
  sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * sumi;
 
9929
  GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S);
9930
  return type == GGML_TYPE_IQ2_XXS ? 256 :
9931
  type == GGML_TYPE_IQ2_XS ? 512 :
9932
+ type == GGML_TYPE_IQ1_S ? NGRID_IQ1S : 1024;
9933
  }
9934
 
9935
  static int iq2_compare_func(const void * left, const void * right) {
 
9996
  40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
9997
  42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
9998
  };
9999
+ static const uint16_t kgrid_1bit_2048[NGRID_IQ1S] = {
10000
+ 0, 2, 5, 8, 10, 17, 21, 32, 34, 40, 42, 69, 81, 84, 86, 101,
10001
+ 128, 130, 136, 138, 149, 160, 162, 168, 170, 260, 261, 273, 276, 278, 281, 282,
10002
+ 293, 321, 326, 329, 338, 341, 346, 353, 356, 358, 360, 389, 401, 404, 406, 421,
10003
+ 512, 514, 520, 522, 533, 544, 546, 552, 554, 581, 593, 601, 612, 617, 640, 642,
10004
+ 648, 650, 657, 661, 665, 672, 674, 680, 682, 1041, 1044, 1046, 1061, 1089, 1097, 1109,
10005
+ 1114, 1124, 1125, 1169, 1177, 1189, 1281, 1284, 1285, 1286, 1301, 1304, 1306, 1321, 1344, 1349,
10006
+ 1354, 1360, 1361, 1364, 1365, 1366, 1369, 1376, 1378, 1381, 1384, 1386, 1409, 1425, 1429, 1432,
10007
+ 1434, 1441, 1444, 1445, 1446, 1449, 1556, 1561, 1601, 1604, 1616, 1618, 1621, 1624, 1632, 1633,
10008
+ 1638, 1641, 1669, 1681, 1684, 1689, 2048, 2050, 2056, 2058, 2069, 2080, 2082, 2088, 2090, 2117,
10009
+ 2129, 2134, 2149, 2176, 2178, 2184, 2186, 2197, 2208, 2210, 2216, 2218, 2309, 2321, 2324, 2329,
10010
+ 2340, 2341, 2369, 2384, 2385, 2389, 2401, 2404, 2409, 2449, 2452, 2454, 2457, 2469, 2560, 2562,
10011
+ 2568, 2570, 2581, 2592, 2594, 2600, 2602, 2629, 2641, 2649, 2657, 2661, 2688, 2690, 2693, 2696,
10012
+ 2698, 2709, 2720, 2722, 2728, 2730, 4112, 4113, 4116, 4121, 4132, 4133, 4161, 4164, 4176, 4181,
10013
+ 4184, 4193, 4196, 4197, 4201, 4241, 4244, 4246, 4257, 4261, 4353, 4356, 4358, 4361, 4368, 4370,
10014
+ 4373, 4376, 4385, 4388, 4393, 4421, 4426, 4432, 4433, 4434, 4436, 4437, 4438, 4441, 4448, 4453,
10015
+ 4484, 4498, 4501, 4513, 4516, 4625, 4628, 4630, 4645, 4672, 4678, 4681, 4690, 4693, 4696, 4698,
10016
+ 4708, 4710, 4741, 4753, 4756, 4758, 4773, 5121, 5126, 5129, 5140, 5141, 5144, 5145, 5153, 5158,
10017
+ 5185, 5189, 5190, 5192, 5194, 5201, 5204, 5205, 5206, 5209, 5218, 5221, 5224, 5252, 5257, 5264,
10018
+ 5268, 5269, 5272, 5273, 5274, 5281, 5284, 5285, 5289, 5378, 5381, 5386, 5393, 5396, 5397, 5398,
10019
+ 5401, 5408, 5410, 5413, 5416, 5418, 5441, 5444, 5445, 5446, 5457, 5458, 5460, 5461, 5462, 5465,
10020
+ 5466, 5473, 5476, 5477, 5478, 5481, 5504, 5506, 5508, 5509, 5512, 5514, 5520, 5521, 5524, 5525,
10021
+ 5526, 5529, 5530, 5536, 5538, 5541, 5633, 5636, 5637, 5638, 5653, 5654, 5656, 5658, 5665, 5670,
10022
+ 5696, 5698, 5700, 5701, 5704, 5706, 5713, 5717, 5718, 5720, 5721, 5729, 5732, 5733, 5736, 5737,
10023
+ 5738, 5766, 5770, 5778, 5781, 5796, 5801, 6161, 6166, 6181, 6209, 6212, 6214, 6217, 6224, 6229,
10024
+ 6232, 6234, 6240, 6241, 6244, 6246, 6249, 6277, 6289, 6292, 6309, 6416, 6418, 6421, 6426, 6433,
10025
+ 6437, 6466, 6468, 6469, 6472, 6481, 6484, 6485, 6486, 6489, 6490, 6496, 6501, 6506, 6537, 6545,
10026
+ 6546, 6549, 6552, 6561, 6566, 6569, 6665, 6678, 6692, 6694, 6724, 6726, 6729, 6736, 6738, 6741,
10027
+ 6744, 6753, 6758, 6761, 6789, 6801, 6806, 6810, 8192, 8194, 8200, 8202, 8213, 8224, 8226, 8229,
10028
+ 8232, 8234, 8261, 8273, 8281, 8289, 8293, 8320, 8322, 8328, 8330, 8341, 8352, 8354, 8357, 8360,
10029
+ 8362, 8453, 8465, 8468, 8473, 8485, 8514, 8516, 8521, 8533, 8536, 8538, 8545, 8548, 8549, 8550,
10030
+ 8581, 8592, 8598, 8601, 8613, 8705, 8712, 8714, 8721, 8725, 8736, 8738, 8744, 8746, 8773, 8785,
10031
+ 8790, 8793, 8805, 8833, 8840, 8842, 8849, 8853, 8864, 8866, 8872, 8874, 9221, 9236, 9238, 9241,
10032
+ 9253, 9284, 9285, 9286, 9289, 9298, 9301, 9304, 9306, 9318, 9349, 9361, 9364, 9369, 9377, 9381,
10033
+ 9481, 9493, 9505, 9513, 9536, 9541, 9544, 9553, 9556, 9557, 9561, 9570, 9573, 9576, 9609, 9616,
10034
+ 9620, 9621, 9624, 9626, 9633, 9636, 9638, 9641, 9733, 9744, 9746, 9753, 9765, 9793, 9801, 9813,
10035
+ 9824, 9825, 9833, 9860, 9862, 9872, 9882, 10240, 10242, 10248, 10250, 10261, 10272, 10274, 10280, 10282,
10036
+ 10309, 10321, 10324, 10341, 10368, 10370, 10376, 10378, 10400, 10402, 10408, 10410, 10505, 10513, 10516, 10521,
10037
+ 10533, 10566, 10569, 10578, 10581, 10593, 10596, 10598, 10601, 10629, 10640, 10646, 10649, 10660, 10661, 10752,
10038
+ 10754, 10760, 10762, 10784, 10786, 10792, 10794, 10821, 10833, 10838, 10841, 10853, 10880, 10882, 10888, 10890,
10039
+ 10901, 10912, 10914, 10920, 10922, 16389, 16401, 16406, 16421, 16457, 16466, 16469, 16472, 16474, 16481, 16484,
10040
+ 16486, 16532, 16537, 16545, 16550, 16640, 16641, 16644, 16646, 16649, 16658, 16661, 16662, 16664, 16666, 16673,
10041
+ 16678, 16681, 16709, 16712, 16714, 16721, 16724, 16725, 16726, 16729, 16730, 16741, 16744, 16746, 16769, 16772,
10042
+ 16774, 16784, 16786, 16789, 16800, 16801, 16802, 16901, 16913, 16916, 16918, 16933, 16961, 16978, 16981, 16986,
10043
+ 16996, 17001, 17033, 17044, 17061, 17409, 17429, 17433, 17449, 17477, 17480, 17482, 17489, 17492, 17493, 17494,
10044
+ 17505, 17506, 17509, 17512, 17514, 17537, 17542, 17545, 17552, 17554, 17557, 17568, 17569, 17577, 17665, 17666,
10045
+ 17669, 17674, 17681, 17684, 17685, 17686, 17689, 17696, 17701, 17706, 17729, 17732, 17733, 17734, 17737, 17744,
10046
+ 17745, 17748, 17749, 17750, 17752, 17753, 17761, 17764, 17765, 17766, 17769, 17794, 17796, 17797, 17800, 17809,
10047
+ 17812, 17813, 17814, 17817, 17818, 17829, 17832, 17834, 17921, 17925, 17929, 17940, 17941, 17944, 17946, 17953,
10048
+ 17956, 17961, 17984, 17986, 17989, 17992, 18000, 18001, 18002, 18005, 18006, 18009, 18018, 18021, 18024, 18049,
10049
+ 18053, 18058, 18068, 18069, 18081, 18084, 18086, 18437, 18449, 18453, 18458, 18469, 18498, 18505, 18512, 18517,
10050
+ 18520, 18529, 18532, 18534, 18537, 18565, 18577, 18580, 18582, 18585, 18597, 18689, 18693, 18694, 18698, 18704,
10051
+ 18708, 18709, 18712, 18721, 18724, 18726, 18752, 18757, 18762, 18769, 18770, 18772, 18773, 18774, 18777, 18784,
10052
+ 18786, 18789, 18790, 18794, 18822, 18825, 18834, 18837, 18838, 18840, 18849, 18852, 18854, 18857, 18966, 19012,
10053
+ 19014, 19017, 19029, 19032, 19034, 19044, 19049, 19092, 19109, 20481, 20484, 20485, 20486, 20489, 20498, 20501,
10054
+ 20506, 20513, 20516, 20521, 20544, 20549, 20552, 20561, 20564, 20565, 20566, 20569, 20581, 20584, 20614, 20617,
10055
+ 20629, 20632, 20640, 20641, 20646, 20649, 20741, 20744, 20745, 20746, 20753, 20756, 20757, 20758, 20760, 20761,
10056
+ 20768, 20773, 20774, 20776, 20778, 20801, 20804, 20805, 20806, 20809, 20816, 20817, 20818, 20820, 20821, 20822,
10057
+ 20824, 20825, 20826, 20833, 20836, 20837, 20838, 20841, 20866, 20869, 20881, 20884, 20885, 20886, 20889, 20896,
10058
+ 20901, 20906, 20993, 20998, 21010, 21013, 21018, 21025, 21028, 21058, 21061, 21066, 21073, 21076, 21077, 21078,
10059
+ 21081, 21090, 21093, 21125, 21136, 21138, 21141, 21145, 21146, 21156, 21508, 21509, 21521, 21524, 21525, 21526,
10060
+ 21528, 21529, 21537, 21541, 21544, 21546, 21569, 21572, 21573, 21574, 21577, 21578, 21584, 21585, 21588, 21589,
10061
+ 21590, 21592, 21593, 21594, 21601, 21602, 21604, 21605, 21606, 21609, 21632, 21640, 21642, 21649, 21652, 21653,
10062
+ 21654, 21657, 21665, 21668, 21669, 21674, 21761, 21762, 21764, 21765, 21766, 21769, 21776, 21777, 21778, 21780,
10063
+ 21781, 21782, 21785, 21786, 21793, 21796, 21797, 21798, 21801, 21824, 21825, 21826, 21828, 21829, 21830, 21832,
10064
+ 21833, 21840, 21841, 21842, 21844, 21845, 21846, 21848, 21849, 21850, 21856, 21857, 21860, 21861, 21862, 21864,
10065
+ 21865, 21866, 21889, 21892, 21893, 21897, 21898, 21904, 21905, 21908, 21909, 21910, 21912, 21913, 21921, 21924,
10066
+ 21925, 21926, 21929, 22016, 22017, 22018, 22020, 22022, 22024, 22025, 22033, 22036, 22037, 22040, 22041, 22048,
10067
+ 22049, 22050, 22052, 22053, 22054, 22056, 22057, 22081, 22085, 22086, 22088, 22089, 22090, 22096, 22097, 22098,
10068
+ 22100, 22101, 22102, 22104, 22105, 22106, 22113, 22116, 22117, 22121, 22146, 22149, 22150, 22152, 22153, 22154,
10069
+ 22161, 22165, 22170, 22178, 22181, 22182, 22184, 22185, 22532, 22533, 22534, 22537, 22544, 22549, 22552, 22561,
10070
+ 22570, 22597, 22600, 22602, 22609, 22612, 22613, 22614, 22616, 22617, 22624, 22626, 22628, 22629, 22658, 22665,
10071
+ 22672, 22674, 22677, 22680, 22689, 22697, 22785, 22786, 22789, 22794, 22801, 22804, 22805, 22806, 22809, 22821,
10072
+ 22849, 22852, 22853, 22854, 22857, 22864, 22865, 22866, 22868, 22869, 22870, 22872, 22873, 22874, 22881, 22884,
10073
+ 22885, 22886, 22889, 22913, 22917, 22921, 22929, 22932, 22933, 22934, 22936, 22937, 22949, 23044, 23048, 23061,
10074
+ 23066, 23072, 23077, 23078, 23081, 23109, 23112, 23113, 23121, 23125, 23126, 23128, 23129, 23138, 23141, 23144,
10075
+ 23146, 23169, 23178, 23186, 23189, 23190, 23192, 23194, 23201, 24581, 24596, 24598, 24601, 24613, 24644, 24656,
10076
+ 24661, 24662, 24664, 24666, 24673, 24676, 24678, 24681, 24705, 24726, 24741, 24833, 24836, 24838, 24841, 24850,
10077
+ 24853, 24865, 24866, 24870, 24873, 24901, 24905, 24913, 24917, 24918, 24921, 24933, 24934, 24938, 24964, 24970,
10078
+ 24978, 24981, 24993, 24998, 25001, 25105, 25110, 25113, 25152, 25153, 25158, 25173, 25174, 25176, 25184, 25221,
10079
+ 25233, 25238, 25253, 25617, 25618, 25621, 25622, 25626, 25633, 25638, 25641, 25664, 25666, 25669, 25672, 25674,
10080
+ 25681, 25684, 25685, 25686, 25689, 25690, 25696, 25698, 25701, 25732, 25733, 25737, 25744, 25746, 25748, 25749,
10081
+ 25750, 25752, 25754, 25761, 25764, 25769, 25861, 25864, 25866, 25873, 25877, 25878, 25881, 25924, 25925, 25926,
10082
+ 25929, 25936, 25937, 25940, 25941, 25942, 25945, 25953, 25956, 25957, 25958, 25961, 25990, 25993, 25994, 26001,
10083
+ 26005, 26006, 26009, 26010, 26018, 26021, 26022, 26024, 26114, 26121, 26133, 26144, 26150, 26152, 26153, 26176,
10084
+ 26181, 26184, 26186, 26193, 26196, 26197, 26198, 26200, 26202, 26208, 26213, 26216, 26240, 26242, 26245, 26250,
10085
+ 26260, 26262, 26264, 26265, 26272, 26276, 26278, 26282, 26646, 26649, 26661, 26689, 26706, 26709, 26714, 26721,
10086
+ 26729, 26757, 26769, 26776, 26790, 26881, 26884, 26896, 26901, 26913, 26916, 26918, 26921, 26944, 26945, 26949,
10087
+ 26950, 26952, 26961, 26964, 26965, 26966, 26969, 26976, 26981, 26986, 27010, 27012, 27018, 27029, 27041, 27044,
10088
+ 27045, 27049, 27153, 27158, 27160, 27201, 27204, 27209, 27216, 27221, 27224, 27226, 27236, 27237, 27241, 27270,
10089
+ 27284, 27288, 27290, 27302, 32768, 32770, 32776, 32778, 32800, 32802, 32808, 32810, 32837, 32848, 32849, 32852,
10090
+ 32854, 32857, 32869, 32896, 32898, 32904, 32906, 32917, 32928, 32930, 32936, 32938, 33029, 33041, 33044, 33046,
10091
+ 33049, 33061, 33089, 33092, 33097, 33104, 33106, 33109, 33110, 33112, 33113, 33124, 33126, 33129, 33157, 33161,
10092
+ 33172, 33174, 33177, 33189, 33280, 33282, 33288, 33290, 33301, 33312, 33314, 33320, 33322, 33361, 33364, 33369,
10093
+ 33381, 33408, 33410, 33416, 33418, 33429, 33440, 33442, 33448, 33450, 33812, 33817, 33857, 33860, 33873, 33877,
10094
+ 33882, 33889, 33892, 33897, 33940, 33945, 34049, 34057, 34066, 34069, 34074, 34086, 34089, 34112, 34113, 34117,
10095
+ 34120, 34129, 34132, 34133, 34134, 34137, 34138, 34149, 34150, 34152, 34154, 34177, 34180, 34182, 34185, 34192,
10096
+ 34194, 34197, 34200, 34214, 34321, 34326, 34329, 34341, 34369, 34372, 34377, 34378, 34384, 34389, 34393, 34394,
10097
+ 34401, 34406, 34410, 34437, 34449, 34458, 34468, 34816, 34818, 34824, 34826, 34837, 34848, 34850, 34856, 34858,
10098
+ 34881, 34885, 34897, 34900, 34905, 34917, 34921, 34944, 34946, 34952, 34954, 34965, 34976, 34978, 34984, 34986,
10099
+ 35077, 35078, 35089, 35092, 35094, 35109, 35137, 35140, 35142, 35145, 35152, 35154, 35157, 35162, 35169, 35172,
10100
+ 35205, 35222, 35225, 35237, 35328, 35330, 35336, 35338, 35349, 35360, 35362, 35368, 35370, 35397, 35409, 35412,
10101
+ 35414, 35456, 35458, 35464, 35466, 35477, 35488, 35490, 35496, 35498, 36869, 36881, 36886, 36888, 36889, 36901,
10102
+ 36929, 36934, 36937, 36949, 36952, 36954, 36969, 36970, 36997, 37009, 37012, 37014, 37017, 37029, 37121, 37124,
10103
+ 37126, 37129, 37136, 37141, 37144, 37146, 37153, 37156, 37158, 37161, 37184, 37189, 37200, 37201, 37204, 37205,
10104
+ 37206, 37209, 37218, 37221, 37252, 37254, 37266, 37269, 37272, 37281, 37284, 37286, 37289, 37381, 37393, 37396,
10105
+ 37401, 37413, 37444, 37446, 37449, 37456, 37458, 37461, 37464, 37478, 37481, 37509, 37524, 37526, 37545, 37889,
10106
+ 37892, 37894, 37904, 37909, 37912, 37926, 37952, 37962, 37969, 37972, 37973, 37974, 37976, 37977, 37984, 37985,
10107
+ 37986, 37989, 38020, 38022, 38034, 38036, 38037, 38040, 38049, 38057, 38144, 38149, 38152, 38154, 38160, 38161,
10108
+ 38164, 38165, 38166, 38169, 38177, 38181, 38185, 38186, 38209, 38212, 38213, 38214, 38217, 38224, 38225, 38226,
10109
+ 38228, 38229, 38230, 38232, 38233, 38234, 38241, 38244, 38245, 38246, 38249, 38273, 38277, 38280, 38289, 38290,
10110
+ 38292, 38293, 38294, 38297, 38298, 38304, 38306, 38309, 38312, 38314, 38401, 38404, 38416, 38421, 38425, 38432,
10111
+ 38438, 38441, 38469, 38472, 38473, 38481, 38482, 38485, 38486, 38489, 38501, 38504, 38530, 38532, 38537, 38538,
10112
+ 38546, 38548, 38549, 38564, 38566, 38569, 38917, 38934, 38937, 38949, 38977, 38982, 38992, 38994, 38997, 38998,
10113
+ 39002, 39012, 39013, 39045, 39057, 39062, 39065, 39077, 39172, 39174, 39177, 39184, 39186, 39189, 39192, 39194,
10114
+ 39200, 39201, 39204, 39206, 39232, 39234, 39237, 39240, 39242, 39249, 39252, 39253, 39254, 39257, 39266, 39269,
10115
+ 39270, 39274, 39297, 39300, 39312, 39314, 39317, 39322, 39329, 39334, 39429, 39445, 39461, 39492, 39494, 39497,
10116
+ 39504, 39509, 39512, 39521, 39557, 39569, 39572, 39573, 39574, 40960, 40962, 40968, 40970, 40981, 40992, 40994,
10117
+ 41000, 41002, 41029, 41041, 41044, 41046, 41049, 41088, 41090, 41096, 41098, 41109, 41120, 41122, 41128, 41130,
10118
+ 41221, 41225, 41233, 41236, 41238, 41241, 41242, 41286, 41289, 41297, 41301, 41304, 41306, 41313, 41316, 41349,
10119
+ 41360, 41362, 41366, 41369, 41474, 41480, 41482, 41488, 41497, 41506, 41512, 41514, 41541, 41553, 41558, 41561,
10120
+ 41573, 41600, 41602, 41608, 41610, 41621, 41632, 41634, 41640, 41642, 42009, 42021, 42049, 42052, 42064, 42068,
10121
+ 42069, 42072, 42074, 42081, 42085, 42086, 42088, 42089, 42117, 42246, 42249, 42256, 42258, 42261, 42264, 42278,
10122
+ 42281, 42306, 42309, 42321, 42324, 42325, 42326, 42329, 42341, 42346, 42369, 42372, 42373, 42374, 42377, 42386,
10123
+ 42389, 42392, 42501, 42513, 42518, 42522, 42529, 42533, 42564, 42566, 42570, 42578, 42581, 42582, 42584, 42592,
10124
+ 42594, 42630, 42640, 42645, 42646, 42649, 42657, 42660, 42662, 43008, 43010, 43016, 43018, 43040, 43042, 43048,
10125
+ 43050, 43089, 43092, 43094, 43097, 43136, 43138, 43144, 43146, 43157, 43168, 43170, 43176, 43178, 43269, 43284,
10126
+ 43289, 43297, 43301, 43329, 43344, 43349, 43354, 43361, 43366, 43369, 43408, 43414, 43520, 43522, 43528, 43530,
10127
+ 43552, 43554, 43560, 43562, 43601, 43604, 43606, 43648, 43650, 43656, 43658, 43669, 43680, 43682, 43688, 43690,
10128
  };
10129
  static const uint16_t kgrid_2bit_1024[1024] = {
10130
  0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
 
10198
  const int nwant = type == GGML_TYPE_IQ1_S ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2;
10199
  const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
10200
  type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 :
10201
+ type == GGML_TYPE_IQ1_S ? kgrid_1bit_2048 : kgrid_2bit_1024;
10202
  uint64_t * kgrid_q2xs;
10203
  int * kmap_q2xs;
10204
  uint16_t * kneighbors_q2xs;
 
11437
  return grid_index;
11438
  }
11439
 
11440
+ static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
11441
+ const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L, int ngrid) {
11442
+ int num_neighbors = neighbours[0];
11443
+ GGML_ASSERT(num_neighbors > 0);
11444
+ float best_score = FLT_MAX;
11445
+ int grid_index = -1;
11446
+ for (int j = 1; j <= num_neighbors; ++j) {
11447
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
11448
+ float d2 = 0;
11449
+ for (int i = 0; i < 8; ++i) {
11450
+ float q = (pg[i] - 3)/2;
11451
+ float w = weight[i];
11452
+ float diff = scale*q - xval[i];
11453
+ d2 += w*diff*diff;
11454
+ }
11455
+ if (d2 < best_score) {
11456
+ best_score = d2;
11457
+ grid_index = neighbours[j];
11458
+ }
11459
+ }
11460
+ if (grid_index < 0) {
11461
+ for (int i = 0; i < ngrid; ++i) {
11462
+ const int8_t * grid_i = (const int8_t *)(grid + i);
11463
+ float d2 = 0;
11464
+ for (int j = 0; j < 8; ++j) {
11465
+ float w = weight[j];
11466
+ float q = (grid_i[j] - 3)/2;
11467
+ float diff = scale*q - xval[i];
11468
+ d2 += w*diff*diff;
11469
+ }
11470
+ if (d2 < best_score) {
11471
+ best_score = d2;
11472
+ grid_index = i;
11473
+ }
11474
+ }
11475
+ }
11476
+ if (grid_index < 0) {
11477
+ printf("Oops, did not find grid point\n");
11478
+ printf("Have %d neighbours\n", num_neighbors);
11479
+ for (int j = 1; j <= num_neighbors; ++j) {
11480
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
11481
+ float sumqx = 0, sumq2 = 0;
11482
+ for (int i = 0; i < 8; ++i) {
11483
+ float q = (pg[i] - 3)/2;
11484
+ float w = weight[i];
11485
+ sumqx += w*q*xval[i];
11486
+ sumq2 += w*q*q;
11487
+ }
11488
+ printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
11489
+ }
11490
+ }
11491
+ GGML_ASSERT(grid_index >= 0);
11492
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
11493
+ for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
11494
+ return grid_index;
11495
+ }
11496
+
11497
  static int iq1_sort_helper(const void * left, const void * right) {
11498
  const float * l = left;
11499
  const float * r = right;
11500
  return *l < *r ? -1 : *l > *r ? 1 : 0;
11501
  }
11502
 
11503
+ #define IQ1S_BLOCK_SIZE 32
11504
  static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
11505
 
11506
  const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
 
11519
 
11520
  block_iq1_s * y = vy;
11521
 
11522
+ float scales[QK_K/IQ1S_BLOCK_SIZE];
11523
+ float weight[IQ1S_BLOCK_SIZE];
11524
+ int8_t L[IQ1S_BLOCK_SIZE];
11525
+ float sumx[IQ1S_BLOCK_SIZE+1];
11526
+ float sumw[IQ1S_BLOCK_SIZE+1];
11527
+ float pairs[2*IQ1S_BLOCK_SIZE];
11528
  int * idx = (int *)(pairs + 1);
11529
+ uint16_t index[IQ1S_BLOCK_SIZE/8];
11530
 
11531
  for (int ibl = 0; ibl < nbl; ++ibl) {
11532
 
11533
  y[ibl].d = GGML_FP32_TO_FP16(0.f);
11534
  memset(y[ibl].qs, 0, QK_K/8);
11535
+ memset(y[ibl].qh, 0, QK_K/16);
11536
 
11537
  float max_scale = 0;
11538
 
11539
  const float * xbl = x + QK_K*ibl;
11540
  float sumx2 = 0;
11541
  for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
11542
+ float sigma2 = 2*sumx2/QK_K;
11543
 
11544
+ for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ++ib) {
11545
+ const float * xb = xbl + IQ1S_BLOCK_SIZE*ib;
11546
+ const float * qw = quant_weights + QK_K*ibl + IQ1S_BLOCK_SIZE*ib;
11547
+ for (int i = 0; i < IQ1S_BLOCK_SIZE; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
11548
  float max = fabsf(xb[0]);
11549
+ for (int i = 1; i < IQ1S_BLOCK_SIZE; ++i) max = MAX(max, fabsf(xb[i]));
11550
  if (!max) {
11551
  scales[ib] = 0;
11552
+ memset(L, 1, IQ1S_BLOCK_SIZE);
11553
  continue;
11554
  }
11555
  // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
 
11558
  // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
11559
  // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
11560
  // for each possible and score for each split.
11561
+ for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) {
11562
  pairs[2*j] = xb[j];
11563
  idx[2*j] = j;
11564
  }
11565
+ qsort(pairs, IQ1S_BLOCK_SIZE, 2*sizeof(float), iq1_sort_helper);
11566
  {
11567
  sumx[0] = sumw[0] = 0;
11568
+ for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) {
11569
  int i = idx[2*j];
11570
  sumx[j+1] = sumx[j] + weight[i]*xb[i];
11571
  sumw[j+1] = sumw[j] + weight[i];
 
11573
  }
11574
  float best_score = 0, scale = max;
11575
  int besti1 = 0, besti2 = 0;
11576
+ for (int i1 = 0; i1 <= IQ1S_BLOCK_SIZE; ++i1) {
11577
+ for (int i2 = i1; i2 <= IQ1S_BLOCK_SIZE; ++i2) {
11578
+ float sumqx = -(sumx[i1] - sumx[0]) + (sumx[IQ1S_BLOCK_SIZE] - sumx[i2]);
11579
+ float sumq2 = (sumw[i1] - sumw[0]) + (sumw[IQ1S_BLOCK_SIZE] - sumw[i2]);
11580
  if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
11581
  scale = sumqx/sumq2; best_score = scale*sumqx;
11582
  besti1 = i1; besti2 = i2;
 
11585
  }
11586
  for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
11587
  for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
11588
+ for (int j = besti2; j < IQ1S_BLOCK_SIZE; ++j) L[idx[2*j]] = 2;
11589
  if (scale < 0) {
11590
+ for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) L[j] = 2 - L[j];
11591
  scale = -scale;
11592
  }
11593
+ bool all_on_grid = true;
11594
+ for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
11595
+ uint16_t u = 0;
11596
+ for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
11597
+ int grid_index = kmap_q2xs[u];
11598
+ if (grid_index < 0) {
11599
+ all_on_grid = false;
11600
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
11601
+ grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, L + 8*k, NGRID_IQ1S);
11602
+ GGML_ASSERT(grid_index >= 0);
11603
+ }
11604
+ index[k] = grid_index;
11605
+ }
11606
+ if (!all_on_grid) {
11607
+ float sumqx = 0, sumq2 = 0;
11608
+ for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
11609
+ const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
11610
+ for (int j = 0; j < 8; ++j) {
11611
+ float w = weight[8*k + j];
11612
+ float q = (pg[j] - 3)/2;
11613
+ sumqx += w*q*xb[8*k+j];
11614
+ sumq2 += w*q*q;
11615
+ }
11616
+ }
11617
+ if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2;
11618
+ }
11619
+ uint16_t h = 0;
11620
+ for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
11621
+ y[ibl].qs[(IQ1S_BLOCK_SIZE/8)*ib + k] = index[k] & 255;
11622
+ h |= (index[k] >> 8) << 3*k;
11623
+ }
11624
+ y[ibl].qh[ib] = h;
11625
  GGML_ASSERT(scale >= 0);
11626
  scales[ib] = scale;
11627
  max_scale = MAX(max_scale, scale);
 
11632
  continue;
11633
  }
11634
 
11635
+ float d = max_scale/31;
11636
+ y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.085f is another fudge factor. Don't ask me why it is needed.
11637
  float id = 1/d;
11638
+ for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ++ib) {
11639
  int l = nearest_int(0.5f*(id*scales[ib]-1));
11640
+ l = MAX(0, MIN(15, l));
11641
+ y[ibl].qh[ib] |= (l << 12);
 
11642
  }
11643
  }
11644
  }
ggml-quants.h CHANGED
@@ -217,8 +217,8 @@ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N
217
 
218
  typedef struct {
219
  ggml_fp16_t d;
220
- uint8_t qs[QK_K/8];
221
- uint8_t scales[QK_K/16];
222
  } block_iq1_s;
223
  static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
224
 
 
217
 
218
  typedef struct {
219
  ggml_fp16_t d;
220
+ uint8_t qs[QK_K/8];
221
+ uint16_t qh[QK_K/32];
222
  } block_iq1_s;
223
  static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
224