| import os |
| import time |
| import tarfile |
| import hashlib |
| import shutil |
| import argparse |
| import sys |
| from enum import Enum, auto |
| from pathlib import Path |
| from typing import Optional |
| from dataclasses import dataclass |
| from contextlib import contextmanager |
| import logging |
| from dotenv import load_dotenv |
| from huggingface_hub import CommitScheduler, HfApi |
|
|
| class SyncMode(Enum): |
| INIT_ONLY = auto() |
| SYNC_ONLY = auto() |
| BOTH = auto() |
|
|
| @dataclass |
| class Config: |
| repo_id: str |
| sync_interval: int |
| data_path: Path |
| sync_path: Path |
| tmp_path: Path |
| archive_name: str |
|
|
| @classmethod |
| def from_env(cls): |
| load_dotenv() |
| repo_id = os.getenv('HF_DATASET_REPO_ID') |
| if not repo_id: |
| raise ValueError("HF_DATASET_REPO_ID must be set") |
| |
| return cls( |
| repo_id=repo_id, |
| sync_interval=int(os.getenv('SYNC_INTERVAL', '5')), |
| data_path=Path("/data"), |
| sync_path=Path("/sync"), |
| tmp_path=Path("/tmp/sync"), |
| archive_name="data.tar.gz" |
| ) |
|
|
| class Logger: |
| def __init__(self): |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| self.logger = logging.getLogger(__name__) |
|
|
| class DirectoryMonitor: |
| def __init__(self, path: Path): |
| self.path = path |
| self.last_hash: Optional[str] = None |
| |
| def get_directory_hash(self) -> str: |
| sha256_hash = hashlib.sha256() |
| |
| all_files = sorted( |
| str(p) for p in self.path.rglob('*') if p.is_file() |
| ) |
| |
| for file_path in all_files: |
| rel_path = os.path.relpath(file_path, self.path) |
| sha256_hash.update(rel_path.encode()) |
| |
| with open(file_path, 'rb') as f: |
| for chunk in iter(lambda: f.read(4096), b''): |
| sha256_hash.update(chunk) |
| |
| return sha256_hash.hexdigest() |
|
|
| def has_changes(self) -> bool: |
| current_hash = self.get_directory_hash() |
| if current_hash != self.last_hash: |
| self.last_hash = current_hash |
| return True |
| return False |
|
|
| class ArchiveManager: |
| def __init__(self, config: Config, logger: Logger): |
| self.config = config |
| self.logger = logger.logger |
| |
| @contextmanager |
| def safe_archive(self): |
| """安全地创建归档文件的上下文管理器""" |
| self.config.tmp_path.mkdir(parents=True, exist_ok=True) |
| tmp_archive = self.config.tmp_path / self.config.archive_name |
| |
| try: |
| with tarfile.open(tmp_archive, "w:gz") as tar: |
| yield tar |
| |
| |
| self.config.sync_path.mkdir(parents=True, exist_ok=True) |
| shutil.move(tmp_archive, self.config.sync_path / self.config.archive_name) |
| |
| finally: |
| |
| if tmp_archive.exists(): |
| tmp_archive.unlink() |
|
|
| def create_archive(self): |
| """创建压缩包""" |
| self.logger.info("Creating new archive...") |
| with self.safe_archive() as tar: |
| tar.add(self.config.data_path, arcname="data") |
| self.logger.info("Archive created") |
|
|
| def extract_archive(self): |
| """解压现有数据""" |
| api = HfApi() |
| try: |
| self.logger.info("Downloading data archive...") |
| api.hf_hub_download( |
| repo_id=self.config.repo_id, |
| filename=self.config.archive_name, |
| repo_type="dataset", |
| local_dir=self.config.sync_path |
| ) |
| |
| self.logger.info("Extracting archive...") |
| archive_path = self.config.sync_path / self.config.archive_name |
| with tarfile.open(archive_path, "r:gz") as tar: |
| tar.extractall( |
| path=self.config.data_path, |
| filter=self._tar_filter |
| ) |
| return True |
| except Exception as e: |
| self.logger.error(f"No existing archive found or download failed: {e}") |
| self.config.data_path.mkdir(parents=True, exist_ok=True) |
| return False |
|
|
| @staticmethod |
| def _tar_filter(tarinfo, path): |
| """tar 文件过滤器""" |
| if tarinfo.name.startswith('data/'): |
| tarinfo.name = tarinfo.name[5:] |
| return tarinfo |
| return None |
|
|
| class SyncService: |
| def __init__(self, config: Config, logger: Logger): |
| self.config = config |
| self.logger = logger.logger |
| self.monitor = DirectoryMonitor(config.data_path) |
| self.archive_manager = ArchiveManager(config, logger) |
| |
| def init(self) -> bool: |
| """ |
| 执行初始化操作 |
| 返回: 是否成功初始化 |
| """ |
| try: |
| self.logger.info("Starting initialization...") |
| self.config.sync_path.mkdir(parents=True, exist_ok=True) |
| success = self.archive_manager.extract_archive() |
| if success: |
| self.logger.info("Initialization completed successfully") |
| else: |
| self.logger.warning("Initialization completed with warnings") |
| return success |
| except Exception as e: |
| self.logger.error(f"Initialization failed: {e}") |
| return False |
|
|
| def sync(self): |
| """执行持续同步操作""" |
| self.logger.info(f"Starting sync process for repo: {self.config.repo_id}") |
| self.logger.info(f"Sync interval: {self.config.sync_interval} minutes") |
|
|
| scheduler = CommitScheduler( |
| repo_id=self.config.repo_id, |
| repo_type="dataset", |
| folder_path=str(self.config.sync_path), |
| path_in_repo="", |
| every=self.config.sync_interval, |
| squash_history=True, |
| private=True |
| ) |
|
|
| try: |
| while True: |
| if self.monitor.has_changes(): |
| self.logger.info("Directory changes detected, creating new archive...") |
| self.archive_manager.create_archive() |
| else: |
| self.logger.info("No changes detected") |
| |
| self.logger.info(f"Waiting {self.config.sync_interval} minutes until next check...") |
| time.sleep(self.config.sync_interval * 60) |
| except KeyboardInterrupt: |
| self.logger.info("Stopping sync process...") |
| scheduler.stop() |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Data synchronization service') |
| parser.add_argument( |
| '--mode', |
| type=str, |
| choices=['init', 'sync', 'both'], |
| default='both', |
| help='Operation mode: init (initialization only), sync (synchronization only), both (default)' |
| ) |
| return parser.parse_args() |
|
|
| def main(): |
| args = parse_args() |
| config = Config.from_env() |
| logger = Logger() |
| service = SyncService(config, logger) |
|
|
| mode = { |
| 'init': SyncMode.INIT_ONLY, |
| 'sync': SyncMode.SYNC_ONLY, |
| 'both': SyncMode.BOTH |
| }[args.mode] |
|
|
| if mode in (SyncMode.INIT_ONLY, SyncMode.BOTH): |
| success = service.init() |
| if not success: |
| sys.exit(1) |
| if mode == SyncMode.INIT_ONLY: |
| return |
|
|
| if mode in (SyncMode.SYNC_ONLY, SyncMode.BOTH): |
| service.sync() |
|
|
| if __name__ == "__main__": |
| main() |
|
|