File size: 12,012 Bytes
3e426e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel
from rembg import remove
from PIL import Image
import torch
from ip_adapter import IPAdapterXL
from ip_adapter.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images
from PIL import Image, ImageChops
from PIL import ImageEnhance
import numpy as np
import glob
import os
import copy


def get_preprocess_image(ip_model, target_image_path, texture_image_path, depth_image_path, target_image_save_root):
    target_image = Image.open(target_image_path).convert('RGB')
    rm_bg = remove(target_image)
    target_mask = rm_bg.convert("RGB").point(lambda x: 0 if x < 1 else 255).convert('L').convert('RGB')# Convert mask to grayscale
    noise = np.random.randint(0, 256, target_image.size + (3,), dtype=np.uint8)
    noise_image = Image.fromarray(noise)
    mask_target_img = ImageChops.lighter(target_image, target_mask)
    invert_target_mask = ImageChops.invert(target_mask)

    gray_target_image = target_image.convert('L').convert('RGB')
    gray_target_image = ImageEnhance.Brightness(gray_target_image)

    # Adjust brightness
    # The factor 1.0 means original brightness, greater than 1.0 makes the image brighter. Adjust this if the image is too dim
    factor = 1.0  # Try adjusting this to get the desired brightness

    gray_target_image = gray_target_image.enhance(factor)
    grayscale_img = ImageChops.darker(gray_target_image, target_mask)
    img_black_mask = ImageChops.darker(target_image, invert_target_mask)
    grayscale_init_img = ImageChops.lighter(img_black_mask, grayscale_img)

    ip_image = Image.open(texture_image_path)
    np_image = np.array(Image.open(depth_image_path))
    np_image = (np_image / 256).astype('uint8')
    depth_map = Image.fromarray(np_image).resize((1024,1024))
    init_img = grayscale_init_img
    init_img = init_img.resize((1024,1024))
    mask = target_mask.resize((1024, 1024))

    num_samples = 1
    images = ip_model.generate(pil_image=ip_image, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=0.9, num_samples=num_samples, num_inference_steps=30, seed=42)
    
    target_image_class = target_image_path.split("/")[-2]
    target_image_name = os.path.splitext(os.path.basename(target_image_path))[0]
    material_name = os.path.splitext(os.path.basename(texture_image_path))[0]
    target_save_path = os.path.join(target_image_save_root, target_image_class)
    if os.path.exists(target_save_path) is False:
            os.mkdir(target_save_path)
    images[0].save(os.path.join(target_save_path, target_image_name+"AND"+material_name+".png") )

def find_image_files(folder1, image_type=".png"):
    png_files = []  # 创建一个空列表来存储找到的png文件的路径
    # 遍历第一个文件夹
    for root, dirs, files in os.walk(folder1):
        for file in files:
            if file.endswith(image_type):
                png_files.append(os.path.join(root, file))  # 将相对路径添加到列表中
    return png_files

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid
root = "/home/zhourixin/OOD_Folder/CODE/Image-Transfer/zest_code/"
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
image_encoder_path = root+"models/image_encoder"
ip_ckpt = root+"sdxl_models/ip-adapter_sdxl_vit-h.bin"
controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
device = "cuda"
np.random.seed(2024)

torch.cuda.empty_cache()

# load SDXL pipeline
controlnet = ControlNetModel.from_pretrained(controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=torch.float16).to(device)
pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
    base_model_path,
    controlnet=controlnet,
    use_safetensors=True,
    torch_dtype=torch.float16,
    add_watermarker=False,
).to(device)
pipe.unet = register_cross_attention_hook(pipe.unet)
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)

"""

得到相应的image list

"""
bronze_root_train = r"/home/zhourixin/OOD_Folder/Dataset/OODdata/bronze_ID_and_OOD/bronze2NotLine/images_split/train"
bronze_root_val = r"/home/zhourixin/OOD_Folder/Dataset/OODdata/bronze_ID_and_OOD/bronze2NotLine/images_split/val"
bronze_root_test = r"/home/zhourixin/OOD_Folder/Dataset/OODdata/bronze_ID_and_OOD/bronze2NotLine/images_split/test"
bronze_image_list = []
bronze_root_train_list = find_image_files(bronze_root_train)
bronze_root_val_list = find_image_files(bronze_root_val)
bronze_root_test_list = find_image_files(bronze_root_test)
bronze_image_list = bronze_root_train_list + bronze_root_val_list
bronze_depth_root = r"/home/zhourixin/OOD_Folder/Dataset/OODdata/images_largescale/depth_image/bronze_not_line_Ding_Gui"
bronze_depth_list = find_image_files(bronze_depth_root)

container_image_root = r"/home/zhourixin/OOD_Folder/Dataset/OODdata/images_largescale/imagenet-21k-container-refine/images"
container_image_list = find_image_files(container_image_root, image_type=".JPEG")
container_depth_root = r"/home/zhourixin/OOD_Folder/Dataset/OODdata/images_largescale/depth_image/container"
container_depth_list = find_image_files(container_depth_root)

"""

制做青铜器structure,容器material的图像, 容器有51024张,青铜器有5711张(train 2278+val 572=2850, test 2861)

"""
result_save_root = r"/home/zhourixin/OOD_Folder/Dataset/OODdata/images_largescale/transfer_dataset/bronze_structure_container_material/train_val"
structure_list = copy.deepcopy(bronze_image_list)
structure_depth_list = copy.deepcopy(bronze_depth_list)
material_list = copy.deepcopy(container_image_list)

# material_list = material_list[0:int(len(material_list)/2)] # 25512张


structure_list_np = np.array(structure_list)
structure_depth_list_np  = np.array(structure_depth_list)
material_list_np = np.array(material_list)
sample_ratio = int(len(material_list)/len(structure_list))





for index, structure_image in enumerate(structure_list):

    structure_name = os.path.splitext(os.path.basename(structure_image))[0]
    structure_folder = structure_image.split("/")[-2]
    structure_depth_image = os.path.join(bronze_depth_root, structure_folder, structure_name+".png")

    if index<len(structure_list)-276:
        indices = np.random.choice(len(material_list_np), size=sample_ratio+1, replace=False)
    else:
        indices = np.random.choice(len(material_list_np), size=sample_ratio, replace=False)
    
    material_list_selected = material_list_np[indices].tolist()
    # 移除已经抽取的图像
    material_list_np = np.delete(material_list_np, indices)

    for select_material_image in material_list_selected:
        # # 模型生成
        get_preprocess_image(ip_model, structure_image, select_material_image, structure_depth_image, result_save_root)
        pass

print("容器structure,青铜器material图像 生成完成") 



# """
# 先做容器structure,青铜器material的图像, 容器有51024张,青铜器有5711张
# """
# result_save_root = r"E:\zhourixin\OOD_folder\DATA\transfer_dataset\container_structure_bronze_material"
# structure_list = copy.deepcopy(container_image_list)
# structure_depth_list = copy.deepcopy(container_depth_list)
# material_list = copy.deepcopy(bronze_image_list)

# structure_list_np = np.array(structure_list)
# structure_depth_list_np  = np.array(structure_depth_list)
# sample_ratio = int(len(structure_list)/len(material_list))

# for index, material_image in enumerate(material_list):
    
#     if index<len(material_list)-375:
#         indices = np.random.choice(len(structure_list_np), size=sample_ratio+1, replace=False)
#     else:
#         indices = np.random.choice(len(structure_list_np), size=sample_ratio, replace=False)
#     structure_list_selected = structure_list_np[indices].tolist()
#     # 移除已经抽取的图像
#     structure_list_np = np.delete(structure_list_np, indices)


#     for select_structure_image in structure_list_selected:
#         structure_name = os.path.splitext(os.path.basename(select_structure_image))[0]
#         structure_folder = select_structure_image.split("\\")[-2]
#         select_structure_depth_image = os.path.join(container_depth_root, structure_folder, structure_name+".png")

#         # # 模型生成
#         get_preprocess_image(ip_model, select_structure_image, material_image, select_structure_depth_image, result_save_root)
# print("容器structure,青铜器material图像 生成完成")
    
    




# textures = [tex.split('/')[-1].replace('.png', '') for tex in glob.glob(root+'demo_assets/material_exemplars/*.png')]
# objs = [obj.split('/')[-1].replace('.png', '') for obj in glob.glob(root+'demo_assets/input_imgs/*.png')]

# for texture in textures:
#     for obj in objs:
#         # target_image_path = 'demo_assets/input_imgs/' + obj + '.png'  # Replace with your image path
#         target_image_path = 'demo_assets/' + obj + '.png'  # Replace with your image path
#         target_image = Image.open(target_image_path).convert('RGB')
#         rm_bg = remove(target_image)
#         # output.save(output_path)
#         target_mask = rm_bg.convert("RGB").point(lambda x: 0 if x < 1 else 255).convert('L').convert('RGB')# Convert mask to grayscale
#         target_mask.save('demo_assets/temp_file/' + 'target_mask.png')

#         # Ensure mask is the same size as image

#         # mask = ImageChops.invert(mask)
#         # Generate random noise for the size of the image
#         noise = np.random.randint(0, 256, target_image.size + (3,), dtype=np.uint8)
#         noise_image = Image.fromarray(noise)
#         mask_target_img = ImageChops.lighter(target_image, target_mask)
#         invert_target_mask = ImageChops.invert(target_mask)


#         gray_target_image = target_image.convert('L').convert('RGB')
#         gray_target_image = ImageEnhance.Brightness(gray_target_image)

#         # Adjust brightness
#         # The factor 1.0 means original brightness, greater than 1.0 makes the image brighter. Adjust this if the image is too dim
#         factor = 1.0  # Try adjusting this to get the desired brightness

#         gray_target_image = gray_target_image.enhance(factor)
#         grayscale_img = ImageChops.darker(gray_target_image, target_mask)
#         img_black_mask = ImageChops.darker(target_image, invert_target_mask)
#         grayscale_init_img = ImageChops.lighter(img_black_mask, grayscale_img)
#         init_img = grayscale_init_img
#         init_img.save('demo_assets/temp_file/' + 'init_img.png')

#         # ip_image = Image.open("demo_assets/material_exemplars/" + texture + ".png")
#         ip_image = Image.open("demo_assets/" + texture + ".png")
#         # np_image = np.array(Image.open('demo_assets/depths/' + obj + '.png'))
#         # np_image = np.array(Image.open(r'demo_assets\depths\101001.png'))
#         np_image = np.array(Image.open(r'demo_assets\depths\n02824058_184.png'))
        
#         np_image = (np_image / 256).astype('uint8')

#         depth_map = Image.fromarray(np_image).resize((1024,1024))

#         init_img = init_img.resize((1024,1024))
#         mask = target_mask.resize((1024, 1024))
        
#         num_samples = 1
#         images = ip_model.generate(pil_image=ip_image, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=0.9, num_samples=num_samples, num_inference_steps=30, seed=42)


#         # images[0].save('demo_assets/output_images/' + obj + '_' + texture + '.png' )
#         images[0].save('demo_assets/output_images/' + 'result.png' )