Kawrakow ikawrakow commited on
Commit
32589c9
·
unverified ·
1 Parent(s): a7eb9f6

IQ3_S: a much better alternative to Q3_K (llama/5676)

Browse files

* iq4_nl: squash commits for easier rebase

* Basics (quantize, dequantize)
* CUDA dequantize and dot product
* Slightly faster CUDA dot product (120 t/s)
* Switch to 6-bit scales
* Scalar dot product
* AVX2 dot product
* ARM_NEON dot product
* Works on metal, but still slow
* Slightly better Metal dot product
* Another small Metal improvement
* Metal dot product is getting there
* Faster CUDA dot product
* Add 1/8 ffn_down layers as Q5_K when no imatrix has been provided
* Report the actual bpw
* Add _xs mix that is 4.05 bpw for non-MoE models
* Remove IQ4_XS for now, slightly adjust kvalues_iq4nl
* AVX2 dot product uses Q8_0 instead of Q8_K
* Add to test-backend-ops
* Minor fix
* Also use use Q5_K for attn_output in MoE models
* Fixes after merging latest master
* Switching to blocks of 32
* AVX2 for blocks of 32
* Scaler dot product for blocks of 32
* ARM_NEON dot product for blocks of 32
* Metal kernels for blocks of 32
* Slightly faster Metal kernels

* Resurrecting iq3_xs

After all the experimentation, nothing was better than this.

* Minor PPL improvement via a block scale fudge factor

* Minor improvement via 3 neighbours

* iq3_xs: working scalar and AVX2 dot products

* iq3_xs: ARM_NEON dot product - works but extremely slow (10 t/s)

* iq3_xs: working Metal implementation

* Adding IQ3_M - IQ3_XS mix with mostly Q4_K

* iiq3_xs: a 3.4375 bpw variant

* iq3_xs: make CUDA work for new version

* iq3_xs: make scalar and AVX2 work for new version

* iq3_s: make ARM_NEON work with new version

* iq3_xs: make new version work on metal

Performance is very similar to Q3_K_S

* iq3_xs: tiny Metal speed improvement

* iq3_xs: tiny Metal speed improvement

* Fix stupid warning

* Q3_K_XS now uses a mix of IQ3_XS and IQ3_XXS

* iq3_xs: rename to iq3_s

* iq3_s: make tests pass

* Move Q3_K_XS mix to 3.25 bpw

* Attempt to fix failing tests

* Another attempt to fix the Windows builds

* Attempt to fix ROCm

* ROCm again

* iq3_s: partial fix for QK_K = 64

* iq3_s: make it work on metal for QK_K = 64

Pleasent surprise: the coding was super-block size independent,
so all it took was to delete some QK_K == 256 guards.

* Will this fix ROCm?

---------

Co-authored-by: Iwan Kawrakow <[email protected]>

Files changed (7) hide show
  1. ggml-cuda.cu +170 -1
  2. ggml-metal.m +29 -4
  3. ggml-metal.metal +304 -0
  4. ggml-quants.c +610 -64
  5. ggml-quants.h +20 -0
  6. ggml.c +31 -0
  7. ggml.h +2 -0
ggml-cuda.cu CHANGED
@@ -172,6 +172,7 @@
172
  #endif
173
 
174
  typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
 
175
  static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
176
  const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
177
  const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
@@ -196,6 +197,18 @@ static __device__ __forceinline__ int __vsub4(const int a, const int b) {
196
  return __vsubss4(a, b);
197
  }
198
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
200
  #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
201
  c = __builtin_amdgcn_sdot4(a, b, c, false);
@@ -518,6 +531,17 @@ typedef struct {
518
  } block_iq3_xxs;
519
  static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
520
 
 
 
 
 
 
 
 
 
 
 
 
521
  #define QR1_S 8
522
  #define QI1_S (QK_K / (4*QR1_S))
523
  typedef struct {
@@ -1700,6 +1724,74 @@ static const __device__ uint32_t iq3xxs_grid[256] = {
1700
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
1701
  };
1702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1703
  static const __device__ uint64_t iq1s_grid[512] = {
1704
  0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
1705
  0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
@@ -1973,6 +2065,32 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
1973
 
1974
  }
1975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1976
  template<typename dst_t>
1977
  static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
1978
 
@@ -4717,6 +4835,41 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
4717
  #endif
4718
  }
4719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4720
  static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
4721
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4722
  #if QK_K == 256
@@ -6849,6 +7002,12 @@ static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k,
6849
  dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
6850
  }
6851
 
 
 
 
 
 
 
6852
  template<typename dst_t>
6853
  static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6854
  const int nb = k / QK_K;
@@ -6904,6 +7063,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
6904
  return dequantize_row_iq1_s_cuda;
6905
  case GGML_TYPE_IQ4_NL:
6906
  return dequantize_row_iq4_nl_cuda;
 
 
6907
  case GGML_TYPE_F32:
6908
  return convert_unary_cuda<float>;
6909
  default:
@@ -6943,6 +7104,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
6943
  return dequantize_row_iq1_s_cuda;
6944
  case GGML_TYPE_IQ4_NL:
6945
  return dequantize_row_iq4_nl_cuda;
 
 
6946
  case GGML_TYPE_F16:
6947
  return convert_unary_cuda<half>;
6948
  default:
@@ -8688,6 +8851,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
8688
  case GGML_TYPE_IQ3_XXS:
8689
  case GGML_TYPE_IQ1_S:
8690
  case GGML_TYPE_IQ4_NL:
 
8691
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
8692
  default:
8693
  GGML_ASSERT(false);
@@ -8713,6 +8877,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
8713
  case GGML_TYPE_IQ3_XXS:
8714
  case GGML_TYPE_IQ1_S:
8715
  case GGML_TYPE_IQ4_NL:
 
8716
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
8717
  case GGML_TYPE_Q6_K:
8718
  return 64;
@@ -8818,6 +8983,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
8818
  mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
8819
  (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8820
  break;
 
 
 
 
8821
  default:
8822
  GGML_ASSERT(false);
8823
  break;
@@ -11541,7 +11710,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
11541
  }
11542
  ggml_type a_type = a->type;
11543
  if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
11544
- a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL) {
11545
  if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
11546
  return false;
11547
  }
 
172
  #endif
173
 
174
  typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
175
+ typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
176
  static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
177
  const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
178
  const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
 
197
  return __vsubss4(a, b);
198
  }
199
 
200
+ static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
201
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
202
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
203
+ unsigned int c;
204
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
205
+ #pragma unroll
206
+ for (int i = 0; i < 4; ++i) {
207
+ vc[i] = va[i] == vb[i] ? 0xff : 0x00;
208
+ }
209
+ return c;
210
+ }
211
+
212
  static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
213
  #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
214
  c = __builtin_amdgcn_sdot4(a, b, c, false);
 
531
  } block_iq3_xxs;
532
  static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
533
 
534
+ #define QR3_XS 8
535
+ #define QI3_XS (QK_K / (4*QR3_XS))
536
+ typedef struct {
537
+ half d;
538
+ uint8_t qs[QK_K/4];
539
+ uint8_t qh[QK_K/32];
540
+ uint8_t signs[QK_K/8];
541
+ uint8_t scales[QK_K/64];
542
+ } block_iq3_s;
543
+ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding");
544
+
545
  #define QR1_S 8
546
  #define QI1_S (QK_K / (4*QR1_S))
547
  typedef struct {
 
1724
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
1725
  };
1726
 
1727
+ static const __device__ uint32_t iq3xs_grid[512] = {
1728
+ 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
1729
+ 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
1730
+ 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
1731
+ 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
1732
+ 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
1733
+ 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
1734
+ 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
1735
+ 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
1736
+ 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
1737
+ 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
1738
+ 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
1739
+ 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
1740
+ 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
1741
+ 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
1742
+ 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
1743
+ 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
1744
+ 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
1745
+ 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
1746
+ 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
1747
+ 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
1748
+ 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
1749
+ 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
1750
+ 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
1751
+ 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
1752
+ 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
1753
+ 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
1754
+ 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
1755
+ 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
1756
+ 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
1757
+ 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
1758
+ 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
1759
+ 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
1760
+ 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
1761
+ 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
1762
+ 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
1763
+ 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
1764
+ 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
1765
+ 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
1766
+ 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
1767
+ 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
1768
+ 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
1769
+ 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
1770
+ 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
1771
+ 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
1772
+ 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
1773
+ 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
1774
+ 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
1775
+ 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
1776
+ 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
1777
+ 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
1778
+ 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
1779
+ 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
1780
+ 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
1781
+ 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
1782
+ 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
1783
+ 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
1784
+ 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
1785
+ 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
1786
+ 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
1787
+ 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
1788
+ 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
1789
+ 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
1790
+ 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
1791
+ 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
1792
+ };
1793
+
1794
+
1795
  static const __device__ uint64_t iq1s_grid[512] = {
1796
  0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
1797
  0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
 
2065
 
2066
  }
2067
 
2068
+ template<typename dst_t>
2069
+ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
2070
+
2071
+ const int i = blockIdx.x;
2072
+ const block_iq3_s * x = (const block_iq3_s *) vx;
2073
+
2074
+ const int tid = threadIdx.x;
2075
+ #if QK_K == 256
2076
+ const int il = tid/8; // 0...3
2077
+ const int ib = tid%8; // 0...7
2078
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
2079
+ const uint8_t * qs = x[i].qs + 8*ib;
2080
+ const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
2081
+ const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
2082
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
2083
+ const uint8_t signs = x[i].signs[4*ib + il];
2084
+ for (int j = 0; j < 4; ++j) {
2085
+ y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
2086
+ y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
2087
+ }
2088
+ #else
2089
+ assert(false);
2090
+ #endif
2091
+
2092
+ }
2093
+
2094
  template<typename dst_t>
2095
  static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
2096
 
 
4835
  #endif
4836
  }
4837
 
4838
+ // TODO: don't use lookup table for signs
4839
+ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
4840
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4841
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
4842
+ #if QK_K == 256
4843
+ const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
4844
+
4845
+ const int ib32 = iqs;
4846
+ const uint8_t * qs = bq2->qs + 8*ib32;
4847
+ const int8_t * q8 = bq8_1[ib32].qs;
4848
+ int sumi = 0;
4849
+ for (int l = 0; l < 4; ++l) {
4850
+ const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
4851
+ const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
4852
+ uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
4853
+ uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
4854
+ const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
4855
+ const int grid_h = __vsub4(grid2[0] ^ signs1, signs1);
4856
+ sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
4857
+ sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
4858
+ q8 += 8;
4859
+ }
4860
+ const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f;
4861
+ return d * sumi;
4862
+ #else
4863
+ assert(false);
4864
+ return 0.f;
4865
+ #endif
4866
+ #else
4867
+ assert(false);
4868
+ return 0.f;
4869
+ #endif
4870
+ }
4871
+
4872
+
4873
  static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
4874
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4875
  #if QK_K == 256
 
7002
  dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
7003
  }
7004
 
7005
+ template<typename dst_t>
7006
+ static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
7007
+ const int nb = k / QK_K;
7008
+ dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
7009
+ }
7010
+
7011
  template<typename dst_t>
7012
  static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
7013
  const int nb = k / QK_K;
 
7063
  return dequantize_row_iq1_s_cuda;
7064
  case GGML_TYPE_IQ4_NL:
7065
  return dequantize_row_iq4_nl_cuda;
7066
+ case GGML_TYPE_IQ3_S:
7067
+ return dequantize_row_iq3_s_cuda;
7068
  case GGML_TYPE_F32:
7069
  return convert_unary_cuda<float>;
7070
  default:
 
7104
  return dequantize_row_iq1_s_cuda;
7105
  case GGML_TYPE_IQ4_NL:
7106
  return dequantize_row_iq4_nl_cuda;
7107
+ case GGML_TYPE_IQ3_S:
7108
+ return dequantize_row_iq3_s_cuda;
7109
  case GGML_TYPE_F16:
7110
  return convert_unary_cuda<half>;
7111
  default:
 
8851
  case GGML_TYPE_IQ3_XXS:
8852
  case GGML_TYPE_IQ1_S:
8853
  case GGML_TYPE_IQ4_NL:
8854
+ case GGML_TYPE_IQ3_S:
8855
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
8856
  default:
8857
  GGML_ASSERT(false);
 
8877
  case GGML_TYPE_IQ3_XXS:
8878
  case GGML_TYPE_IQ1_S:
8879
  case GGML_TYPE_IQ4_NL:
8880
+ case GGML_TYPE_IQ3_S:
8881
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
8882
  case GGML_TYPE_Q6_K:
8883
  return 64;
 
8983
  mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
8984
  (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8985
  break;
8986
+ case GGML_TYPE_IQ3_S:
8987
+ mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
8988
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8989
+ break;
8990
  default:
8991
  GGML_ASSERT(false);
8992
  break;
 
11710
  }
11711
  ggml_type a_type = a->type;
11712
  if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
11713
+ a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S) {
11714
  if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
11715
  return false;
11716
  }
ggml-metal.m CHANGED
@@ -61,6 +61,7 @@ enum ggml_metal_kernel_type {
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
 
64
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
65
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
@@ -85,6 +86,7 @@ enum ggml_metal_kernel_type {
85
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
86
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
87
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
 
88
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
89
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
90
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
@@ -105,6 +107,7 @@ enum ggml_metal_kernel_type {
105
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
106
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
107
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
 
108
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
109
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
110
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
@@ -122,6 +125,7 @@ enum ggml_metal_kernel_type {
122
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
123
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
124
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
 
125
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
126
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
127
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
@@ -139,6 +143,7 @@ enum ggml_metal_kernel_type {
139
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
140
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
141
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
 
142
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
143
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
144
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
@@ -452,6 +457,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
452
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
453
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
454
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
 
455
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
456
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
457
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
@@ -476,6 +482,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
 
479
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
480
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
481
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
@@ -496,6 +503,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
 
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
500
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
501
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
@@ -513,6 +521,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
513
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
514
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
515
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
 
516
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
517
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
518
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
@@ -530,6 +539,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
530
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
531
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
532
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
 
533
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
534
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
535
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
@@ -1347,6 +1357,7 @@ static bool ggml_metal_graph_compute(
1347
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1348
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1349
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
 
1350
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1351
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1352
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
@@ -1483,6 +1494,12 @@ static bool ggml_metal_graph_compute(
1483
  nth1 = 16;
1484
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1485
  } break;
 
 
 
 
 
 
1486
  case GGML_TYPE_IQ1_S:
1487
  {
1488
  nth0 = 4;
@@ -1537,8 +1554,8 @@ static bool ggml_metal_graph_compute(
1537
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1538
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1539
  }
1540
- else if (src0t == GGML_TYPE_IQ3_XXS) {
1541
- const int mem_size = 256*4+128;
1542
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1543
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1544
  }
@@ -1640,6 +1657,7 @@ static bool ggml_metal_graph_compute(
1640
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1641
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1642
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
 
1643
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1644
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1645
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
@@ -1779,6 +1797,12 @@ static bool ggml_metal_graph_compute(
1779
  nth1 = 16;
1780
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1781
  } break;
 
 
 
 
 
 
1782
  case GGML_TYPE_IQ1_S:
1783
  {
1784
  nth0 = 4;
@@ -1849,8 +1873,8 @@ static bool ggml_metal_graph_compute(
1849
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1850
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1851
  }
1852
- else if (src2t == GGML_TYPE_IQ3_XXS) {
1853
- const int mem_size = 256*4+128;
1854
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1855
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1856
  }
@@ -1900,6 +1924,7 @@ static bool ggml_metal_graph_compute(
1900
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1901
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1902
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
 
1903
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
1904
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
1905
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
 
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
65
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
67
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
 
86
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
87
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
88
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
89
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
90
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
91
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
92
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
 
107
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
108
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
109
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
110
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
111
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
112
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
113
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
 
125
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
126
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
127
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
128
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
129
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
130
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
131
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
 
143
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
144
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
145
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
146
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
147
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
148
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
149
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
 
457
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
458
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
459
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
460
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
461
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
462
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
463
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
 
482
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
483
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
484
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
485
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
486
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
487
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
488
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
 
503
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
504
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
505
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
506
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
507
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
508
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
509
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
 
521
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
522
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
523
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
524
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
525
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
526
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
527
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
 
539
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
540
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
541
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
542
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
543
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
544
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
545
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
 
1357
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1358
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1359
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1360
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1361
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1362
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1363
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
 
1494
  nth1 = 16;
1495
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1496
  } break;
1497
+ case GGML_TYPE_IQ3_S:
1498
+ {
1499
+ nth0 = 4;
1500
+ nth1 = 16;
1501
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
1502
+ } break;
1503
  case GGML_TYPE_IQ1_S:
1504
  {
1505
  nth0 = 4;
 
1554
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1555
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1556
  }
1557
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1558
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1559
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1560
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1561
  }
 
1657
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1658
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1659
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1660
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
1661
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1662
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1663
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
 
1797
  nth1 = 16;
1798
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1799
  } break;
1800
+ case GGML_TYPE_IQ3_S:
1801
+ {
1802
+ nth0 = 4;
1803
+ nth1 = 16;
1804
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
1805
+ } break;
1806
  case GGML_TYPE_IQ1_S:
1807
  {
1808
  nth0 = 4;
 
1873
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1874
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1875
  }
1876
+ else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
1877
+ const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1878
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1879
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1880
  }
 
1924
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1925
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1926
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1927
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
1928
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
1929
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
1930
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
ggml-metal.metal CHANGED
@@ -2525,6 +2525,20 @@ typedef struct {
2525
  } block_iq3_xxs;
2526
  // 98 bytes / block for QK_K = 256, so 3.0625 bpw
2527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2528
  typedef struct {
2529
  half d;
2530
  uint8_t qs[QK_K/8];
@@ -3795,6 +3809,73 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
3795
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3796
  };
3797
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3798
  #define NGRID_IQ1S 512
3799
  constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
3800
  0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
@@ -4361,6 +4442,136 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
4361
  kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4362
  }
4363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4364
  void kernel_mul_mv_iq1_s_f32_impl(
4365
  device const void * src0,
4366
  device const float * src1,
@@ -4952,6 +5163,31 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
4952
  }
4953
  }
4954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4955
  template <typename type4x4>
4956
  void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
4957
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
@@ -5525,6 +5761,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
5525
  template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5526
  template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5527
  template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
 
5528
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
5529
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
5530
 
@@ -5566,6 +5803,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
5566
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5567
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5568
  template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
 
5569
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
5570
  template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
5571
 
@@ -5619,6 +5857,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
5619
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5620
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5621
  template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
 
5622
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
5623
  template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
5624
 
@@ -6589,6 +6828,71 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6589
  sgitg);
6590
  }
6591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6592
  [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6593
  kernel void kernel_mul_mv_id_iq1_s_f32(
6594
  device const char * ids,
 
2525
  } block_iq3_xxs;
2526
  // 98 bytes / block for QK_K = 256, so 3.0625 bpw
2527
 
2528
+ // 3.4375 bpw
2529
+ #if QK_K == 64
2530
+ #define IQ3S_N_SCALE 2
2531
+ #else
2532
+ #define IQ3S_N_SCALE QK_K/64
2533
+ #endif
2534
+ typedef struct {
2535
+ half d;
2536
+ uint8_t qs[QK_K/4];
2537
+ uint8_t qh[QK_K/32];
2538
+ uint8_t signs[QK_K/8];
2539
+ uint8_t scales[IQ3S_N_SCALE];
2540
+ } block_iq3_s;
2541
+
2542
  typedef struct {
2543
  half d;
2544
  uint8_t qs[QK_K/8];
 
3809
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3810
  };
3811
 
3812
+ constexpr constant static uint32_t iq3xs_grid[512] = {
3813
+ 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
3814
+ 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
3815
+ 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
3816
+ 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
3817
+ 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
3818
+ 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
3819
+ 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
3820
+ 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
3821
+ 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
3822
+ 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
3823
+ 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
3824
+ 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
3825
+ 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
3826
+ 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
3827
+ 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
3828
+ 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
3829
+ 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
3830
+ 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
3831
+ 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
3832
+ 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
3833
+ 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
3834
+ 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
3835
+ 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
3836
+ 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
3837
+ 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
3838
+ 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
3839
+ 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
3840
+ 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
3841
+ 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
3842
+ 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
3843
+ 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
3844
+ 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
3845
+ 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
3846
+ 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
3847
+ 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
3848
+ 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
3849
+ 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
3850
+ 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
3851
+ 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
3852
+ 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
3853
+ 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
3854
+ 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
3855
+ 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
3856
+ 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
3857
+ 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
3858
+ 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
3859
+ 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
3860
+ 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
3861
+ 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
3862
+ 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
3863
+ 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
3864
+ 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
3865
+ 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
3866
+ 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
3867
+ 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
3868
+ 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
3869
+ 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
3870
+ 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
3871
+ 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
3872
+ 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
3873
+ 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
3874
+ 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
3875
+ 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
3876
+ 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
3877
+ };
3878
+
3879
  #define NGRID_IQ1S 512
3880
  constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
3881
  0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
 
4442
  kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4443
  }
4444
 
4445
+ void kernel_mul_mv_iq3_s_f32_impl(
4446
+ device const void * src0,
4447
+ device const float * src1,
4448
+ device float * dst,
4449
+ constant int64_t & ne00,
4450
+ constant int64_t & ne01,
4451
+ constant int64_t & ne02,
4452
+ constant int64_t & ne10,
4453
+ constant int64_t & ne12,
4454
+ constant int64_t & ne0,
4455
+ constant int64_t & ne1,
4456
+ constant uint & r2,
4457
+ constant uint & r3,
4458
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4459
+ uint3 tgpig[[threadgroup_position_in_grid]],
4460
+ uint tiisg[[thread_index_in_simdgroup]],
4461
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4462
+
4463
+ const int nb = ne00/QK_K;
4464
+ const int r0 = tgpig.x;
4465
+ const int r1 = tgpig.y;
4466
+ const int im = tgpig.z;
4467
+
4468
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4469
+ const int ib_row = first_row * nb;
4470
+
4471
+ const uint i12 = im%ne12;
4472
+ const uint i13 = im/ne12;
4473
+
4474
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4475
+
4476
+ device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
4477
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4478
+
4479
+ float yl[32];
4480
+ float sumf[N_DST]={0.f}, all_sum;
4481
+
4482
+ const int nb32 = nb * (QK_K / 32);
4483
+
4484
+ threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
4485
+ {
4486
+ int nval = 8;
4487
+ int pos = (32*sgitg + tiisg)*nval;
4488
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i];
4489
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4490
+ }
4491
+
4492
+ const int ix = tiisg;
4493
+
4494
+ device const float * y4 = y + 32 * ix;
4495
+
4496
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
4497
+
4498
+ for (int i = 0; i < 32; ++i) {
4499
+ yl[i] = y4[i];
4500
+ }
4501
+
4502
+ const int ibl = ib32 / (QK_K / 32);
4503
+ const int ib = ib32 % (QK_K / 32);
4504
+
4505
+ device const block_iq3_s * xr = x + ibl;
4506
+ device const uint8_t * qs = xr->qs + 8 * ib;
4507
+ device const uint8_t * qh = xr->qh + ib;
4508
+ device const uint8_t * sc = xr->scales + (ib/2);
4509
+ device const uint8_t * signs = xr->signs + 4 * ib;
4510
+ device const half * dh = &xr->d;
4511
+
4512
+ for (int row = 0; row < N_DST; row++) {
4513
+
4514
+ const float db = dh[0];
4515
+ const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf));
4516
+
4517
+ float2 sum = {0};
4518
+ for (int l = 0; l < 4; ++l) {
4519
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
4520
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
4521
+ for (int j = 0; j < 4; ++j) {
4522
+ sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
4523
+ sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
4524
+ }
4525
+ }
4526
+ sumf[row] += d * (sum[0] + sum[1]);
4527
+
4528
+ dh += nb*sizeof(block_iq3_s)/2;
4529
+ qs += nb*sizeof(block_iq3_s);
4530
+ qh += nb*sizeof(block_iq3_s);
4531
+ sc += nb*sizeof(block_iq3_s);
4532
+ signs += nb*sizeof(block_iq3_s);
4533
+ }
4534
+
4535
+ y4 += 32 * 32;
4536
+ }
4537
+
4538
+ for (int row = 0; row < N_DST; ++row) {
4539
+ all_sum = simd_sum(sumf[row]);
4540
+ if (tiisg == 0) {
4541
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
4542
+ }
4543
+ }
4544
+ }
4545
+
4546
+ [[host_name("kernel_mul_mv_iq3_s_f32")]]
4547
+ kernel void kernel_mul_mv_iq3_s_f32(
4548
+ device const void * src0,
4549
+ device const float * src1,
4550
+ device float * dst,
4551
+ constant int64_t & ne00,
4552
+ constant int64_t & ne01,
4553
+ constant int64_t & ne02,
4554
+ constant uint64_t & nb00,
4555
+ constant uint64_t & nb01,
4556
+ constant uint64_t & nb02,
4557
+ constant int64_t & ne10,
4558
+ constant int64_t & ne11,
4559
+ constant int64_t & ne12,
4560
+ constant uint64_t & nb10,
4561
+ constant uint64_t & nb11,
4562
+ constant uint64_t & nb12,
4563
+ constant int64_t & ne0,
4564
+ constant int64_t & ne1,
4565
+ constant uint & r2,
4566
+ constant uint & r3,
4567
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4568
+ uint3 tgpig[[threadgroup_position_in_grid]],
4569
+ uint tiisg[[thread_index_in_simdgroup]],
4570
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4571
+
4572
+ kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4573
+ }
4574
+
4575
  void kernel_mul_mv_iq1_s_f32_impl(
4576
  device const void * src0,
4577
  device const float * src1,
 
5163
  }
5164
  }
5165
 
5166
+ template <typename type4x4>
5167
+ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
5168
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5169
+ const float d = xb->d;
5170
+ const int ib32 = il/2;
5171
+ il = il%2;
5172
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5173
+ device const uint8_t * qs = xb->qs + 8*ib32;
5174
+ device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
5175
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
5176
+ const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f;
5177
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
5178
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
5179
+ for (int i = 0; i < 4; ++i) {
5180
+ reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
5181
+ reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
5182
+ }
5183
+ grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
5184
+ grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
5185
+ for (int i = 0; i < 4; ++i) {
5186
+ reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
5187
+ reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
5188
+ }
5189
+ }
5190
+
5191
  template <typename type4x4>
5192
  void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
5193
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
 
5761
  template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5762
  template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5763
  template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5764
+ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
5765
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
5766
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
5767
 
 
5803
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5804
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5805
  template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5806
+ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
5807
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
5808
  template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
5809
 
 
5857
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5858
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5859
  template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5860
+ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
5861
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
5862
  template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
5863
 
 
6828
  sgitg);
6829
  }
6830
 
6831
+ [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
6832
+ kernel void kernel_mul_mv_id_iq3_s_f32(
6833
+ device const char * ids,
6834
+ device const char * src1,
6835
+ device float * dst,
6836
+ constant uint64_t & nbi1,
6837
+ constant int64_t & ne00,
6838
+ constant int64_t & ne01,
6839
+ constant int64_t & ne02,
6840
+ constant uint64_t & nb00,
6841
+ constant uint64_t & nb01,
6842
+ constant uint64_t & nb02,
6843
+ constant int64_t & ne10,
6844
+ constant int64_t & ne11,
6845
+ constant int64_t & ne12,
6846
+ constant int64_t & ne13,
6847
+ constant uint64_t & nb10,
6848
+ constant uint64_t & nb11,
6849
+ constant uint64_t & nb12,
6850
+ constant int64_t & ne0,
6851
+ constant int64_t & ne1,
6852
+ constant uint64_t & nb1,
6853
+ constant uint & r2,
6854
+ constant uint & r3,
6855
+ constant int & idx,
6856
+ device const char * src00,
6857
+ device const char * src01,
6858
+ device const char * src02,
6859
+ device const char * src03,
6860
+ device const char * src04,
6861
+ device const char * src05,
6862
+ device const char * src06,
6863
+ device const char * src07,
6864
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
6865
+ uint3 tgpig[[threadgroup_position_in_grid]],
6866
+ uint tiitg[[thread_index_in_threadgroup]],
6867
+ uint tiisg[[thread_index_in_simdgroup]],
6868
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6869
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6870
+
6871
+ const int64_t bid = tgpig.z/(ne12*ne13);
6872
+
6873
+ tgpig.z = tgpig.z%(ne12*ne13);
6874
+
6875
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6876
+
6877
+ kernel_mul_mv_iq3_s_f32_impl(
6878
+ src0[id],
6879
+ (device const float *) (src1 + bid*nb11),
6880
+ dst + bid*ne0,
6881
+ ne00,
6882
+ ne01,
6883
+ ne02,
6884
+ ne10,
6885
+ ne12,
6886
+ ne0,
6887
+ ne1,
6888
+ r2,
6889
+ r3,
6890
+ shared_values,
6891
+ tgpig,
6892
+ tiisg,
6893
+ sgitg);
6894
+ }
6895
+
6896
  [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6897
  kernel void kernel_mul_mv_id_iq1_s_f32(
6898
  device const char * ids,
ggml-quants.c CHANGED
@@ -3505,6 +3505,73 @@ static const uint32_t iq3xxs_grid[256] = {
3505
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3506
  };
3507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3508
  #define NGRID_IQ2XXS 512
3509
  static const uint64_t iq1s_grid[NGRID_IQ2XXS] = {
3510
  0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
@@ -3736,6 +3803,49 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
3736
  }
3737
  }
3738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3739
  // ====================== 1.5625 bpw (de)-quantization
3740
 
3741
  void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
@@ -8806,6 +8916,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8806
 
8807
  #endif
8808
 
 
8809
  static const int8_t keven_signs_q2xs[1024] = {
8810
  1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
8811
  1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
@@ -8840,6 +8951,7 @@ static const int8_t keven_signs_q2xs[1024] = {
8840
  1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
8841
  1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
8842
  };
 
8843
 
8844
  void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8845
  assert(n % QK_K == 0);
@@ -9327,6 +9439,202 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
9327
  #endif
9328
  }
9329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9330
  #ifdef __AVX2__
9331
  static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
9332
  const __m256i ax = _mm256_sign_epi8(x, x);
@@ -9523,6 +9831,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
9523
  float sumf = 0;
9524
 
9525
  for (int ib = 0; ib < nb; ib += 2) {
 
9526
  q4bits.val[0] = vld1q_u8(x[ib+0].qs);
9527
  q4bits.val[1] = vld1q_u8(x[ib+1].qs);
9528
  q8b.val[0] = vld1q_s8(y[ib+0].qs);
@@ -10239,14 +10548,15 @@ typedef struct {
10239
  uint16_t * neighbours;
10240
  } iq3_entry_t;
10241
 
10242
- static iq3_entry_t iq3_data[1] = {
 
10243
  {NULL, NULL, NULL},
10244
  };
10245
 
10246
  static inline int iq3_data_index(int grid_size) {
10247
  (void)grid_size;
10248
- GGML_ASSERT(grid_size == 256);
10249
- return 0;
10250
  }
10251
 
10252
  static int iq3_compare_func(const void * left, const void * right) {
@@ -10278,9 +10588,44 @@ void iq3xs_init_impl(int grid_size) {
10278
  3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
10279
  3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
10280
  };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10281
  const int kmap_size = 4096;
10282
- const int nwant = 2;
10283
- const uint16_t * kgrid = kgrid_256;
10284
  uint32_t * kgrid_q3xs;
10285
  int * kmap_q3xs;
10286
  uint16_t * kneighbors_q3xs;
@@ -10377,7 +10722,7 @@ void iq3xs_init_impl(int grid_size) {
10377
  }
10378
 
10379
  void iq3xs_free_impl(int grid_size) {
10380
- GGML_ASSERT(grid_size == 256);
10381
  const int gindex = iq3_data_index(grid_size);
10382
  if (iq3_data[gindex].grid) {
10383
  free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL;
@@ -10410,9 +10755,10 @@ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const u
10410
  return grid_index;
10411
  }
10412
 
10413
- static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
 
10414
 
10415
- const int gindex = iq3_data_index(256);
10416
 
10417
  const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
10418
  const int * kmap_q3xs = iq3_data[gindex].map;
@@ -10426,9 +10772,23 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
10426
 
10427
  const int kMaxQ = 8;
10428
 
10429
- const int nbl = n/256;
10430
 
10431
- block_iq3_xxs * y = vy;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10432
 
10433
  float scales[QK_K/32];
10434
  float weight[32];
@@ -10439,20 +10799,21 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
10439
  bool is_on_grid[8];
10440
  bool is_on_grid_aux[8];
10441
  uint8_t block_signs[8];
10442
- uint8_t q3[3*(QK_K/8)];
10443
  uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
 
10444
 
10445
  for (int ibl = 0; ibl < nbl; ++ibl) {
10446
 
10447
- y[ibl].d = GGML_FP32_TO_FP16(0.f);
10448
- memset(q3, 0, 3*QK_K/8);
10449
 
10450
  float max_scale = 0;
10451
 
10452
  const float * xbl = x + QK_K*ibl;
10453
  float sumx2 = 0;
10454
  for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
10455
- float sigma2 = sumx2/QK_K;
10456
 
10457
  for (int ib = 0; ib < QK_K/32; ++ib) {
10458
  const float * xb = xbl + 32*ib;
@@ -10570,7 +10931,13 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
10570
  printf("\n");
10571
  GGML_ASSERT(false);
10572
  }
10573
- q3[8*ib+k] = grid_index;
 
 
 
 
 
 
10574
  }
10575
  scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
10576
  GGML_ASSERT(scale >= 0);
@@ -10579,63 +10946,25 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
10579
  }
10580
 
10581
  if (!max_scale) {
10582
- memset(y[ibl].qs, 0, 3*QK_K/8);
 
 
10583
  continue;
10584
  }
10585
 
10586
  float d = max_scale/31;
10587
- y[ibl].d = GGML_FP32_TO_FP16(d);
10588
  float id = 1/d;
10589
- float sumqx = 0, sumq2 = 0;
10590
  for (int ib = 0; ib < QK_K/32; ++ib) {
10591
  int l = nearest_int(0.5f*(id*scales[ib]-1));
10592
  l = MAX(0, MIN(15, l));
10593
  scales_and_signs[ib] |= ((uint32_t)l << 28);
10594
- if (false) {
10595
- const float * xb = xbl + 32*ib;
10596
- if (quant_weights) {
10597
- const float * qw = quant_weights + QK_K*ibl + 32*ib;
10598
- for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
10599
- } else {
10600
- for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
10601
- }
10602
- const float db = 0.25f * d * (1 + 2*l);
10603
- for (int k = 0; k < 8; ++k) {
10604
- const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2);
10605
- const float * xk = xb + 4*k;
10606
- const float * wk = weight + 4*k;
10607
- //const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
10608
- const uint8_t * grid = (const uint8_t *)(iq3xxs_grid + q3[8*ib+k]);
10609
- float best_mse = 0; int best_index = q3[8*ib+k];
10610
- for (int j = 0; j < 4; ++j) {
10611
- float diff = db * grid[j] * signs[j] - xk[j];
10612
- best_mse += wk[j] * diff * diff;
10613
- }
10614
- for (int idx = 0; idx < 256; ++idx) {
10615
- //grid = (const uint8_t *)(kgrid_q3xs + idx);
10616
- grid = (const uint8_t *)(iq3xxs_grid + idx);
10617
- float mse = 0;
10618
- for (int j = 0; j < 4; ++j) {
10619
- float diff = db * grid[j] * signs[j] - xk[j];
10620
- mse += wk[j] * diff * diff;
10621
- }
10622
- if (mse < best_mse) {
10623
- best_mse = mse; best_index = idx;
10624
- }
10625
- }
10626
- q3[8*ib+k] = best_index;
10627
- //grid = (const uint8_t *)(kgrid_q3xs + best_index);
10628
- grid = (const uint8_t *)(iq3xxs_grid + best_index);
10629
- for (int j = 0; j < 4; ++j) {
10630
- float q = db * grid[j] * signs[j];
10631
- sumqx += wk[j] * q * xk[j];
10632
- sumq2 += wk[j] * q * q;
10633
- }
10634
- }
10635
- if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
10636
- }
10637
  }
10638
- memcpy(y[ibl].qs, q3, 3*QK_K/8);
 
 
 
 
10639
  }
10640
  }
10641
 
@@ -10645,7 +10974,7 @@ size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row,
10645
  int nblock = n_per_row/QK_K;
10646
  char * qrow = (char *)dst;
10647
  for (int row = 0; row < nrow; ++row) {
10648
- quantize_row_iq3_xxs_impl(src, qrow, n_per_row, quant_weights);
10649
  src += n_per_row;
10650
  qrow += nblock*sizeof(block_iq3_xxs);
10651
  }
@@ -10660,9 +10989,226 @@ void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
10660
 
10661
  void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
10662
  assert(k % QK_K == 0);
10663
- quantize_row_iq3_xxs_impl(x, y, k, NULL);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10664
  }
10665
 
 
 
 
 
 
 
10666
  // =================================== 1.5 bpw ===================================================
10667
 
10668
  static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
 
3505
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3506
  };
3507
 
3508
+ static const uint32_t iq3xs_grid[512] = {
3509
+ 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
3510
+ 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
3511
+ 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
3512
+ 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
3513
+ 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
3514
+ 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
3515
+ 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
3516
+ 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
3517
+ 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
3518
+ 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
3519
+ 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
3520
+ 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
3521
+ 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
3522
+ 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
3523
+ 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
3524
+ 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
3525
+ 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
3526
+ 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
3527
+ 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
3528
+ 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
3529
+ 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
3530
+ 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
3531
+ 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
3532
+ 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
3533
+ 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
3534
+ 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
3535
+ 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
3536
+ 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
3537
+ 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
3538
+ 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
3539
+ 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
3540
+ 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
3541
+ 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
3542
+ 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
3543
+ 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
3544
+ 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
3545
+ 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
3546
+ 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
3547
+ 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
3548
+ 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
3549
+ 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
3550
+ 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
3551
+ 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
3552
+ 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
3553
+ 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
3554
+ 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
3555
+ 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
3556
+ 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
3557
+ 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
3558
+ 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
3559
+ 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
3560
+ 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
3561
+ 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
3562
+ 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
3563
+ 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
3564
+ 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
3565
+ 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
3566
+ 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
3567
+ 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
3568
+ 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
3569
+ 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
3570
+ 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
3571
+ 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
3572
+ 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
3573
+ };
3574
+
3575
  #define NGRID_IQ2XXS 512
3576
  static const uint64_t iq1s_grid[NGRID_IQ2XXS] = {
3577
  0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
 
3803
  }
3804
  }
3805
 
3806
+ // ====================== 3.3125 bpw (de)-quantization
3807
+
3808
+ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) {
3809
+ assert(k % QK_K == 0);
3810
+ const int nb = k / QK_K;
3811
+
3812
+ for (int i = 0; i < nb; i++) {
3813
+
3814
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3815
+ const uint8_t * qs = x[i].qs;
3816
+ const uint8_t * qh = x[i].qh;
3817
+ const uint8_t * signs = x[i].signs;
3818
+
3819
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3820
+ const float db1 = d * (0.5f + (x[i].scales[ib32/2] & 0xf)) * 0.5f;
3821
+ const float db2 = d * (0.5f + (x[i].scales[ib32/2] >> 4)) * 0.5f;
3822
+ for (int l = 0; l < 4; ++l) {
3823
+ const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
3824
+ const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
3825
+ for (int j = 0; j < 4; ++j) {
3826
+ y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
3827
+ y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
3828
+ }
3829
+ y += 8;
3830
+ }
3831
+ qs += 8;
3832
+ signs += 4;
3833
+ for (int l = 0; l < 4; ++l) {
3834
+ const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)));
3835
+ const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)));
3836
+ for (int j = 0; j < 4; ++j) {
3837
+ y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
3838
+ y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
3839
+ }
3840
+ y += 8;
3841
+ }
3842
+ qh += 2;
3843
+ qs += 8;
3844
+ signs += 4;
3845
+ }
3846
+ }
3847
+ }
3848
+
3849
  // ====================== 1.5625 bpw (de)-quantization
3850
 
3851
  void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
 
8916
 
8917
  #endif
8918
 
8919
+ #if defined (__AVX2__) || defined (__ARM_NEON)
8920
  static const int8_t keven_signs_q2xs[1024] = {
8921
  1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
8922
  1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
 
8951
  1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
8952
  1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
8953
  };
8954
+ #endif
8955
 
8956
  void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8957
  assert(n % QK_K == 0);
 
9439
  #endif
9440
  }
9441
 
9442
+ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
9443
+ assert(n % QK_K == 0);
9444
+ assert(nrc == 1);
9445
+ UNUSED(nrc);
9446
+ UNUSED(bx);
9447
+ UNUSED(by);
9448
+ UNUSED(bs);
9449
+
9450
+ const block_iq3_s * restrict x = vx;
9451
+ const block_q8_K * restrict y = vy;
9452
+
9453
+ const int nb = n / QK_K;
9454
+
9455
+ #if defined(__ARM_NEON)
9456
+
9457
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
9458
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
9459
+ };
9460
+
9461
+ static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
9462
+
9463
+ const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1);
9464
+ const uint8x16_t mask2 = vld1q_u8(k_mask2);
9465
+
9466
+ uint8x16x2_t vs;
9467
+ ggml_int8x16x4_t q3s;
9468
+ ggml_int8x16x4_t q8b;
9469
+
9470
+ float sumf = 0;
9471
+ for (int i = 0; i < nb; ++i) {
9472
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
9473
+ const uint8_t * restrict qs = x[i].qs;
9474
+ const uint8_t * restrict qh = x[i].qh;
9475
+ const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
9476
+ const int8_t * restrict q8 = y[i].qs;
9477
+ int sumi1 = 0, sumi2 = 0;
9478
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
9479
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
9480
+ const uint32x4_t aux32x4_0 = {iq3xs_grid[qs[ 0] | ((qh[ib32+0] << 8) & 256)], iq3xs_grid[qs[ 1] | ((qh[ib32+0] << 7) & 256)],
9481
+ iq3xs_grid[qs[ 2] | ((qh[ib32+0] << 6) & 256)], iq3xs_grid[qs[ 3] | ((qh[ib32+0] << 5) & 256)]};
9482
+ const uint32x4_t aux32x4_1 = {iq3xs_grid[qs[ 4] | ((qh[ib32+0] << 4) & 256)], iq3xs_grid[qs[ 5] | ((qh[ib32+0] << 3) & 256)],
9483
+ iq3xs_grid[qs[ 6] | ((qh[ib32+0] << 2) & 256)], iq3xs_grid[qs[ 7] | ((qh[ib32+0] << 1) & 256)]};
9484
+ const uint32x4_t aux32x4_2 = {iq3xs_grid[qs[ 8] | ((qh[ib32+1] << 8) & 256)], iq3xs_grid[qs[ 9] | ((qh[ib32+1] << 7) & 256)],
9485
+ iq3xs_grid[qs[10] | ((qh[ib32+1] << 6) & 256)], iq3xs_grid[qs[11] | ((qh[ib32+1] << 5) & 256)]};
9486
+ const uint32x4_t aux32x4_3 = {iq3xs_grid[qs[12] | ((qh[ib32+1] << 4) & 256)], iq3xs_grid[qs[13] | ((qh[ib32+1] << 3) & 256)],
9487
+ iq3xs_grid[qs[14] | ((qh[ib32+1] << 2) & 256)], iq3xs_grid[qs[15] | ((qh[ib32+1] << 1) & 256)]};
9488
+ qs += 16;
9489
+
9490
+ vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
9491
+ vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
9492
+ vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
9493
+ vs.val[0] = vceqq_u8(vs.val[0], mask2);
9494
+ vs.val[1] = vceqq_u8(vs.val[1], mask2);
9495
+
9496
+ q3s.val[0] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_0))), vreinterpretq_s8_u8(vs.val[0]));
9497
+ q3s.val[1] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_1))), vreinterpretq_s8_u8(vs.val[1]));
9498
+
9499
+ vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
9500
+ vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
9501
+ vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
9502
+ vs.val[0] = vceqq_u8(vs.val[0], mask2);
9503
+ vs.val[1] = vceqq_u8(vs.val[1], mask2);
9504
+
9505
+ signs += 4;
9506
+
9507
+ q3s.val[2] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_2))), vreinterpretq_s8_u8(vs.val[0]));
9508
+ q3s.val[3] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_3))), vreinterpretq_s8_u8(vs.val[1]));
9509
+
9510
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
9511
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
9512
+ sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf));
9513
+ sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4));
9514
+ }
9515
+ sumf += d*(sumi1 + sumi2);
9516
+ }
9517
+ *s = 0.25f * sumf;
9518
+
9519
+ #elif defined(__AVX2__)
9520
+
9521
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
9522
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
9523
+ };
9524
+
9525
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
9526
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
9527
+ };
9528
+
9529
+ const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
9530
+ const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
9531
+
9532
+ __m256 accumf = _mm256_setzero_ps();
9533
+ for (int i = 0; i < nb; ++i) {
9534
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
9535
+ const uint8_t * restrict qs = x[i].qs;
9536
+ const uint8_t * restrict qh = x[i].qh;
9537
+ const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
9538
+ const int8_t * restrict q8 = y[i].qs;
9539
+ __m256i sumi1 = _mm256_setzero_si256();
9540
+ __m256i sumi2 = _mm256_setzero_si256();
9541
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
9542
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
9543
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
9544
+ const __m256i q2_1 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+0] << 1) & 256)],
9545
+ iq3xs_grid[qs[6] | ((qh[ib32+0] << 2) & 256)],
9546
+ iq3xs_grid[qs[5] | ((qh[ib32+0] << 3) & 256)],
9547
+ iq3xs_grid[qs[4] | ((qh[ib32+0] << 4) & 256)],
9548
+ iq3xs_grid[qs[3] | ((qh[ib32+0] << 5) & 256)],
9549
+ iq3xs_grid[qs[2] | ((qh[ib32+0] << 6) & 256)],
9550
+ iq3xs_grid[qs[1] | ((qh[ib32+0] << 7) & 256)],
9551
+ iq3xs_grid[qs[0] | ((qh[ib32+0] << 8) & 256)]);
9552
+ qs += 8;
9553
+ const __m256i q2_2 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+1] << 1) & 256)],
9554
+ iq3xs_grid[qs[6] | ((qh[ib32+1] << 2) & 256)],
9555
+ iq3xs_grid[qs[5] | ((qh[ib32+1] << 3) & 256)],
9556
+ iq3xs_grid[qs[4] | ((qh[ib32+1] << 4) & 256)],
9557
+ iq3xs_grid[qs[3] | ((qh[ib32+1] << 5) & 256)],
9558
+ iq3xs_grid[qs[2] | ((qh[ib32+1] << 6) & 256)],
9559
+ iq3xs_grid[qs[1] | ((qh[ib32+1] << 7) & 256)],
9560
+ iq3xs_grid[qs[0] | ((qh[ib32+1] << 8) & 256)]);
9561
+ qs += 8;
9562
+
9563
+ __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
9564
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
9565
+ const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
9566
+ const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
9567
+
9568
+ aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));
9569
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
9570
+ const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
9571
+ const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
9572
+
9573
+ signs += 4;
9574
+
9575
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
9576
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
9577
+ const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
9578
+ const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
9579
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
9580
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
9581
+ sumi1 = _mm256_add_epi32(sumi1, p1);
9582
+ sumi2 = _mm256_add_epi32(sumi2, p2);
9583
+ }
9584
+
9585
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
9586
+
9587
+ }
9588
+
9589
+ *s = 0.25f * hsum_float_8(accumf);
9590
+
9591
+ #else
9592
+
9593
+ float sumf = 0.f;
9594
+ for (int i = 0; i < nb; ++i) {
9595
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
9596
+ const uint8_t * restrict qs = x[i].qs;
9597
+ const uint8_t * restrict qh = x[i].qh;
9598
+ const uint8_t * restrict signs = x[i].signs;
9599
+ const int8_t * restrict q8 = y[i].qs;
9600
+ int32_t bsum = 0;
9601
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
9602
+ const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
9603
+ const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1;
9604
+ int32_t sumi = 0;
9605
+ for (int l = 0; l < 4; ++l) {
9606
+ const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
9607
+ const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
9608
+ for (int j = 0; j < 4; ++j) {
9609
+ sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
9610
+ sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
9611
+ }
9612
+ q8 += 8;
9613
+ }
9614
+ qs += 8;
9615
+ signs += 4;
9616
+ bsum += sumi * ls1;
9617
+ sumi = 0;
9618
+ for (int l = 0; l < 4; ++l) {
9619
+ const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
9620
+ const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
9621
+ for (int j = 0; j < 4; ++j) {
9622
+ sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
9623
+ sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
9624
+ }
9625
+ q8 += 8;
9626
+ }
9627
+ qs += 8;
9628
+ signs += 4;
9629
+ bsum += sumi * ls2;
9630
+ }
9631
+ sumf += d * bsum;
9632
+ }
9633
+ *s = 0.25f * sumf;
9634
+ #endif
9635
+ }
9636
+
9637
+
9638
  #ifdef __AVX2__
9639
  static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
9640
  const __m256i ax = _mm256_sign_epi8(x, x);
 
9831
  float sumf = 0;
9832
 
9833
  for (int ib = 0; ib < nb; ib += 2) {
9834
+
9835
  q4bits.val[0] = vld1q_u8(x[ib+0].qs);
9836
  q4bits.val[1] = vld1q_u8(x[ib+1].qs);
9837
  q8b.val[0] = vld1q_s8(y[ib+0].qs);
 
10548
  uint16_t * neighbours;
10549
  } iq3_entry_t;
10550
 
10551
+ static iq3_entry_t iq3_data[2] = {
10552
+ {NULL, NULL, NULL},
10553
  {NULL, NULL, NULL},
10554
  };
10555
 
10556
  static inline int iq3_data_index(int grid_size) {
10557
  (void)grid_size;
10558
+ GGML_ASSERT(grid_size == 256 || grid_size == 512);
10559
+ return grid_size == 256 ? 0 : 1;
10560
  }
10561
 
10562
  static int iq3_compare_func(const void * left, const void * right) {
 
10588
  3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
10589
  3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
10590
  };
10591
+ static const uint16_t kgrid_512[512] = {
10592
+ 0, 1, 2, 5, 7, 8, 9, 10, 12, 14, 16, 17, 21, 27, 32, 34,
10593
+ 37, 39, 41, 43, 48, 50, 57, 60, 63, 64, 65, 66, 68, 72, 73, 77,
10594
+ 80, 83, 87, 89, 93, 100, 113, 117, 122, 128, 129, 133, 135, 136, 139, 142,
10595
+ 145, 149, 152, 156, 162, 165, 167, 169, 171, 184, 187, 195, 201, 205, 208, 210,
10596
+ 217, 219, 222, 228, 232, 234, 247, 249, 253, 256, 267, 271, 273, 276, 282, 288,
10597
+ 291, 297, 312, 322, 324, 336, 338, 342, 347, 353, 357, 359, 374, 379, 390, 393,
10598
+ 395, 409, 426, 441, 448, 450, 452, 464, 466, 470, 475, 488, 492, 512, 513, 514,
10599
+ 516, 520, 521, 523, 525, 527, 528, 530, 537, 540, 542, 556, 558, 561, 570, 576,
10600
+ 577, 579, 582, 584, 588, 593, 600, 603, 609, 616, 618, 632, 638, 640, 650, 653,
10601
+ 655, 656, 660, 666, 672, 675, 685, 688, 698, 705, 708, 711, 712, 715, 721, 727,
10602
+ 728, 732, 737, 754, 760, 771, 773, 778, 780, 793, 795, 802, 806, 808, 812, 833,
10603
+ 840, 843, 849, 856, 858, 873, 912, 916, 919, 932, 934, 961, 963, 968, 970, 977,
10604
+ 989, 993, 1010, 1016, 1024, 1025, 1027, 1029, 1031, 1032, 1034, 1036, 1038, 1041, 1043, 1047,
10605
+ 1048, 1050, 1057, 1059, 1061, 1064, 1066, 1079, 1080, 1083, 1085, 1088, 1090, 1096, 1099, 1103,
10606
+ 1106, 1109, 1113, 1116, 1122, 1129, 1153, 1156, 1159, 1169, 1171, 1176, 1183, 1185, 1195, 1199,
10607
+ 1209, 1212, 1216, 1218, 1221, 1225, 1234, 1236, 1241, 1243, 1250, 1256, 1270, 1281, 1287, 1296,
10608
+ 1299, 1306, 1309, 1313, 1338, 1341, 1348, 1353, 1362, 1375, 1376, 1387, 1400, 1408, 1410, 1415,
10609
+ 1425, 1453, 1457, 1477, 1481, 1494, 1496, 1507, 1512, 1538, 1545, 1547, 1549, 1551, 1554, 1561,
10610
+ 1563, 1565, 1570, 1572, 1575, 1577, 1587, 1593, 1601, 1603, 1605, 1612, 1617, 1619, 1632, 1648,
10611
+ 1658, 1662, 1664, 1674, 1680, 1690, 1692, 1704, 1729, 1736, 1740, 1745, 1747, 1751, 1752, 1761,
10612
+ 1763, 1767, 1773, 1787, 1795, 1801, 1806, 1810, 1817, 1834, 1840, 1844, 1857, 1864, 1866, 1877,
10613
+ 1882, 1892, 1902, 1915, 1934, 1953, 1985, 1987, 2000, 2002, 2013, 2048, 2052, 2058, 2064, 2068,
10614
+ 2071, 2074, 2081, 2088, 2104, 2114, 2119, 2121, 2123, 2130, 2136, 2141, 2147, 2153, 2157, 2177,
10615
+ 2179, 2184, 2189, 2193, 2203, 2208, 2223, 2226, 2232, 2244, 2249, 2251, 2256, 2258, 2265, 2269,
10616
+ 2304, 2306, 2324, 2335, 2336, 2361, 2373, 2375, 2385, 2418, 2443, 2460, 2480, 2504, 2509, 2520,
10617
+ 2531, 2537, 2562, 2568, 2572, 2578, 2592, 2596, 2599, 2602, 2614, 2620, 2625, 2627, 2629, 2634,
10618
+ 2641, 2650, 2682, 2688, 2697, 2707, 2712, 2718, 2731, 2754, 2759, 2760, 2775, 2788, 2793, 2805,
10619
+ 2811, 2817, 2820, 2832, 2842, 2854, 2890, 2902, 2921, 2923, 2978, 3010, 3012, 3026, 3081, 3083,
10620
+ 3085, 3097, 3099, 3120, 3136, 3152, 3159, 3188, 3210, 3228, 3234, 3245, 3250, 3256, 3264, 3276,
10621
+ 3281, 3296, 3349, 3363, 3378, 3392, 3395, 3420, 3440, 3461, 3488, 3529, 3531, 3584, 3588, 3591,
10622
+ 3600, 3602, 3614, 3616, 3628, 3634, 3650, 3657, 3668, 3683, 3685, 3713, 3716, 3720, 3726, 3729,
10623
+ 3736, 3753, 3778, 3802, 3805, 3819, 3841, 3845, 3851, 3856, 3880, 3922, 3938, 3970, 3993, 4032,
10624
+ };
10625
+
10626
  const int kmap_size = 4096;
10627
+ const int nwant = grid_size == 256 ? 2 : 3;
10628
+ const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
10629
  uint32_t * kgrid_q3xs;
10630
  int * kmap_q3xs;
10631
  uint16_t * kneighbors_q3xs;
 
10722
  }
10723
 
10724
  void iq3xs_free_impl(int grid_size) {
10725
+ GGML_ASSERT(grid_size == 256 || grid_size == 512);
10726
  const int gindex = iq3_data_index(grid_size);
10727
  if (iq3_data[gindex].grid) {
10728
  free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL;
 
10755
  return grid_index;
10756
  }
10757
 
10758
+ static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int n,
10759
+ const float * restrict quant_weights) {
10760
 
10761
+ const int gindex = iq3_data_index(grid_size);
10762
 
10763
  const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
10764
  const int * kmap_q3xs = iq3_data[gindex].map;
 
10772
 
10773
  const int kMaxQ = 8;
10774
 
10775
+ const int nbl = n/QK_K;
10776
 
10777
+ ggml_fp16_t * dh;
10778
+ uint8_t * qs;
10779
+ int block_size;
10780
+ if (grid_size == 256) {
10781
+ block_iq3_xxs * y = vy;
10782
+ dh = &y->d;
10783
+ qs = y->qs;
10784
+ block_size = sizeof(block_iq3_xxs);
10785
+ } else {
10786
+ block_iq3_s * y = vy;
10787
+ dh = &y->d;
10788
+ qs = y->qs;
10789
+ block_size = sizeof(block_iq3_s);
10790
+ }
10791
+ int quant_size = block_size - sizeof(ggml_fp16_t);
10792
 
10793
  float scales[QK_K/32];
10794
  float weight[32];
 
10799
  bool is_on_grid[8];
10800
  bool is_on_grid_aux[8];
10801
  uint8_t block_signs[8];
10802
+ uint8_t q3[3*(QK_K/8)+QK_K/32];
10803
  uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
10804
+ uint8_t * qh = q3 + 3*(QK_K/8);
10805
 
10806
  for (int ibl = 0; ibl < nbl; ++ibl) {
10807
 
10808
+ dh[0] = GGML_FP32_TO_FP16(0.f);
10809
+ memset(q3, 0, 3*QK_K/8+QK_K/32);
10810
 
10811
  float max_scale = 0;
10812
 
10813
  const float * xbl = x + QK_K*ibl;
10814
  float sumx2 = 0;
10815
  for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
10816
+ float sigma2 = 2*sumx2/QK_K;
10817
 
10818
  for (int ib = 0; ib < QK_K/32; ++ib) {
10819
  const float * xb = xbl + 32*ib;
 
10931
  printf("\n");
10932
  GGML_ASSERT(false);
10933
  }
10934
+ if (grid_size == 256) {
10935
+ q3[8*ib+k] = grid_index;
10936
+ } else {
10937
+ q3[8*ib+k] = grid_index & 255;
10938
+ qh[ib] |= ((grid_index >> 8) << k);
10939
+ }
10940
+
10941
  }
10942
  scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
10943
  GGML_ASSERT(scale >= 0);
 
10946
  }
10947
 
10948
  if (!max_scale) {
10949
+ memset(qs, 0, quant_size);
10950
+ dh += block_size/sizeof(ggml_fp16_t);
10951
+ qs += block_size;
10952
  continue;
10953
  }
10954
 
10955
  float d = max_scale/31;
10956
+ dh[0] = GGML_FP32_TO_FP16(d * 1.0125f); // small improvement via this fudge factor
10957
  float id = 1/d;
 
10958
  for (int ib = 0; ib < QK_K/32; ++ib) {
10959
  int l = nearest_int(0.5f*(id*scales[ib]-1));
10960
  l = MAX(0, MIN(15, l));
10961
  scales_and_signs[ib] |= ((uint32_t)l << 28);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10962
  }
10963
+ memcpy(qs, q3, quant_size);
10964
+
10965
+ dh += block_size/sizeof(ggml_fp16_t);
10966
+ qs += block_size;
10967
+
10968
  }
10969
  }
10970
 
 
10974
  int nblock = n_per_row/QK_K;
10975
  char * qrow = (char *)dst;
10976
  for (int row = 0; row < nrow; ++row) {
10977
+ quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights);
10978
  src += n_per_row;
10979
  qrow += nblock*sizeof(block_iq3_xxs);
10980
  }
 
10989
 
10990
  void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
10991
  assert(k % QK_K == 0);
10992
+ quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
10993
+ }
10994
+
10995
+ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, void * restrict vy, int n,
10996
+ const float * restrict quant_weights,
10997
+ float * scales,
10998
+ float * weight,
10999
+ float * xval,
11000
+ int8_t * L,
11001
+ int8_t * Laux,
11002
+ float * waux,
11003
+ bool * is_on_grid,
11004
+ bool * is_on_grid_aux,
11005
+ uint8_t * block_signs) {
11006
+
11007
+ const int gindex = iq3_data_index(512);
11008
+
11009
+ const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
11010
+ const int * kmap_q3xs = iq3_data[gindex].map;
11011
+ const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
11012
+
11013
+ //GGML_ASSERT(quant_weights && "missing quantization weights");
11014
+ GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?");
11015
+ GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?");
11016
+ GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
11017
+ GGML_ASSERT(n%QK_K == 0);
11018
+
11019
+ const int kMaxQ = 8;
11020
+
11021
+ const int nbl = n/QK_K;
11022
+
11023
+ block_iq3_s * y = vy;
11024
+
11025
+ const int bs4 = block_size/4;
11026
+ const int bs8 = block_size/8;
11027
+
11028
+ for (int ibl = 0; ibl < nbl; ++ibl) {
11029
+
11030
+ memset(&y[ibl], 0, sizeof(block_iq3_s));
11031
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
11032
+
11033
+ uint8_t * qs = y[ibl].qs;
11034
+ uint8_t * qh = y[ibl].qh;
11035
+ uint8_t * signs = y[ibl].signs;
11036
+
11037
+ float max_scale = 0;
11038
+
11039
+ const float * xbl = x + QK_K*ibl;
11040
+ float sumx2 = 0;
11041
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
11042
+ float sigma2 = 2*sumx2/QK_K;
11043
+
11044
+ for (int ib = 0; ib < QK_K/block_size; ++ib) {
11045
+ const float * xb = xbl + block_size*ib;
11046
+ if (quant_weights) {
11047
+ const float * qw = quant_weights + QK_K*ibl + block_size*ib;
11048
+ for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
11049
+ } else {
11050
+ for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
11051
+ }
11052
+ for (int i = 0; i < block_size; ++i) waux[i] = sqrtf(weight[i]);
11053
+ for (int k = 0; k < bs8; ++k) {
11054
+ uint8_t s = 0;
11055
+ for (int i = 0; i < 8; ++i) {
11056
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
11057
+ else {
11058
+ xval[8*k + i] = -xb[8*k + i]; s |= (1 << i);
11059
+ }
11060
+ }
11061
+ block_signs[k] = s;
11062
+ }
11063
+ float max = xval[0];
11064
+ for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]);
11065
+ if (!max) {
11066
+ scales[ib] = 0;
11067
+ continue;
11068
+ }
11069
+ float best = 0;
11070
+ float scale = max/(2*kMaxQ-1);
11071
+ for (int is = -15; is <= 15; ++is) {
11072
+ float id = (2*kMaxQ-1+is*0.2f)/max;
11073
+ float this_scale = 1/id;
11074
+ for (int k = 0; k < bs4; ++k) {
11075
+ for (int i = 0; i < 4; ++i) {
11076
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
11077
+ Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
11078
+ }
11079
+ uint16_t u = 0;
11080
+ for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
11081
+ int grid_index = kmap_q3xs[u];
11082
+ is_on_grid_aux[k] = true;
11083
+ if (grid_index < 0) {
11084
+ is_on_grid_aux[k] = false;
11085
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
11086
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
11087
+ }
11088
+ }
11089
+ float sumqx = 0, sumq2 = 0;
11090
+ for (int i = 0; i < block_size; ++i) {
11091
+ float w = weight[i];
11092
+ float q = 2*Laux[i] + 1;
11093
+ sumqx += w*xval[i]*q;
11094
+ sumq2 += w*q*q;
11095
+ }
11096
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
11097
+ scale = sumqx/sumq2; best = scale*sumqx;
11098
+ for (int i = 0; i < block_size; ++i) L[i] = Laux[i];
11099
+ for (int k = 0; k < bs4; ++k) is_on_grid[k] = is_on_grid_aux[k];
11100
+ }
11101
+ }
11102
+ int n_not_ongrid = 0;
11103
+ for (int k = 0; k < bs4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
11104
+ if (n_not_ongrid > 0 && scale > 0) {
11105
+ float id = 1/scale;
11106
+ for (int k = 0; k < bs4; ++k) {
11107
+ if (is_on_grid[k]) continue;
11108
+ uint16_t u = 0;
11109
+ for (int i = 0; i < 4; ++i) {
11110
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
11111
+ l = MAX(0, MIN(kMaxQ-1, l));
11112
+ u |= (l << 3*i);
11113
+ }
11114
+ int grid_index = kmap_q3xs[u];
11115
+ if (grid_index < 0) {
11116
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
11117
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
11118
+ }
11119
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
11120
+ for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
11121
+ }
11122
+ float sumqx = 0, sumq2 = 0;
11123
+ for (int i = 0; i < block_size; ++i) {
11124
+ float w = weight[i];
11125
+ float q = 2*L[i] + 1;
11126
+ sumqx += w*xval[i]*q;
11127
+ sumq2 += w*q*q;
11128
+ }
11129
+ if (sumq2 > 0) scale = sumqx/sumq2;
11130
+ }
11131
+ if (scale < 0) {
11132
+ // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
11133
+ // and correspondingly flip quant signs.
11134
+ scale = -scale;
11135
+ for (int k = 0; k < bs8; ++k) block_signs[k] = ~block_signs[k];
11136
+ }
11137
+ for (int k = 0; k < bs4; ++k) {
11138
+ uint16_t u = 0;
11139
+ for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
11140
+ int grid_index = kmap_q3xs[u];
11141
+ if (grid_index < 0) {
11142
+ printf("Oops: found point %u not on grid:", u);
11143
+ for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
11144
+ printf("\n");
11145
+ GGML_ASSERT(false);
11146
+ }
11147
+ qs[k] = grid_index & 255;
11148
+ qh[(ib*bs4+k)/8] |= ((grid_index >> 8) << ((ib*bs4+k)%8));
11149
+ }
11150
+ qs += bs4;
11151
+ for (int k = 0; k < bs8; ++k) signs[k] = block_signs[k];
11152
+ signs += bs8;
11153
+ GGML_ASSERT(scale >= 0);
11154
+ scales[ib] = scale;
11155
+ max_scale = MAX(max_scale, scale);
11156
+ }
11157
+
11158
+ if (!max_scale) {
11159
+ continue;
11160
+ }
11161
+
11162
+ float d = max_scale/31;
11163
+ y[ibl].d = GGML_FP32_TO_FP16(d);
11164
+ float id = 1/d;
11165
+ for (int ib = 0; ib < QK_K/block_size; ib += 2) {
11166
+ int l1 = nearest_int(0.5f*(id*scales[ib+0]-1));
11167
+ l1 = MAX(0, MIN(15, l1));
11168
+ int l2 = nearest_int(0.5f*(id*scales[ib+1]-1));
11169
+ l2 = MAX(0, MIN(15, l2));
11170
+ y[ibl].scales[ib/2] = l1 | (l2 << 4);
11171
+ }
11172
+
11173
+ }
11174
+ }
11175
+
11176
+ #define IQ3S_BLOCK_SIZE 32
11177
+ size_t quantize_iq3_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
11178
+ (void)hist;
11179
+ GGML_ASSERT(n_per_row%QK_K == 0);
11180
+ int nblock = n_per_row/QK_K;
11181
+ float scales[QK_K/IQ3S_BLOCK_SIZE];
11182
+ float weight[IQ3S_BLOCK_SIZE];
11183
+ float xval[IQ3S_BLOCK_SIZE];
11184
+ int8_t L[IQ3S_BLOCK_SIZE];
11185
+ int8_t Laux[IQ3S_BLOCK_SIZE];
11186
+ float waux[IQ3S_BLOCK_SIZE];
11187
+ bool is_on_grid[IQ3S_BLOCK_SIZE/4];
11188
+ bool is_on_grid_aux[IQ3S_BLOCK_SIZE/4];
11189
+ uint8_t block_signs[IQ3S_BLOCK_SIZE/8];
11190
+ char * qrow = (char *)dst;
11191
+ for (int row = 0; row < nrow; ++row) {
11192
+ quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights,
11193
+ scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs);
11194
+ src += n_per_row;
11195
+ qrow += nblock*sizeof(block_iq3_s);
11196
+ }
11197
+ return nrow * nblock * sizeof(block_iq3_s);
11198
+ }
11199
+
11200
+ void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int k) {
11201
+ assert(k % QK_K == 0);
11202
+ block_iq3_s * restrict y = vy;
11203
+ quantize_row_iq3_s_reference(x, y, k);
11204
  }
11205
 
11206
+ void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int k) {
11207
+ assert(k % QK_K == 0);
11208
+ quantize_iq3_s(x, y, 1, k, NULL, NULL);
11209
+ }
11210
+
11211
+
11212
  // =================================== 1.5 bpw ===================================================
11213
 
11214
  static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
ggml-quants.h CHANGED
@@ -191,6 +191,21 @@ typedef struct {
191
  } block_iq3_xxs;
192
  static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  typedef struct {
195
  ggml_fp16_t d;
196
  uint8_t qs[QK_K/8];
@@ -226,6 +241,7 @@ void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGM
226
  void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
227
  void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
228
  void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k);
 
229
 
230
  void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
231
  void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
@@ -242,6 +258,7 @@ void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
242
  void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
243
  void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
244
  void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 
245
 
246
  // Dequantization
247
  void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
@@ -262,6 +279,7 @@ void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_
262
  void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
263
  void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
264
  void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 
265
 
266
  // Dot product
267
  void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -280,6 +298,7 @@ void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
280
  void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
281
  void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
282
  void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
283
 
284
  //
285
  // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
@@ -289,6 +308,7 @@ size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row,
289
  size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
290
  size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
291
  size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 
292
  size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
293
  size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
294
  size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 
191
  } block_iq3_xxs;
192
  static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
193
 
194
+ // 3.4375 bpw
195
+ #if QK_K == 64
196
+ #define IQ3S_N_SCALE 2
197
+ #else
198
+ #define IQ3S_N_SCALE QK_K/64
199
+ #endif
200
+ typedef struct {
201
+ ggml_fp16_t d;
202
+ uint8_t qs[QK_K/4];
203
+ uint8_t qh[QK_K/32];
204
+ uint8_t signs[QK_K/8];
205
+ uint8_t scales[IQ3S_N_SCALE];
206
+ } block_iq3_s;
207
+ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
208
+
209
  typedef struct {
210
  ggml_fp16_t d;
211
  uint8_t qs[QK_K/8];
 
241
  void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
242
  void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
243
  void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k);
244
+ void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int k);
245
 
246
  void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
247
  void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 
258
  void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
259
  void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
260
  void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
261
+ void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
262
 
263
  // Dequantization
264
  void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 
279
  void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
280
  void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
281
  void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
282
+ void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
283
 
284
  // Dot product
285
  void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
298
  void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
299
  void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
300
  void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
301
+ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
302
 
303
  //
304
  // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
 
308
  size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
309
  size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
310
  size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
311
+ size_t quantize_iq3_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
312
  size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
313
  size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
314
  size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
ggml.c CHANGED
@@ -682,6 +682,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
682
  .vec_dot_type = GGML_TYPE_Q8_K,
683
  .nrows = 1,
684
  },
 
 
 
 
 
 
 
 
 
 
 
 
685
  [GGML_TYPE_IQ1_S] = {
686
  .type_name = "iq1_s",
687
  .blck_size = QK_K,
@@ -2308,6 +2320,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2308
  case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
2309
  case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
2310
  case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
 
2311
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
2312
  case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
2313
  }
@@ -7742,6 +7755,7 @@ static void ggml_compute_forward_add(
7742
  case GGML_TYPE_IQ3_XXS:
7743
  case GGML_TYPE_IQ1_S:
7744
  case GGML_TYPE_IQ4_NL:
 
7745
  {
7746
  ggml_compute_forward_add_q_f32(params, dst);
7747
  } break;
@@ -8021,6 +8035,7 @@ static void ggml_compute_forward_add1(
8021
  case GGML_TYPE_IQ3_XXS:
8022
  case GGML_TYPE_IQ1_S:
8023
  case GGML_TYPE_IQ4_NL:
 
8024
  {
8025
  ggml_compute_forward_add1_q_f32(params, dst);
8026
  } break;
@@ -8145,6 +8160,7 @@ static void ggml_compute_forward_acc(
8145
  case GGML_TYPE_IQ3_XXS:
8146
  case GGML_TYPE_IQ1_S:
8147
  case GGML_TYPE_IQ4_NL:
 
8148
  default:
8149
  {
8150
  GGML_ASSERT(false);
@@ -11043,6 +11059,7 @@ static void ggml_compute_forward_out_prod(
11043
  case GGML_TYPE_IQ3_XXS:
11044
  case GGML_TYPE_IQ1_S:
11045
  case GGML_TYPE_IQ4_NL:
 
11046
  {
11047
  ggml_compute_forward_out_prod_q_f32(params, dst);
11048
  } break;
@@ -11231,6 +11248,7 @@ static void ggml_compute_forward_set(
11231
  case GGML_TYPE_IQ3_XXS:
11232
  case GGML_TYPE_IQ1_S:
11233
  case GGML_TYPE_IQ4_NL:
 
11234
  default:
11235
  {
11236
  GGML_ASSERT(false);
@@ -11433,6 +11451,7 @@ static void ggml_compute_forward_get_rows(
11433
  case GGML_TYPE_IQ3_XXS:
11434
  case GGML_TYPE_IQ1_S:
11435
  case GGML_TYPE_IQ4_NL:
 
11436
  {
11437
  ggml_compute_forward_get_rows_q(params, dst);
11438
  } break;
@@ -12133,6 +12152,7 @@ static void ggml_compute_forward_alibi(
12133
  case GGML_TYPE_IQ3_XXS:
12134
  case GGML_TYPE_IQ1_S:
12135
  case GGML_TYPE_IQ4_NL:
 
12136
  case GGML_TYPE_Q8_K:
12137
  case GGML_TYPE_I8:
12138
  case GGML_TYPE_I16:
@@ -12216,6 +12236,7 @@ static void ggml_compute_forward_clamp(
12216
  case GGML_TYPE_IQ3_XXS:
12217
  case GGML_TYPE_IQ1_S:
12218
  case GGML_TYPE_IQ4_NL:
 
12219
  case GGML_TYPE_Q8_K:
12220
  case GGML_TYPE_I8:
12221
  case GGML_TYPE_I16:
@@ -19467,6 +19488,7 @@ void ggml_quantize_init(enum ggml_type type) {
19467
  case GGML_TYPE_IQ2_XS:
19468
  case GGML_TYPE_IQ1_S: iq2xs_init_impl(type); break;
19469
  case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
 
19470
  default: // nothing
19471
  break;
19472
  }
@@ -19741,6 +19763,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
19741
  result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19742
  GGML_ASSERT(result == row_size * nrows);
19743
  } break;
 
 
 
 
 
 
 
 
 
19744
  case GGML_TYPE_IQ1_S:
19745
  {
19746
  GGML_ASSERT(start % QK_K == 0);
 
682
  .vec_dot_type = GGML_TYPE_Q8_K,
683
  .nrows = 1,
684
  },
685
+ [GGML_TYPE_IQ3_S] = {
686
+ .type_name = "iq3_s",
687
+ .blck_size = QK_K,
688
+ .type_size = sizeof(block_iq3_s),
689
+ .is_quantized = true,
690
+ .to_float = (ggml_to_float_t) dequantize_row_iq3_s,
691
+ .from_float = quantize_row_iq3_s,
692
+ .from_float_reference = (ggml_from_float_t)quantize_row_iq3_s_reference,
693
+ .vec_dot = ggml_vec_dot_iq3_s_q8_K,
694
+ .vec_dot_type = GGML_TYPE_Q8_K,
695
+ .nrows = 1,
696
+ },
697
  [GGML_TYPE_IQ1_S] = {
698
  .type_name = "iq1_s",
699
  .blck_size = QK_K,
 
2320
  case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
2321
  case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
2322
  case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
2323
+ case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
2324
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
2325
  case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
2326
  }
 
7755
  case GGML_TYPE_IQ3_XXS:
7756
  case GGML_TYPE_IQ1_S:
7757
  case GGML_TYPE_IQ4_NL:
7758
+ case GGML_TYPE_IQ3_S:
7759
  {
7760
  ggml_compute_forward_add_q_f32(params, dst);
7761
  } break;
 
8035
  case GGML_TYPE_IQ3_XXS:
8036
  case GGML_TYPE_IQ1_S:
8037
  case GGML_TYPE_IQ4_NL:
8038
+ case GGML_TYPE_IQ3_S:
8039
  {
8040
  ggml_compute_forward_add1_q_f32(params, dst);
8041
  } break;
 
8160
  case GGML_TYPE_IQ3_XXS:
8161
  case GGML_TYPE_IQ1_S:
8162
  case GGML_TYPE_IQ4_NL:
8163
+ case GGML_TYPE_IQ3_S:
8164
  default:
8165
  {
8166
  GGML_ASSERT(false);
 
11059
  case GGML_TYPE_IQ3_XXS:
11060
  case GGML_TYPE_IQ1_S:
11061
  case GGML_TYPE_IQ4_NL:
11062
+ case GGML_TYPE_IQ3_S:
11063
  {
11064
  ggml_compute_forward_out_prod_q_f32(params, dst);
11065
  } break;
 
11248
  case GGML_TYPE_IQ3_XXS:
11249
  case GGML_TYPE_IQ1_S:
11250
  case GGML_TYPE_IQ4_NL:
11251
+ case GGML_TYPE_IQ3_S:
11252
  default:
11253
  {
11254
  GGML_ASSERT(false);
 
11451
  case GGML_TYPE_IQ3_XXS:
11452
  case GGML_TYPE_IQ1_S:
11453
  case GGML_TYPE_IQ4_NL:
11454
+ case GGML_TYPE_IQ3_S:
11455
  {
11456
  ggml_compute_forward_get_rows_q(params, dst);
11457
  } break;
 
12152
  case GGML_TYPE_IQ3_XXS:
12153
  case GGML_TYPE_IQ1_S:
12154
  case GGML_TYPE_IQ4_NL:
12155
+ case GGML_TYPE_IQ3_S:
12156
  case GGML_TYPE_Q8_K:
12157
  case GGML_TYPE_I8:
12158
  case GGML_TYPE_I16:
 
12236
  case GGML_TYPE_IQ3_XXS:
12237
  case GGML_TYPE_IQ1_S:
12238
  case GGML_TYPE_IQ4_NL:
12239
+ case GGML_TYPE_IQ3_S:
12240
  case GGML_TYPE_Q8_K:
12241
  case GGML_TYPE_I8:
12242
  case GGML_TYPE_I16:
 
19488
  case GGML_TYPE_IQ2_XS:
19489
  case GGML_TYPE_IQ1_S: iq2xs_init_impl(type); break;
19490
  case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
19491
+ case GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break;
19492
  default: // nothing
19493
  break;
19494
  }
 
19763
  result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19764
  GGML_ASSERT(result == row_size * nrows);
19765
  } break;
19766
+ case GGML_TYPE_IQ3_S:
19767
+ {
19768
+ GGML_ASSERT(start % QK_K == 0);
19769
+ GGML_ASSERT(start % n_per_row == 0);
19770
+ size_t start_row = start / n_per_row;
19771
+ size_t row_size = ggml_row_size(type, n_per_row);
19772
+ result = quantize_iq3_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19773
+ GGML_ASSERT(result == row_size * nrows);
19774
+ } break;
19775
  case GGML_TYPE_IQ1_S:
19776
  {
19777
  GGML_ASSERT(start % QK_K == 0);
ggml.h CHANGED
@@ -350,6 +350,7 @@ extern "C" {
350
  GGML_TYPE_IQ3_XXS = 18,
351
  GGML_TYPE_IQ1_S = 19,
352
  GGML_TYPE_IQ4_NL = 20,
 
353
  GGML_TYPE_I8,
354
  GGML_TYPE_I16,
355
  GGML_TYPE_I32,
@@ -389,6 +390,7 @@ extern "C" {
389
  GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
390
  GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
391
  GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
 
392
  };
393
 
394
  // available tensor operations:
 
350
  GGML_TYPE_IQ3_XXS = 18,
351
  GGML_TYPE_IQ1_S = 19,
352
  GGML_TYPE_IQ4_NL = 20,
353
+ GGML_TYPE_IQ3_S = 21,
354
  GGML_TYPE_I8,
355
  GGML_TYPE_I16,
356
  GGML_TYPE_I32,
 
390
  GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
391
  GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
392
  GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
393
+ GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
394
  };
395
 
396
  // available tensor operations: