Kawrakow ikawrakow commited on
Commit
9c3aa6a
·
unverified ·
1 Parent(s): 26c019a

1.5 bit quantization (llama/5453)

Browse files

* iq1_s: WIP basics

* iq1_s: CUDA is working

* iq1_s: scalar CPU dot product

* iq1_s: WIP AVX2 dot product - something is not right

* Fix tests

* Fix shadow warnings

* Fix after merge with latest master

* iq1_s: AVX2 finally works

* iq1_s: ARM_NEON dot product. Works, but not very fast

* iq1_s: better grid

* iq1_s: use IQ2_XXS for attn_output

At a cost of 0.04 extra bpw this gives a big improvement in PPL.

* iq1_s: Metal basics

Dequantize works, but not dot product

* iq1_s: Metal works, but quite slow

As usual, Apple Silicon does not like the code I write.

* iq1_s: Tests

* iq1_s: slightly faster dot product

---------

Co-authored-by: Iwan Kawrakow <[email protected]>

Files changed (8) hide show
  1. ggml-backend.c +1 -1
  2. ggml-cuda.cu +223 -1
  3. ggml-metal.m +27 -2
  4. ggml-metal.metal +337 -0
  5. ggml-quants.c +627 -30
  6. ggml-quants.h +12 -2
  7. ggml.c +39 -5
  8. ggml.h +2 -0
ggml-backend.c CHANGED
@@ -756,7 +756,7 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
756
  GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
757
  switch (op->op) {
758
  case GGML_OP_CPY:
759
- return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS; // missing type_traits.from_float
760
  case GGML_OP_MUL_MAT:
761
  return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
762
  default:
 
756
  GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
757
  switch (op->op) {
758
  case GGML_OP_CPY:
759
+ return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS && op->type != GGML_TYPE_IQ1_S; // missing type_traits.from_float
760
  case GGML_OP_MUL_MAT:
761
  return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
762
  default:
ggml-cuda.cu CHANGED
@@ -517,6 +517,15 @@ typedef struct {
517
  } block_iq3_xxs;
518
  static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
519
 
 
 
 
 
 
 
 
 
 
520
  #define WARP_SIZE 32
521
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
522
 
@@ -1681,6 +1690,137 @@ static const __device__ uint32_t iq3xxs_grid[256] = {
1681
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
1682
  };
1683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1684
  static const __device__ uint8_t ksigns_iq2xs[128] = {
1685
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
1686
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -1823,6 +1963,29 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
1823
 
1824
  }
1825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1826
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1827
 
1828
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -4522,6 +4685,49 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
4522
  #endif
4523
  }
4524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4525
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
4526
  allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
4527
  static __device__ __forceinline__ void mul_mat_q(
@@ -6561,6 +6767,12 @@ static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k,
6561
  dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
6562
  }
6563
 
 
 
 
 
 
 
6564
  template <typename src_t, typename dst_t>
6565
  static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
6566
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@@ -6600,6 +6812,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
6600
  return dequantize_row_iq2_xs_cuda;
6601
  case GGML_TYPE_IQ3_XXS:
6602
  return dequantize_row_iq3_xxs_cuda;
 
 
6603
  case GGML_TYPE_F32:
6604
  return convert_unary_cuda<float>;
6605
  default:
@@ -6635,6 +6849,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
6635
  return dequantize_row_iq2_xs_cuda;
6636
  case GGML_TYPE_IQ3_XXS:
6637
  return dequantize_row_iq3_xxs_cuda;
 
 
6638
  case GGML_TYPE_F16:
6639
  return convert_unary_cuda<half>;
6640
  default:
@@ -8378,6 +8594,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
8378
  case GGML_TYPE_IQ2_XXS:
8379
  case GGML_TYPE_IQ2_XS:
8380
  case GGML_TYPE_IQ3_XXS:
 
8381
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
8382
  default:
8383
  GGML_ASSERT(false);
@@ -8401,6 +8618,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
8401
  case GGML_TYPE_IQ2_XXS:
8402
  case GGML_TYPE_IQ2_XS:
8403
  case GGML_TYPE_IQ3_XXS:
 
8404
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
8405
  case GGML_TYPE_Q6_K:
8406
  return 64;
@@ -8498,6 +8716,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
8498
  mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
8499
  (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8500
  break;
 
 
 
 
8501
  default:
8502
  GGML_ASSERT(false);
8503
  break;
@@ -11214,7 +11436,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
11214
  return false;
11215
  }
11216
  ggml_type a_type = a->type;
11217
- if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS) {
11218
  if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
11219
  return false;
11220
  }
 
517
  } block_iq3_xxs;
518
  static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
519
 
520
+ #define QR1_S 8
521
+ #define QI1_S (QK_K / (4*QR1_S))
522
+ typedef struct {
523
+ half d;
524
+ uint8_t qs[QK_K/8];
525
+ uint8_t scales[QK_K/16];
526
+ } block_iq1_s;
527
+ static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
528
+
529
  #define WARP_SIZE 32
530
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
531
 
 
1690
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
1691
  };
1692
 
1693
+ static const __device__ uint64_t iq1s_grid[512] = {
1694
+ 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
1695
+ 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
1696
+ 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
1697
+ 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
1698
+ 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
1699
+ 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
1700
+ 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
1701
+ 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
1702
+ 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
1703
+ 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
1704
+ 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
1705
+ 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
1706
+ 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
1707
+ 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
1708
+ 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
1709
+ 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
1710
+ 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
1711
+ 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
1712
+ 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
1713
+ 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
1714
+ 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
1715
+ 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
1716
+ 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
1717
+ 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
1718
+ 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
1719
+ 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
1720
+ 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
1721
+ 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
1722
+ 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
1723
+ 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
1724
+ 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
1725
+ 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
1726
+ 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
1727
+ 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
1728
+ 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
1729
+ 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
1730
+ 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
1731
+ 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
1732
+ 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
1733
+ 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
1734
+ 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
1735
+ 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
1736
+ 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
1737
+ 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
1738
+ 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
1739
+ 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
1740
+ 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
1741
+ 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
1742
+ 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
1743
+ 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
1744
+ 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
1745
+ 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
1746
+ 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
1747
+ 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
1748
+ 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
1749
+ 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
1750
+ 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
1751
+ 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
1752
+ 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
1753
+ 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
1754
+ 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
1755
+ 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
1756
+ 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
1757
+ 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
1758
+ 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
1759
+ 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
1760
+ 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
1761
+ 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
1762
+ 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
1763
+ 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
1764
+ 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
1765
+ 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
1766
+ 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
1767
+ 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
1768
+ 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
1769
+ 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
1770
+ 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
1771
+ 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
1772
+ 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
1773
+ 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
1774
+ 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
1775
+ 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
1776
+ 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
1777
+ 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
1778
+ 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
1779
+ 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
1780
+ 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
1781
+ 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
1782
+ 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
1783
+ 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
1784
+ 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
1785
+ 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
1786
+ 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
1787
+ 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
1788
+ 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
1789
+ 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
1790
+ 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
1791
+ 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
1792
+ 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
1793
+ 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
1794
+ 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
1795
+ 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
1796
+ 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
1797
+ 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
1798
+ 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
1799
+ 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
1800
+ 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
1801
+ 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
1802
+ 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
1803
+ 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
1804
+ 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
1805
+ 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
1806
+ 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
1807
+ 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
1808
+ 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
1809
+ 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
1810
+ 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
1811
+ 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
1812
+ 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
1813
+ 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
1814
+ 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
1815
+ 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
1816
+ 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
1817
+ 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
1818
+ 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
1819
+ 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
1820
+ 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
1821
+ 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
1822
+ };
1823
+
1824
  static const __device__ uint8_t ksigns_iq2xs[128] = {
1825
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
1826
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
 
1963
 
1964
  }
1965
 
1966
+ template<typename dst_t>
1967
+ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
1968
+
1969
+ const int i = blockIdx.x;
1970
+ const block_iq1_s * x = (const block_iq1_s *) vx;
1971
+
1972
+ const int tid = threadIdx.x;
1973
+ #if QK_K == 256
1974
+ const int il = tid/8; // 0...3
1975
+ const int ib = tid%8; // 0...7
1976
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
1977
+ const int i8 = 4*ib+il;
1978
+ uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
1979
+ const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
1980
+ const float d = (float)x[i].d * (2*(h & 7) + 1);
1981
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j];
1982
+ #else
1983
+ assert(false);
1984
+ #endif
1985
+
1986
+ }
1987
+
1988
+
1989
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1990
 
1991
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
4685
  #endif
4686
  }
4687
 
4688
+ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
4689
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4690
+ #if QK_K == 256
4691
+ const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
4692
+
4693
+ const int ib32 = iqs;
4694
+ int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
4695
+ const uint8_t h1 = bq1->scales[2*ib32+0];
4696
+ const uint8_t h2 = bq1->scales[2*ib32+1];
4697
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
4698
+ const int * q8 = (const int *)bq8_1[ib32].qs;
4699
+ const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
4700
+ const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
4701
+ const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
4702
+ const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
4703
+ for (int j = 0; j < 2; ++j) {
4704
+ sumi1 = __dp4a(q8[j+0], grid1[j], sumi1);
4705
+ sumi2 = __dp4a(q8[j+2], grid2[j], sumi2);
4706
+ sumi3 = __dp4a(q8[j+4], grid3[j], sumi3);
4707
+ sumi4 = __dp4a(q8[j+6], grid4[j], sumi4);
4708
+ }
4709
+ #else
4710
+ const int8_t * q8 = bq8_1[ib32].qs;
4711
+ const int8_t * grid1 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
4712
+ const int8_t * grid2 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
4713
+ const int8_t * grid3 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
4714
+ const int8_t * grid4 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
4715
+ for (int j = 0; j < 8; ++j) {
4716
+ sumi1 += q8[j+ 0] * grid1[j];
4717
+ sumi2 += q8[j+ 8] * grid2[j];
4718
+ sumi3 += q8[j+16] * grid3[j];
4719
+ sumi4 += q8[j+24] * grid4[j];
4720
+ }
4721
+ #endif
4722
+ const float d = (float)bq1->d * __low2float(bq8_1[ib32].ds);
4723
+ return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) +
4724
+ sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1));
4725
+ #else
4726
+ assert(false);
4727
+ return 0.f;
4728
+ #endif
4729
+ }
4730
+
4731
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
4732
  allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
4733
  static __device__ __forceinline__ void mul_mat_q(
 
6767
  dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
6768
  }
6769
 
6770
+ template<typename dst_t>
6771
+ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6772
+ const int nb = k / QK_K;
6773
+ dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
6774
+ }
6775
+
6776
  template <typename src_t, typename dst_t>
6777
  static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
6778
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
 
6812
  return dequantize_row_iq2_xs_cuda;
6813
  case GGML_TYPE_IQ3_XXS:
6814
  return dequantize_row_iq3_xxs_cuda;
6815
+ case GGML_TYPE_IQ1_S:
6816
+ return dequantize_row_iq1_s_cuda;
6817
  case GGML_TYPE_F32:
6818
  return convert_unary_cuda<float>;
6819
  default:
 
6849
  return dequantize_row_iq2_xs_cuda;
6850
  case GGML_TYPE_IQ3_XXS:
6851
  return dequantize_row_iq3_xxs_cuda;
6852
+ case GGML_TYPE_IQ1_S:
6853
+ return dequantize_row_iq1_s_cuda;
6854
  case GGML_TYPE_F16:
6855
  return convert_unary_cuda<half>;
6856
  default:
 
8594
  case GGML_TYPE_IQ2_XXS:
8595
  case GGML_TYPE_IQ2_XS:
8596
  case GGML_TYPE_IQ3_XXS:
8597
+ case GGML_TYPE_IQ1_S:
8598
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
8599
  default:
8600
  GGML_ASSERT(false);
 
8618
  case GGML_TYPE_IQ2_XXS:
8619
  case GGML_TYPE_IQ2_XS:
8620
  case GGML_TYPE_IQ3_XXS:
8621
+ case GGML_TYPE_IQ1_S:
8622
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
8623
  case GGML_TYPE_Q6_K:
8624
  return 64;
 
8716
  mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
8717
  (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8718
  break;
8719
+ case GGML_TYPE_IQ1_S:
8720
+ mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
8721
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8722
+ break;
8723
  default:
8724
  GGML_ASSERT(false);
8725
  break;
 
11436
  return false;
11437
  }
11438
  ggml_type a_type = a->type;
11439
+ if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ1_S) {
11440
  if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
11441
  return false;
11442
  }
ggml-metal.m CHANGED
@@ -61,6 +61,7 @@ enum ggml_metal_kernel_type {
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
 
64
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
65
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
66
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -83,6 +84,7 @@ enum ggml_metal_kernel_type {
83
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
84
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
85
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
 
86
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
87
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
88
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -101,6 +103,7 @@ enum ggml_metal_kernel_type {
101
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
102
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
103
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
 
104
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
105
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
106
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -116,6 +119,7 @@ enum ggml_metal_kernel_type {
116
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
117
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
118
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
 
119
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
120
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
121
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -131,6 +135,7 @@ enum ggml_metal_kernel_type {
131
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
132
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
133
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
 
134
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
135
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
136
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -442,6 +447,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
442
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
443
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
444
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
 
445
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
446
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
447
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -464,6 +470,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
464
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
465
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
466
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
 
467
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
468
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
469
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -482,6 +489,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
482
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
483
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
484
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
 
485
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
486
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
487
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -497,6 +505,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
 
500
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
501
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -512,6 +521,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
512
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
513
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
514
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
 
515
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
516
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
517
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -1327,6 +1337,7 @@ static bool ggml_metal_graph_compute(
1327
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1328
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1329
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
 
1330
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1331
  }
1332
 
@@ -1461,6 +1472,12 @@ static bool ggml_metal_graph_compute(
1461
  nth1 = 16;
1462
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1463
  } break;
 
 
 
 
 
 
1464
  default:
1465
  {
1466
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1495,7 +1512,7 @@ static bool ggml_metal_graph_compute(
1495
 
1496
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1497
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1498
- src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1499
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1500
  }
1501
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1601,6 +1618,7 @@ static bool ggml_metal_graph_compute(
1601
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1602
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1603
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
 
1604
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1605
  }
1606
 
@@ -1738,6 +1756,12 @@ static bool ggml_metal_graph_compute(
1738
  nth1 = 16;
1739
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1740
  } break;
 
 
 
 
 
 
1741
  default:
1742
  {
1743
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1788,7 +1812,7 @@ static bool ggml_metal_graph_compute(
1788
 
1789
  if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1790
  src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1791
- src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1792
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1793
  }
1794
  else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
@@ -1842,6 +1866,7 @@ static bool ggml_metal_graph_compute(
1842
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1843
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1844
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
 
1845
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1846
  default: GGML_ASSERT(false && "not implemented");
1847
  }
 
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
65
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
66
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
67
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
 
84
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
85
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
86
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
87
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
88
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
89
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
90
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
 
103
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
104
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
105
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
106
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
107
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
108
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
109
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
 
119
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
120
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
121
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
122
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
123
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
124
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
125
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
 
135
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
136
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
137
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
138
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
139
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
140
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
141
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
 
447
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
448
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
449
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
450
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
451
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
452
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
453
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
 
470
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
471
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
472
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
473
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
474
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
475
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
 
489
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
490
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
491
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
492
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
493
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
494
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
 
505
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
506
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
507
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
508
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
509
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
510
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
511
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
 
521
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
522
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
523
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
524
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
525
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
526
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
527
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
 
1337
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1338
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1339
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1340
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1341
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1342
  }
1343
 
 
1472
  nth1 = 16;
1473
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1474
  } break;
1475
+ case GGML_TYPE_IQ1_S:
1476
+ {
1477
+ nth0 = 4;
1478
+ nth1 = 16;
1479
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1480
+ } break;
1481
  default:
1482
  {
1483
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
 
1512
 
1513
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1514
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1515
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K) {
1516
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1517
  }
1518
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
 
1618
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1619
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1620
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1621
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1622
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1623
  }
1624
 
 
1756
  nth1 = 16;
1757
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1758
  } break;
1759
+ case GGML_TYPE_IQ1_S:
1760
+ {
1761
+ nth0 = 4;
1762
+ nth1 = 16;
1763
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
1764
+ } break;
1765
  default:
1766
  {
1767
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
 
1812
 
1813
  if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1814
  src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1815
+ src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K) {
1816
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1817
  }
1818
  else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
 
1866
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1867
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1868
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1869
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
1870
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1871
  default: GGML_ASSERT(false && "not implemented");
1872
  }
ggml-metal.metal CHANGED
@@ -2525,6 +2525,13 @@ typedef struct {
2525
  } block_iq3_xxs;
2526
  // 98 bytes / block for QK_K = 256, so 3.0625 bpw
2527
 
 
 
 
 
 
 
 
2528
  //====================================== dot products =========================
2529
 
2530
  void kernel_mul_mv_q2_K_f32_impl(
@@ -3782,6 +3789,137 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
3782
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3783
  };
3784
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3785
 
3786
  constexpr constant static uint8_t ksigns_iq2xs[128] = {
3787
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
@@ -4208,6 +4346,123 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
4208
  kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4209
  }
4210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4211
 
4212
  //============================= templates and their specializations =============================
4213
 
@@ -4553,6 +4808,22 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
4553
  }
4554
  }
4555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4556
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
4557
  kernel void kernel_get_rows(
4558
  device const void * src0,
@@ -5095,6 +5366,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
5095
  template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5096
  template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5097
  template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
 
5098
 
5099
  //
5100
  // matrix-matrix multiplication
@@ -5134,6 +5406,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
5134
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5135
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5136
  template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
 
5137
 
5138
  //
5139
  // indirect matrix-matrix multiplication
@@ -5185,6 +5458,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
5185
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5186
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5187
  template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
 
5188
 
5189
  //
5190
  // matrix-vector multiplication
@@ -6152,3 +6426,66 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6152
  tiisg,
6153
  sgitg);
6154
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2525
  } block_iq3_xxs;
2526
  // 98 bytes / block for QK_K = 256, so 3.0625 bpw
2527
 
2528
+ typedef struct {
2529
+ half d;
2530
+ uint8_t qs[QK_K/8];
2531
+ uint8_t scales[QK_K/16];
2532
+ } block_iq1_s;
2533
+
2534
+
2535
  //====================================== dot products =========================
2536
 
2537
  void kernel_mul_mv_q2_K_f32_impl(
 
3789
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3790
  };
3791
 
3792
+ #define NGRID_IQ1S 512
3793
+ constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
3794
+ 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
3795
+ 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
3796
+ 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
3797
+ 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
3798
+ 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
3799
+ 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
3800
+ 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
3801
+ 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
3802
+ 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
3803
+ 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
3804
+ 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
3805
+ 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
3806
+ 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
3807
+ 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
3808
+ 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
3809
+ 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
3810
+ 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
3811
+ 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
3812
+ 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
3813
+ 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
3814
+ 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
3815
+ 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
3816
+ 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
3817
+ 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
3818
+ 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
3819
+ 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
3820
+ 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
3821
+ 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
3822
+ 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
3823
+ 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
3824
+ 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
3825
+ 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
3826
+ 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
3827
+ 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
3828
+ 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
3829
+ 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
3830
+ 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
3831
+ 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
3832
+ 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
3833
+ 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
3834
+ 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
3835
+ 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
3836
+ 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
3837
+ 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
3838
+ 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
3839
+ 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
3840
+ 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
3841
+ 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
3842
+ 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
3843
+ 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
3844
+ 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
3845
+ 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
3846
+ 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
3847
+ 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
3848
+ 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
3849
+ 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
3850
+ 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
3851
+ 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
3852
+ 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
3853
+ 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
3854
+ 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
3855
+ 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
3856
+ 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
3857
+ 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
3858
+ 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
3859
+ 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
3860
+ 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
3861
+ 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
3862
+ 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
3863
+ 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
3864
+ 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
3865
+ 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
3866
+ 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
3867
+ 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
3868
+ 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
3869
+ 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
3870
+ 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
3871
+ 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
3872
+ 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
3873
+ 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
3874
+ 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
3875
+ 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
3876
+ 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
3877
+ 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
3878
+ 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
3879
+ 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
3880
+ 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
3881
+ 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
3882
+ 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
3883
+ 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
3884
+ 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
3885
+ 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
3886
+ 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
3887
+ 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
3888
+ 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
3889
+ 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
3890
+ 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
3891
+ 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
3892
+ 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
3893
+ 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
3894
+ 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
3895
+ 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
3896
+ 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
3897
+ 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
3898
+ 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
3899
+ 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
3900
+ 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
3901
+ 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
3902
+ 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
3903
+ 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
3904
+ 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
3905
+ 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
3906
+ 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
3907
+ 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
3908
+ 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
3909
+ 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
3910
+ 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
3911
+ 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
3912
+ 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
3913
+ 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
3914
+ 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
3915
+ 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
3916
+ 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
3917
+ 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
3918
+ 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
3919
+ 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
3920
+ 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
3921
+ 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
3922
+ };
3923
 
3924
  constexpr constant static uint8_t ksigns_iq2xs[128] = {
3925
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
 
4346
  kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4347
  }
4348
 
4349
+ void kernel_mul_mv_iq1_s_f32_impl(
4350
+ device const void * src0,
4351
+ device const float * src1,
4352
+ device float * dst,
4353
+ constant int64_t & ne00,
4354
+ constant int64_t & ne01,
4355
+ constant int64_t & ne02,
4356
+ constant int64_t & ne10,
4357
+ constant int64_t & ne12,
4358
+ constant int64_t & ne0,
4359
+ constant int64_t & ne1,
4360
+ constant uint & r2,
4361
+ constant uint & r3,
4362
+ uint3 tgpig[[threadgroup_position_in_grid]],
4363
+ uint tiisg[[thread_index_in_simdgroup]],
4364
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4365
+
4366
+ const int nb = ne00/QK_K;
4367
+ const int r0 = tgpig.x;
4368
+ const int r1 = tgpig.y;
4369
+ const int im = tgpig.z;
4370
+
4371
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4372
+ const int ib_row = first_row * nb;
4373
+
4374
+ const uint i12 = im%ne12;
4375
+ const uint i13 = im/ne12;
4376
+
4377
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4378
+
4379
+ device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
4380
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4381
+
4382
+ float yl[16];
4383
+ float sumf[N_DST]={0.f}, all_sum;
4384
+
4385
+ const int nb32 = nb * (QK_K / 32);
4386
+
4387
+ #if QK_K == 256
4388
+ const int ix = tiisg/2;
4389
+ const int il = tiisg%2;
4390
+
4391
+ device const float * y4 = y + 32 * ix + 16 * il;
4392
+
4393
+ for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
4394
+
4395
+ for (int i = 0; i < 16; ++i) {
4396
+ yl[i] = y4[i];
4397
+ }
4398
+
4399
+ const int ibl = ib32 / (QK_K / 32);
4400
+ const int ib = ib32 % (QK_K / 32);
4401
+
4402
+ device const block_iq1_s * xr = x + ibl;
4403
+ device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
4404
+ device const uint8_t * sc = xr->scales + 2 * ib + il;
4405
+ device const half * dh = &xr->d;
4406
+
4407
+ for (int row = 0; row < N_DST; row++) {
4408
+
4409
+ constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
4410
+ constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
4411
+
4412
+ float2 sum = {0};
4413
+ for (int j = 0; j < 8; ++j) {
4414
+ sum[0] += yl[j+ 0] * grid1[j];
4415
+ sum[1] += yl[j+ 8] * grid2[j];
4416
+ }
4417
+ sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
4418
+
4419
+ dh += nb*sizeof(block_iq1_s)/2;
4420
+ qs += nb*sizeof(block_iq1_s);
4421
+ sc += nb*sizeof(block_iq1_s);
4422
+ }
4423
+
4424
+ y4 += 16 * 32;
4425
+ }
4426
+ #else
4427
+ // TODO
4428
+ #endif
4429
+
4430
+ for (int row = 0; row < N_DST; ++row) {
4431
+ all_sum = simd_sum(sumf[row]);
4432
+ if (tiisg == 0) {
4433
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4434
+ }
4435
+ }
4436
+ }
4437
+
4438
+ [[host_name("kernel_mul_mv_iq1_s_f32")]]
4439
+ kernel void kernel_mul_mv_iq1_s_f32(
4440
+ device const void * src0,
4441
+ device const float * src1,
4442
+ device float * dst,
4443
+ constant int64_t & ne00,
4444
+ constant int64_t & ne01,
4445
+ constant int64_t & ne02,
4446
+ constant uint64_t & nb00,
4447
+ constant uint64_t & nb01,
4448
+ constant uint64_t & nb02,
4449
+ constant int64_t & ne10,
4450
+ constant int64_t & ne11,
4451
+ constant int64_t & ne12,
4452
+ constant uint64_t & nb10,
4453
+ constant uint64_t & nb11,
4454
+ constant uint64_t & nb12,
4455
+ constant int64_t & ne0,
4456
+ constant int64_t & ne1,
4457
+ constant uint & r2,
4458
+ constant uint & r3,
4459
+ uint3 tgpig[[threadgroup_position_in_grid]],
4460
+ uint tiisg[[thread_index_in_simdgroup]],
4461
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4462
+
4463
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4464
+ }
4465
+
4466
 
4467
  //============================= templates and their specializations =============================
4468
 
 
4808
  }
4809
  }
4810
 
4811
+ template <typename type4x4>
4812
+ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
4813
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4814
+ const float d = xb->d;
4815
+ device const uint8_t * qs = xb->qs + 2*il;
4816
+ device const uint8_t * sc = xb->scales + il;
4817
+ const float dl1 = d * (2*(sc[0] & 7) + 1);
4818
+ const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1);
4819
+ constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
4820
+ constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
4821
+ for (int i = 0; i < 8; ++i) {
4822
+ reg[i/4+0][i%4] = dl1 * grid1[i];
4823
+ reg[i/4+2][i%4] = dl2 * grid2[i];
4824
+ }
4825
+ }
4826
+
4827
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
4828
  kernel void kernel_get_rows(
4829
  device const void * src0,
 
5366
  template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5367
  template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5368
  template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5369
+ template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
5370
 
5371
  //
5372
  // matrix-matrix multiplication
 
5406
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5407
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5408
  template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5409
+ template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
5410
 
5411
  //
5412
  // indirect matrix-matrix multiplication
 
5458
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5459
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5460
  template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5461
+ template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
5462
 
5463
  //
5464
  // matrix-vector multiplication
 
6426
  tiisg,
6427
  sgitg);
6428
  }
6429
+
6430
+ [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6431
+ kernel void kernel_mul_mv_id_iq1_s_f32(
6432
+ device const char * ids,
6433
+ device const char * src1,
6434
+ device float * dst,
6435
+ constant uint64_t & nbi1,
6436
+ constant int64_t & ne00,
6437
+ constant int64_t & ne01,
6438
+ constant int64_t & ne02,
6439
+ constant uint64_t & nb00,
6440
+ constant uint64_t & nb01,
6441
+ constant uint64_t & nb02,
6442
+ constant int64_t & ne10,
6443
+ constant int64_t & ne11,
6444
+ constant int64_t & ne12,
6445
+ constant int64_t & ne13,
6446
+ constant uint64_t & nb10,
6447
+ constant uint64_t & nb11,
6448
+ constant uint64_t & nb12,
6449
+ constant int64_t & ne0,
6450
+ constant int64_t & ne1,
6451
+ constant uint64_t & nb1,
6452
+ constant uint & r2,
6453
+ constant uint & r3,
6454
+ constant int & idx,
6455
+ device const char * src00,
6456
+ device const char * src01,
6457
+ device const char * src02,
6458
+ device const char * src03,
6459
+ device const char * src04,
6460
+ device const char * src05,
6461
+ device const char * src06,
6462
+ device const char * src07,
6463
+ uint3 tgpig[[threadgroup_position_in_grid]],
6464
+ uint tiitg[[thread_index_in_threadgroup]],
6465
+ uint tiisg[[thread_index_in_simdgroup]],
6466
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6467
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6468
+
6469
+ const int64_t bid = tgpig.z/(ne12*ne13);
6470
+
6471
+ tgpig.z = tgpig.z%(ne12*ne13);
6472
+
6473
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6474
+
6475
+ kernel_mul_mv_iq1_s_f32_impl(
6476
+ src0[id],
6477
+ (device const float *) (src1 + bid*nb11),
6478
+ dst + bid*ne0,
6479
+ ne00,
6480
+ ne01,
6481
+ ne02,
6482
+ ne10,
6483
+ ne12,
6484
+ ne0,
6485
+ ne1,
6486
+ r2,
6487
+ r3,
6488
+ tgpig,
6489
+ tiisg,
6490
+ sgitg);
6491
+ }
ggml-quants.c CHANGED
@@ -3480,6 +3480,139 @@ static const uint32_t iq3xxs_grid[256] = {
3480
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3481
  };
3482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3483
  static const uint8_t ksigns_iq2xs[128] = {
3484
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3485
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -3578,6 +3711,49 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
3578
  }
3579
  }
3580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3581
  //===================================== Q8_K ==============================================
3582
 
3583
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -3679,7 +3855,7 @@ static inline __m128i get_scale_shuffle(int i) {
3679
  }
3680
  #endif
3681
 
3682
- void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
3683
  const int qk = QK8_0;
3684
  const int nb = n / qk;
3685
 
@@ -3690,8 +3866,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3690
  assert(nrc == 1);
3691
  #endif
3692
  UNUSED(nrc);
3693
- UNUSED(bx);
3694
- UNUSED(by);
3695
  UNUSED(bs);
3696
 
3697
  const block_q4_0 * restrict x = vx;
@@ -4046,7 +4222,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4046
  #endif
4047
  }
4048
 
4049
- void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4050
  const int qk = QK8_1;
4051
  const int nb = n / qk;
4052
 
@@ -4057,8 +4233,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4057
  assert(nrc == 1);
4058
  #endif
4059
  UNUSED(nrc);
4060
- UNUSED(bx);
4061
- UNUSED(by);
4062
  UNUSED(bs);
4063
 
4064
  const block_q4_1 * restrict x = vx;
@@ -4264,7 +4440,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4264
  #endif
4265
  }
4266
 
4267
- void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4268
  const int qk = QK8_0;
4269
  const int nb = n / qk;
4270
 
@@ -4272,8 +4448,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4272
  assert(qk == QK5_0);
4273
  assert(nrc == 1);
4274
  UNUSED(nrc);
4275
- UNUSED(bx);
4276
- UNUSED(by);
4277
  UNUSED(bs);
4278
 
4279
  const block_q5_0 * restrict x = vx;
@@ -4555,7 +4731,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4555
  #endif
4556
  }
4557
 
4558
- void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4559
  const int qk = QK8_1;
4560
  const int nb = n / qk;
4561
 
@@ -4563,8 +4739,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4563
  assert(qk == QK5_1);
4564
  assert(nrc == 1);
4565
  UNUSED(nrc);
4566
- UNUSED(bx);
4567
- UNUSED(by);
4568
  UNUSED(bs);
4569
 
4570
  const block_q5_1 * restrict x = vx;
@@ -4859,7 +5035,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4859
  #endif
4860
  }
4861
 
4862
- void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4863
  const int qk = QK8_0;
4864
  const int nb = n / qk;
4865
 
@@ -4870,8 +5046,8 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4870
  assert(nrc == 1);
4871
  #endif
4872
  UNUSED(nrc);
4873
- UNUSED(bx);
4874
- UNUSED(by);
4875
  UNUSED(bs);
4876
 
4877
  const block_q8_0 * restrict x = vx;
@@ -9107,6 +9283,178 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
9107
  #endif
9108
  }
9109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9110
  // ================================ IQ2 quantization =============================================
9111
 
9112
  typedef struct {
@@ -9115,14 +9463,22 @@ typedef struct {
9115
  uint16_t * neighbours;
9116
  } iq2_entry_t;
9117
 
9118
- static iq2_entry_t iq2_data[2] = {
 
9119
  {NULL, NULL, NULL},
9120
  {NULL, NULL, NULL},
9121
  };
9122
 
9123
- static inline int iq2_data_index(int grid_size) {
9124
- GGML_ASSERT(grid_size == 256 || grid_size == 512);
9125
- return grid_size == 256 ? 0 : 1;
 
 
 
 
 
 
 
9126
  }
9127
 
9128
  static int iq2_compare_func(const void * left, const void * right) {
@@ -9131,12 +9487,13 @@ static int iq2_compare_func(const void * left, const void * right) {
9131
  return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
9132
  }
9133
 
9134
- void iq2xs_init_impl(int grid_size) {
9135
- const int gindex = iq2_data_index(grid_size);
 
9136
  if (iq2_data[gindex].grid) {
9137
  return;
9138
  }
9139
- static const uint16_t kgrid_256[256] = {
9140
  0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97,
9141
  100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642,
9142
  1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288,
@@ -9154,7 +9511,7 @@ void iq2xs_init_impl(int grid_size) {
9154
  33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
9155
  37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
9156
  };
9157
- static const uint16_t kgrid_512[512] = {
9158
  0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
9159
  73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257,
9160
  260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340,
@@ -9188,9 +9545,45 @@ void iq2xs_init_impl(int grid_size) {
9188
  40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
9189
  42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
9190
  };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9191
  const int kmap_size = 43692;
9192
- const int nwant = 2;
9193
- const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
 
9194
  uint64_t * kgrid_q2xs;
9195
  int * kmap_q2xs;
9196
  uint16_t * kneighbors_q2xs;
@@ -9286,9 +9679,9 @@ void iq2xs_init_impl(int grid_size) {
9286
  free(dist2);
9287
  }
9288
 
9289
- void iq2xs_free_impl(int grid_size) {
9290
- GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024);
9291
- const int gindex = iq2_data_index(grid_size);
9292
  if (iq2_data[gindex].grid) {
9293
  free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL;
9294
  free(iq2_data[gindex].map); iq2_data[gindex].map = NULL;
@@ -9322,7 +9715,7 @@ static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const u
9322
 
9323
  static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
9324
 
9325
- const int gindex = iq2_data_index(256);
9326
 
9327
  const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
9328
  const int * kmap_q2xs = iq2_data[gindex].map;
@@ -9495,7 +9888,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
9495
 
9496
  static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
9497
 
9498
- const int gindex = iq2_data_index(512);
9499
 
9500
  const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
9501
  const int * kmap_q2xs = iq2_data[gindex].map;
@@ -10132,3 +10525,207 @@ void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * re
10132
  assert(k % QK_K == 0);
10133
  quantize_row_iq3_xxs_impl(x, y, k, NULL);
10134
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3480
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3481
  };
3482
 
3483
+ #define NGRID_IQ2XXS 512
3484
+ static const uint64_t iq1s_grid[NGRID_IQ2XXS] = {
3485
+ 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
3486
+ 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
3487
+ 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
3488
+ 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
3489
+ 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
3490
+ 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
3491
+ 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
3492
+ 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
3493
+ 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
3494
+ 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
3495
+ 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
3496
+ 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
3497
+ 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
3498
+ 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
3499
+ 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
3500
+ 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
3501
+ 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
3502
+ 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
3503
+ 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
3504
+ 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
3505
+ 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
3506
+ 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
3507
+ 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
3508
+ 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
3509
+ 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
3510
+ 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
3511
+ 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
3512
+ 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
3513
+ 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
3514
+ 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
3515
+ 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
3516
+ 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
3517
+ 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
3518
+ 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
3519
+ 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
3520
+ 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
3521
+ 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
3522
+ 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
3523
+ 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
3524
+ 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
3525
+ 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
3526
+ 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
3527
+ 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
3528
+ 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
3529
+ 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
3530
+ 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
3531
+ 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
3532
+ 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
3533
+ 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
3534
+ 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
3535
+ 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
3536
+ 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
3537
+ 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
3538
+ 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
3539
+ 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
3540
+ 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
3541
+ 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
3542
+ 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
3543
+ 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
3544
+ 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
3545
+ 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
3546
+ 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
3547
+ 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
3548
+ 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
3549
+ 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
3550
+ 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
3551
+ 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
3552
+ 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
3553
+ 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
3554
+ 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
3555
+ 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
3556
+ 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
3557
+ 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
3558
+ 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
3559
+ 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
3560
+ 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
3561
+ 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
3562
+ 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
3563
+ 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
3564
+ 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
3565
+ 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
3566
+ 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
3567
+ 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
3568
+ 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
3569
+ 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
3570
+ 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
3571
+ 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
3572
+ 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
3573
+ 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
3574
+ 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
3575
+ 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
3576
+ 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
3577
+ 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
3578
+ 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
3579
+ 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
3580
+ 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
3581
+ 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
3582
+ 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
3583
+ 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
3584
+ 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
3585
+ 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
3586
+ 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
3587
+ 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
3588
+ 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
3589
+ 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
3590
+ 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
3591
+ 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
3592
+ 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
3593
+ 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
3594
+ 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
3595
+ 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
3596
+ 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
3597
+ 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
3598
+ 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
3599
+ 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
3600
+ 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
3601
+ 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
3602
+ 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
3603
+ 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
3604
+ 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
3605
+ 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
3606
+ 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
3607
+ 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
3608
+ 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
3609
+ 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
3610
+ 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
3611
+ 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
3612
+ 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
3613
+
3614
+ };
3615
+
3616
  static const uint8_t ksigns_iq2xs[128] = {
3617
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3618
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
 
3711
  }
3712
  }
3713
 
3714
+ // ====================== 1.5625 bpw (de)-quantization
3715
+
3716
+ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
3717
+ assert(k % QK_K == 0);
3718
+ const int nb = k / QK_K;
3719
+
3720
+ float db[4];
3721
+ uint16_t idx[4];
3722
+ //const int8_t * grid[4];
3723
+
3724
+ for (int i = 0; i < nb; i++) {
3725
+
3726
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3727
+ const uint8_t * sc = x[i].scales;
3728
+ const uint8_t * qs = x[i].qs;
3729
+
3730
+ for (int i8 = 0; i8 < QK_K/8; i8 += 4) {
3731
+ idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
3732
+ idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
3733
+ idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
3734
+ idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
3735
+ //grid[0] = (const int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
3736
+ //grid[1] = (const int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
3737
+ //grid[2] = (const int8_t *)(iq1s_grid + (qs[2] | ((sc[1] & 0x08) << 5)));
3738
+ //grid[3] = (const int8_t *)(iq1s_grid + (qs[3] | ((sc[1] & 0x80) << 1)));
3739
+ db[0] = d * (2*(sc[0] & 7) + 1);
3740
+ db[1] = d * (2*((sc[0] >> 4) & 7) + 1);
3741
+ db[2] = d * (2*(sc[1] & 7) + 1);
3742
+ db[3] = d * (2*((sc[1] >> 4) & 7) + 1);
3743
+ for (int l = 0; l < 4; ++l) {
3744
+ const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
3745
+ for (int j = 0; j < 8; ++j) {
3746
+ //y[j] = db[l] * grid[l][j];
3747
+ y[j] = db[l] * grid[j];
3748
+ }
3749
+ y += 8;
3750
+ }
3751
+ qs += 4;
3752
+ sc += 2;
3753
+ }
3754
+ }
3755
+ }
3756
+
3757
  //===================================== Q8_K ==============================================
3758
 
3759
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
 
3855
  }
3856
  #endif
3857
 
3858
+ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
3859
  const int qk = QK8_0;
3860
  const int nb = n / qk;
3861
 
 
3866
  assert(nrc == 1);
3867
  #endif
3868
  UNUSED(nrc);
3869
+ UNUSED(bbx);
3870
+ UNUSED(bby);
3871
  UNUSED(bs);
3872
 
3873
  const block_q4_0 * restrict x = vx;
 
4222
  #endif
4223
  }
4224
 
4225
+ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
4226
  const int qk = QK8_1;
4227
  const int nb = n / qk;
4228
 
 
4233
  assert(nrc == 1);
4234
  #endif
4235
  UNUSED(nrc);
4236
+ UNUSED(bbx);
4237
+ UNUSED(bby);
4238
  UNUSED(bs);
4239
 
4240
  const block_q4_1 * restrict x = vx;
 
4440
  #endif
4441
  }
4442
 
4443
+ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
4444
  const int qk = QK8_0;
4445
  const int nb = n / qk;
4446
 
 
4448
  assert(qk == QK5_0);
4449
  assert(nrc == 1);
4450
  UNUSED(nrc);
4451
+ UNUSED(bbx);
4452
+ UNUSED(bby);
4453
  UNUSED(bs);
4454
 
4455
  const block_q5_0 * restrict x = vx;
 
4731
  #endif
4732
  }
4733
 
4734
+ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
4735
  const int qk = QK8_1;
4736
  const int nb = n / qk;
4737
 
 
4739
  assert(qk == QK5_1);
4740
  assert(nrc == 1);
4741
  UNUSED(nrc);
4742
+ UNUSED(bbx);
4743
+ UNUSED(bby);
4744
  UNUSED(bs);
4745
 
4746
  const block_q5_1 * restrict x = vx;
 
5035
  #endif
5036
  }
5037
 
5038
+ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
5039
  const int qk = QK8_0;
5040
  const int nb = n / qk;
5041
 
 
5046
  assert(nrc == 1);
5047
  #endif
5048
  UNUSED(nrc);
5049
+ UNUSED(bbx);
5050
+ UNUSED(bby);
5051
  UNUSED(bs);
5052
 
5053
  const block_q8_0 * restrict x = vx;
 
9283
  #endif
9284
  }
9285
 
9286
+ #ifdef __AVX2__
9287
+ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
9288
+ const __m256i ax = _mm256_sign_epi8(x, x);
9289
+ const __m256i sy = _mm256_sign_epi8(y, x);
9290
+ return _mm256_maddubs_epi16(ax, sy);
9291
+ }
9292
+ #endif
9293
+
9294
+ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
9295
+ assert(n % QK_K == 0);
9296
+ assert(nrc == 1);
9297
+ UNUSED(nrc);
9298
+ UNUSED(bx);
9299
+ UNUSED(by);
9300
+ UNUSED(bs);
9301
+
9302
+ const block_iq1_s * restrict x = vx;
9303
+ const block_q8_K * restrict y = vy;
9304
+
9305
+ const int nb = n / QK_K;
9306
+
9307
+ #if defined __ARM_NEON
9308
+
9309
+ const uint8x16_t m8 = vdupq_n_u8(0x08);
9310
+ const uint8x16_t m7 = vdupq_n_u8(0x07);
9311
+ const uint8x16_t m1 = vdupq_n_u8(0x01);
9312
+ const int32x4_t vzero = vdupq_n_s32(0);
9313
+
9314
+ uint16_t gindex[8];
9315
+ uint16x8x2_t vindex;
9316
+ int8x16x4_t q1b;
9317
+ int8x16x4_t q8b;
9318
+ uint16x8x4_t scales;
9319
+ int32x4x2_t sumi;
9320
+ int32x4x2_t dotq;
9321
+
9322
+ float sumf = 0;
9323
+ for (int i = 0; i < nb; ++i) {
9324
+
9325
+ const int8_t * q8 = y[i].qs;
9326
+ const uint8_t * qs = x[i].qs;
9327
+ const uint8_t * sc = x[i].scales;
9328
+
9329
+ sumi.val[0] = sumi.val[1] = vzero;
9330
+
9331
+ for (int i128 = 0; i128 < QK_K/128; ++i128) {
9332
+ const uint8x16_t ql = vld1q_u8(qs); qs += 16;
9333
+ const uint8x8_t tm1 = vld1_u8 (sc); sc += 8;
9334
+ const uint8x8_t tm2 = vshr_n_u8(tm1, 4);
9335
+ const uint8x16_t qh = vcombine_u8(vzip1_u8(tm1, tm2), vzip2_u8(tm1, tm2));
9336
+ const uint8x16_t hbit = vandq_u8(qh, m8);
9337
+ vindex.val[0] = vorrq_u16(vmovl_u8(vget_low_u8 (ql)), vshlq_n_u16(vmovl_u8(vget_low_u8 (hbit)), 5));
9338
+ vindex.val[1] = vorrq_u16(vmovl_u8(vget_high_u8(ql)), vshlq_n_u16(vmovl_u8(vget_high_u8(hbit)), 5));
9339
+ const uint8x16_t scales8 = vorrq_u8(vshlq_n_u8(vandq_u8(qh, m7), 1), m1);
9340
+ scales.val[0] = vmovl_u8(vget_low_u8 (scales8));
9341
+ scales.val[1] = vmovl_u8(vget_high_u8 (scales8));
9342
+
9343
+ for (int l = 0; l < 2; ++l) {
9344
+ vst1q_u16(gindex+0, vindex.val[l]);
9345
+ q1b.val[0] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[0])), vld1_s8((const void *)(iq1s_grid+gindex[1])));
9346
+ q1b.val[1] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[2])), vld1_s8((const void *)(iq1s_grid+gindex[3])));
9347
+ q1b.val[2] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[4])), vld1_s8((const void *)(iq1s_grid+gindex[5])));
9348
+ q1b.val[3] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[6])), vld1_s8((const void *)(iq1s_grid+gindex[7])));
9349
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
9350
+
9351
+ dotq.val[0] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(vzero, q1b.val[1], q8b.val[1]));
9352
+ dotq.val[1] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(vzero, q1b.val[3], q8b.val[3]));
9353
+
9354
+ sumi.val[0] = vmlaq_s32(sumi.val[0], dotq.val[0], vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales.val[l]))));
9355
+ sumi.val[1] = vmlaq_s32(sumi.val[1], dotq.val[1], vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales.val[l]))));
9356
+ }
9357
+ }
9358
+
9359
+ sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * vaddvq_s32(vaddq_s32(sumi.val[0], sumi.val[1]));
9360
+ }
9361
+
9362
+ *s = sumf;
9363
+
9364
+ #elif defined __AVX2__
9365
+
9366
+ const __m128i m8 = _mm_set1_epi8(0x08);
9367
+ const __m128i m7 = _mm_set1_epi8(0x07);
9368
+ const __m128i m1 = _mm_set1_epi8(0x01);
9369
+ const __m128i shuffle_h = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
9370
+ const __m128i shuffle_s[4] = {
9371
+ _mm_set_epi32(0x03030303, 0x02020202, 0x01010101, 0x00000000),
9372
+ _mm_set_epi32(0x07070707, 0x06060606, 0x05050505, 0x04040404),
9373
+ _mm_set_epi32(0x0b0b0b0b, 0x0a0a0a0a, 0x09090909, 0x08080808),
9374
+ _mm_set_epi32(0x0f0f0f0f, 0x0e0e0e0e, 0x0d0d0d0d, 0x0c0c0c0c)
9375
+ };
9376
+
9377
+ uint64_t aux64;
9378
+
9379
+ __m256i v_gindex;
9380
+ const uint16_t * gindex = (const uint16_t *)&v_gindex;
9381
+
9382
+ __m256 accum = _mm256_setzero_ps();
9383
+ for (int i = 0; i < nb; ++i) {
9384
+
9385
+ const int8_t * q8 = y[i].qs;
9386
+ const uint8_t * qs = x[i].qs;
9387
+ const uint8_t * sc = x[i].scales;
9388
+
9389
+ __m256i sumi = _mm256_setzero_si256();
9390
+ for (int i128 = 0; i128 < QK_K/128; ++i128) {
9391
+ const __m128i ql = _mm_loadu_si128((const __m128i*)qs); qs += 16;
9392
+ memcpy(&aux64, sc, 8); sc += 8;
9393
+ const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h);
9394
+ const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8));
9395
+ v_gindex = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
9396
+ const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1);
9397
+
9398
+ for (int i32 = 0; i32 < 4; ++i32) {
9399
+ const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
9400
+ const __m256i q1b = _mm256_set_epi64x(iq1s_grid[gindex[4*i32+3]], iq1s_grid[gindex[4*i32+2]],
9401
+ iq1s_grid[gindex[4*i32+1]], iq1s_grid[gindex[4*i32+0]]);
9402
+ const __m256i dot = mul_add_epi8(q1b, q8b);
9403
+ const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32]));
9404
+ const __m256i p = _mm256_madd_epi16(s16, dot);
9405
+ sumi = _mm256_add_epi32(sumi, p);
9406
+ }
9407
+
9408
+ }
9409
+
9410
+ accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum);
9411
+
9412
+ }
9413
+
9414
+ *s = hsum_float_8(accum);
9415
+
9416
+ #else
9417
+
9418
+ int db[4];
9419
+ uint16_t idx[4];
9420
+
9421
+ float sumf = 0;
9422
+ for (int i = 0; i < nb; ++i) {
9423
+
9424
+ const int8_t * q8 = y[i].qs;
9425
+ const uint8_t * qs = x[i].qs;
9426
+ const uint8_t * sc = x[i].scales;
9427
+
9428
+ int sumi = 0;
9429
+ for (int i32 = 0; i32 < QK_K/32; ++i32) {
9430
+ idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
9431
+ idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
9432
+ idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
9433
+ idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
9434
+ db[0] = (2*(sc[0] & 7) + 1);
9435
+ db[1] = (2*((sc[0] >> 4) & 7) + 1);
9436
+ db[2] = (2*(sc[1] & 7) + 1);
9437
+ db[3] = (2*((sc[1] >> 4) & 7) + 1);
9438
+ for (int l = 0; l < 4; ++l) {
9439
+ const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
9440
+ int suml = 0;
9441
+ for (int j = 0; j < 8; ++j) suml += q8[j] * grid[j];
9442
+ sumi += db[l] * suml;
9443
+ q8 += 8;
9444
+ }
9445
+ qs += 4;
9446
+ sc += 2;
9447
+ }
9448
+
9449
+ sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * sumi;
9450
+ }
9451
+
9452
+ *s = sumf;
9453
+
9454
+ #endif
9455
+
9456
+ }
9457
+
9458
  // ================================ IQ2 quantization =============================================
9459
 
9460
  typedef struct {
 
9463
  uint16_t * neighbours;
9464
  } iq2_entry_t;
9465
 
9466
+ static iq2_entry_t iq2_data[3] = {
9467
+ {NULL, NULL, NULL},
9468
  {NULL, NULL, NULL},
9469
  {NULL, NULL, NULL},
9470
  };
9471
 
9472
+ static inline int iq2_data_index(enum ggml_type type) {
9473
+ GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
9474
+ return type == GGML_TYPE_IQ2_XXS ? 0 :
9475
+ type == GGML_TYPE_IQ2_XS ? 1 : 2;
9476
+ }
9477
+
9478
+ static inline int iq2_grid_size(enum ggml_type type) {
9479
+ GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
9480
+ return type == GGML_TYPE_IQ2_XXS ? 256 :
9481
+ type == GGML_TYPE_IQ2_XS ? 512 : 512;
9482
  }
9483
 
9484
  static int iq2_compare_func(const void * left, const void * right) {
 
9487
  return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
9488
  }
9489
 
9490
+ void iq2xs_init_impl(enum ggml_type type) {
9491
+ const int gindex = iq2_data_index(type);
9492
+ const int grid_size = iq2_grid_size(type);
9493
  if (iq2_data[gindex].grid) {
9494
  return;
9495
  }
9496
+ static const uint16_t kgrid_2bit_256[256] = {
9497
  0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97,
9498
  100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642,
9499
  1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288,
 
9511
  33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
9512
  37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
9513
  };
9514
+ static const uint16_t kgrid_2bit_512[512] = {
9515
  0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
9516
  73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257,
9517
  260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340,
 
9545
  40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
9546
  42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
9547
  };
9548
+ static const uint16_t kgrid_1bit_512[512] = {
9549
+ 10, 33, 41, 85, 132, 134, 160, 162, 277, 337, 340, 345, 357, 405, 516, 545,
9550
+ 553, 598, 641, 650, 681, 1042, 1044, 1097, 1169, 1176, 1320, 1345, 1365, 1378, 1434, 1444,
9551
+ 1545, 1617, 1642, 1685, 2053, 2080, 2089, 2133, 2176, 2182, 2208, 2214, 2306, 2384, 2393, 2440,
9552
+ 2453, 2581, 2664, 2690, 2721, 4117, 4161, 4182, 4184, 4261, 4357, 4369, 4372, 4377, 4390, 4422,
9553
+ 4432, 4437, 4449, 4457, 4485, 4497, 4505, 4629, 4677, 4696, 4774, 5205, 5217, 5225, 5386, 5397,
9554
+ 5409, 5445, 5457, 5460, 5461, 5462, 5465, 5472, 5477, 5525, 5545, 5650, 5668, 5717, 5729, 5769,
9555
+ 5777, 6212, 6234, 6244, 6293, 6424, 6482, 6485, 6502, 6505, 6529, 6538, 6565, 6656, 6682, 6788,
9556
+ 6806, 6820, 8218, 8224, 8226, 8232, 8277, 8326, 8354, 8469, 8521, 8530, 8549, 8596, 8737, 8794,
9557
+ 9221, 9253, 9348, 9369, 9380, 9474, 9557, 9633, 9732, 9753, 9793, 9830, 9862, 9880, 10240, 10272,
9558
+ 10282, 10321, 10406, 10517, 10530, 10566, 10585, 10645, 10896, 16466, 16468, 16473, 16485, 16646, 16660, 16665,
9559
+ 16725, 16793, 16806, 16914, 16969, 16977, 16996, 17028, 17057, 17408, 17416, 17434, 17493, 17512, 17578, 17685,
9560
+ 17696, 17733, 17745, 17748, 17749, 17750, 17753, 17765, 17794, 17813, 17946, 17984, 18005, 18072, 18453, 18529,
9561
+ 18569, 18722, 18756, 18762, 18773, 18794, 18833, 18853, 18945, 19026, 19033, 19077, 20489, 20497, 20500, 20517,
9562
+ 20565, 20586, 20610, 20633, 20757, 20769, 20776, 20805, 20817, 20820, 20821, 20822, 20825, 20837, 20864, 20872,
9563
+ 20885, 20896, 21002, 21029, 21077, 21146, 21510, 21525, 21573, 21585, 21588, 21589, 21590, 21593, 21605, 21653,
9564
+ 21665, 21765, 21777, 21780, 21781, 21782, 21785, 21797, 21825, 21828, 21829, 21830, 21833, 21840, 21841, 21842,
9565
+ 21844, 21846, 21848, 21849, 21850, 21857, 21860, 21861, 21862, 21865, 21893, 21905, 21908, 21909, 21910, 21913,
9566
+ 21925, 22024, 22037, 22085, 22097, 22100, 22101, 22102, 22105, 22117, 22165, 22545, 22566, 22568, 22594, 22608,
9567
+ 22613, 22676, 22697, 22793, 22805, 22853, 22865, 22868, 22869, 22870, 22873, 22885, 22933, 22946, 23046, 23072,
9568
+ 23125, 23209, 24597, 24640, 24665, 24673, 24725, 24833, 24840, 24869, 24917, 24934, 24965, 25001, 25108, 25110,
9569
+ 25152, 25184, 25192, 25234, 25616, 25618, 25625, 25685, 25704, 25738, 25744, 25770, 25877, 25897, 25925, 25937,
9570
+ 25940, 25941, 25942, 25945, 25957, 25986, 26005, 26186, 26197, 26276, 26632, 26634, 26725, 26757, 26770, 26885,
9571
+ 26965, 26976, 26986, 27032, 27153, 27174, 27200, 27208, 27240, 27269, 27282, 27290, 32778, 32800, 32802, 32808,
9572
+ 32810, 32853, 32904, 32922, 32930, 32932, 33105, 33110, 33112, 33125, 33157, 33280, 33288, 33301, 33312, 33320,
9573
+ 33424, 33797, 33829, 33858, 34068, 34133, 34146, 34176, 34217, 34306, 34342, 34441, 34454, 34468, 34832, 34918,
9574
+ 34965, 34984, 35094, 35137, 35161, 35208, 35232, 35332, 35338, 35368, 35429, 36932, 36934, 36953, 37009, 37125,
9575
+ 37136, 37138, 37145, 37157, 37205, 37220, 37258, 37290, 37444, 37446, 37465, 37478, 37525, 37905, 37968, 37973,
9576
+ 38040, 38054, 38145, 38154, 38165, 38180, 38186, 38213, 38225, 38228, 38229, 38230, 38233, 38245, 38293, 38485,
9577
+ 38504, 38530, 38938, 38985, 38993, 39012, 39040, 39173, 39192, 39253, 39265, 39301, 39316, 39322, 39442, 39497,
9578
+ 39504, 39590, 40970, 40984, 40992, 41002, 41045, 41120, 41128, 41237, 41289, 41297, 41317, 41364, 41366, 41514,
9579
+ 41557, 41633, 41989, 42021, 42056, 42068, 42074, 42113, 42242, 42265, 42274, 42325, 42340, 42402, 42501, 42512,
9580
+ 42533, 42624, 42632, 42666, 43040, 43093, 43106, 43168, 43176, 43264, 43286, 43345, 43429, 43590, 43618, 43680,
9581
+ };
9582
+
9583
  const int kmap_size = 43692;
9584
+ const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2;
9585
+ const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
9586
+ type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 : kgrid_1bit_512;
9587
  uint64_t * kgrid_q2xs;
9588
  int * kmap_q2xs;
9589
  uint16_t * kneighbors_q2xs;
 
9679
  free(dist2);
9680
  }
9681
 
9682
+ void iq2xs_free_impl(enum ggml_type type) {
9683
+ GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
9684
+ const int gindex = iq2_data_index(type);
9685
  if (iq2_data[gindex].grid) {
9686
  free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL;
9687
  free(iq2_data[gindex].map); iq2_data[gindex].map = NULL;
 
9715
 
9716
  static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
9717
 
9718
+ const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS);
9719
 
9720
  const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
9721
  const int * kmap_q2xs = iq2_data[gindex].map;
 
9888
 
9889
  static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
9890
 
9891
+ const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS);
9892
 
9893
  const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
9894
  const int * kmap_q2xs = iq2_data[gindex].map;
 
10525
  assert(k % QK_K == 0);
10526
  quantize_row_iq3_xxs_impl(x, y, k, NULL);
10527
  }
10528
+
10529
+ // =================================== 1.5 bpw ===================================================
10530
+
10531
+ static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
10532
+ const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) {
10533
+ int num_neighbors = neighbours[0];
10534
+ GGML_ASSERT(num_neighbors > 0);
10535
+ float best_score = 0;
10536
+ int grid_index = -1;
10537
+ for (int j = 1; j <= num_neighbors; ++j) {
10538
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
10539
+ float sumqx = 0, sumq2 = 0;
10540
+ for (int i = 0; i < 8; ++i) {
10541
+ float q = (pg[i] - 3)/2;
10542
+ float w = weight[i];
10543
+ sumqx += w*q*xval[i];
10544
+ sumq2 += w*q*q;
10545
+ }
10546
+ if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
10547
+ *scale = sumqx/sumq2; best_score = *scale * sumqx;
10548
+ grid_index = neighbours[j];
10549
+ }
10550
+ }
10551
+ if (grid_index < 0) {
10552
+ for (int i = 0; i < ngrid; ++i) {
10553
+ const int8_t * grid_i = (const int8_t *)(grid + i);
10554
+ float sumqx = 0, sumq2 = 0;
10555
+ for (int j = 0; j < 8; ++j) {
10556
+ float w = weight[j];
10557
+ float q = (grid_i[j] - 3)/2;
10558
+ sumqx += w*q*xval[j];
10559
+ sumq2 += w*q*q;
10560
+ }
10561
+ if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
10562
+ *scale = sumqx/sumq2; best_score = *scale*sumqx;
10563
+ grid_index = i;
10564
+ }
10565
+ }
10566
+ }
10567
+ if (grid_index < 0) {
10568
+ printf("Oops, did not find grid point\n");
10569
+ printf("Have %d neighbours\n", num_neighbors);
10570
+ for (int j = 1; j <= num_neighbors; ++j) {
10571
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
10572
+ float sumqx = 0, sumq2 = 0;
10573
+ for (int i = 0; i < 8; ++i) {
10574
+ float q = (pg[i] - 3)/2;
10575
+ float w = weight[i];
10576
+ sumqx += w*q*xval[i];
10577
+ sumq2 += w*q*q;
10578
+ }
10579
+ printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
10580
+ }
10581
+ }
10582
+ GGML_ASSERT(grid_index >= 0);
10583
+ //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
10584
+ *scale *= 1.05f; // This is a fudge factor. Don't ask me why it improves the result.
10585
+ //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
10586
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
10587
+ for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
10588
+ return grid_index;
10589
+ }
10590
+
10591
+ static int iq1_sort_helper(const void * left, const void * right) {
10592
+ const float * l = left;
10593
+ const float * r = right;
10594
+ return *l < *r ? -1 : *l > *r ? 1 : 0;
10595
+ }
10596
+
10597
+ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
10598
+
10599
+ const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
10600
+
10601
+ const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
10602
+ const int * kmap_q2xs = iq2_data[gindex].map;
10603
+ const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
10604
+
10605
+ GGML_ASSERT(quant_weights && "missing quantization weights");
10606
+ GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
10607
+ GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
10608
+ GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
10609
+ GGML_ASSERT(n%QK_K == 0);
10610
+
10611
+ const int nbl = n/256;
10612
+
10613
+ block_iq1_s * y = vy;
10614
+
10615
+ float scales[QK_K/8];
10616
+ float weight[8];
10617
+ int8_t L[8];
10618
+ float sumx[9];
10619
+ float sumw[9];
10620
+ float pairs[16];
10621
+ int * idx = (int *)(pairs + 1);
10622
+ uint8_t hbit[QK_K/8];
10623
+
10624
+ for (int ibl = 0; ibl < nbl; ++ibl) {
10625
+
10626
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
10627
+ memset(y[ibl].qs, 0, QK_K/8);
10628
+ memset(y[ibl].scales, 0, QK_K/16);
10629
+
10630
+ float max_scale = 0;
10631
+
10632
+ const float * xbl = x + QK_K*ibl;
10633
+ float sumx2 = 0;
10634
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
10635
+ float sigma2 = sumx2/QK_K;
10636
+
10637
+ for (int ib = 0; ib < QK_K/8; ++ib) {
10638
+ const float * xb = xbl + 8*ib;
10639
+ const float * qw = quant_weights + QK_K*ibl + 8*ib;
10640
+ for (int i = 0; i < 8; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
10641
+ float max = fabsf(xb[0]);
10642
+ for (int i = 1; i < 8; ++i) max = MAX(max, fabsf(xb[i]));
10643
+ if (!max) {
10644
+ scales[ib] = 0;
10645
+ memset(L, 1, 8);
10646
+ continue;
10647
+ }
10648
+ // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
10649
+ // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
10650
+ // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
10651
+ // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
10652
+ // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
10653
+ // for each possible and score for each split.
10654
+ for (int j = 0; j < 8; ++j) {
10655
+ pairs[2*j] = xb[j];
10656
+ idx[2*j] = j;
10657
+ }
10658
+ qsort(pairs, 8, 2*sizeof(float), iq1_sort_helper);
10659
+ {
10660
+ sumx[0] = sumw[0] = 0;
10661
+ for (int j = 0; j < 8; ++j) {
10662
+ int i = idx[2*j];
10663
+ sumx[j+1] = sumx[j] + weight[i]*xb[i];
10664
+ sumw[j+1] = sumw[j] + weight[i];
10665
+ }
10666
+ }
10667
+ float best_score = 0, scale = max;
10668
+ int besti1 = 0, besti2 = 0;
10669
+ for (int i1 = 0; i1 <= 8; ++i1) {
10670
+ for (int i2 = i1; i2 <= 8; ++i2) {
10671
+ float sumqx = -(sumx[i1] - sumx[0]) + (sumx[8] - sumx[i2]);
10672
+ float sumq2 = (sumw[i1] - sumw[0]) + (sumw[8] - sumw[i2]);
10673
+ if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
10674
+ scale = sumqx/sumq2; best_score = scale*sumqx;
10675
+ besti1 = i1; besti2 = i2;
10676
+ }
10677
+ }
10678
+ }
10679
+ for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
10680
+ for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
10681
+ for (int j = besti2; j < 8; ++j) L[idx[2*j]] = 2;
10682
+ if (scale < 0) {
10683
+ for (int j = 0; j < 8; ++j) L[j] = 2 - L[j];
10684
+ scale = -scale;
10685
+ }
10686
+ // Now we check if the solution found above corresponds to a grid point and, if not, use a neighbouring
10687
+ // grid point that minimizes SSD.
10688
+ uint16_t u = 0;
10689
+ for (int j = 0; j < 8; ++j) u |= (L[j] << 2*j);
10690
+ int grid_index = kmap_q2xs[u];
10691
+ if (grid_index < 0) {
10692
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
10693
+ grid_index = iq1_find_best_neighbour(neighbours, kgrid_q2xs, xb, weight, &scale, L, NGRID_IQ2XXS);
10694
+ GGML_ASSERT(grid_index >= 0);
10695
+ }
10696
+ y[ibl].qs[ib] = grid_index & 255;
10697
+ hbit[ib] = grid_index >> 8;
10698
+ GGML_ASSERT(scale >= 0);
10699
+ scales[ib] = scale;
10700
+ max_scale = MAX(max_scale, scale);
10701
+ }
10702
+
10703
+ if (!max_scale) {
10704
+ memset(y[ibl].qs, 0, QK_K/8);
10705
+ continue;
10706
+ }
10707
+
10708
+ float d = max_scale/15;
10709
+ y[ibl].d = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed.
10710
+ float id = 1/d;
10711
+ for (int ib = 0; ib < QK_K/8; ++ib) {
10712
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
10713
+ l = MAX(0, MIN(7, l));
10714
+ if (hbit[ib]) l |= 8;
10715
+ y[ibl].scales[ib/2] |= (l << 4*(ib%2));
10716
+ }
10717
+ }
10718
+ }
10719
+
10720
+ size_t quantize_iq1_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
10721
+ (void)hist;
10722
+ GGML_ASSERT(n_per_row%QK_K == 0);
10723
+ int nblock = n_per_row/QK_K;
10724
+ char * qrow = (char *)dst;
10725
+ for (int row = 0; row < nrow; ++row) {
10726
+ quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights);
10727
+ src += n_per_row;
10728
+ qrow += nblock*sizeof(block_iq1_s);
10729
+ }
10730
+ return nrow * nblock * sizeof(block_iq1_s);
10731
+ }
ggml-quants.h CHANGED
@@ -191,6 +191,13 @@ typedef struct {
191
  } block_iq3_xxs;
192
  static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
193
 
 
 
 
 
 
 
 
194
  #ifdef __cplusplus
195
  extern "C" {
196
  #endif
@@ -243,6 +250,7 @@ void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRI
243
  void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
244
  void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
245
  void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 
246
 
247
  // Dot product
248
  void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -259,6 +267,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
259
  void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
260
  void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
261
  void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
262
 
263
  //
264
  // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
@@ -266,6 +275,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
266
  size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
267
  size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
268
  size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 
269
  size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
270
  size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
271
  size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
@@ -276,8 +286,8 @@ size_t quantize_q4_1 (const float * src, void * dst, int nrows, int n_per_row,
276
  size_t quantize_q5_0 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
277
  size_t quantize_q5_1 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
278
 
279
- void iq2xs_init_impl(int grid_size);
280
- void iq2xs_free_impl(int grid_size);
281
  void iq3xs_init_impl(int grid_size);
282
  void iq3xs_free_impl(int grid_size);
283
 
 
191
  } block_iq3_xxs;
192
  static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
193
 
194
+ typedef struct {
195
+ ggml_fp16_t d;
196
+ uint8_t qs[QK_K/8];
197
+ uint8_t scales[QK_K/16];
198
+ } block_iq1_s;
199
+ static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
200
+
201
  #ifdef __cplusplus
202
  extern "C" {
203
  #endif
 
250
  void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
251
  void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
252
  void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
253
+ void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
254
 
255
  // Dot product
256
  void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
267
  void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
268
  void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
269
  void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
270
+ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
271
 
272
  //
273
  // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
 
275
  size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
276
  size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
277
  size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
278
+ size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
279
  size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
280
  size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
281
  size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 
286
  size_t quantize_q5_0 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
287
  size_t quantize_q5_1 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
288
 
289
+ void iq2xs_init_impl(enum ggml_type type);
290
+ void iq2xs_free_impl(enum ggml_type type);
291
  void iq3xs_init_impl(int grid_size);
292
  void iq3xs_free_impl(int grid_size);
293
 
ggml.c CHANGED
@@ -673,6 +673,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
673
  .vec_dot_type = GGML_TYPE_Q8_K,
674
  .nrows = 1,
675
  },
 
 
 
 
 
 
 
 
 
 
 
 
676
  [GGML_TYPE_Q8_K] = {
677
  .type_name = "q8_K",
678
  .blck_size = QK_K,
@@ -2267,6 +2279,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2267
  case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
2268
  case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
2269
  case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
 
2270
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
2271
  case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
2272
  }
@@ -7677,6 +7690,7 @@ static void ggml_compute_forward_add(
7677
  case GGML_TYPE_IQ2_XXS:
7678
  case GGML_TYPE_IQ2_XS:
7679
  case GGML_TYPE_IQ3_XXS:
 
7680
  {
7681
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
7682
  } break;
@@ -7944,6 +7958,7 @@ static void ggml_compute_forward_add1(
7944
  case GGML_TYPE_IQ2_XXS:
7945
  case GGML_TYPE_IQ2_XS:
7946
  case GGML_TYPE_IQ3_XXS:
 
7947
  {
7948
  ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
7949
  } break;
@@ -8064,6 +8079,7 @@ static void ggml_compute_forward_acc(
8064
  case GGML_TYPE_IQ2_XXS:
8065
  case GGML_TYPE_IQ2_XS:
8066
  case GGML_TYPE_IQ3_XXS:
 
8067
  default:
8068
  {
8069
  GGML_ASSERT(false);
@@ -10830,6 +10846,7 @@ static void ggml_compute_forward_out_prod(
10830
  case GGML_TYPE_IQ2_XXS:
10831
  case GGML_TYPE_IQ2_XS:
10832
  case GGML_TYPE_IQ3_XXS:
 
10833
  {
10834
  ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
10835
  } break;
@@ -11010,6 +11027,7 @@ static void ggml_compute_forward_set(
11010
  case GGML_TYPE_IQ2_XXS:
11011
  case GGML_TYPE_IQ2_XS:
11012
  case GGML_TYPE_IQ3_XXS:
 
11013
  default:
11014
  {
11015
  GGML_ASSERT(false);
@@ -11207,6 +11225,7 @@ static void ggml_compute_forward_get_rows(
11207
  case GGML_TYPE_IQ2_XXS:
11208
  case GGML_TYPE_IQ2_XS:
11209
  case GGML_TYPE_IQ3_XXS:
 
11210
  {
11211
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
11212
  } break;
@@ -11880,6 +11899,7 @@ static void ggml_compute_forward_alibi(
11880
  case GGML_TYPE_IQ2_XXS:
11881
  case GGML_TYPE_IQ2_XS:
11882
  case GGML_TYPE_IQ3_XXS:
 
11883
  case GGML_TYPE_Q8_K:
11884
  case GGML_TYPE_I8:
11885
  case GGML_TYPE_I16:
@@ -11957,6 +11977,7 @@ static void ggml_compute_forward_clamp(
11957
  case GGML_TYPE_IQ2_XXS:
11958
  case GGML_TYPE_IQ2_XS:
11959
  case GGML_TYPE_IQ3_XXS:
 
11960
  case GGML_TYPE_Q8_K:
11961
  case GGML_TYPE_I8:
11962
  case GGML_TYPE_I16:
@@ -19136,8 +19157,9 @@ void ggml_quantize_init(enum ggml_type type) {
19136
  ggml_critical_section_start();
19137
 
19138
  switch (type) {
19139
- case GGML_TYPE_IQ2_XXS: iq2xs_init_impl(256); break;
19140
- case GGML_TYPE_IQ2_XS: iq2xs_init_impl(512); break;
 
19141
  case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
19142
  default: // nothing
19143
  break;
@@ -19149,8 +19171,10 @@ void ggml_quantize_init(enum ggml_type type) {
19149
  void ggml_quantize_free(void) {
19150
  ggml_critical_section_start();
19151
 
19152
- iq2xs_free_impl(256);
19153
- iq2xs_free_impl(512);
 
 
19154
 
19155
  ggml_critical_section_end();
19156
  }
@@ -19285,7 +19309,8 @@ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t *
19285
  bool ggml_quantize_requires_imatrix(enum ggml_type type) {
19286
  return
19287
  type == GGML_TYPE_IQ2_XXS ||
19288
- type == GGML_TYPE_IQ2_XS;
 
19289
  }
19290
 
19291
  size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start,
@@ -19410,6 +19435,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
19410
  result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19411
  GGML_ASSERT(result == row_size * nrows);
19412
  } break;
 
 
 
 
 
 
 
 
 
19413
  case GGML_TYPE_F16:
19414
  {
19415
  size_t elemsize = sizeof(ggml_fp16_t);
 
673
  .vec_dot_type = GGML_TYPE_Q8_K,
674
  .nrows = 1,
675
  },
676
+ [GGML_TYPE_IQ1_S] = {
677
+ .type_name = "iq1_s",
678
+ .blck_size = QK_K,
679
+ .type_size = sizeof(block_iq1_s),
680
+ .is_quantized = true,
681
+ .to_float = (ggml_to_float_t) dequantize_row_iq1_s,
682
+ .from_float = NULL,
683
+ .from_float_reference = NULL,
684
+ .vec_dot = ggml_vec_dot_iq1_s_q8_K,
685
+ .vec_dot_type = GGML_TYPE_Q8_K,
686
+ .nrows = 1,
687
+ },
688
  [GGML_TYPE_Q8_K] = {
689
  .type_name = "q8_K",
690
  .blck_size = QK_K,
 
2279
  case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
2280
  case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
2281
  case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
2282
+ case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
2283
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
2284
  case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
2285
  }
 
7690
  case GGML_TYPE_IQ2_XXS:
7691
  case GGML_TYPE_IQ2_XS:
7692
  case GGML_TYPE_IQ3_XXS:
7693
+ case GGML_TYPE_IQ1_S:
7694
  {
7695
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
7696
  } break;
 
7958
  case GGML_TYPE_IQ2_XXS:
7959
  case GGML_TYPE_IQ2_XS:
7960
  case GGML_TYPE_IQ3_XXS:
7961
+ case GGML_TYPE_IQ1_S:
7962
  {
7963
  ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
7964
  } break;
 
8079
  case GGML_TYPE_IQ2_XXS:
8080
  case GGML_TYPE_IQ2_XS:
8081
  case GGML_TYPE_IQ3_XXS:
8082
+ case GGML_TYPE_IQ1_S:
8083
  default:
8084
  {
8085
  GGML_ASSERT(false);
 
10846
  case GGML_TYPE_IQ2_XXS:
10847
  case GGML_TYPE_IQ2_XS:
10848
  case GGML_TYPE_IQ3_XXS:
10849
+ case GGML_TYPE_IQ1_S:
10850
  {
10851
  ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
10852
  } break;
 
11027
  case GGML_TYPE_IQ2_XXS:
11028
  case GGML_TYPE_IQ2_XS:
11029
  case GGML_TYPE_IQ3_XXS:
11030
+ case GGML_TYPE_IQ1_S:
11031
  default:
11032
  {
11033
  GGML_ASSERT(false);
 
11225
  case GGML_TYPE_IQ2_XXS:
11226
  case GGML_TYPE_IQ2_XS:
11227
  case GGML_TYPE_IQ3_XXS:
11228
+ case GGML_TYPE_IQ1_S:
11229
  {
11230
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
11231
  } break;
 
11899
  case GGML_TYPE_IQ2_XXS:
11900
  case GGML_TYPE_IQ2_XS:
11901
  case GGML_TYPE_IQ3_XXS:
11902
+ case GGML_TYPE_IQ1_S:
11903
  case GGML_TYPE_Q8_K:
11904
  case GGML_TYPE_I8:
11905
  case GGML_TYPE_I16:
 
11977
  case GGML_TYPE_IQ2_XXS:
11978
  case GGML_TYPE_IQ2_XS:
11979
  case GGML_TYPE_IQ3_XXS:
11980
+ case GGML_TYPE_IQ1_S:
11981
  case GGML_TYPE_Q8_K:
11982
  case GGML_TYPE_I8:
11983
  case GGML_TYPE_I16:
 
19157
  ggml_critical_section_start();
19158
 
19159
  switch (type) {
19160
+ case GGML_TYPE_IQ2_XXS:
19161
+ case GGML_TYPE_IQ2_XS:
19162
+ case GGML_TYPE_IQ1_S: iq2xs_init_impl(type); break;
19163
  case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
19164
  default: // nothing
19165
  break;
 
19171
  void ggml_quantize_free(void) {
19172
  ggml_critical_section_start();
19173
 
19174
+ iq2xs_free_impl(GGML_TYPE_IQ2_XXS);
19175
+ iq2xs_free_impl(GGML_TYPE_IQ2_XS);
19176
+ iq2xs_free_impl(GGML_TYPE_IQ1_S);
19177
+ iq3xs_free_impl(256);
19178
 
19179
  ggml_critical_section_end();
19180
  }
 
19309
  bool ggml_quantize_requires_imatrix(enum ggml_type type) {
19310
  return
19311
  type == GGML_TYPE_IQ2_XXS ||
19312
+ type == GGML_TYPE_IQ2_XS ||
19313
+ type == GGML_TYPE_IQ1_S;
19314
  }
19315
 
19316
  size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start,
 
19435
  result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19436
  GGML_ASSERT(result == row_size * nrows);
19437
  } break;
19438
+ case GGML_TYPE_IQ1_S:
19439
+ {
19440
+ GGML_ASSERT(start % QK_K == 0);
19441
+ GGML_ASSERT(start % n_per_row == 0);
19442
+ size_t start_row = start / n_per_row;
19443
+ size_t row_size = ggml_row_size(type, n_per_row);
19444
+ result = quantize_iq1_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19445
+ GGML_ASSERT(result == row_size * nrows);
19446
+ } break;
19447
  case GGML_TYPE_F16:
19448
  {
19449
  size_t elemsize = sizeof(ggml_fp16_t);
ggml.h CHANGED
@@ -354,6 +354,7 @@ extern "C" {
354
  GGML_TYPE_IQ2_XXS = 16,
355
  GGML_TYPE_IQ2_XS = 17,
356
  GGML_TYPE_IQ3_XXS = 18,
 
357
  GGML_TYPE_I8,
358
  GGML_TYPE_I16,
359
  GGML_TYPE_I32,
@@ -391,6 +392,7 @@ extern "C" {
391
  GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
392
  GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
393
  GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
 
394
  };
395
 
396
  // available tensor operations:
 
354
  GGML_TYPE_IQ2_XXS = 16,
355
  GGML_TYPE_IQ2_XS = 17,
356
  GGML_TYPE_IQ3_XXS = 18,
357
+ GGML_TYPE_IQ1_S = 19,
358
  GGML_TYPE_I8,
359
  GGML_TYPE_I16,
360
  GGML_TYPE_I32,
 
392
  GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
393
  GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
394
  GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
395
+ GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
396
  };
397
 
398
  // available tensor operations: