| | import os |
| | import tempfile |
| | import unittest |
| |
|
| | from transformers import TrainingArguments |
| |
|
| |
|
| | class TestTrainingArguments(unittest.TestCase): |
| | def test_default_output_dir(self): |
| | """Test that output_dir defaults to 'trainer_output' when not specified.""" |
| | args = TrainingArguments(output_dir=None) |
| | self.assertEqual(args.output_dir, "trainer_output") |
| |
|
| | def test_custom_output_dir(self): |
| | """Test that output_dir is respected when specified.""" |
| | with tempfile.TemporaryDirectory() as tmp_dir: |
| | args = TrainingArguments(output_dir=tmp_dir) |
| | self.assertEqual(args.output_dir, tmp_dir) |
| |
|
| | def test_output_dir_creation(self): |
| | """Test that output_dir is created only when needed.""" |
| | with tempfile.TemporaryDirectory() as tmp_dir: |
| | output_dir = os.path.join(tmp_dir, "test_output") |
| |
|
| | |
| | self.assertFalse(os.path.exists(output_dir)) |
| |
|
| | |
| | args = TrainingArguments( |
| | output_dir=output_dir, |
| | do_train=True, |
| | save_strategy="no", |
| | report_to=None, |
| | ) |
| | self.assertFalse(os.path.exists(output_dir)) |
| |
|
| | |
| | args.save_strategy = "steps" |
| | args.save_steps = 1 |
| | self.assertFalse(os.path.exists(output_dir)) |
| |
|
| | |
| |
|
| | def test_torch_empty_cache_steps_requirements(self): |
| | """Test that torch_empty_cache_steps is a positive integer or None.""" |
| |
|
| | |
| | args = TrainingArguments(torch_empty_cache_steps=None) |
| | self.assertIsNone(args.torch_empty_cache_steps) |
| |
|
| | |
| | with self.assertRaises(ValueError): |
| | TrainingArguments(torch_empty_cache_steps=1.0) |
| | with self.assertRaises(ValueError): |
| | TrainingArguments(torch_empty_cache_steps="none") |
| |
|
| | |
| | with self.assertRaises(ValueError): |
| | TrainingArguments(torch_empty_cache_steps=-1) |
| |
|
| | |
| | with self.assertRaises(ValueError): |
| | TrainingArguments(torch_empty_cache_steps=0) |
| |
|
| | |
| | args = TrainingArguments(torch_empty_cache_steps=1) |
| | self.assertEqual(args.torch_empty_cache_steps, 1) |
| |
|