| import torch |
| from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig |
| from transformers.models.auto.auto_factory import _BaseAutoModelClass |
| from .MemoryCell import MemoryCell |
| from .RecurrentWrapper import RecurrentWrapper |
| from .PreTrainedRMTConfig import PreTrainedRMTConfig |
|
|
|
|
| |
| class RecurrentMemoryTransformer(PreTrainedModel): |
| """ |
| Recurrent Memory Transformer Model Class |
| A transformer model that processes long context in segments and retains information using memory |
| """ |
| |
| config_class = PreTrainedRMTConfig |
| auto_model_class = "AutoModelForCausalLM" |
| |
| |
| _keys_to_ignore_on_load_missing = [r"position_ids"] |
| |
| |
| AUTO_MAP = { |
| "AutoModelForCausalLM": "RecurrentMemoryTransformer", |
| } |
| |
| def __init__(self, config, base_model=None): |
| """ |
| Initialization |
| |
| Parameters |
| ---------- |
| config : PreTrainedRMTConfig |
| Model configuration |
| base_model : PreTrainedModel, optional |
| Base transformer model |
| """ |
| super().__init__(config) |
| |
| |
| if base_model is None: |
| |
| if not hasattr(config, "base_model_type"): |
| raise ValueError("configにbase_model_typeが指定されていません。RMTの設定にはベースモデルタイプが必要です。") |
| base_model_type = config.base_model_type |
| |
| |
| base_config = AutoConfig.from_pretrained(base_model_type) |
| |
| |
| rmt_specific_params = ['model_type', 'is_memory_all', 'max_n_segments', 'input_seg_len', |
| 'output_seg_len', 'align', 'num_mem_tokens', 'base_model_type'] |
| for key, value in config.__dict__.items(): |
| if key not in rmt_specific_params and not key.startswith('_'): |
| setattr(base_config, key, value) |
| |
| |
| base_model = AutoModelForCausalLM.from_config(base_config) |
| |
| |
| memory_cell = MemoryCell(base_model, config.num_mem_tokens) |
| self.recurrent_wrapper = RecurrentWrapper( |
| memory_cell=memory_cell, |
| is_memory_all=config.is_memory_all, |
| max_n_segments=config.max_n_segments, |
| input_seg_len=config.input_seg_len, |
| output_seg_len=config.output_seg_len, |
| align=config.align |
| ) |
| |
| def get_base_model(self): |
| """ |
| Get the base model |
| """ |
| return self.recurrent_wrapper.memory_cell.model |
| |
| def forward(self, input_ids=None, attention_mask=None, labels=None, labels_mask=None, |
| inputs_embeds=None, output_attentions=None, output_hidden_states=None): |
| """ |
| Forward pass of the model |
| |
| Parameters |
| ---------- |
| input_ids : torch.Tensor, optional |
| Input tensor |
| attention_mask : torch.Tensor, optional |
| Attention mask |
| labels : torch.Tensor, optional |
| Label tensor |
| labels_mask : torch.Tensor, optional |
| Label mask |
| inputs_embeds : torch.Tensor, optional |
| Input embeddings |
| output_attentions : bool, optional |
| Whether to output attention weights |
| output_hidden_states : bool, optional |
| Whether to output hidden states |
| """ |
| forward_kwargs = {} |
| if input_ids is not None: |
| forward_kwargs["input_ids"] = input_ids |
| if labels is not None: |
| forward_kwargs["labels"] = labels |
| if attention_mask is not None: |
| forward_kwargs["attention_mask"] = attention_mask |
| if labels_mask is not None: |
| forward_kwargs["labels_mask"] = labels_mask |
| if inputs_embeds is not None: |
| forward_kwargs["inputs_embeds"] = inputs_embeds |
| if output_attentions is not None: |
| forward_kwargs["output_attentions"] = output_attentions |
| if output_hidden_states is not None: |
| forward_kwargs["output_hidden_states"] = output_hidden_states |
| |
| |
| |
| |
| out = self.recurrent_wrapper.forward(**forward_kwargs) |
| """ |
| # デバッグ出力を削除(または必要に応じてコメント化) |
| # print(out["loss"]) |
| |
| # 分散環境で損失が二重計算されないよう、ワールドサイズで割る |
| # これは処理済みの場合は不要なので、環境変数などで制御することも可能 |
| if torch.distributed.is_initialized() and "loss" in out and out["loss"] is not None: |
| # 既にDeepSpeedが処理している可能性があるため、確認が必要 |
| # テスト目的で一時的に追加(実際の環境に合わせて調整が必要) |
| # world_size = torch.distributed.get_world_size() |
| # out["loss"] = out["loss"] / world_size |
| pass |
| """ |
| return out |
| |
| def generate(self, **kwargs): |
| """ |
| Text generation |
| """ |
| return self.recurrent_wrapper.generate(**kwargs) |
| |
| def generate_with_tokenizer(self, tokenizer, input_text, **kwargs): |
| """ |
| Text generation using tokenizer |
| """ |
| return self.recurrent_wrapper.generate_with_tokenizer(tokenizer, input_text, **kwargs) |
| |
| def get_input_embeddings(self): |
| """ |
| Get input embeddings |
| """ |
| return self.get_base_model().get_input_embeddings() |
| |
| def set_input_embeddings(self, embeddings): |
| """ |
| Set input embeddings |
| """ |
| self.get_base_model().set_input_embeddings(embeddings) |
| |
| def get_output_embeddings(self): |
| """ |
| Get output embeddings |
| """ |
| return self.get_base_model().get_output_embeddings() |
| |
| def resize_token_embeddings(self, new_num_tokens): |
| """ |
| Resize token embeddings |
| """ |
| self.get_base_model().resize_token_embeddings(new_num_tokens) |
| return self.get_input_embeddings() |
|
|
| RecurrentMemoryTransformer.register_for_auto_class("AutoModelForCausalLM") |