| from collections import defaultdict |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| @torch.no_grad() |
| def log_sample_res( |
| text_encoder, vision_encoder, rdt, args, |
| accelerator, weight_dtype, dataset_id2name, dataloader, logger |
| ): |
| logger.info( |
| f"Running sampling for {args.num_sample_batches} batches..." |
| ) |
|
|
| rdt.eval() |
| |
| loss_for_log = defaultdict(float) |
| loss_counter = defaultdict(int) |
| for step, batch in enumerate(dataloader): |
| if step >= args.num_sample_batches: |
| break |
| |
| data_indices = batch["data_indices"] |
| ctrl_freqs = batch["ctrl_freqs"] |
| state_norm = batch["state_norm"].to(dtype=weight_dtype) |
| images = batch["images"].to(dtype=weight_dtype) |
| states = batch["states"].to(dtype=weight_dtype) |
| |
| states = states[:, -1:, :] |
| actions = batch["actions"].to(dtype=weight_dtype) |
| state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype) |
| |
| batch_size, _, C, H, W = images.shape |
| image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach() |
| image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size)) |
| |
| lang_attn_mask = batch["lang_attn_mask"] |
| text_embeds = batch["lang_embeds"].to(dtype=weight_dtype) \ |
| if args.precomp_lang_embed \ |
| else text_encoder( |
| input_ids=batch["input_ids"], |
| attention_mask=lang_attn_mask |
| )["last_hidden_state"].detach() |
| |
| with torch.autocast(device_type='cuda',dtype=torch.bfloat16): |
| pred_actions = rdt.predict_action( |
| lang_tokens=text_embeds, |
| lang_attn_mask=lang_attn_mask, |
| img_tokens=image_embeds, |
| state_tokens=states, |
| action_mask=state_elem_mask.unsqueeze(1), |
| ctrl_freqs=ctrl_freqs |
| ) |
| |
| num_steps = pred_actions.shape[1] |
| expanded_state_elem_mask = state_elem_mask.unsqueeze(1).tile((1, num_steps, 1)).float() |
| expanded_state_norm = state_norm.unsqueeze(1).tile((1, num_steps, 1)).float() |
| |
| loss = F.mse_loss(pred_actions, actions, reduction='none').float() |
| |
| mse_loss_per_entry = ((loss * expanded_state_elem_mask).reshape((batch_size, -1)).sum(1) |
| / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1)) |
| l2_loss_per_entry = loss.sqrt() / (expanded_state_norm + 1e-3) |
| l2_loss_per_entry = ((l2_loss_per_entry * expanded_state_elem_mask).reshape((batch_size, -1)).sum(1) |
| / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1)) |
|
|
| dataset_indices, mse_losses, l2_losses = accelerator.gather_for_metrics( |
| (torch.LongTensor(data_indices).to(device=pred_actions.device), |
| mse_loss_per_entry, l2_loss_per_entry), |
| ) |
| dataset_indices = dataset_indices.tolist() |
| if accelerator.is_main_process: |
| for loss_suffix, losses in zip(["_sample_mse", "_sample_l2err"], [mse_losses, l2_losses]): |
| for dataset_idx, loss_tensor in zip(dataset_indices, losses): |
| loss_name = dataset_id2name[dataset_idx] + loss_suffix |
| loss_for_log[loss_name] += loss_tensor.item() |
| loss_counter[loss_name] += 1 |
| |
| mse_loss = (loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum() |
| mse_loss_scaler = accelerator.gather(mse_loss).mean().item() |
| loss_for_log["overall_avg_sample_mse"] += mse_loss_scaler |
| |
| l2_loss = loss.sqrt() / (expanded_state_norm + 1e-3) |
| l2_loss = (l2_loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum() |
| l2_loss_scaler = accelerator.gather(l2_loss).mean().item() |
| loss_for_log["overall_avg_sample_l2err"] += l2_loss_scaler |
|
|
| for name in loss_for_log: |
| if name in ["overall_avg_sample_mse", "overall_avg_sample_l2err"]: |
| loss_scaler = loss_for_log[name] |
| loss_for_log[name] = round(loss_scaler / (args.num_sample_batches), 4) |
| else: |
| loss_for_log[name] = round(loss_for_log[name] / loss_counter[name], 4) |
| |
| rdt.train() |
| torch.cuda.empty_cache() |
|
|
| return dict(loss_for_log) |