| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import multiprocessing |
| import unittest |
| from multiprocessing import shared_memory |
|
|
| import torch |
|
|
| from verl.workers.rollout.vllm_rollout.bucketed_weight_transfer import create_shared_memory, rebuild_shared_memory |
|
|
|
|
| class TestSharedMemory(unittest.TestCase): |
| """Test cases for shared memory utility functions.""" |
|
|
| def setUp(self): |
| """Set up test fixtures before each test method.""" |
| |
| import uuid |
|
|
| short_id = uuid.uuid4().hex[:8] |
| self.test_name = f"shm_{short_id}" |
|
|
| def tearDown(self): |
| """Clean up shared memory after each test method.""" |
| |
| |
| pass |
|
|
| def test_create_shared_memory_new(self): |
| """Test creating new shared memory with unique name.""" |
| size = 1024 |
|
|
| shm = create_shared_memory(size, self.test_name) |
|
|
| |
| self.assertIsNotNone(shm) |
| |
| self.assertGreaterEqual(shm.size, size) |
| self.assertEqual(shm.name, self.test_name) |
|
|
| |
| del shm |
|
|
| def test_create_shared_memory_attach_existing(self): |
| """Test that create_shared_memory attaches to existing shared memory when FileExistsError occurs.""" |
| size = 2048 |
|
|
| |
| shm1 = create_shared_memory(size, self.test_name) |
| self.assertGreaterEqual(shm1.size, size) |
|
|
| |
| shm2 = create_shared_memory(size, self.test_name) |
|
|
| |
| self.assertIsNotNone(shm2) |
| self.assertGreaterEqual(shm2.size, size) |
| self.assertEqual(shm2.name, self.test_name) |
|
|
| |
| self.assertEqual(shm1.name, shm2.name) |
|
|
| |
| del shm1, shm2 |
|
|
| def test_rebuild_shared_memory_default_dtype(self): |
| """Test rebuilding tensor from shared memory with default dtype (uint8).""" |
| size = 1024 |
|
|
| |
| shm = create_shared_memory(size, self.test_name) |
| test_data = torch.arange(size, dtype=torch.uint8) |
| shm.buf[:size] = test_data.numpy().tobytes() |
|
|
| |
| tensor, _ = rebuild_shared_memory(self.test_name, size) |
|
|
| |
| self.assertEqual(tensor.dtype, torch.uint8) |
| self.assertEqual(len(tensor), size) |
|
|
| |
| reconstructed = torch.frombuffer(shm.buf[:size], dtype=torch.uint8) |
| self.assertTrue(torch.equal(tensor, reconstructed)) |
|
|
| |
| del tensor, reconstructed |
|
|
| def test_rebuild_shared_memory_custom_dtype(self): |
| """Test rebuilding tensor from shared memory with custom dtype.""" |
| size = 256 |
|
|
| |
| shm = create_shared_memory(size, self.test_name) |
| test_data = torch.arange(64, dtype=torch.float32) |
| shm.buf[:size] = test_data.numpy().tobytes() |
|
|
| |
| tensor, _ = rebuild_shared_memory(self.test_name, size, dtype=torch.float32) |
|
|
| |
| self.assertEqual(tensor.dtype, torch.float32) |
| self.assertEqual(len(tensor), 64) |
|
|
| |
| reconstructed = torch.frombuffer(shm.buf[:size], dtype=torch.float32) |
| self.assertTrue(torch.equal(tensor, reconstructed)) |
|
|
| |
| del tensor, reconstructed |
|
|
| def test_shared_memory_data_integrity(self): |
| """Test that data remains intact between create and rebuild operations.""" |
| size = 512 |
|
|
| |
| test_data = torch.randint(0, 256, (size,), dtype=torch.uint8) |
|
|
| |
| shm = create_shared_memory(size, self.test_name) |
| shm.buf[:size] = test_data.numpy().tobytes() |
|
|
| |
| tensor, _ = rebuild_shared_memory(self.test_name, size) |
|
|
| |
| reconstructed = torch.frombuffer(shm.buf[:size], dtype=torch.uint8) |
| self.assertTrue(torch.equal(test_data, reconstructed)) |
|
|
| |
| del tensor, reconstructed |
|
|
| def test_shared_memory_different_dtypes(self): |
| """Test shared memory operations with different tensor dtypes.""" |
| test_cases = [ |
| (torch.float32, 256, 64), |
| (torch.float64, 256, 32), |
| (torch.int32, 256, 64), |
| (torch.int64, 256, 32), |
| (torch.uint8, 256, 256), |
| ] |
|
|
| for dtype, size, expected_len in test_cases: |
| |
| test_data = torch.arange(expected_len, dtype=dtype) |
|
|
| |
| shm = create_shared_memory(size, self.test_name) |
| shm.buf[:size] = test_data.numpy().tobytes() |
|
|
| |
| tensor, _ = rebuild_shared_memory(self.test_name, size, dtype=dtype) |
|
|
| |
| self.assertEqual(tensor.dtype, dtype) |
| self.assertEqual(len(tensor), expected_len) |
|
|
| reconstructed = torch.frombuffer(shm.buf[:size], dtype=dtype) |
| self.assertTrue(torch.equal(test_data, reconstructed)) |
|
|
| |
| del tensor, reconstructed |
|
|
| def test_shared_memory_multiple_operations(self): |
| """Test multiple create/rebuild operations with the same name.""" |
| size = 512 |
|
|
| |
| test_data1 = torch.arange(size, dtype=torch.uint8) |
| shm1 = create_shared_memory(size, self.test_name) |
| shm1.buf[:size] = test_data1.numpy().tobytes() |
| tensor1, _ = rebuild_shared_memory(self.test_name, size) |
| reconstructed1 = torch.frombuffer(shm1.buf[:size], dtype=torch.uint8) |
| self.assertTrue(torch.equal(test_data1, reconstructed1)) |
| del tensor1, reconstructed1, shm1 |
|
|
| |
| test_data2 = torch.arange(size, dtype=torch.uint8) * 2 |
| shm2 = create_shared_memory(size, self.test_name) |
| shm2.buf[:size] = test_data2.numpy().tobytes() |
| tensor2, _ = rebuild_shared_memory(self.test_name, size) |
| reconstructed2 = torch.frombuffer(shm2.buf[:size], dtype=torch.uint8) |
| self.assertTrue(torch.equal(test_data2, reconstructed2)) |
| del tensor2, reconstructed2, shm2 |
|
|
|
|
| |
| def child_process_function(name, size, test_data_bytes): |
| """Child process function to rebuild and verify tensor.""" |
| shm = None |
| tensor = None |
| test_data = None |
| try: |
| |
| test_data = torch.frombuffer(test_data_bytes, dtype=torch.uint8) |
|
|
| |
| shm = shared_memory.SharedMemory(name=name) |
|
|
| |
| tensor = torch.frombuffer(shm.buf[:size], dtype=torch.uint8) |
|
|
| |
| assert torch.equal(test_data, tensor), "Data mismatch in child process" |
| return True |
| except Exception as e: |
| print(f"Error in child process: {e}") |
| return False |
| finally: |
| |
| |
| del tensor, test_data |
| if shm is not None: |
| shm.close() |
| |
|
|
|
|
| class TestSharedMemoryIntegration(unittest.TestCase): |
| """Integration tests for shared memory operations across process boundaries.""" |
|
|
| def test_cross_process_shared_memory(self): |
| """Test shared memory can be created in one process and accessed in another.""" |
| size = 1024 |
| test_data = torch.arange(size, dtype=torch.uint8) |
|
|
| |
| shm = create_shared_memory(size, "test_cross_proc") |
| shm.buf[:size] = test_data.numpy().tobytes() |
|
|
| |
| test_data_bytes = test_data.numpy().tobytes() |
|
|
| |
| process = multiprocessing.Process( |
| target=child_process_function, args=("test_cross_proc", size, test_data_bytes) |
| ) |
| process.start() |
| process.join(timeout=5) |
|
|
| |
| self.assertEqual(process.exitcode, 0, "Child process failed") |
|
|
| |
| del shm |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|