| from bee.register import register |
| from bee.config import BeeConfig |
| from bee.modeling_bee import BeeForCausalLM, BeeAttention |
| register() |
| import torch |
|
|
| orig_attn_forward = BeeAttention.forward |
|
|
| call_count = 0 |
|
|
| def debug_attn_forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False, **kwargs): |
| global call_count |
| call_count += 1 |
| cc = call_count |
| if past_key_value is not None: |
| pk_shape = past_key_value[0].shape if hasattr(past_key_value[0], 'shape') else 'N/A' |
| print(f'[{cc}] START: past_kv={pk_shape}, q_len={hidden_states.shape[1]}') |
| else: |
| print(f'[{cc}] START: past_kv=None, q_len={hidden_states.shape[1]}') |
| out = orig_attn_forward(self, hidden_states, attention_mask, position_ids, past_key_value, use_cache, **kwargs) |
| print(f'[{cc}] END: attn_output={out[0].shape}') |
| return out |
|
|
| BeeAttention.forward = debug_attn_forward |
|
|
| cfg = BeeConfig(vocab_size=1000, hidden_size=256, num_hidden_layers=2, num_attention_heads=4, intermediate_size=512) |
| model = BeeForCausalLM(cfg) |
| input_ids = torch.randint(0, cfg.vocab_size, (1, 8)) |
| try: |
| outputs = model.generate(input_ids, max_new_tokens=2, do_sample=False) |
| print('done') |
| except Exception as e: |
| print('ERROR:', e) |
|
|