Hanrui / progress /github /SpecForge /tests /test_scripts /test_train_eagle3.py
Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
import shutil
import unittest
from pathlib import Path
from tests.utils import execute_shell_command
CACHE_DIR = Path(__file__).parent.parent.parent.joinpath("cache")
def replace_in_script(script_path: Path, pattern: str, replacement: str):
with open(script_path, "r") as f:
script = f.readlines()
script = [line.replace(pattern, replacement) for line in script]
with open(script_path, "w") as f:
for line in script:
f.write(line)
class TestTrainEagle3(unittest.TestCase):
def setUp(self) -> None:
# prepare data
data_process = execute_shell_command(
"python scripts/prepare_data.py --dataset sharegpt"
)
data_process.wait()
# modify the sccript to only train for 10 steps
# add --max-num-steps 10 to the launch command
script_path = Path(__file__).parent.parent.parent.joinpath(
"examples", "run_llama3.1_8b_eagle3_online.sh"
)
with open(script_path, "r") as f:
script = f.readlines()
# remove empty lines
script = [line for line in script if line.strip()]
script[-1] = script[-1].rstrip() + " --max-num-steps 10"
# replace meta-llama/Llama-3.1-8B-Instruct with unsloth/Llama-3.2-1B-Instruct
# so that we don't need HF token for gated repo
script = [
line.replace(
"meta-llama/Llama-3.1-8B-Instruct", "nreHieW/Llama-3.1-8B-Instruct"
)
for line in script
]
# write the script back to the file
with open(script_path, "w") as f:
for line in script:
f.write(line)
def test_online_train_eagle3_with_sglang_backend(self):
# run training
train_process = execute_shell_command(
"bash examples/run_llama3.1_8b_eagle3_online.sh 2"
)
train_process.wait()
self.assertEqual(train_process.returncode, 0)
def test_online_train_eagle3_with_hf_backend(self):
# replace --target-model-backend sglang with --target-model-backend hf
script_path = Path(__file__).parent.parent.parent.joinpath(
"examples", "run_llama3.1_8b_eagle3_online.sh"
)
replace_in_script(
script_path, "--target-model-backend sglang", "--target-model-backend hf"
)
# run training
train_process = execute_shell_command(
"bash examples/run_llama3.1_8b_eagle3_online.sh 2"
)
train_process.wait()
self.assertEqual(train_process.returncode, 0)
def test_online_train_eagle3_with_custom_backend(self):
# replace --target-model-backend sglang with --target-model-backend custom
script_path = Path(__file__).parent.parent.parent.joinpath(
"examples", "run_llama3.1_8b_eagle3_online.sh"
)
replace_in_script(
script_path,
"--target-model-backend sglang",
"--target-model-backend custom",
)
# run training
train_process = execute_shell_command(
"bash examples/run_llama3.1_8b_eagle3_online.sh 2"
)
train_process.wait()
self.assertEqual(train_process.returncode, 0)
def test_offline_train_eagle3(self):
# remove the hidden states if they exist
script_path = Path(__file__).parent.parent.parent.joinpath(
"examples", "run_llama3.1_8b_eagle3_offline.sh"
)
replace_in_script(
script_path,
"meta-llama/Llama-3.1-8B-Instruct",
"nreHieW/Llama-3.1-8B-Instruct",
)
replace_in_script(
script_path,
"--batch-size 32",
"--batch-size 5",
)
replace_in_script(
script_path,
"scripts/prepare_hidden_states.py",
"scripts/prepare_hidden_states.py --num-samples 10",
)
replace_in_script(
script_path,
"$ROOT_DIR/scripts/train_eagle3.py",
"$ROOT_DIR/scripts/train_eagle3.py --max-num-steps 2",
)
hidden_states_path = Path(__file__).parent.parent.parent.joinpath(
"cache", "hidden_states", "sharegpt_train_Llama-3.1-8B-Instruct"
)
if hidden_states_path.exists():
# delete the directory
shutil.rmtree(hidden_states_path)
training_process = execute_shell_command(
"bash examples/run_llama3.1_8b_eagle3_offline.sh 2",
)
training_process.wait()
self.assertEqual(training_process.returncode, 0)
if __name__ == "__main__":
unittest.main(verbosity=2)