| import os |
| import tempfile |
| import unittest |
|
|
| from transformers import TrainingArguments |
|
|
|
|
| class TestTrainingArguments(unittest.TestCase): |
| def test_default_output_dir(self): |
| """Test that output_dir defaults to 'trainer_output' when not specified.""" |
| args = TrainingArguments(output_dir=None) |
| self.assertEqual(args.output_dir, "trainer_output") |
|
|
| def test_custom_output_dir(self): |
| """Test that output_dir is respected when specified.""" |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| args = TrainingArguments(output_dir=tmp_dir) |
| self.assertEqual(args.output_dir, tmp_dir) |
|
|
| def test_output_dir_creation(self): |
| """Test that output_dir is created only when needed.""" |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| output_dir = os.path.join(tmp_dir, "test_output") |
|
|
| |
| self.assertFalse(os.path.exists(output_dir)) |
|
|
| |
| args = TrainingArguments( |
| output_dir=output_dir, |
| do_train=True, |
| save_strategy="no", |
| report_to=None, |
| ) |
| self.assertFalse(os.path.exists(output_dir)) |
|
|
| |
| args.save_strategy = "steps" |
| args.save_steps = 1 |
| self.assertFalse(os.path.exists(output_dir)) |
|
|
| |
|
|
| def test_torch_empty_cache_steps_requirements(self): |
| """Test that torch_empty_cache_steps is a positive integer or None.""" |
|
|
| |
| args = TrainingArguments(torch_empty_cache_steps=None) |
| self.assertIsNone(args.torch_empty_cache_steps) |
|
|
| |
| with self.assertRaises(ValueError): |
| TrainingArguments(torch_empty_cache_steps=1.0) |
| with self.assertRaises(ValueError): |
| TrainingArguments(torch_empty_cache_steps="none") |
|
|
| |
| with self.assertRaises(ValueError): |
| TrainingArguments(torch_empty_cache_steps=-1) |
|
|
| |
| with self.assertRaises(ValueError): |
| TrainingArguments(torch_empty_cache_steps=0) |
|
|
| |
| args = TrainingArguments(torch_empty_cache_steps=1) |
| self.assertEqual(args.torch_empty_cache_steps, 1) |
|
|