| import io |
|
|
| IGNORE_TOKEN_ID = -100 |
| from typing import Dict |
|
|
| import torch |
| import torchvision.transforms as T |
| import transformers |
| from .conversation import get_conv_template |
| from PIL import Image |
| from torch.utils.data import ConcatDataset, WeightedRandomSampler |
| import sys |
|
|
|
|
| def preprocess_qwen( |
| template_name, |
| sources, |
| tokenizer: transformers.PreTrainedTokenizer, |
| special_prefixs, |
| text_only: bool = False, |
| group_by_length: bool = False, |
| ds_name: str = None |
| ) -> Dict: |
| conv = get_conv_template(template_name) |
| roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
| assert len(sources) == len(special_prefixs) |
| |
| conversations = [] |
| for i, source in enumerate(sources): |
| if roles[source[0]['from']] != conv.roles[0]: |
| |
| source = source[1:] |
|
|
| per_prefix = special_prefixs[i] |
| conv.messages = [] |
| for j, sentence in enumerate(source): |
| role = roles[sentence['from']] |
| assert role == conv.roles[j % 2], f'{i}' |
| sentence['value'] = sentence['value'].replace("<image>", "").strip() |
| sentence['value'] = sentence['value'].replace("<video>", "").strip() |
|
|
| if j == 0: |
| sentence['value'] = per_prefix + sentence['value'] |
|
|
| conv.append_message(role, sentence['value']) |
| conversations.append(conv.get_prompt()) |
|
|
| if tokenizer.bos_token is not None: |
| new_conversations = [] |
| for conversation in conversations: |
| conversation = tokenizer.bos_token + conversation |
| new_conversations.append(conversation) |
| conversations = new_conversations |
|
|
| |
| tokenizer.padding_side = 'right' |
| input_ids = tokenizer( |
| conversations, |
| return_tensors='pt', |
| padding=False if group_by_length else 'longest', |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| ).input_ids |
| targets = input_ids.clone() |
|
|
| |
| sep = conv.sep + '\n' + conv.roles[1] + '\n' |
| for conversation, target in zip(conversations, targets): |
| total_len = int(target.ne(int(tokenizer.pad_token_id)).sum()) |
|
|
| sep2 = conv.sep + '\n' |
| turns = conversation.split(sep2) |
| re_turns = [sep2.join(turns[:3])+sep2] |
| for conv_idx in range(3, len(turns), 2): |
| re_turns.append(sep2.join(turns[conv_idx:conv_idx + 2])+sep2) |
| cur_len = 0 |
| target[:cur_len] = IGNORE_TOKEN_ID |
| endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>') |
| target[target == endoftext_id] = IGNORE_TOKEN_ID |
|
|
| for i, turn in enumerate(re_turns): |
| if turn == '': |
| break |
| turn_len = len(tokenizer(turn).input_ids) |
|
|
| parts = turn.split(sep) |
| if len(parts) != 2: |
| break |
| parts[0] += sep |
|
|
| instruction_len = len(tokenizer(parts[0]).input_ids) |
|
|
| |
| target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID |
| |
| |
| |
| |
| cur_len += turn_len |
|
|
| target[cur_len:] = IGNORE_TOKEN_ID |
|
|
| if False: |
| z = target.clone() |
| z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) |
| print(repr(tokenizer.decode(z))) |
|
|
| if cur_len < tokenizer.model_max_length: |
| if cur_len != total_len: |
| target[:] = IGNORE_TOKEN_ID |
| print( |
| f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' |
| f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' |
| f'conversation: {conversation}' |
| ) |
| sys.stdout.flush() |
|
|
| return dict( |
| input_ids=input_ids, |
| labels=targets, |
| attention_mask=input_ids.ne(tokenizer.pad_token_id), |
| conversations=conversations |
| ) |
|
|
|
|
|
|