| | """ |
| | A class to record training dynamics, including: |
| | 1. loss |
| | 2. uncertainty |
| | 3. position |
| | 4. velocity |
| | 5. acceleration |
| | 6. hard samples |
| | 7. training dynamics |
| | 8. |
| | """ |
| | import numpy as np |
| | import umap |
| | import matplotlib.pyplot as plt |
| |
|
| | def softmax(x): |
| | return np.exp(x) / np.sum(np.exp(x)) |
| |
|
| | def cross_entropy(data, y): |
| | log_p = np.array([np.log(softmax(data[i])) for i in range(len(data))]) |
| | y_onehot = np.eye(len(np.unique(y)))[y] |
| | loss = - np.sum(y_onehot * log_p, axis=1) |
| | return loss |
| |
|
| |
|
| | class TD: |
| | def __init__(self, data_provider, projector) -> None: |
| | self.data_provider = data_provider |
| | self.projector = projector |
| | |
| | def loss_dynamics(self, ): |
| | EPOCH_START = self.data_provider.s |
| | EPOCH_END = self.data_provider.e |
| | EPOCH_PERIOD = self.data_provider.p |
| | labels = self.data_provider.train_labels(EPOCH_START) |
| |
|
| | |
| | losses = None |
| |
|
| | for epoch in range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD): |
| | representation = self.data_provider.train_representation(epoch) |
| | pred = self.data_provider.get_pred(epoch, representation) |
| |
|
| | loss = cross_entropy(pred, labels) |
| |
|
| | if losses is None: |
| | losses = np.expand_dims(loss, axis=0) |
| | else: |
| | losses = np.concatenate((losses, np.expand_dims(loss, axis=0)), axis=0) |
| | losses = np.transpose(losses, [1,0]) |
| | return losses |
| | |
| | def uncertainty_dynamics(self): |
| | EPOCH_START = self.data_provider.s |
| | EPOCH_END = self.data_provider.e |
| | EPOCH_PERIOD = self.data_provider.p |
| | labels = self.data_provider.train_labels(EPOCH_START) |
| |
|
| | |
| | uncertainties = None |
| |
|
| | for epoch in range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD): |
| | representation = self.data_provider.train_representation(epoch) |
| | pred = self.data_provider.get_pred(epoch, representation) |
| | uncertainty = pred[np.arange(len(labels)), labels] |
| |
|
| | if uncertainties is None: |
| | uncertainties = np.expand_dims(uncertainty, axis=0) |
| | else: |
| | uncertainties = np.concatenate((uncertainties, np.expand_dims(uncertainty, axis=0)), axis=0) |
| | uncertainties = np.transpose(uncertainties, [1,0]) |
| | return uncertainties |
| | |
| | def pred_dynamics(self): |
| | EPOCH_START = self.data_provider.s |
| | EPOCH_END = self.data_provider.e |
| | EPOCH_PERIOD = self.data_provider.p |
| |
|
| | |
| | preds = None |
| |
|
| | for epoch in range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD): |
| | representation = self.data_provider.train_representation(epoch) |
| | pred = self.data_provider.get_pred(epoch, representation) |
| |
|
| | if preds is None: |
| | preds = np.expand_dims(pred, axis=0) |
| | else: |
| | preds = np.concatenate((preds, np.expand_dims(pred, axis=0)), axis=0) |
| | preds = np.transpose(preds, [1,0, 2]) |
| | return preds |
| | |
| | def dloss_dt_dynamics(self, ): |
| | return |
| | |
| | def position_dynamics(self): |
| | EPOCH_START = self.data_provider.s |
| | EPOCH_END = self.data_provider.e |
| | EPOCH_PERIOD = self.data_provider.p |
| |
|
| | |
| | embeddings = None |
| |
|
| | for epoch in range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD): |
| | representation = self.data_provider.train_representation(epoch) |
| | embedding = self.projector.batch_project(epoch, representation) |
| | if embeddings is None: |
| | embeddings = np.expand_dims(embedding, axis=0) |
| | else: |
| | embeddings = np.concatenate((embeddings, np.expand_dims(embedding, axis=0)), axis=0) |
| | embeddings = np.transpose(embeddings, [1,0,2]) |
| | return embeddings |
| | |
| | def velocity_dynamics(self,): |
| | position_dynamics = self.position_dynamics() |
| | return position_dynamics[:, 1:, :] - position_dynamics[:, :-1, :] |
| |
|
| | def acceleration_dynamics(self, ): |
| | velocity_dynamics = self.velocity_dynamics() |
| | return velocity_dynamics[:, 1:, :] - velocity_dynamics[:, :-1, :] |
| | |
| | def show_ground_truth(self, trajectories, noise_idxs, save_path=None): |
| | |
| | num = len(trajectories) |
| | trajectories = trajectories.reshape(num, -1) |
| |
|
| | reducer = umap.UMAP() |
| | embeddings = reducer.fit_transform(trajectories) |
| |
|
| | EPOCH_START = self.data_provider.s |
| | labels = self.data_provider.train_labels(EPOCH_START) |
| |
|
| | plt.scatter( |
| | embeddings[:, 0], |
| | embeddings[:, 1], |
| | s=.3, |
| | c=labels, |
| | cmap="tab10") |
| | |
| | plt.scatter( |
| | embeddings[:, 0][noise_idxs], |
| | embeddings[:, 1][noise_idxs], |
| | s=.4, |
| | c='black') |
| | |
| | if save_path is None: |
| | plt.show() |
| | else: |
| | plt.savefig(save_path) |
| | |
| |
|
| | |