STLDM / utilspp.py
sqfoo's picture
Made Improvement
dc3d7a9
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.lines import Line2D
import matplotlib.animation as animation
import scienceplots
def resize(seq, size):
# seq shape : (B, T, 1, H, W)
seq = F.interpolate(seq.squeeze(dim=2), size=size, mode='bilinear', align_corners=False) # (B, T, H, W)
seq = seq.clamp(0,1)
return seq.unsqueeze(2) # (B, T, 1, H, W)
# =======================================================================
# Utils in utils :)
# =======================================================================
def to_cpu_tensor(*args):
'''
Input arbitrary number of array/tensors, each will be converted to CPU torch.Tensor
'''
out = []
for tensor in args:
if type(tensor) is np.ndarray:
tensor = torch.Tensor(tensor)
if type(tensor) is torch.Tensor:
tensor = tensor.cpu()
out.append(tensor)
# single value input: return single value output
if len(out) == 1:
return out[0]
return out
from tempfile import NamedTemporaryFile
plt.style.use(['science', 'no-latex'])
VIL_COLORS = [[0, 0, 0],
[0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
[0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
[0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
[0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
[0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
[0.9607843137254902, 0.9607843137254902, 0.0],
[0.9294117647058824, 0.6745098039215687, 0.0],
[0.9411764705882353, 0.43137254901960786, 0.0],
[0.6274509803921569, 0.0, 0.0],
[0.9058823529411765, 0.0, 1.0]]
VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0]
""" Visualize function with colorbar and a line seprate input and output """
def gradio_visualize(sequence):
'''
input: sequences, a list/dict of numpy/torch arrays with shape (T, C, H, W)
C is assumed to be 1 and squeezed
If batch > 1, only the first sequence will be printed
'''
fig_size = 3
fig, axes = plt.subplots(1, len(sequence), figsize=(fig_size*len(sequence), fig_size), tight_layout=True)
plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
plt.setp(axes, xticks=[], yticks=[])
for i, frame in enumerate(sequence):
axes[i].set_xticks([])
axes[i].set_yticks([])
axes[i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=12)
frame = frame.squeeze()
im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N))
# # First pass: compute the vertical height and convert to proper format
# vertical = 0
# display_texts = []
# if (type(sequences) is dict):
# temp = []
# for k, v in sequences.items():
# vertical += int(np.ceil(v.shape[1] / horizontal))
# temp.append(v)
# display_texts.append(k)
# sequences = temp
# else:
# for i, sequence in enumerate(sequences):
# vertical += int(np.ceil(sequence.shape[1] / horizontal))
# display_texts.append(f'Item {i+1}')
# sequences = to_cpu_tensor(*sequences)
# # Plot the sequences
# j = 0
# fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True)
# plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
# plt.setp(axes, xticks=[], yticks=[])
# if vertical == 1:
# for k, sequence in enumerate(sequences.values()):
# for i in range(len(sequence)):
# axes[i].set_xticks([])
# axes[i].set_yticks([])
# axes[i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=12)
# frame = sequence[i].squeeze()
# im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
# norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True)
# else:
# for k, sequence in enumerate(sequences):
# # only take the first batch, now seq[0] is the temporal dim
# sequence = sequence.squeeze() # (T, H, W)
# ## =================
# # = labels of time =
# if k == 0:
# for i in range(len(sequence)):
# axes[j, i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=16)
# axes[j, i].xaxis.set_label_position('top')
# elif k == len(sequences)-1:
# for i in range(len(sequence)):
# axes[j, i].set_xlabel(f'$t+{skip*i+1}$', fontsize=16)
# axes[j, i].xaxis.set_label_position('bottom')
# ## =================
# axes[j, 0].set_ylabel(display_texts[k], fontsize=16)
# for i, frame in enumerate(sequence):
# j_shift = j + i // horizontal
# i_shift = i % horizontal
# im = axes[j_shift, i_shift].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
# norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N))
# j += int(np.ceil(sequence.shape[0] / horizontal))
# # ## = plot splittin line =
# # if ypos == 0:
# # ypos = 1 - 1 / len(sequences) - 0.017
# # fig.lines.append(Line2D((0, 1), (ypos, ypos), transform=fig.transFigure, ls='--', linewidth=2, color='#444'))
# color bar
cax = fig.add_axes([1, 0.05, 0.02, 0.5])
fig.colorbar(im, cax=cax)
# Save the figure to a temporary file
with NamedTemporaryFile(suffix=".png", delete=False) as ff:
fig.savefig(ff.name)
file_path = ff.name
# It's important to close the figure to prevent memory leaks
plt.close(fig)
return file_path
def gradio_gif(sequences, T):
'''
input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W)
C is assumed to be 1 and squeezed
If batch > 1, only the first sequence will be printed
'''
# plt.style.use(['science', 'no-latex'])
# VIL_COLORS = [[0, 0, 0],
# [0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
# [0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
# [0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
# [0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
# [0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
# [0.9607843137254902, 0.9607843137254902, 0.0],
# [0.9294117647058824, 0.6745098039215687, 0.0],
# [0.9411764705882353, 0.43137254901960786, 0.0],
# [0.6274509803921569, 0.0, 0.0],
# [0.9058823529411765, 0.0, 1.0]]
# VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0]
horizontal = len(sequences)
fig_size = 3
fig, axes = plt.subplots(nrows=1, ncols=horizontal, figsize=(fig_size*horizontal, fig_size), tight_layout=True)
plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
plt.setp(axes, xticks=[], yticks=[])
if horizontal == 1:
for i, (key, sequence) in enumerate(sequences.items()):
axes.set_xticks([])
axes.set_yticks([])
axes.set_xlabel(f'{key}', fontsize=12)
frame = sequence[0].squeeze()
im = axes.imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True)
else:
for i, (key, sequence) in enumerate(sequences.items()):
axes[i].set_xticks([])
axes[i].set_yticks([])
axes[i].set_xlabel(f'{key}', fontsize=12)
frame = sequence[0].squeeze()
im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True)
title = fig.suptitle('', y=0.9, x=0.505, fontsize=16) # Initialize an empty super title
# fig.colorbar(im)
def animate(t):
if horizontal == 1:
for i, sequence in enumerate(sequences.values()):
frame = sequence[t].squeeze()
im = axes.imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True)
else:
for i, sequence in enumerate(sequences.values()):
frame = sequence[t].squeeze()
im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True)
plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
title.set_text(f'$t + {t}$') # update the title text
return fig,
ani = animation.FuncAnimation(fig, animate, frames=T, interval=750, blit=True, repeat_delay=50,)
# Save the figure to a temporary file
with NamedTemporaryFile(suffix=".gif", delete=False) as ff:
ani.save(ff.name, writer='pillow', fps=5)
file_path = ff.name
plt.close()
return file_path