Justine Tunney commited on
Commit
093eec4
·
1 Parent(s): daae175

ggml : add llamafile sgemm (llama/6414)

Browse files

This change upstreams llamafile's cpu matrix multiplication kernels
which improve image and prompt evaluation speed. For starters, Q4_0
and Q8_0 weights should go ~40% faster on CPU. The biggest benefits
are with data types like f16 / f32, which process prompts 2x faster
thus making them faster than quantized data types for prompt evals.

This change also introduces bona fide AVX512 support since tinyBLAS
is able to exploit the larger register file. For example, on my CPU
llama.cpp llava-cli processes an image prompt at 305 tokens/second,
using the Q4_K and Q4_0 types, which has always been faster than if
we used f16 LLaVA weights, which at HEAD go 188 tokens/second. With
this change, f16 LLaVA performance leap frogs to 464 tokens/second.

On Intel Core i9-14900K this change improves F16 prompt perf by 5x.
For example, using llama.cpp at HEAD with Mistral 7b f16 to process
a 215 token prompt will go 13 tok/sec. This change has fixes making
it go 52 tok/sec. It's mostly thanks to my vectorized outer product
kernels but also because I added support for correctly counting the
number of cores on Alderlake, so the default thread count discounts
Intel's new efficiency cores. Only Linux right now can count cores.

This work was sponsored by Mozilla who's given permission to change
the license of this code from Apache 2.0 to MIT. To read more about
what's improved, and how it works, see: https://justine.lol/matmul/

Files changed (3) hide show
  1. ggml-impl.h +1 -1
  2. ggml-quants.c +1 -1
  3. ggml.c +54 -0
ggml-impl.h CHANGED
@@ -95,7 +95,7 @@ typedef uint16_t ggml_fp16_internal_t;
95
  #if defined(_MSC_VER) || defined(__MINGW32__)
96
  #include <intrin.h>
97
  #else
98
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
99
  #if !defined(__riscv)
100
  #include <immintrin.h>
101
  #endif
 
95
  #if defined(_MSC_VER) || defined(__MINGW32__)
96
  #include <intrin.h>
97
  #else
98
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
99
  #if !defined(__riscv)
100
  #include <immintrin.h>
101
  #endif
ggml-quants.c CHANGED
@@ -138,7 +138,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
138
  }
139
 
140
  static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
141
- #if defined(__AVXVNNI__) || defined(__AVX512VNNI__)
142
  const __m256i zero = _mm256_setzero_si256();
143
  const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
144
  return _mm256_cvtepi32_ps(summed_pairs);
 
138
  }
139
 
140
  static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
141
+ #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
142
  const __m256i zero = _mm256_setzero_si256();
143
  const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
144
  return _mm256_cvtepi32_ps(summed_pairs);
ggml.c CHANGED
@@ -4,6 +4,7 @@
4
  #include "ggml-impl.h"
5
  #include "ggml-quants.h"
6
  #include "ggml.h"
 
7
 
8
  #if defined(_MSC_VER) || defined(__MINGW32__)
9
  #include <malloc.h> // using malloc.h with MSC/MINGW
@@ -32,6 +33,14 @@
32
  #include <unistd.h>
33
  #endif
34
 
 
 
 
 
 
 
 
 
35
  #if defined(_MSC_VER)
36
  // disable "possible loss of data" to avoid hundreds of casts
37
  // we should just be careful :)
@@ -10872,6 +10881,28 @@ static void ggml_compute_forward_mul_mat(
10872
  }
10873
  #endif
10874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10875
  if (params->type == GGML_TASK_TYPE_INIT) {
10876
  if (ith != 0) {
10877
  return;
@@ -10903,6 +10934,29 @@ static void ggml_compute_forward_mul_mat(
10903
  const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
10904
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
10905
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10906
  const int64_t nr0 = ne01; // src0 rows
10907
  const int64_t nr1 = ne1*ne12*ne13; // src1 rows
10908
 
 
4
  #include "ggml-impl.h"
5
  #include "ggml-quants.h"
6
  #include "ggml.h"
7
+ #include "sgemm.h"
8
 
9
  #if defined(_MSC_VER) || defined(__MINGW32__)
10
  #include <malloc.h> // using malloc.h with MSC/MINGW
 
33
  #include <unistd.h>
34
  #endif
35
 
36
+ #ifndef GGML_USE_LLAMAFILE
37
+ #ifdef __ARM_FEATURE_MATMUL_INT8
38
+ #define GGML_USE_LLAMAFILE 0
39
+ #else
40
+ #define GGML_USE_LLAMAFILE 1
41
+ #endif
42
+ #endif
43
+
44
  #if defined(_MSC_VER)
45
  // disable "possible loss of data" to avoid hundreds of casts
46
  // we should just be careful :)
 
10881
  }
10882
  #endif
10883
 
10884
+ #if GGML_USE_LLAMAFILE
10885
+ if (nb10 == ggml_type_size(src1->type)) {
10886
+ for (int64_t i13 = 0; i13 < ne13; i13++)
10887
+ for (int64_t i12 = 0; i12 < ne12; i12++)
10888
+ if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
10889
+ (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
10890
+ nb01/ggml_type_size(src0->type),
10891
+ (const char *)src1->data + i12*nb12 + i13*nb13,
10892
+ nb11/ggml_type_size(src1->type),
10893
+ (char *)dst->data + i12*nb2 + i13*nb3,
10894
+ nb1/ggml_type_size(dst->type),
10895
+ ith, nth,
10896
+ params->type,
10897
+ src0->type,
10898
+ src1->type,
10899
+ dst->type))
10900
+ goto UseGgmlGemm1;
10901
+ return;
10902
+ }
10903
+ UseGgmlGemm1:;
10904
+ #endif
10905
+
10906
  if (params->type == GGML_TASK_TYPE_INIT) {
10907
  if (ith != 0) {
10908
  return;
 
10934
  const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
10935
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
10936
 
10937
+ #if GGML_USE_LLAMAFILE
10938
+ if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) {
10939
+ for (int64_t i13 = 0; i13 < ne13; i13++)
10940
+ for (int64_t i12 = 0; i12 < ne12; i12++)
10941
+ if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
10942
+ (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
10943
+ nb01/ggml_type_size(src0->type),
10944
+ (const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i12 +
10945
+ nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i13),
10946
+ row_size/ggml_type_size(vec_dot_type),
10947
+ (char *)dst->data + i12*nb2 + i13*nb3,
10948
+ nb1/ggml_type_size(dst->type),
10949
+ ith, nth,
10950
+ params->type,
10951
+ src0->type,
10952
+ vec_dot_type,
10953
+ dst->type))
10954
+ goto UseGgmlGemm2;
10955
+ return;
10956
+ }
10957
+ UseGgmlGemm2:;
10958
+ #endif
10959
+
10960
  const int64_t nr0 = ne01; // src0 rows
10961
  const int64_t nr1 = ne1*ne12*ne13; // src1 rows
10962