File size: 1,271 Bytes
db82745
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from bee.register import register
from bee.config import BeeConfig
from bee.modeling_bee import BeeForCausalLM, BeeAttention
register()
import torch

orig_attn_forward = BeeAttention.forward

call_count = 0

def debug_attn_forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False, **kwargs):
    global call_count
    call_count += 1
    cc = call_count
    if past_key_value is not None:
        pk_shape = past_key_value[0].shape if hasattr(past_key_value[0], 'shape') else 'N/A'
        print(f'[{cc}] START: past_kv={pk_shape}, q_len={hidden_states.shape[1]}')
    else:
        print(f'[{cc}] START: past_kv=None, q_len={hidden_states.shape[1]}')
    out = orig_attn_forward(self, hidden_states, attention_mask, position_ids, past_key_value, use_cache, **kwargs)
    print(f'[{cc}] END: attn_output={out[0].shape}')
    return out

BeeAttention.forward = debug_attn_forward

cfg = BeeConfig(vocab_size=1000, hidden_size=256, num_hidden_layers=2, num_attention_heads=4, intermediate_size=512)
model = BeeForCausalLM(cfg)
input_ids = torch.randint(0, cfg.vocab_size, (1, 8))
try:
    outputs = model.generate(input_ids, max_new_tokens=2, do_sample=False)
    print('done')
except Exception as e:
    print('ERROR:', e)