| import copy |
| import random |
| import json |
|
|
|
|
| def get_bos_eos_token_ids(tokenizer): |
| if tokenizer.__class__.__name__ in [ |
| 'QWenTokenizer', 'QWen2Tokenizer', 'Qwen2TokenizerFast' |
| ]: |
| bos_token_id = [] |
| eos_token_id = tokenizer.eos_token_id |
| assert eos_token_id is not None, \ |
| 'Please set eos_token for Qwen tokenizer!' |
| elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer': |
| bos_token_id = [64790, 64792] |
| eos_token_id = tokenizer.eos_token_id |
| else: |
| bos_token_id = tokenizer.bos_token_id |
| eos_token_id = tokenizer.eos_token_id |
| if isinstance(bos_token_id, int): |
| bos_token_id = [bos_token_id] |
| if isinstance(eos_token_id, int): |
| eos_token_id = [eos_token_id] |
| return bos_token_id, eos_token_id |
|
|
|
|
| IGNORE_INDEX = -100 |
| DEFAULT_PAD_TOKEN_INDEX = 0 |
| IMAGE_TOKEN_INDEX = -200 |
| DEFAULT_IMAGE_TOKEN = '<image>' |
|
|
| INPUT_IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX |
| OUTPUT_IMAGE_TOKEN_INDEX = -300 |
| QUERY_TOKEN_INDEX = -400 |
| QUERY_TOKEN = '<query>' |
|
|
|
|
|
|
|
|
| def crop2square(pil_img): |
| width, height = pil_img.width, pil_img.height |
|
|
| if width > height: |
| y0, y1 = 0, height |
| x0 = random.randint(0, width - height) |
| x1 = x0 + height |
| else: |
| x0, x1 = 0, width |
| y0 = random.randint(0, height - width) |
| y1 = y0 + width |
|
|
| return pil_img.crop(box=(x0, y0, x1, y1)) |
|
|
| def load_jsonl(json_file): |
| with open(json_file) as f: |
| lines = f.readlines() |
| data = [] |
| for line in lines: |
| data.append(json.loads(line)) |
| return data |
|
|
|
|
| def encode_fn(example, |
| tokenizer, |
| max_length=None, |
| image_length=1, |
| query_length=1, |
| input_ids_with_output=True, |
| with_image_token=False, |
| prompt_template=None, |
| truncation='right'): |
| """Only support the following three scenarios: |
| |
| 1. Incremental pretraining dataset. |
| example['conversation'] = [ |
| { |
| 'input': '', |
| 'output': '### Human: Can you write xxx' |
| } |
| ] |
| |
| 2. Single-turn conversation dataset. |
| example['conversation'] = [ |
| { |
| 'input': 'Give three tips for staying healthy.', |
| 'output': '1.Eat a balanced diet xxx' |
| } |
| ] |
| |
| 3. Multi-turn conversation dataset. |
| example['conversation'] = [ |
| { |
| 'input': 'Give three tips for staying healthy.', |
| 'output': '1.Eat a balanced diet xxx' |
| }, |
| { |
| 'input': 'Please expand on the second point.', |
| 'output': 'Here is an expanded explanation of the xxx' |
| } |
| ] |
| """ |
| |
| bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer) |
| is_multi_turn_conversation = len(example['conversation']) > 1 |
| if is_multi_turn_conversation: |
| assert input_ids_with_output |
|
|
| input_ids, labels = [], [] |
| next_needs_bos_token = True |
| for single_turn_conversation in example['conversation']: |
| input = single_turn_conversation['input'] |
| if DEFAULT_IMAGE_TOKEN in input and with_image_token: |
| chunk_encode = [ |
| tokenizer.encode(chunk, add_special_tokens=False) |
| for chunk in input.split(DEFAULT_IMAGE_TOKEN) |
| ] |
| assert len(chunk_encode) == 2 |
| input_encode = [] |
| for idx, cur_chunk_encode in enumerate(chunk_encode): |
| input_encode.extend(cur_chunk_encode) |
| if idx != len(chunk_encode) - 1: |
| input_encode += [INPUT_IMAGE_TOKEN_INDEX] * image_length |
| else: |
| input_encode = tokenizer.encode(input, add_special_tokens=False) |
| if next_needs_bos_token: |
| input_ids += bos_token_id |
| labels += [IGNORE_INDEX] * len(bos_token_id) |
| input_ids += input_encode |
| labels += [IGNORE_INDEX] * len(input_encode) |
| if input_ids_with_output and 'output' in single_turn_conversation: |
| |
| output_with_loss = single_turn_conversation.get( |
| 'output_with_loss', True) |
| output = single_turn_conversation['output'] |
| if DEFAULT_IMAGE_TOKEN in output and with_image_token: |
| chunk_encode = [ |
| tokenizer.encode(chunk, add_special_tokens=False) |
| for chunk in output.split(DEFAULT_IMAGE_TOKEN) |
| ] |
| assert len(chunk_encode) == 2 |
| output_encode = [] |
| for idx, cur_chunk_encode in enumerate(chunk_encode): |
| output_encode.extend(cur_chunk_encode) |
| if idx != len(chunk_encode) - 1: |
| output_encode += [OUTPUT_IMAGE_TOKEN_INDEX] * image_length |
| elif QUERY_TOKEN in output: |
| chunk_encode = [ |
| tokenizer.encode(chunk, add_special_tokens=False) |
| for chunk in output.split(QUERY_TOKEN) |
| ] |
| assert len(chunk_encode) == 2 |
| output_encode = [] |
| for idx, cur_chunk_encode in enumerate(chunk_encode): |
| output_encode.extend(cur_chunk_encode) |
| if idx != len(chunk_encode) - 1: |
| output_encode += [QUERY_TOKEN_INDEX] * query_length |
| else: |
| output_encode = tokenizer.encode(output, add_special_tokens=False) |
| input_ids += output_encode |
| if output_with_loss: |
| labels += copy.deepcopy(output_encode) |
| else: |
| labels += [IGNORE_INDEX] * len(output_encode) |
| |
| if single_turn_conversation.get('need_eos_token', True): |
| next_needs_bos_token = True |
| input_ids += eos_token_id |
| if output_with_loss: |
| labels += copy.deepcopy(eos_token_id) |
| else: |
| labels += [IGNORE_INDEX] * len(eos_token_id) |
| else: |
| next_needs_bos_token = False |
| |
| sep = single_turn_conversation.get('sep', '') |
| if sep != '': |
| sep_encode = tokenizer.encode(sep, add_special_tokens=False) |
| input_ids += sep_encode |
| labels += [IGNORE_INDEX] * len(sep_encode) |
|
|
| if max_length is not None and len(input_ids) > max_length: |
| if truncation == 'right': |
| input_ids = input_ids[:max_length] |
| labels = labels[:max_length] |
| elif truncation == 'left': |
| input_ids = input_ids[-max_length:] |
| labels = labels[-max_length:] |
| else: |
| assert truncation is None |
| return {'input_ids': input_ids, 'labels': labels} |
|
|