| | import copy |
| | import random |
| | import argparse |
| | import os |
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import Dataset |
| | from tqdm import tqdm |
| | from collections import defaultdict |
| | import torch.distributed as dist |
| | import logging |
| | import re |
| | import pdb |
| | import json |
| | from prompt import sft_prompt, all_prompt |
| | import numpy as np |
| |
|
| |
|
| | class BaseDataset(Dataset): |
| |
|
| | def __init__(self, args): |
| | super().__init__() |
| |
|
| | self.args = args |
| | self.dataset = args.dataset |
| | self.data_path = os.path.join(args.data_path, self.dataset) |
| |
|
| | self.max_his_len = args.max_his_len |
| | self.his_sep = args.his_sep |
| | self.index_file = args.index_file |
| | self.inter_path = args.inter_path |
| | self.feature_path = args.feature_path |
| | self.add_prefix = args.add_prefix |
| |
|
| | self.new_tokens = None |
| | self.allowed_tokens = None |
| | self.all_items = None |
| |
|
| |
|
| | def _load_data(self): |
| |
|
| | with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f: |
| | self.indices = json.load(f) |
| |
|
| | def get_new_tokens(self): |
| |
|
| | if self.new_tokens is not None: |
| | return self.new_tokens |
| |
|
| | self.new_tokens = set() |
| | for index in self.indices.values(): |
| | for token in index: |
| | self.new_tokens.add(token) |
| | self.new_tokens = sorted(list(self.new_tokens)) |
| |
|
| | return self.new_tokens |
| |
|
| | def get_all_items(self): |
| |
|
| | if self.all_items is not None: |
| | return self.all_items |
| |
|
| | self.all_items = set() |
| | for index in self.indices.values(): |
| | self.all_items.add("".join(index)) |
| |
|
| | return self.all_items |
| |
|
| | def get_prefix_allowed_tokens_fn(self, tokenizer): |
| |
|
| |
|
| | if self.allowed_tokens is None: |
| | self.allowed_tokens = {} |
| | for index in self.indices.values(): |
| | for i, token in enumerate(index): |
| | token_id = tokenizer(token)["input_ids"][1] |
| | if i not in self.allowed_tokens.keys(): |
| | self.allowed_tokens[i] = set() |
| | self.allowed_tokens[i].add(token_id) |
| | self.allowed_tokens[len(self.allowed_tokens.keys())] = set([tokenizer.eos_token_id]) |
| | sep = tokenizer("Response:")["input_ids"][1:] |
| |
|
| | def prefix_allowed_tokens_fn(batch_id, sentence): |
| | sentence = sentence.tolist() |
| | reversed_sent = sentence[::-1] |
| | for i in range(len(reversed_sent)): |
| | if reversed_sent[i:i + len(sep)] == sep[::-1]: |
| | |
| | return list(self.allowed_tokens[i]) |
| |
|
| | return prefix_allowed_tokens_fn |
| |
|
| | def _process_data(self): |
| |
|
| | raise NotImplementedError |
| |
|
| |
|
| |
|
| | class SeqRecDataset(BaseDataset): |
| | |
| | def __init__(self, args, mode="train", |
| | prompt_sample_num=1, prompt_id=0, sample_num=-1): |
| | super().__init__(args) |
| |
|
| | self.mode = mode |
| | self.prompt_sample_num = prompt_sample_num |
| | self.prompt_id = prompt_id |
| | self.sample_num = sample_num |
| |
|
| | self.prompts = all_prompt["seqrec"] |
| |
|
| |
|
| | |
| | self._load_data() |
| | self._remap_items() |
| | |
| | |
| | if self.mode == 'train': |
| | self.inter_data = self._process_train_data() |
| | elif self.mode == 'valid': |
| | self.sample_valid = args.sample_valid |
| | self.valid_prompt_id = args.valid_prompt_id |
| | self.inter_data = self._process_valid_data() |
| | self._construct_valid_text() |
| | elif self.mode == 'test': |
| | self.inter_data = self._process_test_data() |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| |
|
| | def _load_data(self): |
| |
|
| | with open(self.inter_path, 'r') as f: |
| | self.inters = json.load(f) |
| | with open(self.index_file, 'r') as f: |
| | self.indices = json.load(f) |
| |
|
| |
|
| | def _remap_items(self): |
| |
|
| | self.remapped_inters = dict() |
| | for uid, items in self.inters.items(): |
| | new_items = ["".join(self.indices[str(i)]) for i in items] |
| | self.remapped_inters[uid] = new_items |
| |
|
| |
|
| | def _process_train_data(self): |
| |
|
| | inter_data = [] |
| | for uid in self.remapped_inters: |
| | items = self.remapped_inters[uid][:-2] |
| | for i in range(1, len(items)): |
| | one_data = dict() |
| | |
| | one_data["item"] = items[i] |
| | history = items[:i] |
| | if self.max_his_len > 0: |
| | history = history[-self.max_his_len:] |
| | if self.add_prefix: |
| | history = [str(k+1) + ". " + item_idx for k, item_idx in enumerate(history)] |
| | one_data["inters"] = self.his_sep.join(history) |
| | inter_data.append(one_data) |
| |
|
| | return inter_data |
| | |
| | def _process_valid_data(self): |
| |
|
| | inter_data = [] |
| | for uid in self.remapped_inters: |
| | items = self.remapped_inters[uid] |
| | one_data = dict() |
| | |
| | one_data["item"] = items[-2] |
| | history = items[:-2] |
| | if self.max_his_len > 0: |
| | history = history[-self.max_his_len:] |
| | if self.add_prefix: |
| | history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)] |
| | one_data["inters"] = self.his_sep.join(history) |
| | inter_data.append(one_data) |
| |
|
| | return inter_data |
| |
|
| | def _process_test_data(self): |
| |
|
| | inter_data = [] |
| | for uid in self.remapped_inters: |
| | items = self.remapped_inters[uid] |
| | one_data = dict() |
| | |
| | one_data["item"] = items[-1] |
| | history = items[:-1] |
| | if self.max_his_len > 0: |
| | history = history[-self.max_his_len:] |
| | if self.add_prefix: |
| | history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)] |
| | one_data["inters"] = self.his_sep.join(history) |
| | inter_data.append(one_data) |
| |
|
| | if self.sample_num > 0: |
| | all_inter_idx = range(len(inter_data)) |
| | sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) |
| | inter_data = np.array(inter_data)[sample_idx].tolist() |
| |
|
| | return inter_data |
| |
|
| | def set_prompt(self, prompt_id): |
| |
|
| | self.prompt_id = prompt_id |
| |
|
| | def __len__(self): |
| | if self.mode == 'train': |
| | return len(self.inter_data) * self.prompt_sample_num |
| | elif self.mode == 'valid': |
| | return len(self.valid_text_data) |
| | elif self.mode == 'test': |
| | return len(self.inter_data) |
| | else: |
| | raise NotImplementedError |
| | |
| | def _construct_valid_text(self): |
| | self.valid_text_data = [] |
| | if self.sample_valid: |
| | all_prompt_ids = range(len(self.prompts)) |
| | for i in range(len(self.inter_data)): |
| | d = self.inter_data[i] |
| | prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False) |
| | for prompt_id in prompt_ids: |
| | prompt = self.prompts[prompt_id] |
| | input, output = self._get_text_data(d, prompt) |
| | self.valid_text_data.append({"input_ids": input, "labels": output}) |
| | else: |
| | self.prompt_sample_num = 1 |
| | prompt = self.prompts[self.valid_prompt_id] |
| | for i in range(len(self.inter_data)): |
| | d = self.inter_data[i] |
| | input, output = self._get_text_data(d, prompt) |
| | self.valid_text_data.append({"input_ids": input, "labels": output}) |
| |
|
| | def _get_text_data(self, data, prompt): |
| |
|
| | instruction = prompt["instruction"].format(**data) |
| | response = prompt["response"].format(**data) |
| |
|
| | input = sft_prompt.format(instruction = instruction, response = "") |
| | output = sft_prompt.format(instruction = instruction, response = response) |
| |
|
| | if self.mode == 'test': |
| | return input, response |
| |
|
| | return input, output |
| |
|
| | def __getitem__(self, index): |
| |
|
| | if self.mode == 'valid': |
| | return self.valid_text_data[index] |
| |
|
| | idx = index // self.prompt_sample_num |
| | d = self.inter_data[idx] |
| | |
| |
|
| | if self.mode == 'train': |
| | prompt_id = random.randint(0, len(self.prompts) - 1) |
| | elif self.mode == 'test': |
| | prompt_id = self.prompt_id |
| |
|
| | prompt = self.prompts[prompt_id] |
| |
|
| | input, output = self._get_text_data(d, prompt) |
| |
|
| | |
| |
|
| | return dict(input_ids=input, labels=output) |
| |
|
| |
|
| | class FusionSeqRecDataset(BaseDataset): |
| |
|
| | def __init__(self, args, mode="train", |
| | prompt_sample_num=1, prompt_id=0, sample_num=-1): |
| | super().__init__(args) |
| |
|
| | self.mode = mode |
| | self.prompt_sample_num = prompt_sample_num |
| | self.prompt_id = prompt_id |
| | self.sample_num = sample_num |
| |
|
| | self.prompts = all_prompt["fusionseqrec"] |
| |
|
| | |
| | self._load_data() |
| | |
| |
|
| | |
| | if self.mode == 'train': |
| | self.inter_data = self._process_train_data() |
| | elif self.mode == 'valid': |
| | self.sample_valid = args.sample_valid |
| | self.valid_prompt_id = args.valid_prompt_id |
| | self.inter_data = self._process_valid_data() |
| | self._construct_valid_text() |
| | elif self.mode == 'test': |
| | self.inter_data = self._process_test_data() |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| | def _load_data(self): |
| |
|
| | with open(self.inter_path, 'r') as f: |
| | self.inters = json.load(f) |
| | with open(self.index_file, 'r') as f: |
| | self.indices = json.load(f) |
| | with open(self.feature_path, 'r') as f: |
| | self.item_feat = json.load(f) |
| |
|
| | def _process_train_data(self): |
| |
|
| | inter_data = [] |
| | for uid in self.inters: |
| | items = self.inters[uid][:-2] |
| | for i in range(1, len(items)): |
| | one_data = dict() |
| | |
| | one_data["item"] = "".join(self.indices[str(items[i])]) |
| | one_data["title"] = self.item_feat[str(items[i])]["title"].strip().strip(".!?,;:`") |
| | one_data["description"] = self.item_feat[str(items[i])]["description"] |
| | history = items[:i] |
| | if self.max_his_len > 0: |
| | history = history[-self.max_his_len:] |
| | inters = ["".join(self.indices[str(j)]) for j in history] |
| | inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history] |
| |
|
| |
|
| | if self.add_prefix: |
| | inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] |
| | inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)] |
| |
|
| | one_data["inters"] = self.his_sep.join(inters) |
| | one_data["inter_titles"] = self.his_sep.join(inter_titles) |
| | inter_data.append(one_data) |
| |
|
| | if self.sample_num > 0: |
| | all_inter_idx = range(len(inter_data)) |
| | sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) |
| | inter_data = np.array(inter_data)[sample_idx].tolist() |
| |
|
| | return inter_data |
| |
|
| | def _process_valid_data(self): |
| |
|
| | inter_data = [] |
| | for uid in self.inters: |
| | items = self.inters[uid] |
| | one_data = dict() |
| | one_data["item"] = "".join(self.indices[str(items[-2])]) |
| | one_data["title"] = self.item_feat[str(items[-2])]["title"].strip().strip(".!?,;:`") |
| | one_data["description"] = self.item_feat[str(items[-2])]["description"] |
| |
|
| |
|
| | history = items[:-2] |
| | if self.max_his_len > 0: |
| | history = history[-self.max_his_len:] |
| | inters = ["".join(self.indices[str(j)]) for j in history] |
| | inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history] |
| |
|
| | if self.add_prefix: |
| | inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] |
| | inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)] |
| |
|
| | one_data["inters"] = self.his_sep.join(inters) |
| | one_data["inter_titles"] = self.his_sep.join(inter_titles) |
| | inter_data.append(one_data) |
| |
|
| | if self.sample_num > 0: |
| | all_inter_idx = range(len(inter_data)) |
| | sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) |
| | inter_data = np.array(inter_data)[sample_idx].tolist() |
| |
|
| | return inter_data |
| |
|
| | def _process_test_data(self): |
| |
|
| | inter_data = [] |
| | for uid in self.inters: |
| | items = self.inters[uid] |
| | one_data = dict() |
| | one_data["item"] = "".join(self.indices[str(items[-1])]) |
| | one_data["title"] = self.item_feat[str(items[-1])]["title"].strip().strip(".!?,;:`") |
| | one_data["description"] = self.item_feat[str(items[-1])]["description"] |
| |
|
| | history = items[:-1] |
| | if self.max_his_len > 0: |
| | history = history[-self.max_his_len:] |
| | inters = ["".join(self.indices[str(j)]) for j in history] |
| | inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history] |
| |
|
| | if self.add_prefix: |
| | inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] |
| | inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)] |
| |
|
| | one_data["inters"] = self.his_sep.join(inters) |
| | one_data["inter_titles"] = self.his_sep.join(inter_titles) |
| | inter_data.append(one_data) |
| |
|
| | if self.sample_num > 0: |
| | all_inter_idx = range(len(inter_data)) |
| | sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) |
| | inter_data = np.array(inter_data)[sample_idx].tolist() |
| |
|
| | return inter_data |
| |
|
| | def set_prompt(self, prompt_id): |
| |
|
| | self.prompt_id = prompt_id |
| |
|
| | def __len__(self): |
| | if self.mode == 'train': |
| | return len(self.inter_data) * self.prompt_sample_num |
| | elif self.mode == 'valid': |
| | return len(self.valid_text_data) |
| | elif self.mode == 'test': |
| | return len(self.inter_data) |
| | else: |
| | raise NotImplementedError |
| |
|
| | def _construct_valid_text(self): |
| | self.valid_text_data = [] |
| | if self.sample_valid: |
| | all_prompt_ids = range(len(self.prompts)) |
| | for i in range(len(self.inter_data)): |
| | d = self.inter_data[i] |
| | prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False) |
| | for prompt_id in prompt_ids: |
| | prompt = self.prompts[prompt_id] |
| | input, output = self._get_text_data(d, prompt) |
| | self.valid_text_data.append({"input_ids": input, "labels": output}) |
| | else: |
| | self.prompt_sample_num = 1 |
| | prompt = self.prompts[self.valid_prompt_id] |
| | for i in range(len(self.inter_data)): |
| | d = self.inter_data[i] |
| | input, output = self._get_text_data(d, prompt) |
| | self.valid_text_data.append({"input_ids": input, "labels": output}) |
| |
|
| | def _get_text_data(self, data, prompt): |
| |
|
| | instruction = prompt["instruction"].format(**data) |
| | response = prompt["response"].format(**data) |
| |
|
| | input = sft_prompt.format(instruction=instruction, response="") |
| | output = sft_prompt.format(instruction=instruction, response=response) |
| |
|
| | if self.mode == 'test': |
| | return input, response |
| |
|
| | return input, output |
| |
|
| | def __getitem__(self, index): |
| |
|
| | if self.mode == 'valid': |
| | return self.valid_text_data[index] |
| |
|
| | idx = index // self.prompt_sample_num |
| | d = self.inter_data[idx] |
| |
|
| | if self.mode == 'train': |
| | prompt_id = random.randint(0, len(self.prompts) - 1) |
| | elif self.mode == 'test': |
| | prompt_id = self.prompt_id |
| |
|
| | prompt = self.prompts[prompt_id] |
| |
|
| | input, output = self._get_text_data(d, prompt) |
| |
|
| |
|
| | return dict(input_ids=input, labels=output) |
| |
|
| |
|
| | class ItemFeatDataset(BaseDataset): |
| |
|
| | def __init__(self, args, task="item2index", prompt_sample_num=1, sample_num=-1): |
| | super().__init__(args) |
| |
|
| | self.task = task.lower() |
| | self.prompt_sample_num = prompt_sample_num |
| | self.sample_num = sample_num |
| |
|
| | self.prompts = all_prompt[self.task] |
| |
|
| | |
| | self._load_data() |
| | self.feat_data = self._process_data() |
| |
|
| |
|
| |
|
| | def _load_data(self): |
| |
|
| | with open(self.index_file, 'r') as f: |
| | self.indices = json.load(f) |
| | with open(self.feature_path, 'r') as f: |
| | self.item_feat = json.load(f) |
| |
|
| |
|
| | def _process_data(self): |
| |
|
| | feat_data = [] |
| | for iid in self.item_feat: |
| | feat = self.item_feat[iid] |
| | index = "".join(self.indices[iid]) |
| | feat["item"] = index |
| | feat["title"] = feat["title"].strip().strip(".!?,;:`") |
| | feat_data.append(feat) |
| |
|
| | if self.sample_num > 0: |
| | all_idx = range(len(feat_data)) |
| | sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) |
| |
|
| | feat_data = np.array(feat_data)[sample_idx].tolist() |
| |
|
| | return feat_data |
| |
|
| |
|
| | def __len__(self): |
| | return len(self.feat_data) * self.prompt_sample_num |
| |
|
| | def _get_text_data(self, data, prompt): |
| |
|
| | instruction = prompt["instruction"].format(**data) |
| | response = prompt["response"].format(**data) |
| |
|
| | input = sft_prompt.format(instruction = instruction, response = "") |
| | output = sft_prompt.format(instruction = instruction, response = response) |
| |
|
| | return input, output |
| |
|
| | def __getitem__(self, index): |
| |
|
| | idx = index // self.prompt_sample_num |
| | d = self.feat_data[idx] |
| |
|
| | prompt_id = random.randint(0, len(self.prompts) - 1) |
| |
|
| | prompt = self.prompts[prompt_id] |
| |
|
| | input, output = self._get_text_data(d, prompt) |
| |
|
| | return dict(input_ids=input, labels=output) |
| |
|
| |
|
| | class ItemSearchDataset(BaseDataset): |
| |
|
| | def __init__(self, args, mode="train", |
| | prompt_sample_num=1, prompt_id=0, sample_num=-1): |
| | super().__init__(args) |
| |
|
| | self.mode = mode |
| | self.prompt_sample_num = prompt_sample_num |
| | self.prompt_id = prompt_id |
| | self.sample_num = sample_num |
| |
|
| | self.prompts = all_prompt["itemsearch"] |
| |
|
| | |
| | self._load_data() |
| | self.search_data = self._process_data() |
| |
|
| |
|
| |
|
| | def _load_data(self): |
| |
|
| | with open(self.index_file, 'r') as f: |
| | self.indices = json.load(f) |
| | with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f: |
| | self.user_info = json.load(f) |
| |
|
| |
|
| | def _process_data(self): |
| |
|
| | search_data = [] |
| | user_explicit_preference = self.user_info["user_explicit_preference"] |
| | user_vague_intention = self.user_info["user_vague_intention"] |
| | if self.mode == 'train': |
| | user_vague_intention = user_vague_intention["train"] |
| | elif self.mode == 'test': |
| | user_vague_intention = user_vague_intention["test"] |
| | else: |
| | raise NotImplementedError |
| |
|
| | for uid in user_explicit_preference.keys(): |
| | one_data = {} |
| | user_ep = user_explicit_preference[uid] |
| | user_vi = user_vague_intention[uid]["querys"] |
| | one_data["explicit_preferences"] = user_ep |
| | one_data["user_related_intention"] = user_vi[0] |
| | one_data["item_related_intention"] = user_vi[1] |
| |
|
| | iid = user_vague_intention[uid]["item"] |
| | inters = user_vague_intention[uid]["inters"] |
| |
|
| | index = "".join(self.indices[str(iid)]) |
| | one_data["item"] = index |
| |
|
| | if self.max_his_len > 0: |
| | inters = inters[-self.max_his_len:] |
| | inters = ["".join(self.indices[str(i)]) for i in inters] |
| | if self.add_prefix: |
| | inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] |
| |
|
| | one_data["inters"] = self.his_sep.join(inters) |
| |
|
| | search_data.append(one_data) |
| |
|
| | if self.sample_num > 0: |
| | all_idx = range(len(search_data)) |
| | sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) |
| |
|
| | search_data = np.array(search_data)[sample_idx].tolist() |
| |
|
| | return search_data |
| |
|
| | def set_prompt(self, prompt_id): |
| | self.prompt_id = prompt_id |
| |
|
| | def __len__(self): |
| | if self.mode == 'train': |
| | return len(self.search_data) * self.prompt_sample_num |
| | elif self.mode == 'test': |
| | return len(self.search_data) |
| | else: |
| | return len(self.search_data) |
| |
|
| |
|
| | def _get_text_data(self, data, prompt): |
| |
|
| | instruction = prompt["instruction"].format(**data) |
| | response = prompt["response"].format(**data) |
| |
|
| | input = sft_prompt.format(instruction = instruction, response = "") |
| | output = sft_prompt.format(instruction = instruction, response = response) |
| |
|
| | if self.mode == 'test': |
| | return input, response |
| |
|
| | return input, output |
| |
|
| | def __getitem__(self, index): |
| |
|
| | idx = index // self.prompt_sample_num |
| |
|
| | d = self.search_data[idx] |
| | if self.mode == 'train': |
| | prompt_id = random.randint(0, len(self.prompts) - 1) |
| | elif self.mode == 'test': |
| | prompt_id = self.prompt_id |
| |
|
| | prompt = self.prompts[prompt_id] |
| |
|
| | d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"])) |
| | all_querys = [d["user_related_intention"], d["item_related_intention"]] |
| | d["query"] = random.choice(all_querys) |
| |
|
| | input, output = self._get_text_data(d, prompt) |
| |
|
| | return dict(input_ids=input, labels=output) |
| |
|
| |
|
| |
|
| | class PreferenceObtainDataset(BaseDataset): |
| |
|
| | def __init__(self, args, prompt_sample_num=1, sample_num=-1): |
| | super().__init__(args) |
| |
|
| | self.prompt_sample_num = prompt_sample_num |
| | self.sample_num = sample_num |
| |
|
| | self.prompts = all_prompt["preferenceobtain"] |
| |
|
| | |
| | self._load_data() |
| | self._remap_items() |
| |
|
| | self.preference_data = self._process_data() |
| |
|
| |
|
| |
|
| | def _load_data(self): |
| |
|
| | with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f: |
| | self.user_info = json.load(f) |
| | with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: |
| | self.inters = json.load(f) |
| | with open(self.index_file, 'r') as f: |
| | self.indices = json.load(f) |
| |
|
| |
|
| | def _remap_items(self): |
| |
|
| | self.remapped_inters = dict() |
| | for uid, items in self.inters.items(): |
| | new_items = ["".join(self.indices[str(i)]) for i in items] |
| | self.remapped_inters[uid] = new_items |
| |
|
| | def _process_data(self): |
| |
|
| | preference_data = [] |
| | user_explicit_preference = self.user_info["user_explicit_preference"] |
| |
|
| | for uid in user_explicit_preference.keys(): |
| | one_data = {} |
| | inters = self.remapped_inters[uid][:-3] |
| | user_ep = user_explicit_preference[uid] |
| |
|
| | if self.max_his_len > 0: |
| | inters = inters[-self.max_his_len:] |
| | if self.add_prefix: |
| | inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] |
| |
|
| | one_data["explicit_preferences"] = user_ep |
| | one_data["inters"] = self.his_sep.join(inters) |
| |
|
| | preference_data.append(one_data) |
| |
|
| | if self.sample_num > 0: |
| | all_idx = range(len(preference_data)) |
| | sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) |
| |
|
| | preference_data = np.array(preference_data)[sample_idx].tolist() |
| |
|
| | return preference_data |
| |
|
| | def set_prompt(self, prompt_id): |
| | self.prompt_id = prompt_id |
| |
|
| | def __len__(self): |
| | return len(self.preference_data) * self.prompt_sample_num |
| |
|
| |
|
| | def _get_text_data(self, data, prompt): |
| |
|
| | instruction = prompt["instruction"].format(**data) |
| | response = prompt["response"].format(**data) |
| |
|
| | input = sft_prompt.format(instruction = instruction, response = "") |
| | output = sft_prompt.format(instruction = instruction, response = response) |
| |
|
| | return input, output |
| |
|
| | def __getitem__(self, index): |
| |
|
| | idx = index // self.prompt_sample_num |
| |
|
| | d = self.preference_data[idx] |
| | prompt_id = random.randint(0, len(self.prompts) - 1) |
| |
|
| | prompt = self.prompts[prompt_id] |
| |
|
| | d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"])) |
| |
|
| | input, output = self._get_text_data(d, prompt) |
| |
|
| | return dict(input_ids=input, labels=output) |
| |
|
| |
|
| |
|
| |
|
| |
|
| | class SeqRecTestDataset(BaseDataset): |
| |
|
| | def __init__(self, args, prompt_id=0, sample_num=-1): |
| | super().__init__(args) |
| |
|
| | self.prompt_id = prompt_id |
| | self.sample_num = sample_num |
| |
|
| | self.prompt = all_prompt["seqrec"][self.prompt_id] |
| |
|
| | |
| | self._load_data() |
| | self._remap_items() |
| |
|
| | self.inter_data = self._process_test_data() |
| |
|
| | def _load_data(self): |
| |
|
| | with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: |
| | self.inters = json.load(f) |
| | with open(self.index_file, 'r') as f: |
| | self.indices = json.load(f) |
| |
|
| |
|
| | def _remap_items(self): |
| |
|
| | self.remapped_inters = dict() |
| | for uid, items in self.inters.items(): |
| | new_items = ["".join(self.indices[str(i)]) for i in items] |
| | self.remapped_inters[uid] = new_items |
| |
|
| | def _process_test_data(self): |
| |
|
| | inter_data = [] |
| | for uid in self.remapped_inters: |
| | items = self.remapped_inters[uid] |
| | one_data = dict() |
| | |
| | one_data["item"] = items[-1] |
| | history = items[:-1] |
| | if self.max_his_len > 0: |
| | history = history[-self.max_his_len:] |
| | if self.add_prefix: |
| | history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)] |
| | one_data["inters"] = self.his_sep.join(history) |
| | inter_data.append(one_data) |
| |
|
| | if self.sample_num > 0: |
| | all_inter_idx = range(len(inter_data)) |
| | sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) |
| |
|
| | inter_data = np.array(inter_data)[sample_idx].tolist() |
| |
|
| | return inter_data |
| |
|
| | def set_prompt(self, prompt_id): |
| | self.prompt_id = prompt_id |
| |
|
| | self.prompt = all_prompt["seqrec"][self.prompt_id] |
| |
|
| | def __len__(self): |
| |
|
| | return len(self.inter_data) |
| |
|
| | def _get_text_data(self, data, prompt): |
| |
|
| | instruction = prompt["instruction"].format(**data) |
| | response = prompt["response"].format(**data) |
| |
|
| | input = sft_prompt.format(instruction=instruction, response="") |
| |
|
| | return input, response |
| |
|
| | def __getitem__(self, index): |
| |
|
| | d = self.inter_data[index] |
| | input, target = self._get_text_data(d, self.prompt) |
| |
|
| | return dict(input_ids=input, labels=target) |