vectorllm_v1 / test_hf.py
insomnia7's picture
Upload folder using huggingface_hub
bcc6605 verified
Raw
History Blame Contribute Delete
5.68 kB
import argparse
import json
import re
import sys
from pathlib import Path
import torch
from PIL import Image, ImageDraw
from transformers import AutoModel, AutoProcessor, GenerationConfig, StoppingCriteria, StoppingCriteriaList
COORD_PATTERN = re.compile(r"<([xy])(\d+)>")
DEFAULT_PROMPT = "<pixel>\nPlease extract the regular vector contour of the central building in the image, start from the left top corner and in clockwise."
DEFAULT_RAW_PROMPT = (
"<|im_start|>user\n<pixel>\nPlease extract the regular vector contour of the central building in the image, "
"start from the left top corner and in clockwise.<|im_end|>\n<|im_start|>assistant\n"
)
class StopWordStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_word):
self.tokenizer = tokenizer
self.stop_word = stop_word
self.length = len(self.stop_word)
def __call__(self, input_ids, *args, **kwargs) -> bool:
cur_text = self.tokenizer.decode(input_ids[0])
cur_text = cur_text.replace("\r", "").replace("\n", "")
return cur_text[-self.length:] == self.stop_word
def get_stop_criteria(tokenizer, stop_words=None):
stop_words = stop_words or []
stop_criteria = StoppingCriteriaList()
for word in stop_words:
stop_criteria.append(StopWordStoppingCriteria(tokenizer, word))
return stop_criteria
def parse_args():
parser = argparse.ArgumentParser(description="Run HF VectorLLM single-image inference.")
parser.add_argument("model_path")
parser.add_argument("image_path")
parser.add_argument("--save-dir", default="./work_dirs/vectorllm_hf_0407_test")
return parser.parse_args()
def bootstrap_local_registry(model_path):
model_path = Path(model_path).expanduser().resolve()
parent = str(model_path.parent)
package_name = model_path.name
if parent not in sys.path:
sys.path.insert(0, parent)
__import__(package_name)
def decode_generated_text(output, model_inputs, tokenizer):
input_ids = model_inputs.get("input_ids")
input_length = input_ids.shape[-1] if input_ids is not None else 0
generated_ids = output.sequences[0][input_length:]
if generated_ids.numel() == 0:
generated_ids = output.sequences[0]
return tokenizer.decode(generated_ids, skip_special_tokens=False).strip()
def parse_polygon(text):
points = []
pending_x = None
for axis, raw_value in COORD_PATTERN.findall(text):
value = int(raw_value)
if axis == "x":
pending_x = value
elif pending_x is not None:
points.append((pending_x, value))
pending_x = None
return points
def recover_polygon(points, image_size, grid_size=128):
image_w, image_h = image_size
ret = []
for x_coord, y_coord in points:
x_val = (x_coord + 0.5) / grid_size * image_w
y_val = (y_coord + 0.5) / grid_size * image_h
ret.append((x_val, y_val))
return ret
def draw_polygon(image, polygon):
rendered = image.convert("RGBA")
overlay = Image.new("RGBA", rendered.size, (0, 0, 0, 0))
drawer = ImageDraw.Draw(overlay)
if len(polygon) >= 3:
drawer.polygon(polygon, outline=(255, 0, 255, 255), fill=(0, 255, 255, 90), width=2)
for x_coord, y_coord in polygon:
drawer.ellipse((x_coord - 2, y_coord - 2, x_coord + 2, y_coord + 2), fill=(255, 165, 0, 255))
return Image.alpha_composite(rendered, overlay).convert("RGB")
def main():
args = parse_args()
save_dir = Path(args.save_dir).expanduser().resolve()
save_dir.mkdir(parents=True, exist_ok=True)
bootstrap_local_registry(args.model_path)
model = AutoModel.from_pretrained(
args.model_path,
trust_remote_code=False,
torch_dtype=torch.bfloat16,
)
processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=False)
tokenizer = processor.tokenizer
if torch.cuda.is_available():
model = model.cuda()
model.eval()
image = Image.open(args.image_path).convert("RGB")
model_inputs = processor(text=[DEFAULT_RAW_PROMPT], images=[image], return_tensors="pt")
model_inputs = {
key: value.to(model.device) if torch.is_tensor(value) else value
for key, value in model_inputs.items()
}
stop_criteria = get_stop_criteria(tokenizer, ["<|im_end|>", "<|endoftext|>"])
output = model.generate(
**model_inputs,
generation_config=GenerationConfig(
max_new_tokens=640,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
temperature=0.0,
top_k=1,
),
bos_token_id=tokenizer.bos_token_id,
stopping_criteria=stop_criteria,
output_hidden_states=False,
return_dict_in_generate=True,
do_sample=False,
temperature=0.0,
top_k=1,
)
text = decode_generated_text(output, model_inputs, tokenizer)
grid_polygon = parse_polygon(text)
polygon = recover_polygon(grid_polygon, image.size)
overlay = draw_polygon(image, polygon)
overlay_path = save_dir / "overlay.png"
report_path = save_dir / "report.json"
overlay.save(overlay_path)
report_path.write_text(
json.dumps(
{
"text": text,
"grid_polygon": grid_polygon,
"polygon": polygon,
"overlay_path": str(overlay_path),
},
ensure_ascii=False,
indent=2,
)
+ "\n",
encoding="utf-8",
)
print(report_path)
if __name__ == "__main__":
main()