Spaces:
Sleeping
Sleeping
Dave
dave-fl
commited on
Commit
·
a06cbc7
1
Parent(s):
6901743
Added support for GGML_OP_CLAMP in Metal (llama/6662)
Browse files* Added support for GGML_OP_CLAMP in Metal
* Corrected size
---------
Co-authored-by: dave-fl <[email protected]>
- ggml-metal.m +22 -0
- ggml-metal.metal +9 -0
ggml-metal.m
CHANGED
|
@@ -37,6 +37,7 @@ enum ggml_metal_kernel_type {
|
|
| 37 |
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
| 38 |
GGML_METAL_KERNEL_TYPE_SCALE,
|
| 39 |
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
|
|
|
| 40 |
GGML_METAL_KERNEL_TYPE_TANH,
|
| 41 |
GGML_METAL_KERNEL_TYPE_RELU,
|
| 42 |
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
|
@@ -469,6 +470,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 469 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
| 470 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
| 471 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
|
|
|
| 472 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
| 473 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
| 474 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
|
@@ -716,6 +718,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
| 716 |
case GGML_OP_MUL:
|
| 717 |
case GGML_OP_DIV:
|
| 718 |
case GGML_OP_SCALE:
|
|
|
|
| 719 |
case GGML_OP_SQR:
|
| 720 |
case GGML_OP_SUM_ROWS:
|
| 721 |
return true;
|
|
@@ -1157,6 +1160,25 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1157 |
|
| 1158 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1159 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1160 |
case GGML_OP_UNARY:
|
| 1161 |
switch (ggml_get_unary_op(gf->nodes[i])) {
|
| 1162 |
case GGML_UNARY_OP_TANH:
|
|
|
|
| 37 |
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
| 38 |
GGML_METAL_KERNEL_TYPE_SCALE,
|
| 39 |
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
| 40 |
+
GGML_METAL_KERNEL_TYPE_CLAMP,
|
| 41 |
GGML_METAL_KERNEL_TYPE_TANH,
|
| 42 |
GGML_METAL_KERNEL_TYPE_RELU,
|
| 43 |
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
|
|
|
| 470 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
| 471 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
| 472 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
| 473 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
| 474 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
| 475 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
| 476 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
|
|
|
| 718 |
case GGML_OP_MUL:
|
| 719 |
case GGML_OP_DIV:
|
| 720 |
case GGML_OP_SCALE:
|
| 721 |
+
case GGML_OP_CLAMP:
|
| 722 |
case GGML_OP_SQR:
|
| 723 |
case GGML_OP_SUM_ROWS:
|
| 724 |
return true;
|
|
|
|
| 1160 |
|
| 1161 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1162 |
} break;
|
| 1163 |
+
case GGML_OP_CLAMP:
|
| 1164 |
+
{
|
| 1165 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
|
| 1166 |
+
|
| 1167 |
+
float min;
|
| 1168 |
+
float max;
|
| 1169 |
+
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
| 1170 |
+
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
|
| 1171 |
+
|
| 1172 |
+
[encoder setComputePipelineState:pipeline];
|
| 1173 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1174 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1175 |
+
[encoder setBytes:&min length:sizeof(min) atIndex:2];
|
| 1176 |
+
[encoder setBytes:&max length:sizeof(max) atIndex:3];
|
| 1177 |
+
|
| 1178 |
+
const int64_t n = ggml_nelements(dst);
|
| 1179 |
+
|
| 1180 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1181 |
+
} break;
|
| 1182 |
case GGML_OP_UNARY:
|
| 1183 |
switch (ggml_get_unary_op(gf->nodes[i])) {
|
| 1184 |
case GGML_UNARY_OP_TANH:
|
ggml-metal.metal
CHANGED
|
@@ -213,6 +213,15 @@ kernel void kernel_scale_4(
|
|
| 213 |
dst[tpig] = src0[tpig] * scale;
|
| 214 |
}
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
kernel void kernel_relu(
|
| 217 |
device const float * src0,
|
| 218 |
device float * dst,
|
|
|
|
| 213 |
dst[tpig] = src0[tpig] * scale;
|
| 214 |
}
|
| 215 |
|
| 216 |
+
kernel void kernel_clamp(
|
| 217 |
+
device const float * src0,
|
| 218 |
+
device float * dst,
|
| 219 |
+
constant float & min,
|
| 220 |
+
constant float & max,
|
| 221 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 222 |
+
dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
kernel void kernel_relu(
|
| 226 |
device const float * src0,
|
| 227 |
device float * dst,
|