Dave dave-fl commited on
Commit
a06cbc7
·
1 Parent(s): 6901743

Added support for GGML_OP_CLAMP in Metal (llama/6662)

Browse files

* Added support for GGML_OP_CLAMP in Metal

* Corrected size

---------

Co-authored-by: dave-fl <[email protected]>

Files changed (2) hide show
  1. ggml-metal.m +22 -0
  2. ggml-metal.metal +9 -0
ggml-metal.m CHANGED
@@ -37,6 +37,7 @@ enum ggml_metal_kernel_type {
37
  GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
  GGML_METAL_KERNEL_TYPE_SCALE,
39
  GGML_METAL_KERNEL_TYPE_SCALE_4,
 
40
  GGML_METAL_KERNEL_TYPE_TANH,
41
  GGML_METAL_KERNEL_TYPE_RELU,
42
  GGML_METAL_KERNEL_TYPE_SIGMOID,
@@ -469,6 +470,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
469
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
470
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
471
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
 
472
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
473
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
474
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
@@ -716,6 +718,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
716
  case GGML_OP_MUL:
717
  case GGML_OP_DIV:
718
  case GGML_OP_SCALE:
 
719
  case GGML_OP_SQR:
720
  case GGML_OP_SUM_ROWS:
721
  return true;
@@ -1157,6 +1160,25 @@ static enum ggml_status ggml_metal_graph_compute(
1157
 
1158
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1159
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1160
  case GGML_OP_UNARY:
1161
  switch (ggml_get_unary_op(gf->nodes[i])) {
1162
  case GGML_UNARY_OP_TANH:
 
37
  GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
  GGML_METAL_KERNEL_TYPE_SCALE,
39
  GGML_METAL_KERNEL_TYPE_SCALE_4,
40
+ GGML_METAL_KERNEL_TYPE_CLAMP,
41
  GGML_METAL_KERNEL_TYPE_TANH,
42
  GGML_METAL_KERNEL_TYPE_RELU,
43
  GGML_METAL_KERNEL_TYPE_SIGMOID,
 
470
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
471
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
472
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
473
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
474
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
475
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
 
718
  case GGML_OP_MUL:
719
  case GGML_OP_DIV:
720
  case GGML_OP_SCALE:
721
+ case GGML_OP_CLAMP:
722
  case GGML_OP_SQR:
723
  case GGML_OP_SUM_ROWS:
724
  return true;
 
1160
 
1161
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1162
  } break;
1163
+ case GGML_OP_CLAMP:
1164
+ {
1165
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1166
+
1167
+ float min;
1168
+ float max;
1169
+ memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1170
+ memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1171
+
1172
+ [encoder setComputePipelineState:pipeline];
1173
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1174
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1175
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
1176
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
1177
+
1178
+ const int64_t n = ggml_nelements(dst);
1179
+
1180
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1181
+ } break;
1182
  case GGML_OP_UNARY:
1183
  switch (ggml_get_unary_op(gf->nodes[i])) {
1184
  case GGML_UNARY_OP_TANH:
ggml-metal.metal CHANGED
@@ -213,6 +213,15 @@ kernel void kernel_scale_4(
213
  dst[tpig] = src0[tpig] * scale;
214
  }
215
 
 
 
 
 
 
 
 
 
 
216
  kernel void kernel_relu(
217
  device const float * src0,
218
  device float * dst,
 
213
  dst[tpig] = src0[tpig] * scale;
214
  }
215
 
216
+ kernel void kernel_clamp(
217
+ device const float * src0,
218
+ device float * dst,
219
+ constant float & min,
220
+ constant float & max,
221
+ uint tpig[[thread_position_in_grid]]) {
222
+ dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
223
+ }
224
+
225
  kernel void kernel_relu(
226
  device const float * src0,
227
  device float * dst,