| import torch |
| from bee.agi_config import BeeAGIConfig |
| from bee.memory import BeeMemoryBank |
|
|
| cfg = BeeAGIConfig( |
| vocab_size=1000, hidden_size=256, num_hidden_layers=4, |
| num_attention_heads=4, num_key_value_heads=2, intermediate_size=512, |
| num_experts=4, num_experts_per_tok=2, moe_layers=[1, 3], |
| state_space_layers=[2], state_dim=16, memory_slots=64, |
| memory_dim=256, reasoning_depth=2, compression_latent_dim=64, |
| domain_expert_count=4, domains=['programming','quantum','general','math'], |
| max_position_embeddings=512, |
| ) |
| mem = BeeMemoryBank(cfg) |
| x = torch.randn(2, 16, 256) |
|
|
| batch, seq_len, _ = x.shape |
| device = x.device |
| if mem.memory.size(0) != batch: |
| mem.memory = mem.memory[:1].expand(batch, -1, -1).clone().to(device) |
| mem.memory_age = mem.memory_age[:1].expand(batch, -1).clone().to(device) |
| mem.memory_usage = mem.memory_usage[:1].expand(batch, -1).clone().to(device) |
|
|
| compressed = mem.write_proj(x) |
| gates = torch.sigmoid(mem.write_gate(x)).squeeze(-1) |
|
|
| print('memory shape:', mem.memory.shape) |
| print('memory_usage shape:', mem.memory_usage.shape) |
| print('gates shape:', gates.shape) |
|
|
| t = 0 |
| print('gates[:, t] shape:', gates[:, t].shape) |
| print('(1.0 - mem.memory_usage) shape:', (1.0 - mem.memory_usage).shape) |
| print('gates[:, t] unsqueeze(1) shape:', gates[:, t].unsqueeze(1).shape) |
| print('gates[:, t] unsqueeze(-1) shape:', gates[:, t].unsqueeze(-1).shape) |
|
|