【论文复现】Retrieval Head

方法图示:

参考项目:nightdessert/Retrieval_Head

1 安装

1.1 虚拟环境

1
2
3
4
5
6
conda create -n retrieval python=3.8 -y
conda activate retrieval
pip install torch transformers==4.44.1 flash-attn rouge_score accelerate ipykernel

python -m ipykernel install --user --name retrieval
jupyter kernelspec list

1.2 项目结构

1
2
3
4
5
6
haystack_for_detect/  # 背景上下文目录
head_score/ # 保存检索头得分结果
results/graph/ # 保存实验结果

retrieval_head_detection.ipynb # 拆解 retrieval_head_detection.py
modeling_llama.py # 改编自 https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py

1.3 LLM

1
huggingface-cli download nreHieW/Llama-3.1-8B-Instruct --local-dir model/Llama-3.1-8B-Instruct

2 整体流程

2.1 准备工作

  1. 导入必要的库:
1
2
3
4
5
6
7
8
9
10
11
import os
import glob
import json
import torch
import time
from modeling_llama import LlamaForCausalLM
import numpy as np
from rouge_score import rouge_scorer
from collections import defaultdict
from datetime import datetime, timezone
from transformers import AutoTokenizer, AutoConfig
  1. 重新设置 RoPE:
1
2
3
4
5
def reset_rope(model, model_max_train_len, scaling_factor):
for l in model.model.layers:
l.self_attn.rotary_emb.scaling_factor = scaling_factor
l.self_attn.rotary_emb._set_cos_sin_cache(seq_len=model_max_train_len, device=l.self_attn.rotary_emb.inv_freq.device, dtype=torch.float32)
return
  1. 分数统计器:
1
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
  1. 定义大海捞针测试类:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
class LLMNeedleHaystackTester:
def __init__(
self,
needle="\nThe best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n", # 针
haystack_dir="./haystack_for_detect", # 背景上下文目录
retrieval_question="What is the best thing to do in San Francisco?", # 问题
results_version = 1, # 版本
context_lengths_min = 1000, # 上下文最小长度
context_lengths_max = 50000, # 上下文最大长度
context_lengths_num_intervals = 20, # 上下文长度的间隔数
context_lengths = None, # 上下文的长度
document_depth_percent_min = 0, # 文档的最小深度百分比
document_depth_percent_max = 100, # 文档的最大深度百分比
document_depth_percent_intervals = 10, # 文档深度百分比的间隔数
document_depth_percents = None, # 文档的深度百分比
document_depth_percent_interval_type = "linear", # 文档的深度百分比的间隔类型:linear 或 sigmoid
model_provider = "OpenAI", # 模型的提供程序:OpenAI 或 Anthropic
model_name='', # 模型名称
model_name_suffix=None, # 模型名称后缀
num_concurrent_requests = 1, # 并发请求数
save_results = True, # 是否将上下文保存到文件中
save_contexts = True, # 是否将上下文保存到文件中
final_context_length_buffer = 200, # 从输入上下文中保存的缓冲量
seconds_to_sleep_between_completions = None, # 两次完成之间休眠的秒数
print_ongoing_status = True # 是否打印正在进行的状态
):
if not needle or not haystack_dir or not retrieval_question:
raise ValueError("Needle, haystack, and retrieval_question must be provided.")

needles_and_stacks = [json.loads(l) for l in open(f"{haystack_dir}/needles.jsonl")] # 三条数据
self.needle_list = [l["needle"] for l in needles_and_stacks] # 大海捞针列表
self.haystack_dir_list = [f"{haystack_dir}/part{i}" for i in range(1, 4)] # 子目录列表
self.retrieval_question_list = [l["question"] for l in needles_and_stacks] # 问题列表
self.real_ansers_list = [l["real_needle"] for l in needles_and_stacks] # 真实答案列表
self.results_version = results_version
self.num_concurrent_requests = num_concurrent_requests
self.save_results = save_results
self.final_context_length_buffer = final_context_length_buffer
self.save_contexts = save_contexts
self.seconds_to_sleep_between_completions = seconds_to_sleep_between_completions
self.print_ongoing_status = print_ongoing_status
self.model_provider = model_provider
self.testing_results = []
self.head_counter = defaultdict(list)

if("/" in model_name):
self.model_version = model_name.split("/")[-1]
else:
self.model_version = model_name
if(model_name_suffix is not None): self.model_version += "_" + model_name_suffix

if context_lengths is None:
if context_lengths_min is None or context_lengths_max is None or context_lengths_num_intervals is None:
raise ValueError("Either context_lengths_min, context_lengths_max, context_lengths_intervals need to be filled out OR the context_lengths_list needs to be supplied.")
else: # 生成一个等间隔数字的列表,round负责四舍坞入,endpoint包含结束值,astype取整
self.context_lengths = np.round(np.linspace(context_lengths_min, context_lengths_max, num=context_lengths_num_intervals, endpoint=True)).astype(int)
else:
self.context_lengths = context_lengths

if document_depth_percents is None:
if document_depth_percent_min is None or document_depth_percent_max is None or document_depth_percent_intervals is None:
raise ValueError("Either document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals need to be filled out OR the document_depth_percents needs to be supplied.")
else:
if document_depth_percent_interval_type == 'linear': # 在最小值和最大值之间生成等间隔的百分比值
self.document_depth_percents = np.round(np.linspace(document_depth_percent_min, document_depth_percent_max, num=document_depth_percent_intervals, endpoint=True)).astype(int)
elif document_depth_percent_interval_type == 'sigmoid': # 生成S型曲线分布的百分比值,使中间区域更密集
self.document_depth_percents = [self.logistic(x) for x in np.linspace(document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals)]
else:
self.document_depth_percents = document_depth_percents
if document_depth_percent_interval_type not in [None, "linear", "sigmoid"]:
raise ValueError("document_depth_percent_interval_type must be either None, 'linear' or 'sigmoid'. If you'd like your own distribution give a list of ints in via document_depth_percent_intervals")

self.model_name = model_name
self.enc = AutoTokenizer.from_pretrained(model_name, use_fast=False)
print("loading from %s" % model_name)
config = AutoConfig.from_pretrained(model_name)
self.layer_num, self.head_num = config.num_hidden_layers, config.num_attention_heads
print(f"layer number: {self.layer_num}, head number {self.head_num}")
if "Qwen" in self.model_version:
self.model_to_test = Qwen2ForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map='auto', use_flash_attention_2="flash_attention_2"
).eval()
elif "Mixtral" in self.model_version:
self.model_to_test = MixtralForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map='auto', use_flash_attention_2="flash_attention_2", trust_remote_code=True,
).eval()
elif "Mistral" in self.model_version:
self.model_to_test = MistralForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map='auto', use_flash_attention_2="flash_attention_2", trust_remote_code=True,
).eval()
elif "Phi3" in self.model_version:
self.model_to_test = Phi3ForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map='auto', use_flash_attention_2="flash_attention_2", trust_remote_code=True,
).eval()
else:
self.model_to_test = LlamaForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map='auto', use_flash_attention_2="flash_attention_2",
).eval()

# 位置编码进行特殊配置,调整模型处理长上下文的能力
if 'llama-2-7b-80k' in self.model_version:
scaling_factor = 10
reset_rope(self.model_to_test, model_max_train_len=81920, scaling_factor=scaling_factor)

if "CUDA_VISIBLE_DEVICES" in os.environ:
self.multi_gpus = len(os.environ["CUDA_VISIBLE_DEVICES"])>1
else:
self.multi_gpus = True

self.model_to_test_description = model_name
self.evaluation_model = None
self.debug='debug'


def logistic(self, x, L=100, x0=50, k=.1):
if x == 0:
return 0
if x == 100:
return 100
return np.round(L / (1 + np.exp(-k * (x - x0))), 3)


# 开始测试
def start_test(self, args):
for ni in range(len(self.needle_list)):
self.needle = self.needle_list[ni]
self.haystack_dir = self.haystack_dir_list[ni]
self.real_needle = self.real_ansers_list[ni]
self.retrieval_question = self.retrieval_question_list[ni]
if self.print_ongoing_status:
self.print_start_test_summary()
self.run_test(args)

# 如果已经存在,则累加历史得分
if os.path.exists(f"head_score/{self.model_version}.json"):
with open(f"./head_score/{self.model_version}.json", "r") as file:
head_counter = json.loads(file.readline())
for k,v in head_counter.items():
self.head_counter[k] += v

with open(f"head_score/{self.model_version}.json", 'w') as f:
json.dump(self.head_counter, f)


# 打印进程结果
def print_start_test_summary(self):
print ("\n")
print ("Starting Needle In A Haystack Testing...")
print (f"- Model: {self.model_name}")
print (f"- Context Lengths: {len(self.context_lengths)}, Min: {min(self.context_lengths)}, Max: {max(self.context_lengths)}")
print (f"- Document Depths: {len(self.document_depth_percents)}, Min: {min(self.document_depth_percents)}%, Max: {max(self.document_depth_percents)}%")
print (f"- Needle: {self.needle.strip()}")
print ("\n\n")


# 执行测试
def run_test(self, args):
tasks = []

# 遍历所有预设的上下文长度和文档深度百分比的组合
for context_length in self.context_lengths:
if context_length < args.s_len or context_length > args.e_len:
continue
for depth_percent in self.document_depth_percents:
task = self.bound_evaluate_and_log(context_length, depth_percent)


# 转发 evaluate_and_log 方法
def bound_evaluate_and_log(self, *args):
self.evaluate_and_log(*args)


# 评估并记录
def evaluate_and_log(self, context_length, depth_percent):
context = self.generate_context(context_length, depth_percent) # 生成测试上下文
question = f"Based on the content of the book, Question: {self.retrieval_question}\nAnswer:" # 构造问题

# 构造模型输入
if self.model_version in ["Mistral-7B-Instruct-v0.2", "Qwen1.5-14B-Chat"]: # 聊天模板
prompt = [
{"role": "user", "content": f"<book>{context}</book>\nBased on the content of the book, Question: {self.retrieval_question}\nAnswer:"},
]
input_ids = self.enc.apply_chat_template(conversation=prompt, tokenize=True, add_generation_prompt=True, return_tensors='pt')
else: # 拼接
input_context = context + question
input_ids = self.enc(input_context , return_tensors="pt")['input_ids']

test_start_time = time.time()
self.prompt_ids = input_ids[0, :]
if not self.multi_gpus:
input_ids = input_ids.to(self.model_to_test.device)
self.needle_start, self.needle_end = self.find_needle_idx(self.real_needle) # 寻找针的token位置

# 模型推理
with torch.no_grad():
q_outputs = self.model_to_test(input_ids=input_ids[:,:-1], use_cache=True, return_dict=True)
output, retrieval_score = self.decode(q_outputs, input_ids[:,-1], 50)
response = self.enc.decode(output, skip_special_tokens=True).strip()

test_end_time = time.time()
test_elapsed_time = test_end_time - test_start_time

# 评估回答
score = scorer.score(self.real_needle, response)['rouge1'].recall*100
if score > 50: # 回答正确,则更新注意力头的检索得分
self.retrieval_head_accumulate(retrieval_score)
head_score = [(i[0], np.mean(i[1])) for i in self.head_counter.items()]
head_score = sorted(head_score, key=lambda x:x[1], reverse=True)
print([[i[0]] for i in head_score][:20])

# 保存结果
results = {
'model': self.model_to_test_description,
'context_length': int(context_length),
'depth_percent': float(depth_percent),
'version': self.results_version,
'needle': self.needle,
'model_response': response,
'score': score,
'test_duration_seconds': test_elapsed_time,
'test_timestamp_utc': datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S%z')
}
self.testing_results.append(results)

# 打印进度
if self.print_ongoing_status:
print (f"-- Test Summary -- ")
print (f"Duration: {test_elapsed_time:.1f} seconds")
print (f"Context: {context_length} tokens")
print (f"Depth: {depth_percent}%")
print (f"Score: {score}")
print (f"Response: {response}\n")

# 保存上下文和结果
context_file_location = f'{self.model_version.replace(".", "_")}_len_{context_length}_depth_{int(depth_percent*100)}'

if self.save_contexts:
results['file_name'] : context_file_location

if not os.path.exists('contexts'):
os.makedirs('contexts')

if not os.path.exists(f'contexts/{self.model_version}'):
os.makedirs(f'contexts/{self.model_version}')

with open(f'contexts/{self.model_version}/{context_file_location}_context.txt', 'w') as f:
f.write(context)

if self.save_results:
if not os.path.exists(f'results/graph/{self.model_version}'):
os.makedirs(f'results/graph/{self.model_version}')

p = f'results/graph/{self.model_version}/{context_file_location}_results.json'
print("Writing at %s" % p)
with open(p, 'w') as f:
json.dump(results, f)


# 生成上下文
def generate_context(self, context_length, depth_percent):
context = self.read_context_files() # 读取原始上下文文件
context = self.encode_and_trim(context, context_length) # 编码和裁剪上下文
context = self.insert_needle(context, depth_percent, context_length) # 插入needle
return context


# 寻找针的token位置
def find_needle_idx(self, needle):
needle_ids = self.enc(needle, add_special_tokens=False)["input_ids"]
print(self.enc.decode(needle_ids, skip_special_tokens=False))
span_len = len(needle_ids)
for i in range(len(self.prompt_ids)): # 滑动窗口搜索
token_span = self.prompt_ids[i : i + span_len] # 提取当前窗口的token ID子序列token_span
span_ids = set(token_span.tolist()) # 将token_span和needle_ids转换为集合,计算它们的重叠率overlap
overlap = float(len(span_ids.intersection(set(needle_ids)))) / len(set(needle_ids))
if(overlap > 0.9): # 如果重叠率超过 90%,则认为找到了 needle,返回当前窗口的起始和结束位置
return i, i + span_len
return -1, -1


# 自回归解码模型输出并计算注意力头的检索得分
def decode(self, q_outputs, inp, decode_len, block_list=None):
output, retrieval_score = [], [[[0, ''] for _ in range(self.head_num)] for _ in range(self.layer_num)]
past_kv = q_outputs.past_key_values
for step_i in range(decode_len):
inp = inp.view(1, 1)
outputs = self.model_to_test(input_ids=inp, past_key_values=past_kv, use_cache=True, output_attentions=True, attn_mode="torch" )
past_kv = outputs.past_key_values # 更新KV缓存
inp = outputs.logits[0, -1].argmax() # 贪婪解码
step_token = self.enc.convert_ids_to_tokens(inp.item()) # 将token ID转换为文本标记
output.append(inp.item())
self.retrieval_calculate(outputs.attentions, retrieval_score, inp, step_token) # 计算注意力头的检索得分
if step_token=='<0x0A>' or inp.item()==144: break
return output, retrieval_score


# 累记每个注意力头的得分
def retrieval_head_accumulate(self, retrieval_score):
for layer_idx in range(self.layer_num):
for head_idx in range(self.head_num):
self.head_counter[f"{layer_idx}-{head_idx}"].append(retrieval_score[layer_idx][head_idx][0])


# 读取原始上下文文件
def read_context_files(self):
context = ""
max_context_length = max(self.context_lengths)

while len(context.split()) < max_context_length:
for file in glob.glob(f"{self.haystack_dir}/*.txt"):
with open(file, 'r') as f:
context += f.read()
return context


# 编码和裁剪上下文
def encode_and_trim(self, context, context_length):
tokens = self.encode_text_to_tokens(context)
if len(tokens) > context_length:
context = self.decode_tokens(tokens, context_length)
return context


# 插入needle
def insert_needle(self, context, depth_percent, context_length):
tokens_needle = self.encode_text_to_tokens(self.needle)
tokens_context = self.encode_text_to_tokens(context)

# 留出缓冲区空间给系统消息、用户问题和回答
context_length -= self.final_context_length_buffer

# 如果上下文+needle超过限制,截断上下文
if len(tokens_context) + len(tokens_needle) > context_length:
tokens_context = tokens_context[:context_length - len(tokens_needle)]

if depth_percent == 100: # 直接追加到末尾
tokens_new_context = tokens_context + tokens_needle
else:
# 计算初始插入点
insertion_point = int(len(tokens_context) * (depth_percent / 100))
tokens_new_context = tokens_context[:insertion_point]

# 确定句号token
if(self.model_provider in ["LLaMA", "LongLLaMA"]): period_tokens = [29889, 869]
elif(self.model_provider == "Mistral"): period_tokens = [842, 28723]
elif(self.model_provider == "GLM"): period_tokens = [918, 30930]
else: period_tokens = self.encode_text_to_tokens('.')
while tokens_new_context and tokens_new_context[-1] not in period_tokens:
insertion_point -= 1
tokens_new_context = tokens_context[:insertion_point]

print("insertion at %d" % insertion_point)
tokens_new_context += tokens_needle + tokens_context[insertion_point:]

# 将token序列解码回文本
new_context = self.decode_tokens(tokens_new_context)
return new_context


# 计算和更新模型各注意力头在needle时的表现得分
def retrieval_calculate(self, attention_maxtrix,retrieval_score, inp, step_token, topk=1):
for layer_idx in range(self.layer_num):
for head_idx in range(self.head_num):
values, idx = attention_maxtrix[layer_idx][0][head_idx][-1].topk(topk)
for v, i in zip(values, idx):
# 如果某个注意力头成功关注到needle的位置,则为其累积得分,并记录相关的token信息
if self.needle_start <= i < self.needle_end and inp.item()==self.prompt_ids[i].item():
retrieval_score[layer_idx][head_idx][0] += 1/(self.needle_end - self.needle_start)
retrieval_score[layer_idx][head_idx][1] += step_token
break


# 文本转token序列
def encode_text_to_tokens(self, text):
if self.model_provider in ["OpenAI", "LLaMA", "Mistral", "GLM"]:
return self.enc.encode(text)
elif self.model_provider == "Anthropic":
return self.enc.encode(text).ids
else:
raise ValueError("model_provider must be either 'OpenAI' or 'Anthropic'")


# token序列转文本
def decode_tokens(self, tokens, context_length=None):
if self.model_provider in ["OpenAI", "LLaMA", "Mistral", "GLM"]:
return self.enc.decode(tokens[:context_length])
elif self.model_provider == "Anthropic":
return self.enc.decode(tokens[:context_length])
else:
raise ValueError("model_provider must be either 'OpenAI' or 'Anthropic'")

2.2 执行实验

  1. 模拟命令行参数:
1
2
3
4
5
6
7
8
9
10
11
class Args:
pass

args = Args()

args.model_name = '../model/Llama-3-8B-Instruct'
args.model_path = '../model/Llama-3-8B-Instruct'
args.model_name_suffix = None
args.model_provider = 'LLaMA'
args.s_len = 0
args.e_len = 5000
  1. 执行实验:
1
2
3
4
5
6
7
8
9
10
11
ht = LLMNeedleHaystackTester(
model_name=args.model_name,
model_name_suffix=args.model_name_suffix,
model_provider=args.model_provider,
save_contexts=False,
save_results=False,
context_lengths_min=args.s_len,
context_lengths_max=args.e_len,
)

ht.start_test(args)

loading from ../model/Llama-3-8B-Instruct
layer number: 32, head number 32
Loading checkpoint shards: 100%|██████████| 4/4 [00:25<00:00,  6.32s/it]


Starting Needle In A Haystack Testing...
- Model: ../model/Llama-3-8B-Instruct
- Context Lengths: 20, Min: 0, Max: 5000
- Document Depths: 10, Min: 0%, Max: 100%
- Needle: A new report from the WMO shows that records were once again broken, and in some cases smashed, for greenhouse gas levels, surface temperatures, ocean heat and acidification.


insertion at 0
records were once again broken, and in some cases smashed, for greenhouse gas levels, surface temperatures, ocean heat and acidification.
[['15-30'], ['16-20'], ['2-22'], ['24-27'], ['20-14'], ['5-8'], ['10-14'], ['15-1'], ['16-1'], ['5-11'], ['16-23'], ['27-7'], ['19-3'], ['20-1'], ['27-5'], ['8-1'], ['27-6'], ['19-13'], ['16-0'], ['20-23']]
-- Test Summary -- 
Duration: 4.9 seconds
Context: 0 tokens
Depth: 0%
Score: 100.0
Response: The report shows that records were once again broken, and in some cases smashed, for greenhouse gas levels, surface temperatures, ocean heat and acidification. This suggests that the report highlights the alarming rate of climate change and the urgent need for action to mitigate

...
  1. 查看结果:
1
2
3
4
5
6
7
with open('./head_score/Llama-3-8B-Instruct.json') as file:
head_list = json.loads(file.readline())

head_score_list = [([int(ll) for ll in l[0].split("-")],np.mean(l[1])) for l in head_list.items()]
head_score_list = sorted(head_score_list, key=lambda x: x[1], reverse=True)
top_retrieval_heads = [[l[0], round(np.mean(l[1]), 2)] for l in head_score_list][:10]
print(top_retrieval_heads)

[[[15, 30], 0.93], [[24, 27], 0.53], [[16, 1], 0.51], [[27, 7], 0.51], [[16, 20], 0.49], [[8, 1], 0.49], [[15, 1], 0.48], [[27, 5], 0.46], [[20, 14], 0.45], [[10, 14], 0.44]]

【论文复现】Retrieval Head
http://xuan-van.github.io/代码复现/【论文复现】retrieval-head/
作者
文晋
发布于
2025年7月29日
许可协议