【论文复现】SelfElicit

参考项目:ZhiningLiu1998/SelfElicit

方法图示:

1 安装

1.1 虚拟环境

1
2
3
4
5
6
7
8
conda create -n selfelicit python=3.8 -y
conda activate selfelicit
pip install torch transformers==4.44.1 pandas==1.4.4 seaborn ipykernel
pip install -i https://mirrors.aliyun.com/pypi/simple/ spacy-3.7.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl # https://pypi.tuna.tsinghua.edu.cn/simple/spacy/
pip install en_core_web_sm-3.8.0-py3-none-any.whl # https://github.com/explosion/spacy-models/releases/tag/en_core_web_sm-3.8.0

python -m ipykernel install --user --name selfelicit
jupyter kernelspec list

1.2 模型和数据集

模型使用的是 meta-llama/Meta-Llama-3.1-8B-Instruct,数据集使用的是 HotpotQAdev_distractor.json 作为演示。

1.3 nltk 语料库

需要提前下载以下语料库,并保存在相应的文件夹下:

1
2
3
4
5
nltk_data/corpora/worknet
nltk_data/corpora/omw-1.4
nltk_data/tokenizers/punkt
nltk_data/tokenizers/punkt_tab
nltk_data/taggers/averaged_perceptron_tagger_eng

链接:nltk data

2 整体流程

2.1 参数导入

  1. 导入必要的库:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import warnings
warnings.filterwarnings("ignore")

import os
import re
import json
import tqdm
import argparse
import yaml
import random
import numpy as np
import pandas as pd

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

import spacy

import regex, string, unicodedata
import nltk
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet
from nltk.tokenize import word_tokenize

import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="whitegrid")
  1. 加载 args.py 的功能:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# 从 YAML 文件加载默认配置
def load_config(config_file="config.yaml"):
with open(config_file, "r") as f:
config = yaml.safe_load(f)
return config

# 获取命令行参数和配置文件参数
def get_args(config_file="config.yaml", using_notebook=False, verbose=1):
if verbose:
print(f"Loading default configuration from '{config_file}' ...")

# 从配置文件中载入默认值
config = load_config(config_file)

# 创建参数解析器
parser = argparse.ArgumentParser(description="Configuration for QA and SE Instructions")

# 添加模型、方法和数据集参数
ALL_MODELS = ["meta-llama/Meta-Llama-3.1-8B-Instruct", "meta-llama/Meta-Llama-3.1-70B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Mistral-Nemo-Instruct-2407", "Qwen/Qwen2.5-7B-Instruct", "Qwen/Qwen2.5-32B-Instruct",]
ALL_METHODS = ["Base", "COT", "FullElicit", "PromptElicit", "SelfElicit"]
ALL_DATASETS = ["HotpotQA", "NewsQA", "TQA", "NQ"]

# 添加参数
parser.add_argument("--hf_token", type=str, default=config["hf_token"], help=f"Hugging Face API token") # API token
parser.add_argument("--model_id", type=str, default=config["model_id"], help=f"The HuggingFace Model ID, should be one of {ALL_MODELS}") # 模型 ID
parser.add_argument("--methods", nargs="+", default=config["methods"], help=f"Method(s) to test, can be a list or a single value from {ALL_METHODS}") # 方法
parser.add_argument("--datasets", nargs="+", default=config["datasets"], help=f"Dataset(s) to use, can be a list or a single value from {ALL_DATASETS}") # 数据集
parser.add_argument("--alpha", type=float, default=config["alpha"], help="Threshold for SelfElicit method") # 阈值 alpha
parser.add_argument("--layer_span", type=tuple, default=tuple(config["layer_span"]), help="Layer span for SelfElicit method") # layer 层的范围
parser.add_argument("--gpu_ids", nargs="+", default=config["gpu_ids"], help="GPU IDs") # GPU ID
parser.add_argument("--n_samples", type=int, default=config["n_samples"], help="Number of samples") # 每个数据集的样本数量
parser.add_argument("--random_state", type=int, default=config["random_state"], help="Random state for reproducibility") # 随机种子
parser.add_argument("--max_ans_tokens", type=int, default=config["max_ans_tokens"], help="Maximum answer length in tokens") # 答案的最大长度
parser.add_argument("--marker_impstart", type=str, default=config["marker_impstart"], help="Marker for the start of important information") # 重要信息开始的标记
parser.add_argument("--marker_impend", type=str, default=config["marker_impend"], help="Marker for the end of important information") # 重要信息结束的标记
parser.add_argument("--qa_inst", type=str, default=config["qa_inst"], help="QA instruction") # 基于上下文的问答指令
parser.add_argument("--se_inst", type=str, default=config["se_inst"], help="QA instruction for SelfElicit") # 使用 SelfElicit 高亮显示的基于上下文的问答指令
parser.add_argument("--cot_inst", type=str, default=config["cot_inst"], help="QA instruction with Chain of Thought prompt") # 使用思维链提示的基于上下文的问答指令
parser.add_argument("--pe_inst", type=str, default=config["pe_inst"], help="Instruction for 1st-step extracting evidence in PromptElicit") # 从上下文中提取证据的指令

# 解析参数
if using_notebook:
if verbose:
print("Parsing arguments from command line is disabled as using_notebook=True.")
args = parser.parse_args([]) # 传入空列表避免解析命令行
else:
args = parser.parse_args()

# 设置 GPU 环境参数
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in args.gpu_ids])
if verbose:
print("Using GPUs: {}".format(os.environ["CUDA_VISIBLE_DEVICES"]))

# 非法模型输入
if args.model_id not in ALL_MODELS:
raise ValueError(f"Invalid model: {args.model_id}. Must be one of {ALL_MODELS}")

# 非法方法输入
if isinstance(args.methods, list):
for method in args.methods:
if method not in ALL_METHODS:
raise ValueError(f"Invalid method: {method}. Must be one of {ALL_METHODS}")
elif args.methods not in ALL_METHODS:
raise ValueError(f"Invalid method: {args.methods}. Must be one of {ALL_METHODS}")

# 非法数据集输入
if isinstance(args.datasets, list):
for dataset in args.datasets:
if dataset not in ALL_DATASETS:
raise ValueError(f"Invalid dataset: {dataset}. Must be one of {ALL_DATASETS}")
elif args.datasets not in ALL_DATASETS:
raise ValueError(f"Invalid dataset: {args.datasets}. Must be one of {ALL_DATASETS}")

# 填充指令字符串中的标记
assert ("{MARKER_IMPSTART}" in args.se_inst and "{MARKER_IMPEND}" in args.se_inst), "Instruction for SelfElicit must contain {MARKER_IMPSTART} and {MARKER_IMPEND}"
args.se_inst = args.se_inst.format(MARKER_IMPSTART=args.marker_impstart, MARKER_IMPEND=args.marker_impend)

# 打印参数信息
if verbose:
print("Arguments loaded successfully!\nArguments:")
for key, value in vars(args).items():
if key == "hf_token":
print(f"\t{key:<10s}: {'*' * len(value)}")
else:
print(f"\t{key:<10s}: {value}")

return args
  1. 获取参数:
1
2
3
4
args = get_args(using_notebook=True)
args.n_samples = 200 # 用子集快速测试
args.model_id = "../model/Llama-3.1-8B-Instruct"
args.datasets = ['HotpotQA']

Loading default configuration from 'config.yaml' ...
Parsing arguments from command line is disabled as using_notebook=True.
Using GPUs: 0
Arguments loaded successfully!
Arguments:
    hf_token  : *************
    model_id  : meta-llama/Meta-Llama-3.1-8B-Instruct
    methods   : ['Base', 'COT', 'FullElicit', 'PromptElicit', 'SelfElicit']
    datasets  : ['HotpotQA', 'NewsQA', 'TQA', 'NQ']
    alpha     : 0.5
    layer_span: (0.5, 1.0)
    gpu_ids   : [0]
    n_samples : 1000
    random_state: 0
    max_ans_tokens: 100
    marker_impstart: <START_IMPORTANT>
    marker_impend: <END_IMPORTANT>
    qa_inst   : Directly answer the question based on the context passage, no explanation is needed. If the context does not contain any evidence, output 'I cannot answer based on the given context.'
    se_inst   : Directly answer the question based on the context passage, no explanation is needed. Within the context, <START_IMPORTANT> and <END_IMPORTANT> are used to mark the important evidence. Read carefully but still keep your answer short, do not output the markers. If the context does not contain any evidence, output 'I cannot answer based on the given context.'
    cot_inst  : Directly answer the question based on the context passage, no explanation is needed. If the context does not contain any evidence, output 'I cannot answer based on the given context.' Think step by step to provide the answer.
    pe_inst   : Please find the supporting evidence sentences from the context for the question, then copy-paste the original text to output without any additional words. Template for output: '
- [sentence1]
- [sentence2] ...'

2.2 模型加载

  1. 加载 utils.py 的功能:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# 加载 Hugging Face 模型和分词器,并返回主设备
def get_model_tokenizer_device(hf_token, model_id, verbose=True):
# # 检查 CUDA 是否可用
assert torch.cuda.is_available(), "CUDA is not available!"

# 打印 CUDA 设备信息
if verbose:
print("CUDA is available with devices:")
for i in range(torch.cuda.device_count()):
print(f"\t- Device {i}: {torch.cuda.get_device_name(i)}")

# 登录 Hugging Face
if verbose:
print("Logging in to Hugging Face ...")
# huggingface_hub.login(hf_token)

# 加载模型和分词器
if verbose:
print("Loading model and tokenizer ... ", end="")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
output_attentions=True,
trust_remote_code=True,
device_map="auto",
attn_implementation="eager",
)
if verbose:
print("Success!")

main_device = torch.device("cuda:0") # 设置主设备
return model, tokenizer, main_device

# 标准化文本
def norm_text(input_string):
res = re.sub(r"\n+", " ", input_string) # 移除多余换行
res = re.sub(r"\s+", " ", res) # 移除多余空格

# 移除多余符号
for sep_char in [".", "?", "!", ",", ":"]:
res = re.sub(rf"\s+\{sep_char}\s+", f"{sep_char} ", res)

# 移除特殊 token
for token in ["[CLS]", "[SEP]", "[PAD]", "[MASK]", "[UNK]", "[PAR]", "[DOC]", "[TLE]", "<P>", "</P>", "<Tr>"]:
res = res.replace(token, "")

return res.strip()

# 获取句子级别的 token 范围
def get_sentence_token_spans(context_ids, tokenizer):
context_text = tokenizer.decode(context_ids[0]) # 将输入的 token IDs 解码为原始文本
context_tokens_text = [tokenizer.decode([token_id]).replace(" ", "") for token_id in context_ids[0]] # 获取每个 token 对应的文本(去除空格)
sents = [sent.text for sent in spacy.load("en_core_web_sm")(context_text).sents] # 使用 spacy 的英文模型将文本分割成句子

# 如果句子全是空格或长度小于等于 5,则将其与下一句合并(最后一句与前一句合并)
for i in range(len(sents)):
if len(sents[i].strip()) <= 5:
if i < len(sents) - 1:
sents[i + 1] = sents[i] + sents[i + 1]
sents[i] = ""
else:
sents[i - 1] = sents[i - 1] + sents[i]
sents[i] = ""
sents = [sent for sent in sents if sent != ""] # 过滤掉空句子

# 查找每个句子对应的 token 范围
sent_token_spans = []
tk_start_idx = 0

for i, sent in enumerate(sents):
sent = sent.lstrip(" ") # 去除句子左侧空格
sent_num_tokens = len(tokenizer.encode(sent, add_special_tokens=False)) # 计算句子的 token 数量

# 获取当前 span 对应的文本
sent_text = sent.replace(" ", "")
span_text = tokenizer.decode(context_ids[0, tk_start_idx : tk_start_idx + sent_num_tokens]).replace(" ", "")

# 检查 span 和句子的包含关系
span_include_sent = span_text.find(sent_text) >= 0
sent_include_span = sent_text.find(span_text) >= 0
len_span = sent_num_tokens

if span_include_sent and sent_include_span: # 完全匹配
pass
elif span_include_sent and not sent_include_span: # span 比句子长
while True:
len_span -= 1 # 减少 span 长度
del_token = context_tokens_text[tk_start_idx + len_span]
span_text = span_text.rstrip(del_token) # 从右侧移除 token
if span_text.find(sent_text) < 0: # 如果 span 比句子短了
span_text = span_text + del_token # 加回最后一个 token
break
elif not span_include_sent: # span 比句子短
while True:
add_token = context_tokens_text[tk_start_idx + len_span]
len_span += 1 # 增加 span 长度
span_text = span_text + add_token # 添加 token
if span_text.find(sent_text) >= 0: # 直到包含完整句子
break

# 计算句子结束 token 索引
tk_end_idx = tk_start_idx + len_span
sent_token_spans.append((tk_start_idx, tk_end_idx))
tk_start_idx = tk_end_idx # 更新下一句起始位置

if not span_text.endswith(sent_text): # 如果最后一个 token 包含下一句内容
tk_start_idx -= 1 # 回退一个 token

assert len(sent_token_spans) == len(sents) # 验证句子数量匹配

return sent_token_spans, sents
  1. 加载 qa_agent.py 的功能:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# 初始化不同任务的 Agent 字典
def get_agents_dict(model, tokenizer, device, args):
# 验证输入参数结构
assert (type(args) == argparse.Namespace), "args should be an argparse.Namespace object"
assert hasattr(args, "qa_inst"), "args should have 'qa_inst' attribute"
assert hasattr(args, "se_inst"), "args should have 'se_inst' attribute"
assert hasattr(args, "cot_inst"), "args should have 'cot_inst' attribute"
assert hasattr(args, "pe_inst"), "args should have 'pe_inst' attribute"

# 准备所有 Agent 共享的参数
agent_kwargs = {"model": model, "tokenizer": tokenizer, "device": device, "max_ans_tokens": args.max_ans_tokens}

# 初始化并返回 Agent 字典
return {
"qa": ContextQuestionAnsweringAgent(instruction=args.qa_inst, **agent_kwargs),
"se": ContextQuestionAnsweringAgent(instruction=args.se_inst, **agent_kwargs),
"cot": ContextQuestionAnsweringAgent(instruction=args.cot_inst, **agent_kwargs),
"pe": ContextQuestionAnsweringAgent(instruction=args.pe_inst, **agent_kwargs),
}


class ContextQuestionAnsweringAgent:
"""
基于上下文的问答 Agent 类,支持多种策略。

属性:
model: 初始化的语言模型。
tokenizer: 初始化的分词器。
device: 计算设备。
instruction: Agent 指令。
max_ans_tokens: 答案最大 token 数。
"""

# 初始化 Agent
def __init__(self, model, tokenizer, device, instruction, max_ans_tokens):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.instruction = instruction
self.max_ans_tokens = max_ans_tokens

try: # 设置模型生成结束标记
self.model.generation_config.pad_token_id = self.tokenizer.eos_token_id
except AttributeError:
pass

# 使用聊天模板准备模型的输入 ID
def get_chat_template_input_ids(self, context, question, return_tensors=None):
# 构建包含指令、上下文和问题的消息
instruction = self.instruction
msg = f"Instruction: {instruction} Context: {context} Question: {question}"

# 使用聊天模板进行分词
input_ids = self.tokenizer.apply_chat_template(
[{"role": "user", "content": msg}],
add_generation_prompt=True,
return_tensors=return_tensors,
)
return input_ids

# 基于上下文生成问题的答案
def get_answer(self, context, question, max_ans_tokens=None, verbose=False, return_n_tokens=False):
model, tokenizer, device = self.model, self.tokenizer, self.device

# 设置最大答案 token 数
if max_ans_tokens is None:
max_ans_tokens = self.max_ans_tokens
else:
assert type(max_ans_tokens) == int, "max_ans_tokens should be an integer"

# 分词输入
input_ids = self.get_chat_template_input_ids(context, question, return_tensors="pt").to(device)
len_input = input_ids.shape[-1]

# 生成答案
with torch.no_grad():
outputs = model.generate(
input_ids,
attention_mask=torch.ones_like(input_ids),
max_new_tokens=max_ans_tokens,
do_sample=False,
top_p=None,
top_k=None,
temperature=None,
)

# 解码答案
answer_ids = outputs[0][len_input:]
answer = tokenizer.decode(answer_ids, skip_special_tokens=True)

# 可选打印
if verbose:
print(f"Context: {context}\nQuestion: {question}\nAnswer: {answer}")

# 可选返回 token 数
if return_n_tokens:
n_tokens = len(answer_ids)
return answer, n_tokens

return answer

@staticmethod
# 计算子字符串在字符串中的出现次数
def get_n_match(string, substring):
all_starts = []
start = 0
while True:
start = string.find(substring, start)
if start == -1:
break
all_starts.append(start)
start += 1
return len(all_starts)

# 定位目标文本在分词输入中的位置范围
def find_text_token_spans(self, input_ids, target_text, raise_if_not_found=True):
# 保证输入 ID 是列表
assert (type(input_ids) == list) and (type(input_ids[0]) == int), "input_ids should be a 1-d list, make sure it's not a tensor."

# 解码输入和目标文本
tokenizer = self.tokenizer
source = tokenizer.decode(input_ids)
target_ids = tokenizer.encode(target_text, add_special_tokens=False)
target = tokenizer.decode(target_ids)

# 未找到目标文本
if raise_if_not_found:
assert target in source, f"'{target}' not found in input"

# 初始化寻找范围
n_match_left = self.get_n_match(source, target)
spans = []
start = 0

while True:
start += 1
source_seg = tokenizer.decode(input_ids[start:])
n_match_cur = self.get_n_match(source_seg, target)

if n_match_cur < n_match_left:
assert (n_match_left - n_match_cur == 1), f"{n_match_left - n_match_cur} matches in a same token"
n_match_left = n_match_cur
start -= 1
end = max(start + len(target_ids) - 5, start)
while True:
end += 1
seg_text = tokenizer.decode(input_ids[start:end])
if target in seg_text:
break

spans.append((start, end))
start = end

if n_match_left == 0 or start >= len(input_ids):
break

return spans

# 获取上下文在分词输入中的位置范围
def get_context_token_span(self, context, question):
input_ids = self.get_chat_template_input_ids(context, question, return_tensors=None)
context_spans = self.find_text_token_spans(input_ids, context)
assert (len(context_spans) == 1), f"Multiple/no context spans found: {context_spans}"
return context_spans[0]
  1. 载入模型和分词器,设置问答助手:
1
2
3
model, tokenizer, device = get_model_tokenizer_device(args.hf_token, args.model_id)
agents_dict = get_agents_dict(model, tokenizer, device, args)
agents_dict

CUDA is available with devices:
    - Device 0: Tesla V100-SXM2-32GB
Logging in to Hugging Face ...
Loading model and tokenizer ... 
Loading checkpoint shards: 100%|██████████| 4/4 [00:35<00:00,  8.92s/it]
Success!

{'qa': <__main__.ContextQuestionAnsweringAgent at 0x7fddf8cd1df0>,
'se': <__main__.ContextQuestionAnsweringAgent at 0x7fddf8cd1fa0>,
'cot': <__main__.ContextQuestionAnsweringAgent at 0x7fddf8c29d60>,
'pe': <__main__.ContextQuestionAnsweringAgent at 0x7fddf8c0dee0>}

2.3 实验执行与评估

  1. 加载 dataloader.py 的功能:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# 载入数据:返回处理过的数据集
def load_data(dataset_name, n_samples=1000, random_state=42, verbose=True):
kwargs = {"n_samples": n_samples, "random_state": random_state, "verbose": verbose} # 设置数据集载入的默认参数

# 匹配数据集名称到对应的类
if dataset_name == "HotpotQA":
data = HotpotQA(**kwargs)
elif dataset_name == "NewsQA":
data = NewsQA(**kwargs)
elif dataset_name == "TQA":
data = TQA(**kwargs)
elif dataset_name == "NQ":
data = NQ(**kwargs)
else:
raise ValueError(f"Invalid dataset: {dataset_name}")

return data

# 载入数据集
def load_dataset(file_path):
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)

class HotpotQA:
"""
用于加载和处理 HotpotQA 数据集的类。

属性:
n_samples: int, 要加载的样本数量。
shuffle: bool, 是否打乱数据集。
random_state: int, 用于打乱数据的随机种子。
verbose: bool, 是否打印进度信息。
"""
HF_DATASET = "hotpotqa/hotpot_qa"

def __init__(self, n_samples=None, shuffle=True, random_state=42, verbose=True):
self.n_samples = n_samples
self.shuffle = shuffle
self.random_state = random_state
self.verbose = verbose

# 加载 HotpotQA 数据集
if verbose:
print(f"Loading the HotpotQA dataset ...", end=" ")
dataset = load_dataset("dev_distractor.json")

dataset_length = len(dataset) # 获取数据集长度

# 检查样本数
if n_samples is None:
dataset = dataset
elif dataset_length < n_samples:
warnings.warn(f"The dataset only has {dataset_length} samples that satisfy the filtering criteria.")
dataset = dataset
elif dataset_length >= n_samples:
random.seed(random_state)
dataset = random.sample(dataset, n_samples) # 创建子集

if verbose:
print("Success!")

self.dataset = dataset

def __len__(self): # 返回数据集的长度
return len(self.dataset)

def __getitem__(self, idx): # 根据索引获取数据集中的样本
return self.dataset[idx]

# 获取特殊索引的上下文和问题
def get_context_question(self, idx, use_gold=True, norm=False):
if use_gold:
context = self.get_gold_context(idx)
else:
context = self.get_context(idx)
if norm:
context = norm_text(context)
question = self.dataset[idx]["question"]
return context, question

# 获取指定索引的完整上下文文本
def get_context(self, idx):
context = self.dataset[idx]["context"]
title_sent_start_index = {}
sent_counter = 0
context_text = ""
for i in range(len(context["title"])):
title_sent_start_index[context["title"][i]] = sent_counter
for j in range(len(context["sentences"][i])):
context_text += context["sentences"][i][j]
sent_counter += 1
context_text += "\n"
return context_text

# 获取指定索引的黄金上下文
def get_gold_context(self, idx, return_list=False):
context = self.dataset[idx]["context"]
gold_facts = self.dataset[idx]["supporting_facts"]
gold_sents = []

context_title = [i[0] for i in context]
for title, sent_id in gold_facts:
title_id = context_title.index(title)
sent_text = context[title_id][1][sent_id]
gold_sents.append(sent_text)

if return_list:
return gold_sents
else:
gold_context_text = ""
for sent_text in gold_sents:
gold_context_text += sent_text + " "
return gold_context_text

# 获取指定索引的答案列表
def get_answers(self, idx):
return [self[idx]["answer"]]

# 获取指定索引的上下文、问题和答案
def get_context_question_answer(self, idx, use_gold=True, norm=True):
context, question = self.get_context_question(idx, use_gold, norm)
answers = self.get_answers(idx)
return context, question, answers
  1. 加载 self_elicit.py 的功能:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# 获取基础答案
def get_answer_base(context, question, agents_dict, args):
return agents_dict["qa"].get_answer(context, question)

# 获取 CoT 答案
def get_answer_cot(context, question, agents_dict, args):
return agents_dict["cot"].get_answer(context, question)

# 获取 FullElicit 答案
def get_answer_fullelicit(context, question, agents_dict, args):
# 通过添加证据标记来将整个上下文标记为重要
context = f"{args.marker_impstart} {context} {args.marker_impend}"
return agents_dict["qa"].get_answer(context, question)

# 获取 PromptElicit 答案
def get_answer_promptelicit(context, question, agents_dict, args, return_evidence=False):
# 提取证据句子:基于提示词
def prompt_elicit(agent_elicit, context, question, marker_impstart, marker_impend, max_gen_tokens):
# 使用 "pe" Agent 从上下文中生成证据句子
model_ans_raw = agent_elicit.get_answer(context, question, max_ans_tokens=max_gen_tokens)
elicited_context = f"{context}"
evidence_sents = []

# 解析模型输出,识别上下文中的证据句子
for sent in [sent.lstrip("- ").lstrip('"').rstrip('"') for sent in model_ans_raw.split("\n")]:
if context.find(sent) > -1: # 检查句子是否存在于原始上下文中
# 定位句子在上下文中的位置
sent_start = context.find(sent)
sent_end = sent_start + len(sent)

# 在句子周围插入证据标记
elicited_context = (elicited_context[:sent_start] + f"{marker_impstart} {sent} {marker_impend}" + elicited_context[sent_end:])
evidence_sents.append(sent)

return elicited_context, evidence_sents

elicited_context, evidence_sents = prompt_elicit(agents_dict["pe"], context, question, args.marker_impstart, args.marker_impend, args.max_ans_tokens)
model_ans = agents_dict["se"].get_answer(elicited_context, question) # 使用 "se" Agent 获取最终答案

# 是否返回提取的证据句子
if return_evidence:
return model_ans, evidence_sents
else:
return model_ans

# 获取 SelfElicit 答案
def get_answer_selfelicit(context, question, agents_dict, device, args, return_evidence=False):
# 提取证据句子:基于注意力分数
def self_elicit(output_att, sents, sent_spans, context_span, marker_impstart, marker_impend, layer_span, threshold, verbose=False):
# 计算指定层范围内的注意力分数
att_layer_scores = np.array([output_att[l][0, :, -1, context_span[0] : context_span[1]].detach().cpu().float().numpy().mean(axis=0) for l in range(layer_span[0], layer_span[1])])
att_layer_scores /= att_layer_scores.sum(axis=1, keepdims=True) # 跨层归一化注意力分数

# 将 token 级分数聚合成句子级分数
att_token_scores = att_layer_scores.mean(axis=0)
sent_scores = np.array([att_token_scores[sent_span[0] : sent_span[1]].mean() for sent_span in sent_spans])

# 选择分数超过阈值的句子
target_sent_index = (sent_scores >= sent_scores.max() * threshold).nonzero()[0]

if verbose:
print(f"Sentences scores: {sent_scores.round(2)}")
print(f"Target sentence index: {target_sent_index}")

elicited_context = ""
sent_end = "\n"
evidence_sents = []
for i, sent in enumerate(sents):
if i in target_sent_index and len(sent.replace(" ", "")) > 5: # 过滤过短句子
# 为选中的句子添加证据标记
elicited_context += (f"{marker_impstart} {sent} {marker_impend} {sent_end}")
evidence_sents.append(sent)
else:
elicited_context += f"{sent} {sent_end}"

# 收集选中证据句子的 token 范围
evidence_spans = [sent_spans[i] for i in target_sent_index]

return elicited_context, evidence_sents, evidence_spans

input_ids = (agents_dict["qa"].get_chat_template_input_ids(context, question, return_tensors="pt").to(device)) # 准备输入 tokens
context_span = agents_dict["qa"].get_context_token_span(context, question) # 获取上下文 token 范围
context_ids = input_ids[:, context_span[0] : context_span[1]]
sent_spans, sents = get_sentence_token_spans(context_ids, agents_dict["qa"].tokenizer) # 分词上下文并识别句子范围
outputs = agents_dict["qa"].model(input_ids, output_attentions=True, attention_mask=torch.ones_like(input_ids)) # 运行模型并获取注意力输出
output_att = outputs.attentions
n_layers = len(output_att)
layer_span = (int(args.layer_span[0] * n_layers), int(args.layer_span[1] * n_layers)) # 定义用于证据选择的层范围

# 使用计算的注意力模式进行证据提取
elicited_context, evidence_sents, evidence_spans = self_elicit(
output_att,
sents,
sent_spans,
context_span,
args.marker_impstart,
args.marker_impend,
layer_span=layer_span,
threshold=args.alpha,
)

# 计算完成后释放GPU内存
del outputs
torch.cuda.empty_cache()

# 使用 "se" Agent 基于处理后的上下文生成最终答案
model_ans = agents_dict["se"].get_answer(elicited_context, question)

# 是否返回提取的证据句子
if return_evidence:
return model_ans, evidence_sents
else:
return model_ans
  1. 加载 qa_metrics 库的功能:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
lemmatizer = WordNetLemmatizer()  # 初始化 WordNet 词形还原器

# 答案标准化
def normalize_answer(text):
# 修正单个句子
def fix_answer(s):
def remove_articles(text): # 移除冠词
return regex.sub(r'\b(a|an|the)\b', ' ', text)

def white_space_fix(text): # 修正多余空格
return ' '.join(text.split())

def remove_punc(text): # 移除标点符号
return ''.join(ch for ch in text if ch not in set(string.punctuation))

def lower(text): # 转换为小写
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s)))) # 依次应用

if isinstance(text, list): # 输入为列表
result = []
for ele in text:
ele = str(ele)
ele = ''.join(char for char in ele if not unicodedata.category(char).startswith('P')) # 移除 Unicode 标点符号
ele = fix_answer(' '.join(ele.split()))
result.append(ele.strip().replace("’", "'").lower()) # 处理特殊引号并添加到结果列表
return result
else: # 输入为单个文本
text = str(text)
text = ''.join(char for char in text if not unicodedata.category(char).startswith('P'))
text = fix_answer(' '.join(text.split()))
return text.strip().replace("’", "'").lower()

# 文本词形还原
def lemmatize_text(text):
# 获取单词的词性标签,将 POS 标签映射为 lemmatize() 接受的第一个字符
def get_wordnet_pos(word):
tag = nltk.pos_tag([word])[0][1][0].upper()
tag_dict = {
"J": wordnet.ADJ, # 形容词
"N": wordnet.NOUN, # 名词
"V": wordnet.VERB, # 动词
"R": wordnet.ADV, # 副词
}
return tag_dict.get(tag, wordnet.NOUN) # 词性未知默认返回名词

# 分词
words = word_tokenize(text)

# 词形还原
lemmatized_words = [lemmatizer.lemmatize(word, get_wordnet_pos(word)) for word in words]

# 将还原后的单词重新组合成句子
lemmatized_sentence = ' '.join(lemmatized_words)

return lemmatized_sentence

# 精确匹配: 判断候选答案是否包含参考答案
def em_match(reference, candidate):
if len(reference) == 0 or len(candidate) == 0:
return False

if isinstance(reference, list) and isinstance(candidate, list):
reference = [normalize_answer(str(ele)) for ele in reference]
candidate = [normalize_answer(str(ele)) for ele in candidate]
elif isinstance(reference, list):
reference = [normalize_answer(str(ele)) for ele in reference]
candidate = [normalize_answer(str(candidate))]
elif isinstance(candidate, list):
candidate = [normalize_answer(str(ele)) for ele in candidate]
reference = [normalize_answer(str(reference))]
else:
reference = [normalize_answer(str(reference))]
candidate = [normalize_answer(str(candidate))]

for ref in reference:
for can in candidate:
if ref in can:
return True

return False

# 计算 F1 分数及精确率、召回率
def f1_score_with_precision_recall(reference, candidate):
# 对参考和候选答案进行标准化和词形还原
reference = lemmatize_text(normalize_answer(str(reference)))
candidate = lemmatize_text(normalize_answer(str(candidate)))

# 分词并转换为集合
words_reference = set(reference.split())
words_candidate = set(candidate.split())

# 计算真正例、假正例、假反例
tp = len(words_reference.intersection(words_candidate))
fp = len(words_reference - words_candidate)
fn = len(words_candidate - words_reference)

precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

return {'f1': f1_score, 'precision': precision, 'recall': recall}

# 返回 F1 分数
def f1_score(reference, candidate):
f1_stats = f1_score_with_precision_recall(reference, candidate)
return f1_stats['f1']

# 模糊匹配: F1 分数大于阈值则返回 True
def f1_match(reference, candidate, threshold=0.5):
if len(reference) == 0 or len(candidate) == 0:
return False

if isinstance(reference, list) and isinstance(candidate, list):
references = [lemmatize_text(normalize_answer(str(ele))) for ele in reference]
candidates = [lemmatize_text(normalize_answer(str(ele))) for ele in candidate]
f1_scores = []
for reference in references:
for candidate in candidates:
f1_scores.append(f1_score(reference, candidate))
return max(f1_scores) > threshold

elif isinstance(reference, list):
references = [lemmatize_text(normalize_answer(str(ele))) for ele in reference]
candidate = lemmatize_text(normalize_answer(str(candidate)))
f1_scores = []
for reference in references:
f1_scores.append(f1_score(reference, candidate))
return max(f1_scores) > threshold

elif isinstance(candidate, list):
candidates = [lemmatize_text(normalize_answer(str(ele))) for ele in candidate]
reference = lemmatize_text(normalize_answer(str(reference)))
f1_scores = []
for candidate in candidates:
f1_scores.append(f1_score(reference, candidate))
return max(f1_scores) > threshold

else:
reference = lemmatize_text(normalize_answer(str(reference)))
candidate = lemmatize_text(normalize_answer(str(candidate)))
return f1_score(reference, candidate) > threshold
  1. 加载 eval.py 的功能:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 评估模型答案与单个标准答案的匹配程度
def evaluate_single_ans(true_ans, model_ans):
f1_pr = f1_score_with_precision_recall(true_ans, model_ans)
f1, pr, re = f1_pr["f1"], f1_pr["precision"], f1_pr["recall"]
return {
"em": em_match(true_ans, model_ans), # 精确匹配
"f1m": f1_match(true_ans, model_ans), # 模糊匹配
"f1": f1,
"pr": pr,
"re": re,
}

# 评估模型生成的答案与一组标准答案的匹配程度
def evaluate(true_ans_list, model_ans, sel_metric="f1"):
best_score = -1 # 初始化最佳分数
for true_ans in true_ans_list:
# 评估模型答案与当前标准答案的匹配程度
eval_res = evaluate_single_ans(true_ans, model_ans)
# 如果当前分数更高,则更新最佳分数和对应指标
if eval_res[sel_metric] > best_score:
best_score = eval_res[sel_metric]
best_ans = true_ans
best_eval_res = eval_res

# 计算整体精确匹配分数
best_eval_res["em"] = em_match(true_ans_list, model_ans)

return best_ans, best_eval_res
  1. 运行实验:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def run_experiment(device, agents_dict, args):
methods = args.methods
datasets_dict = {dataset: load_data(dataset, args.n_samples, args.random_state, True) for dataset in args.datasets}
qa_res_eval_cols = ["em", "f1", "pr", "re"]
qa_res_columns = ["dataset", "idx", "true_ans", "model_ans", "method"] + qa_res_eval_cols

qa_results = []

for dataset_name, dataset in datasets_dict.items():
dataset_runstat = {"f1": {method: [] for method in methods}, "em": {method: [] for method in methods}}
iterator = tqdm.tqdm(range(len(dataset)), desc=f"DATA - {dataset_name:<10s}")

for idx in iterator:
context, question, true_ans_list = dataset.get_context_question_answer(idx)
for method in methods:
try:
if method == "Base":
model_ans = get_answer_base(context, question, agents_dict, args)
elif method == "COT":
model_ans = get_answer_cot(context, question, agents_dict, args)
elif method == "FullElicit":
model_ans = get_answer_fullelicit(context, question, agents_dict, args)
elif method == "PromptElicit":
model_ans = get_answer_promptelicit(context, question, agents_dict, args)
elif method == "SelfElicit":
model_ans, evidence_sents = get_answer_selfelicit(context, question, agents_dict, device, args, return_evidence=True)
except:
continue

true_ans_used, scores = evaluate(true_ans_list, model_ans, sel_metric="f1")
qa_results.append([dataset_name, idx, true_ans_used, model_ans, method] + [scores[col] for col in qa_res_eval_cols])

dataset_runstat["f1"][method].append(scores["f1"] * 100)
dataset_runstat["em"][method].append(scores["em"] * 100)

iterator.set_postfix(
{
"f1": {method: np.mean(dataset_runstat["f1"][method]).round(2) for method in methods},
"em": {method: np.mean(dataset_runstat["em"][method]).round(2) for method in methods},
}
)

qa_results = pd.DataFrame(qa_results, columns=qa_res_columns)
return qa_results

qa_results = run_experiment(device, agents_dict, args)
qa_results

Loading the HotpotQA dataset ... Success!
DATA - HotpotQA  : 100%|██████████| 200/200 [20:45<00:00,  6.23s/it, f1={'Base': 57.3, 'COT': 58.46, 'FullElicit': 57.67, 'PromptElicit': 69.02, 'SelfElicit': 71.03}, em={'Base': 55.5, 'COT': 56.0, 'FullElicit': 55.5, 'PromptElicit': 67.0, 'SelfElicit': 68.84}]     
  1. 可视化结果:
1
2
3
4
5
6
7
8
9
10
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

for i, metric in enumerate(["f1", "em"]):
ax = sns.barplot(x="dataset", hue="method", y=metric, data=qa_results, ax=axes[i], ci=False)
ax.legend(loc="lower right")
ax.set_title(f"Metric: {metric.upper()}")
ax.set_ylim(0.5, ax.get_ylim()[1])

plt.tight_layout()
plt.show()
  1. 查看评估结果:
1
qa_results.groupby('method')[['f1', 'em', 'pr', 're']].mean() * 100
  1. 保存评估结果:
1
2
3
4
path = f"exp_[MODEL]{args.model_id.replace('/', '|')}_[METHOD]{'-'.join(args.methods)}_[DATA]{'-'.join(args.datasets)}.csv"
print(f"Saving results to {path} ...", end="")
qa_results.to_csv(path, index=False)
print("Success!")

Saving results to exp_[MODEL]..|model|Llama-3.1-8B-Instruct_[METHOD]Base-COT-FullElicit-PromptElicit-SelfElicit_[DATA]HotpotQA.csv ...Success!

3 config.yaml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
hf_token: "Your HF Token"
model_id: "meta-llama/Meta-Llama-3.1-8B-Instruct"
methods:
- "Base"
- "COT"
- "FullElicit"
- "PromptElicit"
- "SelfElicit"
datasets:
- "HotpotQA"
- "NewsQA"
- "TQA"
- "NQ"
alpha: 0.5
layer_span: [0.5, 1.0]
gpu_ids: [0]
n_samples: 1000
random_state: 0
max_ans_tokens: 100
marker_impstart: "<START_IMPORTANT>"
marker_impend: "<END_IMPORTANT>"
qa_inst: "Directly answer the question based on the context passage, no explanation is needed. If the context does not contain any evidence, output 'I cannot answer based on the given context.'"
se_inst: "Directly answer the question based on the context passage, no explanation is needed. Within the context, {MARKER_IMPSTART} and {MARKER_IMPEND} are used to mark the important evidence. Read carefully but still keep your answer short, do not output the markers. If the context does not contain any evidence, output 'I cannot answer based on the given context.'"
cot_inst: "Directly answer the question based on the context passage, no explanation is needed. If the context does not contain any evidence, output 'I cannot answer based on the given context.' Think step by step to provide the answer."
pe_inst: "Please find the supporting evidence sentences from the context for the question, then copy-paste the original text to output without any additional words. Template for output: '\n- [sentence1]\n- [sentence2] ...'"

【论文复现】SelfElicit
http://xuan-van.github.io/代码复现/【论文复现】selfelicit/
作者
文晋
发布于
2025年7月23日
许可协议