| | import random |
| | from dataclasses import dataclass, field |
| | from functools import partial |
| | from pathlib import Path |
| |
|
| | import jax |
| | import jax.numpy as jnp |
| | import numpy as np |
| | from braceexpand import braceexpand |
| | from datasets import Dataset, load_dataset |
| |
|
| | from .model.text import TextNormalizer |
| |
|
| |
|
| | @dataclass |
| | class Dataset: |
| | dataset_repo_or_path: str |
| | train_file: str = None |
| | validation_file: str = None |
| | streaming: bool = True |
| | use_auth_token: bool = False |
| | text_column: str = "caption" |
| | encoding_column: str = "encoding" |
| | max_train_samples: int = None |
| | max_eval_samples: int = None |
| | preprocessing_num_workers: int = None |
| | overwrite_cache: bool = False |
| | do_train: bool = False |
| | do_eval: bool = True |
| | seed_dataset: int = None |
| | shard_by_host: bool = False |
| | blank_caption_prob: float = 0.0 |
| | clip_score_column: str = "clip_score" |
| | min_clip_score: float = None |
| | max_clip_score: float = None |
| | filter_column: str = None |
| | filter_value: str = None |
| | multi_eval_ds: bool = False |
| | train_dataset: Dataset = field(init=False) |
| | eval_dataset: Dataset = field(init=False) |
| | other_eval_datasets: list = field(init=False) |
| | rng_dataset: jnp.ndarray = field(init=False) |
| | multi_hosts: bool = field(init=False) |
| |
|
| | def __post_init__(self): |
| | if self.seed_dataset is None: |
| | |
| | self.seed_dataset = random.randint(0, 2**32 - 1) |
| | |
| | self.np_rng = np.random.default_rng(self.seed_dataset) |
| | self.multi_hosts = jax.process_count() > 1 |
| | |
| | |
| | if self.blank_caption_prob: |
| | assert ( |
| | self.streaming is True |
| | ), "blank_caption_prob can only be used in streaming mode" |
| | |
| | if self.train_file is not None or self.validation_file is not None: |
| | |
| | for k in ["train_file", "validation_file"]: |
| | f = getattr(self, k) |
| | if isinstance(f, str): |
| | setattr(self, k, list(braceexpand(f))) |
| | |
| | if ( |
| | isinstance(self.train_file, list) |
| | and self.multi_hosts |
| | and self.shard_by_host |
| | ): |
| | self.train_file = self.train_file[ |
| | jax.process_index() :: jax.process_count() |
| | ] |
| | data_files = { |
| | "train": self.train_file, |
| | "validation": self.validation_file, |
| | } |
| | else: |
| | data_files = None |
| |
|
| | |
| | if self.multi_eval_ds: |
| | assert Path( |
| | self.dataset_repo_or_path |
| | ).is_dir(), f"{self.dataset_repo_or_path} is not a directory, required for multi_eval_ds" |
| | data_files = { |
| | split.name: [str(f) for f in split.glob("*.parquet")] |
| | for split in Path(self.dataset_repo_or_path).glob("*") |
| | } |
| | |
| | if "valid" in data_files: |
| | data_files["validation"] = data_files["valid"] |
| | del data_files["valid"] |
| | self.dataset_repo_or_path = "parquet" |
| |
|
| | |
| | dataset = load_dataset( |
| | self.dataset_repo_or_path, |
| | data_files=data_files, |
| | streaming=self.streaming, |
| | use_auth_token=self.use_auth_token, |
| | ) |
| | if self.do_train: |
| | if "train" not in dataset: |
| | raise ValueError("Training requires a training dataset") |
| | self.train_dataset = dataset["train"] |
| | if self.max_train_samples is not None: |
| | self.train_dataset = ( |
| | self.train_dataset.take(self.max_train_samples) |
| | if self.streaming |
| | else self.train_dataset.select(range(self.max_train_samples)) |
| | ) |
| | if self.do_eval: |
| | if "validation" not in dataset: |
| | raise ValueError("Evaluating requires a validation dataset") |
| | self.eval_dataset = dataset["validation"] |
| | if self.max_eval_samples is not None: |
| | self.eval_dataset = ( |
| | self.eval_dataset.take(self.max_eval_samples) |
| | if self.streaming |
| | else self.eval_dataset.select(range(self.max_eval_samples)) |
| | ) |
| | |
| | other_eval_splits = dataset.keys() - {"train", "validation"} |
| | self.other_eval_datasets = { |
| | split: dataset[split] for split in other_eval_splits |
| | } |
| |
|
| | def preprocess(self, tokenizer, config): |
| | |
| | decoder_start_token_id = config.decoder_start_token_id |
| | normalize_text = config.normalize_text |
| | max_length = config.max_text_length |
| |
|
| | if self.streaming: |
| | |
| | if hasattr(self, "train_dataset"): |
| | self.train_dataset = self.train_dataset.shuffle( |
| | buffer_size=5000, seed=self.seed_dataset |
| | ) |
| | else: |
| | self.rng_dataset = jax.random.PRNGKey(self.seed_dataset) |
| |
|
| | |
| | partial_filter_function = partial( |
| | filter_function, |
| | filter_column=self.filter_column, |
| | filter_value=self.filter_value, |
| | clip_score_column=self.clip_score_column, |
| | min_clip_score=self.min_clip_score, |
| | max_clip_score=self.max_clip_score, |
| | ) |
| | for ds in ["train_dataset", "eval_dataset"]: |
| | if hasattr(self, ds): |
| | setattr( |
| | self, |
| | ds, |
| | ( |
| | getattr(self, ds).filter(partial_filter_function) |
| | if self.streaming |
| | else getattr(self, ds).filter( |
| | partial_filter_function, |
| | num_proc=self.preprocessing_num_workers, |
| | load_from_cache_file=not self.overwrite_cache, |
| | desc="Filtering datasets", |
| | ) |
| | ), |
| | ) |
| | if hasattr(self, "other_eval_datasets"): |
| | self.other_eval_datasets = { |
| | split: ( |
| | ds.filter(partial_filter_function) |
| | if self.streaming |
| | else ds.filter( |
| | partial_filter_function, |
| | num_proc=self.preprocessing_num_workers, |
| | load_from_cache_file=not self.overwrite_cache, |
| | desc="Filtering datasets", |
| | ) |
| | ) |
| | for split, ds in self.other_eval_datasets.items() |
| | } |
| |
|
| | |
| | if normalize_text: |
| | text_normalizer = TextNormalizer() |
| | partial_normalize_function = partial( |
| | normalize_function, |
| | text_column=self.text_column, |
| | text_normalizer=text_normalizer, |
| | ) |
| | for ds in ["train_dataset", "eval_dataset"]: |
| | if hasattr(self, ds): |
| | setattr( |
| | self, |
| | ds, |
| | ( |
| | getattr(self, ds).map(partial_normalize_function) |
| | if self.streaming |
| | else getattr(self, ds).map( |
| | partial_normalize_function, |
| | num_proc=self.preprocessing_num_workers, |
| | load_from_cache_file=not self.overwrite_cache, |
| | desc="Normalizing datasets", |
| | ) |
| | ), |
| | ) |
| | if hasattr(self, "other_eval_datasets"): |
| | self.other_eval_datasets = { |
| | split: ( |
| | ds.map(partial_normalize_function) |
| | if self.streaming |
| | else ds.map( |
| | partial_normalize_function, |
| | num_proc=self.preprocessing_num_workers, |
| | load_from_cache_file=not self.overwrite_cache, |
| | desc="Normalizing datasets", |
| | ) |
| | ) |
| | for split, ds in self.other_eval_datasets.items() |
| | } |
| |
|
| | |
| | if self.blank_caption_prob: |
| | partial_blank_caption_function = partial( |
| | blank_caption_function, |
| | text_column=self.text_column, |
| | blank_caption_prob=self.blank_caption_prob, |
| | rng=self.np_rng, |
| | ) |
| | if hasattr(self, "train_dataset"): |
| | self.train_dataset = ( |
| | self.train_dataset.map(partial_blank_caption_function) |
| | if self.streaming |
| | else self.train_dataset.map( |
| | partial_blank_caption_function, |
| | num_proc=None |
| | if self.seed_dataset |
| | else self.preprocessing_num_workers, |
| | load_from_cache_file=False, |
| | desc="Blanking some captions", |
| | ) |
| | ) |
| |
|
| | |
| | partial_preprocess_function = partial( |
| | preprocess_function, |
| | tokenizer=tokenizer, |
| | text_column=self.text_column, |
| | encoding_column=self.encoding_column, |
| | max_length=max_length, |
| | decoder_start_token_id=decoder_start_token_id, |
| | ) |
| | for ds in ["train_dataset", "eval_dataset"]: |
| | if hasattr(self, ds): |
| | setattr( |
| | self, |
| | ds, |
| | ( |
| | getattr(self, ds).map( |
| | partial_preprocess_function, |
| | batched=True, |
| | remove_columns=[ |
| | self.text_column, |
| | self.encoding_column, |
| | ], |
| | ) |
| | if self.streaming |
| | else getattr(self, ds).map( |
| | partial_preprocess_function, |
| | batched=True, |
| | remove_columns=getattr(ds, "column_names"), |
| | num_proc=self.preprocessing_num_workers, |
| | load_from_cache_file=not self.overwrite_cache, |
| | desc="Preprocessing datasets", |
| | ) |
| | ), |
| | ) |
| | if hasattr(self, "other_eval_datasets"): |
| | self.other_eval_datasets = { |
| | split: ( |
| | ds.map( |
| | partial_preprocess_function, |
| | batched=True, |
| | remove_columns=[ |
| | self.text_column, |
| | self.encoding_column, |
| | ], |
| | ) |
| | if self.streaming |
| | else ds.map( |
| | partial_preprocess_function, |
| | batched=True, |
| | remove_columns=getattr(ds, "column_names"), |
| | num_proc=self.preprocessing_num_workers, |
| | load_from_cache_file=not self.overwrite_cache, |
| | desc="Preprocessing datasets", |
| | ) |
| | ) |
| | for split, ds in self.other_eval_datasets.items() |
| | } |
| |
|
| | def dataloader(self, split, batch_size, epoch=None): |
| | def _dataloader_datasets_non_streaming( |
| | dataset: Dataset, |
| | rng: jax.random.PRNGKey = None, |
| | ): |
| | """ |
| | Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. |
| | Shuffle batches if rng is set. |
| | """ |
| | steps_per_epoch = len(dataset) // batch_size |
| |
|
| | if rng is not None: |
| | batch_idx = jax.random.permutation(rng, len(dataset)) |
| | else: |
| | batch_idx = jnp.arange(len(dataset)) |
| |
|
| | batch_idx = batch_idx[ |
| | : steps_per_epoch * batch_size |
| | ] |
| | batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) |
| |
|
| | for idx in batch_idx: |
| | batch = dataset[idx] |
| | batch = {k: jnp.array(v) for k, v in batch.items()} |
| | yield batch |
| |
|
| | def _dataloader_datasets_streaming( |
| | dataset: Dataset, |
| | epoch: int, |
| | ): |
| | keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"] |
| | batch = {k: [] for k in keys} |
| | first_loop = True |
| | while (self.multi_hosts and split == "train") or first_loop: |
| | |
| | |
| | |
| | |
| | if epoch is not None: |
| | assert split == "train" |
| | |
| | dataset.set_epoch(epoch) |
| | epoch += 1 |
| | for item in dataset: |
| | for k in keys: |
| | batch[k].append(item[k]) |
| | if len(batch[keys[0]]) == batch_size: |
| | batch = {k: jnp.array(v) for k, v in batch.items()} |
| | yield batch |
| | batch = {k: [] for k in keys} |
| | first_loop = False |
| |
|
| | if split == "train": |
| | ds = self.train_dataset |
| | elif split == "eval": |
| | ds = self.eval_dataset |
| | else: |
| | ds = self.other_eval_datasets[split] |
| |
|
| | if self.streaming: |
| | return _dataloader_datasets_streaming(ds, epoch) |
| | else: |
| | if split == "train": |
| | self.rng_dataset, input_rng = jax.random.split(self.rng_dataset) |
| | return _dataloader_datasets_non_streaming(ds, input_rng) |
| |
|
| | @property |
| | def length(self): |
| | len_train_dataset, len_eval_dataset = None, None |
| | if self.streaming: |
| | |
| | if self.max_train_samples is not None: |
| | len_train_dataset = self.max_train_samples |
| | if self.max_eval_samples is not None: |
| | len_eval_dataset = self.max_eval_samples |
| | else: |
| | len_train_dataset = ( |
| | len(self.train_dataset) if hasattr(self, "train_dataset") else None |
| | ) |
| | len_eval_dataset = ( |
| | len(self.eval_dataset) if hasattr(self, "eval_dataset") else None |
| | ) |
| | return len_train_dataset, len_eval_dataset |
| |
|
| |
|
| | def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int): |
| | """ |
| | Shift input ids one token to the right. |
| | """ |
| | shifted_input_ids = np.zeros(input_ids.shape) |
| | shifted_input_ids[:, 1:] = input_ids[:, :-1] |
| | shifted_input_ids[:, 0] = decoder_start_token_id |
| | return shifted_input_ids |
| |
|
| |
|
| | def blank_caption_function(example, text_column, blank_caption_prob, rng=None): |
| | if ( |
| | blank_caption_prob |
| | and (rng.random() if rng is not None else np.random.random()) |
| | < blank_caption_prob |
| | ): |
| | example[text_column] = "" |
| | return example |
| |
|
| |
|
| | def normalize_function(example, text_column, text_normalizer): |
| | example[text_column] = text_normalizer(example[text_column]) |
| | return example |
| |
|
| |
|
| | def filter_function( |
| | example, |
| | min_clip_score, |
| | max_clip_score, |
| | clip_score_column, |
| | filter_column, |
| | filter_value, |
| | ): |
| | if min_clip_score is not None and example[clip_score_column] < min_clip_score: |
| | return False |
| | if max_clip_score is not None and example[clip_score_column] > max_clip_score: |
| | return False |
| | if filter_column is not None and example[filter_column] != filter_value: |
| | return False |
| | return True |
| |
|
| |
|
| | def preprocess_function( |
| | examples, |
| | tokenizer, |
| | text_column, |
| | encoding_column, |
| | max_length, |
| | decoder_start_token_id, |
| | ): |
| | inputs = examples[text_column] |
| | |
| | model_inputs = tokenizer( |
| | inputs, |
| | max_length=max_length, |
| | padding="max_length", |
| | truncation=True, |
| | return_tensors="np", |
| | ) |
| |
|
| | |
| | |
| | |
| | labels = examples[encoding_column] |
| | labels = np.asarray(labels) |
| |
|
| | |
| | model_inputs["labels"] = labels |
| |
|
| | |
| | decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id) |
| | model_inputs["decoder_input_ids"] = decoder_input_ids |
| |
|
| | return model_inputs |
| |
|