| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from functools import partial |
| | from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple |
| |
|
| | from .processors.feedback import preprocess_feedback_dataset |
| | from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example |
| | from .processors.pretrain import preprocess_pretrain_dataset |
| | from .processors.supervised import ( |
| | preprocess_packed_supervised_dataset, |
| | preprocess_supervised_dataset, |
| | print_supervised_dataset_example, |
| | ) |
| | from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments |
| |
|
| | from ..hparams import DataArguments |
| | from .template import Template |
| |
|
| |
|
| | def get_preprocess_and_print_func( |
| | data_args: "DataArguments", |
| | training_args: "Seq2SeqTrainingArguments", |
| | stage: Literal["pt", "sft", "rm", "ppo", "kto"], |
| | template: "Template", |
| | tokenizer: "PreTrainedTokenizer", |
| | processor: Optional["ProcessorMixin"], |
| | ) -> Tuple[Callable, Callable]: |
| | if stage == "pt": |
| | preprocess_func = partial( |
| | preprocess_pretrain_dataset, |
| | tokenizer=tokenizer, |
| | data_args=data_args, |
| | ) |
| | print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) |
| | elif stage == "sft" and not training_args.predict_with_generate: |
| | if data_args.packing: |
| | if data_args.neat_packing: |
| | from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence |
| |
|
| | def __init__(self, data, **kwargs): |
| | return TypedSequence.__init__( |
| | self, |
| | data, |
| | type=kwargs.pop("type", None), |
| | try_type=kwargs.pop("try_type", None), |
| | optimized_int_type=kwargs.pop("optimized_int_type", None), |
| | ) |
| |
|
| | OptimizedTypedSequence.__init__ = __init__ |
| | preprocess_func = partial( |
| | preprocess_packed_supervised_dataset, |
| | template=template, |
| | tokenizer=tokenizer, |
| | data_args=data_args, |
| | ) |
| | else: |
| | preprocess_func = partial( |
| | preprocess_supervised_dataset, |
| | template=template, |
| | tokenizer=tokenizer, |
| | processor=processor, |
| | data_args=data_args, |
| | ) |
| |
|
| | print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) |
| | elif stage == "rm": |
| | preprocess_func = partial( |
| | preprocess_pairwise_dataset, |
| | template=template, |
| | tokenizer=tokenizer, |
| | processor=processor, |
| | data_args=data_args, |
| | ) |
| | print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) |
| | elif stage == "kto": |
| | preprocess_func = partial( |
| | preprocess_feedback_dataset, |
| | template=template, |
| | tokenizer=tokenizer, |
| | processor=processor, |
| | data_args=data_args, |
| | ) |
| | print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) |
| | else: |
| | preprocess_func = partial( |
| | preprocess_unsupervised_dataset, |
| | template=template, |
| | tokenizer=tokenizer, |
| | processor=processor, |
| | data_args=data_args, |
| | ) |
| | print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) |
| |
|
| | return preprocess_func, print_function |
| |
|