vectorllm_v1 / convert_to_hf.py
insomnia7's picture
Upload folder using huggingface_hub
bcc6605 verified
Raw
History Blame Contribute Delete
13.2 kB
import argparse
import json
import os
import shutil
import sys
from pathlib import Path
import torch
from PIL import Image
from mmengine.config import Config
from mmengine.fileio import PetrelBackend, get_file_backend
from peft import PeftModel
from transformers import AutoModel, AutoProcessor, GenerationConfig, StoppingCriteria, StoppingCriteriaList
from xtuner.model.utils import guess_load_checkpoint
from xtuner.registry import BUILDER
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from projects.vectorllm_hf_0407.configuration_vectorllm import VectorLLMConfig
from projects.vectorllm_hf_0407.image_processing_vectorllm import VectorLLMImageProcessor
from projects.vectorllm_hf_0407.modeling_vectorllm import VectorLLMForCausalLM
from projects.vectorllm_hf_0407.processing_vectorllm import VectorLLMProcessor
DEFAULT_PROMPT = "<pixel>\nPlease extract the regular vector contour of the central building in the image, start from the left top corner and in clockwise."
DEFAULT_RAW_PROMPT = (
"<|im_start|>user\n<pixel>\nPlease extract the regular vector contour of the central building in the image, "
"start from the left top corner and in clockwise.<|im_end|>\n<|im_start|>assistant\n"
)
class StopWordStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_word):
self.tokenizer = tokenizer
self.stop_word = stop_word
self.length = len(self.stop_word)
def __call__(self, input_ids, *args, **kwargs) -> bool:
cur_text = self.tokenizer.decode(input_ids[0])
cur_text = cur_text.replace("\r", "").replace("\n", "")
return cur_text[-self.length:] == self.stop_word
def get_stop_criteria(tokenizer, stop_words=None):
stop_words = stop_words or []
stop_criteria = StoppingCriteriaList()
for word in stop_words:
stop_criteria.append(StopWordStoppingCriteria(tokenizer, word))
return stop_criteria
def parse_args():
parser = argparse.ArgumentParser(description="Convert xtuner VectorLLM checkpoint to HF format.")
parser.add_argument("config", help="xtuner config path")
parser.add_argument("pth_model", help="xtuner checkpoint path")
parser.add_argument("--save-path", required=True, help="HF export directory")
parser.add_argument("--demo-image", required=True, help="demo image for validation")
return parser.parse_args()
def seed_local_transformers_modules(local_model_dir):
model_dir = Path(local_model_dir).expanduser().resolve()
if not model_dir.is_dir():
return
hf_home = Path(os.environ.get("HF_HOME", "~/.cache/huggingface")).expanduser()
cache_root = hf_home / "modules" / "transformers_modules"
cache_root.mkdir(parents=True, exist_ok=True)
init_file = cache_root / "__init__.py"
if not init_file.exists():
init_file.write_text("", encoding="utf-8")
for py_file in model_dir.glob("*.py"):
target = cache_root / py_file.name
if not target.exists():
shutil.copy2(py_file, target)
def build_xtuner_model(config_path, pth_model):
cfg = Config.fromfile(config_path)
cfg.model.pretrained_pth = None
seed_local_transformers_modules(cfg.model.visual_encoder.pretrained_model_name_or_path)
seed_local_transformers_modules(cfg.model.llm.pretrained_model_name_or_path)
image_processor = BUILDER.build(cfg.image_processor)
model = BUILDER.build(cfg.model)
backend = get_file_backend(pth_model)
if isinstance(backend, PetrelBackend):
from xtuner.utils.fileio import patch_fileio
with patch_fileio():
state_dict = guess_load_checkpoint(pth_model)
else:
state_dict = guess_load_checkpoint(pth_model)
model.load_state_dict(state_dict, strict=False)
model.eval()
model.preparing_for_generation(metainfo={})
return cfg, model, image_processor
def build_hf_config(cfg, model):
vision_config_path = Path(
cfg.visual_encoder_name_or_path
if hasattr(cfg, "visual_encoder_name_or_path")
else cfg.model.visual_encoder.pretrained_model_name_or_path
)
llm_config = model.llm.config.to_dict()
vision_config = json.loads((vision_config_path / "config.json").read_text())
vision_args = vision_config.get("args", {})
if vision_args:
vision_args["dtype"] = "bfloat16"
vision_args["amp_dtype"] = "bfloat16"
vision_config["torch_dtype"] = "bfloat16"
llm_config["torch_dtype"] = "bfloat16"
pixel_token_idx = model.tokenizer("<pixel>", add_special_tokens=False).input_ids[0]
return VectorLLMConfig(
vision_config=vision_config,
llm_config=llm_config,
regression_size=cfg.model.regression_size,
projector_depth=cfg.model.get("projector_depth", 2),
visual_hidden_size=model.projector.model[0].in_features,
pixel_idx=pixel_token_idx,
pre_resize_size=432,
resized_size=cfg.model.regression_size[0],
patch_size=16,
do_normalize=False,
vision_model_name_or_path="",
llm_name_or_path="",
visual_peft_config=None,
vision_torch_dtype="bfloat16",
torch_dtype="bfloat16",
auto_map={
"AutoConfig": "configuration_vectorllm.VectorLLMConfig",
"AutoModel": "modeling_vectorllm.VectorLLMForCausalLM",
"AutoModelForCausalLM": "modeling_vectorllm.VectorLLMForCausalLM",
"AutoImageProcessor": "image_processing_vectorllm.VectorLLMImageProcessor",
"AutoProcessor": "processing_vectorllm.VectorLLMProcessor",
},
)
def maybe_merge_visual_encoder(visual_encoder):
if isinstance(visual_encoder, PeftModel):
return visual_encoder.merge_and_unload()
if hasattr(visual_encoder, "merge_and_unload"):
return visual_encoder.merge_and_unload()
return visual_encoder
def copy_remote_code(save_path):
src_root = REPO_ROOT / "projects" / "vectorllm_hf_0407"
dst_root = Path(save_path)
for src_path in src_root.glob("*.py"):
shutil.copy2(src_path, dst_root / src_path.name)
radio_src = src_root / "radio_bundle"
radio_dst = dst_root / "radio_bundle"
if radio_dst.exists():
shutil.rmtree(radio_dst, ignore_errors=True)
if radio_dst.exists():
raise RuntimeError(f"Failed to clean export directory: {radio_dst}")
shutil.copytree(radio_src, radio_dst)
def bootstrap_local_registry(model_path):
model_path = Path(model_path).expanduser().resolve()
parent = str(model_path.parent)
package_name = model_path.name
if parent not in sys.path:
sys.path.insert(0, parent)
__import__(package_name)
def decode_generated_text(output, model_inputs, tokenizer):
input_ids = model_inputs.get("input_ids")
input_length = input_ids.shape[-1] if input_ids is not None else 0
if hasattr(output, "sequences"):
generated_ids = output.sequences[0][input_length:]
if generated_ids.numel() == 0:
generated_ids = output.sequences[0]
else:
generated_ids = output[0][input_length:]
if generated_ids.numel() == 0:
generated_ids = output[0]
return tokenizer.decode(generated_ids, skip_special_tokens=False).strip()
def validate_export(save_path, demo_image_path, expected_text):
bootstrap_local_registry(save_path)
model = AutoModel.from_pretrained(
save_path,
trust_remote_code=False,
torch_dtype=torch.bfloat16,
)
processor = AutoProcessor.from_pretrained(save_path, trust_remote_code=False)
tokenizer = processor.tokenizer
if torch.cuda.is_available():
model = model.cuda()
model.eval()
image = Image.open(demo_image_path).convert("RGB")
model_inputs = processor(text=[DEFAULT_RAW_PROMPT], images=[image], return_tensors="pt")
model_inputs = {
key: value.to(model.device) if torch.is_tensor(value) else value
for key, value in model_inputs.items()
}
generation_config = GenerationConfig(
max_new_tokens=640,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
temperature=0.0,
top_k=1,
)
stop_criteria = get_stop_criteria(tokenizer, ["<|im_end|>", "<|endoftext|>"])
output = model.generate(
**model_inputs,
generation_config=generation_config,
bos_token_id=tokenizer.bos_token_id,
stopping_criteria=stop_criteria,
output_hidden_states=False,
return_dict_in_generate=True,
do_sample=False,
temperature=0.0,
top_k=1,
)
actual_text = decode_generated_text(output, model_inputs, tokenizer)
return {
"expected_text": expected_text,
"actual_text": actual_text,
"match": actual_text == expected_text,
}
def run_xtuner_reference(model, image_processor, demo_image_path):
image = Image.open(demo_image_path).convert("RGB")
resized_image = image.resize((432, 432), resample=Image.BICUBIC)
pixel_values = image_processor.preprocess(resized_image, return_tensors="pt")["pixel_values"][0]
if torch.cuda.is_available():
pixel_values = pixel_values.cuda()
model = model.cuda()
result = model.predict_forward(
pixel_values=pixel_values,
text_prompts="<image>\nPlease extract the regular vector contour of the central building in the image, start from the left top corner and in clockwise.",
)
return result["prediction"]
def main():
args = parse_args()
save_path = Path(args.save_path).expanduser().resolve()
save_path.mkdir(parents=True, exist_ok=True)
vision_backbone_dir = save_path / "vision_backbone"
if vision_backbone_dir.exists():
shutil.rmtree(vision_backbone_dir)
cfg, xtuner_model, xtuner_image_processor = build_xtuner_model(args.config, args.pth_model)
xtuner_reference_text = run_xtuner_reference(xtuner_model, xtuner_image_processor, args.demo_image)
hf_config = build_hf_config(cfg, xtuner_model)
vision_model = maybe_merge_visual_encoder(xtuner_model.visual_encoder)
hf_model = VectorLLMForCausalLM(
config=hf_config,
vision_model=vision_model,
language_model=xtuner_model.llm,
projector=xtuner_model.projector,
pos_embeds=xtuner_model.viusal_pos_embeddings,
)
hf_model = hf_model.to(dtype=torch.bfloat16)
hf_model.eval()
hf_model.generation_config = xtuner_model.llm.generation_config
hf_model.config.torch_dtype = "bfloat16"
image_processor = VectorLLMImageProcessor(
do_resize=True,
do_rescale=True,
do_normalize=False,
do_convert_rgb=True,
pre_resize_size=432,
resized_size=hf_config.resized_size,
patch_size=hf_config.patch_size,
auto_map={
"AutoImageProcessor": "image_processing_vectorllm.VectorLLMImageProcessor",
"AutoProcessor": "processing_vectorllm.VectorLLMProcessor",
},
)
tokenizer = xtuner_model.tokenizer
processor = VectorLLMProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
chat_template=tokenizer.chat_template,
)
demo_image = Image.open(args.demo_image).convert("RGB")
demo_inputs = processor(
text=[DEFAULT_RAW_PROMPT],
images=[demo_image],
return_tensors="pt",
)
if torch.cuda.is_available():
hf_model = hf_model.cuda()
demo_inputs = {
key: value.to(hf_model.device) if torch.is_tensor(value) else value
for key, value in demo_inputs.items()
}
generation_config = GenerationConfig(
max_new_tokens=640,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
temperature=0.0,
top_k=1,
)
stop_criteria = get_stop_criteria(tokenizer, ["<|im_end|>", "<|endoftext|>"])
output = hf_model.generate(
**demo_inputs,
generation_config=generation_config,
bos_token_id=tokenizer.bos_token_id,
stopping_criteria=stop_criteria,
output_hidden_states=False,
return_dict_in_generate=True,
do_sample=False,
temperature=0.0,
top_k=1,
)
pre_save_text = decode_generated_text(output, demo_inputs, tokenizer)
hf_model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
image_processor.save_pretrained(save_path)
processor.save_pretrained(save_path)
copy_remote_code(save_path)
validation = validate_export(str(save_path), args.demo_image, xtuner_reference_text)
validation["xtuner_reference_text"] = xtuner_reference_text
validation["pre_save_hf_text"] = pre_save_text
validation["pre_save_match_xtuner"] = pre_save_text == xtuner_reference_text
(save_path / "conversion_report.json").write_text(
json.dumps(validation, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
print(json.dumps(validation, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()