| import os |
| import torch |
| import evaluate |
| import numpy as np |
| import pandas as pd |
| import glob as glob |
| import torch.optim as optim |
| import matplotlib.pyplot as plt |
| import torchvision.transforms as transforms |
| import subprocess |
|
|
| from flask import Flask, request, jsonify |
| from PIL import Image |
| from zipfile import ZipFile |
| from tqdm.notebook import tqdm |
| from dataclasses import dataclass |
| from torch.utils.data import Dataset |
| from urllib.request import urlretrieve |
| from transformers import ( |
| VisionEncoderDecoderModel, |
| TrOCRProcessor, |
| Seq2SeqTrainer, |
| Seq2SeqTrainingArguments, |
| default_data_collator |
| ) |
| from roboflow import Roboflow |
| rf = Roboflow(api_key="kGIFR6wPmDow2dHnoXoi") |
| project = rf.workspace("capstone-design-oyzc3").project("dataset-train-test") |
| dataset = project.version(1).download("folder") |
|
|
| |
| |
|
|
| |
| subprocess.run(['wget', '--no-check-certificate', 'https://docs.google.com/uc?export=download&id=12reT7rxiRqTERYqeKYx7WGz5deMXjnEo', '-O', 'filetxt']) |
| subprocess.run(['unzip', 'filetxt']) |
|
|
| def seed_everything(seed_value): |
| np.random.seed(seed_value) |
| torch.manual_seed(seed_value) |
| torch.cuda.manual_seed_all(seed_value) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
| seed_everything(42) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| def download_and_unzip(url, save_path): |
| print(f"Downloading and extracting assets....", end="") |
|
|
|
|
| |
| urlretrieve(url, save_path) |
|
|
|
|
| try: |
| |
| with ZipFile(save_path) as z: |
| |
| z.extractall(os.path.split(save_path)[0]) |
|
|
|
|
| print("Done") |
|
|
|
|
| except Exception as e: |
| print("\nInvalid file.", e) |
|
|
| URL = r"https://app.roboflow.com/ds/TZnI5u5spH?key=krcK5FWtuB" |
| asset_zip_path = os.path.join(os.getcwd(), "capstone-design-oyzc3.zip") |
|
|
| |
| if not os.path.exists(asset_zip_path): |
| download_and_unzip(URL, asset_zip_path) |
|
|
| @dataclass(frozen=True) |
| class TrainingConfig: |
| BATCH_SIZE: int = 25 |
| EPOCHS: int = 20 |
| LEARNING_RATE: float = 0.00005 |
|
|
| @dataclass(frozen=True) |
| class DatasetConfig: |
| DATA_ROOT: str = 'DATASET-TRAIN-TEST-1' |
| |
| |
| @dataclass(frozen=True) |
| class ModelConfig: |
| MODEL_NAME: str = 'microsoft/trocr-small-printed' |
|
|
| def visualize(dataset_path): |
| plt.figure(figsize=(15, 3)) |
| for i in range(15): |
| plt.subplot(3, 5, i+1) |
| all_images = os.listdir(f"{dataset_path}/train/train") |
| image = plt.imread(f"{dataset_path}/train/train/{all_images[i]}") |
| plt.imshow(image) |
| plt.axis('off') |
| plt.title(all_images[i].split('.')[0]) |
| plt.show() |
|
|
|
|
| visualize(DatasetConfig.DATA_ROOT) |
|
|
| train_df = pd.read_fwf( |
| os.path.join('train.txt'), header=None |
| ) |
| train_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True) |
| test_df = pd.read_fwf( |
| os.path.join('test.txt'), header=None |
| ) |
| test_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True) |
|
|
| |
| train_transforms = transforms.Compose([ |
| transforms.ColorJitter(brightness=.5, hue=.3), |
| transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)), |
| ]) |
|
|
| class CustomOCRDataset(Dataset): |
| def __init__(self, root_dir, df, processor, max_target_length=128): |
| self.root_dir = root_dir |
| self.df = df |
| self.processor = processor |
| self.max_target_length = max_target_length |
|
|
|
|
| def __len__(self): |
| return len(self.df) |
|
|
|
|
| def __getitem__(self, idx): |
| |
| file_name = self.df['file_name'][idx] |
| |
| text = self.df['text'][idx] |
| |
| image = Image.open(self.root_dir + file_name).convert('RGB') |
| image = train_transforms(image) |
| pixel_values = self.processor(image, return_tensors='pt').pixel_values |
| |
| |
| labels = self.processor.tokenizer( |
| text, |
| padding='max_length', |
| max_length=self.max_target_length |
| ).input_ids |
| |
| labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] |
| encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} |
| return encoding |
|
|
| processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME) |
| train_dataset = CustomOCRDataset( |
| root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'train/train/'), |
| df=train_df, |
| processor=processor |
| ) |
| valid_dataset = CustomOCRDataset( |
| root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'test/test/'), |
| df=test_df, |
| processor=processor |
| ) |
|
|
| model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME) |
| model.to(device) |
| print(model) |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"{total_params:,} total parameters.") |
| total_trainable_params = sum( |
| p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"{total_trainable_params:,} training parameters.") |
|
|
| |
| model.config.decoder_start_token_id = processor.tokenizer.cls_token_id |
| model.config.pad_token_id = processor.tokenizer.pad_token_id |
| |
| model.config.vocab_size = model.config.decoder.vocab_size |
| model.config.eos_token_id = processor.tokenizer.sep_token_id |
|
|
|
|
| model.config.max_length = 64 |
| model.config.early_stopping = True |
| model.config.no_repeat_ngram_size = 3 |
| model.config.length_penalty = 2.0 |
| model.config.num_beams = 4 |
|
|
| optimizer = optim.AdamW( |
| model.parameters(), lr=TrainingConfig.LEARNING_RATE, weight_decay=0.0005 |
| ) |
|
|
| cer_metric = evaluate.load('cer') |
|
|
|
|
| def compute_cer(pred): |
| labels_ids = pred.label_ids |
| pred_ids = pred.predictions |
|
|
|
|
| pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) |
| labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id |
| label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) |
|
|
|
|
| cer = cer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
|
| return {"cer": cer} |
|
|
| training_args = Seq2SeqTrainingArguments( |
| predict_with_generate=True, |
| evaluation_strategy='epoch', |
| per_device_train_batch_size=TrainingConfig.BATCH_SIZE, |
| per_device_eval_batch_size=TrainingConfig.BATCH_SIZE, |
| fp16=False, |
| |
| |
| output_dir='seq2seq_model_printed/', |
| logging_strategy='epoch', |
| save_strategy='epoch', |
| save_total_limit=5, |
| report_to='tensorboard', |
| num_train_epochs=TrainingConfig.EPOCHS |
| ) |
|
|
| |
| trainer = Seq2SeqTrainer( |
| model=model, |
| tokenizer=processor.image_processor, |
| args=training_args, |
| compute_metrics=compute_cer, |
| train_dataset=train_dataset, |
| eval_dataset=valid_dataset, |
| data_collator=default_data_collator |
| ) |
|
|
| res = trainer.train() |
|
|
| processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME) |
| trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step)).to(device) |
|
|
| def read_and_show(image_path): |
| """ |
| :param image_path: String, path to the input image. |
| |
| |
| Returns: |
| image: PIL Image. |
| """ |
| image = Image.open(image_path).convert('RGB') |
| return image |
|
|
| def ocr(image, processor, model): |
| """ |
| :param image: PIL Image. |
| :param processor: Huggingface OCR processor. |
| :param model: Huggingface OCR model. |
| |
| |
| Returns: |
| generated_text: the OCR'd text string. |
| """ |
| |
| pixel_values = processor(image, return_tensors='pt').pixel_values.to(device) |
| generated_ids = model.generate(pixel_values) |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
| return generated_text |
|
|
| def eval_new_data( |
| data_path=os.path.join(DatasetConfig.DATA_ROOT, 'test/test', '*'), |
| num_samples=50 |
| ): |
| image_paths = glob.glob(data_path) |
| for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)): |
| if i == num_samples: |
| break |
| image = read_and_show(image_path) |
| text = ocr(image, processor, trained_model) |
| plt.figure(figsize=(7, 4)) |
| plt.imshow(image) |
| plt.title(text) |
| plt.axis('off') |
| plt.show() |
|
|
| eval_new_data( |
| data_path=os.path.join(DatasetConfig.DATA_ROOT, 'test/test/', '*'), |
| num_samples=100 |
| ) |
|
|
|
|