Spaces:
Running
IQ3_S: a much better alternative to Q3_K (llama/5676)
Browse files* iq4_nl: squash commits for easier rebase
* Basics (quantize, dequantize)
* CUDA dequantize and dot product
* Slightly faster CUDA dot product (120 t/s)
* Switch to 6-bit scales
* Scalar dot product
* AVX2 dot product
* ARM_NEON dot product
* Works on metal, but still slow
* Slightly better Metal dot product
* Another small Metal improvement
* Metal dot product is getting there
* Faster CUDA dot product
* Add 1/8 ffn_down layers as Q5_K when no imatrix has been provided
* Report the actual bpw
* Add _xs mix that is 4.05 bpw for non-MoE models
* Remove IQ4_XS for now, slightly adjust kvalues_iq4nl
* AVX2 dot product uses Q8_0 instead of Q8_K
* Add to test-backend-ops
* Minor fix
* Also use use Q5_K for attn_output in MoE models
* Fixes after merging latest master
* Switching to blocks of 32
* AVX2 for blocks of 32
* Scaler dot product for blocks of 32
* ARM_NEON dot product for blocks of 32
* Metal kernels for blocks of 32
* Slightly faster Metal kernels
* Resurrecting iq3_xs
After all the experimentation, nothing was better than this.
* Minor PPL improvement via a block scale fudge factor
* Minor improvement via 3 neighbours
* iq3_xs: working scalar and AVX2 dot products
* iq3_xs: ARM_NEON dot product - works but extremely slow (10 t/s)
* iq3_xs: working Metal implementation
* Adding IQ3_M - IQ3_XS mix with mostly Q4_K
* iiq3_xs: a 3.4375 bpw variant
* iq3_xs: make CUDA work for new version
* iq3_xs: make scalar and AVX2 work for new version
* iq3_s: make ARM_NEON work with new version
* iq3_xs: make new version work on metal
Performance is very similar to Q3_K_S
* iq3_xs: tiny Metal speed improvement
* iq3_xs: tiny Metal speed improvement
* Fix stupid warning
* Q3_K_XS now uses a mix of IQ3_XS and IQ3_XXS
* iq3_xs: rename to iq3_s
* iq3_s: make tests pass
* Move Q3_K_XS mix to 3.25 bpw
* Attempt to fix failing tests
* Another attempt to fix the Windows builds
* Attempt to fix ROCm
* ROCm again
* iq3_s: partial fix for QK_K = 64
* iq3_s: make it work on metal for QK_K = 64
Pleasent surprise: the coding was super-block size independent,
so all it took was to delete some QK_K == 256 guards.
* Will this fix ROCm?
---------
Co-authored-by: Iwan Kawrakow <[email protected]>
- ggml-cuda.cu +170 -1
- ggml-metal.m +29 -4
- ggml-metal.metal +304 -0
- ggml-quants.c +610 -64
- ggml-quants.h +20 -0
- ggml.c +31 -0
- ggml.h +2 -0
|
@@ -172,6 +172,7 @@
|
|
| 172 |
#endif
|
| 173 |
|
| 174 |
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
|
|
|
| 175 |
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
| 176 |
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
| 177 |
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
|
@@ -196,6 +197,18 @@ static __device__ __forceinline__ int __vsub4(const int a, const int b) {
|
|
| 196 |
return __vsubss4(a, b);
|
| 197 |
}
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
| 200 |
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
|
| 201 |
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
|
@@ -518,6 +531,17 @@ typedef struct {
|
|
| 518 |
} block_iq3_xxs;
|
| 519 |
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
|
| 520 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
#define QR1_S 8
|
| 522 |
#define QI1_S (QK_K / (4*QR1_S))
|
| 523 |
typedef struct {
|
|
@@ -1700,6 +1724,74 @@ static const __device__ uint32_t iq3xxs_grid[256] = {
|
|
| 1700 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 1701 |
};
|
| 1702 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1703 |
static const __device__ uint64_t iq1s_grid[512] = {
|
| 1704 |
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
| 1705 |
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
|
@@ -1973,6 +2065,32 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
|
|
| 1973 |
|
| 1974 |
}
|
| 1975 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1976 |
template<typename dst_t>
|
| 1977 |
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
| 1978 |
|
|
@@ -4717,6 +4835,41 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
|
| 4717 |
#endif
|
| 4718 |
}
|
| 4719 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4720 |
static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
| 4721 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
| 4722 |
#if QK_K == 256
|
|
@@ -6849,6 +7002,12 @@ static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k,
|
|
| 6849 |
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
| 6850 |
}
|
| 6851 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6852 |
template<typename dst_t>
|
| 6853 |
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
| 6854 |
const int nb = k / QK_K;
|
|
@@ -6904,6 +7063,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|
| 6904 |
return dequantize_row_iq1_s_cuda;
|
| 6905 |
case GGML_TYPE_IQ4_NL:
|
| 6906 |
return dequantize_row_iq4_nl_cuda;
|
|
|
|
|
|
|
| 6907 |
case GGML_TYPE_F32:
|
| 6908 |
return convert_unary_cuda<float>;
|
| 6909 |
default:
|
|
@@ -6943,6 +7104,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|
| 6943 |
return dequantize_row_iq1_s_cuda;
|
| 6944 |
case GGML_TYPE_IQ4_NL:
|
| 6945 |
return dequantize_row_iq4_nl_cuda;
|
|
|
|
|
|
|
| 6946 |
case GGML_TYPE_F16:
|
| 6947 |
return convert_unary_cuda<half>;
|
| 6948 |
default:
|
|
@@ -8688,6 +8851,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
|
|
| 8688 |
case GGML_TYPE_IQ3_XXS:
|
| 8689 |
case GGML_TYPE_IQ1_S:
|
| 8690 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 8691 |
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
| 8692 |
default:
|
| 8693 |
GGML_ASSERT(false);
|
|
@@ -8713,6 +8877,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
|
|
| 8713 |
case GGML_TYPE_IQ3_XXS:
|
| 8714 |
case GGML_TYPE_IQ1_S:
|
| 8715 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 8716 |
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
| 8717 |
case GGML_TYPE_Q6_K:
|
| 8718 |
return 64;
|
|
@@ -8818,6 +8983,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
|
| 8818 |
mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
|
| 8819 |
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
| 8820 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8821 |
default:
|
| 8822 |
GGML_ASSERT(false);
|
| 8823 |
break;
|
|
@@ -11541,7 +11710,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 11541 |
}
|
| 11542 |
ggml_type a_type = a->type;
|
| 11543 |
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
|
| 11544 |
-
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL) {
|
| 11545 |
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
| 11546 |
return false;
|
| 11547 |
}
|
|
|
|
| 172 |
#endif
|
| 173 |
|
| 174 |
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
| 175 |
+
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
| 176 |
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
| 177 |
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
| 178 |
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
|
|
|
| 197 |
return __vsubss4(a, b);
|
| 198 |
}
|
| 199 |
|
| 200 |
+
static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
|
| 201 |
+
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
| 202 |
+
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
| 203 |
+
unsigned int c;
|
| 204 |
+
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
| 205 |
+
#pragma unroll
|
| 206 |
+
for (int i = 0; i < 4; ++i) {
|
| 207 |
+
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
| 208 |
+
}
|
| 209 |
+
return c;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
| 213 |
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
|
| 214 |
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
|
|
|
| 531 |
} block_iq3_xxs;
|
| 532 |
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
|
| 533 |
|
| 534 |
+
#define QR3_XS 8
|
| 535 |
+
#define QI3_XS (QK_K / (4*QR3_XS))
|
| 536 |
+
typedef struct {
|
| 537 |
+
half d;
|
| 538 |
+
uint8_t qs[QK_K/4];
|
| 539 |
+
uint8_t qh[QK_K/32];
|
| 540 |
+
uint8_t signs[QK_K/8];
|
| 541 |
+
uint8_t scales[QK_K/64];
|
| 542 |
+
} block_iq3_s;
|
| 543 |
+
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding");
|
| 544 |
+
|
| 545 |
#define QR1_S 8
|
| 546 |
#define QI1_S (QK_K / (4*QR1_S))
|
| 547 |
typedef struct {
|
|
|
|
| 1724 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 1725 |
};
|
| 1726 |
|
| 1727 |
+
static const __device__ uint32_t iq3xs_grid[512] = {
|
| 1728 |
+
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
|
| 1729 |
+
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
|
| 1730 |
+
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
|
| 1731 |
+
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
|
| 1732 |
+
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
|
| 1733 |
+
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
|
| 1734 |
+
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
|
| 1735 |
+
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
|
| 1736 |
+
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
|
| 1737 |
+
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
|
| 1738 |
+
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
|
| 1739 |
+
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
|
| 1740 |
+
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
|
| 1741 |
+
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
|
| 1742 |
+
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
|
| 1743 |
+
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
|
| 1744 |
+
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
|
| 1745 |
+
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
|
| 1746 |
+
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
|
| 1747 |
+
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
|
| 1748 |
+
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
|
| 1749 |
+
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
|
| 1750 |
+
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
|
| 1751 |
+
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
|
| 1752 |
+
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
|
| 1753 |
+
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
|
| 1754 |
+
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
|
| 1755 |
+
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
|
| 1756 |
+
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
|
| 1757 |
+
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
|
| 1758 |
+
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
|
| 1759 |
+
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
|
| 1760 |
+
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
|
| 1761 |
+
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
|
| 1762 |
+
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
|
| 1763 |
+
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
|
| 1764 |
+
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
|
| 1765 |
+
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
|
| 1766 |
+
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
|
| 1767 |
+
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
|
| 1768 |
+
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
|
| 1769 |
+
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
|
| 1770 |
+
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
|
| 1771 |
+
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
|
| 1772 |
+
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
|
| 1773 |
+
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
|
| 1774 |
+
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
|
| 1775 |
+
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
|
| 1776 |
+
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
|
| 1777 |
+
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
|
| 1778 |
+
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
|
| 1779 |
+
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
|
| 1780 |
+
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
|
| 1781 |
+
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
|
| 1782 |
+
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
|
| 1783 |
+
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
|
| 1784 |
+
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
|
| 1785 |
+
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
|
| 1786 |
+
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
|
| 1787 |
+
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
|
| 1788 |
+
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
|
| 1789 |
+
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
|
| 1790 |
+
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
|
| 1791 |
+
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
|
| 1792 |
+
};
|
| 1793 |
+
|
| 1794 |
+
|
| 1795 |
static const __device__ uint64_t iq1s_grid[512] = {
|
| 1796 |
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
| 1797 |
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
|
|
|
| 2065 |
|
| 2066 |
}
|
| 2067 |
|
| 2068 |
+
template<typename dst_t>
|
| 2069 |
+
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
| 2070 |
+
|
| 2071 |
+
const int i = blockIdx.x;
|
| 2072 |
+
const block_iq3_s * x = (const block_iq3_s *) vx;
|
| 2073 |
+
|
| 2074 |
+
const int tid = threadIdx.x;
|
| 2075 |
+
#if QK_K == 256
|
| 2076 |
+
const int il = tid/8; // 0...3
|
| 2077 |
+
const int ib = tid%8; // 0...7
|
| 2078 |
+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
| 2079 |
+
const uint8_t * qs = x[i].qs + 8*ib;
|
| 2080 |
+
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
| 2081 |
+
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
|
| 2082 |
+
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
|
| 2083 |
+
const uint8_t signs = x[i].signs[4*ib + il];
|
| 2084 |
+
for (int j = 0; j < 4; ++j) {
|
| 2085 |
+
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
| 2086 |
+
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
| 2087 |
+
}
|
| 2088 |
+
#else
|
| 2089 |
+
assert(false);
|
| 2090 |
+
#endif
|
| 2091 |
+
|
| 2092 |
+
}
|
| 2093 |
+
|
| 2094 |
template<typename dst_t>
|
| 2095 |
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
| 2096 |
|
|
|
|
| 4835 |
#endif
|
| 4836 |
}
|
| 4837 |
|
| 4838 |
+
// TODO: don't use lookup table for signs
|
| 4839 |
+
static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
|
| 4840 |
+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
| 4841 |
+
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
| 4842 |
+
#if QK_K == 256
|
| 4843 |
+
const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
|
| 4844 |
+
|
| 4845 |
+
const int ib32 = iqs;
|
| 4846 |
+
const uint8_t * qs = bq2->qs + 8*ib32;
|
| 4847 |
+
const int8_t * q8 = bq8_1[ib32].qs;
|
| 4848 |
+
int sumi = 0;
|
| 4849 |
+
for (int l = 0; l < 4; ++l) {
|
| 4850 |
+
const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
|
| 4851 |
+
const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
|
| 4852 |
+
uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
|
| 4853 |
+
uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
|
| 4854 |
+
const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
|
| 4855 |
+
const int grid_h = __vsub4(grid2[0] ^ signs1, signs1);
|
| 4856 |
+
sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
|
| 4857 |
+
sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
|
| 4858 |
+
q8 += 8;
|
| 4859 |
+
}
|
| 4860 |
+
const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f;
|
| 4861 |
+
return d * sumi;
|
| 4862 |
+
#else
|
| 4863 |
+
assert(false);
|
| 4864 |
+
return 0.f;
|
| 4865 |
+
#endif
|
| 4866 |
+
#else
|
| 4867 |
+
assert(false);
|
| 4868 |
+
return 0.f;
|
| 4869 |
+
#endif
|
| 4870 |
+
}
|
| 4871 |
+
|
| 4872 |
+
|
| 4873 |
static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
| 4874 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
| 4875 |
#if QK_K == 256
|
|
|
|
| 7002 |
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
| 7003 |
}
|
| 7004 |
|
| 7005 |
+
template<typename dst_t>
|
| 7006 |
+
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
| 7007 |
+
const int nb = k / QK_K;
|
| 7008 |
+
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
|
| 7009 |
+
}
|
| 7010 |
+
|
| 7011 |
template<typename dst_t>
|
| 7012 |
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
| 7013 |
const int nb = k / QK_K;
|
|
|
|
| 7063 |
return dequantize_row_iq1_s_cuda;
|
| 7064 |
case GGML_TYPE_IQ4_NL:
|
| 7065 |
return dequantize_row_iq4_nl_cuda;
|
| 7066 |
+
case GGML_TYPE_IQ3_S:
|
| 7067 |
+
return dequantize_row_iq3_s_cuda;
|
| 7068 |
case GGML_TYPE_F32:
|
| 7069 |
return convert_unary_cuda<float>;
|
| 7070 |
default:
|
|
|
|
| 7104 |
return dequantize_row_iq1_s_cuda;
|
| 7105 |
case GGML_TYPE_IQ4_NL:
|
| 7106 |
return dequantize_row_iq4_nl_cuda;
|
| 7107 |
+
case GGML_TYPE_IQ3_S:
|
| 7108 |
+
return dequantize_row_iq3_s_cuda;
|
| 7109 |
case GGML_TYPE_F16:
|
| 7110 |
return convert_unary_cuda<half>;
|
| 7111 |
default:
|
|
|
|
| 8851 |
case GGML_TYPE_IQ3_XXS:
|
| 8852 |
case GGML_TYPE_IQ1_S:
|
| 8853 |
case GGML_TYPE_IQ4_NL:
|
| 8854 |
+
case GGML_TYPE_IQ3_S:
|
| 8855 |
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
| 8856 |
default:
|
| 8857 |
GGML_ASSERT(false);
|
|
|
|
| 8877 |
case GGML_TYPE_IQ3_XXS:
|
| 8878 |
case GGML_TYPE_IQ1_S:
|
| 8879 |
case GGML_TYPE_IQ4_NL:
|
| 8880 |
+
case GGML_TYPE_IQ3_S:
|
| 8881 |
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
| 8882 |
case GGML_TYPE_Q6_K:
|
| 8883 |
return 64;
|
|
|
|
| 8983 |
mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
|
| 8984 |
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
| 8985 |
break;
|
| 8986 |
+
case GGML_TYPE_IQ3_S:
|
| 8987 |
+
mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
|
| 8988 |
+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
| 8989 |
+
break;
|
| 8990 |
default:
|
| 8991 |
GGML_ASSERT(false);
|
| 8992 |
break;
|
|
|
|
| 11710 |
}
|
| 11711 |
ggml_type a_type = a->type;
|
| 11712 |
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
|
| 11713 |
+
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S) {
|
| 11714 |
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
| 11715 |
return false;
|
| 11716 |
}
|
|
@@ -61,6 +61,7 @@ enum ggml_metal_kernel_type {
|
|
| 61 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
| 62 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
| 63 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
|
|
|
|
| 64 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
| 65 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
| 66 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
|
@@ -85,6 +86,7 @@ enum ggml_metal_kernel_type {
|
|
| 85 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
| 86 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
| 87 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
|
|
|
|
| 88 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
|
| 89 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
| 90 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
|
@@ -105,6 +107,7 @@ enum ggml_metal_kernel_type {
|
|
| 105 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
| 106 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
| 107 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
|
|
|
|
| 108 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
|
| 109 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
| 110 |
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
|
@@ -122,6 +125,7 @@ enum ggml_metal_kernel_type {
|
|
| 122 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
| 123 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
| 124 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
|
|
|
|
| 125 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
|
| 126 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
| 127 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
|
@@ -139,6 +143,7 @@ enum ggml_metal_kernel_type {
|
|
| 139 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
| 140 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
| 141 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
|
|
|
|
| 142 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
| 143 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
| 144 |
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
|
@@ -452,6 +457,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 452 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
| 453 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
| 454 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
|
|
|
| 455 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
| 456 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
| 457 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
|
@@ -476,6 +482,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 476 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
| 477 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
| 478 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
|
|
|
| 479 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
| 480 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
| 481 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
|
@@ -496,6 +503,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 496 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
| 497 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
| 498 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
|
|
|
| 499 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
| 500 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
| 501 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
|
@@ -513,6 +521,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 513 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
| 514 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
| 515 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
|
|
|
| 516 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
| 517 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
| 518 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
|
@@ -530,6 +539,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 530 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
| 531 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
| 532 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
|
|
|
| 533 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
| 534 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
| 535 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
|
@@ -1347,6 +1357,7 @@ static bool ggml_metal_graph_compute(
|
|
| 1347 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
| 1348 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
| 1349 |
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
|
|
|
| 1350 |
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
| 1351 |
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
| 1352 |
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
|
@@ -1483,6 +1494,12 @@ static bool ggml_metal_graph_compute(
|
|
| 1483 |
nth1 = 16;
|
| 1484 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
| 1485 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1486 |
case GGML_TYPE_IQ1_S:
|
| 1487 |
{
|
| 1488 |
nth0 = 4;
|
|
@@ -1537,8 +1554,8 @@ static bool ggml_metal_graph_compute(
|
|
| 1537 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1538 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1539 |
}
|
| 1540 |
-
else if (src0t == GGML_TYPE_IQ3_XXS) {
|
| 1541 |
-
const int mem_size = 256*4+128;
|
| 1542 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1543 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1544 |
}
|
|
@@ -1640,6 +1657,7 @@ static bool ggml_metal_graph_compute(
|
|
| 1640 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
| 1641 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
| 1642 |
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
|
|
|
| 1643 |
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
| 1644 |
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
| 1645 |
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
|
@@ -1779,6 +1797,12 @@ static bool ggml_metal_graph_compute(
|
|
| 1779 |
nth1 = 16;
|
| 1780 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
| 1781 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1782 |
case GGML_TYPE_IQ1_S:
|
| 1783 |
{
|
| 1784 |
nth0 = 4;
|
|
@@ -1849,8 +1873,8 @@ static bool ggml_metal_graph_compute(
|
|
| 1849 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1850 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1851 |
}
|
| 1852 |
-
else if (src2t == GGML_TYPE_IQ3_XXS) {
|
| 1853 |
-
const int mem_size = 256*4+128;
|
| 1854 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1855 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1856 |
}
|
|
@@ -1900,6 +1924,7 @@ static bool ggml_metal_graph_compute(
|
|
| 1900 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
| 1901 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
| 1902 |
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
|
|
|
|
| 1903 |
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
|
| 1904 |
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
|
| 1905 |
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
|
|
|
| 61 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
| 62 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
| 63 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
|
| 64 |
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
|
| 65 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
| 66 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
| 67 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
|
|
|
| 86 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
| 87 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
| 88 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
|
| 89 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
|
| 90 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
|
| 91 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
| 92 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
|
|
|
| 107 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
| 108 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
| 109 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
|
| 110 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
|
| 111 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
|
| 112 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
| 113 |
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
|
|
|
| 125 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
| 126 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
| 127 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
|
| 128 |
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
|
| 129 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
|
| 130 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
| 131 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
|
|
|
| 143 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
| 144 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
| 145 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
|
| 146 |
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
|
| 147 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
| 148 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
| 149 |
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
|
|
|
| 457 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
| 458 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
| 459 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
| 460 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
| 461 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
| 462 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
| 463 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
|
|
|
| 482 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
| 483 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
| 484 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
| 485 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
| 486 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
| 487 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
| 488 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
|
|
|
| 503 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
| 504 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
| 505 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
| 506 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
| 507 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
| 508 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
| 509 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
|
|
|
| 521 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
| 522 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
| 523 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
| 524 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
| 525 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
| 526 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
| 527 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
|
|
|
| 539 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
| 540 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
| 541 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
| 542 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
| 543 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
| 544 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
| 545 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
|
|
|
| 1357 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
| 1358 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
| 1359 |
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
| 1360 |
+
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
| 1361 |
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
| 1362 |
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
| 1363 |
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
|
|
|
| 1494 |
nth1 = 16;
|
| 1495 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
| 1496 |
} break;
|
| 1497 |
+
case GGML_TYPE_IQ3_S:
|
| 1498 |
+
{
|
| 1499 |
+
nth0 = 4;
|
| 1500 |
+
nth1 = 16;
|
| 1501 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
| 1502 |
+
} break;
|
| 1503 |
case GGML_TYPE_IQ1_S:
|
| 1504 |
{
|
| 1505 |
nth0 = 4;
|
|
|
|
| 1554 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1555 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1556 |
}
|
| 1557 |
+
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
| 1558 |
+
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
| 1559 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1560 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1561 |
}
|
|
|
|
| 1657 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
| 1658 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
| 1659 |
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
| 1660 |
+
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
| 1661 |
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
| 1662 |
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
| 1663 |
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
|
|
|
| 1797 |
nth1 = 16;
|
| 1798 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
| 1799 |
} break;
|
| 1800 |
+
case GGML_TYPE_IQ3_S:
|
| 1801 |
+
{
|
| 1802 |
+
nth0 = 4;
|
| 1803 |
+
nth1 = 16;
|
| 1804 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
|
| 1805 |
+
} break;
|
| 1806 |
case GGML_TYPE_IQ1_S:
|
| 1807 |
{
|
| 1808 |
nth0 = 4;
|
|
|
|
| 1873 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1874 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1875 |
}
|
| 1876 |
+
else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
|
| 1877 |
+
const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
| 1878 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 1879 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1880 |
}
|
|
|
|
| 1924 |
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
| 1925 |
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
| 1926 |
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
|
| 1927 |
+
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
|
| 1928 |
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
|
| 1929 |
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
|
| 1930 |
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
|
@@ -2525,6 +2525,20 @@ typedef struct {
|
|
| 2525 |
} block_iq3_xxs;
|
| 2526 |
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
| 2527 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2528 |
typedef struct {
|
| 2529 |
half d;
|
| 2530 |
uint8_t qs[QK_K/8];
|
|
@@ -3795,6 +3809,73 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
|
|
| 3795 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 3796 |
};
|
| 3797 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3798 |
#define NGRID_IQ1S 512
|
| 3799 |
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
|
| 3800 |
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
|
@@ -4361,6 +4442,136 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
| 4361 |
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4362 |
}
|
| 4363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4364 |
void kernel_mul_mv_iq1_s_f32_impl(
|
| 4365 |
device const void * src0,
|
| 4366 |
device const float * src1,
|
|
@@ -4952,6 +5163,31 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
|
|
| 4952 |
}
|
| 4953 |
}
|
| 4954 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4955 |
template <typename type4x4>
|
| 4956 |
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
|
| 4957 |
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
@@ -5525,6 +5761,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
|
|
| 5525 |
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5526 |
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5527 |
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
|
| 5528 |
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 5529 |
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 5530 |
|
|
@@ -5566,6 +5803,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
| 5566 |
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5567 |
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5568 |
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
|
| 5569 |
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 5570 |
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 5571 |
|
|
@@ -5619,6 +5857,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
| 5619 |
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5620 |
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5621 |
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
|
| 5622 |
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 5623 |
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 5624 |
|
|
@@ -6589,6 +6828,71 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
| 6589 |
sgitg);
|
| 6590 |
}
|
| 6591 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6592 |
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
| 6593 |
kernel void kernel_mul_mv_id_iq1_s_f32(
|
| 6594 |
device const char * ids,
|
|
|
|
| 2525 |
} block_iq3_xxs;
|
| 2526 |
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
| 2527 |
|
| 2528 |
+
// 3.4375 bpw
|
| 2529 |
+
#if QK_K == 64
|
| 2530 |
+
#define IQ3S_N_SCALE 2
|
| 2531 |
+
#else
|
| 2532 |
+
#define IQ3S_N_SCALE QK_K/64
|
| 2533 |
+
#endif
|
| 2534 |
+
typedef struct {
|
| 2535 |
+
half d;
|
| 2536 |
+
uint8_t qs[QK_K/4];
|
| 2537 |
+
uint8_t qh[QK_K/32];
|
| 2538 |
+
uint8_t signs[QK_K/8];
|
| 2539 |
+
uint8_t scales[IQ3S_N_SCALE];
|
| 2540 |
+
} block_iq3_s;
|
| 2541 |
+
|
| 2542 |
typedef struct {
|
| 2543 |
half d;
|
| 2544 |
uint8_t qs[QK_K/8];
|
|
|
|
| 3809 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 3810 |
};
|
| 3811 |
|
| 3812 |
+
constexpr constant static uint32_t iq3xs_grid[512] = {
|
| 3813 |
+
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
|
| 3814 |
+
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
|
| 3815 |
+
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
|
| 3816 |
+
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
|
| 3817 |
+
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
|
| 3818 |
+
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
|
| 3819 |
+
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
|
| 3820 |
+
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
|
| 3821 |
+
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
|
| 3822 |
+
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
|
| 3823 |
+
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
|
| 3824 |
+
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
|
| 3825 |
+
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
|
| 3826 |
+
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
|
| 3827 |
+
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
|
| 3828 |
+
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
|
| 3829 |
+
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
|
| 3830 |
+
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
|
| 3831 |
+
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
|
| 3832 |
+
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
|
| 3833 |
+
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
|
| 3834 |
+
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
|
| 3835 |
+
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
|
| 3836 |
+
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
|
| 3837 |
+
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
|
| 3838 |
+
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
|
| 3839 |
+
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
|
| 3840 |
+
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
|
| 3841 |
+
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
|
| 3842 |
+
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
|
| 3843 |
+
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
|
| 3844 |
+
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
|
| 3845 |
+
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
|
| 3846 |
+
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
|
| 3847 |
+
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
|
| 3848 |
+
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
|
| 3849 |
+
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
|
| 3850 |
+
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
|
| 3851 |
+
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
|
| 3852 |
+
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
|
| 3853 |
+
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
|
| 3854 |
+
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
|
| 3855 |
+
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
|
| 3856 |
+
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
|
| 3857 |
+
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
|
| 3858 |
+
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
|
| 3859 |
+
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
|
| 3860 |
+
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
|
| 3861 |
+
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
|
| 3862 |
+
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
|
| 3863 |
+
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
|
| 3864 |
+
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
|
| 3865 |
+
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
|
| 3866 |
+
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
|
| 3867 |
+
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
|
| 3868 |
+
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
|
| 3869 |
+
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
|
| 3870 |
+
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
|
| 3871 |
+
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
|
| 3872 |
+
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
|
| 3873 |
+
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
|
| 3874 |
+
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
|
| 3875 |
+
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
|
| 3876 |
+
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
|
| 3877 |
+
};
|
| 3878 |
+
|
| 3879 |
#define NGRID_IQ1S 512
|
| 3880 |
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
|
| 3881 |
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
|
|
|
| 4442 |
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4443 |
}
|
| 4444 |
|
| 4445 |
+
void kernel_mul_mv_iq3_s_f32_impl(
|
| 4446 |
+
device const void * src0,
|
| 4447 |
+
device const float * src1,
|
| 4448 |
+
device float * dst,
|
| 4449 |
+
constant int64_t & ne00,
|
| 4450 |
+
constant int64_t & ne01,
|
| 4451 |
+
constant int64_t & ne02,
|
| 4452 |
+
constant int64_t & ne10,
|
| 4453 |
+
constant int64_t & ne12,
|
| 4454 |
+
constant int64_t & ne0,
|
| 4455 |
+
constant int64_t & ne1,
|
| 4456 |
+
constant uint & r2,
|
| 4457 |
+
constant uint & r3,
|
| 4458 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 4459 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4460 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4461 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4462 |
+
|
| 4463 |
+
const int nb = ne00/QK_K;
|
| 4464 |
+
const int r0 = tgpig.x;
|
| 4465 |
+
const int r1 = tgpig.y;
|
| 4466 |
+
const int im = tgpig.z;
|
| 4467 |
+
|
| 4468 |
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 4469 |
+
const int ib_row = first_row * nb;
|
| 4470 |
+
|
| 4471 |
+
const uint i12 = im%ne12;
|
| 4472 |
+
const uint i13 = im/ne12;
|
| 4473 |
+
|
| 4474 |
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
| 4475 |
+
|
| 4476 |
+
device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
|
| 4477 |
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
| 4478 |
+
|
| 4479 |
+
float yl[32];
|
| 4480 |
+
float sumf[N_DST]={0.f}, all_sum;
|
| 4481 |
+
|
| 4482 |
+
const int nb32 = nb * (QK_K / 32);
|
| 4483 |
+
|
| 4484 |
+
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
|
| 4485 |
+
{
|
| 4486 |
+
int nval = 8;
|
| 4487 |
+
int pos = (32*sgitg + tiisg)*nval;
|
| 4488 |
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i];
|
| 4489 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 4490 |
+
}
|
| 4491 |
+
|
| 4492 |
+
const int ix = tiisg;
|
| 4493 |
+
|
| 4494 |
+
device const float * y4 = y + 32 * ix;
|
| 4495 |
+
|
| 4496 |
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 4497 |
+
|
| 4498 |
+
for (int i = 0; i < 32; ++i) {
|
| 4499 |
+
yl[i] = y4[i];
|
| 4500 |
+
}
|
| 4501 |
+
|
| 4502 |
+
const int ibl = ib32 / (QK_K / 32);
|
| 4503 |
+
const int ib = ib32 % (QK_K / 32);
|
| 4504 |
+
|
| 4505 |
+
device const block_iq3_s * xr = x + ibl;
|
| 4506 |
+
device const uint8_t * qs = xr->qs + 8 * ib;
|
| 4507 |
+
device const uint8_t * qh = xr->qh + ib;
|
| 4508 |
+
device const uint8_t * sc = xr->scales + (ib/2);
|
| 4509 |
+
device const uint8_t * signs = xr->signs + 4 * ib;
|
| 4510 |
+
device const half * dh = &xr->d;
|
| 4511 |
+
|
| 4512 |
+
for (int row = 0; row < N_DST; row++) {
|
| 4513 |
+
|
| 4514 |
+
const float db = dh[0];
|
| 4515 |
+
const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf));
|
| 4516 |
+
|
| 4517 |
+
float2 sum = {0};
|
| 4518 |
+
for (int l = 0; l < 4; ++l) {
|
| 4519 |
+
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
|
| 4520 |
+
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
|
| 4521 |
+
for (int j = 0; j < 4; ++j) {
|
| 4522 |
+
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
| 4523 |
+
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
| 4524 |
+
}
|
| 4525 |
+
}
|
| 4526 |
+
sumf[row] += d * (sum[0] + sum[1]);
|
| 4527 |
+
|
| 4528 |
+
dh += nb*sizeof(block_iq3_s)/2;
|
| 4529 |
+
qs += nb*sizeof(block_iq3_s);
|
| 4530 |
+
qh += nb*sizeof(block_iq3_s);
|
| 4531 |
+
sc += nb*sizeof(block_iq3_s);
|
| 4532 |
+
signs += nb*sizeof(block_iq3_s);
|
| 4533 |
+
}
|
| 4534 |
+
|
| 4535 |
+
y4 += 32 * 32;
|
| 4536 |
+
}
|
| 4537 |
+
|
| 4538 |
+
for (int row = 0; row < N_DST; ++row) {
|
| 4539 |
+
all_sum = simd_sum(sumf[row]);
|
| 4540 |
+
if (tiisg == 0) {
|
| 4541 |
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
|
| 4542 |
+
}
|
| 4543 |
+
}
|
| 4544 |
+
}
|
| 4545 |
+
|
| 4546 |
+
[[host_name("kernel_mul_mv_iq3_s_f32")]]
|
| 4547 |
+
kernel void kernel_mul_mv_iq3_s_f32(
|
| 4548 |
+
device const void * src0,
|
| 4549 |
+
device const float * src1,
|
| 4550 |
+
device float * dst,
|
| 4551 |
+
constant int64_t & ne00,
|
| 4552 |
+
constant int64_t & ne01,
|
| 4553 |
+
constant int64_t & ne02,
|
| 4554 |
+
constant uint64_t & nb00,
|
| 4555 |
+
constant uint64_t & nb01,
|
| 4556 |
+
constant uint64_t & nb02,
|
| 4557 |
+
constant int64_t & ne10,
|
| 4558 |
+
constant int64_t & ne11,
|
| 4559 |
+
constant int64_t & ne12,
|
| 4560 |
+
constant uint64_t & nb10,
|
| 4561 |
+
constant uint64_t & nb11,
|
| 4562 |
+
constant uint64_t & nb12,
|
| 4563 |
+
constant int64_t & ne0,
|
| 4564 |
+
constant int64_t & ne1,
|
| 4565 |
+
constant uint & r2,
|
| 4566 |
+
constant uint & r3,
|
| 4567 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 4568 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4569 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4570 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4571 |
+
|
| 4572 |
+
kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4573 |
+
}
|
| 4574 |
+
|
| 4575 |
void kernel_mul_mv_iq1_s_f32_impl(
|
| 4576 |
device const void * src0,
|
| 4577 |
device const float * src1,
|
|
|
|
| 5163 |
}
|
| 5164 |
}
|
| 5165 |
|
| 5166 |
+
template <typename type4x4>
|
| 5167 |
+
void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
|
| 5168 |
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
| 5169 |
+
const float d = xb->d;
|
| 5170 |
+
const int ib32 = il/2;
|
| 5171 |
+
il = il%2;
|
| 5172 |
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
| 5173 |
+
device const uint8_t * qs = xb->qs + 8*ib32;
|
| 5174 |
+
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
|
| 5175 |
+
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
| 5176 |
+
const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f;
|
| 5177 |
+
constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
|
| 5178 |
+
constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
|
| 5179 |
+
for (int i = 0; i < 4; ++i) {
|
| 5180 |
+
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
| 5181 |
+
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
| 5182 |
+
}
|
| 5183 |
+
grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
|
| 5184 |
+
grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
|
| 5185 |
+
for (int i = 0; i < 4; ++i) {
|
| 5186 |
+
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
| 5187 |
+
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
| 5188 |
+
}
|
| 5189 |
+
}
|
| 5190 |
+
|
| 5191 |
template <typename type4x4>
|
| 5192 |
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
|
| 5193 |
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
|
|
| 5761 |
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5762 |
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5763 |
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
| 5764 |
+
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
| 5765 |
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 5766 |
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 5767 |
|
|
|
|
| 5803 |
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5804 |
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5805 |
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
| 5806 |
+
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
| 5807 |
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 5808 |
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 5809 |
|
|
|
|
| 5857 |
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5858 |
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5859 |
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
| 5860 |
+
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
| 5861 |
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 5862 |
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 5863 |
|
|
|
|
| 6828 |
sgitg);
|
| 6829 |
}
|
| 6830 |
|
| 6831 |
+
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
| 6832 |
+
kernel void kernel_mul_mv_id_iq3_s_f32(
|
| 6833 |
+
device const char * ids,
|
| 6834 |
+
device const char * src1,
|
| 6835 |
+
device float * dst,
|
| 6836 |
+
constant uint64_t & nbi1,
|
| 6837 |
+
constant int64_t & ne00,
|
| 6838 |
+
constant int64_t & ne01,
|
| 6839 |
+
constant int64_t & ne02,
|
| 6840 |
+
constant uint64_t & nb00,
|
| 6841 |
+
constant uint64_t & nb01,
|
| 6842 |
+
constant uint64_t & nb02,
|
| 6843 |
+
constant int64_t & ne10,
|
| 6844 |
+
constant int64_t & ne11,
|
| 6845 |
+
constant int64_t & ne12,
|
| 6846 |
+
constant int64_t & ne13,
|
| 6847 |
+
constant uint64_t & nb10,
|
| 6848 |
+
constant uint64_t & nb11,
|
| 6849 |
+
constant uint64_t & nb12,
|
| 6850 |
+
constant int64_t & ne0,
|
| 6851 |
+
constant int64_t & ne1,
|
| 6852 |
+
constant uint64_t & nb1,
|
| 6853 |
+
constant uint & r2,
|
| 6854 |
+
constant uint & r3,
|
| 6855 |
+
constant int & idx,
|
| 6856 |
+
device const char * src00,
|
| 6857 |
+
device const char * src01,
|
| 6858 |
+
device const char * src02,
|
| 6859 |
+
device const char * src03,
|
| 6860 |
+
device const char * src04,
|
| 6861 |
+
device const char * src05,
|
| 6862 |
+
device const char * src06,
|
| 6863 |
+
device const char * src07,
|
| 6864 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6865 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6866 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 6867 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 6868 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6869 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 6870 |
+
|
| 6871 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6872 |
+
|
| 6873 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6874 |
+
|
| 6875 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6876 |
+
|
| 6877 |
+
kernel_mul_mv_iq3_s_f32_impl(
|
| 6878 |
+
src0[id],
|
| 6879 |
+
(device const float *) (src1 + bid*nb11),
|
| 6880 |
+
dst + bid*ne0,
|
| 6881 |
+
ne00,
|
| 6882 |
+
ne01,
|
| 6883 |
+
ne02,
|
| 6884 |
+
ne10,
|
| 6885 |
+
ne12,
|
| 6886 |
+
ne0,
|
| 6887 |
+
ne1,
|
| 6888 |
+
r2,
|
| 6889 |
+
r3,
|
| 6890 |
+
shared_values,
|
| 6891 |
+
tgpig,
|
| 6892 |
+
tiisg,
|
| 6893 |
+
sgitg);
|
| 6894 |
+
}
|
| 6895 |
+
|
| 6896 |
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
| 6897 |
kernel void kernel_mul_mv_id_iq1_s_f32(
|
| 6898 |
device const char * ids,
|
|
@@ -3505,6 +3505,73 @@ static const uint32_t iq3xxs_grid[256] = {
|
|
| 3505 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 3506 |
};
|
| 3507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3508 |
#define NGRID_IQ2XXS 512
|
| 3509 |
static const uint64_t iq1s_grid[NGRID_IQ2XXS] = {
|
| 3510 |
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
|
@@ -3736,6 +3803,49 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
|
|
| 3736 |
}
|
| 3737 |
}
|
| 3738 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3739 |
// ====================== 1.5625 bpw (de)-quantization
|
| 3740 |
|
| 3741 |
void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
|
|
@@ -8806,6 +8916,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
| 8806 |
|
| 8807 |
#endif
|
| 8808 |
|
|
|
|
| 8809 |
static const int8_t keven_signs_q2xs[1024] = {
|
| 8810 |
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
|
| 8811 |
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
|
|
@@ -8840,6 +8951,7 @@ static const int8_t keven_signs_q2xs[1024] = {
|
|
| 8840 |
1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
|
| 8841 |
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
|
| 8842 |
};
|
|
|
|
| 8843 |
|
| 8844 |
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
| 8845 |
assert(n % QK_K == 0);
|
|
@@ -9327,6 +9439,202 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
|
|
| 9327 |
#endif
|
| 9328 |
}
|
| 9329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9330 |
#ifdef __AVX2__
|
| 9331 |
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
| 9332 |
const __m256i ax = _mm256_sign_epi8(x, x);
|
|
@@ -9523,6 +9831,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
| 9523 |
float sumf = 0;
|
| 9524 |
|
| 9525 |
for (int ib = 0; ib < nb; ib += 2) {
|
|
|
|
| 9526 |
q4bits.val[0] = vld1q_u8(x[ib+0].qs);
|
| 9527 |
q4bits.val[1] = vld1q_u8(x[ib+1].qs);
|
| 9528 |
q8b.val[0] = vld1q_s8(y[ib+0].qs);
|
|
@@ -10239,14 +10548,15 @@ typedef struct {
|
|
| 10239 |
uint16_t * neighbours;
|
| 10240 |
} iq3_entry_t;
|
| 10241 |
|
| 10242 |
-
static iq3_entry_t iq3_data[
|
|
|
|
| 10243 |
{NULL, NULL, NULL},
|
| 10244 |
};
|
| 10245 |
|
| 10246 |
static inline int iq3_data_index(int grid_size) {
|
| 10247 |
(void)grid_size;
|
| 10248 |
-
GGML_ASSERT(grid_size == 256);
|
| 10249 |
-
return 0;
|
| 10250 |
}
|
| 10251 |
|
| 10252 |
static int iq3_compare_func(const void * left, const void * right) {
|
|
@@ -10278,9 +10588,44 @@ void iq3xs_init_impl(int grid_size) {
|
|
| 10278 |
3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
|
| 10279 |
3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
|
| 10280 |
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10281 |
const int kmap_size = 4096;
|
| 10282 |
-
const int nwant = 2;
|
| 10283 |
-
const uint16_t * kgrid = kgrid_256;
|
| 10284 |
uint32_t * kgrid_q3xs;
|
| 10285 |
int * kmap_q3xs;
|
| 10286 |
uint16_t * kneighbors_q3xs;
|
|
@@ -10377,7 +10722,7 @@ void iq3xs_init_impl(int grid_size) {
|
|
| 10377 |
}
|
| 10378 |
|
| 10379 |
void iq3xs_free_impl(int grid_size) {
|
| 10380 |
-
GGML_ASSERT(grid_size == 256);
|
| 10381 |
const int gindex = iq3_data_index(grid_size);
|
| 10382 |
if (iq3_data[gindex].grid) {
|
| 10383 |
free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL;
|
|
@@ -10410,9 +10755,10 @@ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const u
|
|
| 10410 |
return grid_index;
|
| 10411 |
}
|
| 10412 |
|
| 10413 |
-
static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict vy, int n,
|
|
|
|
| 10414 |
|
| 10415 |
-
const int gindex = iq3_data_index(
|
| 10416 |
|
| 10417 |
const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
|
| 10418 |
const int * kmap_q3xs = iq3_data[gindex].map;
|
|
@@ -10426,9 +10772,23 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
|
|
| 10426 |
|
| 10427 |
const int kMaxQ = 8;
|
| 10428 |
|
| 10429 |
-
const int nbl = n/
|
| 10430 |
|
| 10431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10432 |
|
| 10433 |
float scales[QK_K/32];
|
| 10434 |
float weight[32];
|
|
@@ -10439,20 +10799,21 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
|
|
| 10439 |
bool is_on_grid[8];
|
| 10440 |
bool is_on_grid_aux[8];
|
| 10441 |
uint8_t block_signs[8];
|
| 10442 |
-
uint8_t q3[3*(QK_K/8)];
|
| 10443 |
uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
|
|
|
|
| 10444 |
|
| 10445 |
for (int ibl = 0; ibl < nbl; ++ibl) {
|
| 10446 |
|
| 10447 |
-
|
| 10448 |
-
memset(q3, 0, 3*QK_K/8);
|
| 10449 |
|
| 10450 |
float max_scale = 0;
|
| 10451 |
|
| 10452 |
const float * xbl = x + QK_K*ibl;
|
| 10453 |
float sumx2 = 0;
|
| 10454 |
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
|
| 10455 |
-
float sigma2 = sumx2/QK_K;
|
| 10456 |
|
| 10457 |
for (int ib = 0; ib < QK_K/32; ++ib) {
|
| 10458 |
const float * xb = xbl + 32*ib;
|
|
@@ -10570,7 +10931,13 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
|
|
| 10570 |
printf("\n");
|
| 10571 |
GGML_ASSERT(false);
|
| 10572 |
}
|
| 10573 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10574 |
}
|
| 10575 |
scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
|
| 10576 |
GGML_ASSERT(scale >= 0);
|
|
@@ -10579,63 +10946,25 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
|
|
| 10579 |
}
|
| 10580 |
|
| 10581 |
if (!max_scale) {
|
| 10582 |
-
memset(
|
|
|
|
|
|
|
| 10583 |
continue;
|
| 10584 |
}
|
| 10585 |
|
| 10586 |
float d = max_scale/31;
|
| 10587 |
-
|
| 10588 |
float id = 1/d;
|
| 10589 |
-
float sumqx = 0, sumq2 = 0;
|
| 10590 |
for (int ib = 0; ib < QK_K/32; ++ib) {
|
| 10591 |
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
| 10592 |
l = MAX(0, MIN(15, l));
|
| 10593 |
scales_and_signs[ib] |= ((uint32_t)l << 28);
|
| 10594 |
-
if (false) {
|
| 10595 |
-
const float * xb = xbl + 32*ib;
|
| 10596 |
-
if (quant_weights) {
|
| 10597 |
-
const float * qw = quant_weights + QK_K*ibl + 32*ib;
|
| 10598 |
-
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
| 10599 |
-
} else {
|
| 10600 |
-
for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
|
| 10601 |
-
}
|
| 10602 |
-
const float db = 0.25f * d * (1 + 2*l);
|
| 10603 |
-
for (int k = 0; k < 8; ++k) {
|
| 10604 |
-
const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2);
|
| 10605 |
-
const float * xk = xb + 4*k;
|
| 10606 |
-
const float * wk = weight + 4*k;
|
| 10607 |
-
//const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
|
| 10608 |
-
const uint8_t * grid = (const uint8_t *)(iq3xxs_grid + q3[8*ib+k]);
|
| 10609 |
-
float best_mse = 0; int best_index = q3[8*ib+k];
|
| 10610 |
-
for (int j = 0; j < 4; ++j) {
|
| 10611 |
-
float diff = db * grid[j] * signs[j] - xk[j];
|
| 10612 |
-
best_mse += wk[j] * diff * diff;
|
| 10613 |
-
}
|
| 10614 |
-
for (int idx = 0; idx < 256; ++idx) {
|
| 10615 |
-
//grid = (const uint8_t *)(kgrid_q3xs + idx);
|
| 10616 |
-
grid = (const uint8_t *)(iq3xxs_grid + idx);
|
| 10617 |
-
float mse = 0;
|
| 10618 |
-
for (int j = 0; j < 4; ++j) {
|
| 10619 |
-
float diff = db * grid[j] * signs[j] - xk[j];
|
| 10620 |
-
mse += wk[j] * diff * diff;
|
| 10621 |
-
}
|
| 10622 |
-
if (mse < best_mse) {
|
| 10623 |
-
best_mse = mse; best_index = idx;
|
| 10624 |
-
}
|
| 10625 |
-
}
|
| 10626 |
-
q3[8*ib+k] = best_index;
|
| 10627 |
-
//grid = (const uint8_t *)(kgrid_q3xs + best_index);
|
| 10628 |
-
grid = (const uint8_t *)(iq3xxs_grid + best_index);
|
| 10629 |
-
for (int j = 0; j < 4; ++j) {
|
| 10630 |
-
float q = db * grid[j] * signs[j];
|
| 10631 |
-
sumqx += wk[j] * q * xk[j];
|
| 10632 |
-
sumq2 += wk[j] * q * q;
|
| 10633 |
-
}
|
| 10634 |
-
}
|
| 10635 |
-
if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
|
| 10636 |
-
}
|
| 10637 |
}
|
| 10638 |
-
memcpy(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10639 |
}
|
| 10640 |
}
|
| 10641 |
|
|
@@ -10645,7 +10974,7 @@ size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row,
|
|
| 10645 |
int nblock = n_per_row/QK_K;
|
| 10646 |
char * qrow = (char *)dst;
|
| 10647 |
for (int row = 0; row < nrow; ++row) {
|
| 10648 |
-
quantize_row_iq3_xxs_impl(src, qrow, n_per_row, quant_weights);
|
| 10649 |
src += n_per_row;
|
| 10650 |
qrow += nblock*sizeof(block_iq3_xxs);
|
| 10651 |
}
|
|
@@ -10660,9 +10989,226 @@ void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
|
|
| 10660 |
|
| 10661 |
void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
|
| 10662 |
assert(k % QK_K == 0);
|
| 10663 |
-
quantize_row_iq3_xxs_impl(x, y, k, NULL);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10664 |
}
|
| 10665 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10666 |
// =================================== 1.5 bpw ===================================================
|
| 10667 |
|
| 10668 |
static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
|
|
|
|
| 3505 |
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
| 3506 |
};
|
| 3507 |
|
| 3508 |
+
static const uint32_t iq3xs_grid[512] = {
|
| 3509 |
+
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
|
| 3510 |
+
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
|
| 3511 |
+
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
|
| 3512 |
+
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
|
| 3513 |
+
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
|
| 3514 |
+
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
|
| 3515 |
+
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
|
| 3516 |
+
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
|
| 3517 |
+
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
|
| 3518 |
+
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
|
| 3519 |
+
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
|
| 3520 |
+
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
|
| 3521 |
+
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
|
| 3522 |
+
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
|
| 3523 |
+
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
|
| 3524 |
+
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
|
| 3525 |
+
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
|
| 3526 |
+
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
|
| 3527 |
+
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
|
| 3528 |
+
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
|
| 3529 |
+
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
|
| 3530 |
+
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
|
| 3531 |
+
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
|
| 3532 |
+
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
|
| 3533 |
+
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
|
| 3534 |
+
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
|
| 3535 |
+
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
|
| 3536 |
+
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
|
| 3537 |
+
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
|
| 3538 |
+
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
|
| 3539 |
+
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
|
| 3540 |
+
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
|
| 3541 |
+
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
|
| 3542 |
+
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
|
| 3543 |
+
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
|
| 3544 |
+
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
|
| 3545 |
+
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
|
| 3546 |
+
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
|
| 3547 |
+
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
|
| 3548 |
+
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
|
| 3549 |
+
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
|
| 3550 |
+
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
|
| 3551 |
+
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
|
| 3552 |
+
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
|
| 3553 |
+
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
|
| 3554 |
+
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
|
| 3555 |
+
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
|
| 3556 |
+
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
|
| 3557 |
+
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
|
| 3558 |
+
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
|
| 3559 |
+
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
|
| 3560 |
+
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
|
| 3561 |
+
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
|
| 3562 |
+
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
|
| 3563 |
+
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
|
| 3564 |
+
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
|
| 3565 |
+
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
|
| 3566 |
+
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
|
| 3567 |
+
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
|
| 3568 |
+
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
|
| 3569 |
+
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
|
| 3570 |
+
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
|
| 3571 |
+
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
|
| 3572 |
+
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
|
| 3573 |
+
};
|
| 3574 |
+
|
| 3575 |
#define NGRID_IQ2XXS 512
|
| 3576 |
static const uint64_t iq1s_grid[NGRID_IQ2XXS] = {
|
| 3577 |
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
|
|
|
| 3803 |
}
|
| 3804 |
}
|
| 3805 |
|
| 3806 |
+
// ====================== 3.3125 bpw (de)-quantization
|
| 3807 |
+
|
| 3808 |
+
void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) {
|
| 3809 |
+
assert(k % QK_K == 0);
|
| 3810 |
+
const int nb = k / QK_K;
|
| 3811 |
+
|
| 3812 |
+
for (int i = 0; i < nb; i++) {
|
| 3813 |
+
|
| 3814 |
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
| 3815 |
+
const uint8_t * qs = x[i].qs;
|
| 3816 |
+
const uint8_t * qh = x[i].qh;
|
| 3817 |
+
const uint8_t * signs = x[i].signs;
|
| 3818 |
+
|
| 3819 |
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 3820 |
+
const float db1 = d * (0.5f + (x[i].scales[ib32/2] & 0xf)) * 0.5f;
|
| 3821 |
+
const float db2 = d * (0.5f + (x[i].scales[ib32/2] >> 4)) * 0.5f;
|
| 3822 |
+
for (int l = 0; l < 4; ++l) {
|
| 3823 |
+
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
|
| 3824 |
+
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
|
| 3825 |
+
for (int j = 0; j < 4; ++j) {
|
| 3826 |
+
y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
| 3827 |
+
y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
| 3828 |
+
}
|
| 3829 |
+
y += 8;
|
| 3830 |
+
}
|
| 3831 |
+
qs += 8;
|
| 3832 |
+
signs += 4;
|
| 3833 |
+
for (int l = 0; l < 4; ++l) {
|
| 3834 |
+
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)));
|
| 3835 |
+
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)));
|
| 3836 |
+
for (int j = 0; j < 4; ++j) {
|
| 3837 |
+
y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
| 3838 |
+
y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
| 3839 |
+
}
|
| 3840 |
+
y += 8;
|
| 3841 |
+
}
|
| 3842 |
+
qh += 2;
|
| 3843 |
+
qs += 8;
|
| 3844 |
+
signs += 4;
|
| 3845 |
+
}
|
| 3846 |
+
}
|
| 3847 |
+
}
|
| 3848 |
+
|
| 3849 |
// ====================== 1.5625 bpw (de)-quantization
|
| 3850 |
|
| 3851 |
void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
|
|
|
|
| 8916 |
|
| 8917 |
#endif
|
| 8918 |
|
| 8919 |
+
#if defined (__AVX2__) || defined (__ARM_NEON)
|
| 8920 |
static const int8_t keven_signs_q2xs[1024] = {
|
| 8921 |
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
|
| 8922 |
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
|
|
|
|
| 8951 |
1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
|
| 8952 |
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
|
| 8953 |
};
|
| 8954 |
+
#endif
|
| 8955 |
|
| 8956 |
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
| 8957 |
assert(n % QK_K == 0);
|
|
|
|
| 9439 |
#endif
|
| 9440 |
}
|
| 9441 |
|
| 9442 |
+
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 9443 |
+
assert(n % QK_K == 0);
|
| 9444 |
+
assert(nrc == 1);
|
| 9445 |
+
UNUSED(nrc);
|
| 9446 |
+
UNUSED(bx);
|
| 9447 |
+
UNUSED(by);
|
| 9448 |
+
UNUSED(bs);
|
| 9449 |
+
|
| 9450 |
+
const block_iq3_s * restrict x = vx;
|
| 9451 |
+
const block_q8_K * restrict y = vy;
|
| 9452 |
+
|
| 9453 |
+
const int nb = n / QK_K;
|
| 9454 |
+
|
| 9455 |
+
#if defined(__ARM_NEON)
|
| 9456 |
+
|
| 9457 |
+
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
| 9458 |
+
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
| 9459 |
+
};
|
| 9460 |
+
|
| 9461 |
+
static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
|
| 9462 |
+
|
| 9463 |
+
const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1);
|
| 9464 |
+
const uint8x16_t mask2 = vld1q_u8(k_mask2);
|
| 9465 |
+
|
| 9466 |
+
uint8x16x2_t vs;
|
| 9467 |
+
ggml_int8x16x4_t q3s;
|
| 9468 |
+
ggml_int8x16x4_t q8b;
|
| 9469 |
+
|
| 9470 |
+
float sumf = 0;
|
| 9471 |
+
for (int i = 0; i < nb; ++i) {
|
| 9472 |
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
| 9473 |
+
const uint8_t * restrict qs = x[i].qs;
|
| 9474 |
+
const uint8_t * restrict qh = x[i].qh;
|
| 9475 |
+
const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
|
| 9476 |
+
const int8_t * restrict q8 = y[i].qs;
|
| 9477 |
+
int sumi1 = 0, sumi2 = 0;
|
| 9478 |
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 9479 |
+
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
| 9480 |
+
const uint32x4_t aux32x4_0 = {iq3xs_grid[qs[ 0] | ((qh[ib32+0] << 8) & 256)], iq3xs_grid[qs[ 1] | ((qh[ib32+0] << 7) & 256)],
|
| 9481 |
+
iq3xs_grid[qs[ 2] | ((qh[ib32+0] << 6) & 256)], iq3xs_grid[qs[ 3] | ((qh[ib32+0] << 5) & 256)]};
|
| 9482 |
+
const uint32x4_t aux32x4_1 = {iq3xs_grid[qs[ 4] | ((qh[ib32+0] << 4) & 256)], iq3xs_grid[qs[ 5] | ((qh[ib32+0] << 3) & 256)],
|
| 9483 |
+
iq3xs_grid[qs[ 6] | ((qh[ib32+0] << 2) & 256)], iq3xs_grid[qs[ 7] | ((qh[ib32+0] << 1) & 256)]};
|
| 9484 |
+
const uint32x4_t aux32x4_2 = {iq3xs_grid[qs[ 8] | ((qh[ib32+1] << 8) & 256)], iq3xs_grid[qs[ 9] | ((qh[ib32+1] << 7) & 256)],
|
| 9485 |
+
iq3xs_grid[qs[10] | ((qh[ib32+1] << 6) & 256)], iq3xs_grid[qs[11] | ((qh[ib32+1] << 5) & 256)]};
|
| 9486 |
+
const uint32x4_t aux32x4_3 = {iq3xs_grid[qs[12] | ((qh[ib32+1] << 4) & 256)], iq3xs_grid[qs[13] | ((qh[ib32+1] << 3) & 256)],
|
| 9487 |
+
iq3xs_grid[qs[14] | ((qh[ib32+1] << 2) & 256)], iq3xs_grid[qs[15] | ((qh[ib32+1] << 1) & 256)]};
|
| 9488 |
+
qs += 16;
|
| 9489 |
+
|
| 9490 |
+
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
|
| 9491 |
+
vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
| 9492 |
+
vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
| 9493 |
+
vs.val[0] = vceqq_u8(vs.val[0], mask2);
|
| 9494 |
+
vs.val[1] = vceqq_u8(vs.val[1], mask2);
|
| 9495 |
+
|
| 9496 |
+
q3s.val[0] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_0))), vreinterpretq_s8_u8(vs.val[0]));
|
| 9497 |
+
q3s.val[1] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_1))), vreinterpretq_s8_u8(vs.val[1]));
|
| 9498 |
+
|
| 9499 |
+
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
|
| 9500 |
+
vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
| 9501 |
+
vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
| 9502 |
+
vs.val[0] = vceqq_u8(vs.val[0], mask2);
|
| 9503 |
+
vs.val[1] = vceqq_u8(vs.val[1], mask2);
|
| 9504 |
+
|
| 9505 |
+
signs += 4;
|
| 9506 |
+
|
| 9507 |
+
q3s.val[2] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_2))), vreinterpretq_s8_u8(vs.val[0]));
|
| 9508 |
+
q3s.val[3] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_3))), vreinterpretq_s8_u8(vs.val[1]));
|
| 9509 |
+
|
| 9510 |
+
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
|
| 9511 |
+
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
| 9512 |
+
sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf));
|
| 9513 |
+
sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4));
|
| 9514 |
+
}
|
| 9515 |
+
sumf += d*(sumi1 + sumi2);
|
| 9516 |
+
}
|
| 9517 |
+
*s = 0.25f * sumf;
|
| 9518 |
+
|
| 9519 |
+
#elif defined(__AVX2__)
|
| 9520 |
+
|
| 9521 |
+
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
| 9522 |
+
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
| 9523 |
+
};
|
| 9524 |
+
|
| 9525 |
+
static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
| 9526 |
+
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
| 9527 |
+
};
|
| 9528 |
+
|
| 9529 |
+
const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
|
| 9530 |
+
const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
|
| 9531 |
+
|
| 9532 |
+
__m256 accumf = _mm256_setzero_ps();
|
| 9533 |
+
for (int i = 0; i < nb; ++i) {
|
| 9534 |
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
| 9535 |
+
const uint8_t * restrict qs = x[i].qs;
|
| 9536 |
+
const uint8_t * restrict qh = x[i].qh;
|
| 9537 |
+
const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
|
| 9538 |
+
const int8_t * restrict q8 = y[i].qs;
|
| 9539 |
+
__m256i sumi1 = _mm256_setzero_si256();
|
| 9540 |
+
__m256i sumi2 = _mm256_setzero_si256();
|
| 9541 |
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 9542 |
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 9543 |
+
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 9544 |
+
const __m256i q2_1 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+0] << 1) & 256)],
|
| 9545 |
+
iq3xs_grid[qs[6] | ((qh[ib32+0] << 2) & 256)],
|
| 9546 |
+
iq3xs_grid[qs[5] | ((qh[ib32+0] << 3) & 256)],
|
| 9547 |
+
iq3xs_grid[qs[4] | ((qh[ib32+0] << 4) & 256)],
|
| 9548 |
+
iq3xs_grid[qs[3] | ((qh[ib32+0] << 5) & 256)],
|
| 9549 |
+
iq3xs_grid[qs[2] | ((qh[ib32+0] << 6) & 256)],
|
| 9550 |
+
iq3xs_grid[qs[1] | ((qh[ib32+0] << 7) & 256)],
|
| 9551 |
+
iq3xs_grid[qs[0] | ((qh[ib32+0] << 8) & 256)]);
|
| 9552 |
+
qs += 8;
|
| 9553 |
+
const __m256i q2_2 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+1] << 1) & 256)],
|
| 9554 |
+
iq3xs_grid[qs[6] | ((qh[ib32+1] << 2) & 256)],
|
| 9555 |
+
iq3xs_grid[qs[5] | ((qh[ib32+1] << 3) & 256)],
|
| 9556 |
+
iq3xs_grid[qs[4] | ((qh[ib32+1] << 4) & 256)],
|
| 9557 |
+
iq3xs_grid[qs[3] | ((qh[ib32+1] << 5) & 256)],
|
| 9558 |
+
iq3xs_grid[qs[2] | ((qh[ib32+1] << 6) & 256)],
|
| 9559 |
+
iq3xs_grid[qs[1] | ((qh[ib32+1] << 7) & 256)],
|
| 9560 |
+
iq3xs_grid[qs[0] | ((qh[ib32+1] << 8) & 256)]);
|
| 9561 |
+
qs += 8;
|
| 9562 |
+
|
| 9563 |
+
__m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
|
| 9564 |
+
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
|
| 9565 |
+
const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
|
| 9566 |
+
const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
|
| 9567 |
+
|
| 9568 |
+
aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));
|
| 9569 |
+
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
|
| 9570 |
+
const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
|
| 9571 |
+
const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
|
| 9572 |
+
|
| 9573 |
+
signs += 4;
|
| 9574 |
+
|
| 9575 |
+
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
| 9576 |
+
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
| 9577 |
+
const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
|
| 9578 |
+
const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
|
| 9579 |
+
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
|
| 9580 |
+
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
|
| 9581 |
+
sumi1 = _mm256_add_epi32(sumi1, p1);
|
| 9582 |
+
sumi2 = _mm256_add_epi32(sumi2, p2);
|
| 9583 |
+
}
|
| 9584 |
+
|
| 9585 |
+
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
|
| 9586 |
+
|
| 9587 |
+
}
|
| 9588 |
+
|
| 9589 |
+
*s = 0.25f * hsum_float_8(accumf);
|
| 9590 |
+
|
| 9591 |
+
#else
|
| 9592 |
+
|
| 9593 |
+
float sumf = 0.f;
|
| 9594 |
+
for (int i = 0; i < nb; ++i) {
|
| 9595 |
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
| 9596 |
+
const uint8_t * restrict qs = x[i].qs;
|
| 9597 |
+
const uint8_t * restrict qh = x[i].qh;
|
| 9598 |
+
const uint8_t * restrict signs = x[i].signs;
|
| 9599 |
+
const int8_t * restrict q8 = y[i].qs;
|
| 9600 |
+
int32_t bsum = 0;
|
| 9601 |
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
| 9602 |
+
const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
|
| 9603 |
+
const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1;
|
| 9604 |
+
int32_t sumi = 0;
|
| 9605 |
+
for (int l = 0; l < 4; ++l) {
|
| 9606 |
+
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
|
| 9607 |
+
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
|
| 9608 |
+
for (int j = 0; j < 4; ++j) {
|
| 9609 |
+
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
|
| 9610 |
+
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
|
| 9611 |
+
}
|
| 9612 |
+
q8 += 8;
|
| 9613 |
+
}
|
| 9614 |
+
qs += 8;
|
| 9615 |
+
signs += 4;
|
| 9616 |
+
bsum += sumi * ls1;
|
| 9617 |
+
sumi = 0;
|
| 9618 |
+
for (int l = 0; l < 4; ++l) {
|
| 9619 |
+
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
|
| 9620 |
+
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
|
| 9621 |
+
for (int j = 0; j < 4; ++j) {
|
| 9622 |
+
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
|
| 9623 |
+
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
|
| 9624 |
+
}
|
| 9625 |
+
q8 += 8;
|
| 9626 |
+
}
|
| 9627 |
+
qs += 8;
|
| 9628 |
+
signs += 4;
|
| 9629 |
+
bsum += sumi * ls2;
|
| 9630 |
+
}
|
| 9631 |
+
sumf += d * bsum;
|
| 9632 |
+
}
|
| 9633 |
+
*s = 0.25f * sumf;
|
| 9634 |
+
#endif
|
| 9635 |
+
}
|
| 9636 |
+
|
| 9637 |
+
|
| 9638 |
#ifdef __AVX2__
|
| 9639 |
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
| 9640 |
const __m256i ax = _mm256_sign_epi8(x, x);
|
|
|
|
| 9831 |
float sumf = 0;
|
| 9832 |
|
| 9833 |
for (int ib = 0; ib < nb; ib += 2) {
|
| 9834 |
+
|
| 9835 |
q4bits.val[0] = vld1q_u8(x[ib+0].qs);
|
| 9836 |
q4bits.val[1] = vld1q_u8(x[ib+1].qs);
|
| 9837 |
q8b.val[0] = vld1q_s8(y[ib+0].qs);
|
|
|
|
| 10548 |
uint16_t * neighbours;
|
| 10549 |
} iq3_entry_t;
|
| 10550 |
|
| 10551 |
+
static iq3_entry_t iq3_data[2] = {
|
| 10552 |
+
{NULL, NULL, NULL},
|
| 10553 |
{NULL, NULL, NULL},
|
| 10554 |
};
|
| 10555 |
|
| 10556 |
static inline int iq3_data_index(int grid_size) {
|
| 10557 |
(void)grid_size;
|
| 10558 |
+
GGML_ASSERT(grid_size == 256 || grid_size == 512);
|
| 10559 |
+
return grid_size == 256 ? 0 : 1;
|
| 10560 |
}
|
| 10561 |
|
| 10562 |
static int iq3_compare_func(const void * left, const void * right) {
|
|
|
|
| 10588 |
3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
|
| 10589 |
3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
|
| 10590 |
};
|
| 10591 |
+
static const uint16_t kgrid_512[512] = {
|
| 10592 |
+
0, 1, 2, 5, 7, 8, 9, 10, 12, 14, 16, 17, 21, 27, 32, 34,
|
| 10593 |
+
37, 39, 41, 43, 48, 50, 57, 60, 63, 64, 65, 66, 68, 72, 73, 77,
|
| 10594 |
+
80, 83, 87, 89, 93, 100, 113, 117, 122, 128, 129, 133, 135, 136, 139, 142,
|
| 10595 |
+
145, 149, 152, 156, 162, 165, 167, 169, 171, 184, 187, 195, 201, 205, 208, 210,
|
| 10596 |
+
217, 219, 222, 228, 232, 234, 247, 249, 253, 256, 267, 271, 273, 276, 282, 288,
|
| 10597 |
+
291, 297, 312, 322, 324, 336, 338, 342, 347, 353, 357, 359, 374, 379, 390, 393,
|
| 10598 |
+
395, 409, 426, 441, 448, 450, 452, 464, 466, 470, 475, 488, 492, 512, 513, 514,
|
| 10599 |
+
516, 520, 521, 523, 525, 527, 528, 530, 537, 540, 542, 556, 558, 561, 570, 576,
|
| 10600 |
+
577, 579, 582, 584, 588, 593, 600, 603, 609, 616, 618, 632, 638, 640, 650, 653,
|
| 10601 |
+
655, 656, 660, 666, 672, 675, 685, 688, 698, 705, 708, 711, 712, 715, 721, 727,
|
| 10602 |
+
728, 732, 737, 754, 760, 771, 773, 778, 780, 793, 795, 802, 806, 808, 812, 833,
|
| 10603 |
+
840, 843, 849, 856, 858, 873, 912, 916, 919, 932, 934, 961, 963, 968, 970, 977,
|
| 10604 |
+
989, 993, 1010, 1016, 1024, 1025, 1027, 1029, 1031, 1032, 1034, 1036, 1038, 1041, 1043, 1047,
|
| 10605 |
+
1048, 1050, 1057, 1059, 1061, 1064, 1066, 1079, 1080, 1083, 1085, 1088, 1090, 1096, 1099, 1103,
|
| 10606 |
+
1106, 1109, 1113, 1116, 1122, 1129, 1153, 1156, 1159, 1169, 1171, 1176, 1183, 1185, 1195, 1199,
|
| 10607 |
+
1209, 1212, 1216, 1218, 1221, 1225, 1234, 1236, 1241, 1243, 1250, 1256, 1270, 1281, 1287, 1296,
|
| 10608 |
+
1299, 1306, 1309, 1313, 1338, 1341, 1348, 1353, 1362, 1375, 1376, 1387, 1400, 1408, 1410, 1415,
|
| 10609 |
+
1425, 1453, 1457, 1477, 1481, 1494, 1496, 1507, 1512, 1538, 1545, 1547, 1549, 1551, 1554, 1561,
|
| 10610 |
+
1563, 1565, 1570, 1572, 1575, 1577, 1587, 1593, 1601, 1603, 1605, 1612, 1617, 1619, 1632, 1648,
|
| 10611 |
+
1658, 1662, 1664, 1674, 1680, 1690, 1692, 1704, 1729, 1736, 1740, 1745, 1747, 1751, 1752, 1761,
|
| 10612 |
+
1763, 1767, 1773, 1787, 1795, 1801, 1806, 1810, 1817, 1834, 1840, 1844, 1857, 1864, 1866, 1877,
|
| 10613 |
+
1882, 1892, 1902, 1915, 1934, 1953, 1985, 1987, 2000, 2002, 2013, 2048, 2052, 2058, 2064, 2068,
|
| 10614 |
+
2071, 2074, 2081, 2088, 2104, 2114, 2119, 2121, 2123, 2130, 2136, 2141, 2147, 2153, 2157, 2177,
|
| 10615 |
+
2179, 2184, 2189, 2193, 2203, 2208, 2223, 2226, 2232, 2244, 2249, 2251, 2256, 2258, 2265, 2269,
|
| 10616 |
+
2304, 2306, 2324, 2335, 2336, 2361, 2373, 2375, 2385, 2418, 2443, 2460, 2480, 2504, 2509, 2520,
|
| 10617 |
+
2531, 2537, 2562, 2568, 2572, 2578, 2592, 2596, 2599, 2602, 2614, 2620, 2625, 2627, 2629, 2634,
|
| 10618 |
+
2641, 2650, 2682, 2688, 2697, 2707, 2712, 2718, 2731, 2754, 2759, 2760, 2775, 2788, 2793, 2805,
|
| 10619 |
+
2811, 2817, 2820, 2832, 2842, 2854, 2890, 2902, 2921, 2923, 2978, 3010, 3012, 3026, 3081, 3083,
|
| 10620 |
+
3085, 3097, 3099, 3120, 3136, 3152, 3159, 3188, 3210, 3228, 3234, 3245, 3250, 3256, 3264, 3276,
|
| 10621 |
+
3281, 3296, 3349, 3363, 3378, 3392, 3395, 3420, 3440, 3461, 3488, 3529, 3531, 3584, 3588, 3591,
|
| 10622 |
+
3600, 3602, 3614, 3616, 3628, 3634, 3650, 3657, 3668, 3683, 3685, 3713, 3716, 3720, 3726, 3729,
|
| 10623 |
+
3736, 3753, 3778, 3802, 3805, 3819, 3841, 3845, 3851, 3856, 3880, 3922, 3938, 3970, 3993, 4032,
|
| 10624 |
+
};
|
| 10625 |
+
|
| 10626 |
const int kmap_size = 4096;
|
| 10627 |
+
const int nwant = grid_size == 256 ? 2 : 3;
|
| 10628 |
+
const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
|
| 10629 |
uint32_t * kgrid_q3xs;
|
| 10630 |
int * kmap_q3xs;
|
| 10631 |
uint16_t * kneighbors_q3xs;
|
|
|
|
| 10722 |
}
|
| 10723 |
|
| 10724 |
void iq3xs_free_impl(int grid_size) {
|
| 10725 |
+
GGML_ASSERT(grid_size == 256 || grid_size == 512);
|
| 10726 |
const int gindex = iq3_data_index(grid_size);
|
| 10727 |
if (iq3_data[gindex].grid) {
|
| 10728 |
free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL;
|
|
|
|
| 10755 |
return grid_index;
|
| 10756 |
}
|
| 10757 |
|
| 10758 |
+
static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int n,
|
| 10759 |
+
const float * restrict quant_weights) {
|
| 10760 |
|
| 10761 |
+
const int gindex = iq3_data_index(grid_size);
|
| 10762 |
|
| 10763 |
const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
|
| 10764 |
const int * kmap_q3xs = iq3_data[gindex].map;
|
|
|
|
| 10772 |
|
| 10773 |
const int kMaxQ = 8;
|
| 10774 |
|
| 10775 |
+
const int nbl = n/QK_K;
|
| 10776 |
|
| 10777 |
+
ggml_fp16_t * dh;
|
| 10778 |
+
uint8_t * qs;
|
| 10779 |
+
int block_size;
|
| 10780 |
+
if (grid_size == 256) {
|
| 10781 |
+
block_iq3_xxs * y = vy;
|
| 10782 |
+
dh = &y->d;
|
| 10783 |
+
qs = y->qs;
|
| 10784 |
+
block_size = sizeof(block_iq3_xxs);
|
| 10785 |
+
} else {
|
| 10786 |
+
block_iq3_s * y = vy;
|
| 10787 |
+
dh = &y->d;
|
| 10788 |
+
qs = y->qs;
|
| 10789 |
+
block_size = sizeof(block_iq3_s);
|
| 10790 |
+
}
|
| 10791 |
+
int quant_size = block_size - sizeof(ggml_fp16_t);
|
| 10792 |
|
| 10793 |
float scales[QK_K/32];
|
| 10794 |
float weight[32];
|
|
|
|
| 10799 |
bool is_on_grid[8];
|
| 10800 |
bool is_on_grid_aux[8];
|
| 10801 |
uint8_t block_signs[8];
|
| 10802 |
+
uint8_t q3[3*(QK_K/8)+QK_K/32];
|
| 10803 |
uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
|
| 10804 |
+
uint8_t * qh = q3 + 3*(QK_K/8);
|
| 10805 |
|
| 10806 |
for (int ibl = 0; ibl < nbl; ++ibl) {
|
| 10807 |
|
| 10808 |
+
dh[0] = GGML_FP32_TO_FP16(0.f);
|
| 10809 |
+
memset(q3, 0, 3*QK_K/8+QK_K/32);
|
| 10810 |
|
| 10811 |
float max_scale = 0;
|
| 10812 |
|
| 10813 |
const float * xbl = x + QK_K*ibl;
|
| 10814 |
float sumx2 = 0;
|
| 10815 |
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
|
| 10816 |
+
float sigma2 = 2*sumx2/QK_K;
|
| 10817 |
|
| 10818 |
for (int ib = 0; ib < QK_K/32; ++ib) {
|
| 10819 |
const float * xb = xbl + 32*ib;
|
|
|
|
| 10931 |
printf("\n");
|
| 10932 |
GGML_ASSERT(false);
|
| 10933 |
}
|
| 10934 |
+
if (grid_size == 256) {
|
| 10935 |
+
q3[8*ib+k] = grid_index;
|
| 10936 |
+
} else {
|
| 10937 |
+
q3[8*ib+k] = grid_index & 255;
|
| 10938 |
+
qh[ib] |= ((grid_index >> 8) << k);
|
| 10939 |
+
}
|
| 10940 |
+
|
| 10941 |
}
|
| 10942 |
scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
|
| 10943 |
GGML_ASSERT(scale >= 0);
|
|
|
|
| 10946 |
}
|
| 10947 |
|
| 10948 |
if (!max_scale) {
|
| 10949 |
+
memset(qs, 0, quant_size);
|
| 10950 |
+
dh += block_size/sizeof(ggml_fp16_t);
|
| 10951 |
+
qs += block_size;
|
| 10952 |
continue;
|
| 10953 |
}
|
| 10954 |
|
| 10955 |
float d = max_scale/31;
|
| 10956 |
+
dh[0] = GGML_FP32_TO_FP16(d * 1.0125f); // small improvement via this fudge factor
|
| 10957 |
float id = 1/d;
|
|
|
|
| 10958 |
for (int ib = 0; ib < QK_K/32; ++ib) {
|
| 10959 |
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
| 10960 |
l = MAX(0, MIN(15, l));
|
| 10961 |
scales_and_signs[ib] |= ((uint32_t)l << 28);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10962 |
}
|
| 10963 |
+
memcpy(qs, q3, quant_size);
|
| 10964 |
+
|
| 10965 |
+
dh += block_size/sizeof(ggml_fp16_t);
|
| 10966 |
+
qs += block_size;
|
| 10967 |
+
|
| 10968 |
}
|
| 10969 |
}
|
| 10970 |
|
|
|
|
| 10974 |
int nblock = n_per_row/QK_K;
|
| 10975 |
char * qrow = (char *)dst;
|
| 10976 |
for (int row = 0; row < nrow; ++row) {
|
| 10977 |
+
quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights);
|
| 10978 |
src += n_per_row;
|
| 10979 |
qrow += nblock*sizeof(block_iq3_xxs);
|
| 10980 |
}
|
|
|
|
| 10989 |
|
| 10990 |
void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
|
| 10991 |
assert(k % QK_K == 0);
|
| 10992 |
+
quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
|
| 10993 |
+
}
|
| 10994 |
+
|
| 10995 |
+
static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, void * restrict vy, int n,
|
| 10996 |
+
const float * restrict quant_weights,
|
| 10997 |
+
float * scales,
|
| 10998 |
+
float * weight,
|
| 10999 |
+
float * xval,
|
| 11000 |
+
int8_t * L,
|
| 11001 |
+
int8_t * Laux,
|
| 11002 |
+
float * waux,
|
| 11003 |
+
bool * is_on_grid,
|
| 11004 |
+
bool * is_on_grid_aux,
|
| 11005 |
+
uint8_t * block_signs) {
|
| 11006 |
+
|
| 11007 |
+
const int gindex = iq3_data_index(512);
|
| 11008 |
+
|
| 11009 |
+
const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
|
| 11010 |
+
const int * kmap_q3xs = iq3_data[gindex].map;
|
| 11011 |
+
const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
|
| 11012 |
+
|
| 11013 |
+
//GGML_ASSERT(quant_weights && "missing quantization weights");
|
| 11014 |
+
GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?");
|
| 11015 |
+
GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?");
|
| 11016 |
+
GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
|
| 11017 |
+
GGML_ASSERT(n%QK_K == 0);
|
| 11018 |
+
|
| 11019 |
+
const int kMaxQ = 8;
|
| 11020 |
+
|
| 11021 |
+
const int nbl = n/QK_K;
|
| 11022 |
+
|
| 11023 |
+
block_iq3_s * y = vy;
|
| 11024 |
+
|
| 11025 |
+
const int bs4 = block_size/4;
|
| 11026 |
+
const int bs8 = block_size/8;
|
| 11027 |
+
|
| 11028 |
+
for (int ibl = 0; ibl < nbl; ++ibl) {
|
| 11029 |
+
|
| 11030 |
+
memset(&y[ibl], 0, sizeof(block_iq3_s));
|
| 11031 |
+
y[ibl].d = GGML_FP32_TO_FP16(0.f);
|
| 11032 |
+
|
| 11033 |
+
uint8_t * qs = y[ibl].qs;
|
| 11034 |
+
uint8_t * qh = y[ibl].qh;
|
| 11035 |
+
uint8_t * signs = y[ibl].signs;
|
| 11036 |
+
|
| 11037 |
+
float max_scale = 0;
|
| 11038 |
+
|
| 11039 |
+
const float * xbl = x + QK_K*ibl;
|
| 11040 |
+
float sumx2 = 0;
|
| 11041 |
+
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
|
| 11042 |
+
float sigma2 = 2*sumx2/QK_K;
|
| 11043 |
+
|
| 11044 |
+
for (int ib = 0; ib < QK_K/block_size; ++ib) {
|
| 11045 |
+
const float * xb = xbl + block_size*ib;
|
| 11046 |
+
if (quant_weights) {
|
| 11047 |
+
const float * qw = quant_weights + QK_K*ibl + block_size*ib;
|
| 11048 |
+
for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
| 11049 |
+
} else {
|
| 11050 |
+
for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
|
| 11051 |
+
}
|
| 11052 |
+
for (int i = 0; i < block_size; ++i) waux[i] = sqrtf(weight[i]);
|
| 11053 |
+
for (int k = 0; k < bs8; ++k) {
|
| 11054 |
+
uint8_t s = 0;
|
| 11055 |
+
for (int i = 0; i < 8; ++i) {
|
| 11056 |
+
if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
|
| 11057 |
+
else {
|
| 11058 |
+
xval[8*k + i] = -xb[8*k + i]; s |= (1 << i);
|
| 11059 |
+
}
|
| 11060 |
+
}
|
| 11061 |
+
block_signs[k] = s;
|
| 11062 |
+
}
|
| 11063 |
+
float max = xval[0];
|
| 11064 |
+
for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]);
|
| 11065 |
+
if (!max) {
|
| 11066 |
+
scales[ib] = 0;
|
| 11067 |
+
continue;
|
| 11068 |
+
}
|
| 11069 |
+
float best = 0;
|
| 11070 |
+
float scale = max/(2*kMaxQ-1);
|
| 11071 |
+
for (int is = -15; is <= 15; ++is) {
|
| 11072 |
+
float id = (2*kMaxQ-1+is*0.2f)/max;
|
| 11073 |
+
float this_scale = 1/id;
|
| 11074 |
+
for (int k = 0; k < bs4; ++k) {
|
| 11075 |
+
for (int i = 0; i < 4; ++i) {
|
| 11076 |
+
int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
|
| 11077 |
+
Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
|
| 11078 |
+
}
|
| 11079 |
+
uint16_t u = 0;
|
| 11080 |
+
for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
|
| 11081 |
+
int grid_index = kmap_q3xs[u];
|
| 11082 |
+
is_on_grid_aux[k] = true;
|
| 11083 |
+
if (grid_index < 0) {
|
| 11084 |
+
is_on_grid_aux[k] = false;
|
| 11085 |
+
const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
|
| 11086 |
+
grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
|
| 11087 |
+
}
|
| 11088 |
+
}
|
| 11089 |
+
float sumqx = 0, sumq2 = 0;
|
| 11090 |
+
for (int i = 0; i < block_size; ++i) {
|
| 11091 |
+
float w = weight[i];
|
| 11092 |
+
float q = 2*Laux[i] + 1;
|
| 11093 |
+
sumqx += w*xval[i]*q;
|
| 11094 |
+
sumq2 += w*q*q;
|
| 11095 |
+
}
|
| 11096 |
+
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
|
| 11097 |
+
scale = sumqx/sumq2; best = scale*sumqx;
|
| 11098 |
+
for (int i = 0; i < block_size; ++i) L[i] = Laux[i];
|
| 11099 |
+
for (int k = 0; k < bs4; ++k) is_on_grid[k] = is_on_grid_aux[k];
|
| 11100 |
+
}
|
| 11101 |
+
}
|
| 11102 |
+
int n_not_ongrid = 0;
|
| 11103 |
+
for (int k = 0; k < bs4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
|
| 11104 |
+
if (n_not_ongrid > 0 && scale > 0) {
|
| 11105 |
+
float id = 1/scale;
|
| 11106 |
+
for (int k = 0; k < bs4; ++k) {
|
| 11107 |
+
if (is_on_grid[k]) continue;
|
| 11108 |
+
uint16_t u = 0;
|
| 11109 |
+
for (int i = 0; i < 4; ++i) {
|
| 11110 |
+
int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
|
| 11111 |
+
l = MAX(0, MIN(kMaxQ-1, l));
|
| 11112 |
+
u |= (l << 3*i);
|
| 11113 |
+
}
|
| 11114 |
+
int grid_index = kmap_q3xs[u];
|
| 11115 |
+
if (grid_index < 0) {
|
| 11116 |
+
const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
|
| 11117 |
+
grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
|
| 11118 |
+
}
|
| 11119 |
+
const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
|
| 11120 |
+
for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
|
| 11121 |
+
}
|
| 11122 |
+
float sumqx = 0, sumq2 = 0;
|
| 11123 |
+
for (int i = 0; i < block_size; ++i) {
|
| 11124 |
+
float w = weight[i];
|
| 11125 |
+
float q = 2*L[i] + 1;
|
| 11126 |
+
sumqx += w*xval[i]*q;
|
| 11127 |
+
sumq2 += w*q*q;
|
| 11128 |
+
}
|
| 11129 |
+
if (sumq2 > 0) scale = sumqx/sumq2;
|
| 11130 |
+
}
|
| 11131 |
+
if (scale < 0) {
|
| 11132 |
+
// This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
|
| 11133 |
+
// and correspondingly flip quant signs.
|
| 11134 |
+
scale = -scale;
|
| 11135 |
+
for (int k = 0; k < bs8; ++k) block_signs[k] = ~block_signs[k];
|
| 11136 |
+
}
|
| 11137 |
+
for (int k = 0; k < bs4; ++k) {
|
| 11138 |
+
uint16_t u = 0;
|
| 11139 |
+
for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
|
| 11140 |
+
int grid_index = kmap_q3xs[u];
|
| 11141 |
+
if (grid_index < 0) {
|
| 11142 |
+
printf("Oops: found point %u not on grid:", u);
|
| 11143 |
+
for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
|
| 11144 |
+
printf("\n");
|
| 11145 |
+
GGML_ASSERT(false);
|
| 11146 |
+
}
|
| 11147 |
+
qs[k] = grid_index & 255;
|
| 11148 |
+
qh[(ib*bs4+k)/8] |= ((grid_index >> 8) << ((ib*bs4+k)%8));
|
| 11149 |
+
}
|
| 11150 |
+
qs += bs4;
|
| 11151 |
+
for (int k = 0; k < bs8; ++k) signs[k] = block_signs[k];
|
| 11152 |
+
signs += bs8;
|
| 11153 |
+
GGML_ASSERT(scale >= 0);
|
| 11154 |
+
scales[ib] = scale;
|
| 11155 |
+
max_scale = MAX(max_scale, scale);
|
| 11156 |
+
}
|
| 11157 |
+
|
| 11158 |
+
if (!max_scale) {
|
| 11159 |
+
continue;
|
| 11160 |
+
}
|
| 11161 |
+
|
| 11162 |
+
float d = max_scale/31;
|
| 11163 |
+
y[ibl].d = GGML_FP32_TO_FP16(d);
|
| 11164 |
+
float id = 1/d;
|
| 11165 |
+
for (int ib = 0; ib < QK_K/block_size; ib += 2) {
|
| 11166 |
+
int l1 = nearest_int(0.5f*(id*scales[ib+0]-1));
|
| 11167 |
+
l1 = MAX(0, MIN(15, l1));
|
| 11168 |
+
int l2 = nearest_int(0.5f*(id*scales[ib+1]-1));
|
| 11169 |
+
l2 = MAX(0, MIN(15, l2));
|
| 11170 |
+
y[ibl].scales[ib/2] = l1 | (l2 << 4);
|
| 11171 |
+
}
|
| 11172 |
+
|
| 11173 |
+
}
|
| 11174 |
+
}
|
| 11175 |
+
|
| 11176 |
+
#define IQ3S_BLOCK_SIZE 32
|
| 11177 |
+
size_t quantize_iq3_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
| 11178 |
+
(void)hist;
|
| 11179 |
+
GGML_ASSERT(n_per_row%QK_K == 0);
|
| 11180 |
+
int nblock = n_per_row/QK_K;
|
| 11181 |
+
float scales[QK_K/IQ3S_BLOCK_SIZE];
|
| 11182 |
+
float weight[IQ3S_BLOCK_SIZE];
|
| 11183 |
+
float xval[IQ3S_BLOCK_SIZE];
|
| 11184 |
+
int8_t L[IQ3S_BLOCK_SIZE];
|
| 11185 |
+
int8_t Laux[IQ3S_BLOCK_SIZE];
|
| 11186 |
+
float waux[IQ3S_BLOCK_SIZE];
|
| 11187 |
+
bool is_on_grid[IQ3S_BLOCK_SIZE/4];
|
| 11188 |
+
bool is_on_grid_aux[IQ3S_BLOCK_SIZE/4];
|
| 11189 |
+
uint8_t block_signs[IQ3S_BLOCK_SIZE/8];
|
| 11190 |
+
char * qrow = (char *)dst;
|
| 11191 |
+
for (int row = 0; row < nrow; ++row) {
|
| 11192 |
+
quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights,
|
| 11193 |
+
scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs);
|
| 11194 |
+
src += n_per_row;
|
| 11195 |
+
qrow += nblock*sizeof(block_iq3_s);
|
| 11196 |
+
}
|
| 11197 |
+
return nrow * nblock * sizeof(block_iq3_s);
|
| 11198 |
+
}
|
| 11199 |
+
|
| 11200 |
+
void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int k) {
|
| 11201 |
+
assert(k % QK_K == 0);
|
| 11202 |
+
block_iq3_s * restrict y = vy;
|
| 11203 |
+
quantize_row_iq3_s_reference(x, y, k);
|
| 11204 |
}
|
| 11205 |
|
| 11206 |
+
void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int k) {
|
| 11207 |
+
assert(k % QK_K == 0);
|
| 11208 |
+
quantize_iq3_s(x, y, 1, k, NULL, NULL);
|
| 11209 |
+
}
|
| 11210 |
+
|
| 11211 |
+
|
| 11212 |
// =================================== 1.5 bpw ===================================================
|
| 11213 |
|
| 11214 |
static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
|
|
@@ -191,6 +191,21 @@ typedef struct {
|
|
| 191 |
} block_iq3_xxs;
|
| 192 |
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
typedef struct {
|
| 195 |
ggml_fp16_t d;
|
| 196 |
uint8_t qs[QK_K/8];
|
|
@@ -226,6 +241,7 @@ void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGM
|
|
| 226 |
void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
|
| 227 |
void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
|
| 228 |
void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k);
|
|
|
|
| 229 |
|
| 230 |
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
| 231 |
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
|
@@ -242,6 +258,7 @@ void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
|
|
| 242 |
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
| 243 |
void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
| 244 |
void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
|
|
|
| 245 |
|
| 246 |
// Dequantization
|
| 247 |
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
|
@@ -262,6 +279,7 @@ void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_
|
|
| 262 |
void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
| 263 |
void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
| 264 |
void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
|
|
|
| 265 |
|
| 266 |
// Dot product
|
| 267 |
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
|
@@ -280,6 +298,7 @@ void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
|
| 280 |
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 281 |
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 282 |
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
|
|
|
| 283 |
|
| 284 |
//
|
| 285 |
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
|
@@ -289,6 +308,7 @@ size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row,
|
|
| 289 |
size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 290 |
size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 291 |
size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
|
|
|
| 292 |
size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 293 |
size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 294 |
size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
|
|
|
| 191 |
} block_iq3_xxs;
|
| 192 |
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
|
| 193 |
|
| 194 |
+
// 3.4375 bpw
|
| 195 |
+
#if QK_K == 64
|
| 196 |
+
#define IQ3S_N_SCALE 2
|
| 197 |
+
#else
|
| 198 |
+
#define IQ3S_N_SCALE QK_K/64
|
| 199 |
+
#endif
|
| 200 |
+
typedef struct {
|
| 201 |
+
ggml_fp16_t d;
|
| 202 |
+
uint8_t qs[QK_K/4];
|
| 203 |
+
uint8_t qh[QK_K/32];
|
| 204 |
+
uint8_t signs[QK_K/8];
|
| 205 |
+
uint8_t scales[IQ3S_N_SCALE];
|
| 206 |
+
} block_iq3_s;
|
| 207 |
+
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
|
| 208 |
+
|
| 209 |
typedef struct {
|
| 210 |
ggml_fp16_t d;
|
| 211 |
uint8_t qs[QK_K/8];
|
|
|
|
| 241 |
void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
|
| 242 |
void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
|
| 243 |
void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k);
|
| 244 |
+
void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int k);
|
| 245 |
|
| 246 |
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
| 247 |
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
|
|
|
| 258 |
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
| 259 |
void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
| 260 |
void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
| 261 |
+
void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
| 262 |
|
| 263 |
// Dequantization
|
| 264 |
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
|
|
|
| 279 |
void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
| 280 |
void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
| 281 |
void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
| 282 |
+
void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
| 283 |
|
| 284 |
// Dot product
|
| 285 |
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
|
|
|
| 298 |
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 299 |
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 300 |
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 301 |
+
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 302 |
|
| 303 |
//
|
| 304 |
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
|
|
|
| 308 |
size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 309 |
size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 310 |
size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 311 |
+
size_t quantize_iq3_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 312 |
size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 313 |
size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
| 314 |
size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
|
@@ -682,6 +682,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|
| 682 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 683 |
.nrows = 1,
|
| 684 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
[GGML_TYPE_IQ1_S] = {
|
| 686 |
.type_name = "iq1_s",
|
| 687 |
.blck_size = QK_K,
|
|
@@ -2308,6 +2320,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
|
| 2308 |
case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
|
| 2309 |
case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
|
| 2310 |
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
|
|
|
|
| 2311 |
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
| 2312 |
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
| 2313 |
}
|
|
@@ -7742,6 +7755,7 @@ static void ggml_compute_forward_add(
|
|
| 7742 |
case GGML_TYPE_IQ3_XXS:
|
| 7743 |
case GGML_TYPE_IQ1_S:
|
| 7744 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 7745 |
{
|
| 7746 |
ggml_compute_forward_add_q_f32(params, dst);
|
| 7747 |
} break;
|
|
@@ -8021,6 +8035,7 @@ static void ggml_compute_forward_add1(
|
|
| 8021 |
case GGML_TYPE_IQ3_XXS:
|
| 8022 |
case GGML_TYPE_IQ1_S:
|
| 8023 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 8024 |
{
|
| 8025 |
ggml_compute_forward_add1_q_f32(params, dst);
|
| 8026 |
} break;
|
|
@@ -8145,6 +8160,7 @@ static void ggml_compute_forward_acc(
|
|
| 8145 |
case GGML_TYPE_IQ3_XXS:
|
| 8146 |
case GGML_TYPE_IQ1_S:
|
| 8147 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 8148 |
default:
|
| 8149 |
{
|
| 8150 |
GGML_ASSERT(false);
|
|
@@ -11043,6 +11059,7 @@ static void ggml_compute_forward_out_prod(
|
|
| 11043 |
case GGML_TYPE_IQ3_XXS:
|
| 11044 |
case GGML_TYPE_IQ1_S:
|
| 11045 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 11046 |
{
|
| 11047 |
ggml_compute_forward_out_prod_q_f32(params, dst);
|
| 11048 |
} break;
|
|
@@ -11231,6 +11248,7 @@ static void ggml_compute_forward_set(
|
|
| 11231 |
case GGML_TYPE_IQ3_XXS:
|
| 11232 |
case GGML_TYPE_IQ1_S:
|
| 11233 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 11234 |
default:
|
| 11235 |
{
|
| 11236 |
GGML_ASSERT(false);
|
|
@@ -11433,6 +11451,7 @@ static void ggml_compute_forward_get_rows(
|
|
| 11433 |
case GGML_TYPE_IQ3_XXS:
|
| 11434 |
case GGML_TYPE_IQ1_S:
|
| 11435 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 11436 |
{
|
| 11437 |
ggml_compute_forward_get_rows_q(params, dst);
|
| 11438 |
} break;
|
|
@@ -12133,6 +12152,7 @@ static void ggml_compute_forward_alibi(
|
|
| 12133 |
case GGML_TYPE_IQ3_XXS:
|
| 12134 |
case GGML_TYPE_IQ1_S:
|
| 12135 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 12136 |
case GGML_TYPE_Q8_K:
|
| 12137 |
case GGML_TYPE_I8:
|
| 12138 |
case GGML_TYPE_I16:
|
|
@@ -12216,6 +12236,7 @@ static void ggml_compute_forward_clamp(
|
|
| 12216 |
case GGML_TYPE_IQ3_XXS:
|
| 12217 |
case GGML_TYPE_IQ1_S:
|
| 12218 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
| 12219 |
case GGML_TYPE_Q8_K:
|
| 12220 |
case GGML_TYPE_I8:
|
| 12221 |
case GGML_TYPE_I16:
|
|
@@ -19467,6 +19488,7 @@ void ggml_quantize_init(enum ggml_type type) {
|
|
| 19467 |
case GGML_TYPE_IQ2_XS:
|
| 19468 |
case GGML_TYPE_IQ1_S: iq2xs_init_impl(type); break;
|
| 19469 |
case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
|
|
|
|
| 19470 |
default: // nothing
|
| 19471 |
break;
|
| 19472 |
}
|
|
@@ -19741,6 +19763,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
| 19741 |
result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
| 19742 |
GGML_ASSERT(result == row_size * nrows);
|
| 19743 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19744 |
case GGML_TYPE_IQ1_S:
|
| 19745 |
{
|
| 19746 |
GGML_ASSERT(start % QK_K == 0);
|
|
|
|
| 682 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 683 |
.nrows = 1,
|
| 684 |
},
|
| 685 |
+
[GGML_TYPE_IQ3_S] = {
|
| 686 |
+
.type_name = "iq3_s",
|
| 687 |
+
.blck_size = QK_K,
|
| 688 |
+
.type_size = sizeof(block_iq3_s),
|
| 689 |
+
.is_quantized = true,
|
| 690 |
+
.to_float = (ggml_to_float_t) dequantize_row_iq3_s,
|
| 691 |
+
.from_float = quantize_row_iq3_s,
|
| 692 |
+
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_s_reference,
|
| 693 |
+
.vec_dot = ggml_vec_dot_iq3_s_q8_K,
|
| 694 |
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 695 |
+
.nrows = 1,
|
| 696 |
+
},
|
| 697 |
[GGML_TYPE_IQ1_S] = {
|
| 698 |
.type_name = "iq1_s",
|
| 699 |
.blck_size = QK_K,
|
|
|
|
| 2320 |
case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
|
| 2321 |
case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
|
| 2322 |
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
|
| 2323 |
+
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
|
| 2324 |
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
| 2325 |
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
| 2326 |
}
|
|
|
|
| 7755 |
case GGML_TYPE_IQ3_XXS:
|
| 7756 |
case GGML_TYPE_IQ1_S:
|
| 7757 |
case GGML_TYPE_IQ4_NL:
|
| 7758 |
+
case GGML_TYPE_IQ3_S:
|
| 7759 |
{
|
| 7760 |
ggml_compute_forward_add_q_f32(params, dst);
|
| 7761 |
} break;
|
|
|
|
| 8035 |
case GGML_TYPE_IQ3_XXS:
|
| 8036 |
case GGML_TYPE_IQ1_S:
|
| 8037 |
case GGML_TYPE_IQ4_NL:
|
| 8038 |
+
case GGML_TYPE_IQ3_S:
|
| 8039 |
{
|
| 8040 |
ggml_compute_forward_add1_q_f32(params, dst);
|
| 8041 |
} break;
|
|
|
|
| 8160 |
case GGML_TYPE_IQ3_XXS:
|
| 8161 |
case GGML_TYPE_IQ1_S:
|
| 8162 |
case GGML_TYPE_IQ4_NL:
|
| 8163 |
+
case GGML_TYPE_IQ3_S:
|
| 8164 |
default:
|
| 8165 |
{
|
| 8166 |
GGML_ASSERT(false);
|
|
|
|
| 11059 |
case GGML_TYPE_IQ3_XXS:
|
| 11060 |
case GGML_TYPE_IQ1_S:
|
| 11061 |
case GGML_TYPE_IQ4_NL:
|
| 11062 |
+
case GGML_TYPE_IQ3_S:
|
| 11063 |
{
|
| 11064 |
ggml_compute_forward_out_prod_q_f32(params, dst);
|
| 11065 |
} break;
|
|
|
|
| 11248 |
case GGML_TYPE_IQ3_XXS:
|
| 11249 |
case GGML_TYPE_IQ1_S:
|
| 11250 |
case GGML_TYPE_IQ4_NL:
|
| 11251 |
+
case GGML_TYPE_IQ3_S:
|
| 11252 |
default:
|
| 11253 |
{
|
| 11254 |
GGML_ASSERT(false);
|
|
|
|
| 11451 |
case GGML_TYPE_IQ3_XXS:
|
| 11452 |
case GGML_TYPE_IQ1_S:
|
| 11453 |
case GGML_TYPE_IQ4_NL:
|
| 11454 |
+
case GGML_TYPE_IQ3_S:
|
| 11455 |
{
|
| 11456 |
ggml_compute_forward_get_rows_q(params, dst);
|
| 11457 |
} break;
|
|
|
|
| 12152 |
case GGML_TYPE_IQ3_XXS:
|
| 12153 |
case GGML_TYPE_IQ1_S:
|
| 12154 |
case GGML_TYPE_IQ4_NL:
|
| 12155 |
+
case GGML_TYPE_IQ3_S:
|
| 12156 |
case GGML_TYPE_Q8_K:
|
| 12157 |
case GGML_TYPE_I8:
|
| 12158 |
case GGML_TYPE_I16:
|
|
|
|
| 12236 |
case GGML_TYPE_IQ3_XXS:
|
| 12237 |
case GGML_TYPE_IQ1_S:
|
| 12238 |
case GGML_TYPE_IQ4_NL:
|
| 12239 |
+
case GGML_TYPE_IQ3_S:
|
| 12240 |
case GGML_TYPE_Q8_K:
|
| 12241 |
case GGML_TYPE_I8:
|
| 12242 |
case GGML_TYPE_I16:
|
|
|
|
| 19488 |
case GGML_TYPE_IQ2_XS:
|
| 19489 |
case GGML_TYPE_IQ1_S: iq2xs_init_impl(type); break;
|
| 19490 |
case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
|
| 19491 |
+
case GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break;
|
| 19492 |
default: // nothing
|
| 19493 |
break;
|
| 19494 |
}
|
|
|
|
| 19763 |
result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
| 19764 |
GGML_ASSERT(result == row_size * nrows);
|
| 19765 |
} break;
|
| 19766 |
+
case GGML_TYPE_IQ3_S:
|
| 19767 |
+
{
|
| 19768 |
+
GGML_ASSERT(start % QK_K == 0);
|
| 19769 |
+
GGML_ASSERT(start % n_per_row == 0);
|
| 19770 |
+
size_t start_row = start / n_per_row;
|
| 19771 |
+
size_t row_size = ggml_row_size(type, n_per_row);
|
| 19772 |
+
result = quantize_iq3_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
| 19773 |
+
GGML_ASSERT(result == row_size * nrows);
|
| 19774 |
+
} break;
|
| 19775 |
case GGML_TYPE_IQ1_S:
|
| 19776 |
{
|
| 19777 |
GGML_ASSERT(start % QK_K == 0);
|
|
@@ -350,6 +350,7 @@ extern "C" {
|
|
| 350 |
GGML_TYPE_IQ3_XXS = 18,
|
| 351 |
GGML_TYPE_IQ1_S = 19,
|
| 352 |
GGML_TYPE_IQ4_NL = 20,
|
|
|
|
| 353 |
GGML_TYPE_I8,
|
| 354 |
GGML_TYPE_I16,
|
| 355 |
GGML_TYPE_I32,
|
|
@@ -389,6 +390,7 @@ extern "C" {
|
|
| 389 |
GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
|
| 390 |
GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
|
| 391 |
GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
|
|
|
|
| 392 |
};
|
| 393 |
|
| 394 |
// available tensor operations:
|
|
|
|
| 350 |
GGML_TYPE_IQ3_XXS = 18,
|
| 351 |
GGML_TYPE_IQ1_S = 19,
|
| 352 |
GGML_TYPE_IQ4_NL = 20,
|
| 353 |
+
GGML_TYPE_IQ3_S = 21,
|
| 354 |
GGML_TYPE_I8,
|
| 355 |
GGML_TYPE_I16,
|
| 356 |
GGML_TYPE_I32,
|
|
|
|
| 390 |
GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
|
| 391 |
GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
|
| 392 |
GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
|
| 393 |
+
GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
|
| 394 |
};
|
| 395 |
|
| 396 |
// available tensor operations:
|