ShiftedBronzes / zest_code /make_transfer_dataset_train_val.py
AnonymousUser20's picture
Upload 944 files
3e426e9 verified
from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel
from rembg import remove
from PIL import Image
import torch
from ip_adapter import IPAdapterXL
from ip_adapter.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images
from PIL import Image, ImageChops
from PIL import ImageEnhance
import numpy as np
import glob
import os
import copy
def get_preprocess_image(ip_model, target_image_path, texture_image_path, depth_image_path, target_image_save_root):
target_image = Image.open(target_image_path).convert('RGB')
rm_bg = remove(target_image)
target_mask = rm_bg.convert("RGB").point(lambda x: 0 if x < 1 else 255).convert('L').convert('RGB')# Convert mask to grayscale
noise = np.random.randint(0, 256, target_image.size + (3,), dtype=np.uint8)
noise_image = Image.fromarray(noise)
mask_target_img = ImageChops.lighter(target_image, target_mask)
invert_target_mask = ImageChops.invert(target_mask)
gray_target_image = target_image.convert('L').convert('RGB')
gray_target_image = ImageEnhance.Brightness(gray_target_image)
# Adjust brightness
# The factor 1.0 means original brightness, greater than 1.0 makes the image brighter. Adjust this if the image is too dim
factor = 1.0 # Try adjusting this to get the desired brightness
gray_target_image = gray_target_image.enhance(factor)
grayscale_img = ImageChops.darker(gray_target_image, target_mask)
img_black_mask = ImageChops.darker(target_image, invert_target_mask)
grayscale_init_img = ImageChops.lighter(img_black_mask, grayscale_img)
ip_image = Image.open(texture_image_path)
np_image = np.array(Image.open(depth_image_path))
np_image = (np_image / 256).astype('uint8')
depth_map = Image.fromarray(np_image).resize((1024,1024))
init_img = grayscale_init_img
init_img = init_img.resize((1024,1024))
mask = target_mask.resize((1024, 1024))
num_samples = 1
images = ip_model.generate(pil_image=ip_image, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=0.9, num_samples=num_samples, num_inference_steps=30, seed=42)
target_image_class = target_image_path.split("/")[-2]
target_image_name = os.path.splitext(os.path.basename(target_image_path))[0]
material_name = os.path.splitext(os.path.basename(texture_image_path))[0]
target_save_path = os.path.join(target_image_save_root, target_image_class)
if os.path.exists(target_save_path) is False:
os.mkdir(target_save_path)
images[0].save(os.path.join(target_save_path, target_image_name+"AND"+material_name+".png") )
def find_image_files(folder1, image_type=".png"):
png_files = [] # 创建一个空列表来存储找到的png文件的路径
# 遍历第一个文件夹
for root, dirs, files in os.walk(folder1):
for file in files:
if file.endswith(image_type):
png_files.append(os.path.join(root, file)) # 将相对路径添加到列表中
return png_files
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
root = "/home/zhourixin/OOD_Folder/CODE/Image-Transfer/zest_code/"
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
image_encoder_path = root+"models/image_encoder"
ip_ckpt = root+"sdxl_models/ip-adapter_sdxl_vit-h.bin"
controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
device = "cuda"
np.random.seed(2024)
torch.cuda.empty_cache()
# load SDXL pipeline
controlnet = ControlNetModel.from_pretrained(controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=torch.float16).to(device)
pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
base_model_path,
controlnet=controlnet,
use_safetensors=True,
torch_dtype=torch.float16,
add_watermarker=False,
).to(device)
pipe.unet = register_cross_attention_hook(pipe.unet)
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)
"""
得到相应的image list
"""
bronze_root_train = r"/home/zhourixin/OOD_Folder/Dataset/OODdata/bronze_ID_and_OOD/bronze2NotLine/images_split/train"
bronze_root_val = r"/home/zhourixin/OOD_Folder/Dataset/OODdata/bronze_ID_and_OOD/bronze2NotLine/images_split/val"
bronze_root_test = r"/home/zhourixin/OOD_Folder/Dataset/OODdata/bronze_ID_and_OOD/bronze2NotLine/images_split/test"
bronze_image_list = []
bronze_root_train_list = find_image_files(bronze_root_train)
bronze_root_val_list = find_image_files(bronze_root_val)
bronze_root_test_list = find_image_files(bronze_root_test)
bronze_image_list = bronze_root_train_list + bronze_root_val_list
bronze_depth_root = r"/home/zhourixin/OOD_Folder/Dataset/OODdata/images_largescale/depth_image/bronze_not_line_Ding_Gui"
bronze_depth_list = find_image_files(bronze_depth_root)
container_image_root = r"/home/zhourixin/OOD_Folder/Dataset/OODdata/images_largescale/imagenet-21k-container-refine/images"
container_image_list = find_image_files(container_image_root, image_type=".JPEG")
container_depth_root = r"/home/zhourixin/OOD_Folder/Dataset/OODdata/images_largescale/depth_image/container"
container_depth_list = find_image_files(container_depth_root)
"""
制做青铜器structure,容器material的图像, 容器有51024张,青铜器有5711张(train 2278+val 572=2850, test 2861)
"""
result_save_root = r"/home/zhourixin/OOD_Folder/Dataset/OODdata/images_largescale/transfer_dataset/bronze_structure_container_material/train_val"
structure_list = copy.deepcopy(bronze_image_list)
structure_depth_list = copy.deepcopy(bronze_depth_list)
material_list = copy.deepcopy(container_image_list)
# material_list = material_list[0:int(len(material_list)/2)] # 25512张
structure_list_np = np.array(structure_list)
structure_depth_list_np = np.array(structure_depth_list)
material_list_np = np.array(material_list)
sample_ratio = int(len(material_list)/len(structure_list))
for index, structure_image in enumerate(structure_list):
structure_name = os.path.splitext(os.path.basename(structure_image))[0]
structure_folder = structure_image.split("/")[-2]
structure_depth_image = os.path.join(bronze_depth_root, structure_folder, structure_name+".png")
if index<len(structure_list)-276:
indices = np.random.choice(len(material_list_np), size=sample_ratio+1, replace=False)
else:
indices = np.random.choice(len(material_list_np), size=sample_ratio, replace=False)
material_list_selected = material_list_np[indices].tolist()
# 移除已经抽取的图像
material_list_np = np.delete(material_list_np, indices)
for select_material_image in material_list_selected:
# # 模型生成
get_preprocess_image(ip_model, structure_image, select_material_image, structure_depth_image, result_save_root)
pass
print("容器structure,青铜器material图像 生成完成")
# """
# 先做容器structure,青铜器material的图像, 容器有51024张,青铜器有5711张
# """
# result_save_root = r"E:\zhourixin\OOD_folder\DATA\transfer_dataset\container_structure_bronze_material"
# structure_list = copy.deepcopy(container_image_list)
# structure_depth_list = copy.deepcopy(container_depth_list)
# material_list = copy.deepcopy(bronze_image_list)
# structure_list_np = np.array(structure_list)
# structure_depth_list_np = np.array(structure_depth_list)
# sample_ratio = int(len(structure_list)/len(material_list))
# for index, material_image in enumerate(material_list):
# if index<len(material_list)-375:
# indices = np.random.choice(len(structure_list_np), size=sample_ratio+1, replace=False)
# else:
# indices = np.random.choice(len(structure_list_np), size=sample_ratio, replace=False)
# structure_list_selected = structure_list_np[indices].tolist()
# # 移除已经抽取的图像
# structure_list_np = np.delete(structure_list_np, indices)
# for select_structure_image in structure_list_selected:
# structure_name = os.path.splitext(os.path.basename(select_structure_image))[0]
# structure_folder = select_structure_image.split("\\")[-2]
# select_structure_depth_image = os.path.join(container_depth_root, structure_folder, structure_name+".png")
# # # 模型生成
# get_preprocess_image(ip_model, select_structure_image, material_image, select_structure_depth_image, result_save_root)
# print("容器structure,青铜器material图像 生成完成")
# textures = [tex.split('/')[-1].replace('.png', '') for tex in glob.glob(root+'demo_assets/material_exemplars/*.png')]
# objs = [obj.split('/')[-1].replace('.png', '') for obj in glob.glob(root+'demo_assets/input_imgs/*.png')]
# for texture in textures:
# for obj in objs:
# # target_image_path = 'demo_assets/input_imgs/' + obj + '.png' # Replace with your image path
# target_image_path = 'demo_assets/' + obj + '.png' # Replace with your image path
# target_image = Image.open(target_image_path).convert('RGB')
# rm_bg = remove(target_image)
# # output.save(output_path)
# target_mask = rm_bg.convert("RGB").point(lambda x: 0 if x < 1 else 255).convert('L').convert('RGB')# Convert mask to grayscale
# target_mask.save('demo_assets/temp_file/' + 'target_mask.png')
# # Ensure mask is the same size as image
# # mask = ImageChops.invert(mask)
# # Generate random noise for the size of the image
# noise = np.random.randint(0, 256, target_image.size + (3,), dtype=np.uint8)
# noise_image = Image.fromarray(noise)
# mask_target_img = ImageChops.lighter(target_image, target_mask)
# invert_target_mask = ImageChops.invert(target_mask)
# gray_target_image = target_image.convert('L').convert('RGB')
# gray_target_image = ImageEnhance.Brightness(gray_target_image)
# # Adjust brightness
# # The factor 1.0 means original brightness, greater than 1.0 makes the image brighter. Adjust this if the image is too dim
# factor = 1.0 # Try adjusting this to get the desired brightness
# gray_target_image = gray_target_image.enhance(factor)
# grayscale_img = ImageChops.darker(gray_target_image, target_mask)
# img_black_mask = ImageChops.darker(target_image, invert_target_mask)
# grayscale_init_img = ImageChops.lighter(img_black_mask, grayscale_img)
# init_img = grayscale_init_img
# init_img.save('demo_assets/temp_file/' + 'init_img.png')
# # ip_image = Image.open("demo_assets/material_exemplars/" + texture + ".png")
# ip_image = Image.open("demo_assets/" + texture + ".png")
# # np_image = np.array(Image.open('demo_assets/depths/' + obj + '.png'))
# # np_image = np.array(Image.open(r'demo_assets\depths\101001.png'))
# np_image = np.array(Image.open(r'demo_assets\depths\n02824058_184.png'))
# np_image = (np_image / 256).astype('uint8')
# depth_map = Image.fromarray(np_image).resize((1024,1024))
# init_img = init_img.resize((1024,1024))
# mask = target_mask.resize((1024, 1024))
# num_samples = 1
# images = ip_model.generate(pil_image=ip_image, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=0.9, num_samples=num_samples, num_inference_steps=30, seed=42)
# # images[0].save('demo_assets/output_images/' + obj + '_' + texture + '.png' )
# images[0].save('demo_assets/output_images/' + 'result.png' )