Diffusers
Safetensors
File size: 4,701 Bytes
1b7e5b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import os
import random

import torch
import torchvision
import torchvision.transforms as TS
from PIL import Image
from ram import inference_ram
from ram.models import ram
from tqdm import tqdm
from transformers import (
    AutoModelForZeroShotObjectDetection,
    AutoProcessor,
    Blip2ForConditionalGeneration,
    Blip2Processor,
    CLIPTextModel,
    CLIPTokenizer,
)


torch.autograd.set_grad_enabled(False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser("Caption Generation script", add_help=False)
    parser.add_argument("--data_root", type=str, required=True, help="path to COCO")
    parser.add_argument("--save_root", type=str, required=True, help="path to save")
    parser.add_argument("--ram_checkpoint", type=str, required=True, help="path to save")
    args = parser.parse_args()

    # ram_checkpoint = '/root/.cache/huggingface/hub/models--xinyu1205--recognize_anything_model/snapshots/ebc52dc741e86466202a5ab8ab22eae6e7d48bf1/ram_swin_large_14m.pth'
    # data_root = '/mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017'
    # save_root = '/root/gligen_data'
    box_threshold = 0.25
    text_threshold = 0.2

    import torch.distributed as dist

    dist.init_process_group(backend="nccl", init_method="env://")
    local_rank = torch.distributed.get_rank() % torch.cuda.device_count()
    device = f"cuda:{local_rank}"
    torch.cuda.set_device(local_rank)

    ram_model = ram(pretrained=args.ram_checkpoint, image_size=384, vit="swin_l").cuda().eval()
    ram_processor = TS.Compose(
        [TS.Resize((384, 384)), TS.ToTensor(), TS.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
    )

    grounding_dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
    grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
        "IDEA-Research/grounding-dino-base"
    ).cuda()

    blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")
    blip2_model = Blip2ForConditionalGeneration.from_pretrained(
        "Salesforce/blip2-flan-t5-xxl", torch_dtype=torch.float16
    ).cuda()

    clip_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").cuda()
    clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

    image_paths = [os.path.join(args.data_root, x) for x in os.listdir(args.data_root)]
    random.shuffle(image_paths)

    for image_path in tqdm.tqdm(image_paths):
        pth_path = os.path.join(args.save_root, os.path.basename(image_path))
        if os.path.exists(pth_path):
            continue

        sample = {"file_path": os.path.basename(image_path), "annos": []}

        raw_image = Image.open(image_path).convert("RGB")

        res = inference_ram(ram_processor(raw_image).unsqueeze(0).cuda(), ram_model)

        text = res[0].replace(" |", ".")

        inputs = grounding_dino_processor(images=raw_image, text=text, return_tensors="pt")
        inputs = {k: v.cuda() for k, v in inputs.items()}
        outputs = grounding_dino_model(**inputs)

        results = grounding_dino_processor.post_process_grounded_object_detection(
            outputs,
            inputs["input_ids"],
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            target_sizes=[raw_image.size[::-1]],
        )
        boxes = results[0]["boxes"]
        labels = results[0]["labels"]
        scores = results[0]["scores"]
        indices = torchvision.ops.nms(boxes, scores, 0.5)
        boxes = boxes[indices]
        category_names = [labels[i] for i in indices]

        for i, bbox in enumerate(boxes):
            bbox = bbox.tolist()
            inputs = blip2_processor(images=raw_image.crop(bbox), return_tensors="pt")
            inputs = {k: v.cuda().to(torch.float16) for k, v in inputs.items()}
            outputs = blip2_model.generate(**inputs)
            caption = blip2_processor.decode(outputs[0], skip_special_tokens=True)
            inputs = clip_tokenizer(
                caption,
                padding="max_length",
                max_length=clip_tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )
            inputs = {k: v.cuda() for k, v in inputs.items()}
            text_embeddings_before_projection = clip_text_encoder(**inputs).pooler_output.squeeze(0)

            sample["annos"].append(
                {
                    "caption": caption,
                    "bbox": bbox,
                    "text_embeddings_before_projection": text_embeddings_before_projection,
                }
            )
        torch.save(sample, pth_path)