1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
| import torch import json import argparse import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer from tqdm import tqdm from utils import read_json_data, ensure_directory_exists
def construct_model_input(item): sorted_contexts = sorted(item['contexts'], key=lambda x: x['score']) sub_texts = ["<|begin_of_text|><|start_header_id|>user<|end_header_id|>"] for i, context in enumerate(sorted_contexts): chunk = f"\n\nDocument {i + 1} (Title: {context['title']}): {context['content']}" sub_texts.append(chunk)
sub_texts += [ "\n\nBased on your knowledge and the provided information, answer the question:", f"\n{item['question']}", "\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>" ] return sub_texts, sorted_contexts
def find_subtexts_positions(sub_texts, tokenizer): result = {} input_text = "".join(sub_texts)
encoding = tokenizer(input_text, return_offsets_mapping=True) offset_mapping = encoding["offset_mapping"]
for i, sub_text in enumerate(sub_texts): if sub_text not in input_text: raise ValueError(f"子文本'{i}'不在整个文本中")
start_char = input_text.find(sub_text) end_char = start_char + len(sub_text)
start_token = None end_token = None
for j, (start, end) in enumerate(offset_mapping): if start <= start_char < end and start_token is None: start_token = j if start < end_char <= end and end_token is None: end_token = j break
if start_token is None or end_token is None: raise ValueError(f"没能定位到子文本'{i}'的token序列位置")
result[(start_char, end_char)] = [start_token, end_token]
return result
def calculate_score(data, tokenizer, model, method): for item in tqdm(data, desc="Processing"): sub_texts, sorted_contexts = construct_model_input(item) input_text = "".join(sub_texts) positions = find_subtexts_positions(sub_texts, tokenizer)
inputs = tokenizer(input_text, return_tensors="pt").to(device) with torch.no_grad(): embeddings = model.get_input_embeddings()(inputs.input_ids)
subtext_embeddings = {} for (char_start, char_end), (token_start, token_end) in positions.items(): sub_embeddings = embeddings[0, token_start:token_end+1, :] mean_embedding = torch.mean(sub_embeddings, dim=0) subtext_embeddings[(char_start, char_end)] = mean_embedding
question_pos = list(positions.keys())[-2] question_embedding = subtext_embeddings[question_pos]
for doc_idx in range(len(sorted_contexts)): doc_pos = list(positions.keys())[doc_idx + 1] doc_embedding = subtext_embeddings[doc_pos] if method == 'similarity': score = F.cosine_similarity(question_embedding.unsqueeze(0), doc_embedding.unsqueeze(0)) else: score = torch.corrcoef(torch.stack([question_embedding, doc_embedding]))[0, 1] sorted_contexts[doc_idx]['score'] = score.item()
item['contexts'] = sorted(sorted_contexts, key=lambda x: x['score'], reverse=True)
return data
def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--input_file", type=str, required=True, help="输入文件路径") parser.add_argument("--output_file", type=str, required=True, help="输出文件路径") parser.add_argument("--model_path", type=str, default="model/Llama-3-8B-Instruct", help="LLM模型路径") parser.add_argument("--method", type=str, choices=['similarity', 'relevance'], help="排序得分计算方法")
return parser.parse_args()
if __name__ == "__main__": args = parse_args()
data = read_json_data(args.input_file) print(f"读取到 {len(data)} 条数据")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(args.model_path) model = AutoModel.from_pretrained(args.model_path).to(device).eval()
updated_data = calculate_score(data, tokenizer, model, args.method)
ensure_directory_exists(args.output_file) with open(args.output_file, 'w', encoding='utf-8') as f: json.dump(updated_data, f, indent=4, ensure_ascii=False) print(f"处理完成,结果已保存到 {args.output_file}")
|