| | |
| | """ |
| | S3 Dataset Loader - Streaming JSON/JSONL loader for S3 training data |
| | S3 is the training mecca - all training data should be loaded from S3 |
| | """ |
| |
|
| | import contextlib |
| | import json |
| | import logging |
| | import os |
| | from collections.abc import Iterator |
| | from io import BytesIO |
| | from pathlib import Path |
| | from typing import TYPE_CHECKING, Any |
| |
|
| | try: |
| | import boto3 |
| | from botocore.exceptions import ClientError as _BotocoreClientError |
| | except ImportError: |
| | |
| | boto3 = None |
| | _BotocoreClientError = None |
| |
|
| | if TYPE_CHECKING: |
| | |
| | class ClientError(Exception): |
| | response: dict[str, Any] |
| | else: |
| | ClientError = ( |
| | _BotocoreClientError if _BotocoreClientError is not None else Exception |
| | ) |
| |
|
| | BOTO3_AVAILABLE = boto3 is not None |
| |
|
| | |
| | with contextlib.suppress(ImportError): |
| | from dotenv import load_dotenv |
| |
|
| | |
| | |
| | |
| | |
| | module_path = Path(__file__).resolve() |
| | env_paths = [] |
| | try: |
| | env_paths.append(module_path.parents[2] / ".env") |
| | env_paths.append(module_path.parents[3] / ".env") |
| | except IndexError: |
| | |
| | env_paths.append(module_path.parent / ".env") |
| | if module_path.parent.name != "ai": |
| | env_paths.append(module_path.parent.parent / ".env") |
| |
|
| | for env_path in env_paths: |
| | try: |
| | if env_path.exists() and env_path.is_file(): |
| | load_dotenv(env_path, override=False) |
| | break |
| | except Exception: |
| | continue |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class S3DatasetLoader: |
| | """ |
| | Load datasets from S3 with streaming support for large files. |
| | S3 is the canonical training data location. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | bucket: str = "pixel-data", |
| | endpoint_url: str | None = None, |
| | aws_access_key_id: str | None = None, |
| | aws_secret_access_key: str | None = None, |
| | region_name: str = "us-east-va", |
| | ): |
| | """ |
| | Initialize S3 client for dataset loading. |
| | |
| | Args: |
| | bucket: S3 bucket name (default: pixel-data) |
| | endpoint_url: S3 endpoint URL (default: OVH S3 endpoint) |
| | aws_access_key_id: AWS access key (from env if not provided) |
| | aws_secret_access_key: AWS secret key (from env if not provided) |
| | region_name: AWS region (default: us-east-va for OVH) |
| | """ |
| | if boto3 is None: |
| | raise ImportError( |
| | "boto3 is required for S3 dataset loading. " |
| | "Install with: uv pip install boto3" |
| | ) |
| |
|
| | |
| | |
| | self.bucket = os.getenv("OVH_S3_BUCKET", bucket) |
| | print( |
| | f"[DEBUG] S3Loader: env OVH_S3_BUCKET={os.getenv('OVH_S3_BUCKET')}, " |
| | f"input bucket={bucket}, final={self.bucket}", |
| | flush=True, |
| | ) |
| | self.endpoint_url = endpoint_url or os.getenv( |
| | "OVH_S3_ENDPOINT", "https://s3.us-east-va.io.cloud.ovh.us" |
| | ) |
| |
|
| | |
| | access_key = ( |
| | aws_access_key_id |
| | or os.getenv("OVH_S3_ACCESS_KEY") |
| | or os.getenv("OVH_ACCESS_KEY") |
| | or os.getenv("AWS_ACCESS_KEY_ID") |
| | ) |
| | secret_key = ( |
| | aws_secret_access_key |
| | or os.getenv("OVH_S3_SECRET_KEY") |
| | or os.getenv("OVH_SECRET_KEY") |
| | or os.getenv("AWS_SECRET_ACCESS_KEY") |
| | ) |
| |
|
| | if not access_key or not secret_key: |
| | raise ValueError( |
| | "S3 credentials not found. Set OVH_S3_ACCESS_KEY/OVH_S3_SECRET_KEY " |
| | "(or OVH_ACCESS_KEY/OVH_SECRET_KEY, " |
| | "or AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY)." |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | verify_ssl = os.getenv("OVH_S3_CA_BUNDLE", True) |
| | |
| | if str(verify_ssl).lower() in {"false", "0", "no"}: |
| | verify_ssl = False |
| |
|
| | if verify_ssl is False: |
| | logger.warning( |
| | "Initializing S3 client with SSL verification DISABLED (insecure)" |
| | ) |
| |
|
| | self.s3_client = boto3.client( |
| | "s3", |
| | endpoint_url=self.endpoint_url, |
| | aws_access_key_id=access_key, |
| | aws_secret_access_key=secret_key, |
| | region_name=region_name or os.getenv("OVH_S3_REGION", "us-east-va"), |
| | verify=verify_ssl, |
| | ) |
| |
|
| | logger.info(f"S3DatasetLoader initialized for bucket: {bucket}") |
| |
|
| | def _parse_s3_path(self, s3_path: str) -> tuple[str, str]: |
| | """ |
| | Parse S3 path into bucket and key. |
| | |
| | Args: |
| | s3_path: S3 path (s3://bucket/key or just key) |
| | |
| | Returns: |
| | Tuple of (bucket, key) |
| | """ |
| | |
| | if s3_path.startswith("s3://"): |
| | s3_path = s3_path.removeprefix("s3://") |
| | if "/" in s3_path: |
| | parts = s3_path.split("/", 1) |
| | return parts[0], parts[1] |
| | |
| | return s3_path, "" |
| |
|
| | |
| | return self.bucket, s3_path |
| |
|
| | def object_exists(self, s3_path: str) -> bool: |
| | """Check if S3 object exists""" |
| | try: |
| | bucket, key = self._parse_s3_path(s3_path) |
| | self.s3_client.head_object(Bucket=bucket, Key=key) |
| | return True |
| | except ClientError as e: |
| | if e.response["Error"]["Code"] == "404": |
| | return False |
| | raise |
| |
|
| | def load_json( |
| | self, |
| | s3_path: str, |
| | cache_local: Path | None = None, |
| | ) -> dict[str, Any]: |
| | """ |
| | Load JSON dataset from S3. |
| | |
| | Args: |
| | s3_path: S3 path (s3://bucket/key or just key) |
| | cache_local: Optional local cache path |
| | |
| | Returns: |
| | Parsed JSON data |
| | """ |
| | bucket, key = self._parse_s3_path(s3_path) |
| |
|
| | |
| | if cache_local and cache_local.exists(): |
| | logger.info(f"Loading from local cache: {cache_local}") |
| | with open(cache_local) as f: |
| | return json.load(f) |
| |
|
| | |
| | logger.info(f"Loading from S3: s3://{bucket}/{key}") |
| | try: |
| | response = self.s3_client.get_object(Bucket=bucket, Key=key) |
| | data = json.loads(response["Body"].read()) |
| |
|
| | |
| | if cache_local: |
| | cache_local.parent.mkdir(parents=True, exist_ok=True) |
| | with open(cache_local, "w") as f: |
| | json.dump(data, f) |
| | logger.info(f"Cached to: {cache_local}") |
| |
|
| | return data |
| | except ClientError as e: |
| | if e.response["Error"]["Code"] == "NoSuchKey": |
| | raise FileNotFoundError( |
| | f"Dataset not found in S3: s3://{bucket}/{key}" |
| | ) from e |
| | raise |
| |
|
| | def load_bytes(self, s3_path: str) -> bytes: |
| | """ |
| | Load raw bytes from S3. |
| | |
| | Args: |
| | s3_path: S3 path (s3://bucket/key or just key) |
| | |
| | Returns: |
| | Raw bytes of the object body |
| | """ |
| | bucket, key = self._parse_s3_path(s3_path) |
| | logger.info(f"Loading bytes from S3: s3://{bucket}/{key}") |
| |
|
| | try: |
| | response = self.s3_client.get_object(Bucket=bucket, Key=key) |
| | return response["Body"].read() |
| | except ClientError as e: |
| | if e.response["Error"]["Code"] == "NoSuchKey": |
| | raise FileNotFoundError( |
| | f"Dataset not found in S3: s3://{bucket}/{key}" |
| | ) from e |
| | raise |
| |
|
| | def load_text( |
| | self, |
| | s3_path: str, |
| | *, |
| | encoding: str = "utf-8", |
| | errors: str = "replace", |
| | ) -> str: |
| | """ |
| | Load a text object from S3. |
| | |
| | This is primarily for transcript corpora (e.g. .txt) that need to be |
| | converted into ChatML examples. |
| | """ |
| | data = self.load_bytes(s3_path) |
| | return data.decode(encoding, errors=errors) |
| |
|
| | def _parse_jsonl_line(self, line: bytes) -> dict[str, Any] | None: |
| | """ |
| | Parse a single JSONL line with robust error handling. |
| | |
| | Args: |
| | line: Raw bytes of a JSONL line |
| | |
| | Returns: |
| | Parsed JSON object or None if parsing failed |
| | """ |
| | try: |
| | return json.loads(line.decode("utf-8")) |
| | except UnicodeDecodeError: |
| | try: |
| | return json.loads(line.decode("utf-8", errors="replace")) |
| | except json.JSONDecodeError as e: |
| | logger.warning(f"Failed to parse JSONL line: {e}") |
| | except json.JSONDecodeError as e: |
| | logger.warning(f"Failed to parse JSONL line: {e}") |
| | return None |
| |
|
| | def _stream_with_iter_lines(self, body) -> Iterator[dict[str, Any]]: |
| | """ |
| | Stream JSONL using iter_lines() method. |
| | |
| | Args: |
| | body: S3 response body with iter_lines capability |
| | |
| | Yields: |
| | Parsed JSON objects |
| | """ |
| | for raw_line in body.iter_lines(): |
| | if not raw_line: |
| | continue |
| | parsed = self._parse_jsonl_line(raw_line) |
| | if parsed is not None: |
| | yield parsed |
| |
|
| | def _stream_with_manual_buffering(self, body) -> Iterator[dict[str, Any]]: |
| | """ |
| | Stream JSONL using manual buffering as fallback. |
| | |
| | Args: |
| | body: S3 response body |
| | |
| | Yields: |
| | Parsed JSON objects |
| | """ |
| | buffer = BytesIO() |
| | for chunk in body.iter_chunks(chunk_size=8192): |
| | buffer.write(chunk) |
| | while True: |
| | buffer.seek(0) |
| | line = buffer.readline() |
| | if not line: |
| | buffer = BytesIO() |
| | break |
| | if not line.endswith(b"\n"): |
| | |
| | rest = buffer.read() |
| | buffer = BytesIO(line + rest) |
| | break |
| |
|
| | parsed = self._parse_jsonl_line(line) |
| | if parsed is not None: |
| | yield parsed |
| |
|
| | rest = buffer.read() |
| | buffer = BytesIO(rest) |
| |
|
| | def stream_jsonl(self, s3_path: str) -> Iterator[dict[str, Any]]: |
| | """ |
| | Stream JSONL dataset from S3 (memory-efficient for large files). |
| | |
| | Args: |
| | s3_path: S3 path (s3://bucket/key or just key) |
| | |
| | Yields: |
| | Parsed JSON objects (one per line) |
| | """ |
| | bucket, key = self._parse_s3_path(s3_path) |
| |
|
| | logger.info(f"Streaming JSONL from S3: s3://{bucket}/{key}") |
| | try: |
| | response = self.s3_client.get_object(Bucket=bucket, Key=key) |
| | body = response["Body"] |
| |
|
| | with contextlib.closing(body): |
| | |
| | iter_lines = getattr(body, "iter_lines", None) |
| | if callable(iter_lines): |
| | yield from self._stream_with_iter_lines(body) |
| | else: |
| | |
| | yield from self._stream_with_manual_buffering(body) |
| |
|
| | except ClientError as e: |
| | if e.response["Error"]["Code"] == "NoSuchKey": |
| | raise FileNotFoundError( |
| | f"Dataset not found in S3: s3://{bucket}/{key}" |
| | ) from e |
| | raise |
| |
|
| | def list_datasets(self, prefix: str = "gdrive/processed/") -> list[str]: |
| | """ |
| | List available datasets in S3. |
| | |
| | Args: |
| | prefix: S3 prefix to search (default: gdrive/processed/) |
| | |
| | Returns: |
| | List of S3 paths |
| | """ |
| | logger.info(f"Listing datasets with prefix: {prefix}") |
| | datasets: list[str] = [] |
| |
|
| | try: |
| | paginator = self.s3_client.get_paginator("list_objects_v2") |
| | pages = paginator.paginate(Bucket=self.bucket, Prefix=prefix) |
| |
|
| | for page in pages: |
| | if "Contents" in page: |
| | datasets.extend( |
| | f"s3://{self.bucket}/{obj['Key']}" |
| | for obj in page["Contents"] |
| | if obj["Key"].endswith((".json", ".jsonl")) |
| | ) |
| |
|
| | except ClientError: |
| | logger.exception("Failed to list S3 objects") |
| | raise |
| | return datasets |
| |
|
| | def download_file(self, s3_path: str, local_path: Path | str) -> None: |
| | """Download a file from S3 to local path""" |
| | try: |
| | bucket, key = self._parse_s3_path(s3_path) |
| | logger.info(f"Downloading s3://{bucket}/{key} to {local_path}") |
| | self.s3_client.download_file(bucket, key, str(local_path)) |
| | except Exception: |
| | logger.exception(f"Failed to download {s3_path} to {local_path}") |
| | raise |
| |
|
| | def upload_file(self, local_path: Path | str, s3_key: str) -> None: |
| | """Upload a local file to S3""" |
| | try: |
| | if not isinstance(local_path, Path): |
| | local_path = Path(local_path) |
| |
|
| | bucket, key = self._parse_s3_path(s3_key) |
| |
|
| | logger.info(f"Uploading {local_path} to s3://{bucket}/{key}") |
| | self.s3_client.upload_file(str(local_path), bucket, key) |
| | except Exception: |
| | logger.exception(f"Failed to upload {local_path} to {s3_key}") |
| | raise |
| |
|
| |
|
| | def get_s3_dataset_path( |
| | dataset_name: str, |
| | category: str | None = None, |
| | bucket: str = "pixel-data", |
| | prefer_processed: bool = True, |
| | ) -> str: |
| | """ |
| | Get S3 path for dataset - S3 is canonical training data location. |
| | |
| | Args: |
| | dataset_name: Name of the dataset file |
| | category: Optional category (cot_reasoning, professional_therapeutic, etc.) |
| | bucket: S3 bucket name |
| | prefer_processed: Prefer processed/canonical structure over raw |
| | |
| | Returns: |
| | S3 path (s3://bucket/path) |
| | """ |
| | loader = S3DatasetLoader(bucket=bucket) |
| |
|
| | |
| | if category and prefer_processed: |
| | path = f"s3://{bucket}/gdrive/processed/{category}/{dataset_name}" |
| | if loader.object_exists(path): |
| | return path |
| |
|
| | |
| | if prefer_processed: |
| | path = f"s3://{bucket}/gdrive/raw/{dataset_name}" |
| | if loader.object_exists(path): |
| | return path |
| |
|
| | |
| | path = f"s3://{bucket}/acquired/{dataset_name}" |
| | if loader.object_exists(path): |
| | return path |
| |
|
| | |
| | if category: |
| | return f"s3://{bucket}/gdrive/processed/{category}/{dataset_name}" |
| |
|
| | return f"s3://{bucket}/gdrive/raw/{dataset_name}" |
| |
|
| |
|
| | def load_dataset_from_s3( |
| | dataset_name: str, |
| | category: str | None = None, |
| | cache_local: Path | None = None, |
| | bucket: str = "pixel-data", |
| | ) -> dict[str, Any]: |
| | """ |
| | Load dataset from S3 with automatic path resolution. |
| | |
| | Args: |
| | dataset_name: Name of the dataset file |
| | category: Optional category for canonical structure |
| | cache_local: Optional local cache path |
| | bucket: S3 bucket name |
| | |
| | Returns: |
| | Dataset data |
| | """ |
| | loader = S3DatasetLoader(bucket=bucket) |
| | s3_path = get_s3_dataset_path(dataset_name, category, bucket) |
| |
|
| | if dataset_name.endswith(".jsonl"): |
| | |
| | return {"conversations": list(loader.stream_jsonl(s3_path))} |
| | return loader.load_json(s3_path, cache_local) |
| |
|