| import argparse |
| import json |
| import re |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| from PIL import Image, ImageDraw |
| from transformers import AutoModel, AutoProcessor, GenerationConfig, StoppingCriteria, StoppingCriteriaList |
|
|
|
|
| COORD_PATTERN = re.compile(r"<([xy])(\d+)>") |
| DEFAULT_PROMPT = "<pixel>\nPlease extract the regular vector contour of the central building in the image, start from the left top corner and in clockwise." |
| DEFAULT_RAW_PROMPT = ( |
| "<|im_start|>user\n<pixel>\nPlease extract the regular vector contour of the central building in the image, " |
| "start from the left top corner and in clockwise.<|im_end|>\n<|im_start|>assistant\n" |
| ) |
|
|
|
|
| class StopWordStoppingCriteria(StoppingCriteria): |
| def __init__(self, tokenizer, stop_word): |
| self.tokenizer = tokenizer |
| self.stop_word = stop_word |
| self.length = len(self.stop_word) |
|
|
| def __call__(self, input_ids, *args, **kwargs) -> bool: |
| cur_text = self.tokenizer.decode(input_ids[0]) |
| cur_text = cur_text.replace("\r", "").replace("\n", "") |
| return cur_text[-self.length:] == self.stop_word |
|
|
|
|
| def get_stop_criteria(tokenizer, stop_words=None): |
| stop_words = stop_words or [] |
| stop_criteria = StoppingCriteriaList() |
| for word in stop_words: |
| stop_criteria.append(StopWordStoppingCriteria(tokenizer, word)) |
| return stop_criteria |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Run HF VectorLLM single-image inference.") |
| parser.add_argument("model_path") |
| parser.add_argument("image_path") |
| parser.add_argument("--save-dir", default="./work_dirs/vectorllm_hf_0407_test") |
| return parser.parse_args() |
|
|
|
|
| def bootstrap_local_registry(model_path): |
| model_path = Path(model_path).expanduser().resolve() |
| parent = str(model_path.parent) |
| package_name = model_path.name |
| if parent not in sys.path: |
| sys.path.insert(0, parent) |
| __import__(package_name) |
|
|
|
|
| def decode_generated_text(output, model_inputs, tokenizer): |
| input_ids = model_inputs.get("input_ids") |
| input_length = input_ids.shape[-1] if input_ids is not None else 0 |
| generated_ids = output.sequences[0][input_length:] |
| if generated_ids.numel() == 0: |
| generated_ids = output.sequences[0] |
| return tokenizer.decode(generated_ids, skip_special_tokens=False).strip() |
|
|
|
|
| def parse_polygon(text): |
| points = [] |
| pending_x = None |
| for axis, raw_value in COORD_PATTERN.findall(text): |
| value = int(raw_value) |
| if axis == "x": |
| pending_x = value |
| elif pending_x is not None: |
| points.append((pending_x, value)) |
| pending_x = None |
| return points |
|
|
|
|
| def recover_polygon(points, image_size, grid_size=128): |
| image_w, image_h = image_size |
| ret = [] |
| for x_coord, y_coord in points: |
| x_val = (x_coord + 0.5) / grid_size * image_w |
| y_val = (y_coord + 0.5) / grid_size * image_h |
| ret.append((x_val, y_val)) |
| return ret |
|
|
|
|
| def draw_polygon(image, polygon): |
| rendered = image.convert("RGBA") |
| overlay = Image.new("RGBA", rendered.size, (0, 0, 0, 0)) |
| drawer = ImageDraw.Draw(overlay) |
| if len(polygon) >= 3: |
| drawer.polygon(polygon, outline=(255, 0, 255, 255), fill=(0, 255, 255, 90), width=2) |
| for x_coord, y_coord in polygon: |
| drawer.ellipse((x_coord - 2, y_coord - 2, x_coord + 2, y_coord + 2), fill=(255, 165, 0, 255)) |
| return Image.alpha_composite(rendered, overlay).convert("RGB") |
|
|
|
|
| def main(): |
| args = parse_args() |
| save_dir = Path(args.save_dir).expanduser().resolve() |
| save_dir.mkdir(parents=True, exist_ok=True) |
| bootstrap_local_registry(args.model_path) |
|
|
| model = AutoModel.from_pretrained( |
| args.model_path, |
| trust_remote_code=False, |
| torch_dtype=torch.bfloat16, |
| ) |
| processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=False) |
| tokenizer = processor.tokenizer |
| if torch.cuda.is_available(): |
| model = model.cuda() |
| model.eval() |
|
|
| image = Image.open(args.image_path).convert("RGB") |
| model_inputs = processor(text=[DEFAULT_RAW_PROMPT], images=[image], return_tensors="pt") |
| model_inputs = { |
| key: value.to(model.device) if torch.is_tensor(value) else value |
| for key, value in model_inputs.items() |
| } |
| stop_criteria = get_stop_criteria(tokenizer, ["<|im_end|>", "<|endoftext|>"]) |
| output = model.generate( |
| **model_inputs, |
| generation_config=GenerationConfig( |
| max_new_tokens=640, |
| do_sample=False, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, |
| temperature=0.0, |
| top_k=1, |
| ), |
| bos_token_id=tokenizer.bos_token_id, |
| stopping_criteria=stop_criteria, |
| output_hidden_states=False, |
| return_dict_in_generate=True, |
| do_sample=False, |
| temperature=0.0, |
| top_k=1, |
| ) |
| text = decode_generated_text(output, model_inputs, tokenizer) |
| grid_polygon = parse_polygon(text) |
| polygon = recover_polygon(grid_polygon, image.size) |
|
|
| overlay = draw_polygon(image, polygon) |
| overlay_path = save_dir / "overlay.png" |
| report_path = save_dir / "report.json" |
| overlay.save(overlay_path) |
| report_path.write_text( |
| json.dumps( |
| { |
| "text": text, |
| "grid_polygon": grid_polygon, |
| "polygon": polygon, |
| "overlay_path": str(overlay_path), |
| }, |
| ensure_ascii=False, |
| indent=2, |
| ) |
| + "\n", |
| encoding="utf-8", |
| ) |
| print(report_path) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|