File size: 5,681 Bytes
bcc6605
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import argparse
import json
import re
import sys
from pathlib import Path

import torch
from PIL import Image, ImageDraw
from transformers import AutoModel, AutoProcessor, GenerationConfig, StoppingCriteria, StoppingCriteriaList


COORD_PATTERN = re.compile(r"<([xy])(\d+)>")
DEFAULT_PROMPT = "<pixel>\nPlease extract the regular vector contour of the central building in the image, start from the left top corner and in clockwise."
DEFAULT_RAW_PROMPT = (
    "<|im_start|>user\n<pixel>\nPlease extract the regular vector contour of the central building in the image, "
    "start from the left top corner and in clockwise.<|im_end|>\n<|im_start|>assistant\n"
)


class StopWordStoppingCriteria(StoppingCriteria):
    def __init__(self, tokenizer, stop_word):
        self.tokenizer = tokenizer
        self.stop_word = stop_word
        self.length = len(self.stop_word)

    def __call__(self, input_ids, *args, **kwargs) -> bool:
        cur_text = self.tokenizer.decode(input_ids[0])
        cur_text = cur_text.replace("\r", "").replace("\n", "")
        return cur_text[-self.length:] == self.stop_word


def get_stop_criteria(tokenizer, stop_words=None):
    stop_words = stop_words or []
    stop_criteria = StoppingCriteriaList()
    for word in stop_words:
        stop_criteria.append(StopWordStoppingCriteria(tokenizer, word))
    return stop_criteria


def parse_args():
    parser = argparse.ArgumentParser(description="Run HF VectorLLM single-image inference.")
    parser.add_argument("model_path")
    parser.add_argument("image_path")
    parser.add_argument("--save-dir", default="./work_dirs/vectorllm_hf_0407_test")
    return parser.parse_args()


def bootstrap_local_registry(model_path):
    model_path = Path(model_path).expanduser().resolve()
    parent = str(model_path.parent)
    package_name = model_path.name
    if parent not in sys.path:
        sys.path.insert(0, parent)
    __import__(package_name)


def decode_generated_text(output, model_inputs, tokenizer):
    input_ids = model_inputs.get("input_ids")
    input_length = input_ids.shape[-1] if input_ids is not None else 0
    generated_ids = output.sequences[0][input_length:]
    if generated_ids.numel() == 0:
        generated_ids = output.sequences[0]
    return tokenizer.decode(generated_ids, skip_special_tokens=False).strip()


def parse_polygon(text):
    points = []
    pending_x = None
    for axis, raw_value in COORD_PATTERN.findall(text):
        value = int(raw_value)
        if axis == "x":
            pending_x = value
        elif pending_x is not None:
            points.append((pending_x, value))
            pending_x = None
    return points


def recover_polygon(points, image_size, grid_size=128):
    image_w, image_h = image_size
    ret = []
    for x_coord, y_coord in points:
        x_val = (x_coord + 0.5) / grid_size * image_w
        y_val = (y_coord + 0.5) / grid_size * image_h
        ret.append((x_val, y_val))
    return ret


def draw_polygon(image, polygon):
    rendered = image.convert("RGBA")
    overlay = Image.new("RGBA", rendered.size, (0, 0, 0, 0))
    drawer = ImageDraw.Draw(overlay)
    if len(polygon) >= 3:
        drawer.polygon(polygon, outline=(255, 0, 255, 255), fill=(0, 255, 255, 90), width=2)
    for x_coord, y_coord in polygon:
        drawer.ellipse((x_coord - 2, y_coord - 2, x_coord + 2, y_coord + 2), fill=(255, 165, 0, 255))
    return Image.alpha_composite(rendered, overlay).convert("RGB")


def main():
    args = parse_args()
    save_dir = Path(args.save_dir).expanduser().resolve()
    save_dir.mkdir(parents=True, exist_ok=True)
    bootstrap_local_registry(args.model_path)

    model = AutoModel.from_pretrained(
        args.model_path,
        trust_remote_code=False,
        torch_dtype=torch.bfloat16,
    )
    processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=False)
    tokenizer = processor.tokenizer
    if torch.cuda.is_available():
        model = model.cuda()
    model.eval()

    image = Image.open(args.image_path).convert("RGB")
    model_inputs = processor(text=[DEFAULT_RAW_PROMPT], images=[image], return_tensors="pt")
    model_inputs = {
        key: value.to(model.device) if torch.is_tensor(value) else value
        for key, value in model_inputs.items()
    }
    stop_criteria = get_stop_criteria(tokenizer, ["<|im_end|>", "<|endoftext|>"])
    output = model.generate(
        **model_inputs,
        generation_config=GenerationConfig(
            max_new_tokens=640,
            do_sample=False,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
            temperature=0.0,
            top_k=1,
        ),
        bos_token_id=tokenizer.bos_token_id,
        stopping_criteria=stop_criteria,
        output_hidden_states=False,
        return_dict_in_generate=True,
        do_sample=False,
        temperature=0.0,
        top_k=1,
    )
    text = decode_generated_text(output, model_inputs, tokenizer)
    grid_polygon = parse_polygon(text)
    polygon = recover_polygon(grid_polygon, image.size)

    overlay = draw_polygon(image, polygon)
    overlay_path = save_dir / "overlay.png"
    report_path = save_dir / "report.json"
    overlay.save(overlay_path)
    report_path.write_text(
        json.dumps(
            {
                "text": text,
                "grid_polygon": grid_polygon,
                "polygon": polygon,
                "overlay_path": str(overlay_path),
            },
            ensure_ascii=False,
            indent=2,
        )
        + "\n",
        encoding="utf-8",
    )
    print(report_path)


if __name__ == "__main__":
    main()