| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
| from torch.nn.utils import weight_norm |
|
|
|
|
| def WNConv1d(*args, **kwargs): |
| return weight_norm(nn.Conv1d(*args, **kwargs)) |
|
|
|
|
| def WNConvTranspose1d(*args, **kwargs): |
| return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
|
|
|
|
| |
| @torch.jit.script |
| def snake(x, alpha): |
| shape = x.shape |
| x = x.reshape(shape[0], shape[1], -1) |
| x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) |
| x = x.reshape(shape) |
| return x |
|
|
|
|
| class Snake1d(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.alpha = nn.Parameter(torch.ones(1, channels, 1)) |
|
|
| def forward(self, x): |
| return snake(x, self.alpha) |
|
|