| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch.nn as nn |
| import pytorch_lightning as pl |
|
|
| from .geometry import index, orthogonal, perspective |
|
|
|
|
| class BasePIFuNet(pl.LightningModule): |
|
|
| def __init__( |
| self, |
| projection_mode='orthogonal', |
| error_term=nn.MSELoss(), |
| ): |
| """ |
| :param projection_mode: |
| Either orthogonal or perspective. |
| It will call the corresponding function for projection. |
| :param error_term: |
| nn Loss between the predicted [B, Res, N] and the label [B, Res, N] |
| """ |
| super(BasePIFuNet, self).__init__() |
| self.name = 'base' |
|
|
| self.error_term = error_term |
|
|
| self.index = index |
| self.projection = orthogonal if projection_mode == 'orthogonal' else perspective |
|
|
| def forward(self, points, images, calibs, transforms=None): |
| ''' |
| :param points: [B, 3, N] world space coordinates of points |
| :param images: [B, C, H, W] input images |
| :param calibs: [B, 3, 4] calibration matrices for each image |
| :param transforms: Optional [B, 2, 3] image space coordinate transforms |
| :return: [B, Res, N] predictions for each point |
| ''' |
| features = self.filter(images) |
| preds = self.query(features, points, calibs, transforms) |
| return preds |
|
|
| def filter(self, images): |
| ''' |
| Filter the input images |
| store all intermediate features. |
| :param images: [B, C, H, W] input images |
| ''' |
| return None |
|
|
| def query(self, features, points, calibs, transforms=None): |
| ''' |
| Given 3D points, query the network predictions for each point. |
| Image features should be pre-computed before this call. |
| store all intermediate features. |
| query() function may behave differently during training/testing. |
| :param points: [B, 3, N] world space coordinates of points |
| :param calibs: [B, 3, 4] calibration matrices for each image |
| :param transforms: Optional [B, 2, 3] image space coordinate transforms |
| :param labels: Optional [B, Res, N] gt labeling |
| :return: [B, Res, N] predictions for each point |
| ''' |
| return None |
|
|
| def get_error(self, preds, labels): |
| ''' |
| Get the network loss from the last query |
| :return: loss term |
| ''' |
| return self.error_term(preds, labels) |
|
|