Spaces:
Sleeping
Sleeping
ggml : add AVX dot products
Browse files
ggml.c
CHANGED
|
@@ -580,7 +580,63 @@ static inline __m128i packNibbles( __m256i bytes )
|
|
| 580 |
return _mm_packus_epi16( r0, r1 );
|
| 581 |
#endif
|
| 582 |
}
|
| 583 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
| 585 |
{
|
| 586 |
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
|
@@ -2355,7 +2411,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
|
|
| 2355 |
}
|
| 2356 |
|
| 2357 |
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
| 2358 |
-
#elif defined(__AVX2__)
|
| 2359 |
// Initialize accumulator with zeros
|
| 2360 |
__m256 acc = _mm256_setzero_ps();
|
| 2361 |
|
|
@@ -2381,7 +2437,11 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
|
|
| 2381 |
const __m256 xy = mul_sum_i8_pairs_float(bx, by);
|
| 2382 |
|
| 2383 |
// Accumulate d0*d1*x*y
|
|
|
|
| 2384 |
acc = _mm256_fmadd_ps( d0d1, xy, acc );
|
|
|
|
|
|
|
|
|
|
| 2385 |
}
|
| 2386 |
|
| 2387 |
*s = hsum_float_8(acc) + summs;
|
|
@@ -2592,6 +2652,37 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
|
|
| 2592 |
acc = _mm256_fmadd_ps(d, q, acc);
|
| 2593 |
}
|
| 2594 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2595 |
*s = hsum_float_8(acc);
|
| 2596 |
#else
|
| 2597 |
// scalar
|
|
@@ -2820,6 +2911,40 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
|
| 2820 |
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
|
| 2821 |
}
|
| 2822 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2823 |
*s = hsum_float_8(acc) + summs;
|
| 2824 |
#else
|
| 2825 |
// scalar
|
|
@@ -2910,7 +3035,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
|
|
| 2910 |
}
|
| 2911 |
|
| 2912 |
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
| 2913 |
-
#elif defined(__AVX2__)
|
| 2914 |
// Initialize accumulator with zeros
|
| 2915 |
__m256 acc = _mm256_setzero_ps();
|
| 2916 |
|
|
@@ -2924,7 +3049,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
|
|
| 2924 |
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
| 2925 |
|
| 2926 |
// Multiply q with scale and accumulate
|
|
|
|
| 2927 |
acc = _mm256_fmadd_ps( d, q, acc );
|
|
|
|
|
|
|
|
|
|
| 2928 |
}
|
| 2929 |
|
| 2930 |
*s = hsum_float_8(acc);
|
|
|
|
| 580 |
return _mm_packus_epi16( r0, r1 );
|
| 581 |
#endif
|
| 582 |
}
|
| 583 |
+
#elif defined(__AVX__)
|
| 584 |
+
// spread 32 bits to 32 bytes { 0x00, 0xFF }
|
| 585 |
+
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
| 586 |
+
uint32_t x32;
|
| 587 |
+
memcpy(&x32, x, sizeof(uint32_t));
|
| 588 |
+
const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
|
| 589 |
+
const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
|
| 590 |
+
__m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
|
| 591 |
+
__m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
|
| 592 |
+
const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
|
| 593 |
+
bytesl = _mm_or_si128(bytesl, bit_mask);
|
| 594 |
+
bytesh = _mm_or_si128(bytesh, bit_mask);
|
| 595 |
+
bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
|
| 596 |
+
bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
|
| 597 |
+
return _mm256_set_m128i(bytesh, bytesl);
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
// Unpack 32 4-bit fields into 32 bytes
|
| 601 |
+
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
| 602 |
+
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
| 603 |
+
{
|
| 604 |
+
// Load 16 bytes from memory
|
| 605 |
+
__m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
|
| 606 |
+
__m128i tmph = _mm_srli_epi16(tmpl, 4);
|
| 607 |
+
const __m128i lowMask = _mm_set1_epi8(0xF);
|
| 608 |
+
tmpl = _mm_and_si128(lowMask, tmpl);
|
| 609 |
+
tmph = _mm_and_si128(lowMask, tmph);
|
| 610 |
+
return _mm256_set_m128i(tmph, tmpl);
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
// add int16_t pairwise and return as float vector
|
| 614 |
+
static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
|
| 615 |
+
const __m128i ones = _mm_set1_epi16(1);
|
| 616 |
+
const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
|
| 617 |
+
const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
|
| 618 |
+
const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl);
|
| 619 |
+
return _mm256_cvtepi32_ps(summed_pairs);
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
// multiply int8_t, add results pairwise twice and return as float vector
|
| 623 |
+
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
| 624 |
+
const __m128i xl = _mm256_castsi256_si128(x);
|
| 625 |
+
const __m128i xh = _mm256_extractf128_si256(x, 1);
|
| 626 |
+
const __m128i yl = _mm256_castsi256_si128(y);
|
| 627 |
+
const __m128i yh = _mm256_extractf128_si256(y, 1);
|
| 628 |
+
// Get absolute values of x vectors
|
| 629 |
+
const __m128i axl = _mm_sign_epi8(xl, xl);
|
| 630 |
+
const __m128i axh = _mm_sign_epi8(xh, xh);
|
| 631 |
+
// Sign the values of the y vectors
|
| 632 |
+
const __m128i syl = _mm_sign_epi8(yl, xl);
|
| 633 |
+
const __m128i syh = _mm_sign_epi8(yh, xh);
|
| 634 |
+
// Perform multiplication and create 16-bit values
|
| 635 |
+
const __m128i dotl = _mm_maddubs_epi16(axl, syl);
|
| 636 |
+
const __m128i doth = _mm_maddubs_epi16(axh, syh);
|
| 637 |
+
return sum_i16_pairs_float(doth, dotl);
|
| 638 |
+
}
|
| 639 |
+
|
| 640 |
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
| 641 |
{
|
| 642 |
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
|
|
|
| 2411 |
}
|
| 2412 |
|
| 2413 |
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
| 2414 |
+
#elif defined(__AVX2__) || defined(__AVX__)
|
| 2415 |
// Initialize accumulator with zeros
|
| 2416 |
__m256 acc = _mm256_setzero_ps();
|
| 2417 |
|
|
|
|
| 2437 |
const __m256 xy = mul_sum_i8_pairs_float(bx, by);
|
| 2438 |
|
| 2439 |
// Accumulate d0*d1*x*y
|
| 2440 |
+
#if defined(__AVX2__)
|
| 2441 |
acc = _mm256_fmadd_ps( d0d1, xy, acc );
|
| 2442 |
+
#else
|
| 2443 |
+
acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
|
| 2444 |
+
#endif
|
| 2445 |
}
|
| 2446 |
|
| 2447 |
*s = hsum_float_8(acc) + summs;
|
|
|
|
| 2652 |
acc = _mm256_fmadd_ps(d, q, acc);
|
| 2653 |
}
|
| 2654 |
|
| 2655 |
+
*s = hsum_float_8(acc);
|
| 2656 |
+
#elif defined(__AVX__)
|
| 2657 |
+
// Initialize accumulator with zeros
|
| 2658 |
+
__m256 acc = _mm256_setzero_ps();
|
| 2659 |
+
__m128i mask = _mm_set1_epi8((char)0xF0);
|
| 2660 |
+
|
| 2661 |
+
// Main loop
|
| 2662 |
+
for (int i = 0; i < nb; i++) {
|
| 2663 |
+
/* Compute combined scale for the block */
|
| 2664 |
+
const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
|
| 2665 |
+
|
| 2666 |
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
| 2667 |
+
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
|
| 2668 |
+
__m128i bxhil = _mm256_castsi256_si128(bxhi);
|
| 2669 |
+
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
|
| 2670 |
+
bxhil = _mm_andnot_si128(bxhil, mask);
|
| 2671 |
+
bxhih = _mm_andnot_si128(bxhih, mask);
|
| 2672 |
+
__m128i bxl = _mm256_castsi256_si128(bx);
|
| 2673 |
+
__m128i bxh = _mm256_extractf128_si256(bx, 1);
|
| 2674 |
+
bxl = _mm_or_si128(bxl, bxhil);
|
| 2675 |
+
bxh = _mm_or_si128(bxh, bxhih);
|
| 2676 |
+
bx = _mm256_set_m128i(bxh, bxl);
|
| 2677 |
+
|
| 2678 |
+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
| 2679 |
+
|
| 2680 |
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
| 2681 |
+
|
| 2682 |
+
/* Multiply q with scale and accumulate */
|
| 2683 |
+
acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
|
| 2684 |
+
}
|
| 2685 |
+
|
| 2686 |
*s = hsum_float_8(acc);
|
| 2687 |
#else
|
| 2688 |
// scalar
|
|
|
|
| 2911 |
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
|
| 2912 |
}
|
| 2913 |
|
| 2914 |
+
*s = hsum_float_8(acc) + summs;
|
| 2915 |
+
#elif defined(__AVX__)
|
| 2916 |
+
// Initialize accumulator with zeros
|
| 2917 |
+
__m256 acc = _mm256_setzero_ps();
|
| 2918 |
+
__m128i mask = _mm_set1_epi8(0x10);
|
| 2919 |
+
|
| 2920 |
+
float summs = 0.0f;
|
| 2921 |
+
|
| 2922 |
+
// Main loop
|
| 2923 |
+
for (int i = 0; i < nb; i++) {
|
| 2924 |
+
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
|
| 2925 |
+
|
| 2926 |
+
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
|
| 2927 |
+
|
| 2928 |
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
| 2929 |
+
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
|
| 2930 |
+
__m128i bxhil = _mm256_castsi256_si128(bxhi);
|
| 2931 |
+
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
|
| 2932 |
+
bxhil = _mm_and_si128(bxhil, mask);
|
| 2933 |
+
bxhih = _mm_and_si128(bxhih, mask);
|
| 2934 |
+
__m128i bxl = _mm256_castsi256_si128(bx);
|
| 2935 |
+
__m128i bxh = _mm256_extractf128_si256(bx, 1);
|
| 2936 |
+
bxl = _mm_or_si128(bxl, bxhil);
|
| 2937 |
+
bxh = _mm_or_si128(bxh, bxhih);
|
| 2938 |
+
bx = _mm256_set_m128i(bxh, bxl);
|
| 2939 |
+
|
| 2940 |
+
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
|
| 2941 |
+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
| 2942 |
+
|
| 2943 |
+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
| 2944 |
+
|
| 2945 |
+
acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
|
| 2946 |
+
}
|
| 2947 |
+
|
| 2948 |
*s = hsum_float_8(acc) + summs;
|
| 2949 |
#else
|
| 2950 |
// scalar
|
|
|
|
| 3035 |
}
|
| 3036 |
|
| 3037 |
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
| 3038 |
+
#elif defined(__AVX2__) || defined(__AVX__)
|
| 3039 |
// Initialize accumulator with zeros
|
| 3040 |
__m256 acc = _mm256_setzero_ps();
|
| 3041 |
|
|
|
|
| 3049 |
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
| 3050 |
|
| 3051 |
// Multiply q with scale and accumulate
|
| 3052 |
+
#if defined(__AVX2__)
|
| 3053 |
acc = _mm256_fmadd_ps( d, q, acc );
|
| 3054 |
+
#else
|
| 3055 |
+
acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
|
| 3056 |
+
#endif
|
| 3057 |
}
|
| 3058 |
|
| 3059 |
*s = hsum_float_8(acc);
|