| from collections import defaultdict |
|
|
| import torch |
| from rex.utils.iteration import windowed_queue_iter |
| from rex.utils.position import find_all_positions |
|
|
|
|
| def find_paths_from_adj_mat(adj_mat: torch.Tensor) -> list[tuple[int]]: |
| assert adj_mat.shape[0] == adj_mat.shape[1] and len(adj_mat.shape) == 2 |
|
|
| paths = [] |
| self_loops = set() |
| adj_map = defaultdict(set) |
| rev_adj_map = defaultdict(set) |
| |
| for c, n in adj_mat.detach().nonzero().tolist(): |
| |
| if c == n: |
| self_loops.add(c) |
| else: |
| adj_map[c].add(n) |
| |
| rev_adj_map[n].add(c) |
| for self_loop_node in self_loops: |
| paths.append((self_loop_node,)) |
|
|
| def track(path: tuple[int], c: int): |
| visited: set[tuple[int]] = set() |
| stack = [(path, c)] |
| while stack: |
| path, c = stack.pop() |
| if c in adj_map: |
| for n in adj_map[c]: |
| if (c, n) in visited: |
| continue |
| visited.add((c, n)) |
| stack.append((path + (c,), n)) |
| |
| if path: |
| paths.append(path + (c,)) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| start_nodes = set(adj_map.keys()) - set(rev_adj_map.keys()) |
| for c in start_nodes: |
| ns = adj_map[c] |
| for n in ns: |
| track((c,), n) |
|
|
| return paths |
|
|
|
|
| def encode_nnw_thw_mat( |
| spans: list[tuple[int]], seq_len: int, nnw_id: int = 0, thw_id: int = 1 |
| ) -> torch.Tensor: |
| mat = torch.zeros(2, seq_len, seq_len) |
| for span in spans: |
| if len(span) == 1: |
| mat[:, span[0], span[0]] = 1 |
| else: |
| for s, e in windowed_queue_iter(span, 2, 1, drop_last=True): |
| mat[nnw_id, s, e] = 1 |
| mat[thw_id, span[-1], span[0]] = 1 |
| return mat |
|
|
|
|
| def decode_nnw_thw_mat( |
| batch_mat: torch.LongTensor, |
| nnw_id: int = 0, |
| thw_id: int = 1, |
| offsets: list[int] = None, |
| ) -> list[list[tuple[int]]]: |
| """Decode NNW THW matrix into a list of spans |
| |
| Args: |
| matrix: (batch_size, 2, seq_len, seq_len) |
| """ |
| ins_num, cls_num, seq_len1, seq_len2 = batch_mat.shape |
| assert seq_len1 == seq_len2 |
| assert cls_num == 2 |
|
|
| result_batch = [] |
| for ins_id in range(ins_num): |
| offset = offsets[ins_id] if offsets else 0 |
| ins_span_paths = [] |
| |
| ins_mat = batch_mat[ins_id] |
| nnw_paths = find_paths_from_adj_mat(ins_mat[nnw_id, ...]) |
| end_start_to_paths = defaultdict(set) |
| for path in nnw_paths: |
| end_start_to_paths[(path[-1], path[0])].add(path) |
| thw_pairs = ins_mat[thw_id, ...].detach().nonzero().tolist() |
| |
| for e, s in thw_pairs: |
| for path in end_start_to_paths[(e, s)]: |
| ins_span_paths.append(tuple(i - offset for i in path)) |
| result_batch.append(ins_span_paths) |
|
|
| return result_batch |
|
|
|
|
| def decode_pointer_mat( |
| batch_mat: torch.LongTensor, offsets: list[int] = None |
| ) -> list[list[tuple[int]]]: |
| batch_paths = [] |
| for i in range(len(batch_mat)): |
| offset = offsets[i] if offsets else 0 |
| coordinates = (batch_mat[i, 0] == 1).nonzero().tolist() |
| paths = [] |
| for s, e in coordinates: |
| path = tuple(range(s - offset, e + 1 - offset)) |
| paths.append(path) |
| batch_paths.append(paths) |
| return batch_paths |
|
|
|
|
| def encode_nnw_nsw_thw_mat( |
| spans: list[list[tuple[int]]], |
| seq_len: int, |
| nnw_id: int = 0, |
| nsw_id: int = 1, |
| thw_id: int = 2, |
| ) -> torch.Tensor: |
| mat = torch.zeros(3, seq_len, seq_len) |
| for parts in spans: |
| span = () |
| for p_i, part in enumerate(parts): |
| if not all(0 <= el <= seq_len - 1 for el in part): |
| continue |
| span += part |
| if p_i < len(parts) - 1 and 0 <= parts[p_i + 1][0] <= seq_len - 1: |
| |
| mat[nsw_id, parts[p_i][-1], parts[p_i + 1][0]] = 1 |
| if len(span) == 1: |
| mat[:, span[0], span[0]] = 1 |
| elif len(span) > 1: |
| for s, e in windowed_queue_iter(span, 2, 1, drop_last=True): |
| mat[nnw_id, s, e] = 1 |
| if span: |
| mat[thw_id, span[-1], span[0]] = 1 |
| return mat |
|
|
|
|
| def split_tuple_by_positions(nums, positions) -> list: |
| """ |
| Examples: |
| >>> nums = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10) |
| >>> positions = [2, 5, 7] |
| >>> split_tuple_by_positions(nums, positions) |
| ((1, 2), (3, 4, 5), (6, 7), (8, 9, 10)) |
| """ |
| |
| if not all(p < len(nums) for p in positions): |
| raise ValueError("Invalid positions") |
|
|
| |
| positions = [0] + sorted(positions) + [len(nums)] |
|
|
| |
| result = [] |
| for i in range(1, len(positions)): |
| start = positions[i - 1] |
| end = positions[i] |
| result.append(nums[start:end]) |
|
|
| return result |
|
|
|
|
| def decode_nnw_nsw_thw_mat( |
| batch_mat: torch.LongTensor, |
| nnw_id: int = 0, |
| nsw_id: int = 1, |
| thw_id: int = 2, |
| offsets: list[int] = None, |
| ) -> list[list[tuple[int]]]: |
| """Decode NNW NSW THW matrix into a list of spans |
| One span has multiple parts |
| |
| Args: |
| batch_mat: (batch_size, 3, seq_len, seq_len) |
| """ |
| ins_num, cls_num, seq_len1, seq_len2 = batch_mat.shape |
| assert seq_len1 == seq_len2 |
| assert cls_num == 3 |
|
|
| result_batch = [] |
| for ins_id in range(ins_num): |
| offset = offsets[ins_id] if offsets else 0 |
| ins_span_paths = set() |
| |
| ins_mat = batch_mat[ins_id] |
| nsw_connections = { |
| (part1e, part2s) |
| for part1e, part2s in ins_mat[nsw_id, ...].detach().nonzero().tolist() |
| } |
| nnw_paths = find_paths_from_adj_mat(ins_mat[nnw_id, ...]) |
| end_start_to_paths = defaultdict(set) |
| for path in nnw_paths: |
| end_start_to_paths[(path[-1], path[0])].add(path) |
| thw_pairs = ins_mat[thw_id, ...].detach().nonzero().tolist() |
| |
| for e, s in thw_pairs: |
| for path in nnw_paths: |
| if s in path: |
| sub_path = path[path.index(s) :] |
| if e in sub_path: |
| sub_path = sub_path[: sub_path.index(e) + 1] |
| chain = tuple(i - offset for i in sub_path) |
| parts = [] |
| all_sep_positions = set() |
| |
| if len(chain) > 1: |
| for sep in nsw_connections: |
| sep = tuple(i - offset for i in sep) |
| positions = find_all_positions(list(chain), list(sep)) |
| if positions: |
| |
| |
| positions = {p[0] + 1 for p in positions} |
| all_sep_positions.update(positions) |
| parts = split_tuple_by_positions(chain, all_sep_positions) |
| if not parts: |
| parts = [chain] |
| ins_span_paths.add(tuple(parts)) |
| result_batch.append(list(ins_span_paths)) |
|
|
| return result_batch |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|