Dua Rajper commited on
Commit
ba3eabc
·
verified ·
1 Parent(s): d4efc1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -13
app.py CHANGED
@@ -6,10 +6,15 @@ import streamlit as st
6
  # Load the Stable Diffusion pipeline
7
  @st.cache(allow_output_mutation=True)
8
  def load_pipeline():
9
- pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
10
- pipeline = pipeline.to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
11
  return pipeline
12
 
 
13
  def main():
14
  st.title("Stable Diffusion Image Generator")
15
  st.write("Generate images from text prompts using Stable Diffusion 2.1")
@@ -21,17 +26,19 @@ def main():
21
  prompt = st.text_input("Enter your text prompt", "")
22
 
23
  # Generate button
24
- if st.button("Generate"):
25
- if not prompt:
26
- st.warning("Please enter a prompt first.")
27
- return
28
-
29
- st.write("Generating your image...")
30
- with torch.no_grad():
31
- image = pipeline(prompt).images[0]
32
-
33
- st.write("Generated Image:")
34
- st.image(image, use_column_width=True)
 
 
35
 
36
  if __name__ == "__main__":
37
  main()
 
6
  # Load the Stable Diffusion pipeline
7
  @st.cache(allow_output_mutation=True)
8
  def load_pipeline():
9
+ pipeline = StableDiffusionPipeline.from_pretrained(
10
+ "CompVis/stable-diffusion-v1-4",
11
+ torch_dtype=torch.float32 # Change to float32 for CPU
12
+ )
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ pipeline.to(device)
15
  return pipeline
16
 
17
+
18
  def main():
19
  st.title("Stable Diffusion Image Generator")
20
  st.write("Generate images from text prompts using Stable Diffusion 2.1")
 
26
  prompt = st.text_input("Enter your text prompt", "")
27
 
28
  # Generate button
29
+ if st.button("Generate"):
30
+ if not prompt:
31
+ st.warning("Please enter a prompt first.")
32
+ return
33
+
34
+ st.write("Generating your image...")
35
+ with torch.no_grad():
36
+ result = pipeline(prompt) # Returns a dictionary
37
+ image = result.images[0] # Extract the first generated image
38
+
39
+ st.write("Generated Image:")
40
+ st.image(image, use_column_width=True)
41
+
42
 
43
  if __name__ == "__main__":
44
  main()