ggerganov commited on
Commit
a0d4b48
·
unverified ·
1 Parent(s): ed75714

sync : ggml (Metal fixes, new ops, tests) (#1633)

Browse files

* sync : ggml (Metal fixes, new ops, tests)

* cuda : fix bin bcast when src1 and dst have different types

Files changed (7) hide show
  1. ggml-alloc.h +1 -1
  2. ggml-cuda.cu +683 -89
  3. ggml-metal.m +530 -53
  4. ggml-metal.metal +1497 -169
  5. ggml-quants.c +2 -2
  6. ggml.c +264 -99
  7. ggml.h +21 -7
ggml-alloc.h CHANGED
@@ -43,7 +43,7 @@ GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph
43
  // ggml-backend v2 API
44
  //
45
 
46
- // Seperate tensor and graph allocator objects
47
  // This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
48
  // The original API is kept as a wrapper around the new API
49
 
 
43
  // ggml-backend v2 API
44
  //
45
 
46
+ // Separate tensor and graph allocator objects
47
  // This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
48
  // The original API is kept as a wrapper around the new API
49
 
ggml-cuda.cu CHANGED
@@ -1,13 +1,15 @@
1
  #include <algorithm>
 
 
 
2
  #include <cstddef>
3
  #include <cstdint>
4
- #include <cinttypes>
5
  #include <float.h>
6
  #include <limits>
7
  #include <stdint.h>
8
  #include <stdio.h>
9
- #include <atomic>
10
- #include <assert.h>
11
 
12
  #if defined(GGML_USE_HIPBLAS)
13
  #include <hip/hip_runtime.h>
@@ -437,6 +439,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
437
 
438
  #define CUDA_GELU_BLOCK_SIZE 256
439
  #define CUDA_SILU_BLOCK_SIZE 256
 
440
  #define CUDA_RELU_BLOCK_SIZE 256
441
  #define CUDA_SQR_BLOCK_SIZE 256
442
  #define CUDA_CPY_BLOCK_SIZE 32
@@ -449,6 +452,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
449
  #define CUDA_QUANTIZE_BLOCK_SIZE 256
450
  #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
451
  #define CUDA_GET_ROWS_BLOCK_SIZE 256
 
 
 
 
 
452
 
453
  // dmmv = dequantize_mul_mat_vec
454
  #ifndef GGML_CUDA_DMMV_X
@@ -610,6 +618,24 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
610
  dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
611
  }
612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  static __global__ void gelu_f32(const float * x, float * dst, const int k) {
614
  const float GELU_COEF_A = 0.044715f;
615
  const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -632,6 +658,23 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
632
  dst[i] = x[i] / (1.0f + expf(-x[i]));
633
  }
634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
  static __global__ void relu_f32(const float * x, float * dst, const int k) {
636
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
637
 
@@ -641,6 +684,14 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
641
  dst[i] = fmaxf(x[i], 0);
642
  }
643
 
 
 
 
 
 
 
 
 
644
  static __global__ void sqr_f32(const float * x, float * dst, const int k) {
645
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
646
 
@@ -686,6 +737,132 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
686
  }
687
  }
688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
  template <int block_size>
690
  static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
691
  const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -1684,31 +1861,65 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
1684
  }
1685
 
1686
  template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
1687
- static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) {
1688
- const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2;
1689
- const int row = blockDim.y*blockIdx.y + threadIdx.y;
1690
-
1691
- if (col >= ncols) {
 
 
 
 
 
 
 
 
 
1692
  return;
1693
  }
1694
 
1695
- const int r = y[row];
1696
 
1697
- // copy x[r*ncols + col] to dst[row*ncols + col]
1698
- const int xi = r*ncols + col;
1699
- const int di = row*ncols + col;
1700
 
1701
- const int ib = xi/qk; // block index
1702
- const int iqs = (xi%qk)/qr; // quant index
1703
- const int iybs = di - di%qk; // y block start index
1704
  const int y_offset = qr == 1 ? 1 : qk/2;
1705
 
1706
  // dequantize
1707
  dfloat2 v;
1708
- dequantize_kernel(x, ib, iqs, v);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1709
 
1710
- dst[iybs + iqs + 0] = v.x;
1711
- dst[iybs + iqs + y_offset] = v.y;
1712
  }
1713
 
1714
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@@ -5035,29 +5246,98 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
5035
 
5036
  static __global__ void im2col_f32_f16(
5037
  const float * x, half * dst,
5038
- int ofs0, int ofs1, int IW, int IH, int CHW,
5039
  int s0, int s1, int p0, int p1, int d0, int d1) {
5040
- const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
5041
- const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
 
 
 
 
 
 
 
 
 
 
 
5042
 
5043
  const int offset_dst =
5044
- (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
5045
- (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
5046
 
5047
  if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
5048
  dst[offset_dst] = __float2half(0.0f);
5049
  } else {
5050
- const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
5051
  dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
5052
  }
5053
  }
5054
 
5055
  template<int qk, int qr, dequantize_kernel_t dq>
5056
- static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
 
 
 
 
5057
  const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
5058
- const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
5059
- const dim3 block_nums(block_num_x, nrows, 1);
5060
- k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5061
  }
5062
 
5063
  template<float (*bin_op)(const float, const float)>
@@ -5069,7 +5349,6 @@ struct bin_bcast_cuda {
5069
 
5070
  GGML_TENSOR_BINARY_OP_LOCALS
5071
 
5072
-
5073
  int nr0 = ne10/ne0;
5074
  int nr1 = ne11/ne1;
5075
  int nr2 = ne12/ne2;
@@ -5117,26 +5396,28 @@ struct bin_bcast_cuda {
5117
  int64_t ne12 = cne1[2];
5118
  int64_t ne13 = cne1[3];
5119
 
5120
- //size_t nb0 = cnb0[0];
5121
  size_t nb1 = cnb0[1];
5122
  size_t nb2 = cnb0[2];
5123
  size_t nb3 = cnb0[3];
5124
 
5125
- //size_t nb10 = cnb1[0];
5126
  size_t nb11 = cnb1[1];
5127
  size_t nb12 = cnb1[2];
5128
  size_t nb13 = cnb1[3];
5129
 
5130
- //size_t s0 = nb0 / sizeof(src1_t);
5131
- size_t s1 = nb1 / sizeof(src1_t);
5132
- size_t s2 = nb2 / sizeof(src1_t);
5133
- size_t s3 = nb3 / sizeof(src1_t);
5134
 
5135
- //size_t s10 = nb10 / sizeof(src1_t);
5136
  size_t s11 = nb11 / sizeof(src1_t);
5137
  size_t s12 = nb12 / sizeof(src1_t);
5138
  size_t s13 = nb13 / sizeof(src1_t);
5139
 
 
 
5140
 
5141
  const int block_size = 128;
5142
 
@@ -5174,6 +5455,13 @@ struct bin_bcast_cuda {
5174
  }
5175
  };
5176
 
 
 
 
 
 
 
 
5177
  static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
5178
  const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
5179
  gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -5184,11 +5472,26 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
5184
  silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
5185
  }
5186
 
 
 
 
 
 
 
 
 
 
 
5187
  static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
5188
  const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
5189
  relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
5190
  }
5191
 
 
 
 
 
 
5192
  static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
5193
  const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
5194
  sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -5205,6 +5508,38 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
5205
  }
5206
  }
5207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5208
  static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
5209
  GGML_ASSERT(ncols % WARP_SIZE == 0);
5210
  if (ncols < 1024) {
@@ -6167,13 +6502,14 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
6167
  soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6168
  }
6169
 
6170
- static void im2col_f32_f16_cuda(const float * x, half * dst,
6171
- int OH, int IW, int IH, int OW, int IC,
6172
- int KH, int KW, int N, int ofs0, int ofs1,
6173
- int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
6174
- dim3 block_nums(IC, OH, OW);
6175
- dim3 block_dims(N, KH, KW);
6176
- im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
 
6177
  }
6178
 
6179
  // buffer pool for cuda
@@ -6447,36 +6783,34 @@ static void ggml_cuda_op_get_rows(
6447
 
6448
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
6449
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
6450
- GGML_ASSERT(ggml_is_contiguous(src0));
6451
- GGML_ASSERT(ggml_is_contiguous(src1));
6452
- GGML_ASSERT(ggml_is_contiguous(dst));
6453
 
6454
- const int ncols = src0->ne[0];
6455
- const int nrows = ggml_nelements(src1);
 
6456
 
6457
  const int32_t * src1_i32 = (const int32_t *) src1_d;
6458
 
6459
  switch (src0->type) {
6460
  case GGML_TYPE_F16:
6461
- get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
6462
  break;
6463
  case GGML_TYPE_F32:
6464
- get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
6465
  break;
6466
  case GGML_TYPE_Q4_0:
6467
- get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
6468
  break;
6469
  case GGML_TYPE_Q4_1:
6470
- get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
6471
  break;
6472
  case GGML_TYPE_Q5_0:
6473
- get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
6474
  break;
6475
  case GGML_TYPE_Q5_1:
6476
- get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
6477
  break;
6478
  case GGML_TYPE_Q8_0:
6479
- get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
6480
  break;
6481
  default:
6482
  // TODO: k-quants
@@ -6522,6 +6856,25 @@ inline void ggml_cuda_op_add(
6522
  ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
6523
  }
6524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6525
  inline void ggml_cuda_op_mul(
6526
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6527
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6564,6 +6917,34 @@ inline void ggml_cuda_op_silu(
6564
  (void) src1_dd;
6565
  }
6566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6567
  inline void ggml_cuda_op_relu(
6568
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6569
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6578,6 +6959,23 @@ inline void ggml_cuda_op_relu(
6578
  (void) src1_dd;
6579
  }
6580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6581
  inline void ggml_cuda_op_sqr(
6582
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6583
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6612,6 +7010,71 @@ inline void ggml_cuda_op_norm(
6612
  (void) src1_dd;
6613
  }
6614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6615
  inline void ggml_cuda_op_rms_norm(
6616
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6617
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7126,7 +7589,6 @@ inline void ggml_cuda_op_im2col(
7126
 
7127
  const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
7128
 
7129
- const int64_t N = src1->ne[is_2D ? 3 : 2];
7130
  const int64_t IC = src1->ne[is_2D ? 2 : 1];
7131
  const int64_t IH = is_2D ? src1->ne[1] : 1;
7132
  const int64_t IW = src1->ne[0];
@@ -7137,17 +7599,15 @@ inline void ggml_cuda_op_im2col(
7137
  const int64_t OH = is_2D ? dst->ne[2] : 1;
7138
  const int64_t OW = dst->ne[1];
7139
 
7140
- const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
7141
- const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
7142
 
7143
- im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
7144
- OH, IW, IH, OW, IC, KH, KW, N,
7145
- ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
7146
 
7147
  (void) src0;
7148
  (void) src0_dd;
7149
  }
7150
 
 
7151
  inline void ggml_cuda_op_sum_rows(
7152
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
7153
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7696,6 +8156,10 @@ static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, gg
7696
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
7697
  }
7698
 
 
 
 
 
7699
  static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7700
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
7701
  }
@@ -7712,10 +8176,22 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
7712
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
7713
  }
7714
 
 
 
 
 
 
 
 
 
7715
  static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7716
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
7717
  }
7718
 
 
 
 
 
7719
  static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7720
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
7721
  }
@@ -7724,6 +8200,22 @@ static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, g
7724
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
7725
  }
7726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7727
  static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7728
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
7729
  }
@@ -8234,36 +8726,69 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
8234
  }
8235
  #endif
8236
 
8237
- static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
8238
  #if 0
8239
- //#ifdef CUDA_USE_TENSOR_CORES
8240
- // const bool use_tensor_cores = true;
8241
- //#else
8242
- // const bool use_tensor_cores = false;
8243
- //#endif
8244
-
8245
  ggml_cuda_mul_mat_id_cublas(dst);
8246
-
8247
  // TODO: mmq/mmv support
8248
- #else
8249
- const struct ggml_tensor * ids = dst->src[0];
8250
- const struct ggml_tensor * src1 = dst->src[1];
8251
- const int id = dst->op_params[0];
8252
 
8253
- int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
8254
 
8255
- int32_t a_id;
8256
- CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
8257
- CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
8258
 
8259
- GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
8260
- const struct ggml_tensor * src0 = dst->src[a_id + 2];
8261
 
8262
- ggml_cuda_mul_mat(src0, src1, dst);
8263
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8264
 
8265
- (void) _src0;
8266
- (void) _src1;
 
 
 
 
 
 
 
 
8267
  }
8268
 
8269
  static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -8683,6 +9208,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
8683
  case GGML_OP_ADD:
8684
  func = ggml_cuda_add;
8685
  break;
 
 
 
8686
  case GGML_OP_MUL:
8687
  func = ggml_cuda_mul;
8688
  break;
@@ -8697,6 +9225,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
8697
  case GGML_UNARY_OP_SILU:
8698
  func = ggml_cuda_silu;
8699
  break;
 
 
 
 
 
 
8700
  case GGML_UNARY_OP_RELU:
8701
  func = ggml_cuda_relu;
8702
  break;
@@ -8707,6 +9241,21 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
8707
  case GGML_OP_NORM:
8708
  func = ggml_cuda_norm;
8709
  break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8710
  case GGML_OP_RMS_NORM:
8711
  func = ggml_cuda_rms_norm;
8712
  break;
@@ -8729,9 +9278,6 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
8729
  func = ggml_cuda_sqr;
8730
  break;
8731
  case GGML_OP_CLAMP:
8732
- if (!any_on_device) {
8733
- return false;
8734
- }
8735
  func = ggml_cuda_clamp;
8736
  break;
8737
  case GGML_OP_CPY:
@@ -8740,6 +9286,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
8740
  case GGML_OP_CONT:
8741
  func = ggml_cuda_dup;
8742
  break;
 
8743
  case GGML_OP_RESHAPE:
8744
  case GGML_OP_VIEW:
8745
  case GGML_OP_PERMUTE:
@@ -9159,6 +9706,8 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
9159
  case GGML_UNARY_OP_GELU:
9160
  case GGML_UNARY_OP_SILU:
9161
  case GGML_UNARY_OP_RELU:
 
 
9162
  return true;
9163
  default:
9164
  return false;
@@ -9181,6 +9730,45 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
9181
  }
9182
  return true;
9183
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9184
  case GGML_OP_NONE:
9185
  case GGML_OP_RESHAPE:
9186
  case GGML_OP_VIEW:
@@ -9188,7 +9776,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
9188
  case GGML_OP_TRANSPOSE:
9189
  case GGML_OP_NORM:
9190
  case GGML_OP_REPEAT:
9191
- case GGML_OP_GET_ROWS:
9192
  case GGML_OP_DUP:
9193
  case GGML_OP_ADD:
9194
  case GGML_OP_MUL:
@@ -9197,7 +9784,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
9197
  case GGML_OP_SCALE:
9198
  case GGML_OP_SQR:
9199
  case GGML_OP_CLAMP:
9200
- case GGML_OP_CPY:
9201
  case GGML_OP_CONT:
9202
  case GGML_OP_DIAG_MASK_INF:
9203
  case GGML_OP_SOFT_MAX:
@@ -9206,6 +9792,12 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
9206
  case GGML_OP_IM2COL:
9207
  case GGML_OP_SUM_ROWS:
9208
  case GGML_OP_ARGSORT:
 
 
 
 
 
 
9209
  return true;
9210
  default:
9211
  return false;
@@ -9264,7 +9856,9 @@ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * use
9264
  UNUSED(params);
9265
  }
9266
 
9267
- extern "C" int ggml_backend_cuda_reg_devices() {
 
 
9268
  int device_count = ggml_cuda_get_device_count();
9269
  //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
9270
  for (int i = 0; i < device_count; i++) {
 
1
  #include <algorithm>
2
+ #include <assert.h>
3
+ #include <atomic>
4
+ #include <cinttypes>
5
  #include <cstddef>
6
  #include <cstdint>
 
7
  #include <float.h>
8
  #include <limits>
9
  #include <stdint.h>
10
  #include <stdio.h>
11
+ #include <vector>
12
+
13
 
14
  #if defined(GGML_USE_HIPBLAS)
15
  #include <hip/hip_runtime.h>
 
439
 
440
  #define CUDA_GELU_BLOCK_SIZE 256
441
  #define CUDA_SILU_BLOCK_SIZE 256
442
+ #define CUDA_TANH_BLOCK_SIZE 256
443
  #define CUDA_RELU_BLOCK_SIZE 256
444
  #define CUDA_SQR_BLOCK_SIZE 256
445
  #define CUDA_CPY_BLOCK_SIZE 32
 
452
  #define CUDA_QUANTIZE_BLOCK_SIZE 256
453
  #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
454
  #define CUDA_GET_ROWS_BLOCK_SIZE 256
455
+ #define CUDA_UPSCALE_BLOCK_SIZE 256
456
+ #define CUDA_CONCAT_BLOCK_SIZE 256
457
+ #define CUDA_PAD_BLOCK_SIZE 256
458
+ #define CUDA_ACC_BLOCK_SIZE 256
459
+ #define CUDA_IM2COL_BLOCK_SIZE 256
460
 
461
  // dmmv = dequantize_mul_mat_vec
462
  #ifndef GGML_CUDA_DMMV_X
 
618
  dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
619
  }
620
 
621
+ static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
622
+ const int ne10, const int ne11, const int ne12,
623
+ const int nb1, const int nb2, int offset) {
624
+ const int i = blockDim.x * blockIdx.x + threadIdx.x;
625
+ if (i >= ne) {
626
+ return;
627
+ }
628
+ int src1_idx = i - offset;
629
+ int oz = src1_idx / nb2;
630
+ int oy = (src1_idx - (oz * nb2)) / nb1;
631
+ int ox = src1_idx % nb1;
632
+ if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
633
+ dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
634
+ } else {
635
+ dst[i] = x[i];
636
+ }
637
+ }
638
+
639
  static __global__ void gelu_f32(const float * x, float * dst, const int k) {
640
  const float GELU_COEF_A = 0.044715f;
641
  const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
658
  dst[i] = x[i] / (1.0f + expf(-x[i]));
659
  }
660
 
661
+ static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
662
+ const float GELU_QUICK_COEF = -1.702f;
663
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
664
+ if (i >= k) {
665
+ return;
666
+ }
667
+ dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
668
+ }
669
+
670
+ static __global__ void tanh_f32(const float *x, float *dst, int k) {
671
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
672
+ if (i >= k) {
673
+ return;
674
+ }
675
+ dst[i] = tanhf(x[i]);
676
+ }
677
+
678
  static __global__ void relu_f32(const float * x, float * dst, const int k) {
679
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
680
 
 
684
  dst[i] = fmaxf(x[i], 0);
685
  }
686
 
687
+ static __global__ void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope) {
688
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
689
+ if (i >= k) {
690
+ return;
691
+ }
692
+ dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
693
+ }
694
+
695
  static __global__ void sqr_f32(const float * x, float * dst, const int k) {
696
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
697
 
 
737
  }
738
  }
739
 
740
+ static __global__ void concat_f32(const float *x,const float *y, float *dst, const int ne0, const int ne02) {
741
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
742
+ if (nidx >= ne0) {
743
+ return;
744
+ }
745
+ // operation
746
+ int offset_dst =
747
+ nidx +
748
+ blockIdx.y * ne0 +
749
+ blockIdx.z * ne0 * gridDim.y;
750
+ if (blockIdx.z < ne02) { // src0
751
+ int offset_src =
752
+ nidx +
753
+ blockIdx.y * ne0 +
754
+ blockIdx.z * ne0 * gridDim.y;
755
+ dst[offset_dst] = x[offset_src];
756
+ } else {
757
+ int offset_src =
758
+ nidx +
759
+ blockIdx.y * ne0 +
760
+ (blockIdx.z - ne02) * ne0 * gridDim.y;
761
+ dst[offset_dst] = y[offset_src];
762
+ }
763
+ }
764
+
765
+ static __global__ void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor) {
766
+ int ne0 = ne00 * scale_factor;
767
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
768
+ if (nidx >= ne0) {
769
+ return;
770
+ }
771
+ // operation
772
+ int i00 = nidx / scale_factor;
773
+ int i01 = blockIdx.y / scale_factor;
774
+ int offset_src =
775
+ i00 +
776
+ i01 * ne00 +
777
+ blockIdx.z * nb02;
778
+ int offset_dst =
779
+ nidx +
780
+ blockIdx.y * ne0 +
781
+ blockIdx.z * ne0 * gridDim.y;
782
+ dst[offset_dst] = x[offset_src];
783
+ }
784
+
785
+ static __global__ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02) {
786
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
787
+ if (nidx >= ne0) {
788
+ return;
789
+ }
790
+
791
+ // operation
792
+ int offset_dst =
793
+ nidx +
794
+ blockIdx.y * ne0 +
795
+ blockIdx.z * ne0 * gridDim.y;
796
+ if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) {
797
+ int offset_src =
798
+ nidx +
799
+ blockIdx.y * ne00 +
800
+ blockIdx.z * ne00 * ne01;
801
+ dst[offset_dst] = x[offset_src];
802
+ } else {
803
+ dst[offset_dst] = 0.0f;
804
+ }
805
+ }
806
+
807
+ template <int block_size>
808
+ static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
809
+ int start = blockIdx.x * group_size;
810
+ int end = start + group_size;
811
+
812
+ start += threadIdx.x;
813
+
814
+ if (end >= ne_elements) {
815
+ end = ne_elements;
816
+ }
817
+
818
+ float tmp = 0.0f; // partial sum for thread in warp
819
+
820
+ for (int j = start; j < end; j += block_size) {
821
+ tmp += x[j];
822
+ }
823
+
824
+ tmp = warp_reduce_sum(tmp);
825
+ if (block_size > WARP_SIZE) {
826
+ __shared__ float s_sum[32];
827
+ int warp_id = threadIdx.x / WARP_SIZE;
828
+ int lane_id = threadIdx.x % WARP_SIZE;
829
+ if (lane_id == 0) {
830
+ s_sum[warp_id] = tmp;
831
+ }
832
+ __syncthreads();
833
+ tmp = s_sum[lane_id];
834
+ tmp = warp_reduce_sum(tmp);
835
+ }
836
+
837
+ float mean = tmp / group_size;
838
+ tmp = 0.0f;
839
+
840
+ for (int j = start; j < end; j += block_size) {
841
+ float xi = x[j] - mean;
842
+ dst[j] = xi;
843
+ tmp += xi * xi;
844
+ }
845
+
846
+ tmp = warp_reduce_sum(tmp);
847
+ if (block_size > WARP_SIZE) {
848
+ __shared__ float s_sum[32];
849
+ int warp_id = threadIdx.x / WARP_SIZE;
850
+ int lane_id = threadIdx.x % WARP_SIZE;
851
+ if (lane_id == 0) {
852
+ s_sum[warp_id] = tmp;
853
+ }
854
+ __syncthreads();
855
+ tmp = s_sum[lane_id];
856
+ tmp = warp_reduce_sum(tmp);
857
+ }
858
+
859
+ float variance = tmp / group_size;
860
+ float scale = rsqrtf(variance + eps);
861
+ for (int j = start; j < end; j += block_size) {
862
+ dst[j] *= scale;
863
+ }
864
+ }
865
+
866
  template <int block_size>
867
  static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
868
  const int row = blockIdx.x*blockDim.y + threadIdx.y;
 
1861
  }
1862
 
1863
  template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
1864
+ static __global__ void k_get_rows(
1865
+ const void * src0, const int32_t * src1, dst_t * dst,
1866
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1867
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1868
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1869
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1870
+ size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
1871
+
1872
+ const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
1873
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
1874
+ const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
1875
+ const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
1876
+
1877
+ if (i00 >= ne00) {
1878
  return;
1879
  }
1880
 
1881
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1882
 
1883
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1884
+ const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
 
1885
 
1886
+ const int ib = i00/qk; // block index
1887
+ const int iqs = (i00%qk)/qr; // quant index
1888
+ const int iybs = i00 - i00%qk; // dst block start index
1889
  const int y_offset = qr == 1 ? 1 : qk/2;
1890
 
1891
  // dequantize
1892
  dfloat2 v;
1893
+ dequantize_kernel(src0_row, ib, iqs, v);
1894
+
1895
+ dst_row[iybs + iqs + 0] = v.x;
1896
+ dst_row[iybs + iqs + y_offset] = v.y;
1897
+ }
1898
+
1899
+ template<typename src0_t, typename dst_t>
1900
+ static __global__ void k_get_rows_float(
1901
+ const src0_t * src0, const int32_t * src1, dst_t * dst,
1902
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1903
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1904
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1905
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1906
+ size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
1907
+
1908
+ const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
1909
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
1910
+ const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
1911
+ const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
1912
+
1913
+ if (i00 >= ne00) {
1914
+ return;
1915
+ }
1916
+
1917
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1918
+
1919
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1920
+ const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
1921
 
1922
+ dst_row[i00] = src0_row[i00];
 
1923
  }
1924
 
1925
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 
5246
 
5247
  static __global__ void im2col_f32_f16(
5248
  const float * x, half * dst,
5249
+ int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
5250
  int s0, int s1, int p0, int p1, int d0, int d1) {
5251
+ const int i = threadIdx.x + blockIdx.x * blockDim.x;
5252
+ if (i >= pelements) {
5253
+ return;
5254
+ }
5255
+
5256
+ const int ksize = OW * (KH > 1 ? KW : 1);
5257
+ const int kx = i / ksize;
5258
+ const int kd = kx * ksize;
5259
+ const int ky = (i - kd) / OW;
5260
+ const int ix = i % OW;
5261
+
5262
+ const int iiw = ix * s0 + kx * d0 - p0;
5263
+ const int iih = blockIdx.y * s1 + ky * d1 - p1;
5264
 
5265
  const int offset_dst =
5266
+ (blockIdx.y * OW + ix) * CHW +
5267
+ (blockIdx.z * (KW * KH) + ky * KW + kx);
5268
 
5269
  if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
5270
  dst[offset_dst] = __float2half(0.0f);
5271
  } else {
5272
+ const int offset_src = blockIdx.z * offset_delta;
5273
  dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
5274
  }
5275
  }
5276
 
5277
  template<int qk, int qr, dequantize_kernel_t dq>
5278
+ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
5279
+ const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
5280
+
5281
+ GGML_TENSOR_BINARY_OP_LOCALS
5282
+
5283
  const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
5284
+ const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
5285
+ const dim3 block_nums(block_num_x, ne10, ne11*ne12);
5286
+
5287
+ // strides in elements
5288
+ //const size_t s0 = nb0 / ggml_element_size(dst);
5289
+ const size_t s1 = nb1 / ggml_element_size(dst);
5290
+ const size_t s2 = nb2 / ggml_element_size(dst);
5291
+ const size_t s3 = nb3 / ggml_element_size(dst);
5292
+
5293
+ const size_t s10 = nb10 / ggml_element_size(src1);
5294
+ const size_t s11 = nb11 / ggml_element_size(src1);
5295
+ const size_t s12 = nb12 / ggml_element_size(src1);
5296
+ //const size_t s13 = nb13 / ggml_element_size(src1);
5297
+
5298
+ GGML_ASSERT(ne00 % 2 == 0);
5299
+
5300
+ k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
5301
+ src0_dd, src1_dd, dst_dd,
5302
+ ne00, /*ne01, ne02, ne03,*/
5303
+ /*ne10, ne11,*/ ne12, /*ne13,*/
5304
+ /* s0,*/ s1, s2, s3,
5305
+ /* nb00,*/ nb01, nb02, nb03,
5306
+ s10, s11, s12/*, s13*/);
5307
+
5308
+ (void) dst;
5309
+ }
5310
+
5311
+ template<typename src0_t>
5312
+ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
5313
+ const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
5314
+
5315
+ GGML_TENSOR_BINARY_OP_LOCALS
5316
+
5317
+ const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
5318
+ const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
5319
+ const dim3 block_nums(block_num_x, ne10, ne11*ne12);
5320
+
5321
+ // strides in elements
5322
+ //const size_t s0 = nb0 / ggml_element_size(dst);
5323
+ const size_t s1 = nb1 / ggml_element_size(dst);
5324
+ const size_t s2 = nb2 / ggml_element_size(dst);
5325
+ const size_t s3 = nb3 / ggml_element_size(dst);
5326
+
5327
+ const size_t s10 = nb10 / ggml_element_size(src1);
5328
+ const size_t s11 = nb11 / ggml_element_size(src1);
5329
+ const size_t s12 = nb12 / ggml_element_size(src1);
5330
+ //const size_t s13 = nb13 / ggml_element_size(src1);
5331
+
5332
+ k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
5333
+ src0_dd, src1_dd, dst_dd,
5334
+ ne00, /*ne01, ne02, ne03,*/
5335
+ /*ne10, ne11,*/ ne12, /*ne13,*/
5336
+ /* s0,*/ s1, s2, s3,
5337
+ /* nb00,*/ nb01, nb02, nb03,
5338
+ s10, s11, s12/*, s13*/);
5339
+
5340
+ (void) dst;
5341
  }
5342
 
5343
  template<float (*bin_op)(const float, const float)>
 
5349
 
5350
  GGML_TENSOR_BINARY_OP_LOCALS
5351
 
 
5352
  int nr0 = ne10/ne0;
5353
  int nr1 = ne11/ne1;
5354
  int nr2 = ne12/ne2;
 
5396
  int64_t ne12 = cne1[2];
5397
  int64_t ne13 = cne1[3];
5398
 
5399
+ size_t nb0 = cnb0[0];
5400
  size_t nb1 = cnb0[1];
5401
  size_t nb2 = cnb0[2];
5402
  size_t nb3 = cnb0[3];
5403
 
5404
+ size_t nb10 = cnb1[0];
5405
  size_t nb11 = cnb1[1];
5406
  size_t nb12 = cnb1[2];
5407
  size_t nb13 = cnb1[3];
5408
 
5409
+ size_t s0 = nb0 / sizeof(dst_t);
5410
+ size_t s1 = nb1 / sizeof(dst_t);
5411
+ size_t s2 = nb2 / sizeof(dst_t);
5412
+ size_t s3 = nb3 / sizeof(dst_t);
5413
 
5414
+ size_t s10 = nb10 / sizeof(src1_t);
5415
  size_t s11 = nb11 / sizeof(src1_t);
5416
  size_t s12 = nb12 / sizeof(src1_t);
5417
  size_t s13 = nb13 / sizeof(src1_t);
5418
 
5419
+ GGML_ASSERT(s0 == 1);
5420
+ GGML_ASSERT(s10 == 1);
5421
 
5422
  const int block_size = 128;
5423
 
 
5455
  }
5456
  };
5457
 
5458
+ static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
5459
+ const int ne10, const int ne11, const int ne12,
5460
+ const int nb1, const int nb2, const int offset, cudaStream_t stream) {
5461
+ int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
5462
+ acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
5463
+ }
5464
+
5465
  static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
5466
  const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
5467
  gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 
5472
  silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
5473
  }
5474
 
5475
+ static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
5476
+ const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
5477
+ gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
5478
+ }
5479
+
5480
+ static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
5481
+ const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
5482
+ tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
5483
+ }
5484
+
5485
  static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
5486
  const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
5487
  relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
5488
  }
5489
 
5490
+ static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
5491
+ const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
5492
+ leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
5493
+ }
5494
+
5495
  static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
5496
  const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
5497
  sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 
5508
  }
5509
  }
5510
 
5511
+ static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
5512
+ static const float eps = 1e-6f;
5513
+ if (group_size < 1024) {
5514
+ const dim3 block_dims(WARP_SIZE, 1, 1);
5515
+ group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
5516
+ } else {
5517
+ const dim3 block_dims(1024, 1, 1);
5518
+ group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
5519
+ }
5520
+ }
5521
+
5522
+ static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) {
5523
+ int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
5524
+ dim3 gridDim(num_blocks, ne1, ne2);
5525
+ concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
5526
+ }
5527
+
5528
+ static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
5529
+ int ne0 = (ne00 * scale_factor);
5530
+ int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
5531
+ dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02);
5532
+ upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
5533
+ }
5534
+
5535
+ static void pad_f32_cuda(const float * x, float * dst,
5536
+ const int ne00, const int ne01, const int ne02,
5537
+ const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
5538
+ int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
5539
+ dim3 gridDim(num_blocks, ne1, ne2);
5540
+ pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02);
5541
+ }
5542
+
5543
  static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
5544
  GGML_ASSERT(ncols % WARP_SIZE == 0);
5545
  if (ncols < 1024) {
 
6502
  soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6503
  }
6504
 
6505
+ static void im2col_f32_f16_cuda(const float* x, half* dst,
6506
+ int IW, int IH, int OW, int OH, int KW, int KH, int IC,
6507
+ int offset_delta,
6508
+ int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
6509
+ const int parallel_elements = OW * KW * KH;
6510
+ const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
6511
+ dim3 block_nums(num_blocks, OH, IC);
6512
+ im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
6513
  }
6514
 
6515
  // buffer pool for cuda
 
6783
 
6784
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
6785
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
 
 
6786
 
6787
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
6788
+ GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
6789
+ GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
6790
 
6791
  const int32_t * src1_i32 = (const int32_t *) src1_d;
6792
 
6793
  switch (src0->type) {
6794
  case GGML_TYPE_F16:
6795
+ get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
6796
  break;
6797
  case GGML_TYPE_F32:
6798
+ get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
6799
  break;
6800
  case GGML_TYPE_Q4_0:
6801
+ get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
6802
  break;
6803
  case GGML_TYPE_Q4_1:
6804
+ get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
6805
  break;
6806
  case GGML_TYPE_Q5_0:
6807
+ get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
6808
  break;
6809
  case GGML_TYPE_Q5_1:
6810
+ get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
6811
  break;
6812
  case GGML_TYPE_Q8_0:
6813
+ get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
6814
  break;
6815
  default:
6816
  // TODO: k-quants
 
6856
  ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
6857
  }
6858
 
6859
+ inline void ggml_cuda_op_acc(
6860
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6861
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6862
+
6863
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
6864
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6865
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
6866
+ GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
6867
+
6868
+ int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
6869
+ int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
6870
+ // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
6871
+ int offset = dst->op_params[3] / 4; // offset in bytes
6872
+
6873
+ acc_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
6874
+
6875
+ (void) dst;
6876
+ }
6877
+
6878
  inline void ggml_cuda_op_mul(
6879
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6880
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
6917
  (void) src1_dd;
6918
  }
6919
 
6920
+ inline void ggml_cuda_op_gelu_quick(
6921
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6922
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6923
+
6924
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
6925
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
6926
+
6927
+ gelu_quick_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
6928
+
6929
+ (void) src1;
6930
+ (void) dst;
6931
+ (void) src1_dd;
6932
+ }
6933
+
6934
+ inline void ggml_cuda_op_tanh(
6935
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6936
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6937
+
6938
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
6939
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
6940
+
6941
+ tanh_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
6942
+
6943
+ (void) src1;
6944
+ (void) dst;
6945
+ (void) src1_dd;
6946
+ }
6947
+
6948
  inline void ggml_cuda_op_relu(
6949
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6950
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
6959
  (void) src1_dd;
6960
  }
6961
 
6962
+ inline void ggml_cuda_op_leaky_relu(
6963
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6964
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6965
+
6966
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
6967
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
6968
+
6969
+ float negative_slope;
6970
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
6971
+
6972
+ leaky_relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
6973
+
6974
+ (void) src1;
6975
+ (void) dst;
6976
+ (void) src1_dd;
6977
+ }
6978
+
6979
  inline void ggml_cuda_op_sqr(
6980
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6981
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
7010
  (void) src1_dd;
7011
  }
7012
 
7013
+
7014
+ inline void ggml_cuda_op_group_norm(
7015
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
7016
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
7017
+
7018
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
7019
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
7020
+
7021
+ int num_groups = dst->op_params[0];
7022
+ int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
7023
+ group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
7024
+
7025
+ (void) src1;
7026
+ (void) dst;
7027
+ (void) src1_dd;
7028
+ }
7029
+
7030
+ inline void ggml_cuda_op_concat(
7031
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
7032
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
7033
+
7034
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
7035
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
7036
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
7037
+
7038
+ for (int i3 = 0; i3 < dst->ne[3]; i3++) {
7039
+ concat_f32_cuda(src0_dd + i3 * (src0->nb[3] / 4), src1_dd + i3 * (src1->nb[3] / 4), dst_dd + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], main_stream);
7040
+ }
7041
+
7042
+ (void) src1;
7043
+ (void) dst;
7044
+ }
7045
+
7046
+ inline void ggml_cuda_op_upscale(
7047
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
7048
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
7049
+
7050
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
7051
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
7052
+ GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
7053
+
7054
+ const int scale_factor = dst->op_params[0];
7055
+
7056
+ upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
7057
+
7058
+ (void) src1;
7059
+ (void) dst;
7060
+ }
7061
+
7062
+ inline void ggml_cuda_op_pad(
7063
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
7064
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
7065
+
7066
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
7067
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
7068
+ GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
7069
+
7070
+ pad_f32_cuda(src0_dd, dst_dd,
7071
+ src0->ne[0], src0->ne[1], src0->ne[2],
7072
+ dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
7073
+
7074
+ (void) src1;
7075
+ (void) dst;
7076
+ }
7077
+
7078
  inline void ggml_cuda_op_rms_norm(
7079
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
7080
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
7589
 
7590
  const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
7591
 
 
7592
  const int64_t IC = src1->ne[is_2D ? 2 : 1];
7593
  const int64_t IH = is_2D ? src1->ne[1] : 1;
7594
  const int64_t IW = src1->ne[0];
 
7599
  const int64_t OH = is_2D ? dst->ne[2] : 1;
7600
  const int64_t OW = dst->ne[1];
7601
 
7602
+ const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
 
7603
 
7604
+ im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
 
 
7605
 
7606
  (void) src0;
7607
  (void) src0_dd;
7608
  }
7609
 
7610
+
7611
  inline void ggml_cuda_op_sum_rows(
7612
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
7613
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
8156
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
8157
  }
8158
 
8159
+ static void ggml_cuda_acc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8160
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_acc);
8161
+ }
8162
+
8163
  static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8164
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
8165
  }
 
8176
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
8177
  }
8178
 
8179
+ static void ggml_cuda_gelu_quick(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8180
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu_quick);
8181
+ }
8182
+
8183
+ static void ggml_cuda_tanh(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8184
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_tanh);
8185
+ }
8186
+
8187
  static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8188
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
8189
  }
8190
 
8191
+ static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8192
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
8193
+ }
8194
+
8195
  static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8196
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
8197
  }
 
8200
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
8201
  }
8202
 
8203
+ static void ggml_cuda_group_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8204
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_group_norm);
8205
+ }
8206
+
8207
+ static void ggml_cuda_concat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8208
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_concat);
8209
+ }
8210
+
8211
+ static void ggml_cuda_upscale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8212
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_upscale);
8213
+ }
8214
+
8215
+ static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8216
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
8217
+ }
8218
+
8219
  static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8220
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
8221
  }
 
8726
  }
8727
  #endif
8728
 
8729
+ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
8730
  #if 0
 
 
 
 
 
 
8731
  ggml_cuda_mul_mat_id_cublas(dst);
 
8732
  // TODO: mmq/mmv support
8733
+ #endif
 
 
 
8734
 
8735
+ GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
8736
 
8737
+ const struct ggml_tensor * ids = src0;
8738
+ const int32_t id = ((int32_t *) dst->op_params)[0];
8739
+ const int32_t n_as = ((int32_t *) dst->op_params)[1];
8740
 
8741
+ std::vector<char> ids_host(ggml_nbytes(ids));
 
8742
 
8743
+ if (ids->backend == GGML_BACKEND_GPU) {
8744
+ const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
8745
+ CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
8746
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
8747
+ } else {
8748
+ memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
8749
+ }
8750
+
8751
+ const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
8752
+ const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
8753
+
8754
+ ggml_tensor_extra_gpu src1_row_extra;
8755
+ ggml_tensor_extra_gpu dst_row_extra;
8756
+
8757
+ ggml_tensor src1_row = *src1;
8758
+ ggml_tensor dst_row = *dst;
8759
+
8760
+ src1_row.ne[1] = 1;
8761
+ dst_row.ne[1] = 1;
8762
+
8763
+ src1_row.nb[2] = src1_row.nb[1];
8764
+ dst_row.nb[2] = dst_row.nb[1];
8765
+
8766
+ src1_row.nb[3] = src1_row.nb[1];
8767
+ dst_row.nb[3] = dst_row.nb[1];
8768
+
8769
+ src1_row.extra = &src1_row_extra;
8770
+ dst_row.extra = &dst_row_extra;
8771
+
8772
+
8773
+ for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
8774
+ //int32_t row_id;
8775
+ //CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
8776
+ //CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
8777
+
8778
+ const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
8779
+
8780
+ GGML_ASSERT(row_id >= 0 && row_id < n_as);
8781
 
8782
+ const struct ggml_tensor * src0_row = dst->src[row_id + 2];
8783
+
8784
+ src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
8785
+ src1_row.data = (char *) src1->data + i01*src1->nb[1];
8786
+
8787
+ dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
8788
+ dst_row.data = (char *) dst->data + i01*dst->nb[1];
8789
+
8790
+ ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
8791
+ }
8792
  }
8793
 
8794
  static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 
9208
  case GGML_OP_ADD:
9209
  func = ggml_cuda_add;
9210
  break;
9211
+ case GGML_OP_ACC:
9212
+ func = ggml_cuda_acc;
9213
+ break;
9214
  case GGML_OP_MUL:
9215
  func = ggml_cuda_mul;
9216
  break;
 
9225
  case GGML_UNARY_OP_SILU:
9226
  func = ggml_cuda_silu;
9227
  break;
9228
+ case GGML_UNARY_OP_GELU_QUICK:
9229
+ func = ggml_cuda_gelu_quick;
9230
+ break;
9231
+ case GGML_UNARY_OP_TANH:
9232
+ func = ggml_cuda_tanh;
9233
+ break;
9234
  case GGML_UNARY_OP_RELU:
9235
  func = ggml_cuda_relu;
9236
  break;
 
9241
  case GGML_OP_NORM:
9242
  func = ggml_cuda_norm;
9243
  break;
9244
+ case GGML_OP_GROUP_NORM:
9245
+ func = ggml_cuda_group_norm;
9246
+ break;
9247
+ case GGML_OP_CONCAT:
9248
+ func = ggml_cuda_concat;
9249
+ break;
9250
+ case GGML_OP_UPSCALE:
9251
+ func = ggml_cuda_upscale;
9252
+ break;
9253
+ case GGML_OP_PAD:
9254
+ func = ggml_cuda_pad;
9255
+ break;
9256
+ case GGML_OP_LEAKY_RELU:
9257
+ func = ggml_cuda_leaky_relu;
9258
+ break;
9259
  case GGML_OP_RMS_NORM:
9260
  func = ggml_cuda_rms_norm;
9261
  break;
 
9278
  func = ggml_cuda_sqr;
9279
  break;
9280
  case GGML_OP_CLAMP:
 
 
 
9281
  func = ggml_cuda_clamp;
9282
  break;
9283
  case GGML_OP_CPY:
 
9286
  case GGML_OP_CONT:
9287
  func = ggml_cuda_dup;
9288
  break;
9289
+ case GGML_OP_NONE:
9290
  case GGML_OP_RESHAPE:
9291
  case GGML_OP_VIEW:
9292
  case GGML_OP_PERMUTE:
 
9706
  case GGML_UNARY_OP_GELU:
9707
  case GGML_UNARY_OP_SILU:
9708
  case GGML_UNARY_OP_RELU:
9709
+ case GGML_UNARY_OP_GELU_QUICK:
9710
+ case GGML_UNARY_OP_TANH:
9711
  return true;
9712
  default:
9713
  return false;
 
9730
  }
9731
  return true;
9732
  } break;
9733
+ case GGML_OP_GET_ROWS:
9734
+ {
9735
+ switch (op->src[0]->type) {
9736
+ case GGML_TYPE_F16:
9737
+ case GGML_TYPE_F32:
9738
+ case GGML_TYPE_Q4_0:
9739
+ case GGML_TYPE_Q4_1:
9740
+ case GGML_TYPE_Q5_0:
9741
+ case GGML_TYPE_Q5_1:
9742
+ case GGML_TYPE_Q8_0:
9743
+ return true;
9744
+ default:
9745
+ return false;
9746
+ }
9747
+ } break;
9748
+ case GGML_OP_CPY:
9749
+ {
9750
+ ggml_type src0_type = op->src[0]->type;
9751
+ ggml_type src1_type = op->src[1]->type;
9752
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
9753
+ return true;
9754
+ }
9755
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
9756
+ return true;
9757
+ }
9758
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
9759
+ return true;
9760
+ }
9761
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
9762
+ return true;
9763
+ }
9764
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
9765
+ return true;
9766
+ }
9767
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
9768
+ return true;
9769
+ }
9770
+ return false;
9771
+ } break;
9772
  case GGML_OP_NONE:
9773
  case GGML_OP_RESHAPE:
9774
  case GGML_OP_VIEW:
 
9776
  case GGML_OP_TRANSPOSE:
9777
  case GGML_OP_NORM:
9778
  case GGML_OP_REPEAT:
 
9779
  case GGML_OP_DUP:
9780
  case GGML_OP_ADD:
9781
  case GGML_OP_MUL:
 
9784
  case GGML_OP_SCALE:
9785
  case GGML_OP_SQR:
9786
  case GGML_OP_CLAMP:
 
9787
  case GGML_OP_CONT:
9788
  case GGML_OP_DIAG_MASK_INF:
9789
  case GGML_OP_SOFT_MAX:
 
9792
  case GGML_OP_IM2COL:
9793
  case GGML_OP_SUM_ROWS:
9794
  case GGML_OP_ARGSORT:
9795
+ case GGML_OP_ACC:
9796
+ case GGML_OP_CONCAT:
9797
+ case GGML_OP_GROUP_NORM:
9798
+ case GGML_OP_UPSCALE:
9799
+ case GGML_OP_PAD:
9800
+ case GGML_OP_LEAKY_RELU:
9801
  return true;
9802
  default:
9803
  return false;
 
9856
  UNUSED(params);
9857
  }
9858
 
9859
+ extern "C" int ggml_backend_cuda_reg_devices();
9860
+
9861
+ int ggml_backend_cuda_reg_devices() {
9862
  int device_count = ggml_cuda_get_device_count();
9863
  //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
9864
  for (int i = 0; i < device_count; i++) {
ggml-metal.m CHANGED
@@ -66,9 +66,11 @@ struct ggml_metal_context {
66
  GGML_METAL_DECL_KERNEL(div_row);
67
  GGML_METAL_DECL_KERNEL(scale);
68
  GGML_METAL_DECL_KERNEL(scale_4);
69
- GGML_METAL_DECL_KERNEL(silu);
70
  GGML_METAL_DECL_KERNEL(relu);
71
  GGML_METAL_DECL_KERNEL(gelu);
 
 
72
  GGML_METAL_DECL_KERNEL(soft_max);
73
  GGML_METAL_DECL_KERNEL(soft_max_4);
74
  GGML_METAL_DECL_KERNEL(diag_mask_inf);
@@ -86,6 +88,7 @@ struct ggml_metal_context {
86
  GGML_METAL_DECL_KERNEL(get_rows_q5_K);
87
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
88
  GGML_METAL_DECL_KERNEL(rms_norm);
 
89
  GGML_METAL_DECL_KERNEL(norm);
90
  GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
91
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
@@ -102,6 +105,21 @@ struct ggml_metal_context {
102
  GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
103
  GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
104
  GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
106
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
107
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -130,8 +148,11 @@ struct ggml_metal_context {
130
  GGML_METAL_DECL_KERNEL(rope_f16);
131
  GGML_METAL_DECL_KERNEL(alibi_f32);
132
  GGML_METAL_DECL_KERNEL(im2col_f16);
 
 
133
  GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
134
  GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
 
135
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
136
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
137
  GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
@@ -140,6 +161,7 @@ struct ggml_metal_context {
140
  //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
141
  //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
142
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
 
143
  GGML_METAL_DECL_KERNEL(concat);
144
  GGML_METAL_DECL_KERNEL(sqr);
145
  GGML_METAL_DECL_KERNEL(sum_rows);
@@ -318,9 +340,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
318
  GGML_METAL_ADD_KERNEL(div_row);
319
  GGML_METAL_ADD_KERNEL(scale);
320
  GGML_METAL_ADD_KERNEL(scale_4);
321
- GGML_METAL_ADD_KERNEL(silu);
322
  GGML_METAL_ADD_KERNEL(relu);
323
  GGML_METAL_ADD_KERNEL(gelu);
 
 
324
  GGML_METAL_ADD_KERNEL(soft_max);
325
  GGML_METAL_ADD_KERNEL(soft_max_4);
326
  GGML_METAL_ADD_KERNEL(diag_mask_inf);
@@ -338,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
338
  GGML_METAL_ADD_KERNEL(get_rows_q5_K);
339
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
340
  GGML_METAL_ADD_KERNEL(rms_norm);
 
341
  GGML_METAL_ADD_KERNEL(norm);
342
  GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
343
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
@@ -354,6 +379,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
354
  GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
355
  GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
356
  GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
358
  GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
359
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -384,8 +424,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
384
  GGML_METAL_ADD_KERNEL(rope_f16);
385
  GGML_METAL_ADD_KERNEL(alibi_f32);
386
  GGML_METAL_ADD_KERNEL(im2col_f16);
 
 
387
  GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
388
  GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
 
389
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
390
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
391
  GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
@@ -394,6 +437,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
394
  //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
395
  //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
396
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
 
397
  GGML_METAL_ADD_KERNEL(concat);
398
  GGML_METAL_ADD_KERNEL(sqr);
399
  GGML_METAL_ADD_KERNEL(sum_rows);
@@ -418,9 +462,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
418
  GGML_METAL_DEL_KERNEL(div_row);
419
  GGML_METAL_DEL_KERNEL(scale);
420
  GGML_METAL_DEL_KERNEL(scale_4);
421
- GGML_METAL_DEL_KERNEL(silu);
422
  GGML_METAL_DEL_KERNEL(relu);
423
  GGML_METAL_DEL_KERNEL(gelu);
 
 
424
  GGML_METAL_DEL_KERNEL(soft_max);
425
  GGML_METAL_DEL_KERNEL(soft_max_4);
426
  GGML_METAL_DEL_KERNEL(diag_mask_inf);
@@ -438,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
438
  GGML_METAL_DEL_KERNEL(get_rows_q5_K);
439
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
440
  GGML_METAL_DEL_KERNEL(rms_norm);
 
441
  GGML_METAL_DEL_KERNEL(norm);
442
  GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
443
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
@@ -454,6 +501,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
454
  GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
455
  GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
456
  GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
458
  GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
459
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -484,8 +546,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
484
  GGML_METAL_DEL_KERNEL(rope_f16);
485
  GGML_METAL_DEL_KERNEL(alibi_f32);
486
  GGML_METAL_DEL_KERNEL(im2col_f16);
 
 
487
  GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
488
  GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
 
489
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
490
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
491
  GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
@@ -494,6 +559,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
494
  //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
495
  //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
496
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
 
497
  GGML_METAL_DEL_KERNEL(concat);
498
  GGML_METAL_DEL_KERNEL(sqr);
499
  GGML_METAL_DEL_KERNEL(sum_rows);
@@ -795,9 +861,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
795
  switch (op->op) {
796
  case GGML_OP_UNARY:
797
  switch (ggml_get_unary_op(op)) {
798
- case GGML_UNARY_OP_SILU:
799
  case GGML_UNARY_OP_RELU:
800
  case GGML_UNARY_OP_GELU:
 
 
801
  return true;
802
  default:
803
  return false;
@@ -809,6 +877,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
809
  case GGML_OP_PERMUTE:
810
  case GGML_OP_CONCAT:
811
  case GGML_OP_ADD:
 
812
  case GGML_OP_MUL:
813
  case GGML_OP_DIV:
814
  case GGML_OP_SCALE:
@@ -816,21 +885,50 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
816
  case GGML_OP_SUM_ROWS:
817
  case GGML_OP_SOFT_MAX:
818
  case GGML_OP_RMS_NORM:
 
819
  case GGML_OP_NORM:
820
  case GGML_OP_ALIBI:
821
  case GGML_OP_ROPE:
822
  case GGML_OP_IM2COL:
 
 
823
  case GGML_OP_ARGSORT:
824
- case GGML_OP_DUP:
825
- case GGML_OP_CPY:
826
- case GGML_OP_CONT:
827
  case GGML_OP_MUL_MAT:
828
  case GGML_OP_MUL_MAT_ID:
829
  return true;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
830
  case GGML_OP_DIAG_MASK_INF:
831
  case GGML_OP_GET_ROWS:
832
  {
833
- return op->ne[0] % 4 == 0;
834
  }
835
  default:
836
  return false;
@@ -906,7 +1004,10 @@ void ggml_metal_graph_compute(
906
  } break;
907
  }
908
 
909
- GGML_ASSERT(ggml_metal_supports_op(dst));
 
 
 
910
 
911
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
912
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
@@ -1003,34 +1104,39 @@ void ggml_metal_graph_compute(
1003
  case GGML_OP_MUL:
1004
  case GGML_OP_DIV:
1005
  {
1006
- GGML_ASSERT(ggml_is_contiguous(src0));
1007
- GGML_ASSERT(ggml_is_contiguous(src1));
1008
 
1009
  bool bcast_row = false;
1010
 
1011
  int64_t nb = ne00;
1012
 
1013
- if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
 
 
 
 
1014
  // src1 is a row
1015
  GGML_ASSERT(ne11 == 1);
1016
 
1017
  nb = ne00 / 4;
1018
  switch (dst->op) {
1019
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
1020
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
1021
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
1022
  default: GGML_ASSERT(false);
1023
  }
1024
 
1025
  bcast_row = true;
1026
  } else {
1027
  switch (dst->op) {
1028
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
1029
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
1030
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
1031
  default: GGML_ASSERT(false);
1032
  }
1033
  }
 
 
1034
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1035
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1036
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -1058,18 +1164,99 @@ void ggml_metal_graph_compute(
1058
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1059
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1060
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1061
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
 
1062
 
1063
  if (bcast_row) {
1064
  const int64_t n = ggml_nelements(dst)/4;
1065
 
1066
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1067
  } else {
1068
- const int nth = MIN(1024, ne0);
1069
 
1070
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1071
  }
1072
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1073
  case GGML_OP_SCALE:
1074
  {
1075
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -1093,16 +1280,15 @@ void ggml_metal_graph_compute(
1093
  } break;
1094
  case GGML_OP_UNARY:
1095
  switch (ggml_get_unary_op(gf->nodes[i])) {
1096
- case GGML_UNARY_OP_SILU:
1097
  {
1098
- [encoder setComputePipelineState:ctx->pipeline_silu];
1099
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1100
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1101
 
1102
  const int64_t n = ggml_nelements(dst);
1103
- GGML_ASSERT(n % 4 == 0);
1104
 
1105
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1106
  } break;
1107
  case GGML_UNARY_OP_RELU:
1108
  {
@@ -1123,6 +1309,28 @@ void ggml_metal_graph_compute(
1123
  const int64_t n = ggml_nelements(dst);
1124
  GGML_ASSERT(n % 4 == 0);
1125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1126
  [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1127
  } break;
1128
  default:
@@ -1197,6 +1405,8 @@ void ggml_metal_graph_compute(
1197
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1198
  if (id_src1) {
1199
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
 
 
1200
  }
1201
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1202
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
@@ -1448,7 +1658,7 @@ void ggml_metal_graph_compute(
1448
  else if (src0t == GGML_TYPE_Q6_K) {
1449
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1450
  } else {
1451
- int64_t ny = (ne11 + nrows - 1)/nrows;
1452
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1453
  }
1454
  }
@@ -1460,7 +1670,7 @@ void ggml_metal_graph_compute(
1460
 
1461
  GGML_ASSERT(src0t == GGML_TYPE_I32);
1462
 
1463
- const int n_as = ne00;
1464
 
1465
  // TODO: make this more general
1466
  GGML_ASSERT(n_as <= 8);
@@ -1492,14 +1702,22 @@ void ggml_metal_graph_compute(
1492
 
1493
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1494
  // to the matrix-vector kernel
1495
- int ne11_mm_min = 0;
1496
 
1497
  const int idx = ((int32_t *) dst->op_params)[0];
1498
 
 
 
 
 
 
1499
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1500
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1501
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1502
- ne11 > ne11_mm_min) {
 
 
 
1503
  switch (src2->type) {
1504
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1505
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
@@ -1518,19 +1736,22 @@ void ggml_metal_graph_compute(
1518
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1519
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1520
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1521
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
1522
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
1523
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
1524
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
1525
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1526
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1527
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1528
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1529
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1530
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1531
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1532
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1533
- [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
 
 
 
1534
  // TODO: how to make this an array? read Metal docs
1535
  for (int j = 0; j < n_as; ++j) {
1536
  struct ggml_tensor * src_cur = dst->src[2 + j];
@@ -1538,11 +1759,157 @@ void ggml_metal_graph_compute(
1538
  size_t offs_src_cur = 0;
1539
  id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1540
 
1541
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
1542
  }
1543
 
1544
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1545
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1546
  }
1547
  } break;
1548
  case GGML_OP_GET_ROWS:
@@ -1563,16 +1930,19 @@ void ggml_metal_graph_compute(
1563
  default: GGML_ASSERT(false && "not implemented");
1564
  }
1565
 
1566
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1567
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1568
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1569
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1570
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1571
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
1572
-
1573
- const int64_t n = ggml_nelements(src1);
1574
-
1575
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
 
 
 
1576
  } break;
1577
  case GGML_OP_RMS_NORM:
1578
  {
@@ -1599,6 +1969,38 @@ void ggml_metal_graph_compute(
1599
 
1600
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1601
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1602
  case GGML_OP_NORM:
1603
  {
1604
  float eps;
@@ -1768,6 +2170,65 @@ void ggml_metal_graph_compute(
1768
 
1769
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1770
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1771
  case GGML_OP_ARGSORT:
1772
  {
1773
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -1789,6 +2250,22 @@ void ggml_metal_graph_compute(
1789
 
1790
  [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
1791
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1792
  case GGML_OP_DUP:
1793
  case GGML_OP_CPY:
1794
  case GGML_OP_CONT:
@@ -1817,7 +2294,7 @@ void ggml_metal_graph_compute(
1817
  {
1818
  switch (dstt) {
1819
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
1820
- case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
1821
  default: GGML_ASSERT(false && "not implemented");
1822
  };
1823
  } break;
 
66
  GGML_METAL_DECL_KERNEL(div_row);
67
  GGML_METAL_DECL_KERNEL(scale);
68
  GGML_METAL_DECL_KERNEL(scale_4);
69
+ GGML_METAL_DECL_KERNEL(tanh);
70
  GGML_METAL_DECL_KERNEL(relu);
71
  GGML_METAL_DECL_KERNEL(gelu);
72
+ GGML_METAL_DECL_KERNEL(gelu_quick);
73
+ GGML_METAL_DECL_KERNEL(silu);
74
  GGML_METAL_DECL_KERNEL(soft_max);
75
  GGML_METAL_DECL_KERNEL(soft_max_4);
76
  GGML_METAL_DECL_KERNEL(diag_mask_inf);
 
88
  GGML_METAL_DECL_KERNEL(get_rows_q5_K);
89
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
90
  GGML_METAL_DECL_KERNEL(rms_norm);
91
+ GGML_METAL_DECL_KERNEL(group_norm);
92
  GGML_METAL_DECL_KERNEL(norm);
93
  GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
94
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
 
105
  GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
106
  GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
107
  GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
108
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
109
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
110
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
111
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
112
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
113
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
114
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
115
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
116
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
117
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
118
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
119
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
120
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
121
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
122
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
123
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
124
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
125
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
 
148
  GGML_METAL_DECL_KERNEL(rope_f16);
149
  GGML_METAL_DECL_KERNEL(alibi_f32);
150
  GGML_METAL_DECL_KERNEL(im2col_f16);
151
+ GGML_METAL_DECL_KERNEL(upscale_f32);
152
+ GGML_METAL_DECL_KERNEL(pad_f32);
153
  GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
154
  GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
155
+ GGML_METAL_DECL_KERNEL(leaky_relu_f32);
156
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
157
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
158
  GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
 
161
  //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
162
  //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
163
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
164
+ GGML_METAL_DECL_KERNEL(cpy_f16_f32);
165
  GGML_METAL_DECL_KERNEL(concat);
166
  GGML_METAL_DECL_KERNEL(sqr);
167
  GGML_METAL_DECL_KERNEL(sum_rows);
 
340
  GGML_METAL_ADD_KERNEL(div_row);
341
  GGML_METAL_ADD_KERNEL(scale);
342
  GGML_METAL_ADD_KERNEL(scale_4);
343
+ GGML_METAL_ADD_KERNEL(tanh);
344
  GGML_METAL_ADD_KERNEL(relu);
345
  GGML_METAL_ADD_KERNEL(gelu);
346
+ GGML_METAL_ADD_KERNEL(gelu_quick);
347
+ GGML_METAL_ADD_KERNEL(silu);
348
  GGML_METAL_ADD_KERNEL(soft_max);
349
  GGML_METAL_ADD_KERNEL(soft_max_4);
350
  GGML_METAL_ADD_KERNEL(diag_mask_inf);
 
362
  GGML_METAL_ADD_KERNEL(get_rows_q5_K);
363
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
364
  GGML_METAL_ADD_KERNEL(rms_norm);
365
+ GGML_METAL_ADD_KERNEL(group_norm);
366
  GGML_METAL_ADD_KERNEL(norm);
367
  GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
368
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
 
379
  GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
380
  GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
381
  GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
382
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
383
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
384
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
385
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
386
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
387
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
388
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
389
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
390
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
391
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
392
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
393
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
394
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
395
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
396
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
397
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
398
  GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
399
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
 
424
  GGML_METAL_ADD_KERNEL(rope_f16);
425
  GGML_METAL_ADD_KERNEL(alibi_f32);
426
  GGML_METAL_ADD_KERNEL(im2col_f16);
427
+ GGML_METAL_ADD_KERNEL(upscale_f32);
428
+ GGML_METAL_ADD_KERNEL(pad_f32);
429
  GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
430
  GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
431
+ GGML_METAL_ADD_KERNEL(leaky_relu_f32);
432
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
433
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
434
  GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
 
437
  //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
438
  //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
439
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
440
+ GGML_METAL_ADD_KERNEL(cpy_f16_f32);
441
  GGML_METAL_ADD_KERNEL(concat);
442
  GGML_METAL_ADD_KERNEL(sqr);
443
  GGML_METAL_ADD_KERNEL(sum_rows);
 
462
  GGML_METAL_DEL_KERNEL(div_row);
463
  GGML_METAL_DEL_KERNEL(scale);
464
  GGML_METAL_DEL_KERNEL(scale_4);
465
+ GGML_METAL_DEL_KERNEL(tanh);
466
  GGML_METAL_DEL_KERNEL(relu);
467
  GGML_METAL_DEL_KERNEL(gelu);
468
+ GGML_METAL_DEL_KERNEL(gelu_quick);
469
+ GGML_METAL_DEL_KERNEL(silu);
470
  GGML_METAL_DEL_KERNEL(soft_max);
471
  GGML_METAL_DEL_KERNEL(soft_max_4);
472
  GGML_METAL_DEL_KERNEL(diag_mask_inf);
 
484
  GGML_METAL_DEL_KERNEL(get_rows_q5_K);
485
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
486
  GGML_METAL_DEL_KERNEL(rms_norm);
487
+ GGML_METAL_DEL_KERNEL(group_norm);
488
  GGML_METAL_DEL_KERNEL(norm);
489
  GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
490
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
 
501
  GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
502
  GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
503
  GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
504
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
505
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
506
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
507
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
508
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
509
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
510
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
511
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
512
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
513
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
514
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
515
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
516
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
517
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
518
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
519
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
520
  GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
521
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
 
546
  GGML_METAL_DEL_KERNEL(rope_f16);
547
  GGML_METAL_DEL_KERNEL(alibi_f32);
548
  GGML_METAL_DEL_KERNEL(im2col_f16);
549
+ GGML_METAL_DEL_KERNEL(upscale_f32);
550
+ GGML_METAL_DEL_KERNEL(pad_f32);
551
  GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
552
  GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
553
+ GGML_METAL_DEL_KERNEL(leaky_relu_f32);
554
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
555
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
556
  GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
 
559
  //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
560
  //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
561
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
562
+ GGML_METAL_DEL_KERNEL(cpy_f16_f32);
563
  GGML_METAL_DEL_KERNEL(concat);
564
  GGML_METAL_DEL_KERNEL(sqr);
565
  GGML_METAL_DEL_KERNEL(sum_rows);
 
861
  switch (op->op) {
862
  case GGML_OP_UNARY:
863
  switch (ggml_get_unary_op(op)) {
864
+ case GGML_UNARY_OP_TANH:
865
  case GGML_UNARY_OP_RELU:
866
  case GGML_UNARY_OP_GELU:
867
+ case GGML_UNARY_OP_GELU_QUICK:
868
+ case GGML_UNARY_OP_SILU:
869
  return true;
870
  default:
871
  return false;
 
877
  case GGML_OP_PERMUTE:
878
  case GGML_OP_CONCAT:
879
  case GGML_OP_ADD:
880
+ case GGML_OP_ACC:
881
  case GGML_OP_MUL:
882
  case GGML_OP_DIV:
883
  case GGML_OP_SCALE:
 
885
  case GGML_OP_SUM_ROWS:
886
  case GGML_OP_SOFT_MAX:
887
  case GGML_OP_RMS_NORM:
888
+ case GGML_OP_GROUP_NORM:
889
  case GGML_OP_NORM:
890
  case GGML_OP_ALIBI:
891
  case GGML_OP_ROPE:
892
  case GGML_OP_IM2COL:
893
+ case GGML_OP_UPSCALE:
894
+ case GGML_OP_PAD:
895
  case GGML_OP_ARGSORT:
896
+ case GGML_OP_LEAKY_RELU:
 
 
897
  case GGML_OP_MUL_MAT:
898
  case GGML_OP_MUL_MAT_ID:
899
  return true;
900
+ case GGML_OP_CPY:
901
+ case GGML_OP_DUP:
902
+ case GGML_OP_CONT:
903
+ {
904
+ switch (op->src[0]->type) {
905
+ case GGML_TYPE_F32:
906
+ switch (op->type) {
907
+ case GGML_TYPE_F16:
908
+ case GGML_TYPE_F32:
909
+ case GGML_TYPE_Q8_0:
910
+ case GGML_TYPE_Q4_0:
911
+ case GGML_TYPE_Q4_1:
912
+ return true;
913
+ default:
914
+ return false;
915
+ }
916
+ case GGML_TYPE_F16:
917
+ switch (op->type) {
918
+ case GGML_TYPE_F16:
919
+ case GGML_TYPE_F32:
920
+ return true;
921
+ default:
922
+ return false;
923
+ }
924
+ default:
925
+ return false;
926
+ };
927
+ }
928
  case GGML_OP_DIAG_MASK_INF:
929
  case GGML_OP_GET_ROWS:
930
  {
931
+ return op->ne[3] == 1;
932
  }
933
  default:
934
  return false;
 
1004
  } break;
1005
  }
1006
 
1007
+ if (!ggml_metal_supports_op(dst)) {
1008
+ GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
1009
+ GGML_ASSERT(!"unsupported op");
1010
+ }
1011
 
1012
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
1013
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
 
1104
  case GGML_OP_MUL:
1105
  case GGML_OP_DIV:
1106
  {
1107
+ const size_t offs = 0;
 
1108
 
1109
  bool bcast_row = false;
1110
 
1111
  int64_t nb = ne00;
1112
 
1113
+ id<MTLComputePipelineState> pipeline = nil;
1114
+
1115
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1116
+ GGML_ASSERT(ggml_is_contiguous(src0));
1117
+
1118
  // src1 is a row
1119
  GGML_ASSERT(ne11 == 1);
1120
 
1121
  nb = ne00 / 4;
1122
  switch (dst->op) {
1123
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
1124
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
1125
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
1126
  default: GGML_ASSERT(false);
1127
  }
1128
 
1129
  bcast_row = true;
1130
  } else {
1131
  switch (dst->op) {
1132
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
1133
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
1134
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
1135
  default: GGML_ASSERT(false);
1136
  }
1137
  }
1138
+
1139
+ [encoder setComputePipelineState:pipeline];
1140
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1141
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1142
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
 
1164
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1165
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1166
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1167
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1168
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
1169
 
1170
  if (bcast_row) {
1171
  const int64_t n = ggml_nelements(dst)/4;
1172
 
1173
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1174
  } else {
1175
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1176
 
1177
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1178
  }
1179
  } break;
1180
+ case GGML_OP_ACC:
1181
+ {
1182
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
1183
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1184
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
1185
+
1186
+ GGML_ASSERT(ggml_is_contiguous(src0));
1187
+ GGML_ASSERT(ggml_is_contiguous(src1));
1188
+
1189
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1190
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1191
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1192
+ const size_t offs = ((int32_t *) dst->op_params)[3];
1193
+
1194
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1195
+
1196
+ if (!inplace) {
1197
+ // run a separete kernel to cpy src->dst
1198
+ // not sure how to avoid this
1199
+ // TODO: make a simpler cpy_bytes kernel
1200
+
1201
+ const int nth = MIN(1024, ne00);
1202
+
1203
+ [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
1204
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1205
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1206
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1207
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1208
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1209
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1210
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1211
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1212
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1213
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1214
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1215
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1216
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1217
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1218
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1219
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1220
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1221
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1222
+
1223
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1224
+ }
1225
+
1226
+ [encoder setComputePipelineState:ctx->pipeline_add];
1227
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1228
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1229
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1230
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1231
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1232
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1233
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1234
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1235
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1236
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1237
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1238
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1239
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1240
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1241
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1242
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1243
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1244
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1245
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1246
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1247
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1248
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1249
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1250
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1251
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1252
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1253
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1254
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1255
+
1256
+ const int nth = MIN(1024, ne0);
1257
+
1258
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1259
+ } break;
1260
  case GGML_OP_SCALE:
1261
  {
1262
  GGML_ASSERT(ggml_is_contiguous(src0));
 
1280
  } break;
1281
  case GGML_OP_UNARY:
1282
  switch (ggml_get_unary_op(gf->nodes[i])) {
1283
+ case GGML_UNARY_OP_TANH:
1284
  {
1285
+ [encoder setComputePipelineState:ctx->pipeline_tanh];
1286
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1287
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1288
 
1289
  const int64_t n = ggml_nelements(dst);
 
1290
 
1291
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1292
  } break;
1293
  case GGML_UNARY_OP_RELU:
1294
  {
 
1309
  const int64_t n = ggml_nelements(dst);
1310
  GGML_ASSERT(n % 4 == 0);
1311
 
1312
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1313
+ } break;
1314
+ case GGML_UNARY_OP_GELU_QUICK:
1315
+ {
1316
+ [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
1317
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1318
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1319
+
1320
+ const int64_t n = ggml_nelements(dst);
1321
+ GGML_ASSERT(n % 4 == 0);
1322
+
1323
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1324
+ } break;
1325
+ case GGML_UNARY_OP_SILU:
1326
+ {
1327
+ [encoder setComputePipelineState:ctx->pipeline_silu];
1328
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1329
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1330
+
1331
+ const int64_t n = ggml_nelements(dst);
1332
+ GGML_ASSERT(n % 4 == 0);
1333
+
1334
  [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1335
  } break;
1336
  default:
 
1405
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1406
  if (id_src1) {
1407
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1408
+ } else {
1409
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1410
  }
1411
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1412
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
 
1658
  else if (src0t == GGML_TYPE_Q6_K) {
1659
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1660
  } else {
1661
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
1662
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1663
  }
1664
  }
 
1670
 
1671
  GGML_ASSERT(src0t == GGML_TYPE_I32);
1672
 
1673
+ const int n_as = ((int32_t *) dst->op_params)[1];
1674
 
1675
  // TODO: make this more general
1676
  GGML_ASSERT(n_as <= 8);
 
1702
 
1703
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1704
  // to the matrix-vector kernel
1705
+ int ne11_mm_min = 1;
1706
 
1707
  const int idx = ((int32_t *) dst->op_params)[0];
1708
 
1709
+ // batch size
1710
+ GGML_ASSERT(ne01 == ne11);
1711
+
1712
+ const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
1713
+
1714
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1715
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1716
+ // !!!
1717
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1718
+ // indirect matrix multiplication
1719
+ // !!!
1720
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
1721
  switch (src2->type) {
1722
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1723
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
 
1736
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1737
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1738
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1739
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1740
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1741
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1742
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1743
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1744
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1745
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1746
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1747
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1748
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1749
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1750
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
1751
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1752
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1753
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1754
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1755
  // TODO: how to make this an array? read Metal docs
1756
  for (int j = 0; j < n_as; ++j) {
1757
  struct ggml_tensor * src_cur = dst->src[2 + j];
 
1759
  size_t offs_src_cur = 0;
1760
  id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1761
 
1762
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1763
  }
1764
 
1765
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1766
+
1767
+ // TODO: processing one row at a time (ne11 -> 1) is not efficient
1768
+ [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1769
+ } else {
1770
+ int nth0 = 32;
1771
+ int nth1 = 1;
1772
+ int nrows = 1;
1773
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1774
+
1775
+ // use custom matrix x vector kernel
1776
+ switch (src2t) {
1777
+ case GGML_TYPE_F32:
1778
+ {
1779
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1780
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
1781
+ } break;
1782
+ case GGML_TYPE_F16:
1783
+ {
1784
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1785
+ nth0 = 32;
1786
+ nth1 = 1;
1787
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
1788
+ } break;
1789
+ case GGML_TYPE_Q4_0:
1790
+ {
1791
+ nth0 = 8;
1792
+ nth1 = 8;
1793
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
1794
+ } break;
1795
+ case GGML_TYPE_Q4_1:
1796
+ {
1797
+ nth0 = 8;
1798
+ nth1 = 8;
1799
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
1800
+ } break;
1801
+ case GGML_TYPE_Q5_0:
1802
+ {
1803
+ nth0 = 8;
1804
+ nth1 = 8;
1805
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
1806
+ } break;
1807
+ case GGML_TYPE_Q5_1:
1808
+ {
1809
+ nth0 = 8;
1810
+ nth1 = 8;
1811
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
1812
+ } break;
1813
+ case GGML_TYPE_Q8_0:
1814
+ {
1815
+ nth0 = 8;
1816
+ nth1 = 8;
1817
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
1818
+ } break;
1819
+ case GGML_TYPE_Q2_K:
1820
+ {
1821
+ nth0 = 2;
1822
+ nth1 = 32;
1823
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
1824
+ } break;
1825
+ case GGML_TYPE_Q3_K:
1826
+ {
1827
+ nth0 = 2;
1828
+ nth1 = 32;
1829
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
1830
+ } break;
1831
+ case GGML_TYPE_Q4_K:
1832
+ {
1833
+ nth0 = 4; //1;
1834
+ nth1 = 8; //32;
1835
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
1836
+ } break;
1837
+ case GGML_TYPE_Q5_K:
1838
+ {
1839
+ nth0 = 2;
1840
+ nth1 = 32;
1841
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
1842
+ } break;
1843
+ case GGML_TYPE_Q6_K:
1844
+ {
1845
+ nth0 = 2;
1846
+ nth1 = 32;
1847
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1848
+ } break;
1849
+ default:
1850
+ {
1851
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1852
+ GGML_ASSERT(false && "not implemented");
1853
+ }
1854
+ };
1855
+
1856
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1857
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1858
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1859
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1860
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1861
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1862
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1863
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1864
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1865
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1866
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1867
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1868
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1869
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1870
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1871
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1872
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1873
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1874
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1875
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1876
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1877
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1878
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1879
+ // TODO: how to make this an array? read Metal docs
1880
+ for (int j = 0; j < n_as; ++j) {
1881
+ struct ggml_tensor * src_cur = dst->src[2 + j];
1882
+
1883
+ size_t offs_src_cur = 0;
1884
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1885
+
1886
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1887
+ }
1888
+
1889
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1890
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1891
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1892
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1893
+ }
1894
+ else if (src2t == GGML_TYPE_Q4_K) {
1895
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1896
+ }
1897
+ else if (src2t == GGML_TYPE_Q3_K) {
1898
+ #ifdef GGML_QKK_64
1899
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1900
+ #else
1901
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1902
+ #endif
1903
+ }
1904
+ else if (src2t == GGML_TYPE_Q5_K) {
1905
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1906
+ }
1907
+ else if (src2t == GGML_TYPE_Q6_K) {
1908
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1909
+ } else {
1910
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
1911
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1912
+ }
1913
  }
1914
  } break;
1915
  case GGML_OP_GET_ROWS:
 
1930
  default: GGML_ASSERT(false && "not implemented");
1931
  }
1932
 
1933
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1934
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1935
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1936
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1937
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1938
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1939
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1940
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1941
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1942
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1943
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
1944
+
1945
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1946
  } break;
1947
  case GGML_OP_RMS_NORM:
1948
  {
 
1969
 
1970
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1971
  } break;
1972
+ case GGML_OP_GROUP_NORM:
1973
+ {
1974
+ GGML_ASSERT(ne00 % 4 == 0);
1975
+
1976
+ //float eps;
1977
+ //memcpy(&eps, dst->op_params, sizeof(float));
1978
+
1979
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
1980
+
1981
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
1982
+
1983
+ int nth = 32; // SIMD width
1984
+
1985
+ //while (nth < ne00/4 && nth < 1024) {
1986
+ // nth *= 2;
1987
+ //}
1988
+
1989
+ [encoder setComputePipelineState:ctx->pipeline_group_norm];
1990
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1991
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1992
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1993
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1994
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1995
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
1996
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
1997
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
1998
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
1999
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
2000
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2001
+
2002
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2003
+ } break;
2004
  case GGML_OP_NORM:
2005
  {
2006
  float eps;
 
2170
 
2171
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2172
  } break;
2173
+ case GGML_OP_UPSCALE:
2174
+ {
2175
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2176
+
2177
+ const int sf = dst->op_params[0];
2178
+
2179
+ [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
2180
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2181
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2182
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2183
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2184
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2185
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2186
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2187
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2188
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2189
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2190
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2191
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2192
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2193
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2194
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2195
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2196
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2197
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2198
+ [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2199
+
2200
+ const int nth = MIN(1024, ne0);
2201
+
2202
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2203
+ } break;
2204
+ case GGML_OP_PAD:
2205
+ {
2206
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2207
+
2208
+ [encoder setComputePipelineState:ctx->pipeline_pad_f32];
2209
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2210
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2211
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2212
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2213
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2214
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2215
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2216
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2217
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2218
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2219
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2220
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2221
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2222
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2223
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2224
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2225
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2226
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2227
+
2228
+ const int nth = MIN(1024, ne0);
2229
+
2230
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2231
+ } break;
2232
  case GGML_OP_ARGSORT:
2233
  {
2234
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
 
2250
 
2251
  [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2252
  } break;
2253
+ case GGML_OP_LEAKY_RELU:
2254
+ {
2255
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2256
+
2257
+ float slope;
2258
+ memcpy(&slope, dst->op_params, sizeof(float));
2259
+
2260
+ [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
2261
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2262
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2263
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2264
+
2265
+ const int64_t n = ggml_nelements(dst);
2266
+
2267
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2268
+ } break;
2269
  case GGML_OP_DUP:
2270
  case GGML_OP_CPY:
2271
  case GGML_OP_CONT:
 
2294
  {
2295
  switch (dstt) {
2296
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
2297
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
2298
  default: GGML_ASSERT(false && "not implemented");
2299
  };
2300
  } break;
ggml-metal.metal CHANGED
@@ -79,6 +79,7 @@ kernel void kernel_add(
79
  constant int64_t & nb1,
80
  constant int64_t & nb2,
81
  constant int64_t & nb3,
 
82
  uint3 tgpig[[threadgroup_position_in_grid]],
83
  uint3 tpitg[[thread_position_in_threadgroup]],
84
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -90,9 +91,9 @@ kernel void kernel_add(
90
  const int64_t i12 = i02 % ne12;
91
  const int64_t i11 = i01 % ne11;
92
 
93
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
94
  device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
95
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
96
 
97
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
98
  const int i10 = i0 % ne10;
@@ -204,7 +205,7 @@ kernel void kernel_add_row(
204
  device const float4 * src0,
205
  device const float4 * src1,
206
  device float4 * dst,
207
- constant int64_t & nb [[buffer(27)]],
208
  uint tpig[[thread_position_in_grid]]) {
209
  dst[tpig] = src0[tpig] + src1[tpig % nb];
210
  }
@@ -213,7 +214,7 @@ kernel void kernel_mul_row(
213
  device const float4 * src0,
214
  device const float4 * src1,
215
  device float4 * dst,
216
- constant int64_t & nb [[buffer(27)]],
217
  uint tpig[[thread_position_in_grid]]) {
218
  dst[tpig] = src0[tpig] * src1[tpig % nb];
219
  }
@@ -222,7 +223,7 @@ kernel void kernel_div_row(
222
  device const float4 * src0,
223
  device const float4 * src1,
224
  device float4 * dst,
225
- constant int64_t & nb [[buffer(27)]],
226
  uint tpig[[thread_position_in_grid]]) {
227
  dst[tpig] = src0[tpig] / src1[tpig % nb];
228
  }
@@ -243,19 +244,53 @@ kernel void kernel_scale_4(
243
  dst[tpig] = src0[tpig] * scale;
244
  }
245
 
246
- kernel void kernel_silu(
247
- device const float4 * src0,
248
- device float4 * dst,
249
  uint tpig[[thread_position_in_grid]]) {
250
- device const float4 & x = src0[tpig];
251
- dst[tpig] = x / (1.0f + exp(-x));
252
  }
253
 
254
- kernel void kernel_relu(
255
  device const float * src0,
256
  device float * dst,
257
  uint tpig[[thread_position_in_grid]]) {
258
- dst[tpig] = max(0.0f, src0[tpig]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  }
260
 
261
  kernel void kernel_sqr(
@@ -313,22 +348,6 @@ kernel void kernel_sum_rows(
313
  dst_row[0] = row_sum;
314
  }
315
 
316
- constant float GELU_COEF_A = 0.044715f;
317
- constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
318
-
319
- kernel void kernel_gelu(
320
- device const float4 * src0,
321
- device float4 * dst,
322
- uint tpig[[thread_position_in_grid]]) {
323
- device const float4 & x = src0[tpig];
324
-
325
- // BEWARE !!!
326
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
327
- // This was observed with Falcon 7B and 40B models
328
- //
329
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
330
- }
331
-
332
  kernel void kernel_soft_max(
333
  device const float * src0,
334
  device const float * src1,
@@ -347,9 +366,9 @@ kernel void kernel_soft_max(
347
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
348
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
349
 
350
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
351
- device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
352
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
353
 
354
  // parallel max
355
  float lmax = -INFINITY;
@@ -385,7 +404,12 @@ kernel void kernel_soft_max(
385
  pdst[i00] = exp_psrc0;
386
  }
387
 
 
 
 
 
388
  float sum = simd_sum(lsum);
 
389
  if (ntg > N_SIMDWIDTH) {
390
  if (sgitg == 0) {
391
  buf[tiisg] = 0.0f;
@@ -428,9 +452,9 @@ kernel void kernel_soft_max_4(
428
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
429
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
430
 
431
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
432
- device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
433
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
434
 
435
  // parallel max
436
  float4 lmax4 = -INFINITY;
@@ -468,7 +492,13 @@ kernel void kernel_soft_max_4(
468
  }
469
 
470
  const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
 
 
 
 
 
471
  float sum = simd_sum(lsum);
 
472
  if (ntg > N_SIMDWIDTH) {
473
  if (sgitg == 0) {
474
  buf[tiisg] = 0.0f;
@@ -639,6 +669,94 @@ kernel void kernel_rms_norm(
639
  }
640
  }
641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
  // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
643
  // il indicates where the q4 quants begin (0 or QK4_0/4)
644
  // we assume that the yl's have been multiplied with the appropriate scale factor
@@ -731,7 +849,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
731
  // giard against the number of rows not being divisible by
732
  // N_DST, so this is another explicit assumption of the implementation.
733
  template<typename block_q_type, int nr, int nsg, int nw>
734
- void mul_vec_q_n_f32(
735
  device const void * src0,
736
  device const float * src1,
737
  device float * dst,
@@ -813,7 +931,7 @@ kernel void kernel_mul_mv_q4_0_f32(
813
  uint3 tgpig[[threadgroup_position_in_grid]],
814
  uint tiisg[[thread_index_in_simdgroup]],
815
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
816
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
817
  }
818
 
819
  kernel void kernel_mul_mv_q4_1_f32(
@@ -832,7 +950,7 @@ kernel void kernel_mul_mv_q4_1_f32(
832
  uint3 tgpig[[threadgroup_position_in_grid]],
833
  uint tiisg[[thread_index_in_simdgroup]],
834
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
835
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
836
  }
837
 
838
  kernel void kernel_mul_mv_q5_0_f32(
@@ -851,7 +969,7 @@ kernel void kernel_mul_mv_q5_0_f32(
851
  uint3 tgpig[[threadgroup_position_in_grid]],
852
  uint tiisg[[thread_index_in_simdgroup]],
853
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
854
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
855
  }
856
 
857
  kernel void kernel_mul_mv_q5_1_f32(
@@ -870,28 +988,28 @@ kernel void kernel_mul_mv_q5_1_f32(
870
  uint3 tgpig[[threadgroup_position_in_grid]],
871
  uint tiisg[[thread_index_in_simdgroup]],
872
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
873
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
874
  }
875
 
876
 
877
  #define NB_Q8_0 8
878
 
879
- kernel void kernel_mul_mv_q8_0_f32(
880
  device const void * src0,
881
  device const float * src1,
882
  device float * dst,
883
  constant int64_t & ne00,
884
- constant int64_t & ne01[[buffer(4)]],
885
- constant int64_t & ne02[[buffer(5)]],
886
- constant int64_t & ne10[[buffer(9)]],
887
- constant int64_t & ne12[[buffer(11)]],
888
- constant int64_t & ne0 [[buffer(15)]],
889
- constant int64_t & ne1 [[buffer(16)]],
890
- constant uint & r2 [[buffer(17)]],
891
- constant uint & r3 [[buffer(18)]],
892
  uint3 tgpig[[threadgroup_position_in_grid]],
893
- uint tiisg[[thread_index_in_simdgroup]],
894
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
895
  const int nr = N_DST;
896
  const int nsg = N_SIMDGROUP;
897
  const int nw = N_SIMDWIDTH;
@@ -945,9 +1063,29 @@ kernel void kernel_mul_mv_q8_0_f32(
945
  }
946
  }
947
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
948
  #define N_F32_F32 4
949
 
950
- kernel void kernel_mul_mv_f32_f32(
951
  device const char * src0,
952
  device const char * src1,
953
  device float * dst,
@@ -965,8 +1103,8 @@ kernel void kernel_mul_mv_f32_f32(
965
  constant uint64_t & nb12,
966
  constant int64_t & ne0,
967
  constant int64_t & ne1,
968
- constant uint & r2 [[buffer(17)]],
969
- constant uint & r3 [[buffer(18)]],
970
  uint3 tgpig[[threadgroup_position_in_grid]],
971
  uint tiisg[[thread_index_in_simdgroup]]) {
972
 
@@ -1025,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32(
1025
  }
1026
  }
1027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1028
  #define N_F16_F16 4
1029
 
1030
  kernel void kernel_mul_mv_f16_f16(
@@ -1105,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16(
1105
  }
1106
  }
1107
 
1108
- kernel void kernel_mul_mv_f16_f32_1row(
1109
  device const char * src0,
1110
  device const char * src1,
1111
  device float * dst,
@@ -1123,8 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
1123
  constant uint64_t & nb12,
1124
  constant int64_t & ne0,
1125
  constant int64_t & ne1,
1126
- constant uint & r2 [[buffer(17)]],
1127
- constant uint & r3 [[buffer(18)]],
1128
  uint3 tgpig[[threadgroup_position_in_grid]],
1129
  uint tiisg[[thread_index_in_simdgroup]]) {
1130
 
@@ -1161,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
1161
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1162
  }
1163
  }
 
1164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1165
  }
1166
 
1167
  #define N_F16_F32 4
1168
 
1169
- kernel void kernel_mul_mv_f16_f32(
1170
  device const char * src0,
1171
  device const char * src1,
1172
  device float * dst,
@@ -1184,8 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
1184
  constant uint64_t & nb12,
1185
  constant int64_t & ne0,
1186
  constant int64_t & ne1,
1187
- constant uint & r2 [[buffer(17)]],
1188
- constant uint & r3 [[buffer(18)]],
1189
  uint3 tgpig[[threadgroup_position_in_grid]],
1190
  uint tiisg[[thread_index_in_simdgroup]]) {
1191
 
@@ -1244,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
1244
  }
1245
  }
1246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1247
  // Assumes row size (ne00) is a multiple of 4
1248
  kernel void kernel_mul_mv_f16_f32_l4(
1249
  device const char * src0,
@@ -1548,25 +1763,116 @@ kernel void kernel_im2col_f16(
1548
  }
1549
  }
1550
 
1551
- // bitonic sort implementation following the CUDA kernels as reference
1552
- typedef void (argsort_t)(
1553
- device const float * x,
1554
- device int32_t * dst,
1555
- constant int64_t & ncols,
1556
- uint3 tgpig[[threadgroup_position_in_grid]],
1557
- uint3 tpitg[[thread_position_in_threadgroup]]);
1558
-
1559
- template<ggml_sort_order order>
1560
- kernel void kernel_argsort_f32_i32(
1561
- device const float * x,
1562
- device int32_t * dst,
1563
- constant int64_t & ncols,
1564
- uint3 tgpig[[threadgroup_position_in_grid]],
1565
- uint3 tpitg[[thread_position_in_threadgroup]]) {
1566
- // bitonic sort
1567
- int col = tpitg[0];
1568
- int row = tgpig[1];
1569
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1570
  if (col >= ncols) return;
1571
 
1572
  device const float * x_row = x + row * ncols;
@@ -1600,9 +1906,17 @@ kernel void kernel_argsort_f32_i32(
1600
  template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
1601
  template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
1602
 
 
 
 
 
 
 
 
 
1603
  kernel void kernel_cpy_f16_f16(
1604
- device const half * src0,
1605
- device half * dst,
1606
  constant int64_t & ne00,
1607
  constant int64_t & ne01,
1608
  constant int64_t & ne02,
@@ -1641,6 +1955,47 @@ kernel void kernel_cpy_f16_f16(
1641
  }
1642
  }
1643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1644
  kernel void kernel_cpy_f32_f16(
1645
  device const float * src0,
1646
  device half * dst,
@@ -1917,9 +2272,9 @@ kernel void kernel_cpy_f32_q4_1(
1917
  }
1918
 
1919
  kernel void kernel_concat(
1920
- device const char * src0,
1921
- device const char * src1,
1922
- device char * dst,
1923
  constant int64_t & ne00,
1924
  constant int64_t & ne01,
1925
  constant int64_t & ne02,
@@ -1956,7 +2311,7 @@ kernel void kernel_concat(
1956
  const int64_t i12 = i02 % ne12;
1957
  const int64_t i11 = i01 % ne11;
1958
 
1959
- device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
1960
  device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1961
  device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1962
 
@@ -2064,19 +2419,19 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
2064
 
2065
  //====================================== dot products =========================
2066
 
2067
- kernel void kernel_mul_mv_q2_K_f32(
2068
  device const void * src0,
2069
  device const float * src1,
2070
  device float * dst,
2071
  constant int64_t & ne00,
2072
- constant int64_t & ne01[[buffer(4)]],
2073
- constant int64_t & ne02[[buffer(5)]],
2074
- constant int64_t & ne10[[buffer(9)]],
2075
- constant int64_t & ne12[[buffer(11)]],
2076
- constant int64_t & ne0 [[buffer(15)]],
2077
- constant int64_t & ne1 [[buffer(16)]],
2078
- constant uint & r2 [[buffer(17)]],
2079
- constant uint & r3 [[buffer(18)]],
2080
  uint3 tgpig[[threadgroup_position_in_grid]],
2081
  uint tiisg[[thread_index_in_simdgroup]],
2082
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2214,8 +2569,8 @@ kernel void kernel_mul_mv_q2_K_f32(
2214
  }
2215
  }
2216
 
2217
- #if QK_K == 256
2218
- kernel void kernel_mul_mv_q3_K_f32(
2219
  device const void * src0,
2220
  device const float * src1,
2221
  device float * dst,
@@ -2229,8 +2584,29 @@ kernel void kernel_mul_mv_q3_K_f32(
2229
  constant uint & r2 [[buffer(17)]],
2230
  constant uint & r3 [[buffer(18)]],
2231
  uint3 tgpig[[threadgroup_position_in_grid]],
2232
- uint tiisg[[thread_index_in_simdgroup]],
2233
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2234
 
2235
  const int nb = ne00/QK_K;
2236
 
@@ -2373,19 +2749,19 @@ kernel void kernel_mul_mv_q3_K_f32(
2373
  }
2374
  }
2375
  #else
2376
- kernel void kernel_mul_mv_q3_K_f32(
2377
  device const void * src0,
2378
  device const float * src1,
2379
  device float * dst,
2380
  constant int64_t & ne00,
2381
- constant int64_t & ne01[[buffer(4)]],
2382
- constant int64_t & ne02[[buffer(5)]],
2383
- constant int64_t & ne10[[buffer(9)]],
2384
- constant int64_t & ne12[[buffer(11)]],
2385
- constant int64_t & ne0 [[buffer(15)]],
2386
- constant int64_t & ne1 [[buffer(16)]],
2387
- constant uint & r2 [[buffer(17)]],
2388
- constant uint & r3 [[buffer(18)]],
2389
  uint3 tgpig[[threadgroup_position_in_grid]],
2390
  uint tiisg[[thread_index_in_simdgroup]],
2391
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2450,20 +2826,41 @@ kernel void kernel_mul_mv_q3_K_f32(
2450
  }
2451
  #endif
2452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2453
  #if QK_K == 256
2454
- kernel void kernel_mul_mv_q4_K_f32(
2455
  device const void * src0,
2456
  device const float * src1,
2457
  device float * dst,
2458
  constant int64_t & ne00,
2459
- constant int64_t & ne01 [[buffer(4)]],
2460
- constant int64_t & ne02 [[buffer(5)]],
2461
- constant int64_t & ne10 [[buffer(9)]],
2462
- constant int64_t & ne12 [[buffer(11)]],
2463
- constant int64_t & ne0 [[buffer(15)]],
2464
- constant int64_t & ne1 [[buffer(16)]],
2465
- constant uint & r2 [[buffer(17)]],
2466
- constant uint & r3 [[buffer(18)]],
2467
  uint3 tgpig[[threadgroup_position_in_grid]],
2468
  uint tiisg[[thread_index_in_simdgroup]],
2469
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2564,19 +2961,19 @@ kernel void kernel_mul_mv_q4_K_f32(
2564
  }
2565
  }
2566
  #else
2567
- kernel void kernel_mul_mv_q4_K_f32(
2568
  device const void * src0,
2569
  device const float * src1,
2570
  device float * dst,
2571
  constant int64_t & ne00,
2572
- constant int64_t & ne01[[buffer(4)]],
2573
- constant int64_t & ne02[[buffer(5)]],
2574
- constant int64_t & ne10[[buffer(9)]],
2575
- constant int64_t & ne12[[buffer(11)]],
2576
- constant int64_t & ne0 [[buffer(15)]],
2577
- constant int64_t & ne1 [[buffer(16)]],
2578
- constant uint & r2 [[buffer(17)]],
2579
- constant uint & r3 [[buffer(18)]],
2580
  uint3 tgpig[[threadgroup_position_in_grid]],
2581
  uint tiisg[[thread_index_in_simdgroup]],
2582
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2660,7 +3057,8 @@ kernel void kernel_mul_mv_q4_K_f32(
2660
  }
2661
  #endif
2662
 
2663
- kernel void kernel_mul_mv_q5_K_f32(
 
2664
  device const void * src0,
2665
  device const float * src1,
2666
  device float * dst,
@@ -2677,6 +3075,26 @@ kernel void kernel_mul_mv_q5_K_f32(
2677
  uint tiisg[[thread_index_in_simdgroup]],
2678
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
2679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2680
  const int nb = ne00/QK_K;
2681
 
2682
  const int64_t r0 = tgpig.x;
@@ -2836,10 +3254,10 @@ kernel void kernel_mul_mv_q5_K_f32(
2836
  dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
2837
  }
2838
  }
2839
-
2840
  }
2841
 
2842
- kernel void kernel_mul_mv_q6_K_f32(
 
2843
  device const void * src0,
2844
  device const float * src1,
2845
  device float * dst,
@@ -2853,18 +3271,38 @@ kernel void kernel_mul_mv_q6_K_f32(
2853
  constant uint & r2 [[buffer(17)]],
2854
  constant uint & r3 [[buffer(18)]],
2855
  uint3 tgpig[[threadgroup_position_in_grid]],
2856
- uint tiisg[[thread_index_in_simdgroup]],
2857
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2858
-
2859
- const uint8_t kmask1 = 0x03;
2860
- const uint8_t kmask2 = 0x0C;
2861
- const uint8_t kmask3 = 0x30;
2862
- const uint8_t kmask4 = 0xC0;
2863
 
2864
- const int nb = ne00/QK_K;
 
2865
 
2866
- const int64_t r0 = tgpig.x;
2867
- const int64_t r1 = tgpig.y;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2868
  const int im = tgpig.z;
2869
 
2870
  const int row = 2 * r0 + sgitg;
@@ -2945,6 +3383,27 @@ kernel void kernel_mul_mv_q6_K_f32(
2945
  }
2946
  }
2947
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2948
  //============================= templates and their specializations =============================
2949
 
2950
  // NOTE: this is not dequantizing - we are simply fitting the template
@@ -3062,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
3062
 
3063
  template <typename type4x4>
3064
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
3065
- const half d = xb->d;
3066
- const half min = xb->dmin;
3067
  device const uint8_t * q = (device const uint8_t *)xb->qs;
3068
- half dl, ml;
3069
  uint8_t sc = xb->scales[il];
3070
 
3071
  #if QK_K == 256
@@ -3135,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
3135
  q = q + (il/4) * 32 + 16 * (il&1);
3136
  il = il & 3;
3137
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3138
- const half d = il < 2 ? xb->d : xb->d / 16.h;
3139
- const half min = xb->dmin;
3140
- const half dl = d * sc[0];
3141
- const half ml = min * sc[1];
3142
  #else
3143
  q = q + 16 * (il&1);
3144
  device const uint8_t * s = xb->scales;
@@ -3165,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
3165
  uint8_t ul = 1 << (il/2);
3166
  il = il & 3;
3167
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3168
- const half d = il < 2 ? xb->d : xb->d / 16.h;
3169
- const half min = xb->dmin;
3170
- const half dl = d * sc[0];
3171
- const half ml = min * sc[1];
3172
 
3173
- const ushort mask = il<2 ? 0x0F : 0xF0;
3174
- const half qh_val = il<2 ? 16.h : 256.h;
3175
  for (int i = 0; i < 16; ++i) {
3176
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
3177
  }
@@ -3219,22 +3678,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3219
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
3220
  kernel void kernel_get_rows(
3221
  device const void * src0,
3222
- device const int * src1,
3223
  device float * dst,
3224
  constant int64_t & ne00,
3225
  constant uint64_t & nb01,
 
 
 
 
3226
  constant uint64_t & nb1,
3227
- uint tgpig[[threadgroup_position_in_grid]],
 
3228
  uint tiitg[[thread_index_in_threadgroup]],
3229
- uint tptg[[threads_per_threadgroup]]) {
3230
- const int i = tgpig;
3231
- const int r = ((device int32_t *) src1)[i];
 
 
 
3232
 
3233
- for (int ind = tiitg; ind < ne00/16; ind += tptg) {
 
 
 
 
3234
  float4x4 temp;
3235
  dequantize_func(
3236
- ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
3237
- *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3238
  }
3239
  }
3240
 
@@ -3426,19 +3953,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
3426
 
3427
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3428
  kernel void kernel_mul_mm_id(
3429
- device const int32_t * ids,
3430
  device const uchar * src1,
3431
- device float * dst,
 
3432
  constant int64_t & ne00,
3433
  constant int64_t & ne02,
3434
  constant int64_t & nb01,
3435
  constant int64_t & nb02,
3436
  constant int64_t & ne12,
 
3437
  constant int64_t & nb10,
3438
  constant int64_t & nb11,
3439
  constant int64_t & nb12,
3440
  constant int64_t & ne0,
3441
  constant int64_t & ne1,
 
3442
  constant uint & r2,
3443
  constant uint & r3,
3444
  constant int & idx,
@@ -3456,10 +3986,16 @@ kernel void kernel_mul_mm_id(
3456
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3457
  device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3458
 
 
 
 
 
 
 
3459
  kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3460
- src0[ids[idx]],
3461
- src1,
3462
- dst,
3463
  ne00,
3464
  ne02,
3465
  nb01,
@@ -3484,17 +4020,26 @@ kernel void kernel_mul_mm_id(
3484
  #define QK_NL 4
3485
  #endif
3486
 
 
 
 
 
3487
  typedef void (get_rows_t)(
3488
  device const void * src0,
3489
- device const int * src1,
3490
  device float * dst,
3491
  constant int64_t & ne00,
3492
  constant uint64_t & nb01,
 
 
 
 
3493
  constant uint64_t & nb1,
3494
- uint, uint, uint);
 
3495
 
3496
- template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
3497
- template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
3498
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
3499
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
3500
  template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
@@ -3506,6 +4051,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
3506
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
3507
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
3508
 
 
 
 
 
3509
  typedef void (mat_mm_t)(
3510
  device const uchar * src0,
3511
  device const uchar * src1,
@@ -3538,20 +4087,27 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
3538
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
3539
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
3540
 
 
 
 
 
3541
  typedef void (mat_mm_id_t)(
3542
- device const int32_t * ids,
3543
  device const uchar * src1,
3544
- device float * dst,
 
3545
  constant int64_t & ne00,
3546
  constant int64_t & ne02,
3547
  constant int64_t & nb01,
3548
  constant int64_t & nb02,
3549
  constant int64_t & ne12,
 
3550
  constant int64_t & nb10,
3551
  constant int64_t & nb11,
3552
  constant int64_t & nb12,
3553
  constant int64_t & ne0,
3554
  constant int64_t & ne1,
 
3555
  constant uint & r2,
3556
  constant uint & r3,
3557
  constant int & idx,
@@ -3578,3 +4134,775 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
3578
  template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
3579
  template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
3580
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  constant int64_t & nb1,
80
  constant int64_t & nb2,
81
  constant int64_t & nb3,
82
+ constant int64_t & offs,
83
  uint3 tgpig[[threadgroup_position_in_grid]],
84
  uint3 tpitg[[thread_position_in_threadgroup]],
85
  uint3 ntg[[threads_per_threadgroup]]) {
 
91
  const int64_t i12 = i02 % ne12;
92
  const int64_t i11 = i01 % ne11;
93
 
94
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
95
  device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
96
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
97
 
98
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
99
  const int i10 = i0 % ne10;
 
205
  device const float4 * src0,
206
  device const float4 * src1,
207
  device float4 * dst,
208
+ constant int64_t & nb [[buffer(28)]],
209
  uint tpig[[thread_position_in_grid]]) {
210
  dst[tpig] = src0[tpig] + src1[tpig % nb];
211
  }
 
214
  device const float4 * src0,
215
  device const float4 * src1,
216
  device float4 * dst,
217
+ constant int64_t & nb [[buffer(28)]],
218
  uint tpig[[thread_position_in_grid]]) {
219
  dst[tpig] = src0[tpig] * src1[tpig % nb];
220
  }
 
223
  device const float4 * src0,
224
  device const float4 * src1,
225
  device float4 * dst,
226
+ constant int64_t & nb [[buffer(28)]],
227
  uint tpig[[thread_position_in_grid]]) {
228
  dst[tpig] = src0[tpig] / src1[tpig % nb];
229
  }
 
244
  dst[tpig] = src0[tpig] * scale;
245
  }
246
 
247
+ kernel void kernel_relu(
248
+ device const float * src0,
249
+ device float * dst,
250
  uint tpig[[thread_position_in_grid]]) {
251
+ dst[tpig] = max(0.0f, src0[tpig]);
 
252
  }
253
 
254
+ kernel void kernel_tanh(
255
  device const float * src0,
256
  device float * dst,
257
  uint tpig[[thread_position_in_grid]]) {
258
+ device const float & x = src0[tpig];
259
+ dst[tpig] = precise::tanh(x);
260
+ }
261
+
262
+ constant float GELU_COEF_A = 0.044715f;
263
+ constant float GELU_QUICK_COEF = -1.702f;
264
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
265
+
266
+ kernel void kernel_gelu(
267
+ device const float4 * src0,
268
+ device float4 * dst,
269
+ uint tpig[[thread_position_in_grid]]) {
270
+ device const float4 & x = src0[tpig];
271
+
272
+ // BEWARE !!!
273
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
274
+ // This was observed with Falcon 7B and 40B models
275
+ //
276
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
277
+ }
278
+
279
+ kernel void kernel_gelu_quick(
280
+ device const float4 * src0,
281
+ device float4 * dst,
282
+ uint tpig[[thread_position_in_grid]]) {
283
+ device const float4 & x = src0[tpig];
284
+
285
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
286
+ }
287
+
288
+ kernel void kernel_silu(
289
+ device const float4 * src0,
290
+ device float4 * dst,
291
+ uint tpig[[thread_position_in_grid]]) {
292
+ device const float4 & x = src0[tpig];
293
+ dst[tpig] = x / (1.0f + exp(-x));
294
  }
295
 
296
  kernel void kernel_sqr(
 
348
  dst_row[0] = row_sum;
349
  }
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  kernel void kernel_soft_max(
352
  device const float * src0,
353
  device const float * src1,
 
366
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
367
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
368
 
369
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370
+ device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
371
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
372
 
373
  // parallel max
374
  float lmax = -INFINITY;
 
404
  pdst[i00] = exp_psrc0;
405
  }
406
 
407
+ // This barrier fixes a failing test
408
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
409
+ threadgroup_barrier(mem_flags::mem_none);
410
+
411
  float sum = simd_sum(lsum);
412
+
413
  if (ntg > N_SIMDWIDTH) {
414
  if (sgitg == 0) {
415
  buf[tiisg] = 0.0f;
 
452
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
453
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
454
 
455
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
456
+ device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
457
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
458
 
459
  // parallel max
460
  float4 lmax4 = -INFINITY;
 
492
  }
493
 
494
  const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
495
+
496
+ // This barrier fixes a failing test
497
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
498
+ threadgroup_barrier(mem_flags::mem_none);
499
+
500
  float sum = simd_sum(lsum);
501
+
502
  if (ntg > N_SIMDWIDTH) {
503
  if (sgitg == 0) {
504
  buf[tiisg] = 0.0f;
 
669
  }
670
  }
671
 
672
+ kernel void kernel_group_norm(
673
+ device const float * src0,
674
+ device float * dst,
675
+ constant int64_t & ne00,
676
+ constant int64_t & ne01,
677
+ constant int64_t & ne02,
678
+ constant uint64_t & nb00,
679
+ constant uint64_t & nb01,
680
+ constant uint64_t & nb02,
681
+ constant int32_t & n_groups,
682
+ constant float & eps,
683
+ threadgroup float * buf [[threadgroup(0)]],
684
+ uint tgpig[[threadgroup_position_in_grid]],
685
+ uint tpitg[[thread_position_in_threadgroup]],
686
+ uint sgitg[[simdgroup_index_in_threadgroup]],
687
+ uint tiisg[[thread_index_in_simdgroup]],
688
+ uint ntg[[threads_per_threadgroup]]) {
689
+ const int64_t ne = ne00*ne01*ne02;
690
+ const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
691
+
692
+ int start = tgpig * gs;
693
+ int end = start + gs;
694
+
695
+ start += tpitg;
696
+
697
+ if (end >= ne) {
698
+ end = ne;
699
+ }
700
+
701
+ float tmp = 0.0f; // partial sum for thread in warp
702
+
703
+ for (int j = start; j < end; j += ntg) {
704
+ tmp += src0[j];
705
+ }
706
+
707
+ threadgroup_barrier(mem_flags::mem_threadgroup);
708
+ tmp = simd_sum(tmp);
709
+ if (ntg > N_SIMDWIDTH) {
710
+ if (sgitg == 0) {
711
+ buf[tiisg] = 0.0f;
712
+ }
713
+
714
+ threadgroup_barrier(mem_flags::mem_threadgroup);
715
+
716
+ if (tiisg == 0) {
717
+ buf[sgitg] = tmp;
718
+ }
719
+
720
+ threadgroup_barrier(mem_flags::mem_threadgroup);
721
+
722
+ tmp = buf[tiisg];
723
+ tmp = simd_sum(tmp);
724
+ }
725
+
726
+ const float mean = tmp / gs;
727
+ tmp = 0.0f;
728
+
729
+ for (int j = start; j < end; j += ntg) {
730
+ float xi = src0[j] - mean;
731
+ dst[j] = xi;
732
+ tmp += xi * xi;
733
+ }
734
+
735
+ tmp = simd_sum(tmp);
736
+ if (ntg > N_SIMDWIDTH) {
737
+ if (sgitg == 0) {
738
+ buf[tiisg] = 0.0f;
739
+ }
740
+
741
+ threadgroup_barrier(mem_flags::mem_threadgroup);
742
+
743
+ if (tiisg == 0) {
744
+ buf[sgitg] = tmp;
745
+ }
746
+
747
+ threadgroup_barrier(mem_flags::mem_threadgroup);
748
+
749
+ tmp = buf[tiisg];
750
+ tmp = simd_sum(tmp);
751
+ }
752
+
753
+ const float variance = tmp / gs;
754
+ const float scale = 1.0f/sqrt(variance + eps);
755
+ for (int j = start; j < end; j += ntg) {
756
+ dst[j] *= scale;
757
+ }
758
+ }
759
+
760
  // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
761
  // il indicates where the q4 quants begin (0 or QK4_0/4)
762
  // we assume that the yl's have been multiplied with the appropriate scale factor
 
849
  // giard against the number of rows not being divisible by
850
  // N_DST, so this is another explicit assumption of the implementation.
851
  template<typename block_q_type, int nr, int nsg, int nw>
852
+ void mul_vec_q_n_f32_impl(
853
  device const void * src0,
854
  device const float * src1,
855
  device float * dst,
 
931
  uint3 tgpig[[threadgroup_position_in_grid]],
932
  uint tiisg[[thread_index_in_simdgroup]],
933
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
934
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
935
  }
936
 
937
  kernel void kernel_mul_mv_q4_1_f32(
 
950
  uint3 tgpig[[threadgroup_position_in_grid]],
951
  uint tiisg[[thread_index_in_simdgroup]],
952
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
953
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
954
  }
955
 
956
  kernel void kernel_mul_mv_q5_0_f32(
 
969
  uint3 tgpig[[threadgroup_position_in_grid]],
970
  uint tiisg[[thread_index_in_simdgroup]],
971
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
972
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
973
  }
974
 
975
  kernel void kernel_mul_mv_q5_1_f32(
 
988
  uint3 tgpig[[threadgroup_position_in_grid]],
989
  uint tiisg[[thread_index_in_simdgroup]],
990
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
991
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
992
  }
993
 
994
 
995
  #define NB_Q8_0 8
996
 
997
+ void kernel_mul_mv_q8_0_f32_impl(
998
  device const void * src0,
999
  device const float * src1,
1000
  device float * dst,
1001
  constant int64_t & ne00,
1002
+ constant int64_t & ne01,
1003
+ constant int64_t & ne02,
1004
+ constant int64_t & ne10,
1005
+ constant int64_t & ne12,
1006
+ constant int64_t & ne0,
1007
+ constant int64_t & ne1,
1008
+ constant uint & r2,
1009
+ constant uint & r3,
1010
  uint3 tgpig[[threadgroup_position_in_grid]],
1011
+ uint tiisg[[thread_index_in_simdgroup]],
1012
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1013
  const int nr = N_DST;
1014
  const int nsg = N_SIMDGROUP;
1015
  const int nw = N_SIMDWIDTH;
 
1063
  }
1064
  }
1065
 
1066
+ [[host_name("kernel_mul_mv_q8_0_f32")]]
1067
+ kernel void kernel_mul_mv_q8_0_f32(
1068
+ device const void * src0,
1069
+ device const float * src1,
1070
+ device float * dst,
1071
+ constant int64_t & ne00,
1072
+ constant int64_t & ne01,
1073
+ constant int64_t & ne02,
1074
+ constant int64_t & ne10,
1075
+ constant int64_t & ne12,
1076
+ constant int64_t & ne0,
1077
+ constant int64_t & ne1,
1078
+ constant uint & r2 [[buffer(17)]],
1079
+ constant uint & r3 [[buffer(18)]],
1080
+ uint3 tgpig[[threadgroup_position_in_grid]],
1081
+ uint tiisg[[thread_index_in_simdgroup]],
1082
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1083
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1084
+ }
1085
+
1086
  #define N_F32_F32 4
1087
 
1088
+ void kernel_mul_mv_f32_f32_impl(
1089
  device const char * src0,
1090
  device const char * src1,
1091
  device float * dst,
 
1103
  constant uint64_t & nb12,
1104
  constant int64_t & ne0,
1105
  constant int64_t & ne1,
1106
+ constant uint & r2,
1107
+ constant uint & r3,
1108
  uint3 tgpig[[threadgroup_position_in_grid]],
1109
  uint tiisg[[thread_index_in_simdgroup]]) {
1110
 
 
1163
  }
1164
  }
1165
 
1166
+ [[host_name("kernel_mul_mv_f32_f32")]]
1167
+ kernel void kernel_mul_mv_f32_f32(
1168
+ device const char * src0,
1169
+ device const char * src1,
1170
+ device float * dst,
1171
+ constant int64_t & ne00,
1172
+ constant int64_t & ne01,
1173
+ constant int64_t & ne02,
1174
+ constant uint64_t & nb00,
1175
+ constant uint64_t & nb01,
1176
+ constant uint64_t & nb02,
1177
+ constant int64_t & ne10,
1178
+ constant int64_t & ne11,
1179
+ constant int64_t & ne12,
1180
+ constant uint64_t & nb10,
1181
+ constant uint64_t & nb11,
1182
+ constant uint64_t & nb12,
1183
+ constant int64_t & ne0,
1184
+ constant int64_t & ne1,
1185
+ constant uint & r2 [[buffer(17)]],
1186
+ constant uint & r3 [[buffer(18)]],
1187
+ uint3 tgpig[[threadgroup_position_in_grid]],
1188
+ uint tiisg[[thread_index_in_simdgroup]]) {
1189
+ kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1190
+ }
1191
+
1192
  #define N_F16_F16 4
1193
 
1194
  kernel void kernel_mul_mv_f16_f16(
 
1269
  }
1270
  }
1271
 
1272
+ void kernel_mul_mv_f16_f32_1row_impl(
1273
  device const char * src0,
1274
  device const char * src1,
1275
  device float * dst,
 
1287
  constant uint64_t & nb12,
1288
  constant int64_t & ne0,
1289
  constant int64_t & ne1,
1290
+ constant uint & r2,
1291
+ constant uint & r3,
1292
  uint3 tgpig[[threadgroup_position_in_grid]],
1293
  uint tiisg[[thread_index_in_simdgroup]]) {
1294
 
 
1325
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1326
  }
1327
  }
1328
+ }
1329
 
1330
+ [[host_name("kernel_mul_mv_f16_f32_1row")]]
1331
+ kernel void kernel_mul_mv_f16_f32_1row(
1332
+ device const char * src0,
1333
+ device const char * src1,
1334
+ device float * dst,
1335
+ constant int64_t & ne00,
1336
+ constant int64_t & ne01,
1337
+ constant int64_t & ne02,
1338
+ constant uint64_t & nb00,
1339
+ constant uint64_t & nb01,
1340
+ constant uint64_t & nb02,
1341
+ constant int64_t & ne10,
1342
+ constant int64_t & ne11,
1343
+ constant int64_t & ne12,
1344
+ constant uint64_t & nb10,
1345
+ constant uint64_t & nb11,
1346
+ constant uint64_t & nb12,
1347
+ constant int64_t & ne0,
1348
+ constant int64_t & ne1,
1349
+ constant uint & r2 [[buffer(17)]],
1350
+ constant uint & r3 [[buffer(18)]],
1351
+ uint3 tgpig[[threadgroup_position_in_grid]],
1352
+ uint tiisg[[thread_index_in_simdgroup]]) {
1353
+ kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1354
  }
1355
 
1356
  #define N_F16_F32 4
1357
 
1358
+ void kernel_mul_mv_f16_f32_impl(
1359
  device const char * src0,
1360
  device const char * src1,
1361
  device float * dst,
 
1373
  constant uint64_t & nb12,
1374
  constant int64_t & ne0,
1375
  constant int64_t & ne1,
1376
+ constant uint & r2,
1377
+ constant uint & r3,
1378
  uint3 tgpig[[threadgroup_position_in_grid]],
1379
  uint tiisg[[thread_index_in_simdgroup]]) {
1380
 
 
1433
  }
1434
  }
1435
 
1436
+ [[host_name("kernel_mul_mv_f16_f32")]]
1437
+ kernel void kernel_mul_mv_f16_f32(
1438
+ device const char * src0,
1439
+ device const char * src1,
1440
+ device float * dst,
1441
+ constant int64_t & ne00,
1442
+ constant int64_t & ne01,
1443
+ constant int64_t & ne02,
1444
+ constant uint64_t & nb00,
1445
+ constant uint64_t & nb01,
1446
+ constant uint64_t & nb02,
1447
+ constant int64_t & ne10,
1448
+ constant int64_t & ne11,
1449
+ constant int64_t & ne12,
1450
+ constant uint64_t & nb10,
1451
+ constant uint64_t & nb11,
1452
+ constant uint64_t & nb12,
1453
+ constant int64_t & ne0,
1454
+ constant int64_t & ne1,
1455
+ constant uint & r2 [[buffer(17)]],
1456
+ constant uint & r3 [[buffer(18)]],
1457
+ uint3 tgpig[[threadgroup_position_in_grid]],
1458
+ uint tiisg[[thread_index_in_simdgroup]]) {
1459
+ kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1460
+ }
1461
+
1462
  // Assumes row size (ne00) is a multiple of 4
1463
  kernel void kernel_mul_mv_f16_f32_l4(
1464
  device const char * src0,
 
1763
  }
1764
  }
1765
 
1766
+ kernel void kernel_upscale_f32(
1767
+ device const char * src0,
1768
+ device char * dst,
1769
+ constant int64_t & ne00,
1770
+ constant int64_t & ne01,
1771
+ constant int64_t & ne02,
1772
+ constant int64_t & ne03,
1773
+ constant uint64_t & nb00,
1774
+ constant uint64_t & nb01,
1775
+ constant uint64_t & nb02,
1776
+ constant uint64_t & nb03,
1777
+ constant int64_t & ne0,
1778
+ constant int64_t & ne1,
1779
+ constant int64_t & ne2,
1780
+ constant int64_t & ne3,
1781
+ constant uint64_t & nb0,
1782
+ constant uint64_t & nb1,
1783
+ constant uint64_t & nb2,
1784
+ constant uint64_t & nb3,
1785
+ constant int32_t & sf,
1786
+ uint3 tgpig[[threadgroup_position_in_grid]],
1787
+ uint3 tpitg[[thread_position_in_threadgroup]],
1788
+ uint3 ntg[[threads_per_threadgroup]]) {
1789
+
1790
+ const int64_t i3 = tgpig.z;
1791
+ const int64_t i2 = tgpig.y;
1792
+ const int64_t i1 = tgpig.x;
1793
+
1794
+ const int64_t i03 = i3;
1795
+ const int64_t i02 = i2;
1796
+ const int64_t i01 = i1/sf;
1797
+
1798
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1799
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1800
+
1801
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1802
+ dst_ptr[i0] = src0_ptr[i0/sf];
1803
+ }
1804
+ }
1805
+
1806
+ kernel void kernel_pad_f32(
1807
+ device const char * src0,
1808
+ device char * dst,
1809
+ constant int64_t & ne00,
1810
+ constant int64_t & ne01,
1811
+ constant int64_t & ne02,
1812
+ constant int64_t & ne03,
1813
+ constant uint64_t & nb00,
1814
+ constant uint64_t & nb01,
1815
+ constant uint64_t & nb02,
1816
+ constant uint64_t & nb03,
1817
+ constant int64_t & ne0,
1818
+ constant int64_t & ne1,
1819
+ constant int64_t & ne2,
1820
+ constant int64_t & ne3,
1821
+ constant uint64_t & nb0,
1822
+ constant uint64_t & nb1,
1823
+ constant uint64_t & nb2,
1824
+ constant uint64_t & nb3,
1825
+ uint3 tgpig[[threadgroup_position_in_grid]],
1826
+ uint3 tpitg[[thread_position_in_threadgroup]],
1827
+ uint3 ntg[[threads_per_threadgroup]]) {
1828
+
1829
+ const int64_t i3 = tgpig.z;
1830
+ const int64_t i2 = tgpig.y;
1831
+ const int64_t i1 = tgpig.x;
1832
+
1833
+ const int64_t i03 = i3;
1834
+ const int64_t i02 = i2;
1835
+ const int64_t i01 = i1;
1836
+
1837
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1838
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1839
+
1840
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
1841
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1842
+ if (i0 < ne00) {
1843
+ dst_ptr[i0] = src0_ptr[i0];
1844
+ } else {
1845
+ dst_ptr[i0] = 0.0f;
1846
+ }
1847
+ }
1848
+
1849
+ return;
1850
+ }
1851
+
1852
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1853
+ dst_ptr[i0] = 0.0f;
1854
+ }
1855
+ }
1856
+
1857
+ // bitonic sort implementation following the CUDA kernels as reference
1858
+ typedef void (argsort_t)(
1859
+ device const float * x,
1860
+ device int32_t * dst,
1861
+ constant int64_t & ncols,
1862
+ uint3 tgpig[[threadgroup_position_in_grid]],
1863
+ uint3 tpitg[[thread_position_in_threadgroup]]);
1864
+
1865
+ template<ggml_sort_order order>
1866
+ kernel void kernel_argsort_f32_i32(
1867
+ device const float * x,
1868
+ device int32_t * dst,
1869
+ constant int64_t & ncols,
1870
+ uint3 tgpig[[threadgroup_position_in_grid]],
1871
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
1872
+ // bitonic sort
1873
+ int col = tpitg[0];
1874
+ int row = tgpig[1];
1875
+
1876
  if (col >= ncols) return;
1877
 
1878
  device const float * x_row = x + row * ncols;
 
1906
  template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
1907
  template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
1908
 
1909
+ kernel void kernel_leaky_relu_f32(
1910
+ device const float * src0,
1911
+ device float * dst,
1912
+ constant float & slope,
1913
+ uint tpig[[thread_position_in_grid]]) {
1914
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
1915
+ }
1916
+
1917
  kernel void kernel_cpy_f16_f16(
1918
+ device const half * src0,
1919
+ device half * dst,
1920
  constant int64_t & ne00,
1921
  constant int64_t & ne01,
1922
  constant int64_t & ne02,
 
1955
  }
1956
  }
1957
 
1958
+ kernel void kernel_cpy_f16_f32(
1959
+ device const half * src0,
1960
+ device float * dst,
1961
+ constant int64_t & ne00,
1962
+ constant int64_t & ne01,
1963
+ constant int64_t & ne02,
1964
+ constant int64_t & ne03,
1965
+ constant uint64_t & nb00,
1966
+ constant uint64_t & nb01,
1967
+ constant uint64_t & nb02,
1968
+ constant uint64_t & nb03,
1969
+ constant int64_t & ne0,
1970
+ constant int64_t & ne1,
1971
+ constant int64_t & ne2,
1972
+ constant int64_t & ne3,
1973
+ constant uint64_t & nb0,
1974
+ constant uint64_t & nb1,
1975
+ constant uint64_t & nb2,
1976
+ constant uint64_t & nb3,
1977
+ uint3 tgpig[[threadgroup_position_in_grid]],
1978
+ uint3 tpitg[[thread_position_in_threadgroup]],
1979
+ uint3 ntg[[threads_per_threadgroup]]) {
1980
+ const int64_t i03 = tgpig[2];
1981
+ const int64_t i02 = tgpig[1];
1982
+ const int64_t i01 = tgpig[0];
1983
+
1984
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1985
+
1986
+ const int64_t i3 = n / (ne2*ne1*ne0);
1987
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1988
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1989
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1990
+
1991
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1992
+
1993
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1994
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1995
+ dst_data[i00] = src[0];
1996
+ }
1997
+ }
1998
+
1999
  kernel void kernel_cpy_f32_f16(
2000
  device const float * src0,
2001
  device half * dst,
 
2272
  }
2273
 
2274
  kernel void kernel_concat(
2275
+ device const char * src0,
2276
+ device const char * src1,
2277
+ device char * dst,
2278
  constant int64_t & ne00,
2279
  constant int64_t & ne01,
2280
  constant int64_t & ne02,
 
2311
  const int64_t i12 = i02 % ne12;
2312
  const int64_t i11 = i01 % ne11;
2313
 
2314
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
2315
  device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
2316
  device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
2317
 
 
2419
 
2420
  //====================================== dot products =========================
2421
 
2422
+ void kernel_mul_mv_q2_K_f32_impl(
2423
  device const void * src0,
2424
  device const float * src1,
2425
  device float * dst,
2426
  constant int64_t & ne00,
2427
+ constant int64_t & ne01,
2428
+ constant int64_t & ne02,
2429
+ constant int64_t & ne10,
2430
+ constant int64_t & ne12,
2431
+ constant int64_t & ne0,
2432
+ constant int64_t & ne1,
2433
+ constant uint & r2,
2434
+ constant uint & r3,
2435
  uint3 tgpig[[threadgroup_position_in_grid]],
2436
  uint tiisg[[thread_index_in_simdgroup]],
2437
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
2569
  }
2570
  }
2571
 
2572
+ [[host_name("kernel_mul_mv_q2_K_f32")]]
2573
+ kernel void kernel_mul_mv_q2_K_f32(
2574
  device const void * src0,
2575
  device const float * src1,
2576
  device float * dst,
 
2584
  constant uint & r2 [[buffer(17)]],
2585
  constant uint & r3 [[buffer(18)]],
2586
  uint3 tgpig[[threadgroup_position_in_grid]],
2587
+ uint tiisg[[thread_index_in_simdgroup]],
2588
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2589
+
2590
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2591
+ }
2592
+
2593
+ #if QK_K == 256
2594
+ void kernel_mul_mv_q3_K_f32_impl(
2595
+ device const void * src0,
2596
+ device const float * src1,
2597
+ device float * dst,
2598
+ constant int64_t & ne00,
2599
+ constant int64_t & ne01,
2600
+ constant int64_t & ne02,
2601
+ constant int64_t & ne10,
2602
+ constant int64_t & ne12,
2603
+ constant int64_t & ne0,
2604
+ constant int64_t & ne1,
2605
+ constant uint & r2,
2606
+ constant uint & r3,
2607
+ uint3 tgpig[[threadgroup_position_in_grid]],
2608
+ uint tiisg[[thread_index_in_simdgroup]],
2609
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2610
 
2611
  const int nb = ne00/QK_K;
2612
 
 
2749
  }
2750
  }
2751
  #else
2752
+ void kernel_mul_mv_q3_K_f32_impl(
2753
  device const void * src0,
2754
  device const float * src1,
2755
  device float * dst,
2756
  constant int64_t & ne00,
2757
+ constant int64_t & ne01,
2758
+ constant int64_t & ne02,
2759
+ constant int64_t & ne10,
2760
+ constant int64_t & ne12,
2761
+ constant int64_t & ne0,
2762
+ constant int64_t & ne1,
2763
+ constant uint & r2,
2764
+ constant uint & r3,
2765
  uint3 tgpig[[threadgroup_position_in_grid]],
2766
  uint tiisg[[thread_index_in_simdgroup]],
2767
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
2826
  }
2827
  #endif
2828
 
2829
+ [[host_name("kernel_mul_mv_q3_K_f32")]]
2830
+ kernel void kernel_mul_mv_q3_K_f32(
2831
+ device const void * src0,
2832
+ device const float * src1,
2833
+ device float * dst,
2834
+ constant int64_t & ne00,
2835
+ constant int64_t & ne01[[buffer(4)]],
2836
+ constant int64_t & ne02[[buffer(5)]],
2837
+ constant int64_t & ne10[[buffer(9)]],
2838
+ constant int64_t & ne12[[buffer(11)]],
2839
+ constant int64_t & ne0 [[buffer(15)]],
2840
+ constant int64_t & ne1 [[buffer(16)]],
2841
+ constant uint & r2 [[buffer(17)]],
2842
+ constant uint & r3 [[buffer(18)]],
2843
+ uint3 tgpig[[threadgroup_position_in_grid]],
2844
+ uint tiisg[[thread_index_in_simdgroup]],
2845
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2846
+
2847
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2848
+ }
2849
+
2850
  #if QK_K == 256
2851
+ void kernel_mul_mv_q4_K_f32_impl(
2852
  device const void * src0,
2853
  device const float * src1,
2854
  device float * dst,
2855
  constant int64_t & ne00,
2856
+ constant int64_t & ne01,
2857
+ constant int64_t & ne02,
2858
+ constant int64_t & ne10,
2859
+ constant int64_t & ne12,
2860
+ constant int64_t & ne0,
2861
+ constant int64_t & ne1,
2862
+ constant uint & r2,
2863
+ constant uint & r3,
2864
  uint3 tgpig[[threadgroup_position_in_grid]],
2865
  uint tiisg[[thread_index_in_simdgroup]],
2866
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
2961
  }
2962
  }
2963
  #else
2964
+ void kernel_mul_mv_q4_K_f32_impl(
2965
  device const void * src0,
2966
  device const float * src1,
2967
  device float * dst,
2968
  constant int64_t & ne00,
2969
+ constant int64_t & ne01,
2970
+ constant int64_t & ne02,
2971
+ constant int64_t & ne10,
2972
+ constant int64_t & ne12,
2973
+ constant int64_t & ne0,
2974
+ constant int64_t & ne1,
2975
+ constant uint & r2,
2976
+ constant uint & r3,
2977
  uint3 tgpig[[threadgroup_position_in_grid]],
2978
  uint tiisg[[thread_index_in_simdgroup]],
2979
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
3057
  }
3058
  #endif
3059
 
3060
+ [[host_name("kernel_mul_mv_q4_K_f32")]]
3061
+ kernel void kernel_mul_mv_q4_K_f32(
3062
  device const void * src0,
3063
  device const float * src1,
3064
  device float * dst,
 
3075
  uint tiisg[[thread_index_in_simdgroup]],
3076
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3077
 
3078
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3079
+ }
3080
+
3081
+ void kernel_mul_mv_q5_K_f32_impl(
3082
+ device const void * src0,
3083
+ device const float * src1,
3084
+ device float * dst,
3085
+ constant int64_t & ne00,
3086
+ constant int64_t & ne01,
3087
+ constant int64_t & ne02,
3088
+ constant int64_t & ne10,
3089
+ constant int64_t & ne12,
3090
+ constant int64_t & ne0,
3091
+ constant int64_t & ne1,
3092
+ constant uint & r2,
3093
+ constant uint & r3,
3094
+ uint3 tgpig[[threadgroup_position_in_grid]],
3095
+ uint tiisg[[thread_index_in_simdgroup]],
3096
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3097
+
3098
  const int nb = ne00/QK_K;
3099
 
3100
  const int64_t r0 = tgpig.x;
 
3254
  dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
3255
  }
3256
  }
 
3257
  }
3258
 
3259
+ [[host_name("kernel_mul_mv_q5_K_f32")]]
3260
+ kernel void kernel_mul_mv_q5_K_f32(
3261
  device const void * src0,
3262
  device const float * src1,
3263
  device float * dst,
 
3271
  constant uint & r2 [[buffer(17)]],
3272
  constant uint & r3 [[buffer(18)]],
3273
  uint3 tgpig[[threadgroup_position_in_grid]],
3274
+ uint tiisg[[thread_index_in_simdgroup]],
3275
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
 
 
 
3276
 
3277
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3278
+ }
3279
 
3280
+ void kernel_mul_mv_q6_K_f32_impl(
3281
+ device const void * src0,
3282
+ device const float * src1,
3283
+ device float * dst,
3284
+ constant int64_t & ne00,
3285
+ constant int64_t & ne01,
3286
+ constant int64_t & ne02,
3287
+ constant int64_t & ne10,
3288
+ constant int64_t & ne12,
3289
+ constant int64_t & ne0,
3290
+ constant int64_t & ne1,
3291
+ constant uint & r2,
3292
+ constant uint & r3,
3293
+ uint3 tgpig[[threadgroup_position_in_grid]],
3294
+ uint tiisg[[thread_index_in_simdgroup]],
3295
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3296
+
3297
+ const uint8_t kmask1 = 0x03;
3298
+ const uint8_t kmask2 = 0x0C;
3299
+ const uint8_t kmask3 = 0x30;
3300
+ const uint8_t kmask4 = 0xC0;
3301
+
3302
+ const int nb = ne00/QK_K;
3303
+
3304
+ const int64_t r0 = tgpig.x;
3305
+ const int64_t r1 = tgpig.y;
3306
  const int im = tgpig.z;
3307
 
3308
  const int row = 2 * r0 + sgitg;
 
3383
  }
3384
  }
3385
 
3386
+ [[host_name("kernel_mul_mv_q6_K_f32")]]
3387
+ kernel void kernel_mul_mv_q6_K_f32(
3388
+ device const void * src0,
3389
+ device const float * src1,
3390
+ device float * dst,
3391
+ constant int64_t & ne00,
3392
+ constant int64_t & ne01[[buffer(4)]],
3393
+ constant int64_t & ne02[[buffer(5)]],
3394
+ constant int64_t & ne10[[buffer(9)]],
3395
+ constant int64_t & ne12[[buffer(11)]],
3396
+ constant int64_t & ne0 [[buffer(15)]],
3397
+ constant int64_t & ne1 [[buffer(16)]],
3398
+ constant uint & r2 [[buffer(17)]],
3399
+ constant uint & r3 [[buffer(18)]],
3400
+ uint3 tgpig[[threadgroup_position_in_grid]],
3401
+ uint tiisg[[thread_index_in_simdgroup]],
3402
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3403
+
3404
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3405
+ }
3406
+
3407
  //============================= templates and their specializations =============================
3408
 
3409
  // NOTE: this is not dequantizing - we are simply fitting the template
 
3521
 
3522
  template <typename type4x4>
3523
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
3524
+ const float d = xb->d;
3525
+ const float min = xb->dmin;
3526
  device const uint8_t * q = (device const uint8_t *)xb->qs;
3527
+ float dl, ml;
3528
  uint8_t sc = xb->scales[il];
3529
 
3530
  #if QK_K == 256
 
3594
  q = q + (il/4) * 32 + 16 * (il&1);
3595
  il = il & 3;
3596
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3597
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
3598
+ const float min = xb->dmin;
3599
+ const float dl = d * sc[0];
3600
+ const float ml = min * sc[1];
3601
  #else
3602
  q = q + 16 * (il&1);
3603
  device const uint8_t * s = xb->scales;
 
3624
  uint8_t ul = 1 << (il/2);
3625
  il = il & 3;
3626
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3627
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
3628
+ const float min = xb->dmin;
3629
+ const float dl = d * sc[0];
3630
+ const float ml = min * sc[1];
3631
 
3632
+ const ushort mask = il<2 ? 0x0F : 0xF0;
3633
+ const float qh_val = il<2 ? 16.f : 256.f;
3634
  for (int i = 0; i < 16; ++i) {
3635
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
3636
  }
 
3678
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
3679
  kernel void kernel_get_rows(
3680
  device const void * src0,
3681
+ device const char * src1,
3682
  device float * dst,
3683
  constant int64_t & ne00,
3684
  constant uint64_t & nb01,
3685
+ constant uint64_t & nb02,
3686
+ constant int64_t & ne10,
3687
+ constant uint64_t & nb10,
3688
+ constant uint64_t & nb11,
3689
  constant uint64_t & nb1,
3690
+ constant uint64_t & nb2,
3691
+ uint3 tgpig[[threadgroup_position_in_grid]],
3692
  uint tiitg[[thread_index_in_threadgroup]],
3693
+ uint3 tptg [[threads_per_threadgroup]]) {
3694
+ //const int64_t i = tgpig;
3695
+ //const int64_t r = ((device int32_t *) src1)[i];
3696
+
3697
+ const int64_t i10 = tgpig.x;
3698
+ const int64_t i11 = tgpig.y;
3699
 
3700
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3701
+
3702
+ const int64_t i02 = i11;
3703
+
3704
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
3705
  float4x4 temp;
3706
  dequantize_func(
3707
+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
3708
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
3709
+ }
3710
+ }
3711
+
3712
+ kernel void kernel_get_rows_f32(
3713
+ device const void * src0,
3714
+ device const char * src1,
3715
+ device float * dst,
3716
+ constant int64_t & ne00,
3717
+ constant uint64_t & nb01,
3718
+ constant uint64_t & nb02,
3719
+ constant int64_t & ne10,
3720
+ constant uint64_t & nb10,
3721
+ constant uint64_t & nb11,
3722
+ constant uint64_t & nb1,
3723
+ constant uint64_t & nb2,
3724
+ uint3 tgpig[[threadgroup_position_in_grid]],
3725
+ uint tiitg[[thread_index_in_threadgroup]],
3726
+ uint3 tptg [[threads_per_threadgroup]]) {
3727
+ const int64_t i10 = tgpig.x;
3728
+ const int64_t i11 = tgpig.y;
3729
+
3730
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3731
+
3732
+ const int64_t i02 = i11;
3733
+
3734
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3735
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3736
+ ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3737
+ }
3738
+ }
3739
+
3740
+ kernel void kernel_get_rows_f16(
3741
+ device const void * src0,
3742
+ device const char * src1,
3743
+ device float * dst,
3744
+ constant int64_t & ne00,
3745
+ constant uint64_t & nb01,
3746
+ constant uint64_t & nb02,
3747
+ constant int64_t & ne10,
3748
+ constant uint64_t & nb10,
3749
+ constant uint64_t & nb11,
3750
+ constant uint64_t & nb1,
3751
+ constant uint64_t & nb2,
3752
+ uint3 tgpig[[threadgroup_position_in_grid]],
3753
+ uint tiitg[[thread_index_in_threadgroup]],
3754
+ uint3 tptg [[threads_per_threadgroup]]) {
3755
+ const int64_t i10 = tgpig.x;
3756
+ const int64_t i11 = tgpig.y;
3757
+
3758
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3759
+
3760
+ const int64_t i02 = i11;
3761
+
3762
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3763
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3764
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3765
  }
3766
  }
3767
 
 
3953
 
3954
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3955
  kernel void kernel_mul_mm_id(
3956
+ device const uchar * ids,
3957
  device const uchar * src1,
3958
+ device uchar * dst,
3959
+ constant int64_t & nbi1,
3960
  constant int64_t & ne00,
3961
  constant int64_t & ne02,
3962
  constant int64_t & nb01,
3963
  constant int64_t & nb02,
3964
  constant int64_t & ne12,
3965
+ constant int64_t & ne13,
3966
  constant int64_t & nb10,
3967
  constant int64_t & nb11,
3968
  constant int64_t & nb12,
3969
  constant int64_t & ne0,
3970
  constant int64_t & ne1,
3971
+ constant int64_t & nb1,
3972
  constant uint & r2,
3973
  constant uint & r3,
3974
  constant int & idx,
 
3986
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3987
  device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3988
 
3989
+ const int64_t bid = tgpig.z/(ne12*ne13);
3990
+
3991
+ tgpig.z = tgpig.z%(ne12*ne13);
3992
+
3993
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
3994
+
3995
  kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3996
+ src0[id],
3997
+ src1 + bid*nb11,
3998
+ (device float *) (dst + bid*nb1),
3999
  ne00,
4000
  ne02,
4001
  nb01,
 
4020
  #define QK_NL 4
4021
  #endif
4022
 
4023
+ //
4024
+ // get rows
4025
+ //
4026
+
4027
  typedef void (get_rows_t)(
4028
  device const void * src0,
4029
+ device const char * src1,
4030
  device float * dst,
4031
  constant int64_t & ne00,
4032
  constant uint64_t & nb01,
4033
+ constant uint64_t & nb02,
4034
+ constant int64_t & ne10,
4035
+ constant uint64_t & nb10,
4036
+ constant uint64_t & nb11,
4037
  constant uint64_t & nb1,
4038
+ constant uint64_t & nb2,
4039
+ uint3, uint, uint3);
4040
 
4041
+ //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
4042
+ //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
4043
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
4044
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
4045
  template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
 
4051
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
4052
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
4053
 
4054
+ //
4055
+ // matrix-matrix multiplication
4056
+ //
4057
+
4058
  typedef void (mat_mm_t)(
4059
  device const uchar * src0,
4060
  device const uchar * src1,
 
4087
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
4088
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4089
 
4090
+ //
4091
+ // indirect matrix-matrix multiplication
4092
+ //
4093
+
4094
  typedef void (mat_mm_id_t)(
4095
+ device const uchar * ids,
4096
  device const uchar * src1,
4097
+ device uchar * dst,
4098
+ constant int64_t & nbi1,
4099
  constant int64_t & ne00,
4100
  constant int64_t & ne02,
4101
  constant int64_t & nb01,
4102
  constant int64_t & nb02,
4103
  constant int64_t & ne12,
4104
+ constant int64_t & ne13,
4105
  constant int64_t & nb10,
4106
  constant int64_t & nb11,
4107
  constant int64_t & nb12,
4108
  constant int64_t & ne0,
4109
  constant int64_t & ne1,
4110
+ constant int64_t & nb1,
4111
  constant uint & r2,
4112
  constant uint & r3,
4113
  constant int & idx,
 
4134
  template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
4135
  template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
4136
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4137
+
4138
+ //
4139
+ // matrix-vector multiplication
4140
+ //
4141
+
4142
+ [[host_name("kernel_mul_mv_id_f32_f32")]]
4143
+ kernel void kernel_mul_mv_id_f32_f32(
4144
+ device const char * ids,
4145
+ device const char * src1,
4146
+ device uchar * dst,
4147
+ constant int64_t & nbi1,
4148
+ constant int64_t & ne00,
4149
+ constant int64_t & ne01,
4150
+ constant int64_t & ne02,
4151
+ constant uint64_t & nb00,
4152
+ constant uint64_t & nb01,
4153
+ constant uint64_t & nb02,
4154
+ constant int64_t & ne10,
4155
+ constant int64_t & ne11,
4156
+ constant int64_t & ne12,
4157
+ constant int64_t & ne13,
4158
+ constant uint64_t & nb10,
4159
+ constant uint64_t & nb11,
4160
+ constant uint64_t & nb12,
4161
+ constant int64_t & ne0,
4162
+ constant int64_t & ne1,
4163
+ constant int64_t & nb1,
4164
+ constant uint & r2,
4165
+ constant uint & r3,
4166
+ constant int & idx,
4167
+ device const char * src00,
4168
+ device const char * src01,
4169
+ device const char * src02,
4170
+ device const char * src03,
4171
+ device const char * src04,
4172
+ device const char * src05,
4173
+ device const char * src06,
4174
+ device const char * src07,
4175
+ uint3 tgpig[[threadgroup_position_in_grid]],
4176
+ uint tiitg[[thread_index_in_threadgroup]],
4177
+ uint tiisg[[thread_index_in_simdgroup]],
4178
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4179
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4180
+
4181
+ const int64_t bid = tgpig.z/(ne12*ne13);
4182
+
4183
+ tgpig.z = tgpig.z%(ne12*ne13);
4184
+
4185
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4186
+
4187
+ kernel_mul_mv_f32_f32_impl(
4188
+ src0[id],
4189
+ src1 + bid*nb11,
4190
+ (device float *) (dst + bid*nb1),
4191
+ ne00,
4192
+ ne01,
4193
+ ne02,
4194
+ nb00,
4195
+ nb01,
4196
+ nb02,
4197
+ ne10,
4198
+ ne11,
4199
+ ne12,
4200
+ nb10,
4201
+ nb11,
4202
+ nb12,
4203
+ ne0,
4204
+ ne1,
4205
+ r2,
4206
+ r3,
4207
+ tgpig,
4208
+ tiisg);
4209
+ }
4210
+
4211
+ [[host_name("kernel_mul_mv_id_f16_f32")]]
4212
+ kernel void kernel_mul_mv_id_f16_f32(
4213
+ device const char * ids,
4214
+ device const char * src1,
4215
+ device uchar * dst,
4216
+ constant int64_t & nbi1,
4217
+ constant int64_t & ne00,
4218
+ constant int64_t & ne01,
4219
+ constant int64_t & ne02,
4220
+ constant uint64_t & nb00,
4221
+ constant uint64_t & nb01,
4222
+ constant uint64_t & nb02,
4223
+ constant int64_t & ne10,
4224
+ constant int64_t & ne11,
4225
+ constant int64_t & ne12,
4226
+ constant int64_t & ne13,
4227
+ constant uint64_t & nb10,
4228
+ constant uint64_t & nb11,
4229
+ constant uint64_t & nb12,
4230
+ constant int64_t & ne0,
4231
+ constant int64_t & ne1,
4232
+ constant int64_t & nb1,
4233
+ constant uint & r2,
4234
+ constant uint & r3,
4235
+ constant int & idx,
4236
+ device const char * src00,
4237
+ device const char * src01,
4238
+ device const char * src02,
4239
+ device const char * src03,
4240
+ device const char * src04,
4241
+ device const char * src05,
4242
+ device const char * src06,
4243
+ device const char * src07,
4244
+ uint3 tgpig[[threadgroup_position_in_grid]],
4245
+ uint tiitg[[thread_index_in_threadgroup]],
4246
+ uint tiisg[[thread_index_in_simdgroup]],
4247
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4248
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4249
+
4250
+ const int64_t bid = tgpig.z/(ne12*ne13);
4251
+
4252
+ tgpig.z = tgpig.z%(ne12*ne13);
4253
+
4254
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4255
+
4256
+ kernel_mul_mv_f16_f32_impl(
4257
+ src0[id],
4258
+ src1 + bid*nb11,
4259
+ (device float *) (dst + bid*nb1),
4260
+ ne00,
4261
+ ne01,
4262
+ ne02,
4263
+ nb00,
4264
+ nb01,
4265
+ nb02,
4266
+ ne10,
4267
+ ne11,
4268
+ ne12,
4269
+ nb10,
4270
+ nb11,
4271
+ nb12,
4272
+ ne0,
4273
+ ne1,
4274
+ r2,
4275
+ r3,
4276
+ tgpig,
4277
+ tiisg);
4278
+ }
4279
+
4280
+ [[host_name("kernel_mul_mv_id_q8_0_f32")]]
4281
+ kernel void kernel_mul_mv_id_q8_0_f32(
4282
+ device const char * ids,
4283
+ device const char * src1,
4284
+ device uchar * dst,
4285
+ constant int64_t & nbi1,
4286
+ constant int64_t & ne00,
4287
+ constant int64_t & ne01,
4288
+ constant int64_t & ne02,
4289
+ constant uint64_t & nb00,
4290
+ constant uint64_t & nb01,
4291
+ constant uint64_t & nb02,
4292
+ constant int64_t & ne10,
4293
+ constant int64_t & ne11,
4294
+ constant int64_t & ne12,
4295
+ constant int64_t & ne13,
4296
+ constant uint64_t & nb10,
4297
+ constant uint64_t & nb11,
4298
+ constant uint64_t & nb12,
4299
+ constant int64_t & ne0,
4300
+ constant int64_t & ne1,
4301
+ constant int64_t & nb1,
4302
+ constant uint & r2,
4303
+ constant uint & r3,
4304
+ constant int & idx,
4305
+ device const char * src00,
4306
+ device const char * src01,
4307
+ device const char * src02,
4308
+ device const char * src03,
4309
+ device const char * src04,
4310
+ device const char * src05,
4311
+ device const char * src06,
4312
+ device const char * src07,
4313
+ uint3 tgpig[[threadgroup_position_in_grid]],
4314
+ uint tiitg[[thread_index_in_threadgroup]],
4315
+ uint tiisg[[thread_index_in_simdgroup]],
4316
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4317
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4318
+
4319
+ const int64_t bid = tgpig.z/(ne12*ne13);
4320
+
4321
+ tgpig.z = tgpig.z%(ne12*ne13);
4322
+
4323
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4324
+
4325
+ kernel_mul_mv_q8_0_f32_impl(
4326
+ src0[id],
4327
+ (device const float *) (src1 + bid*nb11),
4328
+ (device float *) ( dst + bid*nb1),
4329
+ ne00,
4330
+ ne01,
4331
+ ne02,
4332
+ ne10,
4333
+ ne12,
4334
+ ne0,
4335
+ ne1,
4336
+ r2,
4337
+ r3,
4338
+ tgpig,
4339
+ tiisg,
4340
+ sgitg);
4341
+ }
4342
+
4343
+ [[host_name("kernel_mul_mv_id_q4_0_f32")]]
4344
+ kernel void kernel_mul_mv_id_q4_0_f32(
4345
+ device const char * ids,
4346
+ device const char * src1,
4347
+ device uchar * dst,
4348
+ constant int64_t & nbi1,
4349
+ constant int64_t & ne00,
4350
+ constant int64_t & ne01,
4351
+ constant int64_t & ne02,
4352
+ constant uint64_t & nb00,
4353
+ constant uint64_t & nb01,
4354
+ constant uint64_t & nb02,
4355
+ constant int64_t & ne10,
4356
+ constant int64_t & ne11,
4357
+ constant int64_t & ne12,
4358
+ constant int64_t & ne13,
4359
+ constant uint64_t & nb10,
4360
+ constant uint64_t & nb11,
4361
+ constant uint64_t & nb12,
4362
+ constant int64_t & ne0,
4363
+ constant int64_t & ne1,
4364
+ constant int64_t & nb1,
4365
+ constant uint & r2,
4366
+ constant uint & r3,
4367
+ constant int & idx,
4368
+ device const char * src00,
4369
+ device const char * src01,
4370
+ device const char * src02,
4371
+ device const char * src03,
4372
+ device const char * src04,
4373
+ device const char * src05,
4374
+ device const char * src06,
4375
+ device const char * src07,
4376
+ uint3 tgpig[[threadgroup_position_in_grid]],
4377
+ uint tiitg[[thread_index_in_threadgroup]],
4378
+ uint tiisg[[thread_index_in_simdgroup]],
4379
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4380
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4381
+
4382
+ const int64_t bid = tgpig.z/(ne12*ne13);
4383
+
4384
+ tgpig.z = tgpig.z%(ne12*ne13);
4385
+
4386
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4387
+
4388
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4389
+ src0[id],
4390
+ (device const float *) (src1 + bid*nb11),
4391
+ (device float *) ( dst + bid*nb1),
4392
+ ne00,
4393
+ ne01,
4394
+ ne02,
4395
+ ne10,
4396
+ ne12,
4397
+ ne0,
4398
+ ne1,
4399
+ r2,
4400
+ r3,
4401
+ tgpig,
4402
+ tiisg,
4403
+ sgitg);
4404
+ }
4405
+
4406
+ [[host_name("kernel_mul_mv_id_q4_1_f32")]]
4407
+ kernel void kernel_mul_mv_id_q4_1_f32(
4408
+ device const char * ids,
4409
+ device const char * src1,
4410
+ device uchar * dst,
4411
+ constant int64_t & nbi1,
4412
+ constant int64_t & ne00,
4413
+ constant int64_t & ne01,
4414
+ constant int64_t & ne02,
4415
+ constant uint64_t & nb00,
4416
+ constant uint64_t & nb01,
4417
+ constant uint64_t & nb02,
4418
+ constant int64_t & ne10,
4419
+ constant int64_t & ne11,
4420
+ constant int64_t & ne12,
4421
+ constant int64_t & ne13,
4422
+ constant uint64_t & nb10,
4423
+ constant uint64_t & nb11,
4424
+ constant uint64_t & nb12,
4425
+ constant int64_t & ne0,
4426
+ constant int64_t & ne1,
4427
+ constant int64_t & nb1,
4428
+ constant uint & r2,
4429
+ constant uint & r3,
4430
+ constant int & idx,
4431
+ device const char * src00,
4432
+ device const char * src01,
4433
+ device const char * src02,
4434
+ device const char * src03,
4435
+ device const char * src04,
4436
+ device const char * src05,
4437
+ device const char * src06,
4438
+ device const char * src07,
4439
+ uint3 tgpig[[threadgroup_position_in_grid]],
4440
+ uint tiitg[[thread_index_in_threadgroup]],
4441
+ uint tiisg[[thread_index_in_simdgroup]],
4442
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4443
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4444
+
4445
+ const int64_t bid = tgpig.z/(ne12*ne13);
4446
+
4447
+ tgpig.z = tgpig.z%(ne12*ne13);
4448
+
4449
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4450
+
4451
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4452
+ src0[id],
4453
+ (device const float *) (src1 + bid*nb11),
4454
+ (device float *) ( dst + bid*nb1),
4455
+ ne00,
4456
+ ne01,
4457
+ ne02,
4458
+ ne10,
4459
+ ne12,
4460
+ ne0,
4461
+ ne1,
4462
+ r2,
4463
+ r3,
4464
+ tgpig,
4465
+ tiisg,
4466
+ sgitg);
4467
+ }
4468
+
4469
+ [[host_name("kernel_mul_mv_id_q5_0_f32")]]
4470
+ kernel void kernel_mul_mv_id_q5_0_f32(
4471
+ device const char * ids,
4472
+ device const char * src1,
4473
+ device uchar * dst,
4474
+ constant int64_t & nbi1,
4475
+ constant int64_t & ne00,
4476
+ constant int64_t & ne01,
4477
+ constant int64_t & ne02,
4478
+ constant uint64_t & nb00,
4479
+ constant uint64_t & nb01,
4480
+ constant uint64_t & nb02,
4481
+ constant int64_t & ne10,
4482
+ constant int64_t & ne11,
4483
+ constant int64_t & ne12,
4484
+ constant int64_t & ne13,
4485
+ constant uint64_t & nb10,
4486
+ constant uint64_t & nb11,
4487
+ constant uint64_t & nb12,
4488
+ constant int64_t & ne0,
4489
+ constant int64_t & ne1,
4490
+ constant int64_t & nb1,
4491
+ constant uint & r2,
4492
+ constant uint & r3,
4493
+ constant int & idx,
4494
+ device const char * src00,
4495
+ device const char * src01,
4496
+ device const char * src02,
4497
+ device const char * src03,
4498
+ device const char * src04,
4499
+ device const char * src05,
4500
+ device const char * src06,
4501
+ device const char * src07,
4502
+ uint3 tgpig[[threadgroup_position_in_grid]],
4503
+ uint tiitg[[thread_index_in_threadgroup]],
4504
+ uint tiisg[[thread_index_in_simdgroup]],
4505
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4506
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4507
+
4508
+ const int64_t bid = tgpig.z/(ne12*ne13);
4509
+
4510
+ tgpig.z = tgpig.z%(ne12*ne13);
4511
+
4512
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4513
+
4514
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4515
+ src0[id],
4516
+ (device const float *) (src1 + bid*nb11),
4517
+ (device float *) ( dst + bid*nb1),
4518
+ ne00,
4519
+ ne01,
4520
+ ne02,
4521
+ ne10,
4522
+ ne12,
4523
+ ne0,
4524
+ ne1,
4525
+ r2,
4526
+ r3,
4527
+ tgpig,
4528
+ tiisg,
4529
+ sgitg);
4530
+ }
4531
+
4532
+ [[host_name("kernel_mul_mv_id_q5_1_f32")]]
4533
+ kernel void kernel_mul_mv_id_q5_1_f32(
4534
+ device const char * ids,
4535
+ device const char * src1,
4536
+ device uchar * dst,
4537
+ constant int64_t & nbi1,
4538
+ constant int64_t & ne00,
4539
+ constant int64_t & ne01,
4540
+ constant int64_t & ne02,
4541
+ constant uint64_t & nb00,
4542
+ constant uint64_t & nb01,
4543
+ constant uint64_t & nb02,
4544
+ constant int64_t & ne10,
4545
+ constant int64_t & ne11,
4546
+ constant int64_t & ne12,
4547
+ constant int64_t & ne13,
4548
+ constant uint64_t & nb10,
4549
+ constant uint64_t & nb11,
4550
+ constant uint64_t & nb12,
4551
+ constant int64_t & ne0,
4552
+ constant int64_t & ne1,
4553
+ constant int64_t & nb1,
4554
+ constant uint & r2,
4555
+ constant uint & r3,
4556
+ constant int & idx,
4557
+ device const char * src00,
4558
+ device const char * src01,
4559
+ device const char * src02,
4560
+ device const char * src03,
4561
+ device const char * src04,
4562
+ device const char * src05,
4563
+ device const char * src06,
4564
+ device const char * src07,
4565
+ uint3 tgpig[[threadgroup_position_in_grid]],
4566
+ uint tiitg[[thread_index_in_threadgroup]],
4567
+ uint tiisg[[thread_index_in_simdgroup]],
4568
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4569
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4570
+
4571
+ const int64_t bid = tgpig.z/(ne12*ne13);
4572
+
4573
+ tgpig.z = tgpig.z%(ne12*ne13);
4574
+
4575
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4576
+
4577
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4578
+ src0[id],
4579
+ (device const float *) (src1 + bid*nb11),
4580
+ (device float *) ( dst + bid*nb1),
4581
+ ne00,
4582
+ ne01,
4583
+ ne02,
4584
+ ne10,
4585
+ ne12,
4586
+ ne0,
4587
+ ne1,
4588
+ r2,
4589
+ r3,
4590
+ tgpig,
4591
+ tiisg,
4592
+ sgitg);
4593
+ }
4594
+
4595
+ [[host_name("kernel_mul_mv_id_q2_K_f32")]]
4596
+ kernel void kernel_mul_mv_id_q2_K_f32(
4597
+ device const char * ids,
4598
+ device const char * src1,
4599
+ device uchar * dst,
4600
+ constant int64_t & nbi1,
4601
+ constant int64_t & ne00,
4602
+ constant int64_t & ne01,
4603
+ constant int64_t & ne02,
4604
+ constant uint64_t & nb00,
4605
+ constant uint64_t & nb01,
4606
+ constant uint64_t & nb02,
4607
+ constant int64_t & ne10,
4608
+ constant int64_t & ne11,
4609
+ constant int64_t & ne12,
4610
+ constant int64_t & ne13,
4611
+ constant uint64_t & nb10,
4612
+ constant uint64_t & nb11,
4613
+ constant uint64_t & nb12,
4614
+ constant int64_t & ne0,
4615
+ constant int64_t & ne1,
4616
+ constant int64_t & nb1,
4617
+ constant uint & r2,
4618
+ constant uint & r3,
4619
+ constant int & idx,
4620
+ device const char * src00,
4621
+ device const char * src01,
4622
+ device const char * src02,
4623
+ device const char * src03,
4624
+ device const char * src04,
4625
+ device const char * src05,
4626
+ device const char * src06,
4627
+ device const char * src07,
4628
+ uint3 tgpig[[threadgroup_position_in_grid]],
4629
+ uint tiitg[[thread_index_in_threadgroup]],
4630
+ uint tiisg[[thread_index_in_simdgroup]],
4631
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4632
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4633
+
4634
+ const int64_t bid = tgpig.z/(ne12*ne13);
4635
+
4636
+ tgpig.z = tgpig.z%(ne12*ne13);
4637
+
4638
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4639
+
4640
+ kernel_mul_mv_q2_K_f32_impl(
4641
+ src0[id],
4642
+ (device const float *) (src1 + bid*nb11),
4643
+ (device float *) ( dst + bid*nb1),
4644
+ ne00,
4645
+ ne01,
4646
+ ne02,
4647
+ ne10,
4648
+ ne12,
4649
+ ne0,
4650
+ ne1,
4651
+ r2,
4652
+ r3,
4653
+ tgpig,
4654
+ tiisg,
4655
+ sgitg);
4656
+ }
4657
+
4658
+ [[host_name("kernel_mul_mv_id_q3_K_f32")]]
4659
+ kernel void kernel_mul_mv_id_q3_K_f32(
4660
+ device const char * ids,
4661
+ device const char * src1,
4662
+ device uchar * dst,
4663
+ constant int64_t & nbi1,
4664
+ constant int64_t & ne00,
4665
+ constant int64_t & ne01,
4666
+ constant int64_t & ne02,
4667
+ constant uint64_t & nb00,
4668
+ constant uint64_t & nb01,
4669
+ constant uint64_t & nb02,
4670
+ constant int64_t & ne10,
4671
+ constant int64_t & ne11,
4672
+ constant int64_t & ne12,
4673
+ constant int64_t & ne13,
4674
+ constant uint64_t & nb10,
4675
+ constant uint64_t & nb11,
4676
+ constant uint64_t & nb12,
4677
+ constant int64_t & ne0,
4678
+ constant int64_t & ne1,
4679
+ constant int64_t & nb1,
4680
+ constant uint & r2,
4681
+ constant uint & r3,
4682
+ constant int & idx,
4683
+ device const char * src00,
4684
+ device const char * src01,
4685
+ device const char * src02,
4686
+ device const char * src03,
4687
+ device const char * src04,
4688
+ device const char * src05,
4689
+ device const char * src06,
4690
+ device const char * src07,
4691
+ uint3 tgpig[[threadgroup_position_in_grid]],
4692
+ uint tiitg[[thread_index_in_threadgroup]],
4693
+ uint tiisg[[thread_index_in_simdgroup]],
4694
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4695
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4696
+
4697
+ const int64_t bid = tgpig.z/(ne12*ne13);
4698
+
4699
+ tgpig.z = tgpig.z%(ne12*ne13);
4700
+
4701
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4702
+
4703
+ kernel_mul_mv_q3_K_f32_impl(
4704
+ src0[id],
4705
+ (device const float *) (src1 + bid*nb11),
4706
+ (device float *) ( dst + bid*nb1),
4707
+ ne00,
4708
+ ne01,
4709
+ ne02,
4710
+ ne10,
4711
+ ne12,
4712
+ ne0,
4713
+ ne1,
4714
+ r2,
4715
+ r3,
4716
+ tgpig,
4717
+ tiisg,
4718
+ sgitg);
4719
+ }
4720
+
4721
+ [[host_name("kernel_mul_mv_id_q4_K_f32")]]
4722
+ kernel void kernel_mul_mv_id_q4_K_f32(
4723
+ device const char * ids,
4724
+ device const char * src1,
4725
+ device uchar * dst,
4726
+ constant int64_t & nbi1,
4727
+ constant int64_t & ne00,
4728
+ constant int64_t & ne01,
4729
+ constant int64_t & ne02,
4730
+ constant uint64_t & nb00,
4731
+ constant uint64_t & nb01,
4732
+ constant uint64_t & nb02,
4733
+ constant int64_t & ne10,
4734
+ constant int64_t & ne11,
4735
+ constant int64_t & ne12,
4736
+ constant int64_t & ne13,
4737
+ constant uint64_t & nb10,
4738
+ constant uint64_t & nb11,
4739
+ constant uint64_t & nb12,
4740
+ constant int64_t & ne0,
4741
+ constant int64_t & ne1,
4742
+ constant int64_t & nb1,
4743
+ constant uint & r2,
4744
+ constant uint & r3,
4745
+ constant int & idx,
4746
+ device const char * src00,
4747
+ device const char * src01,
4748
+ device const char * src02,
4749
+ device const char * src03,
4750
+ device const char * src04,
4751
+ device const char * src05,
4752
+ device const char * src06,
4753
+ device const char * src07,
4754
+ uint3 tgpig[[threadgroup_position_in_grid]],
4755
+ uint tiitg[[thread_index_in_threadgroup]],
4756
+ uint tiisg[[thread_index_in_simdgroup]],
4757
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4758
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4759
+
4760
+ const int64_t bid = tgpig.z/(ne12*ne13);
4761
+
4762
+ tgpig.z = tgpig.z%(ne12*ne13);
4763
+
4764
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4765
+
4766
+ kernel_mul_mv_q4_K_f32_impl(
4767
+ src0[id],
4768
+ (device const float *) (src1 + bid*nb11),
4769
+ (device float *) ( dst + bid*nb1),
4770
+ ne00,
4771
+ ne01,
4772
+ ne02,
4773
+ ne10,
4774
+ ne12,
4775
+ ne0,
4776
+ ne1,
4777
+ r2,
4778
+ r3,
4779
+ tgpig,
4780
+ tiisg,
4781
+ sgitg);
4782
+ }
4783
+
4784
+ [[host_name("kernel_mul_mv_id_q5_K_f32")]]
4785
+ kernel void kernel_mul_mv_id_q5_K_f32(
4786
+ device const char * ids,
4787
+ device const char * src1,
4788
+ device uchar * dst,
4789
+ constant int64_t & nbi1,
4790
+ constant int64_t & ne00,
4791
+ constant int64_t & ne01,
4792
+ constant int64_t & ne02,
4793
+ constant uint64_t & nb00,
4794
+ constant uint64_t & nb01,
4795
+ constant uint64_t & nb02,
4796
+ constant int64_t & ne10,
4797
+ constant int64_t & ne11,
4798
+ constant int64_t & ne12,
4799
+ constant int64_t & ne13,
4800
+ constant uint64_t & nb10,
4801
+ constant uint64_t & nb11,
4802
+ constant uint64_t & nb12,
4803
+ constant int64_t & ne0,
4804
+ constant int64_t & ne1,
4805
+ constant int64_t & nb1,
4806
+ constant uint & r2,
4807
+ constant uint & r3,
4808
+ constant int & idx,
4809
+ device const char * src00,
4810
+ device const char * src01,
4811
+ device const char * src02,
4812
+ device const char * src03,
4813
+ device const char * src04,
4814
+ device const char * src05,
4815
+ device const char * src06,
4816
+ device const char * src07,
4817
+ uint3 tgpig[[threadgroup_position_in_grid]],
4818
+ uint tiitg[[thread_index_in_threadgroup]],
4819
+ uint tiisg[[thread_index_in_simdgroup]],
4820
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4821
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4822
+
4823
+ const int64_t bid = tgpig.z/(ne12*ne13);
4824
+
4825
+ tgpig.z = tgpig.z%(ne12*ne13);
4826
+
4827
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4828
+
4829
+ kernel_mul_mv_q5_K_f32_impl(
4830
+ src0[id],
4831
+ (device const float *) (src1 + bid*nb11),
4832
+ (device float *) ( dst + bid*nb1),
4833
+ ne00,
4834
+ ne01,
4835
+ ne02,
4836
+ ne10,
4837
+ ne12,
4838
+ ne0,
4839
+ ne1,
4840
+ r2,
4841
+ r3,
4842
+ tgpig,
4843
+ tiisg,
4844
+ sgitg);
4845
+ }
4846
+
4847
+ [[host_name("kernel_mul_mv_id_q6_K_f32")]]
4848
+ kernel void kernel_mul_mv_id_q6_K_f32(
4849
+ device const char * ids,
4850
+ device const char * src1,
4851
+ device uchar * dst,
4852
+ constant int64_t & nbi1,
4853
+ constant int64_t & ne00,
4854
+ constant int64_t & ne01,
4855
+ constant int64_t & ne02,
4856
+ constant uint64_t & nb00,
4857
+ constant uint64_t & nb01,
4858
+ constant uint64_t & nb02,
4859
+ constant int64_t & ne10,
4860
+ constant int64_t & ne11,
4861
+ constant int64_t & ne12,
4862
+ constant int64_t & ne13,
4863
+ constant uint64_t & nb10,
4864
+ constant uint64_t & nb11,
4865
+ constant uint64_t & nb12,
4866
+ constant int64_t & ne0,
4867
+ constant int64_t & ne1,
4868
+ constant int64_t & nb1,
4869
+ constant uint & r2,
4870
+ constant uint & r3,
4871
+ constant int & idx,
4872
+ device const char * src00,
4873
+ device const char * src01,
4874
+ device const char * src02,
4875
+ device const char * src03,
4876
+ device const char * src04,
4877
+ device const char * src05,
4878
+ device const char * src06,
4879
+ device const char * src07,
4880
+ uint3 tgpig[[threadgroup_position_in_grid]],
4881
+ uint tiitg[[thread_index_in_threadgroup]],
4882
+ uint tiisg[[thread_index_in_simdgroup]],
4883
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4884
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4885
+
4886
+ const int64_t bid = tgpig.z/(ne12*ne13);
4887
+
4888
+ tgpig.z = tgpig.z%(ne12*ne13);
4889
+
4890
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4891
+
4892
+ kernel_mul_mv_q6_K_f32_impl(
4893
+ src0[id],
4894
+ (device const float *) (src1 + bid*nb11),
4895
+ (device float *) ( dst + bid*nb1),
4896
+ ne00,
4897
+ ne01,
4898
+ ne02,
4899
+ ne10,
4900
+ ne12,
4901
+ ne0,
4902
+ ne1,
4903
+ r2,
4904
+ r3,
4905
+ tgpig,
4906
+ tiisg,
4907
+ sgitg);
4908
+ }
ggml-quants.c CHANGED
@@ -3114,7 +3114,7 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
3114
 
3115
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
3116
 
3117
- // These tempory registers are for masking and shift operations
3118
  vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3119
  vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
3120
 
@@ -4757,7 +4757,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4757
 
4758
  vl = 16;
4759
 
4760
- // retreive lane to multiply with scale
4761
  vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
4762
  vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
4763
  vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
 
3114
 
3115
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
3116
 
3117
+ // These temporary registers are for masking and shift operations
3118
  vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3119
  vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
3120
 
 
4757
 
4758
  vl = 16;
4759
 
4760
+ // retrieve lane to multiply with scale
4761
  vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
4762
  vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
4763
  vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
ggml.c CHANGED
@@ -1,4 +1,4 @@
1
- #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
2
  #define _USE_MATH_DEFINES // For M_PI on MSVC
3
 
4
  #include "ggml-impl.h"
@@ -33,7 +33,7 @@
33
  // we should just be careful :)
34
  #pragma warning(disable: 4244 4267)
35
 
36
- // disable POSIX deprecation warnigns
37
  // these functions are never going away, anyway
38
  #pragma warning(disable: 4996)
39
  #endif
@@ -1395,7 +1395,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
1395
  inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
1396
  inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
1397
  inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
1398
- inline static void ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; }
1399
 
1400
  static const float GELU_COEF_A = 0.044715f;
1401
  static const float GELU_QUICK_COEF = -1.702f;
@@ -1623,7 +1623,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
1623
  "POOL_1D",
1624
  "POOL_2D",
1625
  "UPSCALE",
 
1626
  "ARGSORT",
 
1627
 
1628
  "FLASH_ATTN",
1629
  "FLASH_FF",
@@ -1650,7 +1652,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
1650
  "CROSS_ENTROPY_LOSS_BACK",
1651
  };
1652
 
1653
- static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
1654
 
1655
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1656
  "none",
@@ -1707,7 +1709,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1707
  "pool_1d(x)",
1708
  "pool_2d(x)",
1709
  "upscale(x)",
 
1710
  "argsort(x)",
 
1711
 
1712
  "flash_attn(x)",
1713
  "flash_ff(x)",
@@ -1734,7 +1738,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1734
  "cross_entropy_loss_back(x,y)",
1735
  };
1736
 
1737
- static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
1738
 
1739
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1740
 
@@ -1750,17 +1754,16 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
1750
  "GELU",
1751
  "GELU_QUICK",
1752
  "SILU",
1753
- "LEAKY",
1754
  };
1755
 
1756
- static_assert(GGML_UNARY_OP_COUNT == 11, "GGML_UNARY_OP_COUNT != 11");
1757
 
1758
 
1759
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
1760
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
1761
 
1762
  // WARN:
1763
- // Mis-confguration can lead to problem that's hard to reason about:
1764
  // * At best it crash or talks nosense.
1765
  // * At worst it talks slightly difference but hard to perceive.
1766
  //
@@ -3830,12 +3833,25 @@ struct ggml_tensor * ggml_relu_inplace(
3830
  return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
3831
  }
3832
 
3833
- // ggml_leaky
3834
 
3835
- struct ggml_tensor * ggml_leaky(
3836
  struct ggml_context * ctx,
3837
- struct ggml_tensor * a) {
3838
- return ggml_unary(ctx, a, GGML_UNARY_OP_LEAKY);
 
 
 
 
 
 
 
 
 
 
 
 
 
3839
  }
3840
 
3841
  // ggml_gelu
@@ -4022,8 +4038,9 @@ static struct ggml_tensor * ggml_group_norm_impl(
4022
 
4023
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4024
 
4025
- result->op = GGML_OP_GROUP_NORM;
4026
  result->op_params[0] = n_groups;
 
 
4027
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4028
  result->src[0] = a;
4029
  result->src[1] = NULL; // TODO: maybe store epsilon here?
@@ -4075,17 +4092,18 @@ struct ggml_tensor * ggml_mul_mat(
4075
 
4076
  struct ggml_tensor * ggml_mul_mat_id(
4077
  struct ggml_context * ctx,
4078
- struct ggml_tensor * as[],
 
4079
  struct ggml_tensor * ids,
4080
  int id,
4081
  struct ggml_tensor * b) {
4082
 
4083
- int64_t n_as = ids->ne[0];
4084
-
4085
  GGML_ASSERT(ids->type == GGML_TYPE_I32);
4086
- GGML_ASSERT(ggml_is_vector(ids));
 
 
4087
  GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
4088
- GGML_ASSERT(id >= 0 && id < n_as);
4089
 
4090
  bool is_node = false;
4091
 
@@ -4097,13 +4115,14 @@ struct ggml_tensor * ggml_mul_mat_id(
4097
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
4098
 
4099
  ggml_set_op_params_i32(result, 0, id);
 
4100
 
4101
  result->op = GGML_OP_MUL_MAT_ID;
4102
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4103
  result->src[0] = ids;
4104
  result->src[1] = b;
4105
 
4106
- for (int64_t i = 0; i < n_as; i++) {
4107
  struct ggml_tensor * a = as[i];
4108
  GGML_ASSERT(ggml_are_same_shape(as[0], a));
4109
  GGML_ASSERT(ggml_can_mul_mat(a, b));
@@ -4731,7 +4750,9 @@ struct ggml_tensor * ggml_get_rows(
4731
  struct ggml_context * ctx,
4732
  struct ggml_tensor * a,
4733
  struct ggml_tensor * b) {
4734
- GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
 
 
4735
 
4736
  bool is_node = false;
4737
 
@@ -4741,7 +4762,7 @@ struct ggml_tensor * ggml_get_rows(
4741
 
4742
  // TODO: implement non F32 return
4743
  //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
4744
- struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]);
4745
 
4746
  result->op = GGML_OP_GET_ROWS;
4747
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5519,6 +5540,30 @@ static struct ggml_tensor * ggml_upscale_impl(
5519
  return result;
5520
  }
5521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5522
  struct ggml_tensor * ggml_upscale(
5523
  struct ggml_context * ctx,
5524
  struct ggml_tensor * a,
@@ -7520,7 +7565,7 @@ static void ggml_compute_forward_acc_f32(
7520
  GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
7521
 
7522
  // view src0 and dst with these strides and data offset inbytes during acc
7523
- // nb0 is implicitely element_size because src0 and dst are contiguous
7524
  size_t nb1 = ((int32_t *) dst->op_params)[0];
7525
  size_t nb2 = ((int32_t *) dst->op_params)[1];
7526
  size_t nb3 = ((int32_t *) dst->op_params)[2];
@@ -7714,8 +7759,10 @@ static void ggml_compute_forward_mul_f32(
7714
  const int ith = params->ith;
7715
  const int nth = params->nth;
7716
 
 
7717
  #ifdef GGML_USE_CLBLAST
7718
  if (src1->backend == GGML_BACKEND_GPU) {
 
7719
  if (ith == 0) {
7720
  ggml_cl_mul(src0, src1, dst);
7721
  }
@@ -8981,10 +9028,9 @@ static void ggml_compute_forward_silu(
8981
  } break;
8982
  }
8983
  }
 
8984
 
8985
- // ggml_compute_forward_leaky
8986
-
8987
- static void ggml_compute_forward_leaky_f32(
8988
  const struct ggml_compute_params * params,
8989
  const struct ggml_tensor * src0,
8990
  struct ggml_tensor * dst) {
@@ -8998,24 +9044,27 @@ static void ggml_compute_forward_leaky_f32(
8998
  const int n = ggml_nrows(src0);
8999
  const int nc = src0->ne[0];
9000
 
 
 
 
9001
  assert(dst->nb[0] == sizeof(float));
9002
  assert(src0->nb[0] == sizeof(float));
9003
 
9004
  for (int i = 0; i < n; i++) {
9005
- ggml_vec_leaky_f32(nc,
9006
  (float *) ((char *) dst->data + i*( dst->nb[1])),
9007
- (float *) ((char *) src0->data + i*(src0->nb[1])));
9008
  }
9009
  }
9010
 
9011
- static void ggml_compute_forward_leaky(
9012
  const struct ggml_compute_params * params,
9013
  const struct ggml_tensor * src0,
9014
  struct ggml_tensor * dst) {
9015
  switch (src0->type) {
9016
  case GGML_TYPE_F32:
9017
  {
9018
- ggml_compute_forward_leaky_f32(params, src0, dst);
9019
  } break;
9020
  default:
9021
  {
@@ -9504,8 +9553,11 @@ static bool ggml_compute_forward_mul_mat_use_blas(
9504
  const int64_t ne0 = dst->ne[0];
9505
  const int64_t ne1 = dst->ne[1];
9506
 
 
 
9507
  // TODO: find the optimal values for these
9508
- if (ggml_is_contiguous(src0) &&
 
9509
  ggml_is_contiguous(src1) &&
9510
  //src0->type == GGML_TYPE_F32 &&
9511
  src1->type == GGML_TYPE_F32 &&
@@ -9519,11 +9571,16 @@ static bool ggml_compute_forward_mul_mat_use_blas(
9519
  }
9520
  #endif
9521
 
 
 
 
 
9522
  static void ggml_compute_forward_mul_mat(
9523
  const struct ggml_compute_params * params,
9524
  const struct ggml_tensor * src0,
9525
  const struct ggml_tensor * src1,
9526
- struct ggml_tensor * dst) {
 
9527
  int64_t t0 = ggml_perf_time_us();
9528
  UNUSED(t0);
9529
 
@@ -9591,10 +9648,9 @@ static void ggml_compute_forward_mul_mat(
9591
  const int64_t i03 = i13/r3;
9592
  const int64_t i02 = i12/r2;
9593
 
9594
- const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9595
- const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
9596
-
9597
- float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
9598
 
9599
  if (type != GGML_TYPE_F32) {
9600
  float * const wdata = params->wdata;
@@ -9611,10 +9667,10 @@ static void ggml_compute_forward_mul_mat(
9611
  }
9612
 
9613
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
9614
- ne11, ne01, ne10,
9615
- 1.0f, y, ne10,
9616
- x, ne00,
9617
- 0.0f, d, ne01);
9618
  }
9619
  }
9620
 
@@ -9630,6 +9686,7 @@ static void ggml_compute_forward_mul_mat(
9630
  const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
9631
 
9632
  assert(params->wsize >= ne11*ne12*ne13*row_size);
 
9633
 
9634
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
9635
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -9652,7 +9709,7 @@ static void ggml_compute_forward_mul_mat(
9652
  const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
9653
 
9654
  const int64_t nr0 = ne01; // src0 rows
9655
- const int64_t nr1 = ne11*ne12*ne13; // src1 rows
9656
 
9657
  //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
9658
 
@@ -9694,9 +9751,9 @@ static void ggml_compute_forward_mul_mat(
9694
  for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
9695
  for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
9696
  for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
9697
- const int64_t i13 = (ir1/(ne12*ne11));
9698
- const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
9699
- const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
9700
 
9701
  // broadcast src0 into src1
9702
  const int64_t i03 = i13/r3;
@@ -9736,20 +9793,28 @@ static void ggml_compute_forward_mul_mat(
9736
 
9737
  static void ggml_compute_forward_mul_mat_id(
9738
  const struct ggml_compute_params * params,
 
 
9739
  struct ggml_tensor * dst) {
9740
 
9741
- const struct ggml_tensor * ids = dst->src[0];
9742
- const struct ggml_tensor * src1 = dst->src[1];
9743
-
9744
- const int id = ggml_get_op_params_i32(dst, 0);
 
9745
 
9746
- const int a_id = ((int32_t *)ids->data)[id];
 
 
9747
 
9748
- GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
 
9749
 
9750
- const struct ggml_tensor * src0 = dst->src[a_id + 2];
9751
 
9752
- ggml_compute_forward_mul_mat(params, src0, src1, dst);
 
 
9753
  }
9754
 
9755
  // ggml_compute_forward_out_prod
@@ -10161,7 +10226,7 @@ static void ggml_compute_forward_set_f32(
10161
  GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
10162
 
10163
  // view src0 and dst with these strides and data offset inbytes during set
10164
- // nb0 is implicitely element_size because src0 and dst are contiguous
10165
  size_t nb1 = ((int32_t *) dst->op_params)[0];
10166
  size_t nb2 = ((int32_t *) dst->op_params)[1];
10167
  size_t nb3 = ((int32_t *) dst->op_params)[2];
@@ -10325,21 +10390,30 @@ static void ggml_compute_forward_get_rows_q(
10325
  return;
10326
  }
10327
 
10328
- const int nc = src0->ne[0];
10329
- const int nr = ggml_nelements(src1);
 
 
 
10330
  const enum ggml_type type = src0->type;
10331
  ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
10332
 
10333
- assert( dst->ne[0] == nc);
10334
- assert( dst->ne[1] == nr);
10335
- assert(src0->nb[0] == ggml_type_size(type));
 
10336
 
10337
- for (int i = 0; i < nr; ++i) {
10338
- const int r = ((int32_t *) src1->data)[i];
 
 
 
10339
 
10340
- dequantize_row_q(
10341
- (const void *) ((char *) src0->data + r*src0->nb[1]),
10342
- (float *) ((char *) dst->data + i*dst->nb[1]), nc);
 
 
10343
  }
10344
  }
10345
 
@@ -10354,19 +10428,26 @@ static void ggml_compute_forward_get_rows_f16(
10354
  return;
10355
  }
10356
 
10357
- const int nc = src0->ne[0];
10358
- const int nr = ggml_nelements(src1);
10359
 
10360
- assert( dst->ne[0] == nc);
10361
- assert( dst->ne[1] == nr);
10362
- assert(src0->nb[0] == sizeof(ggml_fp16_t));
10363
 
10364
- for (int i = 0; i < nr; ++i) {
10365
- const int r = ((int32_t *) src1->data)[i];
 
 
10366
 
10367
- for (int j = 0; j < nc; ++j) {
10368
- ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
10369
- ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
 
 
 
 
 
 
 
10370
  }
10371
  }
10372
  }
@@ -10382,19 +10463,27 @@ static void ggml_compute_forward_get_rows_f32(
10382
  return;
10383
  }
10384
 
10385
- const int nc = src0->ne[0];
10386
- const int nr = ggml_nelements(src1);
10387
 
10388
- assert( dst->ne[0] == nc);
10389
- assert( dst->ne[1] == nr);
10390
- assert(src0->nb[0] == sizeof(float));
10391
 
10392
- for (int i = 0; i < nr; ++i) {
10393
- const int r = ((int32_t *) src1->data)[i];
 
 
10394
 
10395
- ggml_vec_cpy_f32(nc,
10396
- (float *) ((char *) dst->data + i*dst->nb[1]),
10397
- (float *) ((char *) src0->data + r*src0->nb[1]));
 
 
 
 
 
 
 
 
10398
  }
10399
  }
10400
 
@@ -12114,6 +12203,7 @@ static void ggml_compute_forward_upscale_f32(
12114
  GGML_ASSERT(src0->nb[0] == sizeof(float));
12115
 
12116
  const int ith = params->ith;
 
12117
 
12118
  GGML_TENSOR_UNARY_OP_LOCALS
12119
 
@@ -12121,16 +12211,17 @@ static void ggml_compute_forward_upscale_f32(
12121
 
12122
  // TODO: optimize
12123
 
12124
- for (int i03 = 0; i03 < ne03; i03++) {
12125
- for (int i02 = ith; i02 < ne02; i02++) {
12126
- for (int m = 0; m < dst->ne[1]; m++) {
12127
- int i01 = m / scale_factor;
12128
- for (int n = 0; n < dst->ne[0]; n++) {
12129
- int i00 = n / scale_factor;
12130
-
12131
- const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03);
12132
 
12133
- float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]);
 
12134
 
12135
  *y = *x;
12136
  }
@@ -12155,6 +12246,64 @@ static void ggml_compute_forward_upscale(
12155
  }
12156
  }
12157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12158
  // ggml_compute_forward_argsort
12159
 
12160
  static void ggml_compute_forward_argsort_f32(
@@ -13362,10 +13511,6 @@ static void ggml_compute_forward_unary(
13362
  {
13363
  ggml_compute_forward_silu(params, src0, dst);
13364
  } break;
13365
- case GGML_UNARY_OP_LEAKY:
13366
- {
13367
- ggml_compute_forward_leaky(params, src0, dst);
13368
- } break;
13369
  default:
13370
  {
13371
  GGML_ASSERT(false);
@@ -14037,11 +14182,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14037
  } break;
14038
  case GGML_OP_MUL_MAT:
14039
  {
14040
- ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
14041
  } break;
14042
  case GGML_OP_MUL_MAT_ID:
14043
  {
14044
- ggml_compute_forward_mul_mat_id(params, tensor);
14045
  } break;
14046
  case GGML_OP_OUT_PROD:
14047
  {
@@ -14147,10 +14292,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14147
  {
14148
  ggml_compute_forward_upscale(params, tensor->src[0], tensor);
14149
  } break;
 
 
 
 
14150
  case GGML_OP_ARGSORT:
14151
  {
14152
  ggml_compute_forward_argsort(params, tensor->src[0], tensor);
14153
  } break;
 
 
 
 
14154
  case GGML_OP_FLASH_ATTN:
14155
  {
14156
  const int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -14475,7 +14628,7 @@ void ggml_build_backward_gradient_checkpointing(
14475
  // insert new tensors recomputing src, reusing already made replacements,
14476
  // remember replacements: remember new tensors with mapping from corresponding gf nodes
14477
  // recurse for input tensors,
14478
- // unless (i.e. terminating when) input tensors are replacments (like checkpoints)
14479
  node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
14480
  }
14481
  // insert rewritten backward node with replacements made into resulting backward graph gb
@@ -15143,10 +15296,18 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15143
  {
15144
  GGML_ASSERT(false); // TODO: not implemented
15145
  } break;
 
 
 
 
15146
  case GGML_OP_ARGSORT:
15147
  {
15148
  GGML_ASSERT(false); // TODO: not implemented
15149
  } break;
 
 
 
 
15150
  case GGML_OP_FLASH_ATTN:
15151
  {
15152
  struct ggml_tensor * flash_grad = NULL;
@@ -15752,6 +15913,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
15752
  case GGML_OP_ARGMAX:
15753
  case GGML_OP_REPEAT:
15754
  case GGML_OP_REPEAT_BACK:
 
15755
  {
15756
  n_tasks = 1;
15757
  } break;
@@ -15764,7 +15926,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
15764
  case GGML_UNARY_OP_TANH:
15765
  case GGML_UNARY_OP_ELU:
15766
  case GGML_UNARY_OP_RELU:
15767
- case GGML_UNARY_OP_LEAKY:
15768
  {
15769
  n_tasks = 1;
15770
  } break;
@@ -15883,6 +16044,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
15883
  {
15884
  n_tasks = n_threads;
15885
  } break;
 
 
 
 
15886
  case GGML_OP_ARGSORT:
15887
  {
15888
  n_tasks = n_threads;
 
1
+ #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
2
  #define _USE_MATH_DEFINES // For M_PI on MSVC
3
 
4
  #include "ggml-impl.h"
 
33
  // we should just be careful :)
34
  #pragma warning(disable: 4244 4267)
35
 
36
+ // disable POSIX deprecation warnings
37
  // these functions are never going away, anyway
38
  #pragma warning(disable: 4996)
39
  #endif
 
1395
  inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
1396
  inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
1397
  inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
1398
+ inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
1399
 
1400
  static const float GELU_COEF_A = 0.044715f;
1401
  static const float GELU_QUICK_COEF = -1.702f;
 
1623
  "POOL_1D",
1624
  "POOL_2D",
1625
  "UPSCALE",
1626
+ "PAD",
1627
  "ARGSORT",
1628
+ "LEAKY_RELU",
1629
 
1630
  "FLASH_ATTN",
1631
  "FLASH_FF",
 
1652
  "CROSS_ENTROPY_LOSS_BACK",
1653
  };
1654
 
1655
+ static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
1656
 
1657
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1658
  "none",
 
1709
  "pool_1d(x)",
1710
  "pool_2d(x)",
1711
  "upscale(x)",
1712
+ "pad(x)",
1713
  "argsort(x)",
1714
+ "leaky_relu(x)",
1715
 
1716
  "flash_attn(x)",
1717
  "flash_ff(x)",
 
1738
  "cross_entropy_loss_back(x,y)",
1739
  };
1740
 
1741
+ static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
1742
 
1743
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1744
 
 
1754
  "GELU",
1755
  "GELU_QUICK",
1756
  "SILU",
 
1757
  };
1758
 
1759
+ static_assert(GGML_UNARY_OP_COUNT == 10, "GGML_UNARY_OP_COUNT != 10");
1760
 
1761
 
1762
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
1763
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
1764
 
1765
  // WARN:
1766
+ // Mis-configuration can lead to problem that's hard to reason about:
1767
  // * At best it crash or talks nosense.
1768
  // * At worst it talks slightly difference but hard to perceive.
1769
  //
 
3833
  return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
3834
  }
3835
 
3836
+ // ggml_leaky_relu
3837
 
3838
+ struct ggml_tensor * ggml_leaky_relu(
3839
  struct ggml_context * ctx,
3840
+ struct ggml_tensor * a, float negative_slope, bool inplace) {
3841
+ bool is_node = false;
3842
+
3843
+ if (!inplace && (a->grad)) {
3844
+ is_node = true;
3845
+ }
3846
+
3847
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3848
+ ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
3849
+
3850
+ result->op = GGML_OP_LEAKY_RELU;
3851
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
3852
+ result->src[0] = a;
3853
+
3854
+ return result;
3855
  }
3856
 
3857
  // ggml_gelu
 
4038
 
4039
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4040
 
 
4041
  result->op_params[0] = n_groups;
4042
+
4043
+ result->op = GGML_OP_GROUP_NORM;
4044
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4045
  result->src[0] = a;
4046
  result->src[1] = NULL; // TODO: maybe store epsilon here?
 
4092
 
4093
  struct ggml_tensor * ggml_mul_mat_id(
4094
  struct ggml_context * ctx,
4095
+ struct ggml_tensor * const as[],
4096
+ int n_as,
4097
  struct ggml_tensor * ids,
4098
  int id,
4099
  struct ggml_tensor * b) {
4100
 
 
 
4101
  GGML_ASSERT(ids->type == GGML_TYPE_I32);
4102
+ GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
4103
+ GGML_ASSERT(ids->ne[1] == b->ne[1]);
4104
+ GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
4105
  GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
4106
+ GGML_ASSERT(id >= 0 && id < ids->ne[0]);
4107
 
4108
  bool is_node = false;
4109
 
 
4115
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
4116
 
4117
  ggml_set_op_params_i32(result, 0, id);
4118
+ ggml_set_op_params_i32(result, 1, n_as);
4119
 
4120
  result->op = GGML_OP_MUL_MAT_ID;
4121
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4122
  result->src[0] = ids;
4123
  result->src[1] = b;
4124
 
4125
+ for (int i = 0; i < n_as; i++) {
4126
  struct ggml_tensor * a = as[i];
4127
  GGML_ASSERT(ggml_are_same_shape(as[0], a));
4128
  GGML_ASSERT(ggml_can_mul_mat(a, b));
 
4750
  struct ggml_context * ctx,
4751
  struct ggml_tensor * a,
4752
  struct ggml_tensor * b) {
4753
+ GGML_ASSERT(a->ne[2] == b->ne[1]);
4754
+ GGML_ASSERT(b->ne[3] == 1);
4755
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
4756
 
4757
  bool is_node = false;
4758
 
 
4762
 
4763
  // TODO: implement non F32 return
4764
  //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
4765
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
4766
 
4767
  result->op = GGML_OP_GET_ROWS;
4768
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
 
5540
  return result;
5541
  }
5542
 
5543
+ struct ggml_tensor * ggml_pad(
5544
+ struct ggml_context * ctx,
5545
+ struct ggml_tensor * a,
5546
+ int p0, int p1, int p2, int p3) {
5547
+ bool is_node = false;
5548
+
5549
+ if (a->grad) {
5550
+ GGML_ASSERT(false); // TODO: implement backward
5551
+ is_node = true;
5552
+ }
5553
+
5554
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
5555
+ a->ne[0] + p0,
5556
+ a->ne[1] + p1,
5557
+ a->ne[2] + p2,
5558
+ a->ne[3] + p3);
5559
+
5560
+ result->op = GGML_OP_PAD;
5561
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5562
+ result->src[0] = a;
5563
+
5564
+ return result;
5565
+ }
5566
+
5567
  struct ggml_tensor * ggml_upscale(
5568
  struct ggml_context * ctx,
5569
  struct ggml_tensor * a,
 
7565
  GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
7566
 
7567
  // view src0 and dst with these strides and data offset inbytes during acc
7568
+ // nb0 is implicitly element_size because src0 and dst are contiguous
7569
  size_t nb1 = ((int32_t *) dst->op_params)[0];
7570
  size_t nb2 = ((int32_t *) dst->op_params)[1];
7571
  size_t nb3 = ((int32_t *) dst->op_params)[2];
 
7759
  const int ith = params->ith;
7760
  const int nth = params->nth;
7761
 
7762
+ // TODO: OpenCL kernel support broadcast
7763
  #ifdef GGML_USE_CLBLAST
7764
  if (src1->backend == GGML_BACKEND_GPU) {
7765
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
7766
  if (ith == 0) {
7767
  ggml_cl_mul(src0, src1, dst);
7768
  }
 
9028
  } break;
9029
  }
9030
  }
9031
+ // ggml_compute_forward_leaky_relu
9032
 
9033
+ static void ggml_compute_forward_leaky_relu_f32(
 
 
9034
  const struct ggml_compute_params * params,
9035
  const struct ggml_tensor * src0,
9036
  struct ggml_tensor * dst) {
 
9044
  const int n = ggml_nrows(src0);
9045
  const int nc = src0->ne[0];
9046
 
9047
+ float negative_slope;
9048
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
9049
+
9050
  assert(dst->nb[0] == sizeof(float));
9051
  assert(src0->nb[0] == sizeof(float));
9052
 
9053
  for (int i = 0; i < n; i++) {
9054
+ ggml_vec_leaky_relu_f32(nc,
9055
  (float *) ((char *) dst->data + i*( dst->nb[1])),
9056
+ (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
9057
  }
9058
  }
9059
 
9060
+ static void ggml_compute_forward_leaky_relu(
9061
  const struct ggml_compute_params * params,
9062
  const struct ggml_tensor * src0,
9063
  struct ggml_tensor * dst) {
9064
  switch (src0->type) {
9065
  case GGML_TYPE_F32:
9066
  {
9067
+ ggml_compute_forward_leaky_relu_f32(params, src0, dst);
9068
  } break;
9069
  default:
9070
  {
 
9553
  const int64_t ne0 = dst->ne[0];
9554
  const int64_t ne1 = dst->ne[1];
9555
 
9556
+ // NOTE: with GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float)
9557
+ // all the experts for each batch element and the processing would become incredibly slow
9558
  // TODO: find the optimal values for these
9559
+ if (dst->op != GGML_OP_MUL_MAT_ID &&
9560
+ ggml_is_contiguous(src0) &&
9561
  ggml_is_contiguous(src1) &&
9562
  //src0->type == GGML_TYPE_F32 &&
9563
  src1->type == GGML_TYPE_F32 &&
 
9571
  }
9572
  #endif
9573
 
9574
+ // off1 = offset in i11 and i1
9575
+ // cne1 = ne11 and ne1
9576
+ // in a normal matrix multiplication, off1 = 0 and cne1 = ne1
9577
+ // during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
9578
  static void ggml_compute_forward_mul_mat(
9579
  const struct ggml_compute_params * params,
9580
  const struct ggml_tensor * src0,
9581
  const struct ggml_tensor * src1,
9582
+ struct ggml_tensor * dst,
9583
+ int64_t off1, int64_t cne1) {
9584
  int64_t t0 = ggml_perf_time_us();
9585
  UNUSED(t0);
9586
 
 
9648
  const int64_t i03 = i13/r3;
9649
  const int64_t i02 = i12/r2;
9650
 
9651
+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9652
+ const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
9653
+ float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3);
 
9654
 
9655
  if (type != GGML_TYPE_F32) {
9656
  float * const wdata = params->wdata;
 
9667
  }
9668
 
9669
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
9670
+ cne1, ne01, ne10,
9671
+ 1.0f, y, ne10,
9672
+ x, ne00,
9673
+ 0.0f, d, ne01);
9674
  }
9675
  }
9676
 
 
9686
  const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
9687
 
9688
  assert(params->wsize >= ne11*ne12*ne13*row_size);
9689
+ assert(src1->type == GGML_TYPE_F32);
9690
 
9691
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
9692
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
 
9709
  const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
9710
 
9711
  const int64_t nr0 = ne01; // src0 rows
9712
+ const int64_t nr1 = cne1*ne12*ne13; // src1 rows
9713
 
9714
  //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
9715
 
 
9751
  for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
9752
  for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
9753
  for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
9754
+ const int64_t i13 = (ir1/(ne12*cne1));
9755
+ const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
9756
+ const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
9757
 
9758
  // broadcast src0 into src1
9759
  const int64_t i03 = i13/r3;
 
9793
 
9794
  static void ggml_compute_forward_mul_mat_id(
9795
  const struct ggml_compute_params * params,
9796
+ const struct ggml_tensor * src0,
9797
+ const struct ggml_tensor * src1,
9798
  struct ggml_tensor * dst) {
9799
 
9800
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9801
+ // during GGML_TASK_INIT the entire src1 is converted to vec_dot_type
9802
+ ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
9803
+ return;
9804
+ }
9805
 
9806
+ const struct ggml_tensor * ids = src0;
9807
+ const int id = ggml_get_op_params_i32(dst, 0);
9808
+ const int n_as = ggml_get_op_params_i32(dst, 1);
9809
 
9810
+ for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
9811
+ const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
9812
 
9813
+ GGML_ASSERT(row_id >= 0 && row_id < n_as);
9814
 
9815
+ const struct ggml_tensor * src0_row = dst->src[row_id + 2];
9816
+ ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
9817
+ }
9818
  }
9819
 
9820
  // ggml_compute_forward_out_prod
 
10226
  GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
10227
 
10228
  // view src0 and dst with these strides and data offset inbytes during set
10229
+ // nb0 is implicitly element_size because src0 and dst are contiguous
10230
  size_t nb1 = ((int32_t *) dst->op_params)[0];
10231
  size_t nb2 = ((int32_t *) dst->op_params)[1];
10232
  size_t nb3 = ((int32_t *) dst->op_params)[2];
 
10390
  return;
10391
  }
10392
 
10393
+ GGML_TENSOR_BINARY_OP_LOCALS
10394
+
10395
+ const int64_t nc = ne00;
10396
+ const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
10397
+
10398
  const enum ggml_type type = src0->type;
10399
  ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
10400
 
10401
+ assert(ne0 == nc);
10402
+ assert(ne02 == ne11);
10403
+ assert(nb00 == ggml_type_size(type));
10404
+ assert(ggml_nrows(dst) == nr);
10405
 
10406
+ // TODO: multi-thread
10407
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
10408
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
10409
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
10410
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
10411
 
10412
+ dequantize_row_q(
10413
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
10414
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
10415
+ }
10416
+ }
10417
  }
10418
  }
10419
 
 
10428
  return;
10429
  }
10430
 
10431
+ GGML_TENSOR_BINARY_OP_LOCALS
 
10432
 
10433
+ const int64_t nc = ne00;
10434
+ const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
 
10435
 
10436
+ assert(ne0 == nc);
10437
+ assert(ne02 == ne11);
10438
+ assert(nb00 == sizeof(ggml_fp16_t));
10439
+ assert(ggml_nrows(dst) == nr);
10440
 
10441
+ // TODO: multi-thread
10442
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
10443
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
10444
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
10445
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
10446
+
10447
+ ggml_fp16_to_fp32_row(
10448
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
10449
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
10450
+ }
10451
  }
10452
  }
10453
  }
 
10463
  return;
10464
  }
10465
 
10466
+ GGML_TENSOR_BINARY_OP_LOCALS
 
10467
 
10468
+ const int64_t nc = ne00;
10469
+ const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
 
10470
 
10471
+ assert(ne0 == nc);
10472
+ assert(ne02 == ne11);
10473
+ assert(nb00 == sizeof(float));
10474
+ assert(ggml_nrows(dst) == nr);
10475
 
10476
+ // TODO: multi-thread
10477
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
10478
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
10479
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
10480
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
10481
+
10482
+ ggml_vec_cpy_f32(nc,
10483
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
10484
+ (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
10485
+ }
10486
+ }
10487
  }
10488
  }
10489
 
 
12203
  GGML_ASSERT(src0->nb[0] == sizeof(float));
12204
 
12205
  const int ith = params->ith;
12206
+ const int nth = params->nth;
12207
 
12208
  GGML_TENSOR_UNARY_OP_LOCALS
12209
 
 
12211
 
12212
  // TODO: optimize
12213
 
12214
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
12215
+ const int64_t i03 = i3;
12216
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
12217
+ const int64_t i02 = i2;
12218
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
12219
+ const int64_t i01 = i1 / scale_factor;
12220
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
12221
+ const int64_t i00 = i0 / scale_factor;
12222
 
12223
+ const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
12224
+ float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
12225
 
12226
  *y = *x;
12227
  }
 
12246
  }
12247
  }
12248
 
12249
+ // ggml_compute_forward_pad
12250
+
12251
+ static void ggml_compute_forward_pad_f32(
12252
+ const struct ggml_compute_params * params,
12253
+ const struct ggml_tensor * src0,
12254
+ struct ggml_tensor * dst) {
12255
+
12256
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12257
+ return;
12258
+ }
12259
+
12260
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
12261
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
12262
+
12263
+ const int ith = params->ith;
12264
+ const int nth = params->nth;
12265
+
12266
+ GGML_TENSOR_UNARY_OP_LOCALS
12267
+
12268
+ float * dst_ptr = (float *) dst->data;
12269
+
12270
+ // TODO: optimize
12271
+
12272
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
12273
+ for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
12274
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
12275
+ for (int64_t i3 = 0; i3 < ne3; ++i3) {
12276
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
12277
+
12278
+ const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
12279
+
12280
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
12281
+ dst_ptr[dst_idx] = *src_ptr;
12282
+ } else {
12283
+ dst_ptr[dst_idx] = 0;
12284
+ }
12285
+ }
12286
+ }
12287
+ }
12288
+ }
12289
+ }
12290
+
12291
+ static void ggml_compute_forward_pad(
12292
+ const struct ggml_compute_params * params,
12293
+ const struct ggml_tensor * src0,
12294
+ struct ggml_tensor * dst) {
12295
+ switch (src0->type) {
12296
+ case GGML_TYPE_F32:
12297
+ {
12298
+ ggml_compute_forward_pad_f32(params, src0, dst);
12299
+ } break;
12300
+ default:
12301
+ {
12302
+ GGML_ASSERT(false);
12303
+ } break;
12304
+ }
12305
+ }
12306
+
12307
  // ggml_compute_forward_argsort
12308
 
12309
  static void ggml_compute_forward_argsort_f32(
 
13511
  {
13512
  ggml_compute_forward_silu(params, src0, dst);
13513
  } break;
 
 
 
 
13514
  default:
13515
  {
13516
  GGML_ASSERT(false);
 
14182
  } break;
14183
  case GGML_OP_MUL_MAT:
14184
  {
14185
+ ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
14186
  } break;
14187
  case GGML_OP_MUL_MAT_ID:
14188
  {
14189
+ ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor);
14190
  } break;
14191
  case GGML_OP_OUT_PROD:
14192
  {
 
14292
  {
14293
  ggml_compute_forward_upscale(params, tensor->src[0], tensor);
14294
  } break;
14295
+ case GGML_OP_PAD:
14296
+ {
14297
+ ggml_compute_forward_pad(params, tensor->src[0], tensor);
14298
+ } break;
14299
  case GGML_OP_ARGSORT:
14300
  {
14301
  ggml_compute_forward_argsort(params, tensor->src[0], tensor);
14302
  } break;
14303
+ case GGML_OP_LEAKY_RELU:
14304
+ {
14305
+ ggml_compute_forward_leaky_relu(params, tensor->src[0], tensor);
14306
+ } break;
14307
  case GGML_OP_FLASH_ATTN:
14308
  {
14309
  const int32_t t = ggml_get_op_params_i32(tensor, 0);
 
14628
  // insert new tensors recomputing src, reusing already made replacements,
14629
  // remember replacements: remember new tensors with mapping from corresponding gf nodes
14630
  // recurse for input tensors,
14631
+ // unless (i.e. terminating when) input tensors are replacements (like checkpoints)
14632
  node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
14633
  }
14634
  // insert rewritten backward node with replacements made into resulting backward graph gb
 
15296
  {
15297
  GGML_ASSERT(false); // TODO: not implemented
15298
  } break;
15299
+ case GGML_OP_PAD:
15300
+ {
15301
+ GGML_ASSERT(false); // TODO: not implemented
15302
+ } break;
15303
  case GGML_OP_ARGSORT:
15304
  {
15305
  GGML_ASSERT(false); // TODO: not implemented
15306
  } break;
15307
+ case GGML_OP_LEAKY_RELU:
15308
+ {
15309
+ GGML_ASSERT(false); // TODO: not implemented
15310
+ } break;
15311
  case GGML_OP_FLASH_ATTN:
15312
  {
15313
  struct ggml_tensor * flash_grad = NULL;
 
15913
  case GGML_OP_ARGMAX:
15914
  case GGML_OP_REPEAT:
15915
  case GGML_OP_REPEAT_BACK:
15916
+ case GGML_OP_LEAKY_RELU:
15917
  {
15918
  n_tasks = 1;
15919
  } break;
 
15926
  case GGML_UNARY_OP_TANH:
15927
  case GGML_UNARY_OP_ELU:
15928
  case GGML_UNARY_OP_RELU:
 
15929
  {
15930
  n_tasks = 1;
15931
  } break;
 
16044
  {
16045
  n_tasks = n_threads;
16046
  } break;
16047
+ case GGML_OP_PAD:
16048
+ {
16049
+ n_tasks = n_threads;
16050
+ } break;
16051
  case GGML_OP_ARGSORT:
16052
  {
16053
  n_tasks = n_threads;
ggml.h CHANGED
@@ -215,9 +215,9 @@
215
  #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
216
 
217
  #define GGML_MAX_DIMS 4
218
- #define GGML_MAX_PARAMS 1024
219
  #define GGML_MAX_CONTEXTS 64
220
- #define GGML_MAX_SRC 6
221
  #define GGML_MAX_NAME 64
222
  #define GGML_MAX_OP_PARAMS 64
223
  #define GGML_DEFAULT_N_THREADS 4
@@ -423,7 +423,9 @@ extern "C" {
423
  GGML_OP_POOL_1D,
424
  GGML_OP_POOL_2D,
425
  GGML_OP_UPSCALE, // nearest interpolate
 
426
  GGML_OP_ARGSORT,
 
427
 
428
  GGML_OP_FLASH_ATTN,
429
  GGML_OP_FLASH_FF,
@@ -463,7 +465,6 @@ extern "C" {
463
  GGML_UNARY_OP_GELU,
464
  GGML_UNARY_OP_GELU_QUICK,
465
  GGML_UNARY_OP_SILU,
466
- GGML_UNARY_OP_LEAKY,
467
 
468
  GGML_UNARY_OP_COUNT,
469
  };
@@ -793,6 +794,9 @@ extern "C" {
793
  struct ggml_tensor * a,
794
  struct ggml_tensor * b);
795
 
 
 
 
796
  GGML_API struct ggml_tensor * ggml_acc(
797
  struct ggml_context * ctx,
798
  struct ggml_tensor * a,
@@ -957,15 +961,14 @@ extern "C" {
957
  struct ggml_context * ctx,
958
  struct ggml_tensor * a);
959
 
960
- GGML_API struct ggml_tensor * ggml_leaky(
961
  struct ggml_context * ctx,
962
- struct ggml_tensor * a);
963
 
964
  GGML_API struct ggml_tensor * ggml_relu_inplace(
965
  struct ggml_context * ctx,
966
  struct ggml_tensor * a);
967
 
968
- // TODO: double-check this computation is correct
969
  GGML_API struct ggml_tensor * ggml_gelu(
970
  struct ggml_context * ctx,
971
  struct ggml_tensor * a);
@@ -1051,7 +1054,8 @@ extern "C" {
1051
  // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
1052
  GGML_API struct ggml_tensor * ggml_mul_mat_id(
1053
  struct ggml_context * ctx,
1054
- struct ggml_tensor * as[],
 
1055
  struct ggml_tensor * ids,
1056
  int id,
1057
  struct ggml_tensor * b);
@@ -1263,6 +1267,7 @@ extern "C" {
1263
  struct ggml_context * ctx,
1264
  struct ggml_tensor * a);
1265
 
 
1266
  GGML_API struct ggml_tensor * ggml_get_rows(
1267
  struct ggml_context * ctx,
1268
  struct ggml_tensor * a,
@@ -1549,6 +1554,15 @@ extern "C" {
1549
  struct ggml_tensor * a,
1550
  int scale_factor);
1551
 
 
 
 
 
 
 
 
 
 
1552
  // sort rows
1553
  enum ggml_sort_order {
1554
  GGML_SORT_ASC,
 
215
  #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
216
 
217
  #define GGML_MAX_DIMS 4
218
+ #define GGML_MAX_PARAMS 2048
219
  #define GGML_MAX_CONTEXTS 64
220
+ #define GGML_MAX_SRC 10
221
  #define GGML_MAX_NAME 64
222
  #define GGML_MAX_OP_PARAMS 64
223
  #define GGML_DEFAULT_N_THREADS 4
 
423
  GGML_OP_POOL_1D,
424
  GGML_OP_POOL_2D,
425
  GGML_OP_UPSCALE, // nearest interpolate
426
+ GGML_OP_PAD,
427
  GGML_OP_ARGSORT,
428
+ GGML_OP_LEAKY_RELU,
429
 
430
  GGML_OP_FLASH_ATTN,
431
  GGML_OP_FLASH_FF,
 
465
  GGML_UNARY_OP_GELU,
466
  GGML_UNARY_OP_GELU_QUICK,
467
  GGML_UNARY_OP_SILU,
 
468
 
469
  GGML_UNARY_OP_COUNT,
470
  };
 
794
  struct ggml_tensor * a,
795
  struct ggml_tensor * b);
796
 
797
+ // dst = a
798
+ // view(dst, nb1, nb2, nb3, offset) += b
799
+ // return dst
800
  GGML_API struct ggml_tensor * ggml_acc(
801
  struct ggml_context * ctx,
802
  struct ggml_tensor * a,
 
961
  struct ggml_context * ctx,
962
  struct ggml_tensor * a);
963
 
964
+ GGML_API struct ggml_tensor * ggml_leaky_relu(
965
  struct ggml_context * ctx,
966
+ struct ggml_tensor * a, float negative_slope, bool inplace);
967
 
968
  GGML_API struct ggml_tensor * ggml_relu_inplace(
969
  struct ggml_context * ctx,
970
  struct ggml_tensor * a);
971
 
 
972
  GGML_API struct ggml_tensor * ggml_gelu(
973
  struct ggml_context * ctx,
974
  struct ggml_tensor * a);
 
1054
  // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
1055
  GGML_API struct ggml_tensor * ggml_mul_mat_id(
1056
  struct ggml_context * ctx,
1057
+ struct ggml_tensor * const as[],
1058
+ int n_as,
1059
  struct ggml_tensor * ids,
1060
  int id,
1061
  struct ggml_tensor * b);
 
1267
  struct ggml_context * ctx,
1268
  struct ggml_tensor * a);
1269
 
1270
+ // supports 3D: a->ne[2] == b->ne[1]
1271
  GGML_API struct ggml_tensor * ggml_get_rows(
1272
  struct ggml_context * ctx,
1273
  struct ggml_tensor * a,
 
1554
  struct ggml_tensor * a,
1555
  int scale_factor);
1556
 
1557
+ // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
1558
+ GGML_API struct ggml_tensor * ggml_pad(
1559
+ struct ggml_context * ctx,
1560
+ struct ggml_tensor * a,
1561
+ int p0,
1562
+ int p1,
1563
+ int p2,
1564
+ int p3);
1565
+
1566
  // sort rows
1567
  enum ggml_sort_order {
1568
  GGML_SORT_ASC,