bee / scripts /debug_mem.py
ceocxx's picture
chore: deploy Bee API backend (bee/, Dockerfile, requirements)
db82745 verified
import torch
from bee.agi_config import BeeAGIConfig
from bee.memory import BeeMemoryBank
cfg = BeeAGIConfig(
vocab_size=1000, hidden_size=256, num_hidden_layers=4,
num_attention_heads=4, num_key_value_heads=2, intermediate_size=512,
num_experts=4, num_experts_per_tok=2, moe_layers=[1, 3],
state_space_layers=[2], state_dim=16, memory_slots=64,
memory_dim=256, reasoning_depth=2, compression_latent_dim=64,
domain_expert_count=4, domains=['programming','quantum','general','math'],
max_position_embeddings=512,
)
mem = BeeMemoryBank(cfg)
x = torch.randn(2, 16, 256)
batch, seq_len, _ = x.shape
device = x.device
if mem.memory.size(0) != batch:
mem.memory = mem.memory[:1].expand(batch, -1, -1).clone().to(device)
mem.memory_age = mem.memory_age[:1].expand(batch, -1).clone().to(device)
mem.memory_usage = mem.memory_usage[:1].expand(batch, -1).clone().to(device)
compressed = mem.write_proj(x)
gates = torch.sigmoid(mem.write_gate(x)).squeeze(-1)
print('memory shape:', mem.memory.shape)
print('memory_usage shape:', mem.memory_usage.shape)
print('gates shape:', gates.shape)
t = 0
print('gates[:, t] shape:', gates[:, t].shape)
print('(1.0 - mem.memory_usage) shape:', (1.0 - mem.memory_usage).shape)
print('gates[:, t] unsqueeze(1) shape:', gates[:, t].unsqueeze(1).shape)
print('gates[:, t] unsqueeze(-1) shape:', gates[:, t].unsqueeze(-1).shape)