| import sys |
| import os |
|
|
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor, DataCollatorForTokenClassification |
| from src.sroie_loader import load_sroie |
| from PIL import Image |
| from tqdm import tqdm |
| from seqeval.metrics import f1_score, precision_score, recall_score |
| from pathlib import Path |
| import os |
|
|
| |
| print("Setting up configuration...") |
| label_list = ['O', 'B-COMPANY', 'I-COMPANY', 'B-DATE', 'I-DATE', |
| 'B-ADDRESS', 'I-ADDRESS', 'B-TOTAL', 'I-TOTAL'] |
| label2id = {label: idx for idx, label in enumerate(label_list)} |
| id2label = {idx: label for idx, label in enumerate(label_list)} |
|
|
| MODEL_CHECKPOINT = "microsoft/layoutlmv3-base" |
| SROIE_DATA_PATH = os.getenv("SROIE_DATA_PATH", os.path.join("data", "sroie")) |
|
|
| |
| class SROIEDataset(Dataset): |
| """PyTorch Dataset for SROIE data.""" |
| def __init__(self, data, processor, label2id): |
| self.data = data |
| self.processor = processor |
| self.label2id = label2id |
| |
| def __len__(self): |
| return len(self.data) |
| |
| def __getitem__(self, idx): |
| example = self.data[idx] |
| |
| |
| image = Image.open(example['image_path']).convert("RGB") |
| width, height = image.size |
| |
| |
| boxes = [] |
| for box in example['bboxes']: |
| x, y, w, h = box |
| x0, y0, x1, y1 = x, y, x + w, y + h |
| |
| x0_norm = int((x0 / width) * 1000) |
| y0_norm = int((y0 / height) * 1000) |
| x1_norm = int((x1 / width) * 1000) |
| y1_norm = int((y1 / height) * 1000) |
| |
| |
| x0_norm = max(0, min(x0_norm, 1000)) |
| y0_norm = max(0, min(y0_norm, 1000)) |
| x1_norm = max(0, min(x1_norm, 1000)) |
| y1_norm = max(0, min(y1_norm, 1000)) |
| |
| boxes.append([x0_norm, y0_norm, x1_norm, y1_norm]) |
| |
| |
| word_labels = [self.label2id[label] for label in example['ner_tags']] |
| |
| |
| encoding = self.processor( |
| image, |
| text=example['words'], |
| boxes=boxes, |
| word_labels=word_labels, |
| truncation=True, |
| max_length=512, |
| return_tensors="pt" |
| ) |
| |
| |
| item = {key: val.squeeze(0) for key, val in encoding.items()} |
| return item |
|
|
| |
| def train(): |
| """Main function to run the training process.""" |
| |
| print("Loading SROIE dataset...") |
| raw_dataset = load_sroie(SROIE_DATA_PATH) |
| |
| |
| print("Creating processor...") |
| processor = LayoutLMv3Processor.from_pretrained(MODEL_CHECKPOINT, apply_ocr=False) |
|
|
| |
| print("Creating PyTorch datasets and dataloaders...") |
| train_dataset = SROIEDataset(raw_dataset['train'], processor, label2id) |
| test_dataset = SROIEDataset(raw_dataset['test'], processor, label2id) |
|
|
| data_collator = DataCollatorForTokenClassification( |
| tokenizer=processor.tokenizer, |
| padding=True, |
| return_tensors="pt" |
| ) |
|
|
| train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=data_collator) |
| test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=data_collator) |
|
|
| |
| print("Loading LayoutLMv3 model for fine-tuning...") |
| model = LayoutLMv3ForTokenClassification.from_pretrained( |
| MODEL_CHECKPOINT, |
| num_labels=len(label_list), |
| id2label=id2label, |
| label2id=label2id |
| ) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| print(f"Training on: {device}") |
|
|
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) |
|
|
| |
| best_f1 = 0 |
| NUM_EPOCHS = 10 |
|
|
| for epoch in range(NUM_EPOCHS): |
| print(f"\n{'='*60}\nEpoch {epoch + 1}/{NUM_EPOCHS}\n{'='*60}") |
| |
| |
| model.train() |
| total_train_loss = 0 |
| train_progress_bar = tqdm(train_dataloader, desc=f"Training Epoch {epoch+1}") |
| for batch in train_progress_bar: |
| batch = {k: v.to(device) for k, v in batch.items()} |
| |
| outputs = model(**batch) |
| loss = outputs.loss |
| |
| loss.backward() |
| optimizer.step() |
| optimizer.zero_grad() |
| |
| total_train_loss += loss.item() |
| train_progress_bar.set_postfix({'loss': f'{loss.item():.4f}'}) |
| |
| avg_train_loss = total_train_loss / len(train_dataloader) |
|
|
| |
| model.eval() |
| all_predictions = [] |
| all_labels = [] |
| with torch.no_grad(): |
| for batch in tqdm(test_dataloader, desc="Validation"): |
| batch = {k: v.to(device) for k, v in batch.items()} |
| outputs = model(**batch) |
| |
| predictions = outputs.logits.argmax(dim=-1) |
| labels = batch['labels'] |
|
|
| for i in range(labels.shape[0]): |
| true_labels_i = [id2label[l.item()] for l in labels[i] if l.item() != -100] |
| pred_labels_i = [id2label[p.item()] for p, l in zip(predictions[i], labels[i]) if l.item() != -100] |
| all_labels.append(true_labels_i) |
| all_predictions.append(pred_labels_i) |
| |
| |
| f1 = f1_score(all_labels, all_predictions) |
| precision = precision_score(all_labels, all_predictions) |
| recall = recall_score(all_labels, all_predictions) |
| |
| print(f"\n📊 Epoch {epoch + 1} Results:") |
| print(f" Train Loss: {avg_train_loss:.4f}") |
| print(f" F1 Score: {f1:.4f}") |
| print(f" Precision: {precision:.4f}") |
| print(f" Recall: {recall:.4f}") |
| |
| |
| if f1 > best_f1: |
| best_f1 = f1 |
| print(f" 🌟 New best F1! Saving model...") |
| save_path = Path("./models/layoutlmv3-sroie-best") |
| save_path.mkdir(parents=True, exist_ok=True) |
| model.save_pretrained(save_path) |
| processor.save_pretrained(save_path) |
|
|
| print(f"\n🎉 TRAINING COMPLETE! Best F1 Score: {best_f1:.4f}") |
| print(f"Model saved to: ./models/layoutlmv3-sroie-best") |
|
|
|
|
| if __name__ == '__main__': |
| train() |
|
|