ggerganov commited on
Commit
7e7b11c
·
unverified ·
1 Parent(s): ef85c02

ggml : add AVX dot products

Browse files
Files changed (1) hide show
  1. ggml.c +132 -3
ggml.c CHANGED
@@ -580,7 +580,63 @@ static inline __m128i packNibbles( __m256i bytes )
580
  return _mm_packus_epi16( r0, r1 );
581
  #endif
582
  }
583
- #else
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
  static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
585
  {
586
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -2355,7 +2411,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2355
  }
2356
 
2357
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
2358
- #elif defined(__AVX2__)
2359
  // Initialize accumulator with zeros
2360
  __m256 acc = _mm256_setzero_ps();
2361
 
@@ -2381,7 +2437,11 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2381
  const __m256 xy = mul_sum_i8_pairs_float(bx, by);
2382
 
2383
  // Accumulate d0*d1*x*y
 
2384
  acc = _mm256_fmadd_ps( d0d1, xy, acc );
 
 
 
2385
  }
2386
 
2387
  *s = hsum_float_8(acc) + summs;
@@ -2592,6 +2652,37 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2592
  acc = _mm256_fmadd_ps(d, q, acc);
2593
  }
2594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2595
  *s = hsum_float_8(acc);
2596
  #else
2597
  // scalar
@@ -2820,6 +2911,40 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
2820
  acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
2821
  }
2822
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2823
  *s = hsum_float_8(acc) + summs;
2824
  #else
2825
  // scalar
@@ -2910,7 +3035,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
2910
  }
2911
 
2912
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2913
- #elif defined(__AVX2__)
2914
  // Initialize accumulator with zeros
2915
  __m256 acc = _mm256_setzero_ps();
2916
 
@@ -2924,7 +3049,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
2924
  const __m256 q = mul_sum_i8_pairs_float(bx, by);
2925
 
2926
  // Multiply q with scale and accumulate
 
2927
  acc = _mm256_fmadd_ps( d, q, acc );
 
 
 
2928
  }
2929
 
2930
  *s = hsum_float_8(acc);
 
580
  return _mm_packus_epi16( r0, r1 );
581
  #endif
582
  }
583
+ #elif defined(__AVX__)
584
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
585
+ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
586
+ uint32_t x32;
587
+ memcpy(&x32, x, sizeof(uint32_t));
588
+ const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
589
+ const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
590
+ __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
591
+ __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
592
+ const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
593
+ bytesl = _mm_or_si128(bytesl, bit_mask);
594
+ bytesh = _mm_or_si128(bytesh, bit_mask);
595
+ bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
596
+ bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
597
+ return _mm256_set_m128i(bytesh, bytesl);
598
+ }
599
+
600
+ // Unpack 32 4-bit fields into 32 bytes
601
+ // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
602
+ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
603
+ {
604
+ // Load 16 bytes from memory
605
+ __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
606
+ __m128i tmph = _mm_srli_epi16(tmpl, 4);
607
+ const __m128i lowMask = _mm_set1_epi8(0xF);
608
+ tmpl = _mm_and_si128(lowMask, tmpl);
609
+ tmph = _mm_and_si128(lowMask, tmph);
610
+ return _mm256_set_m128i(tmph, tmpl);
611
+ }
612
+
613
+ // add int16_t pairwise and return as float vector
614
+ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
615
+ const __m128i ones = _mm_set1_epi16(1);
616
+ const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
617
+ const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
618
+ const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl);
619
+ return _mm256_cvtepi32_ps(summed_pairs);
620
+ }
621
+
622
+ // multiply int8_t, add results pairwise twice and return as float vector
623
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
624
+ const __m128i xl = _mm256_castsi256_si128(x);
625
+ const __m128i xh = _mm256_extractf128_si256(x, 1);
626
+ const __m128i yl = _mm256_castsi256_si128(y);
627
+ const __m128i yh = _mm256_extractf128_si256(y, 1);
628
+ // Get absolute values of x vectors
629
+ const __m128i axl = _mm_sign_epi8(xl, xl);
630
+ const __m128i axh = _mm_sign_epi8(xh, xh);
631
+ // Sign the values of the y vectors
632
+ const __m128i syl = _mm_sign_epi8(yl, xl);
633
+ const __m128i syh = _mm_sign_epi8(yh, xh);
634
+ // Perform multiplication and create 16-bit values
635
+ const __m128i dotl = _mm_maddubs_epi16(axl, syl);
636
+ const __m128i doth = _mm_maddubs_epi16(axh, syh);
637
+ return sum_i16_pairs_float(doth, dotl);
638
+ }
639
+
640
  static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
641
  {
642
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
 
2411
  }
2412
 
2413
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
2414
+ #elif defined(__AVX2__) || defined(__AVX__)
2415
  // Initialize accumulator with zeros
2416
  __m256 acc = _mm256_setzero_ps();
2417
 
 
2437
  const __m256 xy = mul_sum_i8_pairs_float(bx, by);
2438
 
2439
  // Accumulate d0*d1*x*y
2440
+ #if defined(__AVX2__)
2441
  acc = _mm256_fmadd_ps( d0d1, xy, acc );
2442
+ #else
2443
+ acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
2444
+ #endif
2445
  }
2446
 
2447
  *s = hsum_float_8(acc) + summs;
 
2652
  acc = _mm256_fmadd_ps(d, q, acc);
2653
  }
2654
 
2655
+ *s = hsum_float_8(acc);
2656
+ #elif defined(__AVX__)
2657
+ // Initialize accumulator with zeros
2658
+ __m256 acc = _mm256_setzero_ps();
2659
+ __m128i mask = _mm_set1_epi8((char)0xF0);
2660
+
2661
+ // Main loop
2662
+ for (int i = 0; i < nb; i++) {
2663
+ /* Compute combined scale for the block */
2664
+ const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
2665
+
2666
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
2667
+ const __m256i bxhi = bytes_from_bits_32(x[i].qh);
2668
+ __m128i bxhil = _mm256_castsi256_si128(bxhi);
2669
+ __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
2670
+ bxhil = _mm_andnot_si128(bxhil, mask);
2671
+ bxhih = _mm_andnot_si128(bxhih, mask);
2672
+ __m128i bxl = _mm256_castsi256_si128(bx);
2673
+ __m128i bxh = _mm256_extractf128_si256(bx, 1);
2674
+ bxl = _mm_or_si128(bxl, bxhil);
2675
+ bxh = _mm_or_si128(bxh, bxhih);
2676
+ bx = _mm256_set_m128i(bxh, bxl);
2677
+
2678
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2679
+
2680
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2681
+
2682
+ /* Multiply q with scale and accumulate */
2683
+ acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
2684
+ }
2685
+
2686
  *s = hsum_float_8(acc);
2687
  #else
2688
  // scalar
 
2911
  acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
2912
  }
2913
 
2914
+ *s = hsum_float_8(acc) + summs;
2915
+ #elif defined(__AVX__)
2916
+ // Initialize accumulator with zeros
2917
+ __m256 acc = _mm256_setzero_ps();
2918
+ __m128i mask = _mm_set1_epi8(0x10);
2919
+
2920
+ float summs = 0.0f;
2921
+
2922
+ // Main loop
2923
+ for (int i = 0; i < nb; i++) {
2924
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
2925
+
2926
+ summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
2927
+
2928
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
2929
+ const __m256i bxhi = bytes_from_bits_32(x[i].qh);
2930
+ __m128i bxhil = _mm256_castsi256_si128(bxhi);
2931
+ __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
2932
+ bxhil = _mm_and_si128(bxhil, mask);
2933
+ bxhih = _mm_and_si128(bxhih, mask);
2934
+ __m128i bxl = _mm256_castsi256_si128(bx);
2935
+ __m128i bxh = _mm256_extractf128_si256(bx, 1);
2936
+ bxl = _mm_or_si128(bxl, bxhil);
2937
+ bxh = _mm_or_si128(bxh, bxhih);
2938
+ bx = _mm256_set_m128i(bxh, bxl);
2939
+
2940
+ const __m256 dy = _mm256_broadcast_ss(&y[i].d);
2941
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2942
+
2943
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2944
+
2945
+ acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
2946
+ }
2947
+
2948
  *s = hsum_float_8(acc) + summs;
2949
  #else
2950
  // scalar
 
3035
  }
3036
 
3037
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3038
+ #elif defined(__AVX2__) || defined(__AVX__)
3039
  // Initialize accumulator with zeros
3040
  __m256 acc = _mm256_setzero_ps();
3041
 
 
3049
  const __m256 q = mul_sum_i8_pairs_float(bx, by);
3050
 
3051
  // Multiply q with scale and accumulate
3052
+ #if defined(__AVX2__)
3053
  acc = _mm256_fmadd_ps( d, q, acc );
3054
+ #else
3055
+ acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
3056
+ #endif
3057
  }
3058
 
3059
  *s = hsum_float_8(acc);