Justina Cho commited on
Commit
cd0c122
·
1 Parent(s): c40b574

feat: implemented sigmoid function (ggml/806)

Browse files

* added sigmoid function

* implemented metal kernel for sigmoid

* implemented cuda kernel for sigmoid

* added sigmoid unary op and incremented count

Files changed (7) hide show
  1. ggml-cuda.cu +4 -0
  2. ggml-cuda/unary.cu +26 -0
  3. ggml-cuda/unary.cuh +3 -0
  4. ggml-metal.m +15 -0
  5. ggml-metal.metal +7 -0
  6. ggml.c +72 -1
  7. ggml.h +9 -0
ggml-cuda.cu CHANGED
@@ -2115,6 +2115,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2115
  case GGML_UNARY_OP_RELU:
2116
  ggml_cuda_op_relu(ctx, dst);
2117
  break;
 
 
 
2118
  case GGML_UNARY_OP_HARDSIGMOID:
2119
  ggml_cuda_op_hardsigmoid(ctx, dst);
2120
  break;
@@ -2355,6 +2358,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2355
  case GGML_UNARY_OP_GELU:
2356
  case GGML_UNARY_OP_SILU:
2357
  case GGML_UNARY_OP_RELU:
 
2358
  case GGML_UNARY_OP_HARDSIGMOID:
2359
  case GGML_UNARY_OP_HARDSWISH:
2360
  case GGML_UNARY_OP_GELU_QUICK:
 
2115
  case GGML_UNARY_OP_RELU:
2116
  ggml_cuda_op_relu(ctx, dst);
2117
  break;
2118
+ case GGML_UNARY_OP_SIGMOID:
2119
+ ggml_cuda_op_sigmoid(ctx, dst);
2120
+ break;
2121
  case GGML_UNARY_OP_HARDSIGMOID:
2122
  ggml_cuda_op_hardsigmoid(ctx, dst);
2123
  break;
 
2358
  case GGML_UNARY_OP_GELU:
2359
  case GGML_UNARY_OP_SILU:
2360
  case GGML_UNARY_OP_RELU:
2361
+ case GGML_UNARY_OP_SIGMOID:
2362
  case GGML_UNARY_OP_HARDSIGMOID:
2363
  case GGML_UNARY_OP_HARDSWISH:
2364
  case GGML_UNARY_OP_GELU_QUICK:
ggml-cuda/unary.cu CHANGED
@@ -48,6 +48,15 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
48
  dst[i] = fmaxf(x[i], 0);
49
  }
50
 
 
 
 
 
 
 
 
 
 
51
  static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
52
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
53
 
@@ -108,6 +117,11 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
108
  relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
109
  }
110
 
 
 
 
 
 
111
  static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
112
  const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
113
  hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -188,6 +202,18 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
188
  relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
189
  }
190
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
192
  const ggml_tensor * src0 = dst->src[0];
193
  const float * src0_d = (const float *)src0->data;
 
48
  dst[i] = fmaxf(x[i], 0);
49
  }
50
 
51
+ static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
52
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
53
+
54
+ if (i >= k) {
55
+ return;
56
+ }
57
+ dst[i] = 1.0f / (1.0f + expf(-x[i]));
58
+ }
59
+
60
  static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
61
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
62
 
 
117
  relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
118
  }
119
 
120
+ static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
121
+ const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
122
+ sigmoid_f32<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
123
+ }
124
+
125
  static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
126
  const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
127
  hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 
202
  relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
203
  }
204
 
205
+ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
206
+ const ggml_tensor * src0 = dst->src[0];
207
+ const float * src0_d = (const float *)src0->data;
208
+ float * dst_d = (float *)dst->data;
209
+ cudaStream_t stream = ctx.stream();
210
+
211
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
212
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
213
+
214
+ sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
215
+ }
216
+
217
  void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
218
  const ggml_tensor * src0 = dst->src[0];
219
  const float * src0_d = (const float *)src0->data;
ggml-cuda/unary.cuh CHANGED
@@ -4,6 +4,7 @@
4
  #define CUDA_SILU_BLOCK_SIZE 256
5
  #define CUDA_TANH_BLOCK_SIZE 256
6
  #define CUDA_RELU_BLOCK_SIZE 256
 
7
  #define CUDA_HARDSIGMOID_BLOCK_SIZE 256
8
  #define CUDA_HARDSWISH_BLOCK_SIZE 256
9
  #define CUDA_SQR_BLOCK_SIZE 256
@@ -18,6 +19,8 @@ void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
18
 
19
  void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
20
 
 
 
21
  void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
22
 
23
  void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
4
  #define CUDA_SILU_BLOCK_SIZE 256
5
  #define CUDA_TANH_BLOCK_SIZE 256
6
  #define CUDA_RELU_BLOCK_SIZE 256
7
+ #define CUDA_SIGMOID_BLOCK_SIZE 256
8
  #define CUDA_HARDSIGMOID_BLOCK_SIZE 256
9
  #define CUDA_HARDSWISH_BLOCK_SIZE 256
10
  #define CUDA_SQR_BLOCK_SIZE 256
 
19
 
20
  void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
21
 
22
+ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
23
+
24
  void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
25
 
26
  void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml-metal.m CHANGED
@@ -39,6 +39,7 @@ enum ggml_metal_kernel_type {
39
  GGML_METAL_KERNEL_TYPE_SCALE_4,
40
  GGML_METAL_KERNEL_TYPE_TANH,
41
  GGML_METAL_KERNEL_TYPE_RELU,
 
42
  GGML_METAL_KERNEL_TYPE_GELU,
43
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
44
  GGML_METAL_KERNEL_TYPE_SILU,
@@ -470,6 +471,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
470
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
471
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
472
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
 
473
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
474
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
475
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
@@ -695,6 +697,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
695
  switch (ggml_get_unary_op(op)) {
696
  case GGML_UNARY_OP_TANH:
697
  case GGML_UNARY_OP_RELU:
 
698
  case GGML_UNARY_OP_GELU:
699
  case GGML_UNARY_OP_GELU_QUICK:
700
  case GGML_UNARY_OP_SILU:
@@ -1178,6 +1181,18 @@ static enum ggml_status ggml_metal_graph_compute(
1178
 
1179
  const int64_t n = ggml_nelements(dst);
1180
 
 
 
 
 
 
 
 
 
 
 
 
 
1181
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1182
  } break;
1183
  case GGML_UNARY_OP_GELU:
 
39
  GGML_METAL_KERNEL_TYPE_SCALE_4,
40
  GGML_METAL_KERNEL_TYPE_TANH,
41
  GGML_METAL_KERNEL_TYPE_RELU,
42
+ GGML_METAL_KERNEL_TYPE_SIGMOID,
43
  GGML_METAL_KERNEL_TYPE_GELU,
44
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
45
  GGML_METAL_KERNEL_TYPE_SILU,
 
471
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
472
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
473
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
474
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
475
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
 
697
  switch (ggml_get_unary_op(op)) {
698
  case GGML_UNARY_OP_TANH:
699
  case GGML_UNARY_OP_RELU:
700
+ case GGML_UNARY_OP_SIGMOID:
701
  case GGML_UNARY_OP_GELU:
702
  case GGML_UNARY_OP_GELU_QUICK:
703
  case GGML_UNARY_OP_SILU:
 
1181
 
1182
  const int64_t n = ggml_nelements(dst);
1183
 
1184
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1185
+ } break;
1186
+ case GGML_UNARY_OP_SIGMOID:
1187
+ {
1188
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
1189
+
1190
+ [encoder setComputePipelineState:pipeline];
1191
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1192
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1193
+
1194
+ const int64_t n = ggml_nelements(dst);
1195
+
1196
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1197
  } break;
1198
  case GGML_UNARY_OP_GELU:
ggml-metal.metal CHANGED
@@ -220,6 +220,13 @@ kernel void kernel_relu(
220
  dst[tpig] = max(0.0f, src0[tpig]);
221
  }
222
 
 
 
 
 
 
 
 
223
  kernel void kernel_tanh(
224
  device const float * src0,
225
  device float * dst,
 
220
  dst[tpig] = max(0.0f, src0[tpig]);
221
  }
222
 
223
+ kernel void kernel_sigmoid(
224
+ device const float * src0,
225
+ device float * dst,
226
+ uint tpig[[thread_position_in_grid]]) {
227
+ dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
228
+ }
229
+
230
  kernel void kernel_tanh(
231
  device const float * src0,
232
  device float * dst,
ggml.c CHANGED
@@ -1763,6 +1763,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
1763
  inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
1764
  inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
1765
  inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
 
1766
  inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
1767
  // TODO: optimize performance
1768
  inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
@@ -2136,6 +2137,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
2136
  "TANH",
2137
  "ELU",
2138
  "RELU",
 
2139
  "GELU",
2140
  "GELU_QUICK",
2141
  "SILU",
@@ -2143,7 +2145,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
2143
  "HARDSIGMOID",
2144
  };
2145
 
2146
- static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12");
2147
 
2148
 
2149
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -4295,6 +4297,20 @@ struct ggml_tensor * ggml_relu_inplace(
4295
  return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
4296
  }
4297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4298
  // ggml_leaky_relu
4299
 
4300
  struct ggml_tensor * ggml_leaky_relu(
@@ -9838,6 +9854,52 @@ static void ggml_compute_forward_relu(
9838
  }
9839
  }
9840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9841
  // ggml_compute_forward_gelu
9842
 
9843
  static void ggml_compute_forward_gelu_f32(
@@ -15485,6 +15547,10 @@ static void ggml_compute_forward_unary(
15485
  {
15486
  ggml_compute_forward_relu(params, dst);
15487
  } break;
 
 
 
 
15488
  case GGML_UNARY_OP_GELU:
15489
  {
15490
  ggml_compute_forward_gelu(params, dst);
@@ -17471,6 +17537,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17471
  zero_table);
17472
  }
17473
  } break;
 
 
 
 
17474
  case GGML_UNARY_OP_GELU:
17475
  {
17476
  GGML_ASSERT(false); // TODO: not implemented
@@ -18000,6 +18070,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
18000
  case GGML_UNARY_OP_TANH:
18001
  case GGML_UNARY_OP_ELU:
18002
  case GGML_UNARY_OP_RELU:
 
18003
  case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
18004
  case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
18005
  {
 
1763
  inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
1764
  inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
1765
  inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
1766
+ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
1767
  inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
1768
  // TODO: optimize performance
1769
  inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
 
2137
  "TANH",
2138
  "ELU",
2139
  "RELU",
2140
+ "SIGMOID",
2141
  "GELU",
2142
  "GELU_QUICK",
2143
  "SILU",
 
2145
  "HARDSIGMOID",
2146
  };
2147
 
2148
+ static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
2149
 
2150
 
2151
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 
4297
  return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
4298
  }
4299
 
4300
+ // ggml_sigmoid
4301
+
4302
+ struct ggml_tensor * ggml_sigmoid(
4303
+ struct ggml_context * ctx,
4304
+ struct ggml_tensor * a) {
4305
+ return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
4306
+ }
4307
+
4308
+ struct ggml_tensor * ggml_sigmoid_inplace(
4309
+ struct ggml_context * ctx,
4310
+ struct ggml_tensor * a) {
4311
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
4312
+ }
4313
+
4314
  // ggml_leaky_relu
4315
 
4316
  struct ggml_tensor * ggml_leaky_relu(
 
9854
  }
9855
  }
9856
 
9857
+ // ggml_compute_forward_sigmoid
9858
+
9859
+ static void ggml_compute_forward_sigmoid_f32(
9860
+ const struct ggml_compute_params * params,
9861
+ struct ggml_tensor * dst) {
9862
+
9863
+ const struct ggml_tensor * src0 = dst->src[0];
9864
+
9865
+ assert(params->ith == 0);
9866
+ assert(ggml_are_same_shape(src0, dst));
9867
+
9868
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9869
+ return;
9870
+ }
9871
+
9872
+ const int n = ggml_nrows(src0);
9873
+ const int nc = src0->ne[0];
9874
+
9875
+ assert(dst->nb[0] == sizeof(float));
9876
+ assert(src0->nb[0] == sizeof(float));
9877
+
9878
+ for (int i = 0; i < n; i++) {
9879
+ ggml_vec_sigmoid_f32(nc,
9880
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
9881
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
9882
+ }
9883
+ }
9884
+
9885
+ static void ggml_compute_forward_sigmoid(
9886
+ const struct ggml_compute_params * params,
9887
+ struct ggml_tensor * dst) {
9888
+
9889
+ const struct ggml_tensor * src0 = dst->src[0];
9890
+
9891
+ switch (src0->type) {
9892
+ case GGML_TYPE_F32:
9893
+ {
9894
+ ggml_compute_forward_sigmoid_f32(params, dst);
9895
+ } break;
9896
+ default:
9897
+ {
9898
+ GGML_ASSERT(false);
9899
+ } break;
9900
+ }
9901
+ }
9902
+
9903
  // ggml_compute_forward_gelu
9904
 
9905
  static void ggml_compute_forward_gelu_f32(
 
15547
  {
15548
  ggml_compute_forward_relu(params, dst);
15549
  } break;
15550
+ case GGML_UNARY_OP_SIGMOID:
15551
+ {
15552
+ ggml_compute_forward_sigmoid(params, dst);
15553
+ } break;
15554
  case GGML_UNARY_OP_GELU:
15555
  {
15556
  ggml_compute_forward_gelu(params, dst);
 
17537
  zero_table);
17538
  }
17539
  } break;
17540
+ case GGML_UNARY_OP_SIGMOID:
17541
+ {
17542
+ GGML_ASSERT(false); // TODO: not implemented
17543
+ } break;
17544
  case GGML_UNARY_OP_GELU:
17545
  {
17546
  GGML_ASSERT(false); // TODO: not implemented
 
18070
  case GGML_UNARY_OP_TANH:
18071
  case GGML_UNARY_OP_ELU:
18072
  case GGML_UNARY_OP_RELU:
18073
+ case GGML_UNARY_OP_SIGMOID:
18074
  case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
18075
  case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
18076
  {
ggml.h CHANGED
@@ -511,6 +511,7 @@ extern "C" {
511
  GGML_UNARY_OP_TANH,
512
  GGML_UNARY_OP_ELU,
513
  GGML_UNARY_OP_RELU,
 
514
  GGML_UNARY_OP_GELU,
515
  GGML_UNARY_OP_GELU_QUICK,
516
  GGML_UNARY_OP_SILU,
@@ -1055,6 +1056,10 @@ extern "C" {
1055
  struct ggml_context * ctx,
1056
  struct ggml_tensor * a);
1057
 
 
 
 
 
1058
  GGML_API struct ggml_tensor * ggml_leaky_relu(
1059
  struct ggml_context * ctx,
1060
  struct ggml_tensor * a, float negative_slope, bool inplace);
@@ -1063,6 +1068,10 @@ extern "C" {
1063
  struct ggml_context * ctx,
1064
  struct ggml_tensor * a);
1065
 
 
 
 
 
1066
  GGML_API struct ggml_tensor * ggml_gelu(
1067
  struct ggml_context * ctx,
1068
  struct ggml_tensor * a);
 
511
  GGML_UNARY_OP_TANH,
512
  GGML_UNARY_OP_ELU,
513
  GGML_UNARY_OP_RELU,
514
+ GGML_UNARY_OP_SIGMOID,
515
  GGML_UNARY_OP_GELU,
516
  GGML_UNARY_OP_GELU_QUICK,
517
  GGML_UNARY_OP_SILU,
 
1056
  struct ggml_context * ctx,
1057
  struct ggml_tensor * a);
1058
 
1059
+ GGML_API struct ggml_tensor * ggml_sigmoid(
1060
+ struct ggml_context * ctx,
1061
+ struct ggml_tensor * a);
1062
+
1063
  GGML_API struct ggml_tensor * ggml_leaky_relu(
1064
  struct ggml_context * ctx,
1065
  struct ggml_tensor * a, float negative_slope, bool inplace);
 
1068
  struct ggml_context * ctx,
1069
  struct ggml_tensor * a);
1070
 
1071
+ GGML_API struct ggml_tensor * ggml_sigmoid_inplace(
1072
+ struct ggml_context * ctx,
1073
+ struct ggml_tensor * a);
1074
+
1075
  GGML_API struct ggml_tensor * ggml_gelu(
1076
  struct ggml_context * ctx,
1077
  struct ggml_tensor * a);