| | import copy |
| | import transformers |
| | import tokenizers |
| | import torch |
| | from typing import Dict, Optional, Sequence, List |
| | from packaging import version |
| |
|
| | from llava.mm_utils import tokenizer_image_token |
| | from llava.train.arguments import ModelArguments, TrainingArguments, DataArguments |
| | from llava.constants import IGNORE_INDEX, MM_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN |
| | from llava import conversation as conversation_lib |
| |
|
| | IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') |
| |
|
| | def _tokenize_fn(strings: Sequence[str], |
| | tokenizer: transformers.PreTrainedTokenizer) -> Dict: |
| | """Tokenize a list of strings.""" |
| | tokenized_list = [ |
| | tokenizer( |
| | text, |
| | return_tensors="pt", |
| | padding="longest", |
| | max_length=tokenizer.model_max_length, |
| | truncation=True, |
| | ) for text in strings |
| | ] |
| | input_ids = labels = [ |
| | tokenized.input_ids[0] for tokenized in tokenized_list |
| | ] |
| | input_ids_lens = labels_lens = [ |
| | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() |
| | for tokenized in tokenized_list |
| | ] |
| | return dict( |
| | input_ids=input_ids, |
| | labels=labels, |
| | input_ids_lens=input_ids_lens, |
| | labels_lens=labels_lens, |
| | ) |
| |
|
| |
|
| | def _mask_targets(target, tokenized_lens, speakers): |
| | |
| | cur_idx = tokenized_lens[0] |
| | tokenized_lens = tokenized_lens[1:] |
| | target[:cur_idx] = IGNORE_INDEX |
| | |
| | for tokenized_len, speaker in zip(tokenized_lens, speakers): |
| | if speaker == "human": |
| | target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX |
| | cur_idx += tokenized_len |
| |
|
| |
|
| | def _add_speaker_and_signal(header, source, get_conversation=True): |
| | """Add speaker and start/end signal on each round.""" |
| | BEGIN_SIGNAL = "### " |
| | END_SIGNAL = "\n" |
| | conversation = header |
| | for sentence in source: |
| | from_str = sentence["from"] |
| | if from_str.lower() == "human": |
| | from_str = conversation_lib.default_conversation.roles[0] |
| | elif from_str.lower() == "gpt": |
| | from_str = conversation_lib.default_conversation.roles[1] |
| | else: |
| | from_str = 'unknown' |
| | sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + |
| | sentence["value"] + END_SIGNAL) |
| | if get_conversation: |
| | conversation += sentence["value"] |
| | conversation += BEGIN_SIGNAL |
| | return conversation |
| |
|
| | def preprocess_multimodal( |
| | sources: Sequence[str], |
| | data_args: DataArguments |
| | ) -> Dict: |
| | is_multimodal = data_args.is_multimodal |
| | if not is_multimodal: |
| | return sources |
| |
|
| | for source in sources: |
| | for sentence in source: |
| |
|
| | if DEFAULT_VIDEO_TOKEN in sentence['value']: |
| | |
| | |
| | sentence['value'] = sentence['value'].strip() |
| | if "mmtag" in conversation_lib.default_conversation.version: |
| | raise NotImplementedError |
| | |
| | replace_token = DEFAULT_VIDEO_TOKEN |
| | if data_args.mm_use_start_end: |
| | replace_token = DEFAULT_VIDEO_START_TOKEN + replace_token + DEFAULT_VIDEO_END_TOKEN |
| | sentence["value"] = sentence["value"].replace(DEFAULT_VIDEO_TOKEN, replace_token) |
| | |
| |
|
| | if DEFAULT_IMAGE_TOKEN in sentence['value']: |
| | |
| | |
| | sentence['value'] = sentence['value'].strip() |
| | if "mmtag" in conversation_lib.default_conversation.version: |
| | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>') |
| | replace_token = DEFAULT_IMAGE_TOKEN |
| | if data_args.mm_use_start_end: |
| | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN |
| | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) |
| |
|
| | return sources |
| |
|
| |
|
| | def preprocess_llama_2( |
| | sources, |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | has_image: bool = False |
| | ) -> Dict: |
| | conv = conversation_lib.default_conversation.copy() |
| | roles = {"human": conv.roles[0], "gpt": conv.roles[1], "model": conv.roles[1]} |
| |
|
| | |
| | conversations = [] |
| | for i, source in enumerate(sources): |
| | if roles[source[0]["from"]] != conv.roles[0]: |
| | |
| | source = source[1:] |
| |
|
| | conv.messages = [] |
| | for j, sentence in enumerate(source): |
| | role = roles[sentence["from"]] |
| | assert role == conv.roles[j % 2], f"{i}" |
| | conv.append_message(role, sentence["value"]) |
| | conversations.append(conv.get_prompt()) |
| |
|
| | |
| |
|
| | if has_image: |
| | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| | else: |
| | input_ids = tokenizer( |
| | conversations, |
| | return_tensors="pt", |
| | padding="longest", |
| | max_length=tokenizer.model_max_length, |
| | truncation=True, |
| | ).input_ids |
| |
|
| | targets = input_ids.clone() |
| |
|
| | assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 |
| |
|
| | |
| | sep = "[/INST] " |
| | for conversation, target in zip(conversations, targets): |
| | total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
| |
|
| | rounds = conversation.split(conv.sep2) |
| | cur_len = 1 |
| | target[:cur_len] = IGNORE_INDEX |
| | for i, rou in enumerate(rounds): |
| | if rou == "": |
| | break |
| |
|
| | parts = rou.split(sep) |
| | if len(parts) != 2: |
| | break |
| | parts[0] += sep |
| |
|
| | if has_image: |
| | round_len = len(tokenizer_image_token(rou, tokenizer)) |
| | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 |
| | else: |
| | round_len = len(tokenizer(rou).input_ids) |
| | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
| |
|
| | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
| |
|
| | cur_len += round_len |
| | target[cur_len:] = IGNORE_INDEX |
| |
|
| | if cur_len < tokenizer.model_max_length: |
| | if cur_len != total_len: |
| | target[:] = IGNORE_INDEX |
| | print( |
| | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
| | f" (ignored)" |
| | ) |
| |
|
| | return dict( |
| | input_ids=input_ids, |
| | labels=targets, |
| | ) |
| |
|
| |
|
| | def preprocess_v1( |
| | sources, |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | has_image: bool = False |
| | ) -> Dict: |
| | conv = conversation_lib.default_conversation.copy() |
| | roles = {"human": conv.roles[0], "gpt": conv.roles[1], "model": conv.roles[1]} |
| |
|
| | |
| | conversations = [] |
| | for i, source in enumerate(sources): |
| | if roles[source[0]["from"]] != conv.roles[0]: |
| | |
| | source = source[1:] |
| |
|
| | conv.messages = [] |
| | for j, sentence in enumerate(source): |
| | role = roles[sentence["from"]] |
| | assert role == conv.roles[j % 2], f"{i}" |
| | conv.append_message(role, sentence["value"]) |
| | conversations.append(conv.get_prompt()) |
| |
|
| | |
| |
|
| | if has_image: |
| | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| | else: |
| | input_ids = tokenizer( |
| | conversations, |
| | return_tensors="pt", |
| | padding="longest", |
| | max_length=tokenizer.model_max_length, |
| | truncation=True, |
| | ).input_ids |
| |
|
| | targets = input_ids.clone() |
| |
|
| | assert conv.sep_style == conversation_lib.SeparatorStyle.TWO |
| |
|
| | |
| | sep = conv.sep + conv.roles[1] + ": " |
| | for conversation, target in zip(conversations, targets): |
| | total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
| |
|
| | rounds = conversation.split(conv.sep2) |
| | cur_len = 1 |
| | target[:cur_len] = IGNORE_INDEX |
| | for i, rou in enumerate(rounds): |
| | if rou == "": |
| | break |
| |
|
| | parts = rou.split(sep) |
| | if len(parts) != 2: |
| | break |
| | parts[0] += sep |
| |
|
| | if has_image: |
| | round_len = len(tokenizer_image_token(rou, tokenizer)) |
| | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 |
| | else: |
| | round_len = len(tokenizer(rou).input_ids) |
| | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
| |
|
| | if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: |
| | round_len -= 1 |
| | instruction_len -= 1 |
| |
|
| | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
| |
|
| | cur_len += round_len |
| | target[cur_len:] = IGNORE_INDEX |
| |
|
| | if cur_len < tokenizer.model_max_length: |
| | if cur_len != total_len: |
| | target[:] = IGNORE_INDEX |
| | print( |
| | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
| | f" (ignored)" |
| | ) |
| |
|
| | return dict( |
| | input_ids=input_ids, |
| | labels=targets, |
| | ) |
| |
|
| |
|
| | def preprocess_mpt( |
| | sources, |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | has_image: bool = False |
| | ) -> Dict: |
| | conv = conversation_lib.default_conversation.copy() |
| | roles = {"human": conv.roles[0], "gpt": conv.roles[1], "model": conv.roles[1]} |
| |
|
| | |
| | conversations = [] |
| | for i, source in enumerate(sources): |
| | if roles[source[0]["from"]] != conv.roles[0]: |
| | |
| | source = source[1:] |
| |
|
| | conv.messages = [] |
| | for j, sentence in enumerate(source): |
| | role = roles[sentence["from"]] |
| | assert role == conv.roles[j % 2], f"{i}" |
| | conv.append_message(role, sentence["value"]) |
| | conversations.append(conv.get_prompt()) |
| |
|
| | |
| |
|
| | if has_image: |
| | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| | else: |
| | input_ids = tokenizer( |
| | conversations, |
| | return_tensors="pt", |
| | padding="longest", |
| | max_length=tokenizer.model_max_length, |
| | truncation=True, |
| | ).input_ids |
| |
|
| | targets = input_ids.clone() |
| | assert conv.sep_style == conversation_lib.SeparatorStyle.MPT |
| |
|
| | |
| | sep = conv.sep + conv.roles[1] |
| | for conversation, target in zip(conversations, targets): |
| | total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
| |
|
| | rounds = conversation.split(conv.sep) |
| | re_rounds = [conv.sep.join(rounds[:3])] |
| | for conv_idx in range(3, len(rounds), 2): |
| | re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) |
| | cur_len = 0 |
| | target[:cur_len] = IGNORE_INDEX |
| | for i, rou in enumerate(re_rounds): |
| | if rou == "": |
| | break |
| |
|
| | parts = rou.split(sep) |
| | if len(parts) != 2: |
| | break |
| | parts[0] += sep |
| |
|
| | if has_image: |
| | round_len = len(tokenizer_image_token(rou, tokenizer)) |
| | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 |
| | else: |
| | round_len = len(tokenizer(rou).input_ids) |
| | instruction_len = len(tokenizer(parts[0]).input_ids) - 1 |
| |
|
| | if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14: |
| | round_len += 1 |
| | instruction_len += 1 |
| |
|
| | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
| |
|
| | cur_len += round_len |
| | target[cur_len:] = IGNORE_INDEX |
| |
|
| | if cur_len < tokenizer.model_max_length: |
| | if cur_len != total_len: |
| | target[:] = IGNORE_INDEX |
| | print( |
| | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
| | f" (ignored)" |
| | ) |
| |
|
| | return dict( |
| | input_ids=input_ids, |
| | labels=targets, |
| | ) |
| |
|
| |
|
| | def preprocess_plain( |
| | sources: Sequence[str], |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | ) -> Dict: |
| | |
| | conversations = [] |
| | for source in sources: |
| | assert len(source) == 2 |
| | |
| | |
| | conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep |
| | conversations.append(conversation) |
| | |
| | input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] |
| | targets = copy.deepcopy(input_ids) |
| | for target, source in zip(targets, sources): |
| | tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) |
| | target[:tokenized_len] = IGNORE_INDEX |
| |
|
| | return dict(input_ids=input_ids, labels=targets) |
| |
|
| |
|
| |
|
| |
|
| | def preprocess_gemma( |
| | sources, |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | has_image: bool = False |
| | ) -> Dict: |
| | conv = conversation_lib.default_conversation.copy() |
| | roles = {"human": conv.roles[0], "gpt": conv.roles[1], "model": conv.roles[1]} |
| |
|
| | |
| | conversations = [] |
| | for i, source in enumerate(sources): |
| | if roles[source[0]["from"]] != conv.roles[0]: |
| | |
| | source = source[1:] |
| |
|
| | conv.messages = [] |
| | for j, sentence in enumerate(source): |
| | role = roles[sentence["from"]] |
| | assert role == conv.roles[j % 2], f"{i}" |
| | conv.append_message(role, sentence["value"]) |
| | conversations.append(conv.get_prompt(use_chat_template=True, tokenizer=tokenizer)) |
| |
|
| | |
| |
|
| | if has_image: |
| | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| | else: |
| | input_ids = tokenizer( |
| | conversations, |
| | return_tensors="pt", |
| | padding="longest", |
| | max_length=tokenizer.model_max_length, |
| | truncation=True, |
| | ).input_ids |
| |
|
| | targets = input_ids.clone() |
| |
|
| | |
| | sep = conv.sep + conv.roles[1] + '\n' |
| | sep2 = conv.sep2 + '\n' + conv.sep + conv.roles[0] |
| | for conversation, target in zip(conversations, targets): |
| | total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
| |
|
| | rounds = conversation.split(sep2) |
| | cur_len = 1 |
| | target[:cur_len] = IGNORE_INDEX |
| | for i, rou in enumerate(rounds): |
| | if rou == "": |
| | break |
| | if i != len(rounds) - 1: |
| | rou += conv.sep2 + '\n' |
| | if i >= 1 : |
| | rou = conv.sep + conv.roles[0] + rou |
| |
|
| | parts = rou.split(sep) |
| | if len(parts) != 2: |
| | break |
| | parts[0] += sep |
| |
|
| | if has_image: |
| | round_len = len(tokenizer_image_token(rou, tokenizer)) - 1 |
| | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
| |
|
| | cur_len += round_len |
| | target[cur_len:] = IGNORE_INDEX |
| |
|
| | if cur_len < tokenizer.model_max_length: |
| | if cur_len != total_len: |
| | target[:] = IGNORE_INDEX |
| | print( |
| | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
| | f" (ignored)" |
| | ) |
| |
|
| | return dict( |
| | input_ids=input_ids, |
| | labels=targets, |
| | ) |
| |
|
| |
|
| | def preprocess_mistral( |
| | sources, |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | has_image: bool = False |
| | ) -> Dict: |
| | conv = conversation_lib.default_conversation.copy() |
| | roles = {"human": conv.roles[0], "gpt": conv.roles[1], "model": conv.roles[1]} |
| |
|
| | |
| | conversations = [] |
| | for i, source in enumerate(sources): |
| | if roles[source[0]["from"]] != conv.roles[0]: |
| | |
| | source = source[1:] |
| |
|
| | conv.messages = [] |
| | for j, sentence in enumerate(source): |
| | role = roles[sentence["from"]] |
| | assert role == conv.roles[j % 2], f"{i}" |
| | conv.append_message(role, sentence["value"]) |
| | conversations.append(conv.get_prompt(use_chat_template=True, tokenizer=tokenizer)) |
| |
|
| | |
| |
|
| | if has_image: |
| | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| | else: |
| | input_ids = tokenizer( |
| | conversations, |
| | return_tensors="pt", |
| | padding="longest", |
| | max_length=tokenizer.model_max_length, |
| | truncation=True, |
| | ).input_ids |
| |
|
| | targets = input_ids.clone() |
| |
|
| | |
| | sep = " [/INST]" |
| | for conversation, target in zip(conversations, targets): |
| | total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
| |
|
| | rounds = conversation.split(conv.sep2) |
| | cur_len = 1 |
| | target[:cur_len] = IGNORE_INDEX |
| | for i, rou in enumerate(rounds): |
| | if rou == "": |
| | break |
| | parts = rou.split(sep) |
| | if len(parts) != 2: |
| | break |
| | parts[0] += sep |
| |
|
| | if has_image: |
| | round_len = len(tokenizer_image_token(rou, tokenizer)) |
| | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 |
| | else: |
| | round_len = len(tokenizer(rou).input_ids) |
| | instruction_len = len(tokenizer(parts[0]).input_ids) - 1 |
| |
|
| | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
| |
|
| | cur_len += round_len |
| | target[cur_len:] = IGNORE_INDEX |
| | if rou[-1] == ' ': |
| | cur_len += 1 |
| |
|
| | if cur_len < tokenizer.model_max_length: |
| | if cur_len != total_len: |
| | target[:] = IGNORE_INDEX |
| | print( |
| | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
| | f" (ignored)" |
| | ) |
| |
|
| | return dict( |
| | input_ids=input_ids, |
| | labels=targets, |
| | ) |
| |
|
| |
|
| | def preprocess_thoth( |
| | sources, |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | has_image: bool = False |
| | ) -> Dict: |
| | conv = conversation_lib.default_conversation.copy() |
| | roles = {"human": conv.roles[0], "gpt": conv.roles[1], "model": conv.roles[1]} |
| |
|
| | |
| | conversations = [] |
| | for i, source in enumerate(sources): |
| | if roles[source[0]["from"]] != conv.roles[0]: |
| | |
| | source = source[1:] |
| |
|
| | conv.messages = [] |
| | for j, sentence in enumerate(source): |
| | role = roles[sentence["from"]] |
| | assert role == conv.roles[j % 2], f"{i}" |
| | conv.append_message(role, sentence["value"]) |
| | conversations.append(conv.get_prompt()) |
| |
|
| | |
| |
|
| | if has_image: |
| | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| | else: |
| | input_ids = tokenizer( |
| | conversations, |
| | return_tensors="pt", |
| | padding="longest", |
| | max_length=tokenizer.model_max_length, |
| | truncation=True, |
| | ).input_ids |
| |
|
| | targets = input_ids.clone() |
| |
|
| |
|
| | |
| | sep = conv.sep + conv.roles[1] + ": " |
| | for conversation, target in zip(conversations, targets): |
| | total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
| |
|
| | rounds = conversation.split(conv.sep2) |
| | cur_len = 1 |
| | target[:cur_len] = IGNORE_INDEX |
| | for i, rou in enumerate(rounds): |
| | if rou == "": |
| | break |
| |
|
| | parts = rou.split(sep) |
| | if len(parts) != 2: |
| | break |
| | parts[0] += sep |
| |
|
| | if has_image: |
| | round_len = len(tokenizer_image_token(rou, tokenizer)) |
| | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 |
| | else: |
| | round_len = len(tokenizer(rou).input_ids) |
| | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
| |
|
| | target[cur_len: cur_len + instruction_len] = IGNORE_INDEX |
| | cur_len += round_len + 1 |
| | if i == 0: |
| | cur_len -= 1 |
| | target[cur_len:] = IGNORE_INDEX |
| |
|
| | if cur_len < tokenizer.model_max_length: |
| | if cur_len != total_len: |
| | target[:] = IGNORE_INDEX |
| | print( |
| | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
| | f" (ignored)" |
| | ) |
| |
|
| | return dict( |
| | input_ids=input_ids, |
| | labels=targets, |
| | ) |
| |
|
| |
|
| | def preprocess( |
| | sources: Sequence[str], |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | has_image: bool = False |
| | ) -> Dict: |
| | """ |
| | Given a list of sources, each is a conversation list. This transform: |
| | 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; |
| | 2. Concatenate conversations together; |
| | 3. Tokenize the concatenated conversation; |
| | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. |
| | """ |
| | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: |
| | return preprocess_plain(sources, tokenizer) |
| | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: |
| | return preprocess_llama_2(sources, tokenizer, has_image=has_image) |
| | if conversation_lib.default_conversation.version.startswith("v1"): |
| | return preprocess_v1(sources, tokenizer, has_image=has_image) |
| | if conversation_lib.default_conversation.version == "mpt": |
| | return preprocess_mpt(sources, tokenizer, has_image=has_image) |
| | if conversation_lib.default_conversation.version == 'gemma': |
| | return preprocess_gemma(sources, tokenizer, has_image=has_image) |
| | if conversation_lib.default_conversation.version == 'thoth': |
| | return preprocess_thoth(sources, tokenizer, has_image=has_image) |
| | if conversation_lib.default_conversation.version == 'mistral': |
| | return preprocess_mistral(sources, tokenizer, has_image=has_image) |
| | |
| | conversations = [] |
| | for source in sources: |
| | header = f"{conversation_lib.default_conversation.system}\n\n" |
| | conversation = _add_speaker_and_signal(header, source) |
| | conversations.append(conversation) |
| | |
| | def get_tokenize_len(prompts): |
| | return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] |
| |
|
| | if has_image: |
| | input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] |
| | else: |
| | conversations_tokenized = _tokenize_fn(conversations, tokenizer) |
| | input_ids = conversations_tokenized["input_ids"] |
| |
|
| | targets = copy.deepcopy(input_ids) |
| | for target, source in zip(targets, sources): |
| | if has_image: |
| | tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) |
| | else: |
| | tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] |
| | speakers = [sentence["from"] for sentence in source] |
| | _mask_targets(target, tokenized_lens, speakers) |
| |
|
| | return dict(input_ids=input_ids, labels=targets) |
| |
|