| | |
| |
|
| | import torch |
| | from torch.autograd import Variable |
| |
|
| | def to_var(x, requires_grad=False, volatile=False): |
| | if torch.cuda.is_available(): |
| | x = x.cuda() |
| | return Variable(x, requires_grad=requires_grad, volatile=volatile) |
| |
|
| | def top_k_logits(logits, k, probs=False): |
| | """ |
| | Masks everything but the k top entries as -infinity (1e10). |
| | Used to mask logits such that e^-infinity -> 0 won't contribute to the |
| | sum of the denominator. |
| | """ |
| | if k == 0: |
| | return logits |
| | else: |
| | values = torch.topk(logits, k)[0] |
| | batch_mins = values[:, -1].view(-1, 1).expand_as(logits) |
| | if probs: |
| | return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits) |
| | return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits) |
| |
|