| | import os |
| | import shutil |
| | from watchdog.observers import Observer |
| | from watchdog.events import FileSystemEventHandler |
| | import time |
| | import re |
| |
|
| | class CheckpointHandler(FileSystemEventHandler): |
| | def __init__(self, folder_path, max_checkpoints=2): |
| | self.folder_path = folder_path |
| | self.max_checkpoints = max_checkpoints |
| |
|
| | def on_created(self, event): |
| | if not event.is_directory: |
| | return |
| | |
| |
|
| | def cleanup_checkpoints(self): |
| | |
| | checkpoints = [os.path.join(self.folder_path, d) for d in os.listdir(self.folder_path) if os.path.isdir(os.path.join(self.folder_path, d))] |
| | |
| | |
| | checkpoints = [checkpoint for checkpoint in checkpoints if re.match(r'global_step_\d+', os.path.basename(checkpoint))] |
| |
|
| | |
| | checkpoints_with_time = [(os.path.getctime(checkpoint), checkpoint) for checkpoint in checkpoints] |
| | checkpoints_with_time.sort() |
| | |
| | specific_checkpoints = {f"global_step_{i}" for i in [45, 90, 135, 180, 220]} |
| |
|
| | |
| | if len(checkpoints_with_time) <= self.max_checkpoints: |
| | print(f"No need to remove any checkpoints, {len(checkpoints_with_time)} checkpoints exist") |
| | else: |
| | for _, checkpoint in checkpoints_with_time[:-self.max_checkpoints]: |
| | checkpoint_name = os.path.basename(checkpoint) |
| | if checkpoint_name not in specific_checkpoints: |
| | shutil.rmtree(checkpoint) |
| | print(f"Removed old checkpoint: {checkpoint}") |
| | else: |
| | print(f"Skipped specific checkpoint: {checkpoint}") |
| |
|
| | def main(): |
| | folder_path = '/data/wuxinrui/easyr1_checkpoints/1_5B_TCMv2_long_short_regular_budget_modified' |
| | event_handler = CheckpointHandler(folder_path) |
| | observer = Observer() |
| | observer.schedule(event_handler, folder_path, recursive=False) |
| | observer.start() |
| |
|
| | try: |
| | while True: |
| | event_handler.cleanup_checkpoints() |
| | time.sleep(300) |
| | except KeyboardInterrupt: |
| | observer.stop() |
| | observer.join() |
| |
|
| | if __name__ == "__main__": |
| | main() |