Spaces:
Running
Running
ggml : make i-quants work with super-blocks of 64 (CPU,Metal) (llama/5760)
Browse files* WIP: make i-quants work for QK_K = 64
* iq2_xs: attempt to fix AVX dot product for QK_K = 64
Tests pass, but I get gibberish.
* QK_K = 64 tests pass on ARM_NEON and Metal
Sadly, that does not mean it actually works.
* Make CUDA compile with QK_K = 64
Tests don't pass, plus we get misaligned access
* Q2_K: fixed bug in imatrix quantization for QK_K = 64
* iq1_s: turn off SIMD implementation for QK_K = 64 (it does not work)
---------
Co-authored-by: Iwan Kawrakow <[email protected]>
- ggml-cuda.cu +20 -7
- ggml-metal.metal +30 -28
- ggml-quants.c +125 -23
- ggml-quants.h +5 -0
- ggml.c +14 -1
ggml-cuda.cu
CHANGED
|
@@ -544,14 +544,19 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
|
|
| 544 |
|
| 545 |
#define QR3_XS 8
|
| 546 |
#define QI3_XS (QK_K / (4*QR3_XS))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
typedef struct {
|
| 548 |
half d;
|
| 549 |
uint8_t qs[QK_K/4];
|
| 550 |
uint8_t qh[QK_K/32];
|
| 551 |
uint8_t signs[QK_K/8];
|
| 552 |
-
uint8_t scales[
|
| 553 |
} block_iq3_s;
|
| 554 |
-
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) +
|
| 555 |
|
| 556 |
#define QR1_S 8
|
| 557 |
#define QI1_S (QK_K / (4*QR1_S))
|
|
@@ -571,6 +576,11 @@ typedef struct {
|
|
| 571 |
} block_iq4_nl;
|
| 572 |
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
|
| 573 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
// QR4_XS = 8 is very slightly faster than QR4_XS = 4
|
| 575 |
#define QR4_XS 8
|
| 576 |
#define QI4_XS (QK_K / (4*QR4_XS))
|
|
@@ -581,7 +591,7 @@ typedef struct {
|
|
| 581 |
uint8_t qs[QK_K/2];
|
| 582 |
} block_iq4_xs;
|
| 583 |
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
|
| 584 |
-
|
| 585 |
|
| 586 |
#define WARP_SIZE 32
|
| 587 |
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
|
@@ -2439,9 +2449,9 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
|
|
| 2439 |
|
| 2440 |
}
|
| 2441 |
|
|
|
|
| 2442 |
template<typename dst_t>
|
| 2443 |
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
| 2444 |
-
|
| 2445 |
const int i = blockIdx.x;
|
| 2446 |
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
| 2447 |
|
|
@@ -2455,8 +2465,8 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
|
|
| 2455 |
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
| 2456 |
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
| 2457 |
}
|
| 2458 |
-
|
| 2459 |
}
|
|
|
|
| 2460 |
|
| 2461 |
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
| 2462 |
|
|
@@ -5382,8 +5392,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
|
| 5382 |
return 0.f;
|
| 5383 |
#endif
|
| 5384 |
#else
|
| 5385 |
-
|
| 5386 |
-
return 0.f;
|
| 5387 |
#endif
|
| 5388 |
}
|
| 5389 |
|
|
@@ -7444,7 +7453,11 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k,
|
|
| 7444 |
template<typename dst_t>
|
| 7445 |
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
| 7446 |
const int nb = (k + QK_K - 1) / QK_K;
|
|
|
|
|
|
|
|
|
|
| 7447 |
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
|
|
|
| 7448 |
}
|
| 7449 |
|
| 7450 |
template <typename src_t, typename dst_t>
|
|
|
|
| 544 |
|
| 545 |
#define QR3_XS 8
|
| 546 |
#define QI3_XS (QK_K / (4*QR3_XS))
|
| 547 |
+
#if QK_K == 64
|
| 548 |
+
#define IQ3S_N_SCALE 2
|
| 549 |
+
#else
|
| 550 |
+
#define IQ3S_N_SCALE QK_K/64
|
| 551 |
+
#endif
|
| 552 |
typedef struct {
|
| 553 |
half d;
|
| 554 |
uint8_t qs[QK_K/4];
|
| 555 |
uint8_t qh[QK_K/32];
|
| 556 |
uint8_t signs[QK_K/8];
|
| 557 |
+
uint8_t scales[IQ3S_N_SCALE];
|
| 558 |
} block_iq3_s;
|
| 559 |
+
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
|
| 560 |
|
| 561 |
#define QR1_S 8
|
| 562 |
#define QI1_S (QK_K / (4*QR1_S))
|
|
|
|
| 576 |
} block_iq4_nl;
|
| 577 |
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
|
| 578 |
|
| 579 |
+
#if QK_K == 64
|
| 580 |
+
#define block_iq4_xs block_iq4_nl
|
| 581 |
+
#define QR4_XS QR4_NL
|
| 582 |
+
#define QI4_XS QI4_NL
|
| 583 |
+
#else
|
| 584 |
// QR4_XS = 8 is very slightly faster than QR4_XS = 4
|
| 585 |
#define QR4_XS 8
|
| 586 |
#define QI4_XS (QK_K / (4*QR4_XS))
|
|
|
|
| 591 |
uint8_t qs[QK_K/2];
|
| 592 |
} block_iq4_xs;
|
| 593 |
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
|
| 594 |
+
#endif
|
| 595 |
|
| 596 |
#define WARP_SIZE 32
|
| 597 |
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
|
|
|
| 2449 |
|
| 2450 |
}
|
| 2451 |
|
| 2452 |
+
#if QK_K != 64
|
| 2453 |
template<typename dst_t>
|
| 2454 |
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
|
|
|
| 2455 |
const int i = blockIdx.x;
|
| 2456 |
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
| 2457 |
|
|
|
|
| 2465 |
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
| 2466 |
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
| 2467 |
}
|
|
|
|
| 2468 |
}
|
| 2469 |
+
#endif
|
| 2470 |
|
| 2471 |
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
| 2472 |
|
|
|
|
| 5392 |
return 0.f;
|
| 5393 |
#endif
|
| 5394 |
#else
|
| 5395 |
+
return vec_dot_iq4_xs_q8_1(vbq, bq8_1, iqs);
|
|
|
|
| 5396 |
#endif
|
| 5397 |
}
|
| 5398 |
|
|
|
|
| 7453 |
template<typename dst_t>
|
| 7454 |
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
| 7455 |
const int nb = (k + QK_K - 1) / QK_K;
|
| 7456 |
+
#if QK_K == 64
|
| 7457 |
+
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
|
| 7458 |
+
#else
|
| 7459 |
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
| 7460 |
+
#endif
|
| 7461 |
}
|
| 7462 |
|
| 7463 |
template <typename src_t, typename dst_t>
|
ggml-metal.metal
CHANGED
|
@@ -2560,12 +2560,16 @@ typedef struct {
|
|
| 2560 |
uint8_t qs[QK4_NL/2];
|
| 2561 |
} block_iq4_nl;
|
| 2562 |
|
|
|
|
|
|
|
|
|
|
| 2563 |
typedef struct {
|
| 2564 |
half d;
|
| 2565 |
uint16_t scales_h;
|
| 2566 |
uint8_t scales_l[QK_K/64];
|
| 2567 |
uint8_t qs[QK_K/2];
|
| 2568 |
} block_iq4_xs;
|
|
|
|
| 2569 |
|
| 2570 |
//====================================== dot products =========================
|
| 2571 |
|
|
@@ -4346,7 +4350,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
| 4346 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 4347 |
}
|
| 4348 |
|
| 4349 |
-
#if QK_K == 256
|
| 4350 |
const int ix = tiisg;
|
| 4351 |
|
| 4352 |
device const float * y4 = y + 32 * ix;
|
|
@@ -4387,12 +4390,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
| 4387 |
|
| 4388 |
y4 += 32 * 32;
|
| 4389 |
}
|
| 4390 |
-
#else
|
| 4391 |
-
(void) x;
|
| 4392 |
-
(void) y;
|
| 4393 |
-
(void) yl;
|
| 4394 |
-
(void) nb32;
|
| 4395 |
-
#endif
|
| 4396 |
|
| 4397 |
for (int row = 0; row < N_DST; ++row) {
|
| 4398 |
all_sum = simd_sum(sumf[row]);
|
|
@@ -4482,7 +4479,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
| 4482 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 4483 |
}
|
| 4484 |
|
| 4485 |
-
#if QK_K == 256
|
| 4486 |
const int ix = tiisg;
|
| 4487 |
|
| 4488 |
device const float * y4 = y + 32 * ix;
|
|
@@ -4533,12 +4529,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
| 4533 |
|
| 4534 |
y4 += 32 * 32;
|
| 4535 |
}
|
| 4536 |
-
#else
|
| 4537 |
-
(void) x;
|
| 4538 |
-
(void) y;
|
| 4539 |
-
(void) yl;
|
| 4540 |
-
(void) nb32;
|
| 4541 |
-
#endif
|
| 4542 |
|
| 4543 |
for (int row = 0; row < N_DST; ++row) {
|
| 4544 |
all_sum = simd_sum(sumf[row]);
|
|
@@ -4628,7 +4618,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
| 4628 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 4629 |
}
|
| 4630 |
|
| 4631 |
-
#if QK_K == 256
|
| 4632 |
const int ix = tiisg;
|
| 4633 |
|
| 4634 |
device const float * y4 = y + 32 * ix;
|
|
@@ -4672,12 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
| 4672 |
|
| 4673 |
y4 += 32 * 32;
|
| 4674 |
}
|
| 4675 |
-
#else
|
| 4676 |
-
(void) x;
|
| 4677 |
-
(void) y;
|
| 4678 |
-
(void) yl;
|
| 4679 |
-
(void) nb32;
|
| 4680 |
-
#endif
|
| 4681 |
|
| 4682 |
for (int row = 0; row < N_DST; ++row) {
|
| 4683 |
all_sum = simd_sum(sumf[row]);
|
|
@@ -5016,7 +4999,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
| 5016 |
|
| 5017 |
const int nb32 = nb * (QK_K / 32);
|
| 5018 |
|
| 5019 |
-
#if QK_K == 256
|
| 5020 |
const int ix = tiisg/2;
|
| 5021 |
const int il = tiisg%2;
|
| 5022 |
|
|
@@ -5055,12 +5037,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
| 5055 |
|
| 5056 |
y4 += 16 * 32;
|
| 5057 |
}
|
| 5058 |
-
#else
|
| 5059 |
-
(void) x;
|
| 5060 |
-
(void) y;
|
| 5061 |
-
(void) yl;
|
| 5062 |
-
(void) nb32;
|
| 5063 |
-
#endif
|
| 5064 |
|
| 5065 |
for (int row = 0; row < N_DST; ++row) {
|
| 5066 |
all_sum = simd_sum(sumf[row]);
|
|
@@ -5167,6 +5143,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
| 5167 |
}
|
| 5168 |
}
|
| 5169 |
|
|
|
|
| 5170 |
void kernel_mul_mv_iq4_xs_f32_impl(
|
| 5171 |
device const void * src0,
|
| 5172 |
device const float * src1,
|
|
@@ -5260,6 +5237,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
| 5260 |
}
|
| 5261 |
}
|
| 5262 |
}
|
|
|
|
| 5263 |
|
| 5264 |
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
| 5265 |
kernel void kernel_mul_mv_iq1_s_f32(
|
|
@@ -5344,7 +5322,11 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
| 5344 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 5345 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5346 |
|
|
|
|
|
|
|
|
|
|
| 5347 |
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
|
|
| 5348 |
}
|
| 5349 |
|
| 5350 |
//============================= templates and their specializations =============================
|
|
@@ -5770,6 +5752,9 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
|
|
| 5770 |
|
| 5771 |
template <typename type4x4>
|
| 5772 |
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
|
|
|
|
|
|
|
|
|
| 5773 |
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
| 5774 |
const int ib32 = il/2;
|
| 5775 |
il = il%2;
|
|
@@ -5786,6 +5771,7 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
|
|
| 5786 |
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
| 5787 |
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
| 5788 |
}
|
|
|
|
| 5789 |
}
|
| 5790 |
|
| 5791 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
@@ -6334,7 +6320,11 @@ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_r
|
|
| 6334 |
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
| 6335 |
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 6336 |
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
|
|
|
|
|
|
|
|
| 6337 |
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
|
|
| 6338 |
|
| 6339 |
//
|
| 6340 |
// matrix-matrix multiplication
|
|
@@ -6378,7 +6368,11 @@ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_m
|
|
| 6378 |
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
| 6379 |
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 6380 |
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
|
|
|
|
|
|
|
|
| 6381 |
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
|
|
| 6382 |
|
| 6383 |
//
|
| 6384 |
// indirect matrix-matrix multiplication
|
|
@@ -6434,7 +6428,11 @@ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel
|
|
| 6434 |
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
| 6435 |
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 6436 |
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
|
|
|
|
|
|
|
|
| 6437 |
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
|
|
| 6438 |
|
| 6439 |
//
|
| 6440 |
// matrix-vector multiplication
|
|
@@ -7707,7 +7705,11 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
|
|
| 7707 |
|
| 7708 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 7709 |
|
|
|
|
|
|
|
|
|
|
| 7710 |
kernel_mul_mv_iq4_xs_f32_impl(
|
|
|
|
| 7711 |
src0[id],
|
| 7712 |
(device const float *) (src1 + bid*nb11),
|
| 7713 |
dst + bid*ne0,
|
|
|
|
| 2560 |
uint8_t qs[QK4_NL/2];
|
| 2561 |
} block_iq4_nl;
|
| 2562 |
|
| 2563 |
+
#if QK_K == 64
|
| 2564 |
+
#define block_iq4_xs block_iq4_nl
|
| 2565 |
+
#else
|
| 2566 |
typedef struct {
|
| 2567 |
half d;
|
| 2568 |
uint16_t scales_h;
|
| 2569 |
uint8_t scales_l[QK_K/64];
|
| 2570 |
uint8_t qs[QK_K/2];
|
| 2571 |
} block_iq4_xs;
|
| 2572 |
+
#endif
|
| 2573 |
|
| 2574 |
//====================================== dot products =========================
|
| 2575 |
|
|
|
|
| 4350 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 4351 |
}
|
| 4352 |
|
|
|
|
| 4353 |
const int ix = tiisg;
|
| 4354 |
|
| 4355 |
device const float * y4 = y + 32 * ix;
|
|
|
|
| 4390 |
|
| 4391 |
y4 += 32 * 32;
|
| 4392 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4393 |
|
| 4394 |
for (int row = 0; row < N_DST; ++row) {
|
| 4395 |
all_sum = simd_sum(sumf[row]);
|
|
|
|
| 4479 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 4480 |
}
|
| 4481 |
|
|
|
|
| 4482 |
const int ix = tiisg;
|
| 4483 |
|
| 4484 |
device const float * y4 = y + 32 * ix;
|
|
|
|
| 4529 |
|
| 4530 |
y4 += 32 * 32;
|
| 4531 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4532 |
|
| 4533 |
for (int row = 0; row < N_DST; ++row) {
|
| 4534 |
all_sum = simd_sum(sumf[row]);
|
|
|
|
| 4618 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 4619 |
}
|
| 4620 |
|
|
|
|
| 4621 |
const int ix = tiisg;
|
| 4622 |
|
| 4623 |
device const float * y4 = y + 32 * ix;
|
|
|
|
| 4661 |
|
| 4662 |
y4 += 32 * 32;
|
| 4663 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4664 |
|
| 4665 |
for (int row = 0; row < N_DST; ++row) {
|
| 4666 |
all_sum = simd_sum(sumf[row]);
|
|
|
|
| 4999 |
|
| 5000 |
const int nb32 = nb * (QK_K / 32);
|
| 5001 |
|
|
|
|
| 5002 |
const int ix = tiisg/2;
|
| 5003 |
const int il = tiisg%2;
|
| 5004 |
|
|
|
|
| 5037 |
|
| 5038 |
y4 += 16 * 32;
|
| 5039 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5040 |
|
| 5041 |
for (int row = 0; row < N_DST; ++row) {
|
| 5042 |
all_sum = simd_sum(sumf[row]);
|
|
|
|
| 5143 |
}
|
| 5144 |
}
|
| 5145 |
|
| 5146 |
+
#if QK_K != 64
|
| 5147 |
void kernel_mul_mv_iq4_xs_f32_impl(
|
| 5148 |
device const void * src0,
|
| 5149 |
device const float * src1,
|
|
|
|
| 5237 |
}
|
| 5238 |
}
|
| 5239 |
}
|
| 5240 |
+
#endif
|
| 5241 |
|
| 5242 |
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
| 5243 |
kernel void kernel_mul_mv_iq1_s_f32(
|
|
|
|
| 5322 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 5323 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5324 |
|
| 5325 |
+
#if QK_K == 64
|
| 5326 |
+
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 5327 |
+
#else
|
| 5328 |
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 5329 |
+
#endif
|
| 5330 |
}
|
| 5331 |
|
| 5332 |
//============================= templates and their specializations =============================
|
|
|
|
| 5752 |
|
| 5753 |
template <typename type4x4>
|
| 5754 |
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
| 5755 |
+
#if QK_K == 64
|
| 5756 |
+
dequantize_iq4_nl(xb, il, reg);
|
| 5757 |
+
#else
|
| 5758 |
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
| 5759 |
const int ib32 = il/2;
|
| 5760 |
il = il%2;
|
|
|
|
| 5771 |
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
| 5772 |
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
| 5773 |
}
|
| 5774 |
+
#endif
|
| 5775 |
}
|
| 5776 |
|
| 5777 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
|
|
| 6320 |
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
| 6321 |
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 6322 |
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 6323 |
+
#if QK_K == 64
|
| 6324 |
+
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
| 6325 |
+
#else
|
| 6326 |
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
| 6327 |
+
#endif
|
| 6328 |
|
| 6329 |
//
|
| 6330 |
// matrix-matrix multiplication
|
|
|
|
| 6368 |
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
| 6369 |
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 6370 |
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 6371 |
+
#if QK_K == 64
|
| 6372 |
+
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
| 6373 |
+
#else
|
| 6374 |
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
| 6375 |
+
#endif
|
| 6376 |
|
| 6377 |
//
|
| 6378 |
// indirect matrix-matrix multiplication
|
|
|
|
| 6428 |
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
| 6429 |
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 6430 |
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 6431 |
+
#if QK_K == 64
|
| 6432 |
+
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
| 6433 |
+
#else
|
| 6434 |
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
| 6435 |
+
#endif
|
| 6436 |
|
| 6437 |
//
|
| 6438 |
// matrix-vector multiplication
|
|
|
|
| 7705 |
|
| 7706 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 7707 |
|
| 7708 |
+
#if QK_K == 64
|
| 7709 |
+
kernel_mul_mv_iq4_nl_f32_impl(
|
| 7710 |
+
#else
|
| 7711 |
kernel_mul_mv_iq4_xs_f32_impl(
|
| 7712 |
+
#endif
|
| 7713 |
src0[id],
|
| 7714 |
(device const float *) (src1 + bid*nb11),
|
| 7715 |
dst + bid*ne0,
|
ggml-quants.c
CHANGED
|
@@ -1877,7 +1877,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
|
|
| 1877 |
float mins[QK_K/16];
|
| 1878 |
float scales[QK_K/16];
|
| 1879 |
float sw[QK_K/16];
|
| 1880 |
-
float weight[
|
| 1881 |
uint8_t Ls[QK_K/16], Lm[QK_K/16];
|
| 1882 |
|
| 1883 |
for (int i = 0; i < nb; i++) {
|
|
@@ -1887,13 +1887,42 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
|
|
| 1887 |
float sigma2 = sumx2/QK_K;
|
| 1888 |
for (int j = 0; j < QK_K/16; ++j) {
|
| 1889 |
const float * restrict qw = quant_weights + QK_K * i + 16*j;
|
| 1890 |
-
for (int l = 0; l <
|
| 1891 |
for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
|
| 1892 |
-
scales[j] = make_qkx3_quants(
|
| 1893 |
}
|
| 1894 |
|
| 1895 |
-
float dm
|
| 1896 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1897 |
y[i].d = GGML_FP32_TO_FP16(dm);
|
| 1898 |
y[i].dmin = GGML_FP32_TO_FP16(mm);
|
| 1899 |
dm = GGML_FP16_TO_FP32(y[i].d);
|
|
@@ -4227,6 +4256,9 @@ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y,
|
|
| 4227 |
|
| 4228 |
void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int k) {
|
| 4229 |
assert(k % QK_K == 0);
|
|
|
|
|
|
|
|
|
|
| 4230 |
const int nb = k / QK_K;
|
| 4231 |
|
| 4232 |
for (int i = 0; i < nb; i++) {
|
|
@@ -4246,6 +4278,7 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y,
|
|
| 4246 |
qs += 16;
|
| 4247 |
}
|
| 4248 |
}
|
|
|
|
| 4249 |
}
|
| 4250 |
|
| 4251 |
//===================================== Q8_K ==============================================
|
|
@@ -6306,7 +6339,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
| 6306 |
|
| 6307 |
float sumf = 0;
|
| 6308 |
|
| 6309 |
-
int isum[
|
| 6310 |
|
| 6311 |
for (int i = 0; i < nb; ++i) {
|
| 6312 |
|
|
@@ -6322,14 +6355,14 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
| 6322 |
const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
| 6323 |
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
| 6324 |
|
| 6325 |
-
isum
|
| 6326 |
for (int l = 0; l < 16; ++l) {
|
| 6327 |
isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
|
| 6328 |
isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
|
| 6329 |
isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
|
| 6330 |
isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
|
| 6331 |
}
|
| 6332 |
-
for (int l = 0; l <
|
| 6333 |
isum[l] *= (sc[l] & 0xF);
|
| 6334 |
}
|
| 6335 |
sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
|
|
@@ -9488,15 +9521,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
|
| 9488 |
|
| 9489 |
#elif defined(__AVX2__)
|
| 9490 |
|
| 9491 |
-
const __m128i m4 = _mm_set1_epi8(0xf);
|
| 9492 |
-
const __m128i m1 = _mm_set1_epi8(1);
|
| 9493 |
-
const __m256i m511 = _mm256_set1_epi16(511);
|
| 9494 |
const __m256i mone = _mm256_set1_epi8(1);
|
| 9495 |
-
|
| 9496 |
-
static const uint8_t k_bit_helper[32] = {
|
| 9497 |
-
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
| 9498 |
-
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
| 9499 |
-
};
|
| 9500 |
static const char block_sign_shuffle_mask_1[32] = {
|
| 9501 |
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
|
| 9502 |
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
|
|
@@ -9510,11 +9535,77 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
|
| 9510 |
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
| 9511 |
};
|
| 9512 |
|
| 9513 |
-
const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
|
| 9514 |
const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
|
| 9515 |
const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
|
| 9516 |
const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
|
| 9517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9518 |
uint64_t aux64;
|
| 9519 |
|
| 9520 |
// somewhat hacky, but gives a significant boost in performance
|
|
@@ -9603,6 +9694,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
|
| 9603 |
}
|
| 9604 |
|
| 9605 |
*s = 0.125f * hsum_float_8(accumf);
|
|
|
|
| 9606 |
|
| 9607 |
#else
|
| 9608 |
|
|
@@ -10199,7 +10291,8 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
|
| 10199 |
|
| 10200 |
const int nb = n / QK_K;
|
| 10201 |
|
| 10202 |
-
|
|
|
|
| 10203 |
|
| 10204 |
const uint8x16_t m8 = vdupq_n_u8(0x08);
|
| 10205 |
const uint8x16_t m7 = vdupq_n_u8(0x07);
|
|
@@ -10256,7 +10349,8 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
|
| 10256 |
|
| 10257 |
*s = sumf;
|
| 10258 |
|
| 10259 |
-
|
|
|
|
| 10260 |
|
| 10261 |
const __m128i m8 = _mm_set1_epi8(0x08);
|
| 10262 |
const __m128i m7 = _mm_set1_epi8(0x07);
|
|
@@ -10455,6 +10549,9 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
|
| 10455 |
UNUSED(by);
|
| 10456 |
UNUSED(bs);
|
| 10457 |
assert(n % QK_K == 0);
|
|
|
|
|
|
|
|
|
|
| 10458 |
|
| 10459 |
const block_iq4_xs * restrict x = vx;
|
| 10460 |
const block_q8_K * restrict y = vy;
|
|
@@ -10574,6 +10671,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
|
| 10574 |
}
|
| 10575 |
*s = sumf;
|
| 10576 |
#endif
|
|
|
|
| 10577 |
}
|
| 10578 |
|
| 10579 |
// ================================ IQ2 quantization =============================================
|
|
@@ -10921,7 +11019,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
|
|
| 10921 |
|
| 10922 |
const int kMaxQ = 3;
|
| 10923 |
|
| 10924 |
-
const int nbl = n/
|
| 10925 |
|
| 10926 |
block_iq2_xxs * y = vy;
|
| 10927 |
|
|
@@ -11094,7 +11192,7 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v
|
|
| 11094 |
|
| 11095 |
const int kMaxQ = 3;
|
| 11096 |
|
| 11097 |
-
const int nbl = n/
|
| 11098 |
|
| 11099 |
block_iq2_xs * y = vy;
|
| 11100 |
|
|
@@ -12037,7 +12135,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
|
| 12037 |
GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
|
| 12038 |
GGML_ASSERT(n%QK_K == 0);
|
| 12039 |
|
| 12040 |
-
const int nbl = n/
|
| 12041 |
|
| 12042 |
block_iq1_s * y = vy;
|
| 12043 |
|
|
@@ -12315,6 +12413,9 @@ void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * rest
|
|
| 12315 |
}
|
| 12316 |
|
| 12317 |
size_t quantize_iq4_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
|
|
|
|
|
|
|
|
|
| 12318 |
(void)hist;
|
| 12319 |
GGML_ASSERT(n_per_row%QK_K == 0);
|
| 12320 |
int nblock = n_per_row/QK_K;
|
|
@@ -12333,6 +12434,7 @@ size_t quantize_iq4_xs(const float * src, void * dst, int nrow, int n_per_row, i
|
|
| 12333 |
qrow += nblock*sizeof(block_iq4_xs);
|
| 12334 |
}
|
| 12335 |
return nrow * nblock * sizeof(block_iq4_xs);
|
|
|
|
| 12336 |
}
|
| 12337 |
|
| 12338 |
void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int k) {
|
|
@@ -12363,7 +12465,7 @@ static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy
|
|
| 12363 |
|
| 12364 |
const int kMaxQ = 3;
|
| 12365 |
|
| 12366 |
-
const int nbl = n/
|
| 12367 |
|
| 12368 |
block_iq2_s * y = vy;
|
| 12369 |
|
|
|
|
| 1877 |
float mins[QK_K/16];
|
| 1878 |
float scales[QK_K/16];
|
| 1879 |
float sw[QK_K/16];
|
| 1880 |
+
float weight[16];
|
| 1881 |
uint8_t Ls[QK_K/16], Lm[QK_K/16];
|
| 1882 |
|
| 1883 |
for (int i = 0; i < nb; i++) {
|
|
|
|
| 1887 |
float sigma2 = sumx2/QK_K;
|
| 1888 |
for (int j = 0; j < QK_K/16; ++j) {
|
| 1889 |
const float * restrict qw = quant_weights + QK_K * i + 16*j;
|
| 1890 |
+
for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
|
| 1891 |
for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
|
| 1892 |
+
scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
|
| 1893 |
}
|
| 1894 |
|
| 1895 |
+
float dm, mm;
|
| 1896 |
+
#if QK_K == 64
|
| 1897 |
+
float max_scale = 0, max_min = 0;
|
| 1898 |
+
for (int j = 0; j < QK_K/16; ++j) {
|
| 1899 |
+
max_scale = MAX(max_scale, scales[j]);
|
| 1900 |
+
max_min = MAX(max_min, mins[j]);
|
| 1901 |
+
}
|
| 1902 |
+
dm = max_scale/15;
|
| 1903 |
+
mm = max_min/15;
|
| 1904 |
+
if (max_scale) {
|
| 1905 |
+
float id = 1/dm;
|
| 1906 |
+
for (int j = 0; j < QK_K/16; ++j) {
|
| 1907 |
+
int l = nearest_int(id*scales[j]);
|
| 1908 |
+
Ls[j] = MAX(0, MIN(15, l));
|
| 1909 |
+
}
|
| 1910 |
+
} else {
|
| 1911 |
+
memset(Ls, 0, QK_K/16);
|
| 1912 |
+
}
|
| 1913 |
+
if (max_min) {
|
| 1914 |
+
float id = 1/mm;
|
| 1915 |
+
for (int j = 0; j < QK_K/16; ++j) {
|
| 1916 |
+
int l = nearest_int(id*mins[j]);
|
| 1917 |
+
Lm[j] = MAX(0, MIN(15, l));
|
| 1918 |
+
}
|
| 1919 |
+
} else {
|
| 1920 |
+
memset(Lm, 0, QK_K/16);
|
| 1921 |
+
}
|
| 1922 |
+
#else
|
| 1923 |
+
dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
|
| 1924 |
+
mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
|
| 1925 |
+
#endif
|
| 1926 |
y[i].d = GGML_FP32_TO_FP16(dm);
|
| 1927 |
y[i].dmin = GGML_FP32_TO_FP16(mm);
|
| 1928 |
dm = GGML_FP16_TO_FP32(y[i].d);
|
|
|
|
| 4256 |
|
| 4257 |
void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int k) {
|
| 4258 |
assert(k % QK_K == 0);
|
| 4259 |
+
#if QK_K == 64
|
| 4260 |
+
dequantize_row_iq4_nl((const block_iq4_nl *)x, y, k);
|
| 4261 |
+
#else
|
| 4262 |
const int nb = k / QK_K;
|
| 4263 |
|
| 4264 |
for (int i = 0; i < nb; i++) {
|
|
|
|
| 4278 |
qs += 16;
|
| 4279 |
}
|
| 4280 |
}
|
| 4281 |
+
#endif
|
| 4282 |
}
|
| 4283 |
|
| 4284 |
//===================================== Q8_K ==============================================
|
|
|
|
| 6339 |
|
| 6340 |
float sumf = 0;
|
| 6341 |
|
| 6342 |
+
int isum[QK_K/16];
|
| 6343 |
|
| 6344 |
for (int i = 0; i < nb; ++i) {
|
| 6345 |
|
|
|
|
| 6355 |
const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
| 6356 |
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
| 6357 |
|
| 6358 |
+
memset(isum, 0, (QK_K/16)*sizeof(int));
|
| 6359 |
for (int l = 0; l < 16; ++l) {
|
| 6360 |
isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
|
| 6361 |
isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
|
| 6362 |
isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
|
| 6363 |
isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
|
| 6364 |
}
|
| 6365 |
+
for (int l = 0; l < QK_K/16; ++l) {
|
| 6366 |
isum[l] *= (sc[l] & 0xF);
|
| 6367 |
}
|
| 6368 |
sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
|
|
|
|
| 9521 |
|
| 9522 |
#elif defined(__AVX2__)
|
| 9523 |
|
|
|
|
|
|
|
|
|
|
| 9524 |
const __m256i mone = _mm256_set1_epi8(1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9525 |
static const char block_sign_shuffle_mask_1[32] = {
|
| 9526 |
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
|
| 9527 |
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
|
|
|
|
| 9535 |
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
| 9536 |
};
|
| 9537 |
|
|
|
|
| 9538 |
const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
|
| 9539 |
const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
|
| 9540 |
const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
|
| 9541 |
|
| 9542 |
+
#if QK_K == 64
|
| 9543 |
+
static const uint8_t k_bit_helper[16] = {
|
| 9544 |
+
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
| 9545 |
+
};
|
| 9546 |
+
const __m128i bit_helper = _mm_loadu_si128((const __m128i*)k_bit_helper);
|
| 9547 |
+
const __m128i m511 = _mm_set1_epi16(511);
|
| 9548 |
+
typedef union {
|
| 9549 |
+
__m128i vec_index;
|
| 9550 |
+
uint16_t index[8];
|
| 9551 |
+
} index_t;
|
| 9552 |
+
|
| 9553 |
+
index_t idx;
|
| 9554 |
+
__m256 accumf = _mm256_setzero_ps();
|
| 9555 |
+
for (int i = 0; i < nb; ++i) {
|
| 9556 |
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
| 9557 |
+
const __m128i q2_data = _mm_loadu_si128((const __m128i*)x[i].qs);
|
| 9558 |
+
idx.vec_index = _mm_and_si128(q2_data, m511);
|
| 9559 |
+
|
| 9560 |
+
const __m128i partial_sign_bits = _mm_srli_epi16(q2_data, 9);
|
| 9561 |
+
const __m128i partial_sign_bits_upper = _mm_srli_epi16(q2_data, 13);
|
| 9562 |
+
const __m128i partial_sign_bits_for_counting = _mm_xor_si128(partial_sign_bits, partial_sign_bits_upper);
|
| 9563 |
+
|
| 9564 |
+
const __m128i odd_bits = _mm_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
|
| 9565 |
+
const __m128i full_sign_bits = _mm_or_si128(partial_sign_bits, odd_bits);
|
| 9566 |
+
const __m256i full_signs = _mm256_set_m128i(full_sign_bits, full_sign_bits);
|
| 9567 |
+
|
| 9568 |
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
| 9569 |
+
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)(y[i].qs+32));
|
| 9570 |
+
|
| 9571 |
+
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[idx.index[3]], iq2xs_grid[idx.index[2]],
|
| 9572 |
+
iq2xs_grid[idx.index[1]], iq2xs_grid[idx.index[0]]);
|
| 9573 |
+
const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[idx.index[7]], iq2xs_grid[idx.index[6]],
|
| 9574 |
+
iq2xs_grid[idx.index[5]], iq2xs_grid[idx.index[4]]);
|
| 9575 |
+
|
| 9576 |
+
__m256i signs;
|
| 9577 |
+
signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_1);
|
| 9578 |
+
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
| 9579 |
+
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
|
| 9580 |
+
|
| 9581 |
+
signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_2);
|
| 9582 |
+
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
| 9583 |
+
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
|
| 9584 |
+
|
| 9585 |
+
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
| 9586 |
+
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
| 9587 |
+
|
| 9588 |
+
const __m256i sc1 = _mm256_set_m128i(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1));
|
| 9589 |
+
const __m256i sc2 = _mm256_set_m128i(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1));
|
| 9590 |
+
|
| 9591 |
+
const __m256i sum = _mm256_add_epi32(_mm256_madd_epi16(sc1, dot1), _mm256_madd_epi16(sc2, dot2));
|
| 9592 |
+
|
| 9593 |
+
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sum), accumf);
|
| 9594 |
+
|
| 9595 |
+
}
|
| 9596 |
+
|
| 9597 |
+
*s = 0.125f * hsum_float_8(accumf);
|
| 9598 |
+
#else
|
| 9599 |
+
|
| 9600 |
+
static const uint8_t k_bit_helper[32] = {
|
| 9601 |
+
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
| 9602 |
+
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
| 9603 |
+
};
|
| 9604 |
+
const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
|
| 9605 |
+
const __m256i m511 = _mm256_set1_epi16(511);
|
| 9606 |
+
const __m128i m4 = _mm_set1_epi8(0xf);
|
| 9607 |
+
const __m128i m1 = _mm_set1_epi8(1);
|
| 9608 |
+
|
| 9609 |
uint64_t aux64;
|
| 9610 |
|
| 9611 |
// somewhat hacky, but gives a significant boost in performance
|
|
|
|
| 9694 |
}
|
| 9695 |
|
| 9696 |
*s = 0.125f * hsum_float_8(accumf);
|
| 9697 |
+
#endif
|
| 9698 |
|
| 9699 |
#else
|
| 9700 |
|
|
|
|
| 10291 |
|
| 10292 |
const int nb = n / QK_K;
|
| 10293 |
|
| 10294 |
+
// TODO: implement for QK_K = 64
|
| 10295 |
+
#if defined __ARM_NEON && QK_K == 256
|
| 10296 |
|
| 10297 |
const uint8x16_t m8 = vdupq_n_u8(0x08);
|
| 10298 |
const uint8x16_t m7 = vdupq_n_u8(0x07);
|
|
|
|
| 10349 |
|
| 10350 |
*s = sumf;
|
| 10351 |
|
| 10352 |
+
// TODO: implement for QK_K = 64
|
| 10353 |
+
#elif defined __AVX2__ && QK_K == 256
|
| 10354 |
|
| 10355 |
const __m128i m8 = _mm_set1_epi8(0x08);
|
| 10356 |
const __m128i m7 = _mm_set1_epi8(0x07);
|
|
|
|
| 10549 |
UNUSED(by);
|
| 10550 |
UNUSED(bs);
|
| 10551 |
assert(n % QK_K == 0);
|
| 10552 |
+
#if QK_K == 64
|
| 10553 |
+
ggml_vec_dot_iq4_nl_q8_0(n, s, bs, vx, bx, vy, by, nrc);
|
| 10554 |
+
#else
|
| 10555 |
|
| 10556 |
const block_iq4_xs * restrict x = vx;
|
| 10557 |
const block_q8_K * restrict y = vy;
|
|
|
|
| 10671 |
}
|
| 10672 |
*s = sumf;
|
| 10673 |
#endif
|
| 10674 |
+
#endif
|
| 10675 |
}
|
| 10676 |
|
| 10677 |
// ================================ IQ2 quantization =============================================
|
|
|
|
| 11019 |
|
| 11020 |
const int kMaxQ = 3;
|
| 11021 |
|
| 11022 |
+
const int nbl = n/QK_K;
|
| 11023 |
|
| 11024 |
block_iq2_xxs * y = vy;
|
| 11025 |
|
|
|
|
| 11192 |
|
| 11193 |
const int kMaxQ = 3;
|
| 11194 |
|
| 11195 |
+
const int nbl = n/QK_K;
|
| 11196 |
|
| 11197 |
block_iq2_xs * y = vy;
|
| 11198 |
|
|
|
|
| 12135 |
GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
|
| 12136 |
GGML_ASSERT(n%QK_K == 0);
|
| 12137 |
|
| 12138 |
+
const int nbl = n/QK_K;
|
| 12139 |
|
| 12140 |
block_iq1_s * y = vy;
|
| 12141 |
|
|
|
|
| 12413 |
}
|
| 12414 |
|
| 12415 |
size_t quantize_iq4_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
| 12416 |
+
#if QK_K == 64
|
| 12417 |
+
return quantize_iq4_nl(src, dst, nrow, n_per_row, hist, quant_weights);
|
| 12418 |
+
#else
|
| 12419 |
(void)hist;
|
| 12420 |
GGML_ASSERT(n_per_row%QK_K == 0);
|
| 12421 |
int nblock = n_per_row/QK_K;
|
|
|
|
| 12434 |
qrow += nblock*sizeof(block_iq4_xs);
|
| 12435 |
}
|
| 12436 |
return nrow * nblock * sizeof(block_iq4_xs);
|
| 12437 |
+
#endif
|
| 12438 |
}
|
| 12439 |
|
| 12440 |
void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int k) {
|
|
|
|
| 12465 |
|
| 12466 |
const int kMaxQ = 3;
|
| 12467 |
|
| 12468 |
+
const int nbl = n/QK_K;
|
| 12469 |
|
| 12470 |
block_iq2_s * y = vy;
|
| 12471 |
|
ggml-quants.h
CHANGED
|
@@ -230,6 +230,10 @@ typedef struct {
|
|
| 230 |
} block_iq4_nl;
|
| 231 |
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
typedef struct {
|
| 234 |
ggml_fp16_t d;
|
| 235 |
uint16_t scales_h;
|
|
@@ -237,6 +241,7 @@ typedef struct {
|
|
| 237 |
uint8_t qs[QK_K/2];
|
| 238 |
} block_iq4_xs;
|
| 239 |
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
|
|
|
|
| 240 |
|
| 241 |
#ifdef __cplusplus
|
| 242 |
extern "C" {
|
|
|
|
| 230 |
} block_iq4_nl;
|
| 231 |
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
|
| 232 |
|
| 233 |
+
#if QK_K == 64
|
| 234 |
+
#define block_iq4_xs block_iq4_nl
|
| 235 |
+
//typedef struct block_iq4_nl block_iq4_xs;
|
| 236 |
+
#else
|
| 237 |
typedef struct {
|
| 238 |
ggml_fp16_t d;
|
| 239 |
uint16_t scales_h;
|
|
|
|
| 241 |
uint8_t qs[QK_K/2];
|
| 242 |
} block_iq4_xs;
|
| 243 |
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
|
| 244 |
+
#endif
|
| 245 |
|
| 246 |
#ifdef __cplusplus
|
| 247 |
extern "C" {
|
ggml.c
CHANGED
|
@@ -732,14 +732,22 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|
| 732 |
},
|
| 733 |
[GGML_TYPE_IQ4_XS] = {
|
| 734 |
.type_name = "iq4_xs",
|
|
|
|
|
|
|
|
|
|
| 735 |
.blck_size = QK_K,
|
|
|
|
| 736 |
.type_size = sizeof(block_iq4_xs),
|
| 737 |
.is_quantized = true,
|
| 738 |
.to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
|
| 739 |
.from_float = quantize_row_iq4_xs,
|
| 740 |
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
|
| 741 |
.vec_dot = ggml_vec_dot_iq4_xs_q8_K,
|
|
|
|
|
|
|
|
|
|
| 742 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
|
|
|
| 743 |
.nrows = 1,
|
| 744 |
},
|
| 745 |
[GGML_TYPE_Q8_K] = {
|
|
@@ -19848,6 +19856,9 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
| 19848 |
GGML_ASSERT(result == row_size * nrows);
|
| 19849 |
} break;
|
| 19850 |
case GGML_TYPE_IQ4_NL:
|
|
|
|
|
|
|
|
|
|
| 19851 |
{
|
| 19852 |
GGML_ASSERT(start % QK4_NL == 0);
|
| 19853 |
GGML_ASSERT(start % n_per_row == 0);
|
|
@@ -19856,15 +19867,17 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
| 19856 |
result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
| 19857 |
GGML_ASSERT(result == row_size * nrows);
|
| 19858 |
} break;
|
|
|
|
| 19859 |
case GGML_TYPE_IQ4_XS:
|
| 19860 |
{
|
| 19861 |
-
GGML_ASSERT(start %
|
| 19862 |
GGML_ASSERT(start % n_per_row == 0);
|
| 19863 |
size_t start_row = start / n_per_row;
|
| 19864 |
size_t row_size = ggml_row_size(type, n_per_row);
|
| 19865 |
result = quantize_iq4_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
| 19866 |
GGML_ASSERT(result == row_size * nrows);
|
| 19867 |
} break;
|
|
|
|
| 19868 |
case GGML_TYPE_F16:
|
| 19869 |
{
|
| 19870 |
size_t elemsize = sizeof(ggml_fp16_t);
|
|
|
|
| 732 |
},
|
| 733 |
[GGML_TYPE_IQ4_XS] = {
|
| 734 |
.type_name = "iq4_xs",
|
| 735 |
+
#if QK_K == 64
|
| 736 |
+
.blck_size = QK4_NL,
|
| 737 |
+
#else
|
| 738 |
.blck_size = QK_K,
|
| 739 |
+
#endif
|
| 740 |
.type_size = sizeof(block_iq4_xs),
|
| 741 |
.is_quantized = true,
|
| 742 |
.to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
|
| 743 |
.from_float = quantize_row_iq4_xs,
|
| 744 |
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
|
| 745 |
.vec_dot = ggml_vec_dot_iq4_xs_q8_K,
|
| 746 |
+
#if QK_K == 64
|
| 747 |
+
.vec_dot_type = GGML_TYPE_Q8_0,
|
| 748 |
+
#else
|
| 749 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 750 |
+
#endif
|
| 751 |
.nrows = 1,
|
| 752 |
},
|
| 753 |
[GGML_TYPE_Q8_K] = {
|
|
|
|
| 19856 |
GGML_ASSERT(result == row_size * nrows);
|
| 19857 |
} break;
|
| 19858 |
case GGML_TYPE_IQ4_NL:
|
| 19859 |
+
#if QK_K == 64
|
| 19860 |
+
case GGML_TYPE_IQ4_XS:
|
| 19861 |
+
#endif
|
| 19862 |
{
|
| 19863 |
GGML_ASSERT(start % QK4_NL == 0);
|
| 19864 |
GGML_ASSERT(start % n_per_row == 0);
|
|
|
|
| 19867 |
result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
| 19868 |
GGML_ASSERT(result == row_size * nrows);
|
| 19869 |
} break;
|
| 19870 |
+
#if QK_K != 64
|
| 19871 |
case GGML_TYPE_IQ4_XS:
|
| 19872 |
{
|
| 19873 |
+
GGML_ASSERT(start % QK_K == 0);
|
| 19874 |
GGML_ASSERT(start % n_per_row == 0);
|
| 19875 |
size_t start_row = start / n_per_row;
|
| 19876 |
size_t row_size = ggml_row_size(type, n_per_row);
|
| 19877 |
result = quantize_iq4_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
| 19878 |
GGML_ASSERT(result == row_size * nrows);
|
| 19879 |
} break;
|
| 19880 |
+
#endif
|
| 19881 |
case GGML_TYPE_F16:
|
| 19882 |
{
|
| 19883 |
size_t elemsize = sizeof(ggml_fp16_t);
|