File size: 2,494 Bytes
ed1b365 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
class CodetteModelLoader:
def __init__(
self,
base_model="meta-llama/Llama-3.1-8B-Instruct",
adapters=None,
):
self.base_model_name = base_model
self.adapters = adapters or {}
self.model = None
self.tokenizer = None
self.active_adapter = None
self._load_base_model()
def _load_base_model(self):
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.base_model_name,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
quantization_config=quant_config,
device_map="auto",
trust_remote_code=True,
)
self.model = base_model
def load_adapters(self):
first = True
for name, path in self.adapters.items():
path = str(Path(path))
if first:
self.model = PeftModel.from_pretrained(
self.model,
path,
adapter_name=name,
is_trainable=False,
)
self.active_adapter = name
first = False
else:
self.model.load_adapter(
path,
adapter_name=name,
)
def set_active_adapter(self, name):
if name not in self.model.peft_config:
raise ValueError(f"Adapter not loaded: {name}")
self.model.set_adapter(name)
self.active_adapter = name
def format_messages(self, messages):
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
def tokenize(self, prompt):
return self.tokenizer(
prompt,
return_tensors="pt"
).to(self.model.device) |