| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import unittest |
| |
|
| | import numpy as np |
| |
|
| | from transformers import is_torch_available, is_vision_available |
| | from transformers.processing_utils import _validate_images_text_input_order |
| | from transformers.testing_utils import require_torch, require_vision |
| |
|
| |
|
| | if is_vision_available(): |
| | import PIL |
| |
|
| | if is_torch_available(): |
| | import torch |
| |
|
| |
|
| | @require_vision |
| | class ProcessingUtilTester(unittest.TestCase): |
| | def test_validate_images_text_input_order(self): |
| | |
| | images = PIL.Image.new("RGB", (224, 224)) |
| | text = "text" |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
| | self.assertEqual(valid_images, images) |
| | self.assertEqual(valid_text, text) |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
| | self.assertEqual(valid_images, images) |
| | self.assertEqual(valid_text, text) |
| |
|
| | |
| | images = np.random.rand(224, 224, 3) |
| | text = ["text1", "text2"] |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
| | self.assertTrue(np.array_equal(valid_images, images)) |
| | self.assertEqual(valid_text, text) |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
| | self.assertTrue(np.array_equal(valid_images, images)) |
| | self.assertEqual(valid_text, text) |
| |
|
| | |
| | images = [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))] |
| | text = [["text1", "text2, text3"], ["text3", "text4"]] |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
| | self.assertEqual(valid_images, images) |
| | self.assertEqual(valid_text, text) |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
| | self.assertEqual(valid_images, images) |
| | self.assertEqual(valid_text, text) |
| |
|
| | |
| | images = [np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)] |
| | text = ["text1", "text2"] |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
| | self.assertTrue(np.array_equal(valid_images[0], images[0])) |
| | self.assertEqual(valid_text, text) |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
| | self.assertTrue(np.array_equal(valid_images[0], images[0])) |
| | self.assertEqual(valid_text, text) |
| |
|
| | |
| | images = ["https://url1", "https://url2"] |
| | text = ["text1", "text2"] |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
| | self.assertEqual(valid_images, images) |
| | self.assertEqual(valid_text, text) |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
| | self.assertEqual(valid_images, images) |
| | self.assertEqual(valid_text, text) |
| |
|
| | |
| | images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]] |
| | text = ["text1", "text2"] |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
| | self.assertTrue(np.array_equal(valid_images[0][0], images[0][0])) |
| | self.assertEqual(valid_text, text) |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
| | self.assertTrue(np.array_equal(valid_images[0][0], images[0][0])) |
| | self.assertEqual(valid_text, text) |
| |
|
| | |
| | images = [ |
| | [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))], |
| | [PIL.Image.new("RGB", (224, 224))], |
| | ] |
| | text = [["text1", "text2, text3"], ["text3", "text4"]] |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
| | self.assertEqual(valid_images, images) |
| | self.assertEqual(valid_text, text) |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
| | self.assertEqual(valid_images, images) |
| | self.assertEqual(valid_text, text) |
| |
|
| | |
| | images = None |
| | text = "text" |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
| | self.assertEqual(images, None) |
| | self.assertEqual(text, text) |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
| | self.assertEqual(images, None) |
| | self.assertEqual(text, text) |
| |
|
| | |
| | images = PIL.Image.new("RGB", (224, 224)) |
| | text = None |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
| | self.assertEqual(images, images) |
| | self.assertEqual(text, None) |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
| | self.assertEqual(images, images) |
| | self.assertEqual(text, None) |
| |
|
| | |
| | images = "text" |
| | text = "text" |
| | with self.assertRaises(ValueError): |
| | _validate_images_text_input_order(images=images, text=text) |
| |
|
| | @require_torch |
| | def test_validate_images_text_input_order_torch(self): |
| | |
| | images = torch.rand(224, 224, 3) |
| | text = "text" |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
| | self.assertTrue(torch.equal(valid_images, images)) |
| | self.assertEqual(valid_text, text) |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
| | self.assertTrue(torch.equal(valid_images, images)) |
| | self.assertEqual(valid_text, text) |
| |
|
| | |
| | images = [torch.rand(224, 224, 3), torch.rand(224, 224, 3)] |
| | text = ["text1", "text2"] |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
| | self.assertTrue(torch.equal(valid_images[0], images[0])) |
| | self.assertEqual(valid_text, text) |
| | |
| | valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
| | self.assertTrue(torch.equal(valid_images[0], images[0])) |
| | self.assertEqual(valid_text, text) |
| |
|