| |
| """ |
| Example usage of EasyOCR ONNX models for text detection and recognition. |
| """ |
|
|
| import onnxruntime as ort |
| import cv2 |
| import numpy as np |
| from typing import List |
| import argparse |
| import os |
|
|
| class EasyOCR_ONNX: |
| """ONNX implementation of EasyOCR for text detection and recognition.""" |
| |
| def __init__(self, |
| detector_path: str = "craft_mlt_25k_jpqd.onnx", |
| recognizer_path: str = "english_g2_jpqd.onnx"): |
| """ |
| Initialize EasyOCR ONNX models. |
| |
| Args: |
| detector_path: Path to CRAFT detection model |
| recognizer_path: Path to text recognition model |
| """ |
| print(f"Loading detector: {detector_path}") |
| self.detector = ort.InferenceSession(detector_path) |
| |
| print(f"Loading recognizer: {recognizer_path}") |
| self.recognizer = ort.InferenceSession(recognizer_path) |
| |
| |
| self.english_charset = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ' |
| self.latin_charset = self._get_latin_charset() |
| |
| |
| if "english" in recognizer_path.lower(): |
| self.charset = self.english_charset |
| elif "latin" in recognizer_path.lower(): |
| self.charset = self.latin_charset |
| else: |
| self.charset = self.english_charset |
| |
| def _get_latin_charset(self) -> str: |
| """Get extended Latin character set.""" |
| |
| basic = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ' |
| extended = 'àáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿĀāĂ㥹ĆćĈĉĊċČčĎďĐđĒēĔĕĖėĘęĚě' |
| return basic + extended |
| |
| def preprocess_for_detection(self, image: np.ndarray, target_size: int = 640) -> np.ndarray: |
| """Preprocess image for CRAFT text detection.""" |
| |
| image_resized = cv2.resize(image, (target_size, target_size)) |
| |
| |
| image_norm = image_resized.astype(np.float32) / 255.0 |
| |
| |
| image_chw = np.transpose(image_norm, (2, 0, 1)) |
| |
| |
| image_batch = np.expand_dims(image_chw, axis=0) |
| |
| return image_batch |
| |
| def preprocess_for_recognition(self, text_region: np.ndarray) -> np.ndarray: |
| """Preprocess text region for CRNN recognition.""" |
| |
| if len(text_region.shape) == 3: |
| gray = cv2.cvtColor(text_region, cv2.COLOR_RGB2GRAY) |
| else: |
| gray = text_region |
| |
| |
| resized = cv2.resize(gray, (100, 32)) |
| |
| |
| normalized = resized.astype(np.float32) / 255.0 |
| |
| |
| input_batch = np.expand_dims(np.expand_dims(normalized, axis=0), axis=0) |
| |
| return input_batch |
| |
| def detect_text(self, image: np.ndarray) -> np.ndarray: |
| """ |
| Detect text regions in image using CRAFT model. |
| |
| Args: |
| image: Input image (RGB format) |
| |
| Returns: |
| Detection output maps |
| """ |
| |
| input_batch = self.preprocess_for_detection(image) |
| |
| |
| outputs = self.detector.run(None, {"input": input_batch}) |
| |
| |
| if isinstance(outputs[0], np.ndarray): |
| return outputs[0] |
| else: |
| return np.array(outputs[0]) |
| |
| def recognize_text(self, text_regions: List[np.ndarray]) -> List[str]: |
| """ |
| Recognize text in detected regions. |
| |
| Args: |
| text_regions: List of cropped text region images |
| |
| Returns: |
| List of recognized text strings |
| """ |
| results = [] |
| |
| for region in text_regions: |
| |
| input_batch = self.preprocess_for_recognition(region) |
| |
| |
| outputs = self.recognizer.run(None, {"input": input_batch}) |
| |
| |
| output_array = outputs[0] if isinstance(outputs[0], np.ndarray) else np.array(outputs[0]) |
| text = self._decode_text(output_array) |
| results.append(text) |
| |
| return results |
| |
| def _decode_text(self, output: np.ndarray) -> str: |
| """Decode recognition output to text string using greedy decoding.""" |
| |
| indices = np.argmax(output[0], axis=1) |
| |
| |
| text = '' |
| prev_char = '' |
| |
| for idx in indices: |
| if idx < len(self.charset) and idx > 0: |
| char = self.charset[idx] |
| |
| if char != prev_char: |
| text += char |
| prev_char = char |
| |
| return text.strip() |
| |
| def extract_simple_regions(self, detection_output: np.ndarray, |
| original_image: np.ndarray, |
| threshold: float = 0.3) -> List[np.ndarray]: |
| """ |
| Extract text regions from detection output (simplified version). |
| In practice, you'd implement proper CRAFT post-processing. |
| """ |
| |
| |
| |
| h, w = original_image.shape[:2] |
| |
| |
| if len(detection_output.shape) == 4: |
| detection_map = detection_output[0, 0] |
| elif len(detection_output.shape) == 3: |
| detection_map = detection_output[0] |
| else: |
| detection_map = detection_output |
| |
| |
| if detection_map.max() > 1.0: |
| detection_map = detection_map / detection_map.max() |
| |
| |
| binary_map = (detection_map > threshold).astype(np.uint8) * 255 |
| binary_map = cv2.resize(binary_map, (w, h)) |
| |
| |
| kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) |
| binary_map = cv2.morphologyEx(binary_map, cv2.MORPH_CLOSE, kernel) |
| |
| contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| |
| text_regions = [] |
| for contour in contours: |
| |
| x, y, w_box, h_box = cv2.boundingRect(contour) |
| |
| |
| if w_box > 15 and h_box > 8 and cv2.contourArea(contour) > 100: |
| |
| x = max(0, x - 2) |
| y = max(0, y - 2) |
| w_box = min(w - x, w_box + 4) |
| h_box = min(h - y, h_box + 4) |
| |
| |
| region = original_image[y:y+h_box, x:x+w_box] |
| if region.size > 0: |
| text_regions.append(region) |
| |
| |
| if len(text_regions) == 0: |
| print(" No CRAFT regions found, using fallback method...") |
| |
| step_y, step_x = h // 4, w // 4 |
| for y in range(0, h - 32, step_y): |
| for x in range(0, w - 100, step_x): |
| region = original_image[y:y+32, x:x+100] |
| if region.size > 0 and np.mean(region) < 240: |
| text_regions.append(region) |
| if len(text_regions) >= 4: |
| break |
| if len(text_regions) >= 4: |
| break |
| |
| return text_regions |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="EasyOCR ONNX Example") |
| parser.add_argument("--image", type=str, required=True, help="Path to input image") |
| parser.add_argument("--detector", type=str, default="craft_mlt_25k_jpqd.onnx", |
| help="Path to detection model") |
| parser.add_argument("--recognizer", type=str, default="english_g2_jpqd.onnx", |
| help="Path to recognition model") |
| parser.add_argument("--output", type=str, help="Path to save output image with detections") |
| |
| args = parser.parse_args() |
| |
| |
| if not os.path.exists(args.image): |
| print(f"Error: Image file not found: {args.image}") |
| return |
| |
| if not os.path.exists(args.detector): |
| print(f"Error: Detector model not found: {args.detector}") |
| return |
| |
| if not os.path.exists(args.recognizer): |
| print(f"Error: Recognizer model not found: {args.recognizer}") |
| return |
| |
| |
| print("Initializing EasyOCR ONNX...") |
| ocr = EasyOCR_ONNX(args.detector, args.recognizer) |
| |
| |
| print(f"Loading image: {args.image}") |
| image = cv2.imread(args.image) |
| if image is None: |
| print(f"Error: Could not load image: {args.image}") |
| return |
| |
| |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| |
| |
| print("Detecting text regions...") |
| detection_output = ocr.detect_text(image_rgb) |
| |
| |
| text_regions = ocr.extract_simple_regions(detection_output, image_rgb) |
| print(f"Found {len(text_regions)} text regions") |
| |
| |
| if text_regions: |
| print("Recognizing text...") |
| recognized_texts = ocr.recognize_text(text_regions) |
| |
| |
| print(f"\nRecognized text ({len(recognized_texts)} regions):") |
| print("-" * 50) |
| for i, text in enumerate(recognized_texts): |
| print(f"Region {i+1}: '{text}'") |
| else: |
| print("No text regions detected") |
| |
| |
| if args.output and text_regions: |
| output_image = image.copy() |
| |
| cv2.imwrite(args.output, output_image) |
| print(f"Output saved to: {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |