Kawrakow ikawrakow commited on
Commit
75de5bf
·
unverified ·
1 Parent(s): 70c8d60

SOTA 2-bit quants (llama/4773)

Browse files

* iq2_xxs: basics

* iq2_xxs: scalar and AVX2 dot products

Needed to change Q8_K to have quants in the -127...127 range,
else the IQ2_XXS AVX implementation becomes very awkward.
The alternative would have been to use Q8_0 instead. Perhaps
I'll change later, for now this is what we have.

* iq2_xxs: ARM_NEON dot product

Somehow strangely slow (112 ms/token).

* iq2_xxs: WIP Metal

Dequantize works, something is still wrong with the
dot product.

* iq2_xxs: Metal dot product now works

We have
PP-512 = 475 t/s
TG-128 = 47.3 t/s

Not the greatest performance, but not complete garbage either.

* iq2_xxs: slighty faster dot product

TG-128 is now 48.4 t/s

* iq2_xxs: slighty faster dot product

TG-128 is now 50.9 t/s

* iq2_xxs: even faster Metal dot product

TG-128 is now 54.1 t/s.

Strangely enough, putting the signs lookup table
into shared memory has a bigger impact than the
grid values being in shared memory.

* iq2_xxs: dequantize CUDA kernel - fix conflict with master

* iq2_xxs: quantized CUDA dot product (MMVQ)

We get TG-128 = 153.1 t/s

* iq2_xxs: slightly faster CUDA dot product

TG-128 is now at 155.1 t/s.

* iq2_xxs: add to llama ftype enum

* iq2_xxs: fix MoE on Metal

* Fix missing MMQ ops when on hipBLAS

I had put the ggml_supports_mmq call at the wrong place.

* Fix bug in qequantize_row_iq2_xxs

The 0.25f factor was missing.
Great detective work by

@ggerganov
!

* Fixing tests

* PR suggestion

---------

Co-authored-by: Iwan Kawrakow <[email protected]>

Files changed (7) hide show
  1. ggml-cuda.cu +205 -0
  2. ggml-metal.m +40 -0
  3. ggml-metal.metal +314 -0
  4. ggml-quants.c +293 -1
  5. ggml-quants.h +12 -0
  6. ggml.c +26 -0
  7. ggml.h +3 -0
ggml-cuda.cu CHANGED
@@ -477,6 +477,14 @@ typedef struct {
477
  } block_q6_K;
478
  static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
479
 
 
 
 
 
 
 
 
 
480
  #define WARP_SIZE 32
481
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
482
 
@@ -1292,6 +1300,128 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
1292
  #endif
1293
  }
1294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1295
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1296
 
1297
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -3825,6 +3955,55 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
3825
  return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
3826
  }
3827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3828
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
3829
  allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
3830
  static __device__ __forceinline__ void mul_mat_q(
@@ -5664,6 +5843,12 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
5664
  #endif
5665
  }
5666
 
 
 
 
 
 
 
5667
  template <typename src_t, typename dst_t>
5668
  static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
5669
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@@ -5692,6 +5877,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
5692
  return dequantize_row_q5_K_cuda;
5693
  case GGML_TYPE_Q6_K:
5694
  return dequantize_row_q6_K_cuda;
 
 
5695
  case GGML_TYPE_F32:
5696
  return convert_unary_cuda<float>;
5697
  default:
@@ -5721,6 +5908,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
5721
  return dequantize_row_q5_K_cuda;
5722
  case GGML_TYPE_Q6_K:
5723
  return dequantize_row_q6_K_cuda;
 
 
5724
  case GGML_TYPE_F16:
5725
  return convert_unary_cuda<half>;
5726
  default:
@@ -5915,6 +6104,15 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
5915
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
5916
  }
5917
 
 
 
 
 
 
 
 
 
 
5918
  static void ggml_mul_mat_q4_0_q8_1_cuda(
5919
  const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
5920
  const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -7407,6 +7605,7 @@ static int64_t get_row_rounding(ggml_type type) {
7407
  case GGML_TYPE_Q4_K:
7408
  case GGML_TYPE_Q5_K:
7409
  case GGML_TYPE_Q6_K:
 
7410
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
7411
  default:
7412
  GGML_ASSERT(false);
@@ -7427,6 +7626,7 @@ static int64_t get_row_rounding(ggml_type type) {
7427
  case GGML_TYPE_Q3_K:
7428
  case GGML_TYPE_Q4_K:
7429
  case GGML_TYPE_Q5_K:
 
7430
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
7431
  case GGML_TYPE_Q6_K:
7432
  return 64;
@@ -7477,6 +7677,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
7477
  case GGML_TYPE_Q6_K:
7478
  mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
7479
  break;
 
 
 
7480
  default:
7481
  GGML_ASSERT(false);
7482
  break;
@@ -8693,6 +8896,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
8693
 
8694
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8695
 
 
 
8696
  // debug helpers
8697
  //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
8698
  //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
 
477
  } block_q6_K;
478
  static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
479
 
480
+ #define QR2_XXS 8
481
+ #define QI2_XXS (QK_K / (4*QR2_XXS))
482
+ typedef struct {
483
+ half d;
484
+ uint16_t qs[QK_K/8];
485
+ } block_iq2_xxs;
486
+ static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
487
+
488
  #define WARP_SIZE 32
489
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
490
 
 
1300
  #endif
1301
  }
1302
 
1303
+ static const __device__ uint64_t kgrid_iq2xxs[256] = {
1304
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
1305
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
1306
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
1307
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
1308
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
1309
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
1310
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
1311
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
1312
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
1313
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
1314
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
1315
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
1316
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
1317
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
1318
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
1319
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
1320
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
1321
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
1322
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
1323
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
1324
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
1325
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
1326
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
1327
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
1328
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
1329
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
1330
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
1331
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
1332
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
1333
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
1334
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
1335
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
1336
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
1337
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
1338
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
1339
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
1340
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
1341
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
1342
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
1343
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
1344
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
1345
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
1346
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
1347
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
1348
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
1349
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
1350
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
1351
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
1352
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
1353
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
1354
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
1355
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
1356
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
1357
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
1358
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
1359
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
1360
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
1361
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
1362
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
1363
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
1364
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
1365
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
1366
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
1367
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
1368
+ };
1369
+
1370
+ static const __device__ uint8_t ksigns_iq2xs[128] = {
1371
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
1372
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
1373
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
1374
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
1375
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
1376
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
1377
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
1378
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
1379
+ };
1380
+
1381
+ static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
1382
+
1383
+ inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
1384
+ switch (type) {
1385
+ case GGML_TYPE_Q4_0:
1386
+ case GGML_TYPE_Q4_1:
1387
+ case GGML_TYPE_Q5_0:
1388
+ case GGML_TYPE_Q5_1:
1389
+ case GGML_TYPE_Q8_0:
1390
+ case GGML_TYPE_Q2_K:
1391
+ case GGML_TYPE_Q3_K:
1392
+ case GGML_TYPE_Q4_K:
1393
+ case GGML_TYPE_Q5_K:
1394
+ case GGML_TYPE_Q6_K:
1395
+ return true;
1396
+ default:
1397
+ return false;
1398
+ }
1399
+ }
1400
+
1401
+ template<typename dst_t>
1402
+ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
1403
+
1404
+ const int i = blockIdx.x;
1405
+ const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
1406
+
1407
+ const int tid = threadIdx.x;
1408
+ #if QK_K == 256
1409
+ const int il = tid/8; // 0...3
1410
+ const int ib = tid%8; // 0...7
1411
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
1412
+ const uint16_t * q2 = x[i].qs + 4*ib;
1413
+ const uint8_t * aux8 = (const uint8_t *)q2;
1414
+ const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[il]);
1415
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
1416
+ const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
1417
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
1418
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
1419
+ #else
1420
+ assert(false);
1421
+ #endif
1422
+
1423
+ }
1424
+
1425
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1426
 
1427
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
3955
  return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
3956
  }
3957
 
3958
+ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
3959
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
3960
+ #if QK_K == 256
3961
+ const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
3962
+
3963
+ #if QR2_XXS == 8
3964
+ const int ib32 = iqs;
3965
+ const uint16_t * q2 = bq2->qs + 4*ib32;
3966
+ const uint8_t * aux8 = (const uint8_t *)q2;
3967
+ const int8_t * q8 = bq8_1[ib32].qs;
3968
+ uint32_t aux32 = q2[2] | (q2[3] << 16);
3969
+ int sumi = 0;
3970
+ for (int l = 0; l < 4; ++l) {
3971
+ const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[l]);
3972
+ const uint8_t signs = ksigns_iq2xs[aux32 & 127];
3973
+ for (int j = 0; j < 8; ++j) {
3974
+ sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
3975
+ }
3976
+ q8 += 8;
3977
+ aux32 >>= 7;
3978
+ }
3979
+ const float d = (float)bq2->d * (0.5f + aux32) * (float)bq8_1[ib32].ds.x * 0.25f;
3980
+ return d * sumi;
3981
+ #else
3982
+ // iqs is 0...15
3983
+ const int ib32 = iqs/2;
3984
+ const int il = iqs%2;
3985
+ const uint16_t * q2 = bq2->qs + 4*ib32;
3986
+ const uint8_t * aux8 = (const uint8_t *)q2;
3987
+ const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
3988
+ const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
3989
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
3990
+ const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
3991
+ const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
3992
+ const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
3993
+ const int8_t * q8 = bq8_1[ib32].qs + 16*il;
3994
+ int sumi1 = 0, sumi2 = 0;
3995
+ for (int j = 0; j < 8; ++j) {
3996
+ sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
3997
+ sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
3998
+ }
3999
+ return d * (sumi1 + sumi2);
4000
+ #endif
4001
+ #else
4002
+ assert(false);
4003
+ return 0.f;
4004
+ #endif
4005
+ }
4006
+
4007
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
4008
  allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
4009
  static __device__ __forceinline__ void mul_mat_q(
 
5843
  #endif
5844
  }
5845
 
5846
+ template<typename dst_t>
5847
+ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
5848
+ const int nb = k / QK_K;
5849
+ dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
5850
+ }
5851
+
5852
  template <typename src_t, typename dst_t>
5853
  static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
5854
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
 
5877
  return dequantize_row_q5_K_cuda;
5878
  case GGML_TYPE_Q6_K:
5879
  return dequantize_row_q6_K_cuda;
5880
+ case GGML_TYPE_IQ2_XXS:
5881
+ return dequantize_row_iq2_xxs_cuda;
5882
  case GGML_TYPE_F32:
5883
  return convert_unary_cuda<float>;
5884
  default:
 
5908
  return dequantize_row_q5_K_cuda;
5909
  case GGML_TYPE_Q6_K:
5910
  return dequantize_row_q6_K_cuda;
5911
+ case GGML_TYPE_IQ2_XXS:
5912
+ return dequantize_row_iq2_xxs_cuda;
5913
  case GGML_TYPE_F16:
5914
  return convert_unary_cuda<half>;
5915
  default:
 
6104
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6105
  }
6106
 
6107
+ static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6108
+ GGML_ASSERT(ncols % QK_K == 0);
6109
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6110
+ const dim3 block_nums(block_num_y, 1, 1);
6111
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6112
+ mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
6113
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6114
+ }
6115
+
6116
  static void ggml_mul_mat_q4_0_q8_1_cuda(
6117
  const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
6118
  const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
7605
  case GGML_TYPE_Q4_K:
7606
  case GGML_TYPE_Q5_K:
7607
  case GGML_TYPE_Q6_K:
7608
+ case GGML_TYPE_IQ2_XXS:
7609
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
7610
  default:
7611
  GGML_ASSERT(false);
 
7626
  case GGML_TYPE_Q3_K:
7627
  case GGML_TYPE_Q4_K:
7628
  case GGML_TYPE_Q5_K:
7629
+ case GGML_TYPE_IQ2_XXS:
7630
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
7631
  case GGML_TYPE_Q6_K:
7632
  return 64;
 
7677
  case GGML_TYPE_Q6_K:
7678
  mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
7679
  break;
7680
+ case GGML_TYPE_IQ2_XXS:
7681
+ mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
7682
+ break;
7683
  default:
7684
  GGML_ASSERT(false);
7685
  break;
 
8896
 
8897
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8898
 
8899
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
8900
+
8901
  // debug helpers
8902
  //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
8903
  //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
ggml-metal.m CHANGED
@@ -88,6 +88,7 @@ struct ggml_metal_context {
88
  GGML_METAL_DECL_KERNEL(get_rows_q5_K);
89
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
90
  GGML_METAL_DECL_KERNEL(get_rows_i32);
 
91
  GGML_METAL_DECL_KERNEL(rms_norm);
92
  GGML_METAL_DECL_KERNEL(group_norm);
93
  GGML_METAL_DECL_KERNEL(norm);
@@ -106,6 +107,7 @@ struct ggml_metal_context {
106
  GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
107
  GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
108
  GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
 
109
  GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
110
  //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
111
  GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
@@ -121,6 +123,7 @@ struct ggml_metal_context {
121
  GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
122
  GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
123
  GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
 
124
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
125
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
126
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -133,6 +136,7 @@ struct ggml_metal_context {
133
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
134
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
135
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
 
136
  GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
137
  GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
138
  GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
@@ -145,6 +149,7 @@ struct ggml_metal_context {
145
  GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
146
  GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
147
  GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
 
148
  GGML_METAL_DECL_KERNEL(rope_f32);
149
  GGML_METAL_DECL_KERNEL(rope_f16);
150
  GGML_METAL_DECL_KERNEL(alibi_f32);
@@ -379,6 +384,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
379
  GGML_METAL_ADD_KERNEL(get_rows_q5_K);
380
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
381
  GGML_METAL_ADD_KERNEL(get_rows_i32);
 
382
  GGML_METAL_ADD_KERNEL(rms_norm);
383
  GGML_METAL_ADD_KERNEL(group_norm);
384
  GGML_METAL_ADD_KERNEL(norm);
@@ -397,6 +403,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
397
  GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
398
  GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
399
  GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
 
400
  GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
401
  //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
402
  GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
@@ -412,6 +419,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
412
  GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
413
  GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
414
  GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
 
415
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
416
  GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
417
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -425,6 +433,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
425
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
426
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
427
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
 
428
  GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
429
  GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
430
  GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
@@ -437,6 +446,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
437
  GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
438
  GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
439
  GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
 
440
  }
441
  GGML_METAL_ADD_KERNEL(rope_f32);
442
  GGML_METAL_ADD_KERNEL(rope_f16);
@@ -502,6 +512,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
502
  GGML_METAL_DEL_KERNEL(get_rows_q5_K);
503
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
504
  GGML_METAL_DEL_KERNEL(get_rows_i32);
 
505
  GGML_METAL_DEL_KERNEL(rms_norm);
506
  GGML_METAL_DEL_KERNEL(group_norm);
507
  GGML_METAL_DEL_KERNEL(norm);
@@ -520,6 +531,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
520
  GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
521
  GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
522
  GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
 
523
  GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
524
  //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
525
  GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
@@ -535,6 +547,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
535
  GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
536
  GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
537
  GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
 
538
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
539
  GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
540
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -548,6 +561,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
548
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
549
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
550
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
 
551
  GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
552
  GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
553
  GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
@@ -560,6 +574,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
560
  GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
561
  GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
562
  GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
 
563
  }
564
  GGML_METAL_DEL_KERNEL(rope_f32);
565
  GGML_METAL_DEL_KERNEL(rope_f16);
@@ -1541,6 +1556,7 @@ bool ggml_metal_graph_compute(
1541
  case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
1542
  case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
1543
  case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
 
1544
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1545
  }
1546
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1653,6 +1669,12 @@ bool ggml_metal_graph_compute(
1653
  nth1 = 32;
1654
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
1655
  } break;
 
 
 
 
 
 
1656
  default:
1657
  {
1658
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1686,9 +1708,14 @@ bool ggml_metal_graph_compute(
1686
 
1687
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1688
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
 
1689
  src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1690
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1691
  }
 
 
 
 
1692
  else if (src0t == GGML_TYPE_Q4_K) {
1693
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1694
  }
@@ -1778,6 +1805,7 @@ bool ggml_metal_graph_compute(
1778
  case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1779
  case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1780
  case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
 
1781
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1782
  }
1783
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1893,6 +1921,12 @@ bool ggml_metal_graph_compute(
1893
  nth1 = 32;
1894
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1895
  } break;
 
 
 
 
 
 
1896
  default:
1897
  {
1898
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1942,9 +1976,14 @@ bool ggml_metal_graph_compute(
1942
 
1943
  if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1944
  src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
 
1945
  src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1946
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1947
  }
 
 
 
 
1948
  else if (src2t == GGML_TYPE_Q4_K) {
1949
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1950
  }
@@ -1982,6 +2021,7 @@ bool ggml_metal_graph_compute(
1982
  case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
1983
  case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
1984
  case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
 
1985
  default: GGML_ASSERT(false && "not implemented");
1986
  }
1987
 
 
88
  GGML_METAL_DECL_KERNEL(get_rows_q5_K);
89
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
90
  GGML_METAL_DECL_KERNEL(get_rows_i32);
91
+ GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
92
  GGML_METAL_DECL_KERNEL(rms_norm);
93
  GGML_METAL_DECL_KERNEL(group_norm);
94
  GGML_METAL_DECL_KERNEL(norm);
 
107
  GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
108
  GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
109
  GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
110
+ GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
111
  GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
112
  //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
113
  GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
 
123
  GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
124
  GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
125
  GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
126
+ GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
127
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
128
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
129
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
 
136
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
137
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
138
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
139
+ GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
140
  GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
141
  GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
142
  GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
 
149
  GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
150
  GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
151
  GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
152
+ GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
153
  GGML_METAL_DECL_KERNEL(rope_f32);
154
  GGML_METAL_DECL_KERNEL(rope_f16);
155
  GGML_METAL_DECL_KERNEL(alibi_f32);
 
384
  GGML_METAL_ADD_KERNEL(get_rows_q5_K);
385
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
386
  GGML_METAL_ADD_KERNEL(get_rows_i32);
387
+ GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
388
  GGML_METAL_ADD_KERNEL(rms_norm);
389
  GGML_METAL_ADD_KERNEL(group_norm);
390
  GGML_METAL_ADD_KERNEL(norm);
 
403
  GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
404
  GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
405
  GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
406
+ GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
407
  GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
408
  //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
409
  GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
 
419
  GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
420
  GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
421
  GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
422
+ GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
423
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
424
  GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
425
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
 
433
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
434
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
435
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
436
+ GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
437
  GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
438
  GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
439
  GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
 
446
  GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
447
  GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
448
  GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
449
+ GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
450
  }
451
  GGML_METAL_ADD_KERNEL(rope_f32);
452
  GGML_METAL_ADD_KERNEL(rope_f16);
 
512
  GGML_METAL_DEL_KERNEL(get_rows_q5_K);
513
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
514
  GGML_METAL_DEL_KERNEL(get_rows_i32);
515
+ GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
516
  GGML_METAL_DEL_KERNEL(rms_norm);
517
  GGML_METAL_DEL_KERNEL(group_norm);
518
  GGML_METAL_DEL_KERNEL(norm);
 
531
  GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
532
  GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
533
  GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
534
+ GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
535
  GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
536
  //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
537
  GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
 
547
  GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
548
  GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
549
  GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
550
+ GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
551
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
552
  GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
553
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
 
561
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
562
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
563
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
564
+ GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
565
  GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
566
  GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
567
  GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
 
574
  GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
575
  GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
576
  GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
577
+ GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
578
  }
579
  GGML_METAL_DEL_KERNEL(rope_f32);
580
  GGML_METAL_DEL_KERNEL(rope_f16);
 
1556
  case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
1557
  case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
1558
  case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
1559
+ case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
1560
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1561
  }
1562
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
 
1669
  nth1 = 32;
1670
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
1671
  } break;
1672
+ case GGML_TYPE_IQ2_XXS:
1673
+ {
1674
+ nth0 = 4;
1675
+ nth1 = 16;
1676
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
1677
+ } break;
1678
  default:
1679
  {
1680
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
 
1708
 
1709
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1710
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1711
+ //src0t == GGML_TYPE_IQ2_XXS ||
1712
  src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1713
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1714
  }
1715
+ else if (src0t == GGML_TYPE_IQ2_XXS) {
1716
+ [encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
1717
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1718
+ }
1719
  else if (src0t == GGML_TYPE_Q4_K) {
1720
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1721
  }
 
1805
  case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1806
  case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1807
  case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1808
+ case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
1809
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1810
  }
1811
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
 
1921
  nth1 = 32;
1922
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1923
  } break;
1924
+ case GGML_TYPE_IQ2_XXS:
1925
+ {
1926
+ nth0 = 4;
1927
+ nth1 = 16;
1928
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
1929
+ } break;
1930
  default:
1931
  {
1932
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
 
1976
 
1977
  if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1978
  src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1979
+ //src2t == GGML_TYPE_IQ2_XXS ||
1980
  src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1981
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1982
  }
1983
+ else if (src2t == GGML_TYPE_IQ2_XXS) {
1984
+ [encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
1985
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1986
+ }
1987
  else if (src2t == GGML_TYPE_Q4_K) {
1988
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1989
  }
 
2021
  case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
2022
  case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
2023
  case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
2024
+ case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
2025
  default: GGML_ASSERT(false && "not implemented");
2026
  }
2027
 
ggml-metal.metal CHANGED
@@ -2446,6 +2446,12 @@ typedef struct {
2446
  } block_q6_K;
2447
  // 210 bytes / block
2448
 
 
 
 
 
 
 
2449
  //====================================== dot products =========================
2450
 
2451
  void kernel_mul_mv_q2_K_f32_impl(
@@ -3468,6 +3474,221 @@ kernel void kernel_mul_mv_q6_K_f32(
3468
  kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3469
  }
3470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3471
  //============================= templates and their specializations =============================
3472
 
3473
  // NOTE: this is not dequantizing - we are simply fitting the template
@@ -3739,6 +3960,31 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3739
  }
3740
  }
3741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3742
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
3743
  kernel void kernel_get_rows(
3744
  device const void * src0,
@@ -4278,6 +4524,7 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows
4278
  template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
4279
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
4280
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
 
4281
 
4282
  //
4283
  // matrix-matrix multiplication
@@ -4314,6 +4561,7 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
4314
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
4315
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
4316
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
 
4317
 
4318
  //
4319
  // indirect matrix-matrix multiplication
@@ -4362,6 +4610,7 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
4362
  template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
4363
  template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
4364
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
 
4365
 
4366
  //
4367
  // matrix-vector multiplication
@@ -5134,3 +5383,68 @@ kernel void kernel_mul_mv_id_q6_K_f32(
5134
  tiisg,
5135
  sgitg);
5136
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2446
  } block_q6_K;
2447
  // 210 bytes / block
2448
 
2449
+ typedef struct {
2450
+ half d;
2451
+ uint16_t qs[QK_K/8];
2452
+ } block_iq2_xxs;
2453
+ // 66 bytes / block for QK_K = 256, so 2.0625 bpw
2454
+
2455
  //====================================== dot products =========================
2456
 
2457
  void kernel_mul_mv_q2_K_f32_impl(
 
3474
  kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3475
  }
3476
 
3477
+ // ======================= "True" 2-bit
3478
+
3479
+ constexpr constant static uint64_t kgrid_iq2xxs[256] = {
3480
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3481
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
3482
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
3483
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
3484
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
3485
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
3486
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
3487
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
3488
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
3489
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
3490
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
3491
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
3492
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
3493
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
3494
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
3495
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
3496
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
3497
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
3498
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
3499
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
3500
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
3501
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
3502
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
3503
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
3504
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
3505
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
3506
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
3507
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
3508
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
3509
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
3510
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
3511
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
3512
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
3513
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
3514
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
3515
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
3516
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
3517
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
3518
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
3519
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
3520
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
3521
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
3522
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
3523
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
3524
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
3525
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
3526
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
3527
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
3528
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
3529
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
3530
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
3531
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
3532
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
3533
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
3534
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
3535
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
3536
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
3537
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
3538
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
3539
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
3540
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
3541
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
3542
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
3543
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
3544
+ };
3545
+
3546
+ constexpr constant static uint8_t ksigns_iq2xs[128] = {
3547
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3548
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
3549
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
3550
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
3551
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
3552
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
3553
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
3554
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
3555
+ };
3556
+
3557
+ constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
3558
+
3559
+ void kernel_mul_mv_iq2_xxs_f32_impl(
3560
+ device const void * src0,
3561
+ device const float * src1,
3562
+ device float * dst,
3563
+ constant int64_t & ne00,
3564
+ constant int64_t & ne01,
3565
+ constant int64_t & ne02,
3566
+ constant int64_t & ne10,
3567
+ constant int64_t & ne12,
3568
+ constant int64_t & ne0,
3569
+ constant int64_t & ne1,
3570
+ constant uint & r2,
3571
+ constant uint & r3,
3572
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3573
+ uint3 tgpig[[threadgroup_position_in_grid]],
3574
+ uint tiisg[[thread_index_in_simdgroup]],
3575
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3576
+
3577
+ const int nb = ne00/QK_K;
3578
+ const int r0 = tgpig.x;
3579
+ const int r1 = tgpig.y;
3580
+ const int im = tgpig.z;
3581
+
3582
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3583
+ const int ib_row = first_row * nb;
3584
+
3585
+ const uint i12 = im%ne12;
3586
+ const uint i13 = im/ne12;
3587
+
3588
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3589
+
3590
+ device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
3591
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3592
+
3593
+ float yl[32];
3594
+ float sumf[N_DST]={0.f}, all_sum;
3595
+
3596
+ const int nb32 = nb * (QK_K / 32);
3597
+
3598
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
3599
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
3600
+ {
3601
+ int nval = 4;
3602
+ int pos = (32*sgitg + tiisg)*nval;
3603
+ for (int i = 0; i < nval; ++i) values[pos + i] = kgrid_iq2xxs[pos + i];
3604
+ nval = 2;
3605
+ pos = (32*sgitg + tiisg)*nval;
3606
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
3607
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3608
+ }
3609
+
3610
+ #if QK_K == 256
3611
+ const int ix = tiisg;
3612
+
3613
+ device const float * y4 = y + 32 * ix;
3614
+
3615
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
3616
+
3617
+ for (int i = 0; i < 32; ++i) {
3618
+ yl[i] = y4[i];
3619
+ }
3620
+
3621
+ const int ibl = ib32 / (QK_K / 32);
3622
+ const int ib = ib32 % (QK_K / 32);
3623
+
3624
+ device const block_iq2_xxs * xr = x + ibl;
3625
+ device const uint16_t * q2 = xr->qs + 4 * ib;
3626
+ device const half * dh = &xr->d;
3627
+
3628
+ for (int row = 0; row < N_DST; row++) {
3629
+
3630
+ const float db = dh[0];
3631
+ device const uint8_t * aux8 = (device const uint8_t *)q2;
3632
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
3633
+ const float d = db * (0.5f + (aux32 >> 28));
3634
+
3635
+ float sum = 0;
3636
+ for (int l = 0; l < 4; ++l) {
3637
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
3638
+ const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
3639
+ for (int j = 0; j < 8; ++j) {
3640
+ sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3641
+ }
3642
+ }
3643
+ sumf[row] += d * sum;
3644
+
3645
+ dh += nb*sizeof(block_iq2_xxs)/2;
3646
+ q2 += nb*sizeof(block_iq2_xxs)/2;
3647
+ }
3648
+
3649
+ y4 += 32 * 32;
3650
+ }
3651
+ #else
3652
+ // TODO
3653
+ #endif
3654
+
3655
+ for (int row = 0; row < N_DST; ++row) {
3656
+ all_sum = simd_sum(sumf[row]);
3657
+ if (tiisg == 0) {
3658
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
3659
+ }
3660
+ }
3661
+ }
3662
+
3663
+ [[host_name("kernel_mul_mv_iq2_xxs_f32")]]
3664
+ kernel void kernel_mul_mv_iq2_xxs_f32(
3665
+ device const void * src0,
3666
+ device const float * src1,
3667
+ device float * dst,
3668
+ constant int64_t & ne00,
3669
+ constant int64_t & ne01,
3670
+ constant int64_t & ne02,
3671
+ constant uint64_t & nb00,
3672
+ constant uint64_t & nb01,
3673
+ constant uint64_t & nb02,
3674
+ constant int64_t & ne10,
3675
+ constant int64_t & ne11,
3676
+ constant int64_t & ne12,
3677
+ constant uint64_t & nb10,
3678
+ constant uint64_t & nb11,
3679
+ constant uint64_t & nb12,
3680
+ constant int64_t & ne0,
3681
+ constant int64_t & ne1,
3682
+ constant uint & r2,
3683
+ constant uint & r3,
3684
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3685
+ uint3 tgpig[[threadgroup_position_in_grid]],
3686
+ uint tiisg[[thread_index_in_simdgroup]],
3687
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3688
+
3689
+ kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3690
+ }
3691
+
3692
  //============================= templates and their specializations =============================
3693
 
3694
  // NOTE: this is not dequantizing - we are simply fitting the template
 
3960
  }
3961
  }
3962
 
3963
+ template <typename type4x4>
3964
+ void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
3965
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
3966
+ const float d = xb->d;
3967
+ const int ib32 = il/2;
3968
+ il = il%2;
3969
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
3970
+ // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
3971
+ device const uint16_t * q2 = xb->qs + 4*ib32;
3972
+ const uint32_t aux32_g = q2[0] | (q2[1] << 16);
3973
+ const uint32_t aux32_s = q2[2] | (q2[3] << 16);
3974
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
3975
+ const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
3976
+ constant uint8_t * grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
3977
+ uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
3978
+ for (int i = 0; i < 8; ++i) {
3979
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
3980
+ }
3981
+ grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
3982
+ signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
3983
+ for (int i = 0; i < 8; ++i) {
3984
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
3985
+ }
3986
+ }
3987
+
3988
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
3989
  kernel void kernel_get_rows(
3990
  device const void * src0,
 
4524
  template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
4525
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
4526
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
4527
+ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4528
 
4529
  //
4530
  // matrix-matrix multiplication
 
4561
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
4562
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
4563
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4564
+ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4565
 
4566
  //
4567
  // indirect matrix-matrix multiplication
 
4610
  template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
4611
  template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
4612
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4613
+ template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4614
 
4615
  //
4616
  // matrix-vector multiplication
 
5383
  tiisg,
5384
  sgitg);
5385
  }
5386
+
5387
+ [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
5388
+ kernel void kernel_mul_mv_id_iq2_xxs_f32(
5389
+ device const char * ids,
5390
+ device const char * src1,
5391
+ device float * dst,
5392
+ constant uint64_t & nbi1,
5393
+ constant int64_t & ne00,
5394
+ constant int64_t & ne01,
5395
+ constant int64_t & ne02,
5396
+ constant uint64_t & nb00,
5397
+ constant uint64_t & nb01,
5398
+ constant uint64_t & nb02,
5399
+ constant int64_t & ne10,
5400
+ constant int64_t & ne11,
5401
+ constant int64_t & ne12,
5402
+ constant int64_t & ne13,
5403
+ constant uint64_t & nb10,
5404
+ constant uint64_t & nb11,
5405
+ constant uint64_t & nb12,
5406
+ constant int64_t & ne0,
5407
+ constant int64_t & ne1,
5408
+ constant uint64_t & nb1,
5409
+ constant uint & r2,
5410
+ constant uint & r3,
5411
+ constant int & idx,
5412
+ device const char * src00,
5413
+ device const char * src01,
5414
+ device const char * src02,
5415
+ device const char * src03,
5416
+ device const char * src04,
5417
+ device const char * src05,
5418
+ device const char * src06,
5419
+ device const char * src07,
5420
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5421
+ uint3 tgpig[[threadgroup_position_in_grid]],
5422
+ uint tiitg[[thread_index_in_threadgroup]],
5423
+ uint tiisg[[thread_index_in_simdgroup]],
5424
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5425
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5426
+
5427
+ const int64_t bid = tgpig.z/(ne12*ne13);
5428
+
5429
+ tgpig.z = tgpig.z%(ne12*ne13);
5430
+
5431
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5432
+
5433
+ kernel_mul_mv_iq2_xxs_f32_impl(
5434
+ src0[id],
5435
+ (device const float *) (src1 + bid*nb11),
5436
+ dst + bid*ne0,
5437
+ ne00,
5438
+ ne01,
5439
+ ne02,
5440
+ ne10,
5441
+ ne12,
5442
+ ne0,
5443
+ ne1,
5444
+ r2,
5445
+ r3,
5446
+ shared_values,
5447
+ tgpig,
5448
+ tiisg,
5449
+ sgitg);
5450
+ }
ggml-quants.c CHANGED
@@ -2340,6 +2340,138 @@ size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t *
2340
  return (n/QK_K*sizeof(block_q6_K));
2341
  }
2342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2343
  //===================================== Q8_K ==============================================
2344
 
2345
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -2362,7 +2494,9 @@ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict
2362
  x += QK_K;
2363
  continue;
2364
  }
2365
- const float iscale = -128.f/max;
 
 
2366
  for (int j = 0; j < QK_K; ++j) {
2367
  int v = nearest_int(iscale*x[j]);
2368
  y[i].qs[j] = MIN(127, v);
@@ -7065,3 +7199,161 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
7065
  }
7066
 
7067
  #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2340
  return (n/QK_K*sizeof(block_q6_K));
2341
  }
2342
 
2343
+ // ====================== "True" 2-bit (de)-quantization
2344
+
2345
+ void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) {
2346
+ (void)x;
2347
+ (void)y;
2348
+ (void)k;
2349
+ assert(k % QK_K == 0);
2350
+ //fprintf(stderr, "=========================== %s: not implemented\n", __func__);
2351
+ }
2352
+
2353
+ static const uint64_t iq2xxs_grid[256] = {
2354
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
2355
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
2356
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
2357
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
2358
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
2359
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
2360
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
2361
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
2362
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
2363
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
2364
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
2365
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
2366
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
2367
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
2368
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
2369
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
2370
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
2371
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
2372
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
2373
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
2374
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
2375
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
2376
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
2377
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
2378
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
2379
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
2380
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
2381
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
2382
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
2383
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
2384
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
2385
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
2386
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
2387
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
2388
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
2389
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
2390
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
2391
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
2392
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
2393
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
2394
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
2395
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
2396
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
2397
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
2398
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
2399
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
2400
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
2401
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
2402
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
2403
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
2404
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
2405
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
2406
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
2407
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
2408
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
2409
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
2410
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
2411
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
2412
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
2413
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
2414
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
2415
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
2416
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
2417
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
2418
+ };
2419
+
2420
+ static const uint8_t ksigns_iq2xs[128] = {
2421
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
2422
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
2423
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
2424
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
2425
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
2426
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
2427
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
2428
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
2429
+ };
2430
+ static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
2431
+
2432
+ void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
2433
+ assert(k % QK_K == 0);
2434
+ const int nb = k / QK_K;
2435
+
2436
+ uint32_t aux32[2];
2437
+ const uint8_t * aux8 = (const uint8_t *)aux32;
2438
+
2439
+ for (int i = 0; i < nb; i++) {
2440
+
2441
+ const float d = GGML_FP16_TO_FP32(x[i].d);
2442
+
2443
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
2444
+ memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t));
2445
+ const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;
2446
+ for (int l = 0; l < 4; ++l) {
2447
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
2448
+ const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
2449
+ for (int j = 0; j < 8; ++j) {
2450
+ y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
2451
+ }
2452
+ y += 8;
2453
+ }
2454
+ }
2455
+ }
2456
+ }
2457
+
2458
+ void quantize_row_iq2_xxs(const float * restrict x, void * restrict vy, int k) {
2459
+ assert(k % QK_K == 0);
2460
+ block_iq2_xxs * restrict y = vy;
2461
+ quantize_row_iq2_xxs_reference(x, y, k);
2462
+ }
2463
+
2464
+ size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist) {
2465
+ assert(k % QK_K == 0);
2466
+ (void)hist; // TODO: collect histograms
2467
+
2468
+ for (int j = 0; j < n; j += k) {
2469
+ block_iq2_xxs * restrict y = (block_iq2_xxs *)dst + j/QK_K;
2470
+ quantize_row_iq2_xxs_reference(src + j, y, k);
2471
+ }
2472
+ return (n/QK_K*sizeof(block_iq2_xxs));
2473
+ }
2474
+
2475
  //===================================== Q8_K ==============================================
2476
 
2477
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
 
2494
  x += QK_K;
2495
  continue;
2496
  }
2497
+ //const float iscale = -128.f/max;
2498
+ // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward
2499
+ const float iscale = -127.f/max;
2500
  for (int j = 0; j < QK_K; ++j) {
2501
  int v = nearest_int(iscale*x[j]);
2502
  y[i].qs[j] = MIN(127, v);
 
7199
  }
7200
 
7201
  #endif
7202
+
7203
+ static const int8_t keven_signs_q2xs[1024] = {
7204
+ 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
7205
+ 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
7206
+ 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
7207
+ 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
7208
+ 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
7209
+ 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
7210
+ 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
7211
+ 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
7212
+ 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
7213
+ 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
7214
+ 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
7215
+ 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
7216
+ 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
7217
+ 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
7218
+ 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
7219
+ 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
7220
+ 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
7221
+ 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
7222
+ 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
7223
+ 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
7224
+ 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
7225
+ 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
7226
+ 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
7227
+ 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
7228
+ 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
7229
+ 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
7230
+ 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
7231
+ 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
7232
+ 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
7233
+ 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
7234
+ 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
7235
+ 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
7236
+ };
7237
+
7238
+ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
7239
+ assert(n % QK_K == 0);
7240
+
7241
+ const block_iq2_xxs * restrict x = vx;
7242
+ const block_q8_K * restrict y = vy;
7243
+
7244
+ const int nb = n / QK_K;
7245
+
7246
+ #if defined(__ARM_NEON)
7247
+
7248
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
7249
+
7250
+ uint32_t aux32[4];
7251
+ const uint8_t * aux8 = (const uint8_t *)aux32;
7252
+
7253
+ int8x16x4_t q2u;
7254
+ int8x16x4_t q2s;
7255
+ int8x16x4_t q8b;
7256
+
7257
+ float sumf = 0;
7258
+ for (int i = 0; i < nb; ++i) {
7259
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
7260
+ const uint16_t * restrict q2 = x[i].qs;
7261
+ const int8_t * restrict q8 = y[i].qs;
7262
+ float sumf1 = 0, sumf2 = 0;
7263
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
7264
+ q8b = vld1q_s8_x4(q8); q8 += 64;
7265
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
7266
+ q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
7267
+ q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
7268
+ q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
7269
+ q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
7270
+ q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
7271
+ q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
7272
+ q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127))));
7273
+ q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
7274
+ q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
7275
+ q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
7276
+ q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
7277
+ q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
7278
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
7279
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
7280
+ sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
7281
+ sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
7282
+ }
7283
+ sumf += d*(sumf1 + sumf2);
7284
+ }
7285
+ *s = 0.25f * sumf;
7286
+
7287
+ #elif defined(__AVX2__)
7288
+
7289
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
7290
+
7291
+ uint32_t aux32[4];
7292
+ const uint8_t * aux8 = (const uint8_t *)aux32;
7293
+
7294
+ __m256 accumf = _mm256_setzero_ps();
7295
+ for (int i = 0; i < nb; ++i) {
7296
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
7297
+ const uint16_t * restrict q2 = x[i].qs;
7298
+ const int8_t * restrict q8 = y[i].qs;
7299
+ __m256i sumi1 = _mm256_setzero_si256();
7300
+ __m256i sumi2 = _mm256_setzero_si256();
7301
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
7302
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
7303
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
7304
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
7305
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
7306
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
7307
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
7308
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
7309
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
7310
+ signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
7311
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
7312
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
7313
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
7314
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
7315
+ const uint16_t ls1 = aux32[1] >> 28;
7316
+ const uint16_t ls2 = aux32[3] >> 28;
7317
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
7318
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
7319
+ sumi1 = _mm256_add_epi32(sumi1, p1);
7320
+ sumi2 = _mm256_add_epi32(sumi2, p2);
7321
+ }
7322
+
7323
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
7324
+
7325
+ }
7326
+
7327
+ *s = 0.125f * hsum_float_8(accumf);
7328
+
7329
+ #else
7330
+
7331
+ uint32_t aux32[2];
7332
+ const uint8_t * aux8 = (const uint8_t *)aux32;
7333
+
7334
+ float sumf = 0.f;
7335
+ for (int i = 0; i < nb; ++i) {
7336
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
7337
+ const uint16_t * restrict q2 = x[i].qs;
7338
+ const int8_t * restrict q8 = y[i].qs;
7339
+ int32_t bsum = 0;
7340
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
7341
+ memcpy(aux32, q2, 2*sizeof(uint32_t));
7342
+ q2 += 4;
7343
+ const uint32_t ls = 2*(aux32[1] >> 28) + 1;
7344
+ int32_t sumi = 0;
7345
+ for (int l = 0; l < 4; ++l) {
7346
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
7347
+ const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
7348
+ for (int j = 0; j < 8; ++j) {
7349
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
7350
+ }
7351
+ q8 += 8;
7352
+ }
7353
+ bsum += sumi * ls;
7354
+ }
7355
+ sumf += d * bsum;
7356
+ }
7357
+ *s = 0.125f * sumf;
7358
+ #endif
7359
+ }
ggml-quants.h CHANGED
@@ -165,6 +165,14 @@ typedef struct {
165
  } block_q8_K;
166
  static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
167
 
 
 
 
 
 
 
 
 
168
 
169
  // Quantization
170
  void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
@@ -180,6 +188,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
180
  void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
181
  void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
182
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
 
183
 
184
  void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
185
  void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
@@ -194,6 +203,7 @@ void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
194
  void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
195
  void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
196
  void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
 
197
 
198
  // Dequantization
199
  void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
@@ -209,6 +219,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
209
  void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
210
  void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
211
  void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
 
212
 
213
  // Dot product
214
  void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
@@ -222,3 +233,4 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx,
222
  void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
223
  void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
224
  void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 
 
165
  } block_q8_K;
166
  static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
167
 
168
+ // (Almost) "true" 2-bit quantization.
169
+ // Due to the need to use blocks as per ggml dsign, it ends up using
170
+ // 2.0625 bpw because of the 16-bit scale for each block of 256.
171
+ typedef struct {
172
+ ggml_fp16_t d;
173
+ uint16_t qs[QK_K/8];
174
+ } block_iq2_xxs;
175
+ static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
176
 
177
  // Quantization
178
  void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
 
188
  void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
189
  void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
190
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
191
+ void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k);
192
 
193
  void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
194
  void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
 
203
  void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
204
  void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
205
  void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
206
+ void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k);
207
 
208
  // Dequantization
209
  void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
 
219
  void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
220
  void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
221
  void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
222
+ void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
223
 
224
  // Dot product
225
  void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 
233
  void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
234
  void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
235
  void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
236
+ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
ggml.c CHANGED
@@ -573,6 +573,17 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
573
  .vec_dot = ggml_vec_dot_q6_K_q8_K,
574
  .vec_dot_type = GGML_TYPE_Q8_K,
575
  },
 
 
 
 
 
 
 
 
 
 
 
576
  [GGML_TYPE_Q8_K] = {
577
  .type_name = "q8_K",
578
  .blck_size = QK_K,
@@ -2111,6 +2122,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2111
  case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
2112
  case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
2113
  case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
 
2114
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
2115
  case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
2116
  }
@@ -7436,6 +7448,7 @@ static void ggml_compute_forward_add(
7436
  case GGML_TYPE_Q4_K:
7437
  case GGML_TYPE_Q5_K:
7438
  case GGML_TYPE_Q6_K:
 
7439
  {
7440
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
7441
  } break;
@@ -7700,6 +7713,7 @@ static void ggml_compute_forward_add1(
7700
  case GGML_TYPE_Q4_K:
7701
  case GGML_TYPE_Q5_K:
7702
  case GGML_TYPE_Q6_K:
 
7703
  {
7704
  ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
7705
  } break;
@@ -7814,6 +7828,7 @@ static void ggml_compute_forward_acc(
7814
  case GGML_TYPE_Q4_K:
7815
  case GGML_TYPE_Q5_K:
7816
  case GGML_TYPE_Q6_K:
 
7817
  default:
7818
  {
7819
  GGML_ASSERT(false);
@@ -10455,6 +10470,7 @@ static void ggml_compute_forward_out_prod(
10455
  case GGML_TYPE_Q4_K:
10456
  case GGML_TYPE_Q5_K:
10457
  case GGML_TYPE_Q6_K:
 
10458
  {
10459
  ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
10460
  } break;
@@ -10629,6 +10645,7 @@ static void ggml_compute_forward_set(
10629
  case GGML_TYPE_Q4_K:
10630
  case GGML_TYPE_Q5_K:
10631
  case GGML_TYPE_Q6_K:
 
10632
  default:
10633
  {
10634
  GGML_ASSERT(false);
@@ -10823,6 +10840,7 @@ static void ggml_compute_forward_get_rows(
10823
  case GGML_TYPE_Q4_K:
10824
  case GGML_TYPE_Q5_K:
10825
  case GGML_TYPE_Q6_K:
 
10826
  {
10827
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
10828
  } break;
@@ -11459,6 +11477,7 @@ static void ggml_compute_forward_alibi(
11459
  case GGML_TYPE_Q4_K:
11460
  case GGML_TYPE_Q5_K:
11461
  case GGML_TYPE_Q6_K:
 
11462
  case GGML_TYPE_Q8_K:
11463
  case GGML_TYPE_I8:
11464
  case GGML_TYPE_I16:
@@ -11533,6 +11552,7 @@ static void ggml_compute_forward_clamp(
11533
  case GGML_TYPE_Q4_K:
11534
  case GGML_TYPE_Q5_K:
11535
  case GGML_TYPE_Q6_K:
 
11536
  case GGML_TYPE_Q8_K:
11537
  case GGML_TYPE_I8:
11538
  case GGML_TYPE_I16:
@@ -18648,6 +18668,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
18648
  block_q6_K * block = (block_q6_K*)dst + start / QK_K;
18649
  result = ggml_quantize_q6_K(src + start, block, n, n, hist);
18650
  } break;
 
 
 
 
 
 
18651
  case GGML_TYPE_F16:
18652
  {
18653
  int elemsize = sizeof(ggml_fp16_t);
 
573
  .vec_dot = ggml_vec_dot_q6_K_q8_K,
574
  .vec_dot_type = GGML_TYPE_Q8_K,
575
  },
576
+ [GGML_TYPE_IQ2_XXS] = {
577
+ .type_name = "iq2_xxs",
578
+ .blck_size = QK_K,
579
+ .type_size = sizeof(block_iq2_xxs),
580
+ .is_quantized = true,
581
+ .to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
582
+ .from_float = quantize_row_iq2_xxs,
583
+ .from_float_reference = (ggml_from_float_t) quantize_row_iq2_xxs_reference,
584
+ .vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
585
+ .vec_dot_type = GGML_TYPE_Q8_K,
586
+ },
587
  [GGML_TYPE_Q8_K] = {
588
  .type_name = "q8_K",
589
  .blck_size = QK_K,
 
2122
  case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
2123
  case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
2124
  case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
2125
+ case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
2126
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
2127
  case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
2128
  }
 
7448
  case GGML_TYPE_Q4_K:
7449
  case GGML_TYPE_Q5_K:
7450
  case GGML_TYPE_Q6_K:
7451
+ case GGML_TYPE_IQ2_XXS:
7452
  {
7453
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
7454
  } break;
 
7713
  case GGML_TYPE_Q4_K:
7714
  case GGML_TYPE_Q5_K:
7715
  case GGML_TYPE_Q6_K:
7716
+ case GGML_TYPE_IQ2_XXS:
7717
  {
7718
  ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
7719
  } break;
 
7828
  case GGML_TYPE_Q4_K:
7829
  case GGML_TYPE_Q5_K:
7830
  case GGML_TYPE_Q6_K:
7831
+ case GGML_TYPE_IQ2_XXS:
7832
  default:
7833
  {
7834
  GGML_ASSERT(false);
 
10470
  case GGML_TYPE_Q4_K:
10471
  case GGML_TYPE_Q5_K:
10472
  case GGML_TYPE_Q6_K:
10473
+ case GGML_TYPE_IQ2_XXS:
10474
  {
10475
  ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
10476
  } break;
 
10645
  case GGML_TYPE_Q4_K:
10646
  case GGML_TYPE_Q5_K:
10647
  case GGML_TYPE_Q6_K:
10648
+ case GGML_TYPE_IQ2_XXS:
10649
  default:
10650
  {
10651
  GGML_ASSERT(false);
 
10840
  case GGML_TYPE_Q4_K:
10841
  case GGML_TYPE_Q5_K:
10842
  case GGML_TYPE_Q6_K:
10843
+ case GGML_TYPE_IQ2_XXS:
10844
  {
10845
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
10846
  } break;
 
11477
  case GGML_TYPE_Q4_K:
11478
  case GGML_TYPE_Q5_K:
11479
  case GGML_TYPE_Q6_K:
11480
+ case GGML_TYPE_IQ2_XXS:
11481
  case GGML_TYPE_Q8_K:
11482
  case GGML_TYPE_I8:
11483
  case GGML_TYPE_I16:
 
11552
  case GGML_TYPE_Q4_K:
11553
  case GGML_TYPE_Q5_K:
11554
  case GGML_TYPE_Q6_K:
11555
+ case GGML_TYPE_IQ2_XXS:
11556
  case GGML_TYPE_Q8_K:
11557
  case GGML_TYPE_I8:
11558
  case GGML_TYPE_I16:
 
18668
  block_q6_K * block = (block_q6_K*)dst + start / QK_K;
18669
  result = ggml_quantize_q6_K(src + start, block, n, n, hist);
18670
  } break;
18671
+ case GGML_TYPE_IQ2_XXS:
18672
+ {
18673
+ GGML_ASSERT(start % QK_K == 0);
18674
+ block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K;
18675
+ result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist);
18676
+ } break;
18677
  case GGML_TYPE_F16:
18678
  {
18679
  int elemsize = sizeof(ggml_fp16_t);
ggml.h CHANGED
@@ -339,6 +339,7 @@ extern "C" {
339
  GGML_TYPE_Q5_K = 13,
340
  GGML_TYPE_Q6_K = 14,
341
  GGML_TYPE_Q8_K = 15,
 
342
  GGML_TYPE_I8,
343
  GGML_TYPE_I16,
344
  GGML_TYPE_I32,
@@ -373,6 +374,7 @@ extern "C" {
373
  GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
374
  GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
375
  GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
 
376
  };
377
 
378
  // available tensor operations:
@@ -2067,6 +2069,7 @@ extern "C" {
2067
  GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
2068
  GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
2069
  GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
 
2070
 
2071
  GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
2072
 
 
339
  GGML_TYPE_Q5_K = 13,
340
  GGML_TYPE_Q6_K = 14,
341
  GGML_TYPE_Q8_K = 15,
342
+ GGML_TYPE_IQ2_XXS = 16,
343
  GGML_TYPE_I8,
344
  GGML_TYPE_I16,
345
  GGML_TYPE_I32,
 
374
  GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
375
  GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
376
  GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
377
+ GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
378
  };
379
 
380
  // available tensor operations:
 
2069
  GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
2070
  GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
2071
  GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
2072
+ GGML_API size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist);
2073
 
2074
  GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
2075