| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import os |
| import re |
| import string |
| from difflib import SequenceMatcher |
|
|
| from .log import log |
| import nltk |
| from better_profanity import profanity |
|
|
| from .guardrail_blocklist_utils import read_keyword_list_from_dir, to_ascii |
| from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner |
| from .misc import misc, Color, timer |
|
|
| DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/blocklist" |
| CENSOR = Color.red("*") |
|
|
|
|
| class Blocklist(ContentSafetyGuardrail): |
| def __init__( |
| self, |
| checkpoint_dir: str = DEFAULT_CHECKPOINT_DIR, |
| guardrail_partial_match_min_chars: int = 4, |
| guardrail_partial_match_letter_count: float = 0.5, |
| ) -> None: |
| nltk.data.path.append(os.path.join(checkpoint_dir, "nltk_data")) |
| self.lemmatizer = nltk.WordNetLemmatizer() |
| self.profanity = profanity |
| self.checkpoint_dir = checkpoint_dir |
| self.guardrail_partial_match_min_chars = guardrail_partial_match_min_chars |
| self.guardrail_partial_match_letter_count = guardrail_partial_match_letter_count |
|
|
| |
| self.blocklist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "custom")) |
| self.whitelist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "whitelist")) |
| self.exact_match_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "exact_match")) |
|
|
| self.profanity.load_censor_words(custom_words=self.blocklist_words, whitelist_words=self.whitelist_words) |
| log.debug(f"Loaded {len(self.blocklist_words)} words/phrases from blocklist") |
| log.debug(f"Whitelisted {len(self.whitelist_words)} words/phrases from whitelist") |
| log.debug(f"Loaded {len(self.exact_match_words)} exact match words/phrases from blocklist") |
|
|
| def uncensor_whitelist(self, input_prompt: str, censored_prompt: str) -> str: |
| """Explicitly uncensor words that are in the whitelist.""" |
| input_words = input_prompt.split() |
| censored_words = censored_prompt.split() |
| whitelist_words = set(self.whitelist_words) |
| for i, token in enumerate(input_words): |
| if token.strip(string.punctuation).lower() in whitelist_words: |
| censored_words[i] = token |
| censored_prompt = " ".join(censored_words) |
| return censored_prompt |
|
|
| def censor_prompt(self, input_prompt: str) -> tuple[bool, str]: |
| """Censor the prompt using the blocklist with better-profanity fuzzy matching. |
| |
| Args: |
| input_prompt: input prompt to censor |
| |
| Returns: |
| bool: True if the prompt is blocked, False otherwise |
| str: A message indicating why the prompt was blocked |
| """ |
| censored_prompt = self.profanity.censor(input_prompt, censor_char=CENSOR) |
| |
| censored_prompt = self.uncensor_whitelist(input_prompt, censored_prompt) |
| if CENSOR in censored_prompt: |
| return True, f"Prompt blocked by censorship: Censored Prompt: {censored_prompt}" |
| return False, "" |
|
|
| @staticmethod |
| def check_partial_match( |
| normalized_prompt: str, normalized_word: str, guardrail_partial_match_letter_count: float |
| ) -> tuple[bool, str]: |
| """ |
| Check robustly if normalized word and the matching target have a difference of up to guardrail_partial_match_letter_count characters. |
| |
| Args: |
| normalized_prompt: a string with many words |
| normalized_word: a string with one or multiple words, its length is smaller than normalized_prompt |
| guardrail_partial_match_letter_count: maximum allowed difference in characters (float to allow partial characters) |
| |
| Returns: |
| bool: True if a match is found, False otherwise |
| str: A message indicating why the prompt was blocked |
| """ |
| prompt_words = normalized_prompt.split() |
| word_length = len(normalized_word.split()) |
| max_similarity_ratio = (len(normalized_word) - float(guardrail_partial_match_letter_count)) / float( |
| len(normalized_word) |
| ) |
|
|
| for i in range(len(prompt_words) - word_length + 1): |
| |
| substring = " ".join(prompt_words[i : i + word_length]) |
| similarity_ratio = SequenceMatcher(None, substring, normalized_word).ratio() |
| if similarity_ratio >= max_similarity_ratio: |
| return ( |
| True, |
| f"Prompt blocked by partial match blocklist: Prompt: {normalized_prompt}, Partial Match Word: {normalized_word}", |
| ) |
|
|
| return False, "" |
|
|
| @staticmethod |
| def check_against_whole_word_blocklist( |
| prompt: str, |
| blocklist: list[str], |
| guardrail_partial_match_min_chars: int = 4, |
| guardrail_partial_match_letter_count: float = 0.5, |
| ) -> bool: |
| """ |
| Check if the prompt contains any whole words from the blocklist. |
| The match is case insensitive and robust to multiple spaces between words. |
| |
| Args: |
| prompt: input prompt to check |
| blocklist: list of words to check against |
| guardrail_partial_match_min_chars: minimum number of characters in a word to check for partial match |
| guardrail_partial_match_letter_count: maximum allowed difference in characters for partial match |
| |
| Returns: |
| bool: True if a match is found, False otherwise |
| str: A message indicating why the prompt was blocked |
| """ |
| |
| normalized_prompt = re.sub(r"\s+", " ", prompt).strip().lower() |
|
|
| for word in blocklist: |
| |
| normalized_word = re.sub(r"\s+", " ", word).strip().lower() |
|
|
| |
| if re.search(r"\b" + re.escape(normalized_word) + r"\b", normalized_prompt): |
| return True, f"Prompt blocked by exact match blocklist: Prompt: {prompt}, Exact Match Word: {word}" |
|
|
| |
| if len(normalized_word) >= guardrail_partial_match_min_chars: |
| match, message = Blocklist.check_partial_match( |
| normalized_prompt, normalized_word, guardrail_partial_match_letter_count |
| ) |
| if match: |
| return True, message |
|
|
| return False, "" |
|
|
| def is_safe(self, input_prompt: str = "") -> tuple[bool, str]: |
| """Check if the input prompt is safe using the blocklist.""" |
| |
| if not input_prompt: |
| return False, "Input is empty" |
| input_prompt = to_ascii(input_prompt) |
|
|
| |
| censored, message = self.censor_prompt(input_prompt) |
| if censored: |
| return False, message |
|
|
| |
| tokens = nltk.word_tokenize(input_prompt) |
| lemmas = [self.lemmatizer.lemmatize(token) for token in tokens] |
| lemmatized_prompt = " ".join(lemmas) |
| censored, message = self.censor_prompt(lemmatized_prompt) |
| if censored: |
| return False, message |
|
|
| |
| censored, message = self.check_against_whole_word_blocklist( |
| input_prompt, |
| self.exact_match_words, |
| self.guardrail_partial_match_min_chars, |
| self.guardrail_partial_match_letter_count, |
| ) |
| if censored: |
| return False, message |
|
|
| |
| return True, "Input is safe" |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--prompt", type=str, required=True, help="Input prompt") |
| parser.add_argument( |
| "--checkpoint_dir", |
| type=str, |
| help="Path to the Blocklist checkpoint folder", |
| default=DEFAULT_CHECKPOINT_DIR, |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main(args): |
| blocklist = Blocklist(checkpoint_dir=args.checkpoint_dir) |
| runner = GuardrailRunner(safety_models=[blocklist]) |
| with timer("blocklist safety check"): |
| safety, message = runner.run_safety_check(args.prompt) |
| log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}") |
| log.info(f"Message: {message}") if not safety else None |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| main(args) |
|
|