| | from bidict import bidict |
| | import pickle |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class Vocabulary(): |
| | """This class maps strings to integers, which also allow many namespaces |
| | """ |
| |
|
| | DEFAULT_PAD_TOKEN = '*@PAD@*' |
| | DEFAULT_UNK_TOKEN = '*@UNK@*' |
| |
|
| | def __init__(self, |
| | counters=dict(), |
| | min_count=dict(), |
| | pretrained_vocab=dict(), |
| | intersection_namespace=dict(), |
| | no_pad_namespace=list(), |
| | no_unk_namespace=list(), |
| | contain_pad_namespace=dict(), |
| | contain_unk_namespace=dict()): |
| | """initialize vocabulary |
| | |
| | Keyword Arguments: |
| | counters {dict} -- multiple counter (default: {dict()}) |
| | min_count {dict} -- min count dict (default: {dict()}) |
| | pretrained_vocab {dict} -- pretrained vocabulary (default: {dict()}) |
| | intersection_namespace {dict} -- intersection namespace correspond to pretrained vocabulary in case of too large pretrained vocabulary (default: {dict()}) |
| | no_pad_namespace {list} -- no paddding namespace (default: {list()}) |
| | no_unk_namespace {list} -- no unknown namespace (default: {list()}) |
| | contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()}) |
| | contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()}) |
| | """ |
| |
|
| | self.min_count = dict(min_count) |
| | self.intersection_namespace = dict(intersection_namespace) |
| | self.no_pad_namespace = set(no_pad_namespace) |
| | self.no_unk_namespace = set(no_unk_namespace) |
| | self.contain_pad_namespace = dict(contain_pad_namespace) |
| | self.contain_unk_namespace = dict(contain_unk_namespace) |
| | self.vocab = dict() |
| |
|
| | self.extend_from_counter(counters, self.min_count, self.no_pad_namespace, |
| | self.no_unk_namespace) |
| |
|
| | self.extend_from_pretrained_vocab(pretrained_vocab, self.intersection_namespace, |
| | self.no_pad_namespace, self.no_unk_namespace) |
| |
|
| | logger.info("Initialize vocabulary successfully.") |
| |
|
| | def extend_from_pretrained_vocab(self, |
| | pretrained_vocab, |
| | intersection_namespace=dict(), |
| | no_pad_namespace=list(), |
| | no_unk_namespace=list(), |
| | contain_pad_namespace=dict(), |
| | contain_unk_namespace=dict()): |
| | """extend vocabulary from pretrained vocab |
| | |
| | Arguments: |
| | pretrained_vocab {dict} -- pretrained vocabulary |
| | |
| | Keyword Arguments: |
| | intersection_namespace {dict} -- intersection namespace correspond to pretrained vocabulary in case of too large pretrained vocabulary (default: {dict()}) |
| | no_pad_namespace {list} -- no paddding namespace (default: {list()}) |
| | no_unk_namespace {list} -- no unknown namespace (default: {list()}) |
| | contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()}) |
| | contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()}) |
| | """ |
| |
|
| | self.intersection_namespace.update(dict(intersection_namespace)) |
| | self.no_pad_namespace.update(set(no_pad_namespace)) |
| | self.no_unk_namespace.update(set(no_unk_namespace)) |
| | self.contain_pad_namespace.update(dict(contain_pad_namespace)) |
| | self.contain_unk_namespace.update(dict(contain_unk_namespace)) |
| |
|
| | for namespace, vocab in pretrained_vocab.items(): |
| | self.__namespace_init(namespace) |
| | is_intersection = namespace in self.intersection_namespace |
| | intersection_vocab = self.vocab[ |
| | self.intersection_namespace[namespace]] if is_intersection else [] |
| | for key, value in vocab.items(): |
| | if not is_intersection or key in intersection_vocab: |
| | self.vocab[namespace][key] = value |
| |
|
| | logger.info( |
| | "Vocabulay {} (size: {}) was constructed successfully from pretrained_vocab.". |
| | format(namespace, len(self.vocab[namespace]))) |
| |
|
| | def extend_from_counter(self, |
| | counters, |
| | min_count=dict(), |
| | no_pad_namespace=list(), |
| | no_unk_namespace=list(), |
| | contain_pad_namespace=dict(), |
| | contain_unk_namespace=dict()): |
| | """extend vocabulary from counter |
| | |
| | Arguments: |
| | counters {dict} -- multiply counter |
| | |
| | Keyword Arguments: |
| | min_count {dict} -- min count dict (default: {dict()}) |
| | no_pad_namespace {list} -- no paddding namespace (default: {list()}) |
| | no_unk_namespace {list} -- no unknown namespace (default: {list()}) |
| | contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()}) |
| | contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()}) |
| | """ |
| |
|
| | self.no_pad_namespace.update(set(no_pad_namespace)) |
| | self.no_unk_namespace.update(set(no_unk_namespace)) |
| | self.contain_pad_namespace.update(dict(contain_pad_namespace)) |
| | self.contain_unk_namespace.update(dict(contain_unk_namespace)) |
| | self.min_count.update(dict(min_count)) |
| |
|
| | for namespace, counter in counters.items(): |
| | self.__namespace_init(namespace) |
| | for key in counter: |
| | minc = min_count[namespace] \ |
| | if min_count and namespace in min_count else 1 |
| | if counter[key] >= minc: |
| | self.vocab[namespace][key] = len(self.vocab[namespace]) |
| |
|
| | logger.info("Vocabulay {} (size: {}) was constructed successfully from counter.".format( |
| | namespace, len(self.vocab[namespace]))) |
| |
|
| | def add_tokens_to_namespace(self, tokens, namespace): |
| | """This function adds tokens to one namespace for extending vocabulary |
| | |
| | Arguments: |
| | tokens {list} -- token list |
| | namespace {str} -- namespace name |
| | """ |
| |
|
| | if namespace not in self.vocab: |
| | self.__namespace_init(namespace) |
| | logger.error('Add Namespace {} into vocabulary.'.format(namespace)) |
| |
|
| | for token in tokens: |
| | if token not in self.vocab[namespace]: |
| | self.vocab[namespace][token] = len(self.vocab[namespace]) |
| |
|
| | def get_token_index(self, token, namespace): |
| | """This function gets token index in one namespace of vocabulary |
| | |
| | Arguments: |
| | token {str} -- token |
| | namespace {str} -- namespace name |
| | |
| | Raises: |
| | RuntimeError: namespace not exists |
| | |
| | Returns: |
| | int -- token index |
| | """ |
| |
|
| | if token in self.vocab[namespace]: |
| | return self.vocab[namespace][token] |
| |
|
| | if namespace not in self.no_unk_namespace: |
| | return self.get_unknown_index(namespace) |
| |
|
| | logger.error("Can not find the index of {} from a no unknown token namespace {}.".format( |
| | token, namespace)) |
| | raise RuntimeError( |
| | "Can not find the index of {} from a no unknown token namespace {}.".format( |
| | token, namespace)) |
| |
|
| | def get_token_from_index(self, index, namespace): |
| | """This function gets token using index in vocabulary |
| | |
| | Arguments: |
| | index {int} -- index |
| | namespace {str} -- namespace name |
| | |
| | Raises: |
| | RuntimeError: index out of range |
| | |
| | Returns: |
| | str -- token |
| | """ |
| |
|
| | if index < len(self.vocab[namespace]): |
| | return self.vocab[namespace].inv[index] |
| |
|
| | logger.error("The index {} is out of vocabulary {} range.".format(index, namespace)) |
| | raise RuntimeError("The index {} is out of vocabulary {} range.".format(index, namespace)) |
| |
|
| | def get_vocab_size(self, namespace): |
| | """This function gets the size of one namespace in vocabulary |
| | |
| | Arguments: |
| | namespace {str} -- namespace name |
| | |
| | Returns: |
| | int -- vocabulary size |
| | """ |
| |
|
| | return len(self.vocab[namespace]) |
| |
|
| | def get_all_namespaces(self): |
| | """This function gets all namespaces |
| | |
| | Returns: |
| | list -- all namespaces vocabulary contained |
| | """ |
| |
|
| | return set(self.vocab) |
| |
|
| | def get_padding_index(self, namespace): |
| | """This function gets padding token index in one namespace of vocabulary |
| | |
| | Arguments: |
| | namespace {str} -- namespace name |
| | |
| | Raises: |
| | RuntimeError: no padding |
| | |
| | Returns: |
| | int -- padding index |
| | """ |
| |
|
| | if namespace not in self.vocab: |
| | raise RuntimeError("Namespace {} doesn't exist.".format(namespace)) |
| |
|
| | if namespace not in self.no_pad_namespace: |
| | if namespace not in self.contain_pad_namespace: |
| | return self.vocab[namespace][Vocabulary.DEFAULT_PAD_TOKEN] |
| | return self.vocab[namespace][self.contain_pad_namespace[namespace]] |
| |
|
| | logger.error("Namespace {} doesn't has paddding token.".format(namespace)) |
| | raise RuntimeError("Namespace {} doesn't has paddding token.".format(namespace)) |
| |
|
| | def get_unknown_index(self, namespace): |
| | """This function gets unknown token index in one namespace of vocabulary |
| | |
| | Arguments: |
| | namespace {str} -- namespace name |
| | |
| | Raises: |
| | RuntimeError: no unknown |
| | |
| | Returns: |
| | int -- unknown index |
| | """ |
| |
|
| | if namespace not in self.vocab: |
| | raise RuntimeError("Namespace {} doesn't exist.".format(namespace)) |
| |
|
| | if namespace not in self.no_unk_namespace: |
| | if namespace not in self.contain_unk_namespace: |
| | return self.vocab[namespace][Vocabulary.DEFAULT_UNK_TOKEN] |
| | return self.vocab[namespace][self.contain_unk_namespace[namespace]] |
| |
|
| | logger.error("Namespace {} doesn't has unknown token.".format(namespace)) |
| | raise RuntimeError("Namespace {} doesn't has unknown token.".format(namespace)) |
| |
|
| | def get_namespace_tokens(self, namesapce): |
| | """This function returns all tokens in one namespace |
| | |
| | Arguments: |
| | namesapce {str} -- namespce name |
| | |
| | Returns: |
| | dict_keys -- all tokens |
| | """ |
| |
|
| | return self.vocab[namesapce] |
| |
|
| | def save(self, file_path): |
| | """This function saves vocabulary into file |
| | |
| | Arguments: |
| | file_path {str} -- file path |
| | """ |
| |
|
| | pickle.dump(self, open(file_path, 'wb')) |
| |
|
| | @classmethod |
| | def load(cls, file_path): |
| | """This function loads vocabulary from file |
| | |
| | Arguments: |
| | file_path {str} -- file path |
| | |
| | Returns: |
| | Vocabulary -- vocabulary |
| | """ |
| |
|
| | return pickle.load(open(file_path, 'rb'), encoding='utf-8') |
| |
|
| | def __namespace_init(self, namespace): |
| | """This function initializes a namespace, |
| | adds pad and unk token to one namespace of vacabulary |
| | |
| | Arguments: |
| | namespace {str} -- namespace |
| | """ |
| |
|
| | self.vocab[namespace] = bidict() |
| |
|
| | if namespace not in self.no_pad_namespace and namespace not in self.contain_pad_namespace: |
| | self.vocab[namespace][Vocabulary.DEFAULT_PAD_TOKEN] = len(self.vocab[namespace]) |
| |
|
| | if namespace not in self.no_unk_namespace and namespace not in self.contain_unk_namespace: |
| | self.vocab[namespace][Vocabulary.DEFAULT_UNK_TOKEN] = len(self.vocab[namespace]) |
| |
|