Spaces:
Running
Running
sync : ggml (Metal fixes, new ops, tests) (#1633)
Browse files* sync : ggml (Metal fixes, new ops, tests)
* cuda : fix bin bcast when src1 and dst have different types
- ggml-alloc.h +1 -1
- ggml-cuda.cu +683 -89
- ggml-metal.m +530 -53
- ggml-metal.metal +1497 -169
- ggml-quants.c +2 -2
- ggml.c +264 -99
- ggml.h +21 -7
ggml-alloc.h
CHANGED
|
@@ -43,7 +43,7 @@ GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph
|
|
| 43 |
// ggml-backend v2 API
|
| 44 |
//
|
| 45 |
|
| 46 |
-
//
|
| 47 |
// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
|
| 48 |
// The original API is kept as a wrapper around the new API
|
| 49 |
|
|
|
|
| 43 |
// ggml-backend v2 API
|
| 44 |
//
|
| 45 |
|
| 46 |
+
// Separate tensor and graph allocator objects
|
| 47 |
// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
|
| 48 |
// The original API is kept as a wrapper around the new API
|
| 49 |
|
ggml-cuda.cu
CHANGED
|
@@ -1,13 +1,15 @@
|
|
| 1 |
#include <algorithm>
|
|
|
|
|
|
|
|
|
|
| 2 |
#include <cstddef>
|
| 3 |
#include <cstdint>
|
| 4 |
-
#include <cinttypes>
|
| 5 |
#include <float.h>
|
| 6 |
#include <limits>
|
| 7 |
#include <stdint.h>
|
| 8 |
#include <stdio.h>
|
| 9 |
-
#include <
|
| 10 |
-
|
| 11 |
|
| 12 |
#if defined(GGML_USE_HIPBLAS)
|
| 13 |
#include <hip/hip_runtime.h>
|
|
@@ -437,6 +439,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
|
| 437 |
|
| 438 |
#define CUDA_GELU_BLOCK_SIZE 256
|
| 439 |
#define CUDA_SILU_BLOCK_SIZE 256
|
|
|
|
| 440 |
#define CUDA_RELU_BLOCK_SIZE 256
|
| 441 |
#define CUDA_SQR_BLOCK_SIZE 256
|
| 442 |
#define CUDA_CPY_BLOCK_SIZE 32
|
|
@@ -449,6 +452,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
|
| 449 |
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
| 450 |
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
| 451 |
#define CUDA_GET_ROWS_BLOCK_SIZE 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
// dmmv = dequantize_mul_mat_vec
|
| 454 |
#ifndef GGML_CUDA_DMMV_X
|
|
@@ -610,6 +618,24 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
|
|
| 610 |
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
| 611 |
}
|
| 612 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
|
| 614 |
const float GELU_COEF_A = 0.044715f;
|
| 615 |
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
@@ -632,6 +658,23 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
|
| 632 |
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
| 633 |
}
|
| 634 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
| 636 |
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 637 |
|
|
@@ -641,6 +684,14 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
|
| 641 |
dst[i] = fmaxf(x[i], 0);
|
| 642 |
}
|
| 643 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
|
| 645 |
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 646 |
|
|
@@ -686,6 +737,132 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
|
|
| 686 |
}
|
| 687 |
}
|
| 688 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
template <int block_size>
|
| 690 |
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
| 691 |
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
|
@@ -1684,31 +1861,65 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
|
| 1684 |
}
|
| 1685 |
|
| 1686 |
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
| 1687 |
-
static __global__ void k_get_rows(
|
| 1688 |
-
|
| 1689 |
-
|
| 1690 |
-
|
| 1691 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1692 |
return;
|
| 1693 |
}
|
| 1694 |
|
| 1695 |
-
const int
|
| 1696 |
|
| 1697 |
-
|
| 1698 |
-
const
|
| 1699 |
-
const int di = row*ncols + col;
|
| 1700 |
|
| 1701 |
-
const int ib =
|
| 1702 |
-
const int iqs = (
|
| 1703 |
-
const int iybs =
|
| 1704 |
const int y_offset = qr == 1 ? 1 : qk/2;
|
| 1705 |
|
| 1706 |
// dequantize
|
| 1707 |
dfloat2 v;
|
| 1708 |
-
dequantize_kernel(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1709 |
|
| 1710 |
-
|
| 1711 |
-
dst[iybs + iqs + y_offset] = v.y;
|
| 1712 |
}
|
| 1713 |
|
| 1714 |
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
|
@@ -5035,29 +5246,98 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
|
|
| 5035 |
|
| 5036 |
static __global__ void im2col_f32_f16(
|
| 5037 |
const float * x, half * dst,
|
| 5038 |
-
int
|
| 5039 |
int s0, int s1, int p0, int p1, int d0, int d1) {
|
| 5040 |
-
const int
|
| 5041 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5042 |
|
| 5043 |
const int offset_dst =
|
| 5044 |
-
(
|
| 5045 |
-
(blockIdx.
|
| 5046 |
|
| 5047 |
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
| 5048 |
dst[offset_dst] = __float2half(0.0f);
|
| 5049 |
} else {
|
| 5050 |
-
const int offset_src =
|
| 5051 |
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
|
| 5052 |
}
|
| 5053 |
}
|
| 5054 |
|
| 5055 |
template<int qk, int qr, dequantize_kernel_t dq>
|
| 5056 |
-
static void get_rows_cuda(const
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5057 |
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
| 5058 |
-
const int block_num_x = (
|
| 5059 |
-
const dim3 block_nums(block_num_x,
|
| 5060 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5061 |
}
|
| 5062 |
|
| 5063 |
template<float (*bin_op)(const float, const float)>
|
|
@@ -5069,7 +5349,6 @@ struct bin_bcast_cuda {
|
|
| 5069 |
|
| 5070 |
GGML_TENSOR_BINARY_OP_LOCALS
|
| 5071 |
|
| 5072 |
-
|
| 5073 |
int nr0 = ne10/ne0;
|
| 5074 |
int nr1 = ne11/ne1;
|
| 5075 |
int nr2 = ne12/ne2;
|
|
@@ -5117,26 +5396,28 @@ struct bin_bcast_cuda {
|
|
| 5117 |
int64_t ne12 = cne1[2];
|
| 5118 |
int64_t ne13 = cne1[3];
|
| 5119 |
|
| 5120 |
-
|
| 5121 |
size_t nb1 = cnb0[1];
|
| 5122 |
size_t nb2 = cnb0[2];
|
| 5123 |
size_t nb3 = cnb0[3];
|
| 5124 |
|
| 5125 |
-
|
| 5126 |
size_t nb11 = cnb1[1];
|
| 5127 |
size_t nb12 = cnb1[2];
|
| 5128 |
size_t nb13 = cnb1[3];
|
| 5129 |
|
| 5130 |
-
|
| 5131 |
-
size_t s1 = nb1 / sizeof(
|
| 5132 |
-
size_t s2 = nb2 / sizeof(
|
| 5133 |
-
size_t s3 = nb3 / sizeof(
|
| 5134 |
|
| 5135 |
-
|
| 5136 |
size_t s11 = nb11 / sizeof(src1_t);
|
| 5137 |
size_t s12 = nb12 / sizeof(src1_t);
|
| 5138 |
size_t s13 = nb13 / sizeof(src1_t);
|
| 5139 |
|
|
|
|
|
|
|
| 5140 |
|
| 5141 |
const int block_size = 128;
|
| 5142 |
|
|
@@ -5174,6 +5455,13 @@ struct bin_bcast_cuda {
|
|
| 5174 |
}
|
| 5175 |
};
|
| 5176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5177 |
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
| 5178 |
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
| 5179 |
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
@@ -5184,11 +5472,26 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
|
| 5184 |
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
| 5185 |
}
|
| 5186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5187 |
static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
| 5188 |
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
| 5189 |
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
| 5190 |
}
|
| 5191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5192 |
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
| 5193 |
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
|
| 5194 |
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
@@ -5205,6 +5508,38 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
|
|
| 5205 |
}
|
| 5206 |
}
|
| 5207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5208 |
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
| 5209 |
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
| 5210 |
if (ncols < 1024) {
|
|
@@ -6167,13 +6502,14 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
|
|
| 6167 |
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 6168 |
}
|
| 6169 |
|
| 6170 |
-
static void im2col_f32_f16_cuda(const float
|
| 6171 |
-
int
|
| 6172 |
-
int
|
| 6173 |
-
int s0,
|
| 6174 |
-
|
| 6175 |
-
|
| 6176 |
-
|
|
|
|
| 6177 |
}
|
| 6178 |
|
| 6179 |
// buffer pool for cuda
|
|
@@ -6447,36 +6783,34 @@ static void ggml_cuda_op_get_rows(
|
|
| 6447 |
|
| 6448 |
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
| 6449 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 6450 |
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 6451 |
-
GGML_ASSERT(ggml_is_contiguous(src1));
|
| 6452 |
-
GGML_ASSERT(ggml_is_contiguous(dst));
|
| 6453 |
|
| 6454 |
-
|
| 6455 |
-
|
|
|
|
| 6456 |
|
| 6457 |
const int32_t * src1_i32 = (const int32_t *) src1_d;
|
| 6458 |
|
| 6459 |
switch (src0->type) {
|
| 6460 |
case GGML_TYPE_F16:
|
| 6461 |
-
|
| 6462 |
break;
|
| 6463 |
case GGML_TYPE_F32:
|
| 6464 |
-
|
| 6465 |
break;
|
| 6466 |
case GGML_TYPE_Q4_0:
|
| 6467 |
-
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(
|
| 6468 |
break;
|
| 6469 |
case GGML_TYPE_Q4_1:
|
| 6470 |
-
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(
|
| 6471 |
break;
|
| 6472 |
case GGML_TYPE_Q5_0:
|
| 6473 |
-
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(
|
| 6474 |
break;
|
| 6475 |
case GGML_TYPE_Q5_1:
|
| 6476 |
-
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(
|
| 6477 |
break;
|
| 6478 |
case GGML_TYPE_Q8_0:
|
| 6479 |
-
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(
|
| 6480 |
break;
|
| 6481 |
default:
|
| 6482 |
// TODO: k-quants
|
|
@@ -6522,6 +6856,25 @@ inline void ggml_cuda_op_add(
|
|
| 6522 |
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
| 6523 |
}
|
| 6524 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6525 |
inline void ggml_cuda_op_mul(
|
| 6526 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6527 |
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
|
@@ -6564,6 +6917,34 @@ inline void ggml_cuda_op_silu(
|
|
| 6564 |
(void) src1_dd;
|
| 6565 |
}
|
| 6566 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6567 |
inline void ggml_cuda_op_relu(
|
| 6568 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6569 |
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
|
@@ -6578,6 +6959,23 @@ inline void ggml_cuda_op_relu(
|
|
| 6578 |
(void) src1_dd;
|
| 6579 |
}
|
| 6580 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6581 |
inline void ggml_cuda_op_sqr(
|
| 6582 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6583 |
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
|
@@ -6612,6 +7010,71 @@ inline void ggml_cuda_op_norm(
|
|
| 6612 |
(void) src1_dd;
|
| 6613 |
}
|
| 6614 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6615 |
inline void ggml_cuda_op_rms_norm(
|
| 6616 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6617 |
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
|
@@ -7126,7 +7589,6 @@ inline void ggml_cuda_op_im2col(
|
|
| 7126 |
|
| 7127 |
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
|
| 7128 |
|
| 7129 |
-
const int64_t N = src1->ne[is_2D ? 3 : 2];
|
| 7130 |
const int64_t IC = src1->ne[is_2D ? 2 : 1];
|
| 7131 |
const int64_t IH = is_2D ? src1->ne[1] : 1;
|
| 7132 |
const int64_t IW = src1->ne[0];
|
|
@@ -7137,17 +7599,15 @@ inline void ggml_cuda_op_im2col(
|
|
| 7137 |
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
| 7138 |
const int64_t OW = dst->ne[1];
|
| 7139 |
|
| 7140 |
-
const size_t
|
| 7141 |
-
const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
| 7142 |
|
| 7143 |
-
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
|
| 7144 |
-
OH, IW, IH, OW, IC, KH, KW, N,
|
| 7145 |
-
ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
|
| 7146 |
|
| 7147 |
(void) src0;
|
| 7148 |
(void) src0_dd;
|
| 7149 |
}
|
| 7150 |
|
|
|
|
| 7151 |
inline void ggml_cuda_op_sum_rows(
|
| 7152 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 7153 |
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
|
@@ -7696,6 +8156,10 @@ static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
|
| 7696 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
|
| 7697 |
}
|
| 7698 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7699 |
static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7700 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
|
| 7701 |
}
|
|
@@ -7712,10 +8176,22 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
|
|
| 7712 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
|
| 7713 |
}
|
| 7714 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7715 |
static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7716 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
|
| 7717 |
}
|
| 7718 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7719 |
static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7720 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
|
| 7721 |
}
|
|
@@ -7724,6 +8200,22 @@ static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, g
|
|
| 7724 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
|
| 7725 |
}
|
| 7726 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7727 |
static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7728 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
|
| 7729 |
}
|
|
@@ -8234,36 +8726,69 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
|
|
| 8234 |
}
|
| 8235 |
#endif
|
| 8236 |
|
| 8237 |
-
static void ggml_cuda_mul_mat_id(const ggml_tensor *
|
| 8238 |
#if 0
|
| 8239 |
-
//#ifdef CUDA_USE_TENSOR_CORES
|
| 8240 |
-
// const bool use_tensor_cores = true;
|
| 8241 |
-
//#else
|
| 8242 |
-
// const bool use_tensor_cores = false;
|
| 8243 |
-
//#endif
|
| 8244 |
-
|
| 8245 |
ggml_cuda_mul_mat_id_cublas(dst);
|
| 8246 |
-
|
| 8247 |
// TODO: mmq/mmv support
|
| 8248 |
-
#
|
| 8249 |
-
const struct ggml_tensor * ids = dst->src[0];
|
| 8250 |
-
const struct ggml_tensor * src1 = dst->src[1];
|
| 8251 |
-
const int id = dst->op_params[0];
|
| 8252 |
|
| 8253 |
-
|
| 8254 |
|
| 8255 |
-
|
| 8256 |
-
|
| 8257 |
-
|
| 8258 |
|
| 8259 |
-
|
| 8260 |
-
const struct ggml_tensor * src0 = dst->src[a_id + 2];
|
| 8261 |
|
| 8262 |
-
|
| 8263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8264 |
|
| 8265 |
-
|
| 8266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8267 |
}
|
| 8268 |
|
| 8269 |
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
@@ -8683,6 +9208,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|
| 8683 |
case GGML_OP_ADD:
|
| 8684 |
func = ggml_cuda_add;
|
| 8685 |
break;
|
|
|
|
|
|
|
|
|
|
| 8686 |
case GGML_OP_MUL:
|
| 8687 |
func = ggml_cuda_mul;
|
| 8688 |
break;
|
|
@@ -8697,6 +9225,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|
| 8697 |
case GGML_UNARY_OP_SILU:
|
| 8698 |
func = ggml_cuda_silu;
|
| 8699 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8700 |
case GGML_UNARY_OP_RELU:
|
| 8701 |
func = ggml_cuda_relu;
|
| 8702 |
break;
|
|
@@ -8707,6 +9241,21 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|
| 8707 |
case GGML_OP_NORM:
|
| 8708 |
func = ggml_cuda_norm;
|
| 8709 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8710 |
case GGML_OP_RMS_NORM:
|
| 8711 |
func = ggml_cuda_rms_norm;
|
| 8712 |
break;
|
|
@@ -8729,9 +9278,6 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|
| 8729 |
func = ggml_cuda_sqr;
|
| 8730 |
break;
|
| 8731 |
case GGML_OP_CLAMP:
|
| 8732 |
-
if (!any_on_device) {
|
| 8733 |
-
return false;
|
| 8734 |
-
}
|
| 8735 |
func = ggml_cuda_clamp;
|
| 8736 |
break;
|
| 8737 |
case GGML_OP_CPY:
|
|
@@ -8740,6 +9286,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|
| 8740 |
case GGML_OP_CONT:
|
| 8741 |
func = ggml_cuda_dup;
|
| 8742 |
break;
|
|
|
|
| 8743 |
case GGML_OP_RESHAPE:
|
| 8744 |
case GGML_OP_VIEW:
|
| 8745 |
case GGML_OP_PERMUTE:
|
|
@@ -9159,6 +9706,8 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
|
|
| 9159 |
case GGML_UNARY_OP_GELU:
|
| 9160 |
case GGML_UNARY_OP_SILU:
|
| 9161 |
case GGML_UNARY_OP_RELU:
|
|
|
|
|
|
|
| 9162 |
return true;
|
| 9163 |
default:
|
| 9164 |
return false;
|
|
@@ -9181,6 +9730,45 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
|
|
| 9181 |
}
|
| 9182 |
return true;
|
| 9183 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9184 |
case GGML_OP_NONE:
|
| 9185 |
case GGML_OP_RESHAPE:
|
| 9186 |
case GGML_OP_VIEW:
|
|
@@ -9188,7 +9776,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
|
|
| 9188 |
case GGML_OP_TRANSPOSE:
|
| 9189 |
case GGML_OP_NORM:
|
| 9190 |
case GGML_OP_REPEAT:
|
| 9191 |
-
case GGML_OP_GET_ROWS:
|
| 9192 |
case GGML_OP_DUP:
|
| 9193 |
case GGML_OP_ADD:
|
| 9194 |
case GGML_OP_MUL:
|
|
@@ -9197,7 +9784,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
|
|
| 9197 |
case GGML_OP_SCALE:
|
| 9198 |
case GGML_OP_SQR:
|
| 9199 |
case GGML_OP_CLAMP:
|
| 9200 |
-
case GGML_OP_CPY:
|
| 9201 |
case GGML_OP_CONT:
|
| 9202 |
case GGML_OP_DIAG_MASK_INF:
|
| 9203 |
case GGML_OP_SOFT_MAX:
|
|
@@ -9206,6 +9792,12 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
|
|
| 9206 |
case GGML_OP_IM2COL:
|
| 9207 |
case GGML_OP_SUM_ROWS:
|
| 9208 |
case GGML_OP_ARGSORT:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9209 |
return true;
|
| 9210 |
default:
|
| 9211 |
return false;
|
|
@@ -9264,7 +9856,9 @@ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * use
|
|
| 9264 |
UNUSED(params);
|
| 9265 |
}
|
| 9266 |
|
| 9267 |
-
extern "C" int ggml_backend_cuda_reg_devices()
|
|
|
|
|
|
|
| 9268 |
int device_count = ggml_cuda_get_device_count();
|
| 9269 |
//int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
|
| 9270 |
for (int i = 0; i < device_count; i++) {
|
|
|
|
| 1 |
#include <algorithm>
|
| 2 |
+
#include <assert.h>
|
| 3 |
+
#include <atomic>
|
| 4 |
+
#include <cinttypes>
|
| 5 |
#include <cstddef>
|
| 6 |
#include <cstdint>
|
|
|
|
| 7 |
#include <float.h>
|
| 8 |
#include <limits>
|
| 9 |
#include <stdint.h>
|
| 10 |
#include <stdio.h>
|
| 11 |
+
#include <vector>
|
| 12 |
+
|
| 13 |
|
| 14 |
#if defined(GGML_USE_HIPBLAS)
|
| 15 |
#include <hip/hip_runtime.h>
|
|
|
|
| 439 |
|
| 440 |
#define CUDA_GELU_BLOCK_SIZE 256
|
| 441 |
#define CUDA_SILU_BLOCK_SIZE 256
|
| 442 |
+
#define CUDA_TANH_BLOCK_SIZE 256
|
| 443 |
#define CUDA_RELU_BLOCK_SIZE 256
|
| 444 |
#define CUDA_SQR_BLOCK_SIZE 256
|
| 445 |
#define CUDA_CPY_BLOCK_SIZE 32
|
|
|
|
| 452 |
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
| 453 |
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
| 454 |
#define CUDA_GET_ROWS_BLOCK_SIZE 256
|
| 455 |
+
#define CUDA_UPSCALE_BLOCK_SIZE 256
|
| 456 |
+
#define CUDA_CONCAT_BLOCK_SIZE 256
|
| 457 |
+
#define CUDA_PAD_BLOCK_SIZE 256
|
| 458 |
+
#define CUDA_ACC_BLOCK_SIZE 256
|
| 459 |
+
#define CUDA_IM2COL_BLOCK_SIZE 256
|
| 460 |
|
| 461 |
// dmmv = dequantize_mul_mat_vec
|
| 462 |
#ifndef GGML_CUDA_DMMV_X
|
|
|
|
| 618 |
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
| 619 |
}
|
| 620 |
|
| 621 |
+
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
| 622 |
+
const int ne10, const int ne11, const int ne12,
|
| 623 |
+
const int nb1, const int nb2, int offset) {
|
| 624 |
+
const int i = blockDim.x * blockIdx.x + threadIdx.x;
|
| 625 |
+
if (i >= ne) {
|
| 626 |
+
return;
|
| 627 |
+
}
|
| 628 |
+
int src1_idx = i - offset;
|
| 629 |
+
int oz = src1_idx / nb2;
|
| 630 |
+
int oy = (src1_idx - (oz * nb2)) / nb1;
|
| 631 |
+
int ox = src1_idx % nb1;
|
| 632 |
+
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
|
| 633 |
+
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
|
| 634 |
+
} else {
|
| 635 |
+
dst[i] = x[i];
|
| 636 |
+
}
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
|
| 640 |
const float GELU_COEF_A = 0.044715f;
|
| 641 |
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
|
|
| 658 |
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
| 659 |
}
|
| 660 |
|
| 661 |
+
static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
|
| 662 |
+
const float GELU_QUICK_COEF = -1.702f;
|
| 663 |
+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 664 |
+
if (i >= k) {
|
| 665 |
+
return;
|
| 666 |
+
}
|
| 667 |
+
dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
static __global__ void tanh_f32(const float *x, float *dst, int k) {
|
| 671 |
+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 672 |
+
if (i >= k) {
|
| 673 |
+
return;
|
| 674 |
+
}
|
| 675 |
+
dst[i] = tanhf(x[i]);
|
| 676 |
+
}
|
| 677 |
+
|
| 678 |
static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
| 679 |
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 680 |
|
|
|
|
| 684 |
dst[i] = fmaxf(x[i], 0);
|
| 685 |
}
|
| 686 |
|
| 687 |
+
static __global__ void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope) {
|
| 688 |
+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 689 |
+
if (i >= k) {
|
| 690 |
+
return;
|
| 691 |
+
}
|
| 692 |
+
dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
|
| 696 |
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 697 |
|
|
|
|
| 737 |
}
|
| 738 |
}
|
| 739 |
|
| 740 |
+
static __global__ void concat_f32(const float *x,const float *y, float *dst, const int ne0, const int ne02) {
|
| 741 |
+
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
| 742 |
+
if (nidx >= ne0) {
|
| 743 |
+
return;
|
| 744 |
+
}
|
| 745 |
+
// operation
|
| 746 |
+
int offset_dst =
|
| 747 |
+
nidx +
|
| 748 |
+
blockIdx.y * ne0 +
|
| 749 |
+
blockIdx.z * ne0 * gridDim.y;
|
| 750 |
+
if (blockIdx.z < ne02) { // src0
|
| 751 |
+
int offset_src =
|
| 752 |
+
nidx +
|
| 753 |
+
blockIdx.y * ne0 +
|
| 754 |
+
blockIdx.z * ne0 * gridDim.y;
|
| 755 |
+
dst[offset_dst] = x[offset_src];
|
| 756 |
+
} else {
|
| 757 |
+
int offset_src =
|
| 758 |
+
nidx +
|
| 759 |
+
blockIdx.y * ne0 +
|
| 760 |
+
(blockIdx.z - ne02) * ne0 * gridDim.y;
|
| 761 |
+
dst[offset_dst] = y[offset_src];
|
| 762 |
+
}
|
| 763 |
+
}
|
| 764 |
+
|
| 765 |
+
static __global__ void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor) {
|
| 766 |
+
int ne0 = ne00 * scale_factor;
|
| 767 |
+
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
| 768 |
+
if (nidx >= ne0) {
|
| 769 |
+
return;
|
| 770 |
+
}
|
| 771 |
+
// operation
|
| 772 |
+
int i00 = nidx / scale_factor;
|
| 773 |
+
int i01 = blockIdx.y / scale_factor;
|
| 774 |
+
int offset_src =
|
| 775 |
+
i00 +
|
| 776 |
+
i01 * ne00 +
|
| 777 |
+
blockIdx.z * nb02;
|
| 778 |
+
int offset_dst =
|
| 779 |
+
nidx +
|
| 780 |
+
blockIdx.y * ne0 +
|
| 781 |
+
blockIdx.z * ne0 * gridDim.y;
|
| 782 |
+
dst[offset_dst] = x[offset_src];
|
| 783 |
+
}
|
| 784 |
+
|
| 785 |
+
static __global__ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02) {
|
| 786 |
+
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
| 787 |
+
if (nidx >= ne0) {
|
| 788 |
+
return;
|
| 789 |
+
}
|
| 790 |
+
|
| 791 |
+
// operation
|
| 792 |
+
int offset_dst =
|
| 793 |
+
nidx +
|
| 794 |
+
blockIdx.y * ne0 +
|
| 795 |
+
blockIdx.z * ne0 * gridDim.y;
|
| 796 |
+
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) {
|
| 797 |
+
int offset_src =
|
| 798 |
+
nidx +
|
| 799 |
+
blockIdx.y * ne00 +
|
| 800 |
+
blockIdx.z * ne00 * ne01;
|
| 801 |
+
dst[offset_dst] = x[offset_src];
|
| 802 |
+
} else {
|
| 803 |
+
dst[offset_dst] = 0.0f;
|
| 804 |
+
}
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
template <int block_size>
|
| 808 |
+
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
|
| 809 |
+
int start = blockIdx.x * group_size;
|
| 810 |
+
int end = start + group_size;
|
| 811 |
+
|
| 812 |
+
start += threadIdx.x;
|
| 813 |
+
|
| 814 |
+
if (end >= ne_elements) {
|
| 815 |
+
end = ne_elements;
|
| 816 |
+
}
|
| 817 |
+
|
| 818 |
+
float tmp = 0.0f; // partial sum for thread in warp
|
| 819 |
+
|
| 820 |
+
for (int j = start; j < end; j += block_size) {
|
| 821 |
+
tmp += x[j];
|
| 822 |
+
}
|
| 823 |
+
|
| 824 |
+
tmp = warp_reduce_sum(tmp);
|
| 825 |
+
if (block_size > WARP_SIZE) {
|
| 826 |
+
__shared__ float s_sum[32];
|
| 827 |
+
int warp_id = threadIdx.x / WARP_SIZE;
|
| 828 |
+
int lane_id = threadIdx.x % WARP_SIZE;
|
| 829 |
+
if (lane_id == 0) {
|
| 830 |
+
s_sum[warp_id] = tmp;
|
| 831 |
+
}
|
| 832 |
+
__syncthreads();
|
| 833 |
+
tmp = s_sum[lane_id];
|
| 834 |
+
tmp = warp_reduce_sum(tmp);
|
| 835 |
+
}
|
| 836 |
+
|
| 837 |
+
float mean = tmp / group_size;
|
| 838 |
+
tmp = 0.0f;
|
| 839 |
+
|
| 840 |
+
for (int j = start; j < end; j += block_size) {
|
| 841 |
+
float xi = x[j] - mean;
|
| 842 |
+
dst[j] = xi;
|
| 843 |
+
tmp += xi * xi;
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
tmp = warp_reduce_sum(tmp);
|
| 847 |
+
if (block_size > WARP_SIZE) {
|
| 848 |
+
__shared__ float s_sum[32];
|
| 849 |
+
int warp_id = threadIdx.x / WARP_SIZE;
|
| 850 |
+
int lane_id = threadIdx.x % WARP_SIZE;
|
| 851 |
+
if (lane_id == 0) {
|
| 852 |
+
s_sum[warp_id] = tmp;
|
| 853 |
+
}
|
| 854 |
+
__syncthreads();
|
| 855 |
+
tmp = s_sum[lane_id];
|
| 856 |
+
tmp = warp_reduce_sum(tmp);
|
| 857 |
+
}
|
| 858 |
+
|
| 859 |
+
float variance = tmp / group_size;
|
| 860 |
+
float scale = rsqrtf(variance + eps);
|
| 861 |
+
for (int j = start; j < end; j += block_size) {
|
| 862 |
+
dst[j] *= scale;
|
| 863 |
+
}
|
| 864 |
+
}
|
| 865 |
+
|
| 866 |
template <int block_size>
|
| 867 |
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
| 868 |
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
|
|
|
| 1861 |
}
|
| 1862 |
|
| 1863 |
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
| 1864 |
+
static __global__ void k_get_rows(
|
| 1865 |
+
const void * src0, const int32_t * src1, dst_t * dst,
|
| 1866 |
+
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
| 1867 |
+
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
| 1868 |
+
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
| 1869 |
+
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
| 1870 |
+
size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
|
| 1871 |
+
|
| 1872 |
+
const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
|
| 1873 |
+
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
| 1874 |
+
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
|
| 1875 |
+
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
|
| 1876 |
+
|
| 1877 |
+
if (i00 >= ne00) {
|
| 1878 |
return;
|
| 1879 |
}
|
| 1880 |
|
| 1881 |
+
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
| 1882 |
|
| 1883 |
+
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
| 1884 |
+
const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
|
|
|
| 1885 |
|
| 1886 |
+
const int ib = i00/qk; // block index
|
| 1887 |
+
const int iqs = (i00%qk)/qr; // quant index
|
| 1888 |
+
const int iybs = i00 - i00%qk; // dst block start index
|
| 1889 |
const int y_offset = qr == 1 ? 1 : qk/2;
|
| 1890 |
|
| 1891 |
// dequantize
|
| 1892 |
dfloat2 v;
|
| 1893 |
+
dequantize_kernel(src0_row, ib, iqs, v);
|
| 1894 |
+
|
| 1895 |
+
dst_row[iybs + iqs + 0] = v.x;
|
| 1896 |
+
dst_row[iybs + iqs + y_offset] = v.y;
|
| 1897 |
+
}
|
| 1898 |
+
|
| 1899 |
+
template<typename src0_t, typename dst_t>
|
| 1900 |
+
static __global__ void k_get_rows_float(
|
| 1901 |
+
const src0_t * src0, const int32_t * src1, dst_t * dst,
|
| 1902 |
+
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
| 1903 |
+
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
| 1904 |
+
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
| 1905 |
+
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
| 1906 |
+
size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
|
| 1907 |
+
|
| 1908 |
+
const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
|
| 1909 |
+
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
| 1910 |
+
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
|
| 1911 |
+
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
|
| 1912 |
+
|
| 1913 |
+
if (i00 >= ne00) {
|
| 1914 |
+
return;
|
| 1915 |
+
}
|
| 1916 |
+
|
| 1917 |
+
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
| 1918 |
+
|
| 1919 |
+
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
| 1920 |
+
const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
| 1921 |
|
| 1922 |
+
dst_row[i00] = src0_row[i00];
|
|
|
|
| 1923 |
}
|
| 1924 |
|
| 1925 |
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
|
|
|
| 5246 |
|
| 5247 |
static __global__ void im2col_f32_f16(
|
| 5248 |
const float * x, half * dst,
|
| 5249 |
+
int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
|
| 5250 |
int s0, int s1, int p0, int p1, int d0, int d1) {
|
| 5251 |
+
const int i = threadIdx.x + blockIdx.x * blockDim.x;
|
| 5252 |
+
if (i >= pelements) {
|
| 5253 |
+
return;
|
| 5254 |
+
}
|
| 5255 |
+
|
| 5256 |
+
const int ksize = OW * (KH > 1 ? KW : 1);
|
| 5257 |
+
const int kx = i / ksize;
|
| 5258 |
+
const int kd = kx * ksize;
|
| 5259 |
+
const int ky = (i - kd) / OW;
|
| 5260 |
+
const int ix = i % OW;
|
| 5261 |
+
|
| 5262 |
+
const int iiw = ix * s0 + kx * d0 - p0;
|
| 5263 |
+
const int iih = blockIdx.y * s1 + ky * d1 - p1;
|
| 5264 |
|
| 5265 |
const int offset_dst =
|
| 5266 |
+
(blockIdx.y * OW + ix) * CHW +
|
| 5267 |
+
(blockIdx.z * (KW * KH) + ky * KW + kx);
|
| 5268 |
|
| 5269 |
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
| 5270 |
dst[offset_dst] = __float2half(0.0f);
|
| 5271 |
} else {
|
| 5272 |
+
const int offset_src = blockIdx.z * offset_delta;
|
| 5273 |
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
|
| 5274 |
}
|
| 5275 |
}
|
| 5276 |
|
| 5277 |
template<int qk, int qr, dequantize_kernel_t dq>
|
| 5278 |
+
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 5279 |
+
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
| 5280 |
+
|
| 5281 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 5282 |
+
|
| 5283 |
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
| 5284 |
+
const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
| 5285 |
+
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
| 5286 |
+
|
| 5287 |
+
// strides in elements
|
| 5288 |
+
//const size_t s0 = nb0 / ggml_element_size(dst);
|
| 5289 |
+
const size_t s1 = nb1 / ggml_element_size(dst);
|
| 5290 |
+
const size_t s2 = nb2 / ggml_element_size(dst);
|
| 5291 |
+
const size_t s3 = nb3 / ggml_element_size(dst);
|
| 5292 |
+
|
| 5293 |
+
const size_t s10 = nb10 / ggml_element_size(src1);
|
| 5294 |
+
const size_t s11 = nb11 / ggml_element_size(src1);
|
| 5295 |
+
const size_t s12 = nb12 / ggml_element_size(src1);
|
| 5296 |
+
//const size_t s13 = nb13 / ggml_element_size(src1);
|
| 5297 |
+
|
| 5298 |
+
GGML_ASSERT(ne00 % 2 == 0);
|
| 5299 |
+
|
| 5300 |
+
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
|
| 5301 |
+
src0_dd, src1_dd, dst_dd,
|
| 5302 |
+
ne00, /*ne01, ne02, ne03,*/
|
| 5303 |
+
/*ne10, ne11,*/ ne12, /*ne13,*/
|
| 5304 |
+
/* s0,*/ s1, s2, s3,
|
| 5305 |
+
/* nb00,*/ nb01, nb02, nb03,
|
| 5306 |
+
s10, s11, s12/*, s13*/);
|
| 5307 |
+
|
| 5308 |
+
(void) dst;
|
| 5309 |
+
}
|
| 5310 |
+
|
| 5311 |
+
template<typename src0_t>
|
| 5312 |
+
static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 5313 |
+
const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
| 5314 |
+
|
| 5315 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 5316 |
+
|
| 5317 |
+
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
| 5318 |
+
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
| 5319 |
+
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
| 5320 |
+
|
| 5321 |
+
// strides in elements
|
| 5322 |
+
//const size_t s0 = nb0 / ggml_element_size(dst);
|
| 5323 |
+
const size_t s1 = nb1 / ggml_element_size(dst);
|
| 5324 |
+
const size_t s2 = nb2 / ggml_element_size(dst);
|
| 5325 |
+
const size_t s3 = nb3 / ggml_element_size(dst);
|
| 5326 |
+
|
| 5327 |
+
const size_t s10 = nb10 / ggml_element_size(src1);
|
| 5328 |
+
const size_t s11 = nb11 / ggml_element_size(src1);
|
| 5329 |
+
const size_t s12 = nb12 / ggml_element_size(src1);
|
| 5330 |
+
//const size_t s13 = nb13 / ggml_element_size(src1);
|
| 5331 |
+
|
| 5332 |
+
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
|
| 5333 |
+
src0_dd, src1_dd, dst_dd,
|
| 5334 |
+
ne00, /*ne01, ne02, ne03,*/
|
| 5335 |
+
/*ne10, ne11,*/ ne12, /*ne13,*/
|
| 5336 |
+
/* s0,*/ s1, s2, s3,
|
| 5337 |
+
/* nb00,*/ nb01, nb02, nb03,
|
| 5338 |
+
s10, s11, s12/*, s13*/);
|
| 5339 |
+
|
| 5340 |
+
(void) dst;
|
| 5341 |
}
|
| 5342 |
|
| 5343 |
template<float (*bin_op)(const float, const float)>
|
|
|
|
| 5349 |
|
| 5350 |
GGML_TENSOR_BINARY_OP_LOCALS
|
| 5351 |
|
|
|
|
| 5352 |
int nr0 = ne10/ne0;
|
| 5353 |
int nr1 = ne11/ne1;
|
| 5354 |
int nr2 = ne12/ne2;
|
|
|
|
| 5396 |
int64_t ne12 = cne1[2];
|
| 5397 |
int64_t ne13 = cne1[3];
|
| 5398 |
|
| 5399 |
+
size_t nb0 = cnb0[0];
|
| 5400 |
size_t nb1 = cnb0[1];
|
| 5401 |
size_t nb2 = cnb0[2];
|
| 5402 |
size_t nb3 = cnb0[3];
|
| 5403 |
|
| 5404 |
+
size_t nb10 = cnb1[0];
|
| 5405 |
size_t nb11 = cnb1[1];
|
| 5406 |
size_t nb12 = cnb1[2];
|
| 5407 |
size_t nb13 = cnb1[3];
|
| 5408 |
|
| 5409 |
+
size_t s0 = nb0 / sizeof(dst_t);
|
| 5410 |
+
size_t s1 = nb1 / sizeof(dst_t);
|
| 5411 |
+
size_t s2 = nb2 / sizeof(dst_t);
|
| 5412 |
+
size_t s3 = nb3 / sizeof(dst_t);
|
| 5413 |
|
| 5414 |
+
size_t s10 = nb10 / sizeof(src1_t);
|
| 5415 |
size_t s11 = nb11 / sizeof(src1_t);
|
| 5416 |
size_t s12 = nb12 / sizeof(src1_t);
|
| 5417 |
size_t s13 = nb13 / sizeof(src1_t);
|
| 5418 |
|
| 5419 |
+
GGML_ASSERT(s0 == 1);
|
| 5420 |
+
GGML_ASSERT(s10 == 1);
|
| 5421 |
|
| 5422 |
const int block_size = 128;
|
| 5423 |
|
|
|
|
| 5455 |
}
|
| 5456 |
};
|
| 5457 |
|
| 5458 |
+
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
|
| 5459 |
+
const int ne10, const int ne11, const int ne12,
|
| 5460 |
+
const int nb1, const int nb2, const int offset, cudaStream_t stream) {
|
| 5461 |
+
int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
|
| 5462 |
+
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
|
| 5463 |
+
}
|
| 5464 |
+
|
| 5465 |
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
| 5466 |
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
| 5467 |
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
| 5472 |
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
| 5473 |
}
|
| 5474 |
|
| 5475 |
+
static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
| 5476 |
+
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
| 5477 |
+
gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
| 5478 |
+
}
|
| 5479 |
+
|
| 5480 |
+
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
| 5481 |
+
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
|
| 5482 |
+
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
| 5483 |
+
}
|
| 5484 |
+
|
| 5485 |
static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
| 5486 |
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
| 5487 |
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
| 5488 |
}
|
| 5489 |
|
| 5490 |
+
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
| 5491 |
+
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
| 5492 |
+
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
| 5493 |
+
}
|
| 5494 |
+
|
| 5495 |
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
| 5496 |
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
|
| 5497 |
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
| 5508 |
}
|
| 5509 |
}
|
| 5510 |
|
| 5511 |
+
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
|
| 5512 |
+
static const float eps = 1e-6f;
|
| 5513 |
+
if (group_size < 1024) {
|
| 5514 |
+
const dim3 block_dims(WARP_SIZE, 1, 1);
|
| 5515 |
+
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
|
| 5516 |
+
} else {
|
| 5517 |
+
const dim3 block_dims(1024, 1, 1);
|
| 5518 |
+
group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
|
| 5519 |
+
}
|
| 5520 |
+
}
|
| 5521 |
+
|
| 5522 |
+
static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) {
|
| 5523 |
+
int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
|
| 5524 |
+
dim3 gridDim(num_blocks, ne1, ne2);
|
| 5525 |
+
concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
|
| 5526 |
+
}
|
| 5527 |
+
|
| 5528 |
+
static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
|
| 5529 |
+
int ne0 = (ne00 * scale_factor);
|
| 5530 |
+
int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
| 5531 |
+
dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02);
|
| 5532 |
+
upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
|
| 5533 |
+
}
|
| 5534 |
+
|
| 5535 |
+
static void pad_f32_cuda(const float * x, float * dst,
|
| 5536 |
+
const int ne00, const int ne01, const int ne02,
|
| 5537 |
+
const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
|
| 5538 |
+
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
| 5539 |
+
dim3 gridDim(num_blocks, ne1, ne2);
|
| 5540 |
+
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02);
|
| 5541 |
+
}
|
| 5542 |
+
|
| 5543 |
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
| 5544 |
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
| 5545 |
if (ncols < 1024) {
|
|
|
|
| 6502 |
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 6503 |
}
|
| 6504 |
|
| 6505 |
+
static void im2col_f32_f16_cuda(const float* x, half* dst,
|
| 6506 |
+
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
|
| 6507 |
+
int offset_delta,
|
| 6508 |
+
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
|
| 6509 |
+
const int parallel_elements = OW * KW * KH;
|
| 6510 |
+
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
|
| 6511 |
+
dim3 block_nums(num_blocks, OH, IC);
|
| 6512 |
+
im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
| 6513 |
}
|
| 6514 |
|
| 6515 |
// buffer pool for cuda
|
|
|
|
| 6783 |
|
| 6784 |
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
| 6785 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
|
|
|
|
|
|
|
|
| 6786 |
|
| 6787 |
+
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
| 6788 |
+
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
|
| 6789 |
+
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
|
| 6790 |
|
| 6791 |
const int32_t * src1_i32 = (const int32_t *) src1_d;
|
| 6792 |
|
| 6793 |
switch (src0->type) {
|
| 6794 |
case GGML_TYPE_F16:
|
| 6795 |
+
get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
|
| 6796 |
break;
|
| 6797 |
case GGML_TYPE_F32:
|
| 6798 |
+
get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 6799 |
break;
|
| 6800 |
case GGML_TYPE_Q4_0:
|
| 6801 |
+
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 6802 |
break;
|
| 6803 |
case GGML_TYPE_Q4_1:
|
| 6804 |
+
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 6805 |
break;
|
| 6806 |
case GGML_TYPE_Q5_0:
|
| 6807 |
+
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 6808 |
break;
|
| 6809 |
case GGML_TYPE_Q5_1:
|
| 6810 |
+
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 6811 |
break;
|
| 6812 |
case GGML_TYPE_Q8_0:
|
| 6813 |
+
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 6814 |
break;
|
| 6815 |
default:
|
| 6816 |
// TODO: k-quants
|
|
|
|
| 6856 |
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
| 6857 |
}
|
| 6858 |
|
| 6859 |
+
inline void ggml_cuda_op_acc(
|
| 6860 |
+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6861 |
+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
| 6862 |
+
|
| 6863 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 6864 |
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 6865 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 6866 |
+
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
|
| 6867 |
+
|
| 6868 |
+
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
| 6869 |
+
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
| 6870 |
+
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
|
| 6871 |
+
int offset = dst->op_params[3] / 4; // offset in bytes
|
| 6872 |
+
|
| 6873 |
+
acc_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
|
| 6874 |
+
|
| 6875 |
+
(void) dst;
|
| 6876 |
+
}
|
| 6877 |
+
|
| 6878 |
inline void ggml_cuda_op_mul(
|
| 6879 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6880 |
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
|
|
|
| 6917 |
(void) src1_dd;
|
| 6918 |
}
|
| 6919 |
|
| 6920 |
+
inline void ggml_cuda_op_gelu_quick(
|
| 6921 |
+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6922 |
+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
| 6923 |
+
|
| 6924 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 6925 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 6926 |
+
|
| 6927 |
+
gelu_quick_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 6928 |
+
|
| 6929 |
+
(void) src1;
|
| 6930 |
+
(void) dst;
|
| 6931 |
+
(void) src1_dd;
|
| 6932 |
+
}
|
| 6933 |
+
|
| 6934 |
+
inline void ggml_cuda_op_tanh(
|
| 6935 |
+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6936 |
+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
| 6937 |
+
|
| 6938 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 6939 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 6940 |
+
|
| 6941 |
+
tanh_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 6942 |
+
|
| 6943 |
+
(void) src1;
|
| 6944 |
+
(void) dst;
|
| 6945 |
+
(void) src1_dd;
|
| 6946 |
+
}
|
| 6947 |
+
|
| 6948 |
inline void ggml_cuda_op_relu(
|
| 6949 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6950 |
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
|
|
|
| 6959 |
(void) src1_dd;
|
| 6960 |
}
|
| 6961 |
|
| 6962 |
+
inline void ggml_cuda_op_leaky_relu(
|
| 6963 |
+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6964 |
+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
| 6965 |
+
|
| 6966 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 6967 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 6968 |
+
|
| 6969 |
+
float negative_slope;
|
| 6970 |
+
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
| 6971 |
+
|
| 6972 |
+
leaky_relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
|
| 6973 |
+
|
| 6974 |
+
(void) src1;
|
| 6975 |
+
(void) dst;
|
| 6976 |
+
(void) src1_dd;
|
| 6977 |
+
}
|
| 6978 |
+
|
| 6979 |
inline void ggml_cuda_op_sqr(
|
| 6980 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6981 |
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
|
|
|
| 7010 |
(void) src1_dd;
|
| 7011 |
}
|
| 7012 |
|
| 7013 |
+
|
| 7014 |
+
inline void ggml_cuda_op_group_norm(
|
| 7015 |
+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 7016 |
+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
| 7017 |
+
|
| 7018 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 7019 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 7020 |
+
|
| 7021 |
+
int num_groups = dst->op_params[0];
|
| 7022 |
+
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
|
| 7023 |
+
group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
|
| 7024 |
+
|
| 7025 |
+
(void) src1;
|
| 7026 |
+
(void) dst;
|
| 7027 |
+
(void) src1_dd;
|
| 7028 |
+
}
|
| 7029 |
+
|
| 7030 |
+
inline void ggml_cuda_op_concat(
|
| 7031 |
+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 7032 |
+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
| 7033 |
+
|
| 7034 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 7035 |
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 7036 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 7037 |
+
|
| 7038 |
+
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
| 7039 |
+
concat_f32_cuda(src0_dd + i3 * (src0->nb[3] / 4), src1_dd + i3 * (src1->nb[3] / 4), dst_dd + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], main_stream);
|
| 7040 |
+
}
|
| 7041 |
+
|
| 7042 |
+
(void) src1;
|
| 7043 |
+
(void) dst;
|
| 7044 |
+
}
|
| 7045 |
+
|
| 7046 |
+
inline void ggml_cuda_op_upscale(
|
| 7047 |
+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 7048 |
+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
| 7049 |
+
|
| 7050 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 7051 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 7052 |
+
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
| 7053 |
+
|
| 7054 |
+
const int scale_factor = dst->op_params[0];
|
| 7055 |
+
|
| 7056 |
+
upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
|
| 7057 |
+
|
| 7058 |
+
(void) src1;
|
| 7059 |
+
(void) dst;
|
| 7060 |
+
}
|
| 7061 |
+
|
| 7062 |
+
inline void ggml_cuda_op_pad(
|
| 7063 |
+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 7064 |
+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
| 7065 |
+
|
| 7066 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 7067 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 7068 |
+
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
| 7069 |
+
|
| 7070 |
+
pad_f32_cuda(src0_dd, dst_dd,
|
| 7071 |
+
src0->ne[0], src0->ne[1], src0->ne[2],
|
| 7072 |
+
dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
|
| 7073 |
+
|
| 7074 |
+
(void) src1;
|
| 7075 |
+
(void) dst;
|
| 7076 |
+
}
|
| 7077 |
+
|
| 7078 |
inline void ggml_cuda_op_rms_norm(
|
| 7079 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 7080 |
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
|
|
|
| 7589 |
|
| 7590 |
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
|
| 7591 |
|
|
|
|
| 7592 |
const int64_t IC = src1->ne[is_2D ? 2 : 1];
|
| 7593 |
const int64_t IH = is_2D ? src1->ne[1] : 1;
|
| 7594 |
const int64_t IW = src1->ne[0];
|
|
|
|
| 7599 |
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
| 7600 |
const int64_t OW = dst->ne[1];
|
| 7601 |
|
| 7602 |
+
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
|
|
|
| 7603 |
|
| 7604 |
+
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
|
|
|
|
|
|
| 7605 |
|
| 7606 |
(void) src0;
|
| 7607 |
(void) src0_dd;
|
| 7608 |
}
|
| 7609 |
|
| 7610 |
+
|
| 7611 |
inline void ggml_cuda_op_sum_rows(
|
| 7612 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 7613 |
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
|
|
|
| 8156 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
|
| 8157 |
}
|
| 8158 |
|
| 8159 |
+
static void ggml_cuda_acc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 8160 |
+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_acc);
|
| 8161 |
+
}
|
| 8162 |
+
|
| 8163 |
static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 8164 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
|
| 8165 |
}
|
|
|
|
| 8176 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
|
| 8177 |
}
|
| 8178 |
|
| 8179 |
+
static void ggml_cuda_gelu_quick(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 8180 |
+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu_quick);
|
| 8181 |
+
}
|
| 8182 |
+
|
| 8183 |
+
static void ggml_cuda_tanh(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 8184 |
+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_tanh);
|
| 8185 |
+
}
|
| 8186 |
+
|
| 8187 |
static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 8188 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
|
| 8189 |
}
|
| 8190 |
|
| 8191 |
+
static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 8192 |
+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
|
| 8193 |
+
}
|
| 8194 |
+
|
| 8195 |
static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 8196 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
|
| 8197 |
}
|
|
|
|
| 8200 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
|
| 8201 |
}
|
| 8202 |
|
| 8203 |
+
static void ggml_cuda_group_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 8204 |
+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_group_norm);
|
| 8205 |
+
}
|
| 8206 |
+
|
| 8207 |
+
static void ggml_cuda_concat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 8208 |
+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_concat);
|
| 8209 |
+
}
|
| 8210 |
+
|
| 8211 |
+
static void ggml_cuda_upscale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 8212 |
+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_upscale);
|
| 8213 |
+
}
|
| 8214 |
+
|
| 8215 |
+
static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 8216 |
+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
|
| 8217 |
+
}
|
| 8218 |
+
|
| 8219 |
static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 8220 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
|
| 8221 |
}
|
|
|
|
| 8726 |
}
|
| 8727 |
#endif
|
| 8728 |
|
| 8729 |
+
static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 8730 |
#if 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8731 |
ggml_cuda_mul_mat_id_cublas(dst);
|
|
|
|
| 8732 |
// TODO: mmq/mmv support
|
| 8733 |
+
#endif
|
|
|
|
|
|
|
|
|
|
| 8734 |
|
| 8735 |
+
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
|
| 8736 |
|
| 8737 |
+
const struct ggml_tensor * ids = src0;
|
| 8738 |
+
const int32_t id = ((int32_t *) dst->op_params)[0];
|
| 8739 |
+
const int32_t n_as = ((int32_t *) dst->op_params)[1];
|
| 8740 |
|
| 8741 |
+
std::vector<char> ids_host(ggml_nbytes(ids));
|
|
|
|
| 8742 |
|
| 8743 |
+
if (ids->backend == GGML_BACKEND_GPU) {
|
| 8744 |
+
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
|
| 8745 |
+
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
|
| 8746 |
+
CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
|
| 8747 |
+
} else {
|
| 8748 |
+
memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
|
| 8749 |
+
}
|
| 8750 |
+
|
| 8751 |
+
const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
|
| 8752 |
+
const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
|
| 8753 |
+
|
| 8754 |
+
ggml_tensor_extra_gpu src1_row_extra;
|
| 8755 |
+
ggml_tensor_extra_gpu dst_row_extra;
|
| 8756 |
+
|
| 8757 |
+
ggml_tensor src1_row = *src1;
|
| 8758 |
+
ggml_tensor dst_row = *dst;
|
| 8759 |
+
|
| 8760 |
+
src1_row.ne[1] = 1;
|
| 8761 |
+
dst_row.ne[1] = 1;
|
| 8762 |
+
|
| 8763 |
+
src1_row.nb[2] = src1_row.nb[1];
|
| 8764 |
+
dst_row.nb[2] = dst_row.nb[1];
|
| 8765 |
+
|
| 8766 |
+
src1_row.nb[3] = src1_row.nb[1];
|
| 8767 |
+
dst_row.nb[3] = dst_row.nb[1];
|
| 8768 |
+
|
| 8769 |
+
src1_row.extra = &src1_row_extra;
|
| 8770 |
+
dst_row.extra = &dst_row_extra;
|
| 8771 |
+
|
| 8772 |
+
|
| 8773 |
+
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
| 8774 |
+
//int32_t row_id;
|
| 8775 |
+
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
|
| 8776 |
+
//CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
|
| 8777 |
+
|
| 8778 |
+
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
| 8779 |
+
|
| 8780 |
+
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
| 8781 |
|
| 8782 |
+
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
| 8783 |
+
|
| 8784 |
+
src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
|
| 8785 |
+
src1_row.data = (char *) src1->data + i01*src1->nb[1];
|
| 8786 |
+
|
| 8787 |
+
dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
|
| 8788 |
+
dst_row.data = (char *) dst->data + i01*dst->nb[1];
|
| 8789 |
+
|
| 8790 |
+
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
|
| 8791 |
+
}
|
| 8792 |
}
|
| 8793 |
|
| 8794 |
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
|
|
| 9208 |
case GGML_OP_ADD:
|
| 9209 |
func = ggml_cuda_add;
|
| 9210 |
break;
|
| 9211 |
+
case GGML_OP_ACC:
|
| 9212 |
+
func = ggml_cuda_acc;
|
| 9213 |
+
break;
|
| 9214 |
case GGML_OP_MUL:
|
| 9215 |
func = ggml_cuda_mul;
|
| 9216 |
break;
|
|
|
|
| 9225 |
case GGML_UNARY_OP_SILU:
|
| 9226 |
func = ggml_cuda_silu;
|
| 9227 |
break;
|
| 9228 |
+
case GGML_UNARY_OP_GELU_QUICK:
|
| 9229 |
+
func = ggml_cuda_gelu_quick;
|
| 9230 |
+
break;
|
| 9231 |
+
case GGML_UNARY_OP_TANH:
|
| 9232 |
+
func = ggml_cuda_tanh;
|
| 9233 |
+
break;
|
| 9234 |
case GGML_UNARY_OP_RELU:
|
| 9235 |
func = ggml_cuda_relu;
|
| 9236 |
break;
|
|
|
|
| 9241 |
case GGML_OP_NORM:
|
| 9242 |
func = ggml_cuda_norm;
|
| 9243 |
break;
|
| 9244 |
+
case GGML_OP_GROUP_NORM:
|
| 9245 |
+
func = ggml_cuda_group_norm;
|
| 9246 |
+
break;
|
| 9247 |
+
case GGML_OP_CONCAT:
|
| 9248 |
+
func = ggml_cuda_concat;
|
| 9249 |
+
break;
|
| 9250 |
+
case GGML_OP_UPSCALE:
|
| 9251 |
+
func = ggml_cuda_upscale;
|
| 9252 |
+
break;
|
| 9253 |
+
case GGML_OP_PAD:
|
| 9254 |
+
func = ggml_cuda_pad;
|
| 9255 |
+
break;
|
| 9256 |
+
case GGML_OP_LEAKY_RELU:
|
| 9257 |
+
func = ggml_cuda_leaky_relu;
|
| 9258 |
+
break;
|
| 9259 |
case GGML_OP_RMS_NORM:
|
| 9260 |
func = ggml_cuda_rms_norm;
|
| 9261 |
break;
|
|
|
|
| 9278 |
func = ggml_cuda_sqr;
|
| 9279 |
break;
|
| 9280 |
case GGML_OP_CLAMP:
|
|
|
|
|
|
|
|
|
|
| 9281 |
func = ggml_cuda_clamp;
|
| 9282 |
break;
|
| 9283 |
case GGML_OP_CPY:
|
|
|
|
| 9286 |
case GGML_OP_CONT:
|
| 9287 |
func = ggml_cuda_dup;
|
| 9288 |
break;
|
| 9289 |
+
case GGML_OP_NONE:
|
| 9290 |
case GGML_OP_RESHAPE:
|
| 9291 |
case GGML_OP_VIEW:
|
| 9292 |
case GGML_OP_PERMUTE:
|
|
|
|
| 9706 |
case GGML_UNARY_OP_GELU:
|
| 9707 |
case GGML_UNARY_OP_SILU:
|
| 9708 |
case GGML_UNARY_OP_RELU:
|
| 9709 |
+
case GGML_UNARY_OP_GELU_QUICK:
|
| 9710 |
+
case GGML_UNARY_OP_TANH:
|
| 9711 |
return true;
|
| 9712 |
default:
|
| 9713 |
return false;
|
|
|
|
| 9730 |
}
|
| 9731 |
return true;
|
| 9732 |
} break;
|
| 9733 |
+
case GGML_OP_GET_ROWS:
|
| 9734 |
+
{
|
| 9735 |
+
switch (op->src[0]->type) {
|
| 9736 |
+
case GGML_TYPE_F16:
|
| 9737 |
+
case GGML_TYPE_F32:
|
| 9738 |
+
case GGML_TYPE_Q4_0:
|
| 9739 |
+
case GGML_TYPE_Q4_1:
|
| 9740 |
+
case GGML_TYPE_Q5_0:
|
| 9741 |
+
case GGML_TYPE_Q5_1:
|
| 9742 |
+
case GGML_TYPE_Q8_0:
|
| 9743 |
+
return true;
|
| 9744 |
+
default:
|
| 9745 |
+
return false;
|
| 9746 |
+
}
|
| 9747 |
+
} break;
|
| 9748 |
+
case GGML_OP_CPY:
|
| 9749 |
+
{
|
| 9750 |
+
ggml_type src0_type = op->src[0]->type;
|
| 9751 |
+
ggml_type src1_type = op->src[1]->type;
|
| 9752 |
+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
| 9753 |
+
return true;
|
| 9754 |
+
}
|
| 9755 |
+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
| 9756 |
+
return true;
|
| 9757 |
+
}
|
| 9758 |
+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
|
| 9759 |
+
return true;
|
| 9760 |
+
}
|
| 9761 |
+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
|
| 9762 |
+
return true;
|
| 9763 |
+
}
|
| 9764 |
+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
|
| 9765 |
+
return true;
|
| 9766 |
+
}
|
| 9767 |
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
| 9768 |
+
return true;
|
| 9769 |
+
}
|
| 9770 |
+
return false;
|
| 9771 |
+
} break;
|
| 9772 |
case GGML_OP_NONE:
|
| 9773 |
case GGML_OP_RESHAPE:
|
| 9774 |
case GGML_OP_VIEW:
|
|
|
|
| 9776 |
case GGML_OP_TRANSPOSE:
|
| 9777 |
case GGML_OP_NORM:
|
| 9778 |
case GGML_OP_REPEAT:
|
|
|
|
| 9779 |
case GGML_OP_DUP:
|
| 9780 |
case GGML_OP_ADD:
|
| 9781 |
case GGML_OP_MUL:
|
|
|
|
| 9784 |
case GGML_OP_SCALE:
|
| 9785 |
case GGML_OP_SQR:
|
| 9786 |
case GGML_OP_CLAMP:
|
|
|
|
| 9787 |
case GGML_OP_CONT:
|
| 9788 |
case GGML_OP_DIAG_MASK_INF:
|
| 9789 |
case GGML_OP_SOFT_MAX:
|
|
|
|
| 9792 |
case GGML_OP_IM2COL:
|
| 9793 |
case GGML_OP_SUM_ROWS:
|
| 9794 |
case GGML_OP_ARGSORT:
|
| 9795 |
+
case GGML_OP_ACC:
|
| 9796 |
+
case GGML_OP_CONCAT:
|
| 9797 |
+
case GGML_OP_GROUP_NORM:
|
| 9798 |
+
case GGML_OP_UPSCALE:
|
| 9799 |
+
case GGML_OP_PAD:
|
| 9800 |
+
case GGML_OP_LEAKY_RELU:
|
| 9801 |
return true;
|
| 9802 |
default:
|
| 9803 |
return false;
|
|
|
|
| 9856 |
UNUSED(params);
|
| 9857 |
}
|
| 9858 |
|
| 9859 |
+
extern "C" int ggml_backend_cuda_reg_devices();
|
| 9860 |
+
|
| 9861 |
+
int ggml_backend_cuda_reg_devices() {
|
| 9862 |
int device_count = ggml_cuda_get_device_count();
|
| 9863 |
//int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
|
| 9864 |
for (int i = 0; i < device_count; i++) {
|
ggml-metal.m
CHANGED
|
@@ -66,9 +66,11 @@ struct ggml_metal_context {
|
|
| 66 |
GGML_METAL_DECL_KERNEL(div_row);
|
| 67 |
GGML_METAL_DECL_KERNEL(scale);
|
| 68 |
GGML_METAL_DECL_KERNEL(scale_4);
|
| 69 |
-
GGML_METAL_DECL_KERNEL(
|
| 70 |
GGML_METAL_DECL_KERNEL(relu);
|
| 71 |
GGML_METAL_DECL_KERNEL(gelu);
|
|
|
|
|
|
|
| 72 |
GGML_METAL_DECL_KERNEL(soft_max);
|
| 73 |
GGML_METAL_DECL_KERNEL(soft_max_4);
|
| 74 |
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
|
@@ -86,6 +88,7 @@ struct ggml_metal_context {
|
|
| 86 |
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
| 87 |
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
| 88 |
GGML_METAL_DECL_KERNEL(rms_norm);
|
|
|
|
| 89 |
GGML_METAL_DECL_KERNEL(norm);
|
| 90 |
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
| 91 |
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
|
@@ -102,6 +105,21 @@ struct ggml_metal_context {
|
|
| 102 |
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
| 103 |
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
| 104 |
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
| 106 |
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
| 107 |
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
|
@@ -130,8 +148,11 @@ struct ggml_metal_context {
|
|
| 130 |
GGML_METAL_DECL_KERNEL(rope_f16);
|
| 131 |
GGML_METAL_DECL_KERNEL(alibi_f32);
|
| 132 |
GGML_METAL_DECL_KERNEL(im2col_f16);
|
|
|
|
|
|
|
| 133 |
GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
| 134 |
GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
|
|
|
| 135 |
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
| 136 |
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
| 137 |
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
|
@@ -140,6 +161,7 @@ struct ggml_metal_context {
|
|
| 140 |
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
| 141 |
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
| 142 |
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
|
|
|
| 143 |
GGML_METAL_DECL_KERNEL(concat);
|
| 144 |
GGML_METAL_DECL_KERNEL(sqr);
|
| 145 |
GGML_METAL_DECL_KERNEL(sum_rows);
|
|
@@ -318,9 +340,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 318 |
GGML_METAL_ADD_KERNEL(div_row);
|
| 319 |
GGML_METAL_ADD_KERNEL(scale);
|
| 320 |
GGML_METAL_ADD_KERNEL(scale_4);
|
| 321 |
-
GGML_METAL_ADD_KERNEL(
|
| 322 |
GGML_METAL_ADD_KERNEL(relu);
|
| 323 |
GGML_METAL_ADD_KERNEL(gelu);
|
|
|
|
|
|
|
| 324 |
GGML_METAL_ADD_KERNEL(soft_max);
|
| 325 |
GGML_METAL_ADD_KERNEL(soft_max_4);
|
| 326 |
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
|
@@ -338,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 338 |
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
| 339 |
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
| 340 |
GGML_METAL_ADD_KERNEL(rms_norm);
|
|
|
|
| 341 |
GGML_METAL_ADD_KERNEL(norm);
|
| 342 |
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
| 343 |
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
|
@@ -354,6 +379,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 354 |
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
| 355 |
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
| 356 |
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
| 358 |
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
| 359 |
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
|
@@ -384,8 +424,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 384 |
GGML_METAL_ADD_KERNEL(rope_f16);
|
| 385 |
GGML_METAL_ADD_KERNEL(alibi_f32);
|
| 386 |
GGML_METAL_ADD_KERNEL(im2col_f16);
|
|
|
|
|
|
|
| 387 |
GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
| 388 |
GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
|
|
|
| 389 |
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
| 390 |
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
| 391 |
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
|
@@ -394,6 +437,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 394 |
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
| 395 |
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
| 396 |
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
|
|
|
| 397 |
GGML_METAL_ADD_KERNEL(concat);
|
| 398 |
GGML_METAL_ADD_KERNEL(sqr);
|
| 399 |
GGML_METAL_ADD_KERNEL(sum_rows);
|
|
@@ -418,9 +462,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 418 |
GGML_METAL_DEL_KERNEL(div_row);
|
| 419 |
GGML_METAL_DEL_KERNEL(scale);
|
| 420 |
GGML_METAL_DEL_KERNEL(scale_4);
|
| 421 |
-
GGML_METAL_DEL_KERNEL(
|
| 422 |
GGML_METAL_DEL_KERNEL(relu);
|
| 423 |
GGML_METAL_DEL_KERNEL(gelu);
|
|
|
|
|
|
|
| 424 |
GGML_METAL_DEL_KERNEL(soft_max);
|
| 425 |
GGML_METAL_DEL_KERNEL(soft_max_4);
|
| 426 |
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
|
@@ -438,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 438 |
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
| 439 |
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
| 440 |
GGML_METAL_DEL_KERNEL(rms_norm);
|
|
|
|
| 441 |
GGML_METAL_DEL_KERNEL(norm);
|
| 442 |
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
| 443 |
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
|
@@ -454,6 +501,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 454 |
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
| 455 |
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
| 456 |
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
| 458 |
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
| 459 |
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
|
@@ -484,8 +546,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 484 |
GGML_METAL_DEL_KERNEL(rope_f16);
|
| 485 |
GGML_METAL_DEL_KERNEL(alibi_f32);
|
| 486 |
GGML_METAL_DEL_KERNEL(im2col_f16);
|
|
|
|
|
|
|
| 487 |
GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
| 488 |
GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
|
|
|
| 489 |
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
| 490 |
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
| 491 |
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
|
@@ -494,6 +559,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 494 |
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
| 495 |
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
| 496 |
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
|
|
|
| 497 |
GGML_METAL_DEL_KERNEL(concat);
|
| 498 |
GGML_METAL_DEL_KERNEL(sqr);
|
| 499 |
GGML_METAL_DEL_KERNEL(sum_rows);
|
|
@@ -795,9 +861,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
| 795 |
switch (op->op) {
|
| 796 |
case GGML_OP_UNARY:
|
| 797 |
switch (ggml_get_unary_op(op)) {
|
| 798 |
-
case
|
| 799 |
case GGML_UNARY_OP_RELU:
|
| 800 |
case GGML_UNARY_OP_GELU:
|
|
|
|
|
|
|
| 801 |
return true;
|
| 802 |
default:
|
| 803 |
return false;
|
|
@@ -809,6 +877,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
| 809 |
case GGML_OP_PERMUTE:
|
| 810 |
case GGML_OP_CONCAT:
|
| 811 |
case GGML_OP_ADD:
|
|
|
|
| 812 |
case GGML_OP_MUL:
|
| 813 |
case GGML_OP_DIV:
|
| 814 |
case GGML_OP_SCALE:
|
|
@@ -816,21 +885,50 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
| 816 |
case GGML_OP_SUM_ROWS:
|
| 817 |
case GGML_OP_SOFT_MAX:
|
| 818 |
case GGML_OP_RMS_NORM:
|
|
|
|
| 819 |
case GGML_OP_NORM:
|
| 820 |
case GGML_OP_ALIBI:
|
| 821 |
case GGML_OP_ROPE:
|
| 822 |
case GGML_OP_IM2COL:
|
|
|
|
|
|
|
| 823 |
case GGML_OP_ARGSORT:
|
| 824 |
-
case
|
| 825 |
-
case GGML_OP_CPY:
|
| 826 |
-
case GGML_OP_CONT:
|
| 827 |
case GGML_OP_MUL_MAT:
|
| 828 |
case GGML_OP_MUL_MAT_ID:
|
| 829 |
return true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 830 |
case GGML_OP_DIAG_MASK_INF:
|
| 831 |
case GGML_OP_GET_ROWS:
|
| 832 |
{
|
| 833 |
-
return op->ne[
|
| 834 |
}
|
| 835 |
default:
|
| 836 |
return false;
|
|
@@ -906,7 +1004,10 @@ void ggml_metal_graph_compute(
|
|
| 906 |
} break;
|
| 907 |
}
|
| 908 |
|
| 909 |
-
|
|
|
|
|
|
|
|
|
|
| 910 |
|
| 911 |
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
| 912 |
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
@@ -1003,34 +1104,39 @@ void ggml_metal_graph_compute(
|
|
| 1003 |
case GGML_OP_MUL:
|
| 1004 |
case GGML_OP_DIV:
|
| 1005 |
{
|
| 1006 |
-
|
| 1007 |
-
GGML_ASSERT(ggml_is_contiguous(src1));
|
| 1008 |
|
| 1009 |
bool bcast_row = false;
|
| 1010 |
|
| 1011 |
int64_t nb = ne00;
|
| 1012 |
|
| 1013 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1014 |
// src1 is a row
|
| 1015 |
GGML_ASSERT(ne11 == 1);
|
| 1016 |
|
| 1017 |
nb = ne00 / 4;
|
| 1018 |
switch (dst->op) {
|
| 1019 |
-
case GGML_OP_ADD:
|
| 1020 |
-
case GGML_OP_MUL:
|
| 1021 |
-
case GGML_OP_DIV:
|
| 1022 |
default: GGML_ASSERT(false);
|
| 1023 |
}
|
| 1024 |
|
| 1025 |
bcast_row = true;
|
| 1026 |
} else {
|
| 1027 |
switch (dst->op) {
|
| 1028 |
-
case GGML_OP_ADD:
|
| 1029 |
-
case GGML_OP_MUL:
|
| 1030 |
-
case GGML_OP_DIV:
|
| 1031 |
default: GGML_ASSERT(false);
|
| 1032 |
}
|
| 1033 |
}
|
|
|
|
|
|
|
| 1034 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1035 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1036 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
@@ -1058,18 +1164,99 @@ void ggml_metal_graph_compute(
|
|
| 1058 |
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
| 1059 |
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
| 1060 |
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
| 1061 |
-
[encoder setBytes:&
|
|
|
|
| 1062 |
|
| 1063 |
if (bcast_row) {
|
| 1064 |
const int64_t n = ggml_nelements(dst)/4;
|
| 1065 |
|
| 1066 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1067 |
} else {
|
| 1068 |
-
const int nth = MIN(
|
| 1069 |
|
| 1070 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1071 |
}
|
| 1072 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1073 |
case GGML_OP_SCALE:
|
| 1074 |
{
|
| 1075 |
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
@@ -1093,16 +1280,15 @@ void ggml_metal_graph_compute(
|
|
| 1093 |
} break;
|
| 1094 |
case GGML_OP_UNARY:
|
| 1095 |
switch (ggml_get_unary_op(gf->nodes[i])) {
|
| 1096 |
-
case
|
| 1097 |
{
|
| 1098 |
-
[encoder setComputePipelineState:ctx->
|
| 1099 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1100 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1101 |
|
| 1102 |
const int64_t n = ggml_nelements(dst);
|
| 1103 |
-
GGML_ASSERT(n % 4 == 0);
|
| 1104 |
|
| 1105 |
-
[encoder dispatchThreadgroups:MTLSizeMake(n
|
| 1106 |
} break;
|
| 1107 |
case GGML_UNARY_OP_RELU:
|
| 1108 |
{
|
|
@@ -1123,6 +1309,28 @@ void ggml_metal_graph_compute(
|
|
| 1123 |
const int64_t n = ggml_nelements(dst);
|
| 1124 |
GGML_ASSERT(n % 4 == 0);
|
| 1125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1126 |
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1127 |
} break;
|
| 1128 |
default:
|
|
@@ -1197,6 +1405,8 @@ void ggml_metal_graph_compute(
|
|
| 1197 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1198 |
if (id_src1) {
|
| 1199 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
|
|
|
| 1200 |
}
|
| 1201 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1202 |
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
@@ -1448,7 +1658,7 @@ void ggml_metal_graph_compute(
|
|
| 1448 |
else if (src0t == GGML_TYPE_Q6_K) {
|
| 1449 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1450 |
} else {
|
| 1451 |
-
int64_t ny = (ne11 + nrows - 1)/nrows;
|
| 1452 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1453 |
}
|
| 1454 |
}
|
|
@@ -1460,7 +1670,7 @@ void ggml_metal_graph_compute(
|
|
| 1460 |
|
| 1461 |
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
| 1462 |
|
| 1463 |
-
const int n_as =
|
| 1464 |
|
| 1465 |
// TODO: make this more general
|
| 1466 |
GGML_ASSERT(n_as <= 8);
|
|
@@ -1492,14 +1702,22 @@ void ggml_metal_graph_compute(
|
|
| 1492 |
|
| 1493 |
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
| 1494 |
// to the matrix-vector kernel
|
| 1495 |
-
int ne11_mm_min =
|
| 1496 |
|
| 1497 |
const int idx = ((int32_t *) dst->op_params)[0];
|
| 1498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1499 |
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
| 1500 |
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
| 1501 |
-
|
| 1502 |
-
|
|
|
|
|
|
|
|
|
|
| 1503 |
switch (src2->type) {
|
| 1504 |
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
| 1505 |
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
|
@@ -1518,19 +1736,22 @@ void ggml_metal_graph_compute(
|
|
| 1518 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1519 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1520 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1521 |
-
[encoder setBytes:&
|
| 1522 |
-
[encoder setBytes:&
|
| 1523 |
-
[encoder setBytes:&
|
| 1524 |
-
[encoder setBytes:&
|
| 1525 |
-
[encoder setBytes:&
|
| 1526 |
-
[encoder setBytes:&
|
| 1527 |
-
[encoder setBytes:&
|
| 1528 |
-
[encoder setBytes:&
|
| 1529 |
-
[encoder setBytes:&
|
| 1530 |
-
[encoder setBytes:&
|
| 1531 |
-
[encoder setBytes:&
|
| 1532 |
-
[encoder setBytes:&
|
| 1533 |
-
[encoder setBytes:&
|
|
|
|
|
|
|
|
|
|
| 1534 |
// TODO: how to make this an array? read Metal docs
|
| 1535 |
for (int j = 0; j < n_as; ++j) {
|
| 1536 |
struct ggml_tensor * src_cur = dst->src[2 + j];
|
|
@@ -1538,11 +1759,157 @@ void ggml_metal_graph_compute(
|
|
| 1538 |
size_t offs_src_cur = 0;
|
| 1539 |
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
| 1540 |
|
| 1541 |
-
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:
|
| 1542 |
}
|
| 1543 |
|
| 1544 |
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
| 1545 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1546 |
}
|
| 1547 |
} break;
|
| 1548 |
case GGML_OP_GET_ROWS:
|
|
@@ -1563,16 +1930,19 @@ void ggml_metal_graph_compute(
|
|
| 1563 |
default: GGML_ASSERT(false && "not implemented");
|
| 1564 |
}
|
| 1565 |
|
| 1566 |
-
[encoder setBuffer:id_src0
|
| 1567 |
-
[encoder setBuffer:id_src1
|
| 1568 |
-
[encoder setBuffer:id_dst
|
| 1569 |
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
| 1570 |
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
| 1571 |
-
[encoder setBytes:&
|
| 1572 |
-
|
| 1573 |
-
|
| 1574 |
-
|
| 1575 |
-
[encoder
|
|
|
|
|
|
|
|
|
|
| 1576 |
} break;
|
| 1577 |
case GGML_OP_RMS_NORM:
|
| 1578 |
{
|
|
@@ -1599,6 +1969,38 @@ void ggml_metal_graph_compute(
|
|
| 1599 |
|
| 1600 |
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1601 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1602 |
case GGML_OP_NORM:
|
| 1603 |
{
|
| 1604 |
float eps;
|
|
@@ -1768,6 +2170,65 @@ void ggml_metal_graph_compute(
|
|
| 1768 |
|
| 1769 |
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
| 1770 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1771 |
case GGML_OP_ARGSORT:
|
| 1772 |
{
|
| 1773 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
@@ -1789,6 +2250,22 @@ void ggml_metal_graph_compute(
|
|
| 1789 |
|
| 1790 |
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
| 1791 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1792 |
case GGML_OP_DUP:
|
| 1793 |
case GGML_OP_CPY:
|
| 1794 |
case GGML_OP_CONT:
|
|
@@ -1817,7 +2294,7 @@ void ggml_metal_graph_compute(
|
|
| 1817 |
{
|
| 1818 |
switch (dstt) {
|
| 1819 |
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
|
| 1820 |
-
case GGML_TYPE_F32:
|
| 1821 |
default: GGML_ASSERT(false && "not implemented");
|
| 1822 |
};
|
| 1823 |
} break;
|
|
|
|
| 66 |
GGML_METAL_DECL_KERNEL(div_row);
|
| 67 |
GGML_METAL_DECL_KERNEL(scale);
|
| 68 |
GGML_METAL_DECL_KERNEL(scale_4);
|
| 69 |
+
GGML_METAL_DECL_KERNEL(tanh);
|
| 70 |
GGML_METAL_DECL_KERNEL(relu);
|
| 71 |
GGML_METAL_DECL_KERNEL(gelu);
|
| 72 |
+
GGML_METAL_DECL_KERNEL(gelu_quick);
|
| 73 |
+
GGML_METAL_DECL_KERNEL(silu);
|
| 74 |
GGML_METAL_DECL_KERNEL(soft_max);
|
| 75 |
GGML_METAL_DECL_KERNEL(soft_max_4);
|
| 76 |
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
|
|
|
| 88 |
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
| 89 |
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
| 90 |
GGML_METAL_DECL_KERNEL(rms_norm);
|
| 91 |
+
GGML_METAL_DECL_KERNEL(group_norm);
|
| 92 |
GGML_METAL_DECL_KERNEL(norm);
|
| 93 |
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
| 94 |
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
|
|
|
| 105 |
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
| 106 |
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
| 107 |
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
| 108 |
+
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
| 109 |
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
| 110 |
+
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
| 111 |
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
|
| 112 |
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
|
| 113 |
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
|
| 114 |
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
|
| 115 |
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
|
| 116 |
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
|
| 117 |
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
|
| 118 |
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
|
| 119 |
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
|
| 120 |
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
| 121 |
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
| 122 |
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
| 123 |
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
| 124 |
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
| 125 |
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
|
|
|
| 148 |
GGML_METAL_DECL_KERNEL(rope_f16);
|
| 149 |
GGML_METAL_DECL_KERNEL(alibi_f32);
|
| 150 |
GGML_METAL_DECL_KERNEL(im2col_f16);
|
| 151 |
+
GGML_METAL_DECL_KERNEL(upscale_f32);
|
| 152 |
+
GGML_METAL_DECL_KERNEL(pad_f32);
|
| 153 |
GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
| 154 |
GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
| 155 |
+
GGML_METAL_DECL_KERNEL(leaky_relu_f32);
|
| 156 |
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
| 157 |
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
| 158 |
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
|
|
|
| 161 |
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
| 162 |
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
| 163 |
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
| 164 |
+
GGML_METAL_DECL_KERNEL(cpy_f16_f32);
|
| 165 |
GGML_METAL_DECL_KERNEL(concat);
|
| 166 |
GGML_METAL_DECL_KERNEL(sqr);
|
| 167 |
GGML_METAL_DECL_KERNEL(sum_rows);
|
|
|
|
| 340 |
GGML_METAL_ADD_KERNEL(div_row);
|
| 341 |
GGML_METAL_ADD_KERNEL(scale);
|
| 342 |
GGML_METAL_ADD_KERNEL(scale_4);
|
| 343 |
+
GGML_METAL_ADD_KERNEL(tanh);
|
| 344 |
GGML_METAL_ADD_KERNEL(relu);
|
| 345 |
GGML_METAL_ADD_KERNEL(gelu);
|
| 346 |
+
GGML_METAL_ADD_KERNEL(gelu_quick);
|
| 347 |
+
GGML_METAL_ADD_KERNEL(silu);
|
| 348 |
GGML_METAL_ADD_KERNEL(soft_max);
|
| 349 |
GGML_METAL_ADD_KERNEL(soft_max_4);
|
| 350 |
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
|
|
|
| 362 |
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
| 363 |
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
| 364 |
GGML_METAL_ADD_KERNEL(rms_norm);
|
| 365 |
+
GGML_METAL_ADD_KERNEL(group_norm);
|
| 366 |
GGML_METAL_ADD_KERNEL(norm);
|
| 367 |
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
| 368 |
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
|
|
|
| 379 |
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
| 380 |
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
| 381 |
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
| 382 |
+
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
| 383 |
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
| 384 |
+
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
| 385 |
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
|
| 386 |
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
|
| 387 |
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
|
| 388 |
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
|
| 389 |
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
|
| 390 |
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
|
| 391 |
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
|
| 392 |
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
|
| 393 |
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
|
| 394 |
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
|
| 395 |
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
| 396 |
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
| 397 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
| 398 |
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
| 399 |
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
|
|
|
| 424 |
GGML_METAL_ADD_KERNEL(rope_f16);
|
| 425 |
GGML_METAL_ADD_KERNEL(alibi_f32);
|
| 426 |
GGML_METAL_ADD_KERNEL(im2col_f16);
|
| 427 |
+
GGML_METAL_ADD_KERNEL(upscale_f32);
|
| 428 |
+
GGML_METAL_ADD_KERNEL(pad_f32);
|
| 429 |
GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
| 430 |
GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
| 431 |
+
GGML_METAL_ADD_KERNEL(leaky_relu_f32);
|
| 432 |
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
| 433 |
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
| 434 |
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
|
|
|
| 437 |
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
| 438 |
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
| 439 |
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
| 440 |
+
GGML_METAL_ADD_KERNEL(cpy_f16_f32);
|
| 441 |
GGML_METAL_ADD_KERNEL(concat);
|
| 442 |
GGML_METAL_ADD_KERNEL(sqr);
|
| 443 |
GGML_METAL_ADD_KERNEL(sum_rows);
|
|
|
|
| 462 |
GGML_METAL_DEL_KERNEL(div_row);
|
| 463 |
GGML_METAL_DEL_KERNEL(scale);
|
| 464 |
GGML_METAL_DEL_KERNEL(scale_4);
|
| 465 |
+
GGML_METAL_DEL_KERNEL(tanh);
|
| 466 |
GGML_METAL_DEL_KERNEL(relu);
|
| 467 |
GGML_METAL_DEL_KERNEL(gelu);
|
| 468 |
+
GGML_METAL_DEL_KERNEL(gelu_quick);
|
| 469 |
+
GGML_METAL_DEL_KERNEL(silu);
|
| 470 |
GGML_METAL_DEL_KERNEL(soft_max);
|
| 471 |
GGML_METAL_DEL_KERNEL(soft_max_4);
|
| 472 |
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
|
|
|
| 484 |
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
| 485 |
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
| 486 |
GGML_METAL_DEL_KERNEL(rms_norm);
|
| 487 |
+
GGML_METAL_DEL_KERNEL(group_norm);
|
| 488 |
GGML_METAL_DEL_KERNEL(norm);
|
| 489 |
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
| 490 |
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
|
|
|
| 501 |
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
| 502 |
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
| 503 |
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
| 504 |
+
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
| 505 |
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
| 506 |
+
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
| 507 |
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
|
| 508 |
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
|
| 509 |
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
|
| 510 |
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
|
| 511 |
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
|
| 512 |
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
|
| 513 |
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
|
| 514 |
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
|
| 515 |
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
|
| 516 |
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
| 517 |
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
| 518 |
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
| 519 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
| 520 |
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
| 521 |
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
|
|
|
| 546 |
GGML_METAL_DEL_KERNEL(rope_f16);
|
| 547 |
GGML_METAL_DEL_KERNEL(alibi_f32);
|
| 548 |
GGML_METAL_DEL_KERNEL(im2col_f16);
|
| 549 |
+
GGML_METAL_DEL_KERNEL(upscale_f32);
|
| 550 |
+
GGML_METAL_DEL_KERNEL(pad_f32);
|
| 551 |
GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
| 552 |
GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
| 553 |
+
GGML_METAL_DEL_KERNEL(leaky_relu_f32);
|
| 554 |
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
| 555 |
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
| 556 |
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
|
|
|
| 559 |
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
| 560 |
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
| 561 |
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
| 562 |
+
GGML_METAL_DEL_KERNEL(cpy_f16_f32);
|
| 563 |
GGML_METAL_DEL_KERNEL(concat);
|
| 564 |
GGML_METAL_DEL_KERNEL(sqr);
|
| 565 |
GGML_METAL_DEL_KERNEL(sum_rows);
|
|
|
|
| 861 |
switch (op->op) {
|
| 862 |
case GGML_OP_UNARY:
|
| 863 |
switch (ggml_get_unary_op(op)) {
|
| 864 |
+
case GGML_UNARY_OP_TANH:
|
| 865 |
case GGML_UNARY_OP_RELU:
|
| 866 |
case GGML_UNARY_OP_GELU:
|
| 867 |
+
case GGML_UNARY_OP_GELU_QUICK:
|
| 868 |
+
case GGML_UNARY_OP_SILU:
|
| 869 |
return true;
|
| 870 |
default:
|
| 871 |
return false;
|
|
|
|
| 877 |
case GGML_OP_PERMUTE:
|
| 878 |
case GGML_OP_CONCAT:
|
| 879 |
case GGML_OP_ADD:
|
| 880 |
+
case GGML_OP_ACC:
|
| 881 |
case GGML_OP_MUL:
|
| 882 |
case GGML_OP_DIV:
|
| 883 |
case GGML_OP_SCALE:
|
|
|
|
| 885 |
case GGML_OP_SUM_ROWS:
|
| 886 |
case GGML_OP_SOFT_MAX:
|
| 887 |
case GGML_OP_RMS_NORM:
|
| 888 |
+
case GGML_OP_GROUP_NORM:
|
| 889 |
case GGML_OP_NORM:
|
| 890 |
case GGML_OP_ALIBI:
|
| 891 |
case GGML_OP_ROPE:
|
| 892 |
case GGML_OP_IM2COL:
|
| 893 |
+
case GGML_OP_UPSCALE:
|
| 894 |
+
case GGML_OP_PAD:
|
| 895 |
case GGML_OP_ARGSORT:
|
| 896 |
+
case GGML_OP_LEAKY_RELU:
|
|
|
|
|
|
|
| 897 |
case GGML_OP_MUL_MAT:
|
| 898 |
case GGML_OP_MUL_MAT_ID:
|
| 899 |
return true;
|
| 900 |
+
case GGML_OP_CPY:
|
| 901 |
+
case GGML_OP_DUP:
|
| 902 |
+
case GGML_OP_CONT:
|
| 903 |
+
{
|
| 904 |
+
switch (op->src[0]->type) {
|
| 905 |
+
case GGML_TYPE_F32:
|
| 906 |
+
switch (op->type) {
|
| 907 |
+
case GGML_TYPE_F16:
|
| 908 |
+
case GGML_TYPE_F32:
|
| 909 |
+
case GGML_TYPE_Q8_0:
|
| 910 |
+
case GGML_TYPE_Q4_0:
|
| 911 |
+
case GGML_TYPE_Q4_1:
|
| 912 |
+
return true;
|
| 913 |
+
default:
|
| 914 |
+
return false;
|
| 915 |
+
}
|
| 916 |
+
case GGML_TYPE_F16:
|
| 917 |
+
switch (op->type) {
|
| 918 |
+
case GGML_TYPE_F16:
|
| 919 |
+
case GGML_TYPE_F32:
|
| 920 |
+
return true;
|
| 921 |
+
default:
|
| 922 |
+
return false;
|
| 923 |
+
}
|
| 924 |
+
default:
|
| 925 |
+
return false;
|
| 926 |
+
};
|
| 927 |
+
}
|
| 928 |
case GGML_OP_DIAG_MASK_INF:
|
| 929 |
case GGML_OP_GET_ROWS:
|
| 930 |
{
|
| 931 |
+
return op->ne[3] == 1;
|
| 932 |
}
|
| 933 |
default:
|
| 934 |
return false;
|
|
|
|
| 1004 |
} break;
|
| 1005 |
}
|
| 1006 |
|
| 1007 |
+
if (!ggml_metal_supports_op(dst)) {
|
| 1008 |
+
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
| 1009 |
+
GGML_ASSERT(!"unsupported op");
|
| 1010 |
+
}
|
| 1011 |
|
| 1012 |
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
| 1013 |
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
|
|
| 1104 |
case GGML_OP_MUL:
|
| 1105 |
case GGML_OP_DIV:
|
| 1106 |
{
|
| 1107 |
+
const size_t offs = 0;
|
|
|
|
| 1108 |
|
| 1109 |
bool bcast_row = false;
|
| 1110 |
|
| 1111 |
int64_t nb = ne00;
|
| 1112 |
|
| 1113 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 1114 |
+
|
| 1115 |
+
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
| 1116 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 1117 |
+
|
| 1118 |
// src1 is a row
|
| 1119 |
GGML_ASSERT(ne11 == 1);
|
| 1120 |
|
| 1121 |
nb = ne00 / 4;
|
| 1122 |
switch (dst->op) {
|
| 1123 |
+
case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
|
| 1124 |
+
case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
|
| 1125 |
+
case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
|
| 1126 |
default: GGML_ASSERT(false);
|
| 1127 |
}
|
| 1128 |
|
| 1129 |
bcast_row = true;
|
| 1130 |
} else {
|
| 1131 |
switch (dst->op) {
|
| 1132 |
+
case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
|
| 1133 |
+
case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
|
| 1134 |
+
case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
|
| 1135 |
default: GGML_ASSERT(false);
|
| 1136 |
}
|
| 1137 |
}
|
| 1138 |
+
|
| 1139 |
+
[encoder setComputePipelineState:pipeline];
|
| 1140 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1141 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1142 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
| 1164 |
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
| 1165 |
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
| 1166 |
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
| 1167 |
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
| 1168 |
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
| 1169 |
|
| 1170 |
if (bcast_row) {
|
| 1171 |
const int64_t n = ggml_nelements(dst)/4;
|
| 1172 |
|
| 1173 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1174 |
} else {
|
| 1175 |
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
| 1176 |
|
| 1177 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1178 |
}
|
| 1179 |
} break;
|
| 1180 |
+
case GGML_OP_ACC:
|
| 1181 |
+
{
|
| 1182 |
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
| 1183 |
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 1184 |
+
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
| 1185 |
+
|
| 1186 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 1187 |
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
| 1188 |
+
|
| 1189 |
+
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
| 1190 |
+
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
| 1191 |
+
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
| 1192 |
+
const size_t offs = ((int32_t *) dst->op_params)[3];
|
| 1193 |
+
|
| 1194 |
+
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
| 1195 |
+
|
| 1196 |
+
if (!inplace) {
|
| 1197 |
+
// run a separete kernel to cpy src->dst
|
| 1198 |
+
// not sure how to avoid this
|
| 1199 |
+
// TODO: make a simpler cpy_bytes kernel
|
| 1200 |
+
|
| 1201 |
+
const int nth = MIN(1024, ne00);
|
| 1202 |
+
|
| 1203 |
+
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
| 1204 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1205 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1206 |
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
| 1207 |
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
| 1208 |
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
| 1209 |
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
| 1210 |
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
| 1211 |
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
| 1212 |
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
| 1213 |
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
| 1214 |
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
| 1215 |
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
| 1216 |
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
| 1217 |
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
| 1218 |
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
| 1219 |
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
| 1220 |
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
| 1221 |
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
| 1222 |
+
|
| 1223 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1224 |
+
}
|
| 1225 |
+
|
| 1226 |
+
[encoder setComputePipelineState:ctx->pipeline_add];
|
| 1227 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1228 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1229 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1230 |
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
| 1231 |
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
| 1232 |
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
| 1233 |
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
| 1234 |
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
| 1235 |
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
| 1236 |
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
| 1237 |
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
| 1238 |
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
| 1239 |
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
| 1240 |
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
| 1241 |
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
| 1242 |
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
| 1243 |
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
| 1244 |
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
| 1245 |
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
| 1246 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
| 1247 |
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
| 1248 |
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
| 1249 |
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
| 1250 |
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
| 1251 |
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
| 1252 |
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
| 1253 |
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
| 1254 |
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
| 1255 |
+
|
| 1256 |
+
const int nth = MIN(1024, ne0);
|
| 1257 |
+
|
| 1258 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1259 |
+
} break;
|
| 1260 |
case GGML_OP_SCALE:
|
| 1261 |
{
|
| 1262 |
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
| 1280 |
} break;
|
| 1281 |
case GGML_OP_UNARY:
|
| 1282 |
switch (ggml_get_unary_op(gf->nodes[i])) {
|
| 1283 |
+
case GGML_UNARY_OP_TANH:
|
| 1284 |
{
|
| 1285 |
+
[encoder setComputePipelineState:ctx->pipeline_tanh];
|
| 1286 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1287 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1288 |
|
| 1289 |
const int64_t n = ggml_nelements(dst);
|
|
|
|
| 1290 |
|
| 1291 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1292 |
} break;
|
| 1293 |
case GGML_UNARY_OP_RELU:
|
| 1294 |
{
|
|
|
|
| 1309 |
const int64_t n = ggml_nelements(dst);
|
| 1310 |
GGML_ASSERT(n % 4 == 0);
|
| 1311 |
|
| 1312 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1313 |
+
} break;
|
| 1314 |
+
case GGML_UNARY_OP_GELU_QUICK:
|
| 1315 |
+
{
|
| 1316 |
+
[encoder setComputePipelineState:ctx->pipeline_gelu_quick];
|
| 1317 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1318 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1319 |
+
|
| 1320 |
+
const int64_t n = ggml_nelements(dst);
|
| 1321 |
+
GGML_ASSERT(n % 4 == 0);
|
| 1322 |
+
|
| 1323 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1324 |
+
} break;
|
| 1325 |
+
case GGML_UNARY_OP_SILU:
|
| 1326 |
+
{
|
| 1327 |
+
[encoder setComputePipelineState:ctx->pipeline_silu];
|
| 1328 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1329 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1330 |
+
|
| 1331 |
+
const int64_t n = ggml_nelements(dst);
|
| 1332 |
+
GGML_ASSERT(n % 4 == 0);
|
| 1333 |
+
|
| 1334 |
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1335 |
} break;
|
| 1336 |
default:
|
|
|
|
| 1405 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1406 |
if (id_src1) {
|
| 1407 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1408 |
+
} else {
|
| 1409 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 1410 |
}
|
| 1411 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1412 |
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
|
|
| 1658 |
else if (src0t == GGML_TYPE_Q6_K) {
|
| 1659 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1660 |
} else {
|
| 1661 |
+
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
| 1662 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1663 |
}
|
| 1664 |
}
|
|
|
|
| 1670 |
|
| 1671 |
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
| 1672 |
|
| 1673 |
+
const int n_as = ((int32_t *) dst->op_params)[1];
|
| 1674 |
|
| 1675 |
// TODO: make this more general
|
| 1676 |
GGML_ASSERT(n_as <= 8);
|
|
|
|
| 1702 |
|
| 1703 |
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
| 1704 |
// to the matrix-vector kernel
|
| 1705 |
+
int ne11_mm_min = 1;
|
| 1706 |
|
| 1707 |
const int idx = ((int32_t *) dst->op_params)[0];
|
| 1708 |
|
| 1709 |
+
// batch size
|
| 1710 |
+
GGML_ASSERT(ne01 == ne11);
|
| 1711 |
+
|
| 1712 |
+
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
|
| 1713 |
+
|
| 1714 |
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
| 1715 |
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
| 1716 |
+
// !!!
|
| 1717 |
+
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
| 1718 |
+
// indirect matrix multiplication
|
| 1719 |
+
// !!!
|
| 1720 |
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
|
| 1721 |
switch (src2->type) {
|
| 1722 |
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
| 1723 |
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
|
|
|
| 1736 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1737 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1738 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1739 |
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
| 1740 |
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
| 1741 |
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
| 1742 |
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
| 1743 |
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
| 1744 |
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
| 1745 |
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
| 1746 |
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
| 1747 |
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
| 1748 |
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
| 1749 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
| 1750 |
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
|
| 1751 |
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
| 1752 |
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
| 1753 |
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
| 1754 |
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
| 1755 |
// TODO: how to make this an array? read Metal docs
|
| 1756 |
for (int j = 0; j < n_as; ++j) {
|
| 1757 |
struct ggml_tensor * src_cur = dst->src[2 + j];
|
|
|
|
| 1759 |
size_t offs_src_cur = 0;
|
| 1760 |
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
| 1761 |
|
| 1762 |
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
| 1763 |
}
|
| 1764 |
|
| 1765 |
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
| 1766 |
+
|
| 1767 |
+
// TODO: processing one row at a time (ne11 -> 1) is not efficient
|
| 1768 |
+
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 1769 |
+
} else {
|
| 1770 |
+
int nth0 = 32;
|
| 1771 |
+
int nth1 = 1;
|
| 1772 |
+
int nrows = 1;
|
| 1773 |
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
| 1774 |
+
|
| 1775 |
+
// use custom matrix x vector kernel
|
| 1776 |
+
switch (src2t) {
|
| 1777 |
+
case GGML_TYPE_F32:
|
| 1778 |
+
{
|
| 1779 |
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 1780 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
|
| 1781 |
+
} break;
|
| 1782 |
+
case GGML_TYPE_F16:
|
| 1783 |
+
{
|
| 1784 |
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 1785 |
+
nth0 = 32;
|
| 1786 |
+
nth1 = 1;
|
| 1787 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
|
| 1788 |
+
} break;
|
| 1789 |
+
case GGML_TYPE_Q4_0:
|
| 1790 |
+
{
|
| 1791 |
+
nth0 = 8;
|
| 1792 |
+
nth1 = 8;
|
| 1793 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
|
| 1794 |
+
} break;
|
| 1795 |
+
case GGML_TYPE_Q4_1:
|
| 1796 |
+
{
|
| 1797 |
+
nth0 = 8;
|
| 1798 |
+
nth1 = 8;
|
| 1799 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
|
| 1800 |
+
} break;
|
| 1801 |
+
case GGML_TYPE_Q5_0:
|
| 1802 |
+
{
|
| 1803 |
+
nth0 = 8;
|
| 1804 |
+
nth1 = 8;
|
| 1805 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
|
| 1806 |
+
} break;
|
| 1807 |
+
case GGML_TYPE_Q5_1:
|
| 1808 |
+
{
|
| 1809 |
+
nth0 = 8;
|
| 1810 |
+
nth1 = 8;
|
| 1811 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
|
| 1812 |
+
} break;
|
| 1813 |
+
case GGML_TYPE_Q8_0:
|
| 1814 |
+
{
|
| 1815 |
+
nth0 = 8;
|
| 1816 |
+
nth1 = 8;
|
| 1817 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
|
| 1818 |
+
} break;
|
| 1819 |
+
case GGML_TYPE_Q2_K:
|
| 1820 |
+
{
|
| 1821 |
+
nth0 = 2;
|
| 1822 |
+
nth1 = 32;
|
| 1823 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
|
| 1824 |
+
} break;
|
| 1825 |
+
case GGML_TYPE_Q3_K:
|
| 1826 |
+
{
|
| 1827 |
+
nth0 = 2;
|
| 1828 |
+
nth1 = 32;
|
| 1829 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
|
| 1830 |
+
} break;
|
| 1831 |
+
case GGML_TYPE_Q4_K:
|
| 1832 |
+
{
|
| 1833 |
+
nth0 = 4; //1;
|
| 1834 |
+
nth1 = 8; //32;
|
| 1835 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
|
| 1836 |
+
} break;
|
| 1837 |
+
case GGML_TYPE_Q5_K:
|
| 1838 |
+
{
|
| 1839 |
+
nth0 = 2;
|
| 1840 |
+
nth1 = 32;
|
| 1841 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
|
| 1842 |
+
} break;
|
| 1843 |
+
case GGML_TYPE_Q6_K:
|
| 1844 |
+
{
|
| 1845 |
+
nth0 = 2;
|
| 1846 |
+
nth1 = 32;
|
| 1847 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
| 1848 |
+
} break;
|
| 1849 |
+
default:
|
| 1850 |
+
{
|
| 1851 |
+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
| 1852 |
+
GGML_ASSERT(false && "not implemented");
|
| 1853 |
+
}
|
| 1854 |
+
};
|
| 1855 |
+
|
| 1856 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1857 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1858 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1859 |
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
| 1860 |
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
| 1861 |
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
| 1862 |
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
|
| 1863 |
+
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
|
| 1864 |
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
|
| 1865 |
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
|
| 1866 |
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
| 1867 |
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
|
| 1868 |
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
| 1869 |
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
| 1870 |
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
| 1871 |
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
| 1872 |
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
| 1873 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
| 1874 |
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
|
| 1875 |
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
| 1876 |
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
|
| 1877 |
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
| 1878 |
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
| 1879 |
+
// TODO: how to make this an array? read Metal docs
|
| 1880 |
+
for (int j = 0; j < n_as; ++j) {
|
| 1881 |
+
struct ggml_tensor * src_cur = dst->src[2 + j];
|
| 1882 |
+
|
| 1883 |
+
size_t offs_src_cur = 0;
|
| 1884 |
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
| 1885 |
+
|
| 1886 |
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
| 1887 |
+
}
|
| 1888 |
+
|
| 1889 |
+
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
| 1890 |
+
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
| 1891 |
+
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
| 1892 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1893 |
+
}
|
| 1894 |
+
else if (src2t == GGML_TYPE_Q4_K) {
|
| 1895 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1896 |
+
}
|
| 1897 |
+
else if (src2t == GGML_TYPE_Q3_K) {
|
| 1898 |
+
#ifdef GGML_QKK_64
|
| 1899 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1900 |
+
#else
|
| 1901 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1902 |
+
#endif
|
| 1903 |
+
}
|
| 1904 |
+
else if (src2t == GGML_TYPE_Q5_K) {
|
| 1905 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1906 |
+
}
|
| 1907 |
+
else if (src2t == GGML_TYPE_Q6_K) {
|
| 1908 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1909 |
+
} else {
|
| 1910 |
+
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
| 1911 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 1912 |
+
}
|
| 1913 |
}
|
| 1914 |
} break;
|
| 1915 |
case GGML_OP_GET_ROWS:
|
|
|
|
| 1930 |
default: GGML_ASSERT(false && "not implemented");
|
| 1931 |
}
|
| 1932 |
|
| 1933 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1934 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1935 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1936 |
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
| 1937 |
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
| 1938 |
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
| 1939 |
+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
| 1940 |
+
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
| 1941 |
+
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
| 1942 |
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
| 1943 |
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
| 1944 |
+
|
| 1945 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
| 1946 |
} break;
|
| 1947 |
case GGML_OP_RMS_NORM:
|
| 1948 |
{
|
|
|
|
| 1969 |
|
| 1970 |
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1971 |
} break;
|
| 1972 |
+
case GGML_OP_GROUP_NORM:
|
| 1973 |
+
{
|
| 1974 |
+
GGML_ASSERT(ne00 % 4 == 0);
|
| 1975 |
+
|
| 1976 |
+
//float eps;
|
| 1977 |
+
//memcpy(&eps, dst->op_params, sizeof(float));
|
| 1978 |
+
|
| 1979 |
+
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
| 1980 |
+
|
| 1981 |
+
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
| 1982 |
+
|
| 1983 |
+
int nth = 32; // SIMD width
|
| 1984 |
+
|
| 1985 |
+
//while (nth < ne00/4 && nth < 1024) {
|
| 1986 |
+
// nth *= 2;
|
| 1987 |
+
//}
|
| 1988 |
+
|
| 1989 |
+
[encoder setComputePipelineState:ctx->pipeline_group_norm];
|
| 1990 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1991 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1992 |
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
| 1993 |
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
| 1994 |
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
| 1995 |
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
| 1996 |
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
| 1997 |
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
| 1998 |
+
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
| 1999 |
+
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
| 2000 |
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 2001 |
+
|
| 2002 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2003 |
+
} break;
|
| 2004 |
case GGML_OP_NORM:
|
| 2005 |
{
|
| 2006 |
float eps;
|
|
|
|
| 2170 |
|
| 2171 |
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
| 2172 |
} break;
|
| 2173 |
+
case GGML_OP_UPSCALE:
|
| 2174 |
+
{
|
| 2175 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 2176 |
+
|
| 2177 |
+
const int sf = dst->op_params[0];
|
| 2178 |
+
|
| 2179 |
+
[encoder setComputePipelineState:ctx->pipeline_upscale_f32];
|
| 2180 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2181 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 2182 |
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
| 2183 |
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
| 2184 |
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
| 2185 |
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
| 2186 |
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
| 2187 |
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
| 2188 |
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
| 2189 |
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
| 2190 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
| 2191 |
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
| 2192 |
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
| 2193 |
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
| 2194 |
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
| 2195 |
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
| 2196 |
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
| 2197 |
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
| 2198 |
+
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
| 2199 |
+
|
| 2200 |
+
const int nth = MIN(1024, ne0);
|
| 2201 |
+
|
| 2202 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2203 |
+
} break;
|
| 2204 |
+
case GGML_OP_PAD:
|
| 2205 |
+
{
|
| 2206 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 2207 |
+
|
| 2208 |
+
[encoder setComputePipelineState:ctx->pipeline_pad_f32];
|
| 2209 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2210 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 2211 |
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
| 2212 |
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
| 2213 |
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
| 2214 |
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
| 2215 |
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
| 2216 |
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
| 2217 |
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
| 2218 |
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
| 2219 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
| 2220 |
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
| 2221 |
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
| 2222 |
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
| 2223 |
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
| 2224 |
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
| 2225 |
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
| 2226 |
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
| 2227 |
+
|
| 2228 |
+
const int nth = MIN(1024, ne0);
|
| 2229 |
+
|
| 2230 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2231 |
+
} break;
|
| 2232 |
case GGML_OP_ARGSORT:
|
| 2233 |
{
|
| 2234 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
| 2250 |
|
| 2251 |
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
| 2252 |
} break;
|
| 2253 |
+
case GGML_OP_LEAKY_RELU:
|
| 2254 |
+
{
|
| 2255 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 2256 |
+
|
| 2257 |
+
float slope;
|
| 2258 |
+
memcpy(&slope, dst->op_params, sizeof(float));
|
| 2259 |
+
|
| 2260 |
+
[encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
|
| 2261 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2262 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 2263 |
+
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
| 2264 |
+
|
| 2265 |
+
const int64_t n = ggml_nelements(dst);
|
| 2266 |
+
|
| 2267 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2268 |
+
} break;
|
| 2269 |
case GGML_OP_DUP:
|
| 2270 |
case GGML_OP_CPY:
|
| 2271 |
case GGML_OP_CONT:
|
|
|
|
| 2294 |
{
|
| 2295 |
switch (dstt) {
|
| 2296 |
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
|
| 2297 |
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
|
| 2298 |
default: GGML_ASSERT(false && "not implemented");
|
| 2299 |
};
|
| 2300 |
} break;
|
ggml-metal.metal
CHANGED
|
@@ -79,6 +79,7 @@ kernel void kernel_add(
|
|
| 79 |
constant int64_t & nb1,
|
| 80 |
constant int64_t & nb2,
|
| 81 |
constant int64_t & nb3,
|
|
|
|
| 82 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 83 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 84 |
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -90,9 +91,9 @@ kernel void kernel_add(
|
|
| 90 |
const int64_t i12 = i02 % ne12;
|
| 91 |
const int64_t i11 = i01 % ne11;
|
| 92 |
|
| 93 |
-
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
| 94 |
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
| 95 |
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
| 96 |
|
| 97 |
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
| 98 |
const int i10 = i0 % ne10;
|
|
@@ -204,7 +205,7 @@ kernel void kernel_add_row(
|
|
| 204 |
device const float4 * src0,
|
| 205 |
device const float4 * src1,
|
| 206 |
device float4 * dst,
|
| 207 |
-
constant int64_t & nb [[buffer(
|
| 208 |
uint tpig[[thread_position_in_grid]]) {
|
| 209 |
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
| 210 |
}
|
|
@@ -213,7 +214,7 @@ kernel void kernel_mul_row(
|
|
| 213 |
device const float4 * src0,
|
| 214 |
device const float4 * src1,
|
| 215 |
device float4 * dst,
|
| 216 |
-
constant int64_t & nb [[buffer(
|
| 217 |
uint tpig[[thread_position_in_grid]]) {
|
| 218 |
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
| 219 |
}
|
|
@@ -222,7 +223,7 @@ kernel void kernel_div_row(
|
|
| 222 |
device const float4 * src0,
|
| 223 |
device const float4 * src1,
|
| 224 |
device float4 * dst,
|
| 225 |
-
constant int64_t & nb [[buffer(
|
| 226 |
uint tpig[[thread_position_in_grid]]) {
|
| 227 |
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
| 228 |
}
|
|
@@ -243,19 +244,53 @@ kernel void kernel_scale_4(
|
|
| 243 |
dst[tpig] = src0[tpig] * scale;
|
| 244 |
}
|
| 245 |
|
| 246 |
-
kernel void
|
| 247 |
-
device const
|
| 248 |
-
device
|
| 249 |
uint tpig[[thread_position_in_grid]]) {
|
| 250 |
-
|
| 251 |
-
dst[tpig] = x / (1.0f + exp(-x));
|
| 252 |
}
|
| 253 |
|
| 254 |
-
kernel void
|
| 255 |
device const float * src0,
|
| 256 |
device float * dst,
|
| 257 |
uint tpig[[thread_position_in_grid]]) {
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
}
|
| 260 |
|
| 261 |
kernel void kernel_sqr(
|
|
@@ -313,22 +348,6 @@ kernel void kernel_sum_rows(
|
|
| 313 |
dst_row[0] = row_sum;
|
| 314 |
}
|
| 315 |
|
| 316 |
-
constant float GELU_COEF_A = 0.044715f;
|
| 317 |
-
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
| 318 |
-
|
| 319 |
-
kernel void kernel_gelu(
|
| 320 |
-
device const float4 * src0,
|
| 321 |
-
device float4 * dst,
|
| 322 |
-
uint tpig[[thread_position_in_grid]]) {
|
| 323 |
-
device const float4 & x = src0[tpig];
|
| 324 |
-
|
| 325 |
-
// BEWARE !!!
|
| 326 |
-
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
| 327 |
-
// This was observed with Falcon 7B and 40B models
|
| 328 |
-
//
|
| 329 |
-
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
| 330 |
-
}
|
| 331 |
-
|
| 332 |
kernel void kernel_soft_max(
|
| 333 |
device const float * src0,
|
| 334 |
device const float * src1,
|
|
@@ -347,9 +366,9 @@ kernel void kernel_soft_max(
|
|
| 347 |
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
| 348 |
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
| 349 |
|
| 350 |
-
device const float * psrc0 =
|
| 351 |
-
device const float * pmask = src1 ? src1
|
| 352 |
-
device float * pdst =
|
| 353 |
|
| 354 |
// parallel max
|
| 355 |
float lmax = -INFINITY;
|
|
@@ -385,7 +404,12 @@ kernel void kernel_soft_max(
|
|
| 385 |
pdst[i00] = exp_psrc0;
|
| 386 |
}
|
| 387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
float sum = simd_sum(lsum);
|
|
|
|
| 389 |
if (ntg > N_SIMDWIDTH) {
|
| 390 |
if (sgitg == 0) {
|
| 391 |
buf[tiisg] = 0.0f;
|
|
@@ -428,9 +452,9 @@ kernel void kernel_soft_max_4(
|
|
| 428 |
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
| 429 |
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
| 430 |
|
| 431 |
-
device const float4 * psrc4 =
|
| 432 |
-
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
| 433 |
-
device float4 * pdst4 =
|
| 434 |
|
| 435 |
// parallel max
|
| 436 |
float4 lmax4 = -INFINITY;
|
|
@@ -468,7 +492,13 @@ kernel void kernel_soft_max_4(
|
|
| 468 |
}
|
| 469 |
|
| 470 |
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
float sum = simd_sum(lsum);
|
|
|
|
| 472 |
if (ntg > N_SIMDWIDTH) {
|
| 473 |
if (sgitg == 0) {
|
| 474 |
buf[tiisg] = 0.0f;
|
|
@@ -639,6 +669,94 @@ kernel void kernel_rms_norm(
|
|
| 639 |
}
|
| 640 |
}
|
| 641 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
| 643 |
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
| 644 |
// we assume that the yl's have been multiplied with the appropriate scale factor
|
|
@@ -731,7 +849,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
| 731 |
// giard against the number of rows not being divisible by
|
| 732 |
// N_DST, so this is another explicit assumption of the implementation.
|
| 733 |
template<typename block_q_type, int nr, int nsg, int nw>
|
| 734 |
-
void
|
| 735 |
device const void * src0,
|
| 736 |
device const float * src1,
|
| 737 |
device float * dst,
|
|
@@ -813,7 +931,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
| 813 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 814 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 815 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 816 |
-
|
| 817 |
}
|
| 818 |
|
| 819 |
kernel void kernel_mul_mv_q4_1_f32(
|
|
@@ -832,7 +950,7 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
| 832 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 833 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 834 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 835 |
-
|
| 836 |
}
|
| 837 |
|
| 838 |
kernel void kernel_mul_mv_q5_0_f32(
|
|
@@ -851,7 +969,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
| 851 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 852 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 853 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 854 |
-
|
| 855 |
}
|
| 856 |
|
| 857 |
kernel void kernel_mul_mv_q5_1_f32(
|
|
@@ -870,28 +988,28 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
| 870 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 871 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 872 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 873 |
-
|
| 874 |
}
|
| 875 |
|
| 876 |
|
| 877 |
#define NB_Q8_0 8
|
| 878 |
|
| 879 |
-
|
| 880 |
device const void * src0,
|
| 881 |
device const float * src1,
|
| 882 |
device float * dst,
|
| 883 |
constant int64_t & ne00,
|
| 884 |
-
constant int64_t & ne01
|
| 885 |
-
constant int64_t & ne02
|
| 886 |
-
constant int64_t & ne10
|
| 887 |
-
constant int64_t & ne12
|
| 888 |
-
constant int64_t & ne0
|
| 889 |
-
constant int64_t & ne1
|
| 890 |
-
constant uint & r2
|
| 891 |
-
constant uint & r3
|
| 892 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 893 |
-
uint
|
| 894 |
-
uint
|
| 895 |
const int nr = N_DST;
|
| 896 |
const int nsg = N_SIMDGROUP;
|
| 897 |
const int nw = N_SIMDWIDTH;
|
|
@@ -945,9 +1063,29 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
| 945 |
}
|
| 946 |
}
|
| 947 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 948 |
#define N_F32_F32 4
|
| 949 |
|
| 950 |
-
|
| 951 |
device const char * src0,
|
| 952 |
device const char * src1,
|
| 953 |
device float * dst,
|
|
@@ -965,8 +1103,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
| 965 |
constant uint64_t & nb12,
|
| 966 |
constant int64_t & ne0,
|
| 967 |
constant int64_t & ne1,
|
| 968 |
-
constant uint & r2
|
| 969 |
-
constant uint & r3
|
| 970 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 971 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 972 |
|
|
@@ -1025,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32(
|
|
| 1025 |
}
|
| 1026 |
}
|
| 1027 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1028 |
#define N_F16_F16 4
|
| 1029 |
|
| 1030 |
kernel void kernel_mul_mv_f16_f16(
|
|
@@ -1105,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16(
|
|
| 1105 |
}
|
| 1106 |
}
|
| 1107 |
|
| 1108 |
-
|
| 1109 |
device const char * src0,
|
| 1110 |
device const char * src1,
|
| 1111 |
device float * dst,
|
|
@@ -1123,8 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
| 1123 |
constant uint64_t & nb12,
|
| 1124 |
constant int64_t & ne0,
|
| 1125 |
constant int64_t & ne1,
|
| 1126 |
-
constant uint & r2
|
| 1127 |
-
constant uint & r3
|
| 1128 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1129 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1130 |
|
|
@@ -1161,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
| 1161 |
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 1162 |
}
|
| 1163 |
}
|
|
|
|
| 1164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1165 |
}
|
| 1166 |
|
| 1167 |
#define N_F16_F32 4
|
| 1168 |
|
| 1169 |
-
|
| 1170 |
device const char * src0,
|
| 1171 |
device const char * src1,
|
| 1172 |
device float * dst,
|
|
@@ -1184,8 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
| 1184 |
constant uint64_t & nb12,
|
| 1185 |
constant int64_t & ne0,
|
| 1186 |
constant int64_t & ne1,
|
| 1187 |
-
constant uint & r2
|
| 1188 |
-
constant uint & r3
|
| 1189 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1190 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1191 |
|
|
@@ -1244,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
|
|
| 1244 |
}
|
| 1245 |
}
|
| 1246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1247 |
// Assumes row size (ne00) is a multiple of 4
|
| 1248 |
kernel void kernel_mul_mv_f16_f32_l4(
|
| 1249 |
device const char * src0,
|
|
@@ -1548,25 +1763,116 @@ kernel void kernel_im2col_f16(
|
|
| 1548 |
}
|
| 1549 |
}
|
| 1550 |
|
| 1551 |
-
|
| 1552 |
-
|
| 1553 |
-
|
| 1554 |
-
|
| 1555 |
-
|
| 1556 |
-
|
| 1557 |
-
|
| 1558 |
-
|
| 1559 |
-
|
| 1560 |
-
|
| 1561 |
-
|
| 1562 |
-
|
| 1563 |
-
|
| 1564 |
-
|
| 1565 |
-
|
| 1566 |
-
|
| 1567 |
-
|
| 1568 |
-
|
| 1569 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1570 |
if (col >= ncols) return;
|
| 1571 |
|
| 1572 |
device const float * x_row = x + row * ncols;
|
|
@@ -1600,9 +1906,17 @@ kernel void kernel_argsort_f32_i32(
|
|
| 1600 |
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
| 1601 |
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
| 1602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1603 |
kernel void kernel_cpy_f16_f16(
|
| 1604 |
-
device
|
| 1605 |
-
device
|
| 1606 |
constant int64_t & ne00,
|
| 1607 |
constant int64_t & ne01,
|
| 1608 |
constant int64_t & ne02,
|
|
@@ -1641,6 +1955,47 @@ kernel void kernel_cpy_f16_f16(
|
|
| 1641 |
}
|
| 1642 |
}
|
| 1643 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1644 |
kernel void kernel_cpy_f32_f16(
|
| 1645 |
device const float * src0,
|
| 1646 |
device half * dst,
|
|
@@ -1917,9 +2272,9 @@ kernel void kernel_cpy_f32_q4_1(
|
|
| 1917 |
}
|
| 1918 |
|
| 1919 |
kernel void kernel_concat(
|
| 1920 |
-
device
|
| 1921 |
-
device
|
| 1922 |
-
device
|
| 1923 |
constant int64_t & ne00,
|
| 1924 |
constant int64_t & ne01,
|
| 1925 |
constant int64_t & ne02,
|
|
@@ -1956,7 +2311,7 @@ kernel void kernel_concat(
|
|
| 1956 |
const int64_t i12 = i02 % ne12;
|
| 1957 |
const int64_t i11 = i01 % ne11;
|
| 1958 |
|
| 1959 |
-
device const char * src0_ptr = src0 + i03
|
| 1960 |
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
| 1961 |
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
| 1962 |
|
|
@@ -2064,19 +2419,19 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
| 2064 |
|
| 2065 |
//====================================== dot products =========================
|
| 2066 |
|
| 2067 |
-
|
| 2068 |
device const void * src0,
|
| 2069 |
device const float * src1,
|
| 2070 |
device float * dst,
|
| 2071 |
constant int64_t & ne00,
|
| 2072 |
-
constant int64_t & ne01
|
| 2073 |
-
constant int64_t & ne02
|
| 2074 |
-
constant int64_t & ne10
|
| 2075 |
-
constant int64_t & ne12
|
| 2076 |
-
constant int64_t & ne0
|
| 2077 |
-
constant int64_t & ne1
|
| 2078 |
-
constant uint & r2
|
| 2079 |
-
constant uint & r3
|
| 2080 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2081 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2082 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2214,8 +2569,8 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
| 2214 |
}
|
| 2215 |
}
|
| 2216 |
|
| 2217 |
-
|
| 2218 |
-
kernel void
|
| 2219 |
device const void * src0,
|
| 2220 |
device const float * src1,
|
| 2221 |
device float * dst,
|
|
@@ -2229,8 +2584,29 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
| 2229 |
constant uint & r2 [[buffer(17)]],
|
| 2230 |
constant uint & r3 [[buffer(18)]],
|
| 2231 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2232 |
-
uint
|
| 2233 |
-
uint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2234 |
|
| 2235 |
const int nb = ne00/QK_K;
|
| 2236 |
|
|
@@ -2373,19 +2749,19 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
| 2373 |
}
|
| 2374 |
}
|
| 2375 |
#else
|
| 2376 |
-
|
| 2377 |
device const void * src0,
|
| 2378 |
device const float * src1,
|
| 2379 |
device float * dst,
|
| 2380 |
constant int64_t & ne00,
|
| 2381 |
-
constant int64_t & ne01
|
| 2382 |
-
constant int64_t & ne02
|
| 2383 |
-
constant int64_t & ne10
|
| 2384 |
-
constant int64_t & ne12
|
| 2385 |
-
constant int64_t & ne0
|
| 2386 |
-
constant int64_t & ne1
|
| 2387 |
-
constant uint & r2
|
| 2388 |
-
constant uint & r3
|
| 2389 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2390 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2391 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2450,20 +2826,41 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
| 2450 |
}
|
| 2451 |
#endif
|
| 2452 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2453 |
#if QK_K == 256
|
| 2454 |
-
|
| 2455 |
device const void * src0,
|
| 2456 |
device const float * src1,
|
| 2457 |
device float * dst,
|
| 2458 |
constant int64_t & ne00,
|
| 2459 |
-
constant int64_t & ne01
|
| 2460 |
-
constant int64_t & ne02
|
| 2461 |
-
constant int64_t & ne10
|
| 2462 |
-
constant int64_t & ne12
|
| 2463 |
-
constant int64_t & ne0
|
| 2464 |
-
constant int64_t & ne1
|
| 2465 |
-
constant uint & r2
|
| 2466 |
-
constant uint & r3
|
| 2467 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2468 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2469 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2564,19 +2961,19 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
| 2564 |
}
|
| 2565 |
}
|
| 2566 |
#else
|
| 2567 |
-
|
| 2568 |
device const void * src0,
|
| 2569 |
device const float * src1,
|
| 2570 |
device float * dst,
|
| 2571 |
constant int64_t & ne00,
|
| 2572 |
-
constant int64_t & ne01
|
| 2573 |
-
constant int64_t & ne02
|
| 2574 |
-
constant int64_t & ne10
|
| 2575 |
-
constant int64_t & ne12
|
| 2576 |
-
constant int64_t & ne0
|
| 2577 |
-
constant int64_t & ne1
|
| 2578 |
-
constant uint & r2
|
| 2579 |
-
constant uint & r3
|
| 2580 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2581 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2582 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2660,7 +3057,8 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
| 2660 |
}
|
| 2661 |
#endif
|
| 2662 |
|
| 2663 |
-
|
|
|
|
| 2664 |
device const void * src0,
|
| 2665 |
device const float * src1,
|
| 2666 |
device float * dst,
|
|
@@ -2677,6 +3075,26 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
| 2677 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2678 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2679 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2680 |
const int nb = ne00/QK_K;
|
| 2681 |
|
| 2682 |
const int64_t r0 = tgpig.x;
|
|
@@ -2836,10 +3254,10 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
| 2836 |
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
| 2837 |
}
|
| 2838 |
}
|
| 2839 |
-
|
| 2840 |
}
|
| 2841 |
|
| 2842 |
-
|
|
|
|
| 2843 |
device const void * src0,
|
| 2844 |
device const float * src1,
|
| 2845 |
device float * dst,
|
|
@@ -2853,18 +3271,38 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
| 2853 |
constant uint & r2 [[buffer(17)]],
|
| 2854 |
constant uint & r3 [[buffer(18)]],
|
| 2855 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2856 |
-
uint
|
| 2857 |
-
uint
|
| 2858 |
-
|
| 2859 |
-
const uint8_t kmask1 = 0x03;
|
| 2860 |
-
const uint8_t kmask2 = 0x0C;
|
| 2861 |
-
const uint8_t kmask3 = 0x30;
|
| 2862 |
-
const uint8_t kmask4 = 0xC0;
|
| 2863 |
|
| 2864 |
-
|
|
|
|
| 2865 |
|
| 2866 |
-
|
| 2867 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2868 |
const int im = tgpig.z;
|
| 2869 |
|
| 2870 |
const int row = 2 * r0 + sgitg;
|
|
@@ -2945,6 +3383,27 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
| 2945 |
}
|
| 2946 |
}
|
| 2947 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2948 |
//============================= templates and their specializations =============================
|
| 2949 |
|
| 2950 |
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
@@ -3062,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|
| 3062 |
|
| 3063 |
template <typename type4x4>
|
| 3064 |
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
| 3065 |
-
const
|
| 3066 |
-
const
|
| 3067 |
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
| 3068 |
-
|
| 3069 |
uint8_t sc = xb->scales[il];
|
| 3070 |
|
| 3071 |
#if QK_K == 256
|
|
@@ -3135,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
| 3135 |
q = q + (il/4) * 32 + 16 * (il&1);
|
| 3136 |
il = il & 3;
|
| 3137 |
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
| 3138 |
-
const
|
| 3139 |
-
const
|
| 3140 |
-
const
|
| 3141 |
-
const
|
| 3142 |
#else
|
| 3143 |
q = q + 16 * (il&1);
|
| 3144 |
device const uint8_t * s = xb->scales;
|
|
@@ -3165,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
| 3165 |
uint8_t ul = 1 << (il/2);
|
| 3166 |
il = il & 3;
|
| 3167 |
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
| 3168 |
-
const
|
| 3169 |
-
const
|
| 3170 |
-
const
|
| 3171 |
-
const
|
| 3172 |
|
| 3173 |
-
const ushort mask
|
| 3174 |
-
const
|
| 3175 |
for (int i = 0; i < 16; ++i) {
|
| 3176 |
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
| 3177 |
}
|
|
@@ -3219,22 +3678,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
| 3219 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 3220 |
kernel void kernel_get_rows(
|
| 3221 |
device const void * src0,
|
| 3222 |
-
device const
|
| 3223 |
device float * dst,
|
| 3224 |
constant int64_t & ne00,
|
| 3225 |
constant uint64_t & nb01,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3226 |
constant uint64_t & nb1,
|
| 3227 |
-
|
|
|
|
| 3228 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 3229 |
-
|
| 3230 |
-
const
|
| 3231 |
-
const
|
|
|
|
|
|
|
|
|
|
| 3232 |
|
| 3233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3234 |
float4x4 temp;
|
| 3235 |
dequantize_func(
|
| 3236 |
-
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
| 3237 |
-
*(((device float4x4 *) ((device char *) dst +
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3238 |
}
|
| 3239 |
}
|
| 3240 |
|
|
@@ -3426,19 +3953,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
| 3426 |
|
| 3427 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
| 3428 |
kernel void kernel_mul_mm_id(
|
| 3429 |
-
device const
|
| 3430 |
device const uchar * src1,
|
| 3431 |
-
device
|
|
|
|
| 3432 |
constant int64_t & ne00,
|
| 3433 |
constant int64_t & ne02,
|
| 3434 |
constant int64_t & nb01,
|
| 3435 |
constant int64_t & nb02,
|
| 3436 |
constant int64_t & ne12,
|
|
|
|
| 3437 |
constant int64_t & nb10,
|
| 3438 |
constant int64_t & nb11,
|
| 3439 |
constant int64_t & nb12,
|
| 3440 |
constant int64_t & ne0,
|
| 3441 |
constant int64_t & ne1,
|
|
|
|
| 3442 |
constant uint & r2,
|
| 3443 |
constant uint & r3,
|
| 3444 |
constant int & idx,
|
|
@@ -3456,10 +3986,16 @@ kernel void kernel_mul_mm_id(
|
|
| 3456 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3457 |
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 3458 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3459 |
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
| 3460 |
-
src0[
|
| 3461 |
-
src1,
|
| 3462 |
-
dst,
|
| 3463 |
ne00,
|
| 3464 |
ne02,
|
| 3465 |
nb01,
|
|
@@ -3484,17 +4020,26 @@ kernel void kernel_mul_mm_id(
|
|
| 3484 |
#define QK_NL 4
|
| 3485 |
#endif
|
| 3486 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3487 |
typedef void (get_rows_t)(
|
| 3488 |
device const void * src0,
|
| 3489 |
-
device const
|
| 3490 |
device float * dst,
|
| 3491 |
constant int64_t & ne00,
|
| 3492 |
constant uint64_t & nb01,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3493 |
constant uint64_t & nb1,
|
| 3494 |
-
|
|
|
|
| 3495 |
|
| 3496 |
-
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
| 3497 |
-
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
| 3498 |
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
| 3499 |
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
| 3500 |
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
|
@@ -3506,6 +4051,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|
| 3506 |
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 3507 |
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 3508 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3509 |
typedef void (mat_mm_t)(
|
| 3510 |
device const uchar * src0,
|
| 3511 |
device const uchar * src1,
|
|
@@ -3538,20 +4087,27 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
| 3538 |
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 3539 |
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 3540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3541 |
typedef void (mat_mm_id_t)(
|
| 3542 |
-
device const
|
| 3543 |
device const uchar * src1,
|
| 3544 |
-
device
|
|
|
|
| 3545 |
constant int64_t & ne00,
|
| 3546 |
constant int64_t & ne02,
|
| 3547 |
constant int64_t & nb01,
|
| 3548 |
constant int64_t & nb02,
|
| 3549 |
constant int64_t & ne12,
|
|
|
|
| 3550 |
constant int64_t & nb10,
|
| 3551 |
constant int64_t & nb11,
|
| 3552 |
constant int64_t & nb12,
|
| 3553 |
constant int64_t & ne0,
|
| 3554 |
constant int64_t & ne1,
|
|
|
|
| 3555 |
constant uint & r2,
|
| 3556 |
constant uint & r3,
|
| 3557 |
constant int & idx,
|
|
@@ -3578,3 +4134,775 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
| 3578 |
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
| 3579 |
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 3580 |
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
constant int64_t & nb1,
|
| 80 |
constant int64_t & nb2,
|
| 81 |
constant int64_t & nb3,
|
| 82 |
+
constant int64_t & offs,
|
| 83 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 84 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 85 |
uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
| 91 |
const int64_t i12 = i02 % ne12;
|
| 92 |
const int64_t i11 = i01 % ne11;
|
| 93 |
|
| 94 |
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
| 95 |
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
| 96 |
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
| 97 |
|
| 98 |
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
| 99 |
const int i10 = i0 % ne10;
|
|
|
|
| 205 |
device const float4 * src0,
|
| 206 |
device const float4 * src1,
|
| 207 |
device float4 * dst,
|
| 208 |
+
constant int64_t & nb [[buffer(28)]],
|
| 209 |
uint tpig[[thread_position_in_grid]]) {
|
| 210 |
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
| 211 |
}
|
|
|
|
| 214 |
device const float4 * src0,
|
| 215 |
device const float4 * src1,
|
| 216 |
device float4 * dst,
|
| 217 |
+
constant int64_t & nb [[buffer(28)]],
|
| 218 |
uint tpig[[thread_position_in_grid]]) {
|
| 219 |
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
| 220 |
}
|
|
|
|
| 223 |
device const float4 * src0,
|
| 224 |
device const float4 * src1,
|
| 225 |
device float4 * dst,
|
| 226 |
+
constant int64_t & nb [[buffer(28)]],
|
| 227 |
uint tpig[[thread_position_in_grid]]) {
|
| 228 |
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
| 229 |
}
|
|
|
|
| 244 |
dst[tpig] = src0[tpig] * scale;
|
| 245 |
}
|
| 246 |
|
| 247 |
+
kernel void kernel_relu(
|
| 248 |
+
device const float * src0,
|
| 249 |
+
device float * dst,
|
| 250 |
uint tpig[[thread_position_in_grid]]) {
|
| 251 |
+
dst[tpig] = max(0.0f, src0[tpig]);
|
|
|
|
| 252 |
}
|
| 253 |
|
| 254 |
+
kernel void kernel_tanh(
|
| 255 |
device const float * src0,
|
| 256 |
device float * dst,
|
| 257 |
uint tpig[[thread_position_in_grid]]) {
|
| 258 |
+
device const float & x = src0[tpig];
|
| 259 |
+
dst[tpig] = precise::tanh(x);
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
constant float GELU_COEF_A = 0.044715f;
|
| 263 |
+
constant float GELU_QUICK_COEF = -1.702f;
|
| 264 |
+
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
| 265 |
+
|
| 266 |
+
kernel void kernel_gelu(
|
| 267 |
+
device const float4 * src0,
|
| 268 |
+
device float4 * dst,
|
| 269 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 270 |
+
device const float4 & x = src0[tpig];
|
| 271 |
+
|
| 272 |
+
// BEWARE !!!
|
| 273 |
+
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
| 274 |
+
// This was observed with Falcon 7B and 40B models
|
| 275 |
+
//
|
| 276 |
+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
kernel void kernel_gelu_quick(
|
| 280 |
+
device const float4 * src0,
|
| 281 |
+
device float4 * dst,
|
| 282 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 283 |
+
device const float4 & x = src0[tpig];
|
| 284 |
+
|
| 285 |
+
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
kernel void kernel_silu(
|
| 289 |
+
device const float4 * src0,
|
| 290 |
+
device float4 * dst,
|
| 291 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 292 |
+
device const float4 & x = src0[tpig];
|
| 293 |
+
dst[tpig] = x / (1.0f + exp(-x));
|
| 294 |
}
|
| 295 |
|
| 296 |
kernel void kernel_sqr(
|
|
|
|
| 348 |
dst_row[0] = row_sum;
|
| 349 |
}
|
| 350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
kernel void kernel_soft_max(
|
| 352 |
device const float * src0,
|
| 353 |
device const float * src1,
|
|
|
|
| 366 |
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
| 367 |
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
| 368 |
|
| 369 |
+
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
| 370 |
+
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
| 371 |
+
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
| 372 |
|
| 373 |
// parallel max
|
| 374 |
float lmax = -INFINITY;
|
|
|
|
| 404 |
pdst[i00] = exp_psrc0;
|
| 405 |
}
|
| 406 |
|
| 407 |
+
// This barrier fixes a failing test
|
| 408 |
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
| 409 |
+
threadgroup_barrier(mem_flags::mem_none);
|
| 410 |
+
|
| 411 |
float sum = simd_sum(lsum);
|
| 412 |
+
|
| 413 |
if (ntg > N_SIMDWIDTH) {
|
| 414 |
if (sgitg == 0) {
|
| 415 |
buf[tiisg] = 0.0f;
|
|
|
|
| 452 |
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
| 453 |
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
| 454 |
|
| 455 |
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 456 |
+
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
| 457 |
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 458 |
|
| 459 |
// parallel max
|
| 460 |
float4 lmax4 = -INFINITY;
|
|
|
|
| 492 |
}
|
| 493 |
|
| 494 |
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
| 495 |
+
|
| 496 |
+
// This barrier fixes a failing test
|
| 497 |
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
| 498 |
+
threadgroup_barrier(mem_flags::mem_none);
|
| 499 |
+
|
| 500 |
float sum = simd_sum(lsum);
|
| 501 |
+
|
| 502 |
if (ntg > N_SIMDWIDTH) {
|
| 503 |
if (sgitg == 0) {
|
| 504 |
buf[tiisg] = 0.0f;
|
|
|
|
| 669 |
}
|
| 670 |
}
|
| 671 |
|
| 672 |
+
kernel void kernel_group_norm(
|
| 673 |
+
device const float * src0,
|
| 674 |
+
device float * dst,
|
| 675 |
+
constant int64_t & ne00,
|
| 676 |
+
constant int64_t & ne01,
|
| 677 |
+
constant int64_t & ne02,
|
| 678 |
+
constant uint64_t & nb00,
|
| 679 |
+
constant uint64_t & nb01,
|
| 680 |
+
constant uint64_t & nb02,
|
| 681 |
+
constant int32_t & n_groups,
|
| 682 |
+
constant float & eps,
|
| 683 |
+
threadgroup float * buf [[threadgroup(0)]],
|
| 684 |
+
uint tgpig[[threadgroup_position_in_grid]],
|
| 685 |
+
uint tpitg[[thread_position_in_threadgroup]],
|
| 686 |
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
| 687 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 688 |
+
uint ntg[[threads_per_threadgroup]]) {
|
| 689 |
+
const int64_t ne = ne00*ne01*ne02;
|
| 690 |
+
const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
|
| 691 |
+
|
| 692 |
+
int start = tgpig * gs;
|
| 693 |
+
int end = start + gs;
|
| 694 |
+
|
| 695 |
+
start += tpitg;
|
| 696 |
+
|
| 697 |
+
if (end >= ne) {
|
| 698 |
+
end = ne;
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
float tmp = 0.0f; // partial sum for thread in warp
|
| 702 |
+
|
| 703 |
+
for (int j = start; j < end; j += ntg) {
|
| 704 |
+
tmp += src0[j];
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 708 |
+
tmp = simd_sum(tmp);
|
| 709 |
+
if (ntg > N_SIMDWIDTH) {
|
| 710 |
+
if (sgitg == 0) {
|
| 711 |
+
buf[tiisg] = 0.0f;
|
| 712 |
+
}
|
| 713 |
+
|
| 714 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 715 |
+
|
| 716 |
+
if (tiisg == 0) {
|
| 717 |
+
buf[sgitg] = tmp;
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 721 |
+
|
| 722 |
+
tmp = buf[tiisg];
|
| 723 |
+
tmp = simd_sum(tmp);
|
| 724 |
+
}
|
| 725 |
+
|
| 726 |
+
const float mean = tmp / gs;
|
| 727 |
+
tmp = 0.0f;
|
| 728 |
+
|
| 729 |
+
for (int j = start; j < end; j += ntg) {
|
| 730 |
+
float xi = src0[j] - mean;
|
| 731 |
+
dst[j] = xi;
|
| 732 |
+
tmp += xi * xi;
|
| 733 |
+
}
|
| 734 |
+
|
| 735 |
+
tmp = simd_sum(tmp);
|
| 736 |
+
if (ntg > N_SIMDWIDTH) {
|
| 737 |
+
if (sgitg == 0) {
|
| 738 |
+
buf[tiisg] = 0.0f;
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 742 |
+
|
| 743 |
+
if (tiisg == 0) {
|
| 744 |
+
buf[sgitg] = tmp;
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 748 |
+
|
| 749 |
+
tmp = buf[tiisg];
|
| 750 |
+
tmp = simd_sum(tmp);
|
| 751 |
+
}
|
| 752 |
+
|
| 753 |
+
const float variance = tmp / gs;
|
| 754 |
+
const float scale = 1.0f/sqrt(variance + eps);
|
| 755 |
+
for (int j = start; j < end; j += ntg) {
|
| 756 |
+
dst[j] *= scale;
|
| 757 |
+
}
|
| 758 |
+
}
|
| 759 |
+
|
| 760 |
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
| 761 |
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
| 762 |
// we assume that the yl's have been multiplied with the appropriate scale factor
|
|
|
|
| 849 |
// giard against the number of rows not being divisible by
|
| 850 |
// N_DST, so this is another explicit assumption of the implementation.
|
| 851 |
template<typename block_q_type, int nr, int nsg, int nw>
|
| 852 |
+
void mul_vec_q_n_f32_impl(
|
| 853 |
device const void * src0,
|
| 854 |
device const float * src1,
|
| 855 |
device float * dst,
|
|
|
|
| 931 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 932 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 933 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 934 |
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
| 935 |
}
|
| 936 |
|
| 937 |
kernel void kernel_mul_mv_q4_1_f32(
|
|
|
|
| 950 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 951 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 952 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 953 |
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
| 954 |
}
|
| 955 |
|
| 956 |
kernel void kernel_mul_mv_q5_0_f32(
|
|
|
|
| 969 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 970 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 971 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 972 |
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
| 973 |
}
|
| 974 |
|
| 975 |
kernel void kernel_mul_mv_q5_1_f32(
|
|
|
|
| 988 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 989 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 990 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 991 |
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
| 992 |
}
|
| 993 |
|
| 994 |
|
| 995 |
#define NB_Q8_0 8
|
| 996 |
|
| 997 |
+
void kernel_mul_mv_q8_0_f32_impl(
|
| 998 |
device const void * src0,
|
| 999 |
device const float * src1,
|
| 1000 |
device float * dst,
|
| 1001 |
constant int64_t & ne00,
|
| 1002 |
+
constant int64_t & ne01,
|
| 1003 |
+
constant int64_t & ne02,
|
| 1004 |
+
constant int64_t & ne10,
|
| 1005 |
+
constant int64_t & ne12,
|
| 1006 |
+
constant int64_t & ne0,
|
| 1007 |
+
constant int64_t & ne1,
|
| 1008 |
+
constant uint & r2,
|
| 1009 |
+
constant uint & r3,
|
| 1010 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1011 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 1012 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1013 |
const int nr = N_DST;
|
| 1014 |
const int nsg = N_SIMDGROUP;
|
| 1015 |
const int nw = N_SIMDWIDTH;
|
|
|
|
| 1063 |
}
|
| 1064 |
}
|
| 1065 |
|
| 1066 |
+
[[host_name("kernel_mul_mv_q8_0_f32")]]
|
| 1067 |
+
kernel void kernel_mul_mv_q8_0_f32(
|
| 1068 |
+
device const void * src0,
|
| 1069 |
+
device const float * src1,
|
| 1070 |
+
device float * dst,
|
| 1071 |
+
constant int64_t & ne00,
|
| 1072 |
+
constant int64_t & ne01,
|
| 1073 |
+
constant int64_t & ne02,
|
| 1074 |
+
constant int64_t & ne10,
|
| 1075 |
+
constant int64_t & ne12,
|
| 1076 |
+
constant int64_t & ne0,
|
| 1077 |
+
constant int64_t & ne1,
|
| 1078 |
+
constant uint & r2 [[buffer(17)]],
|
| 1079 |
+
constant uint & r3 [[buffer(18)]],
|
| 1080 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1081 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 1082 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1083 |
+
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
| 1084 |
+
}
|
| 1085 |
+
|
| 1086 |
#define N_F32_F32 4
|
| 1087 |
|
| 1088 |
+
void kernel_mul_mv_f32_f32_impl(
|
| 1089 |
device const char * src0,
|
| 1090 |
device const char * src1,
|
| 1091 |
device float * dst,
|
|
|
|
| 1103 |
constant uint64_t & nb12,
|
| 1104 |
constant int64_t & ne0,
|
| 1105 |
constant int64_t & ne1,
|
| 1106 |
+
constant uint & r2,
|
| 1107 |
+
constant uint & r3,
|
| 1108 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1109 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1110 |
|
|
|
|
| 1163 |
}
|
| 1164 |
}
|
| 1165 |
|
| 1166 |
+
[[host_name("kernel_mul_mv_f32_f32")]]
|
| 1167 |
+
kernel void kernel_mul_mv_f32_f32(
|
| 1168 |
+
device const char * src0,
|
| 1169 |
+
device const char * src1,
|
| 1170 |
+
device float * dst,
|
| 1171 |
+
constant int64_t & ne00,
|
| 1172 |
+
constant int64_t & ne01,
|
| 1173 |
+
constant int64_t & ne02,
|
| 1174 |
+
constant uint64_t & nb00,
|
| 1175 |
+
constant uint64_t & nb01,
|
| 1176 |
+
constant uint64_t & nb02,
|
| 1177 |
+
constant int64_t & ne10,
|
| 1178 |
+
constant int64_t & ne11,
|
| 1179 |
+
constant int64_t & ne12,
|
| 1180 |
+
constant uint64_t & nb10,
|
| 1181 |
+
constant uint64_t & nb11,
|
| 1182 |
+
constant uint64_t & nb12,
|
| 1183 |
+
constant int64_t & ne0,
|
| 1184 |
+
constant int64_t & ne1,
|
| 1185 |
+
constant uint & r2 [[buffer(17)]],
|
| 1186 |
+
constant uint & r3 [[buffer(18)]],
|
| 1187 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1188 |
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1189 |
+
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
| 1190 |
+
}
|
| 1191 |
+
|
| 1192 |
#define N_F16_F16 4
|
| 1193 |
|
| 1194 |
kernel void kernel_mul_mv_f16_f16(
|
|
|
|
| 1269 |
}
|
| 1270 |
}
|
| 1271 |
|
| 1272 |
+
void kernel_mul_mv_f16_f32_1row_impl(
|
| 1273 |
device const char * src0,
|
| 1274 |
device const char * src1,
|
| 1275 |
device float * dst,
|
|
|
|
| 1287 |
constant uint64_t & nb12,
|
| 1288 |
constant int64_t & ne0,
|
| 1289 |
constant int64_t & ne1,
|
| 1290 |
+
constant uint & r2,
|
| 1291 |
+
constant uint & r3,
|
| 1292 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1293 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1294 |
|
|
|
|
| 1325 |
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 1326 |
}
|
| 1327 |
}
|
| 1328 |
+
}
|
| 1329 |
|
| 1330 |
+
[[host_name("kernel_mul_mv_f16_f32_1row")]]
|
| 1331 |
+
kernel void kernel_mul_mv_f16_f32_1row(
|
| 1332 |
+
device const char * src0,
|
| 1333 |
+
device const char * src1,
|
| 1334 |
+
device float * dst,
|
| 1335 |
+
constant int64_t & ne00,
|
| 1336 |
+
constant int64_t & ne01,
|
| 1337 |
+
constant int64_t & ne02,
|
| 1338 |
+
constant uint64_t & nb00,
|
| 1339 |
+
constant uint64_t & nb01,
|
| 1340 |
+
constant uint64_t & nb02,
|
| 1341 |
+
constant int64_t & ne10,
|
| 1342 |
+
constant int64_t & ne11,
|
| 1343 |
+
constant int64_t & ne12,
|
| 1344 |
+
constant uint64_t & nb10,
|
| 1345 |
+
constant uint64_t & nb11,
|
| 1346 |
+
constant uint64_t & nb12,
|
| 1347 |
+
constant int64_t & ne0,
|
| 1348 |
+
constant int64_t & ne1,
|
| 1349 |
+
constant uint & r2 [[buffer(17)]],
|
| 1350 |
+
constant uint & r3 [[buffer(18)]],
|
| 1351 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1352 |
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1353 |
+
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
| 1354 |
}
|
| 1355 |
|
| 1356 |
#define N_F16_F32 4
|
| 1357 |
|
| 1358 |
+
void kernel_mul_mv_f16_f32_impl(
|
| 1359 |
device const char * src0,
|
| 1360 |
device const char * src1,
|
| 1361 |
device float * dst,
|
|
|
|
| 1373 |
constant uint64_t & nb12,
|
| 1374 |
constant int64_t & ne0,
|
| 1375 |
constant int64_t & ne1,
|
| 1376 |
+
constant uint & r2,
|
| 1377 |
+
constant uint & r3,
|
| 1378 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1379 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1380 |
|
|
|
|
| 1433 |
}
|
| 1434 |
}
|
| 1435 |
|
| 1436 |
+
[[host_name("kernel_mul_mv_f16_f32")]]
|
| 1437 |
+
kernel void kernel_mul_mv_f16_f32(
|
| 1438 |
+
device const char * src0,
|
| 1439 |
+
device const char * src1,
|
| 1440 |
+
device float * dst,
|
| 1441 |
+
constant int64_t & ne00,
|
| 1442 |
+
constant int64_t & ne01,
|
| 1443 |
+
constant int64_t & ne02,
|
| 1444 |
+
constant uint64_t & nb00,
|
| 1445 |
+
constant uint64_t & nb01,
|
| 1446 |
+
constant uint64_t & nb02,
|
| 1447 |
+
constant int64_t & ne10,
|
| 1448 |
+
constant int64_t & ne11,
|
| 1449 |
+
constant int64_t & ne12,
|
| 1450 |
+
constant uint64_t & nb10,
|
| 1451 |
+
constant uint64_t & nb11,
|
| 1452 |
+
constant uint64_t & nb12,
|
| 1453 |
+
constant int64_t & ne0,
|
| 1454 |
+
constant int64_t & ne1,
|
| 1455 |
+
constant uint & r2 [[buffer(17)]],
|
| 1456 |
+
constant uint & r3 [[buffer(18)]],
|
| 1457 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1458 |
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1459 |
+
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
| 1460 |
+
}
|
| 1461 |
+
|
| 1462 |
// Assumes row size (ne00) is a multiple of 4
|
| 1463 |
kernel void kernel_mul_mv_f16_f32_l4(
|
| 1464 |
device const char * src0,
|
|
|
|
| 1763 |
}
|
| 1764 |
}
|
| 1765 |
|
| 1766 |
+
kernel void kernel_upscale_f32(
|
| 1767 |
+
device const char * src0,
|
| 1768 |
+
device char * dst,
|
| 1769 |
+
constant int64_t & ne00,
|
| 1770 |
+
constant int64_t & ne01,
|
| 1771 |
+
constant int64_t & ne02,
|
| 1772 |
+
constant int64_t & ne03,
|
| 1773 |
+
constant uint64_t & nb00,
|
| 1774 |
+
constant uint64_t & nb01,
|
| 1775 |
+
constant uint64_t & nb02,
|
| 1776 |
+
constant uint64_t & nb03,
|
| 1777 |
+
constant int64_t & ne0,
|
| 1778 |
+
constant int64_t & ne1,
|
| 1779 |
+
constant int64_t & ne2,
|
| 1780 |
+
constant int64_t & ne3,
|
| 1781 |
+
constant uint64_t & nb0,
|
| 1782 |
+
constant uint64_t & nb1,
|
| 1783 |
+
constant uint64_t & nb2,
|
| 1784 |
+
constant uint64_t & nb3,
|
| 1785 |
+
constant int32_t & sf,
|
| 1786 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1787 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1788 |
+
uint3 ntg[[threads_per_threadgroup]]) {
|
| 1789 |
+
|
| 1790 |
+
const int64_t i3 = tgpig.z;
|
| 1791 |
+
const int64_t i2 = tgpig.y;
|
| 1792 |
+
const int64_t i1 = tgpig.x;
|
| 1793 |
+
|
| 1794 |
+
const int64_t i03 = i3;
|
| 1795 |
+
const int64_t i02 = i2;
|
| 1796 |
+
const int64_t i01 = i1/sf;
|
| 1797 |
+
|
| 1798 |
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
| 1799 |
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
| 1800 |
+
|
| 1801 |
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
| 1802 |
+
dst_ptr[i0] = src0_ptr[i0/sf];
|
| 1803 |
+
}
|
| 1804 |
+
}
|
| 1805 |
+
|
| 1806 |
+
kernel void kernel_pad_f32(
|
| 1807 |
+
device const char * src0,
|
| 1808 |
+
device char * dst,
|
| 1809 |
+
constant int64_t & ne00,
|
| 1810 |
+
constant int64_t & ne01,
|
| 1811 |
+
constant int64_t & ne02,
|
| 1812 |
+
constant int64_t & ne03,
|
| 1813 |
+
constant uint64_t & nb00,
|
| 1814 |
+
constant uint64_t & nb01,
|
| 1815 |
+
constant uint64_t & nb02,
|
| 1816 |
+
constant uint64_t & nb03,
|
| 1817 |
+
constant int64_t & ne0,
|
| 1818 |
+
constant int64_t & ne1,
|
| 1819 |
+
constant int64_t & ne2,
|
| 1820 |
+
constant int64_t & ne3,
|
| 1821 |
+
constant uint64_t & nb0,
|
| 1822 |
+
constant uint64_t & nb1,
|
| 1823 |
+
constant uint64_t & nb2,
|
| 1824 |
+
constant uint64_t & nb3,
|
| 1825 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1826 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1827 |
+
uint3 ntg[[threads_per_threadgroup]]) {
|
| 1828 |
+
|
| 1829 |
+
const int64_t i3 = tgpig.z;
|
| 1830 |
+
const int64_t i2 = tgpig.y;
|
| 1831 |
+
const int64_t i1 = tgpig.x;
|
| 1832 |
+
|
| 1833 |
+
const int64_t i03 = i3;
|
| 1834 |
+
const int64_t i02 = i2;
|
| 1835 |
+
const int64_t i01 = i1;
|
| 1836 |
+
|
| 1837 |
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
| 1838 |
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
| 1839 |
+
|
| 1840 |
+
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
| 1841 |
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
| 1842 |
+
if (i0 < ne00) {
|
| 1843 |
+
dst_ptr[i0] = src0_ptr[i0];
|
| 1844 |
+
} else {
|
| 1845 |
+
dst_ptr[i0] = 0.0f;
|
| 1846 |
+
}
|
| 1847 |
+
}
|
| 1848 |
+
|
| 1849 |
+
return;
|
| 1850 |
+
}
|
| 1851 |
+
|
| 1852 |
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
| 1853 |
+
dst_ptr[i0] = 0.0f;
|
| 1854 |
+
}
|
| 1855 |
+
}
|
| 1856 |
+
|
| 1857 |
+
// bitonic sort implementation following the CUDA kernels as reference
|
| 1858 |
+
typedef void (argsort_t)(
|
| 1859 |
+
device const float * x,
|
| 1860 |
+
device int32_t * dst,
|
| 1861 |
+
constant int64_t & ncols,
|
| 1862 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1863 |
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
| 1864 |
+
|
| 1865 |
+
template<ggml_sort_order order>
|
| 1866 |
+
kernel void kernel_argsort_f32_i32(
|
| 1867 |
+
device const float * x,
|
| 1868 |
+
device int32_t * dst,
|
| 1869 |
+
constant int64_t & ncols,
|
| 1870 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1871 |
+
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
| 1872 |
+
// bitonic sort
|
| 1873 |
+
int col = tpitg[0];
|
| 1874 |
+
int row = tgpig[1];
|
| 1875 |
+
|
| 1876 |
if (col >= ncols) return;
|
| 1877 |
|
| 1878 |
device const float * x_row = x + row * ncols;
|
|
|
|
| 1906 |
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
| 1907 |
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
| 1908 |
|
| 1909 |
+
kernel void kernel_leaky_relu_f32(
|
| 1910 |
+
device const float * src0,
|
| 1911 |
+
device float * dst,
|
| 1912 |
+
constant float & slope,
|
| 1913 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 1914 |
+
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
| 1915 |
+
}
|
| 1916 |
+
|
| 1917 |
kernel void kernel_cpy_f16_f16(
|
| 1918 |
+
device const half * src0,
|
| 1919 |
+
device half * dst,
|
| 1920 |
constant int64_t & ne00,
|
| 1921 |
constant int64_t & ne01,
|
| 1922 |
constant int64_t & ne02,
|
|
|
|
| 1955 |
}
|
| 1956 |
}
|
| 1957 |
|
| 1958 |
+
kernel void kernel_cpy_f16_f32(
|
| 1959 |
+
device const half * src0,
|
| 1960 |
+
device float * dst,
|
| 1961 |
+
constant int64_t & ne00,
|
| 1962 |
+
constant int64_t & ne01,
|
| 1963 |
+
constant int64_t & ne02,
|
| 1964 |
+
constant int64_t & ne03,
|
| 1965 |
+
constant uint64_t & nb00,
|
| 1966 |
+
constant uint64_t & nb01,
|
| 1967 |
+
constant uint64_t & nb02,
|
| 1968 |
+
constant uint64_t & nb03,
|
| 1969 |
+
constant int64_t & ne0,
|
| 1970 |
+
constant int64_t & ne1,
|
| 1971 |
+
constant int64_t & ne2,
|
| 1972 |
+
constant int64_t & ne3,
|
| 1973 |
+
constant uint64_t & nb0,
|
| 1974 |
+
constant uint64_t & nb1,
|
| 1975 |
+
constant uint64_t & nb2,
|
| 1976 |
+
constant uint64_t & nb3,
|
| 1977 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1978 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1979 |
+
uint3 ntg[[threads_per_threadgroup]]) {
|
| 1980 |
+
const int64_t i03 = tgpig[2];
|
| 1981 |
+
const int64_t i02 = tgpig[1];
|
| 1982 |
+
const int64_t i01 = tgpig[0];
|
| 1983 |
+
|
| 1984 |
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
| 1985 |
+
|
| 1986 |
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
| 1987 |
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
| 1988 |
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
| 1989 |
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
| 1990 |
+
|
| 1991 |
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 1992 |
+
|
| 1993 |
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
| 1994 |
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
| 1995 |
+
dst_data[i00] = src[0];
|
| 1996 |
+
}
|
| 1997 |
+
}
|
| 1998 |
+
|
| 1999 |
kernel void kernel_cpy_f32_f16(
|
| 2000 |
device const float * src0,
|
| 2001 |
device half * dst,
|
|
|
|
| 2272 |
}
|
| 2273 |
|
| 2274 |
kernel void kernel_concat(
|
| 2275 |
+
device const char * src0,
|
| 2276 |
+
device const char * src1,
|
| 2277 |
+
device char * dst,
|
| 2278 |
constant int64_t & ne00,
|
| 2279 |
constant int64_t & ne01,
|
| 2280 |
constant int64_t & ne02,
|
|
|
|
| 2311 |
const int64_t i12 = i02 % ne12;
|
| 2312 |
const int64_t i11 = i01 % ne11;
|
| 2313 |
|
| 2314 |
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
| 2315 |
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
| 2316 |
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
| 2317 |
|
|
|
|
| 2419 |
|
| 2420 |
//====================================== dot products =========================
|
| 2421 |
|
| 2422 |
+
void kernel_mul_mv_q2_K_f32_impl(
|
| 2423 |
device const void * src0,
|
| 2424 |
device const float * src1,
|
| 2425 |
device float * dst,
|
| 2426 |
constant int64_t & ne00,
|
| 2427 |
+
constant int64_t & ne01,
|
| 2428 |
+
constant int64_t & ne02,
|
| 2429 |
+
constant int64_t & ne10,
|
| 2430 |
+
constant int64_t & ne12,
|
| 2431 |
+
constant int64_t & ne0,
|
| 2432 |
+
constant int64_t & ne1,
|
| 2433 |
+
constant uint & r2,
|
| 2434 |
+
constant uint & r3,
|
| 2435 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2436 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2437 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 2569 |
}
|
| 2570 |
}
|
| 2571 |
|
| 2572 |
+
[[host_name("kernel_mul_mv_q2_K_f32")]]
|
| 2573 |
+
kernel void kernel_mul_mv_q2_K_f32(
|
| 2574 |
device const void * src0,
|
| 2575 |
device const float * src1,
|
| 2576 |
device float * dst,
|
|
|
|
| 2584 |
constant uint & r2 [[buffer(17)]],
|
| 2585 |
constant uint & r3 [[buffer(18)]],
|
| 2586 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2587 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 2588 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2589 |
+
|
| 2590 |
+
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 2591 |
+
}
|
| 2592 |
+
|
| 2593 |
+
#if QK_K == 256
|
| 2594 |
+
void kernel_mul_mv_q3_K_f32_impl(
|
| 2595 |
+
device const void * src0,
|
| 2596 |
+
device const float * src1,
|
| 2597 |
+
device float * dst,
|
| 2598 |
+
constant int64_t & ne00,
|
| 2599 |
+
constant int64_t & ne01,
|
| 2600 |
+
constant int64_t & ne02,
|
| 2601 |
+
constant int64_t & ne10,
|
| 2602 |
+
constant int64_t & ne12,
|
| 2603 |
+
constant int64_t & ne0,
|
| 2604 |
+
constant int64_t & ne1,
|
| 2605 |
+
constant uint & r2,
|
| 2606 |
+
constant uint & r3,
|
| 2607 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2608 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 2609 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2610 |
|
| 2611 |
const int nb = ne00/QK_K;
|
| 2612 |
|
|
|
|
| 2749 |
}
|
| 2750 |
}
|
| 2751 |
#else
|
| 2752 |
+
void kernel_mul_mv_q3_K_f32_impl(
|
| 2753 |
device const void * src0,
|
| 2754 |
device const float * src1,
|
| 2755 |
device float * dst,
|
| 2756 |
constant int64_t & ne00,
|
| 2757 |
+
constant int64_t & ne01,
|
| 2758 |
+
constant int64_t & ne02,
|
| 2759 |
+
constant int64_t & ne10,
|
| 2760 |
+
constant int64_t & ne12,
|
| 2761 |
+
constant int64_t & ne0,
|
| 2762 |
+
constant int64_t & ne1,
|
| 2763 |
+
constant uint & r2,
|
| 2764 |
+
constant uint & r3,
|
| 2765 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2766 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2767 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 2826 |
}
|
| 2827 |
#endif
|
| 2828 |
|
| 2829 |
+
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
| 2830 |
+
kernel void kernel_mul_mv_q3_K_f32(
|
| 2831 |
+
device const void * src0,
|
| 2832 |
+
device const float * src1,
|
| 2833 |
+
device float * dst,
|
| 2834 |
+
constant int64_t & ne00,
|
| 2835 |
+
constant int64_t & ne01[[buffer(4)]],
|
| 2836 |
+
constant int64_t & ne02[[buffer(5)]],
|
| 2837 |
+
constant int64_t & ne10[[buffer(9)]],
|
| 2838 |
+
constant int64_t & ne12[[buffer(11)]],
|
| 2839 |
+
constant int64_t & ne0 [[buffer(15)]],
|
| 2840 |
+
constant int64_t & ne1 [[buffer(16)]],
|
| 2841 |
+
constant uint & r2 [[buffer(17)]],
|
| 2842 |
+
constant uint & r3 [[buffer(18)]],
|
| 2843 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2844 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 2845 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2846 |
+
|
| 2847 |
+
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 2848 |
+
}
|
| 2849 |
+
|
| 2850 |
#if QK_K == 256
|
| 2851 |
+
void kernel_mul_mv_q4_K_f32_impl(
|
| 2852 |
device const void * src0,
|
| 2853 |
device const float * src1,
|
| 2854 |
device float * dst,
|
| 2855 |
constant int64_t & ne00,
|
| 2856 |
+
constant int64_t & ne01,
|
| 2857 |
+
constant int64_t & ne02,
|
| 2858 |
+
constant int64_t & ne10,
|
| 2859 |
+
constant int64_t & ne12,
|
| 2860 |
+
constant int64_t & ne0,
|
| 2861 |
+
constant int64_t & ne1,
|
| 2862 |
+
constant uint & r2,
|
| 2863 |
+
constant uint & r3,
|
| 2864 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2865 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2866 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 2961 |
}
|
| 2962 |
}
|
| 2963 |
#else
|
| 2964 |
+
void kernel_mul_mv_q4_K_f32_impl(
|
| 2965 |
device const void * src0,
|
| 2966 |
device const float * src1,
|
| 2967 |
device float * dst,
|
| 2968 |
constant int64_t & ne00,
|
| 2969 |
+
constant int64_t & ne01,
|
| 2970 |
+
constant int64_t & ne02,
|
| 2971 |
+
constant int64_t & ne10,
|
| 2972 |
+
constant int64_t & ne12,
|
| 2973 |
+
constant int64_t & ne0,
|
| 2974 |
+
constant int64_t & ne1,
|
| 2975 |
+
constant uint & r2,
|
| 2976 |
+
constant uint & r3,
|
| 2977 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2978 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2979 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 3057 |
}
|
| 3058 |
#endif
|
| 3059 |
|
| 3060 |
+
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
| 3061 |
+
kernel void kernel_mul_mv_q4_K_f32(
|
| 3062 |
device const void * src0,
|
| 3063 |
device const float * src1,
|
| 3064 |
device float * dst,
|
|
|
|
| 3075 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3076 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3077 |
|
| 3078 |
+
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 3079 |
+
}
|
| 3080 |
+
|
| 3081 |
+
void kernel_mul_mv_q5_K_f32_impl(
|
| 3082 |
+
device const void * src0,
|
| 3083 |
+
device const float * src1,
|
| 3084 |
+
device float * dst,
|
| 3085 |
+
constant int64_t & ne00,
|
| 3086 |
+
constant int64_t & ne01,
|
| 3087 |
+
constant int64_t & ne02,
|
| 3088 |
+
constant int64_t & ne10,
|
| 3089 |
+
constant int64_t & ne12,
|
| 3090 |
+
constant int64_t & ne0,
|
| 3091 |
+
constant int64_t & ne1,
|
| 3092 |
+
constant uint & r2,
|
| 3093 |
+
constant uint & r3,
|
| 3094 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3095 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 3096 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3097 |
+
|
| 3098 |
const int nb = ne00/QK_K;
|
| 3099 |
|
| 3100 |
const int64_t r0 = tgpig.x;
|
|
|
|
| 3254 |
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
| 3255 |
}
|
| 3256 |
}
|
|
|
|
| 3257 |
}
|
| 3258 |
|
| 3259 |
+
[[host_name("kernel_mul_mv_q5_K_f32")]]
|
| 3260 |
+
kernel void kernel_mul_mv_q5_K_f32(
|
| 3261 |
device const void * src0,
|
| 3262 |
device const float * src1,
|
| 3263 |
device float * dst,
|
|
|
|
| 3271 |
constant uint & r2 [[buffer(17)]],
|
| 3272 |
constant uint & r3 [[buffer(18)]],
|
| 3273 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3274 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 3275 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3276 |
|
| 3277 |
+
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 3278 |
+
}
|
| 3279 |
|
| 3280 |
+
void kernel_mul_mv_q6_K_f32_impl(
|
| 3281 |
+
device const void * src0,
|
| 3282 |
+
device const float * src1,
|
| 3283 |
+
device float * dst,
|
| 3284 |
+
constant int64_t & ne00,
|
| 3285 |
+
constant int64_t & ne01,
|
| 3286 |
+
constant int64_t & ne02,
|
| 3287 |
+
constant int64_t & ne10,
|
| 3288 |
+
constant int64_t & ne12,
|
| 3289 |
+
constant int64_t & ne0,
|
| 3290 |
+
constant int64_t & ne1,
|
| 3291 |
+
constant uint & r2,
|
| 3292 |
+
constant uint & r3,
|
| 3293 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3294 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 3295 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3296 |
+
|
| 3297 |
+
const uint8_t kmask1 = 0x03;
|
| 3298 |
+
const uint8_t kmask2 = 0x0C;
|
| 3299 |
+
const uint8_t kmask3 = 0x30;
|
| 3300 |
+
const uint8_t kmask4 = 0xC0;
|
| 3301 |
+
|
| 3302 |
+
const int nb = ne00/QK_K;
|
| 3303 |
+
|
| 3304 |
+
const int64_t r0 = tgpig.x;
|
| 3305 |
+
const int64_t r1 = tgpig.y;
|
| 3306 |
const int im = tgpig.z;
|
| 3307 |
|
| 3308 |
const int row = 2 * r0 + sgitg;
|
|
|
|
| 3383 |
}
|
| 3384 |
}
|
| 3385 |
|
| 3386 |
+
[[host_name("kernel_mul_mv_q6_K_f32")]]
|
| 3387 |
+
kernel void kernel_mul_mv_q6_K_f32(
|
| 3388 |
+
device const void * src0,
|
| 3389 |
+
device const float * src1,
|
| 3390 |
+
device float * dst,
|
| 3391 |
+
constant int64_t & ne00,
|
| 3392 |
+
constant int64_t & ne01[[buffer(4)]],
|
| 3393 |
+
constant int64_t & ne02[[buffer(5)]],
|
| 3394 |
+
constant int64_t & ne10[[buffer(9)]],
|
| 3395 |
+
constant int64_t & ne12[[buffer(11)]],
|
| 3396 |
+
constant int64_t & ne0 [[buffer(15)]],
|
| 3397 |
+
constant int64_t & ne1 [[buffer(16)]],
|
| 3398 |
+
constant uint & r2 [[buffer(17)]],
|
| 3399 |
+
constant uint & r3 [[buffer(18)]],
|
| 3400 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3401 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 3402 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3403 |
+
|
| 3404 |
+
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 3405 |
+
}
|
| 3406 |
+
|
| 3407 |
//============================= templates and their specializations =============================
|
| 3408 |
|
| 3409 |
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
|
|
| 3521 |
|
| 3522 |
template <typename type4x4>
|
| 3523 |
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
| 3524 |
+
const float d = xb->d;
|
| 3525 |
+
const float min = xb->dmin;
|
| 3526 |
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
| 3527 |
+
float dl, ml;
|
| 3528 |
uint8_t sc = xb->scales[il];
|
| 3529 |
|
| 3530 |
#if QK_K == 256
|
|
|
|
| 3594 |
q = q + (il/4) * 32 + 16 * (il&1);
|
| 3595 |
il = il & 3;
|
| 3596 |
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
| 3597 |
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
| 3598 |
+
const float min = xb->dmin;
|
| 3599 |
+
const float dl = d * sc[0];
|
| 3600 |
+
const float ml = min * sc[1];
|
| 3601 |
#else
|
| 3602 |
q = q + 16 * (il&1);
|
| 3603 |
device const uint8_t * s = xb->scales;
|
|
|
|
| 3624 |
uint8_t ul = 1 << (il/2);
|
| 3625 |
il = il & 3;
|
| 3626 |
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
| 3627 |
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
| 3628 |
+
const float min = xb->dmin;
|
| 3629 |
+
const float dl = d * sc[0];
|
| 3630 |
+
const float ml = min * sc[1];
|
| 3631 |
|
| 3632 |
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
| 3633 |
+
const float qh_val = il<2 ? 16.f : 256.f;
|
| 3634 |
for (int i = 0; i < 16; ++i) {
|
| 3635 |
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
| 3636 |
}
|
|
|
|
| 3678 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 3679 |
kernel void kernel_get_rows(
|
| 3680 |
device const void * src0,
|
| 3681 |
+
device const char * src1,
|
| 3682 |
device float * dst,
|
| 3683 |
constant int64_t & ne00,
|
| 3684 |
constant uint64_t & nb01,
|
| 3685 |
+
constant uint64_t & nb02,
|
| 3686 |
+
constant int64_t & ne10,
|
| 3687 |
+
constant uint64_t & nb10,
|
| 3688 |
+
constant uint64_t & nb11,
|
| 3689 |
constant uint64_t & nb1,
|
| 3690 |
+
constant uint64_t & nb2,
|
| 3691 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3692 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 3693 |
+
uint3 tptg [[threads_per_threadgroup]]) {
|
| 3694 |
+
//const int64_t i = tgpig;
|
| 3695 |
+
//const int64_t r = ((device int32_t *) src1)[i];
|
| 3696 |
+
|
| 3697 |
+
const int64_t i10 = tgpig.x;
|
| 3698 |
+
const int64_t i11 = tgpig.y;
|
| 3699 |
|
| 3700 |
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
| 3701 |
+
|
| 3702 |
+
const int64_t i02 = i11;
|
| 3703 |
+
|
| 3704 |
+
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
| 3705 |
float4x4 temp;
|
| 3706 |
dequantize_func(
|
| 3707 |
+
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
| 3708 |
+
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
| 3709 |
+
}
|
| 3710 |
+
}
|
| 3711 |
+
|
| 3712 |
+
kernel void kernel_get_rows_f32(
|
| 3713 |
+
device const void * src0,
|
| 3714 |
+
device const char * src1,
|
| 3715 |
+
device float * dst,
|
| 3716 |
+
constant int64_t & ne00,
|
| 3717 |
+
constant uint64_t & nb01,
|
| 3718 |
+
constant uint64_t & nb02,
|
| 3719 |
+
constant int64_t & ne10,
|
| 3720 |
+
constant uint64_t & nb10,
|
| 3721 |
+
constant uint64_t & nb11,
|
| 3722 |
+
constant uint64_t & nb1,
|
| 3723 |
+
constant uint64_t & nb2,
|
| 3724 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3725 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 3726 |
+
uint3 tptg [[threads_per_threadgroup]]) {
|
| 3727 |
+
const int64_t i10 = tgpig.x;
|
| 3728 |
+
const int64_t i11 = tgpig.y;
|
| 3729 |
+
|
| 3730 |
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
| 3731 |
+
|
| 3732 |
+
const int64_t i02 = i11;
|
| 3733 |
+
|
| 3734 |
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
| 3735 |
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
| 3736 |
+
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
| 3737 |
+
}
|
| 3738 |
+
}
|
| 3739 |
+
|
| 3740 |
+
kernel void kernel_get_rows_f16(
|
| 3741 |
+
device const void * src0,
|
| 3742 |
+
device const char * src1,
|
| 3743 |
+
device float * dst,
|
| 3744 |
+
constant int64_t & ne00,
|
| 3745 |
+
constant uint64_t & nb01,
|
| 3746 |
+
constant uint64_t & nb02,
|
| 3747 |
+
constant int64_t & ne10,
|
| 3748 |
+
constant uint64_t & nb10,
|
| 3749 |
+
constant uint64_t & nb11,
|
| 3750 |
+
constant uint64_t & nb1,
|
| 3751 |
+
constant uint64_t & nb2,
|
| 3752 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3753 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 3754 |
+
uint3 tptg [[threads_per_threadgroup]]) {
|
| 3755 |
+
const int64_t i10 = tgpig.x;
|
| 3756 |
+
const int64_t i11 = tgpig.y;
|
| 3757 |
+
|
| 3758 |
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
| 3759 |
+
|
| 3760 |
+
const int64_t i02 = i11;
|
| 3761 |
+
|
| 3762 |
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
| 3763 |
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
| 3764 |
+
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
| 3765 |
}
|
| 3766 |
}
|
| 3767 |
|
|
|
|
| 3953 |
|
| 3954 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
| 3955 |
kernel void kernel_mul_mm_id(
|
| 3956 |
+
device const uchar * ids,
|
| 3957 |
device const uchar * src1,
|
| 3958 |
+
device uchar * dst,
|
| 3959 |
+
constant int64_t & nbi1,
|
| 3960 |
constant int64_t & ne00,
|
| 3961 |
constant int64_t & ne02,
|
| 3962 |
constant int64_t & nb01,
|
| 3963 |
constant int64_t & nb02,
|
| 3964 |
constant int64_t & ne12,
|
| 3965 |
+
constant int64_t & ne13,
|
| 3966 |
constant int64_t & nb10,
|
| 3967 |
constant int64_t & nb11,
|
| 3968 |
constant int64_t & nb12,
|
| 3969 |
constant int64_t & ne0,
|
| 3970 |
constant int64_t & ne1,
|
| 3971 |
+
constant int64_t & nb1,
|
| 3972 |
constant uint & r2,
|
| 3973 |
constant uint & r3,
|
| 3974 |
constant int & idx,
|
|
|
|
| 3986 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3987 |
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 3988 |
|
| 3989 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 3990 |
+
|
| 3991 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 3992 |
+
|
| 3993 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 3994 |
+
|
| 3995 |
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
| 3996 |
+
src0[id],
|
| 3997 |
+
src1 + bid*nb11,
|
| 3998 |
+
(device float *) (dst + bid*nb1),
|
| 3999 |
ne00,
|
| 4000 |
ne02,
|
| 4001 |
nb01,
|
|
|
|
| 4020 |
#define QK_NL 4
|
| 4021 |
#endif
|
| 4022 |
|
| 4023 |
+
//
|
| 4024 |
+
// get rows
|
| 4025 |
+
//
|
| 4026 |
+
|
| 4027 |
typedef void (get_rows_t)(
|
| 4028 |
device const void * src0,
|
| 4029 |
+
device const char * src1,
|
| 4030 |
device float * dst,
|
| 4031 |
constant int64_t & ne00,
|
| 4032 |
constant uint64_t & nb01,
|
| 4033 |
+
constant uint64_t & nb02,
|
| 4034 |
+
constant int64_t & ne10,
|
| 4035 |
+
constant uint64_t & nb10,
|
| 4036 |
+
constant uint64_t & nb11,
|
| 4037 |
constant uint64_t & nb1,
|
| 4038 |
+
constant uint64_t & nb2,
|
| 4039 |
+
uint3, uint, uint3);
|
| 4040 |
|
| 4041 |
+
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
| 4042 |
+
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
| 4043 |
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
| 4044 |
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
| 4045 |
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
|
|
|
| 4051 |
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4052 |
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4053 |
|
| 4054 |
+
//
|
| 4055 |
+
// matrix-matrix multiplication
|
| 4056 |
+
//
|
| 4057 |
+
|
| 4058 |
typedef void (mat_mm_t)(
|
| 4059 |
device const uchar * src0,
|
| 4060 |
device const uchar * src1,
|
|
|
|
| 4087 |
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4088 |
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4089 |
|
| 4090 |
+
//
|
| 4091 |
+
// indirect matrix-matrix multiplication
|
| 4092 |
+
//
|
| 4093 |
+
|
| 4094 |
typedef void (mat_mm_id_t)(
|
| 4095 |
+
device const uchar * ids,
|
| 4096 |
device const uchar * src1,
|
| 4097 |
+
device uchar * dst,
|
| 4098 |
+
constant int64_t & nbi1,
|
| 4099 |
constant int64_t & ne00,
|
| 4100 |
constant int64_t & ne02,
|
| 4101 |
constant int64_t & nb01,
|
| 4102 |
constant int64_t & nb02,
|
| 4103 |
constant int64_t & ne12,
|
| 4104 |
+
constant int64_t & ne13,
|
| 4105 |
constant int64_t & nb10,
|
| 4106 |
constant int64_t & nb11,
|
| 4107 |
constant int64_t & nb12,
|
| 4108 |
constant int64_t & ne0,
|
| 4109 |
constant int64_t & ne1,
|
| 4110 |
+
constant int64_t & nb1,
|
| 4111 |
constant uint & r2,
|
| 4112 |
constant uint & r3,
|
| 4113 |
constant int & idx,
|
|
|
|
| 4134 |
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
| 4135 |
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 4136 |
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 4137 |
+
|
| 4138 |
+
//
|
| 4139 |
+
// matrix-vector multiplication
|
| 4140 |
+
//
|
| 4141 |
+
|
| 4142 |
+
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
| 4143 |
+
kernel void kernel_mul_mv_id_f32_f32(
|
| 4144 |
+
device const char * ids,
|
| 4145 |
+
device const char * src1,
|
| 4146 |
+
device uchar * dst,
|
| 4147 |
+
constant int64_t & nbi1,
|
| 4148 |
+
constant int64_t & ne00,
|
| 4149 |
+
constant int64_t & ne01,
|
| 4150 |
+
constant int64_t & ne02,
|
| 4151 |
+
constant uint64_t & nb00,
|
| 4152 |
+
constant uint64_t & nb01,
|
| 4153 |
+
constant uint64_t & nb02,
|
| 4154 |
+
constant int64_t & ne10,
|
| 4155 |
+
constant int64_t & ne11,
|
| 4156 |
+
constant int64_t & ne12,
|
| 4157 |
+
constant int64_t & ne13,
|
| 4158 |
+
constant uint64_t & nb10,
|
| 4159 |
+
constant uint64_t & nb11,
|
| 4160 |
+
constant uint64_t & nb12,
|
| 4161 |
+
constant int64_t & ne0,
|
| 4162 |
+
constant int64_t & ne1,
|
| 4163 |
+
constant int64_t & nb1,
|
| 4164 |
+
constant uint & r2,
|
| 4165 |
+
constant uint & r3,
|
| 4166 |
+
constant int & idx,
|
| 4167 |
+
device const char * src00,
|
| 4168 |
+
device const char * src01,
|
| 4169 |
+
device const char * src02,
|
| 4170 |
+
device const char * src03,
|
| 4171 |
+
device const char * src04,
|
| 4172 |
+
device const char * src05,
|
| 4173 |
+
device const char * src06,
|
| 4174 |
+
device const char * src07,
|
| 4175 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4176 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 4177 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4178 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4179 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 4180 |
+
|
| 4181 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 4182 |
+
|
| 4183 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 4184 |
+
|
| 4185 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 4186 |
+
|
| 4187 |
+
kernel_mul_mv_f32_f32_impl(
|
| 4188 |
+
src0[id],
|
| 4189 |
+
src1 + bid*nb11,
|
| 4190 |
+
(device float *) (dst + bid*nb1),
|
| 4191 |
+
ne00,
|
| 4192 |
+
ne01,
|
| 4193 |
+
ne02,
|
| 4194 |
+
nb00,
|
| 4195 |
+
nb01,
|
| 4196 |
+
nb02,
|
| 4197 |
+
ne10,
|
| 4198 |
+
ne11,
|
| 4199 |
+
ne12,
|
| 4200 |
+
nb10,
|
| 4201 |
+
nb11,
|
| 4202 |
+
nb12,
|
| 4203 |
+
ne0,
|
| 4204 |
+
ne1,
|
| 4205 |
+
r2,
|
| 4206 |
+
r3,
|
| 4207 |
+
tgpig,
|
| 4208 |
+
tiisg);
|
| 4209 |
+
}
|
| 4210 |
+
|
| 4211 |
+
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
| 4212 |
+
kernel void kernel_mul_mv_id_f16_f32(
|
| 4213 |
+
device const char * ids,
|
| 4214 |
+
device const char * src1,
|
| 4215 |
+
device uchar * dst,
|
| 4216 |
+
constant int64_t & nbi1,
|
| 4217 |
+
constant int64_t & ne00,
|
| 4218 |
+
constant int64_t & ne01,
|
| 4219 |
+
constant int64_t & ne02,
|
| 4220 |
+
constant uint64_t & nb00,
|
| 4221 |
+
constant uint64_t & nb01,
|
| 4222 |
+
constant uint64_t & nb02,
|
| 4223 |
+
constant int64_t & ne10,
|
| 4224 |
+
constant int64_t & ne11,
|
| 4225 |
+
constant int64_t & ne12,
|
| 4226 |
+
constant int64_t & ne13,
|
| 4227 |
+
constant uint64_t & nb10,
|
| 4228 |
+
constant uint64_t & nb11,
|
| 4229 |
+
constant uint64_t & nb12,
|
| 4230 |
+
constant int64_t & ne0,
|
| 4231 |
+
constant int64_t & ne1,
|
| 4232 |
+
constant int64_t & nb1,
|
| 4233 |
+
constant uint & r2,
|
| 4234 |
+
constant uint & r3,
|
| 4235 |
+
constant int & idx,
|
| 4236 |
+
device const char * src00,
|
| 4237 |
+
device const char * src01,
|
| 4238 |
+
device const char * src02,
|
| 4239 |
+
device const char * src03,
|
| 4240 |
+
device const char * src04,
|
| 4241 |
+
device const char * src05,
|
| 4242 |
+
device const char * src06,
|
| 4243 |
+
device const char * src07,
|
| 4244 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4245 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 4246 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4247 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4248 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 4249 |
+
|
| 4250 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 4251 |
+
|
| 4252 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 4253 |
+
|
| 4254 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 4255 |
+
|
| 4256 |
+
kernel_mul_mv_f16_f32_impl(
|
| 4257 |
+
src0[id],
|
| 4258 |
+
src1 + bid*nb11,
|
| 4259 |
+
(device float *) (dst + bid*nb1),
|
| 4260 |
+
ne00,
|
| 4261 |
+
ne01,
|
| 4262 |
+
ne02,
|
| 4263 |
+
nb00,
|
| 4264 |
+
nb01,
|
| 4265 |
+
nb02,
|
| 4266 |
+
ne10,
|
| 4267 |
+
ne11,
|
| 4268 |
+
ne12,
|
| 4269 |
+
nb10,
|
| 4270 |
+
nb11,
|
| 4271 |
+
nb12,
|
| 4272 |
+
ne0,
|
| 4273 |
+
ne1,
|
| 4274 |
+
r2,
|
| 4275 |
+
r3,
|
| 4276 |
+
tgpig,
|
| 4277 |
+
tiisg);
|
| 4278 |
+
}
|
| 4279 |
+
|
| 4280 |
+
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
| 4281 |
+
kernel void kernel_mul_mv_id_q8_0_f32(
|
| 4282 |
+
device const char * ids,
|
| 4283 |
+
device const char * src1,
|
| 4284 |
+
device uchar * dst,
|
| 4285 |
+
constant int64_t & nbi1,
|
| 4286 |
+
constant int64_t & ne00,
|
| 4287 |
+
constant int64_t & ne01,
|
| 4288 |
+
constant int64_t & ne02,
|
| 4289 |
+
constant uint64_t & nb00,
|
| 4290 |
+
constant uint64_t & nb01,
|
| 4291 |
+
constant uint64_t & nb02,
|
| 4292 |
+
constant int64_t & ne10,
|
| 4293 |
+
constant int64_t & ne11,
|
| 4294 |
+
constant int64_t & ne12,
|
| 4295 |
+
constant int64_t & ne13,
|
| 4296 |
+
constant uint64_t & nb10,
|
| 4297 |
+
constant uint64_t & nb11,
|
| 4298 |
+
constant uint64_t & nb12,
|
| 4299 |
+
constant int64_t & ne0,
|
| 4300 |
+
constant int64_t & ne1,
|
| 4301 |
+
constant int64_t & nb1,
|
| 4302 |
+
constant uint & r2,
|
| 4303 |
+
constant uint & r3,
|
| 4304 |
+
constant int & idx,
|
| 4305 |
+
device const char * src00,
|
| 4306 |
+
device const char * src01,
|
| 4307 |
+
device const char * src02,
|
| 4308 |
+
device const char * src03,
|
| 4309 |
+
device const char * src04,
|
| 4310 |
+
device const char * src05,
|
| 4311 |
+
device const char * src06,
|
| 4312 |
+
device const char * src07,
|
| 4313 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4314 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 4315 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4316 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4317 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 4318 |
+
|
| 4319 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 4320 |
+
|
| 4321 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 4322 |
+
|
| 4323 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 4324 |
+
|
| 4325 |
+
kernel_mul_mv_q8_0_f32_impl(
|
| 4326 |
+
src0[id],
|
| 4327 |
+
(device const float *) (src1 + bid*nb11),
|
| 4328 |
+
(device float *) ( dst + bid*nb1),
|
| 4329 |
+
ne00,
|
| 4330 |
+
ne01,
|
| 4331 |
+
ne02,
|
| 4332 |
+
ne10,
|
| 4333 |
+
ne12,
|
| 4334 |
+
ne0,
|
| 4335 |
+
ne1,
|
| 4336 |
+
r2,
|
| 4337 |
+
r3,
|
| 4338 |
+
tgpig,
|
| 4339 |
+
tiisg,
|
| 4340 |
+
sgitg);
|
| 4341 |
+
}
|
| 4342 |
+
|
| 4343 |
+
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
| 4344 |
+
kernel void kernel_mul_mv_id_q4_0_f32(
|
| 4345 |
+
device const char * ids,
|
| 4346 |
+
device const char * src1,
|
| 4347 |
+
device uchar * dst,
|
| 4348 |
+
constant int64_t & nbi1,
|
| 4349 |
+
constant int64_t & ne00,
|
| 4350 |
+
constant int64_t & ne01,
|
| 4351 |
+
constant int64_t & ne02,
|
| 4352 |
+
constant uint64_t & nb00,
|
| 4353 |
+
constant uint64_t & nb01,
|
| 4354 |
+
constant uint64_t & nb02,
|
| 4355 |
+
constant int64_t & ne10,
|
| 4356 |
+
constant int64_t & ne11,
|
| 4357 |
+
constant int64_t & ne12,
|
| 4358 |
+
constant int64_t & ne13,
|
| 4359 |
+
constant uint64_t & nb10,
|
| 4360 |
+
constant uint64_t & nb11,
|
| 4361 |
+
constant uint64_t & nb12,
|
| 4362 |
+
constant int64_t & ne0,
|
| 4363 |
+
constant int64_t & ne1,
|
| 4364 |
+
constant int64_t & nb1,
|
| 4365 |
+
constant uint & r2,
|
| 4366 |
+
constant uint & r3,
|
| 4367 |
+
constant int & idx,
|
| 4368 |
+
device const char * src00,
|
| 4369 |
+
device const char * src01,
|
| 4370 |
+
device const char * src02,
|
| 4371 |
+
device const char * src03,
|
| 4372 |
+
device const char * src04,
|
| 4373 |
+
device const char * src05,
|
| 4374 |
+
device const char * src06,
|
| 4375 |
+
device const char * src07,
|
| 4376 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4377 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 4378 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4379 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4380 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 4381 |
+
|
| 4382 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 4383 |
+
|
| 4384 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 4385 |
+
|
| 4386 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 4387 |
+
|
| 4388 |
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 4389 |
+
src0[id],
|
| 4390 |
+
(device const float *) (src1 + bid*nb11),
|
| 4391 |
+
(device float *) ( dst + bid*nb1),
|
| 4392 |
+
ne00,
|
| 4393 |
+
ne01,
|
| 4394 |
+
ne02,
|
| 4395 |
+
ne10,
|
| 4396 |
+
ne12,
|
| 4397 |
+
ne0,
|
| 4398 |
+
ne1,
|
| 4399 |
+
r2,
|
| 4400 |
+
r3,
|
| 4401 |
+
tgpig,
|
| 4402 |
+
tiisg,
|
| 4403 |
+
sgitg);
|
| 4404 |
+
}
|
| 4405 |
+
|
| 4406 |
+
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
| 4407 |
+
kernel void kernel_mul_mv_id_q4_1_f32(
|
| 4408 |
+
device const char * ids,
|
| 4409 |
+
device const char * src1,
|
| 4410 |
+
device uchar * dst,
|
| 4411 |
+
constant int64_t & nbi1,
|
| 4412 |
+
constant int64_t & ne00,
|
| 4413 |
+
constant int64_t & ne01,
|
| 4414 |
+
constant int64_t & ne02,
|
| 4415 |
+
constant uint64_t & nb00,
|
| 4416 |
+
constant uint64_t & nb01,
|
| 4417 |
+
constant uint64_t & nb02,
|
| 4418 |
+
constant int64_t & ne10,
|
| 4419 |
+
constant int64_t & ne11,
|
| 4420 |
+
constant int64_t & ne12,
|
| 4421 |
+
constant int64_t & ne13,
|
| 4422 |
+
constant uint64_t & nb10,
|
| 4423 |
+
constant uint64_t & nb11,
|
| 4424 |
+
constant uint64_t & nb12,
|
| 4425 |
+
constant int64_t & ne0,
|
| 4426 |
+
constant int64_t & ne1,
|
| 4427 |
+
constant int64_t & nb1,
|
| 4428 |
+
constant uint & r2,
|
| 4429 |
+
constant uint & r3,
|
| 4430 |
+
constant int & idx,
|
| 4431 |
+
device const char * src00,
|
| 4432 |
+
device const char * src01,
|
| 4433 |
+
device const char * src02,
|
| 4434 |
+
device const char * src03,
|
| 4435 |
+
device const char * src04,
|
| 4436 |
+
device const char * src05,
|
| 4437 |
+
device const char * src06,
|
| 4438 |
+
device const char * src07,
|
| 4439 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4440 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 4441 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4442 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4443 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 4444 |
+
|
| 4445 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 4446 |
+
|
| 4447 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 4448 |
+
|
| 4449 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 4450 |
+
|
| 4451 |
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 4452 |
+
src0[id],
|
| 4453 |
+
(device const float *) (src1 + bid*nb11),
|
| 4454 |
+
(device float *) ( dst + bid*nb1),
|
| 4455 |
+
ne00,
|
| 4456 |
+
ne01,
|
| 4457 |
+
ne02,
|
| 4458 |
+
ne10,
|
| 4459 |
+
ne12,
|
| 4460 |
+
ne0,
|
| 4461 |
+
ne1,
|
| 4462 |
+
r2,
|
| 4463 |
+
r3,
|
| 4464 |
+
tgpig,
|
| 4465 |
+
tiisg,
|
| 4466 |
+
sgitg);
|
| 4467 |
+
}
|
| 4468 |
+
|
| 4469 |
+
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
| 4470 |
+
kernel void kernel_mul_mv_id_q5_0_f32(
|
| 4471 |
+
device const char * ids,
|
| 4472 |
+
device const char * src1,
|
| 4473 |
+
device uchar * dst,
|
| 4474 |
+
constant int64_t & nbi1,
|
| 4475 |
+
constant int64_t & ne00,
|
| 4476 |
+
constant int64_t & ne01,
|
| 4477 |
+
constant int64_t & ne02,
|
| 4478 |
+
constant uint64_t & nb00,
|
| 4479 |
+
constant uint64_t & nb01,
|
| 4480 |
+
constant uint64_t & nb02,
|
| 4481 |
+
constant int64_t & ne10,
|
| 4482 |
+
constant int64_t & ne11,
|
| 4483 |
+
constant int64_t & ne12,
|
| 4484 |
+
constant int64_t & ne13,
|
| 4485 |
+
constant uint64_t & nb10,
|
| 4486 |
+
constant uint64_t & nb11,
|
| 4487 |
+
constant uint64_t & nb12,
|
| 4488 |
+
constant int64_t & ne0,
|
| 4489 |
+
constant int64_t & ne1,
|
| 4490 |
+
constant int64_t & nb1,
|
| 4491 |
+
constant uint & r2,
|
| 4492 |
+
constant uint & r3,
|
| 4493 |
+
constant int & idx,
|
| 4494 |
+
device const char * src00,
|
| 4495 |
+
device const char * src01,
|
| 4496 |
+
device const char * src02,
|
| 4497 |
+
device const char * src03,
|
| 4498 |
+
device const char * src04,
|
| 4499 |
+
device const char * src05,
|
| 4500 |
+
device const char * src06,
|
| 4501 |
+
device const char * src07,
|
| 4502 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4503 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 4504 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4505 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4506 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 4507 |
+
|
| 4508 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 4509 |
+
|
| 4510 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 4511 |
+
|
| 4512 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 4513 |
+
|
| 4514 |
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 4515 |
+
src0[id],
|
| 4516 |
+
(device const float *) (src1 + bid*nb11),
|
| 4517 |
+
(device float *) ( dst + bid*nb1),
|
| 4518 |
+
ne00,
|
| 4519 |
+
ne01,
|
| 4520 |
+
ne02,
|
| 4521 |
+
ne10,
|
| 4522 |
+
ne12,
|
| 4523 |
+
ne0,
|
| 4524 |
+
ne1,
|
| 4525 |
+
r2,
|
| 4526 |
+
r3,
|
| 4527 |
+
tgpig,
|
| 4528 |
+
tiisg,
|
| 4529 |
+
sgitg);
|
| 4530 |
+
}
|
| 4531 |
+
|
| 4532 |
+
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
| 4533 |
+
kernel void kernel_mul_mv_id_q5_1_f32(
|
| 4534 |
+
device const char * ids,
|
| 4535 |
+
device const char * src1,
|
| 4536 |
+
device uchar * dst,
|
| 4537 |
+
constant int64_t & nbi1,
|
| 4538 |
+
constant int64_t & ne00,
|
| 4539 |
+
constant int64_t & ne01,
|
| 4540 |
+
constant int64_t & ne02,
|
| 4541 |
+
constant uint64_t & nb00,
|
| 4542 |
+
constant uint64_t & nb01,
|
| 4543 |
+
constant uint64_t & nb02,
|
| 4544 |
+
constant int64_t & ne10,
|
| 4545 |
+
constant int64_t & ne11,
|
| 4546 |
+
constant int64_t & ne12,
|
| 4547 |
+
constant int64_t & ne13,
|
| 4548 |
+
constant uint64_t & nb10,
|
| 4549 |
+
constant uint64_t & nb11,
|
| 4550 |
+
constant uint64_t & nb12,
|
| 4551 |
+
constant int64_t & ne0,
|
| 4552 |
+
constant int64_t & ne1,
|
| 4553 |
+
constant int64_t & nb1,
|
| 4554 |
+
constant uint & r2,
|
| 4555 |
+
constant uint & r3,
|
| 4556 |
+
constant int & idx,
|
| 4557 |
+
device const char * src00,
|
| 4558 |
+
device const char * src01,
|
| 4559 |
+
device const char * src02,
|
| 4560 |
+
device const char * src03,
|
| 4561 |
+
device const char * src04,
|
| 4562 |
+
device const char * src05,
|
| 4563 |
+
device const char * src06,
|
| 4564 |
+
device const char * src07,
|
| 4565 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4566 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 4567 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4568 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4569 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 4570 |
+
|
| 4571 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 4572 |
+
|
| 4573 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 4574 |
+
|
| 4575 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 4576 |
+
|
| 4577 |
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 4578 |
+
src0[id],
|
| 4579 |
+
(device const float *) (src1 + bid*nb11),
|
| 4580 |
+
(device float *) ( dst + bid*nb1),
|
| 4581 |
+
ne00,
|
| 4582 |
+
ne01,
|
| 4583 |
+
ne02,
|
| 4584 |
+
ne10,
|
| 4585 |
+
ne12,
|
| 4586 |
+
ne0,
|
| 4587 |
+
ne1,
|
| 4588 |
+
r2,
|
| 4589 |
+
r3,
|
| 4590 |
+
tgpig,
|
| 4591 |
+
tiisg,
|
| 4592 |
+
sgitg);
|
| 4593 |
+
}
|
| 4594 |
+
|
| 4595 |
+
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
| 4596 |
+
kernel void kernel_mul_mv_id_q2_K_f32(
|
| 4597 |
+
device const char * ids,
|
| 4598 |
+
device const char * src1,
|
| 4599 |
+
device uchar * dst,
|
| 4600 |
+
constant int64_t & nbi1,
|
| 4601 |
+
constant int64_t & ne00,
|
| 4602 |
+
constant int64_t & ne01,
|
| 4603 |
+
constant int64_t & ne02,
|
| 4604 |
+
constant uint64_t & nb00,
|
| 4605 |
+
constant uint64_t & nb01,
|
| 4606 |
+
constant uint64_t & nb02,
|
| 4607 |
+
constant int64_t & ne10,
|
| 4608 |
+
constant int64_t & ne11,
|
| 4609 |
+
constant int64_t & ne12,
|
| 4610 |
+
constant int64_t & ne13,
|
| 4611 |
+
constant uint64_t & nb10,
|
| 4612 |
+
constant uint64_t & nb11,
|
| 4613 |
+
constant uint64_t & nb12,
|
| 4614 |
+
constant int64_t & ne0,
|
| 4615 |
+
constant int64_t & ne1,
|
| 4616 |
+
constant int64_t & nb1,
|
| 4617 |
+
constant uint & r2,
|
| 4618 |
+
constant uint & r3,
|
| 4619 |
+
constant int & idx,
|
| 4620 |
+
device const char * src00,
|
| 4621 |
+
device const char * src01,
|
| 4622 |
+
device const char * src02,
|
| 4623 |
+
device const char * src03,
|
| 4624 |
+
device const char * src04,
|
| 4625 |
+
device const char * src05,
|
| 4626 |
+
device const char * src06,
|
| 4627 |
+
device const char * src07,
|
| 4628 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4629 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 4630 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4631 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4632 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 4633 |
+
|
| 4634 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 4635 |
+
|
| 4636 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 4637 |
+
|
| 4638 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 4639 |
+
|
| 4640 |
+
kernel_mul_mv_q2_K_f32_impl(
|
| 4641 |
+
src0[id],
|
| 4642 |
+
(device const float *) (src1 + bid*nb11),
|
| 4643 |
+
(device float *) ( dst + bid*nb1),
|
| 4644 |
+
ne00,
|
| 4645 |
+
ne01,
|
| 4646 |
+
ne02,
|
| 4647 |
+
ne10,
|
| 4648 |
+
ne12,
|
| 4649 |
+
ne0,
|
| 4650 |
+
ne1,
|
| 4651 |
+
r2,
|
| 4652 |
+
r3,
|
| 4653 |
+
tgpig,
|
| 4654 |
+
tiisg,
|
| 4655 |
+
sgitg);
|
| 4656 |
+
}
|
| 4657 |
+
|
| 4658 |
+
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
| 4659 |
+
kernel void kernel_mul_mv_id_q3_K_f32(
|
| 4660 |
+
device const char * ids,
|
| 4661 |
+
device const char * src1,
|
| 4662 |
+
device uchar * dst,
|
| 4663 |
+
constant int64_t & nbi1,
|
| 4664 |
+
constant int64_t & ne00,
|
| 4665 |
+
constant int64_t & ne01,
|
| 4666 |
+
constant int64_t & ne02,
|
| 4667 |
+
constant uint64_t & nb00,
|
| 4668 |
+
constant uint64_t & nb01,
|
| 4669 |
+
constant uint64_t & nb02,
|
| 4670 |
+
constant int64_t & ne10,
|
| 4671 |
+
constant int64_t & ne11,
|
| 4672 |
+
constant int64_t & ne12,
|
| 4673 |
+
constant int64_t & ne13,
|
| 4674 |
+
constant uint64_t & nb10,
|
| 4675 |
+
constant uint64_t & nb11,
|
| 4676 |
+
constant uint64_t & nb12,
|
| 4677 |
+
constant int64_t & ne0,
|
| 4678 |
+
constant int64_t & ne1,
|
| 4679 |
+
constant int64_t & nb1,
|
| 4680 |
+
constant uint & r2,
|
| 4681 |
+
constant uint & r3,
|
| 4682 |
+
constant int & idx,
|
| 4683 |
+
device const char * src00,
|
| 4684 |
+
device const char * src01,
|
| 4685 |
+
device const char * src02,
|
| 4686 |
+
device const char * src03,
|
| 4687 |
+
device const char * src04,
|
| 4688 |
+
device const char * src05,
|
| 4689 |
+
device const char * src06,
|
| 4690 |
+
device const char * src07,
|
| 4691 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4692 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 4693 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4694 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4695 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 4696 |
+
|
| 4697 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 4698 |
+
|
| 4699 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 4700 |
+
|
| 4701 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 4702 |
+
|
| 4703 |
+
kernel_mul_mv_q3_K_f32_impl(
|
| 4704 |
+
src0[id],
|
| 4705 |
+
(device const float *) (src1 + bid*nb11),
|
| 4706 |
+
(device float *) ( dst + bid*nb1),
|
| 4707 |
+
ne00,
|
| 4708 |
+
ne01,
|
| 4709 |
+
ne02,
|
| 4710 |
+
ne10,
|
| 4711 |
+
ne12,
|
| 4712 |
+
ne0,
|
| 4713 |
+
ne1,
|
| 4714 |
+
r2,
|
| 4715 |
+
r3,
|
| 4716 |
+
tgpig,
|
| 4717 |
+
tiisg,
|
| 4718 |
+
sgitg);
|
| 4719 |
+
}
|
| 4720 |
+
|
| 4721 |
+
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
| 4722 |
+
kernel void kernel_mul_mv_id_q4_K_f32(
|
| 4723 |
+
device const char * ids,
|
| 4724 |
+
device const char * src1,
|
| 4725 |
+
device uchar * dst,
|
| 4726 |
+
constant int64_t & nbi1,
|
| 4727 |
+
constant int64_t & ne00,
|
| 4728 |
+
constant int64_t & ne01,
|
| 4729 |
+
constant int64_t & ne02,
|
| 4730 |
+
constant uint64_t & nb00,
|
| 4731 |
+
constant uint64_t & nb01,
|
| 4732 |
+
constant uint64_t & nb02,
|
| 4733 |
+
constant int64_t & ne10,
|
| 4734 |
+
constant int64_t & ne11,
|
| 4735 |
+
constant int64_t & ne12,
|
| 4736 |
+
constant int64_t & ne13,
|
| 4737 |
+
constant uint64_t & nb10,
|
| 4738 |
+
constant uint64_t & nb11,
|
| 4739 |
+
constant uint64_t & nb12,
|
| 4740 |
+
constant int64_t & ne0,
|
| 4741 |
+
constant int64_t & ne1,
|
| 4742 |
+
constant int64_t & nb1,
|
| 4743 |
+
constant uint & r2,
|
| 4744 |
+
constant uint & r3,
|
| 4745 |
+
constant int & idx,
|
| 4746 |
+
device const char * src00,
|
| 4747 |
+
device const char * src01,
|
| 4748 |
+
device const char * src02,
|
| 4749 |
+
device const char * src03,
|
| 4750 |
+
device const char * src04,
|
| 4751 |
+
device const char * src05,
|
| 4752 |
+
device const char * src06,
|
| 4753 |
+
device const char * src07,
|
| 4754 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4755 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 4756 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4757 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4758 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 4759 |
+
|
| 4760 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 4761 |
+
|
| 4762 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 4763 |
+
|
| 4764 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 4765 |
+
|
| 4766 |
+
kernel_mul_mv_q4_K_f32_impl(
|
| 4767 |
+
src0[id],
|
| 4768 |
+
(device const float *) (src1 + bid*nb11),
|
| 4769 |
+
(device float *) ( dst + bid*nb1),
|
| 4770 |
+
ne00,
|
| 4771 |
+
ne01,
|
| 4772 |
+
ne02,
|
| 4773 |
+
ne10,
|
| 4774 |
+
ne12,
|
| 4775 |
+
ne0,
|
| 4776 |
+
ne1,
|
| 4777 |
+
r2,
|
| 4778 |
+
r3,
|
| 4779 |
+
tgpig,
|
| 4780 |
+
tiisg,
|
| 4781 |
+
sgitg);
|
| 4782 |
+
}
|
| 4783 |
+
|
| 4784 |
+
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
| 4785 |
+
kernel void kernel_mul_mv_id_q5_K_f32(
|
| 4786 |
+
device const char * ids,
|
| 4787 |
+
device const char * src1,
|
| 4788 |
+
device uchar * dst,
|
| 4789 |
+
constant int64_t & nbi1,
|
| 4790 |
+
constant int64_t & ne00,
|
| 4791 |
+
constant int64_t & ne01,
|
| 4792 |
+
constant int64_t & ne02,
|
| 4793 |
+
constant uint64_t & nb00,
|
| 4794 |
+
constant uint64_t & nb01,
|
| 4795 |
+
constant uint64_t & nb02,
|
| 4796 |
+
constant int64_t & ne10,
|
| 4797 |
+
constant int64_t & ne11,
|
| 4798 |
+
constant int64_t & ne12,
|
| 4799 |
+
constant int64_t & ne13,
|
| 4800 |
+
constant uint64_t & nb10,
|
| 4801 |
+
constant uint64_t & nb11,
|
| 4802 |
+
constant uint64_t & nb12,
|
| 4803 |
+
constant int64_t & ne0,
|
| 4804 |
+
constant int64_t & ne1,
|
| 4805 |
+
constant int64_t & nb1,
|
| 4806 |
+
constant uint & r2,
|
| 4807 |
+
constant uint & r3,
|
| 4808 |
+
constant int & idx,
|
| 4809 |
+
device const char * src00,
|
| 4810 |
+
device const char * src01,
|
| 4811 |
+
device const char * src02,
|
| 4812 |
+
device const char * src03,
|
| 4813 |
+
device const char * src04,
|
| 4814 |
+
device const char * src05,
|
| 4815 |
+
device const char * src06,
|
| 4816 |
+
device const char * src07,
|
| 4817 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4818 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 4819 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4820 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4821 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 4822 |
+
|
| 4823 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 4824 |
+
|
| 4825 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 4826 |
+
|
| 4827 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 4828 |
+
|
| 4829 |
+
kernel_mul_mv_q5_K_f32_impl(
|
| 4830 |
+
src0[id],
|
| 4831 |
+
(device const float *) (src1 + bid*nb11),
|
| 4832 |
+
(device float *) ( dst + bid*nb1),
|
| 4833 |
+
ne00,
|
| 4834 |
+
ne01,
|
| 4835 |
+
ne02,
|
| 4836 |
+
ne10,
|
| 4837 |
+
ne12,
|
| 4838 |
+
ne0,
|
| 4839 |
+
ne1,
|
| 4840 |
+
r2,
|
| 4841 |
+
r3,
|
| 4842 |
+
tgpig,
|
| 4843 |
+
tiisg,
|
| 4844 |
+
sgitg);
|
| 4845 |
+
}
|
| 4846 |
+
|
| 4847 |
+
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
| 4848 |
+
kernel void kernel_mul_mv_id_q6_K_f32(
|
| 4849 |
+
device const char * ids,
|
| 4850 |
+
device const char * src1,
|
| 4851 |
+
device uchar * dst,
|
| 4852 |
+
constant int64_t & nbi1,
|
| 4853 |
+
constant int64_t & ne00,
|
| 4854 |
+
constant int64_t & ne01,
|
| 4855 |
+
constant int64_t & ne02,
|
| 4856 |
+
constant uint64_t & nb00,
|
| 4857 |
+
constant uint64_t & nb01,
|
| 4858 |
+
constant uint64_t & nb02,
|
| 4859 |
+
constant int64_t & ne10,
|
| 4860 |
+
constant int64_t & ne11,
|
| 4861 |
+
constant int64_t & ne12,
|
| 4862 |
+
constant int64_t & ne13,
|
| 4863 |
+
constant uint64_t & nb10,
|
| 4864 |
+
constant uint64_t & nb11,
|
| 4865 |
+
constant uint64_t & nb12,
|
| 4866 |
+
constant int64_t & ne0,
|
| 4867 |
+
constant int64_t & ne1,
|
| 4868 |
+
constant int64_t & nb1,
|
| 4869 |
+
constant uint & r2,
|
| 4870 |
+
constant uint & r3,
|
| 4871 |
+
constant int & idx,
|
| 4872 |
+
device const char * src00,
|
| 4873 |
+
device const char * src01,
|
| 4874 |
+
device const char * src02,
|
| 4875 |
+
device const char * src03,
|
| 4876 |
+
device const char * src04,
|
| 4877 |
+
device const char * src05,
|
| 4878 |
+
device const char * src06,
|
| 4879 |
+
device const char * src07,
|
| 4880 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4881 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 4882 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 4883 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4884 |
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
| 4885 |
+
|
| 4886 |
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 4887 |
+
|
| 4888 |
+
tgpig.z = tgpig.z%(ne12*ne13);
|
| 4889 |
+
|
| 4890 |
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 4891 |
+
|
| 4892 |
+
kernel_mul_mv_q6_K_f32_impl(
|
| 4893 |
+
src0[id],
|
| 4894 |
+
(device const float *) (src1 + bid*nb11),
|
| 4895 |
+
(device float *) ( dst + bid*nb1),
|
| 4896 |
+
ne00,
|
| 4897 |
+
ne01,
|
| 4898 |
+
ne02,
|
| 4899 |
+
ne10,
|
| 4900 |
+
ne12,
|
| 4901 |
+
ne0,
|
| 4902 |
+
ne1,
|
| 4903 |
+
r2,
|
| 4904 |
+
r3,
|
| 4905 |
+
tgpig,
|
| 4906 |
+
tiisg,
|
| 4907 |
+
sgitg);
|
| 4908 |
+
}
|
ggml-quants.c
CHANGED
|
@@ -3114,7 +3114,7 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
|
|
| 3114 |
|
| 3115 |
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
| 3116 |
|
| 3117 |
-
// These
|
| 3118 |
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
|
| 3119 |
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
|
| 3120 |
|
|
@@ -4757,7 +4757,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
| 4757 |
|
| 4758 |
vl = 16;
|
| 4759 |
|
| 4760 |
-
//
|
| 4761 |
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
|
| 4762 |
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
|
| 4763 |
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
|
|
|
|
| 3114 |
|
| 3115 |
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
| 3116 |
|
| 3117 |
+
// These temporary registers are for masking and shift operations
|
| 3118 |
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
|
| 3119 |
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
|
| 3120 |
|
|
|
|
| 4757 |
|
| 4758 |
vl = 16;
|
| 4759 |
|
| 4760 |
+
// retrieve lane to multiply with scale
|
| 4761 |
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
|
| 4762 |
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
|
| 4763 |
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
|
ggml.c
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe"
|
| 2 |
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
| 3 |
|
| 4 |
#include "ggml-impl.h"
|
|
@@ -33,7 +33,7 @@
|
|
| 33 |
// we should just be careful :)
|
| 34 |
#pragma warning(disable: 4244 4267)
|
| 35 |
|
| 36 |
-
// disable POSIX deprecation
|
| 37 |
// these functions are never going away, anyway
|
| 38 |
#pragma warning(disable: 4996)
|
| 39 |
#endif
|
|
@@ -1395,7 +1395,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
|
|
| 1395 |
inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
|
| 1396 |
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
|
| 1397 |
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
|
| 1398 |
-
inline static void
|
| 1399 |
|
| 1400 |
static const float GELU_COEF_A = 0.044715f;
|
| 1401 |
static const float GELU_QUICK_COEF = -1.702f;
|
|
@@ -1623,7 +1623,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 1623 |
"POOL_1D",
|
| 1624 |
"POOL_2D",
|
| 1625 |
"UPSCALE",
|
|
|
|
| 1626 |
"ARGSORT",
|
|
|
|
| 1627 |
|
| 1628 |
"FLASH_ATTN",
|
| 1629 |
"FLASH_FF",
|
|
@@ -1650,7 +1652,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 1650 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 1651 |
};
|
| 1652 |
|
| 1653 |
-
static_assert(GGML_OP_COUNT ==
|
| 1654 |
|
| 1655 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 1656 |
"none",
|
|
@@ -1707,7 +1709,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1707 |
"pool_1d(x)",
|
| 1708 |
"pool_2d(x)",
|
| 1709 |
"upscale(x)",
|
|
|
|
| 1710 |
"argsort(x)",
|
|
|
|
| 1711 |
|
| 1712 |
"flash_attn(x)",
|
| 1713 |
"flash_ff(x)",
|
|
@@ -1734,7 +1738,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1734 |
"cross_entropy_loss_back(x,y)",
|
| 1735 |
};
|
| 1736 |
|
| 1737 |
-
static_assert(GGML_OP_COUNT ==
|
| 1738 |
|
| 1739 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1740 |
|
|
@@ -1750,17 +1754,16 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
|
| 1750 |
"GELU",
|
| 1751 |
"GELU_QUICK",
|
| 1752 |
"SILU",
|
| 1753 |
-
"LEAKY",
|
| 1754 |
};
|
| 1755 |
|
| 1756 |
-
static_assert(GGML_UNARY_OP_COUNT ==
|
| 1757 |
|
| 1758 |
|
| 1759 |
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
| 1760 |
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
| 1761 |
|
| 1762 |
// WARN:
|
| 1763 |
-
// Mis-
|
| 1764 |
// * At best it crash or talks nosense.
|
| 1765 |
// * At worst it talks slightly difference but hard to perceive.
|
| 1766 |
//
|
|
@@ -3830,12 +3833,25 @@ struct ggml_tensor * ggml_relu_inplace(
|
|
| 3830 |
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
|
| 3831 |
}
|
| 3832 |
|
| 3833 |
-
//
|
| 3834 |
|
| 3835 |
-
struct ggml_tensor *
|
| 3836 |
struct ggml_context * ctx,
|
| 3837 |
-
struct ggml_tensor * a) {
|
| 3838 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3839 |
}
|
| 3840 |
|
| 3841 |
// ggml_gelu
|
|
@@ -4022,8 +4038,9 @@ static struct ggml_tensor * ggml_group_norm_impl(
|
|
| 4022 |
|
| 4023 |
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
| 4024 |
|
| 4025 |
-
result->op = GGML_OP_GROUP_NORM;
|
| 4026 |
result->op_params[0] = n_groups;
|
|
|
|
|
|
|
| 4027 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 4028 |
result->src[0] = a;
|
| 4029 |
result->src[1] = NULL; // TODO: maybe store epsilon here?
|
|
@@ -4075,17 +4092,18 @@ struct ggml_tensor * ggml_mul_mat(
|
|
| 4075 |
|
| 4076 |
struct ggml_tensor * ggml_mul_mat_id(
|
| 4077 |
struct ggml_context * ctx,
|
| 4078 |
-
struct ggml_tensor * as[],
|
|
|
|
| 4079 |
struct ggml_tensor * ids,
|
| 4080 |
int id,
|
| 4081 |
struct ggml_tensor * b) {
|
| 4082 |
|
| 4083 |
-
int64_t n_as = ids->ne[0];
|
| 4084 |
-
|
| 4085 |
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
| 4086 |
-
GGML_ASSERT(
|
|
|
|
|
|
|
| 4087 |
GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
|
| 4088 |
-
GGML_ASSERT(id >= 0 && id <
|
| 4089 |
|
| 4090 |
bool is_node = false;
|
| 4091 |
|
|
@@ -4097,13 +4115,14 @@ struct ggml_tensor * ggml_mul_mat_id(
|
|
| 4097 |
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
|
| 4098 |
|
| 4099 |
ggml_set_op_params_i32(result, 0, id);
|
|
|
|
| 4100 |
|
| 4101 |
result->op = GGML_OP_MUL_MAT_ID;
|
| 4102 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 4103 |
result->src[0] = ids;
|
| 4104 |
result->src[1] = b;
|
| 4105 |
|
| 4106 |
-
for (
|
| 4107 |
struct ggml_tensor * a = as[i];
|
| 4108 |
GGML_ASSERT(ggml_are_same_shape(as[0], a));
|
| 4109 |
GGML_ASSERT(ggml_can_mul_mat(a, b));
|
|
@@ -4731,7 +4750,9 @@ struct ggml_tensor * ggml_get_rows(
|
|
| 4731 |
struct ggml_context * ctx,
|
| 4732 |
struct ggml_tensor * a,
|
| 4733 |
struct ggml_tensor * b) {
|
| 4734 |
-
GGML_ASSERT(
|
|
|
|
|
|
|
| 4735 |
|
| 4736 |
bool is_node = false;
|
| 4737 |
|
|
@@ -4741,7 +4762,7 @@ struct ggml_tensor * ggml_get_rows(
|
|
| 4741 |
|
| 4742 |
// TODO: implement non F32 return
|
| 4743 |
//struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
|
| 4744 |
-
struct ggml_tensor * result =
|
| 4745 |
|
| 4746 |
result->op = GGML_OP_GET_ROWS;
|
| 4747 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
|
@@ -5519,6 +5540,30 @@ static struct ggml_tensor * ggml_upscale_impl(
|
|
| 5519 |
return result;
|
| 5520 |
}
|
| 5521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5522 |
struct ggml_tensor * ggml_upscale(
|
| 5523 |
struct ggml_context * ctx,
|
| 5524 |
struct ggml_tensor * a,
|
|
@@ -7520,7 +7565,7 @@ static void ggml_compute_forward_acc_f32(
|
|
| 7520 |
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
| 7521 |
|
| 7522 |
// view src0 and dst with these strides and data offset inbytes during acc
|
| 7523 |
-
// nb0 is
|
| 7524 |
size_t nb1 = ((int32_t *) dst->op_params)[0];
|
| 7525 |
size_t nb2 = ((int32_t *) dst->op_params)[1];
|
| 7526 |
size_t nb3 = ((int32_t *) dst->op_params)[2];
|
|
@@ -7714,8 +7759,10 @@ static void ggml_compute_forward_mul_f32(
|
|
| 7714 |
const int ith = params->ith;
|
| 7715 |
const int nth = params->nth;
|
| 7716 |
|
|
|
|
| 7717 |
#ifdef GGML_USE_CLBLAST
|
| 7718 |
if (src1->backend == GGML_BACKEND_GPU) {
|
|
|
|
| 7719 |
if (ith == 0) {
|
| 7720 |
ggml_cl_mul(src0, src1, dst);
|
| 7721 |
}
|
|
@@ -8981,10 +9028,9 @@ static void ggml_compute_forward_silu(
|
|
| 8981 |
} break;
|
| 8982 |
}
|
| 8983 |
}
|
|
|
|
| 8984 |
|
| 8985 |
-
|
| 8986 |
-
|
| 8987 |
-
static void ggml_compute_forward_leaky_f32(
|
| 8988 |
const struct ggml_compute_params * params,
|
| 8989 |
const struct ggml_tensor * src0,
|
| 8990 |
struct ggml_tensor * dst) {
|
|
@@ -8998,24 +9044,27 @@ static void ggml_compute_forward_leaky_f32(
|
|
| 8998 |
const int n = ggml_nrows(src0);
|
| 8999 |
const int nc = src0->ne[0];
|
| 9000 |
|
|
|
|
|
|
|
|
|
|
| 9001 |
assert(dst->nb[0] == sizeof(float));
|
| 9002 |
assert(src0->nb[0] == sizeof(float));
|
| 9003 |
|
| 9004 |
for (int i = 0; i < n; i++) {
|
| 9005 |
-
|
| 9006 |
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
| 9007 |
-
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
| 9008 |
}
|
| 9009 |
}
|
| 9010 |
|
| 9011 |
-
static void
|
| 9012 |
const struct ggml_compute_params * params,
|
| 9013 |
const struct ggml_tensor * src0,
|
| 9014 |
struct ggml_tensor * dst) {
|
| 9015 |
switch (src0->type) {
|
| 9016 |
case GGML_TYPE_F32:
|
| 9017 |
{
|
| 9018 |
-
|
| 9019 |
} break;
|
| 9020 |
default:
|
| 9021 |
{
|
|
@@ -9504,8 +9553,11 @@ static bool ggml_compute_forward_mul_mat_use_blas(
|
|
| 9504 |
const int64_t ne0 = dst->ne[0];
|
| 9505 |
const int64_t ne1 = dst->ne[1];
|
| 9506 |
|
|
|
|
|
|
|
| 9507 |
// TODO: find the optimal values for these
|
| 9508 |
-
if (
|
|
|
|
| 9509 |
ggml_is_contiguous(src1) &&
|
| 9510 |
//src0->type == GGML_TYPE_F32 &&
|
| 9511 |
src1->type == GGML_TYPE_F32 &&
|
|
@@ -9519,11 +9571,16 @@ static bool ggml_compute_forward_mul_mat_use_blas(
|
|
| 9519 |
}
|
| 9520 |
#endif
|
| 9521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9522 |
static void ggml_compute_forward_mul_mat(
|
| 9523 |
const struct ggml_compute_params * params,
|
| 9524 |
const struct ggml_tensor * src0,
|
| 9525 |
const struct ggml_tensor * src1,
|
| 9526 |
-
struct ggml_tensor * dst
|
|
|
|
| 9527 |
int64_t t0 = ggml_perf_time_us();
|
| 9528 |
UNUSED(t0);
|
| 9529 |
|
|
@@ -9591,10 +9648,9 @@ static void ggml_compute_forward_mul_mat(
|
|
| 9591 |
const int64_t i03 = i13/r3;
|
| 9592 |
const int64_t i02 = i12/r2;
|
| 9593 |
|
| 9594 |
-
const void * x = (char *) src0->data +
|
| 9595 |
-
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
|
| 9596 |
-
|
| 9597 |
-
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
|
| 9598 |
|
| 9599 |
if (type != GGML_TYPE_F32) {
|
| 9600 |
float * const wdata = params->wdata;
|
|
@@ -9611,10 +9667,10 @@ static void ggml_compute_forward_mul_mat(
|
|
| 9611 |
}
|
| 9612 |
|
| 9613 |
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
| 9614 |
-
|
| 9615 |
-
|
| 9616 |
-
|
| 9617 |
-
|
| 9618 |
}
|
| 9619 |
}
|
| 9620 |
|
|
@@ -9630,6 +9686,7 @@ static void ggml_compute_forward_mul_mat(
|
|
| 9630 |
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
|
| 9631 |
|
| 9632 |
assert(params->wsize >= ne11*ne12*ne13*row_size);
|
|
|
|
| 9633 |
|
| 9634 |
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
| 9635 |
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
@@ -9652,7 +9709,7 @@ static void ggml_compute_forward_mul_mat(
|
|
| 9652 |
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
|
| 9653 |
|
| 9654 |
const int64_t nr0 = ne01; // src0 rows
|
| 9655 |
-
const int64_t nr1 =
|
| 9656 |
|
| 9657 |
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
|
| 9658 |
|
|
@@ -9694,9 +9751,9 @@ static void ggml_compute_forward_mul_mat(
|
|
| 9694 |
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
| 9695 |
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
| 9696 |
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
| 9697 |
-
const int64_t i13 = (ir1/(ne12*
|
| 9698 |
-
const int64_t i12 = (ir1 - i13*ne12*
|
| 9699 |
-
const int64_t i11 = (ir1 - i13*ne12*
|
| 9700 |
|
| 9701 |
// broadcast src0 into src1
|
| 9702 |
const int64_t i03 = i13/r3;
|
|
@@ -9736,20 +9793,28 @@ static void ggml_compute_forward_mul_mat(
|
|
| 9736 |
|
| 9737 |
static void ggml_compute_forward_mul_mat_id(
|
| 9738 |
const struct ggml_compute_params * params,
|
|
|
|
|
|
|
| 9739 |
struct ggml_tensor * dst) {
|
| 9740 |
|
| 9741 |
-
|
| 9742 |
-
|
| 9743 |
-
|
| 9744 |
-
|
|
|
|
| 9745 |
|
| 9746 |
-
const
|
|
|
|
|
|
|
| 9747 |
|
| 9748 |
-
|
|
|
|
| 9749 |
|
| 9750 |
-
|
| 9751 |
|
| 9752 |
-
|
|
|
|
|
|
|
| 9753 |
}
|
| 9754 |
|
| 9755 |
// ggml_compute_forward_out_prod
|
|
@@ -10161,7 +10226,7 @@ static void ggml_compute_forward_set_f32(
|
|
| 10161 |
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
| 10162 |
|
| 10163 |
// view src0 and dst with these strides and data offset inbytes during set
|
| 10164 |
-
// nb0 is
|
| 10165 |
size_t nb1 = ((int32_t *) dst->op_params)[0];
|
| 10166 |
size_t nb2 = ((int32_t *) dst->op_params)[1];
|
| 10167 |
size_t nb3 = ((int32_t *) dst->op_params)[2];
|
|
@@ -10325,21 +10390,30 @@ static void ggml_compute_forward_get_rows_q(
|
|
| 10325 |
return;
|
| 10326 |
}
|
| 10327 |
|
| 10328 |
-
|
| 10329 |
-
|
|
|
|
|
|
|
|
|
|
| 10330 |
const enum ggml_type type = src0->type;
|
| 10331 |
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
|
| 10332 |
|
| 10333 |
-
assert(
|
| 10334 |
-
assert(
|
| 10335 |
-
assert(
|
|
|
|
| 10336 |
|
| 10337 |
-
|
| 10338 |
-
|
|
|
|
|
|
|
|
|
|
| 10339 |
|
| 10340 |
-
|
| 10341 |
-
|
| 10342 |
-
|
|
|
|
|
|
|
| 10343 |
}
|
| 10344 |
}
|
| 10345 |
|
|
@@ -10354,19 +10428,26 @@ static void ggml_compute_forward_get_rows_f16(
|
|
| 10354 |
return;
|
| 10355 |
}
|
| 10356 |
|
| 10357 |
-
|
| 10358 |
-
const int nr = ggml_nelements(src1);
|
| 10359 |
|
| 10360 |
-
|
| 10361 |
-
|
| 10362 |
-
assert(src0->nb[0] == sizeof(ggml_fp16_t));
|
| 10363 |
|
| 10364 |
-
|
| 10365 |
-
|
|
|
|
|
|
|
| 10366 |
|
| 10367 |
-
|
| 10368 |
-
|
| 10369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10370 |
}
|
| 10371 |
}
|
| 10372 |
}
|
|
@@ -10382,19 +10463,27 @@ static void ggml_compute_forward_get_rows_f32(
|
|
| 10382 |
return;
|
| 10383 |
}
|
| 10384 |
|
| 10385 |
-
|
| 10386 |
-
const int nr = ggml_nelements(src1);
|
| 10387 |
|
| 10388 |
-
|
| 10389 |
-
|
| 10390 |
-
assert(src0->nb[0] == sizeof(float));
|
| 10391 |
|
| 10392 |
-
|
| 10393 |
-
|
|
|
|
|
|
|
| 10394 |
|
| 10395 |
-
|
| 10396 |
-
|
| 10397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10398 |
}
|
| 10399 |
}
|
| 10400 |
|
|
@@ -12114,6 +12203,7 @@ static void ggml_compute_forward_upscale_f32(
|
|
| 12114 |
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
| 12115 |
|
| 12116 |
const int ith = params->ith;
|
|
|
|
| 12117 |
|
| 12118 |
GGML_TENSOR_UNARY_OP_LOCALS
|
| 12119 |
|
|
@@ -12121,16 +12211,17 @@ static void ggml_compute_forward_upscale_f32(
|
|
| 12121 |
|
| 12122 |
// TODO: optimize
|
| 12123 |
|
| 12124 |
-
for (
|
| 12125 |
-
|
| 12126 |
-
|
| 12127 |
-
|
| 12128 |
-
|
| 12129 |
-
|
| 12130 |
-
|
| 12131 |
-
const
|
| 12132 |
|
| 12133 |
-
float *
|
|
|
|
| 12134 |
|
| 12135 |
*y = *x;
|
| 12136 |
}
|
|
@@ -12155,6 +12246,64 @@ static void ggml_compute_forward_upscale(
|
|
| 12155 |
}
|
| 12156 |
}
|
| 12157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12158 |
// ggml_compute_forward_argsort
|
| 12159 |
|
| 12160 |
static void ggml_compute_forward_argsort_f32(
|
|
@@ -13362,10 +13511,6 @@ static void ggml_compute_forward_unary(
|
|
| 13362 |
{
|
| 13363 |
ggml_compute_forward_silu(params, src0, dst);
|
| 13364 |
} break;
|
| 13365 |
-
case GGML_UNARY_OP_LEAKY:
|
| 13366 |
-
{
|
| 13367 |
-
ggml_compute_forward_leaky(params, src0, dst);
|
| 13368 |
-
} break;
|
| 13369 |
default:
|
| 13370 |
{
|
| 13371 |
GGML_ASSERT(false);
|
|
@@ -14037,11 +14182,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 14037 |
} break;
|
| 14038 |
case GGML_OP_MUL_MAT:
|
| 14039 |
{
|
| 14040 |
-
ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
|
| 14041 |
} break;
|
| 14042 |
case GGML_OP_MUL_MAT_ID:
|
| 14043 |
{
|
| 14044 |
-
ggml_compute_forward_mul_mat_id(params, tensor);
|
| 14045 |
} break;
|
| 14046 |
case GGML_OP_OUT_PROD:
|
| 14047 |
{
|
|
@@ -14147,10 +14292,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 14147 |
{
|
| 14148 |
ggml_compute_forward_upscale(params, tensor->src[0], tensor);
|
| 14149 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14150 |
case GGML_OP_ARGSORT:
|
| 14151 |
{
|
| 14152 |
ggml_compute_forward_argsort(params, tensor->src[0], tensor);
|
| 14153 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14154 |
case GGML_OP_FLASH_ATTN:
|
| 14155 |
{
|
| 14156 |
const int32_t t = ggml_get_op_params_i32(tensor, 0);
|
|
@@ -14475,7 +14628,7 @@ void ggml_build_backward_gradient_checkpointing(
|
|
| 14475 |
// insert new tensors recomputing src, reusing already made replacements,
|
| 14476 |
// remember replacements: remember new tensors with mapping from corresponding gf nodes
|
| 14477 |
// recurse for input tensors,
|
| 14478 |
-
// unless (i.e. terminating when) input tensors are
|
| 14479 |
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
|
| 14480 |
}
|
| 14481 |
// insert rewritten backward node with replacements made into resulting backward graph gb
|
|
@@ -15143,10 +15296,18 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 15143 |
{
|
| 15144 |
GGML_ASSERT(false); // TODO: not implemented
|
| 15145 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15146 |
case GGML_OP_ARGSORT:
|
| 15147 |
{
|
| 15148 |
GGML_ASSERT(false); // TODO: not implemented
|
| 15149 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15150 |
case GGML_OP_FLASH_ATTN:
|
| 15151 |
{
|
| 15152 |
struct ggml_tensor * flash_grad = NULL;
|
|
@@ -15752,6 +15913,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 15752 |
case GGML_OP_ARGMAX:
|
| 15753 |
case GGML_OP_REPEAT:
|
| 15754 |
case GGML_OP_REPEAT_BACK:
|
|
|
|
| 15755 |
{
|
| 15756 |
n_tasks = 1;
|
| 15757 |
} break;
|
|
@@ -15764,7 +15926,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 15764 |
case GGML_UNARY_OP_TANH:
|
| 15765 |
case GGML_UNARY_OP_ELU:
|
| 15766 |
case GGML_UNARY_OP_RELU:
|
| 15767 |
-
case GGML_UNARY_OP_LEAKY:
|
| 15768 |
{
|
| 15769 |
n_tasks = 1;
|
| 15770 |
} break;
|
|
@@ -15883,6 +16044,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 15883 |
{
|
| 15884 |
n_tasks = n_threads;
|
| 15885 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15886 |
case GGML_OP_ARGSORT:
|
| 15887 |
{
|
| 15888 |
n_tasks = n_threads;
|
|
|
|
| 1 |
+
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
|
| 2 |
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
| 3 |
|
| 4 |
#include "ggml-impl.h"
|
|
|
|
| 33 |
// we should just be careful :)
|
| 34 |
#pragma warning(disable: 4244 4267)
|
| 35 |
|
| 36 |
+
// disable POSIX deprecation warnings
|
| 37 |
// these functions are never going away, anyway
|
| 38 |
#pragma warning(disable: 4996)
|
| 39 |
#endif
|
|
|
|
| 1395 |
inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
|
| 1396 |
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
|
| 1397 |
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
|
| 1398 |
+
inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
|
| 1399 |
|
| 1400 |
static const float GELU_COEF_A = 0.044715f;
|
| 1401 |
static const float GELU_QUICK_COEF = -1.702f;
|
|
|
|
| 1623 |
"POOL_1D",
|
| 1624 |
"POOL_2D",
|
| 1625 |
"UPSCALE",
|
| 1626 |
+
"PAD",
|
| 1627 |
"ARGSORT",
|
| 1628 |
+
"LEAKY_RELU",
|
| 1629 |
|
| 1630 |
"FLASH_ATTN",
|
| 1631 |
"FLASH_FF",
|
|
|
|
| 1652 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 1653 |
};
|
| 1654 |
|
| 1655 |
+
static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
|
| 1656 |
|
| 1657 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 1658 |
"none",
|
|
|
|
| 1709 |
"pool_1d(x)",
|
| 1710 |
"pool_2d(x)",
|
| 1711 |
"upscale(x)",
|
| 1712 |
+
"pad(x)",
|
| 1713 |
"argsort(x)",
|
| 1714 |
+
"leaky_relu(x)",
|
| 1715 |
|
| 1716 |
"flash_attn(x)",
|
| 1717 |
"flash_ff(x)",
|
|
|
|
| 1738 |
"cross_entropy_loss_back(x,y)",
|
| 1739 |
};
|
| 1740 |
|
| 1741 |
+
static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
|
| 1742 |
|
| 1743 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1744 |
|
|
|
|
| 1754 |
"GELU",
|
| 1755 |
"GELU_QUICK",
|
| 1756 |
"SILU",
|
|
|
|
| 1757 |
};
|
| 1758 |
|
| 1759 |
+
static_assert(GGML_UNARY_OP_COUNT == 10, "GGML_UNARY_OP_COUNT != 10");
|
| 1760 |
|
| 1761 |
|
| 1762 |
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
| 1763 |
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
| 1764 |
|
| 1765 |
// WARN:
|
| 1766 |
+
// Mis-configuration can lead to problem that's hard to reason about:
|
| 1767 |
// * At best it crash or talks nosense.
|
| 1768 |
// * At worst it talks slightly difference but hard to perceive.
|
| 1769 |
//
|
|
|
|
| 3833 |
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
|
| 3834 |
}
|
| 3835 |
|
| 3836 |
+
// ggml_leaky_relu
|
| 3837 |
|
| 3838 |
+
struct ggml_tensor * ggml_leaky_relu(
|
| 3839 |
struct ggml_context * ctx,
|
| 3840 |
+
struct ggml_tensor * a, float negative_slope, bool inplace) {
|
| 3841 |
+
bool is_node = false;
|
| 3842 |
+
|
| 3843 |
+
if (!inplace && (a->grad)) {
|
| 3844 |
+
is_node = true;
|
| 3845 |
+
}
|
| 3846 |
+
|
| 3847 |
+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
| 3848 |
+
ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
|
| 3849 |
+
|
| 3850 |
+
result->op = GGML_OP_LEAKY_RELU;
|
| 3851 |
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 3852 |
+
result->src[0] = a;
|
| 3853 |
+
|
| 3854 |
+
return result;
|
| 3855 |
}
|
| 3856 |
|
| 3857 |
// ggml_gelu
|
|
|
|
| 4038 |
|
| 4039 |
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
| 4040 |
|
|
|
|
| 4041 |
result->op_params[0] = n_groups;
|
| 4042 |
+
|
| 4043 |
+
result->op = GGML_OP_GROUP_NORM;
|
| 4044 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 4045 |
result->src[0] = a;
|
| 4046 |
result->src[1] = NULL; // TODO: maybe store epsilon here?
|
|
|
|
| 4092 |
|
| 4093 |
struct ggml_tensor * ggml_mul_mat_id(
|
| 4094 |
struct ggml_context * ctx,
|
| 4095 |
+
struct ggml_tensor * const as[],
|
| 4096 |
+
int n_as,
|
| 4097 |
struct ggml_tensor * ids,
|
| 4098 |
int id,
|
| 4099 |
struct ggml_tensor * b) {
|
| 4100 |
|
|
|
|
|
|
|
| 4101 |
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
| 4102 |
+
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
|
| 4103 |
+
GGML_ASSERT(ids->ne[1] == b->ne[1]);
|
| 4104 |
+
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
|
| 4105 |
GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
|
| 4106 |
+
GGML_ASSERT(id >= 0 && id < ids->ne[0]);
|
| 4107 |
|
| 4108 |
bool is_node = false;
|
| 4109 |
|
|
|
|
| 4115 |
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
|
| 4116 |
|
| 4117 |
ggml_set_op_params_i32(result, 0, id);
|
| 4118 |
+
ggml_set_op_params_i32(result, 1, n_as);
|
| 4119 |
|
| 4120 |
result->op = GGML_OP_MUL_MAT_ID;
|
| 4121 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 4122 |
result->src[0] = ids;
|
| 4123 |
result->src[1] = b;
|
| 4124 |
|
| 4125 |
+
for (int i = 0; i < n_as; i++) {
|
| 4126 |
struct ggml_tensor * a = as[i];
|
| 4127 |
GGML_ASSERT(ggml_are_same_shape(as[0], a));
|
| 4128 |
GGML_ASSERT(ggml_can_mul_mat(a, b));
|
|
|
|
| 4750 |
struct ggml_context * ctx,
|
| 4751 |
struct ggml_tensor * a,
|
| 4752 |
struct ggml_tensor * b) {
|
| 4753 |
+
GGML_ASSERT(a->ne[2] == b->ne[1]);
|
| 4754 |
+
GGML_ASSERT(b->ne[3] == 1);
|
| 4755 |
+
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
| 4756 |
|
| 4757 |
bool is_node = false;
|
| 4758 |
|
|
|
|
| 4762 |
|
| 4763 |
// TODO: implement non F32 return
|
| 4764 |
//struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
|
| 4765 |
+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
|
| 4766 |
|
| 4767 |
result->op = GGML_OP_GET_ROWS;
|
| 4768 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
|
|
|
| 5540 |
return result;
|
| 5541 |
}
|
| 5542 |
|
| 5543 |
+
struct ggml_tensor * ggml_pad(
|
| 5544 |
+
struct ggml_context * ctx,
|
| 5545 |
+
struct ggml_tensor * a,
|
| 5546 |
+
int p0, int p1, int p2, int p3) {
|
| 5547 |
+
bool is_node = false;
|
| 5548 |
+
|
| 5549 |
+
if (a->grad) {
|
| 5550 |
+
GGML_ASSERT(false); // TODO: implement backward
|
| 5551 |
+
is_node = true;
|
| 5552 |
+
}
|
| 5553 |
+
|
| 5554 |
+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
|
| 5555 |
+
a->ne[0] + p0,
|
| 5556 |
+
a->ne[1] + p1,
|
| 5557 |
+
a->ne[2] + p2,
|
| 5558 |
+
a->ne[3] + p3);
|
| 5559 |
+
|
| 5560 |
+
result->op = GGML_OP_PAD;
|
| 5561 |
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 5562 |
+
result->src[0] = a;
|
| 5563 |
+
|
| 5564 |
+
return result;
|
| 5565 |
+
}
|
| 5566 |
+
|
| 5567 |
struct ggml_tensor * ggml_upscale(
|
| 5568 |
struct ggml_context * ctx,
|
| 5569 |
struct ggml_tensor * a,
|
|
|
|
| 7565 |
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
| 7566 |
|
| 7567 |
// view src0 and dst with these strides and data offset inbytes during acc
|
| 7568 |
+
// nb0 is implicitly element_size because src0 and dst are contiguous
|
| 7569 |
size_t nb1 = ((int32_t *) dst->op_params)[0];
|
| 7570 |
size_t nb2 = ((int32_t *) dst->op_params)[1];
|
| 7571 |
size_t nb3 = ((int32_t *) dst->op_params)[2];
|
|
|
|
| 7759 |
const int ith = params->ith;
|
| 7760 |
const int nth = params->nth;
|
| 7761 |
|
| 7762 |
+
// TODO: OpenCL kernel support broadcast
|
| 7763 |
#ifdef GGML_USE_CLBLAST
|
| 7764 |
if (src1->backend == GGML_BACKEND_GPU) {
|
| 7765 |
+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
| 7766 |
if (ith == 0) {
|
| 7767 |
ggml_cl_mul(src0, src1, dst);
|
| 7768 |
}
|
|
|
|
| 9028 |
} break;
|
| 9029 |
}
|
| 9030 |
}
|
| 9031 |
+
// ggml_compute_forward_leaky_relu
|
| 9032 |
|
| 9033 |
+
static void ggml_compute_forward_leaky_relu_f32(
|
|
|
|
|
|
|
| 9034 |
const struct ggml_compute_params * params,
|
| 9035 |
const struct ggml_tensor * src0,
|
| 9036 |
struct ggml_tensor * dst) {
|
|
|
|
| 9044 |
const int n = ggml_nrows(src0);
|
| 9045 |
const int nc = src0->ne[0];
|
| 9046 |
|
| 9047 |
+
float negative_slope;
|
| 9048 |
+
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
| 9049 |
+
|
| 9050 |
assert(dst->nb[0] == sizeof(float));
|
| 9051 |
assert(src0->nb[0] == sizeof(float));
|
| 9052 |
|
| 9053 |
for (int i = 0; i < n; i++) {
|
| 9054 |
+
ggml_vec_leaky_relu_f32(nc,
|
| 9055 |
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
| 9056 |
+
(float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
|
| 9057 |
}
|
| 9058 |
}
|
| 9059 |
|
| 9060 |
+
static void ggml_compute_forward_leaky_relu(
|
| 9061 |
const struct ggml_compute_params * params,
|
| 9062 |
const struct ggml_tensor * src0,
|
| 9063 |
struct ggml_tensor * dst) {
|
| 9064 |
switch (src0->type) {
|
| 9065 |
case GGML_TYPE_F32:
|
| 9066 |
{
|
| 9067 |
+
ggml_compute_forward_leaky_relu_f32(params, src0, dst);
|
| 9068 |
} break;
|
| 9069 |
default:
|
| 9070 |
{
|
|
|
|
| 9553 |
const int64_t ne0 = dst->ne[0];
|
| 9554 |
const int64_t ne1 = dst->ne[1];
|
| 9555 |
|
| 9556 |
+
// NOTE: with GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float)
|
| 9557 |
+
// all the experts for each batch element and the processing would become incredibly slow
|
| 9558 |
// TODO: find the optimal values for these
|
| 9559 |
+
if (dst->op != GGML_OP_MUL_MAT_ID &&
|
| 9560 |
+
ggml_is_contiguous(src0) &&
|
| 9561 |
ggml_is_contiguous(src1) &&
|
| 9562 |
//src0->type == GGML_TYPE_F32 &&
|
| 9563 |
src1->type == GGML_TYPE_F32 &&
|
|
|
|
| 9571 |
}
|
| 9572 |
#endif
|
| 9573 |
|
| 9574 |
+
// off1 = offset in i11 and i1
|
| 9575 |
+
// cne1 = ne11 and ne1
|
| 9576 |
+
// in a normal matrix multiplication, off1 = 0 and cne1 = ne1
|
| 9577 |
+
// during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
|
| 9578 |
static void ggml_compute_forward_mul_mat(
|
| 9579 |
const struct ggml_compute_params * params,
|
| 9580 |
const struct ggml_tensor * src0,
|
| 9581 |
const struct ggml_tensor * src1,
|
| 9582 |
+
struct ggml_tensor * dst,
|
| 9583 |
+
int64_t off1, int64_t cne1) {
|
| 9584 |
int64_t t0 = ggml_perf_time_us();
|
| 9585 |
UNUSED(t0);
|
| 9586 |
|
|
|
|
| 9648 |
const int64_t i03 = i13/r3;
|
| 9649 |
const int64_t i02 = i12/r2;
|
| 9650 |
|
| 9651 |
+
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
|
| 9652 |
+
const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
|
| 9653 |
+
float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3);
|
|
|
|
| 9654 |
|
| 9655 |
if (type != GGML_TYPE_F32) {
|
| 9656 |
float * const wdata = params->wdata;
|
|
|
|
| 9667 |
}
|
| 9668 |
|
| 9669 |
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
| 9670 |
+
cne1, ne01, ne10,
|
| 9671 |
+
1.0f, y, ne10,
|
| 9672 |
+
x, ne00,
|
| 9673 |
+
0.0f, d, ne01);
|
| 9674 |
}
|
| 9675 |
}
|
| 9676 |
|
|
|
|
| 9686 |
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
|
| 9687 |
|
| 9688 |
assert(params->wsize >= ne11*ne12*ne13*row_size);
|
| 9689 |
+
assert(src1->type == GGML_TYPE_F32);
|
| 9690 |
|
| 9691 |
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
| 9692 |
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
|
|
| 9709 |
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
|
| 9710 |
|
| 9711 |
const int64_t nr0 = ne01; // src0 rows
|
| 9712 |
+
const int64_t nr1 = cne1*ne12*ne13; // src1 rows
|
| 9713 |
|
| 9714 |
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
|
| 9715 |
|
|
|
|
| 9751 |
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
| 9752 |
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
| 9753 |
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
| 9754 |
+
const int64_t i13 = (ir1/(ne12*cne1));
|
| 9755 |
+
const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
|
| 9756 |
+
const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
|
| 9757 |
|
| 9758 |
// broadcast src0 into src1
|
| 9759 |
const int64_t i03 = i13/r3;
|
|
|
|
| 9793 |
|
| 9794 |
static void ggml_compute_forward_mul_mat_id(
|
| 9795 |
const struct ggml_compute_params * params,
|
| 9796 |
+
const struct ggml_tensor * src0,
|
| 9797 |
+
const struct ggml_tensor * src1,
|
| 9798 |
struct ggml_tensor * dst) {
|
| 9799 |
|
| 9800 |
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
| 9801 |
+
// during GGML_TASK_INIT the entire src1 is converted to vec_dot_type
|
| 9802 |
+
ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
|
| 9803 |
+
return;
|
| 9804 |
+
}
|
| 9805 |
|
| 9806 |
+
const struct ggml_tensor * ids = src0;
|
| 9807 |
+
const int id = ggml_get_op_params_i32(dst, 0);
|
| 9808 |
+
const int n_as = ggml_get_op_params_i32(dst, 1);
|
| 9809 |
|
| 9810 |
+
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
| 9811 |
+
const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
|
| 9812 |
|
| 9813 |
+
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
| 9814 |
|
| 9815 |
+
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
| 9816 |
+
ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
|
| 9817 |
+
}
|
| 9818 |
}
|
| 9819 |
|
| 9820 |
// ggml_compute_forward_out_prod
|
|
|
|
| 10226 |
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
| 10227 |
|
| 10228 |
// view src0 and dst with these strides and data offset inbytes during set
|
| 10229 |
+
// nb0 is implicitly element_size because src0 and dst are contiguous
|
| 10230 |
size_t nb1 = ((int32_t *) dst->op_params)[0];
|
| 10231 |
size_t nb2 = ((int32_t *) dst->op_params)[1];
|
| 10232 |
size_t nb3 = ((int32_t *) dst->op_params)[2];
|
|
|
|
| 10390 |
return;
|
| 10391 |
}
|
| 10392 |
|
| 10393 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 10394 |
+
|
| 10395 |
+
const int64_t nc = ne00;
|
| 10396 |
+
const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
|
| 10397 |
+
|
| 10398 |
const enum ggml_type type = src0->type;
|
| 10399 |
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
|
| 10400 |
|
| 10401 |
+
assert(ne0 == nc);
|
| 10402 |
+
assert(ne02 == ne11);
|
| 10403 |
+
assert(nb00 == ggml_type_size(type));
|
| 10404 |
+
assert(ggml_nrows(dst) == nr);
|
| 10405 |
|
| 10406 |
+
// TODO: multi-thread
|
| 10407 |
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
| 10408 |
+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
| 10409 |
+
for (int64_t i10 = 0; i10 < ne10; ++i10) {
|
| 10410 |
+
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
| 10411 |
|
| 10412 |
+
dequantize_row_q(
|
| 10413 |
+
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
|
| 10414 |
+
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
|
| 10415 |
+
}
|
| 10416 |
+
}
|
| 10417 |
}
|
| 10418 |
}
|
| 10419 |
|
|
|
|
| 10428 |
return;
|
| 10429 |
}
|
| 10430 |
|
| 10431 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
|
|
| 10432 |
|
| 10433 |
+
const int64_t nc = ne00;
|
| 10434 |
+
const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
|
|
|
|
| 10435 |
|
| 10436 |
+
assert(ne0 == nc);
|
| 10437 |
+
assert(ne02 == ne11);
|
| 10438 |
+
assert(nb00 == sizeof(ggml_fp16_t));
|
| 10439 |
+
assert(ggml_nrows(dst) == nr);
|
| 10440 |
|
| 10441 |
+
// TODO: multi-thread
|
| 10442 |
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
| 10443 |
+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
| 10444 |
+
for (int64_t i10 = 0; i10 < ne10; ++i10) {
|
| 10445 |
+
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
| 10446 |
+
|
| 10447 |
+
ggml_fp16_to_fp32_row(
|
| 10448 |
+
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
|
| 10449 |
+
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
|
| 10450 |
+
}
|
| 10451 |
}
|
| 10452 |
}
|
| 10453 |
}
|
|
|
|
| 10463 |
return;
|
| 10464 |
}
|
| 10465 |
|
| 10466 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
|
|
| 10467 |
|
| 10468 |
+
const int64_t nc = ne00;
|
| 10469 |
+
const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
|
|
|
|
| 10470 |
|
| 10471 |
+
assert(ne0 == nc);
|
| 10472 |
+
assert(ne02 == ne11);
|
| 10473 |
+
assert(nb00 == sizeof(float));
|
| 10474 |
+
assert(ggml_nrows(dst) == nr);
|
| 10475 |
|
| 10476 |
+
// TODO: multi-thread
|
| 10477 |
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
| 10478 |
+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
| 10479 |
+
for (int64_t i10 = 0; i10 < ne10; ++i10) {
|
| 10480 |
+
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
| 10481 |
+
|
| 10482 |
+
ggml_vec_cpy_f32(nc,
|
| 10483 |
+
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
|
| 10484 |
+
(float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
|
| 10485 |
+
}
|
| 10486 |
+
}
|
| 10487 |
}
|
| 10488 |
}
|
| 10489 |
|
|
|
|
| 12203 |
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
| 12204 |
|
| 12205 |
const int ith = params->ith;
|
| 12206 |
+
const int nth = params->nth;
|
| 12207 |
|
| 12208 |
GGML_TENSOR_UNARY_OP_LOCALS
|
| 12209 |
|
|
|
|
| 12211 |
|
| 12212 |
// TODO: optimize
|
| 12213 |
|
| 12214 |
+
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
| 12215 |
+
const int64_t i03 = i3;
|
| 12216 |
+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
| 12217 |
+
const int64_t i02 = i2;
|
| 12218 |
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
| 12219 |
+
const int64_t i01 = i1 / scale_factor;
|
| 12220 |
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
| 12221 |
+
const int64_t i00 = i0 / scale_factor;
|
| 12222 |
|
| 12223 |
+
const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
| 12224 |
+
float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
| 12225 |
|
| 12226 |
*y = *x;
|
| 12227 |
}
|
|
|
|
| 12246 |
}
|
| 12247 |
}
|
| 12248 |
|
| 12249 |
+
// ggml_compute_forward_pad
|
| 12250 |
+
|
| 12251 |
+
static void ggml_compute_forward_pad_f32(
|
| 12252 |
+
const struct ggml_compute_params * params,
|
| 12253 |
+
const struct ggml_tensor * src0,
|
| 12254 |
+
struct ggml_tensor * dst) {
|
| 12255 |
+
|
| 12256 |
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
| 12257 |
+
return;
|
| 12258 |
+
}
|
| 12259 |
+
|
| 12260 |
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
| 12261 |
+
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
| 12262 |
+
|
| 12263 |
+
const int ith = params->ith;
|
| 12264 |
+
const int nth = params->nth;
|
| 12265 |
+
|
| 12266 |
+
GGML_TENSOR_UNARY_OP_LOCALS
|
| 12267 |
+
|
| 12268 |
+
float * dst_ptr = (float *) dst->data;
|
| 12269 |
+
|
| 12270 |
+
// TODO: optimize
|
| 12271 |
+
|
| 12272 |
+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
| 12273 |
+
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
| 12274 |
+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
| 12275 |
+
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
| 12276 |
+
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
| 12277 |
+
|
| 12278 |
+
const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 12279 |
+
|
| 12280 |
+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
| 12281 |
+
dst_ptr[dst_idx] = *src_ptr;
|
| 12282 |
+
} else {
|
| 12283 |
+
dst_ptr[dst_idx] = 0;
|
| 12284 |
+
}
|
| 12285 |
+
}
|
| 12286 |
+
}
|
| 12287 |
+
}
|
| 12288 |
+
}
|
| 12289 |
+
}
|
| 12290 |
+
|
| 12291 |
+
static void ggml_compute_forward_pad(
|
| 12292 |
+
const struct ggml_compute_params * params,
|
| 12293 |
+
const struct ggml_tensor * src0,
|
| 12294 |
+
struct ggml_tensor * dst) {
|
| 12295 |
+
switch (src0->type) {
|
| 12296 |
+
case GGML_TYPE_F32:
|
| 12297 |
+
{
|
| 12298 |
+
ggml_compute_forward_pad_f32(params, src0, dst);
|
| 12299 |
+
} break;
|
| 12300 |
+
default:
|
| 12301 |
+
{
|
| 12302 |
+
GGML_ASSERT(false);
|
| 12303 |
+
} break;
|
| 12304 |
+
}
|
| 12305 |
+
}
|
| 12306 |
+
|
| 12307 |
// ggml_compute_forward_argsort
|
| 12308 |
|
| 12309 |
static void ggml_compute_forward_argsort_f32(
|
|
|
|
| 13511 |
{
|
| 13512 |
ggml_compute_forward_silu(params, src0, dst);
|
| 13513 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13514 |
default:
|
| 13515 |
{
|
| 13516 |
GGML_ASSERT(false);
|
|
|
|
| 14182 |
} break;
|
| 14183 |
case GGML_OP_MUL_MAT:
|
| 14184 |
{
|
| 14185 |
+
ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
|
| 14186 |
} break;
|
| 14187 |
case GGML_OP_MUL_MAT_ID:
|
| 14188 |
{
|
| 14189 |
+
ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor);
|
| 14190 |
} break;
|
| 14191 |
case GGML_OP_OUT_PROD:
|
| 14192 |
{
|
|
|
|
| 14292 |
{
|
| 14293 |
ggml_compute_forward_upscale(params, tensor->src[0], tensor);
|
| 14294 |
} break;
|
| 14295 |
+
case GGML_OP_PAD:
|
| 14296 |
+
{
|
| 14297 |
+
ggml_compute_forward_pad(params, tensor->src[0], tensor);
|
| 14298 |
+
} break;
|
| 14299 |
case GGML_OP_ARGSORT:
|
| 14300 |
{
|
| 14301 |
ggml_compute_forward_argsort(params, tensor->src[0], tensor);
|
| 14302 |
} break;
|
| 14303 |
+
case GGML_OP_LEAKY_RELU:
|
| 14304 |
+
{
|
| 14305 |
+
ggml_compute_forward_leaky_relu(params, tensor->src[0], tensor);
|
| 14306 |
+
} break;
|
| 14307 |
case GGML_OP_FLASH_ATTN:
|
| 14308 |
{
|
| 14309 |
const int32_t t = ggml_get_op_params_i32(tensor, 0);
|
|
|
|
| 14628 |
// insert new tensors recomputing src, reusing already made replacements,
|
| 14629 |
// remember replacements: remember new tensors with mapping from corresponding gf nodes
|
| 14630 |
// recurse for input tensors,
|
| 14631 |
+
// unless (i.e. terminating when) input tensors are replacements (like checkpoints)
|
| 14632 |
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
|
| 14633 |
}
|
| 14634 |
// insert rewritten backward node with replacements made into resulting backward graph gb
|
|
|
|
| 15296 |
{
|
| 15297 |
GGML_ASSERT(false); // TODO: not implemented
|
| 15298 |
} break;
|
| 15299 |
+
case GGML_OP_PAD:
|
| 15300 |
+
{
|
| 15301 |
+
GGML_ASSERT(false); // TODO: not implemented
|
| 15302 |
+
} break;
|
| 15303 |
case GGML_OP_ARGSORT:
|
| 15304 |
{
|
| 15305 |
GGML_ASSERT(false); // TODO: not implemented
|
| 15306 |
} break;
|
| 15307 |
+
case GGML_OP_LEAKY_RELU:
|
| 15308 |
+
{
|
| 15309 |
+
GGML_ASSERT(false); // TODO: not implemented
|
| 15310 |
+
} break;
|
| 15311 |
case GGML_OP_FLASH_ATTN:
|
| 15312 |
{
|
| 15313 |
struct ggml_tensor * flash_grad = NULL;
|
|
|
|
| 15913 |
case GGML_OP_ARGMAX:
|
| 15914 |
case GGML_OP_REPEAT:
|
| 15915 |
case GGML_OP_REPEAT_BACK:
|
| 15916 |
+
case GGML_OP_LEAKY_RELU:
|
| 15917 |
{
|
| 15918 |
n_tasks = 1;
|
| 15919 |
} break;
|
|
|
|
| 15926 |
case GGML_UNARY_OP_TANH:
|
| 15927 |
case GGML_UNARY_OP_ELU:
|
| 15928 |
case GGML_UNARY_OP_RELU:
|
|
|
|
| 15929 |
{
|
| 15930 |
n_tasks = 1;
|
| 15931 |
} break;
|
|
|
|
| 16044 |
{
|
| 16045 |
n_tasks = n_threads;
|
| 16046 |
} break;
|
| 16047 |
+
case GGML_OP_PAD:
|
| 16048 |
+
{
|
| 16049 |
+
n_tasks = n_threads;
|
| 16050 |
+
} break;
|
| 16051 |
case GGML_OP_ARGSORT:
|
| 16052 |
{
|
| 16053 |
n_tasks = n_threads;
|
ggml.h
CHANGED
|
@@ -215,9 +215,9 @@
|
|
| 215 |
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
|
| 216 |
|
| 217 |
#define GGML_MAX_DIMS 4
|
| 218 |
-
#define GGML_MAX_PARAMS
|
| 219 |
#define GGML_MAX_CONTEXTS 64
|
| 220 |
-
#define GGML_MAX_SRC
|
| 221 |
#define GGML_MAX_NAME 64
|
| 222 |
#define GGML_MAX_OP_PARAMS 64
|
| 223 |
#define GGML_DEFAULT_N_THREADS 4
|
|
@@ -423,7 +423,9 @@ extern "C" {
|
|
| 423 |
GGML_OP_POOL_1D,
|
| 424 |
GGML_OP_POOL_2D,
|
| 425 |
GGML_OP_UPSCALE, // nearest interpolate
|
|
|
|
| 426 |
GGML_OP_ARGSORT,
|
|
|
|
| 427 |
|
| 428 |
GGML_OP_FLASH_ATTN,
|
| 429 |
GGML_OP_FLASH_FF,
|
|
@@ -463,7 +465,6 @@ extern "C" {
|
|
| 463 |
GGML_UNARY_OP_GELU,
|
| 464 |
GGML_UNARY_OP_GELU_QUICK,
|
| 465 |
GGML_UNARY_OP_SILU,
|
| 466 |
-
GGML_UNARY_OP_LEAKY,
|
| 467 |
|
| 468 |
GGML_UNARY_OP_COUNT,
|
| 469 |
};
|
|
@@ -793,6 +794,9 @@ extern "C" {
|
|
| 793 |
struct ggml_tensor * a,
|
| 794 |
struct ggml_tensor * b);
|
| 795 |
|
|
|
|
|
|
|
|
|
|
| 796 |
GGML_API struct ggml_tensor * ggml_acc(
|
| 797 |
struct ggml_context * ctx,
|
| 798 |
struct ggml_tensor * a,
|
|
@@ -957,15 +961,14 @@ extern "C" {
|
|
| 957 |
struct ggml_context * ctx,
|
| 958 |
struct ggml_tensor * a);
|
| 959 |
|
| 960 |
-
GGML_API struct ggml_tensor *
|
| 961 |
struct ggml_context * ctx,
|
| 962 |
-
struct ggml_tensor * a);
|
| 963 |
|
| 964 |
GGML_API struct ggml_tensor * ggml_relu_inplace(
|
| 965 |
struct ggml_context * ctx,
|
| 966 |
struct ggml_tensor * a);
|
| 967 |
|
| 968 |
-
// TODO: double-check this computation is correct
|
| 969 |
GGML_API struct ggml_tensor * ggml_gelu(
|
| 970 |
struct ggml_context * ctx,
|
| 971 |
struct ggml_tensor * a);
|
|
@@ -1051,7 +1054,8 @@ extern "C" {
|
|
| 1051 |
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
| 1052 |
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
| 1053 |
struct ggml_context * ctx,
|
| 1054 |
-
struct ggml_tensor * as[],
|
|
|
|
| 1055 |
struct ggml_tensor * ids,
|
| 1056 |
int id,
|
| 1057 |
struct ggml_tensor * b);
|
|
@@ -1263,6 +1267,7 @@ extern "C" {
|
|
| 1263 |
struct ggml_context * ctx,
|
| 1264 |
struct ggml_tensor * a);
|
| 1265 |
|
|
|
|
| 1266 |
GGML_API struct ggml_tensor * ggml_get_rows(
|
| 1267 |
struct ggml_context * ctx,
|
| 1268 |
struct ggml_tensor * a,
|
|
@@ -1549,6 +1554,15 @@ extern "C" {
|
|
| 1549 |
struct ggml_tensor * a,
|
| 1550 |
int scale_factor);
|
| 1551 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1552 |
// sort rows
|
| 1553 |
enum ggml_sort_order {
|
| 1554 |
GGML_SORT_ASC,
|
|
|
|
| 215 |
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
|
| 216 |
|
| 217 |
#define GGML_MAX_DIMS 4
|
| 218 |
+
#define GGML_MAX_PARAMS 2048
|
| 219 |
#define GGML_MAX_CONTEXTS 64
|
| 220 |
+
#define GGML_MAX_SRC 10
|
| 221 |
#define GGML_MAX_NAME 64
|
| 222 |
#define GGML_MAX_OP_PARAMS 64
|
| 223 |
#define GGML_DEFAULT_N_THREADS 4
|
|
|
|
| 423 |
GGML_OP_POOL_1D,
|
| 424 |
GGML_OP_POOL_2D,
|
| 425 |
GGML_OP_UPSCALE, // nearest interpolate
|
| 426 |
+
GGML_OP_PAD,
|
| 427 |
GGML_OP_ARGSORT,
|
| 428 |
+
GGML_OP_LEAKY_RELU,
|
| 429 |
|
| 430 |
GGML_OP_FLASH_ATTN,
|
| 431 |
GGML_OP_FLASH_FF,
|
|
|
|
| 465 |
GGML_UNARY_OP_GELU,
|
| 466 |
GGML_UNARY_OP_GELU_QUICK,
|
| 467 |
GGML_UNARY_OP_SILU,
|
|
|
|
| 468 |
|
| 469 |
GGML_UNARY_OP_COUNT,
|
| 470 |
};
|
|
|
|
| 794 |
struct ggml_tensor * a,
|
| 795 |
struct ggml_tensor * b);
|
| 796 |
|
| 797 |
+
// dst = a
|
| 798 |
+
// view(dst, nb1, nb2, nb3, offset) += b
|
| 799 |
+
// return dst
|
| 800 |
GGML_API struct ggml_tensor * ggml_acc(
|
| 801 |
struct ggml_context * ctx,
|
| 802 |
struct ggml_tensor * a,
|
|
|
|
| 961 |
struct ggml_context * ctx,
|
| 962 |
struct ggml_tensor * a);
|
| 963 |
|
| 964 |
+
GGML_API struct ggml_tensor * ggml_leaky_relu(
|
| 965 |
struct ggml_context * ctx,
|
| 966 |
+
struct ggml_tensor * a, float negative_slope, bool inplace);
|
| 967 |
|
| 968 |
GGML_API struct ggml_tensor * ggml_relu_inplace(
|
| 969 |
struct ggml_context * ctx,
|
| 970 |
struct ggml_tensor * a);
|
| 971 |
|
|
|
|
| 972 |
GGML_API struct ggml_tensor * ggml_gelu(
|
| 973 |
struct ggml_context * ctx,
|
| 974 |
struct ggml_tensor * a);
|
|
|
|
| 1054 |
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
| 1055 |
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
| 1056 |
struct ggml_context * ctx,
|
| 1057 |
+
struct ggml_tensor * const as[],
|
| 1058 |
+
int n_as,
|
| 1059 |
struct ggml_tensor * ids,
|
| 1060 |
int id,
|
| 1061 |
struct ggml_tensor * b);
|
|
|
|
| 1267 |
struct ggml_context * ctx,
|
| 1268 |
struct ggml_tensor * a);
|
| 1269 |
|
| 1270 |
+
// supports 3D: a->ne[2] == b->ne[1]
|
| 1271 |
GGML_API struct ggml_tensor * ggml_get_rows(
|
| 1272 |
struct ggml_context * ctx,
|
| 1273 |
struct ggml_tensor * a,
|
|
|
|
| 1554 |
struct ggml_tensor * a,
|
| 1555 |
int scale_factor);
|
| 1556 |
|
| 1557 |
+
// pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
|
| 1558 |
+
GGML_API struct ggml_tensor * ggml_pad(
|
| 1559 |
+
struct ggml_context * ctx,
|
| 1560 |
+
struct ggml_tensor * a,
|
| 1561 |
+
int p0,
|
| 1562 |
+
int p1,
|
| 1563 |
+
int p2,
|
| 1564 |
+
int p3);
|
| 1565 |
+
|
| 1566 |
// sort rows
|
| 1567 |
enum ggml_sort_order {
|
| 1568 |
GGML_SORT_ASC,
|