File size: 5,681 Bytes
bcc6605 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | import argparse
import json
import re
import sys
from pathlib import Path
import torch
from PIL import Image, ImageDraw
from transformers import AutoModel, AutoProcessor, GenerationConfig, StoppingCriteria, StoppingCriteriaList
COORD_PATTERN = re.compile(r"<([xy])(\d+)>")
DEFAULT_PROMPT = "<pixel>\nPlease extract the regular vector contour of the central building in the image, start from the left top corner and in clockwise."
DEFAULT_RAW_PROMPT = (
"<|im_start|>user\n<pixel>\nPlease extract the regular vector contour of the central building in the image, "
"start from the left top corner and in clockwise.<|im_end|>\n<|im_start|>assistant\n"
)
class StopWordStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_word):
self.tokenizer = tokenizer
self.stop_word = stop_word
self.length = len(self.stop_word)
def __call__(self, input_ids, *args, **kwargs) -> bool:
cur_text = self.tokenizer.decode(input_ids[0])
cur_text = cur_text.replace("\r", "").replace("\n", "")
return cur_text[-self.length:] == self.stop_word
def get_stop_criteria(tokenizer, stop_words=None):
stop_words = stop_words or []
stop_criteria = StoppingCriteriaList()
for word in stop_words:
stop_criteria.append(StopWordStoppingCriteria(tokenizer, word))
return stop_criteria
def parse_args():
parser = argparse.ArgumentParser(description="Run HF VectorLLM single-image inference.")
parser.add_argument("model_path")
parser.add_argument("image_path")
parser.add_argument("--save-dir", default="./work_dirs/vectorllm_hf_0407_test")
return parser.parse_args()
def bootstrap_local_registry(model_path):
model_path = Path(model_path).expanduser().resolve()
parent = str(model_path.parent)
package_name = model_path.name
if parent not in sys.path:
sys.path.insert(0, parent)
__import__(package_name)
def decode_generated_text(output, model_inputs, tokenizer):
input_ids = model_inputs.get("input_ids")
input_length = input_ids.shape[-1] if input_ids is not None else 0
generated_ids = output.sequences[0][input_length:]
if generated_ids.numel() == 0:
generated_ids = output.sequences[0]
return tokenizer.decode(generated_ids, skip_special_tokens=False).strip()
def parse_polygon(text):
points = []
pending_x = None
for axis, raw_value in COORD_PATTERN.findall(text):
value = int(raw_value)
if axis == "x":
pending_x = value
elif pending_x is not None:
points.append((pending_x, value))
pending_x = None
return points
def recover_polygon(points, image_size, grid_size=128):
image_w, image_h = image_size
ret = []
for x_coord, y_coord in points:
x_val = (x_coord + 0.5) / grid_size * image_w
y_val = (y_coord + 0.5) / grid_size * image_h
ret.append((x_val, y_val))
return ret
def draw_polygon(image, polygon):
rendered = image.convert("RGBA")
overlay = Image.new("RGBA", rendered.size, (0, 0, 0, 0))
drawer = ImageDraw.Draw(overlay)
if len(polygon) >= 3:
drawer.polygon(polygon, outline=(255, 0, 255, 255), fill=(0, 255, 255, 90), width=2)
for x_coord, y_coord in polygon:
drawer.ellipse((x_coord - 2, y_coord - 2, x_coord + 2, y_coord + 2), fill=(255, 165, 0, 255))
return Image.alpha_composite(rendered, overlay).convert("RGB")
def main():
args = parse_args()
save_dir = Path(args.save_dir).expanduser().resolve()
save_dir.mkdir(parents=True, exist_ok=True)
bootstrap_local_registry(args.model_path)
model = AutoModel.from_pretrained(
args.model_path,
trust_remote_code=False,
torch_dtype=torch.bfloat16,
)
processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=False)
tokenizer = processor.tokenizer
if torch.cuda.is_available():
model = model.cuda()
model.eval()
image = Image.open(args.image_path).convert("RGB")
model_inputs = processor(text=[DEFAULT_RAW_PROMPT], images=[image], return_tensors="pt")
model_inputs = {
key: value.to(model.device) if torch.is_tensor(value) else value
for key, value in model_inputs.items()
}
stop_criteria = get_stop_criteria(tokenizer, ["<|im_end|>", "<|endoftext|>"])
output = model.generate(
**model_inputs,
generation_config=GenerationConfig(
max_new_tokens=640,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
temperature=0.0,
top_k=1,
),
bos_token_id=tokenizer.bos_token_id,
stopping_criteria=stop_criteria,
output_hidden_states=False,
return_dict_in_generate=True,
do_sample=False,
temperature=0.0,
top_k=1,
)
text = decode_generated_text(output, model_inputs, tokenizer)
grid_polygon = parse_polygon(text)
polygon = recover_polygon(grid_polygon, image.size)
overlay = draw_polygon(image, polygon)
overlay_path = save_dir / "overlay.png"
report_path = save_dir / "report.json"
overlay.save(overlay_path)
report_path.write_text(
json.dumps(
{
"text": text,
"grid_polygon": grid_polygon,
"polygon": polygon,
"overlay_path": str(overlay_path),
},
ensure_ascii=False,
indent=2,
)
+ "\n",
encoding="utf-8",
)
print(report_path)
if __name__ == "__main__":
main()
|