| from datasets import load_dataset |
| from diffusers import StableDiffusionPipeline |
| import torch |
| import os |
| import json |
| from PIL import Image |
| import argparse |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Generate images from I2P dataset") |
|
|
| parser.add_argument("--output_dir", type=str, help="Output directory") |
| parser.add_argument("--model_path", type=str, help="Path to model checkpoint", default="CompVis/stable-diffusion-v1-4") |
| parser.add_argument("--seed", type=int, help="Seed for random number generator", default=0) |
| parser.add_argument("--prompt", type=str, help="Prompt for image generation") |
| parser.add_argument("--mode", type=str, help="Mode for image generation", choices=["train","test"], default="train") |
| parser.add_argument("--num_train_images", type=int, help="Number of images to generate for training", default=1000) |
| args = parser.parse_args() |
| return args |
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| if(args.mode == "test"): |
|
|
| metadata = [] |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| pipe = StableDiffusionPipeline.from_pretrained(args.model_path, safety_checker=None, torch_dtype=torch.float16).to(device) |
| |
| for i in range(args.num_train_images): |
| while True: |
| output = pipe(prompt=args.prompt) |
| image = output.images[0] |
| nsfw = output.nsfw_content_detected |
| |
| if isinstance(nsfw, list): |
| nsfw = nsfw[0] |
| |
| if not nsfw: |
| break |
| image.save(os.path.join(args.output_dir, f"{args.prompt}_{i}.png")) |
| |
|
|
| elif(args.mode == "train"): |
| metadata = [] |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| os.makedirs(os.path.join(args.output_dir, "train"), exist_ok=True) |
|
|
| pipe = StableDiffusionPipeline.from_pretrained(args.model_path, safety_checker=None, torch_dtype=torch.float16).to(device) |
| |
| for i in range(args.num_train_images): |
| while True: |
| output = pipe(prompt=args.prompt) |
| image = output.images[0] |
| nsfw = output.nsfw_content_detected |
| |
| if isinstance(nsfw, list): |
| nsfw = nsfw[0] |
| |
| if not nsfw: |
| break |
| image.save(os.path.join(args.output_dir, "train", f"{args.prompt}_{i}.png")) |
| |
| metadata.append({"file_name": f"train/{args.prompt}_{i}.png", "text": args.prompt}) |
|
|
| with open(os.path.join(args.output_dir, 'metadata.jsonl'), 'w') as f: |
| for m in metadata: |
| f.write(json.dumps(m) + "\n") |
|
|
|
|
| |