Spaces:
Running
Running
ggml : add mrope kernel for metal (llama/13457)
Browse files
ggml/src/ggml-metal/ggml-metal-impl.h
CHANGED
|
@@ -207,6 +207,10 @@ typedef struct {
|
|
| 207 |
float attn_factor;
|
| 208 |
float beta_fast;
|
| 209 |
float beta_slow;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
} ggml_metal_kargs_rope;
|
| 211 |
|
| 212 |
typedef struct {
|
|
|
|
| 207 |
float attn_factor;
|
| 208 |
float beta_fast;
|
| 209 |
float beta_slow;
|
| 210 |
+
int32_t sect_0;
|
| 211 |
+
int32_t sect_1;
|
| 212 |
+
int32_t sect_2;
|
| 213 |
+
int32_t sect_3;
|
| 214 |
} ggml_metal_kargs_rope;
|
| 215 |
|
| 216 |
typedef struct {
|
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -332,6 +332,10 @@ enum ggml_metal_kernel_type {
|
|
| 332 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
|
| 333 |
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
| 334 |
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
| 336 |
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
| 337 |
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
|
@@ -1275,6 +1279,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1275 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
|
| 1276 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
| 1277 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1278 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
| 1279 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
| 1280 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
|
@@ -1637,16 +1645,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1637 |
case GGML_OP_NORM:
|
| 1638 |
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
| 1639 |
case GGML_OP_ROPE:
|
| 1640 |
-
|
| 1641 |
-
const int mode = ((const int32_t *) op->op_params)[2];
|
| 1642 |
-
if (mode & GGML_ROPE_TYPE_MROPE) {
|
| 1643 |
-
return false;
|
| 1644 |
-
}
|
| 1645 |
-
if (mode & GGML_ROPE_TYPE_VISION) {
|
| 1646 |
-
return false;
|
| 1647 |
-
}
|
| 1648 |
-
return true;
|
| 1649 |
-
}
|
| 1650 |
case GGML_OP_IM2COL:
|
| 1651 |
return op->src[0]->type == GGML_TYPE_F16;
|
| 1652 |
case GGML_OP_POOL_1D:
|
|
@@ -3826,6 +3825,7 @@ static bool ggml_metal_encode_node(
|
|
| 3826 |
} break;
|
| 3827 |
case GGML_OP_ROPE:
|
| 3828 |
{
|
|
|
|
| 3829 |
// make sure we have one or more position id(ne10) per token(ne02)
|
| 3830 |
GGML_ASSERT(ne10 % ne02 == 0);
|
| 3831 |
GGML_ASSERT(ne10 >= ne02);
|
|
@@ -3852,20 +3852,42 @@ static bool ggml_metal_encode_node(
|
|
| 3852 |
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
|
| 3853 |
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
|
| 3854 |
|
| 3855 |
-
const bool is_neox
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3856 |
|
| 3857 |
id<MTLComputePipelineState> pipeline = nil;
|
| 3858 |
|
| 3859 |
-
if (
|
| 3860 |
switch (src0->type) {
|
| 3861 |
-
case GGML_TYPE_F32: pipeline = ctx->kernels[
|
| 3862 |
-
case GGML_TYPE_F16: pipeline = ctx->kernels[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3863 |
default: GGML_ABORT("fatal error");
|
| 3864 |
};
|
| 3865 |
} else {
|
| 3866 |
switch (src0->type) {
|
| 3867 |
-
case GGML_TYPE_F32: pipeline = ctx->kernels[
|
| 3868 |
-
case GGML_TYPE_F16: pipeline = ctx->kernels[
|
| 3869 |
default: GGML_ABORT("fatal error");
|
| 3870 |
};
|
| 3871 |
}
|
|
@@ -3896,6 +3918,10 @@ static bool ggml_metal_encode_node(
|
|
| 3896 |
/*.attn_factor =*/ attn_factor,
|
| 3897 |
/*.beta_fast =*/ beta_fast,
|
| 3898 |
/*.beta_slow =*/ beta_slow,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3899 |
};
|
| 3900 |
|
| 3901 |
[encoder setComputePipelineState:pipeline];
|
|
|
|
| 332 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
|
| 333 |
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
| 334 |
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
| 335 |
+
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
|
| 336 |
+
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
|
| 337 |
+
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
|
| 338 |
+
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
|
| 339 |
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
| 340 |
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
| 341 |
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
|
|
|
| 1279 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
|
| 1280 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
| 1281 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
| 1282 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
|
| 1283 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
|
| 1284 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
|
| 1285 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
|
| 1286 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
| 1287 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
| 1288 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
|
|
|
| 1645 |
case GGML_OP_NORM:
|
| 1646 |
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
| 1647 |
case GGML_OP_ROPE:
|
| 1648 |
+
return true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1649 |
case GGML_OP_IM2COL:
|
| 1650 |
return op->src[0]->type == GGML_TYPE_F16;
|
| 1651 |
case GGML_OP_POOL_1D:
|
|
|
|
| 3825 |
} break;
|
| 3826 |
case GGML_OP_ROPE:
|
| 3827 |
{
|
| 3828 |
+
|
| 3829 |
// make sure we have one or more position id(ne10) per token(ne02)
|
| 3830 |
GGML_ASSERT(ne10 % ne02 == 0);
|
| 3831 |
GGML_ASSERT(ne10 >= ne02);
|
|
|
|
| 3852 |
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
|
| 3853 |
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
|
| 3854 |
|
| 3855 |
+
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
| 3856 |
+
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
| 3857 |
+
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
| 3858 |
+
|
| 3859 |
+
// mrope
|
| 3860 |
+
const int sect_0 = ((const int32_t *) dst->op_params)[11];
|
| 3861 |
+
const int sect_1 = ((const int32_t *) dst->op_params)[12];
|
| 3862 |
+
const int sect_2 = ((const int32_t *) dst->op_params)[13];
|
| 3863 |
+
const int sect_3 = ((const int32_t *) dst->op_params)[14];
|
| 3864 |
|
| 3865 |
id<MTLComputePipelineState> pipeline = nil;
|
| 3866 |
|
| 3867 |
+
if (is_neox) {
|
| 3868 |
switch (src0->type) {
|
| 3869 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
| 3870 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
| 3871 |
+
default: GGML_ABORT("fatal error");
|
| 3872 |
+
};
|
| 3873 |
+
} else if (is_mrope && !is_vision) {
|
| 3874 |
+
GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
|
| 3875 |
+
switch (src0->type) {
|
| 3876 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
|
| 3877 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
|
| 3878 |
+
default: GGML_ABORT("fatal error");
|
| 3879 |
+
};
|
| 3880 |
+
} else if (is_vision) {
|
| 3881 |
+
GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
|
| 3882 |
+
switch (src0->type) {
|
| 3883 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
|
| 3884 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
|
| 3885 |
default: GGML_ABORT("fatal error");
|
| 3886 |
};
|
| 3887 |
} else {
|
| 3888 |
switch (src0->type) {
|
| 3889 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
| 3890 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
| 3891 |
default: GGML_ABORT("fatal error");
|
| 3892 |
};
|
| 3893 |
}
|
|
|
|
| 3918 |
/*.attn_factor =*/ attn_factor,
|
| 3919 |
/*.beta_fast =*/ beta_fast,
|
| 3920 |
/*.beta_slow =*/ beta_slow,
|
| 3921 |
+
/* sect_0 =*/ sect_0,
|
| 3922 |
+
/* sect_1 =*/ sect_1,
|
| 3923 |
+
/* sect_2 =*/ sect_2,
|
| 3924 |
+
/* sect_3 =*/ sect_3,
|
| 3925 |
};
|
| 3926 |
|
| 3927 |
[encoder setComputePipelineState:pipeline];
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -2713,8 +2713,148 @@ kernel void kernel_rope_neox(
|
|
| 2713 |
}
|
| 2714 |
}
|
| 2715 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2716 |
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
| 2717 |
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
|
|
|
|
|
|
| 2718 |
|
| 2719 |
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
| 2720 |
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
|
@@ -2722,6 +2862,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
|
|
| 2722 |
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
| 2723 |
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
| 2724 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2725 |
typedef void (im2col_t)(
|
| 2726 |
device const float * x,
|
| 2727 |
device char * dst,
|
|
|
|
| 2713 |
}
|
| 2714 |
}
|
| 2715 |
|
| 2716 |
+
template<typename T>
|
| 2717 |
+
kernel void kernel_rope_multi(
|
| 2718 |
+
constant ggml_metal_kargs_rope & args,
|
| 2719 |
+
device const char * src0,
|
| 2720 |
+
device const char * src1,
|
| 2721 |
+
device const char * src2,
|
| 2722 |
+
device char * dst,
|
| 2723 |
+
ushort tiitg[[thread_index_in_threadgroup]],
|
| 2724 |
+
ushort3 tptg [[threads_per_threadgroup]],
|
| 2725 |
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
| 2726 |
+
const int i3 = tgpig[2];
|
| 2727 |
+
const int i2 = tgpig[1];
|
| 2728 |
+
const int i1 = tgpig[0];
|
| 2729 |
+
|
| 2730 |
+
float corr_dims[2];
|
| 2731 |
+
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
| 2732 |
+
|
| 2733 |
+
device const int32_t * pos = (device const int32_t *) src1;
|
| 2734 |
+
|
| 2735 |
+
const float inv_ndims = -1.f/args.n_dims;
|
| 2736 |
+
|
| 2737 |
+
float cos_theta;
|
| 2738 |
+
float sin_theta;
|
| 2739 |
+
|
| 2740 |
+
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
| 2741 |
+
if (i0 < args.n_dims) {
|
| 2742 |
+
const int ic = i0/2;
|
| 2743 |
+
|
| 2744 |
+
// mrope theta calculations
|
| 2745 |
+
// note: the rest is the same as kernel_rope_neox
|
| 2746 |
+
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
|
| 2747 |
+
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
|
| 2748 |
+
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
|
| 2749 |
+
const int sector = ic % sect_dims;
|
| 2750 |
+
|
| 2751 |
+
float theta_base;
|
| 2752 |
+
if (sector < args.sect_0) {
|
| 2753 |
+
theta_base = (float) pos[i2];
|
| 2754 |
+
} else if (sector < sec_w01) {
|
| 2755 |
+
theta_base = (float) pos[i2 + args.ne02];
|
| 2756 |
+
} else if (sector < sec_w012) {
|
| 2757 |
+
theta_base = (float) pos[i2 + args.ne02 * 2];
|
| 2758 |
+
} else {
|
| 2759 |
+
theta_base = (float) pos[i2 + args.ne02 * 3];
|
| 2760 |
+
}
|
| 2761 |
+
// end of mrope
|
| 2762 |
+
|
| 2763 |
+
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
| 2764 |
+
|
| 2765 |
+
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
| 2766 |
+
|
| 2767 |
+
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
| 2768 |
+
|
| 2769 |
+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
| 2770 |
+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
| 2771 |
+
|
| 2772 |
+
const float x0 = src[0];
|
| 2773 |
+
const float x1 = src[args.n_dims/2];
|
| 2774 |
+
|
| 2775 |
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
| 2776 |
+
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
| 2777 |
+
} else {
|
| 2778 |
+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
| 2779 |
+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
| 2780 |
+
|
| 2781 |
+
dst_data[0] = src[0];
|
| 2782 |
+
dst_data[1] = src[1];
|
| 2783 |
+
}
|
| 2784 |
+
}
|
| 2785 |
+
}
|
| 2786 |
+
|
| 2787 |
+
template<typename T>
|
| 2788 |
+
kernel void kernel_rope_vision(
|
| 2789 |
+
constant ggml_metal_kargs_rope & args,
|
| 2790 |
+
device const char * src0,
|
| 2791 |
+
device const char * src1,
|
| 2792 |
+
device const char * src2,
|
| 2793 |
+
device char * dst,
|
| 2794 |
+
ushort tiitg[[thread_index_in_threadgroup]],
|
| 2795 |
+
ushort3 tptg [[threads_per_threadgroup]],
|
| 2796 |
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
| 2797 |
+
const int i3 = tgpig[2];
|
| 2798 |
+
const int i2 = tgpig[1];
|
| 2799 |
+
const int i1 = tgpig[0];
|
| 2800 |
+
|
| 2801 |
+
float corr_dims[2];
|
| 2802 |
+
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
| 2803 |
+
|
| 2804 |
+
device const int32_t * pos = (device const int32_t *) src1;
|
| 2805 |
+
|
| 2806 |
+
const float inv_ndims = -1.f/args.n_dims;
|
| 2807 |
+
|
| 2808 |
+
float cos_theta;
|
| 2809 |
+
float sin_theta;
|
| 2810 |
+
|
| 2811 |
+
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
| 2812 |
+
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
|
| 2813 |
+
const int ic = i0/2;
|
| 2814 |
+
|
| 2815 |
+
// mrope theta calculations (only support 2 dimensions)
|
| 2816 |
+
const int sect_dims = args.sect_0 + args.sect_1;
|
| 2817 |
+
const int sector = ic % sect_dims;
|
| 2818 |
+
|
| 2819 |
+
float p;
|
| 2820 |
+
float theta_base;
|
| 2821 |
+
if (sector < args.sect_1) {
|
| 2822 |
+
p = (float) sector;
|
| 2823 |
+
theta_base = (float) pos[i2];
|
| 2824 |
+
} else {
|
| 2825 |
+
p = (float) sector - args.sect_0;
|
| 2826 |
+
theta_base = (float) pos[i2 + args.ne02];
|
| 2827 |
+
}
|
| 2828 |
+
|
| 2829 |
+
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
| 2830 |
+
// end of mrope
|
| 2831 |
+
|
| 2832 |
+
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
| 2833 |
+
|
| 2834 |
+
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
| 2835 |
+
|
| 2836 |
+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
| 2837 |
+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
| 2838 |
+
|
| 2839 |
+
const float x0 = src[0];
|
| 2840 |
+
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
|
| 2841 |
+
|
| 2842 |
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
| 2843 |
+
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
|
| 2844 |
+
} else {
|
| 2845 |
+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
| 2846 |
+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
| 2847 |
+
|
| 2848 |
+
dst_data[0] = src[0];
|
| 2849 |
+
dst_data[1] = src[1];
|
| 2850 |
+
}
|
| 2851 |
+
}
|
| 2852 |
+
}
|
| 2853 |
+
|
| 2854 |
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
| 2855 |
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
| 2856 |
+
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
|
| 2857 |
+
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
|
| 2858 |
|
| 2859 |
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
| 2860 |
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
|
|
|
| 2862 |
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
| 2863 |
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
| 2864 |
|
| 2865 |
+
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
|
| 2866 |
+
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
|
| 2867 |
+
|
| 2868 |
+
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
|
| 2869 |
+
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
|
| 2870 |
+
|
| 2871 |
typedef void (im2col_t)(
|
| 2872 |
device const float * x,
|
| 2873 |
device char * dst,
|