Roman79 commited on
Commit
6c2301f
·
verified ·
1 Parent(s): 4e11df5

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +93 -0
  2. gen_a2b_fp16.pth +3 -0
  3. gen_b2a_fp16.pth +3 -0
  4. model.py +69 -0
  5. requirements.txt +5 -3
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision.transforms as T
6
+ from model import load_generator
7
+
8
+ st.set_page_config(
9
+ page_title="Summer ↔ Winter CycleGAN",
10
+ page_icon="🏔️",
11
+ layout="centered",
12
+ )
13
+
14
+ st.title("🏔️ Summer ↔ Winter Translation")
15
+ st.markdown(
16
+ "Upload a landscape photo and convert it between **summer** and **winter** "
17
+ "using a CycleGAN trained on Yosemite & Alpine datasets."
18
+ )
19
+
20
+ @st.cache_resource
21
+ def get_generators():
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ gen_a2b = load_generator("gen_a2b_fp16.pth", device) # лето → зима
24
+ gen_b2a = load_generator("gen_b2a_fp16.pth", device) # зима → лето
25
+ return gen_a2b, gen_b2a, device
26
+
27
+ with st.spinner("Loading model weights..."):
28
+ gen_a2b, gen_b2a, device = get_generators()
29
+
30
+ st.success(f"Model loaded on **{device}**", icon="✅")
31
+
32
+ MEAN = (0.5, 0.5, 0.5)
33
+ STD = (0.5, 0.5, 0.5)
34
+
35
+ to_tensor = T.Compose([
36
+ T.Resize((256, 256)),
37
+ T.ToTensor(),
38
+ T.Normalize(MEAN, STD),
39
+ ])
40
+
41
+ def to_pil(tensor):
42
+ img = tensor.squeeze(0).cpu().float()
43
+ for i, (m, s) in enumerate(zip(MEAN, STD)):
44
+ img[i] = img[i] * s + m
45
+ img = torch.clamp(img, 0, 1)
46
+ return T.ToPILImage()(img)
47
+
48
+ direction = st.radio(
49
+ "Translation direction",
50
+ ["☀️ Summer → ❄️ Winter", "❄️ Winter → ☀️ Summer"],
51
+ horizontal=True,
52
+ )
53
+
54
+ uploaded = st.file_uploader(
55
+ "Upload your landscape photo (JPG / PNG)",
56
+ type=["jpg", "jpeg", "png"],
57
+ )
58
+
59
+ if uploaded is not None:
60
+ input_img = Image.open(uploaded).convert("RGB")
61
+
62
+ col1, col2 = st.columns(2)
63
+ with col1:
64
+ st.subheader("Input")
65
+ st.image(input_img, use_container_width=True)
66
+
67
+ with st.spinner("Translating..."):
68
+ tensor = to_tensor(input_img).unsqueeze(0).to(device)
69
+ generator = gen_a2b if "Summer" in direction.split("→")[0] else gen_b2a
70
+ with torch.no_grad():
71
+ output_tensor = generator(tensor)
72
+ output_img = to_pil(output_tensor)
73
+
74
+ with col2:
75
+ st.subheader("Output")
76
+ st.image(output_img, use_container_width=True)
77
+
78
+ from io import BytesIO
79
+ buf = BytesIO()
80
+ output_img.save(buf, format="PNG")
81
+ st.download_button(
82
+ label="⬇️ Download result",
83
+ data=buf.getvalue(),
84
+ file_name="translated.png",
85
+ mime="image/png",
86
+ )
87
+
88
+ st.markdown("---")
89
+ st.markdown(
90
+ "**Model:** CycleGAN Generator (ResNet-6 blocks, 64 channels) · "
91
+ "**Training data:** Yosemite summer2winter · "
92
+ "**Test data:** Alpine landscapes (Unsplash) · "
93
+ )
gen_a2b_fp16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34cd9e43b8d797b91fe295fb6cb81569d88eb1246a2838e23681395e85035fc0
3
+ size 15686583
gen_b2a_fp16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ef6e8fa18dac09a8df1452e82616c64f77d476c472df794e082ed381abd6f00
3
+ size 15686583
model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ResidualBlock(nn.Module):
5
+ def __init__(self, channels):
6
+ super().__init__()
7
+ self.block = nn.Sequential(
8
+ nn.ReflectionPad2d(1),
9
+ nn.Conv2d(channels, channels, 3),
10
+ nn.InstanceNorm2d(channels),
11
+ nn.ReLU(inplace=True),
12
+ nn.ReflectionPad2d(1),
13
+ nn.Conv2d(channels, channels, 3),
14
+ nn.InstanceNorm2d(channels),
15
+ )
16
+
17
+ def forward(self, x):
18
+ return x + self.block(x)
19
+
20
+ class ResNetGenerator(nn.Module):
21
+ def __init__(self, in_channels=3, out_channels=3, n_filters=64, n_res_blocks=6):
22
+ super().__init__()
23
+ model = [
24
+ nn.ReflectionPad2d(3),
25
+ nn.Conv2d(in_channels, n_filters, 7),
26
+ nn.InstanceNorm2d(n_filters),
27
+ nn.ReLU(inplace=True),
28
+
29
+ nn.Conv2d(n_filters, n_filters * 2, 3, stride=2, padding=1),
30
+ nn.InstanceNorm2d(n_filters * 2),
31
+ nn.ReLU(inplace=True),
32
+
33
+ nn.Conv2d(n_filters * 2, n_filters * 4, 3, stride=2, padding=1),
34
+ nn.InstanceNorm2d(n_filters * 4),
35
+ nn.ReLU(inplace=True),
36
+ ]
37
+
38
+ for _ in range(n_res_blocks):
39
+ model.append(ResidualBlock(n_filters * 4))
40
+
41
+ model += [
42
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
43
+ nn.Conv2d(n_filters * 4, n_filters * 2, 3, padding=1),
44
+ nn.InstanceNorm2d(n_filters * 2),
45
+ nn.ReLU(inplace=True),
46
+
47
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
48
+ nn.Conv2d(n_filters * 2, n_filters, 3, padding=1),
49
+ nn.InstanceNorm2d(n_filters),
50
+ nn.ReLU(inplace=True),
51
+
52
+ nn.ReflectionPad2d(3),
53
+ nn.Conv2d(n_filters, out_channels, 7),
54
+ nn.Tanh()
55
+ ]
56
+
57
+ self.model = nn.Sequential(*model)
58
+
59
+ def forward(self, x):
60
+ return self.model(x)
61
+
62
+ @torch.no_grad()
63
+ def load_generator(path, device="cpu"):
64
+ gen = ResNetGenerator()
65
+ state_dict = torch.load(path, map_location="cpu", weights_only=True)
66
+ state_dict = {k: v.float() for k, v in state_dict.items()}
67
+ gen.load_state_dict(state_dict)
68
+ gen.to(device).eval()
69
+ return gen
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
1
+ streamlit>=1.32.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ Pillow>=9.0.0
5
+ numpy>=1.24.0