arithmetic-grpo / tests /utils /test_shared_memory.py
LeTue09's picture
initial clean commit
1faccd4
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import unittest
from multiprocessing import shared_memory
import torch
from verl.workers.rollout.vllm_rollout.bucketed_weight_transfer import create_shared_memory, rebuild_shared_memory
class TestSharedMemory(unittest.TestCase):
"""Test cases for shared memory utility functions."""
def setUp(self):
"""Set up test fixtures before each test method."""
# Use short unique names to avoid POSIX shared memory name length limits
import uuid
short_id = uuid.uuid4().hex[:8]
self.test_name = f"shm_{short_id}"
def tearDown(self):
"""Clean up shared memory after each test method."""
# Note: We're relying on the OS to clean up shared memory
# as we properly delete all references in the tests
pass
def test_create_shared_memory_new(self):
"""Test creating new shared memory with unique name."""
size = 1024
shm = create_shared_memory(size, self.test_name)
# Verify shared memory object is created correctly
self.assertIsNotNone(shm)
# Note: shared memory may have system-dependent size rounding
self.assertGreaterEqual(shm.size, size)
self.assertEqual(shm.name, self.test_name)
# Clean up - delete tensor references first
del shm
def test_create_shared_memory_attach_existing(self):
"""Test that create_shared_memory attaches to existing shared memory when FileExistsError occurs."""
size = 2048
# First, create shared memory
shm1 = create_shared_memory(size, self.test_name)
self.assertGreaterEqual(shm1.size, size)
# Second call should attach to existing memory
shm2 = create_shared_memory(size, self.test_name)
# Verify we attached to the same shared memory
self.assertIsNotNone(shm2)
self.assertGreaterEqual(shm2.size, size)
self.assertEqual(shm2.name, self.test_name)
# Both should reference the same shared memory
self.assertEqual(shm1.name, shm2.name)
# Clean up
del shm1, shm2
def test_rebuild_shared_memory_default_dtype(self):
"""Test rebuilding tensor from shared memory with default dtype (uint8)."""
size = 1024
# Create and write to shared memory
shm = create_shared_memory(size, self.test_name)
test_data = torch.arange(size, dtype=torch.uint8)
shm.buf[:size] = test_data.numpy().tobytes()
# Rebuild tensor from shared memory
tensor, _ = rebuild_shared_memory(self.test_name, size)
# Verify tensor properties
self.assertEqual(tensor.dtype, torch.uint8)
self.assertEqual(len(tensor), size)
# Verify data integrity
reconstructed = torch.frombuffer(shm.buf[:size], dtype=torch.uint8)
self.assertTrue(torch.equal(tensor, reconstructed))
# Clean up - delete references before closing
del tensor, reconstructed
def test_rebuild_shared_memory_custom_dtype(self):
"""Test rebuilding tensor from shared memory with custom dtype."""
size = 256 # 256 bytes = 64 float32 values
# Create and write to shared memory
shm = create_shared_memory(size, self.test_name)
test_data = torch.arange(64, dtype=torch.float32)
shm.buf[:size] = test_data.numpy().tobytes()
# Rebuild tensor with custom dtype
tensor, _ = rebuild_shared_memory(self.test_name, size, dtype=torch.float32)
# Verify tensor properties
self.assertEqual(tensor.dtype, torch.float32)
self.assertEqual(len(tensor), 64)
# Verify data integrity
reconstructed = torch.frombuffer(shm.buf[:size], dtype=torch.float32)
self.assertTrue(torch.equal(tensor, reconstructed))
# Clean up - delete references before closing
del tensor, reconstructed
def test_shared_memory_data_integrity(self):
"""Test that data remains intact between create and rebuild operations."""
size = 512
# Create test data with various patterns
test_data = torch.randint(0, 256, (size,), dtype=torch.uint8)
# Create shared memory and write data
shm = create_shared_memory(size, self.test_name)
shm.buf[:size] = test_data.numpy().tobytes()
# Rebuild tensor
tensor, _ = rebuild_shared_memory(self.test_name, size)
# Verify data integrity
reconstructed = torch.frombuffer(shm.buf[:size], dtype=torch.uint8)
self.assertTrue(torch.equal(test_data, reconstructed))
# Clean up - delete references before closing
del tensor, reconstructed
def test_shared_memory_different_dtypes(self):
"""Test shared memory operations with different tensor dtypes."""
test_cases = [
(torch.float32, 256, 64), # 256 bytes / 4 bytes = 64 values
(torch.float64, 256, 32), # 256 bytes / 8 bytes = 32 values
(torch.int32, 256, 64), # 256 bytes / 4 bytes = 64 values
(torch.int64, 256, 32), # 256 bytes / 8 bytes = 32 values
(torch.uint8, 256, 256), # 256 bytes / 1 byte = 256 values
]
for dtype, size, expected_len in test_cases:
# Create test data
test_data = torch.arange(expected_len, dtype=dtype)
# Create shared memory and write data
shm = create_shared_memory(size, self.test_name)
shm.buf[:size] = test_data.numpy().tobytes()
# Rebuild tensor
tensor, _ = rebuild_shared_memory(self.test_name, size, dtype=dtype)
# Verify properties and data
self.assertEqual(tensor.dtype, dtype)
self.assertEqual(len(tensor), expected_len)
reconstructed = torch.frombuffer(shm.buf[:size], dtype=dtype)
self.assertTrue(torch.equal(test_data, reconstructed))
# Clean up - delete references before closing
del tensor, reconstructed
def test_shared_memory_multiple_operations(self):
"""Test multiple create/rebuild operations with the same name."""
size = 512
# First iteration
test_data1 = torch.arange(size, dtype=torch.uint8)
shm1 = create_shared_memory(size, self.test_name)
shm1.buf[:size] = test_data1.numpy().tobytes()
tensor1, _ = rebuild_shared_memory(self.test_name, size)
reconstructed1 = torch.frombuffer(shm1.buf[:size], dtype=torch.uint8)
self.assertTrue(torch.equal(test_data1, reconstructed1))
del tensor1, reconstructed1, shm1
# Second iteration with different data
test_data2 = torch.arange(size, dtype=torch.uint8) * 2
shm2 = create_shared_memory(size, self.test_name)
shm2.buf[:size] = test_data2.numpy().tobytes()
tensor2, _ = rebuild_shared_memory(self.test_name, size)
reconstructed2 = torch.frombuffer(shm2.buf[:size], dtype=torch.uint8)
self.assertTrue(torch.equal(test_data2, reconstructed2))
del tensor2, reconstructed2, shm2
# Module-level function for cross-process testing
def child_process_function(name, size, test_data_bytes):
"""Child process function to rebuild and verify tensor."""
shm = None
tensor = None
test_data = None
try:
# Convert bytes back to tensor
test_data = torch.frombuffer(test_data_bytes, dtype=torch.uint8)
# Attach to shared memory
shm = shared_memory.SharedMemory(name=name)
# Rebuild tensor from shared memory
tensor = torch.frombuffer(shm.buf[:size], dtype=torch.uint8)
# Verify data integrity
assert torch.equal(test_data, tensor), "Data mismatch in child process"
return True
except Exception as e:
print(f"Error in child process: {e}")
return False
finally:
# Clean up shared memory in child process
# Delete all references first
del tensor, test_data
if shm is not None:
shm.close()
# Note: Don't unlink in child process, parent will clean up
class TestSharedMemoryIntegration(unittest.TestCase):
"""Integration tests for shared memory operations across process boundaries."""
def test_cross_process_shared_memory(self):
"""Test shared memory can be created in one process and accessed in another."""
size = 1024
test_data = torch.arange(size, dtype=torch.uint8)
# Create shared memory in parent process
shm = create_shared_memory(size, "test_cross_proc")
shm.buf[:size] = test_data.numpy().tobytes()
# Convert tensor to bytes for passing to child process
test_data_bytes = test_data.numpy().tobytes()
# Start child process
process = multiprocessing.Process(
target=child_process_function, args=("test_cross_proc", size, test_data_bytes)
)
process.start()
process.join(timeout=5)
# Verify child process completed successfully
self.assertEqual(process.exitcode, 0, "Child process failed")
# Clean up
del shm
if __name__ == "__main__":
unittest.main()