import torch import numpy as np import matplotlib.pyplot as plt import torch.nn.functional as F from matplotlib.colors import ListedColormap, BoundaryNorm from matplotlib.lines import Line2D import matplotlib.animation as animation import scienceplots def resize(seq, size): # seq shape : (B, T, 1, H, W) seq = F.interpolate(seq.squeeze(dim=2), size=size, mode='bilinear', align_corners=False) # (B, T, H, W) seq = seq.clamp(0,1) return seq.unsqueeze(2) # (B, T, 1, H, W) # ======================================================================= # Utils in utils :) # ======================================================================= def to_cpu_tensor(*args): ''' Input arbitrary number of array/tensors, each will be converted to CPU torch.Tensor ''' out = [] for tensor in args: if type(tensor) is np.ndarray: tensor = torch.Tensor(tensor) if type(tensor) is torch.Tensor: tensor = tensor.cpu() out.append(tensor) # single value input: return single value output if len(out) == 1: return out[0] return out from tempfile import NamedTemporaryFile plt.style.use(['science', 'no-latex']) VIL_COLORS = [[0, 0, 0], [0.30196078431372547, 0.30196078431372547, 0.30196078431372547], [0.1568627450980392, 0.7450980392156863, 0.1568627450980392], [0.09803921568627451, 0.5882352941176471, 0.09803921568627451], [0.0392156862745098, 0.4117647058823529, 0.0392156862745098], [0.0392156862745098, 0.29411764705882354, 0.0392156862745098], [0.9607843137254902, 0.9607843137254902, 0.0], [0.9294117647058824, 0.6745098039215687, 0.0], [0.9411764705882353, 0.43137254901960786, 0.0], [0.6274509803921569, 0.0, 0.0], [0.9058823529411765, 0.0, 1.0]] VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0] """ Visualize function with colorbar and a line seprate input and output """ def gradio_visualize(sequence): ''' input: sequences, a list/dict of numpy/torch arrays with shape (T, C, H, W) C is assumed to be 1 and squeezed If batch > 1, only the first sequence will be printed ''' fig_size = 3 fig, axes = plt.subplots(1, len(sequence), figsize=(fig_size*len(sequence), fig_size), tight_layout=True) plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout plt.setp(axes, xticks=[], yticks=[]) for i, frame in enumerate(sequence): axes[i].set_xticks([]) axes[i].set_yticks([]) axes[i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=12) frame = frame.squeeze() im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N)) # # First pass: compute the vertical height and convert to proper format # vertical = 0 # display_texts = [] # if (type(sequences) is dict): # temp = [] # for k, v in sequences.items(): # vertical += int(np.ceil(v.shape[1] / horizontal)) # temp.append(v) # display_texts.append(k) # sequences = temp # else: # for i, sequence in enumerate(sequences): # vertical += int(np.ceil(sequence.shape[1] / horizontal)) # display_texts.append(f'Item {i+1}') # sequences = to_cpu_tensor(*sequences) # # Plot the sequences # j = 0 # fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True) # plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout # plt.setp(axes, xticks=[], yticks=[]) # if vertical == 1: # for k, sequence in enumerate(sequences.values()): # for i in range(len(sequence)): # axes[i].set_xticks([]) # axes[i].set_yticks([]) # axes[i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=12) # frame = sequence[i].squeeze() # im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ # norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True) # else: # for k, sequence in enumerate(sequences): # # only take the first batch, now seq[0] is the temporal dim # sequence = sequence.squeeze() # (T, H, W) # ## ================= # # = labels of time = # if k == 0: # for i in range(len(sequence)): # axes[j, i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=16) # axes[j, i].xaxis.set_label_position('top') # elif k == len(sequences)-1: # for i in range(len(sequence)): # axes[j, i].set_xlabel(f'$t+{skip*i+1}$', fontsize=16) # axes[j, i].xaxis.set_label_position('bottom') # ## ================= # axes[j, 0].set_ylabel(display_texts[k], fontsize=16) # for i, frame in enumerate(sequence): # j_shift = j + i // horizontal # i_shift = i % horizontal # im = axes[j_shift, i_shift].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ # norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N)) # j += int(np.ceil(sequence.shape[0] / horizontal)) # # ## = plot splittin line = # # if ypos == 0: # # ypos = 1 - 1 / len(sequences) - 0.017 # # fig.lines.append(Line2D((0, 1), (ypos, ypos), transform=fig.transFigure, ls='--', linewidth=2, color='#444')) # color bar cax = fig.add_axes([1, 0.05, 0.02, 0.5]) fig.colorbar(im, cax=cax) # Save the figure to a temporary file with NamedTemporaryFile(suffix=".png", delete=False) as ff: fig.savefig(ff.name) file_path = ff.name # It's important to close the figure to prevent memory leaks plt.close(fig) return file_path def gradio_gif(sequences, T): ''' input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W) C is assumed to be 1 and squeezed If batch > 1, only the first sequence will be printed ''' # plt.style.use(['science', 'no-latex']) # VIL_COLORS = [[0, 0, 0], # [0.30196078431372547, 0.30196078431372547, 0.30196078431372547], # [0.1568627450980392, 0.7450980392156863, 0.1568627450980392], # [0.09803921568627451, 0.5882352941176471, 0.09803921568627451], # [0.0392156862745098, 0.4117647058823529, 0.0392156862745098], # [0.0392156862745098, 0.29411764705882354, 0.0392156862745098], # [0.9607843137254902, 0.9607843137254902, 0.0], # [0.9294117647058824, 0.6745098039215687, 0.0], # [0.9411764705882353, 0.43137254901960786, 0.0], # [0.6274509803921569, 0.0, 0.0], # [0.9058823529411765, 0.0, 1.0]] # VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0] horizontal = len(sequences) fig_size = 3 fig, axes = plt.subplots(nrows=1, ncols=horizontal, figsize=(fig_size*horizontal, fig_size), tight_layout=True) plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout plt.setp(axes, xticks=[], yticks=[]) if horizontal == 1: for i, (key, sequence) in enumerate(sequences.items()): axes.set_xticks([]) axes.set_yticks([]) axes.set_xlabel(f'{key}', fontsize=12) frame = sequence[0].squeeze() im = axes.imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True) else: for i, (key, sequence) in enumerate(sequences.items()): axes[i].set_xticks([]) axes[i].set_yticks([]) axes[i].set_xlabel(f'{key}', fontsize=12) frame = sequence[0].squeeze() im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True) title = fig.suptitle('', y=0.9, x=0.505, fontsize=16) # Initialize an empty super title # fig.colorbar(im) def animate(t): if horizontal == 1: for i, sequence in enumerate(sequences.values()): frame = sequence[t].squeeze() im = axes.imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True) else: for i, sequence in enumerate(sequences.values()): frame = sequence[t].squeeze() im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True) plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout title.set_text(f'$t + {t}$') # update the title text return fig, ani = animation.FuncAnimation(fig, animate, frames=T, interval=750, blit=True, repeat_delay=50,) # Save the figure to a temporary file with NamedTemporaryFile(suffix=".gif", delete=False) as ff: ani.save(ff.name, writer='pillow', fps=5) file_path = ff.name plt.close() return file_path