File size: 3,947 Bytes
f5d1134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import os
import glob2 
import seaborn as sns 


def find_pareto_points(obtained_scores, threshold=0.02):
    n = len(obtained_scores)
    if n == 1:
        return obtained_scores
    pareto_index = []
    high_low = np.max(obtained_scores, axis=0) - np.min(obtained_scores, axis=0)
    for i in range(n):
        if not any(np.all((obtained_scores - obtained_scores[i] - threshold * high_low) > 0.0, axis=1)):
            pareto_index.append(i)

    points = obtained_scores[np.array(pareto_index)]
    arg_index = np.argsort(points[:, 0])
    points = points[arg_index]
    print(points)
    sorted_index = [0]
    remaining_index = np.ones(len(points))
    i = 0
    remaining_index[i] = 0
    while sum(remaining_index):
        distance = ((points[np.where(remaining_index)] - points[i]) ** 2 ).sum(axis=1)
        min_index = np.where(remaining_index > 0)[0][np.argmin(distance)]
        sorted_index.append(min_index)
        i = min_index
        remaining_index[i] = 0
    return points[np.array(sorted_index)]



index = 1
colors = sns.color_palette('Paired')
def plot_points(dir, label, style='-*', color='b', shift=[0,0], txt_color='black', normalize_path=None, reverse=True):
    threshold = 0.01
    desired_scores = []
    obtained_scores = []

    paths = [os.path.abspath(path) for path in glob2.glob(os.path.join(dir, '*.csv'))]
    paths += [os.path.abspath(path) for path in glob2.glob(os.path.join(dir, '*', '*.csv'))]

    pref_lis = []
    for path in paths:
        if '.csv' in path:
            full_path = path 
            data = pd.read_csv(full_path)
            # morlhf has less points, let the threshold larger to make the frontier better
            if 'ppo' in path and len(paths) <= 5:
                threshold = 0.5
            obtained_scores.append([np.mean(data['obtained_score1']), np.mean(data['obtained_score2'])])
            if 'pref' in path:
                # get the preference
                if 'eval_data_pref' in path:
                    pref = path.split('eval_data_pref')[-1].strip().split('_')[0]
                    pref_lis.append(float(pref))

    print(pref_lis)
    desired_scores = np.array(desired_scores)
    obtained_scores = np.array(obtained_scores)

    if normalize_path is not None:
        norm_info = np.load(normalize_path)
        norm_info = np.array(norm_info).reshape(2, 2)
        for i in range(2):
            obtained_scores[:, i] = (obtained_scores[:, i] - norm_info[i][0]) / norm_info[i][1] 

    global index
    markersize = 10  if ('*' in style or 'o' in style) else 9
    pareto_points = find_pareto_points(obtained_scores, threshold)
    plt.scatter(obtained_scores[:, 0], obtained_scores[:, 1], marker=style[-1], color=colors[index], s=markersize + 60)
    if len(pref_lis):
        for i in range(len(obtained_scores)):
            plt.annotate('{}'.format(round(pref_lis[i], 1)), (obtained_scores[i, 0] + shift[0], obtained_scores[i, 1] + shift[1]), size=4, color=txt_color)

    plt.plot(pareto_points[:, 0], pareto_points[:, 1], style, c=colors[index], markersize=markersize, label=label)
    index += 2


plt.figure(figsize=(5, 4))

name1 = 'harmless'
name2 = 'helpful'

### replace the paths to your own paths
plot_points('./logs_trl/eval_pretrained', 'Llama 2 base', '*')
plot_points('./logs_trl/eval_sft_alldata', 'SFT', '*')
plot_points('./eval_ppo_pref/', 'MORLHF', '--D', shift=[-0.012, -0.022])
plot_points('./logs_ppo/eval_pposoups_llamma2_klreg0.2', 'Rewarded Soups', style='--s', shift=[-0.012, -0.022])
plot_points('.logs_trl/evalnew_onlinefix_helpful_harmlesshelpful_iter2',  'RiC', style='-o', shift=[-0.012, -0.022], txt_color='white')


plt.xlabel('$R_1$ ({})'.format(name1), fontsize=12)
plt.ylabel('$R_2$ ({})'.format(name2), fontsize=12)
plt.legend(fontsize=11, loc='lower left')
plt.tight_layout()
plt.savefig('ric_assistant_{}_{}.pdf'.format(name1, name2))