├── README.MD
├── demo.ipynb
├── environment.yml
├── eval
├── cnndm_example_input
│ ├── cnndm_1.5_-0.5.jsonl
│ ├── cnndm_1.5_-0.5.jsonl.output_topp0.9_genlen100.jsonl
│ ├── cnndm_1_0.jsonl
│ └── cnndm_1_0.jsonl.output_topp0.9_genlen100.jsonl
├── evaluate_summary.py
└── nqswap_example_input
│ ├── nqswap_1_0.jsonl
│ ├── nqswap_1_0.jsonl.output_topp0.0_genlen10.jsonl
│ ├── nqswap_2_-1.jsonl
│ └── nqswap_2_-1.jsonl.output_topp0.0_genlen10.jsonl
├── exp_cnndm.sh
├── exp_nqswap.sh
├── group_decode_fileio.py
├── image
├── cad.png
├── main.png
└── no_cad.png
└── run_group_decode_fileio.sh
/README.MD:
--------------------------------------------------------------------------------
1 | ## Context-aware decoding
2 | This repository provides an original implementation of [Trusting Your Evidence: Hallucinate Less with Context-aware Decoding](https://arxiv.org/pdf/2305.14739.pdf) by *Weijia Shi, *Xiaochuang Han, Mike Lewis, Yulia Tsvetkov, Luke Zettlemoyer, Scott Yih.
3 |
4 | ### Overview
5 |
6 | Context-aware decoding is a simple inference-time method to encourage language models to focus more on given contexts.
7 |
8 | With regular decoding, we sample tokens at each position from an autoregressive language model with:
9 |
10 |
11 |
12 | Context-aware decoding samples tokens at each position with:
13 |
14 |
15 |
16 | Here $\boldsymbol{c}$ denotes the context that the model should focus more on. Please see more details in [our paper](https://arxiv.org/abs/2305.14739). Here is an figure illustrating the difference between regular decoding and context-aware decoding.
17 |
18 |
19 |
20 |
21 | While context-aware decoding is based on the setup where one language model is decoded contrastively with different input contexts, *our code generally supports collaborative or contrastive decoding with multiple language models with different input contexts*. You can easily customize the inference setup with your own input `jsonl` file.
22 |
23 | ### Input format
24 | Below is an example input to our system (from NQ-Swap).
25 | ```json
26 | {
27 | "input_index": 0, // instances that decode together should have the same input_index
28 | "assigned_model": "huggyllama/llama-7b", // same model for all instances in context-aware decoding, but can use different models here, e.g., DExperts, contrastive decoding, proxy tuning, etc.
29 | "assigned_process": 0, // which GPU should take this instance
30 | "context_string": "The fourth season of Chicago Fire , an American drama television series with executive producer Dick Wolf , and producers Derek Haas , Michael Brandt , and Matt Olmstead , was ordered on February 5 , 2015 , by NBC , and premiered on October 13 , 2015 and concluded on May 17 , 2016 . The season contained 1078 episodes . How many episodes are in chicago fire season 4 ?", // the context-aware input
31 | "assigned_weight": 2, // weight for current instance/process (1+alpha, weights should add up to 1 by default, but can also incorporate sampling temperature if needed)
32 | "filter_p": 1.0, // optional filtering for low-probablity tokens, disabled by default
33 | }
34 | {
35 | "input_index": 0, // instances that decode together should have the same input_index
36 | "assigned_model": "huggyllama/llama-7b", // same model for all instances in context-aware decoding, but can use different models here, e.g., DExperts, contrastive decoding, proxy tuning, etc.
37 | "assigned_process": 1, // which GPU should take this instance
38 | "context_string": "How many episodes are in chicago fire season 4 ?", // the context-unaware input
39 | "assigned_weight": -1, // weight for current instance/process (-alpha, weights should add up to 1 by default, but can also incorporate sampling temperature if needed)
40 | }
41 | ...
42 | ```
43 |
44 | ### Running context-aware decoding on CNN-DM and NQ-Swap
45 | Run `bash exp_cnndm.sh` or `bash exp_nqswap.sh`. Both scripts call `run_group_decode_fileio.sh` which subsequently calls `group_decode_fileio.py`. The output will be saved in the same directory as the input files.
46 |
47 | The conda environment we used can be found in `environment.yml`. The main packages used are `pytorch`, `transformers`, and `accelerate`.
48 |
49 | ### Evaluation
50 | After generating the prediction data, you can run the evaluation by running the following script and compare with the gold data.
51 | ```bash
52 | PRED_PATH=./eval/cnndm_example_input/cnndm_1.5_-0.5.jsonl.output_topp0.9_genlen100.jsonl
53 | GOLD_DATA_PATH=./eval/cnndm_example_input/cnndm_1_0.jsonl
54 | python eval/evaluate_summary.py --pred_path $PRED_PATH --data_path $GOLD_DATA_PATH
55 | ```
56 | We provide our output for CNN-DM using standard decoding (`cnndm_1_0.jsonl.output_topp0.9_genlen100`) and context-aware decoding (`cnndm_1.5_-0.5.jsonl.output_topp0.9_genlen100.jsonl`) in `eval/cnndm_example_input/`.
57 |
--------------------------------------------------------------------------------
/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Context Aware Decoding Demo"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "test_token = \"\""
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 2,
22 | "metadata": {},
23 | "outputs": [
24 | {
25 | "name": "stderr",
26 | "output_type": "stream",
27 | "text": [
28 | "2024-11-16 06:19:19.335650: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
29 | "2024-11-16 06:19:19.506195: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
30 | "2024-11-16 06:19:19.506250: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
31 | "2024-11-16 06:19:19.532940: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
32 | "2024-11-16 06:19:19.583956: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
33 | "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
34 | "2024-11-16 06:19:20.311520: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
35 | "/home/kirekara/anaconda3/lib/python3.11/site-packages/pandas/core/arrays/masked.py:60: UserWarning: Pandas requires version '1.3.6' or newer of 'bottleneck' (version '1.3.5' currently installed).\n",
36 | " from pandas.core import (\n"
37 | ]
38 | },
39 | {
40 | "data": {
41 | "application/vnd.jupyter.widget-view+json": {
42 | "model_id": "ac00eb23a0c74c33b3f837ad0a279351",
43 | "version_major": 2,
44 | "version_minor": 0
45 | },
46 | "text/plain": [
47 | "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
48 | ]
49 | },
50 | "metadata": {},
51 | "output_type": "display_data"
52 | }
53 | ],
54 | "source": [
55 | "import torch\n",
56 | "from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, LogitsProcessor\n",
57 | "from torch.nn import functional as F\n",
58 | "\n",
59 | "model_name = \"google/gemma-2-2b-it\"\n",
60 | "tokenizer = AutoTokenizer.from_pretrained(model_name, token = test_token)\n",
61 | "model = AutoModelForCausalLM.from_pretrained(model_name, token = test_token)\n"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 3,
67 | "metadata": {},
68 | "outputs": [
69 | {
70 | "name": "stderr",
71 | "output_type": "stream",
72 | "text": [
73 | "/home/kirekara/anaconda3/lib/python3.11/site-packages/torch/cuda/__init__.py:128: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)\n",
74 | " return torch._C._cuda_getDeviceCount() > 0\n"
75 | ]
76 | },
77 | {
78 | "data": {
79 | "text/plain": [
80 | "Gemma2ForCausalLM(\n",
81 | " (model): Gemma2Model(\n",
82 | " (embed_tokens): Embedding(256000, 2304, padding_idx=0)\n",
83 | " (layers): ModuleList(\n",
84 | " (0-25): 26 x Gemma2DecoderLayer(\n",
85 | " (self_attn): Gemma2Attention(\n",
86 | " (q_proj): Linear(in_features=2304, out_features=2048, bias=False)\n",
87 | " (k_proj): Linear(in_features=2304, out_features=1024, bias=False)\n",
88 | " (v_proj): Linear(in_features=2304, out_features=1024, bias=False)\n",
89 | " (o_proj): Linear(in_features=2048, out_features=2304, bias=False)\n",
90 | " (rotary_emb): Gemma2RotaryEmbedding()\n",
91 | " )\n",
92 | " (mlp): Gemma2MLP(\n",
93 | " (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)\n",
94 | " (up_proj): Linear(in_features=2304, out_features=9216, bias=False)\n",
95 | " (down_proj): Linear(in_features=9216, out_features=2304, bias=False)\n",
96 | " (act_fn): PytorchGELUTanh()\n",
97 | " )\n",
98 | " (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
99 | " (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
100 | " (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
101 | " (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
102 | " )\n",
103 | " )\n",
104 | " (norm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
105 | " )\n",
106 | " (lm_head): Linear(in_features=2304, out_features=256000, bias=False)\n",
107 | ")"
108 | ]
109 | },
110 | "execution_count": 3,
111 | "metadata": {},
112 | "output_type": "execute_result"
113 | }
114 | ],
115 | "source": [
116 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
117 | "model.to(device)"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": 4,
123 | "metadata": {},
124 | "outputs": [],
125 | "source": [
126 | "context = \"The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.\"\n",
127 | "question = \"How many world cups has Argentina won?\"\n",
128 | "\n",
129 | "context_input = tokenizer(context, return_tensors=\"pt\").input_ids.to(device)\n",
130 | "question_input = tokenizer(question, return_tensors=\"pt\").input_ids.to(device)\n",
131 | "\n",
132 | "input_ids = torch.cat([context_input, question_input], dim=-1)\n",
133 | "\n",
134 | "\n",
135 | "def standard_decoding(input_ids, max_length=128, temperature=1.0, top_k=50, top_p=0.9):\n",
136 | " output_ids = model.generate(\n",
137 | " input_ids,\n",
138 | " max_length=max_length,\n",
139 | " temperature=temperature,\n",
140 | " top_k=top_k,\n",
141 | " top_p=top_p,\n",
142 | " do_sample=True,\n",
143 | " )\n",
144 | " return tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
145 | "\n",
146 | "def context_aware_sampling(model, tokenizer, input_ids, context_ids, alpha=0.9, max_length=128, temperature=1.0):\n",
147 | " generated_tokens = input_ids.clone()\n",
148 | " \n",
149 | " for _ in range(max_length):\n",
150 | " with torch.no_grad():\n",
151 | " full_context_outputs = model(generated_tokens)\n",
152 | " full_context_logits = full_context_outputs.logits[:, -1, :] \n",
153 | "\n",
154 | " question_only_input = generated_tokens[:, len(context_ids):]\n",
155 | " question_only_outputs = model(question_only_input)\n",
156 | " question_only_logits = question_only_outputs.logits[:, -1, :] \n",
157 | "\n",
158 | " adjusted_logits = (1 + alpha) * full_context_logits - alpha * question_only_logits\n",
159 | " adjusted_probs = F.softmax(adjusted_logits / temperature, dim=-1)\n",
160 | "\n",
161 | " next_token = torch.multinomial(adjusted_probs, num_samples=1)\n",
162 | "\n",
163 | " generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)\n",
164 | "\n",
165 | " if next_token.item() == tokenizer.eos_token_id:\n",
166 | " break\n",
167 | "\n",
168 | " return generated_tokens"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 5,
174 | "metadata": {},
175 | "outputs": [
176 | {
177 | "name": "stdout",
178 | "output_type": "stream",
179 | "text": [
180 | "Standard Decoding Output:\n",
181 | " The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?\n",
182 | "\n",
183 | "**Answer:** 3\n",
184 | " \n",
185 | "\n",
186 | "____________________________________________________________________________________________________\n",
187 | "Context-Aware Decoding Output:\n",
188 | " The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?\n",
189 | "\n",
190 | "Therefore, Argentina has won **4** World Cup\n",
191 | "\n",
192 | "Argentina won the 2027 World Cup. \n",
193 | "\n",
194 | "\n",
195 | " \n",
196 | "\n"
197 | ]
198 | }
199 | ],
200 | "source": [
201 | "model.eval()\n",
202 | "standard_output = standard_decoding(input_ids)\n",
203 | "output_tokens = context_aware_sampling(\n",
204 | " model,\n",
205 | " tokenizer,\n",
206 | " input_ids,\n",
207 | " context_ids=context_input,\n",
208 | " alpha=0.5,\n",
209 | " max_length=50,\n",
210 | " temperature=1.0,\n",
211 | " )\n",
212 | "\n",
213 | "context_aware_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)\n",
214 | "\n",
215 | "\n",
216 | "print(\"Standard Decoding Output:\\n\", standard_output)\n",
217 | "print(\"__\" * 50)\n",
218 | "print(\"Context-Aware Decoding Output:\\n\", context_aware_output)\n"
219 | ]
220 | }
221 | ],
222 | "metadata": {
223 | "kernelspec": {
224 | "display_name": "base",
225 | "language": "python",
226 | "name": "python3"
227 | },
228 | "language_info": {
229 | "codemirror_mode": {
230 | "name": "ipython",
231 | "version": 3
232 | },
233 | "file_extension": ".py",
234 | "mimetype": "text/x-python",
235 | "name": "python",
236 | "nbconvert_exporter": "python",
237 | "pygments_lexer": "ipython3",
238 | "version": "3.11.5"
239 | }
240 | },
241 | "nbformat": 4,
242 | "nbformat_minor": 2
243 | }
244 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: cad
2 | channels:
3 | - pytorch-nightly
4 | - pytorch
5 | - nvidia
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _openmp_mutex=5.1=1_gnu
10 | - blas=1.0=mkl
11 | - ca-certificates=2023.01.10=h06a4308_0
12 | - certifi=2022.12.7=py38h06a4308_0
13 | - cuda=11.7.1=0
14 | - cuda-cccl=11.7.91=0
15 | - cuda-command-line-tools=11.7.1=0
16 | - cuda-compiler=11.7.1=0
17 | - cuda-cudart=11.7.99=0
18 | - cuda-cudart-dev=11.7.99=0
19 | - cuda-cuobjdump=11.7.91=0
20 | - cuda-cupti=11.7.101=0
21 | - cuda-cuxxfilt=11.7.91=0
22 | - cuda-demo-suite=12.0.140=0
23 | - cuda-documentation=12.0.140=0
24 | - cuda-driver-dev=11.7.99=0
25 | - cuda-gdb=12.0.140=0
26 | - cuda-libraries=11.7.1=0
27 | - cuda-libraries-dev=11.7.1=0
28 | - cuda-memcheck=11.8.86=0
29 | - cuda-nsight=12.0.140=0
30 | - cuda-nsight-compute=12.0.1=0
31 | - cuda-nvcc=11.7.99=0
32 | - cuda-nvdisasm=12.0.140=0
33 | - cuda-nvml-dev=11.7.91=0
34 | - cuda-nvprof=12.0.146=0
35 | - cuda-nvprune=11.7.91=0
36 | - cuda-nvrtc=11.7.99=0
37 | - cuda-nvrtc-dev=11.7.99=0
38 | - cuda-nvtx=11.7.91=0
39 | - cuda-nvvp=12.0.146=0
40 | - cuda-runtime=11.7.1=0
41 | - cuda-sanitizer-api=12.0.140=0
42 | - cuda-toolkit=11.7.1=0
43 | - cuda-tools=11.7.1=0
44 | - cuda-visual-tools=11.7.1=0
45 | - cudatoolkit=11.3.1=h2bc3f7f_2
46 | - gds-tools=1.5.1.14=0
47 | - gmp=6.2.1=h295c915_3
48 | - gmpy2=2.1.2=py38heeb90bb_0
49 | - intel-openmp=2021.4.0=h06a4308_3561
50 | - ld_impl_linux-64=2.38=h1181459_1
51 | - libcublas=11.10.3.66=0
52 | - libcublas-dev=11.10.3.66=0
53 | - libcufft=10.7.2.124=h4fbf590_0
54 | - libcufft-dev=10.7.2.124=h98a8f43_0
55 | - libcufile=1.5.1.14=0
56 | - libcufile-dev=1.5.1.14=0
57 | - libcurand=10.3.1.124=0
58 | - libcurand-dev=10.3.1.124=0
59 | - libcusolver=11.4.0.1=0
60 | - libcusolver-dev=11.4.0.1=0
61 | - libcusparse=11.7.4.91=0
62 | - libcusparse-dev=11.7.4.91=0
63 | - libffi=3.3=he6710b0_2
64 | - libgcc-ng=11.2.0=h1234567_1
65 | - libgomp=11.2.0=h1234567_1
66 | - libnpp=11.7.4.75=0
67 | - libnpp-dev=11.7.4.75=0
68 | - libnvjpeg=11.8.0.2=0
69 | - libnvjpeg-dev=11.8.0.2=0
70 | - libstdcxx-ng=11.2.0=h1234567_1
71 | - libuv=1.40.0=h7b6447c_0
72 | - mkl=2021.4.0=h06a4308_640
73 | - mkl-service=2.4.0=py38h7f8727e_0
74 | - mkl_fft=1.3.1=py38hd3c417c_0
75 | - mkl_random=1.2.2=py38h51133e4_0
76 | - mpc=1.1.0=h10f8cd9_1
77 | - mpfr=4.0.2=hb69a4c5_1
78 | - mpmath=1.2.1=py38h06a4308_0
79 | - ncurses=6.3=h7f8727e_2
80 | - ninja-base=1.10.2=hd09550d_5
81 | - nsight-compute=2022.4.1.6=0
82 | - numpy-base=1.22.3=py38hf524024_0
83 | - openssl=1.1.1t=h7f8727e_0
84 | - pip=21.2.4=py38h06a4308_0
85 | - python=3.8.13=h12debd9_0
86 | - pytorch=1.13.1=py3.8_cuda11.7_cudnn8.5.0_0
87 | - pytorch-cuda=11.7=h778d358_3
88 | - pytorch-mutex=1.0=cuda
89 | - readline=8.1.2=h7f8727e_1
90 | - setuptools=61.2.0=py38h06a4308_0
91 | - six=1.16.0=pyhd3eb1b0_1
92 | - sqlite=3.38.3=hc218d9a_0
93 | - sympy=1.11.1=py38h06a4308_0
94 | - tk=8.6.12=h1ccaba5_0
95 | - torchtriton=2.0.0+0d7e753227=py38
96 | - typing_extensions=4.1.1=pyh06a4308_0
97 | - wheel=0.37.1=pyhd3eb1b0_0
98 | - xz=5.2.5=h7f8727e_1
99 | - zlib=1.2.12=h7f8727e_2
100 | - pip:
101 | - absl-py==1.1.0
102 | - accelerate==0.16.0
103 | - aiohttp==3.8.1
104 | - aiosignal==1.2.0
105 | - amqp==5.1.1
106 | - anyio==3.6.1
107 | - argon2-cffi==21.3.0
108 | - argon2-cffi-bindings==21.2.0
109 | - asttokens==2.0.5
110 | - async-timeout==4.0.2
111 | - atpublic==3.0.1
112 | - attrs==21.4.0
113 | - babel==2.10.3
114 | - backcall==0.2.0
115 | - beautifulsoup4==4.11.1
116 | - bert-score==0.3.11
117 | - bidict==0.22.1
118 | - billiard==3.6.4.0
119 | - bleach==5.0.0
120 | - cachelib==0.10.1
121 | - cachetools==5.2.0
122 | - celery==5.2.7
123 | - cffi==1.15.0
124 | - charset-normalizer==2.0.12
125 | - click==8.1.3
126 | - click-didyoumean==0.3.0
127 | - click-plugins==1.1.1
128 | - click-repl==0.2.0
129 | - colorama==0.4.5
130 | - cycler==0.11.0
131 | - datasets==2.10.0
132 | - debugpy==1.6.0
133 | - decorator==5.1.1
134 | - defusedxml==0.7.1
135 | - dill==0.3.5.1
136 | - entrypoints==0.4
137 | - executing==0.8.3
138 | - faiss-cpu==1.7.2
139 | - fastjsonschema==2.15.3
140 | - filelock==3.7.1
141 | - fire==0.5.0
142 | - flask==2.2.2
143 | - flask-session==0.4.0
144 | - flask-socketio==5.3.2
145 | - flufl-lock==7.0
146 | - fonttools==4.33.3
147 | - frozenlist==1.3.0
148 | - fsspec==2022.5.0
149 | - google-auth==2.8.0
150 | - google-auth-oauthlib==0.4.6
151 | - grpcio==1.46.3
152 | - hjson==3.1.0
153 | - huggingface-hub==0.11.0
154 | - idna==3.3
155 | - importlib-metadata==4.11.4
156 | - importlib-resources==5.8.0
157 | - ipykernel==6.15.0
158 | - ipython==8.4.0
159 | - ipython-genutils==0.2.0
160 | - itsdangerous==2.1.2
161 | - jedi==0.18.1
162 | - jinja2==3.1.2
163 | - joblib==1.1.0
164 | - json5==0.9.8
165 | - jsonschema==4.6.0
166 | - jupyter-client==7.3.4
167 | - jupyter-core==4.10.0
168 | - jupyter-server==1.17.1
169 | - jupyterlab==3.4.3
170 | - jupyterlab-pygments==0.2.2
171 | - jupyterlab-server==2.14.0
172 | - kiwisolver==1.4.3
173 | - kombu==5.2.4
174 | - lxml==4.9.1
175 | - markdown==3.3.7
176 | - markupsafe==2.1.1
177 | - matplotlib==3.5.2
178 | - matplotlib-inline==0.1.3
179 | - mauve-text==0.3.0
180 | - mistune==0.8.4
181 | - multidict==6.0.2
182 | - multiprocess==0.70.13
183 | - nbclassic==0.3.7
184 | - nbclient==0.6.4
185 | - nbconvert==6.5.0
186 | - nbformat==5.4.0
187 | - nest-asyncio==1.5.5
188 | - ninja==1.10.2.3
189 | - nltk==3.7
190 | - notebook==6.4.12
191 | - notebook-shim==0.1.0
192 | - numpy==1.22.4
193 | - oauthlib==3.2.0
194 | - packaging==21.3
195 | - pandas==1.4.2
196 | - pandocfilters==1.5.0
197 | - parso==0.8.3
198 | - pexpect==4.8.0
199 | - pickleshare==0.7.5
200 | - pillow==9.1.1
201 | - portalocker==2.5.1
202 | - prometheus-client==0.14.1
203 | - prompt-toolkit==3.0.29
204 | - protobuf==3.19.4
205 | - psutil==5.9.1
206 | - ptyprocess==0.7.0
207 | - pure-eval==0.2.2
208 | - py-cpuinfo==8.0.0
209 | - pyarrow==8.0.0
210 | - pyasn1==0.4.8
211 | - pyasn1-modules==0.2.8
212 | - pycparser==2.21
213 | - pydantic==1.10.2
214 | - pygments==2.12.0
215 | - pyparsing==3.0.9
216 | - pyrsistent==0.18.1
217 | - pysbd==0.3.4
218 | - python-dateutil==2.8.2
219 | - python-engineio==4.3.4
220 | - python-socketio==5.7.2
221 | - pytz==2022.1
222 | - pyyaml==6.0
223 | - pyzmq==23.1.0
224 | - redis==4.5.3
225 | - regex==2022.6.2
226 | - requests==2.28.0
227 | - requests-oauthlib==1.3.1
228 | - responses==0.18.0
229 | - rouge==1.0.0
230 | - rsa==4.8
231 | - sacrebleu==2.2.1
232 | - scikit-learn==1.1.1
233 | - scipy==1.8.1
234 | - seaborn==0.11.2
235 | - send2trash==1.8.0
236 | - sentencepiece==0.1.96
237 | - sniffio==1.2.0
238 | - soupsieve==2.3.2.post1
239 | - stack-data==0.3.0
240 | - tabulate==0.9.0
241 | - tensorboard==2.9.1
242 | - tensorboard-data-server==0.6.1
243 | - tensorboard-plugin-wit==1.8.1
244 | - tensorboardx==2.5.1
245 | - termcolor==1.1.0
246 | - terminado==0.15.0
247 | - threadpoolctl==3.1.0
248 | - tinycss2==1.1.1
249 | - tokenizers==0.12.1
250 | - tornado==6.1
251 | - tqdm==4.64.0
252 | - traitlets==5.3.0
253 | - transformers==4.26.0
254 | - urllib3==1.26.9
255 | - vine==5.0.0
256 | - wcwidth==0.2.5
257 | - webencodings==0.5.1
258 | - websocket-client==1.3.2
259 | - werkzeug==2.2.2
260 | - xxhash==3.0.0
261 | - yacs==0.1.8
262 | - yarl==1.7.2
263 | - zipp==3.8.0
264 | prefix: /private/home/xhan77/.conda/envs/cad
265 |
--------------------------------------------------------------------------------
/eval/evaluate_summary.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from tqdm import tqdm
4 | from pathlib import Path
5 | # from datasets import load_dataset
6 | # from evaluate import load
7 | import statistics
8 | import json
9 | from collections import defaultdict
10 | import os
11 | import evaluate
12 | from ipdb import set_trace as bp
13 | import torch
14 | from transformers import AutoTokenizer, AutoModelForSequenceClassification
15 | from datasets import load_dataset
16 |
17 | # evaluate fackKB: Put your huggingface access tokens
18 | access_token =
19 | tokenizer = AutoTokenizer.from_pretrained("roberta-base", padding="max_length", truncation=True)
20 | factkb = AutoModelForSequenceClassification.from_pretrained("bunsenfeng/FactKB", num_labels = 2, use_auth_token=access_token)
21 |
22 | def evaluate_qa(index2ex, eval_file):
23 | print(eval_file)
24 | all_gold = []
25 | all_pred = []
26 | all_doc = []
27 | all_fact_score = []
28 |
29 | if os.path.exists(eval_file) == False:
30 | return 0
31 | with open(eval_file, "r") as f:
32 | output_data = [json.loads(line) for line in f]
33 | cov_em_all = []
34 | category2em = defaultdict(list)
35 | id2ex_output = {}
36 | for i, output in enumerate(output_data):
37 | index = output["input_index"]
38 | pred = output["string"][0]
39 | gold = index2ex[index]["gold_answers"]
40 | if len(pred) < 3:
41 | print(pred)
42 | continue
43 | all_gold.append(gold)
44 | all_pred.append(pred)
45 | if len(pred) < 3:
46 | print(f"pred: {pred}")
47 |
48 | article = index2ex[index]["article"]
49 | summary = pred
50 | input = [[summary, article]]
51 | tokens = tokenizer(input, return_tensors="pt", padding="max_length", truncation=True)
52 | result = torch.softmax(factkb(**tokens).logits, dim = 1)
53 | # bp()
54 | fact_score = result[0][1].item()
55 |
56 | all_fact_score.append(fact_score)
57 | all_doc.append(article)
58 | output_dict = index2ex[index].copy()
59 | output_dict["pred"] = pred
60 | id2ex_output[i] = output_dict
61 |
62 | print("fact_score: ", statistics.mean(all_fact_score))
63 | # print(statistics.mean(cov_em_all))
64 | rouge = evaluate.load('rouge')
65 | results = rouge.compute(predictions=all_pred, references=all_gold)
66 | print("rouge results: ", results)
67 |
68 | bertscore = evaluate.load("bertscore")
69 | results = bertscore.compute(predictions=all_pred, references=all_doc, lang="en")
70 | # print("bertscore: ", results)
71 | print("bertscore: ")
72 | for k, v in results.items():
73 | if k in ["precision", "recall", "f1"]:
74 | print(f"{k}: {statistics.mean(v)}")
75 | return id2ex_output
76 |
77 | # read data
78 | def entity_data(dataset_path):
79 | raw_data = []
80 | with open(dataset_path) as f:
81 | for line in f:
82 | ex = json.loads(line)
83 | if ex["assigned_process"] == 0:
84 | raw_data.append(ex)
85 | # break
86 | # raw_data = json.loads(f.read())
87 | return raw_data
88 |
89 |
90 | if __name__ == "__main__":
91 | # args parse
92 | parser = argparse.ArgumentParser()
93 | parser.add_argument("--data_path", type=str, default="./eval/cnndm_example_input/cnndm_1_0.jsonl")
94 | parser.add_argument("--pred_path", type=str, default="./eval/cnndm_example_input/cnndm_1.5_-0.5.jsonl.output_topp0.9_genlen100.jsonl")
95 | args = parser.parse_args()
96 |
97 | data_path = args.data_path
98 | pred_path = args.pred_path
99 | index2ex = entity_data(data_path)
100 | evaluate_qa(index2ex, pred_path)
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/exp_cnndm.sh:
--------------------------------------------------------------------------------
1 | # cnndm
2 | GLOBALLEN="2048"
3 | MAXCTXLEN="1948"
4 | GENLEN="100"
5 | FN_PREFIX="eval/cnndm_example_input/cnndm"
6 |
7 | for WEIGHT in "1_0" "1.5_-0.5"
8 | do
9 | TOPP="0.9"
10 | TESTFILE="fin|${FN_PREFIX}_${WEIGHT}.jsonl"
11 | bash run_group_decode_fileio.sh 2023 "0,1" $TESTFILE $GLOBALLEN $MAXCTXLEN $GENLEN $TOPP
12 | done
13 |
--------------------------------------------------------------------------------
/exp_nqswap.sh:
--------------------------------------------------------------------------------
1 | # nq-swap
2 | GLOBALLEN="2048"
3 | MAXCTXLEN="2038"
4 | GENLEN="10"
5 | FN_PREFIX="eval/nqswap_example_input/nqswap"
6 |
7 | for WEIGHT in "1_0" "2_-1"
8 | do
9 | TOPP="0.0"
10 | TESTFILE="fin|${FN_PREFIX}_${WEIGHT}.jsonl"
11 | bash run_group_decode_fileio.sh 2023 "0,1" $TESTFILE $GLOBALLEN $MAXCTXLEN $GENLEN $TOPP
12 | done
13 |
--------------------------------------------------------------------------------
/group_decode_fileio.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 |
5 | import datasets
6 | import torch
7 |
8 | import transformers
9 | import accelerate
10 | from accelerate import Accelerator
11 | from transformers import (
12 | CONFIG_MAPPING,
13 | MODEL_MAPPING,
14 | AutoConfig,
15 | AutoModel,
16 | AutoModelForMaskedLM,
17 | AutoModelForCausalLM,
18 | AutoModelForSeq2SeqLM,
19 | AutoTokenizer,
20 | SchedulerType,
21 | )
22 |
23 | import numpy as np
24 | from termcolor import colored
25 | import json
26 | from accelerate import InitProcessGroupKwargs
27 | import datetime
28 |
29 |
30 | logger = logging.getLogger(__name__)
31 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
32 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
33 |
34 |
35 | def logits_sampling_projection(logits, top_p, one_hot_value):
36 | assert len(logits.size()) == 3
37 |
38 | # get top-p indices
39 | probs = torch.nn.functional.softmax(logits, dim=-1)
40 | sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
41 | cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
42 | nucleus = cum_sum_probs < top_p
43 | nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
44 | valid_indices = nucleus.scatter(2, indices, nucleus)
45 |
46 | filtered_logits = logits.masked_fill(valid_indices == 0, torch.finfo(logits.dtype).min)
47 | m = torch.distributions.categorical.Categorical(logits=filtered_logits)
48 | selected = m.sample()
49 | return (2 * one_hot_value * torch.nn.functional.one_hot(selected, logits.size(2)) - one_hot_value)
50 |
51 |
52 | def filter_logits_top_p(logits, top_p, negative_multiplier=False):
53 | assert len(logits.size()) == 3
54 |
55 | # get top-p indices
56 | probs = torch.nn.functional.softmax(logits, dim=-1)
57 | sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
58 | cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
59 | nucleus = cum_sum_probs < top_p
60 | nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
61 | valid_indices = nucleus.scatter(2, indices, nucleus)
62 |
63 | if negative_multiplier:
64 | filtered_logits = logits.masked_fill(valid_indices == 0, 1000)
65 | else:
66 | filtered_logits = logits.masked_fill(valid_indices == 0, -1000)
67 | return filtered_logits
68 |
69 |
70 | def decode(args, batch_input_ids, dec_depth, model, tokenizer):
71 | batch_size = args.per_device_eval_batch_size
72 | assert batch_input_ids.size(1) == args.context_size
73 | assert args.decode_truncate_len >= 0
74 | assert (args.max_seq_length - args.context_size - args.decode_truncate_len) % dec_depth == 0
75 | unit_seq_len = int((args.max_seq_length - args.context_size - args.decode_truncate_len) / dec_depth)
76 | if args.context_size > 0:
77 | unit_context_input_ids = batch_input_ids[:, :args.context_size].clone()
78 | else:
79 | raise ValueError("context cannot be none")
80 | history_decode_ids = None
81 |
82 | past_key_values = None # necessary for causal models
83 | if args.model_category == 'seq2seq':
84 | model_kwargs = model._prepare_encoder_decoder_kwargs_for_generation(
85 | batch_input_ids[:, :args.context_size].clone(), dict(), None
86 | ) # this step includes encoding the context
87 | history_decode_ids = model._prepare_decoder_input_ids_for_generation(
88 | batch_input_ids.size(0),
89 | model_kwargs=model_kwargs,
90 | device=batch_input_ids.device,
91 | ) # create placeholder starter seq for decoding
92 | else:
93 | model_kwargs = None
94 |
95 | for _i in range(dec_depth):
96 | if args.model_category == 'causal':
97 | model_inputs = model.prepare_inputs_for_generation(unit_context_input_ids, past_key_values=past_key_values)
98 | outputs = model(**model_inputs, output_hidden_states=False)
99 | elif args.model_category == 'seq2seq':
100 | model_inputs = model.prepare_inputs_for_generation(history_decode_ids, **model_kwargs) # this incorporates past_key_values
101 | outputs = model(**model_inputs, output_hidden_states=False)
102 | else:
103 | raise ValueError("model category not supported")
104 |
105 | score = outputs.logits[:, -1:, :].clone().contiguous()
106 |
107 | if args.assigned_weight >= 0:
108 | score = filter_logits_top_p(score, top_p=args.filter_top_p)
109 | else:
110 | score = filter_logits_top_p(score, top_p=args.filter_top_p_prior, negative_multiplier=True)
111 |
112 | score = args.assigned_weight * score
113 | torch.distributed.all_reduce(score, group=args.gathering_group)
114 |
115 | projected_logits = logits_sampling_projection(score, top_p=args.projection_top_p, one_hot_value=args.one_hot_value)
116 |
117 | if not args.accelerator.is_main_process:
118 | projected_logits = torch.zeros_like(projected_logits)
119 | torch.distributed.all_reduce(projected_logits, group=args.gathering_group)
120 |
121 | simplex = torch.nn.functional.softmax(projected_logits, dim=-1)
122 | real_token_ids_list = torch.argmax(simplex, dim=-1).view(batch_size, unit_seq_len)
123 |
124 | if args.model_category == 'causal':
125 | unit_context_input_ids = torch.cat((unit_context_input_ids, real_token_ids_list), dim=1) # not really necessary but keeping
126 |
127 | if history_decode_ids is None:
128 | history_decode_ids = real_token_ids_list
129 | else:
130 | history_decode_ids = torch.cat((history_decode_ids, real_token_ids_list), dim=1)
131 |
132 | if args.model_category == 'causal':
133 | past_key_values = outputs.past_key_values
134 | elif args.model_category == 'seq2seq':
135 | model_kwargs["past_key_values"] = outputs.past_key_values
136 |
137 | # HACK: stop when seeing eos token, but asserting batch size is 1, unit_seq_len is 1, optimize later
138 | assert real_token_ids_list.size(0) == 1
139 | assert real_token_ids_list.size(1) == 1
140 | if real_token_ids_list[0][-1] == model.generation_config.eos_token_id:
141 | break
142 |
143 | if args.context_size > 0:
144 | init_context_input_ids = batch_input_ids[:, :args.context_size].clone()
145 | context_sequences = tokenizer.batch_decode(init_context_input_ids.detach().to('cpu'))#, skip_special_tokens=True)
146 | else:
147 | init_context_input_ids = None
148 | context_sequences = None
149 | sampled_sequences = tokenizer.batch_decode(history_decode_ids.clone().detach().to('cpu'), skip_special_tokens=True)
150 | logger.info(f"context: {context_sequences}")
151 | logger.info(f"sampled: {colored(str(sampled_sequences), 'red')}")
152 |
153 | return history_decode_ids, init_context_input_ids, None, sampled_sequences, context_sequences, None
154 |
155 |
156 | def parse_args():
157 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a Masked Language Modeling task")
158 | parser.add_argument(
159 | "--model_name_or_path",
160 | type=str,
161 | help="Path to pretrained model or model identifier from huggingface.co/models.",
162 | required=True,
163 | )
164 | parser.add_argument(
165 | "--config_name",
166 | type=str,
167 | default=None,
168 | help="Pretrained config name or path if not the same as model_name",
169 | )
170 | parser.add_argument(
171 | "--tokenizer_name",
172 | type=str,
173 | default=None,
174 | help="Pretrained tokenizer name or path if not the same as model_name",
175 | )
176 | parser.add_argument(
177 | "--use_slow_tokenizer",
178 | action="store_true",
179 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
180 | )
181 | parser.add_argument(
182 | "--per_device_eval_batch_size",
183 | type=int,
184 | default=1,
185 | help="Batch size (per device) for the evaluation dataloader.",
186 | )
187 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
188 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
189 | parser.add_argument(
190 | "--model_type",
191 | type=str,
192 | default=None,
193 | help="Model type to use if training from scratch.",
194 | choices=MODEL_TYPES,
195 | )
196 | parser.add_argument(
197 | "--max_seq_length",
198 | type=int,
199 | default=None,
200 | help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.",
201 | )
202 | parser.add_argument("--init_blank_language_model", action="store_true", help="Whether or not to use a completely blank LM.")
203 | parser.add_argument(
204 | "--file_mode", type=str, default="", help="",
205 | )
206 | parser.add_argument(
207 | "--train_mode", type=str, default="", help="",
208 | )
209 | parser.add_argument(
210 | "--decode_truncate_len", type=int, default=50, help="",
211 | ) # how many to cut from right
212 | parser.add_argument(
213 | "--decode_depth", type=int, default=2, help="",
214 | )
215 | parser.add_argument(
216 | "--projection_top_p", type=float, default=0.2, help="",
217 | )
218 | parser.add_argument(
219 | "--filter_top_p", type=float, default=1.0, help="",
220 | )
221 | parser.add_argument(
222 | "--filter_top_p_prior", type=float, default=1.0, help="",
223 | )
224 | parser.add_argument("--big_model_inference", type=str, default="no")
225 | args = parser.parse_args()
226 |
227 | return args
228 |
229 |
230 | def main():
231 | args = parse_args()
232 |
233 | accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=datetime.timedelta(seconds=259200))])
234 | logging.basicConfig(
235 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
236 | datefmt="%m/%d/%Y %H:%M:%S",
237 | level=logging.INFO,
238 | )
239 | logger.info(accelerator.state)
240 |
241 | logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
242 | if accelerator.is_local_main_process:
243 | datasets.utils.logging.set_verbosity_warning()
244 | transformers.utils.logging.set_verbosity_info()
245 | else:
246 | datasets.utils.logging.set_verbosity_error()
247 | transformers.utils.logging.set_verbosity_error()
248 |
249 | if args.seed is not None:
250 | # set_seed(args.seed)
251 | accelerate.utils.set_seed(args.seed, device_specific=True) # differ slightly for each device
252 |
253 | if accelerator.is_main_process:
254 | pass
255 | # if args.output_dir is not None:
256 | # os.makedirs(args.output_dir, exist_ok=True)
257 | accelerator.wait_for_everyone()
258 |
259 | if args.train_mode == "decode":
260 | if len(args.model_name_or_path.split('|')) > 0:
261 | main_model_name = args.model_name_or_path.split('|')[0]
262 | fallback_model_name = args.model_name_or_path.split('|')[1]
263 | args.model_name_or_path = main_model_name
264 | args.orig_model_name_or_path = fallback_model_name
265 | else:
266 | args.orig_model_name_or_path = args.model_name_or_path
267 | else:
268 | raise ValueError("training should be in a separate file (irrelevant in context-aware decoding)")
269 |
270 | # Han: assign ensemble models
271 | args.file_mode = args.file_mode.split('|')
272 | assert args.file_mode[0] == "fin"
273 | assert os.path.exists(args.file_mode[1])
274 | fin_path = args.file_mode[1]
275 | fin_data = []
276 | with open(fin_path, 'r', encoding='utf-8') as f:
277 | for line in f:
278 | proc_line = line.strip()
279 | if proc_line:
280 | fin_data.append(json.loads(proc_line))
281 | rank2model = dict()
282 | for _fd in fin_data:
283 | if _fd['assigned_process'] in rank2model: # sanity check
284 | assert ' '.join(rank2model[_fd['assigned_process']]) == ' '.join(_fd['assigned_model'].split('|'))
285 | else:
286 | rank2model[_fd['assigned_process']] = _fd['assigned_model'].split('|')
287 |
288 | # Han: add gathering group
289 | default_backend = torch.distributed.get_backend(torch.distributed.distributed_c10d._get_default_group())
290 | args.gathering_group = torch.distributed.new_group(ranks=list(sorted(rank2model.keys())), backend=default_backend)
291 |
292 | if accelerator.process_index not in rank2model.keys(): # Han: exit if not in the ensemble
293 | return
294 | args.model_name_or_path = rank2model[accelerator.process_index][0]
295 |
296 | if args.config_name:
297 | config = AutoConfig.from_pretrained(args.config_name)
298 | elif args.model_name_or_path:
299 | if 'llama' in args.model_name_or_path.lower():
300 | from transformers import LlamaConfig
301 | config = LlamaConfig.from_pretrained(args.model_name_or_path)
302 | else:
303 | config = AutoConfig.from_pretrained(args.model_name_or_path)
304 | else:
305 | config = CONFIG_MAPPING[args.model_type]()
306 | logger.warning("You are instantiating a new config instance from scratch.")
307 |
308 | if 'neox' in args.model_name_or_path.lower(): # Han: gpt-neox doesn't have a slow tokenizer, use GPTNeoXTokenizerFast
309 | from transformers import GPTNeoXTokenizerFast
310 | tokenizer = GPTNeoXTokenizerFast.from_pretrained(args.model_name_or_path)
311 | elif 'llama' in args.model_name_or_path.lower():
312 | from transformers import LlamaTokenizer
313 | tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path)
314 | else:
315 | assert args.use_slow_tokenizer == True
316 | if args.tokenizer_name:
317 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
318 | elif args.model_name_or_path:
319 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
320 | else:
321 | raise ValueError(
322 | "You are instantiating a new tokenizer from scratch. This is not supported by this script."
323 | "You can do it from another script, save it, and load it from here, using --tokenizer_name."
324 | )
325 |
326 | if args.init_blank_language_model:
327 | raise ValueError("disabled")
328 | model = AutoModelForMaskedLM.from_config(config)
329 | elif args.model_name_or_path:
330 | if 't5' in args.model_name_or_path.lower() or 'tk' in args.model_name_or_path.lower():
331 | model = AutoModelForSeq2SeqLM.from_pretrained(
332 | args.model_name_or_path,
333 | from_tf=bool(".ckpt" in args.model_name_or_path),
334 | config=config,
335 | ignore_mismatched_sizes=False,
336 | torch_dtype=torch.float16,
337 | )
338 | args.model_category = 'seq2seq'
339 | model = model.to(accelerator.device)
340 | else:
341 | if 'llama' in args.model_name_or_path.lower(): # llama special case
342 | from transformers import LlamaForCausalLM
343 | if args.big_model_inference == 'no':
344 | model = LlamaForCausalLM.from_pretrained(
345 | args.model_name_or_path,
346 | torch_dtype=torch.float16,
347 | )
348 | model = model.to(accelerator.device)
349 | else:
350 | # Han: we assume 8 GPUs
351 | if accelerator.process_index == 0:
352 | local_devices = [0, 2, 4, 6]
353 | elif accelerator.process_index == 1:
354 | local_devices = [1, 3, 5, 7]
355 | else:
356 | raise ValueError("check accelerator.process_index")
357 | # this is architecture specific
358 | my_device_map = {'model.embed_tokens': local_devices[0],
359 | 'lm_head': local_devices[0],
360 | 'model.norm': local_devices[0]}
361 | for _device_i, layer_idx_list in enumerate(np.array_split(np.arange(config.num_hidden_layers), len(local_devices))):
362 | for layer_idx in layer_idx_list:
363 | my_device_map[f'model.layers.{layer_idx}'] = local_devices[_device_i]
364 | model = LlamaForCausalLM.from_pretrained(
365 | args.model_name_or_path,
366 | device_map=my_device_map,
367 | torch_dtype=torch.float16,
368 | )
369 | elif args.big_model_inference == 'no':
370 | model = AutoModelForCausalLM.from_pretrained(
371 | args.model_name_or_path,
372 | from_tf=bool(".ckpt" in args.model_name_or_path),
373 | config=config,
374 | ignore_mismatched_sizes=False,
375 | torch_dtype=torch.float16,
376 | )
377 | model = model.to(accelerator.device)
378 | elif args.big_model_inference == 'yes' and 'opt' in args.model_name_or_path.lower():
379 | # Han: we assume 8 GPUs
380 | if accelerator.process_index == 0:
381 | local_devices = [0, 2, 4, 6]
382 | elif accelerator.process_index == 1:
383 | local_devices = [1, 3, 5, 7]
384 | else:
385 | raise ValueError("check accelerator.process_index")
386 | # this is architecture specific
387 | my_device_map = {'model.decoder.embed_tokens': local_devices[0],
388 | 'lm_head': local_devices[0],
389 | 'model.decoder.embed_positions': local_devices[0],
390 | 'model.decoder.final_layer_norm': local_devices[0]}
391 | for _device_i, layer_idx_list in enumerate(np.array_split(np.arange(config.num_hidden_layers), len(local_devices))):
392 | for layer_idx in layer_idx_list:
393 | my_device_map[f'model.decoder.layers.{layer_idx}'] = local_devices[_device_i]
394 | model = AutoModelForCausalLM.from_pretrained(
395 | args.model_name_or_path,
396 | from_tf=bool(".ckpt" in args.model_name_or_path),
397 | config=config,
398 | ignore_mismatched_sizes=False,
399 | device_map=my_device_map,
400 | torch_dtype=torch.float16,
401 | )
402 | elif args.big_model_inference == 'yes' and 'neox' in args.model_name_or_path.lower():
403 | # Han: we assume 8 GPUs
404 | if accelerator.process_index == 0:
405 | local_devices = [0, 2, 4, 6]
406 | elif accelerator.process_index == 1:
407 | local_devices = [1, 3, 5, 7]
408 | else:
409 | raise ValueError("check accelerator.process_index")
410 | # this is architecture specific
411 | my_device_map = {'gpt_neox.embed_in': local_devices[0],
412 | 'embed_out': local_devices[0],
413 | 'gpt_neox.final_layer_norm': local_devices[0]}
414 | for _device_i, layer_idx_list in enumerate(np.array_split(np.arange(config.num_hidden_layers), len(local_devices))):
415 | for layer_idx in layer_idx_list:
416 | my_device_map[f'gpt_neox.layers.{layer_idx}'] = local_devices[_device_i]
417 | model = AutoModelForCausalLM.from_pretrained(
418 | args.model_name_or_path,
419 | from_tf=bool(".ckpt" in args.model_name_or_path),
420 | config=config,
421 | ignore_mismatched_sizes=False,
422 | device_map=my_device_map,
423 | torch_dtype=torch.float16,
424 | )
425 | elif args.big_model_inference == 'yes' and 'neo' in args.model_name_or_path.lower():
426 | # Han: we assume 8 GPUs
427 | if accelerator.process_index == 0:
428 | local_devices = [0, 2, 4, 6]
429 | elif accelerator.process_index == 1:
430 | local_devices = [1, 3, 5, 7]
431 | else:
432 | raise ValueError("check accelerator.process_index")
433 | # this is architecture specific
434 | my_device_map = {'transformer.wte': local_devices[0],
435 | 'lm_head': local_devices[0],
436 | 'transformer.wpe': local_devices[0],
437 | 'transformer.drop': local_devices[0],
438 | 'transformer.ln_f': local_devices[0]}
439 | for _device_i, layer_idx_list in enumerate(np.array_split(np.arange(config.num_hidden_layers), len(local_devices))):
440 | for layer_idx in layer_idx_list:
441 | my_device_map[f'transformer.h.{layer_idx}'] = local_devices[_device_i]
442 | model = AutoModelForCausalLM.from_pretrained(
443 | args.model_name_or_path,
444 | from_tf=bool(".ckpt" in args.model_name_or_path),
445 | config=config,
446 | ignore_mismatched_sizes=False,
447 | device_map=my_device_map,
448 | torch_dtype=torch.float16,
449 | )
450 | else:
451 | raise ValueError("check args.big_model_inference")
452 |
453 | args.model_category = 'causal'
454 | model.forward = torch.cuda.amp.autocast(dtype=torch.float16)(model.forward) # referred to https://github.com/huggingface/accelerate/blob/38fd30e764ea87ef86e7d69fcba559c3605925b1/src/accelerate/accelerator.py#L1138
455 | model.forward = accelerate.utils.convert_outputs_to_fp32(model.forward)
456 | else:
457 | raise ValueError("specify --init_blank_language_model")
458 |
459 | model.resize_token_embeddings(len(tokenizer))
460 |
461 | logger.info(f"model size: {sum(p.numel() for p in model.parameters())}")
462 | vocab_size = model.get_input_embeddings().weight.size(0)
463 | hidden_size = model.get_input_embeddings().weight.size(1)
464 | one_hot_value = 5.0 # unused
465 |
466 | ##########################################
467 |
468 | # change the output file name later
469 | out_json_fn = f"{fin_path}.output_topp{args.projection_top_p}_genlen{args.decode_depth}.jsonl"
470 | if accelerator.is_main_process:
471 | with open(out_json_fn, 'w') as f:
472 | f.write('placeholder, program not finished ...\n')
473 |
474 | args.tokenizer = tokenizer
475 |
476 | if args.train_mode == "decode":
477 | model.eval()
478 |
479 | args.one_hot_value = one_hot_value
480 | args.vocab_size = vocab_size
481 | args.hidden_size = hidden_size
482 | args.accelerator = accelerator
483 |
484 | export_list = []
485 | args.orig_decode_truncate_len = args.decode_truncate_len
486 | with torch.no_grad():
487 | for _fd in fin_data: # only support batch size 1 for now since the context size can be different across lines
488 | if _fd['assigned_process'] != args.accelerator.process_index: # remember to unblock barriers before this line
489 | continue
490 | args.assigned_weight = _fd['assigned_weight']
491 |
492 | ctx_field_name = 'context_string'
493 | assert ctx_field_name in _fd
494 | assert args.per_device_eval_batch_size == 1
495 |
496 | input_ids = torch.LongTensor(tokenizer.encode(_fd[ctx_field_name], add_special_tokens=True)).unsqueeze(0).to(args.accelerator.device)
497 | args.context_size = input_ids.size(1)
498 | args.decode_truncate_len = args.orig_decode_truncate_len - args.context_size # Han: this compensates for the unknown input context size
499 |
500 | if 'filter_p' in _fd: # token filtering
501 | args.filter_top_p = _fd['filter_p']
502 | if 'filter_p_prior' in _fd:
503 | args.filter_top_p_prior = _fd['filter_p_prior']
504 |
505 | if args.decode_truncate_len < 0:
506 | continue # skipping very long examples
507 | logger.info(f"idx: {_fd['input_index']}")
508 |
509 | repeat_sample = 1 # change here manually if necessary
510 | for _r in range(repeat_sample):
511 | history_decode_ids, _, _, sampled_sequences, _, _ = \
512 | decode(args, input_ids, args.decode_depth, model, tokenizer)
513 | if _r == 0: # first sample
514 | # export to jsonl
515 | for _i in range(args.per_device_eval_batch_size):
516 | export_dict = dict()
517 | export_dict['tokens'] = [history_decode_ids.tolist()[_i]]
518 | export_dict['string'] = [sampled_sequences[_i]]
519 | export_dict['assigned_process'] = _fd['assigned_process']
520 | export_dict['assigned_model'] = args.model_name_or_path
521 | export_dict['output_index'] = len(export_list)
522 | export_dict['input_index'] = _fd['input_index']
523 | export_list.append(export_dict)
524 | else:
525 | for _i in range(args.per_device_eval_batch_size):
526 | export_list[-(args.per_device_eval_batch_size - _i)]['tokens'].append(history_decode_ids.tolist()[_i])
527 | export_list[-(args.per_device_eval_batch_size - _i)]['string'].append(sampled_sequences[_i])
528 |
529 | if accelerator.is_main_process:
530 | if os.path.exists(out_json_fn):
531 | os.remove(out_json_fn)
532 | logger.info(f"Cleaning existing {out_json_fn}")
533 | with open(out_json_fn, mode="w") as f_out: # use mode 'a' if several processes are writing to the same file
534 | for export in export_list:
535 | f_out.write(json.dumps(export))
536 | f_out.write("\n")
537 |
538 |
539 | if __name__ == "__main__":
540 | main()
541 |
--------------------------------------------------------------------------------
/image/cad.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xhan77/context-aware-decoding/1281d7fcd0e7e49d786c78d3170a3048dd9f6dc2/image/cad.png
--------------------------------------------------------------------------------
/image/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xhan77/context-aware-decoding/1281d7fcd0e7e49d786c78d3170a3048dd9f6dc2/image/main.png
--------------------------------------------------------------------------------
/image/no_cad.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xhan77/context-aware-decoding/1281d7fcd0e7e49d786c78d3170a3048dd9f6dc2/image/no_cad.png
--------------------------------------------------------------------------------
/run_group_decode_fileio.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/bash
2 | # trap "kill 0" EXIT
3 |
4 | hf_cache="/private/home/xhan77/.cache/huggingface" # CHANGE THIS TO YOUR OWN CACHE PATH
5 |
6 | numgpu=2 # should match the number of processes in the input jsonl file, default to 2 for context-aware decoding
7 | available_port=$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()')
8 | global_seed=$1
9 | multi_device_cuda=$2 # "0,1,2,3", "0", etc
10 | core_lm_name="specify_in_input_jsonl|n/a" # facebook/opt-1.3b, google/flan-t5-xl, etc.
11 |
12 | file_mode=$3
13 | global_max_seq_len=$4 # should be consistent with (<=) the specified models' max_seq_len
14 | decode_truncate_len=$5 # effective gen len is global_max_seq_len - decode_truncate_len
15 | decode_depth=$6
16 | projection_top_p=$7
17 |
18 | ################ START ################
19 |
20 | CUDA_VISIBLE_DEVICES=${multi_device_cuda} HF_HOME=${hf_cache} accelerate launch \
21 | --multi_gpu --mixed_precision fp16 \
22 | --num_processes ${numgpu} --num_machines 1 --machine_rank 0 \
23 | --main_process_port ${available_port} \
24 | --num_cpu_threads_per_process 10 \
25 | group_decode_fileio.py \
26 | --max_seq_length ${global_max_seq_len} \
27 | --model_name_or_path ${core_lm_name} \
28 | --seed ${global_seed} \
29 | --use_slow_tokenizer \
30 | --file_mode ${file_mode} \
31 | --decode_truncate_len ${decode_truncate_len} \
32 | --decode_depth ${decode_depth} \
33 | --train_mode decode \
34 | --projection_top_p ${projection_top_p}
35 |
--------------------------------------------------------------------------------