【趣味研究】Embedding Rerank

1 安装

1.1 虚拟环境

1
2
3
4
5
conda create -n rerank python=3.10 -y
conda activate rerank
conda install -c conda-forge openjdk=21 maven -y
conda install -c pytorch faiss-cpu -y
pip install pyserini numpy==1.26.4 torch sentence_transformers nvitop accelerate vllm

1.2 模型和数据集

1
2
3
huggingface-cli download facebook/dpr-question_encoder-single-nq-base --local-dir model/DPR
huggingface-cli download --resume-download Salesforce/SFR-Embedding-Mistral --local-dir model/SFR-Embedding-Mistral
huggingface-cli download --token Your_token meta-llama/Meta-Llama-3-8B-Instruct --local-dir model/Llama-3-8B-Instruct

数据集详见:【论文复现】InstructRAG

2 数据准备

2.1 辅助工具

  1. 读取 JSON/JSONL 文件;
  2. 创建输出文件的目录。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import os
import json


# 读取JSON/JSONL文件
def read_json_data(file_path):
if file_path.endswith('.json'):
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
elif file_path.endswith('.jsonl'):
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data.append(json.loads(line))
return data
else:
raise ValueError("不支持的文件格式")


# 确保输出文件的目录存在,不存在则创建
def ensure_directory_exists(file_path):
directory = os.path.dirname(file_path)
if directory and not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
print(f"创建文件夹:{directory}")

2.2 检索相关文档

  1. 准备工作:读取数据集,选择检索模式,初始化检索器(第一次执行时会下载语料库);
  2. 批量检索:对于每个样本,针对 question 字段到语料库中进行检索;
  3. 结果保存:每个样本添加 contexts 列表字段,其中的每个元素包含 idscore 两个子字段,分别表示文档序号和检索分数。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import json
import argparse
import time
from datetime import timedelta
from multiprocessing import cpu_count
from utils import read_json_data, ensure_directory_exists
from pyserini.encode import DprQueryEncoder
from pyserini.search.lucene import LuceneSearcher
from pyserini.search.faiss import FaissSearcher
from pyserini.search.hybrid import HybridSearcher


# 初始化检索器
def initialize_searchers(mode):
if mode == 'sparse':
return LuceneSearcher.from_prebuilt_index('wikipedia-dpr')
elif mode == 'dense':
encoder = DprQueryEncoder("model/DPR")
return FaissSearcher.from_prebuilt_index('wikipedia-dpr-100w.dpr-single-nq', encoder)
elif mode == 'hybrid':
sparse_searcher = LuceneSearcher.from_prebuilt_index('wikipedia-dpr')
encoder = DprQueryEncoder("model/DPR")
dense_searcher = FaissSearcher.from_prebuilt_index('wikipedia-dpr-100w.dpr-single-nq', encoder)
return HybridSearcher(dense_searcher, sparse_searcher)
else:
raise ValueError(f"不支持的检索模式:{mode}")


# 批量检索
def batch_search(searcher, data, mode, top_k):
n = len(data)
print(f"读取到 {n} 条数据")

start_time = time.time()

if mode == 'sparse':
results = searcher.batch_search(
queries=[item['question'] for item in data],
qids=[str(i) for i in range(n)],
k=top_k,
threads=cpu_count()
)
elif mode == 'dense':
results = searcher.batch_search(
queries=[item['question'] for item in data],
q_ids=[str(i) for i in range(n)],
k=top_k,
threads=cpu_count()
)
else:
results = searcher.batch_search(
queries=[item['question'] for item in data],
q_ids=[str(i) for i in range(n)],
k0=args.top_k,
k=args.top_k,
threads=cpu_count()
)

end_time = time.time()
print(f"检索完成,耗时 {timedelta(seconds=end_time - start_time)}")
return results


# 解析命令行参数
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True, help="输入文件路径")
parser.add_argument("--output_file", type=str, required=True, help="输出文件路径")
parser.add_argument("--top_k", type=int, default=200, help="每个问题检索的文档数量")
parser.add_argument("--mode", type=str, choices=['sparse', 'dense', 'hybrid'], default='hybrid', help="检索模式:稀疏检索,密集检索,混合检索")

return parser.parse_args()


if __name__ == "__main__":
args = parse_args()

# 初始化检索器
searcher = initialize_searchers(args.mode)

# 读取输入数据
data = read_json_data(args.input_file)

# 批量检索
results = batch_search(searcher, data, args.mode, args.top_k)

# 将检索结果添加到数据中
for idx, item in enumerate(data):
contexts = []
for hit in results[str(idx)]:
contexts.append({
'id': hit.docid,
'score': hit.score
})
item['contexts'] = contexts

# 保存结果到输出文件
ensure_directory_exists(args.output_file)
with open(args.output_file, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=4, ensure_ascii=False)
print(f"处理完成,结果已保存到 {args.output_file}")

2.3 获取文档内容

  1. 准备工作:读取数据集,初始化检索器;
  2. 遍历处理:对于每个样本的 contexts 字段的每个元素,利用 id 字段获取文档的 titlecontent
  3. 保存结果:对于每个样本的 contexts 字段的每个元素,保存 titlecontentscore 字段。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import json
import argparse
from tqdm import tqdm
from utils import read_json_data, ensure_directory_exists
from pyserini.search.lucene import LuceneSearcher


# 解析命令行参数
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True, help="输入文件路径")
parser.add_argument("--output_file", type=str, required=True, help="输出文件路径")

return parser.parse_args()


if __name__ == "__main__":
args = parse_args()

# 初始化检索器
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-dpr')

# 读取输入数据
data = read_json_data(args.input_file)

# 遍历每个对象并处理contexts字段
for item in tqdm(data, desc="Preprocess"):
updated_contexts = []
for context in item['contexts']:
content = json.loads(searcher.doc(context['id']).raw())['contents'] # 获取文档内容
title, content = content.split('\n', 1) # 分割标题和内容
updated_contexts.append({
'title': title.strip('"'), # 去除多余的引号
'content': content,
'score': context['score'] # 保留score字段
})
item['contexts'] = updated_contexts # 更新contexts字段

# 保存结果到输出文件
ensure_directory_exists(args.output_file)
with open(args.output_file, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=4, ensure_ascii=False)
print(f"处理完成,结果已保存到 {args.output_file}")

3 相关性评估

3.1 评估文档是否相关

  1. 准备工作:读取数据集,加载模型,批量准备提示词(指令+问答+文档);
  2. 批量推理:让 LLM 判断每个样本中 context 字段的每个文档和问答是否相关;
  3. 结果保存:每个样本的 context 列表中的每个元素添加 relevance 字段,记录每个文档是否相关。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import json
import torch
import argparse
from utils import read_json_data, ensure_directory_exists
from vllm import LLM, SamplingParams
from tqdm import tqdm


# 创建提示模板
def build_prompt(question, answers, context, qa_pairs):
prompt = (
"Relevant means the passage contains information that helps to answer the question or supports one or more of the given answer choices.\nDetermine if the context is relevant to the question and answers.\nRespond with exactly one word: 'Yes' or 'No'.\n\n"
f"Main Question: {question}\n"
f"Main Answers: {'; '.join(answers)}\n\n"
)

# 添加子问题和子答案
if qa_pairs and len(qa_pairs) > 0:
prompt += "Main Question has their Sub-questions and answers:\n\n"
for i, pair in enumerate(qa_pairs, 1):
prompt += (
f"Sub-question {i}: {pair['question']}\n"
f"Sub-answers {i}: {'; '.join(pair['answers'])}\n\n"
)

prompt += (
f"Context Title: {context['title']}\n"
f"Context Content: {context['content']}\n\n"
"Response:"
)

return prompt

# 判断结果是否为相关
def parse_relevance(output):
text = output.strip().lower()
return text == "yes"


# 解析命令行参数
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True, help="输入文件路径")
parser.add_argument("--output_file", type=str, required=True, help="输出文件路径")
parser.add_argument("--llm_path", type=str, default="model/Llama-3-8B-Instruct", help="LLM模型路径")
parser.add_argument("--batch_size", type=int, default=64, help="批处理大小")
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()

data = read_json_data(args.input_file)

llm = LLM(model=args.llm_path, tensor_parallel_size=torch.cuda.device_count())
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=1)

# 批处理上下文
all_requests = []
request_indices = [] # 记录 (data_index, context_index) 用于结果对应
for i, item in enumerate(data):
question = item["question"]
answers = item["answers"]
qa_pairs = item["qa_pairs"]
for j, context in enumerate(item["contexts"]):
prompt = build_prompt(question, answers, context, qa_pairs)
all_requests.append(prompt)
request_indices.append((i, j))

# 分批推理
results = []
for i in tqdm(range(0, len(all_requests), args.batch_size)):
batch_prompts = all_requests[i:i + args.batch_size]
outputs = llm.generate(batch_prompts, sampling_params)
results.extend(outputs)

# 写入 relevance 字段
for output, (data_idx, context_idx) in zip(results, request_indices):
data[data_idx]["contexts"][context_idx]["relevance"] = parse_relevance(output.outputs[0].text)

# 保存结果到文件
ensure_directory_exists(args.output_file)
with open(args.output_file, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=4, ensure_ascii=False)

3.2 评估答案召回率

  1. 准备工作:读取数据集,对于每个样本,选择是否只保留 LLM 认为相关的文档;
  2. 批量评估:根据输入的文档 top 数,评估答案是否在这些文档中;
  3. 打印结果:打印答案召回率。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import re
import string
import argparse
from multiprocessing import Pool, cpu_count
from functools import partial
from utils import read_json_data


# 标准化答案文本
def normalize_answer(s):
def remove_articles(text): # 移除冠词(a, an, the)
return re.sub(r"\b(a|an|the)\b", " ", text)

def white_space_fix(text): # 修复多余的空格,确保单词之间只有一个空格
return " ".join(text.split())

def remove_punc(text): # 移除所有标点符号
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)

def lower(text): # 将文本转换为小写
return text.lower()

# 依次应用上述函数:小写化 -> 移除标点 -> 移除冠词 -> 修复空格
return white_space_fix(remove_articles(remove_punc(lower(s))))


# 初始化合并后的答案列表
def merge_answers(item):
merged_answers = []

# 添加主 answers 中的答案(如果存在且不为空)
if "answers" in item and item["answers"]:
merged_answers.extend(item["answers"])

# 添加 qa_pairs 中的答案(如果存在且不为空)
if "qa_pairs" in item and item["qa_pairs"]:
for qa_pair in item["qa_pairs"]:
if "answers" in qa_pair and qa_pair["answers"]:
merged_answers.extend(qa_pair["answers"])

return merged_answers


# 检查答案是否出现在任何段落中
def check_answer_in_contexts(item, top_k):
# 获取所有可能的答案
all_answers = merge_answers(item)
all_answers = [normalize_answer(ans) for ans in all_answers if ans.strip()]

# 如果没有任何答案,直接返回False
if not all_answers:
return False

# 合并所有段落的文本(title + text)
full_text = " ".join([
f"{p['title']} {p['content']}"
for p in item["contexts"][:top_k]
])
full_text = normalize_answer(full_text)

# 检查是否有任何一个答案出现在文本中
return any(ans in full_text for ans in all_answers if ans)


# 解析命令行参数
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("input_file", type=str, help="输入文件路径")
parser.add_argument("--relevant", action="store_true", default=False, help="仅处理相关问题(默认处理所有问题)")

return parser.parse_args()


# 获取输入的数字列表
def get_user_input():
while True:
user_input = input("请输入一组数字(用空格隔开),默认 1 5 10 20 50 100 150 200:")
if user_input.strip() == "":
return [1, 5, 10, 20, 50, 100, 150, 200]
try:
numbers = list(map(int, user_input.split()))
if not all(1 <= num <= 200 for num in numbers):
print("输入的数字必须在1到200之间,请重新输入")
continue
return sorted(numbers)
except ValueError:
print("输入无效,请确保输入的是数字并用空格隔开")


if __name__ == "__main__":
args = parse_args()

# 读取输入数据
data = read_json_data(args.input_file)

# 仅保留相关问题
if args.relevant:
data = [
{
**item, # 保留其他字段
"contexts": [context for context in item["contexts"] if context.get("relevance", False)]
}
for item in data
if any(context.get("relevance", False) for context in item["contexts"])
]
print(f"读取到 {len(data)} 条数据")

# 获取输入的数字列表
numbers = get_user_input()
total_result = {}

# 使用多进程并行处理
with Pool(processes=cpu_count()) as pool:
for num in numbers:
processor = partial(check_answer_in_contexts, top_k=num)
results = list(pool.imap(processor, data))
total_result[num] = sum(results) / len(results)

# 打印结果
print(f"分析完成,共检查 {len(data)} 个可回答的问题")
for num, result in total_result.items():
print(f"Recall@{num}: {result:.2%}")

4 文档重排序

4.1 Embedding 重排

核心思想:利用 LLM 的输入 Embedding Layer 的 Hidden States 之间的余弦相似度/皮尔逊相关系数进行重排。

  1. 准备工作:读取数据集,加载分词器和模型;
  2. 遍历处理:对于每个样本,构建文档输入(文档+指令+问题),获取每个文档和问题对应的 embedding(均值池化),计算余弦相似度/皮尔逊相关系数;
  3. 保存结果:对于每个样本,更新 contexts 字段中每个文档的 score,并按其降序排序 contexts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import json
import argparse
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
from utils import read_json_data, ensure_directory_exists


# 构造模型输入
def construct_model_input(item):
sorted_contexts = sorted(item['contexts'], key=lambda x: x['score'])
sub_texts = ["<|begin_of_text|><|start_header_id|>user<|end_header_id|>"]
for i, context in enumerate(sorted_contexts):
chunk = f"\n\nDocument {i + 1} (Title: {context['title']}): {context['content']}"
sub_texts.append(chunk)

sub_texts += [
"\n\nBased on your knowledge and the provided information, answer the question:",
f"\n{item['question']}",
"\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
]
return sub_texts, sorted_contexts


# 查找子字符串在整个字符串中的起止位置
def find_subtexts_positions(sub_texts, tokenizer):
result = {}
input_text = "".join(sub_texts)

# 获取整个输入文本的token位置映射
encoding = tokenizer(input_text, return_offsets_mapping=True)
offset_mapping = encoding["offset_mapping"]

for i, sub_text in enumerate(sub_texts):
# 检查sub_text是否是input_text的子串
if sub_text not in input_text:
raise ValueError(f"子文本'{i}'不在整个文本中")

# 查找字符串位置
start_char = input_text.find(sub_text)
end_char = start_char + len(sub_text)

# 查找token位置
start_token = None
end_token = None

for j, (start, end) in enumerate(offset_mapping):
# 检查token是否与子串的起始位置重叠
if start <= start_char < end and start_token is None:
start_token = j
# 检查token是否与子串的结束位置重叠
if start < end_char <= end and end_token is None:
end_token = j
break # 找到结束token后可以提前退出

# 确保找到了起始和结束token
if start_token is None or end_token is None:
raise ValueError(f"没能定位到子文本'{i}'的token序列位置")

# 键: 字符串起止位置,值: token起止位置
result[(start_char, end_char)] = [start_token, end_token]

return result


# 计算文档与问题的余弦相似度
def calculate_score(data, tokenizer, model, method):
for item in tqdm(data, desc="Processing"):
sub_texts, sorted_contexts = construct_model_input(item)
input_text = "".join(sub_texts)
positions = find_subtexts_positions(sub_texts, tokenizer)

# 将input_text转换为token id序列并获取embedding
inputs = tokenizer(input_text, return_tensors="pt").to(device)
with torch.no_grad():
embeddings = model.get_input_embeddings()(inputs.input_ids)

# 对每个sub_text获取其embedding的均值池化
subtext_embeddings = {}
for (char_start, char_end), (token_start, token_end) in positions.items():
sub_embeddings = embeddings[0, token_start:token_end+1, :]
mean_embedding = torch.mean(sub_embeddings, dim=0)
subtext_embeddings[(char_start, char_end)] = mean_embedding

# 分离出question和documents的embedding
question_pos = list(positions.keys())[-2]
question_embedding = subtext_embeddings[question_pos]

for doc_idx in range(len(sorted_contexts)):
doc_pos = list(positions.keys())[doc_idx + 1]
doc_embedding = subtext_embeddings[doc_pos]
if method == 'similarity':
score = F.cosine_similarity(question_embedding.unsqueeze(0), doc_embedding.unsqueeze(0))
else:
score = torch.corrcoef(torch.stack([question_embedding, doc_embedding]))[0, 1]
sorted_contexts[doc_idx]['score'] = score.item()

item['contexts'] = sorted(sorted_contexts, key=lambda x: x['score'], reverse=True)

return data


# 解析命令行参数
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True, help="输入文件路径")
parser.add_argument("--output_file", type=str, required=True, help="输出文件路径")
parser.add_argument("--model_path", type=str, default="model/Llama-3-8B-Instruct", help="LLM模型路径")
parser.add_argument("--method", type=str, choices=['similarity', 'relevance'], help="排序得分计算方法")

return parser.parse_args()


if __name__ == "__main__":
args = parse_args()

# 读取输入数据
data = read_json_data(args.input_file)
print(f"读取到 {len(data)} 条数据")

# 载入分词器和模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
model = AutoModel.from_pretrained(args.model_path).to(device).eval()

# 更新分数
updated_data = calculate_score(data, tokenizer, model, args.method)

# 保存结果到输出文件
ensure_directory_exists(args.output_file)
with open(args.output_file, 'w', encoding='utf-8') as f:
json.dump(updated_data, f, indent=4, ensure_ascii=False)
print(f"处理完成,结果已保存到 {args.output_file}")

4.2 Attend Embedding 重排

核心思想:将文档和问题进行向量化,输入 LLM 后利用注意力矩阵对 embedding 重新排序:

  • 为什么不使用 LLM 向量化:向量化速度慢,而且未经特定训练的 LLM 向量化效果差;
  • 为什么选择均值池化的向量化方式:最后一层池化长文本鲁棒性差,CLS 池化表示语法结束符,而均值池化覆盖了句法和语义,有现成的第三方库;
  • 为什么问题不选择 token 化,而是向量化:确保问题和文档的向量化方式一致。
  1. 准备工作:读取数据集,加载 LLM 和 Retriever 模型;
  2. 批量处理:Retriever 将文档和问题编码成 embdeeing 矩阵,输入 LLM 中,获取每个检索头的注意力矩阵的最后一行(from question embedding to every document),然后求和(shape: (1, 文档数),排除了问题本身);
  3. 保存结果:对于每个样本,更新 contexts 字段中每个文档的 score 为对应的注意力分数,并按其降序排序 contexts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import json
import argparse
from tqdm import tqdm
from utils import read_json_data, ensure_directory_exists
from sentence_transformers import SentenceTransformer
from transformers import AutoModel
from multiprocessing import Process, Queue


# 获取最后一行的注意力分数之和
def get_attention_sums(outputs):
num_layers = len(outputs.attentions)
num_heads = outputs.attentions[0].shape[1]
layer_head_pairs = [[layer_idx, head_idx]
for layer_idx in range(num_layers)
for head_idx in range(num_heads)]

attention_sum = torch.zeros((1, outputs.attentions[0].shape[-1]), device=outputs.attentions[0].device)

for layer_idx, head_idx in layer_head_pairs:
attention_sum += outputs.attentions[layer_idx][0, head_idx, -1, :].unsqueeze(0)

return attention_sum # 形状: (1, seq_len)


# 并行处理数据
def worker(data_chunk, gpu_pair_idx, retriever_path, llm_path, result_queue):
device_retriever = f'cuda:{2 * gpu_pair_idx}'
device_llm = f'cuda:{2 * gpu_pair_idx + 1}'

retriever = SentenceTransformer(retriever_path, device=device_retriever).eval()
llm = AutoModel.from_pretrained(llm_path, output_attentions=True, device_map={'': device_llm}).eval()

results = []
for item in tqdm(data_chunk, desc=f"Worker {gpu_pair_idx}", position=gpu_pair_idx):
item['contexts'] = sorted(item['contexts'], key=lambda x: x['score'])
contexts = [f"Title: {c['title']}\nContent: {c['content']}" for c in item['contexts']]
contexts.append(f'Question: {item["question"]}')

embeddings = retriever.encode(contexts, convert_to_tensor=True, batch_size=16).unsqueeze(0)

with torch.no_grad():
outputs = llm(inputs_embeds=embeddings.to(device_llm))

attention_sums = get_attention_sums(outputs)
context_attention = attention_sums[0, :-1] # 排除question自身

for i, context in enumerate(item['contexts']):
context['score'] = context_attention[i].item()

item['contexts'] = sorted(item['contexts'], key=lambda x: x['score'], reverse=True)
results.append(item)

result_queue.put(results)


# 解析命令行参数
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True, help="输入文件路径")
parser.add_argument("--output_file", type=str, required=True, help="输出文件路径")
parser.add_argument("--llm_path", type=str, default="model/Llama-3-8B-Instruct", help="LLM模型路径")
parser.add_argument("--retriever_path", type=str, default="model/SFR-Embedding-Mistral", help="Embedding模型路径")
parser.add_argument("--num_gpus", type=int, default=torch.cuda.device_count(), help="使用的GPU数量")
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()

# 检查GPU数量
print(f"使用 {args.num_gpus} 个GPU进行计算")
if args.num_gpus % 2 != 0:
raise ValueError("使用的GPU数量必须为偶数")

# 读取输入数据
data = read_json_data(args.input_file)
print(f"读取到 {len(data)} 条数据")

num_proc = args.num_gpus // 2
chunk_size = (len(data) + num_proc - 1) // num_proc
result_queue = Queue()
processes = []

for i in range(num_proc):
chunk = data[i * chunk_size: (i + 1) * chunk_size]
p = Process(target=worker, args=(chunk, i, args.retriever_path, args.llm_path, result_queue))
p.start()
processes.append(p)

all_results = []
for _ in range(num_proc):
all_results.extend(result_queue.get())

for p in processes:
p.join()

# 保存结果到输出文件
ensure_directory_exists(args.output_file)
with open(args.output_file, 'w', encoding='utf-8') as f:
json.dump(all_results, f, indent=4, ensure_ascii=False)

5 研究结果

5.1 混合检索召回率

混合检索召回率
不同数据集的不同 Top-K 的答案召回率
  1. 随着 Top-K 的增加,所有数据集的答案召回率均呈上升趋势,但增长逐渐放缓。
  2. PopQA、NaturalQuestions、2WikiMultiHopQA 的训练集和测试集存在召回率差距。
  3. 数据集间差异显著,ASQA 和 NaturalQuestions 召回率最高,2WikiMultiHopQA 最低,PopQA 与 TriviaQA 居中。

5.2 混合检索与 Embedding 重排

混合检索与 Embedding 重排
不同数据集的不同方法的不同 Top-K 的答案召回率

5.3 混合检索与稀疏检索

混合检索与稀疏检索
测试集的不同方法的不同 Top-K 的答案召回率

5.4 重排召回率

重排召回率
测试集的不同方法的不同 Top-K 的答案召回率

5.5 全部结果

Dataset Set Retrieval Rerank Recall@1 Recall@5 Recall@10 Recall@20 Recall@50 Recall@100 Recall@150 Recall@200
PopQA Train Hybrid None 0.3269 0.5120 0.5868 0.6495 0.7230 0.7769 0.8122 0.8387
PopQA Train Hybrid Embedding Similarity 0.2952 0.4867 0.5818 0.6724 0.7718 0.8154 0.8307 0.8387
PopQA Test Hybrid None 0.3059 0.4496 0.4968 0.5475 0.6090 0.6776 0.7241 0.7548
PopQA Test Hybrid Embedding Similarity 0.3753 0.5540 0.6054 0.6548 0.6998 0.7277 0.7455 0.7548
PopQA Test Hybrid Embedding Relevance 0.3746 0.5540 0.6061 0.6548 0.7005 0.7277 0.7455 0.7548
PopQA Test Sparse None 0.4196 0.5297 0.5726 0.6069 0.6455 0.6791 0.6941 0.7048
PopQA Test Sparse Embedding Similarity 0.3417 0.4889 0.5440 0.5990 0.6512 0.6798 0.6891 0.7048
PopQA Test Hybrid Attention Score 0.2773 0.4425 0.5111 0.5761 0.6576 0.7148 0.7420 0.7548
PopQA Test Hybrid LLM Relevance 0.5293 0.6646 0.6978 0.7128 0.7342 0.7373 0.7389 0.7389
TriviaQA Train Hybrid None 0.4339 0.6228 0.6829 0.7342 0.7923 0.8355 0.8637 0.8794
TriviaQA Test Hybrid None 0.4302 0.6206 0.6814 0.7339 0.7881 0.8371 0.8671 0.8838
TriviaQA Test Hybrid Embedding Similarity 0.3690 0.6174 0.7141 0.7853 0.8423 0.8668 0.8779 0.8838
TriviaQA Test Sparse None 0.5024 0.6949 0.7489 0.7921 0.8321 0.8562 0.8671 0.8744
TriviaQA Test Sparse Embedding Similarity 0.3273 0.5535 0.6430 0.7239 0.8015 0.8454 0.8630 0.8744
NaturalQuestions Train Hybrid None 0.6035 0.7830 0.8180 0.8418 0.8633 0.8772 0.8868 0.8948
NaturalQuestions Test Hybrid None 0.4831 0.6970 0.7568 0.8036 0.8432 0.8604 0.8778 0.8922
NaturalQuestions Test Hybrid Embedding Similarity 0.1579 0.3903 0.5155 0.6526 0.7911 0.8504 0.8776 0.8922
NaturalQuestions Test Sparse None 0.2366 0.4571 0.5601 0.6449 0.7366 0.7859 0.8091 0.8222
NaturalQuestions Test Sparse Embedding Similarity 0.1233 0.2972 0.3928 0.5064 0.6468 0.7410 0.7881 0.8222
2WikiMultiHopQA Train Hybrid None 0.2243 0.3487 0.3876 0.4294 0.4945 0.5686 0.6290 0.6694
2WikiMultiHopQA Test Hybrid None 0.0930 0.1811 0.2323 0.2909 0.3931 0.5101 0.6151 0.6759
2WikiMultiHopQA Test Hybrid Embedding Similarity 0.1667 0.3131 0.3929 0.4674 0.5601 0.6212 0.6523 0.6759
2WikiMultiHopQA Test Sparse None 0.1854 0.3289 0.4095 0.4834 0.5738 0.6367 0.6723 0.6937
2WikiMultiHopQA Test Sparse Embedding Similarity 0.1570 0.2930 0.3635 0.4365 0.5339 0.6096 0.6555 0.6935
ASQA Train Hybrid None 0.5982 0.7926 0.8390 0.8771 0.9001 0.9173 0.9286 0.9405
ASQA Train Hybrid Embedding Similarity 0.1918 0.4620 0.6010 0.7349 0.8548 0.9079 0.9292 0.9405
ASQA Test Hybrid None 0.5243 0.7669 0.8249 0.8629 0.8977 0.9188 0.9357 0.9451
ASQA Test Hybrid Embedding Similarity 0.2068 0.4747 0.6171 0.7479 0.8576 0.9219 0.9378 0.9451
ASQA Test Hybrid Embedding Relevance 0.2068 0.4736 0.6181 0.7479 0.8586 0.9219 0.9378 0.9451
ASQA Test Sparse None 0.3249 0.5738 0.6867 0.7637 0.8312 0.8797 0.8914 0.9051
ASQA Test Sparse Embedding Similarity 0.1582 0.3713 0.4852 0.6086 0.7511 0.8376 0.8861 0.9051
ASQA Test Hybrid Attention Score 0.2289 0.5496 0.6973 0.8143 0.9030 0.9367 0.9451 0.9451
ASQA Test Hybrid LLM Relevance 0.5747 0.8206 0.8668 0.9012 0.9227 0.9345 0.9356 0.9356

6 原因分析

6.1 Embedding 重排

  1. 语义相似的文档不一定包含答案。
  2. Embedding 层未针对检索任务优化。
  3. 均值池化会丢失关键位置信息,噪声信息会稀释关键句的语义表示,降低相似度计算的准确性。

6.2 Attend Embedding 重排

  1. 注意力机制本质是用于建模 token 间依赖关系,而非判断文档是否包含答案。
  2. 没有经过训练的注意力头的注意力分布并不具备检索导向性。
  3. 注意力机制对输入顺序敏感,文档的排列顺序可能影响注意力分布。

【趣味研究】Embedding Rerank
http://xuan-van.github.io/67e5abea4c9d/
作者
文晋
发布于
2025年6月15日
许可协议