【论文复现】SelfRAG

模型结构:

参考项目:AkariAsai/self-rag

1 安装

环境配置:

1
2
3
4
5
cd selfrag
conda env create -f environment.yml
conda activate selfrag
conda install -c conda-forge faiss-gpu
pip install scipy # LoRA 需要使用

问题:flash-attn 2.3.6 需要正确的 CUDA 才能安装。
解决方法:在 flash-attention/releases 中找到对应的 flash-attn 2.3.6 版本,先查看当前环境(Python、CUDA、PyTorch)版本,因此选择下载 flash_attn-2.3.6+cu122torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl,然后在 selfrag 虚拟环境中安装 pip install flash_attn-2.3.6+cu122torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl

下载模型:

1
2
huggingface-cli download --resume-download selfrag/selfrag_llama2_7b --local-dir model/selfrag_llama2_7b
huggingface-cli download --resume-download meta-llama/Llama-2-7b-hf --local-dir ./model/llama2-7b-hf

目录结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
data/ # 数据集
enwiki_2020_intro_only/
eval_data/
selfrag_train_data/
gpt4_reward_all_0813_train.json

model/ # 训练好的模型,不保存下载的模型
critic_llama2_7b/
train_selfrag_7b/

rerpoduce/ # 复现脚本
evaluation/
evaluate.sh
run_long_form_static.py
metrics.py
run_short_form.py
utils.py
retriever/
generate_embeddings.sh
generate_passage_embeddings.py
passage_retrieval.py
run_retrieval.sh
src/
train_critic/
llama_flash_attn_monkey_patch.py
train_critic.sh
train_special_tokens.py
train_generator/
finetune.py
merge.py
stage3_no_offloading_accelerate.conf
train_generator.sh
start.py
start2.py
start3.py

environment.yml # 环境包
flash_attn-2.3.6+cu122torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl # 额外的 whl

2 快速开始

对于推理,使用 vllm 可以显著加快推理速度。

2.1 start.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from vllm import LLM, SamplingParams # 使用 vllm 进行推理

# 加载预训练模型,数据类型为半精度浮点数(half)
model = LLM("../../model/selfrag_llama2_7b", dtype="half")

# 设置生成参数:温度为0.0(无随机性),top_p为1.0(不进行核采样),最大生成token数为100,不跳过特殊token
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)

# 定义一个函数,用于格式化输入提示
def format_prompt(input, paragraph=None):
# 构建基本的提示格式,包含指令和响应部分
prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
# 如果提供了段落信息,将其添加到提示中
if paragraph is not None:
prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
return prompt

# 定义两个查询示例
query_1 = "Leave odd one out: twitter, instagram, whatsapp."
query_2 = "Can you tell me the difference between llamas and alpacas?"
queries = [query_1, query_2]

# 对于不需要检索的查询,生成模型预测
preds = model.generate([format_prompt(query) for query in queries], sampling_params)

# 打印每个查询的模型预测结果
for pred in preds:
print("Model prediction: {0}".format(pred.outputs[0].text))

cd reproduce; python start.py 结果:

Model prediction: Twitter, Instagram, and WhatsApp are all social media platforms.[No Retrieval]However, WhatsApp is a messaging app, while Twitter and Instagram are both primarily used for sharing photos and videos.[No Retrieval]Therefore, WhatsApp is the odd one out in this group.[Utility:5]</s>
Model prediction: Sure![Retrieval]<paragraph>

* Alpaca (left) and llama (right) in the Andes of southern Peru.

Alpacas and llamas are both domesticated species of South American camelids.[Continue to Use Evidence]Alpacas are a much smaller than llamas, with a shoulder height of 3 to 4 feet.[Continue to Use Evidence]They are also bred specifically for their fiber, which is used to make all sorts of textiles and clothing.

当 Self-RAG 不需要检索时,它会在第一个查询中开始生成不需要检索的响应。另一方面,Self-RAG 为第二个问题输出 [Retrieval] 令牌,因为这个问题需要更细粒度的事实基础。

2.2 start2.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from vllm import LLM, SamplingParams

# 加载预训练模型,数据类型为半精度浮点数(half)
model = LLM("../../model/selfrag_llama2_7b", dtype="half")

# 设置生成参数:温度为0.0(无随机性),top_p为1.0(不进行核采样),最大生成token数为100,不跳过特殊token
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)

# 定义一个函数,用于格式化输入提示
def format_prompt(input, paragraph=None):
# 构建基本的提示格式,包含指令和响应部分
prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
# 如果提供了段落信息,将其添加到提示中
if paragraph is not None:
prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
return prompt

prompt = format_prompt("Can you tell me the difference between llamas and alpacas?", "The alpaca (Lama pacos) is a species of South American camelid mammal. It is similar to, and often confused with, the llama. Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber.")
preds = model.generate([prompt], sampling_params)
print([pred.outputs[0].text for pred in preds]) # 打印每个查询的模型预测结果

cd reproduce; python start2.py 结果:

['[Relevant]Alpacas are considerably smaller than llamas.[Fully supported][Utility:5]</s>']

Self-RAG 可以在生成时随时检索和插入段落,并且只要它们被上下文标记特殊词元 <paragraph></paragraph> 包围,就可以识别它们。Self-RAG 找到相关的插入文档,并生成完全有证据支持的答案。

2.3 使用 Online Retrieval 模型运行评估

google drive 下载维基百科的子集 enwiki_2020_intro_only.zip(包括维基百科文章的介绍段落),保存在 data 文件夹下。

1
2
3
4
5
cd data
unzip enwiki_2020_intro_only.zip
rm enwiki_2020_intro_only.zip
cd ../reproduce
python start3.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import sys
sys.path.append('reproduce父目录绝对路径/reproduce/retriever') # 模块导入

from vllm import LLM, SamplingParams # 导入LLM模型和采样参数设置
from passage_retrieval import Retriever # 导入用于文档检索的Retriever类

# 初始化检索器,并设置其参数
retriever = Retriever({})
retriever.setup_retriever_demo(
"facebook/contriever-msmarco", # 使用的模型
"data/enwiki_2020_intro_only/enwiki_2020_dec_intro_only.jsonl", # 检索数据集
"data/enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/*", # 索引文件路径
n_docs=5, # 检索文档数量
save_or_load_index=False # 是否保存或加载索引
)

# 加载预训练模型,数据类型为半精度浮点数(half)
model = LLM("../../model/selfrag_llama2_7b", dtype="half")

# 设置生成参数:温度为0.0(无随机性),top_p为1.0(不进行核采样),最大生成token数为100,不跳过特殊token
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)

# 定义一个函数,用于格式化输入提示
def format_prompt(input, paragraph=None):
# 构建基本的提示格式,包含指令和响应部分
prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
# 如果提供了段落信息,将其添加到提示中
if paragraph is not None:
prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
return prompt

# 定义查询问题
query_3 = "When does overfitting occur?"
# 使用检索器搜索相关文档
retrieved_documents = retriever.search_document_demo(query_3, 5)
# 为每个检索到的文档创建格式化的提示
prompts = [format_prompt(query_3, doc["title"] +"\n"+ doc["text"]) for doc in retrieved_documents]
# 使用模型生成预测结果
preds = model.generate(prompts, sampling_params)
# 检索最相关的文档
top_doc = retriever.search_document_demo(query_3, 1)[0]
# 打印参考文档和模型预测结果
print("Reference: {0}\nModel prediction: {1}".format(top_doc["title"] + "\n" + top_doc["text"], preds[0].outputs[0].text))

Reference: Overfitting
  In statistics, overfitting is "the production of an analysis that corresponds too closely or exactly to a particular set of data, and may therefore fail to fit additional data or predict future observations reliably". An overfitted model is a statistical model that contains more parameters than can be justified by the data. The essence of overfitting is to have unknowingly extracted some of the residual variation (i.e., the noise) as if that variation represented underlying model structure. Underfitting occurs when a statistical model cannot adequately capture the underlying structure of the data. An under-fitted model is a model where some parameters or terms that would appear in a correctly specified model are 
Model prediction: [Relevant]Overfitting occurs when a statistical model has too many parameters relative to the amount of data available.[Fully supported][Continue to Use Evidence]This can lead to the model performing well on the training data but not on new, unseen data.[Utility:5]</s>

3 检索器设置

默认情况下,使用 Contriever 作为检索组件。

3.1 下载数据

下载 DPR 中使用的预处理过的段落数据:

1
2
3
cd data
wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz
gunzip psgs_w100.tsv.gz

下载生成的段落,使用 Contriever-MSMARCO

1
2
3
4
cd data
wget https://dl.fbaipublicfiles.com/contriever/embeddings/contriever-msmarco/wikipedia_embeddings.tar
tar -xf wikipedia_embeddings.tar
rm wikipedia_embeddings.tar

在下载之前可以使用 wget --spider 下载地址 来查看文件大小等情况。

3.2 运行检索器

通过以下命令来运行文章检索,见附录 7.1:

1
2
cd reproduce/retriever
bash run_retrieval.sh

输入文件应为 jsonjsonl,每个实例必须包含 questioninstruction,它们将在检索期间用作查询。

3.3 为自己的数据生成 embeddings

通过以下命令为自己的数据生成 embeddings,见附录 7.2:

1
2
cd reproduce/retriever
bash generate_embeddings.sh

4 训练

Self-RAG 训练两个模型 Critic 和 Generator,这两个模型都使用反射词元扩展词元词汇表,并使用标准的下一个词元预测目标进行训练。

4.1 收集反射词元

使用 GPT4 生成 Critic 数据,在 data_creation/critic 上可以找到为每种特殊令牌类型调用 GPT-4 的脚本。训练结果为:gpt4_reward_all_0813_train.json

4.2 Critic 训练

用新的特殊词元训练 Critic, 对 Llama2-7B 进行微调,见附录 7.3:

1
2
cd reproduce/train_critic
sbatch train_critic.sh

4.3 创建 Generator 数据

使用 Critic 和 Retriever 生成 Generator 训练数据,训练结果为:huggingface-cli download --repo-type dataset --resume-download selfrag/selfrag_train_data --local-dir ./data/selfrag_train_data

4.4 Generator 训练

使用新的特殊词元训练 Generator,用 DeepSpeed 来提高训练效率。设置训练数据路径后,通过运行附录 7.4 的脚本来进行训练。

1
2
cd reproduce/train_generator
sbatch train_generator.sh

注意不同的 GPU 架构可能无法使用 bf16,需要改为 fp16,因此需要修改 stage3_no_offloading_accelerate.conf,将

1
2
3
"bf16": {
"enabled": "auto"
}

改为

1
2
3
"fp16": {
"enabled": "true"
}

由于使用了 LoRA 技术,因此需要 python merge.py 来合并模型权重:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

base_model_path = "../model/llama2-7b-hf" # 基础模型路径
lora_model_path = "model/train_selfrag_7b" # LoRA 微调后的模型路径
output_dir = "model/reproduce_selfrag_7b" # 合并后的模型输出路径

base_model = AutoModelForCausalLM.from_pretrained(base_model_path) # 加载基础模型
lora_tokenizer = AutoTokenizer.from_pretrained(lora_model_path) # 加载 LoRA 分词器
base_model.resize_token_embeddings(len(lora_tokenizer)) # 扩展模型的词汇表
lora_model = PeftModel.from_pretrained(base_model, lora_model_path) # 加载 LoRA 适配器
model = lora_model.merge_and_unload() # 合并模型

model.save_pretrained(output_dir, safe_serialization=False) # 保存合并后的模型
lora_tokenizer.save_pretrained(output_dir) # 保存扩展后的分词器

print("Model merging and saving completed successfully!")

5 推理

对于任务评估,下载数据集 eval_data.zip,每个文件都已经附带了检索到的文档,因此,如果不想在推理中运行检索器,可以简单地在所有 contexts 中加载检索到的文件。使用附录 7.5 中的命令来评估相应的数据集。

5.1 短格式

通常只为简短的生成任务检索一次,因此提供了一个易于运行的评估脚本,该脚本利用了 Contriever 离线检索的预先给定的文档。--world_size 可使用多个 GPU 进行推理。--mode 有三种参数(两个 QA 数据集会用到):

  • adaptive_retrieval:检索给定的阈值或 Self-RAG 预测。
  • no_retrieval:在推理时禁用检索。
  • always_retrieve:总是检索。

5.2 长格式

对于长篇 QA,可以使用检索模型或预先给定的段落运行评估。DPR / Contriever 与整个英文维基百科嵌入需要 100 GB RAM,因此使用一小组初始检索文档。关键参数:

  • w_rel(默认 1.0):控制符杠搜索过程中对 isRel(对检索到的段落是否相关的批评标记)标记概率的强调。
  • w_sup(默认 1.0):控制在符系搜索过程中对 isSup(对文档是否支持生成)标记概率的强调。
  • w_use(默认 0.5):控制 beam 搜索期间对 isUse(对整体质量的批评标记)标记概率的强调。
  • threshold(默认 0.2):此阈值控制自适应检索的频率。
  • max_depth(默认 6):这对应于论文中的 T,它定义了最大搜索深度。
  • beam_width(默认 2):这控制了分段级光束搜索中光束的大小。

6 常见问题

6.1 CUDA out of memory

  1. LoRA 技术。注意需要合并 LoRA 额外的参数。需要额外 pip install scipy
  2. 增加梯度累积步数 --gradient_accumulation_steps
  3. 清理缓存,在训练循环中适当的地方调用 torch.cuda.empty_cache() 来清理未使用的缓存,但是会增加训练时间
  4. 使用混合精度训练,设置 --fp16(V100)或--bf16(A100) 参数为 true可能反而导致内存不足
  5. 调整 PyTorch 内存分配器,设置 export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 来避免内存碎片化。
  6. 将权重加载到 CPU 上,device_map='cpu'

6.2 加速训练

  1. DeepSpeed:
    • 配置 df_config.json
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    {
    训练批量大小 train_batch_size = train_micro_batch_size_per_gpu * n_gpus * gradient_accumulation_steps
    "train_micro_batch_size_per_gpu": 1, 每块 GPU 的微批次大小
    "gradient_accumulation_steps": 8, 梯度累积步数
    "zero_optimization": { ZeRO 优化器配置
    "stage": 2, 启用 ZeRO Stage 2 优化,模型参数和优化器状态被分片到 CPU 或其他设备
    "offload_optimizer": { 优化器状态的卸载配置:不启用
    "device": "none"
    },
    "offload_param": { 模型参数的卸载配置:不启用
    "device": "none"
    }
    },
    "optimizer": { 优化器配置
    "type": "AdamW", "用于训练深度学习模型
    "params": {
    "lr": "auto", 学习率
    "weight_decay": "auto" 权重衰减,用于 L2 正则化
    }
    },
    "scheduler": { 学习率调度器配置
    "type": "WarmupLR", 用于稳定训练初期
    "params": {
    "warmup_num_steps": "auto" Warmup 步数
    }
    },
    "fp16": { 启用混合精度训练
    "enabled": "auto"
    }
    }
  2. --fsdp "full_shard auto_wrap"无法与 DeepSpeed 同时使用

6.3 训练 Critic 出错

错误情况

在执行 train_special_tokens.py 脚本时,SupervisedDataset 类的初始化过程中出现了 KeyError。此错误是由于尝试从 PROMPT_DICT 字典中访问不存在的键 "prompt_no_input_paragraph" 引起的。

解决方法

PROMPT_DICT 字典中添加 "prompt_no_input_paragraph""prompt_no_input_separated" 两个键:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
PROMPT_DICT = {
"prompt_input": (
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"### Instruction:\n{instruction}\n\n### Response:"
),
"prompt_no_input_paragraph": (
"### Instruction:\n{instruction}\n\n### Context:\n{context}\n\n### Response:"
),
"prompt_no_input_separated": (
"### Instruction:\n{instruction}\n\n### Separated Context:\n{context}\n\n### Response:"
),
}

6.4 评估出错

run_short_form.pycall_model_rerank_w_scores_batch 函数中多出一个参数 max_depth,需要删除。

6.5 其他

  1. 注意使用 bash 执行脚本时,输入通常需要 -- 参数来换行,\ 后不能有空格。
  2. 文件路径处理,用于导入数据和模型:
1
2
import os
os.chdir('selfrag父目录绝对路径/selfrag')
  1. 模块路径处理,用于导入自定义的库:
1
2
import sys
sys.path.append('reproduce父目录绝对路径/reproduce/retriever')

7 附录

7.1 run_retrieval.sh

1
2
3
4
5
6
7
python passage_retrieval.py \
--model_name_or_path facebook/contriever-msmarco \ # 指定要使用的模型
--passages ../../data/psgs_w100.tsv \ # 指定要使用的文档集合
--passages_embeddings "wikipedia_embeddings/*" \ # 指定预先计算的文档嵌入文件路径
--data YOUR_INPUT_FILE \ # 指定输入数据文件的路径
--output_dir YOUR_OUTPUT_FILE \ # 指定输出目录
--n_docs 20 # 指定要检索的文档数量

7.2 generate_embeddings.sh

1
2
3
4
5
6
7
8
9
for i in {0..1}; do # 循环遍历0到1的数字
export CUDA_VISIBLE_DEVICES=${i} # 设置 CUDA_VISIBLE_DEVICES 环境变量为当前循环的数字
python generate_passage_embeddings.py \ # 运行 generate_passage_embeddings.py 脚本,生成段落嵌入
--model_name_or_path facebook/contriever-msmarco \ # 指定使用的模型
--output_dir YOUR_OUTPUT_DIR \ # 指定输出目录
--passages YOUR_PASSAGE_DATA \ # 指定段落数据文件
--shard_id ${i} \ # 指定当前分片的ID
--num_shards 4 \ # 指定总分片数
> ./log/nohup.my_embeddings.${i} 2>&1 & # 将输出重定向到日志文件,并在后台运行

7.3 train_critic.sh

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
torchrun --nproc_per_node=2 \  # 使用torchrun命令运行脚本,每个节点使用2个进程
--master_port=2568 train_special_tokens.py \ # 设置主节点端口为2568,运行train_special_tokens.py脚本
--model_name_or_path ../../../model/llama2-7b-hf \ # 指定模型路径
--data_path ../../data/gpt4_reward_all_0813_train.json \ # 指定数据路径
--output_dir ../../model/critic_llama2_7b \ # 指定输出目录
--num_train_epochs 3 \ # 设置训练轮数为 3
--per_device_train_batch_size 1 --per_device_eval_batch_size 1 \ # 设置每个设备的训练和评估批次大小为 1
--gradient_accumulation_steps 8 \ # 设置梯度累积步数为 8
--evaluation_strategy "no" \ # 设置评估策略为不评估
--save_strategy "steps" \ # 设置保存策略为按步数保存
--save_steps 300 \ # 设置每 300 步保存一次
--save_total_limit 1 \ # 设置保存的总数限制为 1
--learning_rate 2e-5 \ # 设置学习率为 2e-5
--weight_decay 0. \ # 设置权重衰减为 0
--warmup_ratio 0.01 \ # 设置预热比例为 0.01
--lr_scheduler_type "cosine" \ # 设置学习率调度器类型为 cosine
--logging_steps 10 \ # 设置每 10 步记录一次日志
--lora_rank 8 \ # 设置 LoRA 秩为 8
--lora_alpha 16 \ # 设置 LoRA alpha 为 16
--lora_dropout 0.1 # 设置 LoRA dropout 为 0.1

7.4 train_generator.sh

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
export CUDA_VISIBLE_DEVICES=0,1

MODEL_SIZE=7B
NUM_GPUS=2
BATCH_SIZE_PER_GPU=1
TOTAL_BATCH_SIZE=128
GRADIENT_ACC_STEPS=$(($TOTAL_BATCH_SIZE/$NUM_GPUS/$BATCH_SIZE_PER_GPU))
echo "Training llama model ${MODEL_SIZE} using $NUM_GPUS GPUs, $BATCH_SIZE_PER_GPU batch size per GPU, $GRADIENT_ACC_STEPS gradient accumulation steps"

CUDA_VISIBLE_DEVICES=0,1 # 设置哪些GPU对当前进程可见,这里指定了0号和1号GPU。

accelerate launch \ # 使用Hugging Face的accelerate库来启动分布式训练。
--num_machines 1 \ # 指定使用的机器数量,这里为1台机器。
--num_processes 2 \ # 指定进程数,通常等于GPU的数量。
--use_deepspeed \ # 启用DeepSpeed库来加速训练。
--deepspeed_config_file stage3_no_offloading_accelerate.conf \ # 指定DeepSpeed配置文件。
finetune.py \ # 指定要执行的Python脚本,这里是finetune.py。
--model_name_or_path ../../../model/llama2-7b-hf \ # 指定模型的名称或路径,这里使用相对路径指定模型位置。
--use_flash_attn \ # 使用Flash Attention,这是一种高效的注意力机制实现。
--tokenizer_name ../../../model/llama2-7b-hf \ # 指定分词器的名称或路径。
--use_slow_tokenizer \ # 使用较慢的分词器实现。
--train_file ../../data/selfrag_train_data/train.jsonl \ # 指定训练数据文件的位置。
--max_seq_length 2048 \ # 设置最大序列长度。
--preprocessing_num_workers 16 \ # 设置预处理工作线程数。
--per_device_train_batch_size 1 \ # 每个设备上的批次大小。
--gradient_accumulation_steps 128 \ # 梯度累积步数。
--learning_rate 2e-5 \ # 设置学习率。
--lr_scheduler_type linear \ # 学习率调度器类型,这里使用线性调度器。
--warmup_ratio 0.03 \ # 预热比例,用于学习率预热。
--weight_decay 0. \ # 设置权重衰减,这里为0,表示不使用权重衰减。
--num_train_epochs 3 \ # 设置训练的轮数(epoch)。
--output_dir ../../model/train_selfrag_7b/ \ # 设置输出目录。
--with_tracking \ # 启用跟踪,可能是指使用某种跟踪工具。
--report_to tensorboard \ # 指定报告工具,这里使用TensorBoard。
--logging_steps 1000 \ # 设置日志记录步数,每1000步记录一次。
--use_special_tokens \ # 使用特殊标记。
--use_lora \ # 使用LoRA(Low-Rank Adaptation)技术。
--lora_rank 8 \ # 设置LoRA的秩。
--lora_alpha 16 \ # 设置LoRA的缩放因子。
--lora_dropout 0.1 # 设置LoRA的dropout率。

7.5 evaluate.sh

PopQA

1
2
3
4
5
6
7
8
9
python run_short_form.py \
--model_name ../../model/selfrag_llama2_7b \
--input_file ../../data/eval_data/popqa_longtail_w_gs.jsonl \
--mode adaptive_retrieval \
--max_new_tokens 100 \
--threshold 0.2 \
--output_file result/popqa.json \
--metric match --ndocs 10 --use_groundness --use_utility --use_seqscore \
--dtype half

TriviaQA

1
2
3
4
5
6
7
8
9
python run_short_form.py \
--model_name ../../model/selfrag_llama2_7b \
--input_file ../../data/eval_data/triviaqa_test_w_gs.jsonl \
--mode adaptive_retrieval \
--max_new_tokens 100 \
--threshold 0.2 \
--output_file result/triviaqa.json \
--metric match --ndocs 10 --use_groundness --use_utility --use_seqscore \
--dtype half

ARC-Challenge

1
2
3
4
5
6
7
python run_short_form.py \
--model_name ../../model/selfrag_llama2_7b \
--input_file ../../data/eval_data/arc_challenge_processed.jsonl \
--max_new_tokens 50 --threshold 0.2 \
--output_file result/arc_challenge.json \
--metric match --ndocs 5 --use_groundness --use_utility --use_seqscore \
--task arc_c

PubHealth

1
2
3
4
5
6
7
8
python run_short_form.py \
--model_name ../../model/selfrag_llama2_7b \
--input_file ../../data/eval_data/health_claims_processed.jsonl \
--max_new_tokens 50 \
--threshold 0.2 --output_file result/health.json \
--metric match --ndocs 5 \
--use_groundness --use_utility --use_seqscore \
--task fever

ASQA

1
2
3
4
5
6
python run_long_form_static.py \
--model_name ../../model/selfrag_llama2_7b \
--ndocs 5 --max_new_tokens 300 --threshold 0.2 \
--use_grounding --use_utility --use_seqscore \
--task asqa --input_file ../../data/eval_data/asqa_eval_gtr_top100.json \
--output_file result/asqa.json --max_depth 7 --mode always_retrieve \

FactScore

1
2
3
4
5
6
python run_long_form_static.py \
--model_name ../../model/selfrag_llama2_7b \
--ndocs 5 --max_new_tokens 300 --threshold 0.2 \
--use_grounding --use_utility --use_seqscore \
--task factscore --input_file ../../data/eval_data/factscore_unlabeled_alpaca_13b_retrieval.jsonl \
--output_file factscore.json --max_depth 7 \

【论文复现】SelfRAG
http://xuan-van.github.io/代码复现/【论文复现】selfrag/
作者
文晋
发布于
2024年12月11日
许可协议