Spaces:
Running
Running
1.5 bit quantization (llama/5453)
Browse files* iq1_s: WIP basics
* iq1_s: CUDA is working
* iq1_s: scalar CPU dot product
* iq1_s: WIP AVX2 dot product - something is not right
* Fix tests
* Fix shadow warnings
* Fix after merge with latest master
* iq1_s: AVX2 finally works
* iq1_s: ARM_NEON dot product. Works, but not very fast
* iq1_s: better grid
* iq1_s: use IQ2_XXS for attn_output
At a cost of 0.04 extra bpw this gives a big improvement in PPL.
* iq1_s: Metal basics
Dequantize works, but not dot product
* iq1_s: Metal works, but quite slow
As usual, Apple Silicon does not like the code I write.
* iq1_s: Tests
* iq1_s: slightly faster dot product
---------
Co-authored-by: Iwan Kawrakow <[email protected]>
- ggml-backend.c +1 -1
- ggml-cuda.cu +223 -1
- ggml-metal.m +27 -2
- ggml-metal.metal +337 -0
- ggml-quants.c +627 -30
- ggml-quants.h +12 -2
- ggml.c +39 -5
- ggml.h +2 -0
ggml-backend.c
CHANGED
|
@@ -756,7 +756,7 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
|
|
| 756 |
GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
| 757 |
switch (op->op) {
|
| 758 |
case GGML_OP_CPY:
|
| 759 |
-
return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS; // missing type_traits.from_float
|
| 760 |
case GGML_OP_MUL_MAT:
|
| 761 |
return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
|
| 762 |
default:
|
|
|
|
| 756 |
GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
| 757 |
switch (op->op) {
|
| 758 |
case GGML_OP_CPY:
|
| 759 |
+
return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS && op->type != GGML_TYPE_IQ1_S; // missing type_traits.from_float
|
| 760 |
case GGML_OP_MUL_MAT:
|
| 761 |
return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
|
| 762 |
default:
|
ggml-cuda.cu
CHANGED
|
@@ -517,6 +517,15 @@ typedef struct {
|
|
| 517 |
} block_iq3_xxs;
|
| 518 |
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
|
| 519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
#define WARP_SIZE 32
|
| 521 |
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
| 522 |
|
|
@@ -1681,6 +1690,137 @@ static const __device__ uint32_t iq3xxs_grid[256] = {
|
|
| 1681 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 1682 |
};
|
| 1683 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1684 |
static const __device__ uint8_t ksigns_iq2xs[128] = {
|
| 1685 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 1686 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
@@ -1823,6 +1963,29 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
|
|
| 1823 |
|
| 1824 |
}
|
| 1825 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1826 |
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
| 1827 |
|
| 1828 |
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
|
@@ -4522,6 +4685,49 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
|
| 4522 |
#endif
|
| 4523 |
}
|
| 4524 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4525 |
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
| 4526 |
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
| 4527 |
static __device__ __forceinline__ void mul_mat_q(
|
|
@@ -6561,6 +6767,12 @@ static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k,
|
|
| 6561 |
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
| 6562 |
}
|
| 6563 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6564 |
template <typename src_t, typename dst_t>
|
| 6565 |
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
| 6566 |
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
|
@@ -6600,6 +6812,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|
| 6600 |
return dequantize_row_iq2_xs_cuda;
|
| 6601 |
case GGML_TYPE_IQ3_XXS:
|
| 6602 |
return dequantize_row_iq3_xxs_cuda;
|
|
|
|
|
|
|
| 6603 |
case GGML_TYPE_F32:
|
| 6604 |
return convert_unary_cuda<float>;
|
| 6605 |
default:
|
|
@@ -6635,6 +6849,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|
| 6635 |
return dequantize_row_iq2_xs_cuda;
|
| 6636 |
case GGML_TYPE_IQ3_XXS:
|
| 6637 |
return dequantize_row_iq3_xxs_cuda;
|
|
|
|
|
|
|
| 6638 |
case GGML_TYPE_F16:
|
| 6639 |
return convert_unary_cuda<half>;
|
| 6640 |
default:
|
|
@@ -8378,6 +8594,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
|
|
| 8378 |
case GGML_TYPE_IQ2_XXS:
|
| 8379 |
case GGML_TYPE_IQ2_XS:
|
| 8380 |
case GGML_TYPE_IQ3_XXS:
|
|
|
|
| 8381 |
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
| 8382 |
default:
|
| 8383 |
GGML_ASSERT(false);
|
|
@@ -8401,6 +8618,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
|
|
| 8401 |
case GGML_TYPE_IQ2_XXS:
|
| 8402 |
case GGML_TYPE_IQ2_XS:
|
| 8403 |
case GGML_TYPE_IQ3_XXS:
|
|
|
|
| 8404 |
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
| 8405 |
case GGML_TYPE_Q6_K:
|
| 8406 |
return 64;
|
|
@@ -8498,6 +8716,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
|
| 8498 |
mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
| 8499 |
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
| 8500 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8501 |
default:
|
| 8502 |
GGML_ASSERT(false);
|
| 8503 |
break;
|
|
@@ -11214,7 +11436,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 11214 |
return false;
|
| 11215 |
}
|
| 11216 |
ggml_type a_type = a->type;
|
| 11217 |
-
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS) {
|
| 11218 |
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
| 11219 |
return false;
|
| 11220 |
}
|
|
|
|
| 517 |
} block_iq3_xxs;
|
| 518 |
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
|
| 519 |
|
| 520 |
+
#define QR1_S 8
|
| 521 |
+
#define QI1_S (QK_K / (4*QR1_S))
|
| 522 |
+
typedef struct {
|
| 523 |
+
half d;
|
| 524 |
+
uint8_t qs[QK_K/8];
|
| 525 |
+
uint8_t scales[QK_K/16];
|
| 526 |
+
} block_iq1_s;
|
| 527 |
+
static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
|
| 528 |
+
|
| 529 |
#define WARP_SIZE 32
|
| 530 |
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
| 531 |
|
|
|
|
| 1690 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 1691 |
};
|
| 1692 |
|
| 1693 |
+
static const __device__ uint64_t iq1s_grid[512] = {
|
| 1694 |
+
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
| 1695 |
+
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
| 1696 |
+
0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
|
| 1697 |
+
0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
|
| 1698 |
+
0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
|
| 1699 |
+
0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
|
| 1700 |
+
0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
|
| 1701 |
+
0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
|
| 1702 |
+
0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
|
| 1703 |
+
0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
|
| 1704 |
+
0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
|
| 1705 |
+
0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
|
| 1706 |
+
0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
|
| 1707 |
+
0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
|
| 1708 |
+
0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
|
| 1709 |
+
0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
|
| 1710 |
+
0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
|
| 1711 |
+
0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
|
| 1712 |
+
0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
|
| 1713 |
+
0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
|
| 1714 |
+
0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
|
| 1715 |
+
0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
|
| 1716 |
+
0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
|
| 1717 |
+
0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
|
| 1718 |
+
0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
|
| 1719 |
+
0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
|
| 1720 |
+
0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
|
| 1721 |
+
0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
|
| 1722 |
+
0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
|
| 1723 |
+
0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
|
| 1724 |
+
0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
|
| 1725 |
+
0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
|
| 1726 |
+
0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
|
| 1727 |
+
0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
|
| 1728 |
+
0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
|
| 1729 |
+
0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
|
| 1730 |
+
0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
|
| 1731 |
+
0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
|
| 1732 |
+
0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
|
| 1733 |
+
0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
|
| 1734 |
+
0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
|
| 1735 |
+
0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
|
| 1736 |
+
0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
|
| 1737 |
+
0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
|
| 1738 |
+
0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
|
| 1739 |
+
0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
|
| 1740 |
+
0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
|
| 1741 |
+
0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
|
| 1742 |
+
0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
|
| 1743 |
+
0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
|
| 1744 |
+
0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
|
| 1745 |
+
0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
|
| 1746 |
+
0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
|
| 1747 |
+
0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
|
| 1748 |
+
0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
|
| 1749 |
+
0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
|
| 1750 |
+
0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
|
| 1751 |
+
0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
|
| 1752 |
+
0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
|
| 1753 |
+
0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
|
| 1754 |
+
0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
|
| 1755 |
+
0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
|
| 1756 |
+
0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
|
| 1757 |
+
0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
|
| 1758 |
+
0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
|
| 1759 |
+
0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
|
| 1760 |
+
0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
|
| 1761 |
+
0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
|
| 1762 |
+
0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
|
| 1763 |
+
0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
|
| 1764 |
+
0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
|
| 1765 |
+
0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
|
| 1766 |
+
0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
|
| 1767 |
+
0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
|
| 1768 |
+
0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
|
| 1769 |
+
0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
|
| 1770 |
+
0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
|
| 1771 |
+
0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
|
| 1772 |
+
0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
|
| 1773 |
+
0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
|
| 1774 |
+
0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
|
| 1775 |
+
0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
|
| 1776 |
+
0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
|
| 1777 |
+
0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
|
| 1778 |
+
0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
|
| 1779 |
+
0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
|
| 1780 |
+
0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
|
| 1781 |
+
0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
|
| 1782 |
+
0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
|
| 1783 |
+
0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
|
| 1784 |
+
0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
|
| 1785 |
+
0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
|
| 1786 |
+
0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
|
| 1787 |
+
0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
|
| 1788 |
+
0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
|
| 1789 |
+
0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
|
| 1790 |
+
0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
|
| 1791 |
+
0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
|
| 1792 |
+
0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
|
| 1793 |
+
0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
|
| 1794 |
+
0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
|
| 1795 |
+
0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
|
| 1796 |
+
0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
|
| 1797 |
+
0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
|
| 1798 |
+
0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
|
| 1799 |
+
0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
|
| 1800 |
+
0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
|
| 1801 |
+
0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
|
| 1802 |
+
0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
|
| 1803 |
+
0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
|
| 1804 |
+
0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
|
| 1805 |
+
0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
|
| 1806 |
+
0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
|
| 1807 |
+
0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
|
| 1808 |
+
0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
|
| 1809 |
+
0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
|
| 1810 |
+
0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
|
| 1811 |
+
0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
|
| 1812 |
+
0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
|
| 1813 |
+
0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
|
| 1814 |
+
0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
|
| 1815 |
+
0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
|
| 1816 |
+
0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
|
| 1817 |
+
0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
|
| 1818 |
+
0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
|
| 1819 |
+
0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
|
| 1820 |
+
0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
|
| 1821 |
+
0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
|
| 1822 |
+
};
|
| 1823 |
+
|
| 1824 |
static const __device__ uint8_t ksigns_iq2xs[128] = {
|
| 1825 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 1826 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
|
|
| 1963 |
|
| 1964 |
}
|
| 1965 |
|
| 1966 |
+
template<typename dst_t>
|
| 1967 |
+
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
| 1968 |
+
|
| 1969 |
+
const int i = blockIdx.x;
|
| 1970 |
+
const block_iq1_s * x = (const block_iq1_s *) vx;
|
| 1971 |
+
|
| 1972 |
+
const int tid = threadIdx.x;
|
| 1973 |
+
#if QK_K == 256
|
| 1974 |
+
const int il = tid/8; // 0...3
|
| 1975 |
+
const int ib = tid%8; // 0...7
|
| 1976 |
+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
| 1977 |
+
const int i8 = 4*ib+il;
|
| 1978 |
+
uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
|
| 1979 |
+
const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
|
| 1980 |
+
const float d = (float)x[i].d * (2*(h & 7) + 1);
|
| 1981 |
+
for (int j = 0; j < 8; ++j) y[j] = d * grid[j];
|
| 1982 |
+
#else
|
| 1983 |
+
assert(false);
|
| 1984 |
+
#endif
|
| 1985 |
+
|
| 1986 |
+
}
|
| 1987 |
+
|
| 1988 |
+
|
| 1989 |
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
| 1990 |
|
| 1991 |
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
|
|
|
| 4685 |
#endif
|
| 4686 |
}
|
| 4687 |
|
| 4688 |
+
static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
| 4689 |
+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
| 4690 |
+
#if QK_K == 256
|
| 4691 |
+
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
|
| 4692 |
+
|
| 4693 |
+
const int ib32 = iqs;
|
| 4694 |
+
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
|
| 4695 |
+
const uint8_t h1 = bq1->scales[2*ib32+0];
|
| 4696 |
+
const uint8_t h2 = bq1->scales[2*ib32+1];
|
| 4697 |
+
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
| 4698 |
+
const int * q8 = (const int *)bq8_1[ib32].qs;
|
| 4699 |
+
const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
|
| 4700 |
+
const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
|
| 4701 |
+
const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
|
| 4702 |
+
const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
|
| 4703 |
+
for (int j = 0; j < 2; ++j) {
|
| 4704 |
+
sumi1 = __dp4a(q8[j+0], grid1[j], sumi1);
|
| 4705 |
+
sumi2 = __dp4a(q8[j+2], grid2[j], sumi2);
|
| 4706 |
+
sumi3 = __dp4a(q8[j+4], grid3[j], sumi3);
|
| 4707 |
+
sumi4 = __dp4a(q8[j+6], grid4[j], sumi4);
|
| 4708 |
+
}
|
| 4709 |
+
#else
|
| 4710 |
+
const int8_t * q8 = bq8_1[ib32].qs;
|
| 4711 |
+
const int8_t * grid1 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
|
| 4712 |
+
const int8_t * grid2 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
|
| 4713 |
+
const int8_t * grid3 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
|
| 4714 |
+
const int8_t * grid4 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
|
| 4715 |
+
for (int j = 0; j < 8; ++j) {
|
| 4716 |
+
sumi1 += q8[j+ 0] * grid1[j];
|
| 4717 |
+
sumi2 += q8[j+ 8] * grid2[j];
|
| 4718 |
+
sumi3 += q8[j+16] * grid3[j];
|
| 4719 |
+
sumi4 += q8[j+24] * grid4[j];
|
| 4720 |
+
}
|
| 4721 |
+
#endif
|
| 4722 |
+
const float d = (float)bq1->d * __low2float(bq8_1[ib32].ds);
|
| 4723 |
+
return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) +
|
| 4724 |
+
sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1));
|
| 4725 |
+
#else
|
| 4726 |
+
assert(false);
|
| 4727 |
+
return 0.f;
|
| 4728 |
+
#endif
|
| 4729 |
+
}
|
| 4730 |
+
|
| 4731 |
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
| 4732 |
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
| 4733 |
static __device__ __forceinline__ void mul_mat_q(
|
|
|
|
| 6767 |
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
| 6768 |
}
|
| 6769 |
|
| 6770 |
+
template<typename dst_t>
|
| 6771 |
+
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
| 6772 |
+
const int nb = k / QK_K;
|
| 6773 |
+
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
|
| 6774 |
+
}
|
| 6775 |
+
|
| 6776 |
template <typename src_t, typename dst_t>
|
| 6777 |
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
| 6778 |
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
|
|
|
| 6812 |
return dequantize_row_iq2_xs_cuda;
|
| 6813 |
case GGML_TYPE_IQ3_XXS:
|
| 6814 |
return dequantize_row_iq3_xxs_cuda;
|
| 6815 |
+
case GGML_TYPE_IQ1_S:
|
| 6816 |
+
return dequantize_row_iq1_s_cuda;
|
| 6817 |
case GGML_TYPE_F32:
|
| 6818 |
return convert_unary_cuda<float>;
|
| 6819 |
default:
|
|
|
|
| 6849 |
return dequantize_row_iq2_xs_cuda;
|
| 6850 |
case GGML_TYPE_IQ3_XXS:
|
| 6851 |
return dequantize_row_iq3_xxs_cuda;
|
| 6852 |
+
case GGML_TYPE_IQ1_S:
|
| 6853 |
+
return dequantize_row_iq1_s_cuda;
|
| 6854 |
case GGML_TYPE_F16:
|
| 6855 |
return convert_unary_cuda<half>;
|
| 6856 |
default:
|
|
|
|
| 8594 |
case GGML_TYPE_IQ2_XXS:
|
| 8595 |
case GGML_TYPE_IQ2_XS:
|
| 8596 |
case GGML_TYPE_IQ3_XXS:
|
| 8597 |
+
case GGML_TYPE_IQ1_S:
|
| 8598 |
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
| 8599 |
default:
|
| 8600 |
GGML_ASSERT(false);
|
|
|
|
| 8618 |
case GGML_TYPE_IQ2_XXS:
|
| 8619 |
case GGML_TYPE_IQ2_XS:
|
| 8620 |
case GGML_TYPE_IQ3_XXS:
|
| 8621 |
+
case GGML_TYPE_IQ1_S:
|
| 8622 |
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
| 8623 |
case GGML_TYPE_Q6_K:
|
| 8624 |
return 64;
|
|
|
|
| 8716 |
mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
| 8717 |
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
| 8718 |
break;
|
| 8719 |
+
case GGML_TYPE_IQ1_S:
|
| 8720 |
+
mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
|
| 8721 |
+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
| 8722 |
+
break;
|
| 8723 |
default:
|
| 8724 |
GGML_ASSERT(false);
|
| 8725 |
break;
|
|
|
|
| 11436 |
return false;
|
| 11437 |
}
|
| 11438 |
ggml_type a_type = a->type;
|
| 11439 |
+
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ1_S) {
|
| 11440 |
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
| 11441 |
return false;
|
| 11442 |
}
|
ggml-metal.m
CHANGED
|
@@ -61,6 +61,7 @@ enum ggml_metal_kernel_type {
|
|
| 61 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
| 62 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
| 63 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
|
|
|
|
| 64 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
| 65 |
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
| 66 |
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
@@ -83,6 +84,7 @@ enum ggml_metal_kernel_type {
|
|
| 83 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
| 84 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
| 85 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
|
|
|
|
| 86 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
| 87 |
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
| 88 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
|
@@ -101,6 +103,7 @@ enum ggml_metal_kernel_type {
|
|
| 101 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
| 102 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
| 103 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
|
|
|
|
| 104 |
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
| 105 |
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
| 106 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
|
@@ -116,6 +119,7 @@ enum ggml_metal_kernel_type {
|
|
| 116 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
| 117 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
| 118 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
|
|
|
|
| 119 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
| 120 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
| 121 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
|
@@ -131,6 +135,7 @@ enum ggml_metal_kernel_type {
|
|
| 131 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
| 132 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
| 133 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
|
|
|
|
| 134 |
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
| 135 |
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
| 136 |
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
|
@@ -442,6 +447,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 442 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
| 443 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
| 444 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
|
|
|
| 445 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
| 446 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
| 447 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
|
@@ -464,6 +470,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 464 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
| 465 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
| 466 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
|
|
|
| 467 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
| 468 |
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
| 469 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
|
@@ -482,6 +489,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 482 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
| 483 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
| 484 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
|
|
|
| 485 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
| 486 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
| 487 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
|
@@ -497,6 +505,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 497 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
| 498 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
| 499 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
|
|
|
| 500 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
| 501 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
| 502 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
|
@@ -512,6 +521,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 512 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
| 513 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
| 514 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
|
|
|
| 515 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
| 516 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
| 517 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
|
@@ -1327,6 +1337,7 @@ static bool ggml_metal_graph_compute(
|
|
| 1327 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
| 1328 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
| 1329 |
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
|
|
|
| 1330 |
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
| 1331 |
}
|
| 1332 |
|
|
@@ -1461,6 +1472,12 @@ static bool ggml_metal_graph_compute(
|
|
| 1461 |
nth1 = 16;
|
| 1462 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
| 1463 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1464 |
default:
|
| 1465 |
{
|
| 1466 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
@@ -1495,7 +1512,7 @@ static bool ggml_metal_graph_compute(
|
|
| 1495 |
|
| 1496 |
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
| 1497 |
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
| 1498 |
-
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
| 1499 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1500 |
}
|
| 1501 |
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
|
@@ -1601,6 +1618,7 @@ static bool ggml_metal_graph_compute(
|
|
| 1601 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
| 1602 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
| 1603 |
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
|
|
|
| 1604 |
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
| 1605 |
}
|
| 1606 |
|
|
@@ -1738,6 +1756,12 @@ static bool ggml_metal_graph_compute(
|
|
| 1738 |
nth1 = 16;
|
| 1739 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
| 1740 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1741 |
default:
|
| 1742 |
{
|
| 1743 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
|
@@ -1788,7 +1812,7 @@ static bool ggml_metal_graph_compute(
|
|
| 1788 |
|
| 1789 |
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
| 1790 |
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
| 1791 |
-
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
| 1792 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1793 |
}
|
| 1794 |
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
|
@@ -1842,6 +1866,7 @@ static bool ggml_metal_graph_compute(
|
|
| 1842 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
| 1843 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
| 1844 |
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
|
|
|
|
| 1845 |
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
| 1846 |
default: GGML_ASSERT(false && "not implemented");
|
| 1847 |
}
|
|
|
|
| 61 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
| 62 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
| 63 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
|
| 64 |
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
| 65 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
| 66 |
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
| 67 |
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
|
|
| 84 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
| 85 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
| 86 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
|
| 87 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
|
| 88 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
| 89 |
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
| 90 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
|
|
|
| 103 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
| 104 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
| 105 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
|
| 106 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
|
| 107 |
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
| 108 |
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
| 109 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
|
|
|
| 119 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
| 120 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
| 121 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
|
| 122 |
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
|
| 123 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
| 124 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
| 125 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
|
|
|
| 135 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
| 136 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
| 137 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
|
| 138 |
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
| 139 |
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
| 140 |
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
| 141 |
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
|
|
|
| 447 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
| 448 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
| 449 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
| 450 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
| 451 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
| 452 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
| 453 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
|
|
|
| 470 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
| 471 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
| 472 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
| 473 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
| 474 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
| 475 |
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
| 476 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
|
|
|
| 489 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
| 490 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
| 491 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
| 492 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
| 493 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
| 494 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
| 495 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
|
|
|
| 505 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
| 506 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
| 507 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
| 508 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
| 509 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
| 510 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
| 511 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
|
|
|
| 521 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
| 522 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
| 523 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
| 524 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
| 525 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
| 526 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
| 527 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
|
|
|
| 1337 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
| 1338 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
| 1339 |
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
| 1340 |
+
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
| 1341 |
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
| 1342 |
}
|
| 1343 |
|
|
|
|
| 1472 |
nth1 = 16;
|
| 1473 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
| 1474 |
} break;
|
| 1475 |
+
case GGML_TYPE_IQ1_S:
|
| 1476 |
+
{
|
| 1477 |
+
nth0 = 4;
|
| 1478 |
+
nth1 = 16;
|
| 1479 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
| 1480 |
+
} break;
|
| 1481 |
default:
|
| 1482 |
{
|
| 1483 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
|
|
| 1512 |
|
| 1513 |
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
| 1514 |
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
| 1515 |
+
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K) {
|
| 1516 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1517 |
}
|
| 1518 |
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
|
|
|
| 1618 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
| 1619 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
| 1620 |
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
| 1621 |
+
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
| 1622 |
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
| 1623 |
}
|
| 1624 |
|
|
|
|
| 1756 |
nth1 = 16;
|
| 1757 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
| 1758 |
} break;
|
| 1759 |
+
case GGML_TYPE_IQ1_S:
|
| 1760 |
+
{
|
| 1761 |
+
nth0 = 4;
|
| 1762 |
+
nth1 = 16;
|
| 1763 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
|
| 1764 |
+
} break;
|
| 1765 |
default:
|
| 1766 |
{
|
| 1767 |
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
|
|
|
| 1812 |
|
| 1813 |
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
| 1814 |
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
| 1815 |
+
src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K) {
|
| 1816 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1817 |
}
|
| 1818 |
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
|
|
|
| 1866 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
| 1867 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
| 1868 |
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
|
| 1869 |
+
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
|
| 1870 |
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
| 1871 |
default: GGML_ASSERT(false && "not implemented");
|
| 1872 |
}
|
ggml-metal.metal
CHANGED
|
@@ -2525,6 +2525,13 @@ typedef struct {
|
|
| 2525 |
} block_iq3_xxs;
|
| 2526 |
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
| 2527 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2528 |
//====================================== dot products =========================
|
| 2529 |
|
| 2530 |
void kernel_mul_mv_q2_K_f32_impl(
|
|
@@ -3782,6 +3789,137 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
|
|
| 3782 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 3783 |
};
|
| 3784 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3785 |
|
| 3786 |
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
| 3787 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
|
@@ -4208,6 +4346,123 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
| 4208 |
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4209 |
}
|
| 4210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4211 |
|
| 4212 |
//============================= templates and their specializations =============================
|
| 4213 |
|
|
@@ -4553,6 +4808,22 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
|
|
| 4553 |
}
|
| 4554 |
}
|
| 4555 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4556 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 4557 |
kernel void kernel_get_rows(
|
| 4558 |
device const void * src0,
|
|
@@ -5095,6 +5366,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
|
|
| 5095 |
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5096 |
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5097 |
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
|
| 5098 |
|
| 5099 |
//
|
| 5100 |
// matrix-matrix multiplication
|
|
@@ -5134,6 +5406,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
| 5134 |
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5135 |
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5136 |
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
|
| 5137 |
|
| 5138 |
//
|
| 5139 |
// indirect matrix-matrix multiplication
|
|
@@ -5185,6 +5458,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
| 5185 |
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5186 |
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5187 |
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
|
| 5188 |
|
| 5189 |
//
|
| 5190 |
// matrix-vector multiplication
|
|
@@ -6152,3 +6426,66 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
| 6152 |
tiisg,
|
| 6153 |
sgitg);
|
| 6154 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2525 |
} block_iq3_xxs;
|
| 2526 |
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
| 2527 |
|
| 2528 |
+
typedef struct {
|
| 2529 |
+
half d;
|
| 2530 |
+
uint8_t qs[QK_K/8];
|
| 2531 |
+
uint8_t scales[QK_K/16];
|
| 2532 |
+
} block_iq1_s;
|
| 2533 |
+
|
| 2534 |
+
|
| 2535 |
//====================================== dot products =========================
|
| 2536 |
|
| 2537 |
void kernel_mul_mv_q2_K_f32_impl(
|
|
|
|
| 3789 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 3790 |
};
|
| 3791 |
|
| 3792 |
+
#define NGRID_IQ1S 512
|
| 3793 |
+
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
|
| 3794 |
+
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
| 3795 |
+
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
| 3796 |
+
0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
|
| 3797 |
+
0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
|
| 3798 |
+
0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
|
| 3799 |
+
0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
|
| 3800 |
+
0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
|
| 3801 |
+
0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
|
| 3802 |
+
0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
|
| 3803 |
+
0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
|
| 3804 |
+
0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
|
| 3805 |
+
0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
|
| 3806 |
+
0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
|
| 3807 |
+
0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
|
| 3808 |
+
0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
|
| 3809 |
+
0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
|
| 3810 |
+
0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
|
| 3811 |
+
0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
|
| 3812 |
+
0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
|
| 3813 |
+
0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
|
| 3814 |
+
0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
|
| 3815 |
+
0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
|
| 3816 |
+
0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
|
| 3817 |
+
0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
|
| 3818 |
+
0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
|
| 3819 |
+
0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
|
| 3820 |
+
0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
|
| 3821 |
+
0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
|
| 3822 |
+
0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
|
| 3823 |
+
0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
|
| 3824 |
+
0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
|
| 3825 |
+
0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
|
| 3826 |
+
0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
|
| 3827 |
+
0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
|
| 3828 |
+
0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
|
| 3829 |
+
0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
|
| 3830 |
+
0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
|
| 3831 |
+
0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
|
| 3832 |
+
0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
|
| 3833 |
+
0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
|
| 3834 |
+
0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
|
| 3835 |
+
0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
|
| 3836 |
+
0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
|
| 3837 |
+
0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
|
| 3838 |
+
0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
|
| 3839 |
+
0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
|
| 3840 |
+
0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
|
| 3841 |
+
0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
|
| 3842 |
+
0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
|
| 3843 |
+
0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
|
| 3844 |
+
0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
|
| 3845 |
+
0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
|
| 3846 |
+
0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
|
| 3847 |
+
0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
|
| 3848 |
+
0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
|
| 3849 |
+
0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
|
| 3850 |
+
0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
|
| 3851 |
+
0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
|
| 3852 |
+
0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
|
| 3853 |
+
0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
|
| 3854 |
+
0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
|
| 3855 |
+
0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
|
| 3856 |
+
0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
|
| 3857 |
+
0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
|
| 3858 |
+
0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
|
| 3859 |
+
0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
|
| 3860 |
+
0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
|
| 3861 |
+
0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
|
| 3862 |
+
0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
|
| 3863 |
+
0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
|
| 3864 |
+
0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
|
| 3865 |
+
0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
|
| 3866 |
+
0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
|
| 3867 |
+
0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
|
| 3868 |
+
0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
|
| 3869 |
+
0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
|
| 3870 |
+
0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
|
| 3871 |
+
0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
|
| 3872 |
+
0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
|
| 3873 |
+
0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
|
| 3874 |
+
0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
|
| 3875 |
+
0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
|
| 3876 |
+
0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
|
| 3877 |
+
0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
|
| 3878 |
+
0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
|
| 3879 |
+
0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
|
| 3880 |
+
0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
|
| 3881 |
+
0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
|
| 3882 |
+
0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
|
| 3883 |
+
0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
|
| 3884 |
+
0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
|
| 3885 |
+
0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
|
| 3886 |
+
0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
|
| 3887 |
+
0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
|
| 3888 |
+
0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
|
| 3889 |
+
0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
|
| 3890 |
+
0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
|
| 3891 |
+
0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
|
| 3892 |
+
0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
|
| 3893 |
+
0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
|
| 3894 |
+
0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
|
| 3895 |
+
0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
|
| 3896 |
+
0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
|
| 3897 |
+
0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
|
| 3898 |
+
0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
|
| 3899 |
+
0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
|
| 3900 |
+
0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
|
| 3901 |
+
0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
|
| 3902 |
+
0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
|
| 3903 |
+
0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
|
| 3904 |
+
0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
|
| 3905 |
+
0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
|
| 3906 |
+
0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
|
| 3907 |
+
0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
|
| 3908 |
+
0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
|
| 3909 |
+
0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
|
| 3910 |
+
0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
|
| 3911 |
+
0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
|
| 3912 |
+
0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
|
| 3913 |
+
0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
|
| 3914 |
+
0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
|
| 3915 |
+
0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
|
| 3916 |
+
0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
|
| 3917 |
+
0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
|
| 3918 |
+
0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
|
| 3919 |
+
0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
|
| 3920 |
+
0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
|
| 3921 |
+
0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
|
| 3922 |
+
};
|
| 3923 |
|
| 3924 |
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
| 3925 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
|
|
|
| 4346 |
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4347 |
}
|
| 4348 |
|
| 4349 |
+
void kernel_mul_mv_iq1_s_f32_impl(
|
| 4350 |
+
device const void * src0,
|
| 4351 |
+
device const float * src1,
|
| 4352 |
+
device float * dst,
|
| 4353 |
+
constant int64_t & ne00,
|
| 4354 |
+
constant int64_t & ne01,
|
| 4355 |
+
constant int64_t & ne02,
|
| 4356 |
+
constant int64_t & ne10,
|
| 4357 |
+
constant int64_t & ne12,
|
| 4358 |
+
constant int64_t & ne0,
|
| 4359 |
+
constant int64_t & ne1,
|
| 4360 |
+
constant uint & r2,
|
| 4361 |
+
constant uint & r3,
|
| 4362 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4363 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4364 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4365 |
+
|
| 4366 |
+
const int nb = ne00/QK_K;
|
| 4367 |
+
const int r0 = tgpig.x;
|
| 4368 |
+
const int r1 = tgpig.y;
|
| 4369 |
+
const int im = tgpig.z;
|
| 4370 |
+
|
| 4371 |
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 4372 |
+
const int ib_row = first_row * nb;
|
| 4373 |
+
|
| 4374 |
+
const uint i12 = im%ne12;
|
| 4375 |
+
const uint i13 = im/ne12;
|
| 4376 |
+
|
| 4377 |
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
| 4378 |
+
|
| 4379 |
+
device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
|
| 4380 |
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
| 4381 |
+
|
| 4382 |
+
float yl[16];
|
| 4383 |
+
float sumf[N_DST]={0.f}, all_sum;
|
| 4384 |
+
|
| 4385 |
+
const int nb32 = nb * (QK_K / 32);
|
| 4386 |
+
|
| 4387 |
+
#if QK_K == 256
|
| 4388 |
+
const int ix = tiisg/2;
|
| 4389 |
+
const int il = tiisg%2;
|
| 4390 |
+
|
| 4391 |
+
device const float * y4 = y + 32 * ix + 16 * il;
|
| 4392 |
+
|
| 4393 |
+
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
|
| 4394 |
+
|
| 4395 |
+
for (int i = 0; i < 16; ++i) {
|
| 4396 |
+
yl[i] = y4[i];
|
| 4397 |
+
}
|
| 4398 |
+
|
| 4399 |
+
const int ibl = ib32 / (QK_K / 32);
|
| 4400 |
+
const int ib = ib32 % (QK_K / 32);
|
| 4401 |
+
|
| 4402 |
+
device const block_iq1_s * xr = x + ibl;
|
| 4403 |
+
device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
|
| 4404 |
+
device const uint8_t * sc = xr->scales + 2 * ib + il;
|
| 4405 |
+
device const half * dh = &xr->d;
|
| 4406 |
+
|
| 4407 |
+
for (int row = 0; row < N_DST; row++) {
|
| 4408 |
+
|
| 4409 |
+
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
| 4410 |
+
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
| 4411 |
+
|
| 4412 |
+
float2 sum = {0};
|
| 4413 |
+
for (int j = 0; j < 8; ++j) {
|
| 4414 |
+
sum[0] += yl[j+ 0] * grid1[j];
|
| 4415 |
+
sum[1] += yl[j+ 8] * grid2[j];
|
| 4416 |
+
}
|
| 4417 |
+
sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
|
| 4418 |
+
|
| 4419 |
+
dh += nb*sizeof(block_iq1_s)/2;
|
| 4420 |
+
qs += nb*sizeof(block_iq1_s);
|
| 4421 |
+
sc += nb*sizeof(block_iq1_s);
|
| 4422 |
+
}
|
| 4423 |
+
|
| 4424 |
+
y4 += 16 * 32;
|
| 4425 |
+
}
|
| 4426 |
+
#else
|
| 4427 |
+
// TODO
|
| 4428 |
+
#endif
|
| 4429 |
+
|
| 4430 |
+
for (int row = 0; row < N_DST; ++row) {
|
| 4431 |
+
all_sum = simd_sum(sumf[row]);
|
| 4432 |
+
if (tiisg == 0) {
|
| 4433 |
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
| 4434 |
+
}
|
| 4435 |
+
}
|
| 4436 |
+
}
|
| 4437 |
+
|
| 4438 |
+
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
| 4439 |
+
kernel void kernel_mul_mv_iq1_s_f32(
|
| 4440 |
+
device const void * src0,
|
| 4441 |
+
device const float * src1,
|
| 4442 |
+
device float * dst,
|
| 4443 |
+
constant int64_t & ne00,
|
| 4444 |
+
constant int64_t & ne01,
|
| 4445 |
+
constant int64_t & ne02,
|
| 4446 |
+
constant uint64_t & nb00,
|
| 4447 |
+
constant uint64_t & nb01,
|
| 4448 |
+
constant uint64_t & nb02,
|
| 4449 |
+
constant int64_t & ne10,
|
| 4450 |
+
constant int64_t & ne11,
|
| 4451 |
+
constant int64_t & ne12,
|
| 4452 |
+
constant uint64_t & nb10,
|
| 4453 |
+
constant uint64_t & nb11,
|
| 4454 |
+
constant uint64_t & nb12,
|
| 4455 |
+
constant int64_t & ne0,
|
| 4456 |
+
constant int64_t & ne1,
|
| 4457 |
+
constant uint & r2,
|
| 4458 |
+
constant uint & r3,
|
| 4459 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4460 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4461 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4462 |
+
|
| 4463 |
+
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 4464 |
+
}
|
| 4465 |
+
|
| 4466 |
|
| 4467 |
//============================= templates and their specializations =============================
|
| 4468 |
|
|
|
|
| 4808 |
}
|
| 4809 |
}
|
| 4810 |
|
| 4811 |
+
template <typename type4x4>
|
| 4812 |
+
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
|
| 4813 |
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
| 4814 |
+
const float d = xb->d;
|
| 4815 |
+
device const uint8_t * qs = xb->qs + 2*il;
|
| 4816 |
+
device const uint8_t * sc = xb->scales + il;
|
| 4817 |
+
const float dl1 = d * (2*(sc[0] & 7) + 1);
|
| 4818 |
+
const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1);
|
| 4819 |
+
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
| 4820 |
+
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
| 4821 |
+
for (int i = 0; i < 8; ++i) {
|
| 4822 |
+
reg[i/4+0][i%4] = dl1 * grid1[i];
|
| 4823 |
+
reg[i/4+2][i%4] = dl2 * grid2[i];
|
| 4824 |
+
}
|
| 4825 |
+
}
|
| 4826 |
+
|
| 4827 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 4828 |
kernel void kernel_get_rows(
|
| 4829 |
device const void * src0,
|
|
|
|
| 5366 |
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5367 |
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5368 |
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
| 5369 |
+
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 5370 |
|
| 5371 |
//
|
| 5372 |
// matrix-matrix multiplication
|
|
|
|
| 5406 |
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5407 |
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5408 |
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
| 5409 |
+
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 5410 |
|
| 5411 |
//
|
| 5412 |
// indirect matrix-matrix multiplication
|
|
|
|
| 5458 |
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5459 |
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5460 |
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
| 5461 |
+
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 5462 |
|
| 5463 |
//
|
| 5464 |
// matrix-vector multiplication
|
|
|
|
| 6426 |
tiisg,
|
| 6427 |
sgitg);
|
| 6428 |
}
|
| 6429 |
+
|
| 6430 |
+
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
| 6431 |
+
kernel void kernel_mul_mv_id_iq1_s_f32(
|
| 6432 |
+
device const char * ids,
|
| 6433 |
+
device const char * src1,
|
| 6434 |
+
device float * dst,
|
| 6435 |
+
constant uint64_t & nbi1,
|
| 6436 |
+
constant int64_t & ne00,
|
| 6437 |
+
constant int64_t & ne01,
|
| 6438 |
+
constant int64_t & ne02,
|
| 6439 |
+
constant uint64_t & nb00,
|
| 6440 |
+
constant uint64_t & nb01,
|
| 6441 |
+
constant uint64_t & nb02,
|
| 6442 |
+
constant int64_t & ne10,
|
| 6443 |
+
constant int64_t & ne11,
|
| 6444 |
+
constant int64_t & ne12,
|
| 6445 |
+
constant int64_t & ne13,
|
| 6446 |
+
constant uint64_t & nb10,
|
| 6447 |
+
constant uint64_t & nb11,
|
| 6448 |
+
constant uint64_t & nb12,
|
| 6449 |
+
constant int64_t & ne0,
|
| 6450 |
+
constant int64_t & ne1,
|
| 6451 |
+
constant uint64_t & nb1,
|
| 6452 |
+
constant uint & r2,
|
| 6453 |
+
constant uint & r3,
|
| 6454 |
+
constant int & idx,
|
| 6455 |
+
device const char * src00,
|
| 6456 |
+
device const char * src01,
|
| 6457 |
+
device const char * src02,
|
| 6458 |
+
device const char * src03,
|
| 6459 |
+
device const char * src04,
|
| 6460 |
+
device const char * src05,
|
| 6461 |
+
device const char * src06,
|
| 6462 |
+
device const char * src07,
|
| 6463 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6464 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 6465 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 6466 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6467 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6468 |
+
|
| 6469 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6470 |
+
|
| 6471 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6472 |
+
|
| 6473 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6474 |
+
|
| 6475 |
+
kernel_mul_mv_iq1_s_f32_impl(
|
| 6476 |
+
src0[id],
|
| 6477 |
+
(device const float *) (src1 + bid*nb11),
|
| 6478 |
+
dst + bid*ne0,
|
| 6479 |
+
ne00,
|
| 6480 |
+
ne01,
|
| 6481 |
+
ne02,
|
| 6482 |
+
ne10,
|
| 6483 |
+
ne12,
|
| 6484 |
+
ne0,
|
| 6485 |
+
ne1,
|
| 6486 |
+
r2,
|
| 6487 |
+
r3,
|
| 6488 |
+
tgpig,
|
| 6489 |
+
tiisg,
|
| 6490 |
+
sgitg);
|
| 6491 |
+
}
|
ggml-quants.c
CHANGED
|
@@ -3480,6 +3480,139 @@ static const uint32_t iq3xxs_grid[256] = {
|
|
| 3480 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 3481 |
};
|
| 3482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3483 |
static const uint8_t ksigns_iq2xs[128] = {
|
| 3484 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 3485 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
@@ -3578,6 +3711,49 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
|
|
| 3578 |
}
|
| 3579 |
}
|
| 3580 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3581 |
//===================================== Q8_K ==============================================
|
| 3582 |
|
| 3583 |
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
|
@@ -3679,7 +3855,7 @@ static inline __m128i get_scale_shuffle(int i) {
|
|
| 3679 |
}
|
| 3680 |
#endif
|
| 3681 |
|
| 3682 |
-
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t
|
| 3683 |
const int qk = QK8_0;
|
| 3684 |
const int nb = n / qk;
|
| 3685 |
|
|
@@ -3690,8 +3866,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
| 3690 |
assert(nrc == 1);
|
| 3691 |
#endif
|
| 3692 |
UNUSED(nrc);
|
| 3693 |
-
UNUSED(
|
| 3694 |
-
UNUSED(
|
| 3695 |
UNUSED(bs);
|
| 3696 |
|
| 3697 |
const block_q4_0 * restrict x = vx;
|
|
@@ -4046,7 +4222,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
| 4046 |
#endif
|
| 4047 |
}
|
| 4048 |
|
| 4049 |
-
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t
|
| 4050 |
const int qk = QK8_1;
|
| 4051 |
const int nb = n / qk;
|
| 4052 |
|
|
@@ -4057,8 +4233,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|
| 4057 |
assert(nrc == 1);
|
| 4058 |
#endif
|
| 4059 |
UNUSED(nrc);
|
| 4060 |
-
UNUSED(
|
| 4061 |
-
UNUSED(
|
| 4062 |
UNUSED(bs);
|
| 4063 |
|
| 4064 |
const block_q4_1 * restrict x = vx;
|
|
@@ -4264,7 +4440,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|
| 4264 |
#endif
|
| 4265 |
}
|
| 4266 |
|
| 4267 |
-
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t
|
| 4268 |
const int qk = QK8_0;
|
| 4269 |
const int nb = n / qk;
|
| 4270 |
|
|
@@ -4272,8 +4448,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
| 4272 |
assert(qk == QK5_0);
|
| 4273 |
assert(nrc == 1);
|
| 4274 |
UNUSED(nrc);
|
| 4275 |
-
UNUSED(
|
| 4276 |
-
UNUSED(
|
| 4277 |
UNUSED(bs);
|
| 4278 |
|
| 4279 |
const block_q5_0 * restrict x = vx;
|
|
@@ -4555,7 +4731,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
| 4555 |
#endif
|
| 4556 |
}
|
| 4557 |
|
| 4558 |
-
void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t
|
| 4559 |
const int qk = QK8_1;
|
| 4560 |
const int nb = n / qk;
|
| 4561 |
|
|
@@ -4563,8 +4739,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|
| 4563 |
assert(qk == QK5_1);
|
| 4564 |
assert(nrc == 1);
|
| 4565 |
UNUSED(nrc);
|
| 4566 |
-
UNUSED(
|
| 4567 |
-
UNUSED(
|
| 4568 |
UNUSED(bs);
|
| 4569 |
|
| 4570 |
const block_q5_1 * restrict x = vx;
|
|
@@ -4859,7 +5035,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|
| 4859 |
#endif
|
| 4860 |
}
|
| 4861 |
|
| 4862 |
-
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t
|
| 4863 |
const int qk = QK8_0;
|
| 4864 |
const int nb = n / qk;
|
| 4865 |
|
|
@@ -4870,8 +5046,8 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
| 4870 |
assert(nrc == 1);
|
| 4871 |
#endif
|
| 4872 |
UNUSED(nrc);
|
| 4873 |
-
UNUSED(
|
| 4874 |
-
UNUSED(
|
| 4875 |
UNUSED(bs);
|
| 4876 |
|
| 4877 |
const block_q8_0 * restrict x = vx;
|
|
@@ -9107,6 +9283,178 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
|
|
| 9107 |
#endif
|
| 9108 |
}
|
| 9109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9110 |
// ================================ IQ2 quantization =============================================
|
| 9111 |
|
| 9112 |
typedef struct {
|
|
@@ -9115,14 +9463,22 @@ typedef struct {
|
|
| 9115 |
uint16_t * neighbours;
|
| 9116 |
} iq2_entry_t;
|
| 9117 |
|
| 9118 |
-
static iq2_entry_t iq2_data[
|
|
|
|
| 9119 |
{NULL, NULL, NULL},
|
| 9120 |
{NULL, NULL, NULL},
|
| 9121 |
};
|
| 9122 |
|
| 9123 |
-
static inline int iq2_data_index(
|
| 9124 |
-
GGML_ASSERT(
|
| 9125 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9126 |
}
|
| 9127 |
|
| 9128 |
static int iq2_compare_func(const void * left, const void * right) {
|
|
@@ -9131,12 +9487,13 @@ static int iq2_compare_func(const void * left, const void * right) {
|
|
| 9131 |
return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
|
| 9132 |
}
|
| 9133 |
|
| 9134 |
-
void iq2xs_init_impl(
|
| 9135 |
-
const int gindex = iq2_data_index(
|
|
|
|
| 9136 |
if (iq2_data[gindex].grid) {
|
| 9137 |
return;
|
| 9138 |
}
|
| 9139 |
-
static const uint16_t
|
| 9140 |
0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97,
|
| 9141 |
100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642,
|
| 9142 |
1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288,
|
|
@@ -9154,7 +9511,7 @@ void iq2xs_init_impl(int grid_size) {
|
|
| 9154 |
33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
|
| 9155 |
37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
|
| 9156 |
};
|
| 9157 |
-
static const uint16_t
|
| 9158 |
0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
|
| 9159 |
73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257,
|
| 9160 |
260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340,
|
|
@@ -9188,9 +9545,45 @@ void iq2xs_init_impl(int grid_size) {
|
|
| 9188 |
40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
|
| 9189 |
42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
|
| 9190 |
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9191 |
const int kmap_size = 43692;
|
| 9192 |
-
const int nwant = 2;
|
| 9193 |
-
const uint16_t * kgrid =
|
|
|
|
| 9194 |
uint64_t * kgrid_q2xs;
|
| 9195 |
int * kmap_q2xs;
|
| 9196 |
uint16_t * kneighbors_q2xs;
|
|
@@ -9286,9 +9679,9 @@ void iq2xs_init_impl(int grid_size) {
|
|
| 9286 |
free(dist2);
|
| 9287 |
}
|
| 9288 |
|
| 9289 |
-
void iq2xs_free_impl(
|
| 9290 |
-
GGML_ASSERT(
|
| 9291 |
-
const int gindex = iq2_data_index(
|
| 9292 |
if (iq2_data[gindex].grid) {
|
| 9293 |
free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL;
|
| 9294 |
free(iq2_data[gindex].map); iq2_data[gindex].map = NULL;
|
|
@@ -9322,7 +9715,7 @@ static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const u
|
|
| 9322 |
|
| 9323 |
static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
| 9324 |
|
| 9325 |
-
const int gindex = iq2_data_index(
|
| 9326 |
|
| 9327 |
const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
|
| 9328 |
const int * kmap_q2xs = iq2_data[gindex].map;
|
|
@@ -9495,7 +9888,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
|
|
| 9495 |
|
| 9496 |
static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
| 9497 |
|
| 9498 |
-
const int gindex = iq2_data_index(
|
| 9499 |
|
| 9500 |
const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
|
| 9501 |
const int * kmap_q2xs = iq2_data[gindex].map;
|
|
@@ -10132,3 +10525,207 @@ void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * re
|
|
| 10132 |
assert(k % QK_K == 0);
|
| 10133 |
quantize_row_iq3_xxs_impl(x, y, k, NULL);
|
| 10134 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3480 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 3481 |
};
|
| 3482 |
|
| 3483 |
+
#define NGRID_IQ2XXS 512
|
| 3484 |
+
static const uint64_t iq1s_grid[NGRID_IQ2XXS] = {
|
| 3485 |
+
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
| 3486 |
+
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
| 3487 |
+
0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
|
| 3488 |
+
0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
|
| 3489 |
+
0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
|
| 3490 |
+
0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
|
| 3491 |
+
0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
|
| 3492 |
+
0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
|
| 3493 |
+
0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
|
| 3494 |
+
0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
|
| 3495 |
+
0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
|
| 3496 |
+
0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
|
| 3497 |
+
0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
|
| 3498 |
+
0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
|
| 3499 |
+
0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
|
| 3500 |
+
0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
|
| 3501 |
+
0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
|
| 3502 |
+
0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
|
| 3503 |
+
0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
|
| 3504 |
+
0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
|
| 3505 |
+
0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
|
| 3506 |
+
0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
|
| 3507 |
+
0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
|
| 3508 |
+
0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
|
| 3509 |
+
0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
|
| 3510 |
+
0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
|
| 3511 |
+
0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
|
| 3512 |
+
0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
|
| 3513 |
+
0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
|
| 3514 |
+
0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
|
| 3515 |
+
0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
|
| 3516 |
+
0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
|
| 3517 |
+
0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
|
| 3518 |
+
0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
|
| 3519 |
+
0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
|
| 3520 |
+
0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
|
| 3521 |
+
0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
|
| 3522 |
+
0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
|
| 3523 |
+
0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
|
| 3524 |
+
0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
|
| 3525 |
+
0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
|
| 3526 |
+
0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
|
| 3527 |
+
0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
|
| 3528 |
+
0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
|
| 3529 |
+
0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
|
| 3530 |
+
0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
|
| 3531 |
+
0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
|
| 3532 |
+
0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
|
| 3533 |
+
0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
|
| 3534 |
+
0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
|
| 3535 |
+
0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
|
| 3536 |
+
0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
|
| 3537 |
+
0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
|
| 3538 |
+
0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
|
| 3539 |
+
0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
|
| 3540 |
+
0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
|
| 3541 |
+
0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
|
| 3542 |
+
0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
|
| 3543 |
+
0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
|
| 3544 |
+
0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
|
| 3545 |
+
0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
|
| 3546 |
+
0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
|
| 3547 |
+
0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
|
| 3548 |
+
0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
|
| 3549 |
+
0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
|
| 3550 |
+
0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
|
| 3551 |
+
0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
|
| 3552 |
+
0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
|
| 3553 |
+
0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
|
| 3554 |
+
0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
|
| 3555 |
+
0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
|
| 3556 |
+
0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
|
| 3557 |
+
0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
|
| 3558 |
+
0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
|
| 3559 |
+
0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
|
| 3560 |
+
0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
|
| 3561 |
+
0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
|
| 3562 |
+
0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
|
| 3563 |
+
0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
|
| 3564 |
+
0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
|
| 3565 |
+
0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
|
| 3566 |
+
0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
|
| 3567 |
+
0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
|
| 3568 |
+
0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
|
| 3569 |
+
0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
|
| 3570 |
+
0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
|
| 3571 |
+
0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
|
| 3572 |
+
0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
|
| 3573 |
+
0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
|
| 3574 |
+
0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
|
| 3575 |
+
0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
|
| 3576 |
+
0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
|
| 3577 |
+
0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
|
| 3578 |
+
0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
|
| 3579 |
+
0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
|
| 3580 |
+
0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
|
| 3581 |
+
0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
|
| 3582 |
+
0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
|
| 3583 |
+
0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
|
| 3584 |
+
0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
|
| 3585 |
+
0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
|
| 3586 |
+
0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
|
| 3587 |
+
0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
|
| 3588 |
+
0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
|
| 3589 |
+
0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
|
| 3590 |
+
0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
|
| 3591 |
+
0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
|
| 3592 |
+
0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
|
| 3593 |
+
0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
|
| 3594 |
+
0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
|
| 3595 |
+
0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
|
| 3596 |
+
0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
|
| 3597 |
+
0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
|
| 3598 |
+
0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
|
| 3599 |
+
0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
|
| 3600 |
+
0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
|
| 3601 |
+
0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
|
| 3602 |
+
0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
|
| 3603 |
+
0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
|
| 3604 |
+
0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
|
| 3605 |
+
0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
|
| 3606 |
+
0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
|
| 3607 |
+
0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
|
| 3608 |
+
0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
|
| 3609 |
+
0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
|
| 3610 |
+
0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
|
| 3611 |
+
0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
|
| 3612 |
+
0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
|
| 3613 |
+
|
| 3614 |
+
};
|
| 3615 |
+
|
| 3616 |
static const uint8_t ksigns_iq2xs[128] = {
|
| 3617 |
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
| 3618 |
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
|
|
| 3711 |
}
|
| 3712 |
}
|
| 3713 |
|
| 3714 |
+
// ====================== 1.5625 bpw (de)-quantization
|
| 3715 |
+
|
| 3716 |
+
void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
|
| 3717 |
+
assert(k % QK_K == 0);
|
| 3718 |
+
const int nb = k / QK_K;
|
| 3719 |
+
|
| 3720 |
+
float db[4];
|
| 3721 |
+
uint16_t idx[4];
|
| 3722 |
+
//const int8_t * grid[4];
|
| 3723 |
+
|
| 3724 |
+
for (int i = 0; i < nb; i++) {
|
| 3725 |
+
|
| 3726 |
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
| 3727 |
+
const uint8_t * sc = x[i].scales;
|
| 3728 |
+
const uint8_t * qs = x[i].qs;
|
| 3729 |
+
|
| 3730 |
+
for (int i8 = 0; i8 < QK_K/8; i8 += 4) {
|
| 3731 |
+
idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
|
| 3732 |
+
idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
|
| 3733 |
+
idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
|
| 3734 |
+
idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
|
| 3735 |
+
//grid[0] = (const int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
| 3736 |
+
//grid[1] = (const int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
| 3737 |
+
//grid[2] = (const int8_t *)(iq1s_grid + (qs[2] | ((sc[1] & 0x08) << 5)));
|
| 3738 |
+
//grid[3] = (const int8_t *)(iq1s_grid + (qs[3] | ((sc[1] & 0x80) << 1)));
|
| 3739 |
+
db[0] = d * (2*(sc[0] & 7) + 1);
|
| 3740 |
+
db[1] = d * (2*((sc[0] >> 4) & 7) + 1);
|
| 3741 |
+
db[2] = d * (2*(sc[1] & 7) + 1);
|
| 3742 |
+
db[3] = d * (2*((sc[1] >> 4) & 7) + 1);
|
| 3743 |
+
for (int l = 0; l < 4; ++l) {
|
| 3744 |
+
const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
|
| 3745 |
+
for (int j = 0; j < 8; ++j) {
|
| 3746 |
+
//y[j] = db[l] * grid[l][j];
|
| 3747 |
+
y[j] = db[l] * grid[j];
|
| 3748 |
+
}
|
| 3749 |
+
y += 8;
|
| 3750 |
+
}
|
| 3751 |
+
qs += 4;
|
| 3752 |
+
sc += 2;
|
| 3753 |
+
}
|
| 3754 |
+
}
|
| 3755 |
+
}
|
| 3756 |
+
|
| 3757 |
//===================================== Q8_K ==============================================
|
| 3758 |
|
| 3759 |
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
|
|
|
| 3855 |
}
|
| 3856 |
#endif
|
| 3857 |
|
| 3858 |
+
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
| 3859 |
const int qk = QK8_0;
|
| 3860 |
const int nb = n / qk;
|
| 3861 |
|
|
|
|
| 3866 |
assert(nrc == 1);
|
| 3867 |
#endif
|
| 3868 |
UNUSED(nrc);
|
| 3869 |
+
UNUSED(bbx);
|
| 3870 |
+
UNUSED(bby);
|
| 3871 |
UNUSED(bs);
|
| 3872 |
|
| 3873 |
const block_q4_0 * restrict x = vx;
|
|
|
|
| 4222 |
#endif
|
| 4223 |
}
|
| 4224 |
|
| 4225 |
+
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
| 4226 |
const int qk = QK8_1;
|
| 4227 |
const int nb = n / qk;
|
| 4228 |
|
|
|
|
| 4233 |
assert(nrc == 1);
|
| 4234 |
#endif
|
| 4235 |
UNUSED(nrc);
|
| 4236 |
+
UNUSED(bbx);
|
| 4237 |
+
UNUSED(bby);
|
| 4238 |
UNUSED(bs);
|
| 4239 |
|
| 4240 |
const block_q4_1 * restrict x = vx;
|
|
|
|
| 4440 |
#endif
|
| 4441 |
}
|
| 4442 |
|
| 4443 |
+
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
| 4444 |
const int qk = QK8_0;
|
| 4445 |
const int nb = n / qk;
|
| 4446 |
|
|
|
|
| 4448 |
assert(qk == QK5_0);
|
| 4449 |
assert(nrc == 1);
|
| 4450 |
UNUSED(nrc);
|
| 4451 |
+
UNUSED(bbx);
|
| 4452 |
+
UNUSED(bby);
|
| 4453 |
UNUSED(bs);
|
| 4454 |
|
| 4455 |
const block_q5_0 * restrict x = vx;
|
|
|
|
| 4731 |
#endif
|
| 4732 |
}
|
| 4733 |
|
| 4734 |
+
void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
| 4735 |
const int qk = QK8_1;
|
| 4736 |
const int nb = n / qk;
|
| 4737 |
|
|
|
|
| 4739 |
assert(qk == QK5_1);
|
| 4740 |
assert(nrc == 1);
|
| 4741 |
UNUSED(nrc);
|
| 4742 |
+
UNUSED(bbx);
|
| 4743 |
+
UNUSED(bby);
|
| 4744 |
UNUSED(bs);
|
| 4745 |
|
| 4746 |
const block_q5_1 * restrict x = vx;
|
|
|
|
| 5035 |
#endif
|
| 5036 |
}
|
| 5037 |
|
| 5038 |
+
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
| 5039 |
const int qk = QK8_0;
|
| 5040 |
const int nb = n / qk;
|
| 5041 |
|
|
|
|
| 5046 |
assert(nrc == 1);
|
| 5047 |
#endif
|
| 5048 |
UNUSED(nrc);
|
| 5049 |
+
UNUSED(bbx);
|
| 5050 |
+
UNUSED(bby);
|
| 5051 |
UNUSED(bs);
|
| 5052 |
|
| 5053 |
const block_q8_0 * restrict x = vx;
|
|
|
|
| 9283 |
#endif
|
| 9284 |
}
|
| 9285 |
|
| 9286 |
+
#ifdef __AVX2__
|
| 9287 |
+
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
| 9288 |
+
const __m256i ax = _mm256_sign_epi8(x, x);
|
| 9289 |
+
const __m256i sy = _mm256_sign_epi8(y, x);
|
| 9290 |
+
return _mm256_maddubs_epi16(ax, sy);
|
| 9291 |
+
}
|
| 9292 |
+
#endif
|
| 9293 |
+
|
| 9294 |
+
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 9295 |
+
assert(n % QK_K == 0);
|
| 9296 |
+
assert(nrc == 1);
|
| 9297 |
+
UNUSED(nrc);
|
| 9298 |
+
UNUSED(bx);
|
| 9299 |
+
UNUSED(by);
|
| 9300 |
+
UNUSED(bs);
|
| 9301 |
+
|
| 9302 |
+
const block_iq1_s * restrict x = vx;
|
| 9303 |
+
const block_q8_K * restrict y = vy;
|
| 9304 |
+
|
| 9305 |
+
const int nb = n / QK_K;
|
| 9306 |
+
|
| 9307 |
+
#if defined __ARM_NEON
|
| 9308 |
+
|
| 9309 |
+
const uint8x16_t m8 = vdupq_n_u8(0x08);
|
| 9310 |
+
const uint8x16_t m7 = vdupq_n_u8(0x07);
|
| 9311 |
+
const uint8x16_t m1 = vdupq_n_u8(0x01);
|
| 9312 |
+
const int32x4_t vzero = vdupq_n_s32(0);
|
| 9313 |
+
|
| 9314 |
+
uint16_t gindex[8];
|
| 9315 |
+
uint16x8x2_t vindex;
|
| 9316 |
+
int8x16x4_t q1b;
|
| 9317 |
+
int8x16x4_t q8b;
|
| 9318 |
+
uint16x8x4_t scales;
|
| 9319 |
+
int32x4x2_t sumi;
|
| 9320 |
+
int32x4x2_t dotq;
|
| 9321 |
+
|
| 9322 |
+
float sumf = 0;
|
| 9323 |
+
for (int i = 0; i < nb; ++i) {
|
| 9324 |
+
|
| 9325 |
+
const int8_t * q8 = y[i].qs;
|
| 9326 |
+
const uint8_t * qs = x[i].qs;
|
| 9327 |
+
const uint8_t * sc = x[i].scales;
|
| 9328 |
+
|
| 9329 |
+
sumi.val[0] = sumi.val[1] = vzero;
|
| 9330 |
+
|
| 9331 |
+
for (int i128 = 0; i128 < QK_K/128; ++i128) {
|
| 9332 |
+
const uint8x16_t ql = vld1q_u8(qs); qs += 16;
|
| 9333 |
+
const uint8x8_t tm1 = vld1_u8 (sc); sc += 8;
|
| 9334 |
+
const uint8x8_t tm2 = vshr_n_u8(tm1, 4);
|
| 9335 |
+
const uint8x16_t qh = vcombine_u8(vzip1_u8(tm1, tm2), vzip2_u8(tm1, tm2));
|
| 9336 |
+
const uint8x16_t hbit = vandq_u8(qh, m8);
|
| 9337 |
+
vindex.val[0] = vorrq_u16(vmovl_u8(vget_low_u8 (ql)), vshlq_n_u16(vmovl_u8(vget_low_u8 (hbit)), 5));
|
| 9338 |
+
vindex.val[1] = vorrq_u16(vmovl_u8(vget_high_u8(ql)), vshlq_n_u16(vmovl_u8(vget_high_u8(hbit)), 5));
|
| 9339 |
+
const uint8x16_t scales8 = vorrq_u8(vshlq_n_u8(vandq_u8(qh, m7), 1), m1);
|
| 9340 |
+
scales.val[0] = vmovl_u8(vget_low_u8 (scales8));
|
| 9341 |
+
scales.val[1] = vmovl_u8(vget_high_u8 (scales8));
|
| 9342 |
+
|
| 9343 |
+
for (int l = 0; l < 2; ++l) {
|
| 9344 |
+
vst1q_u16(gindex+0, vindex.val[l]);
|
| 9345 |
+
q1b.val[0] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[0])), vld1_s8((const void *)(iq1s_grid+gindex[1])));
|
| 9346 |
+
q1b.val[1] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[2])), vld1_s8((const void *)(iq1s_grid+gindex[3])));
|
| 9347 |
+
q1b.val[2] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[4])), vld1_s8((const void *)(iq1s_grid+gindex[5])));
|
| 9348 |
+
q1b.val[3] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[6])), vld1_s8((const void *)(iq1s_grid+gindex[7])));
|
| 9349 |
+
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
| 9350 |
+
|
| 9351 |
+
dotq.val[0] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(vzero, q1b.val[1], q8b.val[1]));
|
| 9352 |
+
dotq.val[1] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(vzero, q1b.val[3], q8b.val[3]));
|
| 9353 |
+
|
| 9354 |
+
sumi.val[0] = vmlaq_s32(sumi.val[0], dotq.val[0], vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales.val[l]))));
|
| 9355 |
+
sumi.val[1] = vmlaq_s32(sumi.val[1], dotq.val[1], vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales.val[l]))));
|
| 9356 |
+
}
|
| 9357 |
+
}
|
| 9358 |
+
|
| 9359 |
+
sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * vaddvq_s32(vaddq_s32(sumi.val[0], sumi.val[1]));
|
| 9360 |
+
}
|
| 9361 |
+
|
| 9362 |
+
*s = sumf;
|
| 9363 |
+
|
| 9364 |
+
#elif defined __AVX2__
|
| 9365 |
+
|
| 9366 |
+
const __m128i m8 = _mm_set1_epi8(0x08);
|
| 9367 |
+
const __m128i m7 = _mm_set1_epi8(0x07);
|
| 9368 |
+
const __m128i m1 = _mm_set1_epi8(0x01);
|
| 9369 |
+
const __m128i shuffle_h = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
|
| 9370 |
+
const __m128i shuffle_s[4] = {
|
| 9371 |
+
_mm_set_epi32(0x03030303, 0x02020202, 0x01010101, 0x00000000),
|
| 9372 |
+
_mm_set_epi32(0x07070707, 0x06060606, 0x05050505, 0x04040404),
|
| 9373 |
+
_mm_set_epi32(0x0b0b0b0b, 0x0a0a0a0a, 0x09090909, 0x08080808),
|
| 9374 |
+
_mm_set_epi32(0x0f0f0f0f, 0x0e0e0e0e, 0x0d0d0d0d, 0x0c0c0c0c)
|
| 9375 |
+
};
|
| 9376 |
+
|
| 9377 |
+
uint64_t aux64;
|
| 9378 |
+
|
| 9379 |
+
__m256i v_gindex;
|
| 9380 |
+
const uint16_t * gindex = (const uint16_t *)&v_gindex;
|
| 9381 |
+
|
| 9382 |
+
__m256 accum = _mm256_setzero_ps();
|
| 9383 |
+
for (int i = 0; i < nb; ++i) {
|
| 9384 |
+
|
| 9385 |
+
const int8_t * q8 = y[i].qs;
|
| 9386 |
+
const uint8_t * qs = x[i].qs;
|
| 9387 |
+
const uint8_t * sc = x[i].scales;
|
| 9388 |
+
|
| 9389 |
+
__m256i sumi = _mm256_setzero_si256();
|
| 9390 |
+
for (int i128 = 0; i128 < QK_K/128; ++i128) {
|
| 9391 |
+
const __m128i ql = _mm_loadu_si128((const __m128i*)qs); qs += 16;
|
| 9392 |
+
memcpy(&aux64, sc, 8); sc += 8;
|
| 9393 |
+
const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h);
|
| 9394 |
+
const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8));
|
| 9395 |
+
v_gindex = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
|
| 9396 |
+
const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1);
|
| 9397 |
+
|
| 9398 |
+
for (int i32 = 0; i32 < 4; ++i32) {
|
| 9399 |
+
const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
| 9400 |
+
const __m256i q1b = _mm256_set_epi64x(iq1s_grid[gindex[4*i32+3]], iq1s_grid[gindex[4*i32+2]],
|
| 9401 |
+
iq1s_grid[gindex[4*i32+1]], iq1s_grid[gindex[4*i32+0]]);
|
| 9402 |
+
const __m256i dot = mul_add_epi8(q1b, q8b);
|
| 9403 |
+
const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32]));
|
| 9404 |
+
const __m256i p = _mm256_madd_epi16(s16, dot);
|
| 9405 |
+
sumi = _mm256_add_epi32(sumi, p);
|
| 9406 |
+
}
|
| 9407 |
+
|
| 9408 |
+
}
|
| 9409 |
+
|
| 9410 |
+
accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum);
|
| 9411 |
+
|
| 9412 |
+
}
|
| 9413 |
+
|
| 9414 |
+
*s = hsum_float_8(accum);
|
| 9415 |
+
|
| 9416 |
+
#else
|
| 9417 |
+
|
| 9418 |
+
int db[4];
|
| 9419 |
+
uint16_t idx[4];
|
| 9420 |
+
|
| 9421 |
+
float sumf = 0;
|
| 9422 |
+
for (int i = 0; i < nb; ++i) {
|
| 9423 |
+
|
| 9424 |
+
const int8_t * q8 = y[i].qs;
|
| 9425 |
+
const uint8_t * qs = x[i].qs;
|
| 9426 |
+
const uint8_t * sc = x[i].scales;
|
| 9427 |
+
|
| 9428 |
+
int sumi = 0;
|
| 9429 |
+
for (int i32 = 0; i32 < QK_K/32; ++i32) {
|
| 9430 |
+
idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
|
| 9431 |
+
idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
|
| 9432 |
+
idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
|
| 9433 |
+
idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
|
| 9434 |
+
db[0] = (2*(sc[0] & 7) + 1);
|
| 9435 |
+
db[1] = (2*((sc[0] >> 4) & 7) + 1);
|
| 9436 |
+
db[2] = (2*(sc[1] & 7) + 1);
|
| 9437 |
+
db[3] = (2*((sc[1] >> 4) & 7) + 1);
|
| 9438 |
+
for (int l = 0; l < 4; ++l) {
|
| 9439 |
+
const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
|
| 9440 |
+
int suml = 0;
|
| 9441 |
+
for (int j = 0; j < 8; ++j) suml += q8[j] * grid[j];
|
| 9442 |
+
sumi += db[l] * suml;
|
| 9443 |
+
q8 += 8;
|
| 9444 |
+
}
|
| 9445 |
+
qs += 4;
|
| 9446 |
+
sc += 2;
|
| 9447 |
+
}
|
| 9448 |
+
|
| 9449 |
+
sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * sumi;
|
| 9450 |
+
}
|
| 9451 |
+
|
| 9452 |
+
*s = sumf;
|
| 9453 |
+
|
| 9454 |
+
#endif
|
| 9455 |
+
|
| 9456 |
+
}
|
| 9457 |
+
|
| 9458 |
// ================================ IQ2 quantization =============================================
|
| 9459 |
|
| 9460 |
typedef struct {
|
|
|
|
| 9463 |
uint16_t * neighbours;
|
| 9464 |
} iq2_entry_t;
|
| 9465 |
|
| 9466 |
+
static iq2_entry_t iq2_data[3] = {
|
| 9467 |
+
{NULL, NULL, NULL},
|
| 9468 |
{NULL, NULL, NULL},
|
| 9469 |
{NULL, NULL, NULL},
|
| 9470 |
};
|
| 9471 |
|
| 9472 |
+
static inline int iq2_data_index(enum ggml_type type) {
|
| 9473 |
+
GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
|
| 9474 |
+
return type == GGML_TYPE_IQ2_XXS ? 0 :
|
| 9475 |
+
type == GGML_TYPE_IQ2_XS ? 1 : 2;
|
| 9476 |
+
}
|
| 9477 |
+
|
| 9478 |
+
static inline int iq2_grid_size(enum ggml_type type) {
|
| 9479 |
+
GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
|
| 9480 |
+
return type == GGML_TYPE_IQ2_XXS ? 256 :
|
| 9481 |
+
type == GGML_TYPE_IQ2_XS ? 512 : 512;
|
| 9482 |
}
|
| 9483 |
|
| 9484 |
static int iq2_compare_func(const void * left, const void * right) {
|
|
|
|
| 9487 |
return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
|
| 9488 |
}
|
| 9489 |
|
| 9490 |
+
void iq2xs_init_impl(enum ggml_type type) {
|
| 9491 |
+
const int gindex = iq2_data_index(type);
|
| 9492 |
+
const int grid_size = iq2_grid_size(type);
|
| 9493 |
if (iq2_data[gindex].grid) {
|
| 9494 |
return;
|
| 9495 |
}
|
| 9496 |
+
static const uint16_t kgrid_2bit_256[256] = {
|
| 9497 |
0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97,
|
| 9498 |
100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642,
|
| 9499 |
1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288,
|
|
|
|
| 9511 |
33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
|
| 9512 |
37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
|
| 9513 |
};
|
| 9514 |
+
static const uint16_t kgrid_2bit_512[512] = {
|
| 9515 |
0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
|
| 9516 |
73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257,
|
| 9517 |
260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340,
|
|
|
|
| 9545 |
40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
|
| 9546 |
42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
|
| 9547 |
};
|
| 9548 |
+
static const uint16_t kgrid_1bit_512[512] = {
|
| 9549 |
+
10, 33, 41, 85, 132, 134, 160, 162, 277, 337, 340, 345, 357, 405, 516, 545,
|
| 9550 |
+
553, 598, 641, 650, 681, 1042, 1044, 1097, 1169, 1176, 1320, 1345, 1365, 1378, 1434, 1444,
|
| 9551 |
+
1545, 1617, 1642, 1685, 2053, 2080, 2089, 2133, 2176, 2182, 2208, 2214, 2306, 2384, 2393, 2440,
|
| 9552 |
+
2453, 2581, 2664, 2690, 2721, 4117, 4161, 4182, 4184, 4261, 4357, 4369, 4372, 4377, 4390, 4422,
|
| 9553 |
+
4432, 4437, 4449, 4457, 4485, 4497, 4505, 4629, 4677, 4696, 4774, 5205, 5217, 5225, 5386, 5397,
|
| 9554 |
+
5409, 5445, 5457, 5460, 5461, 5462, 5465, 5472, 5477, 5525, 5545, 5650, 5668, 5717, 5729, 5769,
|
| 9555 |
+
5777, 6212, 6234, 6244, 6293, 6424, 6482, 6485, 6502, 6505, 6529, 6538, 6565, 6656, 6682, 6788,
|
| 9556 |
+
6806, 6820, 8218, 8224, 8226, 8232, 8277, 8326, 8354, 8469, 8521, 8530, 8549, 8596, 8737, 8794,
|
| 9557 |
+
9221, 9253, 9348, 9369, 9380, 9474, 9557, 9633, 9732, 9753, 9793, 9830, 9862, 9880, 10240, 10272,
|
| 9558 |
+
10282, 10321, 10406, 10517, 10530, 10566, 10585, 10645, 10896, 16466, 16468, 16473, 16485, 16646, 16660, 16665,
|
| 9559 |
+
16725, 16793, 16806, 16914, 16969, 16977, 16996, 17028, 17057, 17408, 17416, 17434, 17493, 17512, 17578, 17685,
|
| 9560 |
+
17696, 17733, 17745, 17748, 17749, 17750, 17753, 17765, 17794, 17813, 17946, 17984, 18005, 18072, 18453, 18529,
|
| 9561 |
+
18569, 18722, 18756, 18762, 18773, 18794, 18833, 18853, 18945, 19026, 19033, 19077, 20489, 20497, 20500, 20517,
|
| 9562 |
+
20565, 20586, 20610, 20633, 20757, 20769, 20776, 20805, 20817, 20820, 20821, 20822, 20825, 20837, 20864, 20872,
|
| 9563 |
+
20885, 20896, 21002, 21029, 21077, 21146, 21510, 21525, 21573, 21585, 21588, 21589, 21590, 21593, 21605, 21653,
|
| 9564 |
+
21665, 21765, 21777, 21780, 21781, 21782, 21785, 21797, 21825, 21828, 21829, 21830, 21833, 21840, 21841, 21842,
|
| 9565 |
+
21844, 21846, 21848, 21849, 21850, 21857, 21860, 21861, 21862, 21865, 21893, 21905, 21908, 21909, 21910, 21913,
|
| 9566 |
+
21925, 22024, 22037, 22085, 22097, 22100, 22101, 22102, 22105, 22117, 22165, 22545, 22566, 22568, 22594, 22608,
|
| 9567 |
+
22613, 22676, 22697, 22793, 22805, 22853, 22865, 22868, 22869, 22870, 22873, 22885, 22933, 22946, 23046, 23072,
|
| 9568 |
+
23125, 23209, 24597, 24640, 24665, 24673, 24725, 24833, 24840, 24869, 24917, 24934, 24965, 25001, 25108, 25110,
|
| 9569 |
+
25152, 25184, 25192, 25234, 25616, 25618, 25625, 25685, 25704, 25738, 25744, 25770, 25877, 25897, 25925, 25937,
|
| 9570 |
+
25940, 25941, 25942, 25945, 25957, 25986, 26005, 26186, 26197, 26276, 26632, 26634, 26725, 26757, 26770, 26885,
|
| 9571 |
+
26965, 26976, 26986, 27032, 27153, 27174, 27200, 27208, 27240, 27269, 27282, 27290, 32778, 32800, 32802, 32808,
|
| 9572 |
+
32810, 32853, 32904, 32922, 32930, 32932, 33105, 33110, 33112, 33125, 33157, 33280, 33288, 33301, 33312, 33320,
|
| 9573 |
+
33424, 33797, 33829, 33858, 34068, 34133, 34146, 34176, 34217, 34306, 34342, 34441, 34454, 34468, 34832, 34918,
|
| 9574 |
+
34965, 34984, 35094, 35137, 35161, 35208, 35232, 35332, 35338, 35368, 35429, 36932, 36934, 36953, 37009, 37125,
|
| 9575 |
+
37136, 37138, 37145, 37157, 37205, 37220, 37258, 37290, 37444, 37446, 37465, 37478, 37525, 37905, 37968, 37973,
|
| 9576 |
+
38040, 38054, 38145, 38154, 38165, 38180, 38186, 38213, 38225, 38228, 38229, 38230, 38233, 38245, 38293, 38485,
|
| 9577 |
+
38504, 38530, 38938, 38985, 38993, 39012, 39040, 39173, 39192, 39253, 39265, 39301, 39316, 39322, 39442, 39497,
|
| 9578 |
+
39504, 39590, 40970, 40984, 40992, 41002, 41045, 41120, 41128, 41237, 41289, 41297, 41317, 41364, 41366, 41514,
|
| 9579 |
+
41557, 41633, 41989, 42021, 42056, 42068, 42074, 42113, 42242, 42265, 42274, 42325, 42340, 42402, 42501, 42512,
|
| 9580 |
+
42533, 42624, 42632, 42666, 43040, 43093, 43106, 43168, 43176, 43264, 43286, 43345, 43429, 43590, 43618, 43680,
|
| 9581 |
+
};
|
| 9582 |
+
|
| 9583 |
const int kmap_size = 43692;
|
| 9584 |
+
const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2;
|
| 9585 |
+
const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
|
| 9586 |
+
type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 : kgrid_1bit_512;
|
| 9587 |
uint64_t * kgrid_q2xs;
|
| 9588 |
int * kmap_q2xs;
|
| 9589 |
uint16_t * kneighbors_q2xs;
|
|
|
|
| 9679 |
free(dist2);
|
| 9680 |
}
|
| 9681 |
|
| 9682 |
+
void iq2xs_free_impl(enum ggml_type type) {
|
| 9683 |
+
GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
|
| 9684 |
+
const int gindex = iq2_data_index(type);
|
| 9685 |
if (iq2_data[gindex].grid) {
|
| 9686 |
free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL;
|
| 9687 |
free(iq2_data[gindex].map); iq2_data[gindex].map = NULL;
|
|
|
|
| 9715 |
|
| 9716 |
static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
| 9717 |
|
| 9718 |
+
const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS);
|
| 9719 |
|
| 9720 |
const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
|
| 9721 |
const int * kmap_q2xs = iq2_data[gindex].map;
|
|
|
|
| 9888 |
|
| 9889 |
static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
| 9890 |
|
| 9891 |
+
const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS);
|
| 9892 |
|
| 9893 |
const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
|
| 9894 |
const int * kmap_q2xs = iq2_data[gindex].map;
|
|
|
|
| 10525 |
assert(k % QK_K == 0);
|
| 10526 |
quantize_row_iq3_xxs_impl(x, y, k, NULL);
|
| 10527 |
}
|
| 10528 |
+
|
| 10529 |
+
// =================================== 1.5 bpw ===================================================
|
| 10530 |
+
|
| 10531 |
+
static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
|
| 10532 |
+
const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) {
|
| 10533 |
+
int num_neighbors = neighbours[0];
|
| 10534 |
+
GGML_ASSERT(num_neighbors > 0);
|
| 10535 |
+
float best_score = 0;
|
| 10536 |
+
int grid_index = -1;
|
| 10537 |
+
for (int j = 1; j <= num_neighbors; ++j) {
|
| 10538 |
+
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
|
| 10539 |
+
float sumqx = 0, sumq2 = 0;
|
| 10540 |
+
for (int i = 0; i < 8; ++i) {
|
| 10541 |
+
float q = (pg[i] - 3)/2;
|
| 10542 |
+
float w = weight[i];
|
| 10543 |
+
sumqx += w*q*xval[i];
|
| 10544 |
+
sumq2 += w*q*q;
|
| 10545 |
+
}
|
| 10546 |
+
if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
|
| 10547 |
+
*scale = sumqx/sumq2; best_score = *scale * sumqx;
|
| 10548 |
+
grid_index = neighbours[j];
|
| 10549 |
+
}
|
| 10550 |
+
}
|
| 10551 |
+
if (grid_index < 0) {
|
| 10552 |
+
for (int i = 0; i < ngrid; ++i) {
|
| 10553 |
+
const int8_t * grid_i = (const int8_t *)(grid + i);
|
| 10554 |
+
float sumqx = 0, sumq2 = 0;
|
| 10555 |
+
for (int j = 0; j < 8; ++j) {
|
| 10556 |
+
float w = weight[j];
|
| 10557 |
+
float q = (grid_i[j] - 3)/2;
|
| 10558 |
+
sumqx += w*q*xval[j];
|
| 10559 |
+
sumq2 += w*q*q;
|
| 10560 |
+
}
|
| 10561 |
+
if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
|
| 10562 |
+
*scale = sumqx/sumq2; best_score = *scale*sumqx;
|
| 10563 |
+
grid_index = i;
|
| 10564 |
+
}
|
| 10565 |
+
}
|
| 10566 |
+
}
|
| 10567 |
+
if (grid_index < 0) {
|
| 10568 |
+
printf("Oops, did not find grid point\n");
|
| 10569 |
+
printf("Have %d neighbours\n", num_neighbors);
|
| 10570 |
+
for (int j = 1; j <= num_neighbors; ++j) {
|
| 10571 |
+
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
|
| 10572 |
+
float sumqx = 0, sumq2 = 0;
|
| 10573 |
+
for (int i = 0; i < 8; ++i) {
|
| 10574 |
+
float q = (pg[i] - 3)/2;
|
| 10575 |
+
float w = weight[i];
|
| 10576 |
+
sumqx += w*q*xval[i];
|
| 10577 |
+
sumq2 += w*q*q;
|
| 10578 |
+
}
|
| 10579 |
+
printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
|
| 10580 |
+
}
|
| 10581 |
+
}
|
| 10582 |
+
GGML_ASSERT(grid_index >= 0);
|
| 10583 |
+
//!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 10584 |
+
*scale *= 1.05f; // This is a fudge factor. Don't ask me why it improves the result.
|
| 10585 |
+
//!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 10586 |
+
const int8_t * pg = (const int8_t *)(grid + grid_index);
|
| 10587 |
+
for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
|
| 10588 |
+
return grid_index;
|
| 10589 |
+
}
|
| 10590 |
+
|
| 10591 |
+
static int iq1_sort_helper(const void * left, const void * right) {
|
| 10592 |
+
const float * l = left;
|
| 10593 |
+
const float * r = right;
|
| 10594 |
+
return *l < *r ? -1 : *l > *r ? 1 : 0;
|
| 10595 |
+
}
|
| 10596 |
+
|
| 10597 |
+
static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
| 10598 |
+
|
| 10599 |
+
const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
|
| 10600 |
+
|
| 10601 |
+
const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
|
| 10602 |
+
const int * kmap_q2xs = iq2_data[gindex].map;
|
| 10603 |
+
const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
|
| 10604 |
+
|
| 10605 |
+
GGML_ASSERT(quant_weights && "missing quantization weights");
|
| 10606 |
+
GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
|
| 10607 |
+
GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
|
| 10608 |
+
GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
|
| 10609 |
+
GGML_ASSERT(n%QK_K == 0);
|
| 10610 |
+
|
| 10611 |
+
const int nbl = n/256;
|
| 10612 |
+
|
| 10613 |
+
block_iq1_s * y = vy;
|
| 10614 |
+
|
| 10615 |
+
float scales[QK_K/8];
|
| 10616 |
+
float weight[8];
|
| 10617 |
+
int8_t L[8];
|
| 10618 |
+
float sumx[9];
|
| 10619 |
+
float sumw[9];
|
| 10620 |
+
float pairs[16];
|
| 10621 |
+
int * idx = (int *)(pairs + 1);
|
| 10622 |
+
uint8_t hbit[QK_K/8];
|
| 10623 |
+
|
| 10624 |
+
for (int ibl = 0; ibl < nbl; ++ibl) {
|
| 10625 |
+
|
| 10626 |
+
y[ibl].d = GGML_FP32_TO_FP16(0.f);
|
| 10627 |
+
memset(y[ibl].qs, 0, QK_K/8);
|
| 10628 |
+
memset(y[ibl].scales, 0, QK_K/16);
|
| 10629 |
+
|
| 10630 |
+
float max_scale = 0;
|
| 10631 |
+
|
| 10632 |
+
const float * xbl = x + QK_K*ibl;
|
| 10633 |
+
float sumx2 = 0;
|
| 10634 |
+
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
|
| 10635 |
+
float sigma2 = sumx2/QK_K;
|
| 10636 |
+
|
| 10637 |
+
for (int ib = 0; ib < QK_K/8; ++ib) {
|
| 10638 |
+
const float * xb = xbl + 8*ib;
|
| 10639 |
+
const float * qw = quant_weights + QK_K*ibl + 8*ib;
|
| 10640 |
+
for (int i = 0; i < 8; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
| 10641 |
+
float max = fabsf(xb[0]);
|
| 10642 |
+
for (int i = 1; i < 8; ++i) max = MAX(max, fabsf(xb[i]));
|
| 10643 |
+
if (!max) {
|
| 10644 |
+
scales[ib] = 0;
|
| 10645 |
+
memset(L, 1, 8);
|
| 10646 |
+
continue;
|
| 10647 |
+
}
|
| 10648 |
+
// Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
|
| 10649 |
+
// With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
|
| 10650 |
+
// boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
|
| 10651 |
+
// in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
|
| 10652 |
+
// Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
|
| 10653 |
+
// for each possible and score for each split.
|
| 10654 |
+
for (int j = 0; j < 8; ++j) {
|
| 10655 |
+
pairs[2*j] = xb[j];
|
| 10656 |
+
idx[2*j] = j;
|
| 10657 |
+
}
|
| 10658 |
+
qsort(pairs, 8, 2*sizeof(float), iq1_sort_helper);
|
| 10659 |
+
{
|
| 10660 |
+
sumx[0] = sumw[0] = 0;
|
| 10661 |
+
for (int j = 0; j < 8; ++j) {
|
| 10662 |
+
int i = idx[2*j];
|
| 10663 |
+
sumx[j+1] = sumx[j] + weight[i]*xb[i];
|
| 10664 |
+
sumw[j+1] = sumw[j] + weight[i];
|
| 10665 |
+
}
|
| 10666 |
+
}
|
| 10667 |
+
float best_score = 0, scale = max;
|
| 10668 |
+
int besti1 = 0, besti2 = 0;
|
| 10669 |
+
for (int i1 = 0; i1 <= 8; ++i1) {
|
| 10670 |
+
for (int i2 = i1; i2 <= 8; ++i2) {
|
| 10671 |
+
float sumqx = -(sumx[i1] - sumx[0]) + (sumx[8] - sumx[i2]);
|
| 10672 |
+
float sumq2 = (sumw[i1] - sumw[0]) + (sumw[8] - sumw[i2]);
|
| 10673 |
+
if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
|
| 10674 |
+
scale = sumqx/sumq2; best_score = scale*sumqx;
|
| 10675 |
+
besti1 = i1; besti2 = i2;
|
| 10676 |
+
}
|
| 10677 |
+
}
|
| 10678 |
+
}
|
| 10679 |
+
for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
|
| 10680 |
+
for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
|
| 10681 |
+
for (int j = besti2; j < 8; ++j) L[idx[2*j]] = 2;
|
| 10682 |
+
if (scale < 0) {
|
| 10683 |
+
for (int j = 0; j < 8; ++j) L[j] = 2 - L[j];
|
| 10684 |
+
scale = -scale;
|
| 10685 |
+
}
|
| 10686 |
+
// Now we check if the solution found above corresponds to a grid point and, if not, use a neighbouring
|
| 10687 |
+
// grid point that minimizes SSD.
|
| 10688 |
+
uint16_t u = 0;
|
| 10689 |
+
for (int j = 0; j < 8; ++j) u |= (L[j] << 2*j);
|
| 10690 |
+
int grid_index = kmap_q2xs[u];
|
| 10691 |
+
if (grid_index < 0) {
|
| 10692 |
+
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
|
| 10693 |
+
grid_index = iq1_find_best_neighbour(neighbours, kgrid_q2xs, xb, weight, &scale, L, NGRID_IQ2XXS);
|
| 10694 |
+
GGML_ASSERT(grid_index >= 0);
|
| 10695 |
+
}
|
| 10696 |
+
y[ibl].qs[ib] = grid_index & 255;
|
| 10697 |
+
hbit[ib] = grid_index >> 8;
|
| 10698 |
+
GGML_ASSERT(scale >= 0);
|
| 10699 |
+
scales[ib] = scale;
|
| 10700 |
+
max_scale = MAX(max_scale, scale);
|
| 10701 |
+
}
|
| 10702 |
+
|
| 10703 |
+
if (!max_scale) {
|
| 10704 |
+
memset(y[ibl].qs, 0, QK_K/8);
|
| 10705 |
+
continue;
|
| 10706 |
+
}
|
| 10707 |
+
|
| 10708 |
+
float d = max_scale/15;
|
| 10709 |
+
y[ibl].d = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed.
|
| 10710 |
+
float id = 1/d;
|
| 10711 |
+
for (int ib = 0; ib < QK_K/8; ++ib) {
|
| 10712 |
+
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
| 10713 |
+
l = MAX(0, MIN(7, l));
|
| 10714 |
+
if (hbit[ib]) l |= 8;
|
| 10715 |
+
y[ibl].scales[ib/2] |= (l << 4*(ib%2));
|
| 10716 |
+
}
|
| 10717 |
+
}
|
| 10718 |
+
}
|
| 10719 |
+
|
| 10720 |
+
size_t quantize_iq1_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
| 10721 |
+
(void)hist;
|
| 10722 |
+
GGML_ASSERT(n_per_row%QK_K == 0);
|
| 10723 |
+
int nblock = n_per_row/QK_K;
|
| 10724 |
+
char * qrow = (char *)dst;
|
| 10725 |
+
for (int row = 0; row < nrow; ++row) {
|
| 10726 |
+
quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights);
|
| 10727 |
+
src += n_per_row;
|
| 10728 |
+
qrow += nblock*sizeof(block_iq1_s);
|
| 10729 |
+
}
|
| 10730 |
+
return nrow * nblock * sizeof(block_iq1_s);
|
| 10731 |
+
}
|
ggml-quants.h
CHANGED
|
@@ -191,6 +191,13 @@ typedef struct {
|
|
| 191 |
} block_iq3_xxs;
|
| 192 |
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
#ifdef __cplusplus
|
| 195 |
extern "C" {
|
| 196 |
#endif
|
|
@@ -243,6 +250,7 @@ void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRI
|
|
| 243 |
void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
| 244 |
void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
| 245 |
void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
|
|
|
| 246 |
|
| 247 |
// Dot product
|
| 248 |
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
|
@@ -259,6 +267,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
| 259 |
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 260 |
void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 261 |
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
|
|
|
| 262 |
|
| 263 |
//
|
| 264 |
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
|
@@ -266,6 +275,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
|
|
| 266 |
size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 267 |
size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 268 |
size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
|
|
|
| 269 |
size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 270 |
size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 271 |
size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
|
@@ -276,8 +286,8 @@ size_t quantize_q4_1 (const float * src, void * dst, int nrows, int n_per_row,
|
|
| 276 |
size_t quantize_q5_0 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 277 |
size_t quantize_q5_1 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 278 |
|
| 279 |
-
void iq2xs_init_impl(
|
| 280 |
-
void iq2xs_free_impl(
|
| 281 |
void iq3xs_init_impl(int grid_size);
|
| 282 |
void iq3xs_free_impl(int grid_size);
|
| 283 |
|
|
|
|
| 191 |
} block_iq3_xxs;
|
| 192 |
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
|
| 193 |
|
| 194 |
+
typedef struct {
|
| 195 |
+
ggml_fp16_t d;
|
| 196 |
+
uint8_t qs[QK_K/8];
|
| 197 |
+
uint8_t scales[QK_K/16];
|
| 198 |
+
} block_iq1_s;
|
| 199 |
+
static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
|
| 200 |
+
|
| 201 |
#ifdef __cplusplus
|
| 202 |
extern "C" {
|
| 203 |
#endif
|
|
|
|
| 250 |
void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
| 251 |
void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
| 252 |
void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
| 253 |
+
void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
| 254 |
|
| 255 |
// Dot product
|
| 256 |
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
|
|
|
| 267 |
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 268 |
void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 269 |
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 270 |
+
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 271 |
|
| 272 |
//
|
| 273 |
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
|
|
|
| 275 |
size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 276 |
size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 277 |
size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 278 |
+
size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 279 |
size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 280 |
size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 281 |
size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
|
|
|
| 286 |
size_t quantize_q5_0 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 287 |
size_t quantize_q5_1 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 288 |
|
| 289 |
+
void iq2xs_init_impl(enum ggml_type type);
|
| 290 |
+
void iq2xs_free_impl(enum ggml_type type);
|
| 291 |
void iq3xs_init_impl(int grid_size);
|
| 292 |
void iq3xs_free_impl(int grid_size);
|
| 293 |
|
ggml.c
CHANGED
|
@@ -673,6 +673,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|
| 673 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 674 |
.nrows = 1,
|
| 675 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
[GGML_TYPE_Q8_K] = {
|
| 677 |
.type_name = "q8_K",
|
| 678 |
.blck_size = QK_K,
|
|
@@ -2267,6 +2279,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
|
| 2267 |
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
|
| 2268 |
case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
|
| 2269 |
case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
|
|
|
|
| 2270 |
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
| 2271 |
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
| 2272 |
}
|
|
@@ -7677,6 +7690,7 @@ static void ggml_compute_forward_add(
|
|
| 7677 |
case GGML_TYPE_IQ2_XXS:
|
| 7678 |
case GGML_TYPE_IQ2_XS:
|
| 7679 |
case GGML_TYPE_IQ3_XXS:
|
|
|
|
| 7680 |
{
|
| 7681 |
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
| 7682 |
} break;
|
|
@@ -7944,6 +7958,7 @@ static void ggml_compute_forward_add1(
|
|
| 7944 |
case GGML_TYPE_IQ2_XXS:
|
| 7945 |
case GGML_TYPE_IQ2_XS:
|
| 7946 |
case GGML_TYPE_IQ3_XXS:
|
|
|
|
| 7947 |
{
|
| 7948 |
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
|
| 7949 |
} break;
|
|
@@ -8064,6 +8079,7 @@ static void ggml_compute_forward_acc(
|
|
| 8064 |
case GGML_TYPE_IQ2_XXS:
|
| 8065 |
case GGML_TYPE_IQ2_XS:
|
| 8066 |
case GGML_TYPE_IQ3_XXS:
|
|
|
|
| 8067 |
default:
|
| 8068 |
{
|
| 8069 |
GGML_ASSERT(false);
|
|
@@ -10830,6 +10846,7 @@ static void ggml_compute_forward_out_prod(
|
|
| 10830 |
case GGML_TYPE_IQ2_XXS:
|
| 10831 |
case GGML_TYPE_IQ2_XS:
|
| 10832 |
case GGML_TYPE_IQ3_XXS:
|
|
|
|
| 10833 |
{
|
| 10834 |
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
| 10835 |
} break;
|
|
@@ -11010,6 +11027,7 @@ static void ggml_compute_forward_set(
|
|
| 11010 |
case GGML_TYPE_IQ2_XXS:
|
| 11011 |
case GGML_TYPE_IQ2_XS:
|
| 11012 |
case GGML_TYPE_IQ3_XXS:
|
|
|
|
| 11013 |
default:
|
| 11014 |
{
|
| 11015 |
GGML_ASSERT(false);
|
|
@@ -11207,6 +11225,7 @@ static void ggml_compute_forward_get_rows(
|
|
| 11207 |
case GGML_TYPE_IQ2_XXS:
|
| 11208 |
case GGML_TYPE_IQ2_XS:
|
| 11209 |
case GGML_TYPE_IQ3_XXS:
|
|
|
|
| 11210 |
{
|
| 11211 |
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
| 11212 |
} break;
|
|
@@ -11880,6 +11899,7 @@ static void ggml_compute_forward_alibi(
|
|
| 11880 |
case GGML_TYPE_IQ2_XXS:
|
| 11881 |
case GGML_TYPE_IQ2_XS:
|
| 11882 |
case GGML_TYPE_IQ3_XXS:
|
|
|
|
| 11883 |
case GGML_TYPE_Q8_K:
|
| 11884 |
case GGML_TYPE_I8:
|
| 11885 |
case GGML_TYPE_I16:
|
|
@@ -11957,6 +11977,7 @@ static void ggml_compute_forward_clamp(
|
|
| 11957 |
case GGML_TYPE_IQ2_XXS:
|
| 11958 |
case GGML_TYPE_IQ2_XS:
|
| 11959 |
case GGML_TYPE_IQ3_XXS:
|
|
|
|
| 11960 |
case GGML_TYPE_Q8_K:
|
| 11961 |
case GGML_TYPE_I8:
|
| 11962 |
case GGML_TYPE_I16:
|
|
@@ -19136,8 +19157,9 @@ void ggml_quantize_init(enum ggml_type type) {
|
|
| 19136 |
ggml_critical_section_start();
|
| 19137 |
|
| 19138 |
switch (type) {
|
| 19139 |
-
case GGML_TYPE_IQ2_XXS:
|
| 19140 |
-
case GGML_TYPE_IQ2_XS:
|
|
|
|
| 19141 |
case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
|
| 19142 |
default: // nothing
|
| 19143 |
break;
|
|
@@ -19149,8 +19171,10 @@ void ggml_quantize_init(enum ggml_type type) {
|
|
| 19149 |
void ggml_quantize_free(void) {
|
| 19150 |
ggml_critical_section_start();
|
| 19151 |
|
| 19152 |
-
iq2xs_free_impl(
|
| 19153 |
-
iq2xs_free_impl(
|
|
|
|
|
|
|
| 19154 |
|
| 19155 |
ggml_critical_section_end();
|
| 19156 |
}
|
|
@@ -19285,7 +19309,8 @@ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t *
|
|
| 19285 |
bool ggml_quantize_requires_imatrix(enum ggml_type type) {
|
| 19286 |
return
|
| 19287 |
type == GGML_TYPE_IQ2_XXS ||
|
| 19288 |
-
type == GGML_TYPE_IQ2_XS
|
|
|
|
| 19289 |
}
|
| 19290 |
|
| 19291 |
size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start,
|
|
@@ -19410,6 +19435,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
| 19410 |
result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
| 19411 |
GGML_ASSERT(result == row_size * nrows);
|
| 19412 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19413 |
case GGML_TYPE_F16:
|
| 19414 |
{
|
| 19415 |
size_t elemsize = sizeof(ggml_fp16_t);
|
|
|
|
| 673 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 674 |
.nrows = 1,
|
| 675 |
},
|
| 676 |
+
[GGML_TYPE_IQ1_S] = {
|
| 677 |
+
.type_name = "iq1_s",
|
| 678 |
+
.blck_size = QK_K,
|
| 679 |
+
.type_size = sizeof(block_iq1_s),
|
| 680 |
+
.is_quantized = true,
|
| 681 |
+
.to_float = (ggml_to_float_t) dequantize_row_iq1_s,
|
| 682 |
+
.from_float = NULL,
|
| 683 |
+
.from_float_reference = NULL,
|
| 684 |
+
.vec_dot = ggml_vec_dot_iq1_s_q8_K,
|
| 685 |
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 686 |
+
.nrows = 1,
|
| 687 |
+
},
|
| 688 |
[GGML_TYPE_Q8_K] = {
|
| 689 |
.type_name = "q8_K",
|
| 690 |
.blck_size = QK_K,
|
|
|
|
| 2279 |
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
|
| 2280 |
case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
|
| 2281 |
case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
|
| 2282 |
+
case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
|
| 2283 |
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
| 2284 |
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
| 2285 |
}
|
|
|
|
| 7690 |
case GGML_TYPE_IQ2_XXS:
|
| 7691 |
case GGML_TYPE_IQ2_XS:
|
| 7692 |
case GGML_TYPE_IQ3_XXS:
|
| 7693 |
+
case GGML_TYPE_IQ1_S:
|
| 7694 |
{
|
| 7695 |
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
| 7696 |
} break;
|
|
|
|
| 7958 |
case GGML_TYPE_IQ2_XXS:
|
| 7959 |
case GGML_TYPE_IQ2_XS:
|
| 7960 |
case GGML_TYPE_IQ3_XXS:
|
| 7961 |
+
case GGML_TYPE_IQ1_S:
|
| 7962 |
{
|
| 7963 |
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
|
| 7964 |
} break;
|
|
|
|
| 8079 |
case GGML_TYPE_IQ2_XXS:
|
| 8080 |
case GGML_TYPE_IQ2_XS:
|
| 8081 |
case GGML_TYPE_IQ3_XXS:
|
| 8082 |
+
case GGML_TYPE_IQ1_S:
|
| 8083 |
default:
|
| 8084 |
{
|
| 8085 |
GGML_ASSERT(false);
|
|
|
|
| 10846 |
case GGML_TYPE_IQ2_XXS:
|
| 10847 |
case GGML_TYPE_IQ2_XS:
|
| 10848 |
case GGML_TYPE_IQ3_XXS:
|
| 10849 |
+
case GGML_TYPE_IQ1_S:
|
| 10850 |
{
|
| 10851 |
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
| 10852 |
} break;
|
|
|
|
| 11027 |
case GGML_TYPE_IQ2_XXS:
|
| 11028 |
case GGML_TYPE_IQ2_XS:
|
| 11029 |
case GGML_TYPE_IQ3_XXS:
|
| 11030 |
+
case GGML_TYPE_IQ1_S:
|
| 11031 |
default:
|
| 11032 |
{
|
| 11033 |
GGML_ASSERT(false);
|
|
|
|
| 11225 |
case GGML_TYPE_IQ2_XXS:
|
| 11226 |
case GGML_TYPE_IQ2_XS:
|
| 11227 |
case GGML_TYPE_IQ3_XXS:
|
| 11228 |
+
case GGML_TYPE_IQ1_S:
|
| 11229 |
{
|
| 11230 |
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
| 11231 |
} break;
|
|
|
|
| 11899 |
case GGML_TYPE_IQ2_XXS:
|
| 11900 |
case GGML_TYPE_IQ2_XS:
|
| 11901 |
case GGML_TYPE_IQ3_XXS:
|
| 11902 |
+
case GGML_TYPE_IQ1_S:
|
| 11903 |
case GGML_TYPE_Q8_K:
|
| 11904 |
case GGML_TYPE_I8:
|
| 11905 |
case GGML_TYPE_I16:
|
|
|
|
| 11977 |
case GGML_TYPE_IQ2_XXS:
|
| 11978 |
case GGML_TYPE_IQ2_XS:
|
| 11979 |
case GGML_TYPE_IQ3_XXS:
|
| 11980 |
+
case GGML_TYPE_IQ1_S:
|
| 11981 |
case GGML_TYPE_Q8_K:
|
| 11982 |
case GGML_TYPE_I8:
|
| 11983 |
case GGML_TYPE_I16:
|
|
|
|
| 19157 |
ggml_critical_section_start();
|
| 19158 |
|
| 19159 |
switch (type) {
|
| 19160 |
+
case GGML_TYPE_IQ2_XXS:
|
| 19161 |
+
case GGML_TYPE_IQ2_XS:
|
| 19162 |
+
case GGML_TYPE_IQ1_S: iq2xs_init_impl(type); break;
|
| 19163 |
case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
|
| 19164 |
default: // nothing
|
| 19165 |
break;
|
|
|
|
| 19171 |
void ggml_quantize_free(void) {
|
| 19172 |
ggml_critical_section_start();
|
| 19173 |
|
| 19174 |
+
iq2xs_free_impl(GGML_TYPE_IQ2_XXS);
|
| 19175 |
+
iq2xs_free_impl(GGML_TYPE_IQ2_XS);
|
| 19176 |
+
iq2xs_free_impl(GGML_TYPE_IQ1_S);
|
| 19177 |
+
iq3xs_free_impl(256);
|
| 19178 |
|
| 19179 |
ggml_critical_section_end();
|
| 19180 |
}
|
|
|
|
| 19309 |
bool ggml_quantize_requires_imatrix(enum ggml_type type) {
|
| 19310 |
return
|
| 19311 |
type == GGML_TYPE_IQ2_XXS ||
|
| 19312 |
+
type == GGML_TYPE_IQ2_XS ||
|
| 19313 |
+
type == GGML_TYPE_IQ1_S;
|
| 19314 |
}
|
| 19315 |
|
| 19316 |
size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start,
|
|
|
|
| 19435 |
result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
| 19436 |
GGML_ASSERT(result == row_size * nrows);
|
| 19437 |
} break;
|
| 19438 |
+
case GGML_TYPE_IQ1_S:
|
| 19439 |
+
{
|
| 19440 |
+
GGML_ASSERT(start % QK_K == 0);
|
| 19441 |
+
GGML_ASSERT(start % n_per_row == 0);
|
| 19442 |
+
size_t start_row = start / n_per_row;
|
| 19443 |
+
size_t row_size = ggml_row_size(type, n_per_row);
|
| 19444 |
+
result = quantize_iq1_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
| 19445 |
+
GGML_ASSERT(result == row_size * nrows);
|
| 19446 |
+
} break;
|
| 19447 |
case GGML_TYPE_F16:
|
| 19448 |
{
|
| 19449 |
size_t elemsize = sizeof(ggml_fp16_t);
|
ggml.h
CHANGED
|
@@ -354,6 +354,7 @@ extern "C" {
|
|
| 354 |
GGML_TYPE_IQ2_XXS = 16,
|
| 355 |
GGML_TYPE_IQ2_XS = 17,
|
| 356 |
GGML_TYPE_IQ3_XXS = 18,
|
|
|
|
| 357 |
GGML_TYPE_I8,
|
| 358 |
GGML_TYPE_I16,
|
| 359 |
GGML_TYPE_I32,
|
|
@@ -391,6 +392,7 @@ extern "C" {
|
|
| 391 |
GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
|
| 392 |
GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
|
| 393 |
GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
|
|
|
|
| 394 |
};
|
| 395 |
|
| 396 |
// available tensor operations:
|
|
|
|
| 354 |
GGML_TYPE_IQ2_XXS = 16,
|
| 355 |
GGML_TYPE_IQ2_XS = 17,
|
| 356 |
GGML_TYPE_IQ3_XXS = 18,
|
| 357 |
+
GGML_TYPE_IQ1_S = 19,
|
| 358 |
GGML_TYPE_I8,
|
| 359 |
GGML_TYPE_I16,
|
| 360 |
GGML_TYPE_I32,
|
|
|
|
| 392 |
GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
|
| 393 |
GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
|
| 394 |
GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
|
| 395 |
+
GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
|
| 396 |
};
|
| 397 |
|
| 398 |
// available tensor operations:
|