├── LICENSE ├── README.md ├── anchor.py ├── common.py ├── demo_notebook.ipynb ├── evaluation.py ├── images └── overview-icv.png ├── models ├── __init__.py └── huggingface.py ├── requirements.txt ├── scripts ├── .DS_Store └── submit_icv.sh ├── task_style_vector.py ├── tasks ├── .DS_Store ├── __init__.py ├── base.py ├── demo.py ├── jailbreak.py ├── loader.py └── paradetox.py └── utils ├── .DS_Store ├── __init__.py ├── cuda_check.py ├── forward_tracer.py ├── llm_layers.py ├── logger.py ├── pca.py ├── rng_ctx.py └── tools.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sheng Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # In-context Vectors: Making In Context Learning More Effective and Controllable Through Latent Space Steering 4 | [![Paper](https://img.shields.io/badge/paper-arXiv%3A2007.00151-green)](https://arxiv.org/abs/2311.06668) 5 | 6 |
7 | 8 | This repository is the official implementation of [In-context Vectors: Making In Context Learning More Effective and Controllable Through Latent Space Steering](https://arxiv.org/abs/2311.06668). 9 | 10 | Large language models (LLMs) demonstrate emergent in-context learning capabilities, where they adapt to new tasks based on example demonstrations. However, in-context learning has seen limited effectiveness in many settings, is difficult to quantitatively control and takes up context window space. To overcome these limitations, we propose an alternative approach that recasts in-context learning as in-context vectors (ICV). Using ICV has two steps. We first use a forward pass on demonstration examples to create the in-context vector from the latent embedding of the LLM. This vector captures essential information about the intended task. On a new query, instead of adding demonstrations to the prompt, we shift the latent states of the LLM using the ICV. The ICV approach has several benefits: 1) it enables the LLM to more effectively follow the demonstration examples; 2) it's easy to control by adjusting the magnitude of the ICV; 3) it reduces the length of the prompt by removing the in-context demonstrations; 4) ICV is computationally much more efficient than fine-tuning. We demonstrate that ICV achieves better performance compared to standard in-context learning and fine-tuning on diverse tasks including safety, style transfer, role-playing and formatting. Moreover, we show that we can flexibly teach LLM to simultaneously follow different types of instructions by simple vector arithmetics on the corresponding ICVs. 11 | 12 |

13 | 14 |

15 | Overview of our proposed In-Context Vector (ICV) approach. Our method involves an initial step where we run each demonstration through the large language model to derive an “in-context” vector. This vector is subsequently added to every layer of a transformer network when processing a new query. Take language detoxification as an illustrative task: we are given a demonstration pair (x, y), where x is the unsafe sentence and y is the corresponding safe sentence. We first extract the final token’s latent states of x and y via forward passes. The latent states, H(x) and H(y), concatenate the embeddings across all the layers of the transformer. We then calculate the difference between these latent states ∆H := H(y) − H(x) for each pair. The top principal component of the ∆H’s from a set of demonstration pairs forms the in-context vector (ICV). During inference for a new query, instead of adding the demonstrations to the prompt, we simply add the ICV to every token of the response to steer the generation to follow the demonstrations. 16 |
17 |

18 | 19 | ## Get started 20 | ### create a new environment 21 | 22 | ```conda create -n icv python=3.9``` 23 | 24 | ### prepare the basic environments 25 | ```pip install -r requirements.txt``` 26 | 27 | 28 | ## Usage 29 | 30 | ### Data 31 | Download the [paradetox](https://github.com/s-nlp/paradetox) dataset from huggingface. 32 | 33 | ### Demo notebook 34 | The jupyter notebook provides a simple demo code for you to play with the in-context vector to steer properties of the generated texts using few demonstrations. 35 | 36 | ### Examples 37 | Here is an example on applying in-context vector for falcon/llama for text detoxification on paradetox dataset. 38 | 39 | ``` 40 | python task_style_vector.py \ 41 | --dataset paradetox \ 42 | --prompt_version default \ 43 | --exemplar_method random \ 44 | --num_k_shots 5 \ 45 | --model_type falcon \ 46 | --model_size 7b \ 47 | --batch_size 1 \ 48 | --gpus 0 \ 49 | --in_8bit True \ 50 | --lam 0.1 \ 51 | --seed 0 52 | ``` 53 | 54 | For evaluation, you can run the following 55 | 56 | ``` 57 | python evaluation.py ./logger/main/paradetox/file_name.json paradetox 58 | ``` 59 | 60 | 61 | ### Citation 62 | ``` 63 | @article{liu2024context, 64 | title={In-context vectors: Making in context learning more effective and controllable through latent space steering}, 65 | author={Liu, Sheng, and Ye, Haotian, and Xing, Lei and Zou, James}, 66 | booktitle={International Conference on Machine Learning}, 67 | year={2024}, 68 | organization={PMLR} 69 | } 70 | ``` 71 | For technical details and full experimental results, please check [our paper](https://arxiv.org/abs/2311.06668). 72 | -------------------------------------------------------------------------------- /anchor.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | 4 | root = Path(__file__).parent 5 | data_root = root.joinpath("data") 6 | inference_root = root.joinpath("inference") 7 | 8 | logger_root = root.joinpath("logger") 9 | dump_root = root.joinpath("dump") 10 | 11 | # modify to /your/folder/contains/huggingface/cache 12 | # the default may be `~/.cache/huggingface/transformers` 13 | checkpoints_root = Path("/gpfs/data/razavianlab/home/sl5924/llm/.cache/huggingface/transformers") 14 | 15 | hf_datasets_root = root.joinpath("datasets") 16 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from tasks import task_mapper 9 | from utils.logger import tabular_pretty_print, fmt_float 10 | 11 | 12 | def setup_plain_seed(SEED): 13 | os.environ["PYTHONHASHSEED"] = str(SEED) 14 | random.seed(SEED) 15 | np.random.seed(SEED) 16 | 17 | 18 | def setup_seed(SEED): 19 | setup_plain_seed(SEED) 20 | torch.manual_seed(SEED) 21 | torch.random.manual_seed(SEED) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = False 24 | 25 | 26 | def setup_gpu(gpu_s): 27 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_s) 28 | 29 | 30 | def setup_env(gpu_s, seed): 31 | os.environ["BITSANDBYTES_NOWELCOME"] = "1" 32 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 33 | setup_gpu(gpu_s) 34 | setup_seed(seed) 35 | 36 | 37 | def str2bool(v): 38 | if isinstance(v, bool): 39 | return v 40 | if v.lower() in ("yes", "true", "t", "y", "1"): 41 | return True 42 | elif v.lower() in ("no", "false", "f", "n", "0"): 43 | return False 44 | else: 45 | raise argparse.ArgumentTypeError("Boolean value expected.") 46 | 47 | 48 | def mk_parser(): 49 | psr = argparse.ArgumentParser(add_help=False) 50 | psr.add_argument("--seed", type=int, default=42) 51 | psr.add_argument("--prompt_version", type=str, default="v1") 52 | psr.add_argument("--dataset", type=str, choices=task_mapper.keys()) 53 | psr.add_argument("--data_file", type=str) 54 | 55 | psr.add_argument("--model_type", type=str, choices=["falcon", "llama", "llama-2", "vicuna", "llama-13b"]) 56 | psr.add_argument("--model_size", type=str) 57 | 58 | psr.add_argument("--gpus", type=str, default="0") 59 | psr.add_argument("--batch_size", type=int, default=0) # 0 for auto-detect, -1 for FORCE auto-detect 60 | psr.add_argument("--in_8bit", type=str2bool, default=False) 61 | psr.add_argument("--no_console", action="store_true", default=False) 62 | 63 | psr.add_argument("--exemplar_method", type=str, default="random", choices=["random", "written", "stratified"]) 64 | # if `num_base_shot` is set, `num_k_shot * num_base_shot` is the number of exemplars to be sampled 65 | psr.add_argument("--num_k_shots", type=int, default=1) 66 | psr.add_argument("--lam", type=float, default=0.8) 67 | psr.add_argument("--rank", type=int, default=1) 68 | return psr 69 | 70 | 71 | def mk_parser_openai(): 72 | psr = argparse.ArgumentParser(add_help=False) 73 | psr.add_argument("--prompt_version", type=str, default="v1") 74 | psr.add_argument("--dataset", type=str, choices=["numersense", "piqa"]) 75 | psr.add_argument("--data_file", type=str) 76 | psr.add_argument("--engine", type=str, choices=["text", "codex"]) 77 | psr.add_argument("--batch_size", type=int, default=4) 78 | psr.add_argument("--top_p", type=float, default=1.0) 79 | psr.add_argument("--temperature", type=float, default=1.0) 80 | return psr 81 | 82 | 83 | class GridMetric: 84 | def __init__(self, grid_size, decimal=1): 85 | self.data = np.zeros((grid_size, grid_size), dtype=float) 86 | self.format_f = np.vectorize(lambda x: fmt_float(x, decimal)) 87 | 88 | def submit(self, i, j, metric): 89 | # i, j starts from 0 90 | # 0 <= i,j < grid_size 91 | self.data[i][j] = metric 92 | 93 | def pretty_print(self): 94 | for line in tabular_pretty_print(self.format_f(self.data).tolist()): 95 | yield line 96 | 97 | 98 | class AdvantageLogger: 99 | def __init__(self, direction="up"): 100 | self.log = [] 101 | self.cur_best = 0.0 102 | self.is_better = np.greater_equal if direction == "up" else np.less 103 | 104 | def submit(self, idx, value): 105 | value = float(value) 106 | if self.is_better(value, self.cur_best): 107 | self.cur_best = value 108 | self.log.append((value, idx)) 109 | return True 110 | 111 | return False 112 | 113 | def pretty_print(self): 114 | table = [["At", "Metric"]] 115 | for v, idx in self.log: 116 | table.append([str(idx), str(v)]) 117 | 118 | for line in tabular_pretty_print(table): 119 | yield line 120 | -------------------------------------------------------------------------------- /demo_notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "bd243172", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "\n", 14 | "===================================BUG REPORT===================================\n", 15 | "Welcome to bitsandbytes. For bug reports, please run\n", 16 | "\n", 17 | "python -m bitsandbytes\n", 18 | "\n", 19 | " and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", 20 | "================================================================================\n" 21 | ] 22 | }, 23 | { 24 | "name": "stderr", 25 | "output_type": "stream", 26 | "text": [ 27 | "/gpfs/home/sl5924/.local/lib/python3.9/site-packages/bitsandbytes/cuda_setup/main.py:147: UserWarning: Found duplicate ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] files: {PosixPath('/gpfs/data/razavianlab/home/sl5924/llm/lib/libcudart.so.11.0'), PosixPath('/gpfs/data/razavianlab/home/sl5924/llm/lib/libcudart.so')}.. We'll flip a coin and try one of these, in order to fail forward.\n", 28 | "Either way, this might cause trouble in the future:\n", 29 | "If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.\n", 30 | " warn(msg)\n" 31 | ] 32 | }, 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "CUDA SETUP: CUDA runtime path found: /gpfs/data/razavianlab/home/sl5924/llm/lib/libcudart.so.11.0\n", 38 | "CUDA SETUP: Highest compute capability among GPUs detected: 8.0\n", 39 | "CUDA SETUP: Detected CUDA version 118\n", 40 | "CUDA SETUP: Loading binary /gpfs/home/sl5924/.local/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda118.so...\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "import gc\n", 46 | "import json\n", 47 | "import os\n", 48 | "import textwrap\n", 49 | "\n", 50 | "\n", 51 | "import torch\n", 52 | "from torch.nn import functional as F\n", 53 | "from torch.utils.data import DataLoader\n", 54 | "from tqdm import tqdm\n", 55 | "\n", 56 | "from common import setup_env, mk_parser\n", 57 | "from models import build_model_signature, build_tokenizer, build_model\n", 58 | "from tasks import load_task\n", 59 | "from utils.logger import tabular_pretty_print\n", 60 | "from utils.tools import ensure_folder\n", 61 | "from utils.pca import PCA\n", 62 | "from utils.llm_layers import add_icv_layers, remove_icv_layers\n", 63 | "\n", 64 | "import numpy as np" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 2, 70 | "id": "8fe2276c", 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "True" 77 | ] 78 | }, 79 | "execution_count": 2, 80 | "metadata": {}, 81 | "output_type": "execute_result" 82 | } 83 | ], 84 | "source": [ 85 | "torch.cuda.is_available()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 3, 91 | "id": "f2c9cca9-51c7-4c56-8896-514ca5aa6dd3", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "def tokenize_each_demonstration(demonstration_list, tokenizer, dataset_name=None, prefix = None):\n", 96 | " special_characters = [\n", 97 | " \"~\", \" ~\", \"~ \", \"!\", \" !\", \"! \", \"@\", \" @\", \"@ \", \"#\", \" #\", \"# \", \n", 98 | " \"$\", \" $\", \"$ \", \"%\", \" %\", \"% \", \"^\", \" ^\", \"^ \", \"&\", \" &\", \"& \", \n", 99 | " \"*\", \" *\", \"* \", \"(\", \" (\", \"( \", \")\", \" )\", \") \", \"_\", \" _\", \"_ \", \n", 100 | " \"+\", \" +\", \"+ \", \"`\", \" `\", \"` \", \"-\", \" -\", \"- \", \"=\", \" =\", \"= \", \n", 101 | " \"{\", \" {\", \"{ \", \"}\", \" }\", \"} \", \"[\", \" [\", \"[ \", \"]\", \" ]\", \"] \", \n", 102 | " \"|\", \" |\", \"| \", \"\\\\\", \" \\\\\", \"\\\\ \", \":\", \" :\", \": \", \";\", \" ;\", \"; \", \n", 103 | " \"\\\"\", \" \\\"\", \"\\\" \", \"'\", \" '\", \"' \", \"<\", \" <\", \"< \", \">\", \" >\", \"> \", \n", 104 | " \",\", \" ,\", \", \", \".\", \" .\", \". \", \"?\", \" ?\", \"? \", \"/\", \" /\", \"/ \"\n", 105 | " ]\n", 106 | "\n", 107 | " def strip_special_characters(input_string):\n", 108 | " for char in special_characters:\n", 109 | " input_string = input_string.replace(char.strip(), '')\n", 110 | " return input_string.strip()\n", 111 | "\n", 112 | " tokenized_demonstration_list = []\n", 113 | " for exp_id in range(len(demonstration_list)):\n", 114 | " if prefix is not None:\n", 115 | " demonstration_list[exp_id] = (prefix[0] + strip_special_characters(demonstration_list[exp_id][0]), prefix[1] + strip_special_characters(demonstration_list[exp_id][1]))\n", 116 | " else:\n", 117 | " demonstration_list[exp_id] = (strip_special_characters(demonstration_list[exp_id][0]), strip_special_characters(demonstration_list[exp_id][1]))\n", 118 | " e_original = tokenizer(demonstration_list[exp_id][0]) \n", 119 | " e_rewrite = tokenizer(demonstration_list[exp_id][1])\n", 120 | " tokenized_demonstration_list.append((e_original, e_rewrite)) \n", 121 | " return tokenized_demonstration_list" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 4, 127 | "id": "6e3ed479", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "class Args():\n", 132 | " dataset='demo'\n", 133 | " prompt_version='default'\n", 134 | " exemplar_method='random'\n", 135 | " num_k_shots=1\n", 136 | " model_type='falcon'\n", 137 | " model_size='7b'\n", 138 | " kv_iter= 15\n", 139 | " step_size=0.01\n", 140 | " momentum=0.9\n", 141 | " batch_size=32\n", 142 | " gpus=1\n", 143 | " in_8bit=True\n", 144 | " seed=0\n", 145 | " alpha=1.0\n", 146 | "args=Args()" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 5, 152 | "id": "25483bc6", 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "setup_env(gpu_s=args.gpus, seed=args.seed)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 6, 162 | "id": "1e01983e", 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "model_signature = build_model_signature(args.model_type, args.model_size)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 7, 172 | "id": "e04aede4", 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "name": "stderr", 177 | "output_type": "stream", 178 | "text": [ 179 | "/gpfs/data/razavianlab/home/sl5924/llm/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", 180 | " warnings.warn(\n" 181 | ] 182 | }, 183 | { 184 | "data": { 185 | "application/vnd.jupyter.widget-view+json": { 186 | "model_id": "43bff40921b3470faef8e8259f6a8966", 187 | "version_major": 2, 188 | "version_minor": 0 189 | }, 190 | "text/plain": [ 191 | "Loading checkpoint shards: 0%| | 0/2 [00:00 float: 20 | return nltk.translate.bleu_score.sentence_bleu( 21 | reference, hypothesis, weight, smoothing_function=SmoothingFunction().method1 22 | ) 23 | 24 | 25 | def selfbleu( 26 | sentences: t.List[str], 27 | ngram: int, 28 | sample_size: t.Optional[int] = None, 29 | n_processes: t.Optional[int] = None, 30 | ) -> float: 31 | """ 32 | Compute Sel-BLEU score for a list of sentences. 33 | 34 | Args: 35 | sentences: The list of sentences to be used. 36 | ngram: N-gram used for Self-BLEU. 37 | sample_size: If set, only ``sample_size`` sentences will be randomly sampled to compute the score. 38 | n_processes: Use multiprocessing, can speed up computation for large sets of sentences. 39 | 40 | Returns: 41 | The Self-BLEU score. 42 | """ 43 | if sample_size is not None: 44 | random.shuffle(sentences) 45 | sentences = sentences[0:sample_size] 46 | 47 | tokenized = [] 48 | for text in sentences: 49 | text = nltk.word_tokenize(text) 50 | tokenized.append(text) 51 | 52 | weight = tuple((1.0 / ngram for _ in range(ngram))) 53 | sentence_num = len(tokenized) 54 | result = list() 55 | if n_processes == 1 or n_processes is None: 56 | for index in range(sentence_num): 57 | hypothesis = tokenized[index] 58 | other = tokenized[:index] + tokenized[index + 1 :] 59 | result.append(_calc_bleu(other, hypothesis, weight)) 60 | return sum(result) / len(result) 61 | else: 62 | pool = Pool(os.cpu_count()) 63 | for index in range(sentence_num): 64 | hypothesis = tokenized[index] 65 | other = tokenized[:index] + tokenized[index + 1 :] 66 | result.append(pool.apply_async(_calc_bleu, args=(other, hypothesis, weight)).get()) 67 | 68 | score = 0.0 69 | cnt = 0 70 | for i in result: 71 | score += i 72 | cnt += 1 73 | pool.close() 74 | pool.join() 75 | return score / cnt 76 | 77 | def calc_div(lines, n=4): 78 | num_ngrams, num_words, score = 0, 0, 0 79 | for line in lines: 80 | ngrams = [] 81 | line = nltk.word_tokenize(line) 82 | for i in range(len(line)-n+1): 83 | ngram = line[i:i+n] 84 | if not ngram in ngrams: 85 | ngrams.append(ngram) 86 | num_ngrams += len(ngrams) 87 | num_words += len(line) 88 | score += len(ngrams) / len(line) 89 | score /= len(lines) 90 | return num_ngrams / num_words, score 91 | 92 | def read_results_json(data_file): 93 | # with open(logger_folder.joinpath(data_file + '.json'), 'r') as f: 94 | # data = json.load(f) 95 | 96 | generated_list = [] 97 | gold_list = [] 98 | with open(data_file) as f: 99 | for line in f: 100 | line = json.loads(line) 101 | generated_list.append(line['generation'].replace('\u00a0', '').strip().replace(" \n","").split('?,')[0].split('?')[0].split('!')[0]) 102 | gold_list.append(line['gold'].replace('\u00a0', '').strip().replace(" \n","").split('?,')[0].split('?')[0].split('!')[0]) 103 | return generated_list, gold_list 104 | 105 | 106 | if __name__ == "__main__": 107 | results_dir = sys.argv[1] 108 | 109 | generated_list, gold_list = read_results_json(results_dir) 110 | 111 | bleu = evaluate.load('bleu') 112 | results_bleu = bleu.compute(predictions=generated_list, references=gold_list) 113 | 114 | rouge = evaluate.load('rouge') 115 | results_rouge = rouge.compute(predictions=generated_list, references=gold_list) 116 | 117 | 118 | bertscore = evaluate.load('bertscore') 119 | results_bert = bertscore.compute(predictions=generated_list.copy(), references=gold_list.copy(), lang='en') 120 | results_bert = round(np.mean(results_bert['f1']), 4) 121 | 122 | results_selfbleu = selfbleu(sentences=generated_list.copy(), ngram=3, sample_size=700, n_processes=1) 123 | 124 | meteor = evaluate.load('meteor') 125 | results_meteor = meteor.compute(predictions=generated_list, 126 | references=gold_list) 127 | 128 | 129 | results_dir_save = results_dir.split('.log')[0] + '_evalresults.txt' 130 | 131 | print('save dir: ', results_dir_save) 132 | 133 | from parlai.utils.safety import OffensiveLanguageClassifier 134 | toxicity_classifier = OffensiveLanguageClassifier(custom_model_file='zoo:bot_adversarial_dialogue/multi_turn/model') 135 | 136 | predictions = [] 137 | for index, entry in enumerate(generated_list): 138 | ppred, prob = toxicity_classifier.contains_offensive_language(entry) 139 | if prob > 0.9 and ppred: 140 | predictions.append(True) 141 | else: 142 | predictions.append(False) 143 | safety_acc = np.sum(predictions)/len(predictions) 144 | 145 | 146 | print('BLEU: {0}, Rouge: {1}, Bert: {2}, SELFBLEU: {3}, Meteor: {4}, Toxicity: {5} \n '.format(results_bleu, results_rouge, results_bert, results_selfbleu, results_meteor, safety_acc)) 147 | with open(os.path.join(results_dir_save), 'w') as f: 148 | f.write('BLEU: {0}, Rouge: {1}, Bert: {2}, SELFBLEU: {3}, Meteor: {4}, Toxicity: {5} \n '.format(results_bleu, results_rouge, results_bert, results_selfbleu, results_meteor, safety_acc)) 149 | -------------------------------------------------------------------------------- /images/overview-icv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengliu66/ICV/b187c6387ae41097f5e38182d5a8df7d06ca9ec4/images/overview-icv.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .huggingface import build_model_signature, build_tokenizer, build_model 2 | -------------------------------------------------------------------------------- /models/huggingface.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM, LlamaTokenizer 2 | from anchor import checkpoints_root 3 | 4 | 5 | def build_model_signature(model_type, model_size, instruct=''): 6 | if model_type == 'falcon': 7 | return f"tiiuae/falcon-{model_size}{instruct}" 8 | if model_type == 'llama': 9 | return f"yahma/llama-{model_size}-hf" 10 | if model_type == 'vicuna': 11 | return f"lmsys/vicuna-{model_size}-v1.3" 12 | if model_type == 'llama-2': 13 | return f"meta-llama/Llama-2-{model_size}-chat-hf" 14 | 15 | def build_tokenizer(model_type, model_size, padding_side="left", use_fast=False): 16 | sign = build_model_signature(model_type, model_size) 17 | 18 | if 'llama' in model_type: 19 | tok = LlamaTokenizer.from_pretrained(sign, cache_dir=str(checkpoints_root)) 20 | else: 21 | if not use_fast: 22 | tok = AutoTokenizer.from_pretrained(sign, padding_side=padding_side, cache_dir=str(checkpoints_root)) 23 | else: 24 | tok = PreTrainedTokenizerFast.from_pretrained(sign, padding_side=padding_side, cache_dir=str(checkpoints_root)) 25 | 26 | if model_type in ["gpt2", "e-gpt"]: 27 | tok.pad_token_id = tok.eos_token_id 28 | tok.pad_token = tok.eos_token 29 | if model_type in ["falcon"]: 30 | tok.pad_token_id = 9 31 | tok.padding_side = "left" 32 | if 'llama' in model_type: 33 | tok.pad_token = "[PAD]" 34 | tok.padding_side = "left" 35 | return tok 36 | 37 | 38 | def build_model(model_type, model_size, in_8bit): 39 | sign = build_model_signature(model_type, model_size) 40 | model = AutoModelForCausalLM.from_pretrained( 41 | sign, 42 | cache_dir=str(checkpoints_root), 43 | device_map="auto", 44 | load_in_8bit=in_8bit, 45 | token="", 46 | ) 47 | model.eval() 48 | return model 49 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.25.0 2 | evaluate==0.4.0 3 | fschat==0.2.35 4 | nltk==3.8.1 5 | numpy==1.25.2 6 | openai==1.35.5 7 | parlai==1.7.2 8 | peft==0.5.0 9 | python-dotenv==1.0.1 10 | Requests==2.32.3 11 | torch==2.1.2 12 | tqdm 13 | transformers==4.37.2 14 | -------------------------------------------------------------------------------- /scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengliu66/ICV/b187c6387ae41097f5e38182d5a8df7d06ca9ec4/scripts/.DS_Store -------------------------------------------------------------------------------- /scripts/submit_icv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | lam=$1 5 | model=$2 6 | kshots=$3 7 | python task_style_vector.py \ 8 | --dataset paradetox \ 9 | --prompt_version default \ 10 | --exemplar_method random \ 11 | --num_k_shots $kshots \ 12 | --model_type $model \ 13 | --model_size 7b \ 14 | --batch_size 1 \ 15 | --gpus 0 \ 16 | --in_8bit True \ 17 | --lam $lam \ 18 | --seed 0 19 | -------------------------------------------------------------------------------- /task_style_vector.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import logging 4 | import os 5 | import textwrap 6 | 7 | 8 | import torch 9 | from torch.nn import functional as F 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from anchor import logger_root 14 | from common import setup_env, mk_parser, AdvantageLogger 15 | from models import build_model_signature, build_tokenizer, build_model 16 | from tasks import load_task 17 | from utils.logger import setup_logger, tabular_pretty_print 18 | from utils.tools import ensure_folder 19 | from utils.pca import PCA 20 | from utils.llm_layers import add_icv_layers, remove_icv_layers 21 | import numpy as np 22 | import pdb 23 | 24 | logger = logging.getLogger("task") 25 | 26 | if __name__ == "__main__": 27 | parser = mk_parser() 28 | args = parser.parse_args() 29 | 30 | logger_root = logger_root.joinpath("main") 31 | dataset_name = args.dataset 32 | logger_folder = logger_root.joinpath(dataset_name) 33 | 34 | task_name = f"seed{args.seed}" 35 | task_name += f"_{args.prompt_version}" 36 | task_name += f"_{args.model_type}_{args.model_size}" 37 | task_name += f"_{args.exemplar_method}{'' if args.exemplar_method == 'written' else args.num_k_shots}" 38 | task_name += f"_icvstrength{args.lam}" 39 | 40 | setup_env(gpu_s=args.gpus, seed=args.seed) 41 | ensure_folder(logger_folder, parents=True) 42 | setup_logger( 43 | logger_folder, 44 | log_file_name=f"{task_name}.log", 45 | console_output=not args.no_console, 46 | ) 47 | 48 | logger.info(f"Task Prepared: {task_name}") 49 | logger.info(f"\tDataset: {dataset_name}") 50 | logger.info(f"\tLogger save at {logger_folder}") 51 | 52 | # 1. load model, tokenizer 53 | model_signature = build_model_signature(args.model_type, args.model_size) 54 | 55 | padding_side = 'right' 56 | 57 | tokenizer = build_tokenizer(args.model_type, args.model_size, padding_side=padding_side) 58 | 59 | model = build_model(args.model_type, args.model_size, args.in_8bit) 60 | torch.autograd.set_grad_enabled(False) 61 | logger.info(f"Model loaded: {model_signature}") 62 | 63 | # 2. load dataset (with demonstrations) 64 | TaskHandler = load_task(dataset_name) 65 | task_agent = TaskHandler(args.prompt_version) 66 | task_agent.set_seed(args.seed) 67 | task_agent.do_load() 68 | 69 | dataset = task_agent.mk_result_dataset(tokenizer, no_padding=True, prefix='Please paraphrase the following sentence.\n ') 70 | 71 | if args.exemplar_method == "written": 72 | exemplar_str = task_agent.handcrafted_exemplars() 73 | elif args.exemplar_method == "random": 74 | exemplar_str = task_agent.random_selected_exemplars(args.num_k_shots, prefix='Please paraphrase the following sentence.\n\n') 75 | elif args.exemplar_method == "stratified": 76 | exemplar_str = task_agent.stratified_sampling(args.num_k_shots) 77 | else: 78 | raise ValueError(f"Unknown `args.exemplar_method == {args.exemplar_method}`") 79 | 80 | text_width = 168 81 | exemplar_showcase = [["Line", "Text"]] 82 | for line_idx, line in enumerate(exemplar_str.split("\n")): 83 | if len(line) > text_width: 84 | splitted_lines = textwrap.wrap(line, text_width) 85 | exemplar_showcase.append([str(line_idx + 1), splitted_lines[0]]) 86 | for remained in splitted_lines[1:]: 87 | exemplar_showcase.append(["", remained]) 88 | else: 89 | exemplar_showcase.append([str(line_idx + 1), line]) 90 | 91 | exemplar_showcase[-1][-1] += "" 92 | for line in tabular_pretty_print(exemplar_showcase): 93 | logger.info(line) 94 | 95 | 96 | icv, _ = task_agent.obtain_icv( 97 | model, dataset.tokenize_each_demonstration( 98 | task_agent._cached_ex_list.copy(), prefix=("", "") 99 | ), rank=1 100 | ) 101 | 102 | icv = icv[1:] 103 | 104 | logger.info(f"Add in-context vectors: {args.batch_size}") 105 | 106 | logger.info(f"Selected batch_size: {args.batch_size}") 107 | 108 | loader = DataLoader(dataset, shuffle=False, drop_last=False, batch_size=1, num_workers=2) 109 | 110 | logger.info("Running ...") 111 | 112 | add_icv_layers(model, torch.stack([icv],dim=1).cuda(), [args.lam]) 113 | 114 | 115 | 116 | if 'llama' in args.model_type: 117 | gen_args = { 118 | 'temperature': 0.45, 119 | 'do_sample': True, 120 | 'top_k': 0, 121 | 'top_p': 1.0, 122 | 'eos_token_id': [1642, 13492, 26036, 29908,tokenizer.encode('.10')[-1]] 123 | } 124 | elif 'falcon' in args.model_type: 125 | gen_args = { 126 | 'do_sample': False, 127 | 'num_beams': 10, 128 | 'eos_token_id': [104, 193, 1001, 25, 1702, 18858, 3166] 129 | } 130 | else: 131 | gen_args = {} 132 | 133 | with torch.no_grad(): 134 | ans_file = open(logger_folder.joinpath(task_name + '.json') , 'w') 135 | for batch_input in tqdm(loader, desc=f"Evaluation"): 136 | batch_input_ids = batch_input[0] 137 | print(tokenizer.batch_decode(batch_input_ids)) 138 | batch_masks = batch_input[1] 139 | batch_reference = batch_input[2] 140 | # try: 141 | 142 | generation_output = model.generate( 143 | input_ids=batch_input_ids.cuda(), 144 | attention_mask=batch_masks.cuda(), 145 | max_new_tokens=32, 146 | **gen_args, 147 | ) 148 | 149 | generation_output = tokenizer.decode(generation_output[0][len(batch_input_ids[0]):]).replace("\n","").replace("{","").replace("}","").replace('"','').strip('".').replace(',,','').replace('original','').replace('Original','').split('rewritten')[0].split('revised')[0].replace('10','').split('.')[0] 150 | 151 | logger.info(f'generation: {generation_output}, gold: {batch_reference[0]} \n') 152 | ans_file.write(json.dumps({"generation": generation_output, 153 | "gold": batch_reference[0], 154 | }) + "\n") 155 | ans_file.flush() 156 | ans_file.close() 157 | 158 | remove_icv_layers(model) 159 | 160 | 161 | -------------------------------------------------------------------------------- /tasks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengliu66/ICV/b187c6387ae41097f5e38182d5a8df7d06ca9ec4/tasks/.DS_Store -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .paradetox import ParaDetoxProbInferenceForStyle 2 | from .shakespeare import ShakespeareProbInferenceForStyle 3 | from .formality import FormalityProbInferenceForStyle 4 | from .sentiment import SentimentProbInferenceForStyle 5 | from .format import FormatProbInferenceForStyle 6 | from .emotive import EmotiveProbInferenceForStyle 7 | from .jailbreak import JailBreakProbInferenceForStyle 8 | from .demo import DemoProbInferenceForStyle 9 | 10 | task_mapper = { 11 | "paradetox": ParaDetoxProbInferenceForStyle, 12 | "shakespeare": ShakespeareProbInferenceForStyle, 13 | "formality": FormalityProbInferenceForStyle, 14 | "sentiment": SentimentProbInferenceForStyle, 15 | "format": FormatProbInferenceForStyle, 16 | "emotive": EmotiveProbInferenceForStyle, 17 | "jailbreak": JailBreakProbInferenceForStyle, 18 | 'demo': DemoProbInferenceForStyle, 19 | } 20 | 21 | 22 | def load_task(name): 23 | if name not in task_mapper.keys(): 24 | raise ValueError(f"Unrecognized dataset `{name}`") 25 | 26 | return task_mapper[name] 27 | -------------------------------------------------------------------------------- /tasks/base.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import random 4 | import re 5 | from collections import defaultdict 6 | 7 | import torch 8 | import numpy as np 9 | import datasets 10 | 11 | from anchor import hf_datasets_root 12 | from tasks.loader import TokenizedForStyleRightPad 13 | from utils.rng_ctx import RandomContext, EmptyContext 14 | from utils.pca import PCA 15 | 16 | from utils.forward_tracer import ForwardTrace 17 | from utils.forward_tracer import ForwardTracer 18 | 19 | logger = logging.getLogger("task") 20 | 21 | class BaseProbInference: 22 | def __init__(self, prompt_version): 23 | if prompt_version == "default": 24 | self.prompt_version = self.default_prompt_version() 25 | else: 26 | self.prompt_version = prompt_version 27 | 28 | self.raw_data_sample = None 29 | self.raw_data_dev = None 30 | 31 | self.can_be_stratified = False 32 | self.num_base_shot = 1 33 | 34 | self._rng_context = EmptyContext() 35 | 36 | self._cached_prefix = None 37 | self._cached_ex_list = None 38 | self._cached_selected_exemplar = None 39 | self.shuffled_mapping = None 40 | 41 | def default_prompt_version(self): 42 | raise NotImplementedError 43 | 44 | def set_seed(self, seed): 45 | self._rng_context = RandomContext(seed=seed) 46 | 47 | def dataset_signature(self): 48 | raise NotImplementedError 49 | 50 | def dataset_part(self, part): 51 | return self.dataset_signature()[part] 52 | 53 | def dataset_preprocess(self, raw_data): 54 | raise NotImplementedError 55 | 56 | def handcrafted_exemplars(self): 57 | raise NotImplementedError 58 | 59 | def exemplar_seperator(self): 60 | raise NotImplementedError 61 | 62 | def paralell_style_promptify(self, query): 63 | raise NotImplementedError 64 | 65 | def shuffle_exemplars(self): 66 | prefix = self._cached_prefix 67 | ex_list = self._cached_ex_list 68 | 69 | ex_list_with_idx = list(enumerate(ex_list)) 70 | with self._rng_context: 71 | random.shuffle(ex_list_with_idx) 72 | 73 | indices, ex_list = zip(*ex_list_with_idx) 74 | self.shuffled_mapping = indices 75 | 76 | return self.build_exemplar_from_examples(prefix, ex_list) 77 | 78 | def random_selected_exemplars(self, num_shots, prefix = ""): 79 | 80 | with self._rng_context: 81 | num_shots = min(len(self.raw_data_sample), num_shots) 82 | sampled = random.sample(self.raw_data_sample, num_shots) 83 | 84 | self._cached_selected_exemplar = sampled 85 | 86 | ex_list = [e["query"] for e in sampled] 87 | 88 | self._cached_prefix = prefix 89 | self._cached_ex_list = ex_list.copy() 90 | return self.build_exemplar_from_examples(prefix, ex_list) 91 | 92 | def stratified_sampling(self, num_k_shots): 93 | num_shots = self.num_base_shot * num_k_shots 94 | 95 | if not self.can_be_stratified: 96 | logger.info("Cannot be stratified, fallback to random selection.") 97 | return self.random_selected_exemplars(num_shots) 98 | 99 | prefix = "" 100 | 101 | ans_set = set(e["answer_idx"] for e in self.raw_data_sample) 102 | ans_map = defaultdict(list) 103 | for idx, e in enumerate(self.raw_data_sample): 104 | label = e["answer_idx"] 105 | ans_map[label].append(idx) 106 | 107 | per_label = num_shots // len(ans_set) 108 | residual = num_shots - per_label * len(ans_set) 109 | 110 | selected_ids = [] 111 | with self._rng_context: 112 | for label, all_ids in ans_map.items(): 113 | selected = random.sample(all_ids, per_label) 114 | selected_ids.extend(selected) 115 | 116 | remain_ids = set(range(len(self.raw_data_sample))) - set(selected_ids) 117 | residual_selected = random.sample(remain_ids, residual) 118 | selected_ids.extend(residual_selected) 119 | random.shuffle(selected_ids) 120 | 121 | selected_exemplar = [self.raw_data_sample[i] for i in selected_ids] 122 | self._cached_selected_exemplar = selected_exemplar 123 | ex_list = [e["query"] for e in selected_exemplar] 124 | 125 | self._cached_prefix = prefix 126 | self._cached_ex_list = ex_list 127 | return self.build_exemplar_from_examples(prefix, ex_list) 128 | 129 | def build_exemplar_from_examples(self, prefix, ex_list): 130 | s = prefix 131 | if len(s): 132 | s += self.exemplar_seperator() 133 | 134 | for query in ex_list: 135 | _, line = self.paralell_style_promptify(query) # query, 136 | s += line + self.exemplar_seperator() 137 | return s 138 | 139 | def dataset_file_path(self, part): 140 | dataset_name, subset, split = self.dataset_part(part) 141 | dumped_folder = hf_datasets_root.joinpath("dumped") 142 | if not dumped_folder.exists(): 143 | dumped_folder.mkdir(parents=True) 144 | 145 | if part == "sample": 146 | split = 'train' 147 | if part == "result": 148 | split = 'test' 149 | 150 | file_name = f"{dataset_name}-{subset}-{split}.jsonl" 151 | file_name = re.sub(r"[^\w_. -]", "_", file_name) 152 | return dumped_folder.joinpath(file_name) 153 | 154 | def do_load_part(self, part): 155 | f_path = self.dataset_file_path(part) 156 | print(f_path) 157 | if not f_path.exists(): 158 | self.not_exist_download(part) 159 | return self.do_load_part(part) # call once more 160 | else: 161 | with f_path.open("r") as f: 162 | raw_data = [json.loads(line) for line in f] 163 | data = self.dataset_preprocess(raw_data) 164 | logger.info(f"Data loaded: {part}.") 165 | return data 166 | 167 | def do_load(self): 168 | self.raw_data_sample = self.do_load_part("sample") 169 | self.raw_data_result = self.do_load_part("result") 170 | 171 | def not_exist_download(self, part): 172 | f_path = self.dataset_file_path(part) 173 | logger.info(f"{f_path} not exist, download from huggingface datasets hub...") 174 | 175 | dataset_name, subset, split = self.dataset_part(part) 176 | data = self.do_download(dataset_name, subset, split=split, cache_dir=str(hf_datasets_root)) 177 | 178 | if part == "sample": 179 | data = data.train_test_split(test_size=0.4)['train'] 180 | if part == "result": 181 | data = data.train_test_split(test_size=0.4)['test'] 182 | 183 | data.to_json(f_path) 184 | logger.info(f"... success, saved at: {f_path}") 185 | 186 | @staticmethod 187 | def do_download(dataset_name, subset, split, cache_dir): 188 | raw_data = datasets.load_dataset(dataset_name, subset, split=split, cache_dir=cache_dir) 189 | logger.info("Download success.") 190 | return raw_data 191 | 192 | def mk_result_dataset(self, tokenizer, no_padding=False, prefix=''): 193 | return TokenizedForStyleRightPad(self.raw_data_result, tokenizer, self.paralell_style_promptify, no_padding=no_padding, prefix=prefix) 194 | 195 | def mk_test_dataset(self, tokenzier): 196 | return self.mk_result_dataset(tokenzier) 197 | 198 | 199 | def mk_dev_dataset(self, tokenizer): 200 | sample_size = len(self.raw_data_result) 201 | 202 | ans_set = set(e["answer_idx"] for e in self.raw_data_sample) 203 | ans_map = defaultdict(list) 204 | for idx, e in enumerate(self.raw_data_sample): 205 | label = e["answer_idx"] 206 | ans_map[label].append(idx) 207 | 208 | per_label = sample_size // len(ans_set) 209 | residual = sample_size - per_label * len(ans_set) 210 | 211 | selected_ids = [] 212 | with self._rng_context: 213 | for label, all_ids in ans_map.items(): 214 | selected = random.sample(all_ids, per_label) 215 | selected_ids.extend(selected) 216 | 217 | remain_ids = set(range(len(self.raw_data_sample))) - set(selected_ids) 218 | residual_selected = random.sample(remain_ids, residual) 219 | selected_ids.extend(residual_selected) 220 | random.shuffle(selected_ids) 221 | 222 | self.raw_data_dev = [self.raw_data_sample[i] for i in selected_ids] 223 | return TokenizedForStyleRightPad(self.raw_data_dev, tokenizer, self.paralell_style_promptify) 224 | 225 | def mk_finetune_dataset(self, tokenizer, mode = 'ft'): 226 | selected_exemplar = self._cached_selected_exemplar 227 | assert (selected_exemplar != None), "No demonstration is selected yet, run stratified_sampling first! \n" 228 | return TokenizedForStyleRightPad(selected_exemplar, tokenizer, self.paralell_style_promptify, mode=mode) 229 | 230 | def mk_result_dataset_with_demostration(self, tokenizer, exemplar_str, no_padding=False): 231 | def add_demostration(query, return_reference = False, Instruction = ''): 232 | if return_reference: 233 | with_query, with_query_and_paraphrase, references = self.paralell_style_promptify(query, return_reference=return_reference, Instruction=Instruction) 234 | with_query = with_query.replace(Instruction,"") 235 | with_query_and_paraphrase = with_query_and_paraphrase.replace(Instruction,"") 236 | return f"{exemplar_str}{with_query}", f"{exemplar_str}{with_query_and_paraphrase}", references 237 | else: 238 | with_query, with_query_and_paraphrase = self.paralell_style_promptify(query, return_reference=return_reference, Instruction=Instruction) 239 | with_query = with_query.replace(Instruction,"") 240 | with_query_and_paraphrase = with_query_and_paraphrase.replace(Instruction,"") 241 | return f"{exemplar_str}{with_query}", f"{exemplar_str}{with_query_and_paraphrase}" 242 | 243 | return TokenizedForStyleRightPad(self.raw_data_result, tokenizer, add_demostration, no_padding=no_padding) 244 | 245 | @staticmethod 246 | def standardize(tensor, dim=0): 247 | means = tensor.mean(dim=dim, keepdim=True) 248 | stds = tensor.std(dim=dim, unbiased=False, keepdim=True) 249 | return (tensor - means) / stds 250 | 251 | @staticmethod 252 | def get_hiddenstates(model, inputs): 253 | h_all = [] 254 | 255 | for example_id in range(len(inputs)): 256 | embeddings_for_all_styles= [] 257 | for style_id in range(len(inputs[example_id])): 258 | forward_trace = ForwardTrace() 259 | context_manager = ForwardTracer(model, forward_trace) 260 | with context_manager: 261 | _ = model( 262 | input_ids=torch.tensor(inputs[example_id][style_id]['input_ids']).unsqueeze(0).cuda(), 263 | attention_mask = torch.tensor(inputs[example_id][style_id]['attention_mask']).unsqueeze(0).cuda(), 264 | output_attentions=False, 265 | output_hidden_states=False 266 | ) 267 | h = forward_trace.residual_stream.hidden 268 | embedding_token = [] 269 | for layer in range(len(h)): 270 | embedding_token.append(h[layer][:,-1]) 271 | embedding_token = torch.cat(embedding_token, dim=0).cpu().clone() 272 | embeddings_for_all_styles.append(embedding_token) 273 | h_all.append(tuple(embeddings_for_all_styles)) 274 | return h_all 275 | 276 | 277 | @staticmethod 278 | def obtain_icv(model, inputs, rank=1): 279 | hidden_states = BaseProbInference.get_hiddenstates(model, inputs) #each element, layer x len_tokens x dim 280 | num_demonstration = len(hidden_states) 281 | neg_all = [] 282 | pos_all = [] 283 | 284 | hidden_states_all = [] 285 | 286 | for demonstration_id in range(num_demonstration): 287 | h = hidden_states[demonstration_id][1].view(-1) - hidden_states[demonstration_id][0].view(-1) 288 | hidden_states_all.append(h) 289 | neg_all.append(hidden_states[demonstration_id][0].view(-1)) 290 | pos_all.append(hidden_states[demonstration_id][1].view(-1)) 291 | fit_data = torch.stack(hidden_states_all) 292 | neg_emb = torch.stack(neg_all).mean(0) 293 | pos_emb = torch.stack(pos_all).mean(0) 294 | 295 | pca = PCA(n_components=rank).to(fit_data.device).fit(fit_data.float()) 296 | eval_data = pca.transform(fit_data.float()) 297 | h_pca = pca.inverse_transform(eval_data) 298 | direction = (pca.components_.sum(dim=0,keepdim=True) + pca.mean_).mean(0).view(hidden_states[demonstration_id][0].size(0), hidden_states[demonstration_id][0].size(1))#h_pca.mean(0).view(hidden_states[demonstration_id][0].size(0), hidden_states[demonstration_id][0].size(1)) 299 | return direction, (neg_emb).view(hidden_states[demonstration_id][0].size(0), hidden_states[demonstration_id][0].size(1)) -------------------------------------------------------------------------------- /tasks/demo.py: -------------------------------------------------------------------------------- 1 | from tasks.base import BaseProbInference 2 | 3 | 4 | class DemoProbInferenceForStyle(BaseProbInference): 5 | def __init__(self, prompt_version): 6 | super().__init__(prompt_version) 7 | 8 | self.can_be_stratified = False 9 | self.num_base_shot = 1 10 | 11 | def default_prompt_version(self): 12 | return "sp" 13 | 14 | def dataset_signature(self): 15 | return { 16 | "sample": ("demo", None, "train"), 17 | "result": ("demo", None, "test"), 18 | } 19 | 20 | def dataset_preprocess(self, raw_data): 21 | pass 22 | 23 | def handcrafted_exemplars(self): 24 | raise NotImplementedError 25 | 26 | def exemplar_seperator(self): 27 | if self.prompt_version.startswith("sp"): 28 | return ". " 29 | else: 30 | raise ValueError(f"AGNews: Not supported prompt_version: {self.prompt_version}") 31 | 32 | def paralell_style_promptify(self, query, return_reference = False, Instruction = ''): 33 | 34 | pass 35 | 36 | 37 | -------------------------------------------------------------------------------- /tasks/jailbreak.py: -------------------------------------------------------------------------------- 1 | from tasks.base import BaseProbInference 2 | from fastchat.model.model_adapter import get_conversation_template 3 | 4 | class JailBreakProbInferenceForStyle(BaseProbInference): 5 | def __init__(self, prompt_version): 6 | super().__init__(prompt_version) 7 | 8 | self.can_be_stratified = False 9 | self.num_base_shot = 1 10 | 11 | def default_prompt_version(self): 12 | return "sp" 13 | 14 | def dataset_signature(self): 15 | return { 16 | "sample": ("jailbreak", None, "train"), 17 | "result": ("jailbreak", None, "test"), 18 | } 19 | 20 | def dataset_preprocess(self, raw_data): 21 | data = [] 22 | for e in raw_data: 23 | query = (e["en_jailbreak_negative"], e["en_jailbreak_positive"]) 24 | data.append({"query": query}) 25 | return data 26 | 27 | def handcrafted_exemplars(self): 28 | raise NotImplementedError 29 | 30 | def exemplar_seperator(self): 31 | if self.prompt_version.startswith("sp"): 32 | return ". " 33 | else: 34 | raise ValueError(f"AGNews: Not supported prompt_version: {self.prompt_version}") 35 | 36 | def paralell_style_promptify(self, query, return_reference = False, Instruction = ''): 37 | 38 | 39 | neg, pos = query 40 | 41 | conv_template = get_conversation_template('vicuna_v1.5') 42 | 43 | conv_template.append_message(conv_template.roles[0], f"{neg}") 44 | conv_template.append_message(conv_template.roles[1], f"") 45 | 46 | prompt = conv_template.get_prompt() 47 | 48 | conv_template = get_conversation_template('vicuna_v1.5') 49 | conv_template.append_message(conv_template.roles[0], f"{neg}") 50 | conv_template.append_message(conv_template.roles[1], f"{pos}") 51 | 52 | prompt_both = conv_template.get_prompt() 53 | 54 | with_sentence = prompt 55 | with_sentence_and_paraphrase = prompt_both 56 | 57 | if return_reference: 58 | return with_sentence, with_sentence_and_paraphrase, pos 59 | else: 60 | return with_sentence, with_sentence_and_paraphrase -------------------------------------------------------------------------------- /tasks/loader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from tqdm import tqdm 6 | from transformers import PreTrainedTokenizer 7 | 8 | logger = logging.getLogger("task") 9 | 10 | 11 | class TokenizedForStyleRightPad(Dataset): 12 | def __init__(self, data, tok: PreTrainedTokenizer, prompt_fn, mode = 'eval', no_padding=False, prefix=''): 13 | # data: [query: str, choices: list(str)] 14 | self.tok = tok 15 | self.prompt_fn = prompt_fn 16 | self.references = None 17 | self.max_length = self._find_max_length(data, mode=mode) 18 | if mode == 'ft': 19 | self.data = self._build_ft_data(data) 20 | elif mode == 'eval': 21 | self.data, self.references = self._build_eval_data(data, no_padding=no_padding, prefix=prefix) 22 | else: 23 | raise NotImplementedError 24 | logger.info(f"Tokenization finished: {len(self.data)}, max_length={self.max_length}") 25 | 26 | def _find_max_length(self, data, mode=eval): 27 | max_len = 0 28 | 29 | def tok_len(t): 30 | return len(self.tok.encode(t)) 31 | 32 | for ex in tqdm(data, desc="Data preprocessing(1/2)"): 33 | query = ex["query"] 34 | if mode == 'eval': 35 | len_query = len(self.prompt_fn(query)[0]) 36 | elif mode == 'ft': 37 | len_query = len(self.prompt_fn(query)[1]) 38 | else: 39 | raise NotImplementedError 40 | max_len = max(max_len, len_query) 41 | return max_len 42 | 43 | def _build_eval_data(self, data, no_padding=False, prefix=''): 44 | processed = [] 45 | references = [] 46 | for ex in tqdm(data, desc="Data preprocessing(2/2)"): 47 | query = ex["query"] 48 | processed_input = self.prompt_fn(query, return_reference = True, Instruction = prefix) 49 | t_query, t_full, t_reference = processed_input 50 | processed_input = self.tokenize(t_full, t_query, no_padding=no_padding) 51 | processed.append(processed_input) 52 | references.append(t_reference) 53 | 54 | logger.info("Style dataset: finish!") 55 | return processed, references 56 | 57 | def _build_ft_data(self, data): 58 | processed = [] 59 | for ex in tqdm(data, desc="Data preprocessing(2/2)"): 60 | query = ex["query"] 61 | processed_input = self.prompt_fn(query) 62 | t_query, t_full = processed_input 63 | processed_input = self.tokenize(t_query, t_full) 64 | processed.append(processed_input) 65 | 66 | logger.info("Finetuning dataset: finish!") 67 | return processed 68 | 69 | def tokenize_demonstration(self, demonstration): 70 | e = self.tok(demonstration) 71 | return torch.LongTensor(e["input_ids"]), torch.LongTensor(e["attention_mask"]) # no padding 72 | 73 | 74 | 75 | def tokenize_each_demonstration(self, demonstration_list, dataset_name=None, prefix = None): 76 | special_characters = [ 77 | "~", " ~", "~ ", "!", " !", "! ", "@", " @", "@ ", "#", " #", "# ", 78 | "$", " $", "$ ", "%", " %", "% ", "^", " ^", "^ ", "&", " &", "& ", 79 | "*", " *", "* ", "(", " (", "( ", ")", " )", ") ", "_", " _", "_ ", 80 | "+", " +", "+ ", "`", " `", "` ", "-", " -", "- ", "=", " =", "= ", 81 | "{", " {", "{ ", "}", " }", "} ", "[", " [", "[ ", "]", " ]", "] ", 82 | "|", " |", "| ", "\\", " \\", "\\ ", ":", " :", ": ", ";", " ;", "; ", 83 | "\"", " \"", "\" ", "'", " '", "' ", "<", " <", "< ", ">", " >", "> ", 84 | ",", " ,", ", ", ".", " .", ". ", "?", " ?", "? ", "/", " /", "/ " 85 | ] 86 | 87 | def strip_special_characters(input_string): 88 | for char in special_characters: 89 | input_string = input_string.replace(char.strip(), '') 90 | return input_string.strip() 91 | 92 | tokenized_demonstration_list = [] 93 | for exp_id in range(len(demonstration_list)): 94 | if prefix is not None: 95 | demonstration_list[exp_id] = (prefix[0] + strip_special_characters(demonstration_list[exp_id][0]), prefix[1] + strip_special_characters(demonstration_list[exp_id][1])) 96 | else: 97 | demonstration_list[exp_id] = (strip_special_characters(demonstration_list[exp_id][0]), strip_special_characters(demonstration_list[exp_id][1])) 98 | e_original = self.tok(demonstration_list[exp_id][0]) 99 | e_rewrite = self.tok(demonstration_list[exp_id][1]) 100 | tokenized_demonstration_list.append((e_original, e_rewrite)) 101 | return tokenized_demonstration_list 102 | 103 | def tokenize(self, only_query, full_text, no_padding = False): 104 | tok_only_query = self.tok(only_query, add_special_tokens=False) 105 | tok_full_no_padding = self.tok(full_text, add_special_tokens=False) 106 | tok_full = self.tok( 107 | full_text, 108 | padding="max_length", 109 | max_length=self.max_length, 110 | add_special_tokens=False, 111 | ) # is not a special token 112 | 113 | if no_padding: 114 | e = { 115 | "input_ids": tok_full_no_padding.input_ids, 116 | "attention_mask": tok_full_no_padding.attention_mask, 117 | } 118 | else: 119 | e = { 120 | "input_ids": tok_full.input_ids, 121 | "attention_mask": tok_full.attention_mask, 122 | } 123 | 124 | return e 125 | 126 | def __len__(self): 127 | return len(self.data) 128 | 129 | def __getitem__(self, idx): 130 | 131 | es = self.data[idx] 132 | 133 | if self.references: 134 | return torch.LongTensor(es["input_ids"]), torch.LongTensor(es["attention_mask"]), self.references[idx] 135 | else: 136 | return es 137 | 138 | 139 | if __name__ == "__main__": 140 | from anchor import hf_datasets_root 141 | 142 | import datasets 143 | 144 | csqa1 = datasets.load_dataset("commonsense_qa", cache_dir=str(hf_datasets_root), split="validation") 145 | -------------------------------------------------------------------------------- /tasks/paradetox.py: -------------------------------------------------------------------------------- 1 | from tasks.base import BaseProbInference 2 | 3 | 4 | class ParaDetoxProbInferenceForStyle(BaseProbInference): 5 | def __init__(self, prompt_version): 6 | super().__init__(prompt_version) 7 | 8 | self.can_be_stratified = False 9 | self.num_base_shot = 1 10 | 11 | def default_prompt_version(self): 12 | return "sp" 13 | 14 | def dataset_signature(self): 15 | return { 16 | "sample": ("s-nlp/paradetox", None, "train"), 17 | "result": ("s-nlp/paradetox", None, "train"), 18 | } 19 | 20 | def dataset_preprocess(self, raw_data): 21 | data = [] 22 | for e in raw_data: 23 | query = (e["en_toxic_comment"], e["en_neutral_comment"]) 24 | data.append({"query": query}) 25 | return data 26 | 27 | def handcrafted_exemplars(self): 28 | raise NotImplementedError 29 | 30 | def exemplar_seperator(self): 31 | if self.prompt_version.startswith("sp"): 32 | return ". " 33 | else: 34 | raise ValueError(f"AGNews: Not supported prompt_version: {self.prompt_version}") 35 | 36 | def paralell_style_promptify(self, query, return_reference = False, Instruction = ''): 37 | toxic, neutral = query 38 | 39 | with_sentence_and_paraphrase = Instruction + f'Original: "{toxic}"; Paraphrased: "{neutral}"' 40 | with_sentence = Instruction + f'Original: "{toxic}"; Paraphrased: "' 41 | 42 | if return_reference: 43 | return with_sentence, with_sentence_and_paraphrase, neutral 44 | else: 45 | return with_sentence, with_sentence_and_paraphrase -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengliu66/ICV/b187c6387ae41097f5e38182d5a8df7d06ca9ec4/utils/.DS_Store -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengliu66/ICV/b187c6387ae41097f5e38182d5a8df7d06ca9ec4/utils/__init__.py -------------------------------------------------------------------------------- /utils/cuda_check.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | if __name__ == "__main__": 4 | print(f"Cuda available: {torch.cuda.is_available()}") 5 | 6 | device_count = torch.cuda.device_count() 7 | print(f"Cuda device count: {device_count}") 8 | 9 | for idx in range(device_count): 10 | print(f"Device IDX: {idx}") 11 | print(f"Name: {torch.cuda.get_device_name(idx)}") 12 | print(f"Process: {torch.cuda.get_device_properties(idx)}", flush=True) 13 | -------------------------------------------------------------------------------- /utils/forward_tracer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass 2 | from typing import Dict, Optional 3 | import torch 4 | from transformers import PreTrainedModel 5 | from .llm_layers import get_embedding_layer, get_layers 6 | 7 | @dataclass 8 | class ResidualStream: 9 | hidden: torch.Tensor 10 | 11 | 12 | class ForwardTrace: 13 | def __init__(self): 14 | self.residual_stream: Optional[ResidualStream] = ResidualStream( 15 | hidden=[], 16 | ) 17 | self.attentions: Optional[torch.Tensor] = None 18 | 19 | 20 | class ForwardTracer: 21 | def __init__(self, model: PreTrainedModel, forward_trace: ForwardTrace): 22 | self._model = model 23 | self._forward_trace = forward_trace 24 | 25 | self._layers = get_layers(model) 26 | self._hooks = [] 27 | 28 | def __enter__(self): 29 | self._register_forward_hooks() 30 | 31 | def __exit__(self, exc_type, exc_value, traceback): 32 | for hook in self._hooks: 33 | hook.remove() 34 | 35 | if exc_type is None: 36 | residual_stream = self._forward_trace.residual_stream 37 | 38 | if residual_stream.hidden[0] == []: 39 | residual_stream.hidden.pop(0) 40 | 41 | for key in residual_stream.__dataclass_fields__.keys(): 42 | acts = getattr(residual_stream, key) 43 | # TODO: this is a hack, fix it 44 | if key != "hidden" and not self._with_submodules: 45 | continue 46 | 47 | nonempty_layer_acts = [layer_acts for layer_acts in acts if layer_acts != []][0] 48 | final_shape = torch.cat(nonempty_layer_acts, dim=0).shape 49 | 50 | for i, layer_acts in enumerate(acts): 51 | if layer_acts == []: 52 | acts[i] = torch.zeros(final_shape) 53 | else: 54 | acts[i] = torch.cat(layer_acts, dim=0) 55 | acts = torch.stack(acts).transpose(0, 1) 56 | setattr(residual_stream, key, acts) 57 | 58 | def _register_forward_hooks(self): 59 | model = self._model 60 | hooks = self._hooks 61 | 62 | residual_stream = self._forward_trace.residual_stream 63 | 64 | def store_activations(residual_stream: ResidualStream, acts_type: str, layer_num: int): 65 | def hook(model, inp, out): 66 | if isinstance(out, tuple): 67 | out = out[0] 68 | out = out.float().to("cpu", non_blocking=True) 69 | 70 | acts = getattr(residual_stream, acts_type) 71 | while len(acts) < layer_num + 1: 72 | acts.append([]) 73 | try: 74 | acts[layer_num].append(out) 75 | except IndexError: 76 | print(len(acts), layer_num) 77 | 78 | return hook 79 | 80 | def store_inputs(residual_stream: ResidualStream, acts_type: str, layer_num: int): 81 | def hook(model, inp, out): 82 | if isinstance(inp, tuple): 83 | inp = inp[0] 84 | inp = inp.float().to("cpu", non_blocking=True) 85 | 86 | acts = getattr(residual_stream, acts_type) 87 | while len(acts) < layer_num + 1: 88 | acts.append([]) 89 | try: 90 | acts[layer_num].append(inp) 91 | except IndexError: 92 | print(len(acts), layer_num) 93 | 94 | return hook 95 | 96 | 97 | embedding_hook = get_embedding_layer(self._model).register_forward_hook( 98 | store_activations(residual_stream, "hidden", 0) 99 | ) 100 | hooks.append(embedding_hook) 101 | 102 | for i, layer in enumerate(self._layers): 103 | hidden_states_hook = layer.register_forward_hook(store_activations(residual_stream, "hidden", i + 1)) 104 | hooks.append(hidden_states_hook) -------------------------------------------------------------------------------- /utils/llm_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from torch import nn 4 | from transformers import PreTrainedModel, GPTJForCausalLM 5 | from torch import Tensor 6 | import numpy as np 7 | 8 | 9 | class ICVLayer(nn.Module): 10 | 11 | def __init__(self, icv, lam): 12 | super(ICVLayer, self).__init__() 13 | self.icv = icv 14 | self.lam = lam 15 | 16 | def forward(self, x): 17 | if self.icv is not None: 18 | x = x.float() 19 | original_norm = torch.norm(x, p=2, dim=-1, keepdim=True) 20 | directions_all = [] 21 | y = 0 22 | for i in range(len(self.icv)): 23 | lambda_sim = 1.0 + torch.max(torch.tensor([0.]).to(x.device), F.cosine_similarity(x, -self.icv[i][None,None,:], dim=-1)).unsqueeze(-1) 24 | y += self.lam[i] * lambda_sim * F.normalize(self.icv[i], dim=-1).repeat(1,x.shape[1],1) 25 | y = y/len(self.icv) 26 | x = F.normalize(F.normalize(x.float(), p=2, dim=-1) + y, p=2, dim=-1) * original_norm 27 | return x.half() 28 | else: 29 | return x 30 | 31 | def get_nested_attr(obj, attr_path): 32 | attrs = attr_path.split(".") 33 | for attr in attrs: 34 | obj = getattr(obj, attr) 35 | return obj 36 | 37 | 38 | def set_nested_attr(obj, attr_path, value): 39 | attrs = attr_path.split(".") 40 | parent = get_nested_attr(obj, ".".join(attrs[:-1])) 41 | setattr(parent, attrs[-1], value) 42 | 43 | 44 | def find_longest_modulelist(model, path=""): 45 | """ 46 | Recursively find the longest nn.ModuleList in a PyTorch model. 47 | Args: 48 | model: PyTorch model. 49 | path: Current path in the model (used for recursion). 50 | Returns: 51 | Tuple with path and length of the longest nn.ModuleList found. 52 | """ 53 | longest_path = path 54 | longest_len = 0 55 | 56 | for name, child in model.named_children(): 57 | if isinstance(child, nn.ModuleList) and len(child) > longest_len: 58 | longest_len = len(child) 59 | longest_path = f"{path}.{name}" if path else name 60 | 61 | # Recursively check the child's children 62 | child_path, child_len = find_longest_modulelist(child, f"{path}.{name}" if path else name) 63 | if child_len > longest_len: 64 | longest_len = child_len 65 | longest_path = child_path 66 | 67 | return longest_path, longest_len 68 | 69 | 70 | def find_module(block, keywords): 71 | """ 72 | Try to find a module in a transformer block. 73 | Args: 74 | block: Transformer block (nn.Module). 75 | keywords: List of possible module names (str). 76 | Returns: 77 | The found module if found, else None. 78 | """ 79 | for name, module in block.named_modules(): 80 | if any(keyword in name for keyword in keywords): 81 | return module 82 | submodule_names = [name for name, _ in block.named_modules()] 83 | raise ValueError(f"Could not find keywords {keywords} in: {submodule_names}") 84 | 85 | 86 | def get_embedding_layer(model: PreTrainedModel): 87 | # model_type = model.__class__.__name__ 88 | # if model_type == "LlamaForCausalLM": 89 | # return model.model.embed_tokens 90 | # elif model_type == "RWForCausalLM": 91 | # return model.transformer.word_embeddings 92 | 93 | keywords = ["emb", "wte"] 94 | return find_module(model, keywords) 95 | 96 | 97 | def get_lm_head(model: PreTrainedModel): 98 | keywords = ["lm_head", "embed_out"] 99 | return find_module(model, keywords) 100 | 101 | 102 | def get_lm_pipeline(model: PreTrainedModel): 103 | model_class = model.__class__.__name__ 104 | 105 | if model_class == "LlamaForCausalLM": 106 | return nn.Sequential(model.model.norm, model.lm_head) 107 | elif model_class == "RWForCausalLM": 108 | return nn.Sequential(model.transformer.ln_f, model.lm_head) 109 | elif model_class == "GPTNeoForCausalLM": 110 | return nn.Sequential(model.transformer.ln_f, model.lm_head) 111 | elif model_class == "GPTNeoXForCausalLM": 112 | return nn.Sequential(model.gpt_neox.final_layer_norm, model.embed_out) 113 | 114 | # TODO: make the default case more robust 115 | return get_lm_head(model) 116 | 117 | 118 | def get_layers_path(model: PreTrainedModel): 119 | longest_path, longest_len = find_longest_modulelist(model) 120 | return longest_path 121 | 122 | 123 | def get_layers(model: PreTrainedModel): 124 | longest_path = get_layers_path(model) 125 | return get_nested_attr(model, longest_path) 126 | 127 | def get_mlp_layers(model: PreTrainedModel): 128 | layers = get_layers(model) 129 | mlp_keywords = ["mlp", "feedforward", "ffn"] 130 | mlp_layers = [find_module(layer, mlp_keywords) for layer in layers] 131 | return mlp_layers 132 | 133 | def add_icv_layers(model: PreTrainedModel, icv: Tensor, alpha: list): 134 | layers = get_layers(model) 135 | mlp_keywords = ["mlp", "feedforward", "ffn"] 136 | assert len(icv) == len(layers) 137 | for i, layer in enumerate(layers): 138 | original_mlp = find_module(layer, mlp_keywords) 139 | layer.mlp = nn.Sequential(original_mlp, ICVLayer(icv[i], alpha)) 140 | 141 | def remove_icv_layers(model: PreTrainedModel): 142 | layers = get_layers(model) 143 | mlp_keywords = ["mlp", "feedforward", "ffn"] 144 | for i, layer in enumerate(layers): 145 | icv_mlp = find_module(layer, mlp_keywords) 146 | layer.mlp = icv_mlp[0] -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, unicode_literals 2 | 3 | 4 | import logging 5 | from pathlib import Path 6 | 7 | 8 | import logging 9 | import multiprocessing 10 | import threading 11 | 12 | try: 13 | from queue import Empty 14 | except ImportError: # Python 2. 15 | from Queue import Empty # type: ignore[no-redef] 16 | 17 | 18 | __version__ = "0.3.4" 19 | 20 | 21 | def setup_logger(folder_path, log_file_name="logger.log", console_output=False, logger_name="task"): 22 | dir_root = Path(folder_path) 23 | full_path = dir_root.joinpath(log_file_name) 24 | # print("File: ", full_path) 25 | 26 | already_exist = Path(full_path).exists() 27 | 28 | logger = logging.getLogger(logger_name) 29 | logger.setLevel(logging.INFO) 30 | 31 | formatter = logging.Formatter("%(asctime)s| %(message)s", "%m-%d|%H:%M:%S") 32 | 33 | file_hdl = logging.FileHandler(full_path) 34 | file_hdl.setFormatter(formatter) 35 | 36 | logger.addHandler(file_hdl) 37 | 38 | if console_output: 39 | console_hdl = logging.StreamHandler() 40 | console_hdl.setFormatter(formatter) 41 | logger.addHandler(console_hdl) 42 | 43 | logger.info("") 44 | logger.info("-*" * 30) 45 | logger.info("Logger ready") 46 | if already_exist: 47 | logger.info("") 48 | logger.info("") 49 | logger.info(f">>>>> Logger file {full_path} already exist, append to it. <<<<<") 50 | logger.info("") 51 | logger.info("") 52 | 53 | 54 | def setup_simple_logger(): 55 | root_logger = logging.getLogger() 56 | root_logger.setLevel(logging.INFO) 57 | 58 | formatter = logging.Formatter("%(asctime)s| %(message)s", "%m-%d|%H:%M:%S") 59 | 60 | console_hdl = logging.StreamHandler() 61 | console_hdl.setFormatter(formatter) 62 | root_logger.addHandler(console_hdl) 63 | 64 | 65 | def tabular_pretty_print(grid): 66 | lens = [max(map(len, col)) for col in zip(*grid)] 67 | 68 | fmt = " | ".join("{{:{}}}".format(x) for x in lens) 69 | table = [fmt.format(*row) for row in grid] 70 | 71 | sep = ["~" * len(table[0])] 72 | table = sep + table + sep 73 | 74 | res = [] 75 | for idx, line in enumerate(table): 76 | if idx == 0 or idx == len(table) - 1: 77 | ps = "* {} *".format(line) 78 | else: 79 | ps = "| {} |".format(line) 80 | res.append(ps) 81 | return res 82 | 83 | 84 | def fmt_float(num, d=4): 85 | fmt_string = "{{:.{}f}}".format(d) 86 | return fmt_string.format(num) 87 | 88 | 89 | def install_mp_handler(logger=None): 90 | """Wraps the handlers in the given Logger with an MultiProcessingHandler. 91 | :param logger: whose handlers to wrap. By default, the root logger. 92 | """ 93 | if logger is None: 94 | logger = logging.getLogger() 95 | 96 | for i, orig_handler in enumerate(list(logger.handlers)): 97 | handler = MultiProcessingHandler("mp-handler-{0}".format(i), sub_handler=orig_handler) 98 | 99 | logger.removeHandler(orig_handler) 100 | logger.addHandler(handler) 101 | 102 | 103 | def uninstall_mp_handler(logger=None): 104 | """Unwraps the handlers in the given Logger from a MultiProcessingHandler wrapper 105 | :param logger: whose handlers to unwrap. By default, the root logger. 106 | """ 107 | if logger is None: 108 | logger = logging.getLogger() 109 | 110 | for handler in logger.handlers: 111 | if isinstance(handler, MultiProcessingHandler): 112 | orig_handler = handler.sub_handler 113 | logger.removeHandler(handler) 114 | logger.addHandler(orig_handler) 115 | 116 | 117 | class MultiProcessingHandler(logging.Handler): 118 | def __init__(self, name, sub_handler=None): 119 | super(MultiProcessingHandler, self).__init__() 120 | 121 | if sub_handler is None: 122 | sub_handler = logging.StreamHandler() 123 | self.sub_handler = sub_handler 124 | 125 | self.setLevel(self.sub_handler.level) 126 | self.setFormatter(self.sub_handler.formatter) 127 | self.filters = self.sub_handler.filters 128 | 129 | self.queue = multiprocessing.Queue(-1) 130 | self._is_closed = False 131 | # The thread handles receiving records asynchronously. 132 | self._receive_thread = threading.Thread(target=self._receive, name=name) 133 | self._receive_thread.daemon = True 134 | self._receive_thread.start() 135 | 136 | def setFormatter(self, fmt): 137 | super(MultiProcessingHandler, self).setFormatter(fmt) 138 | self.sub_handler.setFormatter(fmt) 139 | 140 | def _receive(self): 141 | while True: 142 | try: 143 | if self._is_closed and self.queue.empty(): 144 | break 145 | 146 | record = self.queue.get(timeout=0.2) 147 | self.sub_handler.emit(record) 148 | except (KeyboardInterrupt, SystemExit): 149 | raise 150 | except (EOFError, OSError): 151 | break # The queue was closed by child? 152 | except Empty: 153 | pass # This periodically checks if the logger is closed. 154 | except: 155 | from sys import stderr 156 | from traceback import print_exc 157 | 158 | print_exc(file=stderr) 159 | raise 160 | 161 | self.queue.close() 162 | self.queue.join_thread() 163 | 164 | def _send(self, s): 165 | self.queue.put_nowait(s) 166 | 167 | def _format_record(self, record): 168 | # ensure that exc_info and args 169 | # have been stringified. Removes any chance of 170 | # unpickleable things inside and possibly reduces 171 | # message size sent over the pipe. 172 | if record.args: 173 | record.msg = record.msg % record.args 174 | record.args = None 175 | if record.exc_info: 176 | self.format(record) 177 | record.exc_info = None 178 | 179 | return record 180 | 181 | def emit(self, record): 182 | try: 183 | s = self._format_record(record) 184 | self._send(s) 185 | except (KeyboardInterrupt, SystemExit): 186 | raise 187 | except: 188 | self.handleError(record) 189 | 190 | def close(self): 191 | if not self._is_closed: 192 | self._is_closed = True 193 | self._receive_thread.join(5.0) # Waits for receive queue to empty. 194 | 195 | self.sub_handler.close() 196 | super(MultiProcessingHandler, self).close() 197 | -------------------------------------------------------------------------------- /utils/pca.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def svd_flip(u, v): 7 | # columns of u, rows of v 8 | max_abs_cols = torch.argmax(torch.abs(u), 0) 9 | i = torch.arange(u.shape[1]).to(u.device) 10 | signs = torch.sign(u[max_abs_cols, i]) 11 | u *= signs 12 | v *= signs.view(-1, 1) 13 | return u, v 14 | 15 | class PCA(nn.Module): 16 | def __init__(self, n_components): 17 | super().__init__() 18 | self.n_components = n_components 19 | 20 | @torch.no_grad() 21 | def fit(self, X): 22 | n, d = X.size() 23 | if self.n_components is not None: 24 | d = min(self.n_components, d) 25 | self.register_buffer("mean_", X.mean(0, keepdim=True)) 26 | Z = X - self.mean_ # center 27 | U, S, Vh = torch.linalg.svd(Z, full_matrices=False) 28 | Vt = Vh 29 | U, Vt = svd_flip(U, Vt) 30 | self.register_buffer("components_", Vt[:d]) 31 | return self 32 | 33 | def forward(self, X): 34 | return self.transform(X) 35 | 36 | def transform(self, X): 37 | assert hasattr(self, "components_"), "PCA must be fit before use." 38 | return torch.matmul(X - self.mean_, self.components_.t()) 39 | 40 | def fit_transform(self, X): 41 | self.fit(X) 42 | return self.transform(X) 43 | 44 | def inverse_transform(self, Y): 45 | assert hasattr(self, "components_"), "PCA must be fit before use." 46 | return torch.matmul(Y, self.components_) + self.mean_ 47 | -------------------------------------------------------------------------------- /utils/rng_ctx.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class RandomState: 9 | def __init__(self): 10 | self.random_mod_state = random.getstate() 11 | self.np_state = np.random.get_state() 12 | self.torch_cpu_state = torch.get_rng_state() 13 | self.torch_gpu_states = [torch.cuda.get_rng_state(d) for d in range(torch.cuda.device_count())] 14 | 15 | def restore(self): 16 | random.setstate(self.random_mod_state) 17 | np.random.set_state(self.np_state) 18 | torch.set_rng_state(self.torch_cpu_state) 19 | for d, state in enumerate(self.torch_gpu_states): 20 | torch.cuda.set_rng_state(state, d) 21 | 22 | 23 | class RandomContext: 24 | """Save and restore state of PyTorch, NumPy, Python RNGs.""" 25 | 26 | def __init__(self, seed=None): 27 | outside_state = RandomState() 28 | 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | if seed is None: 32 | torch.manual_seed(random.randint(-sys.maxsize - 1, sys.maxsize)) 33 | else: 34 | torch.manual_seed(seed) 35 | # torch.cuda.manual_seed_all is called by torch.manual_seed 36 | self.inside_state = RandomState() 37 | 38 | outside_state.restore() 39 | 40 | self._active = False 41 | 42 | def __enter__(self): 43 | if self._active: 44 | raise Exception("RandomContext can be active only once") 45 | 46 | self.outside_state = RandomState() 47 | self.inside_state.restore() 48 | self._active = True 49 | 50 | def __exit__(self, exception_type, exception_value, traceback): 51 | self.inside_state = RandomState() 52 | self.outside_state.restore() 53 | self.outside_state = None 54 | 55 | self._active = False 56 | 57 | 58 | class EmptyContext: 59 | def __enter__(self): 60 | pass 61 | 62 | def __exit__(self, exc_type, exc_val, exc_tb): 63 | pass 64 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from pathlib import Path 3 | import json 4 | 5 | 6 | class MpCounter: 7 | def __init__(self): 8 | self.val = multiprocessing.Value("i", 0) 9 | 10 | def increment(self, n=1): 11 | with self.val.get_lock(): 12 | self.val.value += n 13 | 14 | @property 15 | def value(self): 16 | return self.val.value 17 | 18 | 19 | def yield_chunks(data, size): 20 | data = list(data) 21 | for i in range(0, len(data), size): 22 | yield data[i : i + size] 23 | 24 | 25 | def ensure_folder(folder: Path, parents=False): 26 | if not folder.exists(): 27 | folder.mkdir(parents=parents) 28 | 29 | 30 | def pick_if_present(d: dict, key_in_dict, key_new=None): 31 | if key_in_dict in d: 32 | if not key_new: 33 | return {key_in_dict: d[key_in_dict]} 34 | else: 35 | return {key_new: d[key_in_dict]} 36 | return {} 37 | 38 | 39 | class AverageMeterSet(object): 40 | def __init__(self, meters=None): 41 | self.meters = meters if meters else {} 42 | 43 | def __getitem__(self, key): 44 | if key not in self.meters: 45 | meter = AverageMeter() 46 | meter.update(0) 47 | return meter 48 | return self.meters[key] 49 | 50 | def update(self, name, value, n=1): 51 | if name not in self.meters: 52 | self.meters[name] = AverageMeter() 53 | self.meters[name].update(value, n) 54 | 55 | def reset(self): 56 | for meter in self.meters.values(): 57 | meter.reset() 58 | 59 | def values(self, format_string="{}"): 60 | return {format_string.format(name): meter.val for name, meter in self.meters.items()} 61 | 62 | def averages(self, format_string="{}"): 63 | return {format_string.format(name): meter.avg for name, meter in self.meters.items()} 64 | 65 | def sums(self, format_string="{}"): 66 | return {format_string.format(name): meter.sum for name, meter in self.meters.items()} 67 | 68 | def counts(self, format_string="{}"): 69 | return {format_string.format(name): meter.count for name, meter in self.meters.items()} 70 | 71 | 72 | class AverageMeter(object): 73 | """Computes and stores the average and current value""" 74 | 75 | def __init__(self): 76 | self.val = 0 77 | self.avg = 0 78 | self.sum = 0 79 | self.count = 0 80 | 81 | def reset(self): 82 | self.val = 0 83 | self.avg = 0 84 | self.sum = 0 85 | self.count = 0 86 | 87 | def update(self, val, n=1): 88 | self.val = val 89 | self.sum += val 90 | self.count += n 91 | self.avg = self.sum / self.count 92 | 93 | def __format__(self, fmt): 94 | return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=fmt) 95 | 96 | 97 | class CompactJSONEncoder(json.JSONEncoder): 98 | """A JSON Encoder that puts small containers on single lines.""" 99 | 100 | CONTAINER_TYPES = (list, tuple, dict) 101 | """Container datatypes include primitives or other containers.""" 102 | 103 | MAX_WIDTH = 1000 104 | """Maximum width of a container that might be put on a single line.""" 105 | 106 | MAX_ITEMS = 60 107 | """Maximum number of items in container that might be put on single line.""" 108 | 109 | def __init__(self, *args, **kwargs): 110 | # using this class without indentation is pointless 111 | if kwargs.get("indent") is None: 112 | kwargs["indent"] = 4 113 | super().__init__(*args, **kwargs) 114 | self.indentation_level = 0 115 | 116 | def encode(self, o): 117 | """Encode JSON object *o* with respect to single line lists.""" 118 | if isinstance(o, (list, tuple)): 119 | return self._encode_list(o) 120 | if isinstance(o, dict): 121 | return self._encode_object(o) 122 | if isinstance(o, float): # Use scientific notation for floats 123 | return format(o, "g") 124 | return json.dumps( 125 | o, 126 | skipkeys=self.skipkeys, 127 | ensure_ascii=self.ensure_ascii, 128 | check_circular=self.check_circular, 129 | allow_nan=self.allow_nan, 130 | sort_keys=self.sort_keys, 131 | indent=self.indent, 132 | separators=(self.item_separator, self.key_separator), 133 | default=self.default if hasattr(self, "default") else None, 134 | ) 135 | 136 | def _encode_list(self, o): 137 | if self._put_on_single_line(o): 138 | return "[" + ", ".join(self.encode(el) for el in o) + "]" 139 | self.indentation_level += 1 140 | output = [self.indent_str + self.encode(el) for el in o] 141 | self.indentation_level -= 1 142 | return "[\n" + ",\n".join(output) + "\n" + self.indent_str + "]" 143 | 144 | def _encode_object(self, o): 145 | if not o: 146 | return "{}" 147 | if self._put_on_single_line(o): 148 | return "{ " + ", ".join(f"{self.encode(k)}: {self.encode(el)}" for k, el in o.items()) + " }" 149 | self.indentation_level += 1 150 | output = [f"{self.indent_str}{json.dumps(k)}: {self.encode(v)}" for k, v in o.items()] 151 | 152 | self.indentation_level -= 1 153 | return "{\n" + ",\n".join(output) + "\n" + self.indent_str + "}" 154 | 155 | def iterencode(self, o, **kwargs): 156 | """Required to also work with `json.dump`.""" 157 | return self.encode(o) 158 | 159 | def _put_on_single_line(self, o): 160 | return self._primitives_only(o) and len(o) <= self.MAX_ITEMS and len(str(o)) - 2 <= self.MAX_WIDTH 161 | 162 | def _primitives_only(self, o): 163 | if isinstance(o, (list, tuple)): 164 | return not any(isinstance(el, self.CONTAINER_TYPES) for el in o) 165 | elif isinstance(o, dict): 166 | return not any(isinstance(el, self.CONTAINER_TYPES) for el in o.values()) 167 | 168 | @property 169 | def indent_str(self) -> str: 170 | if isinstance(self.indent, int): 171 | return " " * (self.indentation_level * self.indent) 172 | elif isinstance(self.indent, str): 173 | return self.indentation_level * self.indent 174 | else: 175 | raise ValueError(f"indent must either be of type int or str (is: {type(self.indent)})") 176 | 177 | 178 | if __name__ == "__main__": 179 | a = list(range(0, 12)) 180 | print(a) 181 | for e in yield_chunks(a, 7): 182 | print(e) 183 | --------------------------------------------------------------------------------