| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader, Subset |
| import torchvision.transforms as transforms |
| from PIL import Image |
| import os |
| import numpy as np |
| from bs4 import BeautifulSoup |
| import argparse |
| import logging |
| from torch.utils.tensorboard import SummaryWriter |
| from datetime import datetime |
| import json |
| from PIL import Image, ImageDraw |
| import matplotlib.pyplot as plt |
|
|
|
|
| def get_ground_truth(image, cells, otsl, split_width=5): |
|
|
| """ |
| parse OTSL to derive row/column split positions. |
| this is the groundtruth for split model training. |
| |
| Args: |
| image: PIL Image |
| html_tags: not used, kept for compatibility |
| cells: nested list - cells[0] contains actual cell data |
| otsl: OTSL token sequence |
| split_width: width of split regions in pixels (default: 5) |
| """ |
| orig_width, orig_height = image.size |
| target_size = 960 |
| |
| |
| cells_flat = cells[0] |
| |
| |
| grid = [] |
| current_row = [] |
| cell_idx = 0 |
| |
| for token in otsl: |
| if token == 'nl': |
| if current_row: |
| grid.append(current_row) |
| current_row = [] |
| elif token == 'fcel' or token=='ecel': |
| current_row.append({'type': token, 'cell_idx': cell_idx}) |
| cell_idx += 1 |
| elif token in ['lcel', 'ucel', 'xcel']: |
| |
| current_row.append({'type': token, 'cell_idx': None}) |
| |
| if current_row: |
| grid.append(current_row) |
| |
| |
| row_splits = [] |
| for row in grid: |
| row_cell_indices = [item['cell_idx'] for item in row if item['cell_idx'] is not None] |
| if row_cell_indices: |
| max_y = max(cells_flat[i]['bbox'][3] for i in row_cell_indices) |
| row_splits.append(max_y) |
| |
| |
| num_cols = len(grid[0]) if grid else 0 |
| col_splits = [] |
| for col_idx in range(num_cols): |
| col_max_x = [] |
| for row in grid: |
| if col_idx < len(row) and row[col_idx]['cell_idx'] is not None: |
| next_is_lcel = (col_idx + 1 < len(row) and row[col_idx + 1]['type'] == 'lcel') |
| if not next_is_lcel: |
| cell_id = row[col_idx]['cell_idx'] |
| col_max_x.append(cells_flat[cell_id]['bbox'][2]) |
| if col_max_x: |
| col_splits.append(max(col_max_x)) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| row_splits = row_splits[:-1] |
| col_splits = col_splits[:-1] |
|
|
| |
| y_scaled = [(y * target_size / orig_height) for y in row_splits] |
| x_scaled = [(x * target_size / orig_width) for x in col_splits] |
| |
| |
| horizontal_gt = [0] * target_size |
| vertical_gt = [0] * target_size |
|
|
| all_x1 = [c['bbox'][0] for c in cells_flat] |
| all_y1 = [c['bbox'][1] for c in cells_flat] |
| all_x2 = [c['bbox'][2] for c in cells_flat] |
| all_y2 = [c['bbox'][3] for c in cells_flat] |
| table_bbox = [min(all_x1), min(all_y1), max(all_x2), max(all_y2)] |
| table_y1 = int(round(table_bbox[1] * target_size / orig_height)) |
| table_y2 = int(round(table_bbox[3] * target_size / orig_height)) |
| table_x1 = int(round(table_bbox[0] * target_size / orig_width)) |
| table_x2 = int(round(table_bbox[2] * target_size / orig_width)) |
|
|
|
|
| |
| |
| for offset in range(split_width): |
| pos = table_y1 + offset |
| if 0 <= pos < target_size: |
| horizontal_gt[pos] = 1 |
|
|
| |
| for offset in range(split_width): |
| pos = table_y2 - offset |
| if 0 <= pos < target_size: |
| horizontal_gt[pos] = 1 |
|
|
| |
| for offset in range(split_width): |
| pos = table_x1 + offset |
| if 0 <= pos < target_size: |
| vertical_gt[pos] = 1 |
|
|
| |
| for offset in range(split_width): |
| pos = table_x2 - offset |
| if 0 <= pos < target_size: |
| vertical_gt[pos] = 1 |
|
|
| |
| for y in y_scaled: |
| y_int = int(round(y)) |
| if 0 <= y_int < target_size: |
| for offset in range(split_width): |
| pos = y_int + offset |
| if 0 <= pos < target_size: |
| horizontal_gt[pos] = 1 |
|
|
| for x in x_scaled: |
| x_int = int(round(x)) |
| if 0 <= x_int < target_size: |
| for offset in range(split_width): |
| pos = x_int + offset |
| if 0 <= pos < target_size: |
| vertical_gt[pos] = 1 |
| |
| return horizontal_gt, vertical_gt |
|
|
|
|
| def get_ground_truth_auto_gap(image, cells, otsl): |
| """ |
| Parse OTSL to derive row/column split positions with DYNAMIC gap widths. |
| This creates ground truth for the split model training. |
| |
| Args: |
| image: PIL Image |
| cells: nested list - cells[0] contains actual cell data |
| otsl: OTSL token sequence |
| """ |
| orig_width, orig_height = image.size |
| target_size = 960 |
| |
| |
| cells_flat = cells[0] |
| |
| |
| grid = [] |
| current_row = [] |
| cell_idx = 0 |
| |
| for token in otsl: |
| if token == 'nl': |
| if current_row: |
| grid.append(current_row) |
| current_row = [] |
| elif token == 'fcel' or token == 'ecel': |
| current_row.append({'type': token, 'cell_idx': cell_idx}) |
| cell_idx += 1 |
| elif token in ['lcel', 'ucel', 'xcel']: |
| |
| current_row.append({'type': token, 'cell_idx': None}) |
| |
| if current_row: |
| grid.append(current_row) |
| |
| |
| row_boundaries = [] |
| for row in grid: |
| row_cell_indices = [item['cell_idx'] for item in row if item['cell_idx'] is not None] |
| if row_cell_indices: |
| min_y1 = min(cells_flat[i]['bbox'][1] for i in row_cell_indices) |
| max_y2 = max(cells_flat[i]['bbox'][3] for i in row_cell_indices) |
| row_boundaries.append({'min_y': min_y1, 'max_y': max_y2}) |
| |
| |
| num_cols = len(grid[0]) if grid else 0 |
| col_boundaries = [] |
| for col_idx in range(num_cols): |
| col_cells = [] |
| for row in grid: |
| if col_idx < len(row) and row[col_idx]['cell_idx'] is not None: |
| |
| next_is_lcel = (col_idx + 1 < len(row) and row[col_idx + 1]['type'] == 'lcel') |
| if not next_is_lcel: |
| cell_id = row[col_idx]['cell_idx'] |
| col_cells.append(cell_id) |
| if col_cells: |
| min_x1 = min(cells_flat[i]['bbox'][0] for i in col_cells) |
| max_x2 = max(cells_flat[i]['bbox'][2] for i in col_cells) |
| col_boundaries.append({'min_x': min_x1, 'max_x': max_x2}) |
| |
| |
| all_x1 = [c['bbox'][0] for c in cells_flat] |
| all_y1 = [c['bbox'][1] for c in cells_flat] |
| all_x2 = [c['bbox'][2] for c in cells_flat] |
| all_y2 = [c['bbox'][3] for c in cells_flat] |
| table_bbox = [min(all_x1), min(all_y1), max(all_x2), max(all_y2)] |
| |
| |
| horizontal_gt = [0] * target_size |
| vertical_gt = [0] * target_size |
| |
| |
| def mark_range(gt_array, start, end, orig_dim): |
| """Mark all pixels from start to end (scaled to target_size)""" |
| start_scaled = int(round(start * target_size / orig_dim)) |
| end_scaled = int(round(end * target_size / orig_dim)) |
| for pos in range(start_scaled, min(end_scaled + 1, target_size)): |
| if 0 <= pos < target_size: |
| gt_array[pos] = 1 |
| |
| |
| |
| if row_boundaries: |
| mark_range(horizontal_gt, 0, row_boundaries[0]['min_y'], orig_height) |
| |
| |
| for i in range(len(row_boundaries) - 1): |
| gap_start = row_boundaries[i]['max_y'] |
| gap_end = row_boundaries[i + 1]['min_y'] |
| if gap_end > gap_start: |
| mark_range(horizontal_gt, gap_start, gap_end, orig_height) |
| |
| |
| if row_boundaries: |
| mark_range(horizontal_gt, row_boundaries[-1]['max_y'], orig_height, orig_height) |
| |
| |
| |
| if col_boundaries: |
| mark_range(vertical_gt, 0, col_boundaries[0]['min_x'], orig_width) |
| |
| |
| for i in range(len(col_boundaries) - 1): |
| gap_start = col_boundaries[i]['max_x'] |
| gap_end = col_boundaries[i + 1]['min_x'] |
| if gap_end > gap_start: |
| mark_range(vertical_gt, gap_start, gap_end, orig_width) |
| |
| |
| if col_boundaries: |
| mark_range(vertical_gt, col_boundaries[-1]['max_x'], orig_width, orig_width) |
| |
| return horizontal_gt, vertical_gt |
|
|
|
|
| def get_ground_truth_auto_gap_expand_min5pix_overlap_cells(image, cells, otsl, split_width=5): |
| """ |
| Parse OTSL to derive row/column split positions with DYNAMIC gap widths. |
| This creates ground truth for the split model training. |
| |
| Args: |
| image: PIL Image |
| cells: nested list - cells[0] contains actual cell data |
| otsl: OTSL token sequence |
| split_width: width of split when there's no gap (default: 5) |
| """ |
| orig_width, orig_height = image.size |
| target_size = 960 |
| |
| |
| cells_flat = cells[0] |
| |
| |
| grid = [] |
| current_row = [] |
| cell_idx = 0 |
| |
| for token in otsl: |
| if token == 'nl': |
| if current_row: |
| grid.append(current_row) |
| current_row = [] |
| elif token in ['fcel', 'ecel']: |
| current_row.append({'type': token, 'cell_idx': cell_idx}) |
| cell_idx += 1 |
| elif token in ['lcel', 'ucel', 'xcel']: |
| |
| current_row.append({'type': token, 'cell_idx': None}) |
| |
| if current_row: |
| grid.append(current_row) |
| |
| |
| row_boundaries = [] |
| for row in grid: |
| row_cell_indices = [item['cell_idx'] for item in row if item['cell_idx'] is not None] |
| if row_cell_indices: |
| min_y1 = min(cells_flat[i]['bbox'][1] for i in row_cell_indices) |
| max_y2 = max(cells_flat[i]['bbox'][3] for i in row_cell_indices) |
| row_boundaries.append({'min_y': min_y1, 'max_y': max_y2, 'row_cells': row_cell_indices}) |
| |
| |
| num_cols = len(grid[0]) if grid else 0 |
| col_boundaries = [] |
| for col_idx in range(num_cols): |
| col_cells = [] |
| for row in grid: |
| if col_idx < len(row) and row[col_idx]['cell_idx'] is not None: |
| |
| next_is_lcel = (col_idx + 1 < len(row) and row[col_idx + 1]['type'] == 'lcel') |
| if not next_is_lcel: |
| cell_id = row[col_idx]['cell_idx'] |
| col_cells.append(cell_id) |
| if col_cells: |
| min_x1 = min(cells_flat[i]['bbox'][0] for i in col_cells) |
| max_x2 = max(cells_flat[i]['bbox'][2] for i in col_cells) |
| col_boundaries.append({'min_x': min_x1, 'max_x': max_x2, 'col_cells': col_cells}) |
| |
| |
| all_x1 = [c['bbox'][0] for c in cells_flat] |
| all_y1 = [c['bbox'][1] for c in cells_flat] |
| all_x2 = [c['bbox'][2] for c in cells_flat] |
| all_y2 = [c['bbox'][3] for c in cells_flat] |
| table_bbox = [min(all_x1), min(all_y1), max(all_x2), max(all_y2)] |
| |
| |
| horizontal_gt = [0] * target_size |
| vertical_gt = [0] * target_size |
| |
| |
| def mark_range(gt_array, start, end, orig_dim): |
| """Mark all pixels from start to end (scaled to target_size)""" |
| start_scaled = int(round(start * target_size / orig_dim)) |
| end_scaled = int(round(end * target_size / orig_dim)) |
| for pos in range(start_scaled, min(end_scaled + 1, target_size)): |
| if 0 <= pos < target_size: |
| gt_array[pos] = 1 |
| |
| |
| |
| if row_boundaries: |
| mark_range(horizontal_gt, 0, row_boundaries[0]['min_y'], orig_height) |
| |
| |
| for i in range(len(row_boundaries) - 1): |
| gap_start = row_boundaries[i]['max_y'] |
| gap_end = row_boundaries[i + 1]['min_y'] |
| if gap_end > gap_start: |
| mark_range(horizontal_gt, gap_start, gap_end, orig_height) |
| else: |
| |
| curr_row_y2 = [cells_flat[cell_id]['bbox'][3] for cell_id in row_boundaries[i]['row_cells']] |
| next_row_y1 = [cells_flat[cell_id]['bbox'][1] for cell_id in row_boundaries[i + 1]['row_cells']] |
| |
| max_curr_y2 = max(curr_row_y2) |
| min_next_y1 = min(next_row_y1) |
| |
| |
| if min_next_y1 > max_curr_y2: |
| mark_range(horizontal_gt, max_curr_y2, min_next_y1, orig_height) |
| else: |
| |
| split_pos = (max_curr_y2 + min_next_y1) / 2 |
| mark_range(horizontal_gt, split_pos - split_width/2, split_pos + split_width/2, orig_height) |
| |
| |
| if row_boundaries: |
| mark_range(horizontal_gt, row_boundaries[-1]['max_y'], orig_height, orig_height) |
| |
| |
| |
| if col_boundaries: |
| mark_range(vertical_gt, 0, col_boundaries[0]['min_x'], orig_width) |
| |
| |
| for i in range(len(col_boundaries) - 1): |
| gap_start = col_boundaries[i]['max_x'] |
| gap_end = col_boundaries[i + 1]['min_x'] |
| |
| if gap_end > gap_start: |
| mark_range(vertical_gt, gap_start, gap_end, orig_width) |
| else: |
| |
| curr_col_x2 = [cells_flat[cell_id]['bbox'][2] for cell_id in col_boundaries[i]['col_cells']] |
| next_col_x1 = [cells_flat[cell_id]['bbox'][0] for cell_id in col_boundaries[i + 1]['col_cells']] |
| |
| max_curr_x2 = max(curr_col_x2) |
| min_next_x1 = min(next_col_x1) |
| |
| |
| if min_next_x1 > max_curr_x2: |
| mark_range(vertical_gt, max_curr_x2, min_next_x1, orig_width) |
| else: |
| |
| split_pos = (max_curr_x2 + min_next_x1) / 2 |
| mark_range(vertical_gt, split_pos - split_width/2, split_pos + split_width/2, orig_width) |
| |
| |
| if col_boundaries: |
| mark_range(vertical_gt, col_boundaries[-1]['max_x'], orig_width, orig_width) |
| |
| return horizontal_gt, vertical_gt |
|
|
|
|
| class BasicBlock(nn.Module): |
| """Basic ResNet block with halved channels""" |
| def __init__(self, inplanes, planes, stride=1): |
| super().__init__() |
| self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) |
| self.bn1 = nn.BatchNorm2d(planes) |
| self.relu = nn.ReLU(inplace=True) |
| self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) |
| self.bn2 = nn.BatchNorm2d(planes) |
| |
| self.downsample = None |
| if stride != 1 or inplanes != planes: |
| self.downsample = nn.Sequential( |
| nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False), |
| nn.BatchNorm2d(planes) |
| ) |
| |
| def forward(self, x): |
| residual = x |
| out = self.conv1(x) |
| out = self.bn1(out) |
| out = self.relu(out) |
| out = self.conv2(out) |
| out = self.bn2(out) |
| |
| if self.downsample is not None: |
| residual = self.downsample(x) |
| |
| out += residual |
| out = self.relu(out) |
| return out |
|
|
| class ModifiedResNet18(nn.Module): |
| """ResNet-18 with removed maxpool and halved channels""" |
| def __init__(self): |
| super().__init__() |
| |
| self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3, bias=False) |
| self.bn1 = nn.BatchNorm2d(32) |
| self.relu = nn.ReLU(inplace=True) |
| |
| |
| |
| self.layer1 = self._make_layer(32, 32, 2, stride=1) |
| self.layer2 = self._make_layer(32, 64, 2, stride=2) |
| self.layer3 = self._make_layer(64, 128, 2, stride=2) |
| self.layer4 = self._make_layer(128, 256, 2, stride=2) |
| |
| def _make_layer(self, inplanes, planes, blocks, stride=1): |
| layers = [] |
| layers.append(BasicBlock(inplanes, planes, stride)) |
| for _ in range(1, blocks): |
| layers.append(BasicBlock(planes, planes)) |
| return nn.Sequential(*layers) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = self.relu(x) |
| |
| |
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
| x = self.layer4(x) |
| return x |
|
|
| class FPN(nn.Module): |
| """Feature Pyramid Network outputting 128 channels at H/2×W/2""" |
| def __init__(self): |
| super().__init__() |
| self.conv = nn.Conv2d(256, 128, kernel_size=1) |
| |
| def forward(self, x): |
| |
| x = self.conv(x) |
| |
| x = F.interpolate(x, size=(480, 480), mode='bilinear', align_corners=False) |
| return x |
|
|
| class SplitModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.backbone = ModifiedResNet18() |
| self.fpn = FPN() |
| |
| |
| self.h_global_weight = nn.Parameter(torch.randn(480)) |
| self.v_global_weight = nn.Parameter(torch.randn(480)) |
| |
| |
| self.h_local_conv = nn.Conv2d(128, 1, kernel_size=1) |
| self.v_local_conv = nn.Conv2d(128, 1, kernel_size=1) |
| |
| |
| feature_dim = 128 + 120 |
|
|
| |
| self.h_pos_embed = nn.Parameter(torch.randn(480, feature_dim)) |
| self.v_pos_embed = nn.Parameter(torch.randn(480, feature_dim)) |
|
|
| |
| self.h_transformer = nn.TransformerEncoder( |
| nn.TransformerEncoderLayer( |
| d_model=feature_dim, nhead=8, dim_feedforward=2048, |
| dropout=0.1, batch_first=True |
| ), |
| num_layers=3 |
| ) |
| self.v_transformer = nn.TransformerEncoder( |
| nn.TransformerEncoderLayer( |
| d_model=feature_dim, nhead=8, dim_feedforward=2048, |
| dropout=0.1, batch_first=True |
| ), |
| num_layers=3 |
| ) |
|
|
| |
| self.h_classifier = nn.Linear(feature_dim, 1) |
| self.v_classifier = nn.Linear(feature_dim, 1) |
| |
| def forward(self, x): |
| |
| features = self.backbone(x) |
| F_half = self.fpn(features) |
|
|
| B, C, H, W = F_half.shape |
|
|
| |
| |
| F_RG = torch.einsum('bchw,w->bch', F_half, self.h_global_weight) |
| F_RG = F_RG.transpose(1, 2) |
|
|
| |
| F_RL_pooled = F.avg_pool2d(F_half, kernel_size=(1, 4)) |
| F_RL = self.h_local_conv(F_RL_pooled) |
| F_RL = F_RL.squeeze(1) |
|
|
| |
| F_RG_L = torch.cat([F_RG, F_RL], dim=2) |
|
|
| |
| F_RG_L = F_RG_L + self.h_pos_embed |
|
|
| |
| |
| F_CG = torch.einsum('bchw,h->bcw', F_half, self.v_global_weight) |
| F_CG = F_CG.transpose(1, 2) |
|
|
| |
| F_CL_pooled = F.avg_pool2d(F_half, kernel_size=(4, 1)) |
| F_CL = self.v_local_conv(F_CL_pooled) |
| F_CL = F_CL.squeeze(1) |
| F_CL = F_CL.transpose(1, 2) |
|
|
| |
| F_CG_L = torch.cat([F_CG, F_CL], dim=2) |
|
|
| |
| F_CG_L = F_CG_L + self.v_pos_embed |
|
|
| |
| F_R = self.h_transformer(F_RG_L) |
| F_C = self.v_transformer(F_CG_L) |
|
|
| |
| h_logits = self.h_classifier(F_R).squeeze(-1) |
| v_logits = self.v_classifier(F_C).squeeze(-1) |
|
|
| |
| return torch.sigmoid(h_logits), torch.sigmoid(v_logits) |
|
|
| def focal_loss(predictions, targets, alpha=1.0, gamma=2.0): |
| """Focal loss exactly as specified in paper""" |
| ce_loss = F.binary_cross_entropy(predictions, targets, reduction='none') |
| pt = torch.where(targets == 1, predictions, 1 - predictions) |
| focal_weight = alpha * (1 - pt) ** gamma |
| return (focal_weight * ce_loss).mean() |
|
|
| def post_process_predictions(h_pred, v_pred, threshold=0.5): |
| """ |
| Simple post-processing to convert predictions to binary masks |
| """ |
| h_binary = (h_pred > threshold).float() |
| v_binary = (v_pred > threshold).float() |
|
|
| return h_binary, v_binary |
|
|
| class TableDataset(Dataset): |
| def __init__(self, hf_dataset): |
| self.hf_dataset = hf_dataset |
| self.transform = transforms.Compose([ |
| transforms.Resize((960, 960)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| def __len__(self): |
| return len(self.hf_dataset) |
|
|
| def __getitem__(self, idx): |
| item = self.hf_dataset[idx] |
|
|
| image = item['image'].convert('RGB') |
| image_transformed = self.transform(image) |
|
|
| |
| h_gt_960, v_gt_960 = get_ground_truth_auto_gap( |
| item['image'], |
| item['cells'], |
| item['otsl'], |
| ) |
|
|
| |
| h_gt_480 = [h_gt_960[i] for i in range(0, 960, 2)] |
| v_gt_480 = [v_gt_960[i] for i in range(0, 960, 2)] |
|
|
| return ( |
| image_transformed, |
| torch.tensor(h_gt_480, dtype=torch.float), |
| torch.tensor(v_gt_480, dtype=torch.float), |
| torch.tensor(h_gt_960, dtype=torch.float), |
| torch.tensor(v_gt_960, dtype=torch.float), |
| ) |
|
|
|
|
|
|