Spaces:
Running
Running
ggml : add ggml_repeat_4d (llama/13824)
Browse files- ggml/include/ggml.h +9 -0
- ggml/src/ggml.c +20 -0
ggml/include/ggml.h
CHANGED
|
@@ -935,6 +935,15 @@ extern "C" {
|
|
| 935 |
struct ggml_tensor * a,
|
| 936 |
struct ggml_tensor * b);
|
| 937 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 938 |
// sums repetitions in a into shape of b
|
| 939 |
GGML_API struct ggml_tensor * ggml_repeat_back(
|
| 940 |
struct ggml_context * ctx,
|
|
|
|
| 935 |
struct ggml_tensor * a,
|
| 936 |
struct ggml_tensor * b);
|
| 937 |
|
| 938 |
+
// repeat a to the specified shape
|
| 939 |
+
GGML_API struct ggml_tensor * ggml_repeat_4d(
|
| 940 |
+
struct ggml_context * ctx,
|
| 941 |
+
struct ggml_tensor * a,
|
| 942 |
+
int64_t ne0,
|
| 943 |
+
int64_t ne1,
|
| 944 |
+
int64_t ne2,
|
| 945 |
+
int64_t ne3);
|
| 946 |
+
|
| 947 |
// sums repetitions in a into shape of b
|
| 948 |
GGML_API struct ggml_tensor * ggml_repeat_back(
|
| 949 |
struct ggml_context * ctx,
|
ggml/src/ggml.c
CHANGED
|
@@ -2319,6 +2319,26 @@ struct ggml_tensor * ggml_repeat(
|
|
| 2319 |
return result;
|
| 2320 |
}
|
| 2321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2322 |
// ggml_repeat_back
|
| 2323 |
|
| 2324 |
struct ggml_tensor * ggml_repeat_back(
|
|
|
|
| 2319 |
return result;
|
| 2320 |
}
|
| 2321 |
|
| 2322 |
+
struct ggml_tensor * ggml_repeat_4d(
|
| 2323 |
+
struct ggml_context * ctx,
|
| 2324 |
+
struct ggml_tensor * a,
|
| 2325 |
+
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
|
| 2326 |
+
const bool can_repeat = ggml_is_empty(a) || (
|
| 2327 |
+
(ne0 % a->ne[0] == 0) &&
|
| 2328 |
+
(ne1 % a->ne[1] == 0) &&
|
| 2329 |
+
(ne2 % a->ne[2] == 0) &&
|
| 2330 |
+
(ne3 % a->ne[3] == 0)
|
| 2331 |
+
);
|
| 2332 |
+
GGML_ASSERT(can_repeat);
|
| 2333 |
+
|
| 2334 |
+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
|
| 2335 |
+
|
| 2336 |
+
result->op = GGML_OP_REPEAT;
|
| 2337 |
+
result->src[0] = a;
|
| 2338 |
+
|
| 2339 |
+
return result;
|
| 2340 |
+
}
|
| 2341 |
+
|
| 2342 |
// ggml_repeat_back
|
| 2343 |
|
| 2344 |
struct ggml_tensor * ggml_repeat_back(
|