| #pragma once |
|
|
| #include "llama-batch.h" |
| #include "llama-graph.h" |
| #include "llama-memory.h" |
|
|
| #include <map> |
| #include <set> |
| #include <vector> |
|
|
| |
| |
| |
|
|
| |
| |
| class llama_memory_recurrent : public llama_memory_i { |
| public: |
| llama_memory_recurrent( |
| const llama_model & model, |
| ggml_type type_r, |
| ggml_type type_s, |
| bool offload, |
| uint32_t mem_size, |
| uint32_t n_seq_max, |
| const layer_filter_cb & filter); |
|
|
| ~llama_memory_recurrent() = default; |
|
|
| |
| |
| |
|
|
| llama_memory_context_ptr init_batch( |
| llama_batch_allocr & balloc, |
| uint32_t n_ubatch, |
| bool embd_all) override; |
|
|
| llama_memory_context_ptr init_full() override; |
|
|
| llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; |
|
|
| void clear(bool data) override; |
|
|
| bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; |
| void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; |
| void seq_keep(llama_seq_id seq_id) override; |
| void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; |
| void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; |
|
|
| llama_pos seq_pos_min(llama_seq_id seq_id) const override; |
| llama_pos seq_pos_max(llama_seq_id seq_id) const override; |
|
|
| std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override; |
|
|
| bool prepare(const std::vector<llama_ubatch> & ubatches); |
|
|
| |
| bool find_slot(const llama_ubatch & ubatch); |
|
|
| bool get_can_shift() const override; |
|
|
| |
|
|
| void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; |
| void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; |
|
|
| uint32_t head = 0; |
| uint32_t size = 0; |
| uint32_t used = 0; |
|
|
| |
| uint32_t n = 0; |
|
|
| |
| int32_t rs_z = -1; |
|
|
| |
| struct mem_cell { |
| llama_pos pos = -1; |
| int32_t src = -1; |
| int32_t src0 = -1; |
| int32_t tail = -1; |
|
|
| std::set<llama_seq_id> seq_id; |
|
|
| bool has_seq_id(const llama_seq_id & id) const { |
| return seq_id.find(id) != seq_id.end(); |
| } |
|
|
| bool is_empty() const { |
| return seq_id.empty(); |
| } |
|
|
| bool is_same_seq(const mem_cell & other) const { |
| return seq_id == other.seq_id; |
| } |
| }; |
|
|
| std::vector<mem_cell> cells; |
|
|
| |
| std::vector<ggml_tensor *> r_l; |
| std::vector<ggml_tensor *> s_l; |
|
|
| private: |
| |
| const llama_hparams & hparams; |
|
|
| const uint32_t n_seq_max = 1; |
|
|
| std::vector<ggml_context_ptr> ctxs; |
| std::vector<ggml_backend_buffer_ptr> bufs; |
|
|
| size_t total_size() const; |
|
|
| size_t size_r_bytes() const; |
| size_t size_s_bytes() const; |
|
|
| void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const; |
| void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const; |
|
|
| bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); |
| bool state_read_data(llama_io_read_i & io, uint32_t cell_count); |
| }; |
|
|
| class llama_memory_recurrent_context : public llama_memory_context_i { |
| public: |
| |
| llama_memory_recurrent_context(llama_memory_status status); |
|
|
| |
| llama_memory_recurrent_context( |
| llama_memory_recurrent * mem); |
|
|
| |
| llama_memory_recurrent_context( |
| llama_memory_recurrent * mem, |
| std::vector<llama_ubatch> ubatches); |
|
|
| virtual ~llama_memory_recurrent_context(); |
|
|
| |
| |
| |
|
|
| bool next() override; |
| bool apply() override; |
|
|
| llama_memory_status get_status() const override; |
| const llama_ubatch & get_ubatch() const override; |
|
|
| |
| |
| |
|
|
| uint32_t get_n_rs() const; |
| uint32_t get_head() const; |
| int32_t get_rs_z() const; |
| uint32_t get_size() const; |
|
|
| ggml_tensor * get_r_l(int32_t il) const; |
| ggml_tensor * get_s_l(int32_t il) const; |
|
|
| int32_t s_copy(int i) const; |
|
|
| private: |
| const llama_memory_status status; |
|
|
| llama_memory_recurrent * mem; |
|
|
| size_t i_next = 0; |
|
|
| std::vector<llama_ubatch> ubatches; |
|
|
| |
| |
| |
| |
|
|
| const bool is_full = false; |
| }; |
|
|