| | import torch
|
| | import cv2
|
| | import pytesseract
|
| | from PIL import Image, ImageDraw, ImageFont
|
| | from collections import deque
|
| | import numpy as np
|
| | import os
|
| |
|
| |
|
| |
|
| | def get_full_img_path(src_dir):
|
| | """
|
| | input: Đường dẫn đền folder chứa ảnh
|
| | output: Danh sách tên của tất cả các ảnh
|
| | """
|
| | list_img_names = []
|
| | for dirname, _, filenames in os.walk(src_dir):
|
| | for filename in filenames:
|
| | path = os.path.join(dirname, filename).replace(src_dir, '')
|
| | if path[0] == '/':
|
| | path = path[1:]
|
| | list_img_names.append(path)
|
| | return list_img_names
|
| |
|
| |
|
| | def create_text_mask(src_img, detect_text_model, kernel_size=5, iterations=3):
|
| | """
|
| | input: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
|
| | output: Mask đánh dấu text trong ảnh gốc, 0 là chữ, 1 là nền; shape: [H, W]
|
| | """
|
| | img = torch.from_numpy(src_img).to(torch.uint8).to(detect_text_model.device)
|
| | imgT = (img / 255).unsqueeze(0).permute(0, -1, -3, -2)
|
| |
|
| | detect_text_model.eval()
|
| | with torch.no_grad():
|
| | result = detect_text_model(imgT).squeeze()
|
| | result = (result >= 0.5).detach().cpu().numpy()
|
| |
|
| | mask = ((1-result) * 255).astype(np.uint8)
|
| |
|
| | kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| | mask = cv2.erode(mask, kernel, iterations=iterations)
|
| | mask = cv2.dilate(mask, kernel, iterations=2*iterations)
|
| | mask = cv2.erode(mask, kernel, iterations=iterations)
|
| |
|
| | mask = (1 - mask // 255).astype(np.uint8)
|
| | return mask
|
| |
|
| |
|
| | def create_wordball_mask(src_img, detect_wordball_model, kernel_size=5, iterations=3):
|
| | """
|
| | input: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
|
| | output: Mask đánh dấu text trong ảnh gốc, 0 là chữ, 1 là nền; shape: [H, W]
|
| | """
|
| | img = torch.from_numpy(src_img).to(torch.uint8).to(detect_wordball_model.device)
|
| | imgT = (img / 255).unsqueeze(0).permute(0, -1, -3, -2)
|
| |
|
| | detect_wordball_model.eval()
|
| | with torch.no_grad():
|
| | result = detect_wordball_model(imgT).squeeze()
|
| | result = (result >= 0.5).detach().cpu().numpy()
|
| |
|
| | mask = ((1-result) * 255).astype(np.uint8)
|
| |
|
| | kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| | mask = cv2.erode(mask, kernel, iterations=iterations)
|
| | mask = cv2.dilate(mask, kernel, iterations=2*iterations)
|
| | mask = cv2.erode(mask, kernel, iterations=iterations)
|
| |
|
| | mask = (1 - mask // 255).astype(np.uint8)
|
| | return mask
|
| |
|
| |
|
| | def clear_text(src_img, text_msk, wordball_msk, text_value=0, non_text_value=1, r=5):
|
| | """
|
| | input: src_img: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
|
| | text_msk: Mask đánh dấu text trong ảnh gốc; shape: [H, W]
|
| | text_value: Giá trị mà trong mặt nạ nó là text
|
| | non_text_value: Giá trị mà trong mặt nạ nó là nền
|
| | r: Bán kính để sử dụng cho việc xoá text và vẽ lại phần bị xoá
|
| | output: Ảnh sau khi xoá text, để dưới định dạng là np.array, shape: [H, W, C]
|
| | """
|
| | MAX = max(text_value, non_text_value)
|
| | MIN = min(text_value, non_text_value)
|
| |
|
| | scale_text_value = (text_value - MIN) / (MAX - MIN)
|
| | scale_non_text_value = (non_text_value - MIN) / (MAX - MIN)
|
| |
|
| | text_msk[text_msk==text_value] = scale_text_value
|
| | text_msk[text_msk==non_text_value] = scale_non_text_value
|
| |
|
| | wordball_msk[wordball_msk==text_value] = scale_text_value
|
| | wordball_msk[wordball_msk==non_text_value] = scale_non_text_value
|
| |
|
| | if scale_text_value == 0:
|
| | text_msk = 1 - text_msk
|
| | wordball_msk = 1 - wordball_msk
|
| | text_msk = text_msk * 255
|
| |
|
| | remove_txt = cv2.inpaint(src_img, text_msk, r, cv2.INPAINT_TELEA)
|
| | remove_wordball = remove_txt.copy()
|
| | remove_wordball[wordball_msk==1] = 255
|
| |
|
| | return remove_wordball
|
| |
|
| |
|
| | def dfs(grid, y, x, visited, value):
|
| | """
|
| | Thuật toán tìm miền liên thông, xem thêm về đồ thị nếu không biết nó là gì
|
| | Output: Một HCN bao phủ miền liên thông + Diện tích của miền liên thông
|
| | """
|
| | max_y, max_x = y, x
|
| | min_y, min_x = y+1, x+1
|
| | area = 0
|
| |
|
| | stack = deque([(y, x)])
|
| | while stack:
|
| | y, x = stack.pop()
|
| |
|
| | max_x = max(max_x, x)
|
| | max_y = max(max_y, y)
|
| | min_x = min(min_x, x)
|
| | min_y = min(min_y, y)
|
| |
|
| | if (y, x) not in visited:
|
| | visited.add((y, x))
|
| | area += 1
|
| |
|
| | for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]:
|
| | nx, ny = x + dx, y + dy
|
| | if 0 <= ny < grid.shape[0] and 0 <= nx < grid.shape[1] and grid[ny, nx] == value and (ny, nx) not in visited:
|
| | stack.append((ny, nx))
|
| |
|
| | return (min_x, min_y, max_x, max_y), area
|
| |
|
| |
|
| | def find_clusters(grid, value):
|
| | """
|
| | Thuật toán tìm danh sách các miền liên thông
|
| | """
|
| | visited = set()
|
| | clusters = []
|
| | areas = []
|
| |
|
| | for y in range(grid.shape[0]):
|
| | for x in range(grid.shape[1]):
|
| | if grid[y, x] == value and (y, x) not in visited:
|
| | cluster, area = dfs(grid, y, x, visited, value)
|
| | clusters.append(cluster)
|
| | areas.append(area)
|
| |
|
| | return clusters, areas
|
| |
|
| | def get_text_positions(text_msk, text_value=0):
|
| | """
|
| | input: text_msk: Mask đánh dấu text trong ảnh gốc; shape: [H, W]
|
| | text_value: Giá trị mà trong mặt nạ nó là text
|
| | min_area: Giả trị tối thiểu của vùng có thể có text
|
| | output: Danh sách các cùng chứa text, định dạng (min_x, min_y, max_x, max_y)
|
| | """
|
| |
|
| | clusters, areas = find_clusters(text_msk, value=text_value)
|
| | return clusters, areas
|
| |
|
| | def filter_text_positions(clusters, areas, min_area=1200, max_area=10000):
|
| | clusters = clusters[(areas >= min_area) & (areas <= max_area)]
|
| | return clusters
|
| |
|
| |
|
| | def get_list_texts(src_img, text_positions, lang='eng'):
|
| | """
|
| | input: src_img: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
|
| | text_positions: Danh sách các cùng chứa text, định dạng (min_x, min_y, max_x, max_y)
|
| | lang: Ngôn ngữ của text
|
| | output: Danh sách các câu text
|
| | """
|
| | list_texts = []
|
| | for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
|
| | crop_img = src_img[min_y:max_y+1, min_x:max_x+1]
|
| | img_rgb = cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)
|
| | img = Image.fromarray(img_rgb)
|
| | text = pytesseract.image_to_string(img, lang=lang).replace('\n', ' ').strip()
|
| | while ' ' in text:
|
| | text = text.replace(' ', ' ')
|
| | list_texts.append(text)
|
| | return list_texts
|
| |
|
| |
|
| | def translate(list_texts, translator):
|
| | translated_texts = []
|
| | for text in list_texts:
|
| | if not text:
|
| | text = 'a'
|
| | translated_text = translator.translate(text, src='en', dest='vi').text
|
| | translated_texts.append(translated_text)
|
| | return translated_texts
|
| |
|
| |
|
| | def add_centered_multiline_text(image, text, box, font_path="arial.ttf", font_size=36, pad=5, text_color=0):
|
| |
|
| | draw = ImageDraw.Draw(image)
|
| |
|
| |
|
| | min_x, min_y, max_x, max_y = box
|
| |
|
| |
|
| | font = ImageFont.truetype(font_path, font_size)
|
| |
|
| |
|
| | wrapped_lines = wrap_text(text, font, draw, max_x - min_x)
|
| |
|
| |
|
| | total_text_height = sum(get_text_height(line, draw, font) for line in wrapped_lines)
|
| |
|
| |
|
| | start_y = min_y + (max_y - min_y - total_text_height) // 2
|
| |
|
| |
|
| | current_y = start_y
|
| | for line in wrapped_lines:
|
| | text_width, text_height = get_text_dimensions(line, draw, font)
|
| | text_x = min_x + (max_x - min_x - text_width) // 2
|
| | draw.text((text_x, current_y), line, fill=text_color, font=font)
|
| | current_y += text_height + pad
|
| |
|
| |
|
| | return image
|
| |
|
| | def get_text_dimensions(text, draw, font):
|
| | """Trả về (width, height) của văn bản."""
|
| | bbox = draw.textbbox((0, 0), text, font=font)
|
| | width = bbox[2] - bbox[0]
|
| | height = bbox[3] - bbox[1]
|
| | return width, height
|
| |
|
| | def get_text_height(text, draw, font):
|
| | """Trả về chiều cao của văn bản."""
|
| | _, _, _, height = draw.textbbox((0, 0), text, font=font)
|
| | return height
|
| |
|
| | def wrap_text(text, font, draw, max_width):
|
| | """Chia văn bản thành nhiều dòng dựa trên chiều rộng tối đa."""
|
| | words = text.split()
|
| | lines = []
|
| | current_line = ""
|
| |
|
| | for word in words:
|
| |
|
| | test_line = f"{current_line} {word}".strip()
|
| | test_width, _ = get_text_dimensions(test_line, draw, font)
|
| |
|
| | if test_width <= max_width:
|
| | current_line = test_line
|
| | else:
|
| |
|
| | lines.append(current_line)
|
| | current_line = word
|
| |
|
| |
|
| | if current_line:
|
| | lines.append(current_line)
|
| |
|
| | return lines
|
| |
|
| | def insert_text(non_text_src_img, list_translated_texts, text_positions, font=['MTO Astro City.ttf'], font_size=[20], pad=[5], text_color=0, stroke=[3]):
|
| |
|
| | img_bgr = non_text_src_img.copy()
|
| |
|
| |
|
| | for idx, text in enumerate(list_translated_texts):
|
| |
|
| | mask1 = Image.new("L", img_bgr.shape[:2][::-1], 255)
|
| | mask2 = Image.new("L", img_bgr.shape[:2][::-1], 255)
|
| | mask1 = add_centered_multiline_text(mask1, text, text_positions[idx], f'MTO Font/{font[idx]}', font_size[idx], pad=pad[idx], text_color=text_color)
|
| |
|
| |
|
| | mask1 = (np.array(mask1) >= 127).astype(np.uint8) * 255
|
| | mask1 = cv2.cvtColor(mask1, cv2.COLOR_RGB2BGR)
|
| |
|
| | if stroke[idx] > 0:
|
| | mask2 = np.array(mask2).astype(np.uint8)
|
| | mask2 = cv2.cvtColor(mask2, cv2.COLOR_RGB2BGR)
|
| |
|
| | mask2 = mask2 - mask1
|
| | kernel = np.ones((stroke[idx]+1, stroke[idx]+1), np.uint8)
|
| | mask2 = cv2.dilate(mask2, kernel, iterations=1)
|
| | img_bgr[mask2==255] = 255
|
| |
|
| | img_bgr[mask1==text_color] = text_color
|
| | return img_bgr
|
| |
|
| |
|
| | def save_img(path, translated_text_src_img):
|
| | """
|
| | input: path: Đường dẫn đến ảnh gốc ban đầu
|
| | translated_text_src_img: Ảnh sau khi được dịch
|
| | output: Ảnh sau dịch được lưu lại, trong tên có thêm "translated-"
|
| | """
|
| | dot = path.rfind('.')
|
| | last_slash = -1
|
| | if '/' in path:
|
| | last_slash = path.rfind('/')
|
| |
|
| | ext = path[dot:]
|
| | parent_path = path[:last_slash+1]
|
| | name = path[last_slash+1:dot]
|
| |
|
| | if parent_path and not os.path.exists(parent_path):
|
| | os.mkdir(parent_path)
|
| | cv2.imwrite(f'{parent_path}translated-{name}{ext}', translated_text_src_img)
|
| | print(f'Image saved at {parent_path}translated-{name}{ext}') |