1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
| def get_agents_dict(model, tokenizer, device, args): assert (type(args) == argparse.Namespace), "args should be an argparse.Namespace object" assert hasattr(args, "qa_inst"), "args should have 'qa_inst' attribute" assert hasattr(args, "se_inst"), "args should have 'se_inst' attribute" assert hasattr(args, "cot_inst"), "args should have 'cot_inst' attribute" assert hasattr(args, "pe_inst"), "args should have 'pe_inst' attribute" agent_kwargs = {"model": model, "tokenizer": tokenizer, "device": device, "max_ans_tokens": args.max_ans_tokens} return { "qa": ContextQuestionAnsweringAgent(instruction=args.qa_inst, **agent_kwargs), "se": ContextQuestionAnsweringAgent(instruction=args.se_inst, **agent_kwargs), "cot": ContextQuestionAnsweringAgent(instruction=args.cot_inst, **agent_kwargs), "pe": ContextQuestionAnsweringAgent(instruction=args.pe_inst, **agent_kwargs), }
class ContextQuestionAnsweringAgent: """ 基于上下文的问答 Agent 类,支持多种策略。 属性: model: 初始化的语言模型。 tokenizer: 初始化的分词器。 device: 计算设备。 instruction: Agent 指令。 max_ans_tokens: 答案最大 token 数。 """
def __init__(self, model, tokenizer, device, instruction, max_ans_tokens): self.model = model self.tokenizer = tokenizer self.device = device self.instruction = instruction self.max_ans_tokens = max_ans_tokens
try: self.model.generation_config.pad_token_id = self.tokenizer.eos_token_id except AttributeError: pass
def get_chat_template_input_ids(self, context, question, return_tensors=None): instruction = self.instruction msg = f"Instruction: {instruction} Context: {context} Question: {question}" input_ids = self.tokenizer.apply_chat_template( [{"role": "user", "content": msg}], add_generation_prompt=True, return_tensors=return_tensors, ) return input_ids
def get_answer(self, context, question, max_ans_tokens=None, verbose=False, return_n_tokens=False): model, tokenizer, device = self.model, self.tokenizer, self.device
if max_ans_tokens is None: max_ans_tokens = self.max_ans_tokens else: assert type(max_ans_tokens) == int, "max_ans_tokens should be an integer"
input_ids = self.get_chat_template_input_ids(context, question, return_tensors="pt").to(device) len_input = input_ids.shape[-1]
with torch.no_grad(): outputs = model.generate( input_ids, attention_mask=torch.ones_like(input_ids), max_new_tokens=max_ans_tokens, do_sample=False, top_p=None, top_k=None, temperature=None, )
answer_ids = outputs[0][len_input:] answer = tokenizer.decode(answer_ids, skip_special_tokens=True)
if verbose: print(f"Context: {context}\nQuestion: {question}\nAnswer: {answer}")
if return_n_tokens: n_tokens = len(answer_ids) return answer, n_tokens
return answer
@staticmethod def get_n_match(string, substring): all_starts = [] start = 0 while True: start = string.find(substring, start) if start == -1: break all_starts.append(start) start += 1 return len(all_starts)
def find_text_token_spans(self, input_ids, target_text, raise_if_not_found=True): assert (type(input_ids) == list) and (type(input_ids[0]) == int), "input_ids should be a 1-d list, make sure it's not a tensor."
tokenizer = self.tokenizer source = tokenizer.decode(input_ids) target_ids = tokenizer.encode(target_text, add_special_tokens=False) target = tokenizer.decode(target_ids) if raise_if_not_found: assert target in source, f"'{target}' not found in input" n_match_left = self.get_n_match(source, target) spans = [] start = 0
while True: start += 1 source_seg = tokenizer.decode(input_ids[start:]) n_match_cur = self.get_n_match(source_seg, target)
if n_match_cur < n_match_left: assert (n_match_left - n_match_cur == 1), f"{n_match_left - n_match_cur} matches in a same token" n_match_left = n_match_cur start -= 1 end = max(start + len(target_ids) - 5, start) while True: end += 1 seg_text = tokenizer.decode(input_ids[start:end]) if target in seg_text: break spans.append((start, end)) start = end
if n_match_left == 0 or start >= len(input_ids): break
return spans
def get_context_token_span(self, context, question): input_ids = self.get_chat_template_input_ids(context, question, return_tensors=None) context_spans = self.find_text_token_spans(input_ids, context) assert (len(context_spans) == 1), f"Multiple/no context spans found: {context_spans}" return context_spans[0]
|