| --- |
| license: mit |
| pipeline_tag: image-to-image |
| tags: |
| - style-transfer |
| - pytorch |
| --- |
| |
| # Fast Neural Style Transfer β Starry Night |
|
|
| This repository contains weights for a Fast Neural Style Transfer network based on Johnson et al. It is trained on the COCO val2017 dataset to instantly apply Vincent van Gogh's *The Starry Night* style to any input image. |
|
|
| ## Style Transfer Preview |
|
|
|
|
| | Content Image | Stylized Output | |
| | :---: | :---: | |
| | <img src="before.jpg" width="400"> | <img src="after.jpg" width="400"> | |
|
|
| --- |
|
|
| ## How to Use Programmatically |
|
|
| You can run inference using the official `huggingface_hub` utility library. The script automatically downloads your weights file directly from the cloud and applies the necessary ImageNet normalization matching the training routine. |
|
|
| ### Dependencies |
| Ensure you have the required packages installed: |
| ```bash |
| pip install torch torchvision pillow huggingface_hub |
| ``` |
|
|
| ### Inference Script (`inference.py`) |
| Save the following code as `inference.py`. You can run it via terminal with `python inference.py your_image.jpg`. |
|
|
| ```python |
| import sys |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| from torchvision import transforms |
| from torchvision.utils import save_image |
| from huggingface_hub import hf_hub_download |
| |
| # ββ CONFIG βββββββββββββββββββββββββββββββββββββββββββββββββββ |
| REPO_ID = "Rohanify/Brawnz-StyleTransferSN" |
| FILENAME = "pytorch_model.bin" |
| IMG_SIZE = 512 |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| |
| # ββ NATIVE PYTORCH NETWORK DEFINITION ββββββββββββββββββββββββ |
| |
| def conv_bn_relu(in_c, out_c, k, stride=1, pad=0): |
| return nn.Sequential( |
| nn.ReflectionPad2d(pad), |
| nn.Conv2d(in_c, out_c, k, stride), |
| nn.InstanceNorm2d(out_c), |
| nn.ReLU(inplace=True), |
| ) |
| |
| class ResBlock(nn.Module): |
| def __init__(self, c): |
| super().__init__() |
| self.block = nn.Sequential( |
| nn.ReflectionPad2d(1), |
| nn.Conv2d(c, c, 3), |
| nn.InstanceNorm2d(c), |
| nn.ReLU(inplace=True), |
| nn.ReflectionPad2d(1), |
| nn.Conv2d(c, c, 3), |
| nn.InstanceNorm2d(c), |
| ) |
| def forward(self, x): |
| return x + self.block(x) |
| |
| class TransformNet(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.net = nn.Sequential( |
| conv_bn_relu(3, 32, 9, pad=4), |
| conv_bn_relu(32, 64, 3, stride=2, pad=1), |
| conv_bn_relu(64, 128, 3, stride=2, pad=1), |
| ResBlock(128), ResBlock(128), ResBlock(128), |
| ResBlock(128), ResBlock(128), |
| nn.Upsample(scale_factor=2, mode="nearest"), |
| conv_bn_relu(128, 64, 3, pad=1), |
| nn.Upsample(scale_factor=2, mode="nearest"), |
| conv_bn_relu(64, 32, 3, pad=1), |
| nn.ReflectionPad2d(4), |
| nn.Conv2d(32, 3, 9), |
| nn.Tanh(), |
| ) |
| def forward(self, x): |
| return self.net(x) |
| |
| # ββ LOAD INPUT IMAGE βββββββββββββββββββββββββββββββββββββββββ |
| if len(sys.argv) < 2: |
| print("Usage: python inference.py path_to_input_image.jpg") |
| sys.exit(1) |
| |
| input_path = sys.argv[1] |
| output_path = "output_styled.jpg" |
| |
| transform = transforms.Compose([ |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
| |
| img = Image.open(input_path).convert("RGB") |
| x = transform(img).unsqueeze(0).to(DEVICE) |
| |
| # ββ SECURE FILE DOWNLOAD & STATE LOAD ββββββββββββββββββββββββ |
| print("Downloading weights from Hugging Face Hub...") |
| weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) |
| |
| model = TransformNet().to(DEVICE) |
| model.load_state_dict(torch.load(weights_path, map_location=DEVICE)) |
| model.eval() |
| print(f"Weights successfully loaded on: {DEVICE}") |
| |
| # ββ RUN INFERENCE ββββββββββββββββββββββββββββββββββββββββββββ |
| print("Processing style transfer...") |
| with torch.no_grad(): |
| out = model(x) |
| |
| save_image(out[0] * 0.5 + 0.5, output_path) |
| print(f"Success! Styled image saved to: {output_path}") |
| |
| ``` |
|
|