cyrus-spc commited on
Commit
f73e37d
·
1 Parent(s): 166c54b

Adding test

Browse files
Files changed (2) hide show
  1. main.py +39 -0
  2. requirements.txt +18 -0
main.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
3
+ import gradio as gr
4
+
5
+ # Check if CUDA is available
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ print(f"Using device: {device}")
8
+
9
+ if device == "cuda":
10
+ torch.cuda.empty_cache()
11
+
12
+ model_id = "stabilityai/stable-diffusion-2-1"
13
+
14
+ # Use appropriate dtype based on device
15
+ if device == "cuda":
16
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
17
+ else:
18
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
19
+
20
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
21
+ pipe = pipe.to(device)
22
+
23
+ def generate_image(prompt, width, height):
24
+ image = pipe(prompt, width=int(width), height=int(height)).images[0]
25
+ return image
26
+
27
+ iface = gr.Interface(
28
+ fn=generate_image,
29
+ inputs=[
30
+ gr.Textbox(label="Prompt", value="a house in front of the ocean and a dog is running in the field"),
31
+ gr.Number(label="Width", value=1000),
32
+ gr.Number(label="Height", value=1000)
33
+ ],
34
+ outputs=gr.Image(type="pil"),
35
+ title="Stable Diffusion Image Generator"
36
+ )
37
+
38
+ if __name__ == "__main__":
39
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch with CUDA support (for GPU acceleration)
2
+ # Install from PyTorch official channel for CUDA 11.8 or 12.1
3
+ torch>=2.0.0
4
+ torchvision>=0.15.0
5
+ torchaudio>=2.0.0
6
+
7
+ # Diffusion models and related packages
8
+ diffusers
9
+ transformers
10
+ accelerate
11
+ scipy
12
+ safetensors
13
+
14
+ # Optional for better performance
15
+ xformers
16
+
17
+ # Gradio for web interface
18
+ gradio