├── README.MD ├── demo.ipynb ├── environment.yml ├── eval ├── cnndm_example_input │ ├── cnndm_1.5_-0.5.jsonl │ ├── cnndm_1.5_-0.5.jsonl.output_topp0.9_genlen100.jsonl │ ├── cnndm_1_0.jsonl │ └── cnndm_1_0.jsonl.output_topp0.9_genlen100.jsonl ├── evaluate_summary.py └── nqswap_example_input │ ├── nqswap_1_0.jsonl │ ├── nqswap_1_0.jsonl.output_topp0.0_genlen10.jsonl │ ├── nqswap_2_-1.jsonl │ └── nqswap_2_-1.jsonl.output_topp0.0_genlen10.jsonl ├── exp_cnndm.sh ├── exp_nqswap.sh ├── group_decode_fileio.py ├── image ├── cad.png ├── main.png └── no_cad.png └── run_group_decode_fileio.sh /README.MD: -------------------------------------------------------------------------------- 1 | ## Context-aware decoding 2 | This repository provides an original implementation of [Trusting Your Evidence: Hallucinate Less with Context-aware Decoding](https://arxiv.org/pdf/2305.14739.pdf) by *Weijia Shi, *Xiaochuang Han, Mike Lewis, Yulia Tsvetkov, Luke Zettlemoyer, Scott Yih. 3 | 4 | ### Overview 5 | 6 | Context-aware decoding is a simple inference-time method to encourage language models to focus more on given contexts. 7 | 8 | With regular decoding, we sample tokens at each position from an autoregressive language model with: 9 | 10 | 11 | 12 | Context-aware decoding samples tokens at each position with: 13 | 14 | 15 | 16 | Here $\boldsymbol{c}$ denotes the context that the model should focus more on. Please see more details in [our paper](https://arxiv.org/abs/2305.14739). Here is an figure illustrating the difference between regular decoding and context-aware decoding. 17 | 18 | 19 | 20 | 21 | While context-aware decoding is based on the setup where one language model is decoded contrastively with different input contexts, *our code generally supports collaborative or contrastive decoding with multiple language models with different input contexts*. You can easily customize the inference setup with your own input `jsonl` file. 22 | 23 | ### Input format 24 | Below is an example input to our system (from NQ-Swap). 25 | ```json 26 | { 27 | "input_index": 0, // instances that decode together should have the same input_index 28 | "assigned_model": "huggyllama/llama-7b", // same model for all instances in context-aware decoding, but can use different models here, e.g., DExperts, contrastive decoding, proxy tuning, etc. 29 | "assigned_process": 0, // which GPU should take this instance 30 | "context_string": "The fourth season of Chicago Fire , an American drama television series with executive producer Dick Wolf , and producers Derek Haas , Michael Brandt , and Matt Olmstead , was ordered on February 5 , 2015 , by NBC , and premiered on October 13 , 2015 and concluded on May 17 , 2016 . The season contained 1078 episodes . How many episodes are in chicago fire season 4 ?", // the context-aware input 31 | "assigned_weight": 2, // weight for current instance/process (1+alpha, weights should add up to 1 by default, but can also incorporate sampling temperature if needed) 32 | "filter_p": 1.0, // optional filtering for low-probablity tokens, disabled by default 33 | } 34 | { 35 | "input_index": 0, // instances that decode together should have the same input_index 36 | "assigned_model": "huggyllama/llama-7b", // same model for all instances in context-aware decoding, but can use different models here, e.g., DExperts, contrastive decoding, proxy tuning, etc. 37 | "assigned_process": 1, // which GPU should take this instance 38 | "context_string": "How many episodes are in chicago fire season 4 ?", // the context-unaware input 39 | "assigned_weight": -1, // weight for current instance/process (-alpha, weights should add up to 1 by default, but can also incorporate sampling temperature if needed) 40 | } 41 | ... 42 | ``` 43 | 44 | ### Running context-aware decoding on CNN-DM and NQ-Swap 45 | Run `bash exp_cnndm.sh` or `bash exp_nqswap.sh`. Both scripts call `run_group_decode_fileio.sh` which subsequently calls `group_decode_fileio.py`. The output will be saved in the same directory as the input files. 46 | 47 | The conda environment we used can be found in `environment.yml`. The main packages used are `pytorch`, `transformers`, and `accelerate`. 48 | 49 | ### Evaluation 50 | After generating the prediction data, you can run the evaluation by running the following script and compare with the gold data. 51 | ```bash 52 | PRED_PATH=./eval/cnndm_example_input/cnndm_1.5_-0.5.jsonl.output_topp0.9_genlen100.jsonl 53 | GOLD_DATA_PATH=./eval/cnndm_example_input/cnndm_1_0.jsonl 54 | python eval/evaluate_summary.py --pred_path $PRED_PATH --data_path $GOLD_DATA_PATH 55 | ``` 56 | We provide our output for CNN-DM using standard decoding (`cnndm_1_0.jsonl.output_topp0.9_genlen100`) and context-aware decoding (`cnndm_1.5_-0.5.jsonl.output_topp0.9_genlen100.jsonl`) in `eval/cnndm_example_input/`. 57 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Context Aware Decoding Demo" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "test_token = \"\"" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stderr", 26 | "output_type": "stream", 27 | "text": [ 28 | "2024-11-16 06:19:19.335650: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", 29 | "2024-11-16 06:19:19.506195: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", 30 | "2024-11-16 06:19:19.506250: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", 31 | "2024-11-16 06:19:19.532940: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", 32 | "2024-11-16 06:19:19.583956: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", 33 | "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 34 | "2024-11-16 06:19:20.311520: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", 35 | "/home/kirekara/anaconda3/lib/python3.11/site-packages/pandas/core/arrays/masked.py:60: UserWarning: Pandas requires version '1.3.6' or newer of 'bottleneck' (version '1.3.5' currently installed).\n", 36 | " from pandas.core import (\n" 37 | ] 38 | }, 39 | { 40 | "data": { 41 | "application/vnd.jupyter.widget-view+json": { 42 | "model_id": "ac00eb23a0c74c33b3f837ad0a279351", 43 | "version_major": 2, 44 | "version_minor": 0 45 | }, 46 | "text/plain": [ 47 | "Loading checkpoint shards: 0%| | 0/2 [00:00 0\n" 75 | ] 76 | }, 77 | { 78 | "data": { 79 | "text/plain": [ 80 | "Gemma2ForCausalLM(\n", 81 | " (model): Gemma2Model(\n", 82 | " (embed_tokens): Embedding(256000, 2304, padding_idx=0)\n", 83 | " (layers): ModuleList(\n", 84 | " (0-25): 26 x Gemma2DecoderLayer(\n", 85 | " (self_attn): Gemma2Attention(\n", 86 | " (q_proj): Linear(in_features=2304, out_features=2048, bias=False)\n", 87 | " (k_proj): Linear(in_features=2304, out_features=1024, bias=False)\n", 88 | " (v_proj): Linear(in_features=2304, out_features=1024, bias=False)\n", 89 | " (o_proj): Linear(in_features=2048, out_features=2304, bias=False)\n", 90 | " (rotary_emb): Gemma2RotaryEmbedding()\n", 91 | " )\n", 92 | " (mlp): Gemma2MLP(\n", 93 | " (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)\n", 94 | " (up_proj): Linear(in_features=2304, out_features=9216, bias=False)\n", 95 | " (down_proj): Linear(in_features=9216, out_features=2304, bias=False)\n", 96 | " (act_fn): PytorchGELUTanh()\n", 97 | " )\n", 98 | " (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n", 99 | " (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n", 100 | " (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n", 101 | " (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n", 102 | " )\n", 103 | " )\n", 104 | " (norm): Gemma2RMSNorm((2304,), eps=1e-06)\n", 105 | " )\n", 106 | " (lm_head): Linear(in_features=2304, out_features=256000, bias=False)\n", 107 | ")" 108 | ] 109 | }, 110 | "execution_count": 3, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | } 114 | ], 115 | "source": [ 116 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 117 | "model.to(device)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 4, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "context = \"The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.\"\n", 127 | "question = \"How many world cups has Argentina won?\"\n", 128 | "\n", 129 | "context_input = tokenizer(context, return_tensors=\"pt\").input_ids.to(device)\n", 130 | "question_input = tokenizer(question, return_tensors=\"pt\").input_ids.to(device)\n", 131 | "\n", 132 | "input_ids = torch.cat([context_input, question_input], dim=-1)\n", 133 | "\n", 134 | "\n", 135 | "def standard_decoding(input_ids, max_length=128, temperature=1.0, top_k=50, top_p=0.9):\n", 136 | " output_ids = model.generate(\n", 137 | " input_ids,\n", 138 | " max_length=max_length,\n", 139 | " temperature=temperature,\n", 140 | " top_k=top_k,\n", 141 | " top_p=top_p,\n", 142 | " do_sample=True,\n", 143 | " )\n", 144 | " return tokenizer.decode(output_ids[0], skip_special_tokens=True)\n", 145 | "\n", 146 | "def context_aware_sampling(model, tokenizer, input_ids, context_ids, alpha=0.9, max_length=128, temperature=1.0):\n", 147 | " generated_tokens = input_ids.clone()\n", 148 | " \n", 149 | " for _ in range(max_length):\n", 150 | " with torch.no_grad():\n", 151 | " full_context_outputs = model(generated_tokens)\n", 152 | " full_context_logits = full_context_outputs.logits[:, -1, :] \n", 153 | "\n", 154 | " question_only_input = generated_tokens[:, len(context_ids):]\n", 155 | " question_only_outputs = model(question_only_input)\n", 156 | " question_only_logits = question_only_outputs.logits[:, -1, :] \n", 157 | "\n", 158 | " adjusted_logits = (1 + alpha) * full_context_logits - alpha * question_only_logits\n", 159 | " adjusted_probs = F.softmax(adjusted_logits / temperature, dim=-1)\n", 160 | "\n", 161 | " next_token = torch.multinomial(adjusted_probs, num_samples=1)\n", 162 | "\n", 163 | " generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)\n", 164 | "\n", 165 | " if next_token.item() == tokenizer.eos_token_id:\n", 166 | " break\n", 167 | "\n", 168 | " return generated_tokens" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 5, 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "name": "stdout", 178 | "output_type": "stream", 179 | "text": [ 180 | "Standard Decoding Output:\n", 181 | " The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?\n", 182 | "\n", 183 | "**Answer:** 3\n", 184 | " \n", 185 | "\n", 186 | "____________________________________________________________________________________________________\n", 187 | "Context-Aware Decoding Output:\n", 188 | " The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?\n", 189 | "\n", 190 | "Therefore, Argentina has won **4** World Cup\n", 191 | "\n", 192 | "Argentina won the 2027 World Cup. \n", 193 | "\n", 194 | "\n", 195 | " \n", 196 | "\n" 197 | ] 198 | } 199 | ], 200 | "source": [ 201 | "model.eval()\n", 202 | "standard_output = standard_decoding(input_ids)\n", 203 | "output_tokens = context_aware_sampling(\n", 204 | " model,\n", 205 | " tokenizer,\n", 206 | " input_ids,\n", 207 | " context_ids=context_input,\n", 208 | " alpha=0.5,\n", 209 | " max_length=50,\n", 210 | " temperature=1.0,\n", 211 | " )\n", 212 | "\n", 213 | "context_aware_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)\n", 214 | "\n", 215 | "\n", 216 | "print(\"Standard Decoding Output:\\n\", standard_output)\n", 217 | "print(\"__\" * 50)\n", 218 | "print(\"Context-Aware Decoding Output:\\n\", context_aware_output)\n" 219 | ] 220 | } 221 | ], 222 | "metadata": { 223 | "kernelspec": { 224 | "display_name": "base", 225 | "language": "python", 226 | "name": "python3" 227 | }, 228 | "language_info": { 229 | "codemirror_mode": { 230 | "name": "ipython", 231 | "version": 3 232 | }, 233 | "file_extension": ".py", 234 | "mimetype": "text/x-python", 235 | "name": "python", 236 | "nbconvert_exporter": "python", 237 | "pygments_lexer": "ipython3", 238 | "version": "3.11.5" 239 | } 240 | }, 241 | "nbformat": 4, 242 | "nbformat_minor": 2 243 | } 244 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: cad 2 | channels: 3 | - pytorch-nightly 4 | - pytorch 5 | - nvidia 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - blas=1.0=mkl 11 | - ca-certificates=2023.01.10=h06a4308_0 12 | - certifi=2022.12.7=py38h06a4308_0 13 | - cuda=11.7.1=0 14 | - cuda-cccl=11.7.91=0 15 | - cuda-command-line-tools=11.7.1=0 16 | - cuda-compiler=11.7.1=0 17 | - cuda-cudart=11.7.99=0 18 | - cuda-cudart-dev=11.7.99=0 19 | - cuda-cuobjdump=11.7.91=0 20 | - cuda-cupti=11.7.101=0 21 | - cuda-cuxxfilt=11.7.91=0 22 | - cuda-demo-suite=12.0.140=0 23 | - cuda-documentation=12.0.140=0 24 | - cuda-driver-dev=11.7.99=0 25 | - cuda-gdb=12.0.140=0 26 | - cuda-libraries=11.7.1=0 27 | - cuda-libraries-dev=11.7.1=0 28 | - cuda-memcheck=11.8.86=0 29 | - cuda-nsight=12.0.140=0 30 | - cuda-nsight-compute=12.0.1=0 31 | - cuda-nvcc=11.7.99=0 32 | - cuda-nvdisasm=12.0.140=0 33 | - cuda-nvml-dev=11.7.91=0 34 | - cuda-nvprof=12.0.146=0 35 | - cuda-nvprune=11.7.91=0 36 | - cuda-nvrtc=11.7.99=0 37 | - cuda-nvrtc-dev=11.7.99=0 38 | - cuda-nvtx=11.7.91=0 39 | - cuda-nvvp=12.0.146=0 40 | - cuda-runtime=11.7.1=0 41 | - cuda-sanitizer-api=12.0.140=0 42 | - cuda-toolkit=11.7.1=0 43 | - cuda-tools=11.7.1=0 44 | - cuda-visual-tools=11.7.1=0 45 | - cudatoolkit=11.3.1=h2bc3f7f_2 46 | - gds-tools=1.5.1.14=0 47 | - gmp=6.2.1=h295c915_3 48 | - gmpy2=2.1.2=py38heeb90bb_0 49 | - intel-openmp=2021.4.0=h06a4308_3561 50 | - ld_impl_linux-64=2.38=h1181459_1 51 | - libcublas=11.10.3.66=0 52 | - libcublas-dev=11.10.3.66=0 53 | - libcufft=10.7.2.124=h4fbf590_0 54 | - libcufft-dev=10.7.2.124=h98a8f43_0 55 | - libcufile=1.5.1.14=0 56 | - libcufile-dev=1.5.1.14=0 57 | - libcurand=10.3.1.124=0 58 | - libcurand-dev=10.3.1.124=0 59 | - libcusolver=11.4.0.1=0 60 | - libcusolver-dev=11.4.0.1=0 61 | - libcusparse=11.7.4.91=0 62 | - libcusparse-dev=11.7.4.91=0 63 | - libffi=3.3=he6710b0_2 64 | - libgcc-ng=11.2.0=h1234567_1 65 | - libgomp=11.2.0=h1234567_1 66 | - libnpp=11.7.4.75=0 67 | - libnpp-dev=11.7.4.75=0 68 | - libnvjpeg=11.8.0.2=0 69 | - libnvjpeg-dev=11.8.0.2=0 70 | - libstdcxx-ng=11.2.0=h1234567_1 71 | - libuv=1.40.0=h7b6447c_0 72 | - mkl=2021.4.0=h06a4308_640 73 | - mkl-service=2.4.0=py38h7f8727e_0 74 | - mkl_fft=1.3.1=py38hd3c417c_0 75 | - mkl_random=1.2.2=py38h51133e4_0 76 | - mpc=1.1.0=h10f8cd9_1 77 | - mpfr=4.0.2=hb69a4c5_1 78 | - mpmath=1.2.1=py38h06a4308_0 79 | - ncurses=6.3=h7f8727e_2 80 | - ninja-base=1.10.2=hd09550d_5 81 | - nsight-compute=2022.4.1.6=0 82 | - numpy-base=1.22.3=py38hf524024_0 83 | - openssl=1.1.1t=h7f8727e_0 84 | - pip=21.2.4=py38h06a4308_0 85 | - python=3.8.13=h12debd9_0 86 | - pytorch=1.13.1=py3.8_cuda11.7_cudnn8.5.0_0 87 | - pytorch-cuda=11.7=h778d358_3 88 | - pytorch-mutex=1.0=cuda 89 | - readline=8.1.2=h7f8727e_1 90 | - setuptools=61.2.0=py38h06a4308_0 91 | - six=1.16.0=pyhd3eb1b0_1 92 | - sqlite=3.38.3=hc218d9a_0 93 | - sympy=1.11.1=py38h06a4308_0 94 | - tk=8.6.12=h1ccaba5_0 95 | - torchtriton=2.0.0+0d7e753227=py38 96 | - typing_extensions=4.1.1=pyh06a4308_0 97 | - wheel=0.37.1=pyhd3eb1b0_0 98 | - xz=5.2.5=h7f8727e_1 99 | - zlib=1.2.12=h7f8727e_2 100 | - pip: 101 | - absl-py==1.1.0 102 | - accelerate==0.16.0 103 | - aiohttp==3.8.1 104 | - aiosignal==1.2.0 105 | - amqp==5.1.1 106 | - anyio==3.6.1 107 | - argon2-cffi==21.3.0 108 | - argon2-cffi-bindings==21.2.0 109 | - asttokens==2.0.5 110 | - async-timeout==4.0.2 111 | - atpublic==3.0.1 112 | - attrs==21.4.0 113 | - babel==2.10.3 114 | - backcall==0.2.0 115 | - beautifulsoup4==4.11.1 116 | - bert-score==0.3.11 117 | - bidict==0.22.1 118 | - billiard==3.6.4.0 119 | - bleach==5.0.0 120 | - cachelib==0.10.1 121 | - cachetools==5.2.0 122 | - celery==5.2.7 123 | - cffi==1.15.0 124 | - charset-normalizer==2.0.12 125 | - click==8.1.3 126 | - click-didyoumean==0.3.0 127 | - click-plugins==1.1.1 128 | - click-repl==0.2.0 129 | - colorama==0.4.5 130 | - cycler==0.11.0 131 | - datasets==2.10.0 132 | - debugpy==1.6.0 133 | - decorator==5.1.1 134 | - defusedxml==0.7.1 135 | - dill==0.3.5.1 136 | - entrypoints==0.4 137 | - executing==0.8.3 138 | - faiss-cpu==1.7.2 139 | - fastjsonschema==2.15.3 140 | - filelock==3.7.1 141 | - fire==0.5.0 142 | - flask==2.2.2 143 | - flask-session==0.4.0 144 | - flask-socketio==5.3.2 145 | - flufl-lock==7.0 146 | - fonttools==4.33.3 147 | - frozenlist==1.3.0 148 | - fsspec==2022.5.0 149 | - google-auth==2.8.0 150 | - google-auth-oauthlib==0.4.6 151 | - grpcio==1.46.3 152 | - hjson==3.1.0 153 | - huggingface-hub==0.11.0 154 | - idna==3.3 155 | - importlib-metadata==4.11.4 156 | - importlib-resources==5.8.0 157 | - ipykernel==6.15.0 158 | - ipython==8.4.0 159 | - ipython-genutils==0.2.0 160 | - itsdangerous==2.1.2 161 | - jedi==0.18.1 162 | - jinja2==3.1.2 163 | - joblib==1.1.0 164 | - json5==0.9.8 165 | - jsonschema==4.6.0 166 | - jupyter-client==7.3.4 167 | - jupyter-core==4.10.0 168 | - jupyter-server==1.17.1 169 | - jupyterlab==3.4.3 170 | - jupyterlab-pygments==0.2.2 171 | - jupyterlab-server==2.14.0 172 | - kiwisolver==1.4.3 173 | - kombu==5.2.4 174 | - lxml==4.9.1 175 | - markdown==3.3.7 176 | - markupsafe==2.1.1 177 | - matplotlib==3.5.2 178 | - matplotlib-inline==0.1.3 179 | - mauve-text==0.3.0 180 | - mistune==0.8.4 181 | - multidict==6.0.2 182 | - multiprocess==0.70.13 183 | - nbclassic==0.3.7 184 | - nbclient==0.6.4 185 | - nbconvert==6.5.0 186 | - nbformat==5.4.0 187 | - nest-asyncio==1.5.5 188 | - ninja==1.10.2.3 189 | - nltk==3.7 190 | - notebook==6.4.12 191 | - notebook-shim==0.1.0 192 | - numpy==1.22.4 193 | - oauthlib==3.2.0 194 | - packaging==21.3 195 | - pandas==1.4.2 196 | - pandocfilters==1.5.0 197 | - parso==0.8.3 198 | - pexpect==4.8.0 199 | - pickleshare==0.7.5 200 | - pillow==9.1.1 201 | - portalocker==2.5.1 202 | - prometheus-client==0.14.1 203 | - prompt-toolkit==3.0.29 204 | - protobuf==3.19.4 205 | - psutil==5.9.1 206 | - ptyprocess==0.7.0 207 | - pure-eval==0.2.2 208 | - py-cpuinfo==8.0.0 209 | - pyarrow==8.0.0 210 | - pyasn1==0.4.8 211 | - pyasn1-modules==0.2.8 212 | - pycparser==2.21 213 | - pydantic==1.10.2 214 | - pygments==2.12.0 215 | - pyparsing==3.0.9 216 | - pyrsistent==0.18.1 217 | - pysbd==0.3.4 218 | - python-dateutil==2.8.2 219 | - python-engineio==4.3.4 220 | - python-socketio==5.7.2 221 | - pytz==2022.1 222 | - pyyaml==6.0 223 | - pyzmq==23.1.0 224 | - redis==4.5.3 225 | - regex==2022.6.2 226 | - requests==2.28.0 227 | - requests-oauthlib==1.3.1 228 | - responses==0.18.0 229 | - rouge==1.0.0 230 | - rsa==4.8 231 | - sacrebleu==2.2.1 232 | - scikit-learn==1.1.1 233 | - scipy==1.8.1 234 | - seaborn==0.11.2 235 | - send2trash==1.8.0 236 | - sentencepiece==0.1.96 237 | - sniffio==1.2.0 238 | - soupsieve==2.3.2.post1 239 | - stack-data==0.3.0 240 | - tabulate==0.9.0 241 | - tensorboard==2.9.1 242 | - tensorboard-data-server==0.6.1 243 | - tensorboard-plugin-wit==1.8.1 244 | - tensorboardx==2.5.1 245 | - termcolor==1.1.0 246 | - terminado==0.15.0 247 | - threadpoolctl==3.1.0 248 | - tinycss2==1.1.1 249 | - tokenizers==0.12.1 250 | - tornado==6.1 251 | - tqdm==4.64.0 252 | - traitlets==5.3.0 253 | - transformers==4.26.0 254 | - urllib3==1.26.9 255 | - vine==5.0.0 256 | - wcwidth==0.2.5 257 | - webencodings==0.5.1 258 | - websocket-client==1.3.2 259 | - werkzeug==2.2.2 260 | - xxhash==3.0.0 261 | - yacs==0.1.8 262 | - yarl==1.7.2 263 | - zipp==3.8.0 264 | prefix: /private/home/xhan77/.conda/envs/cad 265 | -------------------------------------------------------------------------------- /eval/evaluate_summary.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from tqdm import tqdm 4 | from pathlib import Path 5 | # from datasets import load_dataset 6 | # from evaluate import load 7 | import statistics 8 | import json 9 | from collections import defaultdict 10 | import os 11 | import evaluate 12 | from ipdb import set_trace as bp 13 | import torch 14 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 15 | from datasets import load_dataset 16 | 17 | # evaluate fackKB: Put your huggingface access tokens 18 | access_token = 19 | tokenizer = AutoTokenizer.from_pretrained("roberta-base", padding="max_length", truncation=True) 20 | factkb = AutoModelForSequenceClassification.from_pretrained("bunsenfeng/FactKB", num_labels = 2, use_auth_token=access_token) 21 | 22 | def evaluate_qa(index2ex, eval_file): 23 | print(eval_file) 24 | all_gold = [] 25 | all_pred = [] 26 | all_doc = [] 27 | all_fact_score = [] 28 | 29 | if os.path.exists(eval_file) == False: 30 | return 0 31 | with open(eval_file, "r") as f: 32 | output_data = [json.loads(line) for line in f] 33 | cov_em_all = [] 34 | category2em = defaultdict(list) 35 | id2ex_output = {} 36 | for i, output in enumerate(output_data): 37 | index = output["input_index"] 38 | pred = output["string"][0] 39 | gold = index2ex[index]["gold_answers"] 40 | if len(pred) < 3: 41 | print(pred) 42 | continue 43 | all_gold.append(gold) 44 | all_pred.append(pred) 45 | if len(pred) < 3: 46 | print(f"pred: {pred}") 47 | 48 | article = index2ex[index]["article"] 49 | summary = pred 50 | input = [[summary, article]] 51 | tokens = tokenizer(input, return_tensors="pt", padding="max_length", truncation=True) 52 | result = torch.softmax(factkb(**tokens).logits, dim = 1) 53 | # bp() 54 | fact_score = result[0][1].item() 55 | 56 | all_fact_score.append(fact_score) 57 | all_doc.append(article) 58 | output_dict = index2ex[index].copy() 59 | output_dict["pred"] = pred 60 | id2ex_output[i] = output_dict 61 | 62 | print("fact_score: ", statistics.mean(all_fact_score)) 63 | # print(statistics.mean(cov_em_all)) 64 | rouge = evaluate.load('rouge') 65 | results = rouge.compute(predictions=all_pred, references=all_gold) 66 | print("rouge results: ", results) 67 | 68 | bertscore = evaluate.load("bertscore") 69 | results = bertscore.compute(predictions=all_pred, references=all_doc, lang="en") 70 | # print("bertscore: ", results) 71 | print("bertscore: ") 72 | for k, v in results.items(): 73 | if k in ["precision", "recall", "f1"]: 74 | print(f"{k}: {statistics.mean(v)}") 75 | return id2ex_output 76 | 77 | # read data 78 | def entity_data(dataset_path): 79 | raw_data = [] 80 | with open(dataset_path) as f: 81 | for line in f: 82 | ex = json.loads(line) 83 | if ex["assigned_process"] == 0: 84 | raw_data.append(ex) 85 | # break 86 | # raw_data = json.loads(f.read()) 87 | return raw_data 88 | 89 | 90 | if __name__ == "__main__": 91 | # args parse 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("--data_path", type=str, default="./eval/cnndm_example_input/cnndm_1_0.jsonl") 94 | parser.add_argument("--pred_path", type=str, default="./eval/cnndm_example_input/cnndm_1.5_-0.5.jsonl.output_topp0.9_genlen100.jsonl") 95 | args = parser.parse_args() 96 | 97 | data_path = args.data_path 98 | pred_path = args.pred_path 99 | index2ex = entity_data(data_path) 100 | evaluate_qa(index2ex, pred_path) 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /exp_cnndm.sh: -------------------------------------------------------------------------------- 1 | # cnndm 2 | GLOBALLEN="2048" 3 | MAXCTXLEN="1948" 4 | GENLEN="100" 5 | FN_PREFIX="eval/cnndm_example_input/cnndm" 6 | 7 | for WEIGHT in "1_0" "1.5_-0.5" 8 | do 9 | TOPP="0.9" 10 | TESTFILE="fin|${FN_PREFIX}_${WEIGHT}.jsonl" 11 | bash run_group_decode_fileio.sh 2023 "0,1" $TESTFILE $GLOBALLEN $MAXCTXLEN $GENLEN $TOPP 12 | done 13 | -------------------------------------------------------------------------------- /exp_nqswap.sh: -------------------------------------------------------------------------------- 1 | # nq-swap 2 | GLOBALLEN="2048" 3 | MAXCTXLEN="2038" 4 | GENLEN="10" 5 | FN_PREFIX="eval/nqswap_example_input/nqswap" 6 | 7 | for WEIGHT in "1_0" "2_-1" 8 | do 9 | TOPP="0.0" 10 | TESTFILE="fin|${FN_PREFIX}_${WEIGHT}.jsonl" 11 | bash run_group_decode_fileio.sh 2023 "0,1" $TESTFILE $GLOBALLEN $MAXCTXLEN $GENLEN $TOPP 12 | done 13 | -------------------------------------------------------------------------------- /group_decode_fileio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import datasets 6 | import torch 7 | 8 | import transformers 9 | import accelerate 10 | from accelerate import Accelerator 11 | from transformers import ( 12 | CONFIG_MAPPING, 13 | MODEL_MAPPING, 14 | AutoConfig, 15 | AutoModel, 16 | AutoModelForMaskedLM, 17 | AutoModelForCausalLM, 18 | AutoModelForSeq2SeqLM, 19 | AutoTokenizer, 20 | SchedulerType, 21 | ) 22 | 23 | import numpy as np 24 | from termcolor import colored 25 | import json 26 | from accelerate import InitProcessGroupKwargs 27 | import datetime 28 | 29 | 30 | logger = logging.getLogger(__name__) 31 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 32 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 33 | 34 | 35 | def logits_sampling_projection(logits, top_p, one_hot_value): 36 | assert len(logits.size()) == 3 37 | 38 | # get top-p indices 39 | probs = torch.nn.functional.softmax(logits, dim=-1) 40 | sorted_probs, indices = torch.sort(probs, dim=-1, descending=True) 41 | cum_sum_probs = torch.cumsum(sorted_probs, dim=-1) 42 | nucleus = cum_sum_probs < top_p 43 | nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1) 44 | valid_indices = nucleus.scatter(2, indices, nucleus) 45 | 46 | filtered_logits = logits.masked_fill(valid_indices == 0, torch.finfo(logits.dtype).min) 47 | m = torch.distributions.categorical.Categorical(logits=filtered_logits) 48 | selected = m.sample() 49 | return (2 * one_hot_value * torch.nn.functional.one_hot(selected, logits.size(2)) - one_hot_value) 50 | 51 | 52 | def filter_logits_top_p(logits, top_p, negative_multiplier=False): 53 | assert len(logits.size()) == 3 54 | 55 | # get top-p indices 56 | probs = torch.nn.functional.softmax(logits, dim=-1) 57 | sorted_probs, indices = torch.sort(probs, dim=-1, descending=True) 58 | cum_sum_probs = torch.cumsum(sorted_probs, dim=-1) 59 | nucleus = cum_sum_probs < top_p 60 | nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1) 61 | valid_indices = nucleus.scatter(2, indices, nucleus) 62 | 63 | if negative_multiplier: 64 | filtered_logits = logits.masked_fill(valid_indices == 0, 1000) 65 | else: 66 | filtered_logits = logits.masked_fill(valid_indices == 0, -1000) 67 | return filtered_logits 68 | 69 | 70 | def decode(args, batch_input_ids, dec_depth, model, tokenizer): 71 | batch_size = args.per_device_eval_batch_size 72 | assert batch_input_ids.size(1) == args.context_size 73 | assert args.decode_truncate_len >= 0 74 | assert (args.max_seq_length - args.context_size - args.decode_truncate_len) % dec_depth == 0 75 | unit_seq_len = int((args.max_seq_length - args.context_size - args.decode_truncate_len) / dec_depth) 76 | if args.context_size > 0: 77 | unit_context_input_ids = batch_input_ids[:, :args.context_size].clone() 78 | else: 79 | raise ValueError("context cannot be none") 80 | history_decode_ids = None 81 | 82 | past_key_values = None # necessary for causal models 83 | if args.model_category == 'seq2seq': 84 | model_kwargs = model._prepare_encoder_decoder_kwargs_for_generation( 85 | batch_input_ids[:, :args.context_size].clone(), dict(), None 86 | ) # this step includes encoding the context 87 | history_decode_ids = model._prepare_decoder_input_ids_for_generation( 88 | batch_input_ids.size(0), 89 | model_kwargs=model_kwargs, 90 | device=batch_input_ids.device, 91 | ) # create placeholder starter seq for decoding 92 | else: 93 | model_kwargs = None 94 | 95 | for _i in range(dec_depth): 96 | if args.model_category == 'causal': 97 | model_inputs = model.prepare_inputs_for_generation(unit_context_input_ids, past_key_values=past_key_values) 98 | outputs = model(**model_inputs, output_hidden_states=False) 99 | elif args.model_category == 'seq2seq': 100 | model_inputs = model.prepare_inputs_for_generation(history_decode_ids, **model_kwargs) # this incorporates past_key_values 101 | outputs = model(**model_inputs, output_hidden_states=False) 102 | else: 103 | raise ValueError("model category not supported") 104 | 105 | score = outputs.logits[:, -1:, :].clone().contiguous() 106 | 107 | if args.assigned_weight >= 0: 108 | score = filter_logits_top_p(score, top_p=args.filter_top_p) 109 | else: 110 | score = filter_logits_top_p(score, top_p=args.filter_top_p_prior, negative_multiplier=True) 111 | 112 | score = args.assigned_weight * score 113 | torch.distributed.all_reduce(score, group=args.gathering_group) 114 | 115 | projected_logits = logits_sampling_projection(score, top_p=args.projection_top_p, one_hot_value=args.one_hot_value) 116 | 117 | if not args.accelerator.is_main_process: 118 | projected_logits = torch.zeros_like(projected_logits) 119 | torch.distributed.all_reduce(projected_logits, group=args.gathering_group) 120 | 121 | simplex = torch.nn.functional.softmax(projected_logits, dim=-1) 122 | real_token_ids_list = torch.argmax(simplex, dim=-1).view(batch_size, unit_seq_len) 123 | 124 | if args.model_category == 'causal': 125 | unit_context_input_ids = torch.cat((unit_context_input_ids, real_token_ids_list), dim=1) # not really necessary but keeping 126 | 127 | if history_decode_ids is None: 128 | history_decode_ids = real_token_ids_list 129 | else: 130 | history_decode_ids = torch.cat((history_decode_ids, real_token_ids_list), dim=1) 131 | 132 | if args.model_category == 'causal': 133 | past_key_values = outputs.past_key_values 134 | elif args.model_category == 'seq2seq': 135 | model_kwargs["past_key_values"] = outputs.past_key_values 136 | 137 | # HACK: stop when seeing eos token, but asserting batch size is 1, unit_seq_len is 1, optimize later 138 | assert real_token_ids_list.size(0) == 1 139 | assert real_token_ids_list.size(1) == 1 140 | if real_token_ids_list[0][-1] == model.generation_config.eos_token_id: 141 | break 142 | 143 | if args.context_size > 0: 144 | init_context_input_ids = batch_input_ids[:, :args.context_size].clone() 145 | context_sequences = tokenizer.batch_decode(init_context_input_ids.detach().to('cpu'))#, skip_special_tokens=True) 146 | else: 147 | init_context_input_ids = None 148 | context_sequences = None 149 | sampled_sequences = tokenizer.batch_decode(history_decode_ids.clone().detach().to('cpu'), skip_special_tokens=True) 150 | logger.info(f"context: {context_sequences}") 151 | logger.info(f"sampled: {colored(str(sampled_sequences), 'red')}") 152 | 153 | return history_decode_ids, init_context_input_ids, None, sampled_sequences, context_sequences, None 154 | 155 | 156 | def parse_args(): 157 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a Masked Language Modeling task") 158 | parser.add_argument( 159 | "--model_name_or_path", 160 | type=str, 161 | help="Path to pretrained model or model identifier from huggingface.co/models.", 162 | required=True, 163 | ) 164 | parser.add_argument( 165 | "--config_name", 166 | type=str, 167 | default=None, 168 | help="Pretrained config name or path if not the same as model_name", 169 | ) 170 | parser.add_argument( 171 | "--tokenizer_name", 172 | type=str, 173 | default=None, 174 | help="Pretrained tokenizer name or path if not the same as model_name", 175 | ) 176 | parser.add_argument( 177 | "--use_slow_tokenizer", 178 | action="store_true", 179 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", 180 | ) 181 | parser.add_argument( 182 | "--per_device_eval_batch_size", 183 | type=int, 184 | default=1, 185 | help="Batch size (per device) for the evaluation dataloader.", 186 | ) 187 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 188 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 189 | parser.add_argument( 190 | "--model_type", 191 | type=str, 192 | default=None, 193 | help="Model type to use if training from scratch.", 194 | choices=MODEL_TYPES, 195 | ) 196 | parser.add_argument( 197 | "--max_seq_length", 198 | type=int, 199 | default=None, 200 | help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.", 201 | ) 202 | parser.add_argument("--init_blank_language_model", action="store_true", help="Whether or not to use a completely blank LM.") 203 | parser.add_argument( 204 | "--file_mode", type=str, default="", help="", 205 | ) 206 | parser.add_argument( 207 | "--train_mode", type=str, default="", help="", 208 | ) 209 | parser.add_argument( 210 | "--decode_truncate_len", type=int, default=50, help="", 211 | ) # how many to cut from right 212 | parser.add_argument( 213 | "--decode_depth", type=int, default=2, help="", 214 | ) 215 | parser.add_argument( 216 | "--projection_top_p", type=float, default=0.2, help="", 217 | ) 218 | parser.add_argument( 219 | "--filter_top_p", type=float, default=1.0, help="", 220 | ) 221 | parser.add_argument( 222 | "--filter_top_p_prior", type=float, default=1.0, help="", 223 | ) 224 | parser.add_argument("--big_model_inference", type=str, default="no") 225 | args = parser.parse_args() 226 | 227 | return args 228 | 229 | 230 | def main(): 231 | args = parse_args() 232 | 233 | accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=datetime.timedelta(seconds=259200))]) 234 | logging.basicConfig( 235 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 236 | datefmt="%m/%d/%Y %H:%M:%S", 237 | level=logging.INFO, 238 | ) 239 | logger.info(accelerator.state) 240 | 241 | logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) 242 | if accelerator.is_local_main_process: 243 | datasets.utils.logging.set_verbosity_warning() 244 | transformers.utils.logging.set_verbosity_info() 245 | else: 246 | datasets.utils.logging.set_verbosity_error() 247 | transformers.utils.logging.set_verbosity_error() 248 | 249 | if args.seed is not None: 250 | # set_seed(args.seed) 251 | accelerate.utils.set_seed(args.seed, device_specific=True) # differ slightly for each device 252 | 253 | if accelerator.is_main_process: 254 | pass 255 | # if args.output_dir is not None: 256 | # os.makedirs(args.output_dir, exist_ok=True) 257 | accelerator.wait_for_everyone() 258 | 259 | if args.train_mode == "decode": 260 | if len(args.model_name_or_path.split('|')) > 0: 261 | main_model_name = args.model_name_or_path.split('|')[0] 262 | fallback_model_name = args.model_name_or_path.split('|')[1] 263 | args.model_name_or_path = main_model_name 264 | args.orig_model_name_or_path = fallback_model_name 265 | else: 266 | args.orig_model_name_or_path = args.model_name_or_path 267 | else: 268 | raise ValueError("training should be in a separate file (irrelevant in context-aware decoding)") 269 | 270 | # Han: assign ensemble models 271 | args.file_mode = args.file_mode.split('|') 272 | assert args.file_mode[0] == "fin" 273 | assert os.path.exists(args.file_mode[1]) 274 | fin_path = args.file_mode[1] 275 | fin_data = [] 276 | with open(fin_path, 'r', encoding='utf-8') as f: 277 | for line in f: 278 | proc_line = line.strip() 279 | if proc_line: 280 | fin_data.append(json.loads(proc_line)) 281 | rank2model = dict() 282 | for _fd in fin_data: 283 | if _fd['assigned_process'] in rank2model: # sanity check 284 | assert ' '.join(rank2model[_fd['assigned_process']]) == ' '.join(_fd['assigned_model'].split('|')) 285 | else: 286 | rank2model[_fd['assigned_process']] = _fd['assigned_model'].split('|') 287 | 288 | # Han: add gathering group 289 | default_backend = torch.distributed.get_backend(torch.distributed.distributed_c10d._get_default_group()) 290 | args.gathering_group = torch.distributed.new_group(ranks=list(sorted(rank2model.keys())), backend=default_backend) 291 | 292 | if accelerator.process_index not in rank2model.keys(): # Han: exit if not in the ensemble 293 | return 294 | args.model_name_or_path = rank2model[accelerator.process_index][0] 295 | 296 | if args.config_name: 297 | config = AutoConfig.from_pretrained(args.config_name) 298 | elif args.model_name_or_path: 299 | if 'llama' in args.model_name_or_path.lower(): 300 | from transformers import LlamaConfig 301 | config = LlamaConfig.from_pretrained(args.model_name_or_path) 302 | else: 303 | config = AutoConfig.from_pretrained(args.model_name_or_path) 304 | else: 305 | config = CONFIG_MAPPING[args.model_type]() 306 | logger.warning("You are instantiating a new config instance from scratch.") 307 | 308 | if 'neox' in args.model_name_or_path.lower(): # Han: gpt-neox doesn't have a slow tokenizer, use GPTNeoXTokenizerFast 309 | from transformers import GPTNeoXTokenizerFast 310 | tokenizer = GPTNeoXTokenizerFast.from_pretrained(args.model_name_or_path) 311 | elif 'llama' in args.model_name_or_path.lower(): 312 | from transformers import LlamaTokenizer 313 | tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path) 314 | else: 315 | assert args.use_slow_tokenizer == True 316 | if args.tokenizer_name: 317 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer) 318 | elif args.model_name_or_path: 319 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) 320 | else: 321 | raise ValueError( 322 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 323 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 324 | ) 325 | 326 | if args.init_blank_language_model: 327 | raise ValueError("disabled") 328 | model = AutoModelForMaskedLM.from_config(config) 329 | elif args.model_name_or_path: 330 | if 't5' in args.model_name_or_path.lower() or 'tk' in args.model_name_or_path.lower(): 331 | model = AutoModelForSeq2SeqLM.from_pretrained( 332 | args.model_name_or_path, 333 | from_tf=bool(".ckpt" in args.model_name_or_path), 334 | config=config, 335 | ignore_mismatched_sizes=False, 336 | torch_dtype=torch.float16, 337 | ) 338 | args.model_category = 'seq2seq' 339 | model = model.to(accelerator.device) 340 | else: 341 | if 'llama' in args.model_name_or_path.lower(): # llama special case 342 | from transformers import LlamaForCausalLM 343 | if args.big_model_inference == 'no': 344 | model = LlamaForCausalLM.from_pretrained( 345 | args.model_name_or_path, 346 | torch_dtype=torch.float16, 347 | ) 348 | model = model.to(accelerator.device) 349 | else: 350 | # Han: we assume 8 GPUs 351 | if accelerator.process_index == 0: 352 | local_devices = [0, 2, 4, 6] 353 | elif accelerator.process_index == 1: 354 | local_devices = [1, 3, 5, 7] 355 | else: 356 | raise ValueError("check accelerator.process_index") 357 | # this is architecture specific 358 | my_device_map = {'model.embed_tokens': local_devices[0], 359 | 'lm_head': local_devices[0], 360 | 'model.norm': local_devices[0]} 361 | for _device_i, layer_idx_list in enumerate(np.array_split(np.arange(config.num_hidden_layers), len(local_devices))): 362 | for layer_idx in layer_idx_list: 363 | my_device_map[f'model.layers.{layer_idx}'] = local_devices[_device_i] 364 | model = LlamaForCausalLM.from_pretrained( 365 | args.model_name_or_path, 366 | device_map=my_device_map, 367 | torch_dtype=torch.float16, 368 | ) 369 | elif args.big_model_inference == 'no': 370 | model = AutoModelForCausalLM.from_pretrained( 371 | args.model_name_or_path, 372 | from_tf=bool(".ckpt" in args.model_name_or_path), 373 | config=config, 374 | ignore_mismatched_sizes=False, 375 | torch_dtype=torch.float16, 376 | ) 377 | model = model.to(accelerator.device) 378 | elif args.big_model_inference == 'yes' and 'opt' in args.model_name_or_path.lower(): 379 | # Han: we assume 8 GPUs 380 | if accelerator.process_index == 0: 381 | local_devices = [0, 2, 4, 6] 382 | elif accelerator.process_index == 1: 383 | local_devices = [1, 3, 5, 7] 384 | else: 385 | raise ValueError("check accelerator.process_index") 386 | # this is architecture specific 387 | my_device_map = {'model.decoder.embed_tokens': local_devices[0], 388 | 'lm_head': local_devices[0], 389 | 'model.decoder.embed_positions': local_devices[0], 390 | 'model.decoder.final_layer_norm': local_devices[0]} 391 | for _device_i, layer_idx_list in enumerate(np.array_split(np.arange(config.num_hidden_layers), len(local_devices))): 392 | for layer_idx in layer_idx_list: 393 | my_device_map[f'model.decoder.layers.{layer_idx}'] = local_devices[_device_i] 394 | model = AutoModelForCausalLM.from_pretrained( 395 | args.model_name_or_path, 396 | from_tf=bool(".ckpt" in args.model_name_or_path), 397 | config=config, 398 | ignore_mismatched_sizes=False, 399 | device_map=my_device_map, 400 | torch_dtype=torch.float16, 401 | ) 402 | elif args.big_model_inference == 'yes' and 'neox' in args.model_name_or_path.lower(): 403 | # Han: we assume 8 GPUs 404 | if accelerator.process_index == 0: 405 | local_devices = [0, 2, 4, 6] 406 | elif accelerator.process_index == 1: 407 | local_devices = [1, 3, 5, 7] 408 | else: 409 | raise ValueError("check accelerator.process_index") 410 | # this is architecture specific 411 | my_device_map = {'gpt_neox.embed_in': local_devices[0], 412 | 'embed_out': local_devices[0], 413 | 'gpt_neox.final_layer_norm': local_devices[0]} 414 | for _device_i, layer_idx_list in enumerate(np.array_split(np.arange(config.num_hidden_layers), len(local_devices))): 415 | for layer_idx in layer_idx_list: 416 | my_device_map[f'gpt_neox.layers.{layer_idx}'] = local_devices[_device_i] 417 | model = AutoModelForCausalLM.from_pretrained( 418 | args.model_name_or_path, 419 | from_tf=bool(".ckpt" in args.model_name_or_path), 420 | config=config, 421 | ignore_mismatched_sizes=False, 422 | device_map=my_device_map, 423 | torch_dtype=torch.float16, 424 | ) 425 | elif args.big_model_inference == 'yes' and 'neo' in args.model_name_or_path.lower(): 426 | # Han: we assume 8 GPUs 427 | if accelerator.process_index == 0: 428 | local_devices = [0, 2, 4, 6] 429 | elif accelerator.process_index == 1: 430 | local_devices = [1, 3, 5, 7] 431 | else: 432 | raise ValueError("check accelerator.process_index") 433 | # this is architecture specific 434 | my_device_map = {'transformer.wte': local_devices[0], 435 | 'lm_head': local_devices[0], 436 | 'transformer.wpe': local_devices[0], 437 | 'transformer.drop': local_devices[0], 438 | 'transformer.ln_f': local_devices[0]} 439 | for _device_i, layer_idx_list in enumerate(np.array_split(np.arange(config.num_hidden_layers), len(local_devices))): 440 | for layer_idx in layer_idx_list: 441 | my_device_map[f'transformer.h.{layer_idx}'] = local_devices[_device_i] 442 | model = AutoModelForCausalLM.from_pretrained( 443 | args.model_name_or_path, 444 | from_tf=bool(".ckpt" in args.model_name_or_path), 445 | config=config, 446 | ignore_mismatched_sizes=False, 447 | device_map=my_device_map, 448 | torch_dtype=torch.float16, 449 | ) 450 | else: 451 | raise ValueError("check args.big_model_inference") 452 | 453 | args.model_category = 'causal' 454 | model.forward = torch.cuda.amp.autocast(dtype=torch.float16)(model.forward) # referred to https://github.com/huggingface/accelerate/blob/38fd30e764ea87ef86e7d69fcba559c3605925b1/src/accelerate/accelerator.py#L1138 455 | model.forward = accelerate.utils.convert_outputs_to_fp32(model.forward) 456 | else: 457 | raise ValueError("specify --init_blank_language_model") 458 | 459 | model.resize_token_embeddings(len(tokenizer)) 460 | 461 | logger.info(f"model size: {sum(p.numel() for p in model.parameters())}") 462 | vocab_size = model.get_input_embeddings().weight.size(0) 463 | hidden_size = model.get_input_embeddings().weight.size(1) 464 | one_hot_value = 5.0 # unused 465 | 466 | ########################################## 467 | 468 | # change the output file name later 469 | out_json_fn = f"{fin_path}.output_topp{args.projection_top_p}_genlen{args.decode_depth}.jsonl" 470 | if accelerator.is_main_process: 471 | with open(out_json_fn, 'w') as f: 472 | f.write('placeholder, program not finished ...\n') 473 | 474 | args.tokenizer = tokenizer 475 | 476 | if args.train_mode == "decode": 477 | model.eval() 478 | 479 | args.one_hot_value = one_hot_value 480 | args.vocab_size = vocab_size 481 | args.hidden_size = hidden_size 482 | args.accelerator = accelerator 483 | 484 | export_list = [] 485 | args.orig_decode_truncate_len = args.decode_truncate_len 486 | with torch.no_grad(): 487 | for _fd in fin_data: # only support batch size 1 for now since the context size can be different across lines 488 | if _fd['assigned_process'] != args.accelerator.process_index: # remember to unblock barriers before this line 489 | continue 490 | args.assigned_weight = _fd['assigned_weight'] 491 | 492 | ctx_field_name = 'context_string' 493 | assert ctx_field_name in _fd 494 | assert args.per_device_eval_batch_size == 1 495 | 496 | input_ids = torch.LongTensor(tokenizer.encode(_fd[ctx_field_name], add_special_tokens=True)).unsqueeze(0).to(args.accelerator.device) 497 | args.context_size = input_ids.size(1) 498 | args.decode_truncate_len = args.orig_decode_truncate_len - args.context_size # Han: this compensates for the unknown input context size 499 | 500 | if 'filter_p' in _fd: # token filtering 501 | args.filter_top_p = _fd['filter_p'] 502 | if 'filter_p_prior' in _fd: 503 | args.filter_top_p_prior = _fd['filter_p_prior'] 504 | 505 | if args.decode_truncate_len < 0: 506 | continue # skipping very long examples 507 | logger.info(f"idx: {_fd['input_index']}") 508 | 509 | repeat_sample = 1 # change here manually if necessary 510 | for _r in range(repeat_sample): 511 | history_decode_ids, _, _, sampled_sequences, _, _ = \ 512 | decode(args, input_ids, args.decode_depth, model, tokenizer) 513 | if _r == 0: # first sample 514 | # export to jsonl 515 | for _i in range(args.per_device_eval_batch_size): 516 | export_dict = dict() 517 | export_dict['tokens'] = [history_decode_ids.tolist()[_i]] 518 | export_dict['string'] = [sampled_sequences[_i]] 519 | export_dict['assigned_process'] = _fd['assigned_process'] 520 | export_dict['assigned_model'] = args.model_name_or_path 521 | export_dict['output_index'] = len(export_list) 522 | export_dict['input_index'] = _fd['input_index'] 523 | export_list.append(export_dict) 524 | else: 525 | for _i in range(args.per_device_eval_batch_size): 526 | export_list[-(args.per_device_eval_batch_size - _i)]['tokens'].append(history_decode_ids.tolist()[_i]) 527 | export_list[-(args.per_device_eval_batch_size - _i)]['string'].append(sampled_sequences[_i]) 528 | 529 | if accelerator.is_main_process: 530 | if os.path.exists(out_json_fn): 531 | os.remove(out_json_fn) 532 | logger.info(f"Cleaning existing {out_json_fn}") 533 | with open(out_json_fn, mode="w") as f_out: # use mode 'a' if several processes are writing to the same file 534 | for export in export_list: 535 | f_out.write(json.dumps(export)) 536 | f_out.write("\n") 537 | 538 | 539 | if __name__ == "__main__": 540 | main() 541 | -------------------------------------------------------------------------------- /image/cad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xhan77/context-aware-decoding/1281d7fcd0e7e49d786c78d3170a3048dd9f6dc2/image/cad.png -------------------------------------------------------------------------------- /image/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xhan77/context-aware-decoding/1281d7fcd0e7e49d786c78d3170a3048dd9f6dc2/image/main.png -------------------------------------------------------------------------------- /image/no_cad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xhan77/context-aware-decoding/1281d7fcd0e7e49d786c78d3170a3048dd9f6dc2/image/no_cad.png -------------------------------------------------------------------------------- /run_group_decode_fileio.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | # trap "kill 0" EXIT 3 | 4 | hf_cache="/private/home/xhan77/.cache/huggingface" # CHANGE THIS TO YOUR OWN CACHE PATH 5 | 6 | numgpu=2 # should match the number of processes in the input jsonl file, default to 2 for context-aware decoding 7 | available_port=$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()') 8 | global_seed=$1 9 | multi_device_cuda=$2 # "0,1,2,3", "0", etc 10 | core_lm_name="specify_in_input_jsonl|n/a" # facebook/opt-1.3b, google/flan-t5-xl, etc. 11 | 12 | file_mode=$3 13 | global_max_seq_len=$4 # should be consistent with (<=) the specified models' max_seq_len 14 | decode_truncate_len=$5 # effective gen len is global_max_seq_len - decode_truncate_len 15 | decode_depth=$6 16 | projection_top_p=$7 17 | 18 | ################ START ################ 19 | 20 | CUDA_VISIBLE_DEVICES=${multi_device_cuda} HF_HOME=${hf_cache} accelerate launch \ 21 | --multi_gpu --mixed_precision fp16 \ 22 | --num_processes ${numgpu} --num_machines 1 --machine_rank 0 \ 23 | --main_process_port ${available_port} \ 24 | --num_cpu_threads_per_process 10 \ 25 | group_decode_fileio.py \ 26 | --max_seq_length ${global_max_seq_len} \ 27 | --model_name_or_path ${core_lm_name} \ 28 | --seed ${global_seed} \ 29 | --use_slow_tokenizer \ 30 | --file_mode ${file_mode} \ 31 | --decode_truncate_len ${decode_truncate_len} \ 32 | --decode_depth ${decode_depth} \ 33 | --train_mode decode \ 34 | --projection_top_p ${projection_top_p} 35 | --------------------------------------------------------------------------------