方法图示:
参考项目:OpenDFM/HeadsUp
1 安装 1.1 虚拟环境 1 2 3 4 5 6 7 8 9 conda create -n heads python=3.10 -y conda activate heads pip install ipykernel seaborn adjustText pip install rouge_score nltk jieba pip install torch transformers==4.51.3 datasets==3.0.1 safetensors accelerate flash_attn-2.7.3+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl python -m ipykernel install --user --name heads jupyter kernelspec list
1.2 项目结构 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 dataset/ function_vectors/ abstractive/ adjective_v_verb_3.json antonym.json verb_v_adjective_3.json fewshot/ iwslt2017/ en-zh-test.parquet XNLI-15 way/ xnli.15 way.orig.tsv models/ modeling_gemma2.py modeling_mistral.py modeling_qwen2.py modeling_llama.py modeling_phi3.py output/ llama/ fv/ adjective_v_verb_3/ antonym/ verb_v_adjective_3/ xnli/ en_zh/ zh_en/ HeadsUp.ipynb
1.3 模型和 Mask 矩阵 模型:
1 2 huggingface-cli download --token Your_token meta-llama/Meta-Llama-3.1-8B-Instruct --local-dir model/Llama-3.1-8B-Instruct huggingface-cli download --token Your_token BAAI/bge-reranker-v2-m3 --local-dir model/bge-reranker-v2-m3
Mask 矩阵 :下载 fv/adjective_v_verb_3/ 和 fv/verb_v_adjective_3/。
2 整体流程 2.1 训练 Mask 矩阵
导入必要的库:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 import argparseimport copyimport osimport refrom dataclasses import dataclass, fieldfrom tqdm import tqdmimport torchfrom torch import nnimport safetensors.torchimport itertoolsimport numpy as npimport pandas as pdimport datasetsfrom datasets import Dataset, concatenate_datasetsfrom models.modeling_llama import LlamaForCausalLMfrom models.modeling_phi3 import Phi3ForCausalLMfrom models.modeling_mistral import MistralForCausalLMfrom models.modeling_qwen2 import Qwen2ForCausalLMfrom models.modeling_gemma2 import Gemma2ForCausalLMfrom transformers import AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments, HfArgumentParserfrom rouge_score import rouge_scorerfrom nltk.translate.bleu_score import sentence_bleu, SmoothingFunctionimport jiebaimport matplotlib.pyplot as pltimport matplotlib as mplimport seaborn as snsfrom adjustText import adjust_text
重新定义训练器:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 class SimpleTensorModel (nn.Module): def __init__ (self, tensor_length=1056 ): super ().__init__() self .tensor = nn.Parameter(torch.zeros(tensor_length)) nn.init.normal_(self .tensor, mean=4 , std=0.02 ) def forward (self, x=None ): return self .tensorclass CustomDataCollator (DataCollatorForLanguageModeling ): def __init__ (self, lm_tokenizer, padding=True ): self .lm_tokenizer = lm_tokenizer self .padding = padding def torch_call (self, features ): lm_input_ids = [f["lm_input_ids" ] for f in features] lm_labels = [torch.tensor(f["lm_labels" ]) for f in features] lm_batch = self .lm_tokenizer.pad({"input_ids" : lm_input_ids}, padding=self .padding, return_tensors="pt" ) lm_labels = torch.nn.utils.rnn.pad_sequence(lm_labels, batch_first=True , padding_value=-100 ) return {"lm_input_ids" : lm_batch["input_ids" ], "lm_attention_mask" : lm_batch["attention_mask" ], "lm_labels" : lm_labels}def gumbel_sigmoid (logits, tau=1 , hard=False , threshold=0.5 ): gumbels = (torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()) gumbels = (logits + gumbels) / tau y_soft = gumbels.sigmoid() if hard: indices = (y_soft > threshold).nonzero(as_tuple=True ) y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format) y_hard[indices[0 ], indices[1 ]] = 1.0 ret = y_hard - y_soft.detach() + y_soft else : ret = y_soft return retclass CustomTrainer (Trainer ): def __init__ (self, *args, lm_model, lm_tokenizer, **kwargs ): super ().__init__(*args, **kwargs) self .lm_model = lm_model self .lm_tokenizer = lm_tokenizer def compute_loss (self, model, inputs, return_outputs=False , num_items_in_batch=None ): lm_input_ids = inputs["lm_input_ids" ].to(self .args.device) lm_attention_mask = inputs["lm_attention_mask" ].to(self .args.device) lm_labels = inputs["lm_labels" ].to(self .args.device) bsz = lm_input_ids.size(0 ) weights_logit = model(None ).unsqueeze(0 ).repeat(bsz, 1 ) if self .args.tau_decay_steps < 1 : tau_decay_end_step = int (self .args.tau_decay_steps * self .state.max_steps) else : tau_decay_end_step = int (self .args.tau_decay_steps) if self .state.global_step >= tau_decay_end_step: tau_temp = self .args.tau_temp_end else : decay_ratio = self .state.global_step / tau_decay_end_step tau_temp = self .args.tau_temp_begin - decay_ratio * (self .args.tau_temp_begin - self .args.tau_temp_end) if self .args.use_gumbel: weights_tensor = gumbel_sigmoid(weights_logit, tau=tau_temp, hard=self .args.gumbel_hard) else : weights_tensor = torch.sigmoid(weights_logit) weights_tensor = weights_tensor.to(self .lm_model.device) pred_outputs = self .lm_model(lm_input_ids, attention_mask=lm_attention_mask, weight_tensor=weights_tensor, labels=lm_labels, use_cache=False ) pred_output_loss = pred_outputs.loss norm_lambda = self .args.norm_lambda normalizer = torch.sum (weights_tensor, dim=1 ).mean() loss = pred_output_loss + norm_lambda * normalizer self .log({"pred_output_loss" : pred_output_loss.item(), "normalizer" : normalizer.item(), "tau_temp" : tau_temp, "total_loss" : loss.item()}) return (loss, pred_outputs) if return_outputs else loss def prediction_step (self, model, inputs, prediction_loss_only, ignore_keys=None ): lm_input_ids = inputs["lm_input_ids" ].to(self .args.device) lm_labels = inputs["lm_labels" ].to(self .args.device) bsz = lm_input_ids.size(0 ) with torch.no_grad(): weights_logit = model(None ).unsqueeze(0 ).repeat(bsz, 1 ) weights_tensor = (weights_logit.sigmoid() >= 0.5 ).to(weights_logit.dtype) weights_tensor = weights_tensor.to(self .lm_model.device) pred_output_loss = None for one_lm_input_ids, one_lm_labels, one_weights_tensor in zip (lm_input_ids, lm_labels, weights_tensor): one_lm_scr_input_ids = one_lm_input_ids[one_lm_labels == -100 ].unsqueeze(0 ) one_lm_attention_mask = torch.ones_like(one_lm_scr_input_ids) generate_ids = self .lm_model.generate(one_lm_scr_input_ids, attention_mask=one_lm_attention_mask, max_new_tokens=30 , weight_tensor=one_weights_tensor.unsqueeze(0 ), do_sample=False ) pred_str = self .lm_tokenizer.decode(generate_ids[0 ], skip_special_tokens=True ) print ("[GENERATED]" , pred_str) if prediction_loss_only: return (pred_output_loss, None , None ) lm_logits = None return (pred_output_loss, lm_logits, lm_labels)
定义模板字符串和预处理函数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 LLAMA_TEMPLATE = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{src}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" LLAMA_PLM_TEMPLATE = "<|begin_of_text|>{src}\n\n" PHI3_TEMPLATE = "<|user|>\n{src}<|end|>\n<|assistant|>\n" MISTRAL_TEMPLATE = "<s>[INST] {src}[/INST]" QWEN2_TEMPLATE = "<|im_start|>user\n{src}<|im_end|>\n<|im_start|>assistant\n" QWEN2_PLM_TEMPLATE = "{src}\n\n" GEMMA2_TEMPLATE = "<bos><start_of_turn>user\n{src}<end_of_turn>\n<start_of_turn>model\n" def preprocess_batch (samples, lm_tokenizer, model_type, fewshot_dataset=None ): assert model_type in ["llama" , "llama-plm" , "phi3" , "mistral" , "qwen2" , "qwen2-plm" , "gemma2" ] template = "" if model_type == "llama" : template = LLAMA_TEMPLATE combined_strs = [LLAMA_TEMPLATE.format (src=input_str.strip()) + f"{target_str.strip()} <|eot_id|>" for input_str, target_str in zip (samples["input_str" ], samples["target_str" ])] elif model_type == "llama-plm" : template = LLAMA_PLM_TEMPLATE combined_strs = [LLAMA_PLM_TEMPLATE.format (src=input_str.strip()) + f"{target_str.strip()} <|end_of_text|>" for input_str, target_str in zip (samples["input_str" ], samples["target_str" ])] elif model_type == "phi3" : template = PHI3_TEMPLATE combined_strs = [PHI3_TEMPLATE.format (src=input_str.strip()) + f"{target_str.strip()} <|end|>" for input_str, target_str in zip (samples["input_str" ], samples["target_str" ])] elif model_type == "mistral" : template = MISTRAL_TEMPLATE combined_strs = [MISTRAL_TEMPLATE.format (src=input_str.strip()) + f" {target_str.strip()} </s>" for input_str, target_str in zip (samples["input_str" ], samples["target_str" ])] elif model_type == "qwen2" : template = QWEN2_TEMPLATE combined_strs = [QWEN2_TEMPLATE.format (src=input_str.strip()) + f"{target_str.strip()} <|im_end|>" for input_str, target_str in zip (samples["input_str" ], samples["target_str" ])] elif model_type == "qwen2-plm" : template = QWEN2_PLM_TEMPLATE combined_strs = [QWEN2_PLM_TEMPLATE.format (src=input_str.strip()) + f"{target_str.strip()} <|endoftext|>" for input_str, target_str in zip (samples["input_str" ], samples["target_str" ])] elif model_type == "gemma2" : template = GEMMA2_TEMPLATE combined_strs = [GEMMA2_TEMPLATE.format (src=input_str.strip()) + f"{target_str.strip()} <end_of_turn>" for input_str, target_str in zip (samples["input_str" ], samples["target_str" ])] lm_inputs = lm_tokenizer(combined_strs, max_length=32768 , truncation=True , padding=False , add_special_tokens=False ,return_tensors="np" ) input_str_lens = [len (lm_tokenizer(template.format (src=input_str.strip()), add_special_tokens=False )["input_ids" ]) for input_str in samples["input_str" ]] labels = copy.deepcopy(lm_inputs["input_ids" ]) for i, input_str_len in enumerate (input_str_lens): labels[i][:input_str_len] = -100 return {"lm_input_ids" : lm_inputs["input_ids" ], "lm_labels" : labels}def preprocess_fewshot_batch (samples, lm_tokenizer, model_type, fewshot_dataset ): assert model_type in ["llama" ] fewshot_input_str = "<|begin_of_text|>" for sample in fewshot_dataset: input_str, target_str = sample["input_str" ], sample["target_str" ] fewshot_input_str += f"<|start_header_id|>user<|end_header_id|>\n\n{input_str.strip()} <|eot_id|>" fewshot_input_str += f"<|start_header_id|>assistant<|end_header_id|>\n\n{target_str.strip()} <|eot_id|>" fewshot_input_str += "<|start_header_id|>user<|end_header_id|>\n\n{src}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" template = fewshot_input_str combined_strs = [template.format (src=input_str.strip()) + f"{target_str.strip()} <|eot_id|>" for input_str, target_str in zip (samples["input_str" ], samples["target_str" ])] lm_inputs = lm_tokenizer(combined_strs, max_length=32768 , truncation=True , padding=False , add_special_tokens=False ,return_tensors="np" ) input_str_lens = [len (lm_tokenizer(template.format (src=input_str.strip()), add_special_tokens=False )["input_ids" ]) for input_str in samples["input_str" ]] labels = copy.deepcopy(lm_inputs["input_ids" ]) for i, input_str_len in enumerate (input_str_lens): labels[i][:input_str_len] = -100 return {"lm_input_ids" : lm_inputs["input_ids" ], "lm_labels" : labels}
定义训练过程:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 def search_weight_embed (lm_model, task, args, training_args ): training_args.output_dir = os.path.join(args.output_dir, task) os.makedirs(training_args.output_dir, exist_ok=True ) model_name = os.path.basename(args.lm_model_path) training_args.run_name = f"{model_name} -train_weight-{task} " lm_tokenizer = AutoTokenizer.from_pretrained(args.lm_model_path, use_fast=True ) lm_tokenizer.pad_token = lm_tokenizer.eos_token lm_tokenizer.padding_side = "right" lm_tokenizer.truncation_side = "right" lm_model.generation_config.pad_token_id = lm_tokenizer.pad_token_id n_layers = lm_model.config.num_hidden_layers n_heads = lm_model.config.num_attention_heads head_mask_model = SimpleTensorModel(tensor_length=n_layers * n_heads + n_layers).to(torch.device("cuda" )) if not args.dataset_use_cache: datasets.disable_caching() train_dataset, dev_dataset = None , None fewshot_dataset = None if "XNLI" in args.train_data_path: map_func = preprocess_batch LANG_LIST = task.split("_" ) LANG_DICT = {"ar" : "Arabic" , "fr" : "French" , "es" : "Spanish" , "de" : "German" , "en" : "English" , "ru" : "Russian" , "zh" : "Chinese" } def _preprocess_xnli (dataset ): pair_datasets = [] for src_lang, tgt_lang in itertools.combinations(LANG_LIST, 2 ): pair_dataset = dataset.map (lambda sample: {"input_str" : f"{sample[src_lang]} " , "target_str" : sample[tgt_lang]}).select_columns(["input_str" , "target_str" ]) pair_datasets.append(pair_dataset) return concatenate_datasets(pair_datasets) dataset = Dataset.from_csv(args.train_data_path, sep='\t' ).select_columns(LANG_LIST) train_dataset = dataset.select(range (len (dataset) - 100 )) dev_dataset = dataset.select(range (len (dataset) - 100 , len (dataset))) train_dataset = _preprocess_xnli(train_dataset) dev_dataset = _preprocess_xnli(dev_dataset) train_dataset = train_dataset.shuffle(seed=args.seed) elif "function_vectors" in args.train_data_path: dataset = Dataset.from_json(args.train_data_path) if "fewshot" in args.train_data_path: map_func = preprocess_fewshot_batch dataset = dataset.map (lambda sample: {"input_str" : sample["input" ], "target_str" : sample["output" ]}) dataset = dataset.remove_columns(["input" , "output" ]) dataset = dataset.select(range (min (len (dataset) - 100 , 10000 ))).shuffle(seed=args.seed) fewshot_dataset, dataset = dataset.select(range (5 )), dataset.select(range (5 , len (dataset))) else : map_func = preprocess_batch dataset = dataset.map (lambda sample: {"input_str" : str (sample["input" ]), "target_str" : str (sample["output" ])}) dataset = dataset.remove_columns(["input" , "output" ]) train_dataset = dataset.select(range (min (len (dataset) - 100 , 10000 ))) dev_dataset = dataset.select(range (len (dataset) - 100 , len (dataset))) else : assert False , "Unsupported dataset" train_dataset_dir, dev_dataset_dir = os.path.dirname(args.train_data_path), os.path.dirname(args.dev_data_path) train_dataset = train_dataset.map ( lambda sample: map_func(sample, lm_tokenizer, model_type=args.model_type, fewshot_dataset=fewshot_dataset), batched=True , batch_size=128 , num_proc=1 , cache_file_name=os.path.join(train_dataset_dir, ".cache/train_dataset_cache.arrow" ) if args.dataset_use_cache else None ).filter ( lambda sample: len (sample["lm_input_ids" ]) <= args.max_seq_length, batched=False ).shuffle(seed=args.seed) dev_dataset = dev_dataset.map ( lambda sample: map_func(sample, lm_tokenizer, model_type=args.model_type, fewshot_dataset=fewshot_dataset), batched=True , batch_size=128 , num_proc=1 , cache_file_name=os.path.join(dev_dataset_dir, ".cache/dev_dataset_cache.arrow" ) if args.dataset_use_cache else None ).filter ( lambda sample: len (sample["lm_input_ids" ]) <= args.max_seq_length, batched=False ) data_collator = CustomDataCollator(lm_tokenizer=lm_tokenizer, padding=True ) trainer = CustomTrainer( model=head_mask_model, args=training_args, train_dataset=train_dataset, eval_dataset=dev_dataset, data_collator=data_collator, compute_metrics=None , lm_model=lm_model, lm_tokenizer=lm_tokenizer, ) trainer.train(resume_from_checkpoint=False )
定义相关参数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 @dataclass class ModelArguments : embed_model_path: str = field(default="../model/bge-reranker-v2-m3" , metadata={"help" : "嵌入模型路径" }) lm_model_path: str = field(default="../model/Llama-3.1-8B-Instruct" , metadata={"help" : "语言模型路径" }) model_type: str = field(default="llama" , metadata={"help" : "模型类型(plm 表示纯语言模型)" })@dataclass class DataArguments : train_data_path: str = field(default="dataset/XNLI-15way/xnli.15way.orig.tsv" , metadata={"help" : "训练数据路径" }) dev_data_path: str = field(default="dataset/XNLI-15way/xnli.15way.orig.tsv" , metadata={"help" : "验证数据路径" }) dataset_use_cache: bool = field(default=False , metadata={"help" : "是否使用数据集缓存" }) max_seq_length: int = field(default=32768 , metadata={"help" : "输入的最大序列长度" })@dataclass class CustomTrainingArguments (TrainingArguments ): tau_temp_begin: float = field(default=4.0 , metadata={"help" : "Gumbel-Sigmoid 的初始温度" }) tau_temp_end: float = field(default=0.05 , metadata={"help" : "Gumbel-Sigmoid 的最终温度" }) tau_decay_steps: float = field(default=0.4 , metadata={"help" : "Gumbel-Sigmoid 温度的衰减步数(若小于 1,则表示占总训练步数的比例)" }) norm_lambda: float = field(default=0 , metadata={"help" : "权重张量归一化的 lambda 系数" }) norm_power: float = field(default=1 , metadata={"help" : "权重张量归一化的幂次" }) use_gumbel: bool = field(default=True , metadata={"help" : "是否对权重张量使用 Gumbel-Sigmoid" }) gumbel_hard: bool = field(default=True , metadata={"help" : "是否使用 hard 形式的 Gumbel-Sigmoid" }) hfparser = HfArgumentParser((ModelArguments, DataArguments, CustomTrainingArguments)) model_args, data_args, training_args, _ = hfparser.parse_args_into_dataclasses(return_remaining_strings=True ) data_args.dataset_use_cache = False training_args.output_dir = "output/llama/xnli" training_args.overwrite_output_dir = True training_args.num_train_epochs = 10 training_args.per_device_train_batch_size = 4 training_args.per_device_eval_batch_size = 1 training_args.remove_unused_columns = False training_args.save_steps = 31 training_args.save_total_limit = 200 training_args.save_only_model = True training_args.logging_dir = "logs" training_args.logging_steps = 500 training_args.learning_rate = 1e-2 training_args.lr_scheduler_type = "cosine_with_min_lr" training_args.lr_scheduler_kwargs = {"min_lr" : 1e-4 } training_args.bf16 = True training_args.warmup_ratio = 0.1 training_args.gradient_accumulation_steps = 4 args = argparse.Namespace(**vars (model_args), **vars (training_args), **vars (data_args))print (args)
Namespace(embed_model_path='../model/bge-reranker-v2-m3', lm_model_path='../model/Llama-3.1-8B-Instruct', model_type='llama', output_dir='output/llama/xnli', overwrite_output_dir=True, do_train=False, do_eval=False, do_predict=False, eval_strategy=<IntervalStrategy.NO: 'no'>, prediction_loss_only=False, per_device_train_batch_size=4, per_device_eval_batch_size=1, per_gpu_train_batch_size=None, per_gpu_eval_batch_size=None, gradient_accumulation_steps=4, eval_accumulation_steps=None, eval_delay=0, torch_empty_cache_steps=None, learning_rate=0.01, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0, num_train_epochs=10, max_steps=-1, lr_scheduler_type='cosine_with_min_lr', lr_scheduler_kwargs={'min_lr': 0.0001}, warmup_ratio=0.1, warmup_steps=0, log_level='passive', log_level_replica='warning', log_on_each_node=True, logging_dir='logs', logging_strategy=<IntervalStrategy.STEPS: 'steps'>, logging_first_step=False, logging_steps=500, logging_nan_inf_filter=True, save_strategy=<SaveStrategy.STEPS: 'steps'>, save_steps=31, save_total_limit=200, save_safetensors=True, save_on_each_node=False, save_only_model=True, restore_callback_states_from_checkpoint=False, no_cuda=False, use_cpu=False, use_mps_device=False, seed=42, data_seed=None, jit_mode_eval=False, use_ipex=False, bf16=True, fp16=False, fp16_opt_level='O1', half_precision_backend='auto', bf16_full_eval=False, fp16_full_eval=False, tf32=None, local_rank=0, ddp_backend=None, tpu_num_cores=None, tpu_metrics_debug=False, debug=[], dataloader_drop_last=False, eval_steps=None, dataloader_num_workers=0, dataloader_prefetch_factor=None, past_index=-1, run_name='trainer_output', disable_tqdm=False, remove_unused_columns=False, label_names=None, load_best_model_at_end=False, metric_for_best_model=None, greater_is_better=None, ignore_data_skip=False, fsdp=[], fsdp_min_num_params=0, fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}, tp_size=0, fsdp_transformer_layer_cls_to_wrap=None, accelerator_config=AcceleratorConfig(split_batches=False, dispatch_batches=None, even_batches=True, use_seedable_sampler=True, non_blocking=False, gradient_accumulation_kwargs=None, use_configured_state=False), deepspeed=None, label_smoothing_factor=0.0, optim=<OptimizerNames.ADAMW_TORCH: 'adamw_torch'>, optim_args=None, adafactor=False, group_by_length=False, length_column_name='length', report_to=[], ddp_find_unused_parameters=None, ddp_bucket_cap_mb=None, ddp_broadcast_buffers=None, dataloader_pin_memory=True, dataloader_persistent_workers=False, skip_memory_metrics=True, use_legacy_prediction_loop=False, push_to_hub=False, resume_from_checkpoint=None, hub_model_id=None, hub_strategy=<HubStrategy.EVERY_SAVE: 'every_save'>, hub_token=None, hub_private_repo=None, hub_always_push=False, gradient_checkpointing=False, gradient_checkpointing_kwargs=None, include_inputs_for_metrics=False, include_for_metrics=[], eval_do_concat_batches=True, fp16_backend='auto', push_to_hub_model_id=None, push_to_hub_organization=None, push_to_hub_token=None, mp_parameters='', auto_find_batch_size=False, full_determinism=False, torchdynamo=None, ray_scope='last', ddp_timeout=1800, torch_compile=False, torch_compile_backend=None, torch_compile_mode=None, include_tokens_per_second=False, include_num_input_tokens_seen=False, neftune_noise_alpha=None, optim_target_modules=None, batch_eval_metrics=False, eval_on_start=False, use_liger_kernel=False, eval_use_gather_object=False, average_tokens_across_devices=False, tau_temp_begin=4.0, tau_temp_end=0.05, tau_decay_steps=0.4, norm_lambda=0, norm_power=1, use_gumbel=True, gumbel_hard=True, distributed_state=Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
, _n_gpu=1, __cached__setup_devices=device(type='cuda', index=0), deepspeed_plugin=None, train_data_path='dataset/XNLI-15way/xnli.15way.orig.tsv', dev_data_path='dataset/XNLI-15way/xnli.15way.orig.tsv', dataset_use_cache=False, max_seq_length=32768)
选择模型类:
1 2 3 4 5 6 7 8 9 10 11 12 13 if "phi3" in args.model_type: casual_lm = Phi3ForCausalLMelif "llama" in args.model_type: casual_lm = LlamaForCausalLMelif "mistral" in args.model_type: casual_lm = MistralForCausalLMelif "qwen2" in args.model_type: casual_lm = Qwen2ForCausalLMelif "gemma2" in args.model_type: casual_lm = Gemma2ForCausalLMelse : assert False , "Unsupported model type" print (f"Using model type: {casual_lm.__name__} " )
Using model type: LlamaForCausalLM
加载语言模型:
1 2 3 4 5 6 7 8 lm_model = casual_lm.from_pretrained( args.lm_model_path, local_files_only=True , device_map=torch.device(f"cuda:{int (os.environ['LOCAL_RANK' ])} " ) if args.distributed_state.use_distributed else "auto" , torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" , max_position_embeddings=32768 )
Loading checkpoint shards: 100%|██████████| 4/4 [00:21<00:00, 5.39s/it]
机器翻译训练,以中英翻译为例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 train_data_dir = args.train_data_pathif "XNLI" in train_data_dir: ALL_LANGS = ["en" , "zh" ] for i, (src_lang, tgt_lang) in enumerate (itertools.permutations(ALL_LANGS, 2 )): print (f"\n********** [{i+1 } /{len (ALL_LANGS) * (len (ALL_LANGS) - 1 )} ] Training {src_lang} -> {tgt_lang} **********\n" ) search_weight_embed(lm_model, f"{src_lang} _{tgt_lang} " , args, training_args)elif "function_vectors" in train_data_dir: for root, dirs, files in os.walk(train_data_dir): for file in files: if file.endswith(".json" ): task = os.path.basename(file).split("." )[0 ] print (f"\n********** Training {task} **********\n" ) args.train_data_path = os.path.join(root, file) search_weight_embed(lm_model, task, args, training_args)else : assert False , "Unsupported dataset"
********** [1/2] Training en -> zh **********
Map: 100%|██████████| 9900/9900 [00:00<00:00, 24753.96 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 16005.74 examples/s]
Map: 100%|██████████| 9900/9900 [00:01<00:00, 7272.35 examples/s]
Filter: 100%|██████████| 9900/9900 [00:00<00:00, 31385.21 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 5276.52 examples/s]
Filter: 100%|██████████| 100/100 [00:00<00:00, 18811.07 examples/s]
Step Training Loss
500 2.330100
1000 1.664200
1500 1.343300
2000 1.235700
2500 1.195600
3000 1.178300
3500 1.163300
4000 1.172700
4500 1.192800
5000 1.158600
5500 1.169000
6000 1.143300
********** [2/2] Training zh -> en **********
Map: 100%|██████████| 9900/9900 [00:00<00:00, 32588.55 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 17755.17 examples/s]
Map: 100%|██████████| 9900/9900 [00:01<00:00, 6584.71 examples/s]
Filter: 100%|██████████| 9900/9900 [00:00<00:00, 29741.66 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 5468.88 examples/s]
Filter: 100%|██████████| 100/100 [00:00<00:00, 17750.66 examples/s]
Step Training Loss
500 1.758600
1000 1.261800
1500 1.209500
2000 1.180300
2500 1.206600
3000 1.192100
3500 1.170300
4000 1.188600
4500 1.199100
5000 1.178500
5500 1.170500
6000 1.145500
替换为任务的参数:
1 2 3 4 5 6 7 8 9 10 data_args.train_data_path = "dataset/function_vectors/abstractive" data_args.dev_data_path = "dataset/function_vectors/abstractive" training_args.output_dir = "output/llama/fv" training_args.num_train_epochs = 3 training_args.max_steps = 6250 training_args.eval_steps = 625 args = argparse.Namespace(**vars (model_args), **vars (training_args), **vars (data_args))print (args)
Namespace(embed_model_path='../model/bge-reranker-v2-m3', lm_model_path='../model/Llama-3.1-8B-Instruct', model_type='llama', output_dir='output/llama/fv', overwrite_output_dir=True, do_train=False, do_eval=False, do_predict=False, eval_strategy=<IntervalStrategy.NO: 'no'>, prediction_loss_only=False, per_device_train_batch_size=4, per_device_eval_batch_size=1, per_gpu_train_batch_size=None, per_gpu_eval_batch_size=None, gradient_accumulation_steps=4, eval_accumulation_steps=None, eval_delay=0, torch_empty_cache_steps=None, learning_rate=0.01, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0, num_train_epochs=3, max_steps=6250, lr_scheduler_type='cosine_with_min_lr', lr_scheduler_kwargs={'min_lr': 0.0001}, warmup_ratio=0.1, warmup_steps=0, log_level='passive', log_level_replica='warning', log_on_each_node=True, logging_dir='logs', logging_strategy=<IntervalStrategy.STEPS: 'steps'>, logging_first_step=False, logging_steps=500, logging_nan_inf_filter=True, save_strategy=<SaveStrategy.STEPS: 'steps'>, save_steps=31, save_total_limit=200, save_safetensors=True, save_on_each_node=False, save_only_model=True, restore_callback_states_from_checkpoint=False, no_cuda=False, use_cpu=False, use_mps_device=False, seed=42, data_seed=None, jit_mode_eval=False, use_ipex=False, bf16=True, fp16=False, fp16_opt_level='O1', half_precision_backend='auto', bf16_full_eval=False, fp16_full_eval=False, tf32=None, local_rank=0, ddp_backend=None, tpu_num_cores=None, tpu_metrics_debug=False, debug=[], dataloader_drop_last=False, eval_steps=625, dataloader_num_workers=0, dataloader_prefetch_factor=None, past_index=-1, run_name='trainer_output', disable_tqdm=False, remove_unused_columns=False, label_names=None, load_best_model_at_end=False, metric_for_best_model=None, greater_is_better=None, ignore_data_skip=False, fsdp=[], fsdp_min_num_params=0, fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}, tp_size=0, fsdp_transformer_layer_cls_to_wrap=None, accelerator_config=AcceleratorConfig(split_batches=False, dispatch_batches=None, even_batches=True, use_seedable_sampler=True, non_blocking=False, gradient_accumulation_kwargs=None, use_configured_state=False), deepspeed=None, label_smoothing_factor=0.0, optim=<OptimizerNames.ADAMW_TORCH: 'adamw_torch'>, optim_args=None, adafactor=False, group_by_length=False, length_column_name='length', report_to=[], ddp_find_unused_parameters=None, ddp_bucket_cap_mb=None, ddp_broadcast_buffers=None, dataloader_pin_memory=True, dataloader_persistent_workers=False, skip_memory_metrics=True, use_legacy_prediction_loop=False, push_to_hub=False, resume_from_checkpoint=None, hub_model_id=None, hub_strategy=<HubStrategy.EVERY_SAVE: 'every_save'>, hub_token=None, hub_private_repo=None, hub_always_push=False, gradient_checkpointing=False, gradient_checkpointing_kwargs=None, include_inputs_for_metrics=False, include_for_metrics=[], eval_do_concat_batches=True, fp16_backend='auto', push_to_hub_model_id=None, push_to_hub_organization=None, push_to_hub_token=None, mp_parameters='', auto_find_batch_size=False, full_determinism=False, torchdynamo=None, ray_scope='last', ddp_timeout=1800, torch_compile=False, torch_compile_backend=None, torch_compile_mode=None, include_tokens_per_second=False, include_num_input_tokens_seen=False, neftune_noise_alpha=None, optim_target_modules=None, batch_eval_metrics=False, eval_on_start=False, use_liger_kernel=False, eval_use_gather_object=False, average_tokens_across_devices=False, tau_temp_begin=4.0, tau_temp_end=0.05, tau_decay_steps=0.4, norm_lambda=0, norm_power=1, use_gumbel=True, gumbel_hard=True, distributed_state=Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
, _n_gpu=1, __cached__setup_devices=device(type='cuda', index=0), deepspeed_plugin=None, train_data_path='dataset/function_vectors/abstractive', dev_data_path='dataset/function_vectors/abstractive', dataset_use_cache=False, max_seq_length=32768)
任务训练,以反义词任务为例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 train_data_dir = args.train_data_pathif "XNLI" in train_data_dir: ALL_LANGS = ["en" , "zh" , "fr" , "de" , "es" , "ru" , "ar" ] for i, (src_lang, tgt_lang) in enumerate (itertools.permutations(ALL_LANGS, 2 )): print (f"\n********** [{i+1 } /{len (ALL_LANGS) * (len (ALL_LANGS) - 1 )} ] Training {src_lang} -> {tgt_lang} **********\n" ) search_weight_embed(lm_model, f"{src_lang} _{tgt_lang} " , args, training_args)elif "function_vectors" in train_data_dir: for root, dirs, files in os.walk(train_data_dir): for file in files: if file.endswith("antonym.json" ): task = os.path.basename(file).split("." )[0 ] print (f"\n********** Training {task} **********\n" ) args.train_data_path = os.path.join(root, file) search_weight_embed(lm_model, task, args, training_args)else : assert False , "Unsupported dataset"
********** Training antonym **********
Generating train split: 2398 examples [00:00, 77072.94 examples/s]
Map: 100%|██████████| 2398/2398 [00:00<00:00, 36112.36 examples/s]
Map: 100%|██████████| 2298/2298 [00:00<00:00, 7820.68 examples/s]
Filter: 100%|██████████| 2298/2298 [00:00<00:00, 59810.06 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 3802.15 examples/s]
Filter: 100%|██████████| 100/100 [00:00<00:00, 12812.90 examples/s]
Step Training Loss
500 6.312800
1000 1.012000
1500 0.716200
2000 0.607200
2500 0.609400
3000 0.585000
3500 0.594200
4000 0.583300
4500 0.586900
5000 0.554000
5500 0.540600
6000 0.546000
训练结束后,保存路径下会有若干个 checkpoint- 开头的文件夹,每个文件夹里面的内容为:
1 2 3 model.safetensors trainer_state.json training_args.bin
2.2 评估
读取训练的不同任务的注意力头 Mask 权重:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 weight_dict = {}for root, dir , files in os.walk("output/llama/xnli" ): for file in files: if file.endswith(".safetensors" ) and "-6180" in root: langs = os.path.basename(os.path.dirname(root)) path = os.path.join(root, file) weight_dict[langs] = safetensors.torch.load_file(path)["tensor" ] for root, dir , files in os.walk("output/llama/fv" ): for file in files: if file.endswith(".safetensors" ) and "-6250" in root: langs = os.path.basename(os.path.dirname(root)) path = os.path.join(root, file) weight_dict[langs] = safetensors.torch.load_file(path)["tensor" ]print ("Total tasks:" , len (weight_dict))print ("Task name" , "\t" , "Weight" , "\t" , "# of up heads" )for k, v in sorted (weight_dict.items()): print (k, "\t" , v, "\t" , (v.sigmoid() >= 0.5 ).sum ().item())
Total tasks: 5
Task name Weight # of up heads
adjective_v_verb_3 tensor([6.4547, 0.8442, 4.9701, ..., 6.9430, 5.0350, 3.9923]) 845
antonym tensor([4.9961, 4.3987, 4.5589, ..., 7.8902, 3.0385, 3.9649]) 864
en_zh tensor([5.2550, 1.9984, 5.9432, ..., 7.5652, 3.0530, 4.0198]) 856
verb_v_adjective_3 tensor([6.7996, 2.7995, 6.5897, ..., 7.7574, 5.2005, 3.9966]) 871
zh_en tensor([ 1.6995, 3.6389, 6.9171, ..., 4.2308, -1.9574, 3.9649]) 966
载入模型和分词器,选择模板:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 DEVICE = "cuda:0" MODEL_DIR = "../model/Llama-3.1-8B-Instruct" model_name = os.path.basename(MODEL_DIR).lower() tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True ) tokenizer.pad_token = tokenizer.eos_token model = LlamaForCausalLM.from_pretrained(MODEL_DIR, local_files_only=True , device_map=DEVICE, torch_dtype=torch.bfloat16, attn_implementation="eager" , max_position_embeddings=2048 ) model.generation_config.pad_token_id = tokenizer.pad_token_id n_layers = model.config.num_hidden_layers n_heads = model.config.num_attention_headsif "llama" in model_name: if "instruct" in model_name: template = LLAMA_TEMPLATE print ("using LLAMA_TEMPLATE" ) else : template = LLAMA_PLM_TEMPLATE print ("using LLAMA_PLM_TEMPLATE" )elif "qwen2" in model_name: if "instruct" in model_name: template = QWEN2_TEMPLATE print ("using QWEN2_TEMPLATE" ) else : template = QWEN2_PLM_TEMPLATE print ("using QWEN2_PLM_TEMPLATE" )else : print ("Unknown model" )
Loading checkpoint shards: 100%|██████████| 4/4 [00:19<00:00, 4.79s/it]
using LLAMA_TEMPLATE
任务提示词:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 TASK_DICT = { "lowercase_first_letter" : "Output the first letter of the given word in lowercase." , "park-country" : "Identify the country where the given national park is located." , "synonym" : "Identify a synonym for the given word." , "ag_news" : "Classify the given news headline into one of the categories: Business, Science, Sports, or World. Provide only the category name." , "word_length" : "Determine the number of letters in the given word and output the count." , "present-past" : "Convert the given verb from its present tense to its simple past tense." , "capitalize" : "Output the given word with its first letter capitalized." , "landmark-country" : "Identify the country where the given landmark is located." , "english-german" : "Translate the given English word into German." , "sentiment" : "Determine the sentiment of the given input. Output either 'positive' or 'negative'." , "country-capital" : "What is the capital of the given country? Provide only the name of the capital." , "person-occupation" : "Identify the occupation of the given individual." , "country-currency" : "What is the official currency of the given country?" , "lowercase_last_letter" : "Output the last letter of the given word in lowercase." , "person-sport" : "Identify the sport associated with the given individual." , "person-instrument" : "Identify the musical instrument played by the given musician." , "antonym" : "Identify the antonym of the given word." , "capitalize_last_letter" : "Output the last letter of the given word in uppercase." , "english-french" : "Translate the given English word into French." , "next_item" : "What is the next sequential item following the given input?" , "singular-plural" : "Provide the plural form of the given singular noun." , "capitalize_second_letter" : "Output the second letter of the given word in uppercase." , "prev_item" : "What is the item that comes before the given input in a sequential context?" , "capitalize_first_letter" : "Output the first letter of the given word in uppercase." , "english-spanish" : "Translate the given English word into Spanish." , "next_capital_letter" : "What is the next uppercase letter in alphabetical order after the given input?" , "national_parks" : "Identify the U.S. state where the given national park is located." , "product-company" : "Identify the company associated with the given product." , "conll2003_organization" : "Extract the organization mentioned in the given text." , "conll2003_person" : "Extract the name of the person mentioned in the given text." , "conll2003_location" : "Extract the location mentioned in the given text." , "adjective_v_verb_3" : "From the given words, identify the one that is an adjective." , "object_v_concept_3" : "From the given words, identify the one that is a object." , "verb_v_adjective_3" : "From the given words, identify the one that is a verb." , "fruit_v_animal_3" : "From the given words, identify the one that is a fruit." }
评估反义词任务:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 def eval_function_vectors (use_instruction, use_mask ): dev_datasets = [] for root, dirs, files in os.walk("dataset/function_vectors/abstractive" ): for file in files: if file.endswith("antonym.json" ): task = os.path.basename(file).split("." )[0 ] data_path = os.path.join(root, file) dataset = datasets.Dataset.from_json(data_path) dataset = dataset.map ( lambda sample: { "input_str" : TASK_DICT[task] + f"\n\nInput:\n\n{sample['input' ]} \n\nOutput:\n\n" if use_instruction else sample["input" ], "target_str" : sample["output" ] } ) dataset = dataset.remove_columns(["input" , "output" ]) dev_dataset = dataset.select(range (len (dataset) - 100 , len (dataset))) dev_datasets.append((task, dev_dataset)) for task, dataset in tqdm(dev_datasets): mask_weight = (weight_dict[task].sigmoid() >= 0.5 ).float ().numpy() mask_tensor = torch.tensor(mask_weight).unsqueeze(0 ).repeat(1 , 1 ).to(model.device) correct = 0 for sample in dataset: input_str, target_str = str (sample["input_str" ]), str (sample["target_str" ]) combined_str = template.format (src=input_str.strip()) lm_inputs_src = tokenizer(combined_str, max_length=2048 , truncation=True , padding=False , add_special_tokens=False ,return_tensors="pt" ).to(DEVICE) with torch.no_grad(): if use_mask: generate_ids = model.generate(**lm_inputs_src, max_new_tokens=10 , weight_tensor=mask_tensor, do_sample=False ) else : generate_ids = model.generate(**lm_inputs_src, max_new_tokens=10 , weight_tensor=None , do_sample=False ) pred_str = tokenizer.decode(generate_ids[0 ][lm_inputs_src.input_ids.size(1 ):], skip_special_tokens=True ) if task in [ "capitalize_first_letter" , "capitalize_last_letter" , "capitalize_second_letter" , "capitalize" , "lowercase_first_letter" , "lowercase_last_letter" , "next_capital_letter" , "next_item" , "prev_item" , "commonsense_qa" , "conll2003_organization" , "conll2003_person" , "conll2003_location" , "adjective_v_verb_3" , "object_v_concept_3" , "verb_v_adjective_3" , "fruit_v_animal_3" , ]: if pred_str.strip().startswith(target_str.strip()): correct += 1 else : if target_str.strip() in pred_str.strip(): correct += 1 score = correct / len (dataset) print (f"{task} : {score} " ) print ("使用指令,使用Mask" ) eval_function_vectors(use_instruction=True , use_mask=True )print ("使用指令,不使用Mask" ) eval_function_vectors(use_instruction=True , use_mask=False )print ("不使用指令,使用Mask" ) eval_function_vectors(use_instruction=False , use_mask=True )print ("不使用指令,不使用Mask" ) eval_function_vectors(use_instruction=False , use_mask=False )
using LLAMA_TEMPLATE
使用指令,使用Mask
100%|██████████| 1/1 [00:10<00:00, 10.85s/it]
antonym: 0.63
使用指令,不使用Mask
100%|██████████| 1/1 [00:30<00:00, 30.66s/it]
antonym: 0.3
不使用指令,使用Mask
100%|██████████| 1/1 [00:08<00:00, 8.78s/it]
antonym: 0.76
不使用指令,不使用Mask
100%|██████████| 1/1 [00:30<00:00, 30.58s/it]
antonym: 0.12
观察对比抽取任务的注意力分布情况:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 dev_datasets = []for root, dirs, files in os.walk("dataset/function_vectors/abstractive" ): for file in files: if file.endswith("verb_3.json" ) or file.endswith("adjective_3.json" ): task = os.path.basename(file).split("." )[0 ] if "_v_" not in task: continue data_path = os.path.join(root, file) dataset = datasets.Dataset.from_json(data_path) dataset = dataset.map (lambda sample: {"input_str" : sample["input" ], "target_str" : sample["output" ]}) dataset = dataset.remove_columns(["input" , "output" ]) dev_dataset = dataset.select(range (len (dataset) - 100 , len (dataset))) dev_datasets.append((task, dev_dataset))def find_segments_ (l ): indices_271 = [i for i, x in enumerate (l) if x == 271 ] start, end = indices_271[0 ], indices_271[1 ] segment = l[start + 1 :end] segments = [] current_segment = [] for idx, value in enumerate (segment, start=start + 1 ): if value == 11 : if current_segment: segments.append(current_segment) current_segment = [] else : current_segment.append(idx) if current_segment: segments.append(current_segment) return segments task_attentions = {}for task, dataset in tqdm(dev_datasets): mask_weight = (weight_dict[task].sigmoid() >= 0.5 ).float ().numpy() mask_tensor = torch.tensor(mask_weight).unsqueeze(0 ).repeat(1 , 1 ).to(model.device) choose_the, choose_from = task.split("_v_" )[0 ], task.split("_v_" )[1 ][:-2 ] instructed_attn, used_attn, unused_attn = [], [], [] for sample in dataset: target_choice_idx = sample["input_str" ].split(", " ).index(sample["target_str" ]) input_inst_str = f"{sample['input_str' ]} \n\nChoose the {choose_the} out of {choose_from} s:" input_str = f"{sample['input_str' ]} " lm_inputs_inst_src = tokenizer([LLAMA_TEMPLATE.format (src=input_inst_str)], add_special_tokens=False , return_tensors="pt" ).to(DEVICE) lm_inputs_inst_src_choice_idx = find_segments_(lm_inputs_inst_src.input_ids[0 ].tolist()) assert len (lm_inputs_inst_src_choice_idx) == 3 lm_inputs_src = tokenizer([LLAMA_TEMPLATE.format (src=input_str)], add_special_tokens=False , return_tensors="pt" ).to(DEVICE) lm_inputs_src_choice_idx = lm_inputs_inst_src_choice_idx with torch.no_grad(): original_output = model(**lm_inputs_inst_src, weight_tensor=None , output_attentions=True ) original_attention = torch.stack(original_output.attentions) last_token_attn = original_attention[:, -1 , :, -1 ] original_token_attention = last_token_attn.mean(dim=(0 ,1 )).float ().cpu().numpy() original_choice_attentions = [original_token_attention[lm_inputs_inst_src_choice_idx[i]].sum () for i in range (3 )] original_choice_attentions = np.array(original_choice_attentions) / sum (original_choice_attentions) instructed_attn.append(original_choice_attentions[target_choice_idx]) weighted_output = model(**lm_inputs_src, weight_tensor=mask_tensor, output_attentions=True ) weighted_attention = torch.stack(weighted_output.attentions) last_token_attn = weighted_attention[:, -1 , :, -1 ] last_token_attn[mask_tensor.view(32 ,33 )[:, :-1 ] == 0 ] = 0 layer_head_num = (mask_tensor.view(32 ,33 )[:, :-1 ] != 0 ).sum (dim=1 ) weighted_token_attention = (last_token_attn.sum (dim=(0 ,1 )) / (mask_tensor.sum () - 32 )).float ().cpu().numpy() weighted_choice_attentions = [weighted_token_attention[lm_inputs_src_choice_idx[i]].sum () for i in range (3 )] weighted_choice_attentions = np.array(weighted_choice_attentions) / sum (weighted_choice_attentions) used_attn.append(weighted_choice_attentions[target_choice_idx]) weighted_output = model(**lm_inputs_src, weight_tensor=mask_tensor, output_attentions=True ) weighted_attention = torch.stack(weighted_output.attentions) last_token_unused_attn = weighted_attention[:, -1 , :, -1 ] last_token_unused_attn[mask_tensor.view(32 ,33 )[:, :-1 ] == 1 ] = 0 weighted_token_unused_attention = (last_token_unused_attn.sum (dim=(0 ,1 )) / (1056 - mask_tensor.sum ())).float ().cpu().numpy() weighted_choice_unused_attentions = [weighted_token_unused_attention[lm_inputs_src_choice_idx[i]].sum () for i in range (3 )] weighted_choice_unused_attentions = np.array(weighted_choice_unused_attentions) / sum (weighted_choice_unused_attentions) unused_attn.append(weighted_choice_unused_attentions[target_choice_idx]) task_attentions[task] = {"Instructed" : np.mean(instructed_attn), "Used" : np.mean(used_attn), "Ununsed" : np.mean(unused_attn)} df = pd.DataFrame.from_dict(task_attentions, orient="index" ) display(df)
100%|██████████| 2/2 [00:19<00:00, 9.96s/it]
Instructed Used Ununsed
verb_v_adjective_3 0.345372 0.426140 0.379153
adjective_v_verb_3 0.408563 0.462195 0.483597
评估中英互译任务,指标为 PPL 困惑度:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 LANG_LIST = ["en" , "zh" ] LANG_DICT = {"ar" : "Arabic" , "fr" : "French" , "es" : "Spanish" , "de" : "German" , "en" : "English" , "ru" : "Russian" , "zh" : "Chinese" } dataset = datasets.Dataset.from_csv("dataset/XNLI-15way/xnli.15way.orig.tsv" , sep='\t' ).select_columns(LANG_LIST) dataset = dataset.select(range (len (dataset) - 100 , len (dataset)))def eval_xnli_ppl (use_random, use_instruction ): langs_original_ppl, langs_weighted_ppl = {}, {} for src_lang, tgt_lang in tqdm(list (itertools.permutations(LANG_LIST, 2 ))): mask_weight = (weight_dict[f"{src_lang} _{tgt_lang} " ].sigmoid() >= 0.5 ).float ().numpy() if use_random: mask = np.zeros_like(mask_weight, dtype=bool ) mask[np.arange(n_heads, (n_heads + 1 ) * n_layers, n_heads + 1 )] = True random_weight = mask_weight[~mask] np.random.shuffle(random_weight) mask_weight[~mask] = random_weight dev_dataset = dataset.map ( lambda sample: { "input_str" : sample[src_lang] + f"\n\nTranslate into {LANG_DICT[tgt_lang]} :" if use_instruction else sample[src_lang], "target_str" : sample[tgt_lang] } ).select_columns(["input_str" , "target_str" ]) if src_lang not in langs_original_ppl: langs_original_ppl[src_lang] = {src_lang: 0 } if src_lang not in langs_weighted_ppl: langs_weighted_ppl[src_lang] = {src_lang: 0 } original_ppl, weighted_ppl = [], [] for input_str, target_str in zip (*dev_dataset[:100 ].values()): combined_str = LLAMA_TEMPLATE.format (src=input_str.strip()) + f"{target_str.strip()} <|eot_id|>" lm_inputs = tokenizer(combined_str, max_length=2048 , truncation=True , padding=False , add_special_tokens=False , return_tensors="pt" ).to(DEVICE) input_str_len = tokenizer(LLAMA_TEMPLATE.format (src=input_str.strip()), add_special_tokens=False , return_tensors="pt" )["input_ids" ].size(-1 ) labels = copy.deepcopy(lm_inputs["input_ids" ]).to(DEVICE) labels[:, :input_str_len] = -100 with torch.no_grad(): original_output = model(**lm_inputs, labels=labels) original_ppl.append(original_output.loss.item()) mask_tensor = torch.tensor(mask_weight).unsqueeze(0 ).repeat(1 , 1 ).to(model.device) weighted_output = model(**lm_inputs, labels=labels, weight_tensor=mask_tensor) weighted_ppl.append(weighted_output.loss.item()) langs_original_ppl[src_lang][tgt_lang] = np.nanmean(original_ppl) langs_weighted_ppl[src_lang][tgt_lang] = np.nanmean(weighted_ppl) original_df = pd.DataFrame.from_dict(langs_original_ppl, orient="index" ) print ("w/o head mask:" ) display(original_df) weighted_df = pd.DataFrame.from_dict(langs_weighted_ppl, orient="index" ) print ("w/ head mask:" ) display(weighted_df) print ("使用指令" ) eval_xnli_ppl(use_random=False , use_instruction=True )print ("使用指令,随机Mask" ) eval_xnli_ppl(use_random=True , use_instruction=True )print ("不使用指令" ) eval_xnli_ppl(use_random=False , use_instruction=False )print ("不使用指令,随机Mask" ) eval_xnli_ppl(use_random=True , use_instruction=False )
使用指令
100%|██████████| 2/2 [00:13<00:00, 6.62s/it]
w/o head mask:
en zh
en 0.000000 1.091267
zh 1.002608 0.000000
w/ head mask:
en zh
en 0.000000 1.004982
zh 0.986755 0.000000
使用指令,随机Mask
100%|██████████| 2/2 [00:13<00:00, 6.61s/it]
w/o head mask:
en zh
en 0.000000 1.091267
zh 1.002608 0.000000
w/ head mask:
en zh
en 0.000000 1.802226
zh 1.260897 0.000000
不使用指令
100%|██████████| 2/2 [00:13<00:00, 6.61s/it]
w/o head mask:
en zh
en 0.000000 2.643437
zh 2.142624 0.000000
w/ head mask:
en zh
en 0.000000 0.99274
zh 0.989524 0.00000
不使用指令,随机Mask
100%|██████████| 2/2 [00:13<00:00, 6.60s/it]
w/o head mask:
en zh
en 0.000000 2.643437
zh 2.142624 0.000000
w/ head mask:
en zh
en 0.000000 3.069456
zh 2.092448 0.000000
评估中英互译任务,指标为 ROUGE-L:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 sentence_tokenizer = AutoTokenizer.from_pretrained("../model/bge-reranker-v2-m3" , use_fast=True ) rouge_scorer_instance = rouge_scorer.RougeScorer(["rougeL" ], use_stemmer=True , tokenizer=sentence_tokenizer) def eval_xnli_rouge (use_instruction ): langs_original_rouge, langs_weighted_rouge = {}, {} for src_lang, tgt_lang in tqdm(list (itertools.permutations(LANG_LIST, 2 ))[:12 ]): mask_weight = (weight_dict[f"{src_lang} _{tgt_lang} " ].sigmoid() >= 0.5 ).float ().numpy() dev_dataset = dataset.map ( lambda sample: { "input_str" : sample[src_lang] + f"\n\nTranslate into {LANG_DICT[tgt_lang]} :" if use_instruction else sample[src_lang], "target_str" : sample[tgt_lang] } ).select_columns(["input_str" , "target_str" ]) if src_lang not in langs_original_rouge: langs_original_rouge[src_lang] = {src_lang: 0 } if src_lang not in langs_weighted_rouge: langs_weighted_rouge[src_lang] = {src_lang: 0 } original_rouge, weighted_rouge = [], [] for input_str, target_str in zip (*dev_dataset[:100 ].values()): combined_str = LLAMA_TEMPLATE.format (src=input_str.strip()) lm_inputs_src = tokenizer(combined_str, max_length=2048 , truncation=True , padding=False , add_special_tokens=False ,return_tensors="pt" ).to(DEVICE) with torch.no_grad(): original_generate_ids = model.generate(**lm_inputs_src, max_new_tokens=50 , weight_tensor=None , do_sample=False ) original_pred_str = tokenizer.decode(original_generate_ids[0 ][lm_inputs_src.input_ids.size(1 ):], skip_special_tokens=True ) original_rouge_scores = rouge_scorer_instance.score(original_pred_str, target_str) original_rouge_l = original_rouge_scores["rougeL" ].fmeasure original_rouge.append(original_rouge_l) mask_tensor = torch.tensor(mask_weight).unsqueeze(0 ).repeat(1 , 1 ).to(model.device) weighted_generate_ids = model.generate(**lm_inputs_src, max_new_tokens=50 , weight_tensor=mask_tensor, do_sample=False ) weighted_pred_str = tokenizer.decode(weighted_generate_ids[0 ][lm_inputs_src.input_ids.size(1 ):], skip_special_tokens=True ) weighted_rouge_scores = rouge_scorer_instance.score(weighted_pred_str, target_str) weighted_rouge_l = weighted_rouge_scores["rougeL" ].fmeasure weighted_rouge.append(weighted_rouge_l) langs_original_rouge[src_lang][tgt_lang] = np.nanmean(original_rouge) langs_weighted_rouge[src_lang][tgt_lang] = np.nanmean(weighted_rouge) original_df = pd.DataFrame.from_dict(langs_original_rouge, orient="index" ) display(original_df) weighted_df = pd.DataFrame.from_dict(langs_weighted_rouge, orient="index" ) display(weighted_df) print ("使用指令" ) eval_xnli_rouge(use_instruction=True )print ("不使用指令" ) eval_xnli_rouge(use_instruction=False )
使用指令
100%|██████████| 2/2 [02:58<00:00, 89.34s/it]
en zh
en 0.000000 0.539153
zh 0.668623 0.000000
en zh
en 0.000000 0.60087
zh 0.654566 0.00000
不使用指令
100%|██████████| 2/2 [05:00<00:00, 150.44s/it]
en zh
en 0.00000 0.019667
zh 0.02056 0.000000
en zh
en 0.000000 0.603147
zh 0.655537 0.000000
功能性注意力头消融实验:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 langs_weighted_ppl = {} lang_pairs = [("en" , "zh" ), ("zh" , "en" )]for src_lang, tgt_lang in tqdm(lang_pairs): dev_dataset = dataset.map ( lambda sample: { "input_str" : sample[src_lang] + f"\n\nTranslate into {LANG_DICT[tgt_lang]} :" , "target_str" : sample[tgt_lang], } ).select_columns(["input_str" , "target_str" ]) weighted_ppl = {v: [] for v in np.arange(0 , 1.1 , 0.1 )} for input_str, target_str in tqdm(zip (*dev_dataset[:100 ].values()), total=100 ): combined_str = LLAMA_TEMPLATE.format (src=input_str.strip()) + f"{target_str.strip()} <|eot_id|>" lm_inputs = tokenizer(combined_str, max_length=2048 , truncation=True , padding=False , add_special_tokens=False ,return_tensors="pt" ).to(DEVICE) input_str_len = tokenizer(LLAMA_TEMPLATE.format (src=input_str.strip()), add_special_tokens=False , return_tensors="pt" )["input_ids" ].size(-1 ) labels = copy.deepcopy(lm_inputs["input_ids" ]).to(DEVICE) labels[:, :input_str_len] = -100 with torch.no_grad(): for min_weight in np.arange(0 , 1.1 , 0.1 ): mask_weight = (weight_dict[f"{src_lang} _{tgt_lang} " ].sigmoid() >= 0.5 ).float ().numpy().clip(min =min_weight) mask_tensor = torch.tensor(mask_weight).unsqueeze(0 ).repeat(1 , 1 ).to(model.device) weighted_output = model(**lm_inputs, labels=labels, weight_tensor=mask_tensor) weighted_ppl[min_weight].append(weighted_output.loss.item()) weighted_ppl = {k: np.nanmean(weighted_ppl[k]) for k in weighted_ppl} langs_weighted_ppl[f"{src_lang} _{tgt_lang} " ] = weighted_ppl df = pd.DataFrame.from_dict(langs_weighted_ppl, orient="index" ) display(df)
0%| | 0/2 [00:00<?, ?it/s]
100%|██████████| 100/100 [00:38<00:00, 2.61it/s]
50%|█████ | 1/2 [00:38<00:38, 38.36s/it]
100%|██████████| 100/100 [00:38<00:00, 2.62it/s]
100%|██████████| 2/2 [01:16<00:00, 38.25s/it]
0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
en_zh 1.004982 1.002170 1.005190 1.010720 1.014762 1.018487 1.022877 1.029847 1.045129 1.063954 1.091267
zh_en 0.986755 0.986517 0.987423 0.988694 0.988407 0.989914 0.989815 0.991908 0.994180 0.999186 1.002608
绘制消融实验图:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 colors = sns.color_palette(n_colors=2 ) plt.figure(figsize=(10 , 6 )) plt.ylim(0.9 , 1.1 ) plt.xticks(np.arange(0 , 1.1 , 0.1 ), fontsize=12 ) plt.xlabel(r"Scaling factor $\alpha$" , fontsize=18 ) plt.ylabel(r"PPL with respective $\mathcal{M}$" , fontsize=18 ) maskonly_ppl = {"en_zh" : 1.016755 , "zh_en" : 0.997446 } for i, (lang_pair, values) in enumerate (df.iterrows()): plt.plot(np.arange(0 , 1.1 , 0.1 ), values, linestyle='-' , label=lang_pair, c=colors[i]) plt.plot(0 , maskonly_ppl[lang_pair], marker="x" , c=colors[i]) plt.grid(alpha=0.3 ) plt.legend(ncol=3 , fontsize=12 , loc="lower right" ) plt.tight_layout() plt.show()
评估中英互译任务,指标为 BLEU,数据集为 IWSLT2017:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 LANG_DICT = {"fr" : "French" , "de" : "German" , "en" : "English" , "zh" : "Chinese" } LANGPAIR_MIN = {"en_de" : 0.3 , "en_fr" : 0.2 , "en_zh" : 0.1 , "zh_en" : 0.3 , "fr_en" : 0.2 , "de_en" : 0.4 } langs_original_bleu, langs_weighted_bleu = {}, {} smoothing = SmoothingFunction().method1def general_postprocess (text ): answer_text = text.split("\n\n" ) if len (answer_text) > 2 : answer_text = answer_text[1 :2 ] elif len (answer_text) > 1 : answer_text = answer_text[1 :] answer_text = "\n\n" .join(answer_text) no_punctuation = re.sub(r'[^\w\s]' , '' , answer_text) cleaned_text = re.sub(r'\s+' , ' ' , no_punctuation).strip() cleaned_text = " " .join(jieba.cut(cleaned_text)) return cleaned_textdef postprocess_and_score (text, target_str ): answer_text = text.split("\n\n" ) cleaned_texts = [" " .join(jieba.cut(re.sub(r'\s+' , ' ' , re.sub(r'[^\w\s]' , '' , t)).strip())) for t in answer_text] scores = [sentence_bleu([target_str], t, smoothing_function=smoothing) for t in cleaned_texts] return max (zip (cleaned_texts, scores), key=lambda x: x[1 ]) for src_lang, tgt_lang in [("en" , "zh" ), ("zh" , "en" )]: if tgt_lang != "en" : dev_dataset = datasets.Dataset.from_parquet(f"dataset/iwslt2017/{src_lang} -{tgt_lang} -test.parquet" ) else : dev_dataset = datasets.Dataset.from_parquet(f"dataset/iwslt2017/{tgt_lang} -{src_lang} -test.parquet" ) dev_dataset = dev_dataset.map ( lambda sample: { "input_str" : f"{sample['translation' ][src_lang]} \n\nTranslate into {tgt_lang} :" , "target_str" : sample['translation' ][tgt_lang], } ).select_columns(["input_str" , "target_str" ]) dev_dataset = dev_dataset.select(range (len (dataset) - 100 , len (dataset))) mask_weight = (weight_dict[f"{src_lang} _{tgt_lang} " ].sigmoid() >= 0.5 ).float ().numpy().clip(min =LANGPAIR_MIN[f"{src_lang} _{tgt_lang} " ]) original_bleu, weighted_bleu = [], [] for sample in tqdm(dev_dataset): input_str, target_str = sample["input_str" ], sample["target_str" ] target_str = general_postprocess(target_str) combined_str = LLAMA_TEMPLATE.format (src=input_str.strip()) lm_inputs_src = tokenizer(combined_str, max_length=2048 , truncation=True , padding=False , add_special_tokens=False ,return_tensors="pt" ).to(DEVICE) with torch.no_grad(): original_generate_ids = model.generate(**lm_inputs_src, max_new_tokens=50 , weight_tensor=None , do_sample=False ) original_pred_str = tokenizer.decode(original_generate_ids[0 ][lm_inputs_src.input_ids.size(1 ):], skip_special_tokens=True ) original_pred_str, original_bleu4 = postprocess_and_score(original_pred_str, target_str) original_bleu.append(original_bleu4) mask_tensor = torch.tensor(mask_weight).unsqueeze(0 ).repeat(1 , 1 ).to(model.device) weighted_generate_ids = model.generate(**lm_inputs_src, max_new_tokens=50 , weight_tensor=mask_tensor, do_sample=False ) weighted_pred_str = tokenizer.decode(weighted_generate_ids[0 ][lm_inputs_src.input_ids.size(1 ):], skip_special_tokens=True ) weighted_pred_str, weighted_bleu4 = postprocess_and_score(weighted_pred_str, target_str) weighted_bleu.append(weighted_bleu4) langs_original_bleu[f"{src_lang} _{tgt_lang} " ] = np.nanmean(original_bleu) langs_weighted_bleu[f"{src_lang} _{tgt_lang} " ] = np.nanmean(weighted_bleu) print (f"{src_lang} _{tgt_lang} : {np.nanmean(original_bleu)} {np.nanmean(weighted_bleu)} " ) original_df = pd.DataFrame.from_dict(langs_original_bleu, orient="index" ) display(original_df) weighted_df = pd.DataFrame.from_dict(langs_weighted_bleu, orient="index" ) display(weighted_df)
Loading model from cache /tmp/jieba.cache
Loading model cost 0.598 seconds.
Prefix dict has been built successfully.
100%|██████████| 100/100 [02:09<00:00, 1.30s/it]
en_zh: 0.3118341977481546 0.3301332102535084
100%|██████████| 100/100 [01:50<00:00, 1.10s/it]
zh_en: 0.5805953950107409 0.5813078377389967
0
en_zh 0.311834
zh_en 0.580595
0
en_zh 0.330133
zh_en 0.581308
2.3 研究
对比有指令和无指令的机器翻译结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 def translate (input_str, src_lang="en" , tgt_lang="zh" ): mask_weight = (weight_dict[f"{src_lang} _{tgt_lang} " ].sigmoid() >= 0.5 ).float ().numpy().clip(min =0.0 ) print (mask_weight.sum ()) with torch.no_grad(): lm_inputs_src = tokenizer([LLAMA_TEMPLATE.format (src=input_str)+"" ], add_special_tokens=False , return_tensors="pt" ).to(DEVICE) generate_ids = model.generate(**lm_inputs_src, max_new_tokens=50 , weight_tensor=None , do_sample=False ) pred_str = tokenizer.decode(generate_ids[0 ][lm_inputs_src.input_ids.size(1 ):], skip_special_tokens=True ) print ("[ORIGINAL]" , pred_str) mask_tensor = torch.tensor(mask_weight).unsqueeze(0 ).repeat(1 , 1 ).to(model.device) generate_ids = model.generate(**lm_inputs_src, max_new_tokens=50 , weight_tensor=mask_tensor, do_sample=False ) pred_str = tokenizer.decode(generate_ids[0 ][lm_inputs_src.input_ids.size(1 ):], skip_special_tokens=True ) print (f"[{tgt_lang} ]" , pred_str) random_tensor = torch.tensor(mask_weight)[torch.randperm(len (mask_weight))].unsqueeze(0 ).repeat(1 , 1 ).to(model.device) generate_ids = model.generate(**lm_inputs_src, max_new_tokens=50 , weight_tensor=random_tensor, do_sample=False ) pred_str = tokenizer.decode(generate_ids[0 ][lm_inputs_src.input_ids.size(1 ):], skip_special_tokens=True ) print ("[RANDOM]" , pred_str) translate("I have never seen such a beautiful sunset." ) translate("This is expected as the model has lost the majority of its attention heads.\n\nTranslate into Chinese:" )
856.0
[ORIGINAL] I'm glad you're enjoying the moment! Sunsets can be truly breathtaking, with their vibrant colors and serene atmosphere. They have a way of evoking feelings of peace and wonder. What made this sunset particularly special to you? Was it the colors
[zh] 我从来没有见过如此美丽的日落。
[RANDOM] As a conversational AI, I don't have personal experiences, but I can tell you that sunsets can be truly breathtaking. The colors of the sky during a sunset are a result of a combination of atmospheric conditions, including the scattering of light,
856.0
[ORIGINAL] 这也算是预期的,因为模型已经失去了大多数注意力头。
[zh] 这预计是因为模型失去了大多数注意力头。
[RANDOM] This is expected as the model has lost the majority of its attention.
计算每层的输出的平均相似度和前 5 个预测的 token:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 src_lang = "en" tgt_lang = "zh" input_str = "I see a llama sleeping in my backyard." input_str_inst = input_str + "\n\nTranslate into Chinese:" target_str = "我看见一只羊驼在我的后院睡觉。" with torch.no_grad(): lm_inputs_src_inst = tokenizer([LLAMA_TEMPLATE.format (src=input_str_inst)+"" ], add_special_tokens=False , return_tensors="pt" ).to(DEVICE) lm_inputs_src = tokenizer([LLAMA_TEMPLATE.format (src=input_str)+"" ], add_special_tokens=False , return_tensors="pt" ).to(DEVICE) original_output = model(**lm_inputs_src_inst, output_hidden_states=True , output_attentions=True ) layer_outputs = torch.stack(original_output.hidden_states) layer_token_logits = model.lm_head(layer_outputs[:, 0 , -1 ]) original_layer_pred_tokens = tokenizer.batch_decode(layer_token_logits.argmax(-1 )) original_layer_pred_tokens_top5 = [tokenizer.batch_decode(layer_token_logits.argsort(dim=-1 , descending=True )[:, k]) for k in range (5 )] original_res = layer_outputs[1 :, 0 , -1 ] - layer_outputs[:-1 , 0 , -1 ] mask_tensor = torch.tensor((weight_dict["en_zh" ].sigmoid() >= 0.5 ).float ().numpy()).unsqueeze(0 ).repeat(1 , 1 ).to(model.device) weighted_output = model(**lm_inputs_src, weight_tensor=mask_tensor, output_hidden_states=True , output_attentions=True ) weighted_layer_outputs = torch.stack(weighted_output.hidden_states) weighted_layer_token_logits = model.lm_head(weighted_layer_outputs[:, 0 , -1 ]) weighted_layer_pred_tokens = tokenizer.batch_decode(weighted_layer_token_logits.argmax(-1 )) weighted_layer_pred_tokens_top5 = [tokenizer.batch_decode(weighted_layer_token_logits.argsort(dim=-1 , descending=True )[:, k]) for k in range (5 )] weighted_res = weighted_layer_outputs[1 :, 0 , -1 ] - weighted_layer_outputs[:-1 , 0 , -1 ] print ("hidden similarity\n" , torch.cosine_similarity((layer_outputs[:, 0 , -1 ]), (weighted_layer_outputs[:, 0 , -1 ]))) print ("lm_head logits similarity\n" , torch.cosine_similarity(layer_token_logits, weighted_layer_token_logits)) print ("residual similarity\n" , torch.cosine_similarity((layer_outputs[1 :, 0 , -1 ] - layer_outputs[:-1 , 0 , -1 ]), (weighted_layer_outputs[1 :, 0 , -1 ] - weighted_layer_outputs[:-1 , 0 , -1 ]))) pd.set_option("display.max_colwidth" , None ) pd.set_option("display.max_rows" , None ) df = pd.DataFrame({"original" : original_layer_pred_tokens, "weighted" : weighted_layer_pred_tokens}) df1 = pd.DataFrame(original_layer_pred_tokens_top5).transpose() display(df1)
hidden similarity
tensor([1.0000, 0.7852, 0.7617, 0.7852, 0.7422, 0.7539, 0.7266, 0.7422, 0.6875,
0.5859, 0.5508, 0.5859, 0.5195, 0.5430, 0.4531, 0.5117, 0.5938, 0.6406,
0.7109, 0.7383, 0.7383, 0.7656, 0.7930, 0.8125, 0.8242, 0.8438, 0.8477,
0.8477, 0.8438, 0.8359, 0.8320, 0.8125, 0.7500], device='cuda:0',
dtype=torch.bfloat16)
lm_head logits similarity
tensor([1.0078, 0.7930, 0.7773, 0.8008, 0.7539, 0.8008, 0.8281, 0.8398, 0.7656,
0.6719, 0.6406, 0.7305, 0.7031, 0.5430, 0.4609, 0.5078, 0.6016, 0.6562,
0.8281, 0.8516, 0.9062, 0.9219, 0.9609, 0.9688, 0.9727, 0.9766, 0.9766,
0.9805, 0.9766, 0.9766, 0.9727, 0.9258, 0.7188], device='cuda:0',
dtype=torch.bfloat16)
residual similarity
tensor([0.7383, 0.6719, 0.6719, 0.5430, 0.6094, 0.5234, 0.5234, 0.4316, 0.2520,
0.2656, 0.2871, 0.2852, 0.2314, 0.2969, 0.3125, 0.3984, 0.4570, 0.6133,
0.5977, 0.6133, 0.7031, 0.7578, 0.7188, 0.7578, 0.8086, 0.7305, 0.7109,
0.7617, 0.7461, 0.7070, 0.7969, 0.7773], device='cuda:0',
dtype=torch.bfloat16)
0 1 2 3 4
0 illo abil incer otron câ
1 cheng the čin utas Alam
2 'gc alink utas .netbeans
3 'gc .netbeans -toggler cheng
4 'gc .netbeans reff .CR edn
5 'gc -toggler шиб ATAB \Dependency
6 #ab 'gc #ac kke ーニ
7 -*-\r\n #ad #ab .netbeans 'gc
8 LLU -LAST -*-\r\n 'gc #ab
9 emain 'gc #af ektor chalk
10 poil .SIG GetInt emain >tag
11 'gc >tag .SIG цес poil
12 ruz 감 корист ncy >tag
13 'gc корист pNet dime Sharper
14 'gc pNet .Reporting Sharper isci
15 -wsj � ContentLoaded меть pNet
16 -wsj >tag :uint lap hon
17 RetVal hon )frame -wsj Macy
18 hon artz increasingly confidently RetVal
19 confidently bes increasingly WSTR hon
20 bes hon increasingly (
21 hon bes in increasingly p
22 I my in c p
23 I my you c me
24 I you my in ..\n
25 my I ..\n in you
26 my ..\n in you I
27 my I ..\n in you
28 in my you ..\n …\n
29 …\n ..\n in /stdc backyard
30 …\n /stdc in ..\n _exempt
31 backyard _exempt 我 in ..\n
32 我 你 您 在 有
计算并绘制每层输出之间的平均相似度:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 with torch.no_grad(): lm_inputs_src_inst = tokenizer([LLAMA_TEMPLATE.format (src=input_str_inst)+target_str], add_special_tokens=False , return_tensors="pt" ).to(DEVICE) lm_inputs_src = tokenizer([LLAMA_TEMPLATE.format (src=input_str)+target_str], add_special_tokens=False , return_tensors="pt" ).to(DEVICE) mask_tensor = torch.tensor((weight_dict["en_zh" ].sigmoid() >= 0.5 ).float ().numpy()).unsqueeze(0 ).repeat(1 , 1 ).to(model.device) original_output = model(**lm_inputs_src, output_hidden_states=True , output_attentions=True ) layer_outputs = torch.stack(original_output.hidden_states) layer_token_logits = model.lm_head(layer_outputs) layer_pred_tokens = tokenizer.batch_decode(layer_token_logits[:, 0 , -15 ].argmax(-1 )) original_inst_output = model(**lm_inputs_src_inst, output_hidden_states=True , output_attentions=True ) inst_layer_outputs = torch.stack(original_inst_output.hidden_states) inst_layer_token_logits = model.lm_head(inst_layer_outputs) inst_layer_pred_tokens = tokenizer.batch_decode(inst_layer_token_logits[:, 0 , -15 ].argmax(-1 )) weighted_output = model(**lm_inputs_src, weight_tensor=mask_tensor, output_hidden_states=True , output_attentions=True ) weighted_layer_outputs = torch.stack(weighted_output.hidden_states) weighted_layer_token_logits = model.lm_head(weighted_layer_outputs) weighted_layer_pred_tokens = tokenizer.batch_decode(weighted_layer_token_logits[:, 0 , -15 ].argmax(-1 )) weighted_inst_output = model(**lm_inputs_src_inst, weight_tensor=mask_tensor, output_hidden_states=True , output_attentions=True ) weighted_inst_layer_outputs = torch.stack(weighted_inst_output.hidden_states) weighted_inst_layer_token_logits = model.lm_head(weighted_inst_layer_outputs) weighted_inst_layer_pred_tokens = tokenizer.batch_decode(weighted_inst_layer_token_logits[:, 0 , -15 ].argmax(-1 )) cosine_sim1 = torch.cosine_similarity((inst_layer_outputs[:, 0 , -15 :-14 ].mean(dim=-2 )), (weighted_layer_outputs[:, 0 , -15 :-14 ].mean(dim=-2 ))).float ().cpu().numpy() cosine_sim2 = torch.cosine_similarity((inst_layer_outputs[:, 0 , -15 :-14 ].mean(dim=-2 )), (layer_outputs[:, 0 , -15 :-14 ].mean(dim=-2 ))).float ().cpu().numpy() cosine_sim3 = torch.cosine_similarity((weighted_layer_outputs[:, 0 , -15 :-14 ].mean(dim=-2 )), (weighted_inst_layer_outputs[:, 0 , -15 :-14 ].mean(dim=-2 ))).float ().cpu().numpy() print ("Similarity between w/ mask vs. w/ instruction\n" , cosine_sim1) print ("Similarity between Original vs. w/ instruction\n" , cosine_sim2) print ("Similarity between w/ mask vs. w/ mask + instruction\n" , cosine_sim3) (b1, _, b2), (r1, _, r2), (g1, _, g2) = sns.color_palette("Blues" , 3 ), sns.color_palette("Reds" , 3 ), sns.color_palette("Greens" , 3 ) plt.figure(figsize=(12 , 6 )) plt.xlabel("Layer No." , fontsize=18 ) plt.xticks(np.arange(1 , 32 , 4 ), labels=np.arange(0 , 32 , 4 ), fontsize=12 ) plt.xticks(np.arange(33 ), minor=True ) plt.ylabel("Cosine similarity of layer output" , fontsize=18 ) plt.yticks(fontsize=12 ) plt.plot(cosine_sim2, label="Original vs. w/ instruction" , c=b2) plt.plot(cosine_sim1, label=r"w/ $\mathcal{M}$ vs. w/ instruction" , c=r2) plt.plot(cosine_sim3, label=r"w/ $\mathcal{M}$ vs. w/ $\mathcal{M}$ + instruction" , c=g2) texts = []for i, (xi, yi1, yi2, yi3) in enumerate (zip (np.arange(33 ), cosine_sim1, cosine_sim2, cosine_sim3)): texts.append(plt.text(xi, yi1-0.05 if xi<16 else yi1+0.05 , weighted_layer_pred_tokens[i], fontsize=9 , ha="center" )) texts.append(plt.text(xi, yi2+0.05 if xi<16 else yi2-0.05 , layer_pred_tokens[i], fontsize=9 , ha="center" )) adjust_text(texts, avoid_self=False ) for (t1, t2), xi, yi1, yi2 in zip ([(texts[i], texts[i+1 ]) for i in range (0 , len (texts), 2 )], np.arange(33 ), cosine_sim1, cosine_sim2): x_text, y_text = t1.get_position() plt.plot([xi, x_text], [yi1, y_text+0.02 if xi<16 else yi1+0.04 ], color=r2, linestyle="-" , linewidth=0.5 ) x_text, y_text = t2.get_position() plt.plot([xi, x_text], [yi2, y_text-0.02 if xi<16 else yi2-0.02 ], color=b2, linestyle="-" , linewidth=0.5 ) plt.ylim(-0.05 , 1.1 ) plt.legend(fontsize=15 ) plt.tight_layout() plt.show()
Similarity between w/ mask vs. w/ instruction
[1. 0.78515625 0.76171875 0.78515625 0.7421875 0.75390625
0.7265625 0.7421875 0.6875 0.5859375 0.55078125 0.5859375
0.51953125 0.54296875 0.453125 0.51171875 0.59375 0.640625
0.7109375 0.73828125 0.73828125 0.765625 0.79296875 0.8125
0.82421875 0.84375 0.84765625 0.84765625 0.84375 0.8359375
0.83203125 0.8125 0.75 ]
Similarity between Original vs. w/ instruction
[1. 0.96484375 0.90234375 0.90234375 0.8359375 0.81640625
0.76171875 0.75 0.72265625 0.6328125 0.57421875 0.60546875
0.58203125 0.6953125 0.66015625 0.625 0.59765625 0.53515625
0.4765625 0.44921875 0.42578125 0.39257812 0.38867188 0.36914062
0.36914062 0.40234375 0.41601562 0.42773438 0.42578125 0.44726562
0.44921875 0.3984375 0.26367188]
Similarity between w/ mask vs. w/ mask + instruction
[1. 0.984375 0.95703125 0.9296875 0.90234375 0.84375
0.8125 0.80078125 0.7734375 0.74609375 0.640625 0.66015625
0.640625 0.640625 0.5703125 0.58984375 0.62890625 0.68359375
0.74609375 0.7578125 0.765625 0.78515625 0.81640625 0.84765625
0.86328125 0.87890625 0.87890625 0.8828125 0.88671875 0.88671875
0.88671875 0.87890625 0.84375 ]
逐步移除功能性注意力头后计算每层输出的平均相似度:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 with torch.no_grad(): lm_inputs_src_inst = tokenizer([LLAMA_TEMPLATE.format (src=input_str_inst)], add_special_tokens=False , return_tensors="pt" ).to(DEVICE) lm_inputs_src = tokenizer([LLAMA_TEMPLATE.format (src=input_str)], add_special_tokens=False , return_tensors="pt" ).to(DEVICE) original_output = model(**lm_inputs_src_inst, output_hidden_states=True , output_attentions=True ) original_layer_outputs = torch.stack(original_output.hidden_states) original_layer_token_logits = model.lm_head(original_layer_outputs[:, 0 , -1 ]).float ().cpu() mask_weight = weight_dict["en_zh" ].sigmoid().numpy() mask_weight[32 ::33 ] = 0 order = mask_weight.argsort()[32 :] full_weight = np.ones_like(weight_dict["en_zh" ].sigmoid().numpy()) test_tensor = torch.tensor(full_weight).unsqueeze(0 ).repeat(1 , 1 ).to(model.device) weighted_outputs, weighted_logits = [], [] test_order = order[:320 ]with torch.no_grad(): weighted_output = model(**lm_inputs_src, weight_tensor=test_tensor, output_hidden_states=True , output_attentions=True ) weighted_layer_outputs = torch.stack(weighted_output.hidden_states) weighted_layer_token_logits = model.lm_head(weighted_layer_outputs[:, 0 , -1 ]) weighted_outputs.append(weighted_layer_outputs[:, 0 , -1 ].float ().cpu()) weighted_logits.append(weighted_layer_token_logits.float ().cpu()) for head_idx in tqdm(test_order): test_tensor[0 , head_idx] = 0 with torch.no_grad(): weighted_output = model(**lm_inputs_src, weight_tensor=test_tensor, output_hidden_states=True , output_attentions=True ) weighted_layer_outputs = torch.stack(weighted_output.hidden_states) weighted_layer_token_logits = model.lm_head(weighted_layer_outputs[:, 0 , -1 ]) weighted_outputs.append(weighted_layer_outputs[:, 0 , -1 ].float ().cpu()) weighted_logits.append(weighted_layer_token_logits.float ().cpu()) weighted_outputs = torch.stack(weighted_outputs) weighted_logits = torch.stack(weighted_logits) layers = np.arange(32 ) colors = sns.color_palette("Blues" , n_colors=36 )[-32 :] plt.figure(figsize=(12 , 6 )) plt.xlabel("Removed heads" , fontsize=18 ) plt.xticks(np.arange(0 , 1025 , 64 ), labels=np.arange(0 , 1025 , 64 ), fontsize=12 ) plt.ylabel("Cosine similarity of FFN layer output" , fontsize=18 ) plt.yticks(fontsize=12 ) plt.ylim(-0.05 , 1.05 ) layer_cosine_sim = [torch.cosine_similarity(original_layer_outputs[l+1 , 0 , -1 ].squeeze(0 ).float ().cpu(), weighted_outputs[:, l+1 ]).float ().cpu().numpy() for l in range (len (layers))]for l in layers: plt.plot(layer_cosine_sim[l], c=colors[l], label=f"Layer {l} " ) norm = mpl.colors.Normalize(vmin=0 , vmax=32 ) sm = mpl.cm.ScalarMappable(cmap=mpl.colors.ListedColormap(colors), norm=norm) sm.set_array([]) cbar = plt.colorbar(sm, ax=plt.gca(), orientation="vertical" , pad=0.02 ) cbar.set_label("Layer No." , fontsize=12 ) cbar.set_ticks(np.arange(0 , 32 )+0.5 ) cbar.set_ticklabels(np.arange(0 , 32 ), fontsize=12 ) cbar.ax.tick_params(labelsize=12 ) plt.tight_layout() plt.show() plt.figure(figsize=(12 , 6 )) plt.xlabel("# of removed heads" , fontsize=12 ) plt.ylabel("Similarity" , fontsize=12 )for l in layers: plt.plot(torch.cosine_similarity(original_layer_token_logits[l+1 ].squeeze(0 ).float ().cpu(), weighted_logits[:, l+1 ]).float ().cpu().numpy(), c=colors[l], label=f"Layer {l} " ) plt.ylim(-0.05 , 1.05 ) plt.title("Cosine similarities: Layer lm_head logits" ) plt.tight_layout() plt.show()
100%|██████████| 320/320 [01:15<00:00, 4.24it/s]
跟踪特定 token 的 logits 值和排名如何随着注意力头的逐步移除而变化:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 lm_inputs_src = tokenizer([LLAMA_TEMPLATE.format (src=input_str)], add_special_tokens=False , return_tensors="pt" ).to(DEVICE) original_output_logits = model(**lm_inputs_src).logits[0 , -1 ].detach().cpu().numpy() topk_token_ids = original_output_logits.argsort()[::-1 ][:5 ] topk_tokens = [tokenizer.decode(token_id, skip_special_tokens=True ) for token_id in topk_token_ids] query_token_id = tokenizer.encode("我" , add_special_tokens=False )[0 ] pred_strs = [] pred_logits, pred_rank = [], [] test_order = orderwith torch.no_grad(): generate_ids = model.generate(**lm_inputs_src, max_new_tokens=20 , weight_tensor=test_tensor, do_sample=False ) pred_str = tokenizer.decode(generate_ids[0 ][lm_inputs_src.input_ids.size(1 ):], skip_special_tokens=True ) pred_strs.append(pred_str) pd.DataFrame(pred_strs).to_csv("all.csv" , encoding="utf-8" ) weighted_output_logits = model(**lm_inputs_src, weight_tensor=test_tensor).logits[0 , -1 ].cpu().numpy() pred_logits.append((weighted_output_logits[topk_token_ids].tolist() + [weighted_output_logits[query_token_id]])) pred_rank.append((weighted_output_logits.argsort()[::-1 ].argsort()[topk_token_ids].tolist() + [weighted_output_logits.argsort()[::-1 ].argsort()[query_token_id]])) for head_idx in tqdm(test_order): test_tensor[0 , head_idx] = 0 with torch.no_grad(): generate_ids = model.generate(**lm_inputs_src, max_new_tokens=20 , weight_tensor=test_tensor, do_sample=False ) pred_str = tokenizer.decode(generate_ids[0 ][lm_inputs_src.input_ids.size(1 ):], skip_special_tokens=True ) pred_strs.append(pred_str) pd.DataFrame(pred_strs).to_csv("remove.csv" , encoding="utf-8" ) weighted_output_logits = model(**lm_inputs_src, weight_tensor=test_tensor).logits[0 , -1 ].cpu().numpy() pred_logits.append((weighted_output_logits[topk_token_ids].tolist() + [weighted_output_logits[query_token_id]])) pred_rank.append((weighted_output_logits.argsort()[::-1 ].argsort()[topk_token_ids].tolist() + [weighted_output_logits.argsort()[::-1 ].argsort()[query_token_id]])) pred_logits = np.array(pred_logits) pred_rank = np.array(pred_rank)
100%|██████████| 1024/1024 [18:31<00:00, 1.08s/it]
进行重排序消融实验:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 rank_diff = -(pred_rank[1 :, -1 ] - pred_rank[:-1 , -1 ]) rank_diff_order = rank_diff[:320 ].argsort()[::-1 ] reorder_strs = [] reorder_rank = [] rerank_test_order = test_order[rank_diff_order]for head_idx in tqdm(rerank_test_order): test_tensor[0 , head_idx] = 0 with torch.no_grad(): generate_ids = model.generate(**lm_inputs_src, max_new_tokens=20 , weight_tensor=test_tensor, do_sample=False ) pred_str = tokenizer.decode(generate_ids[0 ][lm_inputs_src.input_ids.size(1 ):], skip_special_tokens=True ) reorder_strs.append(pred_str) pd.DataFrame(reorder_strs).to_csv("test.csv" , encoding="utf-8" ) weighted_output_logits = model(**lm_inputs_src, weight_tensor=test_tensor).logits[0 , -1 ].cpu().numpy() reorder_rank.append((weighted_output_logits.argsort()[::-1 ].argsort()[topk_token_ids].tolist() + [weighted_output_logits.argsort()[::-1 ].argsort()[query_token_id]])) reorder_rank = np.array(reorder_rank)
100%|██████████| 320/320 [06:43<00:00, 1.26s/it]
可视化结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 colors = sns.color_palette("Blues" , 5 )[::-1 ] red = sns.color_palette("Reds" , 1 )[-1 ] fig, ax = plt.subplots(figsize=(14 , 7 ))for i in range (6 ): if i < 5 : ax.plot(pred_logits[:, i], label=topk_tokens[i], color=colors[i]) else : ax.plot(pred_logits[:, i], label="我" , color=red) ax.set_xlabel("Remove heads" ) ax.set_ylabel("Logits" ) ax.set_xticks(range (0 , 1025 , 32 )) plt.legend() plt.show() fig, ax = plt.subplots(figsize=(14 , 7 ))for i in [4 ,3 ,2 ,1 ,0 ,5 ]: if i < 5 : ax.plot(pred_rank[:, i], label=topk_tokens[i], color=colors[i]) else : ax.plot(pred_rank[:, i], label="我" , color=red) ax.set_xlabel("Removed heads" , fontsize=18 ) ax.set_ylabel("Token rank" , fontsize=18 ) ax.set_xticks(np.arange(0 , 1025 , 64 ), labels=np.arange(0 , 1025 , 64 ), fontsize=12 ) ax.set_xticks(np.arange(32 , 1025 , 64 ), minor=True ) ax.set_yticks(np.arange(0 , 4001 , 500 ), labels=np.arange(0 , 4001 , 500 ), fontsize=12 ) ax.set_ylim(-100 , 4000 ) top_ax = ax.twiny() top_ax.set_xlim(ax.get_xlim()) top_ax.xaxis.tick_top() top_ax.tick_params(direction="in" , which="both" ) top_ax.set_xticks(np.arange(0 , 1025 , 64 ), labels=[]) top_ax.set_xticks(np.arange(32 , 1025 , 64 ), minor=True ) plt.gca().invert_yaxis() plt.legend([ax.lines[i] for i in [4 ,3 ,2 ,1 ,0 ,5 ]], [topk_tokens[i] for i in range (5 )] + ["我" ], fontsize=12 ) plt.show()
查看检查点结果:
1 2 3 4 5 6 7 8 9 step_weight_dict = {}for root, dir , files in os.walk("output/llama/xnli/en_zh" ): for file in files: if file.endswith(".safetensors" ): ckeckpoint = int (os.path.basename(root).split("-" )[-1 ]) path = os.path.join(root, file) step_weight_dict[ckeckpoint] = safetensors.torch.load_file(path)["tensor" ] step_weight_dict = dict (sorted (step_weight_dict.items()))len (step_weight_dict)
200
可视化注意力头激活情况:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 (b1, b2), (r1, r2) = sns.color_palette("Blues" , 2 ), sns.color_palette("Reds" , 2 ) step_weight_array = torch.stack([t.view(32 , 33 )[:, :-1 ].flatten().sigmoid() for t in step_weight_dict.values()]).numpy() plt.figure(figsize=(12 , 6 )) plt.axhline(y=0.5 , color="gray" , linestyle="--" )for i in range (1024 ): if (step_weight_array[:100 , i] < 0.5 ).all () or (step_weight_array[-1 , i] >= 0.5 and (step_weight_array[:60 , i] >= 0.5 ).any ()) or (0.3 < step_weight_array[-1 , i] <= 0.7 ): c = b1 if step_weight_array[-1 , i] >= 0.5 else r1 plt.plot(range (101 ), step_weight_array[:101 , i], linewidth=0.5 , alpha=0.1 , color=c) else : c = b2 if step_weight_array[-1 , i] >= 0.5 else r2 plt.plot(range (101 ), step_weight_array[:101 , i], linewidth=0.5 , alpha=1 , color=c) plt.xlabel("Training progress" , fontsize=18 ) plt.xticks(np.arange(0 , 101 , 20 ), labels=["0%" , "10%" , "20%" , "30%" , "40%" , "50%" ], fontsize=12 ) plt.ylabel("Head weight (sigmoid)" , fontsize=18 ) plt.ylim(-0.02 , 1.02 ) plt.tight_layout() plt.show()