| import importlib.util |
| import pathlib |
| import unittest |
| from dataclasses import dataclass |
|
|
| ROOT = pathlib.Path(__file__).resolve().parents[1] |
| MODULE_PATH = ROOT / "app" / "gpt" / "request_chunker.py" |
| spec = importlib.util.spec_from_file_location("request_chunker", MODULE_PATH) |
| if spec is None or spec.loader is None: |
| raise ImportError("request_chunker module spec not found") |
| request_chunker = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(request_chunker) |
| RequestChunker = request_chunker.RequestChunker |
|
|
|
|
| @dataclass |
| class DummySeg: |
| start: float |
| end: float |
| text: str |
|
|
|
|
| def build_messages(segments, image_urls, **_): |
| content = [{"type": "text", "text": "".join(s.text for s in segments)}] |
| for url in image_urls: |
| content.append({"type": "image_url", "image_url": {"url": url, "detail": "auto"}}) |
| return [{"role": "user", "content": content}] |
|
|
|
|
| def size_estimator(messages): |
| size = 0 |
| for part in messages[0]["content"]: |
| if part["type"] == "text": |
| size += len(part["text"]) |
| else: |
| size += len(part["image_url"]["url"]) |
| return size |
|
|
|
|
| class TestRequestChunker(unittest.TestCase): |
| def test_chunk_segments_preserves_order_and_content(self): |
| segments = [ |
| DummySeg(0, 1, "aaaa"), |
| DummySeg(1, 2, "bbbb"), |
| DummySeg(2, 3, "cccc"), |
| ] |
| chunker = RequestChunker(build_messages, max_bytes=8, size_estimator=size_estimator) |
| chunks = chunker.chunk(segments, []) |
| texts = ["".join(seg.text for seg in c.segments) for c in chunks] |
| self.assertEqual("".join(texts), "aaaabbbbcccc") |
| self.assertTrue(all(texts)) |
|
|
| def test_chunk_images_distributed_across_batches(self): |
| segments = [DummySeg(0, 1, "aa")] |
| images = ["i" * 6, "j" * 6, "k" * 6] |
| chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator) |
| chunks = chunker.chunk(segments, images) |
| all_images = [img for c in chunks for img in c.image_urls] |
| self.assertEqual(all_images, images) |
|
|
| def test_chunk_images_are_not_front_loaded_when_multiple_segment_chunks(self): |
| segments = [ |
| DummySeg(0, 1, "aaaaaa"), |
| DummySeg(1, 2, "bbbbbb"), |
| DummySeg(2, 3, "cccccc"), |
| ] |
| images = ["11111", "22222", "33333"] |
| chunker = RequestChunker(build_messages, max_bytes=12, size_estimator=size_estimator) |
| chunks = chunker.chunk(segments, images) |
|
|
| self.assertGreaterEqual(len(chunks), 3) |
| image_counts = [len(c.image_urls) for c in chunks] |
| self.assertGreater(image_counts[1], 0) |
| self.assertGreater(image_counts[2], 0) |
| all_images = [img for c in chunks for img in c.image_urls] |
| self.assertEqual(all_images, images) |
|
|
| def test_split_oversized_segment(self): |
| segments = [DummySeg(0, 1, "x" * 25)] |
| chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator) |
| chunks = chunker.chunk(segments, []) |
| combined = "".join(seg.text for c in chunks for seg in c.segments) |
| self.assertEqual(combined, "x" * 25) |
|
|
| def test_group_texts_by_budget(self): |
| chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator) |
|
|
| def build_text_messages(texts, *_args, **_kwargs): |
| content = [{"type": "text", "text": "".join(texts)}] |
| return [{"role": "user", "content": content}] |
|
|
| groups = chunker.group_texts_by_budget(["aaaaa", "bbbbb", "ccccc"], build_text_messages) |
| self.assertEqual(groups, [["aaaaa", "bbbbb"], ["ccccc"]]) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|