| import math |
| import re |
| from collections import defaultdict |
| from datetime import datetime |
| from typing import List |
|
|
| import torch |
| import torch.optim as optim |
| from rex import accelerator |
| from rex.data.data_manager import DataManager |
| from rex.data.dataset import CachedDataset, StreamReadDataset |
| from rex.tasks.simple_metric_task import SimpleMetricTask |
| from rex.utils.batch import decompose_batch_into_instances |
| from rex.utils.config import ConfigParser |
| from rex.utils.dict import flatten_dict |
| from rex.utils.io import load_jsonlines |
| from rex.utils.registry import register |
| from torch.utils.tensorboard import SummaryWriter |
| from transformers.optimization import ( |
| get_cosine_schedule_with_warmup, |
| get_linear_schedule_with_warmup, |
| ) |
|
|
| from .metric import MrcNERMetric, MrcSpanMetric, MultiPartSpanMetric |
| from .model import ( |
| MrcGlobalPointerModel, |
| MrcPointerMatrixModel, |
| SchemaGuidedInstructBertModel, |
| ) |
| from .transform import ( |
| CachedLabelPointerTransform, |
| CachedPointerMRCTransform, |
| CachedPointerTaggingTransform, |
| ) |
|
|
|
|
| @register("task") |
| class MrcTaggingTask(SimpleMetricTask): |
| def __init__(self, config, **kwargs) -> None: |
| super().__init__(config, **kwargs) |
|
|
| def after_initialization(self): |
| now_string = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") |
| self.tb_logger: SummaryWriter = SummaryWriter( |
| log_dir=self.task_path / "tb_summary" / now_string, |
| comment=self.config.comment, |
| ) |
|
|
| def after_whole_train(self): |
| self.tb_logger.close() |
|
|
| def get_grad_norm(self): |
| |
| |
| |
| |
| total_norm = 0.0 |
| for p in self.model.parameters(): |
| if p.grad is not None: |
| param_norm = p.grad.detach().data.norm(2) |
| total_norm += param_norm.item() ** 2 |
| total_norm = total_norm ** (1.0 / 2) |
| return total_norm |
|
|
| def log_loss( |
| self, idx: int, loss_item: float, step_or_epoch: str, dataset_name: str |
| ): |
| self.tb_logger.add_scalar( |
| f"loss/{dataset_name}/{step_or_epoch}", loss_item, idx |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| self.tb_logger.add_scalar("lr", self.optimizer.param_groups[0]["lr"], idx) |
| self.tb_logger.add_scalar("grad_norm_total", self.get_grad_norm(), idx) |
|
|
| def log_metrics( |
| self, idx: int, metrics: dict, step_or_epoch: str, dataset_name: str |
| ): |
| metrics = flatten_dict(metrics) |
| self.tb_logger.add_scalars(f"{dataset_name}/{step_or_epoch}", metrics, idx) |
|
|
| def init_transform(self): |
| return CachedPointerTaggingTransform( |
| self.config.max_seq_len, |
| self.config.plm_dir, |
| self.config.ent_type2query_filepath, |
| mode=self.config.mode, |
| negative_sample_prob=self.config.negative_sample_prob, |
| ) |
|
|
| def init_data_manager(self): |
| return DataManager( |
| self.config.train_filepath, |
| self.config.dev_filepath, |
| self.config.test_filepath, |
| CachedDataset, |
| self.transform, |
| load_jsonlines, |
| self.config.train_batch_size, |
| self.config.eval_batch_size, |
| self.transform.collate_fn, |
| use_stream_transform=False, |
| debug_mode=self.config.debug_mode, |
| dump_cache_dir=self.config.dump_cache_dir, |
| regenerate_cache=self.config.regenerate_cache, |
| ) |
|
|
| def init_model(self): |
| |
| m = MrcGlobalPointerModel( |
| self.config.plm_dir, |
| biaffine_size=self.config.biaffine_size, |
| dropout=self.config.dropout, |
| mode=self.config.mode, |
| ) |
| return m |
|
|
| def init_metric(self): |
| return MrcNERMetric() |
|
|
| def init_optimizer(self): |
| no_decay = r"(embedding|LayerNorm|\.bias$)" |
| plm_lr = r"^plm\." |
| non_trainable = r"^plm\.(emb|encoder\.layer\.[0-3])" |
|
|
| param_groups = [] |
| for name, param in self.model.named_parameters(): |
| lr = self.config.learning_rate |
| weight_decay = self.config.weight_decay |
| if re.search(non_trainable, name): |
| param.requires_grad = False |
| if not re.search(plm_lr, name): |
| lr = self.config.other_learning_rate |
| if re.search(no_decay, name): |
| weight_decay = 0.0 |
| param_groups.append( |
| {"params": param, "lr": lr, "weight_decay": weight_decay} |
| ) |
| return optim.AdamW( |
| param_groups, |
| lr=self.config.learning_rate, |
| betas=(0.9, 0.98), |
| eps=1e-6, |
| ) |
|
|
| def init_lr_scheduler(self): |
| num_training_steps = int( |
| len(self.data_manager.train_loader) |
| * self.config.num_epochs |
| * accelerator.num_processes |
| ) |
| num_warmup_steps = math.floor( |
| num_training_steps * self.config.warmup_proportion |
| ) |
| |
| return get_cosine_schedule_with_warmup( |
| self.optimizer, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=num_training_steps, |
| ) |
|
|
| def predict_api(self, texts: List[str], **kwargs): |
| raw_dataset = self.transform.predict_transform(texts) |
| text_ids = sorted(list({ins["id"] for ins in raw_dataset})) |
| loader = self.data_manager.prepare_loader(raw_dataset) |
| |
| loader = accelerator.prepare_data_loader(loader) |
| id2ents = defaultdict(set) |
| for batch in loader: |
| batch_out = self.model(**batch, is_eval=True) |
| for _id, _pred in zip(batch["id"], batch_out["pred"]): |
| id2ents[_id].update(_pred) |
| results = [id2ents[_id] for _id in text_ids] |
|
|
| return results |
|
|
|
|
| @register("task") |
| class MrcQaTask(MrcTaggingTask): |
| def init_transform(self): |
| return CachedPointerMRCTransform( |
| self.config.max_seq_len, |
| self.config.plm_dir, |
| mode=self.config.mode, |
| ) |
|
|
| def init_model(self): |
| |
| m = MrcGlobalPointerModel( |
| self.config.plm_dir, |
| biaffine_size=self.config.biaffine_size, |
| dropout=self.config.dropout, |
| mode=self.config.mode, |
| ) |
| return m |
|
|
| def init_metric(self): |
| return MrcSpanMetric() |
|
|
| def predict_api(self, data: list[dict], **kwargs): |
| """ |
| Args: |
| data: a list of dict with query, context, and background strings |
| """ |
| raw_dataset = self.transform.predict_transform(data) |
| loader = self.data_manager.prepare_loader(raw_dataset) |
| results = [] |
| for batch in loader: |
| batch_out = self.model(**batch, is_eval=True) |
| batch["pred"] = batch_out["pred"] |
| instances = decompose_batch_into_instances(batch) |
| for ins in instances: |
| preds = ins["pred"] |
| ins_results = [] |
| for index_list in preds: |
| ins_result = [] |
| for i in index_list: |
| ins_result.append(ins["raw_tokens"][i]) |
| ins_results.append(("".join(ins_result), tuple(index_list))) |
| results.append(ins_results) |
|
|
| return results |
|
|
|
|
| class StreamReadDatasetWithLen(StreamReadDataset): |
| def __len__(self): |
| return 631346 |
|
|
|
|
| @register("task") |
| class SchemaGuidedInstructBertTask(MrcTaggingTask): |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def init_transform(self): |
| self.transform: CachedLabelPointerTransform |
| return CachedLabelPointerTransform( |
| self.config.max_seq_len, |
| self.config.plm_dir, |
| mode=self.config.mode, |
| label_span=self.config.label_span, |
| include_instructions=self.config.get("include_instructions", True), |
| ) |
|
|
| def init_data_manager(self): |
| if self.config.get("stream_mode", False): |
| DatasetClass = StreamReadDatasetWithLen |
| transform = self.transform.transform |
| else: |
| DatasetClass = CachedDataset |
| transform = self.transform |
| return DataManager( |
| self.config.train_filepath, |
| self.config.dev_filepath, |
| self.config.test_filepath, |
| DatasetClass, |
| transform, |
| load_jsonlines, |
| self.config.train_batch_size, |
| self.config.eval_batch_size, |
| self.transform.collate_fn, |
| use_stream_transform=self.config.get("stream_mode", False), |
| debug_mode=self.config.debug_mode, |
| dump_cache_dir=self.config.dump_cache_dir, |
| regenerate_cache=self.config.regenerate_cache, |
| ) |
|
|
| def init_model(self): |
| self.model = SchemaGuidedInstructBertModel( |
| self.config.plm_dir, |
| vocab_size=len(self.transform.tokenizer), |
| use_rope=self.config.use_rope, |
| biaffine_size=self.config.biaffine_size, |
| dropout=self.config.dropout, |
| ) |
|
|
| if self.config.get("base_model_path"): |
| self.load( |
| self.config.base_model_path, |
| load_config=False, |
| load_model=True, |
| load_optimizer=False, |
| load_history=False, |
| ) |
| return self.model |
|
|
| def init_optimizer(self): |
| no_decay = r"(embedding|LayerNorm|\.bias$)" |
| plm_lr = r"^plm\." |
| |
| non_trainable = "no_non_trainable" |
|
|
| param_groups = [] |
| for name, param in self.model.named_parameters(): |
| lr = self.config.learning_rate |
| weight_decay = self.config.weight_decay |
| if re.search(non_trainable, name): |
| param.requires_grad = False |
| if not re.search(plm_lr, name): |
| lr = self.config.other_learning_rate |
| if re.search(no_decay, name): |
| weight_decay = 0.0 |
| param_groups.append( |
| {"params": param, "lr": lr, "weight_decay": weight_decay} |
| ) |
| return optim.AdamW( |
| param_groups, |
| lr=self.config.learning_rate, |
| betas=(0.9, 0.98), |
| eps=1e-6, |
| ) |
|
|
| def init_metric(self): |
| return MultiPartSpanMetric() |
|
|
| def _convert_span_to_string(self, span, token_ids, tokenizer): |
| string = "" |
| if len(span) == 0 or len(span) > 2: |
| pass |
| elif len(span) == 1: |
| string = tokenizer.decode(token_ids[span[0]]) |
| elif len(span) == 2: |
| string = tokenizer.decode(token_ids[span[0] : span[1] + 1]) |
| return (string, self.reset_position(token_ids, span)) |
|
|
| def reset_position(self, input_ids: list[int], span: list[int]) -> list[int]: |
| if isinstance(input_ids, torch.Tensor): |
| input_ids = input_ids.cpu().tolist() |
| if len(span) < 1: |
| return span |
|
|
| tp_token_id, tl_token_id = self.transform.tokenizer.convert_tokens_to_ids( |
| [self.transform.tp_token, self.transform.tl_token] |
| ) |
| offset = 0 |
| if tp_token_id in input_ids: |
| offset = input_ids.index(tp_token_id) + 1 |
| elif tl_token_id in input_ids: |
| offset = input_ids.index(tl_token_id) + 1 |
| return [i - offset for i in span] |
|
|
| def predict_api(self, data: list[dict], **kwargs): |
| """ |
| Args: |
| data: a list of dict in UDI: |
| { |
| "id": str, |
| "instruction": str, |
| "schema": { |
| "ent": list, |
| "rel": list, |
| "event": dict, |
| "cls": list, |
| "discontinuous_ent": list, |
| "hyper_rel": dict |
| }, |
| "text": str, |
| "bg": str, |
| "ans": {}, # empty dict |
| } |
| """ |
| raw_dataset = [self.transform.transform(d) for d in data] |
| loader = self.data_manager.prepare_loader(raw_dataset) |
| results = [] |
| for batch in loader: |
| batch_out = self.model(**batch, is_eval=True) |
| batch["pred"] = batch_out["pred"] |
| instances = decompose_batch_into_instances(batch) |
| for ins in instances: |
| pred_clses = [] |
| pred_ents = [] |
| pred_rels = [] |
| pred_trigger_to_event = defaultdict( |
| lambda: {"event_type": "", "arguments": []} |
| ) |
| pred_events = [] |
| pred_spans = [] |
| pred_discon_ents = [] |
| pred_hyper_rels = [] |
| raw_schema = ins["raw"]["schema"] |
| for multi_part_span in ins["pred"]: |
| span = tuple(multi_part_span) |
| span_to_label = ins["span_to_label"] |
| if span[0] in span_to_label: |
| label = span_to_label[span[0]] |
| if label["task"] == "cls" and len(span) == 1: |
| pred_clses.append(label["string"]) |
| elif label["task"] == "ent" and len(span) == 2: |
| string = self._convert_span_to_string( |
| span[1], ins["input_ids"], self.transform.tokenizer |
| ) |
| pred_ents.append((label["string"], string)) |
| elif label["task"] == "rel" and len(span) == 3: |
| head = self._convert_span_to_string( |
| span[1], ins["input_ids"], self.transform.tokenizer |
| ) |
| tail = self._convert_span_to_string( |
| span[2], ins["input_ids"], self.transform.tokenizer |
| ) |
| pred_rels.append((label["string"], head, tail)) |
| elif label["task"] == "event": |
| if label["type"] == "lm" and len(span) == 2: |
| pred_trigger_to_event[span[1]]["event_type"] = label["string"] |
| elif label["type"] == "lr" and len(span) == 3: |
| arg = self._convert_span_to_string( |
| span[2], ins["input_ids"], self.transform.tokenizer |
| ) |
| pred_trigger_to_event[span[1]]["arguments"].append( |
| {"argument": arg, "role": label["string"]} |
| ) |
| elif label["task"] == "discontinuous_ent" and len(span) > 1: |
| parts = [ |
| self._convert_span_to_string( |
| part, ins["input_ids"], self.transform.tokenizer |
| ) |
| for part in span[1:] |
| ] |
| string = " ".join([part[0] for part in parts]) |
| position = [] |
| for part in parts: |
| position.append(part[1]) |
| pred_discon_ents.append( |
| (label["string"], string, self.reset_position(position)) |
| ) |
| elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: |
| q_label = span_to_label[span[3]] |
| span_1 = self._convert_span_to_string( |
| span[1], ins["input_ids"], self.transform.tokenizer |
| ) |
| span_2 = self._convert_span_to_string( |
| span[2], ins["input_ids"], self.transform.tokenizer |
| ) |
| span_4 = self._convert_span_to_string( |
| span[4], ins["input_ids"], self.transform.tokenizer |
| ) |
| pred_hyper_rels.append((label["string"], span_1, span_2, q_label["string"], span_4)) |
| else: |
| |
| pred_token_ids = [] |
| for part in span: |
| _pred_token_ids = [ins["input_ids"][i] for i in part] |
| pred_token_ids.extend(_pred_token_ids) |
| span_string = self.transform.tokenizer.decode(pred_token_ids) |
| pred_spans.append( |
| ( |
| span_string, |
| tuple( |
| [ |
| tuple( |
| self.reset_position( |
| ins["input_ids"].cpu().tolist(), part |
| ) |
| ) |
| for part in span |
| ] |
| ), |
| ) |
| ) |
| for trigger, item in pred_trigger_to_event.items(): |
| trigger = self._convert_span_to_string( |
| trigger, ins["input_ids"], self.transform.tokenizer |
| ) |
| if item["event_type"] not in raw_schema["event"]: |
| continue |
| legal_roles = raw_schema["event"][item["event_type"]] |
| pred_events.append( |
| { |
| "trigger": trigger, |
| "event_type": item["event_type"], |
| "arguments": [ |
| arg |
| for arg in filter( |
| lambda arg: arg["role"] in legal_roles, |
| item["arguments"], |
| ) |
| ], |
| } |
| ) |
| results.append( |
| { |
| "id": ins["raw"]["id"], |
| "results": { |
| "cls": pred_clses, |
| "ent": pred_ents, |
| "rel": pred_rels, |
| "event": pred_events, |
| "span": pred_spans, |
| "discon_ent": pred_discon_ents, |
| "hyper_rel": pred_hyper_rels, |
| }, |
| } |
| ) |
|
|
| return results |
|
|
|
|
| if __name__ == "__main__": |
| pass |
| |
|
|
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|