Shijie ggerganov commited on
Commit
daae175
·
1 Parent(s): 6d1ba81

llama : add qwen2moe (llama/6074)

Browse files

* support qwen2moe

* fix-review

* metal : support unary ops for nelements % 4 != 0

* metal : require contiguousness for float4 unary kernels

* metal : require contiguousness for float4 unary kernels (cont)

* fix-review

* names : for brevity "SHARED_EXP" -> "SHEXP"

* llama : reuse build_moe_ffn()

* llama : add model type name

---------

Co-authored-by: Georgi Gerganov <[email protected]>

Files changed (2) hide show
  1. ggml-metal.m +42 -15
  2. ggml-metal.metal +26 -0
ggml-metal.m CHANGED
@@ -42,8 +42,11 @@ enum ggml_metal_kernel_type {
42
  GGML_METAL_KERNEL_TYPE_RELU,
43
  GGML_METAL_KERNEL_TYPE_SIGMOID,
44
  GGML_METAL_KERNEL_TYPE_GELU,
 
45
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
 
46
  GGML_METAL_KERNEL_TYPE_SILU,
 
47
  GGML_METAL_KERNEL_TYPE_SOFT_MAX,
48
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
49
  GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
@@ -475,8 +478,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
475
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
 
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
 
479
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
 
480
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
481
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
482
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
@@ -1181,6 +1187,9 @@ static enum ggml_status ggml_metal_graph_compute(
1181
  } break;
1182
  case GGML_OP_UNARY:
1183
  switch (ggml_get_unary_op(gf->nodes[i])) {
 
 
 
1184
  case GGML_UNARY_OP_TANH:
1185
  {
1186
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
@@ -1219,42 +1228,60 @@ static enum ggml_status ggml_metal_graph_compute(
1219
  } break;
1220
  case GGML_UNARY_OP_GELU:
1221
  {
1222
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
 
 
 
 
 
 
 
 
 
1223
 
1224
  [encoder setComputePipelineState:pipeline];
1225
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1226
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1227
 
1228
- const int64_t n = ggml_nelements(dst);
1229
- GGML_ASSERT(n % 4 == 0);
1230
-
1231
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1232
  } break;
1233
  case GGML_UNARY_OP_GELU_QUICK:
1234
  {
1235
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
 
 
 
 
 
 
 
 
 
1236
 
1237
  [encoder setComputePipelineState:pipeline];
1238
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1239
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1240
 
1241
- const int64_t n = ggml_nelements(dst);
1242
- GGML_ASSERT(n % 4 == 0);
1243
-
1244
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1245
  } break;
1246
  case GGML_UNARY_OP_SILU:
1247
  {
1248
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
 
 
 
 
 
 
 
 
 
1249
 
1250
  [encoder setComputePipelineState:pipeline];
1251
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1252
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1253
 
1254
- const int64_t n = ggml_nelements(dst);
1255
- GGML_ASSERT(n % 4 == 0);
1256
-
1257
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1258
  } break;
1259
  default:
1260
  {
 
42
  GGML_METAL_KERNEL_TYPE_RELU,
43
  GGML_METAL_KERNEL_TYPE_SIGMOID,
44
  GGML_METAL_KERNEL_TYPE_GELU,
45
+ GGML_METAL_KERNEL_TYPE_GELU_4,
46
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
47
+ GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
48
  GGML_METAL_KERNEL_TYPE_SILU,
49
+ GGML_METAL_KERNEL_TYPE_SILU_4,
50
  GGML_METAL_KERNEL_TYPE_SOFT_MAX,
51
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
52
  GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
 
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
479
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
480
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
481
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
482
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
483
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
484
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
485
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
486
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
487
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
488
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
 
1187
  } break;
1188
  case GGML_OP_UNARY:
1189
  switch (ggml_get_unary_op(gf->nodes[i])) {
1190
+ // we are not taking into account the strides, so for now require contiguous tensors
1191
+ GGML_ASSERT(ggml_is_contiguous(src0));
1192
+
1193
  case GGML_UNARY_OP_TANH:
1194
  {
1195
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
 
1228
  } break;
1229
  case GGML_UNARY_OP_GELU:
1230
  {
1231
+ int64_t n = ggml_nelements(dst);
1232
+
1233
+ id<MTLComputePipelineState> pipeline = nil;
1234
+
1235
+ if (n % 4 == 0) {
1236
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
1237
+ n /= 4;
1238
+ } else {
1239
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1240
+ }
1241
 
1242
  [encoder setComputePipelineState:pipeline];
1243
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1244
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1245
 
1246
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
 
 
 
1247
  } break;
1248
  case GGML_UNARY_OP_GELU_QUICK:
1249
  {
1250
+ int64_t n = ggml_nelements(dst);
1251
+
1252
+ id<MTLComputePipelineState> pipeline = nil;
1253
+
1254
+ if (n % 4 == 0) {
1255
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
1256
+ n /= 4;
1257
+ } else {
1258
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1259
+ }
1260
 
1261
  [encoder setComputePipelineState:pipeline];
1262
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1263
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1264
 
1265
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
 
 
 
1266
  } break;
1267
  case GGML_UNARY_OP_SILU:
1268
  {
1269
+ int64_t n = ggml_nelements(dst);
1270
+
1271
+ id<MTLComputePipelineState> pipeline = nil;
1272
+
1273
+ if (n % 4 == 0) {
1274
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
1275
+ n /= 4;
1276
+ } else {
1277
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1278
+ }
1279
 
1280
  [encoder setComputePipelineState:pipeline];
1281
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1282
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1283
 
1284
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
 
 
 
1285
  } break;
1286
  default:
1287
  {
ggml-metal.metal CHANGED
@@ -249,6 +249,15 @@ constant float GELU_QUICK_COEF = -1.702f;
249
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
250
 
251
  kernel void kernel_gelu(
 
 
 
 
 
 
 
 
 
252
  device const float4 * src0,
253
  device float4 * dst,
254
  uint tpig[[thread_position_in_grid]]) {
@@ -262,6 +271,15 @@ kernel void kernel_gelu(
262
  }
263
 
264
  kernel void kernel_gelu_quick(
 
 
 
 
 
 
 
 
 
265
  device const float4 * src0,
266
  device float4 * dst,
267
  uint tpig[[thread_position_in_grid]]) {
@@ -271,6 +289,14 @@ kernel void kernel_gelu_quick(
271
  }
272
 
273
  kernel void kernel_silu(
 
 
 
 
 
 
 
 
274
  device const float4 * src0,
275
  device float4 * dst,
276
  uint tpig[[thread_position_in_grid]]) {
 
249
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
250
 
251
  kernel void kernel_gelu(
252
+ device const float * src0,
253
+ device float * dst,
254
+ uint tpig[[thread_position_in_grid]]) {
255
+ device const float & x = src0[tpig];
256
+
257
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
258
+ }
259
+
260
+ kernel void kernel_gelu_4(
261
  device const float4 * src0,
262
  device float4 * dst,
263
  uint tpig[[thread_position_in_grid]]) {
 
271
  }
272
 
273
  kernel void kernel_gelu_quick(
274
+ device const float * src0,
275
+ device float * dst,
276
+ uint tpig[[thread_position_in_grid]]) {
277
+ device const float & x = src0[tpig];
278
+
279
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
280
+ }
281
+
282
+ kernel void kernel_gelu_quick_4(
283
  device const float4 * src0,
284
  device float4 * dst,
285
  uint tpig[[thread_position_in_grid]]) {
 
289
  }
290
 
291
  kernel void kernel_silu(
292
+ device const float * src0,
293
+ device float * dst,
294
+ uint tpig[[thread_position_in_grid]]) {
295
+ device const float & x = src0[tpig];
296
+ dst[tpig] = x / (1.0f + exp(-x));
297
+ }
298
+
299
+ +kernel void kernel_silu_4(
300
  device const float4 * src0,
301
  device float4 * dst,
302
  uint tpig[[thread_position_in_grid]]) {