| import io |
| import os |
| import re |
| import pandas as pd |
| from typing import Any, Dict, List |
| import requests |
| from PIL import Image as PILImage |
| from scraper import ScraperBot, ScraperBotConfig |
| from helpers import starts_with_quotes, get_start_end_quotes |
| from dataclasses import dataclass |
| from datasets import Image |
|
|
|
|
| @dataclass(frozen=True) |
| class HFDatasetScheme: |
| caption: str |
| image: Image(decode=True) |
| link: str |
| message_id: str |
| timestamp: str |
|
|
|
|
| url_pattern = re.compile(r'https?://\S+') |
|
|
|
|
| def parse_fn(message: Dict[str, Any]) -> List[HFDatasetScheme]: |
| """Parses a message into a list of Hugging Face Dataset Schemes. |
| |
| Parameters |
| ---------- |
| message : Dict[str, Any] |
| The message to parse. |
| |
| Returns |
| ------- |
| List[HFDatasetScheme] |
| A list of Hugging Face Dataset Schemes. |
| """ |
| content = message["content"] |
|
|
| (first_quote_index, last_quote_index) = get_start_end_quotes(content) |
|
|
| |
| prompt = content[first_quote_index + 1:last_quote_index].strip() |
| image_urls = url_pattern.findall(content) |
| timestamp = message["timestamp"] |
| message_id = message["id"] |
|
|
| return [HFDatasetScheme(caption=prompt, image=None, link=image_url, message_id=message_id, timestamp=timestamp) |
| for image_url in image_urls] |
|
|
|
|
| def condition_fn(message: Dict[str, Any]) -> bool: |
| """Checks if a message meets the condition to be parsed. |
| |
| Parameters |
| ---------- |
| message : Dict[str, Any] |
| The message to check. |
| |
| Returns |
| ------- |
| bool |
| True if the message meets the condition, False otherwise. |
| """ |
| return url_pattern.search(message["content"]) and starts_with_quotes(message["content"]) |
|
|
|
|
| def prepare_dataset(messages: List[HFDatasetScheme]) -> pd.DataFrame: |
| return pd.DataFrame( |
| { |
| "caption": [msg.caption for msg in messages], |
| "image": [ |
| None for msg in messages |
| ], |
| "link": [ |
| msg.link for msg in messages |
| ], |
| "message_id": [msg.message_id for msg in messages], |
| "timestamp": [msg.timestamp for msg in messages], |
| } |
| ) |
|
|
|
|
| def get_image(link: str) -> bytes: |
| image = PILImage.open(requests.get(link, stream=True).raw).convert("RGB") |
| img_byte_arr = io.BytesIO() |
| image.save(img_byte_arr, format="PNG") |
| return {"bytes": img_byte_arr.getvalue(), "path": None} |
|
|
|
|
| if __name__ == "__main__": |
| config_path = os.path.join(os.path.dirname(__file__), "config.json") |
| config = ScraperBotConfig.from_json(config_path) |
|
|
| bot = ScraperBot(config=config, HFDatasetScheme=HFDatasetScheme, prepare_dataset=prepare_dataset, parse_fn=parse_fn, condition_fn=condition_fn, download_fn=get_image) |
| bot.scrape(fetch_all=os.environ.get("FETCH_ALL", "false").lower() == "true") |
|
|