| from typing import List, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from transformers import BertTokenizer |
|
|
| from src.models import BertForPunctuation |
|
|
| PUNCTUATION_SIGNS = ['', ',', '.', '?'] |
| PAUSE_TOKEN = 0 |
| MODEL_NAME = "verbit/hebrew_punctuation" |
|
|
|
|
| def tokenize_text( |
| word_list: List[str], pause_list: List[float], tokenizer: BertTokenizer |
| ) -> Tuple[List[int], List[int], List[float]]: |
| """ |
| Tokenizes text and generates pause list for each word |
| Args: |
| word_list: list of words |
| pause_list: list of pauses after each word in seconds |
| tokenizer: tokenizer |
| |
| Returns: |
| original_word_idx: list of indexes of original words |
| x: list of indexed words |
| pause: list of pauses after each word in seconds |
| """ |
| assert len(word_list) == len(pause_list), "word_list and pause_list should have the same length" |
| x, pause = [], [] |
|
|
| |
| |
| original_word_idx = [] |
| for w, p in zip(word_list, pause_list): |
| tokens = tokenizer.tokenize(w) |
| p = [p] |
| |
| _x = tokenizer.convert_tokens_to_ids(tokens) if tokens else [0] |
|
|
| if len(_x) > 1: |
| p = (len(_x) - 1) * [0] + p |
| x += _x |
| original_word_idx.append(len(x) - 1) |
| pause += p |
|
|
| return original_word_idx, x, pause |
|
|
|
|
| def gen_model_inputs( |
| x: List[int], |
| pause: List[float], |
| forward_context: int, |
| backward_context: int, |
| ) -> torch.Tensor: |
| """ |
| Generates inputs for model out of list of indexed words. |
| Inserts a pause token into the segment |
| Args: |
| x: list of indexed words |
| pause: list of corresponding pauses |
| forward_context: size of the forward context window |
| backward_context: size of the backward context window (without the predicted token)` |
| |
| Returns: |
| A tensor of model inputs for each indexed word in x |
| """ |
| model_input = [] |
| tokenized_pause = [PAUSE_TOKEN] * len(pause) |
| x_pad = [0] * backward_context + x + [0] * forward_context |
|
|
| for i in range(len(x)): |
| segment = x_pad[i : i + backward_context + forward_context + 1] |
| segment.insert(backward_context + 1, tokenized_pause[i]) |
| model_input.append(segment) |
| return torch.tensor(model_input) |
|
|
|
|
| def add_punctuation_to_text(text: str, punct_prob: np.ndarray) -> str: |
| """ |
| Inserts punctuation to text on provided punctuation string for every word |
| Args: |
| text: text to insert punctuation to |
| punct_prob: matrix of probabilities for each punctuation |
| |
| Returns: |
| text with punctuation |
| """ |
| words = text.split() |
| new_words = list() |
|
|
| punctuation_idx = np.argmax(punct_prob, axis=1) |
| punctuation_list = [PUNCTUATION_SIGNS[i] for i in punctuation_idx] |
|
|
| for word, punctuation_str in zip(words, punctuation_list): |
| if punctuation_str: |
| new_words.append(word + punctuation_str) |
| else: |
| new_words.append(word) |
|
|
| punct_text = ' '.join(new_words) |
| return punct_text |
|
|
|
|
| def get_prediction( |
| model: BertForPunctuation, |
| text: str, |
| tokenizer: BertTokenizer, |
| batch_size: int = 16, |
| backward_context: int = 15, |
| forward_context: int = 16, |
| pause_list: Optional[List[float]] = None, |
| device: str = 'cpu', |
| ) -> str: |
| """ |
| Generates predictions for given list of words. |
| Args: |
| model: punctuation model |
| text: text to predict punctuation for |
| tokenizer: tokenizer |
| batch_size: batch size |
| backward_context: size of the backward context window |
| forward_context: size of the forward context window |
| pause_list: list of pauses after each word in seconds |
| device: device to run model on |
| |
| Returns: |
| text with punctuation |
| """ |
| word_list = text.split() |
| if not pause_list: |
| |
| pause_list = [0.0] * len(word_list) |
|
|
| word_idx, x, pause = tokenize_text(word_list=word_list, pause_list=pause_list, tokenizer=tokenizer) |
|
|
| model_inputs = gen_model_inputs(x, pause, forward_context, backward_context) |
| model_inputs = model_inputs.index_select(0, torch.LongTensor(word_idx)).to(device) |
| inputs_length = len(model_inputs) |
|
|
| output = [] |
| with torch.no_grad(): |
| for ndx in range(0, inputs_length, batch_size): |
| o = model(model_inputs[ndx : min(ndx + batch_size, inputs_length)]) |
| o = F.softmax(o, dim=1) |
| output.append(o.cpu().data.numpy()) |
|
|
| punct_probabilities_matrix = np.concatenate(output, axis=0) |
|
|
| punct_text = add_punctuation_to_text(text, punct_probabilities_matrix) |
|
|
| return punct_text |
|
|
|
|
| def main(): |
| model = BertForPunctuation.from_pretrained(MODEL_NAME) |
| tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) |
| model.eval() |
|
|
| text = """讞讘专转 讜专讘讬讟 驻讬转讞讛 诪注专讻转 诇转诪诇讜诇 讛诪讘讜住住转 注诇 讘讬谞讛 诪诇讗讻讜转讬转 讜讙讜专诐 讗谞讜砖讬 讜砖讜拽讚转 注诇 转诪诇讜诇 注讚讜讬讜转 谞讬爪讜诇讬 砖讜讗讛 |
| 讗转 讛转讜爪讗讜转 讗驻砖专 诇专讗讜转 讻讘专 讘专砖转 讘讛谉 讞诇拽讬诐 诪注讚讜转讜 砖诇 讟讜讘讬讛 讘讬讬诇住拽讬 砖讛讬讛 诪驻拽讚 讙讚讜讚 讛驻专讟讬讝谞讬诐 讛讬讛讜讚讬诐 讘讘讬讬诇讜专讜住讬讛""" |
| punct_text = get_prediction( |
| model=model, |
| text=text, |
| tokenizer=tokenizer, |
| backward_context=model.config.backward_context, |
| forward_context=model.config.forward_context, |
| ) |
| print(punct_text) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|