【论文复现】xRAG

模型结构:

参考项目:Hannibal046/xRAG

1 安装

1.1 虚拟环境

1
2
3
4
5
conda create -n xrag python=3.9
pip install torch==2.1.1 transformers==4.38.0 accelerate==0.27.2 datasets==2.17.1 deepspeed==0.13.2 sentencepiece wandb "numpy<2" ipykernel
python -m ipykernel install --user --name xrag
jupyter kernelspec list
pip install flash-attn==2.3.4 --no-build-isolation

1.2 模型

1
2
huggingface-cli download --resume-download Hannibal046/xrag-7b --local-dir model/xrag-7b
huggingface-cli download --resume-download Salesforce/SFR-Embedding-Mistral --local-dir model/SFR-Embedding-Mistral

1.3 数据集

google drive 可以下载部分数据集,每个数据集都有 train.jsonldev.jsonltest.jsonl 三个子集,数据格式如下:

1
2
3
4
5
6
{
"id": str,
"question": str,
"answer": [str, ...], # fever 数据集的 test.jsonl 的这一字段为空
"entity": str, # 只有 tqa 数据集有这一字段
}

1.4 语料库

corpora/wiki/enwiki-dec2021

1
2
https://dl.fbaipublicfiles.com/atlas/corpora/wiki/enwiki-dec2021/text-list-100-sec.jsonl
https://dl.fbaipublicfiles.com/atlas/corpora/wiki/enwiki-dec2021/infobox.jsonl

infobox.jsonl: 4330888 包含了维基百科页面中的信息框数据,数据格式如下:

1
2
3
4
5
{
"id": str,
"title": str,
"text": str
}

text-list-100-sec.jsonl: 33176581 包含了从维基百科中提取的文本段落,数据格式如下:

1
2
3
4
5
6
{
"id": str,
"title": str,
"section": str,
"text": str
}

2 Projector 的训练过程

2.1 释义预训练

这一过程是为了让模型学会理解文档块和对应的嵌入之间的关系,训练数据格式如下:

1
2
3
4
{
"id": str,
"text": str
}

2.2 上下文感知指令微调

类似知识蒸馏的方法,通过学习用检索到的上下文和问题来生成概率分布,让模型学会用对应的嵌入和相关指令来生成相似的概率分布。训练数据的的预处理过程可以参考 prepare_data.ipynb,得到的数据格式如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
{
"id": str,
"message": [
{
"role": "user",
"content": str
},
{
"role": "assistant",
"content": str
}
],
"task_type": str, # open_qa(无 background), close_qa, summarization, fact_checking(无 background)
"background": str # 部分没有这一字段
}

不同的任务类型有不同的提示词模板,生成数据时会随机选择,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
templates_for_qa = [
"Question: {question}?\nAnswer:",
"{question}?",
"Answer the following question:\n\n{question}",
"Answer this question:\n\n{question}?",
"Please answer this question: {question}",
"Answer the question...{question}?",
"What is the answer to this question? {question}\n\n",
"Can you tell me the answer to {question}?",
"Next question: {question}\n\n",
"Q: {question} A:",
"{question}\nWhat is the answer?",
"Write the answer: {question}",
"{question}???",
]

templates_for_sum = [
"Write a short summary for the text\n\nSummary:",
"Briefly summarize this article:\nSummary:",
"What is a shorter version of this:\n\nSummary:",
"Write a brief summary in a sentence or less.",
"What is a very short summary of the above text?",
"Summarize the aforementioned text in a single phrase.",
"Can you generate a short summary of the above paragraph?",
"Summarize the above articles\n\ntl;dr:",
]

template_for_fact_checking = [
"Verify the following claims with \"True\" or \"False\":\n{question}",
]

3 示例

3.1 准备工作

  1. 导入必要的包:
1
2
3
4
5
import torch
import torch.nn as nn
import re
from torch import Tensor
from transformers import AutoTokenizer, MistralForCausalLM, MistralModel
  1. 加载 src/model/SFR/modeling_sfr.py 的功能:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
"""
池化函数:从最后一个隐藏状态中提取每个序列的最后一个有效token的表示。

参数:
last_hidden_states (Tensor): 模型的最后一个隐藏层输出,形状为[batch_size, sequence_length, hidden_size]
attention_mask (Tensor): 注意力掩码,形状为[batch_size, sequence_length]

返回:
Tensor: 池化后的嵌入向量,形状为[batch_size, hidden_size]
"""
# 检查是否是左填充
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) # 每个序列的最后一个位置的注意力掩码值都为1,表明最后一个位置是有效内容,而不是填充符号
if left_padding:
# 如果是左填充,直接取最后一个token
return last_hidden_states[:, -1]
else:
# 否则计算每个序列的实际长度,并取最后一个有效token
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]


# 基于MistralModel的SFR嵌入模型类,用于生成文档和查询的嵌入表示
class SFR(MistralModel):
# 返回嵌入向量的维度(隐藏层大小)
def get_embed_dim(self):
return self.config.hidden_size

# 返回嵌入向量的长度(固定为1)
def get_embed_length(self):
return 1

# 生成嵌入向量
def get_embedding(self, input_ids, attention_mask):
# 前向传播获取模型输出
outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask)
# 使用last_token_pool池化最后一个隐藏状态
embeddings = last_token_pool(outputs.last_hidden_state, attention_mask)
return embeddings

# 生成文档嵌入向量
def get_doc_embedding(self, input_ids, attention_mask):
return self.get_embedding(input_ids, attention_mask)

# 生成查询嵌入向量
def get_query_embedding(self, input_ids, attention_mask):
return self.get_embedding(input_ids, attention_mask)
  1. 加载 src/model/xMistral/modeling_xmistral.py 的功能:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# 投影器
class Projector(nn.Module):
def __init__(self, config):
super().__init__()
projector_type = config.projector_type # 获取配置中的投影器类型
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) # 使用正则表达式匹配投影器类型
if mlp_gelu_match: # 如果匹配成功
mlp_depth = int(mlp_gelu_match.group(1)) # 获取MLP的深度
modules = [nn.Linear(config.retriever_hidden_size, config.hidden_size)] # 创建第一个线性层
for _ in range(1, mlp_depth): # 根据深度添加更多层
modules.append(nn.GELU()) # 添加GELU激活函数
modules.append(nn.Linear(config.hidden_size, config.hidden_size)) # 添加线性层
self.projector = nn.Sequential(*modules) # 将模块序列化为投影器

# 前向传播:将上下文嵌入通过投影器
def forward(self,context_embedding):
return self.projector(context_embedding)


# 基于MistralForCausalLM的推理模型
class XMistralForCausalLM(MistralForCausalLM):
def __init__(self, config):
super().__init__(config) # 调用父类MistralForCausalLM的初始化方法
if hasattr(config,"retriever_hidden_size") and config.retriever_hidden_size > 0: # 如果配置中有retriever_hidden_size且大于0
self.projector = Projector(config) # 初始化投影器
self.retriever_hidden_size = config.retriever_hidden_size # 设置检索器隐藏层大小
self.post_init() # 调用父类的后初始化方法

# 设置xrag token的ID
def set_xrag_token_id(self, token_id):
self.xrag_token_id = token_id

# 准备输入嵌入
def prepare_inputs_embeds(self, input_ids, retrieval_embeds):
inputs_embeds = self.model.embed_tokens(input_ids) # 将输入ID转换为嵌入
retrieval_embeds = retrieval_embeds.view(-1, self.retriever_hidden_size) # 重塑检索嵌入的形状

## 完整性检查
num_xrag_tokens = torch.sum(input_ids == self.xrag_token_id).item() # 计算xrag token的数量
num_retrieval_embeds = retrieval_embeds.shape[0] # 获取检索嵌入的数量
assert num_xrag_tokens == num_retrieval_embeds, (num_xrag_tokens, num_retrieval_embeds) # 确保两者数量一致

retrieval_embeds = self.projector(retrieval_embeds.to(inputs_embeds.dtype)) # 将检索嵌入通过投影器
inputs_embeds[input_ids == self.xrag_token_id] = retrieval_embeds # 用投影后的检索嵌入替换xrag标记位置的嵌入

return inputs_embeds # 返回处理后的输入嵌入

# 前向传播
def forward(
self,
input_ids=None, # 输入的token IDs
retrieval_embeds=None, # 检索嵌入,形状为[-1, retrieval_hidden_size]
attention_mask=None, # 注意力掩码
**kwargs, # 其他参数
):
## 当传入inputs_embeds(只有生成的第一轮会)时,表示模型正在生成
inputs_embeds = kwargs.pop("inputs_embeds", None) # 从kwargs中取出inputs_embeds
at_the_beginning_of_generation = False # 标记是否处于生成开始阶段
if inputs_embeds is not None: # 如果传入了inputs_embeds,则不允许传入 retrieval_embeds
assert not self.training # 确保不在训练模式下
assert retrieval_embeds is None # 确保没有传入retrieval_embeds
at_the_beginning_of_generation = True # 标记为生成开始阶段

if not at_the_beginning_of_generation: # 如果不是生成开始阶段
## 单次前向传播
if retrieval_embeds is not None: # 如果传入了检索嵌入
inputs_embeds = self.prepare_inputs_embeds(input_ids, retrieval_embeds) # 准备输入嵌入
input_ids = None # 将input_ids设为None
if attention_mask is not None: # 如果有注意力掩码
assert inputs_embeds.shape[1] == attention_mask.shape[1],(inputs_embeds.shape, attention_mask.shape) # 确保形状匹配

return super().forward( # 调用父类的forward方法
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**kwargs,
)

# 生成
@torch.no_grad() # 禁用梯度计算
def generate(
self,
input_ids=None, # 输入的token IDs
retrieval_embeds=None, # 检索嵌入
**kwargs, # 其他参数
):
attention_mask = kwargs.pop("attention_mask", None) # 从kwargs中取出attention_mask
if "inputs_embeds" in kwargs: # 如果kwargs中包含inputs_embeds
raise NotImplementedError("`inputs_embeds` is not supported for generate") # 抛出未实现错误

inputs_embeds=None
if retrieval_embeds is not None: # 如果传入了检索嵌入
inputs_embeds = self.prepare_inputs_embeds(input_ids, retrieval_embeds) # 准备输入嵌入
input_ids = None # 将input_ids设为None
if attention_mask is not None: # 如果有注意力掩码
assert inputs_embeds.shape[1] == attention_mask.shape[1], (inputs_embeds.shape, attention_mask.shape) # 确保形状匹配
return super().generate( # 调用父类的generate方法
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)

else: # 如果没有传入检索嵌入
return super().generate( # 调用父类的generate方法
attention_mask=attention_mask,
input_ids=input_ids,
**kwargs
)
  1. 加载 src/language_modeling/utils.py 的功能:
1
2
3
4
5
6
7
8
9
10
11
XRAG_TOKEN = "<xRAG>"

# 提取检索文本的嵌入
def get_retrieval_embeds(model, input_ids, attention_mask=None):
with torch.no_grad(): # 推理阶段禁用梯度计算
embeds = model.get_doc_embedding( # 调用模型的get_doc_embedding方法
input_ids=input_ids,
attention_mask=attention_mask,
)
embeds = embeds.view(-1, embeds.shape[-1]) # 重塑嵌入的形状
return embeds
  1. 加载模型和分词器:
1
2
3
4
device = torch.device("cuda:0")
llm_name_or_path = "../model/xrag-7b"
llm = XMistralForCausalLM.from_pretrained(llm_name_or_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True,).to(device).eval() # 启用低CPU内存占用模式
llm_tokenizer = AutoTokenizer.from_pretrained(llm_name_or_path, add_eos_token=False, use_fast=False, padding_side='left') # 不自动在文本末尾添加结束符,禁用快速分词器,左填充策略

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  5.52it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  1. 此时,XRAG_TOKEN 只是一个占位符:
1
2
llm.set_xrag_token_id(llm_tokenizer.convert_tokens_to_ids(XRAG_TOKEN)) # 调用模型的set_xrag_token_id方法,设置一个特殊的token ID;
print(XRAG_TOKEN)

<xRAG>

3.2 无 RAG

  1. 根据问题构建 prompt:
1
2
3
4
question = """What company advertised itself with the slogan "We'll leave a light on for you"?""" # 哪家公司用“我们会为您留下一盏灯”的口号来宣传自己(答案是“Motel 6”)
template = "[INST] Answer the questions:\n\nQuestion: {question} [/INST] The answer is:"
prompt = template.format_map(dict(question=question)) # format_map将问题插入到模板
print(prompt)

[INST] Answer the questions:

Question: What company advertised itself with the slogan "We'll leave a light on for you"? [/INST] The answer is:
  1. 进行推理(不同的精度会得到不同的答案):
1
2
3
4
5
6
7
8
9
input_ids = llm_tokenizer(prompt, return_tensors='pt').input_ids.to(device) # return_tensors='pt':指定返回的张量类型为PyTorch张量
generated_output = llm.generate(
input_ids=input_ids,
do_sample=False, # 禁用采样,使用贪心解码
max_new_tokens=20, # 指定生成的最大新Token数量
pad_token_id=llm_tokenizer.pad_token_id, # 指定填充Token的ID
)
result = llm_tokenizer.batch_decode(generated_output[:, input_ids.shape[1]:], skip_special_tokens=True)[0] # 提取生成的新增部分
print(result)

Holiday Inn. Holiday Inn is a global hotel chain that has used the slogan "We
  1. 测量运行时间:
1
2
3
4
5
6
7
8
9
10
11
%%time
batch_size = 24
num_batch = 50
input_ids = input_ids.repeat(batch_size, 1) # 每个输入在批次中重复batch_size次
for _ in range(num_batch): # 批量生成文本
generated_output = llm.generate(
input_ids=input_ids,
do_sample=False,
max_new_tokens=20,
pad_token_id=llm_tokenizer.pad_token_id,
)

CPU times: user 28.6 s, sys: 1.37 s, total: 29.9 s
Wall time: 30 s

3.3 传统 RAG

  1. 模拟数据库:
1
2
3
4
5
6
7
documents = [
'Alvin and the Chipmunks | " Alvin and the Chipmunks, originally David Seville and the Chipmunks or simply The Chipmunks, are an American animated virtual band created by Ross Bagdasarian for a novelty record in 1958. The group consists of three singing animated anthropomorphic chipmunks named Alvin, Simon, and Theodore. They are managed by their human adoptive father, David ""Dave"" Seville. Bagdasarian provided the group\'s voices sped up to create high-pitched squeaky voices (which wasn\'t entirely new to him, having worked on ""Witch Doctor"" earned the record two Grammy Awards for engineering). ""The Chipmunk Song"" became a number-one single in the United States. After Bagdasarian died in 1972, the characters’ voices were provided by his son Ross Bagdasarian Jr. and the latter\'s wife Janice Karman in the subsequent incarnations of "',
"Jamie Lee Curtis | Jamie Lee Curtis (born November 22, 1958) is an American actress and writer. She is the recipient of several accolades, including a British Academy Film Award, two Golden Globe Awards and a star on the Hollywood Walk of Fame in 1998. Curtis made her film acting debut as Laurie Strode in John Carpenter's horror film Halloween (1978), which established her as a scream queen, and she thereafter appeared in a string of horror films, including The Fog, Prom Night, Terror Train (all 1980) and Roadgames (1981). She reprised the role of Laurie in the sequels Halloween II (1981), Halloween H20: 20 Years Later (1998), Halloween: Resurrection (2002), Halloween (2018), and Halloween Kills (2021). Her filmography is largely characterized by independent film that have been box-office successes, with 8 of her lead-actress credits ",
'Sunset Boulevard (musical) | " The American premiere was at the Shubert Theatre in Century City, Los Angeles, California, on 9 December 1993, with Close as Norma and Alan Campbell as Joe. Featured were George Hearn as Max and Judy Kuhn as Betty. Lloyd Webber had reworked both the book and score, tightening the production, better organising the orchestrations, and adding the song ""Every Movie\'s a Circus"". This new production was better received by the critics and was an instant success, running for 369 performances. The Los Angeles production also recorded a new cast album that is well regarded. It is also the only unabridged cast recording of the show, since the original London recording was trimmed by over thirty minutes. A controversy arose with this production after Faye Dunaway was hired to replace Glenn Close. Dunaway went into rehearsals with Rex Smith as Joe and Jon Cypher as Max. Tickets "',
'Arthur Balfour | Balfour was appointed prime minister on 12 July 1902 while the King was recovering from his recent appendicitis operation. Changes to the Cabinet were thus not announced until 9 August, when the King was back in London. The new ministers were received in audience and took their oaths on 11 August.',
'Motel 6 | " Beginning in 1986, Motel 6 has advertised through radio commercials featuring the voice of writer and National Public Radio commentator Tom Bodett, with the tagline "We\'ll leave the light on for you." The ads were created by Dallas advertising agency The Richards Group. They feature a tune composed by Tom Faulkner, performed by him on guitar and Milo Deering on fiddle. The first spots were conceived and written by David Fowler. In 1996, the ads won a Clio Award. The campaign itself has won numerous national and international awards and was selected by Advertising Age magazine as one of the Top 100 Advertising Campaigns of the Twentieth Century."',
]
  1. 加载检索模型:
1
2
3
retriever_name_or_path = "../model/SFR-Embedding-Mistral"
retriever = SFR.from_pretrained(retriever_name_or_path, torch_dtype=torch.bfloat16).eval().to(device)
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_name_or_path)

Loading checkpoint shards: 100%|██████████| 3/3 [00:23<00:00,  7.69s/it]
  1. 计算每个文档的嵌入:
1
2
3
4
retriever_input = retriever_tokenizer(documents, max_length=180, padding=True, truncation=True, return_tensors='pt').to(device) # padding:对不足max_length的文档进行填充;truncation:对超过max_length的文档进行截断
with torch.no_grad():
doc_embeds = retriever.get_doc_embedding(input_ids=retriever_input.input_ids, attention_mask=retriever_input.attention_mask)
doc_embeds.shape

torch.Size([5, 4096])
  1. 建立索引:
1
datastore = (documents, doc_embeds)
  1. 计算问题的嵌入:
1
2
3
4
retriever_input = retriever_tokenizer(question, max_length=180, padding=True, truncation=True, return_tensors='pt').to(device)
with torch.no_grad():
query_embed = retriever.get_query_embedding(input_ids=retriever_input.input_ids, attention_mask=retriever_input.attention_mask)
query_embed.shape

torch.Size([1, 4096])
  1. 获取相似度最高的文档索引:
1
2
3
_, index = torch.topk(torch.matmul(query_embed, doc_embeds.T), k=1) # 向量点积
top1_doc_index = index[0][0].item()
top1_doc_index

4
  1. 根据索引获得对应的文档:
1
2
relevant_doc = datastore[0][top1_doc_index]
print(relevant_doc)

Motel 6 | " Beginning in 1986, Motel 6 has advertised through radio commercials featuring the voice of writer and National Public Radio commentator Tom Bodett, with the tagline "We'll leave the light on for you." The ads were created by Dallas advertising agency The Richards Group. They feature a tune composed by Tom Faulkner, performed by him on guitar and Milo Deering on fiddle. The first spots were conceived and written by David Fowler. In 1996, the ads won a Clio Award. The campaign itself has won numerous national and international awards and was selected by Advertising Age magazine as one of the Top 100 Advertising Campaigns of the Twentieth Century."
  1. 根据问题和检索到的文档构建新的 prompt:
1
2
3
4
5
6
7
rag_template = """[INST] Refer to the background document and answer the questions:

Background: {document}

Question: {question} [/INST] The answer is:"""
prompt = rag_template.format_map(dict(document=relevant_doc, question=question))
print(prompt)

[INST] Refer to the background document and answer the questions:

Background: Motel 6 | " Beginning in 1986, Motel 6 has advertised through radio commercials featuring the voice of writer and National Public Radio commentator Tom Bodett, with the tagline "We'll leave the light on for you." The ads were created by Dallas advertising agency The Richards Group. They feature a tune composed by Tom Faulkner, performed by him on guitar and Milo Deering on fiddle. The first spots were conceived and written by David Fowler. In 1996, the ads won a Clio Award. The campaign itself has won numerous national and international awards and was selected by Advertising Age magazine as one of the Top 100 Advertising Campaigns of the Twentieth Century."

Question: What company advertised itself with the slogan "We'll leave a light on for you"? [/INST] The answer is:
  1. 进行推理:
1
2
3
4
5
6
7
8
9
input_ids = llm_tokenizer(prompt, return_tensors='pt').input_ids.to(device)
generated_output = llm.generate(
input_ids=input_ids,
do_sample=False,
max_new_tokens=20,
pad_token_id=llm_tokenizer.pad_token_id,
)
result = llm_tokenizer.batch_decode(generated_output[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
print(result)

Motel 6

Explanation: Motel 6 is the company that advertised
  1. 测量运行时间:
1
2
3
4
5
6
7
8
9
10
11
%%time
batch_size = 24
num_batch = 50
input_ids = input_ids.repeat(batch_size, 1)
for _ in range(num_batch):
generated_output = llm.generate(
input_ids=input_ids,
do_sample=False,
max_new_tokens=20,
pad_token_id=llm_tokenizer.pad_token_id,
)

CPU times: user 42.3 s, sys: 9.36 s, total: 51.7 s
Wall time: 51.7 s
  1. 查看问题和文档的长度:
1
2
3
question_len = llm_tokenizer(question, return_length=True, add_special_tokens=False).length
doc_len = llm_tokenizer(relevant_doc, return_length=True, add_special_tokens=False).length
question_len, doc_len

(20, 163)

3.4 xRAG

  1. 根据索引获得对应的嵌入,并进行推理:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
relevant_embedding = datastore[1][top1_doc_index]

prompt = rag_template.format_map(dict(question=question, document=XRAG_TOKEN)) # 检索文档换成了一个 token
print(prompt)
input_ids = llm_tokenizer(prompt,return_tensors='pt').input_ids.to(device)
generated_output = llm.generate(
input_ids=input_ids,
do_sample=False,
max_new_tokens=20,
pad_token_id=llm_tokenizer.pad_token_id,
retrieval_embeds = relevant_embedding.unsqueeze(0),
)
result = llm_tokenizer.batch_decode(generated_output, skip_special_tokens=True)[0]
print(result)

[INST] Refer to the background document and answer the questions:

Background: <xRAG>

Question: What company advertised itself with the slogan "We'll leave a light on for you"? [/INST] The answer is:
Motel 6. The slogan was created in 1962 by Tom Bodett
  1. 测量运行时间:
1
2
3
4
5
6
7
8
9
10
11
12
13
%%time
batch_size = 24
num_batch = 50
input_ids = input_ids.repeat(batch_size, 1)
retrieval_embeds = relevant_embedding.unsqueeze(0).repeat(batch_size, 1)
for _ in range(num_batch):
generated_output = llm.generate(
input_ids=input_ids,
do_sample=False,
max_new_tokens=20,
pad_token_id=llm_tokenizer.pad_token_id,
retrieval_embeds=retrieval_embeds,
)

CPU times: user 30.5 s, sys: 2.07 s, total: 32.5 s
Wall time: 32.5 s

【论文复现】xRAG
http://xuan-van.github.io/代码复现/【论文复现】xrag/
作者
文晋
发布于
2025年4月26日
许可协议