| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer,AutoConfig |
| |
|
| | def load_tokenizer(model_name: str, is_hf: bool=False): |
| | if not is_hf: |
| | tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| | tokenizer.model_max_length = 2048 |
| | else: |
| | if "mamba" in model_name or "mpt" in model_name: |
| | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") |
| | else: |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | tokenizer.pad_token = tokenizer.eos_token |
| | tokenizer.pad_token_id = tokenizer.eos_token_id |
| | return tokenizer |
| |
|
| | from fla.models import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel |
| | print(DeltaNetConfig.model_type) |
| | AutoConfig.register("delta_net",DeltaNetConfig) |
| | AutoModelForCausalLM.register(DeltaNetConfig,DeltaNetForCausalLM) |
| |
|
| | from opencompass.models.fla2.models import mask_deltanetConfig,mask_deltanetForCausalLM |
| | print(mask_deltanetConfig.model_type) |
| | AutoConfig.register("mask_deltanet",mask_deltanetConfig) |
| | AutoModelForCausalLM.register(mask_deltanetConfig,mask_deltanetForCausalLM) |
| | |
| | model_path = "/mnt/jfzn/msj/train_exp/mask_deltanet_1B_rank4" |
| | |
| | |
| |
|
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_path, |
| | torch_dtype=torch.bfloat16, |
| | trust_remote_code=True, |
| | device_map="cuda", |
| | ) |
| | tokenizer = load_tokenizer(model_path, is_hf=True) |
| | prompt = "What is the official language of China?" |
| | inputs = tokenizer(prompt, return_tensors="pt").to("cuda") |
| |
|
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=100, |
| | do_sample=False, |
| | pad_token_id=tokenizer.eos_token_id, |
| | eos_token_id=tokenizer.eos_token_id, |
| | ) |
| |
|
| | print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |