| from abc import ABCMeta, abstractmethod |
|
|
| import ase |
| import torch |
| import torch.nn as nn |
| from torch_scatter import scatter |
|
|
| from visnet.models.utils import act_class_mapping |
|
|
| __all__ = ["Scalar", "DipoleMoment", "ElectronicSpatialExtent", "VectorOutput"] |
|
|
|
|
| class GatedEquivariantBlock(nn.Module): |
| """ |
| Gated Equivariant Block as defined in Schütt et al. (2021): |
| Equivariant message passing for the prediction of tensorial properties and molecular spectra |
| """ |
| def __init__( |
| self, |
| hidden_channels, |
| out_channels, |
| intermediate_channels=None, |
| activation="silu", |
| scalar_activation=False, |
| ): |
| super(GatedEquivariantBlock, self).__init__() |
| self.out_channels = out_channels |
|
|
| if intermediate_channels is None: |
| intermediate_channels = hidden_channels |
|
|
| self.vec1_proj = nn.Linear(hidden_channels, hidden_channels, bias=False) |
| self.vec2_proj = nn.Linear(hidden_channels, out_channels, bias=False) |
|
|
| act_class = act_class_mapping[activation] |
| self.update_net = nn.Sequential( |
| nn.Linear(hidden_channels * 2, intermediate_channels), |
| act_class(), |
| nn.Linear(intermediate_channels, out_channels * 2), |
| ) |
|
|
| self.act = act_class() if scalar_activation else None |
| |
| def reset_parameters(self): |
| nn.init.xavier_uniform_(self.vec1_proj.weight) |
| nn.init.xavier_uniform_(self.vec2_proj.weight) |
| nn.init.xavier_uniform_(self.update_net[0].weight) |
| self.update_net[0].bias.data.fill_(0) |
| nn.init.xavier_uniform_(self.update_net[2].weight) |
| self.update_net[2].bias.data.fill_(0) |
| |
| def forward(self, x, v): |
| vec1 = torch.norm(self.vec1_proj(v), dim=-2) |
| vec2 = self.vec2_proj(v) |
|
|
| x = torch.cat([x, vec1], dim=-1) |
| x, v = torch.split(self.update_net(x), self.out_channels, dim=-1) |
| v = v.unsqueeze(1) * vec2 |
|
|
| if self.act is not None: |
| x = self.act(x) |
| return x, v |
|
|
|
|
| class OutputModel(nn.Module, metaclass=ABCMeta): |
| def __init__(self, allow_prior_model): |
| super(OutputModel, self).__init__() |
| self.allow_prior_model = allow_prior_model |
| |
| def reset_parameters(self): |
| pass |
|
|
| @abstractmethod |
| def pre_reduce(self, x, v, z, pos, batch): |
| return |
| |
| def post_reduce(self, x): |
| return x |
|
|
|
|
| class Scalar(OutputModel): |
| def __init__(self, hidden_channels, activation="silu", allow_prior_model=True): |
| super(Scalar, self).__init__(allow_prior_model=allow_prior_model) |
| act_class = act_class_mapping[activation] |
| self.output_network = nn.Sequential( |
| nn.Linear(hidden_channels, hidden_channels // 2), |
| act_class(), |
| nn.Linear(hidden_channels // 2, 1), |
| ) |
| |
| self.reset_parameters() |
| |
| def reset_parameters(self): |
| nn.init.xavier_uniform_(self.output_network[0].weight) |
| self.output_network[0].bias.data.fill_(0) |
| nn.init.xavier_uniform_(self.output_network[2].weight) |
| self.output_network[2].bias.data.fill_(0) |
|
|
| def pre_reduce(self, x, v, z, pos, batch): |
| |
| return self.output_network(x) |
|
|
|
|
| class EquivariantScalar(OutputModel): |
| def __init__(self, hidden_channels, activation="silu", allow_prior_model=True): |
| super(EquivariantScalar, self).__init__(allow_prior_model=allow_prior_model) |
| self.output_network = nn.ModuleList([ |
| GatedEquivariantBlock( |
| hidden_channels, |
| hidden_channels // 2, |
| activation=activation, |
| scalar_activation=True, |
| ), |
| GatedEquivariantBlock( |
| hidden_channels // 2, |
| 1, |
| activation=activation, |
| scalar_activation=False, |
| ), |
| ]) |
| |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| for layer in self.output_network: |
| layer.reset_parameters() |
| |
| def pre_reduce(self, x, v, z, pos, batch): |
| for layer in self.output_network: |
| x, v = layer(x, v) |
| |
| return x + v.sum() * 0 |
|
|
|
|
| class DipoleMoment(Scalar): |
| def __init__(self, hidden_channels, activation="silu", allow_prior_model=False): |
| super(DipoleMoment, self).__init__(hidden_channels, activation, allow_prior_model=allow_prior_model) |
| atomic_mass = torch.from_numpy(ase.data.atomic_masses).float() |
| self.register_buffer("atomic_mass", atomic_mass) |
|
|
| def pre_reduce(self, x, v, z, pos, batch): |
| x = self.output_network(x) |
|
|
| |
| mass = self.atomic_mass[z].view(-1, 1) |
| c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0) |
| x = x * (pos - c[batch]) |
| return x |
|
|
| def post_reduce(self, x): |
| return torch.norm(x, dim=-1, keepdim=True) |
|
|
|
|
| class EquivariantDipoleMoment(EquivariantScalar): |
| def __init__(self, hidden_channels, activation="silu", allow_prior_model=False): |
| super(EquivariantDipoleMoment, self).__init__(hidden_channels, activation, allow_prior_model=allow_prior_model) |
| atomic_mass = torch.from_numpy(ase.data.atomic_masses).float() |
| self.register_buffer("atomic_mass", atomic_mass) |
|
|
| def pre_reduce(self, x, v, z, pos, batch): |
| if v.shape[1] == 8: |
| l1_v, l2_v = torch.split(v, [3, 5], dim=1) |
| else: |
| l1_v, l2_v = v, torch.zeros(v.shape[0], 5, v.shape[2]) |
| |
| for layer in self.output_network: |
| x, l1_v = layer(x, l1_v) |
|
|
| |
| mass = self.atomic_mass[z].view(-1, 1) |
| c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0) |
| x = x * (pos - c[batch]) |
| return x + l1_v.squeeze() + l2_v.sum() * 0 |
|
|
| def post_reduce(self, x): |
| return torch.norm(x, dim=-1, keepdim=True) |
|
|
|
|
| class ElectronicSpatialExtent(OutputModel): |
| def __init__(self, hidden_channels, activation="silu", allow_prior_model=False): |
| super(ElectronicSpatialExtent, self).__init__(allow_prior_model=False) |
| act_class = act_class_mapping[activation] |
| self.output_network = nn.Sequential( |
| nn.Linear(hidden_channels, hidden_channels // 2), |
| act_class(), |
| nn.Linear(hidden_channels // 2, 1), |
| ) |
| atomic_mass = torch.from_numpy(ase.data.atomic_masses).float() |
| self.register_buffer("atomic_mass", atomic_mass) |
|
|
| self.reset_parameters() |
| |
| def reset_parameters(self): |
| nn.init.xavier_uniform_(self.output_network[0].weight) |
| self.output_network[0].bias.data.fill_(0) |
| nn.init.xavier_uniform_(self.output_network[2].weight) |
| self.output_network[2].bias.data.fill_(0) |
|
|
| def pre_reduce(self, x, v, z, pos, batch): |
| x = self.output_network(x) |
|
|
| |
| mass = self.atomic_mass[z].view(-1, 1) |
| c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0) |
|
|
| x = torch.norm(pos - c[batch], dim=1, keepdim=True) ** 2 * x |
| return x |
|
|
|
|
| class EquivariantElectronicSpatialExtent(ElectronicSpatialExtent): |
| pass |
|
|
|
|
| class EquivariantVectorOutput(EquivariantScalar): |
| def __init__(self, hidden_channels, activation="silu", allow_prior_model=False): |
| super(EquivariantVectorOutput, self).__init__(hidden_channels, activation, allow_prior_model=allow_prior_model) |
|
|
| def pre_reduce(self, x, v, z, pos, batch): |
| for layer in self.output_network: |
| x, v = layer(x, v) |
| |
| if v.shape[1] == 8: |
| l1_v, l2_v = torch.split(v.squeeze(), [3, 5], dim=1) |
| return l1_v + x.sum() * 0 + l2_v.sum() * 0 |
| else: |
| return v + x.sum() * 0 |
|
|