| | """ |
| | Deprecated network.py module. This file only exists to support backwards-compatibility |
| | with old pickle files. See lib/__init__.py for more information. |
| | """ |
| |
|
| | from __future__ import print_function |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.autograd import Variable |
| | from torch.nn.parameter import Parameter |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def choose(matrix, idxs): |
| | if isinstance(idxs, Variable): |
| | idxs = idxs.data |
| | assert(matrix.ndimension() == 2) |
| | unrolled_idxs = idxs + \ |
| | torch.arange(0, matrix.size(0)).type_as(idxs) * matrix.size(1) |
| | return matrix.view(matrix.nelement())[unrolled_idxs] |
| |
|
| |
|
| | class Network(nn.Module): |
| | """ |
| | Todo: |
| | - Beam search |
| | - check if this is right? attend during P->FC rather than during softmax->P? |
| | - allow length 0 inputs/targets |
| | - give n_examples as input to FC |
| | - Initialise new weights randomly, rather than as zeroes |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | input_vocabulary, |
| | target_vocabulary, |
| | hidden_size=512, |
| | embedding_size=128, |
| | cell_type="LSTM"): |
| | """ |
| | :param list input_vocabulary: list of possible inputs |
| | :param list target_vocabulary: list of possible targets |
| | """ |
| | super(Network, self).__init__() |
| | self.h_input_encoder_size = hidden_size |
| | self.h_output_encoder_size = hidden_size |
| | self.h_decoder_size = hidden_size |
| | self.embedding_size = embedding_size |
| | self.input_vocabulary = input_vocabulary |
| | self.target_vocabulary = target_vocabulary |
| | |
| | self.v_input = len(input_vocabulary) |
| | |
| | self.v_target = len(target_vocabulary) |
| |
|
| | self.cell_type = cell_type |
| | if cell_type == 'GRU': |
| | self.input_encoder_cell = nn.GRUCell( |
| | input_size=self.v_input + 1, |
| | hidden_size=self.h_input_encoder_size, |
| | bias=True) |
| | self.input_encoder_init = Parameter( |
| | torch.rand(1, self.h_input_encoder_size)) |
| | self.output_encoder_cell = nn.GRUCell( |
| | input_size=self.v_input + |
| | 1 + |
| | self.h_input_encoder_size, |
| | hidden_size=self.h_output_encoder_size, |
| | bias=True) |
| | self.decoder_cell = nn.GRUCell( |
| | input_size=self.v_target + 1, |
| | hidden_size=self.h_decoder_size, |
| | bias=True) |
| | if cell_type == 'LSTM': |
| | self.input_encoder_cell = nn.LSTMCell( |
| | input_size=self.v_input + 1, |
| | hidden_size=self.h_input_encoder_size, |
| | bias=True) |
| | self.input_encoder_init = nn.ParameterList([Parameter(torch.rand( |
| | 1, self.h_input_encoder_size)), Parameter(torch.rand(1, self.h_input_encoder_size))]) |
| | self.output_encoder_cell = nn.LSTMCell( |
| | input_size=self.v_input + |
| | 1 + |
| | self.h_input_encoder_size, |
| | hidden_size=self.h_output_encoder_size, |
| | bias=True) |
| | self.output_encoder_init_c = Parameter( |
| | torch.rand(1, self.h_output_encoder_size)) |
| | self.decoder_cell = nn.LSTMCell( |
| | input_size=self.v_target + 1, |
| | hidden_size=self.h_decoder_size, |
| | bias=True) |
| | self.decoder_init_c = Parameter(torch.rand(1, self.h_decoder_size)) |
| |
|
| | self.W = nn.Linear( |
| | self.h_output_encoder_size + |
| | self.h_decoder_size, |
| | self.embedding_size) |
| | self.V = nn.Linear(self.embedding_size, self.v_target + 1) |
| | self.input_A = nn.Bilinear( |
| | self.h_input_encoder_size, |
| | self.h_output_encoder_size, |
| | 1, |
| | bias=False) |
| | self.output_A = nn.Bilinear( |
| | self.h_output_encoder_size, |
| | self.h_decoder_size, |
| | 1, |
| | bias=False) |
| | self.input_EOS = torch.zeros(1, self.v_input + 1) |
| | self.input_EOS[:, -1] = 1 |
| | self.input_EOS = Parameter(self.input_EOS) |
| | self.output_EOS = torch.zeros(1, self.v_input + 1) |
| | self.output_EOS[:, -1] = 1 |
| | self.output_EOS = Parameter(self.output_EOS) |
| | self.target_EOS = torch.zeros(1, self.v_target + 1) |
| | self.target_EOS[:, -1] = 1 |
| | self.target_EOS = Parameter(self.target_EOS) |
| |
|
| | def __getstate__(self): |
| | if hasattr(self, 'opt'): |
| | return dict([(k, v) for k, v in self.__dict__.items( |
| | ) if k is not 'opt'] + [('optstate', self.opt.state_dict())]) |
| | |
| | |
| | else: |
| | return self.__dict__ |
| |
|
| | def __setstate__(self, state): |
| | self.__dict__.update(state) |
| | |
| | if isinstance(self.input_encoder_init, tuple): |
| | self.input_encoder_init = nn.ParameterList( |
| | list(self.input_encoder_init)) |
| |
|
| | def clear_optimiser(self): |
| | if hasattr(self, 'opt'): |
| | del self.opt |
| | if hasattr(self, 'optstate'): |
| | del self.optstate |
| |
|
| | def get_optimiser(self): |
| | self.opt = torch.optim.Adam(self.parameters(), lr=0.001) |
| | if hasattr(self, 'optstate'): |
| | self.opt.load_state_dict(self.optstate) |
| |
|
| | def optimiser_step(self, inputs, outputs, target): |
| | if not hasattr(self, 'opt'): |
| | self.get_optimiser() |
| | score = self.score(inputs, outputs, target, autograd=True).mean() |
| | (-score).backward() |
| | self.opt.step() |
| | self.opt.zero_grad() |
| | return score.data[0] |
| |
|
| | def set_target_vocabulary(self, target_vocabulary): |
| | if target_vocabulary == self.target_vocabulary: |
| | return |
| |
|
| | V_weight = [] |
| | V_bias = [] |
| | decoder_ih = [] |
| |
|
| | for i in range(len(target_vocabulary)): |
| | if target_vocabulary[i] in self.target_vocabulary: |
| | j = self.target_vocabulary.index(target_vocabulary[i]) |
| | V_weight.append(self.V.weight.data[j:j + 1]) |
| | V_bias.append(self.V.bias.data[j:j + 1]) |
| | decoder_ih.append(self.decoder_cell.weight_ih.data[:, j:j + 1]) |
| | else: |
| | V_weight.append(torch.zeros(1, self.V.weight.size(1))) |
| | V_bias.append(torch.ones(1) * -10) |
| | decoder_ih.append( |
| | torch.zeros( |
| | self.decoder_cell.weight_ih.data.size(0), 1)) |
| |
|
| | V_weight.append(self.V.weight.data[-1:]) |
| | V_bias.append(self.V.bias.data[-1:]) |
| | decoder_ih.append(self.decoder_cell.weight_ih.data[:, -1:]) |
| |
|
| | self.target_vocabulary = target_vocabulary |
| | self.v_target = len(target_vocabulary) |
| | self.target_EOS.data = torch.zeros(1, self.v_target + 1) |
| | self.target_EOS.data[:, -1] = 1 |
| |
|
| | self.V.weight.data = torch.cat(V_weight, dim=0) |
| | self.V.bias.data = torch.cat(V_bias, dim=0) |
| | self.V.out_features = self.V.bias.data.size(0) |
| |
|
| | self.decoder_cell.weight_ih.data = torch.cat(decoder_ih, dim=1) |
| | self.decoder_cell.input_size = self.decoder_cell.weight_ih.data.size(1) |
| |
|
| | self.clear_optimiser() |
| |
|
| | def input_encoder_get_init(self, batch_size): |
| | if self.cell_type == "GRU": |
| | return self.input_encoder_init.repeat(batch_size, 1) |
| | if self.cell_type == "LSTM": |
| | return tuple(x.repeat(batch_size, 1) |
| | for x in self.input_encoder_init) |
| |
|
| | def output_encoder_get_init(self, input_encoder_h): |
| | if self.cell_type == "GRU": |
| | return input_encoder_h |
| | if self.cell_type == "LSTM": |
| | return ( |
| | input_encoder_h, |
| | self.output_encoder_init_c.repeat( |
| | input_encoder_h.size(0), |
| | 1)) |
| |
|
| | def decoder_get_init(self, output_encoder_h): |
| | if self.cell_type == "GRU": |
| | return output_encoder_h |
| | if self.cell_type == "LSTM": |
| | return ( |
| | output_encoder_h, |
| | self.decoder_init_c.repeat( |
| | output_encoder_h.size(0), |
| | 1)) |
| |
|
| | def cell_get_h(self, cell_state): |
| | if self.cell_type == "GRU": |
| | return cell_state |
| | if self.cell_type == "LSTM": |
| | return cell_state[0] |
| |
|
| | def score(self, inputs, outputs, target, autograd=False): |
| | inputs = self.inputsToTensors(inputs) |
| | outputs = self.inputsToTensors(outputs) |
| | target = self.targetToTensor(target) |
| | target, score = self.run(inputs, outputs, target=target, mode="score") |
| | |
| | if autograd: |
| | return score |
| | else: |
| | return score.data |
| |
|
| | def sample(self, inputs, outputs): |
| | inputs = self.inputsToTensors(inputs) |
| | outputs = self.inputsToTensors(outputs) |
| | target, score = self.run(inputs, outputs, mode="sample") |
| | target = self.tensorToOutput(target) |
| | return target |
| |
|
| | def sampleAndScore(self, inputs, outputs, nRepeats=None): |
| | inputs = self.inputsToTensors(inputs) |
| | outputs = self.inputsToTensors(outputs) |
| | if nRepeats is None: |
| | target, score = self.run(inputs, outputs, mode="sample") |
| | target = self.tensorToOutput(target) |
| | return target, score.data |
| | else: |
| | target = [] |
| | score = [] |
| | for i in range(nRepeats): |
| | |
| | t, s = self.run(inputs, outputs, mode="sample") |
| | t = self.tensorToOutput(t) |
| | target.extend(t) |
| | score.extend(list(s.data)) |
| | return target, score |
| |
|
| | def run(self, inputs, outputs, target=None, mode="sample"): |
| | """ |
| | :param mode: "score" returns log p(target|input), "sample" returns target ~ p(-|input) |
| | :param List[LongTensor] inputs: n_examples * (max_length_input * batch_size) |
| | :param List[LongTensor] target: max_length_target * batch_size |
| | """ |
| | assert((mode == "score" and target is not None) or mode == "sample") |
| |
|
| | n_examples = len(inputs) |
| | max_length_input = [inputs[j].size(0) for j in range(n_examples)] |
| | max_length_output = [outputs[j].size(0) for j in range(n_examples)] |
| | max_length_target = target.size(0) if target is not None else 10 |
| | batch_size = inputs[0].size(1) |
| |
|
| | score = Variable(torch.zeros(batch_size)) |
| | inputs_scatter = [Variable(torch.zeros(max_length_input[j], batch_size, self.v_input + 1).scatter_( |
| | 2, inputs[j][:, :, None], 1)) for j in range(n_examples)] |
| | outputs_scatter = [Variable(torch.zeros(max_length_output[j], batch_size, self.v_input + 1).scatter_( |
| | 2, outputs[j][:, :, None], 1)) for j in range(n_examples)] |
| | if target is not None: |
| | target_scatter = Variable(torch.zeros(max_length_target, |
| | batch_size, |
| | self.v_target + 1).scatter_(2, |
| | target[:, |
| | :, |
| | None], |
| | 1)) |
| |
|
| | |
| |
|
| | |
| | input_H = [] |
| | input_embeddings = [] |
| | |
| | input_attention_mask = [] |
| | for j in range(n_examples): |
| | active = torch.Tensor(max_length_input[j], batch_size).byte() |
| | active[0, :] = 1 |
| | state = self.input_encoder_get_init(batch_size) |
| | hs = [] |
| | for i in range(max_length_input[j]): |
| | state = self.input_encoder_cell( |
| | inputs_scatter[j][i, :, :], state) |
| | if i + 1 < max_length_input[j]: |
| | active[i + 1, :] = active[i, :] * \ |
| | (inputs[j][i, :] != self.v_input) |
| | h = self.cell_get_h(state) |
| | hs.append(h[None, :, :]) |
| | input_H.append(torch.cat(hs, 0)) |
| | embedding_idx = active.sum(0).long() - 1 |
| | embedding = input_H[j].gather(0, Variable( |
| | embedding_idx[None, :, None].repeat(1, 1, self.h_input_encoder_size)))[0] |
| | input_embeddings.append(embedding) |
| | input_attention_mask.append(Variable(active.float().log())) |
| |
|
| | |
| |
|
| | def input_attend(j, h_out): |
| | """ |
| | 'general' attention from https://arxiv.org/pdf/1508.04025.pdf |
| | :param j: Index of example |
| | :param h_out: batch_size * h_output_encoder_size |
| | """ |
| | scores = self.input_A( |
| | input_H[j].view( |
| | max_length_input[j] * batch_size, |
| | self.h_input_encoder_size), |
| | h_out.view( |
| | batch_size, |
| | self.h_output_encoder_size).repeat( |
| | max_length_input[j], |
| | 1)).view( |
| | max_length_input[j], |
| | batch_size) + input_attention_mask[j] |
| | c = (F.softmax(scores[:, :, None], dim=0) * input_H[j]).sum(0) |
| | return c |
| |
|
| | |
| | output_H = [] |
| | output_embeddings = [] |
| | |
| | output_attention_mask = [] |
| | for j in range(n_examples): |
| | active = torch.Tensor(max_length_output[j], batch_size).byte() |
| | active[0, :] = 1 |
| | state = self.output_encoder_get_init(input_embeddings[j]) |
| | hs = [] |
| | h = self.cell_get_h(state) |
| | for i in range(max_length_output[j]): |
| | state = self.output_encoder_cell(torch.cat( |
| | [outputs_scatter[j][i, :, :], input_attend(j, h)], 1), state) |
| | if i + 1 < max_length_output[j]: |
| | active[i + 1, :] = active[i, :] * \ |
| | (outputs[j][i, :] != self.v_input) |
| | h = self.cell_get_h(state) |
| | hs.append(h[None, :, :]) |
| | output_H.append(torch.cat(hs, 0)) |
| | embedding_idx = active.sum(0).long() - 1 |
| | embedding = output_H[j].gather(0, Variable( |
| | embedding_idx[None, :, None].repeat(1, 1, self.h_output_encoder_size)))[0] |
| | output_embeddings.append(embedding) |
| | output_attention_mask.append(Variable(active.float().log())) |
| |
|
| | |
| |
|
| | def output_attend(j, h_dec): |
| | """ |
| | 'general' attention from https://arxiv.org/pdf/1508.04025.pdf |
| | :param j: Index of example |
| | :param h_dec: batch_size * h_decoder_size |
| | """ |
| | scores = self.output_A( |
| | output_H[j].view( |
| | max_length_output[j] * batch_size, |
| | self.h_output_encoder_size), |
| | h_dec.view( |
| | batch_size, |
| | self.h_decoder_size).repeat( |
| | max_length_output[j], |
| | 1)).view( |
| | max_length_output[j], |
| | batch_size) + output_attention_mask[j] |
| | c = (F.softmax(scores[:, :, None], dim=0) * output_H[j]).sum(0) |
| | return c |
| |
|
| | |
| | target = target if mode == "score" else torch.zeros( |
| | max_length_target, batch_size).long() |
| | decoder_states = [ |
| | self.decoder_get_init( |
| | output_embeddings[j]) for j in range(n_examples)] |
| | active = torch.ones(batch_size).byte() |
| | for i in range(max_length_target): |
| | FC = [] |
| | for j in range(n_examples): |
| | h = self.cell_get_h(decoder_states[j]) |
| | p_aug = torch.cat([h, output_attend(j, h)], 1) |
| | FC.append(F.tanh(self.W(p_aug)[None, :, :])) |
| | |
| | m = torch.max(torch.cat(FC, 0), 0)[0] |
| | logsoftmax = F.log_softmax(self.V(m), dim=1) |
| | if mode == "sample": |
| | target[i, :] = torch.multinomial( |
| | logsoftmax.data.exp(), 1)[:, 0] |
| | score = score + \ |
| | choose(logsoftmax, target[i, :]) * Variable(active.float()) |
| | active *= (target[i, :] != self.v_target) |
| | for j in range(n_examples): |
| | if mode == "score": |
| | target_char_scatter = target_scatter[i, :, :] |
| | elif mode == "sample": |
| | target_char_scatter = Variable(torch.zeros( |
| | batch_size, self.v_target + 1).scatter_(1, target[i, :, None], 1)) |
| | decoder_states[j] = self.decoder_cell( |
| | target_char_scatter, decoder_states[j]) |
| | return target, score |
| |
|
| | def inputsToTensors(self, inputss): |
| | """ |
| | :param inputss: size = nBatch * nExamples |
| | """ |
| | tensors = [] |
| | for j in range(len(inputss[0])): |
| | inputs = [x[j] for x in inputss] |
| | maxlen = max(len(s) for s in inputs) |
| | t = torch.ones( |
| | 1 if maxlen == 0 else maxlen + 1, |
| | len(inputs)).long() * self.v_input |
| | for i in range(len(inputs)): |
| | s = inputs[i] |
| | if len(s) > 0: |
| | t[:len(s), i] = torch.LongTensor( |
| | [self.input_vocabulary.index(x) for x in s]) |
| | tensors.append(t) |
| | return tensors |
| |
|
| | def targetToTensor(self, targets): |
| | """ |
| | :param targets: |
| | """ |
| | maxlen = max(len(s) for s in targets) |
| | t = torch.ones( |
| | 1 if maxlen == 0 else maxlen + 1, |
| | len(targets)).long() * self.v_target |
| | for i in range(len(targets)): |
| | s = targets[i] |
| | if len(s) > 0: |
| | t[:len(s), i] = torch.LongTensor( |
| | [self.target_vocabulary.index(x) for x in s]) |
| | return t |
| |
|
| | def tensorToOutput(self, tensor): |
| | """ |
| | :param tensor: max_length * batch_size |
| | """ |
| | out = [] |
| | for i in range(tensor.size(1)): |
| | l = tensor[:, i].tolist() |
| | if l[0] == self.v_target: |
| | out.append([]) |
| | elif self.v_target in l: |
| | final = tensor[:, i].tolist().index(self.v_target) |
| | out.append([self.target_vocabulary[x] |
| | for x in tensor[:final, i]]) |
| | else: |
| | out.append([self.target_vocabulary[x] for x in tensor[:, i]]) |
| | return out |
| |
|