cccode / misc.py
WayneW's picture
Upload folder using huggingface_hub
705a8fd verified
raw
history blame
8.44 kB
import yaml
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from PIL import Image
from torchvision import transforms
import torchvision.transforms.functional as TF
IMAGE_ASPECT_RATIO = (4 / 3) # all images are centered cropped to a 4:3 aspect ratio in training
with open("config/data_config.yaml", "r") as f:
data_config = yaml.safe_load(f)
def get_action_torch(diffusion_output, action_stats):
ndeltas = diffusion_output
ndeltas = ndeltas.reshape(ndeltas.shape[0], -1, 2)
ndeltas = unnormalize_data(ndeltas, action_stats)
actions = torch.cumsum(ndeltas, dim=1)
return actions.to(ndeltas)
def log_viz_single(dataset_name, obs_image, goal_image, preds, deltas, loss, min_idx, actions, action_stats, plan_iter=0, output_dir='plot.png'):
'''
Visualize a single instance
actions is gt actions
'''
viz_obs_image = unnormalize(obs_image.detach().cpu())[-1] # take last img
viz_goal_image = unnormalize(goal_image.detach().cpu())
deltas = deltas.detach().cpu()
loss = loss.detach().cpu()
actions = actions.detach().cpu()
pred_actions = get_action_torch(deltas[:, :, :2], action_stats)
plot_array = plot_images_and_actions(dataset_name, viz_obs_image, viz_goal_image, pred_actions, actions, min_idx, loss=loss)
plt.imshow(plot_array)
plt.axis('off') # Hide axes for a cleaner image
# Save the plot array as a PNG file locally
plt.savefig(output_dir, format='png', dpi=300, bbox_inches='tight')
def plot_images_and_actions(dataset_name, curr_viz_obs_image, curr_viz_goal_image, curr_viz_pred_actions, curr_viz_actions, min_idx, loss):
curr_viz_obs_image = curr_viz_obs_image.permute(1, 2, 0).cpu().numpy()
curr_viz_goal_image = curr_viz_goal_image.permute(1, 2, 0).cpu().numpy()
# scale back to metric space for plotting
curr_viz_pred_actions = curr_viz_pred_actions * data_config[dataset_name]['metric_waypoint_spacing']
curr_viz_actions = curr_viz_actions * data_config[dataset_name]['metric_waypoint_spacing']
# Create the figure with three subplots
fig, axs = plt.subplots(1, 3, figsize=(9, 3))
# Plot condition image
axs[0].imshow(curr_viz_obs_image)
axs[0].set_title("Condition Image", fontsize=13)
axs[0].axis("off")
# Plot goal image
axs[1].imshow(curr_viz_goal_image)
axs[1].set_title("Goal Image", fontsize=13)
axs[1].axis("off")
colors = ['red', 'orange', 'cyan']
for i in range(1, curr_viz_pred_actions.shape[0]):
color = colors[(i - 1) % len(colors)]
label = f"Sample {i} Min Loss" if i == min_idx.item() else f"{i}"
if i != min_idx.item():
axs[2].plot(-curr_viz_pred_actions[i, :, 1], curr_viz_pred_actions[i, :, 0],
color=color, marker="o", markersize=5, label=label)
axs[2].text(-curr_viz_pred_actions[i, -1, 1],
curr_viz_pred_actions[i, -1, 0],
round(loss[i].item(), 3),
color='black',
fontsize=10,
ha='left', va='bottom') # Adjust position to avoid overlap
# Highlight the minimum loss sample
axs[2].plot(-curr_viz_pred_actions[min_idx.item(), :, 1], curr_viz_pred_actions[min_idx.item(), :, 0],
color='green', marker="o", markersize=5, label=f"{min_idx.item()}")
axs[2].text(-curr_viz_pred_actions[min_idx.item(), -1, 1],
curr_viz_pred_actions[min_idx.item(), -1, 0],
round(loss[min_idx.item()].item(), 3),
color='black',
fontsize=10,
ha='left', va='bottom') # Adjust position to avoid overlap
# Plot ground truth actions
axs[2].plot(-curr_viz_actions[:, 1], curr_viz_actions[:, 0], color='blue', marker="o", label="GT")
# Set titles and labels with larger font size
axs[2].set_title(" ", fontsize=13)
axs[2].set_xlabel("X (m)", fontsize=11)
axs[2].set_ylabel("Y (m)", fontsize=11)
# Set equal aspect ratio and adjust axis limits
axs[2].set_aspect('equal', adjustable='box')
x_min, x_max = axs[2].get_xlim()
y_min, y_max = axs[2].get_ylim()
axis_range = max(x_max - x_min, y_max - y_min) / 2
x_mid = (x_max + x_min) / 2
y_mid = (y_max + y_min) / 2
axs[2].set_xlim(x_mid - axis_range, x_mid + axis_range)
axs[2].set_ylim(y_mid - axis_range, y_mid + axis_range)
axs[2].legend(loc='lower left', fontsize=10, frameon=True, bbox_to_anchor=(0, 0))
plt.tight_layout()
canvas = FigureCanvas(fig)
canvas.draw()
plot_array = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
plot_array = plot_array.reshape(canvas.get_width_height()[::-1] + (3,))
plt.close(fig)
return plot_array
def normalize_data(data, stats):
# nomalize to [0,1]
ndata = (data - stats['min']) / (stats['max'] - stats['min'])
# normalize to [-1, 1]
ndata = ndata * 2 - 1
return ndata
def unnormalize_data(ndata, stats):
ndata = (ndata + 1) / 2
data = ndata * (stats['max'].to(ndata) - stats['min'].to(ndata)) + stats['min'].to(ndata)
return data
def get_data_path(data_folder: str, f: str, time: int, data_type: str = "image"):
data_ext = {
"image": ".jpg",
"audio": ".wav"
# add more data types here
}
return os.path.join(data_folder, f, f"{str(time)}{data_ext[data_type]}")
def yaw_rotmat(yaw: float) -> np.ndarray:
return np.array(
[
[np.cos(yaw), -np.sin(yaw), 0.0],
[np.sin(yaw), np.cos(yaw), 0.0],
[0.0, 0.0, 1.0],
],
)
def angle_difference(theta1, theta2):
delta_theta = theta2 - theta1
delta_theta = delta_theta - 2 * np.pi * np.floor((delta_theta + np.pi) / (2 * np.pi))
return delta_theta
def get_delta_np(actions):
# append zeros to first action (unbatched)
ex_actions = np.concatenate((np.zeros((1, actions.shape[1])), actions), axis=0)
delta = ex_actions[1:] - ex_actions[:-1]
return delta
def to_local_coords(
positions: np.ndarray, curr_pos: np.ndarray, curr_yaw: float
) -> np.ndarray:
"""
Convert positions to local coordinates
Args:
positions (np.ndarray): positions to convert
curr_pos (np.ndarray): current position
curr_yaw (float): current yaw
Returns:
np.ndarray: positions in local coordinates
"""
rotmat = yaw_rotmat(curr_yaw)
if positions.shape[-1] == 2:
rotmat = rotmat[:2, :2]
elif positions.shape[-1] == 3:
pass
else:
raise ValueError
return (positions - curr_pos).dot(rotmat)
def calculate_delta_yaw(unnorm_actions):
x = unnorm_actions[..., 0]
y = unnorm_actions[..., 1]
yaw = torch.atan2(y, x).unsqueeze(-1)
delta_yaw = torch.cat((torch.zeros(yaw.shape[0], 1, yaw.shape[2]).to(yaw.device), yaw), dim=1)
delta_yaw = delta_yaw[:, 1:, :] - delta_yaw[:, :-1, :]
return delta_yaw
def save_planning_pred(dataset_save_output_dir, B, idxs, obs_image, goal_image, preds, deltas, loss, gt_actions, plan_iter=0):
for batch_idx, idx in enumerate(idxs.flatten()):
sample_idx = int(idx)
sample_folder = os.path.join(dataset_save_output_dir, f'id_{sample_idx}')
os.makedirs(sample_folder, exist_ok=True)
preds_save = {
'obs_image': obs_image[batch_idx],
'goal_image': goal_image[batch_idx],
'preds': preds[batch_idx],
'deltas': deltas[batch_idx],
'loss': loss[batch_idx],
'gt_actions': gt_actions[batch_idx],
}
preds_file = os.path.join(sample_folder, f"preds_{plan_iter}.pth")
torch.save(preds_save, preds_file)
class CenterCropAR:
def __init__(self, ar: float = IMAGE_ASPECT_RATIO):
self.ar = ar
def __call__(self, img: Image.Image):
w, h = img.size
if w > h:
img = TF.center_crop(img, (h, int(h * self.ar)))
else:
img = TF.center_crop(img, (int(w / self.ar), w))
return img
transform = transforms.Compose([
CenterCropAR(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
unnormalize = transforms.Normalize(
mean=[-0.5 / 0.5, -0.5 / 0.5, -0.5 / 0.5],
std=[1 / 0.5, 1 / 0.5, 1 / 0.5]
)