| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer,AutoConfig |
|
|
| def load_tokenizer(model_name: str, is_hf: bool=False): |
| if not is_hf: |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer.model_max_length = 2048 |
| else: |
| if "mamba" in model_name or "mpt" in model_name: |
| tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") |
| else: |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.pad_token_id = tokenizer.eos_token_id |
| return tokenizer |
|
|
| from fla.models import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel |
| print(DeltaNetConfig.model_type) |
| AutoConfig.register("delta_net",DeltaNetConfig) |
| AutoModelForCausalLM.register(DeltaNetConfig,DeltaNetForCausalLM) |
|
|
| from opencompass.models.fla2.models import mask_deltanetConfig,mask_deltanetForCausalLM |
| print(mask_deltanetConfig.model_type) |
| AutoConfig.register("mask_deltanet",mask_deltanetConfig) |
| AutoModelForCausalLM.register(mask_deltanetConfig,mask_deltanetForCausalLM) |
| |
| model_path = "/mnt/jfzn/msj/train_exp/mask_deltanet_1B_rank4" |
| |
| |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True, |
| device_map="cuda", |
| ) |
| tokenizer = load_tokenizer(model_path, is_hf=True) |
| prompt = "What is the official language of China?" |
| inputs = tokenizer(prompt, return_tensors="pt").to("cuda") |
|
|
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=100, |
| do_sample=False, |
| pad_token_id=tokenizer.eos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |