Spaces:
Sleeping
Sleeping
llama : add qwen2moe (llama/6074)
Browse files* support qwen2moe
* fix-review
* metal : support unary ops for nelements % 4 != 0
* metal : require contiguousness for float4 unary kernels
* metal : require contiguousness for float4 unary kernels (cont)
* fix-review
* names : for brevity "SHARED_EXP" -> "SHEXP"
* llama : reuse build_moe_ffn()
* llama : add model type name
---------
Co-authored-by: Georgi Gerganov <[email protected]>
- ggml-metal.m +42 -15
- ggml-metal.metal +26 -0
ggml-metal.m
CHANGED
|
@@ -42,8 +42,11 @@ enum ggml_metal_kernel_type {
|
|
| 42 |
GGML_METAL_KERNEL_TYPE_RELU,
|
| 43 |
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
| 44 |
GGML_METAL_KERNEL_TYPE_GELU,
|
|
|
|
| 45 |
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
|
|
|
| 46 |
GGML_METAL_KERNEL_TYPE_SILU,
|
|
|
|
| 47 |
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
|
| 48 |
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
|
| 49 |
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
|
@@ -475,8 +478,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 475 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
| 476 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
| 477 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
|
|
|
| 478 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
|
|
|
| 479 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
|
|
|
| 480 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
| 481 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
| 482 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
|
@@ -1181,6 +1187,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1181 |
} break;
|
| 1182 |
case GGML_OP_UNARY:
|
| 1183 |
switch (ggml_get_unary_op(gf->nodes[i])) {
|
|
|
|
|
|
|
|
|
|
| 1184 |
case GGML_UNARY_OP_TANH:
|
| 1185 |
{
|
| 1186 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
|
|
@@ -1219,42 +1228,60 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1219 |
} break;
|
| 1220 |
case GGML_UNARY_OP_GELU:
|
| 1221 |
{
|
| 1222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1223 |
|
| 1224 |
[encoder setComputePipelineState:pipeline];
|
| 1225 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1226 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1227 |
|
| 1228 |
-
|
| 1229 |
-
GGML_ASSERT(n % 4 == 0);
|
| 1230 |
-
|
| 1231 |
-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1232 |
} break;
|
| 1233 |
case GGML_UNARY_OP_GELU_QUICK:
|
| 1234 |
{
|
| 1235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1236 |
|
| 1237 |
[encoder setComputePipelineState:pipeline];
|
| 1238 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1239 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1240 |
|
| 1241 |
-
|
| 1242 |
-
GGML_ASSERT(n % 4 == 0);
|
| 1243 |
-
|
| 1244 |
-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1245 |
} break;
|
| 1246 |
case GGML_UNARY_OP_SILU:
|
| 1247 |
{
|
| 1248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1249 |
|
| 1250 |
[encoder setComputePipelineState:pipeline];
|
| 1251 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1252 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1253 |
|
| 1254 |
-
|
| 1255 |
-
GGML_ASSERT(n % 4 == 0);
|
| 1256 |
-
|
| 1257 |
-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1258 |
} break;
|
| 1259 |
default:
|
| 1260 |
{
|
|
|
|
| 42 |
GGML_METAL_KERNEL_TYPE_RELU,
|
| 43 |
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
| 44 |
GGML_METAL_KERNEL_TYPE_GELU,
|
| 45 |
+
GGML_METAL_KERNEL_TYPE_GELU_4,
|
| 46 |
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
| 47 |
+
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
| 48 |
GGML_METAL_KERNEL_TYPE_SILU,
|
| 49 |
+
GGML_METAL_KERNEL_TYPE_SILU_4,
|
| 50 |
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
|
| 51 |
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
|
| 52 |
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
|
|
|
| 478 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
| 479 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
| 480 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
| 481 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
| 482 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
| 483 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
| 484 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
| 485 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
| 486 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
| 487 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
| 488 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
|
|
|
| 1187 |
} break;
|
| 1188 |
case GGML_OP_UNARY:
|
| 1189 |
switch (ggml_get_unary_op(gf->nodes[i])) {
|
| 1190 |
+
// we are not taking into account the strides, so for now require contiguous tensors
|
| 1191 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 1192 |
+
|
| 1193 |
case GGML_UNARY_OP_TANH:
|
| 1194 |
{
|
| 1195 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
|
|
|
|
| 1228 |
} break;
|
| 1229 |
case GGML_UNARY_OP_GELU:
|
| 1230 |
{
|
| 1231 |
+
int64_t n = ggml_nelements(dst);
|
| 1232 |
+
|
| 1233 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 1234 |
+
|
| 1235 |
+
if (n % 4 == 0) {
|
| 1236 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
|
| 1237 |
+
n /= 4;
|
| 1238 |
+
} else {
|
| 1239 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
|
| 1240 |
+
}
|
| 1241 |
|
| 1242 |
[encoder setComputePipelineState:pipeline];
|
| 1243 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1244 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1245 |
|
| 1246 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
|
|
|
|
|
|
| 1247 |
} break;
|
| 1248 |
case GGML_UNARY_OP_GELU_QUICK:
|
| 1249 |
{
|
| 1250 |
+
int64_t n = ggml_nelements(dst);
|
| 1251 |
+
|
| 1252 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 1253 |
+
|
| 1254 |
+
if (n % 4 == 0) {
|
| 1255 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
|
| 1256 |
+
n /= 4;
|
| 1257 |
+
} else {
|
| 1258 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
|
| 1259 |
+
}
|
| 1260 |
|
| 1261 |
[encoder setComputePipelineState:pipeline];
|
| 1262 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1263 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1264 |
|
| 1265 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
|
|
|
|
|
|
| 1266 |
} break;
|
| 1267 |
case GGML_UNARY_OP_SILU:
|
| 1268 |
{
|
| 1269 |
+
int64_t n = ggml_nelements(dst);
|
| 1270 |
+
|
| 1271 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 1272 |
+
|
| 1273 |
+
if (n % 4 == 0) {
|
| 1274 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
|
| 1275 |
+
n /= 4;
|
| 1276 |
+
} else {
|
| 1277 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
|
| 1278 |
+
}
|
| 1279 |
|
| 1280 |
[encoder setComputePipelineState:pipeline];
|
| 1281 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1282 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1283 |
|
| 1284 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
|
|
|
|
|
|
| 1285 |
} break;
|
| 1286 |
default:
|
| 1287 |
{
|
ggml-metal.metal
CHANGED
|
@@ -249,6 +249,15 @@ constant float GELU_QUICK_COEF = -1.702f;
|
|
| 249 |
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
| 250 |
|
| 251 |
kernel void kernel_gelu(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
device const float4 * src0,
|
| 253 |
device float4 * dst,
|
| 254 |
uint tpig[[thread_position_in_grid]]) {
|
|
@@ -262,6 +271,15 @@ kernel void kernel_gelu(
|
|
| 262 |
}
|
| 263 |
|
| 264 |
kernel void kernel_gelu_quick(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
device const float4 * src0,
|
| 266 |
device float4 * dst,
|
| 267 |
uint tpig[[thread_position_in_grid]]) {
|
|
@@ -271,6 +289,14 @@ kernel void kernel_gelu_quick(
|
|
| 271 |
}
|
| 272 |
|
| 273 |
kernel void kernel_silu(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
device const float4 * src0,
|
| 275 |
device float4 * dst,
|
| 276 |
uint tpig[[thread_position_in_grid]]) {
|
|
|
|
| 249 |
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
| 250 |
|
| 251 |
kernel void kernel_gelu(
|
| 252 |
+
device const float * src0,
|
| 253 |
+
device float * dst,
|
| 254 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 255 |
+
device const float & x = src0[tpig];
|
| 256 |
+
|
| 257 |
+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
kernel void kernel_gelu_4(
|
| 261 |
device const float4 * src0,
|
| 262 |
device float4 * dst,
|
| 263 |
uint tpig[[thread_position_in_grid]]) {
|
|
|
|
| 271 |
}
|
| 272 |
|
| 273 |
kernel void kernel_gelu_quick(
|
| 274 |
+
device const float * src0,
|
| 275 |
+
device float * dst,
|
| 276 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 277 |
+
device const float & x = src0[tpig];
|
| 278 |
+
|
| 279 |
+
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
kernel void kernel_gelu_quick_4(
|
| 283 |
device const float4 * src0,
|
| 284 |
device float4 * dst,
|
| 285 |
uint tpig[[thread_position_in_grid]]) {
|
|
|
|
| 289 |
}
|
| 290 |
|
| 291 |
kernel void kernel_silu(
|
| 292 |
+
device const float * src0,
|
| 293 |
+
device float * dst,
|
| 294 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 295 |
+
device const float & x = src0[tpig];
|
| 296 |
+
dst[tpig] = x / (1.0f + exp(-x));
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
+kernel void kernel_silu_4(
|
| 300 |
device const float4 * src0,
|
| 301 |
device float4 * dst,
|
| 302 |
uint tpig[[thread_position_in_grid]]) {
|