File size: 1,705 Bytes
7a60a87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import unittest
from pathlib import Path

from tests.utils import execute_shell_command, wait_for_server

CACHE_DIR = Path(__file__).parent.parent.parent.joinpath("cache")


class TestRegenerateTrainData(unittest.TestCase):

    def test_regenerate_sharegpt(self):
        # prepare data
        data_process = execute_shell_command(
            "python scripts/prepare_data.py --dataset sharegpt"
        )
        data_process.wait()

        # launch sglang
        sglang_process = execute_shell_command(
            """python3 -m sglang.launch_server \
    --model unsloth/Llama-3.2-1B-Instruct \
    --tp 1 \
    --cuda-graph-bs 4 \
    --dtype bfloat16 \
    --mem-frac=0.8 \
    --port 30000
        """,
            disable_proxy=True,
            enable_hf_mirror=True,
        )
        wait_for_server(f"http://localhost:30000", disable_proxy=True)

        regeneration_process = execute_shell_command(
            """python scripts/regenerate_train_data.py \
    --model unsloth/Llama-3.2-1B-Instruct \
    --concurrency 128 \
    --max-tokens 128 \
    --server-address localhost:30000 \
    --temperature 0.8 \
    --input-file-path ./cache/dataset/sharegpt_train.jsonl \
    --output-file-path ./cache/dataset/sharegpt_train_regen.jsonl \
    --num-samples 10
        """,
            disable_proxy=True,
            enable_hf_mirror=True,
        )
        regeneration_process.wait()
        self.assertEqual(regeneration_process.returncode, 0)
        self.assertTrue(
            CACHE_DIR.joinpath("dataset", "sharegpt_train_regen.jsonl").exists()
        )
        sglang_process.terminate()
        sglang_process.wait()


if __name__ == "__main__":
    unittest.main(verbosity=2)