【论文复现】Heads Up

方法图示:

参考项目:OpenDFM/HeadsUp

1 安装

1.1 虚拟环境

1
2
3
4
5
6
7
8
9
conda create -n heads python=3.10 -y
conda activate heads

pip install ipykernel seaborn adjustText
pip install rouge_score nltk jieba
pip install torch transformers==4.51.3 datasets==3.0.1 safetensors accelerate flash_attn-2.7.3+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl # https://github.com/Dao-AILab/flash-attention/releases

python -m ipykernel install --user --name heads
jupyter kernelspec list

1.2 项目结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
dataset/  # 数据集
function_vectors/ # 用于训练和评估的任务数据集
abstractive/ # 零样本
adjective_v_verb_3.json
antonym.json
verb_v_adjective_3.json
fewshot/ # 少样本
iwslt2017/ # 用于评估的机器翻译数据集
en-zh-test.parquet
XNLI-15way/ # 用于训练和评估的机器翻译数据集
xnli.15way.orig.tsv

models/ # 各种模型家族的类
modeling_gemma2.py
modeling_mistral.py
modeling_qwen2.py
modeling_llama.py
modeling_phi3.py

output/ # 训练的Mask矩阵
llama/
fv/
adjective_v_verb_3/
antonym/
verb_v_adjective_3/
xnli/
en_zh/
zh_en/

HeadsUp.ipynb # 复现train_mask.py、utils.py、eval.ipynb、playground.ipynb

1.3 模型和 Mask 矩阵

模型:

1
2
huggingface-cli download --token Your_token meta-llama/Meta-Llama-3.1-8B-Instruct --local-dir model/Llama-3.1-8B-Instruct
huggingface-cli download --token Your_token BAAI/bge-reranker-v2-m3 --local-dir model/bge-reranker-v2-m3

Mask 矩阵:下载 fv/adjective_v_verb_3/fv/verb_v_adjective_3/

2 整体流程

2.1 训练 Mask 矩阵

  1. 导入必要的库:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import argparse
import copy
import os
import re
from dataclasses import dataclass, field
from tqdm import tqdm

import torch
from torch import nn
import safetensors.torch
import itertools
import numpy as np
import pandas as pd
import datasets
from datasets import Dataset, concatenate_datasets
from models.modeling_llama import LlamaForCausalLM
from models.modeling_phi3 import Phi3ForCausalLM
from models.modeling_mistral import MistralForCausalLM
from models.modeling_qwen2 import Qwen2ForCausalLM
from models.modeling_gemma2 import Gemma2ForCausalLM
from transformers import AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments, HfArgumentParser
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import jieba
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from adjustText import adjust_text
  1. 重新定义训练器:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# 学习权重张量的参数模型
class SimpleTensorModel(nn.Module):
def __init__(self, tensor_length=1056):
super().__init__()
self.tensor = nn.Parameter(torch.zeros(tensor_length)) # 可学习的参数张量
nn.init.normal_(self.tensor, mean=4, std=0.02) # 初始化张量

# 返回参数张量
def forward(self, x=None):
return self.tensor


# 批处理数据整理器
class CustomDataCollator(DataCollatorForLanguageModeling):
def __init__(self, lm_tokenizer, padding=True):
self.lm_tokenizer = lm_tokenizer
self.padding = padding

# 处理特征列表,进行填充和批处理
def torch_call(self, features):
lm_input_ids = [f["lm_input_ids"] for f in features]
lm_labels = [torch.tensor(f["lm_labels"]) for f in features]

# 填充输入ID和标签
lm_batch = self.lm_tokenizer.pad({"input_ids": lm_input_ids}, padding=self.padding, return_tensors="pt")
lm_labels = torch.nn.utils.rnn.pad_sequence(lm_labels, batch_first=True, padding_value=-100)

return {"lm_input_ids": lm_batch["input_ids"], "lm_attention_mask": lm_batch["attention_mask"], "lm_labels": lm_labels}


# 从Gumbel-Sigmoid分布中采样并可选地进行离散化
def gumbel_sigmoid(logits, tau=1, hard=False, threshold=0.5):
gumbels = (torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()) # 从指数分布采样并取对数,生成Gumbel(0, 1)噪声
gumbels = (logits + gumbels) / tau # 应用温度参数并加上logits
y_soft = gumbels.sigmoid() # 应用sigmoid函数,得到概率值

# 处理离散化选项
if hard:
# 使用直通估计器
indices = (y_soft > threshold).nonzero(as_tuple=True) # 找到大于阈值的索引位置
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format) # 创建全零的硬输出张量
y_hard[indices[0], indices[1]] = 1.0 # 将大于阈值的位设置为1
ret = y_hard - y_soft.detach() + y_soft # 直通技巧:在前向传播中使用离散值,在反向传播中使用连续值
else:
# 不使用离散化,直接返回软概率值
ret = y_soft # 重参数化技巧:梯度可以通过y_soft传播

return ret


# 自定义训练器,包含特定的损失计算和预测逻辑
class CustomTrainer(Trainer):
def __init__(self, *args, lm_model, lm_tokenizer, **kwargs):
super().__init__(*args, **kwargs)
self.lm_model = lm_model
self.lm_tokenizer = lm_tokenizer

# 计算损失函数
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
# 将输入移动到设备
lm_input_ids = inputs["lm_input_ids"].to(self.args.device)
lm_attention_mask = inputs["lm_attention_mask"].to(self.args.device)
lm_labels = inputs["lm_labels"].to(self.args.device)
bsz = lm_input_ids.size(0) # 批大小

# 通过嵌入模型获取权重张量
weights_logit = model(None).unsqueeze(0).repeat(bsz, 1)

# 计算温度衰减
if self.args.tau_decay_steps < 1: # 如果小于1,表示是训练步数的比例
tau_decay_end_step = int(self.args.tau_decay_steps * self.state.max_steps)
else:
tau_decay_end_step = int(self.args.tau_decay_steps)

# 根据当前训练步数调整温度
if self.state.global_step >= tau_decay_end_step:
tau_temp = self.args.tau_temp_end
else:
decay_ratio = self.state.global_step / tau_decay_end_step
tau_temp = self.args.tau_temp_begin - decay_ratio * (self.args.tau_temp_begin - self.args.tau_temp_end)

# 应用Gumbel-Sigmoid或普通Sigmoid
if self.args.use_gumbel:
weights_tensor = gumbel_sigmoid(weights_logit, tau=tau_temp, hard=self.args.gumbel_hard)
else:
weights_tensor = torch.sigmoid(weights_logit)
weights_tensor = weights_tensor.to(self.lm_model.device)

# 使用权重张量训练语言模型
pred_outputs = self.lm_model(lm_input_ids, attention_mask=lm_attention_mask, weight_tensor=weights_tensor, labels=lm_labels, use_cache=False)
pred_output_loss = pred_outputs.loss

# 添加正则化项
norm_lambda = self.args.norm_lambda
normalizer = torch.sum(weights_tensor, dim=1).mean()
loss = pred_output_loss + norm_lambda * normalizer

# 记录日志
self.log({"pred_output_loss": pred_output_loss.item(), "normalizer": normalizer.item(), "tau_temp": tau_temp, "total_loss": loss.item()})

return (loss, pred_outputs) if return_outputs else loss

# 预测步骤,用于生成文本
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
# 将输入移动到设备
lm_input_ids = inputs["lm_input_ids"].to(self.args.device)
lm_labels = inputs["lm_labels"].to(self.args.device)
bsz = lm_input_ids.size(0)

with torch.no_grad():
weights_logit = model(None).unsqueeze(0).repeat(bsz, 1) # 获取权重张量
weights_tensor = (weights_logit.sigmoid() >= 0.5).to(weights_logit.dtype) # 使用0.5作为阈值二值化权重
weights_tensor = weights_tensor.to(self.lm_model.device)
pred_output_loss = None

# 生成文本
for one_lm_input_ids, one_lm_labels, one_weights_tensor in zip(lm_input_ids, lm_labels, weights_tensor):
# 提取输入部分
one_lm_scr_input_ids = one_lm_input_ids[one_lm_labels == -100].unsqueeze(0)
one_lm_attention_mask = torch.ones_like(one_lm_scr_input_ids)

# 生成文本
generate_ids = self.lm_model.generate(one_lm_scr_input_ids, attention_mask=one_lm_attention_mask, max_new_tokens=30, weight_tensor=one_weights_tensor.unsqueeze(0), do_sample=False)
pred_str = self.lm_tokenizer.decode(generate_ids[0], skip_special_tokens=True)
print("[GENERATED]", pred_str)

if prediction_loss_only:
return (pred_output_loss, None, None)

lm_logits = None
return (pred_output_loss, lm_logits, lm_labels)
  1. 定义模板字符串和预处理函数:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# 定义不同模型的模板字符串
LLAMA_TEMPLATE = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{src}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
LLAMA_PLM_TEMPLATE = "<|begin_of_text|>{src}\n\n"
PHI3_TEMPLATE = "<|user|>\n{src}<|end|>\n<|assistant|>\n"
MISTRAL_TEMPLATE = "<s>[INST] {src}[/INST]"
QWEN2_TEMPLATE = "<|im_start|>user\n{src}<|im_end|>\n<|im_start|>assistant\n"
QWEN2_PLM_TEMPLATE = "{src}\n\n"
GEMMA2_TEMPLATE = "<bos><start_of_turn>user\n{src}<end_of_turn>\n<start_of_turn>model\n"


# 预处理批数据,根据模型类型格式化文本并生成标签
def preprocess_batch(samples, lm_tokenizer, model_type, fewshot_dataset=None):
assert model_type in ["llama", "llama-plm", "phi3", "mistral", "qwen2", "qwen2-plm", "gemma2"] # 确保模型类型有效
template = ""

# 根据模型类型选择模板并格式化文本
if model_type == "llama":
template = LLAMA_TEMPLATE
combined_strs = [LLAMA_TEMPLATE.format(src=input_str.strip()) + f"{target_str.strip()}<|eot_id|>" for input_str, target_str in zip(samples["input_str"], samples["target_str"])]
elif model_type == "llama-plm":
template = LLAMA_PLM_TEMPLATE
combined_strs = [LLAMA_PLM_TEMPLATE.format(src=input_str.strip()) + f"{target_str.strip()}<|end_of_text|>" for input_str, target_str in zip(samples["input_str"], samples["target_str"])]
elif model_type == "phi3":
template = PHI3_TEMPLATE
combined_strs = [PHI3_TEMPLATE.format(src=input_str.strip()) + f"{target_str.strip()}<|end|>" for input_str, target_str in zip(samples["input_str"], samples["target_str"])]
elif model_type == "mistral":
template = MISTRAL_TEMPLATE
combined_strs = [MISTRAL_TEMPLATE.format(src=input_str.strip()) + f" {target_str.strip()}</s>" for input_str, target_str in zip(samples["input_str"], samples["target_str"])]
elif model_type == "qwen2":
template = QWEN2_TEMPLATE
combined_strs = [QWEN2_TEMPLATE.format(src=input_str.strip()) + f"{target_str.strip()}<|im_end|>" for input_str, target_str in zip(samples["input_str"], samples["target_str"])]
elif model_type == "qwen2-plm":
template = QWEN2_PLM_TEMPLATE
combined_strs = [QWEN2_PLM_TEMPLATE.format(src=input_str.strip()) + f"{target_str.strip()}<|endoftext|>" for input_str, target_str in zip(samples["input_str"], samples["target_str"])]
elif model_type == "gemma2":
template = GEMMA2_TEMPLATE
combined_strs = [GEMMA2_TEMPLATE.format(src=input_str.strip()) + f"{target_str.strip()}<end_of_turn>" for input_str, target_str in zip(samples["input_str"], samples["target_str"])]

# 使用tokenizer将文本转换为模型输入
lm_inputs = lm_tokenizer(combined_strs, max_length=32768, truncation=True, padding=False, add_special_tokens=False,return_tensors="np")
input_str_lens = [len(lm_tokenizer(template.format(src=input_str.strip()), add_special_tokens=False)["input_ids"]) for input_str in samples["input_str"]] # 计算输入部分的长度,用于生成标签掩码

# 创建标签,将输入部分标记为-100
labels = copy.deepcopy(lm_inputs["input_ids"])
for i, input_str_len in enumerate(input_str_lens):
labels[i][:input_str_len] = -100

return {"lm_input_ids": lm_inputs["input_ids"], "lm_labels": labels}


# 预处理few-shot学习批数据,构建包含示例的输入
def preprocess_fewshot_batch(samples, lm_tokenizer, model_type, fewshot_dataset):
assert model_type in ["llama"] # 仅支持llama模型

# 构建包含few-shot示例的模板
fewshot_input_str = "<|begin_of_text|>"
for sample in fewshot_dataset:
input_str, target_str = sample["input_str"], sample["target_str"]
fewshot_input_str += f"<|start_header_id|>user<|end_header_id|>\n\n{input_str.strip()}<|eot_id|>"
fewshot_input_str += f"<|start_header_id|>assistant<|end_header_id|>\n\n{target_str.strip()}<|eot_id|>"
fewshot_input_str += "<|start_header_id|>user<|end_header_id|>\n\n{src}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

# 格式化当前输入
template = fewshot_input_str
combined_strs = [template.format(src=input_str.strip()) + f"{target_str.strip()}<|eot_id|>" for input_str, target_str in zip(samples["input_str"], samples["target_str"])]
lm_inputs = lm_tokenizer(combined_strs, max_length=32768, truncation=True, padding=False, add_special_tokens=False,return_tensors="np")

# 计算输入部分长度并生成标签掩码
input_str_lens = [len(lm_tokenizer(template.format(src=input_str.strip()), add_special_tokens=False)["input_ids"]) for input_str in samples["input_str"]]
labels = copy.deepcopy(lm_inputs["input_ids"])
for i, input_str_len in enumerate(input_str_lens):
labels[i][:input_str_len] = -100

return {"lm_input_ids": lm_inputs["input_ids"], "lm_labels": labels}
  1. 定义训练过程:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# 训练函数,用于搜索权重嵌入
def search_weight_embed(lm_model, task, args, training_args):
# 设置输出目录
training_args.output_dir = os.path.join(args.output_dir, task)
os.makedirs(training_args.output_dir, exist_ok=True)

# 模型和分词器设置
model_name = os.path.basename(args.lm_model_path)
training_args.run_name = f"{model_name}-train_weight-{task}"
lm_tokenizer = AutoTokenizer.from_pretrained(args.lm_model_path, use_fast=True)
lm_tokenizer.pad_token = lm_tokenizer.eos_token
lm_tokenizer.padding_side = "right"
lm_tokenizer.truncation_side = "right"
lm_model.generation_config.pad_token_id = lm_tokenizer.pad_token_id

# 获取模型的层数和头数
n_layers = lm_model.config.num_hidden_layers
n_heads = lm_model.config.num_attention_heads

head_mask_model = SimpleTensorModel(tensor_length=n_layers * n_heads + n_layers).to(torch.device("cuda")) # 创建权重掩码模型

# 数据集处理
if not args.dataset_use_cache:
datasets.disable_caching()

train_dataset, dev_dataset = None, None
fewshot_dataset = None

# 根据数据类型处理不同格式的数据集
if "XNLI" in args.train_data_path:
# XNLI数据集处理
map_func = preprocess_batch

LANG_LIST = task.split("_")
LANG_DICT = {"ar": "Arabic", "fr": "French", "es": "Spanish", "de": "German", "en": "English", "ru": "Russian", "zh": "Chinese"}

# 预处理XNLI数据集
def _preprocess_xnli(dataset):
# 为每种语言组合创建数据集
pair_datasets = []
for src_lang, tgt_lang in itertools.combinations(LANG_LIST, 2):
pair_dataset = dataset.map(lambda sample: {"input_str": f"{sample[src_lang]}", "target_str": sample[tgt_lang]}).select_columns(["input_str", "target_str"])
pair_datasets.append(pair_dataset)

return concatenate_datasets(pair_datasets)

# 加载和划分数据集
dataset = Dataset.from_csv(args.train_data_path, sep='\t').select_columns(LANG_LIST)
train_dataset = dataset.select(range(len(dataset) - 100))
dev_dataset = dataset.select(range(len(dataset) - 100, len(dataset)))
train_dataset = _preprocess_xnli(train_dataset)
dev_dataset = _preprocess_xnli(dev_dataset)
train_dataset = train_dataset.shuffle(seed=args.seed)

elif "function_vectors" in args.train_data_path:
# 函数向量数据集处理
dataset = Dataset.from_json(args.train_data_path)

if "fewshot" in args.train_data_path:
# few-shot学习数据集
map_func = preprocess_fewshot_batch
dataset = dataset.map(lambda sample: {"input_str": sample["input"], "target_str": sample["output"]})
dataset = dataset.remove_columns(["input", "output"])
dataset = dataset.select(range(min(len(dataset) - 100, 10000))).shuffle(seed=args.seed)
fewshot_dataset, dataset = dataset.select(range(5)), dataset.select(range(5, len(dataset)))

else:
# 普通数据集
map_func = preprocess_batch
dataset = dataset.map(lambda sample: {"input_str": str(sample["input"]), "target_str": str(sample["output"])})
dataset = dataset.remove_columns(["input", "output"])

train_dataset = dataset.select(range(min(len(dataset) - 100, 10000)))
dev_dataset = dataset.select(range(len(dataset) - 100, len(dataset)))

else:
assert False, "Unsupported dataset"

# Tokenize数据集
train_dataset_dir, dev_dataset_dir = os.path.dirname(args.train_data_path), os.path.dirname(args.dev_data_path)
train_dataset = train_dataset.map(
lambda sample: map_func(sample, lm_tokenizer, model_type=args.model_type, fewshot_dataset=fewshot_dataset),
batched=True,
batch_size=128,
num_proc=1,
cache_file_name=os.path.join(train_dataset_dir, ".cache/train_dataset_cache.arrow") if args.dataset_use_cache else None
).filter(
lambda sample: len(sample["lm_input_ids"]) <= args.max_seq_length,
batched=False
).shuffle(seed=args.seed)
dev_dataset = dev_dataset.map(
lambda sample: map_func(sample, lm_tokenizer, model_type=args.model_type, fewshot_dataset=fewshot_dataset),
batched=True,
batch_size=128,
num_proc=1,
cache_file_name=os.path.join(dev_dataset_dir, ".cache/dev_dataset_cache.arrow") if args.dataset_use_cache else None
).filter(
lambda sample: len(sample["lm_input_ids"]) <= args.max_seq_length,
batched=False
)

# 数据整理器
data_collator = CustomDataCollator(lm_tokenizer=lm_tokenizer, padding=True)

# 创建训练器
trainer = CustomTrainer(
model=head_mask_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=dev_dataset,
data_collator=data_collator,
compute_metrics=None,
lm_model=lm_model,
lm_tokenizer=lm_tokenizer,
)

# 开始训练
trainer.train(resume_from_checkpoint=False)
  1. 定义相关参数:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# 模型相关参数
@dataclass
class ModelArguments:
embed_model_path: str = field(default="../model/bge-reranker-v2-m3", metadata={"help": "嵌入模型路径"})
lm_model_path: str = field(default="../model/Llama-3.1-8B-Instruct", metadata={"help": "语言模型路径"})
model_type: str = field(default="llama", metadata={"help": "模型类型(plm 表示纯语言模型)"})

# 数据相关参数
@dataclass
class DataArguments:
train_data_path: str = field(default="dataset/XNLI-15way/xnli.15way.orig.tsv", metadata={"help": "训练数据路径"})
dev_data_path: str = field(default="dataset/XNLI-15way/xnli.15way.orig.tsv", metadata={"help": "验证数据路径"})
dataset_use_cache: bool = field(default=False, metadata={"help": "是否使用数据集缓存"})
max_seq_length: int = field(default=32768, metadata={"help": "输入的最大序列长度"})

# 训练相关参数
@dataclass
class CustomTrainingArguments(TrainingArguments):
tau_temp_begin: float = field(default=4.0, metadata={"help": "Gumbel-Sigmoid 的初始温度"})
tau_temp_end: float = field(default=0.05, metadata={"help": "Gumbel-Sigmoid 的最终温度"})
tau_decay_steps: float = field(default=0.4, metadata={"help": "Gumbel-Sigmoid 温度的衰减步数(若小于 1,则表示占总训练步数的比例)"})
norm_lambda: float = field(default=0, metadata={"help": "权重张量归一化的 lambda 系数"})
norm_power: float = field(default=1, metadata={"help": "权重张量归一化的幂次"})
use_gumbel: bool = field(default=True, metadata={"help": "是否对权重张量使用 Gumbel-Sigmoid"})
gumbel_hard: bool = field(default=True, metadata={"help": "是否使用 hard 形式的 Gumbel-Sigmoid"})


# 使用Hugging Face的参数解析器
hfparser = HfArgumentParser((ModelArguments, DataArguments, CustomTrainingArguments))
model_args, data_args, training_args, _ = hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
data_args.dataset_use_cache = False
training_args.output_dir = "output/llama/xnli"
training_args.overwrite_output_dir = True
training_args.num_train_epochs = 10
training_args.per_device_train_batch_size = 4
training_args.per_device_eval_batch_size = 1
training_args.remove_unused_columns = False # 保留数据集中所有列
training_args.save_steps = 31
training_args.save_total_limit = 200
training_args.save_only_model = True
training_args.logging_dir = "logs"
training_args.logging_steps = 500
training_args.learning_rate = 1e-2
training_args.lr_scheduler_type = "cosine_with_min_lr"
training_args.lr_scheduler_kwargs = {"min_lr": 1e-4} # 学习率衰减的最小值
training_args.bf16 = True
training_args.warmup_ratio = 0.1 # 前10%的训练步数进行学习率预热
training_args.gradient_accumulation_steps = 4

# 合并所有参数
args = argparse.Namespace(**vars(model_args), **vars(training_args), **vars(data_args))
print(args)

Namespace(embed_model_path='../model/bge-reranker-v2-m3', lm_model_path='../model/Llama-3.1-8B-Instruct', model_type='llama', output_dir='output/llama/xnli', overwrite_output_dir=True, do_train=False, do_eval=False, do_predict=False, eval_strategy=<IntervalStrategy.NO: 'no'>, prediction_loss_only=False, per_device_train_batch_size=4, per_device_eval_batch_size=1, per_gpu_train_batch_size=None, per_gpu_eval_batch_size=None, gradient_accumulation_steps=4, eval_accumulation_steps=None, eval_delay=0, torch_empty_cache_steps=None, learning_rate=0.01, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0, num_train_epochs=10, max_steps=-1, lr_scheduler_type='cosine_with_min_lr', lr_scheduler_kwargs={'min_lr': 0.0001}, warmup_ratio=0.1, warmup_steps=0, log_level='passive', log_level_replica='warning', log_on_each_node=True, logging_dir='logs', logging_strategy=<IntervalStrategy.STEPS: 'steps'>, logging_first_step=False, logging_steps=500, logging_nan_inf_filter=True, save_strategy=<SaveStrategy.STEPS: 'steps'>, save_steps=31, save_total_limit=200, save_safetensors=True, save_on_each_node=False, save_only_model=True, restore_callback_states_from_checkpoint=False, no_cuda=False, use_cpu=False, use_mps_device=False, seed=42, data_seed=None, jit_mode_eval=False, use_ipex=False, bf16=True, fp16=False, fp16_opt_level='O1', half_precision_backend='auto', bf16_full_eval=False, fp16_full_eval=False, tf32=None, local_rank=0, ddp_backend=None, tpu_num_cores=None, tpu_metrics_debug=False, debug=[], dataloader_drop_last=False, eval_steps=None, dataloader_num_workers=0, dataloader_prefetch_factor=None, past_index=-1, run_name='trainer_output', disable_tqdm=False, remove_unused_columns=False, label_names=None, load_best_model_at_end=False, metric_for_best_model=None, greater_is_better=None, ignore_data_skip=False, fsdp=[], fsdp_min_num_params=0, fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}, tp_size=0, fsdp_transformer_layer_cls_to_wrap=None, accelerator_config=AcceleratorConfig(split_batches=False, dispatch_batches=None, even_batches=True, use_seedable_sampler=True, non_blocking=False, gradient_accumulation_kwargs=None, use_configured_state=False), deepspeed=None, label_smoothing_factor=0.0, optim=<OptimizerNames.ADAMW_TORCH: 'adamw_torch'>, optim_args=None, adafactor=False, group_by_length=False, length_column_name='length', report_to=[], ddp_find_unused_parameters=None, ddp_bucket_cap_mb=None, ddp_broadcast_buffers=None, dataloader_pin_memory=True, dataloader_persistent_workers=False, skip_memory_metrics=True, use_legacy_prediction_loop=False, push_to_hub=False, resume_from_checkpoint=None, hub_model_id=None, hub_strategy=<HubStrategy.EVERY_SAVE: 'every_save'>, hub_token=None, hub_private_repo=None, hub_always_push=False, gradient_checkpointing=False, gradient_checkpointing_kwargs=None, include_inputs_for_metrics=False, include_for_metrics=[], eval_do_concat_batches=True, fp16_backend='auto', push_to_hub_model_id=None, push_to_hub_organization=None, push_to_hub_token=None, mp_parameters='', auto_find_batch_size=False, full_determinism=False, torchdynamo=None, ray_scope='last', ddp_timeout=1800, torch_compile=False, torch_compile_backend=None, torch_compile_mode=None, include_tokens_per_second=False, include_num_input_tokens_seen=False, neftune_noise_alpha=None, optim_target_modules=None, batch_eval_metrics=False, eval_on_start=False, use_liger_kernel=False, eval_use_gather_object=False, average_tokens_across_devices=False, tau_temp_begin=4.0, tau_temp_end=0.05, tau_decay_steps=0.4, norm_lambda=0, norm_power=1, use_gumbel=True, gumbel_hard=True, distributed_state=Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
, _n_gpu=1, __cached__setup_devices=device(type='cuda', index=0), deepspeed_plugin=None, train_data_path='dataset/XNLI-15way/xnli.15way.orig.tsv', dev_data_path='dataset/XNLI-15way/xnli.15way.orig.tsv', dataset_use_cache=False, max_seq_length=32768)
  1. 选择模型类:
1
2
3
4
5
6
7
8
9
10
11
12
13
if "phi3" in args.model_type:
casual_lm = Phi3ForCausalLM
elif "llama" in args.model_type:
casual_lm = LlamaForCausalLM
elif "mistral" in args.model_type:
casual_lm = MistralForCausalLM
elif "qwen2" in args.model_type:
casual_lm = Qwen2ForCausalLM
elif "gemma2" in args.model_type:
casual_lm = Gemma2ForCausalLM
else:
assert False, "Unsupported model type"
print(f"Using model type: {casual_lm.__name__}")

Using model type: LlamaForCausalLM
  1. 加载语言模型:
1
2
3
4
5
6
7
8
lm_model = casual_lm.from_pretrained(
args.lm_model_path,
local_files_only=True,
device_map=torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") if args.distributed_state.use_distributed else "auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
max_position_embeddings=32768
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:21<00:00,  5.39s/it]
  1. 机器翻译训练,以中英翻译为例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
train_data_dir = args.train_data_path
if "XNLI" in train_data_dir:
# XNLI数据集:训练所有语言对
ALL_LANGS = ["en", "zh"]
for i, (src_lang, tgt_lang) in enumerate(itertools.permutations(ALL_LANGS, 2)):
print(f"\n********** [{i+1}/{len(ALL_LANGS) * (len(ALL_LANGS) - 1)}] Training {src_lang} -> {tgt_lang} **********\n")
search_weight_embed(lm_model, f"{src_lang}_{tgt_lang}", args, training_args)

elif "function_vectors" in train_data_dir:
# 函数向量数据集:遍历所有JSON文件
for root, dirs, files in os.walk(train_data_dir):
for file in files:
if file.endswith(".json"):
task = os.path.basename(file).split(".")[0]
print(f"\n********** Training {task} **********\n")
args.train_data_path = os.path.join(root, file)
search_weight_embed(lm_model, task, args, training_args)

else:
assert False, "Unsupported dataset"

********** [1/2] Training en -> zh **********

Map: 100%|██████████| 9900/9900 [00:00<00:00, 24753.96 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 16005.74 examples/s]
Map: 100%|██████████| 9900/9900 [00:01<00:00, 7272.35 examples/s]
Filter: 100%|██████████| 9900/9900 [00:00<00:00, 31385.21 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 5276.52 examples/s]
Filter: 100%|██████████| 100/100 [00:00<00:00, 18811.07 examples/s]
Step	Training Loss
500	    2.330100
1000	1.664200
1500	1.343300
2000	1.235700
2500	1.195600
3000	1.178300
3500	1.163300
4000	1.172700
4500	1.192800
5000	1.158600
5500	1.169000
6000	1.143300

********** [2/2] Training zh -> en **********

Map: 100%|██████████| 9900/9900 [00:00<00:00, 32588.55 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 17755.17 examples/s]
Map: 100%|██████████| 9900/9900 [00:01<00:00, 6584.71 examples/s]
Filter: 100%|██████████| 9900/9900 [00:00<00:00, 29741.66 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 5468.88 examples/s]
Filter: 100%|██████████| 100/100 [00:00<00:00, 17750.66 examples/s]
Step	Training Loss
500	    1.758600
1000	1.261800
1500	1.209500
2000	1.180300
2500	1.206600
3000	1.192100
3500	1.170300
4000	1.188600
4500	1.199100
5000	1.178500
5500	1.170500
6000	1.145500
  1. 替换为任务的参数:
1
2
3
4
5
6
7
8
9
10
data_args.train_data_path = "dataset/function_vectors/abstractive"
data_args.dev_data_path = "dataset/function_vectors/abstractive"
training_args.output_dir = "output/llama/fv"
training_args.num_train_epochs = 3
training_args.max_steps = 6250
training_args.eval_steps = 625

# 合并所有参数
args = argparse.Namespace(**vars(model_args), **vars(training_args), **vars(data_args))
print(args)

Namespace(embed_model_path='../model/bge-reranker-v2-m3', lm_model_path='../model/Llama-3.1-8B-Instruct', model_type='llama', output_dir='output/llama/fv', overwrite_output_dir=True, do_train=False, do_eval=False, do_predict=False, eval_strategy=<IntervalStrategy.NO: 'no'>, prediction_loss_only=False, per_device_train_batch_size=4, per_device_eval_batch_size=1, per_gpu_train_batch_size=None, per_gpu_eval_batch_size=None, gradient_accumulation_steps=4, eval_accumulation_steps=None, eval_delay=0, torch_empty_cache_steps=None, learning_rate=0.01, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0, num_train_epochs=3, max_steps=6250, lr_scheduler_type='cosine_with_min_lr', lr_scheduler_kwargs={'min_lr': 0.0001}, warmup_ratio=0.1, warmup_steps=0, log_level='passive', log_level_replica='warning', log_on_each_node=True, logging_dir='logs', logging_strategy=<IntervalStrategy.STEPS: 'steps'>, logging_first_step=False, logging_steps=500, logging_nan_inf_filter=True, save_strategy=<SaveStrategy.STEPS: 'steps'>, save_steps=31, save_total_limit=200, save_safetensors=True, save_on_each_node=False, save_only_model=True, restore_callback_states_from_checkpoint=False, no_cuda=False, use_cpu=False, use_mps_device=False, seed=42, data_seed=None, jit_mode_eval=False, use_ipex=False, bf16=True, fp16=False, fp16_opt_level='O1', half_precision_backend='auto', bf16_full_eval=False, fp16_full_eval=False, tf32=None, local_rank=0, ddp_backend=None, tpu_num_cores=None, tpu_metrics_debug=False, debug=[], dataloader_drop_last=False, eval_steps=625, dataloader_num_workers=0, dataloader_prefetch_factor=None, past_index=-1, run_name='trainer_output', disable_tqdm=False, remove_unused_columns=False, label_names=None, load_best_model_at_end=False, metric_for_best_model=None, greater_is_better=None, ignore_data_skip=False, fsdp=[], fsdp_min_num_params=0, fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}, tp_size=0, fsdp_transformer_layer_cls_to_wrap=None, accelerator_config=AcceleratorConfig(split_batches=False, dispatch_batches=None, even_batches=True, use_seedable_sampler=True, non_blocking=False, gradient_accumulation_kwargs=None, use_configured_state=False), deepspeed=None, label_smoothing_factor=0.0, optim=<OptimizerNames.ADAMW_TORCH: 'adamw_torch'>, optim_args=None, adafactor=False, group_by_length=False, length_column_name='length', report_to=[], ddp_find_unused_parameters=None, ddp_bucket_cap_mb=None, ddp_broadcast_buffers=None, dataloader_pin_memory=True, dataloader_persistent_workers=False, skip_memory_metrics=True, use_legacy_prediction_loop=False, push_to_hub=False, resume_from_checkpoint=None, hub_model_id=None, hub_strategy=<HubStrategy.EVERY_SAVE: 'every_save'>, hub_token=None, hub_private_repo=None, hub_always_push=False, gradient_checkpointing=False, gradient_checkpointing_kwargs=None, include_inputs_for_metrics=False, include_for_metrics=[], eval_do_concat_batches=True, fp16_backend='auto', push_to_hub_model_id=None, push_to_hub_organization=None, push_to_hub_token=None, mp_parameters='', auto_find_batch_size=False, full_determinism=False, torchdynamo=None, ray_scope='last', ddp_timeout=1800, torch_compile=False, torch_compile_backend=None, torch_compile_mode=None, include_tokens_per_second=False, include_num_input_tokens_seen=False, neftune_noise_alpha=None, optim_target_modules=None, batch_eval_metrics=False, eval_on_start=False, use_liger_kernel=False, eval_use_gather_object=False, average_tokens_across_devices=False, tau_temp_begin=4.0, tau_temp_end=0.05, tau_decay_steps=0.4, norm_lambda=0, norm_power=1, use_gumbel=True, gumbel_hard=True, distributed_state=Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
, _n_gpu=1, __cached__setup_devices=device(type='cuda', index=0), deepspeed_plugin=None, train_data_path='dataset/function_vectors/abstractive', dev_data_path='dataset/function_vectors/abstractive', dataset_use_cache=False, max_seq_length=32768)
  1. 任务训练,以反义词任务为例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
train_data_dir = args.train_data_path
if "XNLI" in train_data_dir:
# XNLI数据集:训练所有语言对
ALL_LANGS = ["en", "zh", "fr", "de", "es", "ru", "ar"]
for i, (src_lang, tgt_lang) in enumerate(itertools.permutations(ALL_LANGS, 2)):
print(f"\n********** [{i+1}/{len(ALL_LANGS) * (len(ALL_LANGS) - 1)}] Training {src_lang} -> {tgt_lang} **********\n")
search_weight_embed(lm_model, f"{src_lang}_{tgt_lang}", args, training_args)

elif "function_vectors" in train_data_dir:
# 函数向量数据集:遍历所有JSON文件
for root, dirs, files in os.walk(train_data_dir):
for file in files:
if file.endswith("antonym.json"): # (".json")
task = os.path.basename(file).split(".")[0]
print(f"\n********** Training {task} **********\n")
args.train_data_path = os.path.join(root, file)
search_weight_embed(lm_model, task, args, training_args)

else:
assert False, "Unsupported dataset"

********** Training antonym **********

Generating train split: 2398 examples [00:00, 77072.94 examples/s]
Map: 100%|██████████| 2398/2398 [00:00<00:00, 36112.36 examples/s]
Map: 100%|██████████| 2298/2298 [00:00<00:00, 7820.68 examples/s]
Filter: 100%|██████████| 2298/2298 [00:00<00:00, 59810.06 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 3802.15 examples/s]
Filter: 100%|██████████| 100/100 [00:00<00:00, 12812.90 examples/s]
Step	Training Loss
500	    6.312800
1000	1.012000
1500	0.716200
2000	0.607200
2500	0.609400
3000	0.585000
3500	0.594200
4000	0.583300
4500	0.586900
5000	0.554000
5500	0.540600
6000	0.546000
  1. 训练结束后,保存路径下会有若干个 checkpoint- 开头的文件夹,每个文件夹里面的内容为:
1
2
3
model.safetensors  # 训练后的模型权重和参数
trainer_state.json # 训练过程状态记录
training_args.bin # 训练配置和参数文件

2.2 评估

  1. 读取训练的不同任务的注意力头 Mask 权重:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
weight_dict = {}
for root, dir, files in os.walk("output/llama/xnli"):
for file in files:
if file.endswith(".safetensors") and "-6180" in root:
langs = os.path.basename(os.path.dirname(root)) # 获取当前目录的父目录名
path = os.path.join(root, file) # 构建完整的文件路径
weight_dict[langs] = safetensors.torch.load_file(path)["tensor"] # 从safetensors文件中加载"tensor"键对应的值,并存储到字典中
for root, dir, files in os.walk("output/llama/fv"):
for file in files:
if file.endswith(".safetensors") and "-6250" in root:
langs = os.path.basename(os.path.dirname(root))
path = os.path.join(root, file)
weight_dict[langs] = safetensors.torch.load_file(path)["tensor"]

print("Total tasks:", len(weight_dict))
print("Task name", "\t", "Weight", "\t", "# of up heads")
for k, v in sorted(weight_dict.items()):
print(k, "\t", v, "\t", (v.sigmoid() >= 0.5).sum().item()) # 打印任务名、权重值、激活的头部数量

Total tasks: 5
Task name 	 Weight 	 # of up heads
adjective_v_verb_3 	 tensor([6.4547, 0.8442, 4.9701,  ..., 6.9430, 5.0350, 3.9923]) 	 845
antonym 	 tensor([4.9961, 4.3987, 4.5589,  ..., 7.8902, 3.0385, 3.9649]) 	 864
en_zh 	 tensor([5.2550, 1.9984, 5.9432,  ..., 7.5652, 3.0530, 4.0198]) 	 856
verb_v_adjective_3 	 tensor([6.7996, 2.7995, 6.5897,  ..., 7.7574, 5.2005, 3.9966]) 	 871
zh_en 	 tensor([ 1.6995,  3.6389,  6.9171,  ...,  4.2308, -1.9574,  3.9649]) 	 966
  1. 载入模型和分词器,选择模板:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
DEVICE = "cuda:0"
MODEL_DIR = "../model/Llama-3.1-8B-Instruct"
model_name = os.path.basename(MODEL_DIR).lower()

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token # 将填充token设置为结束token

model = LlamaForCausalLM.from_pretrained(MODEL_DIR, local_files_only=True, device_map=DEVICE, torch_dtype=torch.bfloat16, attn_implementation="eager", max_position_embeddings=2048)
model.generation_config.pad_token_id = tokenizer.pad_token_id # 设置生成配置中的填充token ID

# 获取模型层数和每层注意力头数
n_layers = model.config.num_hidden_layers
n_heads = model.config.num_attention_heads

# 根据模型名称选择合适的模板
if "llama" in model_name:
if "instruct" in model_name:
template = LLAMA_TEMPLATE
print("using LLAMA_TEMPLATE")
else:
template = LLAMA_PLM_TEMPLATE
print("using LLAMA_PLM_TEMPLATE")
elif "qwen2" in model_name:
if "instruct" in model_name:
template = QWEN2_TEMPLATE
print("using QWEN2_TEMPLATE")
else:
template = QWEN2_PLM_TEMPLATE
print("using QWEN2_PLM_TEMPLATE")
else:
print("Unknown model")

Loading checkpoint shards: 100%|██████████| 4/4 [00:19<00:00,  4.79s/it]
using LLAMA_TEMPLATE
  1. 任务提示词:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
TASK_DICT = {
"lowercase_first_letter": "Output the first letter of the given word in lowercase.",
"park-country": "Identify the country where the given national park is located.",
"synonym": "Identify a synonym for the given word.",
"ag_news": "Classify the given news headline into one of the categories: Business, Science, Sports, or World. Provide only the category name.",
"word_length": "Determine the number of letters in the given word and output the count.",
"present-past": "Convert the given verb from its present tense to its simple past tense.",
"capitalize": "Output the given word with its first letter capitalized.",
"landmark-country": "Identify the country where the given landmark is located.",
"english-german": "Translate the given English word into German.",
"sentiment": "Determine the sentiment of the given input. Output either 'positive' or 'negative'.",
"country-capital": "What is the capital of the given country? Provide only the name of the capital.",
"person-occupation": "Identify the occupation of the given individual.",
"country-currency": "What is the official currency of the given country?",
"lowercase_last_letter": "Output the last letter of the given word in lowercase.",
"person-sport": "Identify the sport associated with the given individual.",
"person-instrument": "Identify the musical instrument played by the given musician.",
"antonym": "Identify the antonym of the given word.",
"capitalize_last_letter": "Output the last letter of the given word in uppercase.",
"english-french": "Translate the given English word into French.",
"next_item": "What is the next sequential item following the given input?",
"singular-plural": "Provide the plural form of the given singular noun.",
"capitalize_second_letter": "Output the second letter of the given word in uppercase.",
"prev_item": "What is the item that comes before the given input in a sequential context?",
"capitalize_first_letter": "Output the first letter of the given word in uppercase.",
"english-spanish": "Translate the given English word into Spanish.",
"next_capital_letter": "What is the next uppercase letter in alphabetical order after the given input?",
"national_parks": "Identify the U.S. state where the given national park is located.",
"product-company": "Identify the company associated with the given product.",
"conll2003_organization": "Extract the organization mentioned in the given text.",
"conll2003_person": "Extract the name of the person mentioned in the given text.",
"conll2003_location": "Extract the location mentioned in the given text.",
"adjective_v_verb_3": "From the given words, identify the one that is an adjective.",
"object_v_concept_3": "From the given words, identify the one that is a object.",
"verb_v_adjective_3": "From the given words, identify the one that is a verb.",
"fruit_v_animal_3": "From the given words, identify the one that is a fruit."
}
  1. 评估反义词任务:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# 评估 function vectors
def eval_function_vectors(use_instruction, use_mask):
# 构建评估数据集
dev_datasets = []
for root, dirs, files in os.walk("dataset/function_vectors/abstractive"):
for file in files:
if file.endswith("antonym.json"):
task = os.path.basename(file).split(".")[0] # 从文件名提取任务名
data_path = os.path.join(root, file) # 构建完整文件路径
dataset = datasets.Dataset.from_json(data_path) # 从JSON文件加载数据集
dataset = dataset.map(
lambda sample: {
"input_str": TASK_DICT[task] + f"\n\nInput:\n\n{sample['input']}\n\nOutput:\n\n" if use_instruction else sample["input"],
"target_str": sample["output"]
}
) # 对数据集进行映射转换
dataset = dataset.remove_columns(["input", "output"]) # 移除原始列
dev_dataset = dataset.select(range(len(dataset) - 100, len(dataset))) # 选择最后100个样本作为开发集
dev_datasets.append((task, dev_dataset)) # 将(任务名, 数据集)加入列表

# 执行评估
for task, dataset in tqdm(dev_datasets):
mask_weight = (weight_dict[task].sigmoid() >= 0.5).float().numpy() # 创建mask权重:将sigmoid后≥0.5的值设为1,否则为0
mask_tensor = torch.tensor(mask_weight).unsqueeze(0).repeat(1, 1).to(model.device) # 将mask转换为张量并扩展到适合模型输入的维度

correct = 0
for sample in dataset:
input_str, target_str = str(sample["input_str"]), str(sample["target_str"])
combined_str = template.format(src=input_str.strip()) # 使用模板格式化输入
lm_inputs_src = tokenizer(combined_str, max_length=2048, truncation=True, padding=False, add_special_tokens=False,return_tensors="pt").to(DEVICE)

with torch.no_grad():
if use_mask:
generate_ids = model.generate(**lm_inputs_src, max_new_tokens=10, weight_tensor=mask_tensor, do_sample=False)
else:
generate_ids = model.generate(**lm_inputs_src, max_new_tokens=10, weight_tensor=None, do_sample=False)
pred_str = tokenizer.decode(generate_ids[0][lm_inputs_src.input_ids.size(1):], skip_special_tokens=True) # 解码生成的文本

# 根据不同任务类型进行准确性判断
if task in [
"capitalize_first_letter", "capitalize_last_letter", "capitalize_second_letter", "capitalize",
"lowercase_first_letter", "lowercase_last_letter", "next_capital_letter", "next_item", "prev_item", "commonsense_qa",
"conll2003_organization", "conll2003_person", "conll2003_location",
"adjective_v_verb_3", "object_v_concept_3", "verb_v_adjective_3", "fruit_v_animal_3",
]:
if pred_str.strip().startswith(target_str.strip()): # 预测是否以目标字符串开头
correct += 1
else: # 目标字符串是否包含在预测中
if target_str.strip() in pred_str.strip():
correct += 1
score = correct / len(dataset)
print(f"{task}: {score}")

print("使用指令,使用Mask")
eval_function_vectors(use_instruction=True, use_mask=True)
print("使用指令,不使用Mask")
eval_function_vectors(use_instruction=True, use_mask=False)
print("不使用指令,使用Mask")
eval_function_vectors(use_instruction=False, use_mask=True)
print("不使用指令,不使用Mask")
eval_function_vectors(use_instruction=False, use_mask=False)

using LLAMA_TEMPLATE
使用指令,使用Mask
100%|██████████| 1/1 [00:10<00:00, 10.85s/it]
antonym: 0.63
使用指令,不使用Mask
100%|██████████| 1/1 [00:30<00:00, 30.66s/it]
antonym: 0.3
不使用指令,使用Mask
100%|██████████| 1/1 [00:08<00:00,  8.78s/it]
antonym: 0.76
不使用指令,不使用Mask
100%|██████████| 1/1 [00:30<00:00, 30.58s/it]
antonym: 0.12
  1. 观察对比抽取任务的注意力分布情况:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
dev_datasets = []
for root, dirs, files in os.walk("dataset/function_vectors/abstractive"):
for file in files:
if file.endswith("verb_3.json") or file.endswith("adjective_3.json"):
task = os.path.basename(file).split(".")[0]
if "_v_" not in task: continue # 只处理对比任务
data_path = os.path.join(root, file)
dataset = datasets.Dataset.from_json(data_path)
dataset = dataset.map(lambda sample: {"input_str": sample["input"], "target_str": sample["output"]})
dataset = dataset.remove_columns(["input", "output"])
dev_dataset = dataset.select(range(len(dataset) - 100, len(dataset)))
dev_datasets.append((task, dev_dataset))


# 在token序列中查找特定片段
def find_segments_(l):
indices_271 = [i for i, x in enumerate(l) if x == 271] # 找到所有271的位置
start, end = indices_271[0], indices_271[1] # 取前两个271
segment = l[start + 1:end] # 取两个271之间的片段

segments = [] # 存储最终片段
current_segment = [] # 当前正在构建的片段
for idx, value in enumerate(segment, start=start + 1):
if value == 11:
if current_segment:
segments.append(current_segment) # 完成当前片段
current_segment = [] # 重置
else:
current_segment.append(idx) # 将索引加入当前片段
if current_segment:
segments.append(current_segment) # 添加最后一个片段

return segments


# 初始化字典,存储各任务的注意力数据
task_attentions = {}
for task, dataset in tqdm(dev_datasets):
mask_weight = (weight_dict[task].sigmoid() >= 0.5).float().numpy()
mask_tensor = torch.tensor(mask_weight).unsqueeze(0).repeat(1, 1).to(model.device)

choose_the, choose_from = task.split("_v_")[0], task.split("_v_")[1][:-2] # 从任务名解析"选择什么"和"从什么中选择"
instructed_attn, used_attn, unused_attn = [], [], [] # 初始化三种情况的注意力值列表
for sample in dataset:
target_choice_idx = sample["input_str"].split(", ").index(sample["target_str"]) # 找到目标选项在输入列表中的索引
input_inst_str = f"{sample['input_str']}\n\nChoose the {choose_the} out of {choose_from}s:" # 构建带指令的输入文本
input_str = f"{sample['input_str']}" # 构建不带指令的输入文本

lm_inputs_inst_src = tokenizer([LLAMA_TEMPLATE.format(src=input_inst_str)], add_special_tokens=False, return_tensors="pt").to(DEVICE) # 对带指令的输入进行分词
lm_inputs_inst_src_choice_idx = find_segments_(lm_inputs_inst_src.input_ids[0].tolist()) # 找到选项对应的token索引
assert len(lm_inputs_inst_src_choice_idx) == 3 # 验证有三个选项

lm_inputs_src = tokenizer([LLAMA_TEMPLATE.format(src=input_str)], add_special_tokens=False, return_tensors="pt").to(DEVICE) # 对不带指令的输入进行分词
lm_inputs_src_choice_idx = lm_inputs_inst_src_choice_idx # 使用相同的选项索引

with torch.no_grad():
# 带指令,无Mask
original_output = model(**lm_inputs_inst_src, weight_tensor=None, output_attentions=True)
original_attention = torch.stack(original_output.attentions)
last_token_attn = original_attention[:, -1, :, -1] # 获取最后一个token对所有其他token的注意力
original_token_attention = last_token_attn.mean(dim=(0,1)).float().cpu().numpy() # 计算平均注意力
original_choice_attentions = [original_token_attention[lm_inputs_inst_src_choice_idx[i]].sum() for i in range(3)] # 计算每个选项获得的注意力总和
original_choice_attentions = np.array(original_choice_attentions) / sum(original_choice_attentions) # 归一化为比例
instructed_attn.append(original_choice_attentions[target_choice_idx]) # 记录目标选项获得的注意力比例

# 不带指令,有Mask
weighted_output = model(**lm_inputs_src, weight_tensor=mask_tensor, output_attentions=True)
weighted_attention = torch.stack(weighted_output.attentions)
last_token_attn = weighted_attention[:, -1, :, -1]
last_token_attn[mask_tensor.view(32,33)[:, :-1] == 0] = 0 # 将未激活的头部的注意力设为0
layer_head_num = (mask_tensor.view(32,33)[:, :-1] != 0).sum(dim=1) # 计算每层激活的头数
weighted_token_attention = (last_token_attn.sum(dim=(0,1)) / (mask_tensor.sum() - 32)).float().cpu().numpy()
weighted_choice_attentions = [weighted_token_attention[lm_inputs_src_choice_idx[i]].sum() for i in range(3)]
weighted_choice_attentions = np.array(weighted_choice_attentions) / sum(weighted_choice_attentions)
used_attn.append(weighted_choice_attentions[target_choice_idx])

# 不带指令,反转Mask
weighted_output = model(**lm_inputs_src, weight_tensor=mask_tensor, output_attentions=True)
weighted_attention = torch.stack(weighted_output.attentions)
last_token_unused_attn = weighted_attention[:, -1, :, -1]
last_token_unused_attn[mask_tensor.view(32,33)[:, :-1] == 1] = 0 # 将激活的头部的注意力设为0
weighted_token_unused_attention = (last_token_unused_attn.sum(dim=(0,1)) / (1056 - mask_tensor.sum())).float().cpu().numpy()
weighted_choice_unused_attentions = [weighted_token_unused_attention[lm_inputs_src_choice_idx[i]].sum() for i in range(3)]
weighted_choice_unused_attentions = np.array(weighted_choice_unused_attentions) / sum(weighted_choice_unused_attentions)
unused_attn.append(weighted_choice_unused_attentions[target_choice_idx])
task_attentions[task] = {"Instructed" : np.mean(instructed_attn), "Used": np.mean(used_attn), "Ununsed": np.mean(unused_attn)}

# 将字典转换为DataFrame并显示
df = pd.DataFrame.from_dict(task_attentions, orient="index")
display(df)
100%|██████████| 2/2 [00:19<00:00,  9.96s/it]
                    Instructed	Used	    Ununsed
verb_v_adjective_3	0.345372	0.426140	0.379153
adjective_v_verb_3	0.408563	0.462195	0.483597
  1. 评估中英互译任务,指标为 PPL 困惑度:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# 定义语言列表和对应的语言名称字典
LANG_LIST = ["en", "zh"]
LANG_DICT = {"ar": "Arabic", "fr": "French", "es": "Spanish", "de": "German", "en": "English", "ru": "Russian", "zh": "Chinese"}


# 加载XNLI-15way数据集,只选择指定的语言列,选择最后100个样本作为开发集
dataset = datasets.Dataset.from_csv("dataset/XNLI-15way/xnli.15way.orig.tsv", sep='\t').select_columns(LANG_LIST)
dataset = dataset.select(range(len(dataset) - 100, len(dataset)))


def eval_xnli_ppl(use_random, use_instruction):
langs_original_ppl, langs_weighted_ppl = {}, {}
for src_lang, tgt_lang in tqdm(list(itertools.permutations(LANG_LIST, 2))):
mask_weight = (weight_dict[f"{src_lang}_{tgt_lang}"].sigmoid() >= 0.5).float().numpy()

# 使用随机Mask
if use_random:
mask = np.zeros_like(mask_weight, dtype=bool)
mask[np.arange(n_heads, (n_heads + 1) * n_layers, n_heads + 1)] = True
random_weight = mask_weight[~mask]
np.random.shuffle(random_weight)
mask_weight[~mask] = random_weight

# 处理数据集,添加指令或直接使用原文
dev_dataset = dataset.map(
lambda sample: {
"input_str": sample[src_lang] + f"\n\nTranslate into {LANG_DICT[tgt_lang]}:" if use_instruction else sample[src_lang],
"target_str": sample[tgt_lang]
}
).select_columns(["input_str", "target_str"])

# 初始化结果字典结构,同语音设为0
if src_lang not in langs_original_ppl:
langs_original_ppl[src_lang] = {src_lang: 0}
if src_lang not in langs_weighted_ppl:
langs_weighted_ppl[src_lang] = {src_lang: 0}
original_ppl, weighted_ppl = [], [] # 初始化困惑度列表

for input_str, target_str in zip(*dev_dataset[:100].values()):
combined_str = LLAMA_TEMPLATE.format(src=input_str.strip()) + f"{target_str.strip()}<|eot_id|>"
lm_inputs = tokenizer(combined_str, max_length=2048, truncation=True, padding=False, add_special_tokens=False, return_tensors="pt").to(DEVICE)
input_str_len = tokenizer(LLAMA_TEMPLATE.format(src=input_str.strip()), add_special_tokens=False, return_tensors="pt")["input_ids"].size(-1)
labels = copy.deepcopy(lm_inputs["input_ids"]).to(DEVICE)
labels[:, :input_str_len] = -100

with torch.no_grad():
# 不使用 Mask
original_output = model(**lm_inputs, labels=labels)
original_ppl.append(original_output.loss.item()) # 记录损失值
# 使用 Mask
mask_tensor = torch.tensor(mask_weight).unsqueeze(0).repeat(1, 1).to(model.device)
weighted_output = model(**lm_inputs, labels=labels, weight_tensor=mask_tensor)
weighted_ppl.append(weighted_output.loss.item())

# 计算平均损失
langs_original_ppl[src_lang][tgt_lang] = np.nanmean(original_ppl)
langs_weighted_ppl[src_lang][tgt_lang] = np.nanmean(weighted_ppl)

# 将结果转换为DataFrame并显示
original_df = pd.DataFrame.from_dict(langs_original_ppl, orient="index")
print("w/o head mask:")
display(original_df)
weighted_df = pd.DataFrame.from_dict(langs_weighted_ppl, orient="index")
print("w/ head mask:")
display(weighted_df)


print("使用指令")
eval_xnli_ppl(use_random=False, use_instruction=True)
print("使用指令,随机Mask")
eval_xnli_ppl(use_random=True, use_instruction=True)
print("不使用指令")
eval_xnli_ppl(use_random=False, use_instruction=False)
print("不使用指令,随机Mask")
eval_xnli_ppl(use_random=True, use_instruction=False)

使用指令
100%|██████████| 2/2 [00:13<00:00,  6.62s/it]
w/o head mask:
    en	        zh
en	0.000000	1.091267
zh	1.002608	0.000000
w/ head mask:
    en	        zh
en	0.000000	1.004982
zh	0.986755	0.000000

使用指令,随机Mask
100%|██████████| 2/2 [00:13<00:00,  6.61s/it]
w/o head mask:
    en	        zh
en	0.000000	1.091267
zh	1.002608	0.000000
w/ head mask:
    en	        zh
en	0.000000	1.802226
zh	1.260897	0.000000

不使用指令
100%|██████████| 2/2 [00:13<00:00,  6.61s/it]
w/o head mask:
    en	        zh
en	0.000000	2.643437
zh	2.142624	0.000000
w/ head mask:
    en	        zh
en	0.000000	0.99274
zh	0.989524	0.00000

不使用指令,随机Mask
100%|██████████| 2/2 [00:13<00:00,  6.60s/it]
w/o head mask:
    en	        zh
en	0.000000	2.643437
zh	2.142624	0.000000
w/ head mask:
    en	        zh
en	0.000000	3.069456
zh	2.092448	0.000000
  1. 评估中英互译任务,指标为 ROUGE-L:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# 初始化分词器和ROUGE评分器
sentence_tokenizer = AutoTokenizer.from_pretrained("../model/bge-reranker-v2-m3", use_fast=True)
rouge_scorer_instance = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True, tokenizer=sentence_tokenizer) # 启用词干提取


def eval_xnli_rouge(use_instruction):
langs_original_rouge, langs_weighted_rouge = {}, {}
for src_lang, tgt_lang in tqdm(list(itertools.permutations(LANG_LIST, 2))[:12]): # 前12个语言对
mask_weight = (weight_dict[f"{src_lang}_{tgt_lang}"].sigmoid() >= 0.5).float().numpy()

dev_dataset = dataset.map(
lambda sample: {
"input_str": sample[src_lang] + f"\n\nTranslate into {LANG_DICT[tgt_lang]}:" if use_instruction else sample[src_lang],
"target_str": sample[tgt_lang]
}
).select_columns(["input_str", "target_str"])

if src_lang not in langs_original_rouge:
langs_original_rouge[src_lang] = {src_lang: 0}
if src_lang not in langs_weighted_rouge:
langs_weighted_rouge[src_lang] = {src_lang: 0}
original_rouge, weighted_rouge = [], []

for input_str, target_str in zip(*dev_dataset[:100].values()):
combined_str = LLAMA_TEMPLATE.format(src=input_str.strip())
lm_inputs_src = tokenizer(combined_str, max_length=2048, truncation=True, padding=False, add_special_tokens=False,return_tensors="pt").to(DEVICE)

with torch.no_grad():
original_generate_ids = model.generate(**lm_inputs_src, max_new_tokens=50, weight_tensor=None, do_sample=False)
original_pred_str = tokenizer.decode(original_generate_ids[0][lm_inputs_src.input_ids.size(1):], skip_special_tokens=True) # 解码生成的文本
original_rouge_scores = rouge_scorer_instance.score(original_pred_str, target_str) # 计算ROUGE-L分数
original_rouge_l = original_rouge_scores["rougeL"].fmeasure # 提取F1分数
original_rouge.append(original_rouge_l)

mask_tensor = torch.tensor(mask_weight).unsqueeze(0).repeat(1, 1).to(model.device)
weighted_generate_ids = model.generate(**lm_inputs_src, max_new_tokens=50, weight_tensor=mask_tensor, do_sample=False)
weighted_pred_str = tokenizer.decode(weighted_generate_ids[0][lm_inputs_src.input_ids.size(1):], skip_special_tokens=True)
weighted_rouge_scores = rouge_scorer_instance.score(weighted_pred_str, target_str)
weighted_rouge_l = weighted_rouge_scores["rougeL"].fmeasure
weighted_rouge.append(weighted_rouge_l)

# 计算平均ROUGE-L分数
langs_original_rouge[src_lang][tgt_lang] = np.nanmean(original_rouge)
langs_weighted_rouge[src_lang][tgt_lang] = np.nanmean(weighted_rouge)

original_df = pd.DataFrame.from_dict(langs_original_rouge, orient="index")
display(original_df)
weighted_df = pd.DataFrame.from_dict(langs_weighted_rouge, orient="index")
display(weighted_df)


print("使用指令")
eval_xnli_rouge(use_instruction=True)
print("不使用指令")
eval_xnli_rouge(use_instruction=False)

使用指令
100%|██████████| 2/2 [02:58<00:00, 89.34s/it] 
    en	        zh
en	0.000000	0.539153
zh	0.668623	0.000000
    en	        zh
en	0.000000	0.60087
zh	0.654566	0.00000

不使用指令
100%|██████████| 2/2 [05:00<00:00, 150.44s/it]
    en	    zh
en	0.00000	0.019667
zh	0.02056	0.000000
    en	        zh
en	0.000000	0.603147
zh	0.655537	0.000000
  1. 功能性注意力头消融实验:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
langs_weighted_ppl = {}
lang_pairs = [("en", "zh"), ("zh", "en")]
for src_lang, tgt_lang in tqdm(lang_pairs):
dev_dataset = dataset.map(
lambda sample: {
"input_str": sample[src_lang] + f"\n\nTranslate into {LANG_DICT[tgt_lang]}:",
"target_str": sample[tgt_lang],
}
).select_columns(["input_str", "target_str"])

# 初始化字典,键为阈值,值为困惑度列表
weighted_ppl = {v: [] for v in np.arange(0, 1.1, 0.1)}
for input_str, target_str in tqdm(zip(*dev_dataset[:100].values()), total=100):
combined_str = LLAMA_TEMPLATE.format(src=input_str.strip()) + f"{target_str.strip()}<|eot_id|>"
lm_inputs = tokenizer(combined_str, max_length=2048, truncation=True, padding=False, add_special_tokens=False,return_tensors="pt").to(DEVICE)
input_str_len = tokenizer(LLAMA_TEMPLATE.format(src=input_str.strip()), add_special_tokens=False, return_tensors="pt")["input_ids"].size(-1)
labels = copy.deepcopy(lm_inputs["input_ids"]).to(DEVICE)
labels[:, :input_str_len] = -100

with torch.no_grad():
for min_weight in np.arange(0, 1.1, 0.1): # 遍历不同的阈值
mask_weight = (weight_dict[f"{src_lang}_{tgt_lang}"].sigmoid() >= 0.5).float().numpy().clip(min=min_weight) # 阈值越高,激活的头越少
mask_tensor = torch.tensor(mask_weight).unsqueeze(0).repeat(1, 1).to(model.device)
weighted_output = model(**lm_inputs, labels=labels, weight_tensor=mask_tensor)
weighted_ppl[min_weight].append(weighted_output.loss.item())

weighted_ppl = {k: np.nanmean(weighted_ppl[k]) for k in weighted_ppl} # 计算每个阈值的平均困惑度
langs_weighted_ppl[f"{src_lang}_{tgt_lang}"] = weighted_ppl

df = pd.DataFrame.from_dict(langs_weighted_ppl, orient="index")
display(df)

0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 100/100 [00:38<00:00,  2.61it/s]
50%|█████     | 1/2 [00:38<00:38, 38.36s/it]
100%|██████████| 100/100 [00:38<00:00,  2.62it/s]
100%|██████████| 2/2 [01:16<00:00, 38.25s/it]

        0.0	        0.1	        0.2	        0.3	        0.4	        0.5	        0.6	        0.7	        0.8	        0.9	        1.0
en_zh	1.004982	1.002170	1.005190	1.010720	1.014762	1.018487	1.022877	1.029847	1.045129	1.063954	1.091267
zh_en	0.986755	0.986517	0.987423	0.988694	0.988407	0.989914	0.989815	0.991908	0.994180	0.999186	1.002608
  1. 绘制消融实验图:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
colors = sns.color_palette(n_colors=2)  # 获取2种颜色的调色板

plt.figure(figsize=(10, 6)) # 尺寸
plt.ylim(0.9, 1.1) # y轴维度
plt.xticks(np.arange(0, 1.1, 0.1), fontsize=12) # x轴范围
plt.xlabel(r"Scaling factor $\alpha$", fontsize=18)
plt.ylabel(r"PPL with respective $\mathcal{M}$", fontsize=18)

maskonly_ppl = {"en_zh": 1.016755, "zh_en": 0.997446} # baseline
for i, (lang_pair, values) in enumerate(df.iterrows()):
plt.plot(np.arange(0, 1.1, 0.1), values, linestyle='-', label=lang_pair, c=colors[i])
plt.plot(0, maskonly_ppl[lang_pair], marker="x", c=colors[i])
plt.grid(alpha=0.3) # 网格线
plt.legend(ncol=3, fontsize=12, loc="lower right") # 图例
plt.tight_layout() # 布局
plt.show()
  1. 评估中英互译任务,指标为 BLEU,数据集为 IWSLT2017:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
LANG_DICT = {"fr": "French", "de": "German", "en": "English", "zh": "Chinese"}
LANGPAIR_MIN = {"en_de": 0.3, "en_fr": 0.2, "en_zh": 0.1, "zh_en": 0.3, "fr_en": 0.2, "de_en": 0.4}

langs_original_bleu, langs_weighted_bleu = {}, {}
smoothing = SmoothingFunction().method1


# 通用后处理函数
def general_postprocess(text):
# 移除指令部分
answer_text = text.split("\n\n")
if len(answer_text) > 2:
answer_text = answer_text[1:2]
elif len(answer_text) > 1:
answer_text = answer_text[1:]
answer_text = "\n\n".join(answer_text)

no_punctuation = re.sub(r'[^\w\s]', '', answer_text) # 移除标点符号
cleaned_text = re.sub(r'\s+', ' ', no_punctuation).strip() # 移除多余的空格
cleaned_text = " ".join(jieba.cut(cleaned_text)) # 对中文进行分词

return cleaned_text


# 后处理和评分函数
def postprocess_and_score(text, target_str):
answer_text = text.split("\n\n")
cleaned_texts = [" ".join(jieba.cut(re.sub(r'\s+', ' ', re.sub(r'[^\w\s]', '', t)).strip())) for t in answer_text] # 对每个部分进行清洗和分词
scores = [sentence_bleu([target_str], t, smoothing_function=smoothing) for t in cleaned_texts] # 计算每个部分的BLEU分数

return max(zip(cleaned_texts, scores), key=lambda x: x[1]) # 返回分数最高的文本和分数


for src_lang, tgt_lang in [("en", "zh"), ("zh", "en")]:
# 加载IWSLT2017测试数据集
if tgt_lang != "en":
dev_dataset = datasets.Dataset.from_parquet(f"dataset/iwslt2017/{src_lang}-{tgt_lang}-test.parquet")
else:
dev_dataset = datasets.Dataset.from_parquet(f"dataset/iwslt2017/{tgt_lang}-{src_lang}-test.parquet")

dev_dataset = dev_dataset.map(
lambda sample: {
"input_str": f"{sample['translation'][src_lang]}\n\nTranslate into {tgt_lang}:",
"target_str": sample['translation'][tgt_lang],
}
).select_columns(["input_str", "target_str"])
dev_dataset = dev_dataset.select(range(len(dataset) - 100, len(dataset)))

mask_weight = (weight_dict[f"{src_lang}_{tgt_lang}"].sigmoid() >= 0.5).float().numpy().clip(min=LANGPAIR_MIN[f"{src_lang}_{tgt_lang}"])

original_bleu, weighted_bleu = [], []
for sample in tqdm(dev_dataset):
input_str, target_str = sample["input_str"], sample["target_str"]
target_str = general_postprocess(target_str)
combined_str = LLAMA_TEMPLATE.format(src=input_str.strip())
lm_inputs_src = tokenizer(combined_str, max_length=2048, truncation=True, padding=False, add_special_tokens=False,return_tensors="pt").to(DEVICE)

with torch.no_grad():
# 不使用 Mask
original_generate_ids = model.generate(**lm_inputs_src, max_new_tokens=50, weight_tensor=None, do_sample=False)
original_pred_str = tokenizer.decode(original_generate_ids[0][lm_inputs_src.input_ids.size(1):], skip_special_tokens=True)
original_pred_str, original_bleu4 = postprocess_and_score(original_pred_str, target_str) # 后处理并计算BLEU-4分数
original_bleu.append(original_bleu4)
# 使用 Mask
mask_tensor = torch.tensor(mask_weight).unsqueeze(0).repeat(1, 1).to(model.device)
weighted_generate_ids = model.generate(**lm_inputs_src, max_new_tokens=50, weight_tensor=mask_tensor, do_sample=False)
weighted_pred_str = tokenizer.decode(weighted_generate_ids[0][lm_inputs_src.input_ids.size(1):], skip_special_tokens=True)
weighted_pred_str, weighted_bleu4 = postprocess_and_score(weighted_pred_str, target_str)
weighted_bleu.append(weighted_bleu4)

# 计算平均BLEU分数
langs_original_bleu[f"{src_lang}_{tgt_lang}"] = np.nanmean(original_bleu)
langs_weighted_bleu[f"{src_lang}_{tgt_lang}"] = np.nanmean(weighted_bleu)
print(f"{src_lang}_{tgt_lang}: {np.nanmean(original_bleu)} {np.nanmean(weighted_bleu)}")


original_df = pd.DataFrame.from_dict(langs_original_bleu, orient="index")
display(original_df)
weighted_df = pd.DataFrame.from_dict(langs_weighted_bleu, orient="index")
display(weighted_df)

Loading model from cache /tmp/jieba.cache
Loading model cost 0.598 seconds.
Prefix dict has been built successfully.
100%|██████████| 100/100 [02:09<00:00,  1.30s/it]
en_zh: 0.3118341977481546 0.3301332102535084
100%|██████████| 100/100 [01:50<00:00,  1.10s/it]
zh_en: 0.5805953950107409 0.5813078377389967

        0
en_zh	0.311834
zh_en	0.580595
        0
en_zh	0.330133
zh_en	0.581308

2.3 研究

  1. 对比有指令和无指令的机器翻译结果:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def translate(input_str, src_lang="en", tgt_lang="zh"):
# 打印激活的注意力头数量
mask_weight = (weight_dict[f"{src_lang}_{tgt_lang}"].sigmoid() >= 0.5).float().numpy().clip(min=0.0)
print(mask_weight.sum())

with torch.no_grad():
lm_inputs_src = tokenizer([LLAMA_TEMPLATE.format(src=input_str)+""], add_special_tokens=False, return_tensors="pt").to(DEVICE)
generate_ids = model.generate(**lm_inputs_src, max_new_tokens=50, weight_tensor=None, do_sample=False)
pred_str = tokenizer.decode(generate_ids[0][lm_inputs_src.input_ids.size(1):], skip_special_tokens=True)
print("[ORIGINAL]", pred_str)

mask_tensor = torch.tensor(mask_weight).unsqueeze(0).repeat(1, 1).to(model.device)
generate_ids = model.generate(**lm_inputs_src, max_new_tokens=50, weight_tensor=mask_tensor, do_sample=False)
pred_str = tokenizer.decode(generate_ids[0][lm_inputs_src.input_ids.size(1):], skip_special_tokens=True)
print(f"[{tgt_lang}]", pred_str)

random_tensor = torch.tensor(mask_weight)[torch.randperm(len(mask_weight))].unsqueeze(0).repeat(1, 1).to(model.device) # 创建随机重排的掩码张量
generate_ids = model.generate(**lm_inputs_src, max_new_tokens=50, weight_tensor=random_tensor, do_sample=False)
pred_str = tokenizer.decode(generate_ids[0][lm_inputs_src.input_ids.size(1):], skip_special_tokens=True)
print("[RANDOM]", pred_str)

translate("I have never seen such a beautiful sunset.")
translate("This is expected as the model has lost the majority of its attention heads.\n\nTranslate into Chinese:")

856.0
[ORIGINAL] I'm glad you're enjoying the moment! Sunsets can be truly breathtaking, with their vibrant colors and serene atmosphere. They have a way of evoking feelings of peace and wonder. What made this sunset particularly special to you? Was it the colors
[zh] 我从来没有见过如此美丽的日落。
[RANDOM] As a conversational AI, I don't have personal experiences, but I can tell you that sunsets can be truly breathtaking. The colors of the sky during a sunset are a result of a combination of atmospheric conditions, including the scattering of light,

856.0
[ORIGINAL] 这也算是预期的,因为模型已经失去了大多数注意力头。
[zh] 这预计是因为模型失去了大多数注意力头。
[RANDOM] This is expected as the model has lost the majority of its attention.
  1. 计算每层的输出的平均相似度和前 5 个预测的 token:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
src_lang = "en"
tgt_lang = "zh"
input_str = "I see a llama sleeping in my backyard."
input_str_inst = input_str + "\n\nTranslate into Chinese:"
target_str = "我看见一只羊驼在我的后院睡觉。"


with torch.no_grad():
lm_inputs_src_inst = tokenizer([LLAMA_TEMPLATE.format(src=input_str_inst)+""], add_special_tokens=False, return_tensors="pt").to(DEVICE)
lm_inputs_src = tokenizer([LLAMA_TEMPLATE.format(src=input_str)+""], add_special_tokens=False, return_tensors="pt").to(DEVICE)

original_output = model(**lm_inputs_src_inst, output_hidden_states=True, output_attentions=True)
layer_outputs = torch.stack(original_output.hidden_states) # 堆叠所有层的隐藏状态
layer_token_logits = model.lm_head(layer_outputs[:, 0, -1]) # 获取每个层最后一个token的预测logits
original_layer_pred_tokens = tokenizer.batch_decode(layer_token_logits.argmax(-1)) # 获取每个层预测的token
original_layer_pred_tokens_top5 = [tokenizer.batch_decode(layer_token_logits.argsort(dim=-1, descending=True)[:, k]) for k in range(5)] # 获取每个层的前5个预测token
original_res = layer_outputs[1:, 0, -1] - layer_outputs[:-1, 0, -1] # 计算残差:每层输出与前一层的差异

mask_tensor = torch.tensor((weight_dict["en_zh"].sigmoid() >= 0.5).float().numpy()).unsqueeze(0).repeat(1, 1).to(model.device)
weighted_output = model(**lm_inputs_src, weight_tensor=mask_tensor, output_hidden_states=True, output_attentions=True)
weighted_layer_outputs = torch.stack(weighted_output.hidden_states)
weighted_layer_token_logits = model.lm_head(weighted_layer_outputs[:, 0, -1])
weighted_layer_pred_tokens = tokenizer.batch_decode(weighted_layer_token_logits.argmax(-1))
weighted_layer_pred_tokens_top5 = [tokenizer.batch_decode(weighted_layer_token_logits.argsort(dim=-1, descending=True)[:, k]) for k in range(5)]
weighted_res = weighted_layer_outputs[1:, 0, -1] - weighted_layer_outputs[:-1, 0, -1]

print("hidden similarity\n", torch.cosine_similarity((layer_outputs[:, 0, -1]), (weighted_layer_outputs[:, 0, -1]))) # 计算隐藏状态的余弦相似度
print("lm_head logits similarity\n", torch.cosine_similarity(layer_token_logits, weighted_layer_token_logits)) # 计算lm_head输出的logits相似度
print("residual similarity\n", torch.cosine_similarity((layer_outputs[1:, 0, -1] - layer_outputs[:-1, 0, -1]), (weighted_layer_outputs[1:, 0, -1] - weighted_layer_outputs[:-1, 0, -1]))) # 计算残差的余弦相似度

pd.set_option("display.max_colwidth", None) # 不限制列宽
pd.set_option("display.max_rows", None) # 显示所有行
df = pd.DataFrame({"original": original_layer_pred_tokens, "weighted": weighted_layer_pred_tokens})
df1 = pd.DataFrame(original_layer_pred_tokens_top5).transpose() # 创建DataFrame显示原始模型各层的前5个预测token
display(df1)

hidden similarity
tensor([1.0000, 0.7852, 0.7617, 0.7852, 0.7422, 0.7539, 0.7266, 0.7422, 0.6875,
        0.5859, 0.5508, 0.5859, 0.5195, 0.5430, 0.4531, 0.5117, 0.5938, 0.6406,
        0.7109, 0.7383, 0.7383, 0.7656, 0.7930, 0.8125, 0.8242, 0.8438, 0.8477,
        0.8477, 0.8438, 0.8359, 0.8320, 0.8125, 0.7500], device='cuda:0',
    dtype=torch.bfloat16)
lm_head logits similarity
tensor([1.0078, 0.7930, 0.7773, 0.8008, 0.7539, 0.8008, 0.8281, 0.8398, 0.7656,
        0.6719, 0.6406, 0.7305, 0.7031, 0.5430, 0.4609, 0.5078, 0.6016, 0.6562,
        0.8281, 0.8516, 0.9062, 0.9219, 0.9609, 0.9688, 0.9727, 0.9766, 0.9766,
        0.9805, 0.9766, 0.9766, 0.9727, 0.9258, 0.7188], device='cuda:0',
    dtype=torch.bfloat16)
residual similarity
tensor([0.7383, 0.6719, 0.6719, 0.5430, 0.6094, 0.5234, 0.5234, 0.4316, 0.2520,
        0.2656, 0.2871, 0.2852, 0.2314, 0.2969, 0.3125, 0.3984, 0.4570, 0.6133,
        0.5977, 0.6133, 0.7031, 0.7578, 0.7188, 0.7578, 0.8086, 0.7305, 0.7109,
        0.7617, 0.7461, 0.7070, 0.7969, 0.7773], device='cuda:0',
    dtype=torch.bfloat16)
    
    0	        1	        2	            3	            4
0	illo	    abil	    incer	        otron	        câ
1	cheng	    the	        čin	            utas	        Alam
2	'gc	        alink		                utas	        .netbeans
3	'gc		                .netbeans	    -toggler	    cheng
4	'gc	        .netbeans	reff	        .CR	            edn
5	'gc	        -toggler	шиб	            ATAB	        \Dependency
6	#ab	        'gc	        #ac	            kke	            ーニ
7	-*-\r\n	    #ad	        #ab	            .netbeans	    'gc
8	LLU	        -LAST	    -*-\r\n	        'gc	            #ab
9	emain	    'gc	        #af	            ektor	        chalk
10	poil	    .SIG	    GetInt	        emain	        >tag
11	'gc	        >tag	    .SIG	        цес	            poil
12	ruz	        감	        корист	        ncy	            >tag
13	'gc	        корист	    pNet	        dime	        Sharper
14	'gc	        pNet	    .Reporting	    Sharper	        isci
15	-wsj	    �	        ContentLoaded	меть	        pNet
16	-wsj	    >tag	    :uint	        lap	            hon
17	RetVal	    hon	        )frame	        -wsj	        Macy
18	hon	        artz	    increasingly	confidently	    RetVal
19	confidently	bes	        increasingly	WSTR	        hon
20	bes	        hon	        increasingly		            (
21	hon	        bes	        in	            increasingly	p
22	I	        my	        in	            c	            p
23	I	        my	        you	            c	            me
24	I	        you	        my	            in	            ..\n
25	my	        I	        ..\n	        in	            you
26	my	        ..\n	    in	            you	            I
27	my	        I	        ..\n	        in	            you
28	in	        my	        you	            ..\n	        …\n
29	…\n	        ..\n	    in	            /stdc	        backyard
30	…\n	        /stdc	    in	            ..\n	        _exempt
31	backyard	_exempt	    我	            in	            ..\n
32	我	        你	        您	            在	            有
  1. 计算并绘制每层输出之间的平均相似度:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
with torch.no_grad():
lm_inputs_src_inst = tokenizer([LLAMA_TEMPLATE.format(src=input_str_inst)+target_str], add_special_tokens=False, return_tensors="pt").to(DEVICE)
lm_inputs_src = tokenizer([LLAMA_TEMPLATE.format(src=input_str)+target_str], add_special_tokens=False, return_tensors="pt").to(DEVICE)
mask_tensor = torch.tensor((weight_dict["en_zh"].sigmoid() >= 0.5).float().numpy()).unsqueeze(0).repeat(1, 1).to(model.device)

original_output = model(**lm_inputs_src, output_hidden_states=True, output_attentions=True)
layer_outputs = torch.stack(original_output.hidden_states)
layer_token_logits = model.lm_head(layer_outputs)
layer_pred_tokens = tokenizer.batch_decode(layer_token_logits[:, 0, -15].argmax(-1)) # 解码倒数第15个位置的预测token

original_inst_output = model(**lm_inputs_src_inst, output_hidden_states=True, output_attentions=True)
inst_layer_outputs = torch.stack(original_inst_output.hidden_states)
inst_layer_token_logits = model.lm_head(inst_layer_outputs)
inst_layer_pred_tokens = tokenizer.batch_decode(inst_layer_token_logits[:, 0, -15].argmax(-1))

weighted_output = model(**lm_inputs_src, weight_tensor=mask_tensor, output_hidden_states=True, output_attentions=True)
weighted_layer_outputs = torch.stack(weighted_output.hidden_states)
weighted_layer_token_logits = model.lm_head(weighted_layer_outputs)
weighted_layer_pred_tokens = tokenizer.batch_decode(weighted_layer_token_logits[:, 0, -15].argmax(-1))

weighted_inst_output = model(**lm_inputs_src_inst, weight_tensor=mask_tensor, output_hidden_states=True, output_attentions=True)
weighted_inst_layer_outputs = torch.stack(weighted_inst_output.hidden_states)
weighted_inst_layer_token_logits = model.lm_head(weighted_inst_layer_outputs)
weighted_inst_layer_pred_tokens = tokenizer.batch_decode(weighted_inst_layer_token_logits[:, 0, -15].argmax(-1))

# 计算倒数第15-14个位置隐藏状态的平均相似度
cosine_sim1 = torch.cosine_similarity((inst_layer_outputs[:, 0, -15:-14].mean(dim=-2)), (weighted_layer_outputs[:, 0, -15:-14].mean(dim=-2))).float().cpu().numpy()
cosine_sim2 = torch.cosine_similarity((inst_layer_outputs[:, 0, -15:-14].mean(dim=-2)), (layer_outputs[:, 0, -15:-14].mean(dim=-2))).float().cpu().numpy()
cosine_sim3 = torch.cosine_similarity((weighted_layer_outputs[:, 0, -15:-14].mean(dim=-2)), (weighted_inst_layer_outputs[:, 0, -15:-14].mean(dim=-2))).float().cpu().numpy()
print("Similarity between w/ mask vs. w/ instruction\n", cosine_sim1) # 有指令 vs 有掩码(无指令)
print("Similarity between Original vs. w/ instruction\n", cosine_sim2) # 有指令 vs 原始(无掩码无指令)
print("Similarity between w/ mask vs. w/ mask + instruction\n", cosine_sim3) # 有掩码 vs 有掩码+指令


# 绘制对应图像
(b1, _, b2), (r1, _, r2), (g1, _, g2) = sns.color_palette("Blues", 3), sns.color_palette("Reds", 3), sns.color_palette("Greens", 3) # 定义颜色方案
plt.figure(figsize=(12, 6))
plt.xlabel("Layer No.", fontsize=18)
plt.xticks(np.arange(1, 32, 4), labels=np.arange(0, 32, 4), fontsize=12)
plt.xticks(np.arange(33), minor=True)
plt.ylabel("Cosine similarity of layer output", fontsize=18)
plt.yticks(fontsize=12)
plt.plot(cosine_sim2, label="Original vs. w/ instruction", c=b2)
plt.plot(cosine_sim1, label=r"w/ $\mathcal{M}$ vs. w/ instruction", c=r2)
plt.plot(cosine_sim3, label=r"w/ $\mathcal{M}$ vs. w/ $\mathcal{M}$ + instruction", c=g2)

# 添加预测token的文本标注
texts = []
for i, (xi, yi1, yi2, yi3) in enumerate(zip(np.arange(33), cosine_sim1, cosine_sim2, cosine_sim3)):
texts.append(plt.text(xi, yi1-0.05 if xi<16 else yi1+0.05, weighted_layer_pred_tokens[i], fontsize=9, ha="center")) # 为掩码模型的预测添加文本
texts.append(plt.text(xi, yi2+0.05 if xi<16 else yi2-0.05, layer_pred_tokens[i], fontsize=9, ha="center")) # 为原始模型的预测添加文本

adjust_text(texts, avoid_self=False) # 调整文本位置避免重叠
for (t1, t2), xi, yi1, yi2 in zip([(texts[i], texts[i+1]) for i in range(0, len(texts), 2)], np.arange(33), cosine_sim1, cosine_sim2):
# 连接掩码模型预测的点
x_text, y_text = t1.get_position()
plt.plot([xi, x_text], [yi1, y_text+0.02 if xi<16 else yi1+0.04], color=r2, linestyle="-", linewidth=0.5)
# 连接原始模型预测的点
x_text, y_text = t2.get_position()
plt.plot([xi, x_text], [yi2, y_text-0.02 if xi<16 else yi2-0.02], color=b2, linestyle="-", linewidth=0.5)

plt.ylim(-0.05, 1.1)
plt.legend(fontsize=15)
plt.tight_layout()
plt.show()

Similarity between w/ mask vs. w/ instruction
[1.         0.78515625 0.76171875 0.78515625 0.7421875  0.75390625
0.7265625  0.7421875  0.6875     0.5859375  0.55078125 0.5859375
0.51953125 0.54296875 0.453125   0.51171875 0.59375    0.640625
0.7109375  0.73828125 0.73828125 0.765625   0.79296875 0.8125
0.82421875 0.84375    0.84765625 0.84765625 0.84375    0.8359375
0.83203125 0.8125     0.75      ]
Similarity between Original vs. w/ instruction
[1.         0.96484375 0.90234375 0.90234375 0.8359375  0.81640625
0.76171875 0.75       0.72265625 0.6328125  0.57421875 0.60546875
0.58203125 0.6953125  0.66015625 0.625      0.59765625 0.53515625
0.4765625  0.44921875 0.42578125 0.39257812 0.38867188 0.36914062
0.36914062 0.40234375 0.41601562 0.42773438 0.42578125 0.44726562
0.44921875 0.3984375  0.26367188]
Similarity between w/ mask vs. w/ mask + instruction
[1.         0.984375   0.95703125 0.9296875  0.90234375 0.84375
0.8125     0.80078125 0.7734375  0.74609375 0.640625   0.66015625
0.640625   0.640625   0.5703125  0.58984375 0.62890625 0.68359375
0.74609375 0.7578125  0.765625   0.78515625 0.81640625 0.84765625
0.86328125 0.87890625 0.87890625 0.8828125  0.88671875 0.88671875
0.88671875 0.87890625 0.84375   ]
  1. 逐步移除功能性注意力头后计算每层输出的平均相似度:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
with torch.no_grad():
lm_inputs_src_inst = tokenizer([LLAMA_TEMPLATE.format(src=input_str_inst)], add_special_tokens=False, return_tensors="pt").to(DEVICE)
lm_inputs_src = tokenizer([LLAMA_TEMPLATE.format(src=input_str)], add_special_tokens=False, return_tensors="pt").to(DEVICE)

original_output = model(**lm_inputs_src_inst, output_hidden_states=True, output_attentions=True)
original_layer_outputs = torch.stack(original_output.hidden_states)
original_layer_token_logits = model.lm_head(original_layer_outputs[:, 0, -1]).float().cpu()

# 获取英语到中文的权重掩码
mask_weight = weight_dict["en_zh"].sigmoid().numpy()
mask_weight[32::33] = 0
order = mask_weight.argsort()[32:] # 按权重值排序,获取排序后的索引

# 创建全1的初始权重
full_weight = np.ones_like(weight_dict["en_zh"].sigmoid().numpy())
test_tensor = torch.tensor(full_weight).unsqueeze(0).repeat(1, 1).to(model.device)
weighted_outputs, weighted_logits = [], []
test_order = order[:320]

# 所有注意力头都激活
with torch.no_grad():
weighted_output = model(**lm_inputs_src, weight_tensor=test_tensor, output_hidden_states=True, output_attentions=True)
weighted_layer_outputs = torch.stack(weighted_output.hidden_states)
weighted_layer_token_logits = model.lm_head(weighted_layer_outputs[:, 0, -1])
weighted_outputs.append(weighted_layer_outputs[:, 0, -1].float().cpu())
weighted_logits.append(weighted_layer_token_logits.float().cpu())

# 逐个移除注意力头
for head_idx in tqdm(test_order):
test_tensor[0, head_idx] = 0 # 将当前注意力头的权重设为0
with torch.no_grad():
weighted_output = model(**lm_inputs_src, weight_tensor=test_tensor, output_hidden_states=True, output_attentions=True)
weighted_layer_outputs = torch.stack(weighted_output.hidden_states)
weighted_layer_token_logits = model.lm_head(weighted_layer_outputs[:, 0, -1])
weighted_outputs.append(weighted_layer_outputs[:, 0, -1].float().cpu())
weighted_logits.append(weighted_layer_token_logits.float().cpu())

# 堆叠所有步骤的结果
weighted_outputs = torch.stack(weighted_outputs)
weighted_logits = torch.stack(weighted_logits)

layers = np.arange(32)
colors = sns.color_palette("Blues", n_colors=36)[-32:] # 创建颜色映射

# 隐藏状态相似度可视化
plt.figure(figsize=(12, 6))
plt.xlabel("Removed heads", fontsize=18)
plt.xticks(np.arange(0, 1025, 64), labels=np.arange(0, 1025, 64), fontsize=12)
plt.ylabel("Cosine similarity of FFN layer output", fontsize=18)
plt.yticks(fontsize=12)
plt.ylim(-0.05, 1.05)

# 计算每个层的余弦相似度序列
layer_cosine_sim = [torch.cosine_similarity(original_layer_outputs[l+1, 0, -1].squeeze(0).float().cpu(), weighted_outputs[:, l+1]).float().cpu().numpy() for l in range(len(layers))]
for l in layers:
plt.plot(layer_cosine_sim[l], c=colors[l], label=f"Layer {l}")

# 创建颜色条
norm = mpl.colors.Normalize(vmin=0, vmax=32)
sm = mpl.cm.ScalarMappable(cmap=mpl.colors.ListedColormap(colors), norm=norm)
sm.set_array([])

# 添加颜色条
cbar = plt.colorbar(sm, ax=plt.gca(), orientation="vertical", pad=0.02)
cbar.set_label("Layer No.", fontsize=12)
cbar.set_ticks(np.arange(0, 32)+0.5)
cbar.set_ticklabels(np.arange(0, 32), fontsize=12)
cbar.ax.tick_params(labelsize=12)
plt.tight_layout()
plt.show()

# Logits相似度可视化
plt.figure(figsize=(12, 6))
plt.xlabel("# of removed heads", fontsize=12)
plt.ylabel("Similarity", fontsize=12)
for l in layers:
plt.plot(torch.cosine_similarity(original_layer_token_logits[l+1].squeeze(0).float().cpu(), weighted_logits[:, l+1]).float().cpu().numpy(), c=colors[l], label=f"Layer {l}")
plt.ylim(-0.05, 1.05)
plt.title("Cosine similarities: Layer lm_head logits")
plt.tight_layout()
plt.show()

100%|██████████| 320/320 [01:15<00:00,  4.24it/s]
  1. 跟踪特定 token 的 logits 值和排名如何随着注意力头的逐步移除而变化:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
lm_inputs_src = tokenizer([LLAMA_TEMPLATE.format(src=input_str)], add_special_tokens=False, return_tensors="pt").to(DEVICE)
original_output_logits = model(**lm_inputs_src).logits[0, -1].detach().cpu().numpy()

# 获取原始模型预测的前5个最可能token
topk_token_ids = original_output_logits.argsort()[::-1][:5]
topk_tokens = [tokenizer.decode(token_id, skip_special_tokens=True) for token_id in topk_token_ids]
query_token_id = tokenizer.encode("我", add_special_tokens=False)[0] # 定义查询token

pred_strs = []
pred_logits, pred_rank = [], []
test_order = order

# 所有注意力头都激活
with torch.no_grad():
generate_ids = model.generate(**lm_inputs_src, max_new_tokens=20, weight_tensor=test_tensor, do_sample=False)
pred_str = tokenizer.decode(generate_ids[0][lm_inputs_src.input_ids.size(1):], skip_special_tokens=True)
pred_strs.append(pred_str)
pd.DataFrame(pred_strs).to_csv("all.csv", encoding="utf-8")
weighted_output_logits = model(**lm_inputs_src, weight_tensor=test_tensor).logits[0, -1].cpu().numpy() # 获取最后一个位置的logits
pred_logits.append((weighted_output_logits[topk_token_ids].tolist() + [weighted_output_logits[query_token_id]])) # 原始模型前5个token的logits值 + 查询token"我"的logits值
pred_rank.append((weighted_output_logits.argsort()[::-1].argsort()[topk_token_ids].tolist() + [weighted_output_logits.argsort()[::-1].argsort()[query_token_id]])) # 原始模型前5个token的排名 + 查询token"我"的排名

# 逐个移除注意力头
for head_idx in tqdm(test_order):
test_tensor[0, head_idx] = 0
with torch.no_grad():
generate_ids = model.generate(**lm_inputs_src, max_new_tokens=20, weight_tensor=test_tensor, do_sample=False)
pred_str = tokenizer.decode(generate_ids[0][lm_inputs_src.input_ids.size(1):], skip_special_tokens=True)
pred_strs.append(pred_str)
pd.DataFrame(pred_strs).to_csv("remove.csv", encoding="utf-8")
weighted_output_logits = model(**lm_inputs_src, weight_tensor=test_tensor).logits[0, -1].cpu().numpy()
pred_logits.append((weighted_output_logits[topk_token_ids].tolist() + [weighted_output_logits[query_token_id]]))
pred_rank.append((weighted_output_logits.argsort()[::-1].argsort()[topk_token_ids].tolist() + [weighted_output_logits.argsort()[::-1].argsort()[query_token_id]]))

pred_logits = np.array(pred_logits)
pred_rank = np.array(pred_rank)

100%|██████████| 1024/1024 [18:31<00:00,  1.08s/it]
  1. 进行重排序消融实验:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 按影响力顺序移除注意力头
rank_diff = -(pred_rank[1:, -1] - pred_rank[:-1, -1])
rank_diff_order = rank_diff[:320].argsort()[::-1]

reorder_strs = []
reorder_rank = []
rerank_test_order = test_order[rank_diff_order]

# 按影响力顺序进行消融实验
for head_idx in tqdm(rerank_test_order):
test_tensor[0, head_idx] = 0
with torch.no_grad():
generate_ids = model.generate(**lm_inputs_src, max_new_tokens=20, weight_tensor=test_tensor, do_sample=False)
pred_str = tokenizer.decode(generate_ids[0][lm_inputs_src.input_ids.size(1):], skip_special_tokens=True)
reorder_strs.append(pred_str)
pd.DataFrame(reorder_strs).to_csv("test.csv", encoding="utf-8")
weighted_output_logits = model(**lm_inputs_src, weight_tensor=test_tensor).logits[0, -1].cpu().numpy()
reorder_rank.append((weighted_output_logits.argsort()[::-1].argsort()[topk_token_ids].tolist() + [weighted_output_logits.argsort()[::-1].argsort()[query_token_id]]))

reorder_rank = np.array(reorder_rank)

100%|██████████| 320/320 [06:43<00:00,  1.26s/it]
  1. 可视化结果:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# 颜色方案
colors = sns.color_palette("Blues", 5)[::-1]
red = sns.color_palette("Reds", 1)[-1]

# Logits变化图
fig, ax = plt.subplots(figsize=(14, 7))
for i in range(6):
if i < 5:
ax.plot(pred_logits[:, i], label=topk_tokens[i], color=colors[i])
else:
ax.plot(pred_logits[:, i], label="我", color=red)
ax.set_xlabel("Remove heads")
ax.set_ylabel("Logits")
ax.set_xticks(range(0, 1025, 32))
plt.legend()
plt.show()

# 排名变化图
fig, ax = plt.subplots(figsize=(14, 7))
for i in [4,3,2,1,0,5]:
if i < 5:
ax.plot(pred_rank[:, i], label=topk_tokens[i], color=colors[i])
else:
ax.plot(pred_rank[:, i], label="我", color=red)
ax.set_xlabel("Removed heads", fontsize=18)
ax.set_ylabel("Token rank", fontsize=18)
ax.set_xticks(np.arange(0, 1025, 64), labels=np.arange(0, 1025, 64), fontsize=12)
ax.set_xticks(np.arange(32, 1025, 64), minor=True)
ax.set_yticks(np.arange(0, 4001, 500), labels=np.arange(0, 4001, 500), fontsize=12)
ax.set_ylim(-100, 4000)
top_ax = ax.twiny()
top_ax.set_xlim(ax.get_xlim())
top_ax.xaxis.tick_top()
top_ax.tick_params(direction="in", which="both")
top_ax.set_xticks(np.arange(0, 1025, 64), labels=[])
top_ax.set_xticks(np.arange(32, 1025, 64), minor=True)
plt.gca().invert_yaxis()
plt.legend([ax.lines[i] for i in [4,3,2,1,0,5]], [topk_tokens[i] for i in range(5)] + ["我"], fontsize=12)
plt.show()
  1. 查看检查点结果:
1
2
3
4
5
6
7
8
9
step_weight_dict = {}
for root, dir, files in os.walk("output/llama/xnli/en_zh"):
for file in files:
if file.endswith(".safetensors"):
ckeckpoint = int(os.path.basename(root).split("-")[-1])
path = os.path.join(root, file)
step_weight_dict[ckeckpoint] = safetensors.torch.load_file(path)["tensor"]
step_weight_dict = dict(sorted(step_weight_dict.items()))
len(step_weight_dict)

200
  1. 可视化注意力头激活情况:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
(b1, b2), (r1, r2) = sns.color_palette("Blues", 2), sns.color_palette("Reds", 2)
step_weight_array = torch.stack([t.view(32, 33)[:, :-1].flatten().sigmoid() for t in step_weight_dict.values()]).numpy()

plt.figure(figsize=(12, 6))
plt.axhline(y=0.5, color="gray", linestyle="--")

# 遍历所有1024个注意力头
for i in range(1024):
if (step_weight_array[:100, i] < 0.5).all() or (step_weight_array[-1, i] >= 0.5 and (step_weight_array[:60, i] >= 0.5).any()) or (0.3 < step_weight_array[-1, i] <= 0.7):
c = b1 if step_weight_array[-1, i] >= 0.5 else r1
plt.plot(range(101), step_weight_array[:101, i], linewidth=0.5, alpha=0.1, color=c)
else:
c = b2 if step_weight_array[-1, i] >= 0.5 else r2
plt.plot(range(101), step_weight_array[:101, i], linewidth=0.5, alpha=1, color=c)

plt.xlabel("Training progress", fontsize=18)
plt.xticks(np.arange(0, 101, 20), labels=["0%", "10%", "20%", "30%", "40%", "50%"], fontsize=12)
plt.ylabel("Head weight (sigmoid)", fontsize=18)
plt.ylim(-0.02, 1.02)
plt.tight_layout()
plt.show()

【论文复现】Heads Up
http://xuan-van.github.io/68d593ffea71/
作者
文晋
发布于
2026年1月5日
许可协议