compilade commited on
Commit
224fbc2
·
unverified ·
1 Parent(s): 2e29431

llama : support Mamba Selective State Space Models (llama/5328)

Browse files

* mamba : begin working on support for Mamba SSM

* mamba : begin figuring out how to (ab)use the kv cache for Mamba

* mamba : recurrent inference almost works, but incoherent

* mamba : recurrent inference WORKS!!!

* convert : optionally use d_conv and d_state from config.json for Mamba

* mamba : refactor recurrent conv, resulting in 20% perf increase

It's still slower than I'd like, but I did not really optimize `ggml_exp` yet.

I also refactored `ggml_exp` to work with tensors with more than 2 dimensions.

* ggml : parallelize ggml_exp

This results in 8% faster token generation for Mamba-130M.

* mamba : simplify the conv step with a self-overlapping view

Turns out the conv_state can be made smaller by one column.
Note that this breaks existing GGUFs of Mamba,
because the key_value_length field is tied to the conv_state size.

Convolution with a self-overlapping view is cool!
And it's much simpler than what I initially thought would be necessary
to make the convolution step work with more than 1 token at a time.

Next step is to make the SSM step work on batches of tokens too,
and thus I need to figure out a way to make a parallel selective scan
which will keep the ssm_state small and won't make it bigger
by a factor of (n_layer * batch_size).

* llama : fix Mamba KV self size wrongly displaying as f16 instead of f32

Relatedly, I also tried to see if other types than f32 worked for the states,
but they don't, because of the operators used.
It's probably better anyway to keep lots of precision there,
since the states are small anyway.

* mamba : fix self-overlapping view depth stride

* mamba : handle batches of more than 1 token

This means running Mamba no longer crashes when using the default settings!
And probably also slightly faster prompt processing.
Both batched and non-batched processing yield the same output.

Previously, the state was not cleared when starting a sequence.
Next step is to make the KV cache API work as expected for Mamba models.

* ggml: add ggml_ssm_scan to help with parallel selective scan

If the selective scan was implemented without a custom operator,
there would be waaay too many nodes in the graph. For example,
for Mamba-130M, with a batch size of 512 (the default),
a naive selective scan could add at least 24*512=12288 nodes,
which is more than LLAMA_MAX_NODES (8192),
and that's only for the smallest Mamba model.
So it's much cleaner with a custom operator.
Not sure about the name, though.

* ggml : in ggml_ssm_scan, merge multiple rows in the same vec operation

This will help with performance on CPU if ggml_vec_mul_f32
and ggml_vec_add_f32 are ever optimized with SIMD.

* mamba : very basic quantization support

Mostly works, but there is currently no difference
between the variants of a k-quant (e.g. Q4_K_S and Q4_K_M are the same).
Most of the SSM-specific weights can be kept in f32 without affecting
the size that much, since they are relatively small.
(the linear projection weights are responsible for most of Mamba's size)

Too much quantization seems to make the state degrade quite fast, and
the model begins to output gibberish.
It seems to affect bigger models to a lesser extent than small models,
but I'm not sure by how much.

Experimentation will be needed to figure out which weights are more important
for the _M (and _L?) variants of k-quants for Mamba.

* convert : fix wrong name for layer norm weight of offical Mamba models

I was using Q-bert/Mamba-* models before, which have a slighlty different
naming scheme for the weights.
(they start with "model.layers" instead of "backbone.layers")

* mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator

This increases performance on CPU by around 30% for prompt processing,
and by around 20% for text generation.

However, it also makes the ggml_exp and ggml_soft_plus operators unused.
Whether or not they should be kept will be decided later.

* convert : for Mamba, also consider the "MambaLMHeadModel" arch name

It's the name of the class of the official implementation,
though they don't use it (yet) in the "architectures" field of config.json

* mamba : fix vocab size problems with official models

The perplexity was waaaay to high for models with a non-round vocab size.
Not sure why, but it needed to be fixed in the metadata.

Note that this breaks existing GGUF-converted Mamba models,
but **only if** the vocab size was not already rounded.

* ggml : remove ggml_exp and ggml_soft_plus

They did not exist anyway outside of this branch,
and since ggml_ssm_scan fused operations together, they are unused.
It's always possible to bring them back if needed.

* mamba : remove some useless comments

No code change.

* convert : fix flake8 linter errors

* mamba : apply suggestions from code review

* mamba : remove unecessary branch for row-wise ssm_state and C multiplication

It was previously done to avoid permuting when only one token is processed
at a time (like when generating text), but permuting is cheap,
and dynamically changing the compute graph is not future-proof.

* ggml : in ggml_ssm_scan, use more appropriate asserts

* ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32

* mamba : multiple sequences, but one at a time

This is a step towards making this Mamba implementation usable
with the server example (the way the system prompt is kept when clearing
the client slots will need to be changed before this can work, though).

The KV cache size for this kind of model is tied to the maximum number
of sequences kept at any single time.
For now, this number is obtained from n_parallel (plus one,
to have an extra sequence to dedicate to the system prompt),
but there might be a better way to do this which won't also
make the main example use 2 cells even if only 1 is really used.
(for this specific case, --parallel 0 helps)

Simultaneous sequence processing will probably require changes to
ggml_ssm_scan, and possibly a new operator for the conv step.

* mamba : support llama_kv_cache_seq_cp

This (mis)uses the logic around K shifts, because tokens in a state
can't be shifted anyway, and because inp_K_shift has the right shape and type.
Using ggml_get_rows is a nice way to do copies, but copy chains can't work.
Fortunately, copy chains don't really seem to be used in the examples.

Each KV cell is dedicated to the sequence ID corresponding to its own index.

* mamba : use a state mask

It's cleaner than the previous heuristic of
checking for the pos of the first token in the batch.

inp_KQ_mask could not be re-used for this, because it has the wrong shape
and because it seems more suited to the next step of
simultaneous sequence processing (helping with the problem of
remembering which token belongs to which sequence(s)/state(s)).

* llama : replace the usage of n_ctx with kv_self.size in many places

* mamba : use n_tokens directly instead of n_tok

* mamba : in comments, properly refer to KV cells instead of slots

* mamba : reduce memory usage of ggml_ssm_scan

From 290.37 MiB to 140.68 MiB of CPU compute buffer size
with Mamba 3B with a batch size of 512.

The result tensor of ggml_ssm_scan was previously a big part
of the CPU compute buffer size. To make it smaller,
it does not contain the intermediate ssm states anymore.
Both y and the last ssm state are combined in the result tensor,
because it seems only a single tensor can be returned by an operator
with the way the graph is built.

* mamba : simultaneous sequence processing

A batch can now contain tokens from multiple sequences.

This is necessary for at least the parallel example, the server example,
and the HellaSwag test in the perplexity example.

However, for this to be useful, uses of llama_kv_cache_seq_rm/cp
will need to be changed to work on whole sequences.

* ggml : add ggml_ssm_conv as a new operator for the conv step of Mamba

This operator makes it possible to use and update the correct states
for each token of the batch in the same way as ggml_ssm_scan.
Other solutions which use existing operators would need loops which would
add too many nodes to the graph (at least the ones I thought of).

Using this operator further reduces the size of the CPU compute buffer
from 140.68 MiB to 103.20 MiB with Mamba 3B with a batch size of 512.
And (at least on CPU), it's a bit faster than before.

Note that "ggml_ssm_conv" is probably not the most appropriate name,
and it could be changed if a better one is found.

* llama : add inp_s_seq as a new input tensor

The most convenient implementation to select the correct state (for Mamba)
for each token is to directly get the correct index from a tensor.
This is why inp_s_seq is storing int32_t and not floats.

The other, less convenient way to select the correct state would be
to have inp_KQ_mask contain 1.0f for each state used by a token
and 0.0f otherwise. This complicates quickly fetching the first used
state of a token, and is also less efficient because a whole row
of the mask would always need to be read for each token.

Using indexes makes it easy to stop searching when there are
no more sequences for a token, and the first sequence assigned
is always very quickly available (it's the first element of each row).

* mamba : support llama_kv_cache_seq_cp copy chains

* mamba : support shifting and dividing the kv cache pos

* mamba : make the server and parallel examples work with whole sequences

A seq_id is dedicated to the system prompt in both cases.

* llama : make llama_kv_cache_seq_rm return whether it succeeded or not

* mamba : dedicate an input tensor for state copy indices

This is cleaner and makes it easier to adapt when/if token positions
(and by extension, inp_K_shift) are no longer integers.

* mamba : adapt perplexity, batched, and batched-bench examples

* perplexity : limit the max number of sequences

This adapts to what the loaded model can provide.

* llama : add llama_n_max_seq to get the upper limit for seq_ids

Used by the perplexity example.

* batched : pass n_parallel to the model's context params

This should have been there already, but it wasn't.

* batched-bench : reserve sequences to support Mamba

* b

Files changed (2) hide show
  1. ggml.c +377 -2
  2. ggml.h +19 -0
ggml.c CHANGED
@@ -1841,6 +1841,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
1841
  "FLASH_ATTN",
1842
  "FLASH_FF",
1843
  "FLASH_ATTN_BACK",
 
 
1844
  "WIN_PART",
1845
  "WIN_UNPART",
1846
  "GET_REL_POS",
@@ -1863,7 +1865,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
1863
  "CROSS_ENTROPY_LOSS_BACK",
1864
  };
1865
 
1866
- static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
1867
 
1868
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1869
  "none",
@@ -1929,6 +1931,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1929
  "flash_attn(x)",
1930
  "flash_ff(x)",
1931
  "flash_attn_back(x)",
 
 
1932
  "win_part(x)",
1933
  "win_unpart(x)",
1934
  "get_rel_pos(x)",
@@ -1951,7 +1955,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1951
  "cross_entropy_loss_back(x,y)",
1952
  };
1953
 
1954
- static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
1955
 
1956
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1957
 
@@ -6154,6 +6158,108 @@ struct ggml_tensor * ggml_flash_attn_back(
6154
  return result;
6155
  }
6156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6157
  // ggml_win_part
6158
 
6159
  struct ggml_tensor * ggml_win_part(
@@ -14771,6 +14877,257 @@ static void ggml_compute_forward_flash_attn_back(
14771
  }
14772
  }
14773
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14774
  // ggml_compute_forward_win_part
14775
 
14776
  static void ggml_compute_forward_win_part_f32(
@@ -15830,6 +16187,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
15830
  bool masked = t != 0;
15831
  ggml_compute_forward_flash_attn_back(params, masked, tensor);
15832
  } break;
 
 
 
 
 
 
 
 
15833
  case GGML_OP_WIN_PART:
15834
  {
15835
  ggml_compute_forward_win_part(params, tensor);
@@ -16884,6 +17249,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16884
  {
16885
  GGML_ASSERT(false); // not supported
16886
  } break;
 
 
 
 
 
16887
  case GGML_OP_WIN_PART:
16888
  case GGML_OP_WIN_UNPART:
16889
  case GGML_OP_UNARY:
@@ -17590,6 +17960,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
17590
  {
17591
  n_tasks = n_threads;
17592
  } break;
 
 
 
 
 
17593
  case GGML_OP_WIN_PART:
17594
  case GGML_OP_WIN_UNPART:
17595
  case GGML_OP_GET_REL_POS:
 
1841
  "FLASH_ATTN",
1842
  "FLASH_FF",
1843
  "FLASH_ATTN_BACK",
1844
+ "SSM_CONV",
1845
+ "SSM_SCAN",
1846
  "WIN_PART",
1847
  "WIN_UNPART",
1848
  "GET_REL_POS",
 
1865
  "CROSS_ENTROPY_LOSS_BACK",
1866
  };
1867
 
1868
+ static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
1869
 
1870
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1871
  "none",
 
1931
  "flash_attn(x)",
1932
  "flash_ff(x)",
1933
  "flash_attn_back(x)",
1934
+ "ssm_conv(x)",
1935
+ "ssm_scan(x)",
1936
  "win_part(x)",
1937
  "win_unpart(x)",
1938
  "get_rel_pos(x)",
 
1955
  "cross_entropy_loss_back(x,y)",
1956
  };
1957
 
1958
+ static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
1959
 
1960
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1961
 
 
6158
  return result;
6159
  }
6160
 
6161
+ // ggml_ssm_conv
6162
+
6163
+ struct ggml_tensor * ggml_ssm_conv(
6164
+ struct ggml_context * ctx,
6165
+ struct ggml_tensor * s,
6166
+ struct ggml_tensor * x,
6167
+ struct ggml_tensor * c,
6168
+ struct ggml_tensor * sq) {
6169
+ GGML_ASSERT(ggml_is_3d(s));
6170
+ GGML_ASSERT(ggml_is_matrix(x));
6171
+ GGML_ASSERT(ggml_is_matrix(c));
6172
+ GGML_ASSERT(ggml_is_matrix(sq));
6173
+ GGML_ASSERT(sq->type == GGML_TYPE_I32);
6174
+
6175
+ const int64_t d_conv = c->ne[0];
6176
+ const int64_t d_inner = c->ne[1];
6177
+ const int64_t n_tokens = x->ne[1];
6178
+ const int64_t n_kv = s->ne[2];
6179
+
6180
+ GGML_ASSERT( s->ne[0] == d_conv - 1);
6181
+ GGML_ASSERT( s->ne[1] == d_inner);
6182
+ GGML_ASSERT( x->ne[0] == d_inner);
6183
+ GGML_ASSERT(sq->ne[0] == n_kv);
6184
+ GGML_ASSERT(sq->ne[1] == n_tokens);
6185
+
6186
+ bool is_node = false;
6187
+
6188
+ if (s->grad || x->grad || c->grad || sq->grad) {
6189
+ GGML_ASSERT(false); // TODO: implement
6190
+ is_node = true;
6191
+ }
6192
+
6193
+ // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
6194
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
6195
+
6196
+ result->op = GGML_OP_SSM_CONV;
6197
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6198
+ result->src[0] = s;
6199
+ result->src[1] = x;
6200
+ result->src[2] = c;
6201
+ result->src[3] = sq;
6202
+
6203
+ return result;
6204
+ }
6205
+
6206
+ // ggml_ssm_scan
6207
+
6208
+ struct ggml_tensor * ggml_ssm_scan(
6209
+ struct ggml_context * ctx,
6210
+ struct ggml_tensor * s,
6211
+ struct ggml_tensor * x,
6212
+ struct ggml_tensor * dt,
6213
+ struct ggml_tensor * A,
6214
+ struct ggml_tensor * B,
6215
+ struct ggml_tensor * C,
6216
+ struct ggml_tensor * sq) {
6217
+ GGML_ASSERT(ggml_is_contiguous(s));
6218
+ GGML_ASSERT(ggml_is_contiguous(x));
6219
+ GGML_ASSERT(ggml_is_contiguous(dt));
6220
+ GGML_ASSERT(ggml_is_contiguous(A));
6221
+ GGML_ASSERT(sq->type == GGML_TYPE_I32);
6222
+ GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
6223
+ GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
6224
+ GGML_ASSERT(ggml_are_same_shape(x, dt));
6225
+
6226
+ {
6227
+ const int64_t d_state = s->ne[0];
6228
+ const int64_t d_inner = s->ne[1];
6229
+ const int64_t n_tokens = x->ne[1];
6230
+
6231
+ GGML_ASSERT(x->ne[0] == d_inner);
6232
+ GGML_ASSERT(A->ne[0] == d_state);
6233
+ GGML_ASSERT(A->ne[1] == d_inner);
6234
+ GGML_ASSERT(B->ne[0] == d_state);
6235
+ GGML_ASSERT(B->ne[1] == n_tokens);
6236
+ GGML_ASSERT(C->ne[0] == d_state);
6237
+ GGML_ASSERT(C->ne[1] == n_tokens);
6238
+ }
6239
+
6240
+ bool is_node = false;
6241
+
6242
+ if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) {
6243
+ GGML_ASSERT(false); // TODO: implement
6244
+ is_node = true;
6245
+ }
6246
+
6247
+ // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
6248
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
6249
+
6250
+ result->op = GGML_OP_SSM_SCAN;
6251
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6252
+ result->src[0] = s;
6253
+ result->src[1] = x;
6254
+ result->src[2] = dt;
6255
+ result->src[3] = A;
6256
+ result->src[4] = B;
6257
+ result->src[5] = C;
6258
+ result->src[6] = sq;
6259
+
6260
+ return result;
6261
+ }
6262
+
6263
  // ggml_win_part
6264
 
6265
  struct ggml_tensor * ggml_win_part(
 
14877
  }
14878
  }
14879
 
14880
+ // ggml_compute_forward_ssm_conv
14881
+
14882
+ static void ggml_compute_forward_ssm_conv_f32(
14883
+ const struct ggml_compute_params * params,
14884
+ struct ggml_tensor * dst) {
14885
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14886
+ return;
14887
+ }
14888
+
14889
+ const struct ggml_tensor * src0 = dst->src[0]; // conv_state
14890
+ const struct ggml_tensor * src1 = dst->src[1]; // x
14891
+ const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
14892
+ const struct ggml_tensor * src3 = dst->src[3]; // state_seq
14893
+
14894
+ const int ith = params->ith;
14895
+ const int nth = params->nth;
14896
+
14897
+ const int nc = src2->ne[0]; // d_conv
14898
+ const int nr = src0->ne[1]; // d_inner
14899
+ const int n_t = src1->ne[1]; // n_tokens
14900
+ const int n_kv = src0->ne[2]; // max number of sequences in the batch
14901
+
14902
+ GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
14903
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
14904
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
14905
+ GGML_ASSERT(src2->nb[0] == sizeof(float));
14906
+ GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
14907
+ GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
14908
+ // for use with the destination state offset between sequences
14909
+ GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
14910
+
14911
+ // rows per thread
14912
+ const int dr = (nr + nth - 1)/nth;
14913
+
14914
+ // row range for this thread
14915
+ const int ir0 = dr*ith;
14916
+ const int ir1 = MIN(ir0 + dr, nr);
14917
+ const int ir = ir1 - ir0;
14918
+
14919
+ if (n_kv > 1) {
14920
+ // multiple sequences means it's hard to know when it's the first time a state is read,
14921
+ // so copy them all over to the destination, just to be sure.
14922
+ for (int i3 = 0; i3 < n_kv; ++i3) {
14923
+ float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
14924
+ float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
14925
+ // can't use memcpy because of d_conv vs d_conv - 1
14926
+ for (int i1 = 0; i1 < ir; ++i1) {
14927
+ for (int i0 = 0; i0 < nc - 1; ++i0) {
14928
+ // copy s0 to last (d_conv - 1) columns of s
14929
+ s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
14930
+ }
14931
+ }
14932
+ }
14933
+ }
14934
+
14935
+ for (int i2 = 0; i2 < n_t; ++i2) {
14936
+ int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
14937
+ float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
14938
+ float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
14939
+ float * s0; // {d_conv - 1, d_inner, n_kv}
14940
+ float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
14941
+ float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
14942
+ int ne0s0;
14943
+
14944
+ GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
14945
+
14946
+ // avoid needing to copy the state for the first token
14947
+ if (i2 == 0) {
14948
+ s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
14949
+ ne0s0 = src0->ne[0];
14950
+ } else {
14951
+ // the source is the last (d_conv - 1) columns of the destination
14952
+ s0 = s + 1;
14953
+ ne0s0 = nc;
14954
+ }
14955
+
14956
+ // d_inner
14957
+ for (int i1 = 0; i1 < ir; ++i1) {
14958
+ // shift state left
14959
+ for (int i0 = 0; i0 < nc - 1; ++i0) {
14960
+ s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
14961
+ }
14962
+ // insert x on the last column
14963
+ s[(nc - 1) + i1*nc] = x0[i1];
14964
+ }
14965
+
14966
+ // handle copies when there are multiple output states
14967
+ for (int i3 = 1; i3 < n_kv; ++i3) {
14968
+ int32_t seq = sq[i3];
14969
+ if (0 <= seq && seq < n_kv) {
14970
+ float * s1 = s + (seq - sq[0])*nc*nr;
14971
+ memcpy(s1, s, nc*ir*sizeof(float));
14972
+ } else {
14973
+ // stop at negative or too big seq_ids
14974
+ break;
14975
+ }
14976
+ }
14977
+
14978
+ // it seems a little faster when this is separate from the state shift
14979
+ for (int i1 = 0; i1 < ir; ++i1) {
14980
+ // rowwise dot product
14981
+ float sumf = 0.0f;
14982
+ for (int i0 = 0; i0 < nc; ++i0) {
14983
+ int i = i0 + i1*nc;
14984
+ sumf += s[i] * c[i];
14985
+ }
14986
+ x[i1] = sumf;
14987
+ }
14988
+ }
14989
+ }
14990
+
14991
+ static void ggml_compute_forward_ssm_conv(
14992
+ const struct ggml_compute_params * params,
14993
+ struct ggml_tensor * dst) {
14994
+ switch (dst->src[0]->type) {
14995
+ case GGML_TYPE_F32:
14996
+ {
14997
+ ggml_compute_forward_ssm_conv_f32(params, dst);
14998
+ } break;
14999
+ default:
15000
+ {
15001
+ GGML_ASSERT(false);
15002
+ } break;
15003
+ }
15004
+ }
15005
+
15006
+ // ggml_compute_forward_ssm_scan
15007
+
15008
+ static void ggml_compute_forward_ssm_scan_f32(
15009
+ const struct ggml_compute_params * params,
15010
+ struct ggml_tensor * dst) {
15011
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
15012
+ return;
15013
+ }
15014
+
15015
+ const struct ggml_tensor * src0 = dst->src[0]; // s
15016
+ const struct ggml_tensor * src1 = dst->src[1]; // x
15017
+ const struct ggml_tensor * src2 = dst->src[2]; // dt
15018
+ const struct ggml_tensor * src3 = dst->src[3]; // A
15019
+ const struct ggml_tensor * src4 = dst->src[4]; // B
15020
+ const struct ggml_tensor * src5 = dst->src[5]; // C
15021
+ const struct ggml_tensor * src6 = dst->src[6]; // sq
15022
+
15023
+ const int ith = params->ith;
15024
+ const int nth = params->nth;
15025
+
15026
+ const int64_t nc = src0->ne[0]; // d_state
15027
+ const int64_t nr = src0->ne[1]; // d_inner
15028
+ const int64_t n_t = src1->ne[1]; // number of tokens in the batch
15029
+ const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
15030
+
15031
+ GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
15032
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
15033
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
15034
+ GGML_ASSERT(src2->nb[0] == sizeof(float));
15035
+ GGML_ASSERT(src3->nb[0] == sizeof(float));
15036
+ GGML_ASSERT(src4->nb[0] == sizeof(float));
15037
+ GGML_ASSERT(src5->nb[0] == sizeof(float));
15038
+ // required for the dot product between s and C, and when copying the states
15039
+ GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
15040
+ // required for per-sequence offsets for states
15041
+ GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
15042
+ // required to get correct offset for state destination (i.e. src1->nb[2])
15043
+ GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
15044
+
15045
+ // rows per thread
15046
+ const int dr = (nr + nth - 1)/nth;
15047
+
15048
+ // row range for this thread
15049
+ const int ir0 = dr*ith;
15050
+ const int ir1 = MIN(ir0 + dr, nr);
15051
+ const int ir = ir1 - ir0;
15052
+
15053
+ if (n_kv > 1) {
15054
+ // it's hard to know if the source states have already been copied
15055
+ // when there are multiple, so copy them already.
15056
+ for (int i3 = 0; i3 < n_kv; ++i3) {
15057
+ float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
15058
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
15059
+ memcpy(s, s0, nc*ir*sizeof(float));
15060
+ }
15061
+ }
15062
+
15063
+ for (int i2 = 0; i2 < n_t; ++i2) {
15064
+ int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
15065
+ float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
15066
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
15067
+ float * s0;
15068
+ float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
15069
+ float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
15070
+ float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
15071
+ float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
15072
+ float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
15073
+
15074
+ GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
15075
+
15076
+ // avoid needing to copy the state for the first token
15077
+ if (i2 == 0) {
15078
+ s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
15079
+ } else {
15080
+ // otherwise the source is the same as the destination
15081
+ s0 = s;
15082
+ }
15083
+
15084
+ // d_inner
15085
+ for (int i1 = 0; i1 < ir; ++i1) {
15086
+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
15087
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
15088
+ float x_dt = x[i1] * dt_soft_plus;
15089
+ float sumf = 0.0f;
15090
+ // d_state
15091
+ for (int i0 = 0; i0 < nc; ++i0) {
15092
+ int i = i0 + i1*nc;
15093
+ // state = prev_state * dA + dB * x
15094
+ float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
15095
+ // y = rowwise_dotprod(state, C)
15096
+ sumf += state * C[i0];
15097
+ s[i] = state;
15098
+ }
15099
+ y[i1] = sumf;
15100
+ }
15101
+
15102
+ // handle copies when there are multiple output states
15103
+ for (int i3 = 1; i3 < n_kv; ++i3) {
15104
+ int32_t seq = sq[i3];
15105
+ if (0 <= seq && seq < n_kv) {
15106
+ float * s1 = s + (seq - sq[0])*nc*nr;
15107
+ memcpy(s1, s, nc*ir*sizeof(float));
15108
+ } else {
15109
+ // stop at negative or too big seq_ids
15110
+ break;
15111
+ }
15112
+ }
15113
+ }
15114
+ }
15115
+
15116
+ static void ggml_compute_forward_ssm_scan(
15117
+ const struct ggml_compute_params * params,
15118
+ struct ggml_tensor * dst) {
15119
+ switch (dst->src[0]->type) {
15120
+ case GGML_TYPE_F32:
15121
+ {
15122
+ ggml_compute_forward_ssm_scan_f32(params, dst);
15123
+ } break;
15124
+ default:
15125
+ {
15126
+ GGML_ASSERT(false);
15127
+ } break;
15128
+ }
15129
+ }
15130
+
15131
  // ggml_compute_forward_win_part
15132
 
15133
  static void ggml_compute_forward_win_part_f32(
 
16187
  bool masked = t != 0;
16188
  ggml_compute_forward_flash_attn_back(params, masked, tensor);
16189
  } break;
16190
+ case GGML_OP_SSM_CONV:
16191
+ {
16192
+ ggml_compute_forward_ssm_conv(params, tensor);
16193
+ } break;
16194
+ case GGML_OP_SSM_SCAN:
16195
+ {
16196
+ ggml_compute_forward_ssm_scan(params, tensor);
16197
+ } break;
16198
  case GGML_OP_WIN_PART:
16199
  {
16200
  ggml_compute_forward_win_part(params, tensor);
 
17249
  {
17250
  GGML_ASSERT(false); // not supported
17251
  } break;
17252
+ case GGML_OP_SSM_CONV:
17253
+ case GGML_OP_SSM_SCAN:
17254
+ {
17255
+ GGML_ASSERT(false); // TODO: not implemented
17256
+ } break;
17257
  case GGML_OP_WIN_PART:
17258
  case GGML_OP_WIN_UNPART:
17259
  case GGML_OP_UNARY:
 
17960
  {
17961
  n_tasks = n_threads;
17962
  } break;
17963
+ case GGML_OP_SSM_CONV:
17964
+ case GGML_OP_SSM_SCAN:
17965
+ {
17966
+ n_tasks = n_threads;
17967
+ } break;
17968
  case GGML_OP_WIN_PART:
17969
  case GGML_OP_WIN_UNPART:
17970
  case GGML_OP_GET_REL_POS:
ggml.h CHANGED
@@ -472,6 +472,8 @@ extern "C" {
472
  GGML_OP_FLASH_ATTN,
473
  GGML_OP_FLASH_FF,
474
  GGML_OP_FLASH_ATTN_BACK,
 
 
475
  GGML_OP_WIN_PART,
476
  GGML_OP_WIN_UNPART,
477
  GGML_OP_GET_REL_POS,
@@ -1728,6 +1730,23 @@ extern "C" {
1728
  struct ggml_tensor * c0,
1729
  struct ggml_tensor * c1);
1730
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1731
  // partition into non-overlapping windows with padding if needed
1732
  // example:
1733
  // a: 768 64 64 1
 
472
  GGML_OP_FLASH_ATTN,
473
  GGML_OP_FLASH_FF,
474
  GGML_OP_FLASH_ATTN_BACK,
475
+ GGML_OP_SSM_CONV,
476
+ GGML_OP_SSM_SCAN,
477
  GGML_OP_WIN_PART,
478
  GGML_OP_WIN_UNPART,
479
  GGML_OP_GET_REL_POS,
 
1730
  struct ggml_tensor * c0,
1731
  struct ggml_tensor * c1);
1732
 
1733
+ GGML_API struct ggml_tensor * ggml_ssm_conv(
1734
+ struct ggml_context * ctx,
1735
+ struct ggml_tensor * s,
1736
+ struct ggml_tensor * x,
1737
+ struct ggml_tensor * c,
1738
+ struct ggml_tensor * sq);
1739
+
1740
+ GGML_API struct ggml_tensor * ggml_ssm_scan(
1741
+ struct ggml_context * ctx,
1742
+ struct ggml_tensor * s,
1743
+ struct ggml_tensor * x,
1744
+ struct ggml_tensor * dt,
1745
+ struct ggml_tensor * A,
1746
+ struct ggml_tensor * B,
1747
+ struct ggml_tensor * C,
1748
+ struct ggml_tensor * sq);
1749
+
1750
  // partition into non-overlapping windows with padding if needed
1751
  // example:
1752
  // a: 768 64 64 1