| |
| import unittest |
| from collections import OrderedDict |
| import torch |
| from torch import nn |
|
|
| from detectron2.checkpoint.c2_model_loading import align_and_update_state_dicts |
| from detectron2.utils.logger import setup_logger |
|
|
|
|
| class TestCheckpointer(unittest.TestCase): |
| def setUp(self): |
| setup_logger() |
|
|
| def create_complex_model(self): |
| m = nn.Module() |
| m.block1 = nn.Module() |
| m.block1.layer1 = nn.Linear(2, 3) |
| m.layer2 = nn.Linear(3, 2) |
| m.res = nn.Module() |
| m.res.layer2 = nn.Linear(3, 2) |
|
|
| state_dict = OrderedDict() |
| state_dict["layer1.weight"] = torch.rand(3, 2) |
| state_dict["layer1.bias"] = torch.rand(3) |
| state_dict["layer2.weight"] = torch.rand(2, 3) |
| state_dict["layer2.bias"] = torch.rand(2) |
| state_dict["res.layer2.weight"] = torch.rand(2, 3) |
| state_dict["res.layer2.bias"] = torch.rand(2) |
| return m, state_dict |
|
|
| def test_complex_model_loaded(self): |
| for add_data_parallel in [False, True]: |
| model, state_dict = self.create_complex_model() |
| if add_data_parallel: |
| model = nn.DataParallel(model) |
| model_sd = model.state_dict() |
|
|
| sd_to_load = align_and_update_state_dicts(model_sd, state_dict) |
| model.load_state_dict(sd_to_load) |
| for loaded, stored in zip(model_sd.values(), state_dict.values()): |
| |
| self.assertFalse(id(loaded) == id(stored)) |
| |
| self.assertTrue(loaded.to(stored).equal(stored)) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|