File size: 1,271 Bytes
db82745 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | from bee.register import register
from bee.config import BeeConfig
from bee.modeling_bee import BeeForCausalLM, BeeAttention
register()
import torch
orig_attn_forward = BeeAttention.forward
call_count = 0
def debug_attn_forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False, **kwargs):
global call_count
call_count += 1
cc = call_count
if past_key_value is not None:
pk_shape = past_key_value[0].shape if hasattr(past_key_value[0], 'shape') else 'N/A'
print(f'[{cc}] START: past_kv={pk_shape}, q_len={hidden_states.shape[1]}')
else:
print(f'[{cc}] START: past_kv=None, q_len={hidden_states.shape[1]}')
out = orig_attn_forward(self, hidden_states, attention_mask, position_ids, past_key_value, use_cache, **kwargs)
print(f'[{cc}] END: attn_output={out[0].shape}')
return out
BeeAttention.forward = debug_attn_forward
cfg = BeeConfig(vocab_size=1000, hidden_size=256, num_hidden_layers=2, num_attention_heads=4, intermediate_size=512)
model = BeeForCausalLM(cfg)
input_ids = torch.randint(0, cfg.vocab_size, (1, 8))
try:
outputs = model.generate(input_ids, max_new_tokens=2, do_sample=False)
print('done')
except Exception as e:
print('ERROR:', e)
|