| import random |
| import re |
| from collections import defaultdict |
| from typing import Iterable, Iterator, List, MutableSet, Optional, Tuple, TypeVar, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| from rex.data.collate_fn import GeneralCollateFn |
| from rex.data.transforms.base import CachedTransformBase, CachedTransformOneBase |
| from rex.metrics import calc_p_r_f1_from_tp_fp_fn |
| from rex.utils.io import load_json |
| from rex.utils.iteration import windowed_queue_iter |
| from rex.utils.logging import logger |
| from transformers import AutoTokenizer |
| from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast |
| from transformers.models.deberta_v2.tokenization_deberta_v2_fast import ( |
| DebertaV2TokenizerFast, |
| ) |
| from transformers.tokenization_utils_base import BatchEncoding |
|
|
| from src.utils import ( |
| decode_nnw_nsw_thw_mat, |
| decode_nnw_thw_mat, |
| encode_nnw_nsw_thw_mat, |
| encode_nnw_thw_mat, |
| ) |
|
|
| Filled = TypeVar("Filled") |
|
|
|
|
| class PaddingMixin: |
| max_seq_len: int |
|
|
| def pad_seq(self, batch_seqs: Iterable[Filled], fill: Filled) -> Iterable[Filled]: |
| max_len = max(len(seq) for seq in batch_seqs) |
| assert max_len <= self.max_seq_len |
| for i in range(len(batch_seqs)): |
| batch_seqs[i] = batch_seqs[i] + [fill] * (max_len - len(batch_seqs[i])) |
| return batch_seqs |
|
|
| def pad_mat( |
| self, mats: List[torch.Tensor], fill: Union[int, float] |
| ) -> List[torch.Tensor]: |
| max_len = max(mat.shape[0] for mat in mats) |
| assert max_len <= self.max_seq_len |
| for i in range(len(mats)): |
| num_add = max_len - mats[i].shape[0] |
| mats[i] = F.pad( |
| mats[i], (0, 0, 0, num_add, 0, num_add), mode="constant", value=fill |
| ) |
| return mats |
|
|
|
|
| class PointerTransformMixin: |
| tokenizer: BertTokenizerFast |
| max_seq_len: int |
| space_token: str = "[unused1]" |
|
|
| def build_ins( |
| self, |
| query_tokens: list[str], |
| context_tokens: list[str], |
| answer_indexes: list[list[int]], |
| add_context_tokens: list[str] = None, |
| ) -> Tuple: |
| |
| reserved_seq_len = self.max_seq_len - 3 - len(query_tokens) |
| |
| if reserved_seq_len < 20: |
| raise ValueError( |
| f"Query {query_tokens} too long: {len(query_tokens)} " |
| f"while max seq len is {self.max_seq_len}" |
| ) |
|
|
| input_tokens = [self.tokenizer.cls_token] |
| input_tokens += query_tokens |
| input_tokens += [self.tokenizer.sep_token] |
| offset = len(input_tokens) |
| input_tokens += context_tokens[:reserved_seq_len] |
| available_token_range = range( |
| offset, offset + len(context_tokens[:reserved_seq_len]) |
| ) |
| input_tokens += [self.tokenizer.sep_token] |
|
|
| add_context_len = 0 |
| max_add_context_len = self.max_seq_len - len(input_tokens) - 1 |
| add_context_flag = False |
| if add_context_tokens and len(add_context_tokens) > 0: |
| add_context_flag = True |
| add_context_len = len(add_context_tokens[:max_add_context_len]) |
| input_tokens += add_context_tokens[:max_add_context_len] |
| input_tokens += [self.tokenizer.sep_token] |
| new_tokens = [] |
| for t in input_tokens: |
| if len(t.strip()) > 0: |
| new_tokens.append(t) |
| else: |
| new_tokens.append(self.space_token) |
| input_tokens = new_tokens |
| input_ids = self.tokenizer.convert_tokens_to_ids(input_tokens) |
|
|
| mask = [1] |
| mask += [2] * len(query_tokens) |
| mask += [3] |
| mask += [4] * len(context_tokens[:reserved_seq_len]) |
| mask += [5] |
| if add_context_flag: |
| mask += [6] * add_context_len |
| mask += [7] |
| assert len(mask) == len(input_ids) <= self.max_seq_len |
|
|
| available_spans = [tuple(i + offset for i in index) for index in answer_indexes] |
| available_spans = list( |
| filter( |
| lambda index: all(i in available_token_range for i in index), |
| available_spans, |
| ) |
| ) |
|
|
| token_len = len(input_ids) |
| pad_len = self.max_seq_len - token_len |
| input_tokens += pad_len * [self.tokenizer.pad_token] |
| input_ids += pad_len * [self.tokenizer.pad_token_id] |
| mask += pad_len * [0] |
|
|
| return input_tokens, input_ids, mask, offset, available_spans |
|
|
| def update_labels(self, data: dict) -> dict: |
| bs = len(data["input_ids"]) |
| seq_len = self.max_seq_len |
| labels = torch.zeros((bs, 2, seq_len, seq_len)) |
| for i, batch_spans in enumerate(data["available_spans"]): |
| |
| |
| |
| for span in batch_spans: |
| if len(span) == 1: |
| labels[i, :, span[0], span[0]] = 1 |
| else: |
| for s, e in windowed_queue_iter(span, 2, 1, drop_last=True): |
| labels[i, 0, s, e] = 1 |
| labels[i, 1, span[-1], span[0]] = 1 |
| |
| |
| |
| |
| data["labels"] = labels |
| return data |
|
|
| def update_consecutive_span_labels(self, data: dict) -> dict: |
| bs = len(data["input_ids"]) |
| seq_len = self.max_seq_len |
| labels = torch.zeros((bs, 1, seq_len, seq_len)) |
| for i, batch_spans in enumerate(data["available_spans"]): |
| for span in batch_spans: |
| assert span == tuple(sorted(set(span))) |
| if len(span) == 1: |
| labels[i, 0, span[0], span[0]] = 1 |
| else: |
| labels[i, 0, span[0], span[-1]] = 1 |
| data["labels"] = labels |
| return data |
|
|
|
|
| class CachedPointerTaggingTransform(CachedTransformBase, PointerTransformMixin): |
| def __init__( |
| self, |
| max_seq_len: int, |
| plm_dir: str, |
| ent_type2query_filepath: str, |
| mode: str = "w2", |
| negative_sample_prob: float = 1.0, |
| ) -> None: |
| super().__init__() |
|
|
| self.max_seq_len: int = max_seq_len |
| self.tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(plm_dir) |
| self.ent_type2query: dict = load_json(ent_type2query_filepath) |
| self.negative_sample_prob = negative_sample_prob |
|
|
| self.collate_fn: GeneralCollateFn = GeneralCollateFn( |
| { |
| "input_ids": torch.long, |
| "mask": torch.long, |
| "labels": torch.long, |
| }, |
| guessing=False, |
| missing_key_as_null=True, |
| ) |
| if mode == "w2": |
| self.collate_fn.update_before_tensorify = self.update_labels |
| elif mode == "cons": |
| self.collate_fn.update_before_tensorify = ( |
| self.update_consecutive_span_labels |
| ) |
| else: |
| raise ValueError(f"Mode: {mode} not recognizable") |
|
|
| def transform( |
| self, |
| transform_loader: Iterator, |
| dataset_name: str = None, |
| **kwargs, |
| ) -> Iterable: |
| final_data = [] |
| |
| for data in transform_loader: |
| ent_type2ents = defaultdict(set) |
| for ent in data["ents"]: |
| ent_type2ents[ent["type"]].add(tuple(ent["index"])) |
| for ent_type in self.ent_type2query: |
| gold_ents = ent_type2ents[ent_type] |
| if ( |
| len(gold_ents) < 1 |
| and dataset_name == "train" |
| and random.random() > self.negative_sample_prob |
| ): |
| |
| continue |
| |
| query = self.ent_type2query[ent_type] |
| query_tokens = self.tokenizer.tokenize(query) |
| try: |
| res = self.build_ins(query_tokens, data["tokens"], gold_ents) |
| except (ValueError, AssertionError): |
| continue |
| input_tokens, input_ids, mask, offset, available_spans = res |
| ins = { |
| "id": data.get("id", str(len(final_data))), |
| "ent_type": ent_type, |
| "gold_ents": gold_ents, |
| "raw_tokens": data["tokens"], |
| "input_tokens": input_tokens, |
| "input_ids": input_ids, |
| "mask": mask, |
| "offset": offset, |
| "available_spans": available_spans, |
| |
| "labels": None, |
| |
| } |
| final_data.append(ins) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| return final_data |
|
|
| def predict_transform(self, texts: List[str]): |
| dataset = [] |
| for text_id, text in enumerate(texts): |
| data_id = f"Prediction#{text_id}" |
| tokens = self.tokenizer.tokenize(text) |
| dataset.append( |
| { |
| "id": data_id, |
| "tokens": tokens, |
| "ents": [], |
| } |
| ) |
| final_data = self(dataset, disable_pbar=True) |
| return final_data |
|
|
|
|
| class CachedPointerMRCTransform(CachedTransformBase, PointerTransformMixin): |
| def __init__( |
| self, |
| max_seq_len: int, |
| plm_dir: str, |
| mode: str = "w2", |
| ) -> None: |
| super().__init__() |
|
|
| self.max_seq_len: int = max_seq_len |
| self.tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(plm_dir) |
|
|
| self.collate_fn: GeneralCollateFn = GeneralCollateFn( |
| { |
| "input_ids": torch.long, |
| "mask": torch.long, |
| "labels": torch.long, |
| }, |
| guessing=False, |
| missing_key_as_null=True, |
| ) |
|
|
| if mode == "w2": |
| self.collate_fn.update_before_tensorify = self.update_labels |
| elif mode == "cons": |
| self.collate_fn.update_before_tensorify = ( |
| self.update_consecutive_span_labels |
| ) |
| else: |
| raise ValueError(f"Mode: {mode} not recognizable") |
|
|
| def transform( |
| self, |
| transform_loader: Iterator, |
| dataset_name: str = None, |
| **kwargs, |
| ) -> Iterable: |
| final_data = [] |
| for data in transform_loader: |
| try: |
| res = self.build_ins( |
| data["query_tokens"], |
| data["context_tokens"], |
| data["answer_index"], |
| data.get("background_tokens"), |
| ) |
| except (ValueError, AssertionError): |
| continue |
| input_tokens, input_ids, mask, offset, available_spans = res |
| ins = { |
| "id": data.get("id", str(len(final_data))), |
| "gold_spans": sorted(set(tuple(x) for x in data["answer_index"])), |
| "raw_tokens": data["context_tokens"], |
| "input_tokens": input_tokens, |
| "input_ids": input_ids, |
| "mask": mask, |
| "offset": offset, |
| "available_spans": available_spans, |
| "labels": None, |
| } |
| final_data.append(ins) |
|
|
| return final_data |
|
|
| def predict_transform(self, data: list[dict]): |
| """ |
| Args: |
| data: a list of dict with query, context, and background strings |
| """ |
| dataset = [] |
| for idx, ins in enumerate(data): |
| idx = f"Prediction#{idx}" |
| dataset.append( |
| { |
| "id": idx, |
| "query_tokens": list(ins["query"]), |
| "context_tokens": list(ins["context"]), |
| "background_tokens": list(ins.get("background")), |
| "answer_index": [], |
| } |
| ) |
| final_data = self(dataset, disable_pbar=True, num_samples=0) |
| return final_data |
|
|
|
|
| class CachedLabelPointerTransform(CachedTransformOneBase): |
| """Transform for label-token linking for skip consecutive spans""" |
|
|
| def __init__( |
| self, |
| max_seq_len: int, |
| plm_dir: str, |
| mode: str = "w2", |
| label_span: str = "tag", |
| include_instructions: bool = True, |
| **kwargs, |
| ) -> None: |
| super().__init__() |
|
|
| self.max_seq_len: int = max_seq_len |
| self.mode = mode |
| self.label_span = label_span |
| self.include_instructions = include_instructions |
|
|
| self.tokenizer: DebertaV2TokenizerFast = DebertaV2TokenizerFast.from_pretrained( |
| plm_dir |
| ) |
| self.lc_token = "[LC]" |
| self.lm_token = "[LM]" |
| self.lr_token = "[LR]" |
| self.i_token = "[I]" |
| self.tl_token = "[TL]" |
| self.tp_token = "[TP]" |
| self.b_token = "[B]" |
| num_added = self.tokenizer.add_tokens( |
| [ |
| self.lc_token, |
| self.lm_token, |
| self.lr_token, |
| self.i_token, |
| self.tl_token, |
| self.tp_token, |
| self.b_token, |
| ] |
| ) |
| assert num_added == 7 |
|
|
| self.collate_fn: GeneralCollateFn = GeneralCollateFn( |
| { |
| "input_ids": torch.long, |
| "mask": torch.long, |
| "labels": torch.long, |
| "spans": None, |
| }, |
| guessing=False, |
| missing_key_as_null=True, |
| |
| discard_missing=False, |
| ) |
|
|
| self.collate_fn.update_before_tensorify = self.skip_consecutive_span_labels |
|
|
| def transform(self, instance: dict, **kwargs): |
| |
| tokens = [self.tokenizer.cls_token] |
| mask = [1] |
| label_map = {"lc": {}, "lm": {}, "lr": {}} |
| |
| span_to_label = {} |
|
|
| def _update_seq( |
| label: str, |
| label_type: str, |
| task: str = "", |
| label_mask: int = 4, |
| content_mask: int = 5, |
| ): |
| if label not in label_map[label_type]: |
| label_token_map = { |
| "lc": self.lc_token, |
| "lm": self.lm_token, |
| "lr": self.lr_token, |
| } |
| label_tag_start_idx = len(tokens) |
| tokens.append(label_token_map[label_type]) |
| mask.append(label_mask) |
| label_tag_end_idx = len(tokens) - 1 |
| label_tokens = self.tokenizer(label, add_special_tokens=False).tokens() |
| label_content_start_idx = len(tokens) |
| tokens.extend(label_tokens) |
| mask.extend([content_mask] * len(label_tokens)) |
| label_content_end_idx = len(tokens) - 1 |
|
|
| if self.label_span == "tag": |
| start_idx = label_tag_start_idx |
| end_idx = label_tag_end_idx |
| elif self.label_span == "content": |
| start_idx = label_content_start_idx |
| end_idx = label_content_end_idx |
| else: |
| raise ValueError(f"label_span={self.label_span} is not supported") |
|
|
| if end_idx == start_idx: |
| label_map[label_type][label] = (start_idx,) |
| else: |
| label_map[label_type][label] = (start_idx, end_idx) |
| span_to_label[label_map[label_type][label]] = { |
| "type": label_type, |
| "task": task, |
| "string": label, |
| } |
| return label_map[label_type][label] |
|
|
| if self.include_instructions: |
| instruction = instance.get("instruction") |
| if not instruction: |
| logger.warning( |
| "include_instructions=True, while the instruction is empty!" |
| ) |
| else: |
| instruction = "" |
| if instruction: |
| tokens.append(self.i_token) |
| mask.append(2) |
| instruction_tokens = self.tokenizer( |
| instruction, add_special_tokens=False |
| ).tokens() |
| tokens.extend(instruction_tokens) |
| mask.extend([3] * len(instruction_tokens)) |
| types = instance["schema"].get("cls") |
| if types: |
| for t in types: |
| _update_seq(t, "lc", task="cls") |
| mention_types = instance["schema"].get("ent") |
| if mention_types: |
| for mt in mention_types: |
| _update_seq(mt, "lm", task="ent") |
| discon_ent_types = instance["schema"].get("discontinuous_ent") |
| if discon_ent_types: |
| for mt in discon_ent_types: |
| _update_seq(mt, "lm", task="discontinuous_ent") |
| rel_types = instance["schema"].get("rel") |
| if rel_types: |
| for rt in rel_types: |
| _update_seq(rt, "lr", task="rel") |
| hyper_rel_schema = instance["schema"].get("hyper_rel") |
| if hyper_rel_schema: |
| for rel, qualifiers in hyper_rel_schema.items(): |
| _update_seq(rel, "lr", task="hyper_rel") |
| for qualifier in qualifiers: |
| _update_seq(qualifier, "lr", task="hyper_rel") |
| event_schema = instance["schema"].get("event") |
| if event_schema: |
| for event_type, roles in event_schema.items(): |
| _update_seq(event_type, "lm", task="event") |
| for role in roles: |
| _update_seq(role, "lr", task="event") |
|
|
| text = instance.get("text") |
| if text: |
| text_tokenized = self.tokenizer( |
| text, return_offsets_mapping=True, add_special_tokens=False |
| ) |
| if any(val for val in label_map.values()): |
| text_label_token = self.tl_token |
| else: |
| text_label_token = self.tp_token |
| tokens.append(text_label_token) |
| mask.append(6) |
| remain_token_len = self.max_seq_len - 1 - len(tokens) |
| if remain_token_len < 5 and kwargs.get("dataset_name", "train") == "train": |
| return None |
| text_off = len(tokens) |
| text_tokens = text_tokenized.tokens()[:remain_token_len] |
| tokens.extend(text_tokens) |
| mask.extend([7] * len(text_tokens)) |
| else: |
| text_tokenized = None |
|
|
| bg = instance.get("bg") |
| if bg: |
| bg_tokenized = self.tokenizer( |
| bg, return_offsets_mapping=True, add_special_tokens=False |
| ) |
| tokens.append(self.b_token) |
| mask.append(8) |
| remain_token_len = self.max_seq_len - 1 - len(tokens) |
| if remain_token_len < 5 and kwargs.get("dataset_name", "train") == "train": |
| return None |
| bg_tokens = bg_tokenized.tokens()[:remain_token_len] |
| tokens.extend(bg_tokens) |
| mask.extend([9] * len(bg_tokens)) |
| else: |
| bg_tokenized = None |
|
|
| tokens.append(self.tokenizer.sep_token) |
| mask.append(10) |
|
|
| |
| |
| spans = [] |
| if "cls" in instance["ans"]: |
| for t in instance["ans"]["cls"]: |
| part = label_map["lc"][t] |
| spans.append([part]) |
| if "ent" in instance["ans"]: |
| for ent in instance["ans"]["ent"]: |
| label_part = label_map["lm"][ent["type"]] |
| position_seq = self.char_to_token_span( |
| ent["span"], text_tokenized, text_off |
| ) |
| spans.append([label_part, position_seq]) |
| if "discontinuous_ent" in instance["ans"]: |
| for ent in instance["ans"]["discontinuous_ent"]: |
| label_part = label_map["lm"][ent["type"]] |
| ent_span = [label_part] |
| for part in ent["span"]: |
| position_seq = self.char_to_token_span( |
| part, text_tokenized, text_off |
| ) |
| ent_span.append(position_seq) |
| spans.append(ent_span) |
| if "rel" in instance["ans"]: |
| for rel in instance["ans"]["rel"]: |
| label_part = label_map["lr"][rel["relation"]] |
| head_position_seq = self.char_to_token_span( |
| rel["head"]["span"], text_tokenized, text_off |
| ) |
| tail_position_seq = self.char_to_token_span( |
| rel["tail"]["span"], text_tokenized, text_off |
| ) |
| spans.append([label_part, head_position_seq, tail_position_seq]) |
| if "hyper_rel" in instance["ans"]: |
| for rel in instance["ans"]["hyper_rel"]: |
| label_part = label_map["lr"][rel["relation"]] |
| head_position_seq = self.char_to_token_span( |
| rel["head"]["span"], text_tokenized, text_off |
| ) |
| tail_position_seq = self.char_to_token_span( |
| rel["tail"]["span"], text_tokenized, text_off |
| ) |
| |
| for q in rel["qualifiers"]: |
| q_label_part = label_map["lr"][q["label"]] |
| q_position_seq = self.char_to_token_span( |
| q["span"], text_tokenized, text_off |
| ) |
| spans.append( |
| [ |
| label_part, |
| head_position_seq, |
| tail_position_seq, |
| q_label_part, |
| q_position_seq, |
| ] |
| ) |
| if "event" in instance["ans"]: |
| for event in instance["ans"]["event"]: |
| event_type_label_part = label_map["lm"][event["event_type"]] |
| trigger_position_seq = self.char_to_token_span( |
| event["trigger"]["span"], text_tokenized, text_off |
| ) |
| trigger_part = [event_type_label_part, trigger_position_seq] |
| spans.append(trigger_part) |
| for arg in event["args"]: |
| role_label_part = label_map["lr"][arg["role"]] |
| arg_position_seq = self.char_to_token_span( |
| arg["span"], text_tokenized, text_off |
| ) |
| arg_part = [role_label_part, trigger_position_seq, arg_position_seq] |
| spans.append(arg_part) |
| if "span" in instance["ans"]: |
| |
| for span in instance["ans"]["span"]: |
| span_position_seq = self.char_to_token_span( |
| span["span"], text_tokenized, text_off |
| ) |
| spans.append([span_position_seq]) |
|
|
| if self.mode == "w2": |
| new_spans = [] |
| for parts in spans: |
| new_parts = [] |
| for part in parts: |
| new_parts.append(tuple(range(part[0], part[-1] + 1))) |
| new_spans.append(new_parts) |
| spans = new_spans |
| elif self.mode == "span": |
| spans = spans |
| else: |
| raise ValueError(f"mode={self.mode} is not supported") |
|
|
| ins = { |
| "raw": instance, |
| "tokens": tokens, |
| "input_ids": self.tokenizer.convert_tokens_to_ids(tokens), |
| "mask": mask, |
| "spans": spans, |
| "label_map": label_map, |
| "span_to_label": span_to_label, |
| "labels": None, |
| } |
| return ins |
|
|
| def char_to_token_span( |
| self, span: list[int], tokenized: BatchEncoding, offset: int = 0 |
| ) -> list[int]: |
| token_s = tokenized.char_to_token(span[0]) |
| token_e = tokenized.char_to_token(span[1] - 1) |
| if token_e == token_s: |
| position_seq = (offset + token_s,) |
| else: |
| position_seq = (offset + token_s, offset + token_e) |
| return position_seq |
|
|
| def skip_consecutive_span_labels(self, data: dict) -> dict: |
| bs = len(data["input_ids"]) |
| max_seq_len = max(len(input_ids) for input_ids in data["input_ids"]) |
| batch_seq_len = min(self.max_seq_len, max_seq_len) |
| for i in range(bs): |
| data["input_ids"][i] = data["input_ids"][i][:batch_seq_len] |
| data["mask"][i] = data["mask"][i][:batch_seq_len] |
| assert len(data["input_ids"][i]) == len(data["mask"][i]) |
| pad_len = batch_seq_len - len(data["mask"][i]) |
| data["input_ids"][i] = ( |
| data["input_ids"][i] + [self.tokenizer.pad_token_id] * pad_len |
| ) |
| data["mask"][i] = data["mask"][i] + [0] * pad_len |
| data["labels"][i] = encode_nnw_nsw_thw_mat(data["spans"][i], batch_seq_len) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| return data |
|
|