| | import pytest |
| | import torch |
| | import torch.nn as nn |
| | from torch.autograd import gradcheck |
| |
|
| | import kornia |
| | import kornia.testing as utils |
| | from kornia.augmentation import ( |
| | CenterCrop3D, |
| | RandomAffine3D, |
| | RandomCrop, |
| | RandomCrop3D, |
| | RandomDepthicalFlip3D, |
| | RandomEqualize3D, |
| | RandomHorizontalFlip3D, |
| | RandomRotation3D, |
| | RandomVerticalFlip3D, |
| | ) |
| | from kornia.testing import assert_close |
| | from kornia.utils._compat import torch_version_geq |
| |
|
| |
|
| | class TestRandomHorizontalFlip3D: |
| |
|
| | |
| | |
| | @pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
| | def test_smoke(self): |
| | f = RandomHorizontalFlip3D(0.5) |
| | repr = "RandomHorizontalFlip3D(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=0.5)" |
| | assert str(f) == repr |
| |
|
| | def test_random_hflip(self, device): |
| |
|
| | f = RandomHorizontalFlip3D(p=1.0, return_transform=True) |
| | f1 = RandomHorizontalFlip3D(p=0.0, return_transform=True) |
| | f2 = RandomHorizontalFlip3D(p=1.0) |
| | f3 = RandomHorizontalFlip3D(p=0.0) |
| |
|
| | input = torch.tensor( |
| | [ |
| | [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 2.0]], |
| | [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 2.0]], |
| | ] |
| | ) |
| |
|
| | input = input.to(device) |
| |
|
| | expected = torch.tensor( |
| | [ |
| | [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [2.0, 1.0, 0.0, 0.0]], |
| | [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [2.0, 1.0, 0.0, 0.0]], |
| | ] |
| | ) |
| |
|
| | expected = expected.to(device) |
| |
|
| | expected_transform = torch.tensor( |
| | [[-1.0, 0.0, 0.0, 3.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] |
| | ) |
| |
|
| | expected_transform = expected_transform.to(device) |
| |
|
| | identity = torch.tensor( |
| | [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] |
| | ) |
| | identity = identity.to(device) |
| |
|
| | assert (f(input)[0] == expected).all() |
| | assert (f(input)[1] == expected_transform).all() |
| | assert (f1(input)[0] == input).all() |
| | assert (f1(input)[1] == identity).all() |
| | assert (f2(input) == expected).all() |
| | assert (f3(input) == input).all() |
| |
|
| | def test_batch_random_hflip(self, device): |
| |
|
| | f = RandomHorizontalFlip3D(p=1.0, return_transform=True) |
| | f1 = RandomHorizontalFlip3D(p=0.0, return_transform=True) |
| |
|
| | input = torch.tensor([[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]]]) |
| | input = input.to(device) |
| |
|
| | expected = torch.tensor([[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 1.0, 0.0]]]]]) |
| | expected = expected.to(device) |
| |
|
| | expected_transform = torch.tensor( |
| | [[[-1.0, 0.0, 0.0, 2.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]] |
| | ) |
| | expected_transform = expected_transform.to(device) |
| |
|
| | identity = torch.tensor( |
| | [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]] |
| | ) |
| | identity = identity.to(device) |
| |
|
| | input = input.repeat(5, 3, 1, 1, 1) |
| | expected = expected.repeat(5, 3, 1, 1, 1) |
| | expected_transform = expected_transform.repeat(5, 1, 1) |
| | identity = identity.repeat(5, 1, 1) |
| |
|
| | assert (f(input)[0] == expected).all() |
| | assert (f(input)[1] == expected_transform).all() |
| | assert (f1(input)[0] == input).all() |
| | assert (f1(input)[1] == identity).all() |
| |
|
| | def test_same_on_batch(self, device): |
| | f = RandomHorizontalFlip3D(p=0.5, same_on_batch=True) |
| | input = torch.eye(3).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 1, 1, 1, 1) |
| | res = f(input) |
| | assert (res[0] == res[1]).all() |
| |
|
| | def test_sequential(self, device): |
| |
|
| | f = nn.Sequential( |
| | RandomHorizontalFlip3D(p=1.0, return_transform=True), RandomHorizontalFlip3D(p=1.0, return_transform=True) |
| | ) |
| | f1 = nn.Sequential(RandomHorizontalFlip3D(p=1.0, return_transform=True), RandomHorizontalFlip3D(p=1.0)) |
| |
|
| | input = torch.tensor([[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]]) |
| | input = input.to(device) |
| |
|
| | expected_transform = torch.tensor( |
| | [[[-1.0, 0.0, 0.0, 2.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]] |
| | ) |
| | expected_transform = expected_transform.to(device) |
| |
|
| | expected_transform_1 = expected_transform @ expected_transform |
| | expected_transform_1 = expected_transform_1.to(device) |
| |
|
| | assert (f(input)[0] == input).all() |
| | assert (f(input)[1] == expected_transform_1).all() |
| | assert (f1(input)[0] == input).all() |
| | assert (f1(input)[1] == expected_transform).all() |
| |
|
| | def test_gradcheck(self, device): |
| | input = torch.rand((1, 3, 3)).to(device) |
| | input = utils.tensor_to_gradcheck_var(input) |
| | assert gradcheck(RandomHorizontalFlip3D(p=1.0), (input,), raise_exception=True) |
| |
|
| |
|
| | class TestRandomVerticalFlip3D: |
| |
|
| | |
| | |
| | @pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
| | def test_smoke(self): |
| | f = RandomVerticalFlip3D(0.5) |
| | repr = "RandomVerticalFlip3D(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=0.5)" |
| | assert str(f) == repr |
| |
|
| | def test_random_vflip(self, device, dtype): |
| |
|
| | f = RandomVerticalFlip3D(p=1.0, return_transform=True) |
| | f1 = RandomVerticalFlip3D(p=0.0, return_transform=True) |
| | f2 = RandomVerticalFlip3D(p=1.0) |
| | f3 = RandomVerticalFlip3D(p=0.0) |
| |
|
| | input = torch.tensor( |
| | [ |
| | [ |
| | [ |
| | [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]], |
| | [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]], |
| | ] |
| | ] |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | expected = torch.tensor( |
| | [ |
| | [ |
| | [ |
| | [[0.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], |
| | [[0.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], |
| | ] |
| | ] |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | expected_transform = torch.tensor( |
| | [[[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 2.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | identity = torch.tensor( |
| | [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | assert_close(f(input)[0], expected) |
| | assert_close(f(input)[1], expected_transform) |
| | assert_close(f1(input)[0], input) |
| | assert_close(f1(input)[1], identity) |
| | assert_close(f2(input), expected) |
| | assert_close(f3(input), input) |
| |
|
| | def test_batch_random_vflip(self, device): |
| |
|
| | f = RandomVerticalFlip3D(p=1.0, return_transform=True) |
| | f1 = RandomVerticalFlip3D(p=0.0, return_transform=True) |
| |
|
| | input = torch.tensor([[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]]]) |
| | input = input.to(device) |
| |
|
| | expected = torch.tensor([[[[[0.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]]]) |
| | expected = expected.to(device) |
| |
|
| | expected_transform = torch.tensor( |
| | [[[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 2.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]] |
| | ) |
| | expected_transform = expected_transform.to(device) |
| |
|
| | identity = torch.tensor( |
| | [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]] |
| | ) |
| | identity = identity.to(device) |
| |
|
| | input = input.repeat(5, 3, 1, 1, 1) |
| | expected = expected.repeat(5, 3, 1, 1, 1) |
| | expected_transform = expected_transform.repeat(5, 1, 1) |
| | identity = identity.repeat(5, 1, 1) |
| |
|
| | assert_close(f(input)[0], expected) |
| | assert_close(f(input)[1], expected_transform) |
| | assert_close(f1(input)[0], input) |
| | assert_close(f1(input)[1], identity) |
| |
|
| | def test_same_on_batch(self, device): |
| | f = RandomVerticalFlip3D(p=0.5, same_on_batch=True) |
| | input = torch.eye(3).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 1, 1, 1, 1) |
| | res = f(input) |
| | assert (res[0] == res[1]).all() |
| |
|
| | def test_sequential(self, device): |
| |
|
| | f = nn.Sequential( |
| | RandomVerticalFlip3D(p=1.0, return_transform=True), RandomVerticalFlip3D(p=1.0, return_transform=True) |
| | ) |
| | f1 = nn.Sequential(RandomVerticalFlip3D(p=1.0, return_transform=True), RandomVerticalFlip3D(p=1.0)) |
| |
|
| | input = torch.tensor([[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]]]) |
| | input = input.to(device) |
| |
|
| | expected_transform = torch.tensor( |
| | [[[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 2.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]] |
| | ) |
| | expected_transform = expected_transform.to(device) |
| |
|
| | expected_transform_1 = expected_transform @ expected_transform |
| |
|
| | assert_close(f(input)[0], input) |
| | assert_close(f(input)[1], expected_transform_1) |
| | assert_close(f1(input)[0], input) |
| | assert_close(f1(input)[1], expected_transform) |
| |
|
| | def test_gradcheck(self, device): |
| | input = torch.rand((1, 3, 3)).to(device) |
| | input = utils.tensor_to_gradcheck_var(input) |
| | assert gradcheck(RandomVerticalFlip3D(p=1.0), (input,), raise_exception=True) |
| |
|
| |
|
| | class TestRandomDepthicalFlip3D: |
| |
|
| | |
| | |
| | @pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
| | def test_smoke(self): |
| | f = RandomDepthicalFlip3D(0.5) |
| | repr = "RandomDepthicalFlip3D(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=0.5)" |
| | assert str(f) == repr |
| |
|
| | def test_random_dflip(self, device, dtype): |
| |
|
| | f = RandomDepthicalFlip3D(p=1.0, return_transform=True) |
| | f1 = RandomDepthicalFlip3D(p=0.0, return_transform=True) |
| | f2 = RandomDepthicalFlip3D(p=1.0) |
| | f3 = RandomDepthicalFlip3D(p=0.0) |
| |
|
| | input = torch.tensor( |
| | [ |
| | [ |
| | [ |
| | [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]], |
| | [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 2.0]], |
| | ] |
| | ] |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | expected = torch.tensor( |
| | [ |
| | [ |
| | [ |
| | [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 2.0]], |
| | [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]], |
| | ] |
| | ] |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | expected_transform = torch.tensor( |
| | [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.0], [0.0, 0.0, 0.0, 1.0]]], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | identity = torch.tensor( |
| | [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | assert_close(f(input)[0], expected) |
| | assert_close(f(input)[1], expected_transform) |
| | assert_close(f1(input)[0], input) |
| | assert_close(f1(input)[1], identity) |
| | assert_close(f2(input), expected) |
| | assert_close(f3(input), input) |
| |
|
| | def test_batch_random_dflip(self, device): |
| |
|
| | f = RandomDepthicalFlip3D(p=1.0, return_transform=True) |
| | f1 = RandomDepthicalFlip3D(p=0.0, return_transform=True) |
| |
|
| | input = torch.tensor( |
| | [ |
| | [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]], |
| | [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 2.0]], |
| | ] |
| | ) |
| |
|
| | input = input.to(device) |
| |
|
| | expected = torch.tensor( |
| | [ |
| | [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 2.0]], |
| | [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]], |
| | ] |
| | ) |
| | expected = expected.to(device) |
| |
|
| | expected_transform = torch.tensor( |
| | [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.0], [0.0, 0.0, 0.0, 1.0]]] |
| | ) |
| | expected_transform = expected_transform.to(device) |
| |
|
| | identity = torch.tensor( |
| | [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]] |
| | ) |
| | identity = identity.to(device) |
| |
|
| | input = input.repeat(5, 3, 1, 1, 1) |
| | expected = expected.repeat(5, 3, 1, 1, 1) |
| | expected_transform = expected_transform.repeat(5, 1, 1) |
| | identity = identity.repeat(5, 1, 1) |
| |
|
| | assert_close(f(input)[0], expected) |
| | assert_close(f(input)[1], expected_transform) |
| | assert_close(f1(input)[0], input) |
| | assert_close(f1(input)[1], identity) |
| |
|
| | def test_same_on_batch(self, device): |
| | f = RandomDepthicalFlip3D(p=0.5, same_on_batch=True) |
| | input = torch.eye(3).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 1, 2, 1, 1) |
| | res = f(input) |
| | assert (res[0] == res[1]).all() |
| |
|
| | def test_sequential(self, device): |
| |
|
| | f = nn.Sequential( |
| | RandomDepthicalFlip3D(p=1.0, return_transform=True), RandomDepthicalFlip3D(p=1.0, return_transform=True) |
| | ) |
| | f1 = nn.Sequential(RandomDepthicalFlip3D(p=1.0, return_transform=True), RandomDepthicalFlip3D(p=1.0)) |
| |
|
| | input = torch.tensor( |
| | [ |
| | [ |
| | [ |
| | [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]], |
| | [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 2.0]], |
| | ] |
| | ] |
| | ] |
| | ) |
| | input = input.to(device) |
| |
|
| | expected_transform = torch.tensor( |
| | [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.0], [0.0, 0.0, 0.0, 1.0]]] |
| | ) |
| | expected_transform = expected_transform.to(device) |
| |
|
| | expected_transform_1 = expected_transform @ expected_transform |
| |
|
| | assert_close(f(input)[0], input) |
| | assert_close(f(input)[1], expected_transform_1) |
| | assert_close(f1(input)[0], input) |
| | assert_close(f1(input)[1], expected_transform) |
| |
|
| | def test_gradcheck(self, device): |
| | input = torch.rand((1, 3, 3)).to(device) |
| | input = utils.tensor_to_gradcheck_var(input) |
| | assert gradcheck(RandomDepthicalFlip3D(p=1.0), (input,), raise_exception=True) |
| |
|
| |
|
| | class TestRandomRotation3D: |
| |
|
| | torch.manual_seed(0) |
| |
|
| | |
| | |
| | @pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
| | def test_smoke(self): |
| | f = RandomRotation3D(degrees=45.5) |
| | repr = ( |
| | """RandomRotation3D(degrees=tensor([[-45.5000, 45.5000], |
| | [-45.5000, 45.5000], |
| | [-45.5000, 45.5000]]), resample=BILINEAR, align_corners=False, p=0.5, """ |
| | """p_batch=1.0, same_on_batch=False, return_transform=False)""" |
| | ) |
| | assert str(f) == repr |
| |
|
| | def test_random_rotation(self, device, dtype): |
| | |
| | torch.manual_seed(0) |
| |
|
| | f = RandomRotation3D(degrees=45.0, return_transform=True) |
| | f1 = RandomRotation3D(degrees=45.0) |
| |
|
| | input = torch.tensor( |
| | [ |
| | [[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]], |
| | [[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]], |
| | [[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]], |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | expected = torch.tensor( |
| | [ |
| | [ |
| | [ |
| | [ |
| | [0.0000, 0.0000, 0.6810, 0.5250], |
| | [0.5052, 0.0000, 0.0000, 0.0613], |
| | [0.1159, 0.1072, 0.5324, 0.0870], |
| | [0.0000, 0.0000, 0.1927, 0.0000], |
| | ], |
| | [ |
| | [0.0000, 0.1683, 0.6963, 0.1131], |
| | [0.0566, 0.0000, 0.5215, 0.2796], |
| | [0.0694, 0.6039, 1.4519, 1.1240], |
| | [0.0000, 0.1325, 0.1542, 0.2510], |
| | ], |
| | [ |
| | [0.0000, 0.2054, 0.0000, 0.0000], |
| | [0.0026, 0.6088, 0.7358, 0.2319], |
| | [0.1261, 1.0830, 1.3687, 1.4940], |
| | [0.0000, 0.0416, 0.2012, 0.3124], |
| | ], |
| | ] |
| | ] |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | expected_transform = torch.tensor( |
| | [ |
| | [ |
| | [0.6523, 0.3666, -0.6635, 0.6352], |
| | [-0.6185, 0.7634, -0.1862, 1.4689], |
| | [0.4382, 0.5318, 0.7247, -1.1797], |
| | [0.0000, 0.0000, 0.0000, 1.0000], |
| | ] |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | out, mat = f(input) |
| | assert_close(out, expected, rtol=1e-6, atol=1e-4) |
| | assert_close(mat, expected_transform, rtol=1e-6, atol=1e-4) |
| |
|
| | torch.manual_seed(0) |
| | assert_close(f1(input), expected, rtol=1e-6, atol=1e-4) |
| |
|
| | def test_batch_random_rotation(self, device, dtype): |
| |
|
| | torch.manual_seed(24) |
| |
|
| | f = RandomRotation3D(degrees=45.0, return_transform=True) |
| |
|
| | input = torch.tensor( |
| | [ |
| | [ |
| | [[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]], |
| | [[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]], |
| | [[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]], |
| | ] |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | expected = torch.tensor( |
| | [ |
| | [ |
| | [ |
| | [ |
| | [1.0000, 0.0000, 0.0000, 2.0000], |
| | [0.0000, 0.0000, 0.0000, 0.0000], |
| | [0.0000, 1.0000, 2.0000, 0.0000], |
| | [0.0000, 0.0000, 1.0000, 2.0000], |
| | ], |
| | [ |
| | [1.0000, 0.0000, 0.0000, 2.0000], |
| | [0.0000, 0.0000, 0.0000, 0.0000], |
| | [0.0000, 1.0000, 2.0000, 0.0000], |
| | [0.0000, 0.0000, 1.0000, 2.0000], |
| | ], |
| | [ |
| | [1.0000, 0.0000, 0.0000, 2.0000], |
| | [0.0000, 0.0000, 0.0000, 0.0000], |
| | [0.0000, 1.0000, 2.0000, 0.0000], |
| | [0.0000, 0.0000, 1.0000, 2.0000], |
| | ], |
| | ] |
| | ], |
| | [ |
| | [ |
| | [ |
| | [0.0000, 0.0726, 0.0000, 0.0000], |
| | [0.1038, 1.0134, 0.5566, 0.1519], |
| | [0.0000, 1.0849, 1.1068, 0.0000], |
| | [0.1242, 1.1065, 0.9681, 0.0000], |
| | ], |
| | [ |
| | [0.0000, 0.0047, 0.0166, 0.0000], |
| | [0.0579, 0.4459, 0.0000, 0.4728], |
| | [0.1864, 1.3349, 0.7530, 0.3251], |
| | [0.1431, 1.2481, 0.4471, 0.0000], |
| | ], |
| | [ |
| | [0.0000, 0.4840, 0.2314, 0.0000], |
| | [0.0000, 0.0328, 0.0000, 0.1434], |
| | [0.1899, 0.5580, 0.0000, 0.9170], |
| | [0.0000, 0.2042, 0.1571, 0.0855], |
| | ], |
| | ] |
| | ], |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | expected_transform = torch.tensor( |
| | [ |
| | [ |
| | [1.0000, 0.0000, 0.0000, 0.0000], |
| | [0.0000, 1.0000, 0.0000, 0.0000], |
| | [0.0000, 0.0000, 1.0000, 0.0000], |
| | [0.0000, 0.0000, 0.0000, 1.0000], |
| | ], |
| | [ |
| | [0.7522, -0.6326, -0.1841, 1.5047], |
| | [0.6029, 0.5482, 0.5796, -0.8063], |
| | [-0.2657, -0.5470, 0.7938, 1.4252], |
| | [0.0000, 0.0000, 0.0000, 1.0000], |
| | ], |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | input = input.repeat(2, 1, 1, 1, 1) |
| |
|
| | out, mat = f(input) |
| | assert_close(out, expected, rtol=1e-6, atol=1e-4) |
| | assert_close(mat, expected_transform, rtol=1e-6, atol=1e-4) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | f = RandomRotation3D(degrees=40, same_on_batch=True) |
| | input = torch.eye(6, device=device, dtype=dtype).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 3, 6, 1, 1) |
| | res = f(input) |
| | assert (res[0] == res[1]).all() |
| |
|
| | def test_sequential(self, device, dtype): |
| |
|
| | torch.manual_seed(24) |
| |
|
| | f = nn.Sequential( |
| | RandomRotation3D(torch.tensor([-45.0, 90]), return_transform=True), |
| | RandomRotation3D(10.4, return_transform=True), |
| | ) |
| | f1 = nn.Sequential(RandomRotation3D(torch.tensor([-45.0, 90]), return_transform=True), RandomRotation3D(10.4)) |
| |
|
| | input = torch.tensor( |
| | [ |
| | [[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]], |
| | [[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]], |
| | [[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]], |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | expected = torch.tensor( |
| | [ |
| | [ |
| | [ |
| | [ |
| | [0.3431, 0.1239, 0.0000, 1.0348], |
| | [0.0000, 0.2035, 0.1139, 0.1770], |
| | [0.0789, 0.9057, 1.7780, 0.0000], |
| | [0.0000, 0.2286, 1.2498, 1.2643], |
| | ], |
| | [ |
| | [0.5460, 0.2131, 0.0000, 1.1453], |
| | [0.0000, 0.0899, 0.0000, 0.4293], |
| | [0.0797, 1.0193, 1.6677, 0.0000], |
| | [0.0000, 0.2458, 1.2765, 1.0920], |
| | ], |
| | [ |
| | [0.6322, 0.2614, 0.0000, 0.9207], |
| | [0.0000, 0.0037, 0.0000, 0.6551], |
| | [0.0689, 0.9251, 1.3442, 0.0000], |
| | [0.0000, 0.2449, 0.9856, 0.6862], |
| | ], |
| | ] |
| | ] |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | expected_transform = torch.tensor( |
| | [ |
| | [ |
| | [0.9857, -0.1686, -0.0019, 0.2762], |
| | [0.1668, 0.9739, 0.1538, -0.3650], |
| | [-0.0241, -0.1520, 0.9881, 0.2760], |
| | [0.0000, 0.0000, 0.0000, 1.0000], |
| | ] |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | expected_transform_2 = torch.tensor( |
| | [ |
| | [ |
| | [0.2348, -0.1615, 0.9585, 0.4316], |
| | [0.1719, 0.9775, 0.1226, -0.3467], |
| | [-0.9567, 0.1360, 0.2573, 1.9738], |
| | [0.0000, 0.0000, 0.0000, 1.0000], |
| | ] |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | out, mat = f(input) |
| | _, mat_2 = f1(input) |
| | assert_close(out, expected, rtol=1e-6, atol=1e-4) |
| | assert_close(mat, expected_transform, rtol=1e-6, atol=1e-4) |
| | assert_close(mat_2, expected_transform_2, rtol=1e-6, atol=1e-4) |
| |
|
| | def test_gradcheck(self, device): |
| |
|
| | torch.manual_seed(0) |
| |
|
| | input = torch.rand((3, 3, 3)).to(device) |
| | input = utils.tensor_to_gradcheck_var(input) |
| | assert gradcheck(RandomRotation3D(degrees=(15.0, 15.0), p=1.0), (input,), raise_exception=True) |
| |
|
| |
|
| | class TestRandomCrop3D: |
| | |
| | |
| | @pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
| | def test_smoke(self): |
| | f = RandomCrop3D(size=(2, 3, 4), padding=(0, 1, 2), fill=10, pad_if_needed=False, p=1.0) |
| | repr = ( |
| | "RandomCrop3D(crop_size=(2, 3, 4), padding=(0, 1, 2), fill=10, pad_if_needed=False, " |
| | "padding_mode=constant, resample=BILINEAR, p=1.0, p_batch=1.0, same_on_batch=False, " |
| | "return_transform=False)" |
| | ) |
| | assert str(f) == repr |
| |
|
| | @pytest.mark.parametrize("batch_size", [1, 2]) |
| | def test_no_padding(self, batch_size, device, dtype): |
| | torch.manual_seed(42) |
| | inp = torch.tensor( |
| | [ |
| | [ |
| | [ |
| | [ |
| | [0.0, 1.0, 2.0, 3.0, 4.0], |
| | [5.0, 6.0, 7.0, 8.0, 9.0], |
| | [10, 11, 12, 13, 14], |
| | [15, 16, 17, 18, 19], |
| | [20, 21, 22, 23, 24], |
| | ] |
| | ] |
| | ] |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ).repeat(batch_size, 1, 5, 1, 1) |
| | f = RandomCrop3D(size=(2, 3, 4), padding=None, align_corners=True, p=1.0) |
| | out = f(inp) |
| | if batch_size == 1: |
| | expected = torch.tensor( |
| | [[[[[11, 12, 13, 14], [16, 17, 18, 19], [21, 22, 23, 24]]]]], device=device, dtype=dtype |
| | ).repeat(batch_size, 1, 2, 1, 1) |
| | if batch_size == 2: |
| | expected = torch.tensor( |
| | [ |
| | [ |
| | [ |
| | [ |
| | [6.0000, 7.0000, 8.0000, 9.0000], |
| | [11.0000, 12.0000, 13.0000, 14.0000], |
| | [16.0000, 17.0000, 18.0000, 19.0000], |
| | ], |
| | [ |
| | [6.0000, 7.0000, 8.0000, 9.0000], |
| | [11.0000, 12.0000, 13.0000, 14.0000], |
| | [16.0000, 17.0000, 18.0000, 19.0000], |
| | ], |
| | ] |
| | ], |
| | [ |
| | [ |
| | [ |
| | [11.0000, 12.0000, 13.0000, 14.0000], |
| | [16.0000, 17.0000, 18.0000, 19.0000], |
| | [21.0000, 22.0000, 23.0000, 24.0000], |
| | ], |
| | [ |
| | [11.0000, 12.0000, 13.0000, 14.0000], |
| | [16.0000, 17.0000, 18.0000, 19.0000], |
| | [21.0000, 22.0000, 23.0000, 24.0000], |
| | ], |
| | ] |
| | ], |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | assert_close(out, expected, atol=1e-4, rtol=1e-4) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | f = RandomCrop3D(size=(2, 3, 4), padding=None, align_corners=True, p=1.0, same_on_batch=True) |
| | input = torch.eye(6).unsqueeze(dim=0).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 3, 5, 1, 1) |
| | res = f(input) |
| | assert (res[0] == res[1]).all() |
| |
|
| | @pytest.mark.parametrize("padding", [1, (1, 1, 1), (1, 1, 1, 1, 1, 1)]) |
| | def test_padding_batch(self, padding, device, dtype): |
| | torch.manual_seed(42) |
| | batch_size = 2 |
| | inp = torch.tensor([[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], device=device, dtype=dtype).repeat( |
| | batch_size, 1, 3, 1, 1 |
| | ) |
| | expected = torch.tensor( |
| | [ |
| | [ |
| | [ |
| | [[0.0, 1.0, 2.0, 10.0], [3.0, 4.0, 5.0, 10.0], [6.0, 7.0, 8.0, 10.0]], |
| | [[0.0, 1.0, 2.0, 10.0], [3.0, 4.0, 5.0, 10.0], [6.0, 7.0, 8.0, 10.0]], |
| | ] |
| | ], |
| | [ |
| | [ |
| | [[3.0, 4.0, 5.0, 10.0], [6.0, 7.0, 8.0, 10.0], [10, 10, 10, 10.0]], |
| | [[3.0, 4.0, 5.0, 10.0], [6.0, 7.0, 8.0, 10.0], [10, 10, 10, 10.0]], |
| | ] |
| | ], |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| | f = RandomCrop3D(size=(2, 3, 4), fill=10.0, padding=padding, align_corners=True, p=1.0) |
| | out = f(inp) |
| |
|
| | assert_close(out, expected, atol=1e-4, rtol=1e-4) |
| |
|
| | def test_pad_if_needed(self, device, dtype): |
| | torch.manual_seed(42) |
| | inp = torch.tensor([[[0.0, 1.0, 2.0]]], device=device, dtype=dtype) |
| | expected = torch.tensor( |
| | [ |
| | [ |
| | [ |
| | [[9.0, 9.0, 9.0, 9.0], [9.0, 9.0, 9.0, 9.0], [9.0, 9.0, 9.0, 9.0]], |
| | [[0.0, 1.0, 2.0, 9.0], [9.0, 9.0, 9.0, 9.0], [9.0, 9.0, 9.0, 9.0]], |
| | ] |
| | ] |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| | rc = RandomCrop3D(size=(2, 3, 4), pad_if_needed=True, fill=9, align_corners=True, p=1.0) |
| | out = rc(inp) |
| |
|
| | assert_close(out, expected, atol=1e-4, rtol=1e-4) |
| |
|
| | def test_gradcheck(self, device, dtype): |
| | torch.manual_seed(0) |
| | inp = torch.rand((3, 3, 3), device=device, dtype=dtype) |
| | inp = utils.tensor_to_gradcheck_var(inp) |
| | assert gradcheck(RandomCrop3D(size=(3, 3, 3), p=1.0), (inp,), raise_exception=True) |
| |
|
| | @pytest.mark.skip("Need to fix Union type") |
| | def test_jit(self, device, dtype): |
| | |
| | op = RandomCrop(size=(3, 3), p=1.0).forward |
| | op_script = torch.jit.script(op) |
| | img = torch.ones(1, 1, 5, 6, device=device, dtype=dtype) |
| |
|
| | actual = op_script(img) |
| | expected = kornia.geometry.transform.center_crop3d(img) |
| | assert_close(actual, expected) |
| |
|
| | @pytest.mark.skip("Need to fix Union type") |
| | def test_jit_trace(self, device, dtype): |
| | |
| | op = RandomCrop(size=(3, 3), p=1.0).forward |
| | op_script = torch.jit.script(op) |
| | |
| | img = torch.ones(1, 1, 5, 6, device=device, dtype=dtype) |
| |
|
| | op_trace = torch.jit.trace(op_script, (img,)) |
| |
|
| | |
| | img = torch.ones(1, 1, 5, 6, device=device, dtype=dtype) |
| |
|
| | |
| | actual = op_trace(img) |
| | expected = op(img) |
| | assert_close(actual, expected) |
| |
|
| |
|
| | class TestCenterCrop3D: |
| | def test_no_transform(self, device, dtype): |
| | inp = torch.rand(1, 2, 4, 4, 4, device=device, dtype=dtype) |
| | out = CenterCrop3D(2)(inp) |
| | assert out.shape == (1, 2, 2, 2, 2) |
| |
|
| | def test_transform(self, device, dtype): |
| | inp = torch.rand(1, 2, 5, 4, 8, device=device, dtype=dtype) |
| | out = CenterCrop3D(2, return_transform=True)(inp) |
| | assert len(out) == 2 |
| | assert out[0].shape == (1, 2, 2, 2, 2) |
| | assert out[1].shape == (1, 4, 4) |
| |
|
| | def test_no_transform_tuple(self, device, dtype): |
| | inp = torch.rand(1, 2, 5, 4, 8, device=device, dtype=dtype) |
| | out = CenterCrop3D((3, 4, 5))(inp) |
| | assert out.shape == (1, 2, 3, 4, 5) |
| |
|
| | def test_gradcheck(self, device, dtype): |
| | input = torch.rand(1, 2, 3, 4, 5, device=device, dtype=dtype) |
| | input = utils.tensor_to_gradcheck_var(input) |
| | assert gradcheck(CenterCrop3D(3), (input,), raise_exception=True) |
| |
|
| |
|
| | class TestRandomEqualize3D: |
| | |
| | |
| | @pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
| | def test_smoke(self, device, dtype): |
| | f = RandomEqualize3D(p=0.5) |
| | repr = "RandomEqualize3D(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=False)" |
| | assert str(f) == repr |
| |
|
| | def test_random_equalize(self, device, dtype): |
| | f = RandomEqualize3D(p=1.0, return_transform=True) |
| | f1 = RandomEqualize3D(p=0.0, return_transform=True) |
| | f2 = RandomEqualize3D(p=1.0) |
| | f3 = RandomEqualize3D(p=0.0) |
| |
|
| | bs, channels, depth, height, width = 1, 3, 6, 10, 10 |
| |
|
| | inputs3d = self.build_input(channels, depth, height, width, bs, device=device, dtype=dtype) |
| |
|
| | row_expected = torch.tensor( |
| | [0.0000, 0.11764, 0.2353, 0.3529, 0.4706, 0.5882, 0.7059, 0.8235, 0.9412, 1.0000], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| | expected = self.build_input(channels, depth, height, width, bs=1, row=row_expected, device=device, dtype=dtype) |
| |
|
| | identity = kornia.eye_like(4, expected) |
| |
|
| | assert_close(f(inputs3d)[0], expected, rtol=1e-4, atol=1e-4) |
| | assert_close(f(inputs3d)[1], identity, rtol=1e-4, atol=1e-4) |
| | assert_close(f1(inputs3d)[0], inputs3d, rtol=1e-4, atol=1e-4) |
| | assert_close(f1(inputs3d)[1], identity, rtol=1e-4, atol=1e-4) |
| | assert_close(f2(inputs3d), expected, rtol=1e-4, atol=1e-4) |
| | assert_close(f3(inputs3d), inputs3d, rtol=1e-4, atol=1e-4) |
| |
|
| | def test_batch_random_equalize(self, device, dtype): |
| | f = RandomEqualize3D(p=1.0, return_transform=True) |
| | f1 = RandomEqualize3D(p=0.0, return_transform=True) |
| | f2 = RandomEqualize3D(p=1.0) |
| | f3 = RandomEqualize3D(p=0.0) |
| |
|
| | bs, channels, depth, height, width = 2, 3, 6, 10, 10 |
| |
|
| | inputs3d = self.build_input(channels, depth, height, width, bs, device=device, dtype=dtype) |
| |
|
| | row_expected = torch.tensor([0.0000, 0.11764, 0.2353, 0.3529, 0.4706, 0.5882, 0.7059, 0.8235, 0.9412, 1.0000]) |
| | expected = self.build_input(channels, depth, height, width, bs, row=row_expected, device=device, dtype=dtype) |
| |
|
| | identity = kornia.eye_like(4, expected) |
| |
|
| | assert_close(f(inputs3d)[0], expected, rtol=1e-4, atol=1e-4) |
| | assert_close(f(inputs3d)[1], identity, rtol=1e-4, atol=1e-4) |
| | assert_close(f1(inputs3d)[0], inputs3d, rtol=1e-4, atol=1e-4) |
| | assert_close(f1(inputs3d)[1], identity, rtol=1e-4, atol=1e-4) |
| | assert_close(f2(inputs3d), expected, rtol=1e-4, atol=1e-4) |
| | assert_close(f3(inputs3d), inputs3d, rtol=1e-4, atol=1e-4) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | f = RandomEqualize3D(p=0.5, same_on_batch=True) |
| | input = torch.eye(4, device=device, dtype=dtype) |
| | input = input.unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 1, 2, 1, 1) |
| | res = f(input) |
| | assert (res[0] == res[1]).all() |
| |
|
| | def test_gradcheck(self, device, dtype): |
| | torch.manual_seed(0) |
| |
|
| | inputs3d = torch.rand((3, 3, 3), device=device, dtype=dtype) |
| | inputs3d = utils.tensor_to_gradcheck_var(inputs3d) |
| | assert gradcheck(RandomEqualize3D(p=0.5), (inputs3d,), raise_exception=True) |
| |
|
| | @staticmethod |
| | def build_input(channels, depth, height, width, bs=1, row=None, device='cpu', dtype=torch.float32): |
| | if row is None: |
| | row = torch.arange(width, device=device, dtype=dtype) / float(width) |
| |
|
| | channel = torch.stack([row] * height) |
| | image = torch.stack([channel] * channels) |
| | image3d = torch.stack([image] * depth).transpose(0, 1) |
| | batch = torch.stack([image3d] * bs) |
| |
|
| | return batch.to(device, dtype) |
| |
|
| |
|
| | class TestRandomAffine3D: |
| | def test_batch_random_affine_3d(self, device, dtype): |
| | |
| | if torch_version_geq(1, 10) and "cuda" in str(device) and dtype == torch.float64: |
| | pytest.skip("AssertionError: assert tensor(False, device='cuda:0')") |
| |
|
| | f = RandomAffine3D((0, 0, 0), p=1.0, return_transform=True) |
| | tensor = torch.tensor( |
| | [[[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]], device=device, dtype=dtype |
| | ) |
| |
|
| | expected = torch.tensor( |
| | [[[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]], device=device, dtype=dtype |
| | ) |
| |
|
| | expected_transform = torch.tensor( |
| | [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | tensor = tensor.repeat(5, 3, 1, 1, 1) |
| | expected = expected.repeat(5, 3, 1, 1, 1) |
| | expected_transform = expected_transform.repeat(5, 1, 1) |
| |
|
| | assert (f(tensor)[0] == expected).all() |
| | assert (f(tensor)[1] == expected_transform).all() |
| |
|