| | from abc import ABC, abstractmethod |
| | import os |
| | import time |
| | import gc |
| | import json |
| | from tqdm import tqdm |
| | import torch |
| | from singleVis.losses import PositionRecoverLoss |
| | from torch.utils.data import DataLoader, WeightedRandomSampler |
| |
|
| | import copy |
| | import numpy as np |
| | from singleVis.custom_weighted_random_sampler import CustomWeightedRandomSampler, CustomWeightedRandomSamplerVis |
| | from singleVis.spatial_edge_constructor import ActiveLearningEpochSpatialEdgeConstructor |
| | from singleVis.edge_dataset import DVIDataHandler |
| |
|
| | torch.manual_seed(0) |
| | torch.cuda.manual_seed_all(0) |
| |
|
| | """ |
| | 1. construct a spatio-temporal complex |
| | 2. construct an edge-dataset |
| | 3. train the network |
| | |
| | Trainer should contains |
| | 1. train_step function |
| | 2. early stop |
| | 3. ... |
| | """ |
| |
|
| | class TrainerAbstractClass(ABC): |
| | @abstractmethod |
| | def __init__(self, *args, **kwargs): |
| | pass |
| |
|
| | @property |
| | @abstractmethod |
| | def loss(self): |
| | pass |
| |
|
| | @abstractmethod |
| | def reset_optim(self): |
| | pass |
| |
|
| | @abstractmethod |
| | def update_edge_loader(self): |
| | pass |
| |
|
| | @abstractmethod |
| | def update_vis_model(self): |
| | pass |
| |
|
| | @abstractmethod |
| | def update_optimizer(self): |
| | pass |
| |
|
| | @abstractmethod |
| | def update_lr_scheduler(self): |
| | pass |
| |
|
| | @abstractmethod |
| | def train_step(self): |
| | pass |
| |
|
| | @abstractmethod |
| | def train(self): |
| | pass |
| |
|
| | @abstractmethod |
| | def load(self): |
| | pass |
| |
|
| | @abstractmethod |
| | def save(self): |
| | pass |
| |
|
| | @abstractmethod |
| | def record_time(self): |
| | pass |
| |
|
| |
|
| | class ActiveLearningEdgeLoader(DataLoader): |
| | def __init__(self, dataset, weights, batch_size=32, **kwargs): |
| | |
| | sampler = WeightedRandomSampler(weights, len(dataset)) |
| | super().__init__(dataset, batch_size=batch_size, sampler=sampler, **kwargs) |
| |
|
| | class SingleVisTrainer(TrainerAbstractClass): |
| | def __init__(self, model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE): |
| | self.model = model |
| | self.criterion = criterion |
| | self.optimizer = optimizer |
| | self.lr_scheduler = lr_scheduler |
| | self.DEVICE = DEVICE |
| | self.edge_loader = edge_loader |
| | self._loss = 100.0 |
| |
|
| | @property |
| | def loss(self): |
| | return self._loss |
| |
|
| | def reset_optim(self, optim, lr_s): |
| | self.optimizer = optim |
| | self.lr_scheduler = lr_s |
| | print("Successfully reset optimizer!") |
| | |
| | def update_edge_loader(self, edge_loader): |
| | del self.edge_loader |
| | gc.collect() |
| | self.edge_loader = edge_loader |
| | |
| | def update_vis_model(self, model): |
| | self.model.load_state_dict(model.state_dict()) |
| | |
| | def update_optimizer(self, optimizer): |
| | self.optimizer = optimizer |
| | |
| | def update_lr_scheduler(self, lr_scheduler): |
| | self.lr_scheduler = lr_scheduler |
| |
|
| | def train_step(self): |
| | self.model.to(device=self.DEVICE) |
| | self.model.train() |
| | all_loss = [] |
| | umap_losses = [] |
| | recon_losses = [] |
| |
|
| | t = tqdm(self.edge_loader, leave=True, total=len(self.edge_loader)) |
| |
|
| | |
| | for data in t: |
| | edge_to, edge_from, a_to, a_from = data |
| |
|
| | edge_to = edge_to.to(device=self.DEVICE, dtype=torch.float32) |
| | edge_from = edge_from.to(device=self.DEVICE, dtype=torch.float32) |
| | a_to = a_to.to(device=self.DEVICE, dtype=torch.float32) |
| | a_from = a_from.to(device=self.DEVICE, dtype=torch.float32) |
| |
|
| | outputs = self.model(edge_to, edge_from) |
| | umap_l, recon_l, loss = self.criterion(edge_to, edge_from, a_to, a_from, outputs) |
| | all_loss.append(loss.mean().item()) |
| | umap_losses.append(umap_l.mean().item()) |
| | recon_losses.append(recon_l.mean().item()) |
| | |
| | self.optimizer.zero_grad() |
| | loss.mean().backward() |
| | self.optimizer.step() |
| | self._loss = sum(all_loss) / len(all_loss) |
| | self.model.eval() |
| | print('umap:{:.4f}\trecon_l:{:.4f}\tloss:{:.4f}'.format(sum(umap_losses) / len(umap_losses), |
| | sum(recon_losses) / len(recon_losses), |
| | sum(all_loss) / len(all_loss))) |
| | return self.loss |
| |
|
| | def train(self, PATIENT, MAX_EPOCH_NUMS): |
| | patient = PATIENT |
| | time_start = time.time() |
| | for epoch in range(MAX_EPOCH_NUMS): |
| | print("====================\nepoch:{}\n===================".format(epoch+1)) |
| | prev_loss = self.loss |
| | loss = self.train_step() |
| | self.lr_scheduler.step() |
| | |
| | if prev_loss - loss < 5E-3: |
| | if patient == 0: |
| | break |
| | else: |
| | patient -= 1 |
| | else: |
| | patient = PATIENT |
| |
|
| | time_end = time.time() |
| | time_spend = time_end - time_start |
| | print("Time spend: {:.2f} for training vis model...".format(time_spend)) |
| |
|
| | def load(self, file_path): |
| | """ |
| | save all parameters... |
| | :param name: |
| | :return: |
| | """ |
| | save_model = torch.load(file_path, map_location="cpu") |
| | self._loss = save_model["loss"] |
| | self.model.load_state_dict(save_model["state_dict"]) |
| | self.model.to(self.DEVICE) |
| | print("Successfully load visualization model...") |
| |
|
| | def save(self, save_dir, file_name): |
| | """ |
| | save all parameters... |
| | :param name: |
| | :return: |
| | """ |
| | save_model = { |
| | "loss": self.loss, |
| | "state_dict": self.model.state_dict(), |
| | "optimizer": self.optimizer.state_dict()} |
| | save_path = os.path.join(save_dir, file_name + '.pth') |
| | torch.save(save_model, save_path) |
| | print("Successfully save visualization model...") |
| | |
| | def record_time(self, save_dir, file_name, key, t): |
| | |
| | save_file = os.path.join(save_dir, file_name+".json") |
| | if not os.path.exists(save_file): |
| | evaluation = dict() |
| | else: |
| | f = open(save_file, "r") |
| | evaluation = json.load(f) |
| | f.close() |
| | evaluation[key] = round(t, 3) |
| | with open(save_file, 'w') as f: |
| | json.dump(evaluation, f) |
| |
|
| |
|
| | |
| |
|
| |
|
| |
|
| | class HybridVisTrainer(SingleVisTrainer): |
| | def __init__(self, model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE): |
| | super().__init__(model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE) |
| |
|
| | def train_step(self): |
| | self.model = self.model.to(device=self.DEVICE) |
| | self.model.train() |
| | all_loss = [] |
| | umap_losses = [] |
| | recon_losses = [] |
| | smooth_losses = [] |
| |
|
| | t = tqdm(self.edge_loader, leave=True, total=len(self.edge_loader)) |
| | |
| | for data in t: |
| | edge_to, edge_from, a_to, a_from, embedded_to, coeffi_to = data |
| |
|
| | edge_to = edge_to.to(device=self.DEVICE, dtype=torch.float32) |
| | edge_from = edge_from.to(device=self.DEVICE, dtype=torch.float32) |
| | a_to = a_to.to(device=self.DEVICE, dtype=torch.float32) |
| | a_from = a_from.to(device=self.DEVICE, dtype=torch.float32) |
| | embedded_to = embedded_to.to(device=self.DEVICE, dtype=torch.float32) |
| | coeffi_to = coeffi_to.to(device=self.DEVICE, dtype=torch.float32) |
| |
|
| | outputs = self.model(edge_to, edge_from) |
| | umap_l, recon_l, smooth_l, loss = self.criterion(edge_to, edge_from, a_to, a_from, embedded_to, coeffi_to, outputs) |
| | all_loss.append(loss.item()) |
| | umap_losses.append(umap_l.item()) |
| | recon_losses.append(recon_l.item()) |
| | smooth_losses.append(smooth_l.item()) |
| | |
| | self.optimizer.zero_grad() |
| | loss.backward() |
| | self.optimizer.step() |
| | self._loss = sum(all_loss) / len(all_loss) |
| | self.model.eval() |
| | print('umap:{:.4f}\trecon_l:{:.4f}\tsmooth_l:{:.4f}\tloss:{:.4f}'.format(sum(umap_losses) / len(umap_losses), |
| | sum(recon_losses) / len(recon_losses), |
| | sum(smooth_losses) / len(smooth_losses), |
| | sum(all_loss) / len(all_loss))) |
| | return self.loss |
| | |
| | def record_time(self, save_dir, file_name, operation, seg, t): |
| | |
| | save_file = os.path.join(save_dir, file_name+".json") |
| | if not os.path.exists(save_file): |
| | evaluation = dict() |
| | else: |
| | f = open(save_file, "r") |
| | evaluation = json.load(f) |
| | f.close() |
| | if operation not in evaluation.keys(): |
| | evaluation[operation] = dict() |
| | evaluation[operation][str(seg)] = round(t, 3) |
| | with open(save_file, 'w') as f: |
| | json.dump(evaluation, f) |
| |
|
| | def disable_grad(model): |
| | for param in model.parameters(): |
| | param.requires_grad = False |
| |
|
| |
|
| | |
| | RE_TRAINING_INTERVAL = 10 |
| |
|
| | class ActiveLearningTrainer(SingleVisTrainer): |
| | def __init__(self, model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE): |
| | self.model = model |
| | self.model = self.model.to(device=DEVICE) |
| | self.criterion = criterion |
| | self.optimizer = optimizer |
| | self.lr_scheduler = lr_scheduler |
| | self.DEVICE = DEVICE |
| | self.edge_loader = edge_loader |
| | self._loss = 100.0 |
| |
|
| | |
| | |
| |
|
| | class DVIALTrainer(SingleVisTrainer): |
| | def __init__(self, model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE): |
| | super().__init__(model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE) |
| | self.is_first_active_learning = True |
| | |
| | |
| |
|
| | def evaluate_loss(self): |
| | print("evluating") |
| | |
| | |
| | losses = [] |
| | |
| | self.model.eval() |
| | with torch.no_grad(): |
| | for data in self.edge_loader: |
| | edge_to, edge_from, a_to, a_from = data |
| | edge_to = edge_to.to(device=self.DEVICE, dtype=torch.float32) |
| | edge_from = edge_from.to(device=self.DEVICE, dtype=torch.float32) |
| | a_to = a_to.to(device=self.DEVICE, dtype=torch.float32) |
| | a_from = a_from.to(device=self.DEVICE, dtype=torch.float32) |
| | outputs = self.model(edge_to, edge_from) |
| | _, _,_, loss = self.criterion(edge_to, edge_from, a_to, a_from, self.model, outputs) |
| | losses.append(loss.item()) |
| | |
| | weights = 1.0 / torch.tensor(losses, dtype=torch.float32) |
| | |
| | weights = weights / weights.sum() |
| | |
| | new_loader = ActiveLearningEdgeLoader(self.edge_loader.dataset, weights, batch_size=self.edge_loader.batch_size) |
| | return losses,new_loader |
| | |
| | def train_step(self, edge_loader ): |
| | self.model = self.model.to(device=self.DEVICE) |
| |
|
| | self.model.train() |
| | all_loss = [] |
| | umap_losses = [] |
| | recon_losses = [] |
| | temporal_losses = [] |
| |
|
| |
|
| | t = tqdm(edge_loader, leave=True, total=len(edge_loader)) |
| | |
| | for data in t: |
| | edge_to, edge_from, a_to, a_from = data |
| |
|
| | edge_to = edge_to.to(device=self.DEVICE, dtype=torch.float32) |
| | edge_from = edge_from.to(device=self.DEVICE, dtype=torch.float32) |
| | a_to = a_to.to(device=self.DEVICE, dtype=torch.float32) |
| | a_from = a_from.to(device=self.DEVICE, dtype=torch.float32) |
| |
|
| | outputs = self.model(edge_to, edge_from) |
| | umap_l, recon_l, temporal_l, loss = self.criterion(edge_to, edge_from, a_to, a_from, self.model, outputs) |
| | |
| | |
| | all_loss.append(loss.mean().item()) |
| | umap_losses.append(umap_l.item()) |
| | recon_losses.append(recon_l.item()) |
| | temporal_losses.append(temporal_l.mean().item()) |
| |
|
| | |
| | self.optimizer.zero_grad() |
| | loss.mean().backward() |
| | self.optimizer.step() |
| | self._loss = sum(all_loss) / len(all_loss) |
| | self.model.eval() |
| | print('umap:{:.4f}\trecon_l:{:.4f}\ttemporal_l:{:.4f}\tloss:{:.4f}'.format(sum(umap_losses) / len(umap_losses), |
| | sum(recon_losses) / len(recon_losses), |
| | sum(temporal_losses) / len(temporal_losses), |
| | sum(all_loss) / len(all_loss))) |
| | return self.loss |
| | |
| | def run_epoch(self, epoch, is_active_learning=False, is_full_data=False): |
| | print("====================\nepoch:{}\n===================".format(epoch+1)) |
| | start_time = time.time() |
| |
|
| | if is_active_learning and is_full_data == False: |
| | _, current_loader = self.evaluate_loss() |
| | |
| | if self.is_first_active_learning: |
| | print("change learning rate") |
| | for param_group in self.optimizer.param_groups: |
| | param_group['lr'] *= 0.1 |
| | self.is_first_active_learning = False |
| | |
| | prev_loss = self.loss |
| |
|
| | if is_full_data: |
| | print("full data") |
| | loss = self.train_step(self.edge_loader) |
| | else: |
| | loss = self.train_step(current_loader) |
| | |
| | self.lr_scheduler.step() |
| |
|
| | elapsed_time = time.time() - start_time |
| | print("Epoch completed in: {:.2f} seconds".format(elapsed_time)) |
| |
|
| | return prev_loss, loss |
| | |
| | def train(self, PATIENT, MAX_EPOCH_NUMS): |
| | print("ininin in dvi") |
| | patient = PATIENT |
| | time_start = time.time() |
| | |
| | for epoch in range(10): |
| | print("Pretraining") |
| | _, _ = self.run_epoch(epoch, is_active_learning=False,is_full_data=True ) |
| |
|
| |
|
| | for epoch in range(MAX_EPOCH_NUMS): |
| | print("In active learning") |
| | |
| | prev_loss, loss = self.run_epoch(epoch, is_active_learning=True, is_full_data=False) |
| | |
| | |
| | if abs(prev_loss - loss) < 5E-3: |
| | if patient == 0: |
| | break |
| | else: |
| | patient -= 1 |
| | else: |
| | patient = PATIENT |
| |
|
| | time_end = time.time() |
| | time_spend = time_end - time_start |
| | print("Time spend: {:.2f} for training vis model...".format(time_spend)) |
| | |
| | |
| | def record_time(self, save_dir, file_name, operation, iteration, t): |
| | |
| | save_file = os.path.join(save_dir, file_name+".json") |
| | if not os.path.exists(save_file): |
| | evaluation = dict() |
| | else: |
| | f = open(save_file, "r") |
| | evaluation = json.load(f) |
| | f.close() |
| | if operation not in evaluation.keys(): |
| | evaluation[operation] = dict() |
| | evaluation[operation][iteration] = round(t, 3) |
| | with open(save_file, 'w') as f: |
| | json.dump(evaluation, f) |
| |
|
| | class DVITrainer(SingleVisTrainer): |
| | def __init__(self, model, criterion, optimizer, lr_scheduler, edge_loader,DEVICE): |
| | super().__init__(model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE) |
| | |
| | |
| | def train_step(self): |
| | self.model = self.model.to(device=self.DEVICE) |
| | self.model.train() |
| | all_loss = [] |
| | umap_losses = [] |
| | recon_losses = [] |
| | temporal_losses = [] |
| |
|
| | t = tqdm(self.edge_loader, leave=True, total=len(self.edge_loader)) |
| | |
| | for data in t: |
| | edge_to, edge_from, a_to, a_from = data |
| |
|
| | edge_to = edge_to.to(device=self.DEVICE, dtype=torch.float32) |
| | edge_from = edge_from.to(device=self.DEVICE, dtype=torch.float32) |
| | a_to = a_to.to(device=self.DEVICE, dtype=torch.float32) |
| | a_from = a_from.to(device=self.DEVICE, dtype=torch.float32) |
| |
|
| | outputs = self.model(edge_to, edge_from) |
| | umap_l, recon_l, temporal_l, loss = self.criterion(edge_to, edge_from, a_to, a_from, self.model, outputs) |
| | loss_new = loss |
| | |
| |
|
| | |
| | all_loss.append(loss.mean().item()) |
| | umap_losses.append(umap_l.mean().item()) |
| | recon_losses.append(recon_l.mean().item()) |
| | temporal_losses.append(temporal_l.mean().item()) |
| | |
| | self.optimizer.zero_grad() |
| | loss_new.mean().backward() |
| | self.optimizer.step() |
| | self._loss = sum(all_loss) / len(all_loss) |
| | self.model.eval() |
| | print('umap:{:.4f}\trecon_l:{:.4f}\ttemporal_l:{:.4f}\tloss:{:.4f}'.format(sum(umap_losses) / len(umap_losses), |
| | sum(recon_losses) / len(recon_losses), |
| | sum(temporal_losses) / len(temporal_losses), |
| | sum(all_loss) / len(all_loss))) |
| | return self.loss |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | def radius_loss(self, embeddings, center, alpha=1.0): |
| | """ |
| | Modified radius loss function that tries to maximize the average distance. |
| | Args: |
| | embeddings: the 2D embeddings, tensor of shape (N, 2) |
| | center: the center of the circle in the 2D space, tensor of shape (2,) |
| | alpha: a coefficient for the radius loss, controlling its importance. |
| | Returns: |
| | A scalar tensor representing the radius loss. |
| | """ |
| | radii = torch.norm(embeddings - center, dim=1) |
| | normalized_radii = torch.nn.functional.normalize(radii, dim=0, p=2) |
| | normalized_mean_radii = torch.mean(normalized_radii) |
| |
|
| | return -alpha * normalized_mean_radii |
| | |
| | def orthogonal_loss(self, embeddings, beta=0.001): |
| | """ |
| | Orthogonal loss function that tries to decorrelate the embeddings. |
| | Args: |
| | embeddings: the 2D embeddings, tensor of shape (N, 2) |
| | beta: a coefficient for the orthogonal loss, controlling its importance. |
| | Returns: |
| | A scalar tensor representing the orthogonal loss. |
| | """ |
| | gram_matrix = torch.mm(embeddings, embeddings.t()) |
| | identity = torch.eye(embeddings.shape[0]).to(embeddings.device) |
| | loss = torch.norm(gram_matrix - identity) |
| | return beta * loss |
| |
|
| | |
| | def distance_order_loss(self,high_embeddings, low_embeddings, high_center, low_center, beta=0.001): |
| | """ |
| | Distance order preserving loss function. |
| | Args: |
| | high_embeddings: the high-dimensional embeddings, tensor of shape (N, D) |
| | low_embeddings: the 2D embeddings, tensor of shape (N, 2) |
| | high_center: the center of the sphere in the high-dimensional space, tensor of shape (D,) |
| | low_center: the center of the circle in the 2D space, tensor of shape (2,) |
| | beta: a coefficient for the distance order loss, controlling its importance. |
| | Returns: |
| | A scalar tensor representing the distance order loss. |
| | """ |
| | high_distances = torch.norm(high_embeddings - high_center, dim=1) |
| | low_distances = torch.norm(low_embeddings - low_center, dim=1) |
| |
|
| | high_order = torch.argsort(high_distances) |
| | low_order = torch.argsort(low_distances) |
| | high_order = high_order.float() |
| | low_order = low_order.float() |
| |
|
| | |
| | loss = torch.norm(high_order - low_order) / high_order.shape[0] |
| | |
| |
|
| |
|
| | return beta * loss |
| | |
| | |
| | def record_time(self, save_dir, file_name, operation, iteration, t): |
| | |
| | save_file = os.path.join(save_dir, file_name+".json") |
| | if not os.path.exists(save_file): |
| | evaluation = dict() |
| | else: |
| | f = open(save_file, "r") |
| | evaluation = json.load(f) |
| | f.close() |
| | if operation not in evaluation.keys(): |
| | evaluation[operation] = dict() |
| | evaluation[operation][iteration] = round(t, 3) |
| | with open(save_file, 'w') as f: |
| | json.dump(evaluation, f) |
| | class DVIActiveLearningTrainer(SingleVisTrainer): |
| | def __init__(self, model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE): |
| | super().__init__(model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE) |
| |
|
| |
|
| | |
| | def train_step(self): |
| | self.model = self.model.to(device=self.DEVICE) |
| |
|
| | self.model.train() |
| | all_loss = [] |
| | umap_losses = [] |
| | recon_losses = [] |
| | temporal_losses = [] |
| |
|
| |
|
| | t = tqdm(self.edge_loader, leave=True, total=len(self.edge_loader)) |
| | |
| | for data in t: |
| | edge_to, edge_from, a_to, a_from = data |
| |
|
| | edge_to = edge_to.to(device=self.DEVICE, dtype=torch.float32) |
| | edge_from = edge_from.to(device=self.DEVICE, dtype=torch.float32) |
| | a_to = a_to.to(device=self.DEVICE, dtype=torch.float32) |
| | a_from = a_from.to(device=self.DEVICE, dtype=torch.float32) |
| |
|
| | outputs = self.model(edge_to, edge_from) |
| | umap_l, recon_l, temporal_l, loss = self.criterion(edge_to, edge_from, a_to, a_from, self.model, outputs) |
| | |
| | |
| | all_loss.append(loss.mean().item()) |
| | umap_losses.append(umap_l.mean().item()) |
| | recon_losses.append(recon_l.mean().item()) |
| | temporal_losses.append(temporal_l.mean().item()) |
| |
|
| | |
| | self.optimizer.zero_grad() |
| | loss.mean().backward() |
| | self.optimizer.step() |
| | |
| | self._loss = sum(all_loss) / len(all_loss) |
| | self.model.eval() |
| | print('umap:{:.4f}\trecon_l:{:.4f}\ttemporal_l:{:.4f}\tloss:{:.4f}'.format(sum(umap_losses) / len(umap_losses), |
| | sum(recon_losses) / len(recon_losses), |
| | sum(temporal_losses) / len(temporal_losses), |
| | sum(all_loss) / len(all_loss))) |
| | return self.loss |
| | |
| | |
| | def record_time(self, save_dir, file_name, operation, iteration, t): |
| | |
| | save_file = os.path.join(save_dir, file_name+".json") |
| | if not os.path.exists(save_file): |
| | evaluation = dict() |
| | else: |
| | f = open(save_file, "r") |
| | evaluation = json.load(f) |
| | f.close() |
| | if operation not in evaluation.keys(): |
| | evaluation[operation] = dict() |
| | evaluation[operation][iteration] = round(t, 3) |
| | with open(save_file, 'w') as f: |
| | json.dump(evaluation, f) |
| |
|
| | class TVITrainer(SingleVisTrainer): |
| | def __init__(self, model, criterion, optimizer, lr_scheduler, edge_loader, adv_edge_loader, DEVICE): |
| | super().__init__(model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE) |
| | self.adv_edge_loader = adv_edge_loader |
| | |
| | def disable_grad(self, model): |
| | for param in model.parameters(): |
| | param.requires_grad = False |
| |
|
| | def enable_grad(self, model): |
| | for param in model.parameters(): |
| | param.requires_grad = True |
| |
|
| | def train_step(self): |
| | self.model = self.model.to(device=self.DEVICE) |
| |
|
| | self.model.train() |
| | all_loss = [] |
| | umap_losses = [] |
| | recon_losses = [] |
| | temporal_losses = [] |
| |
|
| | t = tqdm(self.edge_loader, leave=True, total=len(self.edge_loader)) |
| | self.enable_grad(self.model.encoder) |
| | print("enable") |
| | for data in t: |
| | edge_to, edge_from, a_to, a_from = data |
| |
|
| | edge_to = edge_to.to(device=self.DEVICE, dtype=torch.float32) |
| | edge_from = edge_from.to(device=self.DEVICE, dtype=torch.float32) |
| | a_to = a_to.to(device=self.DEVICE, dtype=torch.float32) |
| | a_from = a_from.to(device=self.DEVICE, dtype=torch.float32) |
| |
|
| | outputs = self.model(edge_to, edge_from) |
| | umap_l, recon_l, temporal_l, loss = self.criterion(edge_to, edge_from, a_to, a_from, self.model, outputs) |
| | |
| | all_loss.append(loss.mean().item()) |
| | umap_losses.append(umap_l.item()) |
| | recon_losses.append(recon_l.item()) |
| | temporal_losses.append(temporal_l.mean().item()) |
| |
|
| | self.optimizer.zero_grad() |
| | loss.mean().backward() |
| | self.optimizer.step() |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | adv_t = tqdm(self.adv_edge_loader, leave=True, total=len(self.adv_edge_loader)) |
| | |
| | self.disable_grad(self.model.encoder) |
| | print("disable") |
| | for adv_data in adv_t: |
| | |
| |
|
| | adv_edge_to, adv_edge_from, adv_a_to, adv_a_from = adv_data |
| |
|
| | adv_edge_to = adv_edge_to.to(device=self.DEVICE, dtype=torch.float32) |
| | adv_edge_from = adv_edge_from.to(device=self.DEVICE, dtype=torch.float32) |
| | adv_a_to = adv_a_to.to(device=self.DEVICE, dtype=torch.float32) |
| | adv_a_from = adv_a_from.to(device=self.DEVICE, dtype=torch.float32) |
| |
|
| | adv_outputs = self.model(adv_edge_to, adv_edge_from) |
| | adv_umap_l, adv_recon_l, adv_temporal_l, adv_loss = self.criterion(adv_edge_to, adv_edge_from, adv_a_to, adv_a_from, self.model, adv_outputs) |
| |
|
| | |
| | self.optimizer.zero_grad() |
| | adv_loss.mean().backward() |
| | self.optimizer.step() |
| |
|
| | self._loss = sum(all_loss) / len(all_loss) |
| | self.model.eval() |
| | print('umap:{:.4f}\trecon_l:{:.4f}\ttemporal_l:{:.4f}\tloss:{:.4f}'.format(sum(umap_losses) / len(umap_losses), |
| | sum(recon_losses) / len(recon_losses), |
| | sum(temporal_losses) / len(temporal_losses), |
| | sum(all_loss) / len(all_loss))) |
| | return self._loss |
| | |
| | def record_time(self, save_dir, file_name, operation, iteration, t): |
| | |
| | save_file = os.path.join(save_dir, file_name+".json") |
| | if not os.path.exists(save_file): |
| | evaluation = dict() |
| | else: |
| | f = open(save_file, "r") |
| | evaluation = json.load(f) |
| | f.close() |
| | if operation not in evaluation.keys(): |
| | evaluation[operation] = dict() |
| | evaluation[operation][iteration] = round(t, 3) |
| | with open(save_file, 'w') as f: |
| | json.dump(evaluation, f) |
| |
|
| | class DVIReFineTrainer(SingleVisTrainer): |
| | def __init__(self, model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE,data, disable_encoder_grad=False, **kwargs): |
| | super().__init__(model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE, **kwargs) |
| | self.disable_encoder_grad = disable_encoder_grad |
| | self.data = data |
| | |
| | def train(self, PATIENT, MAX_EPOCH_NUMS): |
| | patient = PATIENT |
| | print("patient",patient) |
| | time_start = time.time() |
| | for epoch in range(MAX_EPOCH_NUMS): |
| | print("====================\nepoch:{}\n===================".format(epoch+1)) |
| | prev_loss = self.loss |
| | loss = self.train_step() |
| | self.lr_scheduler.step() |
| | |
| | if prev_loss - loss < 5E-3: |
| | if patient == 0: |
| | break |
| | else: |
| | patient -= 1 |
| | else: |
| | patient = PATIENT |
| |
|
| | time_end = time.time() |
| | time_spend = time_end - time_start |
| | print("Time spend: {:.2f} for training vis model...".format(time_spend)) |
| | |
| | def train_step(self): |
| | |
| | self.model = self.model.to(device=self.DEVICE) |
| | |
| | if self.disable_encoder_grad == True: |
| | disable_grad(self.model.encoder) |
| |
|
| | self.model.train() |
| | all_loss = [] |
| | umap_losses = [] |
| | recon_losses = [] |
| | temporal_losses = [] |
| | recoverposition_losses = [] |
| | |
| |
|
| | t = tqdm(self.edge_loader, leave=True, total=len(self.edge_loader)) |
| | |
| | for data in t: |
| | edge_to, edge_from, a_to, a_from = data |
| |
|
| | edge_to = edge_to.to(device=self.DEVICE, dtype=torch.float32) |
| | edge_from = edge_from.to(device=self.DEVICE, dtype=torch.float32) |
| | a_to = a_to.to(device=self.DEVICE, dtype=torch.float32) |
| | a_from = a_from.to(device=self.DEVICE, dtype=torch.float32) |
| |
|
| | outputs = self.model(edge_to, edge_from) |
| | umap_l, recon_l, temporal_l, loss = self.criterion(edge_to, edge_from, a_to, a_from, self.model, outputs) |
| | data = torch.Tensor(self.data).to(self.DEVICE) |
| | new_emb = self.model.encoder(data).to(self.DEVICE) |
| | grid_high = self.model.decoder(torch.Tensor(new_emb).to(self.DEVICE)) |
| | |
| |
|
| | pos_recover_loss_fn = PositionRecoverLoss(self.DEVICE) |
| |
|
| | pos_loss = pos_recover_loss_fn(torch.Tensor(grid_high).to(self.DEVICE), torch.Tensor(self.data).to(self.DEVICE)) |
| |
|
| | all_loss.append(loss.mean().item()) |
| | umap_losses.append(umap_l.item()) |
| | recon_losses.append(recon_l.item()) |
| | temporal_losses.append(temporal_l.mean().item()) |
| | recoverposition_losses.append(pos_loss.mean().item()) |
| | |
| | recoverposition_loss = sum(recoverposition_losses) / len(recoverposition_losses) |
| | loss_new = loss + 1 * recoverposition_loss |
| | self.optimizer.zero_grad() |
| | loss_new.mean().backward() |
| | |
| | self.optimizer.step() |
| | self._loss = sum(all_loss) / len(all_loss) |
| | self.model.eval() |
| | print('umap:{:.4f}\trecon_l:{:.4f}\ttemporal_l:{:.4f}\tloss:{:.4f}\tecoverposition_losses:{}'.format(sum(umap_losses) / len(umap_losses), |
| | sum(recon_losses) / len(recon_losses), |
| | sum(temporal_losses) / len(temporal_losses), |
| | sum(all_loss) / len(all_loss), sum(recoverposition_losses) / len(all_loss))) |
| | return self.loss |
| | |
| | def record_time(self, save_dir, file_name, operation, iteration, t): |
| | |
| | save_file = os.path.join(save_dir, file_name+".json") |
| | if not os.path.exists(save_file): |
| | evaluation = dict() |
| | else: |
| | f = open(save_file, "r") |
| | evaluation = json.load(f) |
| | f.close() |
| | if operation not in evaluation.keys(): |
| | evaluation[operation] = dict() |
| | evaluation[operation][iteration] = round(t, 3) |
| | with open(save_file, 'w') as f: |
| | json.dump(evaluation, f) |
| |
|
| |
|
| | class OriginDVITrainer(SingleVisTrainer): |
| | def __init__(self, model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE): |
| | super().__init__(model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE) |
| | |
| | def train_step(self): |
| | self.model = self.model.to(device=self.DEVICE) |
| | self.model.train() |
| | all_loss = [] |
| | umap_losses = [] |
| | recon_losses = [] |
| | temporal_losses = [] |
| |
|
| | t = tqdm(self.edge_loader, leave=True, total=len(self.edge_loader)) |
| | |
| | for data in t: |
| | edge_to, edge_from, a_to, a_from = data |
| |
|
| | edge_to = edge_to.to(device=self.DEVICE, dtype=torch.float32) |
| | edge_from = edge_from.to(device=self.DEVICE, dtype=torch.float32) |
| | a_to = a_to.to(device=self.DEVICE, dtype=torch.float32) |
| | a_from = a_from.to(device=self.DEVICE, dtype=torch.float32) |
| |
|
| | outputs = self.model(edge_to, edge_from) |
| | umap_l, recon_l, temporal_l, loss = self.criterion(edge_to, edge_from, a_to, a_from, self.model, outputs) |
| | all_loss.append(loss.mean().item()) |
| | umap_losses.append(umap_l.item()) |
| | recon_losses.append(recon_l.item()) |
| | temporal_losses.append(temporal_l.mean().item()) |
| | |
| | self.optimizer.zero_grad() |
| | loss.mean().backward() |
| | self.optimizer.step() |
| | self._loss = sum(all_loss) / len(all_loss) |
| | self.model.eval() |
| | print('umap:{:.4f}\trecon_l:{:.4f}\ttemporal_l:{:.4f}\tloss:{:.4f}'.format(sum(umap_losses) / len(umap_losses), |
| | sum(recon_losses) / len(recon_losses), |
| | sum(temporal_losses) / len(temporal_losses), |
| | sum(all_loss) / len(all_loss))) |
| | return self.loss |
| | |
| | def record_time(self, save_dir, file_name, operation, iteration, t): |
| | |
| | save_file = os.path.join(save_dir, file_name+".json") |
| | if not os.path.exists(save_file): |
| | evaluation = dict() |
| | else: |
| | f = open(save_file, "r") |
| | evaluation = json.load(f) |
| | f.close() |
| | if operation not in evaluation.keys(): |
| | evaluation[operation] = dict() |
| | evaluation[operation][iteration] = round(t, 3) |
| | with open(save_file, 'w') as f: |
| | json.dump(evaluation, f) |
| | class DVIALMODITrainer(SingleVisTrainer): |
| | def __init__(self, model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE, grid_high_mask, high_bom, high_rad, iteration, data_provider, prev_model, S_N_EPOCHS, B_N_EPOCHS, N_NEIGHBORS,vis_error_indices=None, **kwargs): |
| | super().__init__(model, criterion, optimizer, lr_scheduler, edge_loader, DEVICE, **kwargs) |
| | self.is_first_active_learning = True |
| | self.grid_high_mask = grid_high_mask |
| | self.high_bom = high_bom |
| | self.high_rad = high_rad |
| | self.iteration = iteration |
| | self.data_provider = data_provider |
| | self.prev_model = prev_model |
| | self.S_N_EPOCHS = S_N_EPOCHS |
| | self.B_N_EPOCHS = B_N_EPOCHS |
| | self.N_NEIGHBORS = N_NEIGHBORS |
| | self.vis_error_indices = vis_error_indices |
| | |
| |
|
| | def al_loader(self): |
| | print("evluating") |
| | |
| | |
| | losses = [] |
| | |
| | |
| | self.model.eval() |
| | |
| | if isinstance(self.grid_high_mask, torch.Tensor): |
| | |
| | self.grid_high_mask = self.grid_high_mask.cpu().detach().numpy() |
| |
|
| | grid_pred = self.data_provider.get_pred(self.iteration, self.grid_high_mask).argmax(axis=1) |
| | self.grid_high_mask = torch.tensor(self.grid_high_mask).to(device=self.DEVICE, dtype=torch.float32) |
| | grid_second_high_mask = self.model(self.grid_high_mask,self.grid_high_mask)['recon'][0] |
| | grid_second_high_mask = grid_second_high_mask.cpu().detach().numpy() |
| | grid_second_pred = self.data_provider.get_pred(self.iteration, grid_second_high_mask).argmax(axis=1) |
| |
|
| | error_indices = [i for i in range(len(grid_pred)) if grid_pred[i] != grid_second_pred[i]] |
| | grid_high_error = [self.grid_high_mask[i] for i in error_indices] |
| |
|
| | |
| | threshold = self.high_rad[0] // 2 |
| |
|
| | |
| | filtered_indices = np.where(self.high_rad < threshold) |
| |
|
| | |
| | filtered_centers = self.high_bom[filtered_indices] |
| | filtered_radius = self.high_rad[filtered_indices] |
| |
|
| | cluster_points = [] |
| | uncluster_points = [] |
| |
|
| | |
| | for point in grid_high_error: |
| | point = point.cpu().detach().numpy() |
| | |
| | distances = np.linalg.norm(point - filtered_centers, axis=1) |
| | |
| | |
| | closest_center_index = np.argmin(distances) |
| | |
| | |
| | if distances[closest_center_index] < filtered_radius[closest_center_index]: |
| | |
| | cluster_points.append(point) |
| | else: |
| | |
| | uncluster_points.append(point) |
| | cluster_points = np.array(cluster_points) |
| | uncluster_points = np.array(uncluster_points) |
| | |
| |
|
| | al_spatial_cons = ActiveLearningEpochSpatialEdgeConstructor(self.data_provider, self.iteration, self.S_N_EPOCHS, self.B_N_EPOCHS, self.N_NEIGHBORS, cluster_points, uncluster_points, self.high_bom) |
| | al_edge_to, al_edge_from, al_probs, al_feature_vectors, al_attention = al_spatial_cons.construct() |
| |
|
| | al_probs = al_probs / (al_probs.max()+1e-3) |
| | eliminate_zeros = al_probs>5e-2 |
| | al_edge_to = al_edge_to[eliminate_zeros] |
| | al_edge_from = al_edge_from[eliminate_zeros] |
| | al_probs = al_probs[eliminate_zeros] |
| | |
| | dataset = DVIDataHandler(al_edge_to, al_edge_from, al_feature_vectors, al_attention) |
| |
|
| |
|
| | n_samples = int(np.sum(self.S_N_EPOCHS * al_probs) // 1) |
| |
|
| | |
| | if len(al_edge_to) > pow(2,24): |
| | sampler = CustomWeightedRandomSampler(al_probs, n_samples, replacement=True) |
| | else: |
| | sampler = WeightedRandomSampler(al_probs, n_samples, replacement=True) |
| | new_loader = DataLoader(dataset, batch_size=2000, sampler=sampler, num_workers=8, prefetch_factor=10) |
| | |
| | if self.vis_error_indices: |
| | lens_edge = len(al_edge_from) |
| | new_edge_to = [] |
| | new_edge_from = [] |
| | new_feature = [] |
| | new_attention = [] |
| | new_probs = [] |
| | mapping = {} |
| |
|
| | for i in range(lens_edge): |
| | if al_edge_from[i] in self.vis_error_indices or al_edge_to[i] in self.vis_error_indices: |
| | new_edge_to.append(al_edge_to[i]) |
| | new_edge_from.append(al_edge_from[i]) |
| | new_probs.append(al_probs[i]) |
| |
|
| | |
| | new_edge_from = np.array(new_edge_from) |
| | new_edge_to = np.array(new_edge_to) |
| | |
| | |
| | new_probs = np.array(new_probs) |
| |
|
| | dataset = DVIDataHandler(new_edge_to, new_edge_from, al_feature_vectors, al_attention) |
| | n_samples = int(np.sum(self.S_N_EPOCHS * new_probs) // 1) |
| |
|
| | |
| | |
| | |
| | if lens_edge > pow(2,24): |
| | sampler = CustomWeightedRandomSampler(new_probs, n_samples, replacement=True) |
| | else: |
| | sampler = WeightedRandomSampler(new_probs, n_samples, replacement=True) |
| | |
| | new_loader = DataLoader(dataset, batch_size=2000, sampler=sampler, num_workers=8, prefetch_factor=10) |
| | |
| |
|
| | |
| | return losses, new_loader |
| | |
| | def train_step(self, edge_loader): |
| | |
| | self.model = self.model.to(device=self.DEVICE) |
| | self.model.train() |
| | all_loss = [] |
| | umap_losses = [] |
| | recon_losses = [] |
| | temporal_losses = [] |
| |
|
| |
|
| | t = tqdm(edge_loader, leave=True, total=len(edge_loader)) |
| | |
| |
|
| | for data in t: |
| |
|
| | edge_to, edge_from, a_to, a_from = data |
| |
|
| | edge_to = edge_to.to(device=self.DEVICE, dtype=torch.float32) |
| | edge_from = edge_from.to(device=self.DEVICE, dtype=torch.float32) |
| | a_to = a_to.to(device=self.DEVICE, dtype=torch.float32) |
| | a_from = a_from.to(device=self.DEVICE, dtype=torch.float32) |
| |
|
| | outputs = self.model(edge_to, edge_from) |
| | umap_l, recon_l, temporal_l, loss = self.criterion(edge_to, edge_from, a_to, a_from, self.model, outputs) |
| | |
| | |
| | all_loss.append(loss.mean().item()) |
| | umap_losses.append(umap_l.mean().item()) |
| | recon_losses.append(recon_l.mean().item()) |
| | temporal_losses.append(temporal_l.mean().item()) |
| | |
| | self.optimizer.zero_grad() |
| | loss.mean().backward() |
| | self.optimizer.step() |
| |
|
| | self._loss = sum(all_loss) / len(all_loss) |
| | self.model.eval() |
| | print('umap:{:.4f}\trecon_l:{:.4f}\ttemporal_l:{:.4f}\tloss:{:.4f}'.format(sum(umap_losses) / len(umap_losses), |
| | sum(recon_losses) / len(recon_losses), |
| | sum(temporal_losses) / len(temporal_losses), |
| | sum(all_loss) / len(all_loss))) |
| | return self.loss |
| | |
| | def run_epoch(self, epoch, current_loader, is_active_learning=False, is_full_data=False): |
| | print("====================\nepoch:{}\n===================".format(epoch+1)) |
| | start_time = time.time() |
| |
|
| | if is_active_learning and is_full_data == False: |
| | _, current_loader = self.al_loader() |
| | |
| | |
| | if self.is_first_active_learning: |
| | print("change learning rate") |
| | for param_group in self.optimizer.param_groups: |
| | param_group['lr'] *= 0.1 |
| | self.is_first_active_learning = False |
| | |
| | prev_loss = self.loss |
| |
|
| | if is_full_data: |
| | print("full data") |
| | loss = self.train_step(self.edge_loader) |
| | else: |
| | loss = self.train_step(current_loader) |
| | |
| | self.lr_scheduler.step() |
| |
|
| | elapsed_time = time.time() - start_time |
| | print("Epoch completed in: {:.2f} seconds".format(elapsed_time)) |
| |
|
| | return prev_loss, loss, current_loader |
| | |
| | def train(self, PATIENT, MAX_EPOCH_NUMS): |
| | start_flag = 1 |
| | if start_flag: |
| | current_loader = self.edge_loader |
| | start_flag = 0 |
| | print("ininin in dvi") |
| | patient = PATIENT |
| | time_start = time.time() |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | for epoch in range(MAX_EPOCH_NUMS): |
| | print("In active learning") |
| | |
| | prev_loss, loss, current_loader = self.run_epoch(epoch, current_loader, is_active_learning=True, is_full_data=False) |
| | |
| | |
| | if abs(prev_loss - loss) < 5E-3: |
| | if patient == 0: |
| | break |
| | else: |
| | patient -= 1 |
| | else: |
| | patient = PATIENT |
| |
|
| | time_end = time.time() |
| | time_spend = time_end - time_start |
| | print("Time spend: {:.2f} for training vis model...".format(time_spend)) |
| |
|
| | self.prev_model.load_state_dict(self.model.state_dict()) |
| | for param in self.prev_model.parameters(): |
| | param.requires_grad = False |
| | w_prev = dict(self.prev_model.named_parameters()) |
| | |
| | |
| | def record_time(self, save_dir, file_name, operation, iteration, t): |
| | |
| | save_file = os.path.join(save_dir, file_name+".json") |
| | if not os.path.exists(save_file): |
| | evaluation = dict() |
| | else: |
| | f = open(save_file, "r") |
| | evaluation = json.load(f) |
| | f.close() |
| | if operation not in evaluation.keys(): |
| | evaluation[operation] = dict() |
| | evaluation[operation][iteration] = round(t, 3) |
| | with open(save_file, 'w') as f: |
| | json.dump(evaluation, f) |