File size: 4,665 Bytes
7a60a87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import shutil
import unittest
from pathlib import Path

from tests.utils import execute_shell_command

CACHE_DIR = Path(__file__).parent.parent.parent.joinpath("cache")


def replace_in_script(script_path: Path, pattern: str, replacement: str):
    with open(script_path, "r") as f:
        script = f.readlines()
    script = [line.replace(pattern, replacement) for line in script]
    with open(script_path, "w") as f:
        for line in script:
            f.write(line)


class TestTrainEagle3(unittest.TestCase):

    def setUp(self) -> None:
        # prepare data
        data_process = execute_shell_command(
            "python scripts/prepare_data.py --dataset sharegpt"
        )
        data_process.wait()

        # modify the sccript to only train for 10 steps
        # add --max-num-steps 10 to the launch command
        script_path = Path(__file__).parent.parent.parent.joinpath(
            "examples", "run_llama3.1_8b_eagle3_online.sh"
        )
        with open(script_path, "r") as f:
            script = f.readlines()

        # remove empty lines
        script = [line for line in script if line.strip()]
        script[-1] = script[-1].rstrip() + " --max-num-steps 10"

        # replace meta-llama/Llama-3.1-8B-Instruct with unsloth/Llama-3.2-1B-Instruct
        # so that we don't need HF token for gated repo
        script = [
            line.replace(
                "meta-llama/Llama-3.1-8B-Instruct", "nreHieW/Llama-3.1-8B-Instruct"
            )
            for line in script
        ]

        # write the script back to the file
        with open(script_path, "w") as f:
            for line in script:
                f.write(line)

    def test_online_train_eagle3_with_sglang_backend(self):
        # run training
        train_process = execute_shell_command(
            "bash examples/run_llama3.1_8b_eagle3_online.sh 2"
        )
        train_process.wait()
        self.assertEqual(train_process.returncode, 0)

    def test_online_train_eagle3_with_hf_backend(self):
        # replace --target-model-backend sglang with --target-model-backend hf
        script_path = Path(__file__).parent.parent.parent.joinpath(
            "examples", "run_llama3.1_8b_eagle3_online.sh"
        )
        replace_in_script(
            script_path, "--target-model-backend sglang", "--target-model-backend hf"
        )

        # run training
        train_process = execute_shell_command(
            "bash examples/run_llama3.1_8b_eagle3_online.sh 2"
        )
        train_process.wait()
        self.assertEqual(train_process.returncode, 0)

    def test_online_train_eagle3_with_custom_backend(self):
        # replace --target-model-backend sglang with --target-model-backend custom
        script_path = Path(__file__).parent.parent.parent.joinpath(
            "examples", "run_llama3.1_8b_eagle3_online.sh"
        )
        replace_in_script(
            script_path,
            "--target-model-backend sglang",
            "--target-model-backend custom",
        )

        # run training
        train_process = execute_shell_command(
            "bash examples/run_llama3.1_8b_eagle3_online.sh 2"
        )
        train_process.wait()
        self.assertEqual(train_process.returncode, 0)

    def test_offline_train_eagle3(self):
        # remove the hidden states if they exist
        script_path = Path(__file__).parent.parent.parent.joinpath(
            "examples", "run_llama3.1_8b_eagle3_offline.sh"
        )
        replace_in_script(
            script_path,
            "meta-llama/Llama-3.1-8B-Instruct",
            "nreHieW/Llama-3.1-8B-Instruct",
        )
        replace_in_script(
            script_path,
            "--batch-size 32",
            "--batch-size 5",
        )
        replace_in_script(
            script_path,
            "scripts/prepare_hidden_states.py",
            "scripts/prepare_hidden_states.py --num-samples 10",
        )
        replace_in_script(
            script_path,
            "$ROOT_DIR/scripts/train_eagle3.py",
            "$ROOT_DIR/scripts/train_eagle3.py --max-num-steps 2",
        )

        hidden_states_path = Path(__file__).parent.parent.parent.joinpath(
            "cache", "hidden_states", "sharegpt_train_Llama-3.1-8B-Instruct"
        )
        if hidden_states_path.exists():
            # delete the directory
            shutil.rmtree(hidden_states_path)

        training_process = execute_shell_command(
            "bash examples/run_llama3.1_8b_eagle3_offline.sh 2",
        )
        training_process.wait()
        self.assertEqual(training_process.returncode, 0)


if __name__ == "__main__":
    unittest.main(verbosity=2)