-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathplot_utils_llm.py
136 lines (105 loc) · 5.47 KB
/
plot_utils_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
matplotlib.rcParams['pgf.texsystem'] = 'pdflatex'
matplotlib.rcParams.update({
'font.size': 18,
'axes.labelsize': 20,
'axes.titlesize': 24,
'figure.titlesize': 28
})
matplotlib.rcParams['text.usetex'] = False
MODEL_TITLE_DICT={"llama2_7b": "LLaMA-2-7B", "mistral_7b": "Mistral-7B",
"llama2_13b_chat": "LLaMA-2-13B-chat", "llama2_70b_chat": "LLaMA-2-70B-chat",
"llama2_7b_chat": "LLaMA-2-7B-chat", "llama2_13b": "LLaMA-2-13B", "llama2_70b": "LLaMA-2-70B",
"mistral_moe":"Mixtral-8x7B", "falcon_7b": "Falcon-7B", "falcon_40b": "Falcon-40B", "phi-2": "Phi-2",
"opt_7b":"OPT-7B", "opt_13b": "OPT-13B", "opt_30b": "OPT-30B", "opt_66b": "OPT-66B",
"mpt_7b": "MPT-7B", "mpt_30b": "MPT-30B", "pythia_7b": "Pythia-7B", "pythia_12b": "Pythia-12B",
"gpt2": "GPT-2", "gpt2_large": "GPT-2-Large", "gpt2_xl": "GPT-2-XL", "gpt2_medium": "GPT-2-Medium",
"mistral_moe_instruct": "Mixtral-8x7B-Instruct", "mistral_7b_instruct": "Mistral-7B-Instruct"}
def plot_3d_feat_sub(ax, obj, seq_id, layer_id, model_name):
num_tokens = len(obj[f"seq"])
num_channels = obj[f"{layer_id}"].shape[2]
inp_seq = obj[f"seq"]
inp_seq = [x if x != "<0x0A>" else r"\n" for x in inp_seq]
xdata = np.array([np.linspace(0,num_tokens-1,num_tokens) for i in range(num_channels)])
ydata = np.array([np.ones(num_tokens) * i for i in range(num_channels)])
zdata = obj[f"{layer_id}"][0].abs().numpy().T
ax.plot_wireframe(xdata, ydata, zdata, rstride=0, color="royalblue", linewidth=2.5)
ax.set_xticks(np.linspace(0,num_tokens-1,num_tokens), inp_seq,
rotation=50, fontsize=16)
ax.set_zticks([0, 1000, 2000], ["0", "1k", "2k"], fontsize=15)
ax.set_yticks([1415, 2533], [1415, 2533], fontsize=15, fontweight="heavy")
ax.get_xticklabels()[0].set_weight("heavy")
if seq_id in [0, 1]:
ax.get_xticklabels()[3].set_weight("heavy")
ax.set_title(MODEL_TITLE_DICT[model_name], fontsize=18, fontweight="bold", y=1.015)
plt.setp(ax.get_xticklabels(), rotation=50, ha="right", va="center",
rotation_mode="anchor")
plt.setp(ax.get_yticklabels(), ha="left",
rotation_mode="anchor")
ax.tick_params(axis='x', which='major', pad=-4)
ax.tick_params(axis='y', which='major', pad=-5)
ax.tick_params(axis='z', which='major', pad=-1)
ax.set_zlim(0,2400)
def plot_3d_feat(obj, layer_id, model_name, savedir):
fig = plt.figure(figsize=(14,6))
fig.tight_layout() # Or equivalently, "plt.tight_layout()"
plt.subplots_adjust(wspace=0.13)
# for i in range(3):
ax = fig.add_subplot(1,1, 1, projection='3d')
plot_3d_feat_sub(ax, obj, 0, layer_id, model_name)
plt.savefig(os.path.join(savedir, f"{model_name}_layer_{layer_id+1}.png"), bbox_inches="tight", dpi=200)
def plot_layer_ax_sub(ax, mean, model_name):
colors = ["cornflowerblue", "mediumseagreen", "C4", "teal", "dimgrey"]
x_axis = np.arange(mean.shape[-1])+1
for i in range(3):
ax.plot(x_axis, mean[i], label=f"Top {i+1}", color=colors[i],
linestyle="-", marker="o", markerfacecolor='none', markersize=5)
ax.plot(x_axis, mean[-1], label=f"Median", color=colors[-1],
linestyle="-", marker="v", markerfacecolor='none', markersize=5)
ax.set_title(MODEL_TITLE_DICT[model_name], fontsize=18, fontweight="bold")
num_layers = mean.shape[1]
xtick_label = [1, num_layers//4, num_layers//2, num_layers*3//4, num_layers]
ax.set_xticks(xtick_label, xtick_label, fontsize=16)
ax.set_xlabel('Layers', fontsize=18, labelpad=0.8)
ax.set_ylabel("Magnitudes", fontsize=18)
ax.tick_params(axis='x', which='major', pad=1.0)
ax.tick_params(axis='y', which='major', pad=0.4)
ax.grid(axis='x', color='0.75')
def plot_layer_ax(obj, model_name, savedir):
fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(7.5, 4.5))
fig.tight_layout() # Or equivalently, "plt.tight_layout()"
plt.subplots_adjust(wspace=0.13)
mean = np.mean(obj,axis=0)
plot_layer_ax_sub(axs, mean, model_name)
leg = axs.legend(
loc='center', bbox_to_anchor=(0.5, -0.25),
ncol=4, fancybox=True, prop={'size': 14}
)
leg.get_frame().set_edgecolor('silver')
leg.get_frame().set_linewidth(1.0)
plt.savefig(os.path.join(savedir, f"{model_name}.png"), bbox_inches="tight", dpi=200)
def plot_attn_sub(ax, corr, model_name, layer_id):
mask = np.zeros_like(corr)
mask[np.triu_indices_from(mask, k=1)] = True
sns.heatmap(corr, mask=mask, square=True, ax=ax,
cmap="YlGnBu",cbar_kws={"shrink": 1.0, "pad": 0.01, "aspect":50})
ax.set_facecolor("whitesmoke")
cax = ax.figure.axes[-1]
cax.tick_params(labelsize=18)
ax.tick_params(axis='x', which='major')
ax.set(xticklabels=[])
ax.set(yticklabels=[])
ax.tick_params(left=False, bottom=False)
ax.set_title(f"{MODEL_TITLE_DICT[model_name]}, Layer {layer_id+1}", fontsize=24, fontweight="bold")
def plot_attn(attn_logits, model_name, layer_id, savedir):
fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(8, 4.75))
fig.tight_layout() # Or equivalently, "plt.tight_layout()"
plt.subplots_adjust(wspace=0.15)
corr = attn_logits.numpy()[0].mean(0)
corr = corr.astype("float64")
plot_attn_sub(axs, corr, model_name, layer_id)
plt.savefig(os.path.join(savedir, f"{model_name}_layer{layer_id+1}.pdf"), bbox_inches="tight", dpi=200)