| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | """ |
| | The renderer is a module that takes in rays, decides where to sample along each |
| | ray, and computes pixel colors using the volume rendering equation. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from . import math_utils |
| |
|
| |
|
| | def generate_planes(): |
| | """ |
| | Defines planes by the three vectors that form the "axes" of the |
| | plane. Should work with arbitrary number of planes and planes of |
| | arbitrary orientation. |
| | |
| | Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 |
| | """ |
| | return torch.tensor([[[1, 0, 0], |
| | [0, 1, 0], |
| | [0, 0, 1]], |
| | [[1, 0, 0], |
| | [0, 0, 1], |
| | [0, 1, 0]], |
| | [[0, 0, 1], |
| | [0, 1, 0], |
| | [1, 0, 0]]], dtype=torch.float32) |
| |
|
| | def project_onto_planes(planes, coordinates): |
| | """ |
| | Does a projection of a 3D point onto a batch of 2D planes, |
| | returning 2D plane coordinates. |
| | |
| | Takes plane axes of shape n_planes, 3, 3 |
| | # Takes coordinates of shape N, M, 3 |
| | # returns projections of shape N*n_planes, M, 2 |
| | """ |
| | N, M, C = coordinates.shape |
| | n_planes, _, _ = planes.shape |
| | coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) |
| | inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) |
| | projections = torch.bmm(coordinates, inv_planes) |
| | return projections[..., :2] |
| |
|
| | def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): |
| | assert padding_mode == 'zeros' |
| | N, n_planes, C, H, W = plane_features.shape |
| | _, M, _ = coordinates.shape |
| | plane_features = plane_features.view(N*n_planes, C, H, W) |
| | dtype = plane_features.dtype |
| |
|
| | coordinates = (2/box_warp) * coordinates |
| |
|
| | projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) |
| | output_features = torch.nn.functional.grid_sample( |
| | plane_features, |
| | projected_coordinates.to(dtype), |
| | mode=mode, |
| | padding_mode=padding_mode, |
| | align_corners=False, |
| | ).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) |
| | return output_features |
| |
|
| | def sample_from_3dgrid(grid, coordinates): |
| | """ |
| | Expects coordinates in shape (batch_size, num_points_per_batch, 3) |
| | Expects grid in shape (1, channels, H, W, D) |
| | (Also works if grid has batch size) |
| | Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) |
| | """ |
| | batch_size, n_coords, n_dims = coordinates.shape |
| | sampled_features = torch.nn.functional.grid_sample( |
| | grid.expand(batch_size, -1, -1, -1, -1), |
| | coordinates.reshape(batch_size, 1, 1, -1, n_dims), |
| | mode='bilinear', |
| | padding_mode='zeros', |
| | align_corners=False, |
| | ) |
| | N, C, H, W, D = sampled_features.shape |
| | sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) |
| | return sampled_features |
| |
|