Spaces:
Paused
Paused
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| # Node encoder | |
| from models.classifiers.nn_classifiers.encoders.mlp_node_encoder import MLPNodeEncoder | |
| from models.classifiers.nn_classifiers.encoders.mha_node_encoder import SelfMHANodeEncoder | |
| # Edge encoder | |
| from models.classifiers.nn_classifiers.encoders.attn_edge_encoder import AttentionEdgeEncoder | |
| from models.classifiers.nn_classifiers.encoders.concat_edge_encoder import ConcatEdgeEncoder | |
| # Decoder | |
| from models.classifiers.nn_classifiers.decoders.lstm_decoder import LSTMDecoder | |
| from models.classifiers.nn_classifiers.decoders.mlp_decoder import MLPDecoder | |
| from models.classifiers.nn_classifiers.decoders.mha_decoder import SelfMHADecoder | |
| # data loader | |
| from utils.data_utils.tsptw_dataset import load_tsptw_sequentially | |
| from utils.data_utils.pctsp_dataset import load_pctsp_sequentially | |
| from utils.data_utils.pctsptw_dataset import load_pctsptw_sequentially | |
| from utils.data_utils.cvrp_dataset import load_cvrp_sequentially | |
| NODE_ENC_LIST = ["mlp", "mha"] | |
| EDGE_ENC_LIST = ["concat", "attn"] | |
| DEC_LIST = ["mlp", "lstm", "mha"] | |
| class NNClassifier(nn.Module): | |
| def __init__(self, | |
| problem: str, | |
| node_enc_type: str, | |
| edge_enc_type: str, | |
| dec_type: str, | |
| emb_dim: int, | |
| num_enc_mlp_layers: int, | |
| num_dec_mlp_layers: int, | |
| num_classes: int, | |
| dropout: float, | |
| pos_encoder: str = "sincos"): | |
| super().__init__() | |
| self.problem = problem | |
| self.node_enc_type = node_enc_type | |
| self.edge_enc_type = edge_enc_type | |
| self.dec_type = dec_type | |
| assert node_enc_type in NODE_ENC_LIST, f"Invalid enc_type. select from {NODE_ENC_LIST}" | |
| assert dec_type in DEC_LIST, f"Invalid dec_type. select from {DEC_LIST}" | |
| self.is_sequential = True if dec_type in ["lstm", "mha"] else False | |
| coord_dim = 2 # only support 2d problem | |
| if problem == "tsptw": | |
| node_dim = 4 # coords (2) + time window (2) | |
| state_dim = 1 # current time (1) | |
| elif problem == "pctsp": | |
| node_dim = 4 # coords (2) + prize (1) + penalty (1) | |
| state_dim = 2 # current prize (1) + current penalty (1) | |
| elif problem == "pctsptw": | |
| node_dim = 6 # coords (2) + prize (1) + penalty (1) + time window (2) | |
| state_dim = 3 # current prize (1) + current penalty (1) + current time (1) | |
| elif problem == "cvrp": | |
| node_dim = 3 # coords (2) + demand (1) | |
| state_dim = 1 # remaining capacity (1) | |
| else: | |
| NotImplementedError | |
| #---------------- | |
| # Graph encoding | |
| #---------------- | |
| # Node encoder | |
| if node_enc_type == "mlp": | |
| self.node_enc = MLPNodeEncoder(coord_dim, node_dim, emb_dim, num_enc_mlp_layers, dropout) | |
| elif node_enc_type == "mha": | |
| num_heads = 8 | |
| num_mha_layers = 2 | |
| self.node_enc = SelfMHANodeEncoder(coord_dim, node_dim, emb_dim, num_heads, num_mha_layers, dropout) | |
| else: | |
| raise NotImplementedError | |
| # Readout | |
| if edge_enc_type == "concat": | |
| self.readout = ConcatEdgeEncoder(state_dim, emb_dim, dropout) | |
| elif edge_enc_type == "attn": | |
| self.readout = AttentionEdgeEncoder(state_dim, emb_dim, dropout) | |
| else: | |
| raise NotImplementedError | |
| #------------------------ | |
| # Classification Decoder | |
| #------------------------ | |
| if dec_type == "mlp": | |
| self.decoder = MLPDecoder(emb_dim, num_dec_mlp_layers, num_classes, dropout) | |
| elif dec_type == "lstm": | |
| self.decoder = LSTMDecoder(emb_dim, num_dec_mlp_layers, num_classes, dropout) | |
| elif dec_type == "mha": | |
| num_heads = 8 | |
| num_mha_layers = 2 | |
| self.decoder = SelfMHADecoder(emb_dim, num_heads, num_mha_layers, num_classes, dropout, pos_encoder) | |
| else: | |
| raise NotImplementedError | |
| def forward(self, inputs): | |
| """ | |
| Paramters | |
| --------- | |
| inputs: dict | |
| curr_node_id: torch.LongTensor [batch_size x max_seq_length] if self.sequential else [batch_size] | |
| next_node_id: torch.LongTensor [batch_size x max_seq_length] if self.sequential else [batch_size] | |
| node_feat: torch.FloatTensor [batch_size x max_seq_length x num_nodes x node_dim] if self.sequential else [batch_size x num_nodes x node_dim] | |
| mask: torch.LongTensor [batch_size x max_seq_length x num_nodes] if self.sequential else [batch_size x num_nodes] | |
| state: torch.FloatTensor [batch_size x max_seq_length x state_dim] if self.sequential else [batch_size x state_dim] | |
| Returns | |
| ------- | |
| probs: torch.tensor [batch_size x seq_length x num_classes] if self.sequential else [batch_size x num_classes] | |
| probabilities of classes | |
| """ | |
| #----------------- | |
| # Encoding graphs | |
| #----------------- | |
| if self.is_sequential: | |
| shp = inputs["curr_node_id"].size() | |
| inputs = {key: value.flatten(0, 1) for key, value in inputs.items()} | |
| node_emb = self.node_enc(inputs) # [(batch_size*max_seq_length) x emb_dim] if self.sequential else [batch_size x emb_dim] | |
| graph_emb = self.readout(inputs, node_emb) | |
| if self.is_sequential: | |
| graph_emb = graph_emb.view(*shp, -1) # [batch_size x max_seq_length x emb_dim] | |
| #---------- | |
| # Decoding | |
| #---------- | |
| probs = self.decoder(graph_emb) | |
| return probs | |
| def get_inputs(self, routes, first_explained_step, node_feats): | |
| node_feats_ = node_feats.copy() | |
| node_feats_["tour"] = routes | |
| if self.problem == "tsptw": | |
| seq_data = load_tsptw_sequentially(node_feats_) | |
| elif self.problem == "pctsp": | |
| seq_data = load_pctsp_sequentially(node_feats_) | |
| elif self.problem == "pctsptw": | |
| seq_data = load_pctsptw_sequentially(node_feats_) | |
| elif self.problem == "cvrp": | |
| seq_data = load_cvrp_sequentially(node_feats_) | |
| else: | |
| NotImplementedError | |
| def pad_seq_length(batch): | |
| data = {} | |
| for key in batch[0].keys(): | |
| padding_value = True if key == "mask" else 0.0 | |
| # post-padding | |
| data[key] = torch.nn.utils.rnn.pad_sequence([d[key] for d in batch], batch_first=True, padding_value=padding_value) | |
| pad_mask = torch.nn.utils.rnn.pad_sequence([torch.full((d["mask"].size(0), ), True) for d in batch], batch_first=True, padding_value=False) | |
| data.update({"pad_mask": pad_mask}) | |
| return data | |
| instance = pad_seq_length(seq_data) | |
| return instance |