| | from lightning import Callback |
| | import torch |
| | import matplotlib.pyplot as plt |
| | import os |
| | import numpy as np |
| | import torchvision |
| | from einops import rearrange |
| |
|
| | class VisualizationCallback(Callback): |
| | def __init__(self, save_freq=2000, output_dir="visualizations"): |
| | self.save_freq = save_freq |
| | self.output_dir = output_dir |
| | if not os.path.exists(self.output_dir): |
| | os.makedirs(self.output_dir) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def on_train_batch_start(self, trainer, model, batch, batch_idx): |
| | |
| | if trainer.is_global_zero: |
| | global_step = trainer.global_step |
| | if global_step % self.save_freq == 0: |
| | |
| | |
| | self.save_visualization(trainer, model, global_step, batch) |
| |
|
| | def save_visualization(self, trainer, model, global_step, batch): |
| | |
| | fig, ax = plt.subplots() |
| | ax.plot([1, 2, 3], [4, 5, 6]) |
| | ax.set_title(f"Visualization at Step {global_step}") |
| | |
| | |
| | plt.savefig(f"{self.output_dir}/visualization_{global_step}.png") |
| | plt.close(fig) |
| | print(f"Saved visualization at step {global_step}") |
| |
|
| |
|
| | class VisualizationVAECallback(VisualizationCallback): |
| | def __init__(self, save_freq=2000, output_dir="visualizations"): |
| | super().__init__(save_freq, output_dir) |
| | |
| | def save_visualization(self, trainer, model, global_step, batch): |
| | |
| | model.eval() |
| | with torch.no_grad(): |
| | x_pred, x_gt = model(batch) |
| |
|
| | x_pred = x_pred.cpu() |
| | x_gt = x_gt.cpu() |
| |
|
| | x_pred = torch.clamp(x_pred, min=0.0, max=1.0) |
| | x_gt = torch.clamp(x_gt, min=0.0, max=1.0) |
| |
|
| | B = x_gt.shape[0] |
| | rows = int(np.ceil(np.sqrt(B))) |
| | cols = int(np.ceil(B / rows)) |
| |
|
| | gt_grid = torchvision.utils.make_grid(x_gt, nrow=rows) |
| | pred_grid = torchvision.utils.make_grid(x_pred, nrow=rows) |
| | |
| | fig, axes = plt.subplots(1, 2, figsize=(12, 6)) |
| | axes[0].imshow(gt_grid.permute(1, 2, 0)) |
| | axes[0].axis('off') |
| | |
| |
|
| | axes[1].imshow(pred_grid.permute(1, 2, 0)) |
| | axes[1].axis('off') |
| | |
| |
|
| | plt.tight_layout() |
| | plt.show() |
| | plt.savefig(f"{self.output_dir}/image_grid_{global_step}.png") |
| | plt.close() |
| |
|
| | |
| |
|
| |
|
| | class Visualization_HeadAnimator_Callback(VisualizationCallback): |
| | def __init__(self, save_freq=2000, output_dir="visualizations"): |
| | super().__init__(save_freq, output_dir) |
| | |
| | def save_visualization(self, trainer, model, global_step, batch): |
| | |
| |
|
| | masked_target_vid = batch['pixel_values_vid'] |
| | masked_ref_img = batch['pixel_values_ref_img'] |
| |
|
| | ref_img_original = batch['ref_img_original'] |
| | target_vid_original = batch['pixel_values_vid_original'] |
| | |
| | |
| | masked_ref_img = masked_ref_img[:,None].repeat(1, masked_target_vid.size(1), 1, 1, 1) |
| | masked_ref_img = rearrange(masked_ref_img, "b t c h w -> (b t) c h w") |
| | masked_target_vid = rearrange(masked_target_vid, "b t c h w -> (b t) c h w") |
| |
|
| | ref_img_original = ref_img_original[:,None].repeat(1, target_vid_original.size(1), 1, 1, 1) |
| | ref_img_original = rearrange(ref_img_original, "b t c h w -> (b t) c h w") |
| | target_vid_original = rearrange(target_vid_original, "b t c h w -> (b t) c h w") |
| | |
| | with torch.no_grad(): |
| | |
| | model_out = model.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid) |
| | x_pred = model_out['recon_img'] |
| | x_gt = target_vid_original |
| |
|
| | x_pred = x_pred.cpu() |
| | x_gt = x_gt.cpu() |
| | x_ref = ref_img_original.cpu() |
| |
|
| | if x_gt.min() < -0.5: |
| | x_gt = (x_gt + 1) / 2 |
| | x_pred = (x_pred + 1) / 2 |
| | x_ref = (x_ref + 1) / 2 |
| |
|
| | x_pred = torch.clamp(x_pred, min=0.0, max=1.0) |
| | x_gt = torch.clamp(x_gt, min=0.0, max=1.0) |
| | x_ref = torch.clamp(x_ref, min=0.0, max=1.0) |
| |
|
| | B = x_gt.shape[0] |
| | rows = int(np.ceil(np.sqrt(B))) |
| | cols = int(np.ceil(B / rows)) |
| | |
| | ref_grid = torchvision.utils.make_grid(x_ref, nrow=rows) |
| | gt_grid = torchvision.utils.make_grid(x_gt, nrow=rows) |
| | pred_grid = torchvision.utils.make_grid(x_pred, nrow=rows) |
| |
|
| | diff = (x_pred-x_gt).abs() |
| | diff_grid = torchvision.utils.make_grid(diff, nrow=rows) |
| | |
| | fig, axes = plt.subplots(1, 4, figsize=(12, 6)) |
| | axes[0].imshow(ref_grid.permute(1, 2, 0)) |
| | axes[0].axis('off') |
| |
|
| | axes[1].imshow(gt_grid.permute(1, 2, 0)) |
| | axes[1].axis('off') |
| |
|
| | axes[2].imshow(pred_grid.permute(1, 2, 0)) |
| | axes[2].axis('off') |
| |
|
| | axes[3].imshow(diff_grid.permute(1, 2, 0), cmap='jet') |
| | axes[3].axis('off') |
| |
|
| | plt.tight_layout() |
| | plt.show() |
| | plt.savefig(f"{self.output_dir}/image_grid_{global_step}.png") |
| | plt.close() |
| |
|
| |
|
| |
|
| | |