模型结构:
参考项目:AkariAsai/self-rag
1 安装 环境配置:
1 2 3 4 5 cd selfrag conda env create -f environment.yml conda activate selfrag conda install -c conda-forge faiss-gpu pip install scipy
问题:flash-attn 2.3.6 需要正确的 CUDA 才能安装。 解决方法:在 flash-attention/releases 中找到对应的 flash-attn 2.3.6 版本,先查看当前环境(Python、CUDA、PyTorch)版本,因此选择下载 flash_attn-2.3.6+cu122torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
,然后在 selfrag
虚拟环境中安装 pip install flash_attn-2.3.6+cu122torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
。
下载模型:
1 2 huggingface-cli download --resume-download selfrag/selfrag_llama2_7b --local-dir model/selfrag_llama2_7b huggingface-cli download --resume-download meta-llama/Llama-2-7b-hf --local-dir ./model/llama2-7b-hf
目录结构:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 data/ enwiki_2020_intro_only/ eval_data/ selfrag_train_data/ gpt4_reward_all_0813_train.json model/ critic_llama2_7b/ train_selfrag_7b/ rerpoduce/ evaluation/ evaluate.sh run_long_form_static.py metrics.py run_short_form.py utils.py retriever/ generate_embeddings.sh generate_passage_embeddings.py passage_retrieval.py run_retrieval.sh src/ train_critic/ llama_flash_attn_monkey_patch.py train_critic.sh train_special_tokens.py train_generator/ finetune.py merge.py stage3_no_offloading_accelerate.conf train_generator.sh start.py start2.py start3.py environment.yml flash_attn-2.3 .6 + cu122torch2.1 cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
2 快速开始 对于推理,使用 vllm 可以显著加快推理速度。
2.1 start.py 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 from vllm import LLM, SamplingParams model = LLM("../../model/selfrag_llama2_7b" , dtype="half" ) sampling_params = SamplingParams(temperature=0.0 , top_p=1.0 , max_tokens=100 , skip_special_tokens=False )def format_prompt (input , paragraph=None ): prompt = "### Instruction:\n{0}\n\n### Response:\n" .format (input ) if paragraph is not None : prompt += "[Retrieval]<paragraph>{0}</paragraph>" .format (paragraph) return prompt query_1 = "Leave odd one out: twitter, instagram, whatsapp." query_2 = "Can you tell me the difference between llamas and alpacas?" queries = [query_1, query_2] preds = model.generate([format_prompt(query) for query in queries], sampling_params)for pred in preds: print ("Model prediction: {0}" .format (pred.outputs[0 ].text))
cd reproduce; python start.py
结果:
Model prediction: Twitter, Instagram, and WhatsApp are all social media platforms.[No Retrieval]However, WhatsApp is a messaging app, while Twitter and Instagram are both primarily used for sharing photos and videos.[No Retrieval]Therefore, WhatsApp is the odd one out in this group.[Utility:5]</s>
Model prediction: Sure![Retrieval]<paragraph>
* Alpaca (left) and llama (right) in the Andes of southern Peru.
Alpacas and llamas are both domesticated species of South American camelids.[Continue to Use Evidence]Alpacas are a much smaller than llamas, with a shoulder height of 3 to 4 feet.[Continue to Use Evidence]They are also bred specifically for their fiber, which is used to make all sorts of textiles and clothing.
当 Self-RAG 不需要检索时,它会在第一个查询中开始生成不需要检索的响应。另一方面,Self-RAG 为第二个问题输出 [Retrieval]
令牌,因为这个问题需要更细粒度的事实基础。
2.2 start2.py 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 from vllm import LLM, SamplingParams model = LLM("../../model/selfrag_llama2_7b" , dtype="half" ) sampling_params = SamplingParams(temperature=0.0 , top_p=1.0 , max_tokens=100 , skip_special_tokens=False )def format_prompt (input , paragraph=None ): prompt = "### Instruction:\n{0}\n\n### Response:\n" .format (input ) if paragraph is not None : prompt += "[Retrieval]<paragraph>{0}</paragraph>" .format (paragraph) return prompt prompt = format_prompt("Can you tell me the difference between llamas and alpacas?" , "The alpaca (Lama pacos) is a species of South American camelid mammal. It is similar to, and often confused with, the llama. Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber." ) preds = model.generate([prompt], sampling_params)print ([pred.outputs[0 ].text for pred in preds])
cd reproduce; python start2.py
结果:
['[Relevant]Alpacas are considerably smaller than llamas.[Fully supported][Utility:5]</s>']
Self-RAG 可以在生成时随时检索和插入段落,并且只要它们被上下文标记特殊词元 <paragraph>
、</paragraph>
包围,就可以识别它们。Self-RAG 找到相关的插入文档,并生成完全有证据支持的答案。
2.3 使用 Online Retrieval 模型运行评估 从 google drive 下载维基百科的子集 enwiki_2020_intro_only.zip
(包括维基百科文章的介绍段落),保存在 data
文件夹下。
1 2 3 4 5 cd data unzip enwiki_2020_intro_only.ziprm enwiki_2020_intro_only.zipcd ../reproduce python start3.py
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 import sys sys.path.append('reproduce父目录绝对路径/reproduce/retriever' ) from vllm import LLM, SamplingParams from passage_retrieval import Retriever retriever = Retriever({}) retriever.setup_retriever_demo( "facebook/contriever-msmarco" , "data/enwiki_2020_intro_only/enwiki_2020_dec_intro_only.jsonl" , "data/enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/*" , n_docs=5 , save_or_load_index=False ) model = LLM("../../model/selfrag_llama2_7b" , dtype="half" ) sampling_params = SamplingParams(temperature=0.0 , top_p=1.0 , max_tokens=100 , skip_special_tokens=False )def format_prompt (input , paragraph=None ): prompt = "### Instruction:\n{0}\n\n### Response:\n" .format (input ) if paragraph is not None : prompt += "[Retrieval]<paragraph>{0}</paragraph>" .format (paragraph) return prompt query_3 = "When does overfitting occur?" retrieved_documents = retriever.search_document_demo(query_3, 5 ) prompts = [format_prompt(query_3, doc["title" ] +"\n" + doc["text" ]) for doc in retrieved_documents] preds = model.generate(prompts, sampling_params) top_doc = retriever.search_document_demo(query_3, 1 )[0 ]print ("Reference: {0}\nModel prediction: {1}" .format (top_doc["title" ] + "\n" + top_doc["text" ], preds[0 ].outputs[0 ].text))
Reference: Overfitting
In statistics, overfitting is "the production of an analysis that corresponds too closely or exactly to a particular set of data, and may therefore fail to fit additional data or predict future observations reliably". An overfitted model is a statistical model that contains more parameters than can be justified by the data. The essence of overfitting is to have unknowingly extracted some of the residual variation (i.e., the noise) as if that variation represented underlying model structure. Underfitting occurs when a statistical model cannot adequately capture the underlying structure of the data. An under-fitted model is a model where some parameters or terms that would appear in a correctly specified model are
Model prediction: [Relevant]Overfitting occurs when a statistical model has too many parameters relative to the amount of data available.[Fully supported][Continue to Use Evidence]This can lead to the model performing well on the training data but not on new, unseen data.[Utility:5]</s>
3 检索器设置 默认情况下,使用 Contriever 作为检索组件。
3.1 下载数据 下载 DPR 中使用的预处理过的段落数据:
1 2 3 cd data wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz gunzip psgs_w100.tsv.gz
下载生成的段落,使用 Contriever-MSMARCO :
1 2 3 4 cd data wget https://dl.fbaipublicfiles.com/contriever/embeddings/contriever-msmarco/wikipedia_embeddings.tar tar -xf wikipedia_embeddings.tarrm wikipedia_embeddings.tar
在下载之前可以使用 wget --spider 下载地址
来查看文件大小等情况。
3.2 运行检索器 通过以下命令来运行文章检索,见附录 7.1:
1 2 cd reproduce/retriever bash run_retrieval.sh
输入文件应为 json
或 jsonl
,每个实例必须包含 question
或 instruction
,它们将在检索期间用作查询。
3.3 为自己的数据生成 embeddings 通过以下命令为自己的数据生成 embeddings,见附录 7.2:
1 2 cd reproduce/retriever bash generate_embeddings.sh
4 训练 Self-RAG 训练两个模型 Critic 和 Generator,这两个模型都使用反射词元扩展词元词汇表,并使用标准的下一个词元预测目标进行训练。
4.1 收集反射词元 使用 GPT4 生成 Critic 数据,在 data_creation/critic
上可以找到为每种特殊令牌类型调用 GPT-4 的脚本。训练结果 为:gpt4_reward_all_0813_train.json
。
4.2 Critic 训练 用新的特殊词元训练 Critic, 对 Llama2-7B 进行微调,见附录 7.3:
1 2 cd reproduce/train_critic sbatch train_critic.sh
4.3 创建 Generator 数据 使用 Critic 和 Retriever 生成 Generator 训练数据,训练结果为:huggingface-cli download --repo-type dataset --resume-download selfrag/selfrag_train_data --local-dir ./data/selfrag_train_data
4.4 Generator 训练 使用新的特殊词元训练 Generator,用 DeepSpeed 来提高训练效率。设置训练数据路径后,通过运行附录 7.4 的脚本来进行训练。
1 2 cd reproduce/train_generator sbatch train_generator.sh
注意不同的 GPU 架构可能无法使用 bf16
,需要改为 fp16
,因此需要修改 stage3_no_offloading_accelerate.conf
,将
1 2 3 "bf16" : { "enabled" : "auto" }
改为
1 2 3 "fp16" : { "enabled" : "true" }
由于使用了 LoRA 技术,因此需要 python merge.py
来合并模型权重:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 from transformers import AutoModelForCausalLM, AutoTokenizerfrom peft import PeftModel base_model_path = "../model/llama2-7b-hf" lora_model_path = "model/train_selfrag_7b" output_dir = "model/reproduce_selfrag_7b" base_model = AutoModelForCausalLM.from_pretrained(base_model_path) lora_tokenizer = AutoTokenizer.from_pretrained(lora_model_path) base_model.resize_token_embeddings(len (lora_tokenizer)) lora_model = PeftModel.from_pretrained(base_model, lora_model_path) model = lora_model.merge_and_unload() model.save_pretrained(output_dir, safe_serialization=False ) lora_tokenizer.save_pretrained(output_dir) print ("Model merging and saving completed successfully!" )
5 推理 对于任务评估,下载数据集 eval_data.zip ,每个文件都已经附带了检索到的文档,因此,如果不想在推理中运行检索器,可以简单地在所有 contexts
中加载检索到的文件。使用附录 7.5 中的命令来评估相应的数据集。
5.1 短格式 通常只为简短的生成任务检索一次,因此提供了一个易于运行的评估脚本,该脚本利用了 Contriever 离线检索的预先给定的文档。--world_size
可使用多个 GPU 进行推理。--mode
有三种参数(两个 QA 数据集会用到):
adaptive_retrieval
:检索给定的阈值或 Self-RAG 预测。
no_retrieval
:在推理时禁用检索。
always_retrieve
:总是检索。
5.2 长格式 对于长篇 QA,可以使用检索模型或预先给定的段落 运行评估。DPR / Contriever 与整个英文维基百科嵌入需要 100 GB RAM,因此使用一小组初始检索文档。关键参数:
w_rel
(默认 1.0):控制符杠搜索过程中对 isRel
(对检索到的段落是否相关的批评标记)标记概率的强调。
w_sup
(默认 1.0):控制在符系搜索过程中对 isSup
(对文档是否支持生成)标记概率的强调。
w_use
(默认 0.5):控制 beam 搜索期间对 isUse
(对整体质量的批评标记)标记概率的强调。
threshold
(默认 0.2):此阈值控制自适应检索的频率。
max_depth
(默认 6):这对应于论文中的 T
,它定义了最大搜索深度。
beam_width
(默认 2):这控制了分段级光束搜索中光束的大小。
6 常见问题 6.1 CUDA out of memory
LoRA 技术。注意需要合并 LoRA 额外的参数。需要额外 pip install scipy
增加梯度累积步数 --gradient_accumulation_steps
。
清理缓存,在训练循环中适当的地方调用 torch.cuda.empty_cache()
来清理未使用的缓存,但是会增加训练时间 。
使用混合精度训练,设置 --fp16
(V100)或--bf16
(A100) 参数为 true
,可能反而导致内存不足 。
调整 PyTorch 内存分配器,设置 export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128
来避免内存碎片化。
将权重加载到 CPU 上,device_map='cpu'
。
6.2 加速训练
DeepSpeed:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 { 训练批量大小 train_batch_size = train_micro_batch_size_per_gpu * n_gpus * gradient_accumulation_steps "train_micro_batch_size_per_gpu" : 1 , 每块 GPU 的微批次大小 "gradient_accumulation_steps" : 8 , 梯度累积步数 "zero_optimization" : { ZeRO 优化器配置 "stage" : 2 , 启用 ZeRO Stage 2 优化,模型参数和优化器状态被分片到 CPU 或其他设备 "offload_optimizer" : { 优化器状态的卸载配置:不启用 "device" : "none" } , "offload_param" : { 模型参数的卸载配置:不启用 "device" : "none" } } , "optimizer" : { 优化器配置 "type" : "AdamW" , "用于训练深度学习模型 " params": { " lr": " auto", 学习率 " weight_decay": " auto" 权重衰减,用于 L2 正则化 } }, " scheduler": { 学习率调度器配置 " type": " WarmupLR", 用于稳定训练初期 " params": { " warmup_num_steps": " auto" Warmup 步数 } }, " fp16": { 启用混合精度训练 " enabled": " auto" } }
--fsdp "full_shard auto_wrap"
,无法与 DeepSpeed 同时使用 。
6.3 训练 Critic 出错 错误情况 在执行 train_special_tokens.py
脚本时,SupervisedDataset
类的初始化过程中出现了 KeyError
。此错误是由于尝试从 PROMPT_DICT
字典中访问不存在的键 "prompt_no_input_paragraph"
引起的。
解决方法 向 PROMPT_DICT
字典中添加 "prompt_no_input_paragraph"
和 "prompt_no_input_separated"
两个键:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 PROMPT_DICT = { "prompt_input" : ( "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" ), "prompt_no_input" : ( "### Instruction:\n{instruction}\n\n### Response:" ), "prompt_no_input_paragraph" : ( "### Instruction:\n{instruction}\n\n### Context:\n{context}\n\n### Response:" ), "prompt_no_input_separated" : ( "### Instruction:\n{instruction}\n\n### Separated Context:\n{context}\n\n### Response:" ), }
6.4 评估出错 run_short_form.py
的 call_model_rerank_w_scores_batch
函数中多出一个参数 max_depth
,需要删除。
6.5 其他
注意使用 bash 执行脚本时,输入通常需要 --
参数来换行,\
后不能有空格。
文件路径处理,用于导入数据和模型:
1 2 import os os.chdir('selfrag父目录绝对路径/selfrag' )
模块路径处理,用于导入自定义的库:
1 2 import sys sys.path.append('reproduce父目录绝对路径/reproduce/retriever' )
7 附录 7.1 run_retrieval.sh 1 2 3 4 5 6 7 python passage_retrieval.py \ --model_name_or_path facebook/contriever-msmarco \ --passages ../../data/psgs_w100.tsv \ --passages_embeddings "wikipedia_embeddings/*" \ --data YOUR_INPUT_FILE \ --output_dir YOUR_OUTPUT_FILE \ --n_docs 20
7.2 generate_embeddings.sh 1 2 3 4 5 6 7 8 9 for i in {0..1}; do export CUDA_VISIBLE_DEVICES=${i} python generate_passage_embeddings.py \ --model_name_or_path facebook/contriever-msmarco \ --output_dir YOUR_OUTPUT_DIR \ --passages YOUR_PASSAGE_DATA \ --shard_id ${i} \ --num_shards 4 \ > ./log/nohup.my_embeddings.${i} 2>&1 &
7.3 train_critic.sh 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 torchrun --nproc_per_node=2 \ --master_port=2568 train_special_tokens.py \ --model_name_or_path ../../../model/llama2-7b-hf \ --data_path ../../data/gpt4_reward_all_0813_train.json \ --output_dir ../../model/critic_llama2_7b \ --num_train_epochs 3 \ --per_device_train_batch_size 1 --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 8 \ --evaluation_strategy "no" \ --save_strategy "steps" \ --save_steps 300 \ --save_total_limit 1 \ --learning_rate 2e-5 \ --weight_decay 0. \ --warmup_ratio 0.01 \ --lr_scheduler_type "cosine" \ --logging_steps 10 \ --lora_rank 8 \ --lora_alpha 16 \ --lora_dropout 0.1
7.4 train_generator.sh 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 export CUDA_VISIBLE_DEVICES=0,1 MODEL_SIZE=7B NUM_GPUS=2 BATCH_SIZE_PER_GPU=1 TOTAL_BATCH_SIZE=128 GRADIENT_ACC_STEPS=$(($TOTAL_BATCH_SIZE /$NUM_GPUS /$BATCH_SIZE_PER_GPU ))echo "Training llama model ${MODEL_SIZE} using $NUM_GPUS GPUs, $BATCH_SIZE_PER_GPU batch size per GPU, $GRADIENT_ACC_STEPS gradient accumulation steps" CUDA_VISIBLE_DEVICES=0,1 accelerate launch \ --num_machines 1 \ --num_processes 2 \ --use_deepspeed \ --deepspeed_config_file stage3_no_offloading_accelerate.conf \ finetune.py \ --model_name_or_path ../../../model/llama2-7b-hf \ --use_flash_attn \ --tokenizer_name ../../../model/llama2-7b-hf \ --use_slow_tokenizer \ --train_file ../../data/selfrag_train_data/train.jsonl \ --max_seq_length 2048 \ --preprocessing_num_workers 16 \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 128 \ --learning_rate 2e-5 \ --lr_scheduler_type linear \ --warmup_ratio 0.03 \ --weight_decay 0. \ --num_train_epochs 3 \ --output_dir ../../model/train_selfrag_7b/ \ --with_tracking \ --report_to tensorboard \ --logging_steps 1000 \ --use_special_tokens \ --use_lora \ --lora_rank 8 \ --lora_alpha 16 \ --lora_dropout 0.1
7.5 evaluate.sh PopQA 1 2 3 4 5 6 7 8 9 python run_short_form.py \ --model_name ../../model/selfrag_llama2_7b \ --input_file ../../data/eval_data/popqa_longtail_w_gs.jsonl \ --mode adaptive_retrieval \ --max_new_tokens 100 \ --threshold 0.2 \ --output_file result/popqa.json \ --metric match --ndocs 10 --use_groundness --use_utility --use_seqscore \ --dtype half
TriviaQA 1 2 3 4 5 6 7 8 9 python run_short_form.py \ --model_name ../../model/selfrag_llama2_7b \ --input_file ../../data/eval_data/triviaqa_test_w_gs.jsonl \ --mode adaptive_retrieval \ --max_new_tokens 100 \ --threshold 0.2 \ --output_file result/triviaqa.json \ --metric match --ndocs 10 --use_groundness --use_utility --use_seqscore \ --dtype half
ARC-Challenge 1 2 3 4 5 6 7 python run_short_form.py \ --model_name ../../model/selfrag_llama2_7b \ --input_file ../../data/eval_data/arc_challenge_processed.jsonl \ --max_new_tokens 50 --threshold 0.2 \ --output_file result/arc_challenge.json \ --metric match --ndocs 5 --use_groundness --use_utility --use_seqscore \ --task arc_c
PubHealth 1 2 3 4 5 6 7 8 python run_short_form.py \ --model_name ../../model/selfrag_llama2_7b \ --input_file ../../data/eval_data/health_claims_processed.jsonl \ --max_new_tokens 50 \ --threshold 0.2 --output_file result/health.json \ --metric match --ndocs 5 \ --use_groundness --use_utility --use_seqscore \ --task fever
ASQA 1 2 3 4 5 6 python run_long_form_static.py \ --model_name ../../model/selfrag_llama2_7b \ --ndocs 5 --max_new_tokens 300 --threshold 0.2 \ --use_grounding --use_utility --use_seqscore \ --task asqa --input_file ../../data/eval_data/asqa_eval_gtr_top100.json \ --output_file result/asqa.json --max_depth 7 --mode always_retrieve \
FactScore 1 2 3 4 5 6 python run_long_form_static.py \ --model_name ../../model/selfrag_llama2_7b \ --ndocs 5 --max_new_tokens 300 --threshold 0.2 \ --use_grounding --use_utility --use_seqscore \ --task factscore --input_file ../../data/eval_data/factscore_unlabeled_alpaca_13b_retrieval.jsonl \ --output_file factscore.json --max_depth 7 \