| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import torch |
| | import torch.nn as nn |
| | from functools import partialmethod |
| | from typing import Union, List |
| |
|
| |
|
| | class Dropout(nn.Module): |
| | """ |
| | Implementation of dropout with the ability to share the dropout mask |
| | along a particular dimension. |
| | |
| | If not in training mode, this module computes the identity function. |
| | """ |
| |
|
| | def __init__(self, r: float, batch_dim: Union[int, List[int]]): |
| | """ |
| | Args: |
| | r: |
| | Dropout rate |
| | batch_dim: |
| | Dimension(s) along which the dropout mask is shared |
| | """ |
| | super(Dropout, self).__init__() |
| |
|
| | self.r = r |
| | if type(batch_dim) == int: |
| | batch_dim = [batch_dim] |
| | self.batch_dim = batch_dim |
| | self.dropout = nn.Dropout(self.r) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: |
| | Tensor to which dropout is applied. Can have any shape |
| | compatible with self.batch_dim |
| | """ |
| | shape = list(x.shape) |
| | if self.batch_dim is not None: |
| | for bd in self.batch_dim: |
| | shape[bd] = 1 |
| | mask = x.new_ones(shape) |
| | mask = self.dropout(mask) |
| | x *= mask |
| | return x |
| |
|
| |
|
| | class DropoutRowwise(Dropout): |
| | """ |
| | Convenience class for rowwise dropout as described in subsection |
| | 1.11.6. |
| | """ |
| |
|
| | __init__ = partialmethod(Dropout.__init__, batch_dim=-3) |
| |
|
| |
|
| | class DropoutColumnwise(Dropout): |
| | """ |
| | Convenience class for columnwise dropout as described in subsection |
| | 1.11.6. |
| | """ |
| |
|
| | __init__ = partialmethod(Dropout.__init__, batch_dim=-2) |
| |
|