| | import torch |
| | import logging |
| | import numpy as np |
| | import torch.nn as nn |
| | from typing import Callable, List |
| | from accelerate import Accelerator |
| | from sklearn.linear_model import LinearRegression |
| |
|
| |
|
| | class eval_mode: |
| | def __init__(self, *models, no_grad=False): |
| | self.models = models |
| | self.no_grad = no_grad |
| | self.no_grad_context = torch.no_grad() |
| |
|
| | def __enter__(self): |
| | self.prev_states = [] |
| | for model in self.models: |
| | self.prev_states.append(model.training) |
| | model.train(False) |
| | if self.no_grad: |
| | self.no_grad_context.__enter__() |
| |
|
| | def __exit__(self, *args): |
| | if self.no_grad: |
| | self.no_grad_context.__exit__(*args) |
| | for model, state in zip(self.models, self.prev_states): |
| | model.train(state) |
| | return False |
| |
|
| |
|
| | def embed_trajectory_dataset( |
| | model, |
| | dataset, |
| | obs_only=True, |
| | device=None, |
| | embed_goal=False, |
| | ): |
| | if type(model) is nn.parallel.DistributedDataParallel: |
| | return embed_trajectory_dataset_ddp( |
| | model, |
| | dataset, |
| | obs_only=obs_only, |
| | device=device, |
| | embed_goal=embed_goal, |
| | ) |
| | else: |
| | result = [] |
| | accelerator = Accelerator() |
| | device = device or accelerator.device |
| | with eval_mode(model, no_grad=True): |
| | for i in range(len(dataset)): |
| | obs, *rest = dataset[i] |
| | obs = obs.to(accelerator.device) |
| | obs_enc = model(obs).to(device) |
| | if obs_only: |
| | result.append(obs_enc) |
| | else: |
| | if embed_goal: |
| | |
| | goal = rest[-1] |
| | rest = rest[:-1] |
| | goal = goal.to(accelerator.device) |
| | goal_enc = model(goal).to(device) |
| | rest.append(goal_enc) |
| | rest = [x.to(device) for x in rest] |
| | result.append((obs_enc, *rest)) |
| | return result |
| |
|
| |
|
| | def embed_trajectory_dataset_ddp( |
| | model: nn.Module, |
| | dataset, |
| | obs_only=True, |
| | device=None, |
| | embed_goal=False, |
| | ): |
| | assert type(model) is nn.parallel.DistributedDataParallel, "Model must be DDP" |
| | embeddings = [] |
| | accelerator = Accelerator() |
| | dataloader = torch.utils.data.DataLoader( |
| | dataset, |
| | batch_size=1, |
| | num_workers=1, |
| | shuffle=False, |
| | pin_memory=True, |
| | ) |
| | dataloader = accelerator.prepare(dataloader) |
| | |
| | max_T = max(dataset.get_seq_length(i) for i in range(len(dataset))) |
| | with eval_mode(model, no_grad=True): |
| | for obs, *rest in dataloader: |
| | obs = obs.to(accelerator.device) |
| | obs_enc = model(obs) |
| | obs_enc = pad_to_length(obs_enc, max_T, dim=1) |
| | obs_enc = accelerator.gather_for_metrics(obs_enc) |
| | if obs_only: |
| | embeddings.append(obs_enc) |
| | else: |
| | if embed_goal: |
| | |
| | goal = rest[-1] |
| | rest = rest[:-1] |
| | goal = goal.to(accelerator.device) |
| | goal_enc = model(goal) |
| | rest.append(goal_enc) |
| | rest = [x.to(accelerator.device) for x in rest] |
| | rest = [pad_to_length(x, max_T, dim=1) for x in rest] |
| | rest = [accelerator.gather_for_metrics(x) for x in rest] |
| | embeddings.append((obs_enc, *rest)) |
| |
|
| | device = device or accelerator.device |
| | |
| | result = [] |
| | if obs_only: |
| | embeddings = torch.cat(embeddings, dim=0) |
| | assert len(embeddings) == len(dataset) |
| | else: |
| | embeddings = [torch.cat(x, dim=0) for x in zip(*embeddings)] |
| | assert len(embeddings[0]) == len(dataset) |
| | for i in range(len(dataset)): |
| | T = dataset.get_seq_length(i) |
| | if obs_only: |
| | result.append(embeddings[i, :T].to(device)) |
| | else: |
| | result.append([x[i, :T].to(device) for x in embeddings]) |
| | return result |
| |
|
| |
|
| | def pad_to_length(x: torch.Tensor, length: int, dim: int = 0): |
| | """ |
| | Pad tensor x to length along dim, adding zeros at the end. |
| | """ |
| | pad_size = length - x.shape[dim] |
| | if pad_size <= 0: |
| | return x |
| | pad = torch.zeros( |
| | *x.shape[:dim], |
| | pad_size, |
| | *x.shape[dim + 1 :], |
| | device=x.device, |
| | dtype=x.dtype, |
| | ) |
| | return torch.cat([x, pad], dim=dim) |
| |
|
| |
|
| | def repeat_start_to_length(x: torch.Tensor, length: int, dim: int = 0): |
| | """ |
| | Pad tensor x to length along dim, repeating the first value at the start. |
| | """ |
| | pad_size = length - x.shape[dim] |
| | if pad_size <= 0: |
| | return x |
| | first_frame = x.index_select(dim, torch.tensor(0, device=x.device)) |
| | repeat_shape = [1] * len(x.shape) |
| | repeat_shape[dim] = pad_size |
| | pad = first_frame.repeat(*repeat_shape) |
| | return torch.cat([pad, x], dim=dim) |
| |
|
| |
|
| | def nn_lookup( |
| | query: torch.Tensor, |
| | pool: torch.Tensor, |
| | metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], |
| | ): |
| | pairwise_query = query.repeat_interleave(len(pool), dim=0) |
| | pairwise_pool = pool.repeat((len(query), 1)) |
| | dist = metric(pairwise_query, pairwise_pool) |
| | nn_dist, nn_idx = dist.view(len(query), len(pool)).sort(dim=1) |
| | return nn_dist, nn_idx |
| |
|
| |
|
| | def batch_knn( |
| | query: torch.Tensor, |
| | pool: torch.Tensor, |
| | metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], |
| | k: int, |
| | batch_size: int, |
| | ): |
| | """ |
| | Return the k nearest neighbors of query in pool using metric. |
| | Input: |
| | query: Tensor[N, D] of query points |
| | pool: Tensor[M, D] of pool points |
| | metric: Callable[[Tensor[N, D], Tensor[M, D]], Tensor[N, M]] distance function |
| | k: int number of neighbors to return |
| | batch_size: int batch size for computation. Batched over query. |
| | Output: (distances, indices) |
| | distances: Tensor[N, k] of distances to the k nearest neighbors |
| | indices: Tensor[N, k] of indices of the k nearest neighbors |
| | """ |
| | nn_dists = [] |
| | nn_idxs = [] |
| | for i in range(0, len(query), batch_size): |
| | batch = query[i : i + batch_size].to(pool.device) |
| | nn_dist, nn_idx = nn_lookup(batch, pool, metric) |
| | nn_dists.append(nn_dist[:, :k]) |
| | nn_idxs.append(nn_idx[:, :k]) |
| | return torch.cat(nn_dists), torch.cat(nn_idxs) |
| |
|
| |
|
| | def linear_probe_with_trajectory_split( |
| | X: torch.Tensor, |
| | y: torch.Tensor, |
| | train_idx: List[int], |
| | val_idx: List[int], |
| | ): |
| | X_train = torch.cat([X[i] for i in train_idx]).cpu().numpy() |
| | y_train = torch.cat([y[i] for i in train_idx]).cpu().numpy() |
| | X_val = torch.cat([X[i] for i in val_idx]).cpu().numpy() |
| | y_val = torch.cat([y[i] for i in val_idx]).cpu().numpy() |
| |
|
| | X_all = torch.cat(X).cpu().numpy() |
| | y_all = torch.cat(y).cpu().numpy() |
| |
|
| | m = LinearRegression() |
| | |
| | m.fit(X_all, y_all) |
| | linear_probe_mse_train_all = np.mean((m.predict(X_train) - y_train) ** 2).item() |
| | |
| | linear_probe_mse_val_all = np.mean((m.predict(X_val) - y_val) ** 2).item() |
| | return { |
| | "linear_probe_mse_train_all": linear_probe_mse_train_all, |
| | "linear_probe_mse_val_all": linear_probe_mse_val_all, |
| | } |
| |
|
| |
|
| | def mse(a: torch.Tensor, b: torch.Tensor): |
| | return ((a - b) ** 2).mean(dim=1) |
| |
|
| |
|
| | def mahalanobis(a, b, VI): |
| | u = a - b |
| | v = u @ VI |
| | return (u * v).sum(dim=-1).sqrt() |
| |
|
| |
|
| | class OLS: |
| | """ |
| | OLS in torch |
| | NOTE: discrepancy with sklearn's LinearRegression when ill-conditioned; reverting to sklearn for now |
| | """ |
| |
|
| | def __init__(self, bias=True, fallback_to_cpu=True): |
| | self.bias = bias |
| | self.w = None |
| | self.fallback_to_cpu = fallback_to_cpu |
| |
|
| | def fit(self, X: torch.Tensor, y: torch.Tensor): |
| | """ |
| | Fit the model |
| | """ |
| | if self.bias: |
| | X = torch.cat([X, torch.ones(X.shape[0], 1, device=X.device)], dim=1) |
| | self.w = torch.linalg.lstsq(X, y).solution |
| | if torch.isnan(self.w).any(): |
| | cond = torch.linalg.cond(X) |
| | rank = torch.linalg.matrix_rank(X) |
| | msg = f"NaNs in OLS solution. Input shape: {X.shape}, cond: {cond}, rank: {rank}" |
| | if not self.fallback_to_cpu: |
| | raise ValueError(msg) |
| | logging.warn(f"{msg}; Falling back to CPU with gelss driver.") |
| | self.w = torch.linalg.lstsq(X.cpu(), y.cpu(), driver="gelss").solution |
| | self.w = self.w.to(X.device) |
| | return self |
| |
|
| | def predict(self, X: torch.Tensor): |
| | """ |
| | Predict the output |
| | """ |
| | if self.w is None: |
| | raise ValueError("Model not fitted") |
| | if self.bias: |
| | X = torch.cat([X, torch.ones(X.shape[0], 1, device=X.device)], dim=1) |
| | return X @ self.w |
| |
|
| |
|
| | class SGDClassifier: |
| | def __init__(self, lr=1e-4, max_iter=1000, tol=1e-3, batch_size=2048): |
| | self.lr = lr |
| | self.max_iter = max_iter |
| | self.tol = tol |
| | self.batch_size = batch_size |
| |
|
| | def fit(self, X: torch.Tensor, y: torch.Tensor): |
| | n_samples, input_dim = X.shape |
| | n_classes = y.max().item() + 1 |
| | self.linear = nn.Linear(input_dim, n_classes).to(X.device) |
| | optimizer = torch.optim.AdamW( |
| | self.linear.parameters(), lr=self.lr, weight_decay=0.0 |
| | ) |
| | criterion = nn.CrossEntropyLoss() |
| | for j in range(self.max_iter): |
| | total_loss = 0 |
| | n_batches = 0 |
| | indices = torch.randperm(n_samples).to(X.device) |
| | for i in range(0, n_samples, self.batch_size): |
| | batch_indices = indices[i : i + self.batch_size] |
| | batch_X, batch_y = X[batch_indices], y[batch_indices] |
| | optimizer.zero_grad() |
| | logits = self.linear(batch_X) |
| | loss = criterion(logits, batch_y) |
| | loss.backward() |
| | optimizer.step() |
| | total_loss += loss.item() |
| | n_batches += 1 |
| | avg_loss = total_loss / n_batches |
| | if avg_loss < self.tol: |
| | break |
| | if j + 1 < self.max_iter: |
| | logging.info(f"Converged at epoch {j+1}.") |
| | else: |
| | logging.info(f"Max iter reached. Final loss {avg_loss}") |
| | return self |
| |
|
| | def predict(self, X: torch.Tensor): |
| | with torch.no_grad(): |
| | return torch.argmax(self.linear(X), dim=1) |
| |
|
| | def score(self, X: torch.Tensor, y: torch.Tensor): |
| | return (self.predict(X) == y).float().mean().item() |
| |
|