ngxson HF Staff commited on
Commit
3fe8af8
·
1 Parent(s): c4be6fb

ggml : add ggml_repeat_4d (llama/13824)

Browse files
Files changed (2) hide show
  1. ggml/include/ggml.h +9 -0
  2. ggml/src/ggml.c +20 -0
ggml/include/ggml.h CHANGED
@@ -935,6 +935,15 @@ extern "C" {
935
  struct ggml_tensor * a,
936
  struct ggml_tensor * b);
937
 
 
 
 
 
 
 
 
 
 
938
  // sums repetitions in a into shape of b
939
  GGML_API struct ggml_tensor * ggml_repeat_back(
940
  struct ggml_context * ctx,
 
935
  struct ggml_tensor * a,
936
  struct ggml_tensor * b);
937
 
938
+ // repeat a to the specified shape
939
+ GGML_API struct ggml_tensor * ggml_repeat_4d(
940
+ struct ggml_context * ctx,
941
+ struct ggml_tensor * a,
942
+ int64_t ne0,
943
+ int64_t ne1,
944
+ int64_t ne2,
945
+ int64_t ne3);
946
+
947
  // sums repetitions in a into shape of b
948
  GGML_API struct ggml_tensor * ggml_repeat_back(
949
  struct ggml_context * ctx,
ggml/src/ggml.c CHANGED
@@ -2319,6 +2319,26 @@ struct ggml_tensor * ggml_repeat(
2319
  return result;
2320
  }
2321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2322
  // ggml_repeat_back
2323
 
2324
  struct ggml_tensor * ggml_repeat_back(
 
2319
  return result;
2320
  }
2321
 
2322
+ struct ggml_tensor * ggml_repeat_4d(
2323
+ struct ggml_context * ctx,
2324
+ struct ggml_tensor * a,
2325
+ int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
2326
+ const bool can_repeat = ggml_is_empty(a) || (
2327
+ (ne0 % a->ne[0] == 0) &&
2328
+ (ne1 % a->ne[1] == 0) &&
2329
+ (ne2 % a->ne[2] == 0) &&
2330
+ (ne3 % a->ne[3] == 0)
2331
+ );
2332
+ GGML_ASSERT(can_repeat);
2333
+
2334
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
2335
+
2336
+ result->op = GGML_OP_REPEAT;
2337
+ result->src[0] = a;
2338
+
2339
+ return result;
2340
+ }
2341
+
2342
  // ggml_repeat_back
2343
 
2344
  struct ggml_tensor * ggml_repeat_back(