File size: 4,092 Bytes
58220b6
 
 
 
 
 
 
 
 
 
 
 
 
5ef1601
58220b6
 
 
 
 
 
 
 
844e617
58220b6
 
 
 
 
 
 
 
 
 
 
0ec1374
ade9bc3
58220b6
fc04dc0
58220b6
0ec1374
58220b6
0ec1374
58220b6
 
 
5ef1601
 
 
 
 
 
 
 
 
 
 
58220b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
844e617
 
58220b6
 
 
 
0ec1374
58220b6
bc53087
 
58220b6
0ec1374
58220b6
0ec1374
 
58220b6
 
0ec1374
 
5ef1601
 
 
 
0ec1374
 
58220b6
bc53087
 
58220b6
 
0ec1374
58220b6
 
0ec1374
58220b6
 
 
 
 
 
 
 
 
0ec1374
58220b6
 
0ec1374
 
58220b6
 
 
 
 
 
 
 
 
0ec1374
 
ade9bc3
 
58220b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#pragma once

#include "llama-kv-cache-unified.h"

#include <vector>

//
// llama_kv_cache_unified_iswa
//

// utilizes two instances of llama_kv_cache_unified
//   the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers

class llama_kv_cache_unified_iswa : public llama_memory_i {
public:
    llama_kv_cache_unified_iswa(
            const llama_model & model,
                    ggml_type   type_k,
                    ggml_type   type_v,
                         bool   v_trans,
                         bool   offload,
                         bool   swa_full,
                         bool   unified,
                     uint32_t   kv_size,
                     uint32_t   n_seq_max,
                     uint32_t   n_ubatch,
                     uint32_t   n_pad);

    ~llama_kv_cache_unified_iswa() = default;

    //
    // llama_memory_i
    //

    llama_memory_context_ptr init_batch(
            llama_batch_allocr & balloc,
            uint32_t n_ubatch,
            bool embd_all) override;

    llama_memory_context_ptr init_full() override;

    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;

    bool get_can_shift() const override;

    void clear(bool data) override;

    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
    void seq_keep(llama_seq_id seq_id)                                                          override;
    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;

    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
    llama_pos seq_pos_max(llama_seq_id seq_id) const override;

    // state write/load

    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1)       override;

    //
    // llama_kv_cache_unified_iswa specific API
    //

    llama_kv_cache_unified * get_base() const;
    llama_kv_cache_unified * get_swa () const;

private:
    const llama_hparams & hparams;

    const bool unified;

    std::unique_ptr<llama_kv_cache_unified> kv_base;
    std::unique_ptr<llama_kv_cache_unified> kv_swa;
};

class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
public:
    using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;

    // used for errors
    llama_kv_cache_unified_iswa_context(llama_memory_status status);

    // used to create a full-cache context
    llama_kv_cache_unified_iswa_context(
            llama_kv_cache_unified_iswa * kv);

    // used to create an update context
    llama_kv_cache_unified_iswa_context(
            llama_kv_cache_unified_iswa * kv,
            llama_context * lctx,
            bool optimize);

    // used to create a batch processing context from a batch
    llama_kv_cache_unified_iswa_context(
            llama_kv_cache_unified_iswa * kv,
            slot_info_vec_t sinfos_base,
            slot_info_vec_t sinfos_swa,
            std::vector<llama_ubatch> ubatches);

    virtual ~llama_kv_cache_unified_iswa_context();

    //
    // llama_memory_context_i
    //

    bool next()  override;
    bool apply() override;

    llama_memory_status  get_status() const override;
    const llama_ubatch & get_ubatch() const override;

    //
    // llama_kv_cache_unified_iswa_context specific API
    //

    const llama_kv_cache_unified_context * get_base() const;
    const llama_kv_cache_unified_context * get_swa()  const;

private:
    //llama_kv_cache_unified_iswa * kv;

    // the index of the next ubatch to process
    size_t i_next = 0;

    std::vector<llama_ubatch> ubatches;

    const llama_memory_context_ptr ctx_base;
    const llama_memory_context_ptr ctx_swa;

    const llama_memory_status status;
};