| | import os |
| | import time |
| | import tarfile |
| | import hashlib |
| | import shutil |
| | import argparse |
| | import sys |
| | from enum import Enum, auto |
| | from pathlib import Path |
| | from typing import Optional |
| | from dataclasses import dataclass |
| | from contextlib import contextmanager |
| | import logging |
| | from dotenv import load_dotenv |
| | from huggingface_hub import CommitScheduler, HfApi |
| |
|
| | class SyncMode(Enum): |
| | INIT_ONLY = auto() |
| | SYNC_ONLY = auto() |
| | BOTH = auto() |
| |
|
| | @dataclass |
| | class Config: |
| | repo_id: str |
| | sync_interval: int |
| | data_path: Path |
| | sync_path: Path |
| | tmp_path: Path |
| | archive_name: str |
| |
|
| | @classmethod |
| | def from_env(cls): |
| | load_dotenv() |
| | repo_id = os.getenv('HF_DATASET_REPO_ID') |
| | if not repo_id: |
| | raise ValueError("HF_DATASET_REPO_ID must be set") |
| | |
| | return cls( |
| | repo_id=repo_id, |
| | sync_interval=int(os.getenv('SYNC_INTERVAL', '5')), |
| | data_path=Path("/data"), |
| | sync_path=Path("/sync"), |
| | tmp_path=Path("/tmp/sync"), |
| | archive_name="data.tar.gz" |
| | ) |
| |
|
| | class Logger: |
| | def __init__(self): |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| | ) |
| | self.logger = logging.getLogger(__name__) |
| |
|
| | class DirectoryMonitor: |
| | def __init__(self, path: Path): |
| | self.path = path |
| | self.last_hash: Optional[str] = None |
| | |
| | def get_directory_hash(self) -> str: |
| | sha256_hash = hashlib.sha256() |
| | |
| | all_files = sorted( |
| | str(p) for p in self.path.rglob('*') if p.is_file() |
| | ) |
| | |
| | for file_path in all_files: |
| | rel_path = os.path.relpath(file_path, self.path) |
| | sha256_hash.update(rel_path.encode()) |
| | |
| | with open(file_path, 'rb') as f: |
| | for chunk in iter(lambda: f.read(4096), b''): |
| | sha256_hash.update(chunk) |
| | |
| | return sha256_hash.hexdigest() |
| |
|
| | def has_changes(self) -> bool: |
| | current_hash = self.get_directory_hash() |
| | if current_hash != self.last_hash: |
| | self.last_hash = current_hash |
| | return True |
| | return False |
| |
|
| | class ArchiveManager: |
| | def __init__(self, config: Config, logger: Logger): |
| | self.config = config |
| | self.logger = logger.logger |
| | |
| | @contextmanager |
| | def safe_archive(self): |
| | """安全地创建归档文件的上下文管理器""" |
| | self.config.tmp_path.mkdir(parents=True, exist_ok=True) |
| | tmp_archive = self.config.tmp_path / self.config.archive_name |
| | |
| | try: |
| | with tarfile.open(tmp_archive, "w:gz") as tar: |
| | yield tar |
| | |
| | |
| | self.config.sync_path.mkdir(parents=True, exist_ok=True) |
| | shutil.move(tmp_archive, self.config.sync_path / self.config.archive_name) |
| | |
| | finally: |
| | |
| | if tmp_archive.exists(): |
| | tmp_archive.unlink() |
| |
|
| | def create_archive(self): |
| | """创建压缩包""" |
| | self.logger.info("Creating new archive...") |
| | with self.safe_archive() as tar: |
| | tar.add(self.config.data_path, arcname="data") |
| | self.logger.info("Archive created") |
| |
|
| | def extract_archive(self): |
| | """解压现有数据""" |
| | api = HfApi() |
| | try: |
| | self.logger.info("Downloading data archive...") |
| | api.hf_hub_download( |
| | repo_id=self.config.repo_id, |
| | filename=self.config.archive_name, |
| | repo_type="dataset", |
| | local_dir=self.config.sync_path |
| | ) |
| | |
| | self.logger.info("Extracting archive...") |
| | archive_path = self.config.sync_path / self.config.archive_name |
| | with tarfile.open(archive_path, "r:gz") as tar: |
| | tar.extractall( |
| | path=self.config.data_path, |
| | filter=self._tar_filter |
| | ) |
| | return True |
| | except Exception as e: |
| | self.logger.error(f"No existing archive found or download failed: {e}") |
| | self.config.data_path.mkdir(parents=True, exist_ok=True) |
| | return False |
| |
|
| | @staticmethod |
| | def _tar_filter(tarinfo, path): |
| | """tar 文件过滤器""" |
| | if tarinfo.name.startswith('data/'): |
| | tarinfo.name = tarinfo.name[5:] |
| | return tarinfo |
| | return None |
| |
|
| | class SyncService: |
| | def __init__(self, config: Config, logger: Logger): |
| | self.config = config |
| | self.logger = logger.logger |
| | self.monitor = DirectoryMonitor(config.data_path) |
| | self.archive_manager = ArchiveManager(config, logger) |
| | |
| | def init(self) -> bool: |
| | """ |
| | 执行初始化操作 |
| | 返回: 是否成功初始化 |
| | """ |
| | try: |
| | self.logger.info("Starting initialization...") |
| | self.config.sync_path.mkdir(parents=True, exist_ok=True) |
| | success = self.archive_manager.extract_archive() |
| | if success: |
| | self.logger.info("Initialization completed successfully") |
| | else: |
| | self.logger.warning("Initialization completed with warnings") |
| | return success |
| | except Exception as e: |
| | self.logger.error(f"Initialization failed: {e}") |
| | return False |
| |
|
| | def sync(self): |
| | """执行持续同步操作""" |
| | self.logger.info(f"Starting sync process for repo: {self.config.repo_id}") |
| | self.logger.info(f"Sync interval: {self.config.sync_interval} minutes") |
| |
|
| | scheduler = CommitScheduler( |
| | repo_id=self.config.repo_id, |
| | repo_type="dataset", |
| | folder_path=str(self.config.sync_path), |
| | path_in_repo="", |
| | every=self.config.sync_interval, |
| | squash_history=True, |
| | private=True |
| | ) |
| |
|
| | try: |
| | while True: |
| | if self.monitor.has_changes(): |
| | self.logger.info("Directory changes detected, creating new archive...") |
| | self.archive_manager.create_archive() |
| | else: |
| | self.logger.info("No changes detected") |
| | |
| | self.logger.info(f"Waiting {self.config.sync_interval} minutes until next check...") |
| | time.sleep(self.config.sync_interval * 60) |
| | except KeyboardInterrupt: |
| | self.logger.info("Stopping sync process...") |
| | scheduler.stop() |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description='Data synchronization service') |
| | parser.add_argument( |
| | '--mode', |
| | type=str, |
| | choices=['init', 'sync', 'both'], |
| | default='both', |
| | help='Operation mode: init (initialization only), sync (synchronization only), both (default)' |
| | ) |
| | return parser.parse_args() |
| |
|
| | def main(): |
| | args = parse_args() |
| | config = Config.from_env() |
| | logger = Logger() |
| | service = SyncService(config, logger) |
| |
|
| | mode = { |
| | 'init': SyncMode.INIT_ONLY, |
| | 'sync': SyncMode.SYNC_ONLY, |
| | 'both': SyncMode.BOTH |
| | }[args.mode] |
| |
|
| | if mode in (SyncMode.INIT_ONLY, SyncMode.BOTH): |
| | success = service.init() |
| | if not success: |
| | sys.exit(1) |
| | if mode == SyncMode.INIT_ONLY: |
| | return |
| |
|
| | if mode in (SyncMode.SYNC_ONLY, SyncMode.BOTH): |
| | service.sync() |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|