Fast Neural Style Transfer β€” Starry Night

This repository contains weights for a Fast Neural Style Transfer network based on Johnson et al. It is trained on the COCO val2017 dataset to instantly apply Vincent van Gogh's The Starry Night style to any input image.

Style Transfer Preview

Content Image Stylized Output

How to Use Programmatically

You can run inference using the official huggingface_hub utility library. The script automatically downloads your weights file directly from the cloud and applies the necessary ImageNet normalization matching the training routine.

Dependencies

Ensure you have the required packages installed:

pip install torch torchvision pillow huggingface_hub

Inference Script (inference.py)

Save the following code as inference.py. You can run it via terminal with python inference.py your_image.jpg.

import sys
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
from huggingface_hub import hf_hub_download

# ── CONFIG ───────────────────────────────────────────────────
REPO_ID  = "Rohanify/Brawnz-StyleTransferSN"
FILENAME = "pytorch_model.bin"
IMG_SIZE = 512
DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
# ─────────────────────────────────────────────────────────────

# ── NATIVE PYTORCH NETWORK DEFINITION ────────────────────────

def conv_bn_relu(in_c, out_c, k, stride=1, pad=0):
    return nn.Sequential(
        nn.ReflectionPad2d(pad),
        nn.Conv2d(in_c, out_c, k, stride),
        nn.InstanceNorm2d(out_c),
        nn.ReLU(inplace=True),
    )

class ResBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(c, c, 3),
            nn.InstanceNorm2d(c),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(c, c, 3),
            nn.InstanceNorm2d(c),
        )
    def forward(self, x):
        return x + self.block(x)

class TransformNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            conv_bn_relu(3,   32,  9, pad=4),
            conv_bn_relu(32,  64,  3, stride=2, pad=1),
            conv_bn_relu(64, 128,  3, stride=2, pad=1),
            ResBlock(128), ResBlock(128), ResBlock(128),
            ResBlock(128), ResBlock(128),
            nn.Upsample(scale_factor=2, mode="nearest"),
            conv_bn_relu(128, 64, 3, pad=1),
            nn.Upsample(scale_factor=2, mode="nearest"),
            conv_bn_relu(64, 32, 3, pad=1),
            nn.ReflectionPad2d(4),
            nn.Conv2d(32, 3, 9),
            nn.Tanh(),
        )
    def forward(self, x):
        return self.net(x)

# ── LOAD INPUT IMAGE ─────────────────────────────────────────
if len(sys.argv) < 2:
    print("Usage: python inference.py path_to_input_image.jpg")
    sys.exit(1)

input_path = sys.argv[1]
output_path = "output_styled.jpg"

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

img = Image.open(input_path).convert("RGB")
x   = transform(img).unsqueeze(0).to(DEVICE)

# ── SECURE FILE DOWNLOAD & STATE LOAD ────────────────────────
print("Downloading weights from Hugging Face Hub...")
weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

model = TransformNet().to(DEVICE)
model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
model.eval()
print(f"Weights successfully loaded on: {DEVICE}")

# ── RUN INFERENCE ────────────────────────────────────────────
print("Processing style transfer...")
with torch.no_grad():
    out = model(x)

save_image(out[0] * 0.5 + 0.5, output_path)
print(f"Success! Styled image saved to: {output_path}")
Downloads last month
88
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support