| from absl import logging |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import os |
| import pandas as pd |
| import re |
| import sys |
| sys.path.insert(0, "/home/yash/ALIGN-SIM/src") |
| from utils import mkdir_p, full_path, read_data |
| from SentencePerturbation.word_replacer import WordReplacer, WordSwapping |
| import random |
| from perturbation_args import get_args |
|
|
|
|
|
|
| def perturb_sentences(dataset_name: str, task: str, target_lang:str ="en", output_dir: str = "./data/perturbed_dataset/", sample_size: int = 3500, save :str = False) -> None: |
| """ |
| perturb_sentences _summary_ |
| |
| Args: |
| dataset_name (str): ["MRPC","PAWS","QQP"] |
| task (str): ["Synonym","Antonym","Jumbling"] |
| target_lang (str, optional): _description_. Defaults to "en". |
| output_dir (str, optional): _description_. Defaults to "./data/perturbed_dataset/". |
| sample_size (int, optional): _description_. Defaults to 3500. |
| save (str, optional): _description_. Defaults to False. |
| """ |
| |
| print("--------------------------------------") |
| |
| output_csv = full_path(os.path.join(output_dir, target_lang, task, f"{dataset_name}_{task}_perturbed_{target_lang}.csv")) |
| if os.path.exists(output_csv): |
| print(f"File already exists at: {output_csv}") |
| return |
| |
| |
| print("Loading dataset...") |
| data = read_data(dataset_name) |
| if "Unnamed: 0" in data.columns: |
| data.drop("Unnamed: 0", axis=1, inplace=True) |
| |
| if "idx" in data.columns: |
| data.drop("idx", axis=1, inplace=True) |
| |
| print(f"Loaded {dataset_name} dataset") |
| |
| print("--------------------------------------") |
|
|
| |
| |
| replacer = WordReplacer() |
| |
| random.seed(42) |
| |
| |
| |
| perturbed_data = pd.DataFrame(columns=["original_sentence"]) |
| |
| |
| |
| if task in ["Syn","syn","Synonym"]: |
| print("Creating Synonym perturbed data...") |
| sample_data = sampling(data, task, sample_size) |
| perturbed_data["original_sentence"] = sample_data.sentence1 |
| perturbed_data["perturb_n1"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 1, "synonyms")) |
| perturbed_data["perturb_n2"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 2, "synonyms")) |
| perturbed_data["perturb_n3"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 3, "synonyms")) |
| |
| assert perturbed_data.shape[1] == 4, "Perturbed data size mismatch" |
| |
| if task in ["paraphrase","Paraphrase","para"]: |
| print("Creating Paraphrase perturbed data...") |
| |
| |
| perturbed_data = sampling(data, task, sample_size) |
| perturbed_data["original_sentence"] = perturbed_data.sentence1 |
| perturbed_data["paraphrased_sentence"] = perturbed_data.sentence2 |
| assert perturbed_data.shape[1] == 3, "Perturbed data size mismatch" |
| |
| if task in ["Anto","anto","Antonym"]: |
| print("Creating Antonym perturbed data...") |
| pos_pairs = sampling(data, task, sample_size) |
| |
| perturbed_data["original_sentence"] = pos_pairs.sentence1 |
| perturbed_data["paraphrased_sentence"] = pos_pairs.sentence2 |
| perturbed_data["perturb_n1"] = perturbed_data["original_sentence"].apply(lambda x: replacer.sentence_replacement(x, 1, "antonyms")) |
| assert perturbed_data.shape[1] == 3, "Perturbed data size mismatch" |
| |
| |
| if task in ["jumbling", "Jumbling","jumb"]: |
| print("Creating Jumbling perturbed data...") |
| pos_pairs = sampling(data, task, sample_size) |
| perturbed_data["original_sentence"] = pos_pairs.sentence1 |
| perturbed_data["paraphrased_sentence"] = pos_pairs.sentence2 |
| perturbed_data["perturb_n1"]= perturbed_data["original_sentence"].apply(lambda x: WordSwapping.random_swap(x,1)) |
| perturbed_data["perturb_n2"]= perturbed_data["original_sentence"].apply(lambda x: WordSwapping.random_swap(x,2)) |
| perturbed_data["perturb_n3"]= perturbed_data["original_sentence"].apply(lambda x: WordSwapping.random_swap(x,3)) |
| |
| assert perturbed_data.shape[1] == 5, "Perturbed data size mismatch" |
| |
| if save: |
| perturbed_data.to_csv(mkdir_p(output_csv), index=False) |
| print("--------------------------------------") |
| print(f"Saved at: {output_csv}") |
| print("--------------------------------------") |
|
|
|
|
|
|
| def sampling(data: pd.DataFrame, task :str, sample_size: int, random_state: int = 42): |
| """ |
| Combines two sampling strategies: |
| |
| 1. sampled_data: Samples from the dataset by first taking all positive pairs and then, |
| if needed, filling the remainder with negative pairs. |
| 2. balanced_data: Constructs a dataset with roughly equal positive and negative pairs, |
| adjusting the numbers if one group is underrepresented. |
| |
| Returns: |
| sampled_data (pd.DataFrame): Dataset sampled by filling negatives if positives are insufficient. |
| positive_data (pd.DataFrame): All positive samples (label == 1). |
| balanced_data (pd.DataFrame): Dataset balanced between positive and negative pairs. |
| """ |
| |
| positive_data = data[data["label"] == 1] |
| negative_data = data[data["label"] == 0] |
| |
| if task in ["Anto","anto","Antonym","jumbling", "Jumbling","jumb"]: |
| return positive_data |
| |
| |
| if sample_size is None or sample_size > len(positive_data): |
| |
| |
| sampled_data = positive_data.copy() |
| else: |
| |
| sampled_data = positive_data.sample(n=sample_size, random_state=random_state) |
|
|
| |
| if task in ["Syn","syn","Synonym"]: |
| return sampled_data |
|
|
| |
| |
| negative_data = negative_data.reset_index(drop=True) |
| shuffled_sentence2 = negative_data["sentence2"].sample(frac=1, random_state=random_state).reset_index(drop=True) |
| negative_data["sentence2"] = shuffled_sentence2 |
|
|
| |
| if sample_size is None: |
| pos_sample_size = len(positive_data) |
| neg_sample_size = len(negative_data) |
| else: |
| |
| half_size = sample_size // 2 |
| pos_available = len(positive_data) |
| neg_available = len(negative_data) |
| pos_sample_size = min(half_size, pos_available) |
| neg_sample_size = min(half_size, neg_available) |
|
|
| |
| total_sampled = pos_sample_size + neg_sample_size |
| remainder = sample_size - total_sampled |
| if remainder > 0: |
| if (pos_available - pos_sample_size) >= (neg_available - neg_sample_size): |
| pos_sample_size += remainder |
| else: |
| neg_sample_size += remainder |
|
|
| |
| sampled_positive = positive_data.sample(n=pos_sample_size, random_state=random_state) |
| sampled_negative = negative_data.sample(n=neg_sample_size, random_state=random_state) |
| |
| sampled_positive["label"] = 1 |
| sampled_negative["label"] = 0 |
| |
| balanced_data = pd.concat([sampled_positive, sampled_negative]).sample(frac=1, random_state=random_state).reset_index(drop=True) |
| |
| if task in ["paraphrase","Paraphrase","para"]: |
| return balanced_data |
| |
|
|
|
|
|
|
| if __name__ == "__main__": |
|
|
| |
| if sys.gettrace() is not None: |
| config = { |
| "dataset_name": "mrpc", |
| "task": "syn", |
| "target_lang": "en", |
| "output_dir": "./data/perturbed_dataset/", |
| "save": True |
| } |
| else: |
| args = get_args() |
| config = { |
| "dataset_name": args.dataset_name, |
| "task": args.task, |
| "target_lang": args.target_lang, |
| "output_dir": args.output_dir, |
| "save": args.save, |
| "sample_size": args.sample_size |
| } |
| perturb_sentences(**config) |