| |
| """ |
| Custom streaming data loader for AGILLM training |
| Pulls from stream_server on scraper box via HTTP |
| Drop-in replacement for HuggingFace dataset streaming |
| """ |
| import requests |
| import json |
| from typing import Iterator, Dict, Any |
|
|
| class ScraperStreamDataset: |
| """ |
| Streams training data from the scraper server. |
| Compatible with AGILLM's _stream() interface. |
| """ |
| def __init__( |
| self, |
| server_url: str = "http://localhost:8888", |
| batch_size: int = 100, |
| text_field: str = "text", |
| shuffle: bool = True |
| ): |
| self.server_url = server_url |
| self.batch_size = batch_size |
| self.text_field = text_field |
| self.shuffle = shuffle |
| self._buffer = [] |
| |
| def __iter__(self) -> Iterator[Dict[str, Any]]: |
| return self |
| |
| def __next__(self) -> Dict[str, Any]: |
| if not self._buffer: |
| self._fetch_batch() |
| if not self._buffer: |
| raise StopIteration |
| return self._buffer.pop(0) |
| |
| def _fetch_batch(self): |
| """Fetch a batch from stream server""" |
| endpoint = "/stream" if self.shuffle else "/sequential" |
| try: |
| resp = requests.get( |
| f"{self.server_url}{endpoint}", |
| params={"batch": self.batch_size}, |
| stream=True, |
| timeout=30 |
| ) |
| for line in resp.iter_lines(): |
| if line: |
| try: |
| obj = json.loads(line.decode('utf-8')) |
| |
| self._buffer.append({self.text_field: obj.get("text", "")}) |
| except json.JSONDecodeError: |
| continue |
| except requests.RequestException as e: |
| print(f"[StreamLoader] Fetch error: {e}") |
| |
| def get_status(self) -> dict: |
| """Get server status""" |
| try: |
| resp = requests.get(f"{self.server_url}/status", timeout=10) |
| return resp.json() |
| except: |
| return {"error": "unreachable"} |
|
|
|
|
| def create_stream_iterator(server_url: str = "http://localhost:8888", seed: int = 42): |
| """ |
| Create iterator compatible with AGILLM's _stream() function. |
| Returns infinite iterator of {"text": "..."} dicts. |
| """ |
| dataset = ScraperStreamDataset(server_url=server_url) |
| while True: |
| try: |
| yield next(dataset) |
| except StopIteration: |
| |
| dataset._fetch_batch() |
| if dataset._buffer: |
| yield dataset._buffer.pop(0) |
|
|
|
|
| |
| if __name__ == "__main__": |
| import sys |
| url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8888" |
| print(f"Testing stream from {url}") |
| |
| ds = ScraperStreamDataset(server_url=url, batch_size=5) |
| print(f"Status: {ds.get_status()}") |
| |
| for i, item in enumerate(ds): |
| text = item["text"] |
| print(f"Sample {i}: {len(text)} chars - {text[:100]}...") |
| if i >= 4: |
| break |
|
|