JohannesGaessler commited on
Commit
76aa810
·
1 Parent(s): d1a29c6

test: fix OPT_STEP_ADAMW for test-backend-ops (ggml/974)

Browse files
Files changed (2) hide show
  1. ggml/include/ggml.h +1 -0
  2. ggml/src/ggml.c +6 -4
ggml/include/ggml.h CHANGED
@@ -2052,6 +2052,7 @@ extern "C" {
2052
  GGML_API struct ggml_tensor * ggml_opt_step_adamw(
2053
  struct ggml_context * ctx,
2054
  struct ggml_tensor * a,
 
2055
  float alpha,
2056
  float beta1,
2057
  float beta2,
 
2052
  GGML_API struct ggml_tensor * ggml_opt_step_adamw(
2053
  struct ggml_context * ctx,
2054
  struct ggml_tensor * a,
2055
+ struct ggml_tensor * grad,
2056
  float alpha,
2057
  float beta1,
2058
  float beta2,
ggml/src/ggml.c CHANGED
@@ -7818,12 +7818,14 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
7818
  struct ggml_tensor * ggml_opt_step_adamw(
7819
  struct ggml_context * ctx,
7820
  struct ggml_tensor * a,
 
7821
  float alpha,
7822
  float beta1,
7823
  float beta2,
7824
  float eps,
7825
  float wd) {
7826
  GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
 
7827
  GGML_ASSERT(alpha > 0.0f);
7828
  GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
7829
  GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
@@ -7842,9 +7844,9 @@ struct ggml_tensor * ggml_opt_step_adamw(
7842
 
7843
  result->op = GGML_OP_OPT_STEP_ADAMW;
7844
  result->src[0] = a;
7845
- result->src[1] = a->grad;
7846
- result->src[2] = ggml_dup_tensor(ctx, a);
7847
- result->src[3] = ggml_dup_tensor(ctx, a);
7848
 
7849
  return result;
7850
  }
@@ -18769,7 +18771,7 @@ void ggml_build_opt_adamw(
18769
 
18770
  if (node->flags & GGML_TENSOR_FLAG_PARAM) {
18771
  GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
18772
- struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
18773
  ggml_build_forward_expand(gb, opt_step);
18774
  }
18775
  }
 
7818
  struct ggml_tensor * ggml_opt_step_adamw(
7819
  struct ggml_context * ctx,
7820
  struct ggml_tensor * a,
7821
+ struct ggml_tensor * grad,
7822
  float alpha,
7823
  float beta1,
7824
  float beta2,
7825
  float eps,
7826
  float wd) {
7827
  GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
7828
+ GGML_ASSERT(ggml_are_same_shape(a, grad));
7829
  GGML_ASSERT(alpha > 0.0f);
7830
  GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
7831
  GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
 
7844
 
7845
  result->op = GGML_OP_OPT_STEP_ADAMW;
7846
  result->src[0] = a;
7847
+ result->src[1] = grad;
7848
+ result->src[2] = ggml_dup_tensor(ctx, grad);
7849
+ result->src[3] = ggml_dup_tensor(ctx, grad);
7850
 
7851
  return result;
7852
  }
 
18771
 
18772
  if (node->flags & GGML_TENSOR_FLAG_PARAM) {
18773
  GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
18774
+ struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd);
18775
  ggml_build_forward_expand(gb, opt_step);
18776
  }
18777
  }