| | import unittest |
| | import torch |
| | import torch.nn as nn |
| | from omegaconf import OmegaConf |
| | from pytorch_lightning import seed_everything |
| | from model.vae.vqvae import VQAutoEncoder |
| |
|
| | class TestVQAutoEncoder(unittest.TestCase): |
| | @classmethod |
| | def setUpClass(cls): |
| | """Set up test fixtures that are shared across all tests.""" |
| | config = { |
| | 'model': { |
| | 'encoder': { |
| | 'module_name': 'model.vae.cnn', |
| | 'class_name': 'Encoder2D', |
| | 'output_channels': 512 |
| | }, |
| | 'decoder': { |
| | 'module_name': 'model.vae.cnn', |
| | 'class_name': 'Decoder2D', |
| | 'input_dim': 512 |
| | }, |
| | 'latent_dim': 512 |
| | }, |
| | 'optimizer': { |
| | 'lr': 1e-4, |
| | 'weight_decay': 0.0, |
| | 'adam_beta1': 0.9, |
| | 'adam_beta2': 0.999, |
| | 'adam_epsilon': 1e-8 |
| | }, |
| | 'loss': { |
| | 'l_w_recon': 1.0, |
| | 'l_w_embedding': 1.0, |
| | 'l_w_recon': 1.0 |
| | } |
| | } |
| | cls.config = OmegaConf.create(config) |
| | seed_everything(42) |
| | cls.model = VQAutoEncoder(cls.config) |
| | cls.model.configure_model() |
| |
|
| | def test_model_initialization(self): |
| | """Test that the model and its components are initialized correctly.""" |
| | self.assertIsInstance(self.model, VQAutoEncoder) |
| | self.assertIsInstance(self.model.encoder, nn.Module) |
| | self.assertIsInstance(self.model.decoder, nn.Module) |
| | self.assertTrue(hasattr(self.model, 'quantizer')) |
| |
|
| | def test_encode_decode(self): |
| | """Test the encode and decode functions of the model.""" |
| | batch_size = 2 |
| | channels = 3 |
| | height = 512 |
| | width = 512 |
| | |
| | |
| | x = torch.randn(batch_size, channels, height, width) |
| | |
| | |
| | quant, emb_loss, info = self.model.encode(x) |
| | self.assertEqual(quant.shape, (batch_size, 1, self.model.config.model.latent_dim)) |
| | self.assertIsInstance(emb_loss, torch.Tensor) |
| | self.assertIsInstance(info, tuple) |
| | |
| | |
| | dec = self.model.decode(quant) |
| | self.assertEqual(dec.shape, (batch_size, channels, height, width)) |
| |
|
| | def test_forward(self): |
| | """Test the forward pass of the model.""" |
| | batch_size = 2 |
| | channels = 3 |
| | height = 512 |
| | width = 512 |
| | |
| | |
| | x = torch.randn(batch_size, channels, height, width) |
| | |
| | |
| | dec, emb_loss, info = self.model.forward(x) |
| | |
| | |
| | self.assertEqual(dec.shape, (batch_size, channels, height, width)) |
| | self.assertIsInstance(emb_loss, torch.Tensor) |
| | self.assertIsInstance(info, tuple) |
| |
|
| | def test_training_step(self): |
| | """Test the training step of the model.""" |
| | batch_size = 2 |
| | channels = 3 |
| | height = 512 |
| | width = 512 |
| | |
| | |
| | batch = { |
| | 'pixel_values_vid': torch.randn(batch_size, channels, height, width) |
| | } |
| | |
| | |
| | loss = self.model.training_step(batch) |
| | self.assertIsInstance(loss, torch.Tensor) |
| | self.assertTrue(loss.requires_grad) |
| |
|
| | def test_validation_step(self): |
| | """Test the validation step of the model.""" |
| | batch_size = 2 |
| | channels = 3 |
| | height = 512 |
| | width = 512 |
| | |
| | |
| | batch = { |
| | 'pixel_values_vid': torch.randn(batch_size, channels, height, width) |
| | } |
| | |
| | |
| | loss = self.model.validation_step(batch) |
| | self.assertIsInstance(loss, torch.Tensor) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | unittest.main() |
| |
|