from bee.register import register from bee.config import BeeConfig from bee.modeling_bee import BeeForCausalLM, BeeAttention register() import torch orig_attn_forward = BeeAttention.forward call_count = 0 def debug_attn_forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False, **kwargs): global call_count call_count += 1 cc = call_count if past_key_value is not None: pk_shape = past_key_value[0].shape if hasattr(past_key_value[0], 'shape') else 'N/A' print(f'[{cc}] START: past_kv={pk_shape}, q_len={hidden_states.shape[1]}') else: print(f'[{cc}] START: past_kv=None, q_len={hidden_states.shape[1]}') out = orig_attn_forward(self, hidden_states, attention_mask, position_ids, past_key_value, use_cache, **kwargs) print(f'[{cc}] END: attn_output={out[0].shape}') return out BeeAttention.forward = debug_attn_forward cfg = BeeConfig(vocab_size=1000, hidden_size=256, num_hidden_layers=2, num_attention_heads=4, intermediate_size=512) model = BeeForCausalLM(cfg) input_ids = torch.randint(0, cfg.vocab_size, (1, 8)) try: outputs = model.generate(input_ids, max_new_tokens=2, do_sample=False) print('done') except Exception as e: print('ERROR:', e)