| import os |
| import argparse |
| import torch |
|
|
| from resselt import load_from_file |
| from pepeline import read, save, ImgColor, ImgFormat |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser( |
| description="Batch image upscaling script" |
| ) |
| parser.add_argument("--input_dir", type=str, required=True, help="Path to input images") |
| parser.add_argument("--output_dir", type=str, required=True, help="Path to save results") |
| parser.add_argument("--weights", type=str, required=True, help="Path to model weights") |
| parser.add_argument("--device", type=str, default=None, help="cuda or cpu") |
| return parser.parse_args() |
|
|
|
|
| def load_model(weights_path: str, device: torch.device): |
| model = load_from_file(weights_path) |
| model = model.to( |
| device, |
| memory_format=torch.preserve_format, |
| non_blocking=True, |
| ).eval() |
| return model |
|
|
|
|
| def process_image(model, img_path: str, device: torch.device): |
| img = read(img_path, ImgColor.RGB, ImgFormat.F32).transpose(2, 0, 1) |
| img = ( |
| torch.tensor(img) |
| .to( |
| device, |
| memory_format=torch.preserve_format, |
| non_blocking=True, |
| ) |
| .unsqueeze(0) |
| ) |
|
|
| with torch.autocast(device.type, torch.float16): |
| with torch.inference_mode(): |
| output = model(img) |
|
|
| output = output.permute(0, 2, 3, 1).detach().cpu().numpy()[0] |
| return output |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| device = torch.device( |
| args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu") |
| ) |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| model = load_model(args.weights, device) |
|
|
| img_list = os.listdir(args.input_dir) |
| total = len(img_list) |
|
|
| for index, img_name in enumerate(img_list, start=1): |
| print( |
| f"\rProcessing {index}/{total} | {img_name}", |
| end="", |
| flush=True, |
| ) |
|
|
| img_path = os.path.join(args.input_dir, img_name) |
| result = process_image(model, img_path, device) |
|
|
| save(result.copy(), os.path.join(args.output_dir, img_name)) |
|
|
| print("\nDone.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|