| import numpy as np |
| import pandas as pd |
| import matplotlib.pyplot as plt |
| import os |
| import glob2 |
| import seaborn as sns |
|
|
|
|
| def find_pareto_points(obtained_scores, threshold=0.02): |
| n = len(obtained_scores) |
| if n == 1: |
| return obtained_scores |
| pareto_index = [] |
| high_low = np.max(obtained_scores, axis=0) - np.min(obtained_scores, axis=0) |
| for i in range(n): |
| if not any(np.all((obtained_scores - obtained_scores[i] - threshold * high_low) > 0.0, axis=1)): |
| pareto_index.append(i) |
|
|
| points = obtained_scores[np.array(pareto_index)] |
| arg_index = np.argsort(points[:, 0]) |
| points = points[arg_index] |
| print(points) |
| sorted_index = [0] |
| remaining_index = np.ones(len(points)) |
| i = 0 |
| remaining_index[i] = 0 |
| while sum(remaining_index): |
| distance = ((points[np.where(remaining_index)] - points[i]) ** 2 ).sum(axis=1) |
| min_index = np.where(remaining_index > 0)[0][np.argmin(distance)] |
| sorted_index.append(min_index) |
| i = min_index |
| remaining_index[i] = 0 |
| return points[np.array(sorted_index)] |
|
|
|
|
|
|
| index = 1 |
| colors = sns.color_palette('Paired') |
| def plot_points(dir, label, style='-*', color='b', shift=[0,0], txt_color='black', normalize_path=None, reverse=True): |
| threshold = 0.01 |
| desired_scores = [] |
| obtained_scores = [] |
|
|
| paths = [os.path.abspath(path) for path in glob2.glob(os.path.join(dir, '*.csv'))] |
| paths += [os.path.abspath(path) for path in glob2.glob(os.path.join(dir, '*', '*.csv'))] |
|
|
| pref_lis = [] |
| for path in paths: |
| if '.csv' in path: |
| full_path = path |
| data = pd.read_csv(full_path) |
| |
| if 'ppo' in path and len(paths) <= 5: |
| threshold = 0.5 |
| obtained_scores.append([np.mean(data['obtained_score1']), np.mean(data['obtained_score2'])]) |
| if 'pref' in path: |
| |
| if 'eval_data_pref' in path: |
| pref = path.split('eval_data_pref')[-1].strip().split('_')[0] |
| pref_lis.append(float(pref)) |
|
|
| print(pref_lis) |
| desired_scores = np.array(desired_scores) |
| obtained_scores = np.array(obtained_scores) |
|
|
| if normalize_path is not None: |
| norm_info = np.load(normalize_path) |
| norm_info = np.array(norm_info).reshape(2, 2) |
| for i in range(2): |
| obtained_scores[:, i] = (obtained_scores[:, i] - norm_info[i][0]) / norm_info[i][1] |
|
|
| global index |
| markersize = 10 if ('*' in style or 'o' in style) else 9 |
| pareto_points = find_pareto_points(obtained_scores, threshold) |
| plt.scatter(obtained_scores[:, 0], obtained_scores[:, 1], marker=style[-1], color=colors[index], s=markersize + 60) |
| if len(pref_lis): |
| for i in range(len(obtained_scores)): |
| plt.annotate('{}'.format(round(pref_lis[i], 1)), (obtained_scores[i, 0] + shift[0], obtained_scores[i, 1] + shift[1]), size=4, color=txt_color) |
|
|
| plt.plot(pareto_points[:, 0], pareto_points[:, 1], style, c=colors[index], markersize=markersize, label=label) |
| index += 2 |
|
|
|
|
| plt.figure(figsize=(5, 4)) |
|
|
| name1 = 'harmless' |
| name2 = 'helpful' |
|
|
| |
| plot_points('./logs_trl/eval_pretrained', 'Llama 2 base', '*') |
| plot_points('./logs_trl/eval_sft_alldata', 'SFT', '*') |
| plot_points('./eval_ppo_pref/', 'MORLHF', '--D', shift=[-0.012, -0.022]) |
| plot_points('./logs_ppo/eval_pposoups_llamma2_klreg0.2', 'Rewarded Soups', style='--s', shift=[-0.012, -0.022]) |
| plot_points('.logs_trl/evalnew_onlinefix_helpful_harmlesshelpful_iter2', 'RiC', style='-o', shift=[-0.012, -0.022], txt_color='white') |
|
|
|
|
| plt.xlabel('$R_1$ ({})'.format(name1), fontsize=12) |
| plt.ylabel('$R_2$ ({})'.format(name2), fontsize=12) |
| plt.legend(fontsize=11, loc='lower left') |
| plt.tight_layout() |
| plt.savefig('ric_assistant_{}_{}.pdf'.format(name1, name2)) |
|
|
|
|
|
|