| import nltk |
| from nltk.corpus import stopwords |
| from nltk import word_tokenize, pos_tag |
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| import hashlib |
| from scipy.stats import norm |
| import gensim |
| import pdb |
| from transformers import BertForMaskedLM as WoBertForMaskedLM |
| from wobert import WoBertTokenizer |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
| from transformers import BertForMaskedLM, BertTokenizer, RobertaForSequenceClassification, RobertaTokenizer |
| import gensim.downloader as api |
| import Levenshtein |
| import string |
| import spacy |
| import paddle |
| from jieba import posseg |
|
|
| paddle.enable_static() |
| import re |
| def cut_sent(para): |
| para = re.sub('([。!?\?])([^”’])', r'\1\n\2', para) |
| para = re.sub('([。!?\?][”’])([^,。!?\?\n ])', r'\1\n\2', para) |
| para = re.sub('(\.{6}|\…{2})([^”’\n])', r'\1\n\2', para) |
| para = re.sub('([^。!?\?]*)([::][^。!?\?\n]*)', r'\1\n\2', para) |
| para = re.sub('([。!?\?][”’])$', r'\1\n', para) |
| para = para.rstrip() |
| return para.split("\n") |
|
|
| def is_subword(token: str): |
| return token.startswith('##') |
|
|
| def binary_encoding_function(token): |
| hash_value = int(hashlib.sha256(token.encode('utf-8')).hexdigest(), 16) |
| random_bit = hash_value % 2 |
| return random_bit |
|
|
| def is_similar(x, y, threshold=0.5): |
| distance = Levenshtein.distance(x, y) |
| if distance / max(len(x), len(y)) < threshold: |
| return True |
| return False |
|
|
| class watermark_model: |
| def __init__(self, language, mode, tau_word, lamda): |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.language = language |
| self.mode = mode |
| self.tau_word = tau_word |
| self.tau_sent = 0.8 |
| self.lamda = lamda |
| self.cn_tag_black_list = set(['','x','u','j','k','zg','y','eng','uv','uj','ud','nr','nrfg','nrt','nw','nz','ns','nt','m','mq','r','w','PER','LOC','ORG']) |
| self.en_tag_white_list = set(['MD', 'NN', 'NNS', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'RP', 'RB', 'RBR', 'RBS', 'JJ', 'JJR', 'JJS']) |
| if language == 'Chinese': |
| self.relatedness_tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity") |
| self.relatedness_model = AutoModelForSequenceClassification.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity").to(self.device) |
| self.tokenizer = WoBertTokenizer.from_pretrained("junnyu/wobert_chinese_plus_base") |
| self.model = WoBertForMaskedLM.from_pretrained("junnyu/wobert_chinese_plus_base", output_hidden_states=True).to(self.device) |
| self.w2v_model = gensim.models.KeyedVectors.load_word2vec_format('sgns.merge.word.bz2', binary=False, unicode_errors='ignore', limit=50000) |
| elif language == 'English': |
| self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased') |
| self.model = BertForMaskedLM.from_pretrained('bert-base-cased', output_hidden_states=True).to(self.device) |
| self.relatedness_model = RobertaForSequenceClassification.from_pretrained('roberta-large-mnli').to(self.device) |
| self.relatedness_tokenizer = RobertaTokenizer.from_pretrained('roberta-large-mnli') |
| self.w2v_model = api.load("glove-wiki-gigaword-100") |
| nltk.download('stopwords') |
| self.stop_words = set(stopwords.words('english')) |
| self.nlp = spacy.load('en_core_web_sm') |
|
|
| def cut(self,ori_text,text_len): |
| if self.language == 'Chinese': |
| if len(ori_text) > text_len+5: |
| ori_text = ori_text[:text_len+5] |
| if len(ori_text) < text_len-5: |
| return 'Short' |
| elif self.language == 'English': |
| tokens = self.tokenizer.tokenize(ori_text) |
| if len(tokens) > text_len+5: |
| ori_text = self.tokenizer.convert_tokens_to_string(tokens[:text_len+5]) |
| if len(tokens) < text_len-5: |
| return 'Short' |
| return ori_text |
| else: |
| print(f'Unsupported Language:{self.language}') |
| raise NotImplementedError |
| |
| def sent_tokenize(self,ori_text): |
| if self.language == 'Chinese': |
| return cut_sent(ori_text) |
| elif self.language == 'English': |
| return nltk.sent_tokenize(ori_text) |
| |
| def pos_filter(self, tokens, masked_token_index, input_text): |
| if self.language == 'Chinese': |
| pairs = posseg.lcut(input_text) |
| pos_dict = {word: pos for word, pos in pairs} |
| pos_list_input = [pos for _, pos in pairs] |
| pos = pos_dict.get(tokens[masked_token_index], '') |
| if pos in self.cn_tag_black_list: |
| return False |
| else: |
| return True |
| elif self.language == 'English': |
| pos_tags = pos_tag(tokens) |
| pos = pos_tags[masked_token_index][1] |
| if pos not in self.en_tag_white_list: |
| return False |
| if is_subword(tokens[masked_token_index]) or is_subword(tokens[masked_token_index+1]) or (tokens[masked_token_index] in self.stop_words or tokens[masked_token_index] in string.punctuation): |
| return False |
| return True |
| |
| def filter_special_candidate(self, top_n_tokens, tokens,masked_token_index,input_text): |
| if self.language == 'English': |
| filtered_tokens = [tok for tok in top_n_tokens if tok not in self.stop_words and tok not in string.punctuation and pos_tag([tok])[0][1] in self.en_tag_white_list and not is_subword(tok)] |
| |
| lemmatized_tokens = [] |
| |
| |
| |
| |
| |
| base_word = tokens[masked_token_index] |
| base_word_lemma = self.nlp(base_word)[0].lemma_ |
| processed_tokens = [base_word]+[tok for tok in filtered_tokens if self.nlp(tok)[0].lemma_ != base_word_lemma] |
| return processed_tokens |
| elif self.language == 'Chinese': |
| pairs = posseg.lcut(input_text) |
| pos_dict = {word: pos for word, pos in pairs} |
| pos_list_input = [pos for _, pos in pairs] |
| pos = pos_dict.get(tokens[masked_token_index], '') |
| filtered_tokens = [] |
| for tok in top_n_tokens: |
| watermarked_text_segtest = self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [tok] + tokens[masked_token_index+1:-1]) |
| watermarked_text_segtest = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text_segtest) |
| pairs_tok = posseg.lcut(watermarked_text_segtest) |
| pos_dict_tok = {word: pos for word, pos in pairs_tok} |
| flag = pos_dict_tok.get(tok, '') |
| if flag not in self.cn_tag_black_list and flag == pos: |
| filtered_tokens.append(tok) |
| processed_tokens = filtered_tokens |
| return processed_tokens |
| |
| def global_word_sim(self,word,ori_word): |
| try: |
| global_score = self.w2v_model.similarity(word,ori_word) |
| except KeyError: |
| global_score = 0 |
| return global_score |
| |
| def context_word_sim(self,init_candidates, tokens, masked_token_index, input_text): |
| original_input_tensor = self.tokenizer.encode(input_text,return_tensors='pt').to(self.device) |
| batch_input_ids = [[self.tokenizer.convert_tokens_to_ids(['[CLS]'] + tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]+ ['[SEP]'])] for token in init_candidates] |
| batch_input_tensors = torch.tensor(batch_input_ids).squeeze().to(self.device) |
| batch_input_tensors = torch.cat((batch_input_tensors,original_input_tensor),dim=0) |
| with torch.no_grad(): |
| outputs = self.model(batch_input_tensors) |
| cos_sims = torch.zeros([len(init_candidates)]).to(self.device) |
| num_layers = len(outputs[1]) |
| N = 8 |
| i = masked_token_index |
| cos_sim_sum = 0 |
| for layer in range(num_layers-N,num_layers): |
| ls_hidden_states = outputs[1][layer][0:len(init_candidates), i, :] |
| source_hidden_state = outputs[1][layer][len(init_candidates), i, :] |
| cos_sim_sum += F.cosine_similarity(source_hidden_state, ls_hidden_states, dim=1) |
| cos_sim_avg = cos_sim_sum / N |
| |
| cos_sims += cos_sim_avg |
| return cos_sims.tolist() |
|
|
| def sentence_sim(self,init_candidates, tokens, masked_token_index, input_text): |
| if self.language == 'Chinese': |
| batch_sents = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]) for token in init_candidates] |
| batch_sentences = [re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', sent) for sent in batch_sents] |
| roberta_inputs = [input_text + '[SEP]' + s for s in batch_sentences] |
| elif self.language == 'English': |
| batch_sentences = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]) for token in init_candidates] |
| roberta_inputs = [input_text + '</s></s>' + s for s in batch_sentences] |
| |
| encoded_dict = self.relatedness_tokenizer.batch_encode_plus( |
| roberta_inputs, |
| padding=True, |
| truncation=True, |
| max_length=512, |
| return_tensors='pt') |
| |
| input_ids = encoded_dict['input_ids'].to(self.device) |
| attention_masks = encoded_dict['attention_mask'].to(self.device) |
| with torch.no_grad(): |
| outputs = self.relatedness_model(input_ids=input_ids, attention_mask=attention_masks) |
| logits = outputs[0] |
| probs = torch.softmax(logits, dim=1) |
| if self.language == 'Chinese': |
| relatedness_scores = probs[:, 1].tolist() |
| elif self.language == 'English': |
| relatedness_scores = probs[:, 2].tolist() |
| |
| return relatedness_scores |
| |
| def candidates_gen(self,tokens,masked_token_index,input_text,topk=64, dropout_prob=0.3): |
| input_ids_bert = self.tokenizer.convert_tokens_to_ids(tokens) |
| if not self.pos_filter(tokens,masked_token_index,input_text): |
| return [] |
| masked_text = self.tokenizer.convert_tokens_to_string(tokens) |
| |
| input_tensor = torch.tensor([input_ids_bert]).to(self.device) |
|
|
| with torch.no_grad(): |
| embeddings = self.model.bert.embeddings(input_tensor) |
| dropout = nn.Dropout2d(p=dropout_prob) |
| |
| embeddings[:, masked_token_index, :] = dropout(embeddings[:, masked_token_index, :]) |
| with torch.no_grad(): |
| outputs = self.model(inputs_embeds=embeddings) |
|
|
| predicted_logits = outputs[0][0][masked_token_index] |
|
|
| |
| n = topk |
| |
| probs = torch.nn.functional.softmax(predicted_logits, dim=-1) |
| top_n_probs, top_n_indices = torch.topk(probs, n) |
| top_n_tokens = self.tokenizer.convert_ids_to_tokens(top_n_indices.tolist()) |
| processed_tokens = self.filter_special_candidate(top_n_tokens,tokens,masked_token_index) |
| |
| return processed_tokens |
|
|
| def filter_candidates(self, init_candidates, tokens, masked_token_index, input_text): |
| context_word_similarity_scores = self.context_word_sim(init_candidates, tokens, masked_token_index, input_text) |
| sentence_similarity_scores = self.sentence_sim(init_candidates, tokens, masked_token_index, input_text) |
| filtered_candidates = [] |
| for idx, candidate in enumerate(init_candidates): |
| global_word_similarity_score = self.global_word_sim(tokens[masked_token_index], candidate) |
| word_similarity_score = self.lamda*context_word_similarity_scores[idx]+(1-self.lamda)*global_word_similarity_score |
| if word_similarity_score >= self.tau_word and sentence_similarity_scores[idx] >= self.tau_sent: |
| filtered_candidates.append((candidate, word_similarity_score)) |
| return filtered_candidates |
| |
| def watermark_embed(self,text): |
| input_text = text |
| |
| tokens = self.tokenizer.tokenize(input_text) |
| tokens = ['[CLS]'] + tokens + ['[SEP]'] |
| masked_tokens=tokens.copy() |
| start_index = 1 |
| end_index = len(tokens) - 1 |
| for masked_token_index in range(start_index+1, end_index-1): |
| |
| binary_encoding = binary_encoding_function(tokens[masked_token_index - 1] + tokens[masked_token_index]) |
| if binary_encoding == 1: |
| continue |
| init_candidates = self.candidates_gen(tokens,masked_token_index,input_text, 32, 0.3) |
| if len(init_candidates) <=1: |
| continue |
| enhanced_candidates = self.filter_candidates(init_candidates,tokens,masked_token_index,input_text) |
| hash_top_tokens = enhanced_candidates.copy() |
| for i, tok in enumerate(enhanced_candidates): |
| binary_encoding = binary_encoding_function(tokens[masked_token_index - 1] + tok[0]) |
| if binary_encoding != 1 or (is_similar(tok[0], tokens[masked_token_index])) or (tokens[masked_token_index - 1] in tok or tokens[masked_token_index + 1] in tok): |
| hash_top_tokens.remove(tok) |
| hash_top_tokens.sort(key=lambda x: x[1], reverse=True) |
| if len(hash_top_tokens) > 0: |
| selected_token = hash_top_tokens[0][0] |
| else: |
| selected_token = tokens[masked_token_index] |
| |
| tokens[masked_token_index] = selected_token |
| watermarked_text = self.tokenizer.convert_tokens_to_string(tokens[1:-1]) |
| if self.language == 'Chinese': |
| watermarked_text = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text) |
| |
| return watermarked_text |
| |
| def embed(self, ori_text): |
| sents = self.sent_tokenize(ori_text) |
| sents = [s for s in sents if s.strip()] |
| num_sents = len(sents) |
| watermarked_text = '' |
| for i in range(0, num_sents, 2): |
| if i+1 < num_sents: |
| sent_pair = sents[i] + sents[i+1] |
| else: |
| sent_pair = sents[i] |
| if len(watermarked_text) == 0: |
| watermarked_text = self.watermark_embed(sent_pair) |
| else: |
| watermarked_text = watermarked_text + self.watermark_embed(sent_pair) |
| if len(self.get_encodings_fast(ori_text)) == 0: |
| return '' |
| return watermarked_text |
| |
| def get_encodings_fast(self,text): |
| sents = self.sent_tokenize(text) |
| sents = [s for s in sents if s.strip()] |
| num_sents = len(sents) |
| encodings = [] |
| for i in range(0, num_sents, 2): |
| if i+1 < num_sents: |
| sent_pair = sents[i] + sents[i+1] |
| else: |
| sent_pair = sents[i] |
| tokens = self.tokenizer.tokenize(sent_pair) |
| |
| for index in range(1,len(tokens)-1): |
| if not self.pos_filter(tokens,index,text): |
| continue |
| bit = binary_encoding_function(tokens[index-1]+tokens[index]) |
| encodings.append(bit) |
| return encodings |
|
|
| def watermark_detector_fast(self, text,alpha=0.05): |
| p = 0.5 |
| encodings = self.get_encodings_fast(text) |
| n = len(encodings) |
| ones = sum(encodings) |
| z = (ones - p * n) / (n * p * (1 - p)) ** 0.5 |
| threshold = norm.ppf(1 - alpha, loc=0, scale=1) |
| p_value = norm.sf(z) |
| is_watermark = z >= threshold |
| return is_watermark, p_value, n, ones, z |
| |
| def get_encodings_precise(self, text): |
| sents = self.sent_tokenize(text) |
| sents = [s for s in sents if s.strip()] |
| num_sents = len(sents) |
| encodings = [] |
| for i in range(0, num_sents, 2): |
| if i+1 < num_sents: |
| sent_pair = sents[i] + sents[i+1] |
| else: |
| sent_pair = sents[i] |
|
|
| tokens = self.tokenizer.tokenize(sent_pair) |
| |
| tokens = ['[CLS]'] + tokens + ['[SEP]'] |
|
|
| masked_tokens=tokens.copy() |
|
|
| start_index = 1 |
| end_index = len(tokens) - 1 |
|
|
| for masked_token_index in range(start_index+1, end_index-1): |
| init_candidates = self.candidates_gen(tokens,masked_token_index,sent_pair, 8, 0) |
| if len(init_candidates) <=1: |
| continue |
| enhanced_candidates = self.filter_candidates(init_candidates,tokens,masked_token_index,sent_pair) |
| if len(enhanced_candidates) > 1: |
| bit = binary_encoding_function(tokens[masked_token_index-1]+tokens[masked_token_index]) |
| encodings.append(bit) |
| return encodings |
|
|
| def watermark_detector_precise(self,text,alpha=0.05): |
| p = 0.5 |
| encodings = self.get_encodings_precise(text) |
| n = len(encodings) |
| ones = sum(encodings) |
| if n == 0: |
| z = 0 |
| else: |
| z = (ones - p * n) / (n * p * (1 - p)) ** 0.5 |
| threshold = norm.ppf(1 - alpha, loc=0, scale=1) |
| p_value = norm.sf(z) |
| is_watermark = z >= threshold |
| return is_watermark, p_value, n, ones, z |