import torch from bee.agi_config import BeeAGIConfig from bee.memory import BeeMemoryBank cfg = BeeAGIConfig( vocab_size=1000, hidden_size=256, num_hidden_layers=4, num_attention_heads=4, num_key_value_heads=2, intermediate_size=512, num_experts=4, num_experts_per_tok=2, moe_layers=[1, 3], state_space_layers=[2], state_dim=16, memory_slots=64, memory_dim=256, reasoning_depth=2, compression_latent_dim=64, domain_expert_count=4, domains=['programming','quantum','general','math'], max_position_embeddings=512, ) mem = BeeMemoryBank(cfg) x = torch.randn(2, 16, 256) batch, seq_len, _ = x.shape device = x.device if mem.memory.size(0) != batch: mem.memory = mem.memory[:1].expand(batch, -1, -1).clone().to(device) mem.memory_age = mem.memory_age[:1].expand(batch, -1).clone().to(device) mem.memory_usage = mem.memory_usage[:1].expand(batch, -1).clone().to(device) compressed = mem.write_proj(x) gates = torch.sigmoid(mem.write_gate(x)).squeeze(-1) print('memory shape:', mem.memory.shape) print('memory_usage shape:', mem.memory_usage.shape) print('gates shape:', gates.shape) t = 0 print('gates[:, t] shape:', gates[:, t].shape) print('(1.0 - mem.memory_usage) shape:', (1.0 - mem.memory_usage).shape) print('gates[:, t] unsqueeze(1) shape:', gates[:, t].unsqueeze(1).shape) print('gates[:, t] unsqueeze(-1) shape:', gates[:, t].unsqueeze(-1).shape)