├── .gitignore
├── NLI_filter.py
├── README.md
├── SupContrast
├── .gitignore
├── LICENSE
├── README.md
├── figures
│ ├── .DS_Store
│ ├── SimCLR.jpg
│ ├── SupCE.jpg
│ └── SupContrast.jpg
├── losses.py
├── main_ce.py
├── main_linear.py
├── main_supcon.py
├── networks
│ └── resnet_big.py
└── util.py
├── bge_joined_model.py
├── data
├── example_test_dataset.jsonl
├── example_train_dataset.jsonl
└── example_valid_dataset.jsonl
├── joined_dataset.py
└── train_bge_joined.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | __pycache__/
3 | logs/
4 | OPENAI_API_KEY.txt
5 | connect_openai.py
6 | temp*
7 | .vscode/
8 | roberta_weights/
9 | *.png
10 | test*.sh
11 | bge_weights/
12 | colbertv2.0/
13 | rankllama/
14 | bge/
15 | bce/
16 | runs/
17 | summarize_model/
--------------------------------------------------------------------------------
/NLI_filter.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import jsonlines
4 | import json
5 | import random
6 | import re
7 | import os
8 | import copy
9 | import nltk
10 | import numpy as np
11 | import time
12 | from tenacity import (
13 | retry,
14 | stop_after_attempt,
15 | wait_random_exponential,
16 | RetryError
17 | )
18 | import logging
19 | import signal
20 | from tqdm import tqdm
21 | import requests
22 | from concurrent.futures import ThreadPoolExecutor, TimeoutError
23 | import tenacity
24 |
25 | random.seed(0)
26 |
27 |
28 |
29 | filter_results = []
30 | count = 0
31 | filter_count=0
32 | with open('./sample_data/back_trans.jsonl', 'r') as file:
33 | for i,line in enumerate(tqdm(file)):
34 | line = json.loads(line)
35 | back_instructions = line["back_instruction"]
36 | ori_ins = line['instruction']
37 |
38 | nli_scores=[]
39 | for back_ins in back_instructions[:3]:
40 | from transformers import AutoTokenizer, AutoModelForSequenceClassification
41 | import torch
42 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
43 | device = "cuda:0"
44 | model_name = "your path to mDeBERTa-v3-base-xnli-multilingual-nli-2mil7"
45 | tokenizer = AutoTokenizer.from_pretrained(model_name)
46 | model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
47 |
48 | premise = ori_ins
49 | hypothesis = back_ins
50 |
51 | input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt").to(device)
52 | output = model(input["input_ids"].to(device))
53 | prediction = torch.softmax(output["logits"][0], -1).tolist()
54 | label_names = ["entailment", "neutral", "contradiction"]
55 | prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
56 | max_label = max(prediction, key=prediction.get)
57 | nli_scores.append(max_label)
58 |
59 |
60 |
61 | line["nli_scores"] = nli_scores
62 | print(nli_scores)
63 |
64 | if "contradiction" in nli_scores:
65 | filter_count+=1
66 | continue
67 | else:
68 | filter_results.append(line)
69 |
70 |
71 | count+=1
72 |
73 |
74 | print("filter samples nums:",filter_count)
75 |
76 | with jsonlines.open("./output/back_trans_fliter.jsonl", "w") as f:
77 | for each in filter_results:
78 | f.write(each)
79 |
80 |
81 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## [Understand What LLM Needs: Dual Preference Alignment for Retrieval-Augmented Generation](https://arxiv.org/pdf/2406.18676)
3 |
4 |
5 | [](https://paperswithcode.com/sota/knowledge-base-question-answering-on-1?p=understand-what-llm-needs-dual-preference)
6 |
7 | [](https://paperswithcode.com/sota/question-answering-on-natural-questions?p=understand-what-llm-needs-dual-preference)
8 |
9 | [](https://paperswithcode.com/sota/question-answering-on-triviaqa?p=understand-what-llm-needs-dual-preference)
10 |
11 |
12 | *Guanting Dong, Yutao Zhu, Chenghao Zhang, Zechen Wang, Zhicheng Dou and Ji-Rong Wen*
13 |
14 | *Gaoling School of Artificial Intelligence, Renmin University of China.*
15 |
16 | This is the repository contains core implementations of the **DPA-RAG**, proposed by [Understand What LLM Needs: Dual Preference Alignment for Retrieval-Augmented Generation](https://arxiv.org/abs/2406.18676).
17 |
18 |
19 | If you find this work helpful for your research, please kindly cite it.
20 |
21 |
22 | ```bibtex
23 | @article{dong2024understand,
24 | author = {Guanting Dong and
25 | Yutao Zhu and
26 | Chenghao Zhang and
27 | Zechen Wang and
28 | Zhicheng Dou and
29 | Ji{-}Rong Wen},
30 | title = {Understand What {LLM} Needs: Dual Preference Alignment for Retrieval-Augmented
31 | Generation},
32 | journal = {CoRR},
33 | volume = {abs/2406.18676},
34 | year = {2024},
35 | url = {https://doi.org/10.48550/arXiv.2406.18676},
36 | doi = {10.48550/ARXIV.2406.18676},
37 | eprinttype = {arXiv},
38 | eprint = {2406.18676},
39 | timestamp = {Mon, 22 Jul 2024 14:28:34 +0200},
40 | biburl = {https://dblp.org/rec/journals/corr/abs-2406-18676.bib},
41 | bibsource = {dblp computer science bibliography, https://dblp.org}
42 | }
43 | ```
44 |
45 | ## 💥 News
46 | - [01/2025] 🔥 Our paper has been accepted by **WWW 2025**.
47 |
48 | - [05/2024] 🔥 We open-source all our code and sampled data of ***DPA-RAG***. Please check out our [reposity](https://github.com/dongguanting/DPA-RAG).
49 |
50 | - [05/2024] 🔥 We introduced ***DPA-RAG***, a universal framework for aligning diverse preference knowledge within RAG systems. Check out the [paper](https://arxiv.org/abs/2406.18676).
51 |
52 | ---
53 |
54 |
55 | ## 🍯 Overall Framework
56 |
57 |
58 |
59 |
60 | **DPA-RAG** is a universal framework for aligning diverse preference knowledge within RAG systems, consisting of three main components:
61 |
62 | - **Preference Knowledge Construction:** We design three-step process to identify and synthesize high-quality preference knowledge
63 |
64 | - **Reranker-LLM Alignment:** We fine-tune a reranker with multi-grained tasks to achieve alignment between the retriever and LLMs.
65 |
66 | - **LLM Self-Alignment:** Before SFT stage, we introduce a pre-alignment phase to help LLMs capture preference-aligned knowledge.
67 |
68 |
69 | ---
70 |
71 |
72 |
73 |
74 |
75 |
76 | ## 💻 Data preparation
77 | We design a three-step method to gradually mine, augment, and filter out high-quality preference knowledge:
78 |
79 | ### 1. Preference Knowledge Constructio
80 | For each samples in different datasets, you need to use [DPR](https://github.com/facebookresearch/DPR) to retrieve the top 100 passages.
81 |
82 | Then, please follow the process of **Preference Knowledge Construction** section to extract documents labeled “Aligned Knowledge” or “Unaligned Knowledge” from different base model.
83 |
84 |
85 | ### 2. Diverse Query Augmentation
86 | Please use **GPT-3.5-turbo** to perform five augmentations for each query, the template and requirements are as follow:
87 |
88 | 
89 |
90 | - Rephrasing. Rephrase the original query with the same intention.
91 | - Complexity. Increase the semantic complexity of the original query.
92 | - Decomposition. Decompose the original query into several sub-problems.
93 | - Constraint. Add more conditional and constrained statements to the original query.
94 | - SPARQL. Rewrite the original query based on the SPARQL syntax and generate it directly
95 |
96 | We also provide **1k samples of NQ dataset** for each augmentation, which can be directly downloaded in [here](https://drive.google.com/drive/folders/1fbehvvNzas0VitdBky-pDLDZ_vLSHI81).
97 |
98 |
99 | ### 3. NLI Filtering
100 | Please use use [mDeBERTa](https://huggingface.co/MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7) as NLI model for consistency filtering between original and augmented samples.
101 |
102 | **Data format:**
103 |
104 | ```json
105 | {
106 | "instruction": "what does jamaican people speak?",
107 | "back_instruction": ["What language do Jamaican people speak?",
108 | "Given their diverse cultural background and official language policies",
109 | "what language do Jamaican people predominantly speak in everyday life as well as in formal settings?",
110 | "What do Jamaicans speak?"]
111 | }
112 | ```
113 |
114 | **Fliter process:**
115 |
116 | ```bash
117 | python NLI_filter.py
118 | ```
119 |
120 |
121 | ---
122 |
123 | ## :sparkles: Reranker-LLM Alignment
124 |
125 | ### 1. Code clone
126 | Run the following command from the command line to clone our code from github to local:
127 |
128 | ```bash
129 | git clone https://github.com/dongguanting/DPA-RAG.git
130 | ```
131 |
132 | ### 3. Start training
133 |
134 | We present some training data and test data. If you need to modify training data and test data, follow a similar data format without missing the necessary fields.
135 |
136 | ```bash
137 | python train_bge_joined.py \
138 | --pretrained_model_path=your/pretrained/bge/path \
139 | --train_data_path=your/train/data/path \
140 | --valid_data_path=your/valid/data/path \
141 | --test_data_path=your/test/data/path \
142 | --gpu=0 \
143 | --outdir=your/output/path \
144 | --tensorboard_log_dir=your/tensorboard/log/path \
145 | --cls_loss \
146 | --rank_loss \
147 | --scl_loss
148 | ```
149 |
150 | You can choose which GPU to train and test your model on, but right now we don't support multi-GPU training.
151 |
152 | You can choose which losses to incorporate when training. For example, if you only want to use classification loss and contrast learning loss, you can run the following command to start training:
153 |
154 | ```bash
155 | python train_bge_joined.py \
156 | --pretrained_model_path=your/pretrained/bge/path \
157 | --train_data_path=your/train/data/path \
158 | --valid_data_path=your/valid/data/path \
159 | --test_data_path=your/test/data/path \
160 | --gpu=0 \
161 | --outdir=your/output/path \
162 | --tensorboard_log_dir=your/tensorboard/log/path \
163 | --cls_loss \
164 | --scl_loss
165 | ```
166 |
167 | Tests are automatically performed on the specified data set after the training is complete.
168 |
169 | ---
170 |
171 |
172 | ## 🌈 LLM Training
173 |
174 |
175 | For LLM Training,we use the [LlaMA-Factory v0.6.3](https://github.com/hiyouga/LLaMA-Factory/releases/tag/v0.6.3) for Llama2-7B/13B, Mistral-7B, Qwen1.5-0.5B/4B/7B/14B, Phi2-2.7B. Moreover, we use the [LlaMA-Factory v0.8.1](https://github.com/hiyouga/LLaMA-Factory/releases/tag/v0.8.1) for Qwen2-7B, Llama3-8B. Thanks for their excellent work.
176 |
177 |
178 |
179 | ### 1. Pre-aligned Training:
180 |
181 | We provide the NQ dataset at the prealigned stage [here](https://drive.google.com/drive/folders/1JFCGpnmqfMHGh6X9cMJFtTxVduHkiQXi?usp=sharing). Note that the prealigned data we provide does not include augmented data, please merge augmentation data to unlock more powerful alignment capabilities. You can construct other datasets on your own following our NQ's data format. Please replace the parameters with $ symbols with your own parameters.
182 |
183 | ```bash
184 | deepspeed --num_gpus=8 train_bash.py \
185 | --deepspeed $deepspeed_zero3_config_path \
186 | --stage sft \
187 | --do_train \
188 | --use_fast_tokenizer \
189 | --flash_attn \
190 | --adam_beta1 0.9 \
191 | --adam_beta2 0.95 \
192 | --model_name_or_path $MODEL_PATH \
193 | --dataset $dataset \
194 | --template $Template \
195 | --finetuning_type full \
196 | --output_dir $OUTPUT_PATH \
197 | --overwrite_cache \
198 | --overwrite_output_dir \
199 | --warmup_steps 20 \
200 | --weight_decay 0.1 \
201 | --per_device_train_batch_size 4 \
202 | --gradient_accumulation_steps 4 \
203 | --ddp_timeout 9000 \
204 | --learning_rate 7e-6 \
205 | --lr_scheduler_type "linear" \
206 | --logging_steps 1 \
207 | --cutoff_len 8192 \
208 | --save_steps 200 \
209 | --num_train_epochs 3.0 \
210 | --plot_loss \
211 | --bf16
212 | ```
213 |
214 | ### 2. SFT Training:
215 |
216 | You can find the original training data with top3 passages (w/o data augmentation) [here](https://drive.google.com/drive/folders/1dCCpAVPiwPgjOhuKGcyonwgfr2kntJHZ?usp=sharing).
217 | Please merge your augmentation data to unlock more powerful alignment capabilities.
218 |
219 | ```bash
220 | deepspeed --num_gpus=8 train_bash.py \
221 | --deepspeed $deepspeed_zero3_config_path \
222 | --stage sft \
223 | --do_train \
224 | --use_fast_tokenizer \
225 | --flash_attn \
226 | --adam_beta1 0.9 \
227 | --adam_beta2 0.95 \
228 | --model_name_or_path $MODEL_PATH \
229 | --dataset $dataset \
230 | --template $Template \
231 | --finetuning_type full \
232 | --output_dir $OUTPUT_PATH \
233 | --overwrite_cache \
234 | --overwrite_output_dir \
235 | --warmup_steps 20 \
236 | --weight_decay 0.1 \
237 | --per_device_train_batch_size 4 \
238 | --gradient_accumulation_steps 4 \
239 | --ddp_timeout 9000 \
240 | --learning_rate 7e-6 \
241 | --lr_scheduler_type "linear" \
242 | --logging_steps 1 \
243 | --cutoff_len 8192 \
244 | --save_steps 200 \
245 | --num_train_epochs 3.0 \
246 | --plot_loss \
247 | --bf16
248 | ```
249 |
250 | ### 3. Inference
251 |
252 | You can find our reranked test data with top3 passages [here](https://drive.google.com/drive/folders/1HFAEGX5A5aVFNuWMzA1zLoRtfE-FyoTB?usp=sharing). For WebQSP dataset. we only provide train and test data with top2 passages for aligning.
253 |
254 | ```bash
255 | CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
256 | --stage sft \
257 | --model_name_or_path $path_to_llama_model \
258 | --do_predict \
259 | --flash_attn no \
260 | --dataset $dataset \
261 | --template $template \
262 | --output_dir $OUTPUT_PATH \
263 | --per_device_eval_batch_size 8 \
264 | --max_samples 150 \
265 | --cutoff_len 2048 \
266 | --predict_with_generate \
267 | --fp16 \
268 | --quantization_bit 4 \
269 | --max_new_tokens 20
270 | ```
271 |
272 |
273 |
274 |
--------------------------------------------------------------------------------
/SupContrast/.gitignore:
--------------------------------------------------------------------------------
1 | tmp*.py
2 | .idea/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 | save/
--------------------------------------------------------------------------------
/SupContrast/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2020, Yonglong Tian
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/SupContrast/README.md:
--------------------------------------------------------------------------------
1 | # SupContrast: Supervised Contrastive Learning
2 |
3 |
4 |
5 |
6 | This repo covers an reference implementation for the following papers in PyTorch, using CIFAR as an illustrative example:
7 | (1) Supervised Contrastive Learning. [Paper](https://arxiv.org/abs/2004.11362)
8 | (2) A Simple Framework for Contrastive Learning of Visual Representations. [Paper](https://arxiv.org/abs/2002.05709)
9 |
10 | ## Update
11 |
12 | ${\color{red}Note}$: if you found it not easy to parse the supcon loss implementation in this repo, we got you. Supcon loss essentially is just a cross-entropy loss (see eq 4 in the [StableRep](https://arxiv.org/pdf/2306.00984.pdf) paper). So we got a cleaner and simpler implementation [here](https://github.com/google-research/syn-rep-learn/blob/main/StableRep/models/losses.py#L49). Hope it helps.
13 |
14 | ImageNet model (small batch size with the trick of the momentum encoder) is released [here](https://www.dropbox.com/s/l4a69ececk4spdt/supcon.pth?dl=0). It achieved > 79% top-1 accuracy.
15 |
16 | ## Loss Function
17 | The loss function [`SupConLoss`](https://github.com/HobbitLong/SupContrast/blob/master/losses.py#L11) in `losses.py` takes `features` (L2 normalized) and `labels` as input, and return the loss. If `labels` is `None` or not passed to the it, it degenerates to SimCLR.
18 |
19 | Usage:
20 | ```python
21 | from losses import SupConLoss
22 |
23 | # define loss with a temperature `temp`
24 | criterion = SupConLoss(temperature=temp)
25 |
26 | # features: [bsz, n_views, f_dim]
27 | # `n_views` is the number of crops from each image
28 | # better be L2 normalized in f_dim dimension
29 | features = ...
30 | # labels: [bsz]
31 | labels = ...
32 |
33 | # SupContrast
34 | loss = criterion(features, labels)
35 | # or SimCLR
36 | loss = criterion(features)
37 | ...
38 | ```
39 |
40 | ## Comparison
41 | Results on CIFAR-10:
42 | | |Arch | Setting | Loss | Accuracy(%) |
43 | |----------|:----:|:---:|:---:|:---:|
44 | | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | 95.0 |
45 | | SupContrast | ResNet50 | Supervised | Contrastive | 96.0 |
46 | | SimCLR | ResNet50 | Unsupervised | Contrastive | 93.6 |
47 |
48 | Results on CIFAR-100:
49 | | |Arch | Setting | Loss | Accuracy(%) |
50 | |----------|:----:|:---:|:---:|:---:|
51 | | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | 75.3 |
52 | | SupContrast | ResNet50 | Supervised | Contrastive | 76.5 |
53 | | SimCLR | ResNet50 | Unsupervised | Contrastive | 70.7 |
54 |
55 | Results on ImageNet (Stay tuned):
56 | | |Arch | Setting | Loss | Accuracy(%) |
57 | |----------|:----:|:---:|:---:|:---:|
58 | | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | - |
59 | | SupContrast | ResNet50 | Supervised | Contrastive | 79.1 (MoCo trick) |
60 | | SimCLR | ResNet50 | Unsupervised | Contrastive | - |
61 |
62 | ## Running
63 | You might use `CUDA_VISIBLE_DEVICES` to set proper number of GPUs, and/or switch to CIFAR100 by `--dataset cifar100`.
64 | **(1) Standard Cross-Entropy**
65 | ```
66 | python main_ce.py --batch_size 1024 \
67 | --learning_rate 0.8 \
68 | --cosine --syncBN \
69 | ```
70 | **(2) Supervised Contrastive Learning**
71 | Pretraining stage:
72 | ```
73 | python main_supcon.py --batch_size 1024 \
74 | --learning_rate 0.5 \
75 | --temp 0.1 \
76 | --cosine
77 | ```
78 |
79 | You can also specify `--syncBN` but I found it not crucial for SupContrast (`syncBN` 95.9% v.s. `BN` 96.0%).
80 |
81 | WARN: Currently, `--syncBN` has no effect since the code is using `DataParallel` instead of `DistributedDataParaleel`
82 |
83 | Linear evaluation stage:
84 | ```
85 | python main_linear.py --batch_size 512 \
86 | --learning_rate 5 \
87 | --ckpt /path/to/model.pth
88 | ```
89 | **(3) SimCLR**
90 | Pretraining stage:
91 | ```
92 | python main_supcon.py --batch_size 1024 \
93 | --learning_rate 0.5 \
94 | --temp 0.5 \
95 | --cosine --syncBN \
96 | --method SimCLR
97 | ```
98 | The `--method SimCLR` flag simply stops `labels` from being passed to `SupConLoss` criterion.
99 | Linear evaluation stage:
100 | ```
101 | python main_linear.py --batch_size 512 \
102 | --learning_rate 1 \
103 | --ckpt /path/to/model.pth
104 | ```
105 |
106 | On custom dataset:
107 | ```
108 | python main_supcon.py --batch_size 1024 \
109 | --learning_rate 0.5 \
110 | --temp 0.1 --cosine \
111 | --dataset path \
112 | --data_folder ./path \
113 | --mean "(0.4914, 0.4822, 0.4465)" \
114 | --std "(0.2675, 0.2565, 0.2761)" \
115 | --method SimCLR
116 | ```
117 |
118 | The `--data_folder` must be of form ./path/label/xxx.png folowing https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.ImageFolder convension.
119 |
120 | and
121 | ## t-SNE Visualization
122 |
123 | **(1) Standard Cross-Entropy**
124 |
125 |
126 |
127 |
128 | **(2) Supervised Contrastive Learning**
129 |
130 |
131 |
132 |
133 | **(3) SimCLR**
134 |
135 |
136 |
137 |
138 | ## Reference
139 | ```
140 | @Article{khosla2020supervised,
141 | title = {Supervised Contrastive Learning},
142 | author = {Prannay Khosla and Piotr Teterwak and Chen Wang and Aaron Sarna and Yonglong Tian and Phillip Isola and Aaron Maschinot and Ce Liu and Dilip Krishnan},
143 | journal = {arXiv preprint arXiv:2004.11362},
144 | year = {2020},
145 | }
146 | ```
147 |
--------------------------------------------------------------------------------
/SupContrast/figures/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dongguanting/DPA-RAG/2ebaf7e7f60ef2055162c3429fc0fd40ac857a60/SupContrast/figures/.DS_Store
--------------------------------------------------------------------------------
/SupContrast/figures/SimCLR.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dongguanting/DPA-RAG/2ebaf7e7f60ef2055162c3429fc0fd40ac857a60/SupContrast/figures/SimCLR.jpg
--------------------------------------------------------------------------------
/SupContrast/figures/SupCE.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dongguanting/DPA-RAG/2ebaf7e7f60ef2055162c3429fc0fd40ac857a60/SupContrast/figures/SupCE.jpg
--------------------------------------------------------------------------------
/SupContrast/figures/SupContrast.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dongguanting/DPA-RAG/2ebaf7e7f60ef2055162c3429fc0fd40ac857a60/SupContrast/figures/SupContrast.jpg
--------------------------------------------------------------------------------
/SupContrast/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Yonglong Tian (yonglong@mit.edu)
3 | Date: May 07, 2020
4 | """
5 |
6 | from __future__ import print_function
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class SupConLoss(nn.Module):
13 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
14 | It also supports the unsupervised contrastive loss in SimCLR"""
15 |
16 | def __init__(self, temperature=0.07, contrast_mode="all", base_temperature=0.07):
17 | super(SupConLoss, self).__init__()
18 | self.temperature = temperature
19 | self.contrast_mode = contrast_mode
20 | self.base_temperature = base_temperature
21 |
22 | def forward(self, features, labels=None, mask=None):
23 | """Compute loss for model. If both `labels` and `mask` are None,
24 | it degenerates to SimCLR unsupervised loss:
25 | https://arxiv.org/pdf/2002.05709.pdf
26 |
27 | Args:
28 | features: hidden vector of shape [bsz, n_views, ...].
29 | labels: ground truth of shape [bsz].
30 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
31 | has the same class as sample i. Can be asymmetric.
32 | Returns:
33 | A loss scalar.
34 | """
35 | device = features.device
36 |
37 | if len(features.shape) < 3:
38 | raise ValueError(
39 | "`features` needs to be [bsz, n_views, ...],"
40 | "at least 3 dimensions are required"
41 | )
42 | if len(features.shape) > 3:
43 | features = features.view(features.shape[0], features.shape[1], -1)
44 |
45 | batch_size = features.shape[0]
46 | if labels is not None and mask is not None:
47 | raise ValueError("Cannot define both `labels` and `mask`")
48 | elif labels is None and mask is None:
49 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
50 | elif labels is not None:
51 | labels = labels.contiguous().view(-1, 1)
52 | if labels.shape[0] != batch_size:
53 | raise ValueError("Num of labels does not match num of features")
54 | mask = torch.eq(labels, labels.T).float().to(device)
55 | else:
56 | mask = mask.float().to(device)
57 |
58 | contrast_count = features.shape[1]
59 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
60 | if self.contrast_mode == "one":
61 | anchor_feature = features[:, 0]
62 | anchor_count = 1
63 | elif self.contrast_mode == "all":
64 | anchor_feature = contrast_feature
65 | anchor_count = contrast_count
66 | else:
67 | raise ValueError("Unknown mode: {}".format(self.contrast_mode))
68 |
69 | # compute logits
70 | anchor_dot_contrast = torch.div(
71 | torch.matmul(anchor_feature, contrast_feature.T), self.temperature
72 | )
73 | # for numerical stability
74 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
75 | logits = anchor_dot_contrast - logits_max.detach()
76 |
77 | # tile mask
78 | mask = mask.repeat(anchor_count, contrast_count)
79 | # mask-out self-contrast cases
80 | logits_mask = torch.scatter(
81 | torch.ones_like(mask),
82 | 1,
83 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
84 | 0,
85 | )
86 | mask = mask * logits_mask
87 |
88 | # compute log_prob
89 | exp_logits = torch.exp(logits) * logits_mask
90 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
91 |
92 | # compute mean of log-likelihood over positive
93 | # modified to handle edge cases when there is no positive pair
94 | # for an anchor point.
95 | # Edge case e.g.:-
96 | # features of shape: [4,1,...]
97 | # labels: [0,1,1,2]
98 | # loss before mean: [nan, ..., ..., nan]
99 | mask_pos_pairs = mask.sum(1)
100 | mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
101 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs
102 |
103 | # loss
104 | loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
105 | loss = loss.view(anchor_count, batch_size).mean()
106 |
107 | return loss
108 |
--------------------------------------------------------------------------------
/SupContrast/main_ce.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import sys
5 | import argparse
6 | import time
7 | import math
8 |
9 | import tensorboard_logger as tb_logger
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 | from torchvision import transforms, datasets
13 |
14 | from util import AverageMeter
15 | from util import adjust_learning_rate, warmup_learning_rate, accuracy
16 | from util import set_optimizer, save_model
17 | from networks.resnet_big import SupCEResNet
18 |
19 | try:
20 | import apex
21 | from apex import amp, optimizers
22 | except ImportError:
23 | pass
24 |
25 |
26 | def parse_option():
27 | parser = argparse.ArgumentParser('argument for training')
28 |
29 | parser.add_argument('--print_freq', type=int, default=10,
30 | help='print frequency')
31 | parser.add_argument('--save_freq', type=int, default=50,
32 | help='save frequency')
33 | parser.add_argument('--batch_size', type=int, default=256,
34 | help='batch_size')
35 | parser.add_argument('--num_workers', type=int, default=16,
36 | help='num of workers to use')
37 | parser.add_argument('--epochs', type=int, default=500,
38 | help='number of training epochs')
39 |
40 | # optimization
41 | parser.add_argument('--learning_rate', type=float, default=0.2,
42 | help='learning rate')
43 | parser.add_argument('--lr_decay_epochs', type=str, default='350,400,450',
44 | help='where to decay lr, can be a list')
45 | parser.add_argument('--lr_decay_rate', type=float, default=0.1,
46 | help='decay rate for learning rate')
47 | parser.add_argument('--weight_decay', type=float, default=1e-4,
48 | help='weight decay')
49 | parser.add_argument('--momentum', type=float, default=0.9,
50 | help='momentum')
51 |
52 | # model dataset
53 | parser.add_argument('--model', type=str, default='resnet50')
54 | parser.add_argument('--dataset', type=str, default='cifar10',
55 | choices=['cifar10', 'cifar100'], help='dataset')
56 |
57 | # other setting
58 | parser.add_argument('--cosine', action='store_true',
59 | help='using cosine annealing')
60 | parser.add_argument('--syncBN', action='store_true',
61 | help='using synchronized batch normalization')
62 | parser.add_argument('--warm', action='store_true',
63 | help='warm-up for large batch training')
64 | parser.add_argument('--trial', type=str, default='0',
65 | help='id for recording multiple runs')
66 |
67 | opt = parser.parse_args()
68 |
69 | # set the path according to the environment
70 | opt.data_folder = './datasets/'
71 | opt.model_path = './save/SupCon/{}_models'.format(opt.dataset)
72 | opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset)
73 |
74 | iterations = opt.lr_decay_epochs.split(',')
75 | opt.lr_decay_epochs = list([])
76 | for it in iterations:
77 | opt.lr_decay_epochs.append(int(it))
78 |
79 | opt.model_name = 'SupCE_{}_{}_lr_{}_decay_{}_bsz_{}_trial_{}'.\
80 | format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
81 | opt.batch_size, opt.trial)
82 |
83 | if opt.cosine:
84 | opt.model_name = '{}_cosine'.format(opt.model_name)
85 |
86 | # warm-up for large-batch training,
87 | if opt.batch_size > 256:
88 | opt.warm = True
89 | if opt.warm:
90 | opt.model_name = '{}_warm'.format(opt.model_name)
91 | opt.warmup_from = 0.01
92 | opt.warm_epochs = 10
93 | if opt.cosine:
94 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
95 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
96 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
97 | else:
98 | opt.warmup_to = opt.learning_rate
99 |
100 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
101 | if not os.path.isdir(opt.tb_folder):
102 | os.makedirs(opt.tb_folder)
103 |
104 | opt.save_folder = os.path.join(opt.model_path, opt.model_name)
105 | if not os.path.isdir(opt.save_folder):
106 | os.makedirs(opt.save_folder)
107 |
108 | if opt.dataset == 'cifar10':
109 | opt.n_cls = 10
110 | elif opt.dataset == 'cifar100':
111 | opt.n_cls = 100
112 | else:
113 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
114 |
115 | return opt
116 |
117 |
118 | def set_loader(opt):
119 | # construct data loader
120 | if opt.dataset == 'cifar10':
121 | mean = (0.4914, 0.4822, 0.4465)
122 | std = (0.2023, 0.1994, 0.2010)
123 | elif opt.dataset == 'cifar100':
124 | mean = (0.5071, 0.4867, 0.4408)
125 | std = (0.2675, 0.2565, 0.2761)
126 | else:
127 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
128 | normalize = transforms.Normalize(mean=mean, std=std)
129 |
130 | train_transform = transforms.Compose([
131 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
132 | transforms.RandomHorizontalFlip(),
133 | transforms.ToTensor(),
134 | normalize,
135 | ])
136 |
137 | val_transform = transforms.Compose([
138 | transforms.ToTensor(),
139 | normalize,
140 | ])
141 |
142 | if opt.dataset == 'cifar10':
143 | train_dataset = datasets.CIFAR10(root=opt.data_folder,
144 | transform=train_transform,
145 | download=True)
146 | val_dataset = datasets.CIFAR10(root=opt.data_folder,
147 | train=False,
148 | transform=val_transform)
149 | elif opt.dataset == 'cifar100':
150 | train_dataset = datasets.CIFAR100(root=opt.data_folder,
151 | transform=train_transform,
152 | download=True)
153 | val_dataset = datasets.CIFAR100(root=opt.data_folder,
154 | train=False,
155 | transform=val_transform)
156 | else:
157 | raise ValueError(opt.dataset)
158 |
159 | train_sampler = None
160 | train_loader = torch.utils.data.DataLoader(
161 | train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
162 | num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler)
163 | val_loader = torch.utils.data.DataLoader(
164 | val_dataset, batch_size=256, shuffle=False,
165 | num_workers=8, pin_memory=True)
166 |
167 | return train_loader, val_loader
168 |
169 |
170 | def set_model(opt):
171 | model = SupCEResNet(name=opt.model, num_classes=opt.n_cls)
172 | criterion = torch.nn.CrossEntropyLoss()
173 |
174 | # enable synchronized Batch Normalization
175 | if opt.syncBN:
176 | model = apex.parallel.convert_syncbn_model(model)
177 |
178 | if torch.cuda.is_available():
179 | if torch.cuda.device_count() > 1:
180 | model = torch.nn.DataParallel(model)
181 | model = model.cuda()
182 | criterion = criterion.cuda()
183 | cudnn.benchmark = True
184 |
185 | return model, criterion
186 |
187 |
188 | def train(train_loader, model, criterion, optimizer, epoch, opt):
189 | """one epoch training"""
190 | model.train()
191 |
192 | batch_time = AverageMeter()
193 | data_time = AverageMeter()
194 | losses = AverageMeter()
195 | top1 = AverageMeter()
196 |
197 | end = time.time()
198 | for idx, (images, labels) in enumerate(train_loader):
199 | data_time.update(time.time() - end)
200 |
201 | images = images.cuda(non_blocking=True)
202 | labels = labels.cuda(non_blocking=True)
203 | bsz = labels.shape[0]
204 |
205 | # warm-up learning rate
206 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
207 |
208 | # compute loss
209 | output = model(images)
210 | loss = criterion(output, labels)
211 |
212 | # update metric
213 | losses.update(loss.item(), bsz)
214 | acc1, acc5 = accuracy(output, labels, topk=(1, 5))
215 | top1.update(acc1[0], bsz)
216 |
217 | # SGD
218 | optimizer.zero_grad()
219 | loss.backward()
220 | optimizer.step()
221 |
222 | # measure elapsed time
223 | batch_time.update(time.time() - end)
224 | end = time.time()
225 |
226 | # print info
227 | if (idx + 1) % opt.print_freq == 0:
228 | print('Train: [{0}][{1}/{2}]\t'
229 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
230 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
231 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t'
232 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
233 | epoch, idx + 1, len(train_loader), batch_time=batch_time,
234 | data_time=data_time, loss=losses, top1=top1))
235 | sys.stdout.flush()
236 |
237 | return losses.avg, top1.avg
238 |
239 |
240 | def validate(val_loader, model, criterion, opt):
241 | """validation"""
242 | model.eval()
243 |
244 | batch_time = AverageMeter()
245 | losses = AverageMeter()
246 | top1 = AverageMeter()
247 |
248 | with torch.no_grad():
249 | end = time.time()
250 | for idx, (images, labels) in enumerate(val_loader):
251 | images = images.float().cuda()
252 | labels = labels.cuda()
253 | bsz = labels.shape[0]
254 |
255 | # forward
256 | output = model(images)
257 | loss = criterion(output, labels)
258 |
259 | # update metric
260 | losses.update(loss.item(), bsz)
261 | acc1, acc5 = accuracy(output, labels, topk=(1, 5))
262 | top1.update(acc1[0], bsz)
263 |
264 | # measure elapsed time
265 | batch_time.update(time.time() - end)
266 | end = time.time()
267 |
268 | if idx % opt.print_freq == 0:
269 | print('Test: [{0}/{1}]\t'
270 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
271 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
272 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
273 | idx, len(val_loader), batch_time=batch_time,
274 | loss=losses, top1=top1))
275 |
276 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
277 | return losses.avg, top1.avg
278 |
279 |
280 | def main():
281 | best_acc = 0
282 | opt = parse_option()
283 |
284 | # build data loader
285 | train_loader, val_loader = set_loader(opt)
286 |
287 | # build model and criterion
288 | model, criterion = set_model(opt)
289 |
290 | # build optimizer
291 | optimizer = set_optimizer(opt, model)
292 |
293 | # tensorboard
294 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
295 |
296 | # training routine
297 | for epoch in range(1, opt.epochs + 1):
298 | adjust_learning_rate(opt, optimizer, epoch)
299 |
300 | # train for one epoch
301 | time1 = time.time()
302 | loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, opt)
303 | time2 = time.time()
304 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
305 |
306 | # tensorboard logger
307 | logger.log_value('train_loss', loss, epoch)
308 | logger.log_value('train_acc', train_acc, epoch)
309 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)
310 |
311 | # evaluation
312 | loss, val_acc = validate(val_loader, model, criterion, opt)
313 | logger.log_value('val_loss', loss, epoch)
314 | logger.log_value('val_acc', val_acc, epoch)
315 |
316 | if val_acc > best_acc:
317 | best_acc = val_acc
318 |
319 | if epoch % opt.save_freq == 0:
320 | save_file = os.path.join(
321 | opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
322 | save_model(model, optimizer, opt, epoch, save_file)
323 |
324 | # save the last model
325 | save_file = os.path.join(
326 | opt.save_folder, 'last.pth')
327 | save_model(model, optimizer, opt, opt.epochs, save_file)
328 |
329 | print('best accuracy: {:.2f}'.format(best_acc))
330 |
331 |
332 | if __name__ == '__main__':
333 | main()
334 |
--------------------------------------------------------------------------------
/SupContrast/main_linear.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import sys
4 | import argparse
5 | import time
6 | import math
7 |
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 |
11 | from main_ce import set_loader
12 | from util import AverageMeter
13 | from util import adjust_learning_rate, warmup_learning_rate, accuracy
14 | from util import set_optimizer
15 | from networks.resnet_big import SupConResNet, LinearClassifier
16 |
17 | try:
18 | import apex
19 | from apex import amp, optimizers
20 | except ImportError:
21 | pass
22 |
23 |
24 | def parse_option():
25 | parser = argparse.ArgumentParser('argument for training')
26 |
27 | parser.add_argument('--print_freq', type=int, default=10,
28 | help='print frequency')
29 | parser.add_argument('--save_freq', type=int, default=50,
30 | help='save frequency')
31 | parser.add_argument('--batch_size', type=int, default=256,
32 | help='batch_size')
33 | parser.add_argument('--num_workers', type=int, default=16,
34 | help='num of workers to use')
35 | parser.add_argument('--epochs', type=int, default=100,
36 | help='number of training epochs')
37 |
38 | # optimization
39 | parser.add_argument('--learning_rate', type=float, default=0.1,
40 | help='learning rate')
41 | parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90',
42 | help='where to decay lr, can be a list')
43 | parser.add_argument('--lr_decay_rate', type=float, default=0.2,
44 | help='decay rate for learning rate')
45 | parser.add_argument('--weight_decay', type=float, default=0,
46 | help='weight decay')
47 | parser.add_argument('--momentum', type=float, default=0.9,
48 | help='momentum')
49 |
50 | # model dataset
51 | parser.add_argument('--model', type=str, default='resnet50')
52 | parser.add_argument('--dataset', type=str, default='cifar10',
53 | choices=['cifar10', 'cifar100'], help='dataset')
54 |
55 | # other setting
56 | parser.add_argument('--cosine', action='store_true',
57 | help='using cosine annealing')
58 | parser.add_argument('--warm', action='store_true',
59 | help='warm-up for large batch training')
60 |
61 | parser.add_argument('--ckpt', type=str, default='',
62 | help='path to pre-trained model')
63 |
64 | opt = parser.parse_args()
65 |
66 | # set the path according to the environment
67 | opt.data_folder = './datasets/'
68 |
69 | iterations = opt.lr_decay_epochs.split(',')
70 | opt.lr_decay_epochs = list([])
71 | for it in iterations:
72 | opt.lr_decay_epochs.append(int(it))
73 |
74 | opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'.\
75 | format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
76 | opt.batch_size)
77 |
78 | if opt.cosine:
79 | opt.model_name = '{}_cosine'.format(opt.model_name)
80 |
81 | # warm-up for large-batch training,
82 | if opt.warm:
83 | opt.model_name = '{}_warm'.format(opt.model_name)
84 | opt.warmup_from = 0.01
85 | opt.warm_epochs = 10
86 | if opt.cosine:
87 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
88 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
89 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
90 | else:
91 | opt.warmup_to = opt.learning_rate
92 |
93 | if opt.dataset == 'cifar10':
94 | opt.n_cls = 10
95 | elif opt.dataset == 'cifar100':
96 | opt.n_cls = 100
97 | else:
98 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
99 |
100 | return opt
101 |
102 |
103 | def set_model(opt):
104 | model = SupConResNet(name=opt.model)
105 | criterion = torch.nn.CrossEntropyLoss()
106 |
107 | classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls)
108 |
109 | ckpt = torch.load(opt.ckpt, map_location='cpu')
110 | state_dict = ckpt['model']
111 |
112 | if torch.cuda.is_available():
113 | if torch.cuda.device_count() > 1:
114 | model.encoder = torch.nn.DataParallel(model.encoder)
115 | else:
116 | new_state_dict = {}
117 | for k, v in state_dict.items():
118 | k = k.replace("module.", "")
119 | new_state_dict[k] = v
120 | state_dict = new_state_dict
121 | model = model.cuda()
122 | classifier = classifier.cuda()
123 | criterion = criterion.cuda()
124 | cudnn.benchmark = True
125 |
126 | model.load_state_dict(state_dict)
127 | else:
128 | raise NotImplementedError('This code requires GPU')
129 |
130 | return model, classifier, criterion
131 |
132 |
133 | def train(train_loader, model, classifier, criterion, optimizer, epoch, opt):
134 | """one epoch training"""
135 | model.eval()
136 | classifier.train()
137 |
138 | batch_time = AverageMeter()
139 | data_time = AverageMeter()
140 | losses = AverageMeter()
141 | top1 = AverageMeter()
142 |
143 | end = time.time()
144 | for idx, (images, labels) in enumerate(train_loader):
145 | data_time.update(time.time() - end)
146 |
147 | images = images.cuda(non_blocking=True)
148 | labels = labels.cuda(non_blocking=True)
149 | bsz = labels.shape[0]
150 |
151 | # warm-up learning rate
152 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
153 |
154 | # compute loss
155 | with torch.no_grad():
156 | features = model.encoder(images)
157 | output = classifier(features.detach())
158 | loss = criterion(output, labels)
159 |
160 | # update metric
161 | losses.update(loss.item(), bsz)
162 | acc1, acc5 = accuracy(output, labels, topk=(1, 5))
163 | top1.update(acc1[0], bsz)
164 |
165 | # SGD
166 | optimizer.zero_grad()
167 | loss.backward()
168 | optimizer.step()
169 |
170 | # measure elapsed time
171 | batch_time.update(time.time() - end)
172 | end = time.time()
173 |
174 | # print info
175 | if (idx + 1) % opt.print_freq == 0:
176 | print('Train: [{0}][{1}/{2}]\t'
177 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
178 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
179 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t'
180 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
181 | epoch, idx + 1, len(train_loader), batch_time=batch_time,
182 | data_time=data_time, loss=losses, top1=top1))
183 | sys.stdout.flush()
184 |
185 | return losses.avg, top1.avg
186 |
187 |
188 | def validate(val_loader, model, classifier, criterion, opt):
189 | """validation"""
190 | model.eval()
191 | classifier.eval()
192 |
193 | batch_time = AverageMeter()
194 | losses = AverageMeter()
195 | top1 = AverageMeter()
196 |
197 | with torch.no_grad():
198 | end = time.time()
199 | for idx, (images, labels) in enumerate(val_loader):
200 | images = images.float().cuda()
201 | labels = labels.cuda()
202 | bsz = labels.shape[0]
203 |
204 | # forward
205 | output = classifier(model.encoder(images))
206 | loss = criterion(output, labels)
207 |
208 | # update metric
209 | losses.update(loss.item(), bsz)
210 | acc1, acc5 = accuracy(output, labels, topk=(1, 5))
211 | top1.update(acc1[0], bsz)
212 |
213 | # measure elapsed time
214 | batch_time.update(time.time() - end)
215 | end = time.time()
216 |
217 | if idx % opt.print_freq == 0:
218 | print('Test: [{0}/{1}]\t'
219 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
220 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
221 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
222 | idx, len(val_loader), batch_time=batch_time,
223 | loss=losses, top1=top1))
224 |
225 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
226 | return losses.avg, top1.avg
227 |
228 |
229 | def main():
230 | best_acc = 0
231 | opt = parse_option()
232 |
233 | # build data loader
234 | train_loader, val_loader = set_loader(opt)
235 |
236 | # build model and criterion
237 | model, classifier, criterion = set_model(opt)
238 |
239 | # build optimizer
240 | optimizer = set_optimizer(opt, classifier)
241 |
242 | # training routine
243 | for epoch in range(1, opt.epochs + 1):
244 | adjust_learning_rate(opt, optimizer, epoch)
245 |
246 | # train for one epoch
247 | time1 = time.time()
248 | loss, acc = train(train_loader, model, classifier, criterion,
249 | optimizer, epoch, opt)
250 | time2 = time.time()
251 | print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format(
252 | epoch, time2 - time1, acc))
253 |
254 | # eval for one epoch
255 | loss, val_acc = validate(val_loader, model, classifier, criterion, opt)
256 | if val_acc > best_acc:
257 | best_acc = val_acc
258 |
259 | print('best accuracy: {:.2f}'.format(best_acc))
260 |
261 |
262 | if __name__ == '__main__':
263 | main()
264 |
--------------------------------------------------------------------------------
/SupContrast/main_supcon.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import sys
5 | import argparse
6 | import time
7 | import math
8 |
9 | import tensorboard_logger as tb_logger
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 | from torchvision import transforms, datasets
13 |
14 | from util import TwoCropTransform, AverageMeter
15 | from util import adjust_learning_rate, warmup_learning_rate
16 | from util import set_optimizer, save_model
17 | from networks.resnet_big import SupConResNet
18 | from losses import SupConLoss
19 |
20 | try:
21 | import apex
22 | from apex import amp, optimizers
23 | except ImportError:
24 | pass
25 |
26 |
27 | def parse_option():
28 | parser = argparse.ArgumentParser('argument for training')
29 |
30 | parser.add_argument('--print_freq', type=int, default=10,
31 | help='print frequency')
32 | parser.add_argument('--save_freq', type=int, default=50,
33 | help='save frequency')
34 | parser.add_argument('--batch_size', type=int, default=256,
35 | help='batch_size')
36 | parser.add_argument('--num_workers', type=int, default=16,
37 | help='num of workers to use')
38 | parser.add_argument('--epochs', type=int, default=1000,
39 | help='number of training epochs')
40 |
41 | # optimization
42 | parser.add_argument('--learning_rate', type=float, default=0.05,
43 | help='learning rate')
44 | parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900',
45 | help='where to decay lr, can be a list')
46 | parser.add_argument('--lr_decay_rate', type=float, default=0.1,
47 | help='decay rate for learning rate')
48 | parser.add_argument('--weight_decay', type=float, default=1e-4,
49 | help='weight decay')
50 | parser.add_argument('--momentum', type=float, default=0.9,
51 | help='momentum')
52 |
53 | # model dataset
54 | parser.add_argument('--model', type=str, default='resnet50')
55 | parser.add_argument('--dataset', type=str, default='cifar10',
56 | choices=['cifar10', 'cifar100', 'path'], help='dataset')
57 | parser.add_argument('--mean', type=str, help='mean of dataset in path in form of str tuple')
58 | parser.add_argument('--std', type=str, help='std of dataset in path in form of str tuple')
59 | parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset')
60 | parser.add_argument('--size', type=int, default=32, help='parameter for RandomResizedCrop')
61 |
62 | # method
63 | parser.add_argument('--method', type=str, default='SupCon',
64 | choices=['SupCon', 'SimCLR'], help='choose method')
65 |
66 | # temperature
67 | parser.add_argument('--temp', type=float, default=0.07,
68 | help='temperature for loss function')
69 |
70 | # other setting
71 | parser.add_argument('--cosine', action='store_true',
72 | help='using cosine annealing')
73 | parser.add_argument('--syncBN', action='store_true',
74 | help='using synchronized batch normalization')
75 | parser.add_argument('--warm', action='store_true',
76 | help='warm-up for large batch training')
77 | parser.add_argument('--trial', type=str, default='0',
78 | help='id for recording multiple runs')
79 |
80 | opt = parser.parse_args()
81 |
82 | # check if dataset is path that passed required arguments
83 | if opt.dataset == 'path':
84 | assert opt.data_folder is not None \
85 | and opt.mean is not None \
86 | and opt.std is not None
87 |
88 | # set the path according to the environment
89 | if opt.data_folder is None:
90 | opt.data_folder = './datasets/'
91 | opt.model_path = './save/SupCon/{}_models'.format(opt.dataset)
92 | opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset)
93 |
94 | iterations = opt.lr_decay_epochs.split(',')
95 | opt.lr_decay_epochs = list([])
96 | for it in iterations:
97 | opt.lr_decay_epochs.append(int(it))
98 |
99 | opt.model_name = '{}_{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}'.\
100 | format(opt.method, opt.dataset, opt.model, opt.learning_rate,
101 | opt.weight_decay, opt.batch_size, opt.temp, opt.trial)
102 |
103 | if opt.cosine:
104 | opt.model_name = '{}_cosine'.format(opt.model_name)
105 |
106 | # warm-up for large-batch training,
107 | if opt.batch_size > 256:
108 | opt.warm = True
109 | if opt.warm:
110 | opt.model_name = '{}_warm'.format(opt.model_name)
111 | opt.warmup_from = 0.01
112 | opt.warm_epochs = 10
113 | if opt.cosine:
114 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
115 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
116 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
117 | else:
118 | opt.warmup_to = opt.learning_rate
119 |
120 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
121 | if not os.path.isdir(opt.tb_folder):
122 | os.makedirs(opt.tb_folder)
123 |
124 | opt.save_folder = os.path.join(opt.model_path, opt.model_name)
125 | if not os.path.isdir(opt.save_folder):
126 | os.makedirs(opt.save_folder)
127 |
128 | return opt
129 |
130 |
131 | def set_loader(opt):
132 | # construct data loader
133 | if opt.dataset == 'cifar10':
134 | mean = (0.4914, 0.4822, 0.4465)
135 | std = (0.2023, 0.1994, 0.2010)
136 | elif opt.dataset == 'cifar100':
137 | mean = (0.5071, 0.4867, 0.4408)
138 | std = (0.2675, 0.2565, 0.2761)
139 | elif opt.dataset == 'path':
140 | mean = eval(opt.mean)
141 | std = eval(opt.std)
142 | else:
143 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
144 | normalize = transforms.Normalize(mean=mean, std=std)
145 |
146 | train_transform = transforms.Compose([
147 | transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.)),
148 | transforms.RandomHorizontalFlip(),
149 | transforms.RandomApply([
150 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
151 | ], p=0.8),
152 | transforms.RandomGrayscale(p=0.2),
153 | transforms.ToTensor(),
154 | normalize,
155 | ])
156 |
157 | if opt.dataset == 'cifar10':
158 | train_dataset = datasets.CIFAR10(root=opt.data_folder,
159 | transform=TwoCropTransform(train_transform),
160 | download=True)
161 | elif opt.dataset == 'cifar100':
162 | train_dataset = datasets.CIFAR100(root=opt.data_folder,
163 | transform=TwoCropTransform(train_transform),
164 | download=True)
165 | elif opt.dataset == 'path':
166 | train_dataset = datasets.ImageFolder(root=opt.data_folder,
167 | transform=TwoCropTransform(train_transform))
168 | else:
169 | raise ValueError(opt.dataset)
170 |
171 | train_sampler = None
172 | train_loader = torch.utils.data.DataLoader(
173 | train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
174 | num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler)
175 |
176 | return train_loader
177 |
178 |
179 | def set_model(opt):
180 | model = SupConResNet(name=opt.model)
181 | criterion = SupConLoss(temperature=opt.temp)
182 |
183 | # enable synchronized Batch Normalization
184 | if opt.syncBN:
185 | model = apex.parallel.convert_syncbn_model(model)
186 |
187 | if torch.cuda.is_available():
188 | if torch.cuda.device_count() > 1:
189 | model.encoder = torch.nn.DataParallel(model.encoder)
190 | model = model.cuda()
191 | criterion = criterion.cuda()
192 | cudnn.benchmark = True
193 |
194 | return model, criterion
195 |
196 |
197 | def train(train_loader, model, criterion, optimizer, epoch, opt):
198 | """one epoch training"""
199 | model.train()
200 |
201 | batch_time = AverageMeter()
202 | data_time = AverageMeter()
203 | losses = AverageMeter()
204 |
205 | end = time.time()
206 | for idx, (images, labels) in enumerate(train_loader):
207 | data_time.update(time.time() - end)
208 |
209 | images = torch.cat([images[0], images[1]], dim=0)
210 | if torch.cuda.is_available():
211 | images = images.cuda(non_blocking=True)
212 | labels = labels.cuda(non_blocking=True)
213 | bsz = labels.shape[0]
214 |
215 | # warm-up learning rate
216 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
217 |
218 | # compute loss
219 | features = model(images)
220 | f1, f2 = torch.split(features, [bsz, bsz], dim=0)
221 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
222 | if opt.method == 'SupCon':
223 | loss = criterion(features, labels)
224 | elif opt.method == 'SimCLR':
225 | loss = criterion(features)
226 | else:
227 | raise ValueError('contrastive method not supported: {}'.
228 | format(opt.method))
229 |
230 | # update metric
231 | losses.update(loss.item(), bsz)
232 |
233 | # SGD
234 | optimizer.zero_grad()
235 | loss.backward()
236 | optimizer.step()
237 |
238 | # measure elapsed time
239 | batch_time.update(time.time() - end)
240 | end = time.time()
241 |
242 | # print info
243 | if (idx + 1) % opt.print_freq == 0:
244 | print('Train: [{0}][{1}/{2}]\t'
245 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
246 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
247 | 'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
248 | epoch, idx + 1, len(train_loader), batch_time=batch_time,
249 | data_time=data_time, loss=losses))
250 | sys.stdout.flush()
251 |
252 | return losses.avg
253 |
254 |
255 | def main():
256 | opt = parse_option()
257 |
258 | # build data loader
259 | train_loader = set_loader(opt)
260 |
261 | # build model and criterion
262 | model, criterion = set_model(opt)
263 |
264 | # build optimizer
265 | optimizer = set_optimizer(opt, model)
266 |
267 | # tensorboard
268 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
269 |
270 | # training routine
271 | for epoch in range(1, opt.epochs + 1):
272 | adjust_learning_rate(opt, optimizer, epoch)
273 |
274 | # train for one epoch
275 | time1 = time.time()
276 | loss = train(train_loader, model, criterion, optimizer, epoch, opt)
277 | time2 = time.time()
278 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
279 |
280 | # tensorboard logger
281 | logger.log_value('loss', loss, epoch)
282 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)
283 |
284 | if epoch % opt.save_freq == 0:
285 | save_file = os.path.join(
286 | opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
287 | save_model(model, optimizer, opt, epoch, save_file)
288 |
289 | # save the last model
290 | save_file = os.path.join(
291 | opt.save_folder, 'last.pth')
292 | save_model(model, optimizer, opt, opt.epochs, save_file)
293 |
294 |
295 | if __name__ == '__main__':
296 | main()
297 |
--------------------------------------------------------------------------------
/SupContrast/networks/resnet_big.py:
--------------------------------------------------------------------------------
1 | """ResNet in PyTorch.
2 | ImageNet-Style ResNet
3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
5 | Adapted from: https://github.com/bearpaw/pytorch-classification
6 | """
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class BasicBlock(nn.Module):
13 | expansion = 1
14 |
15 | def __init__(self, in_planes, planes, stride=1, is_last=False):
16 | super(BasicBlock, self).__init__()
17 | self.is_last = is_last
18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
19 | self.bn1 = nn.BatchNorm2d(planes)
20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 |
23 | self.shortcut = nn.Sequential()
24 | if stride != 1 or in_planes != self.expansion * planes:
25 | self.shortcut = nn.Sequential(
26 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
27 | nn.BatchNorm2d(self.expansion * planes)
28 | )
29 |
30 | def forward(self, x):
31 | out = F.relu(self.bn1(self.conv1(x)))
32 | out = self.bn2(self.conv2(out))
33 | out += self.shortcut(x)
34 | preact = out
35 | out = F.relu(out)
36 | if self.is_last:
37 | return out, preact
38 | else:
39 | return out
40 |
41 |
42 | class Bottleneck(nn.Module):
43 | expansion = 4
44 |
45 | def __init__(self, in_planes, planes, stride=1, is_last=False):
46 | super(Bottleneck, self).__init__()
47 | self.is_last = is_last
48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
49 | self.bn1 = nn.BatchNorm2d(planes)
50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
51 | self.bn2 = nn.BatchNorm2d(planes)
52 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
53 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
54 |
55 | self.shortcut = nn.Sequential()
56 | if stride != 1 or in_planes != self.expansion * planes:
57 | self.shortcut = nn.Sequential(
58 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
59 | nn.BatchNorm2d(self.expansion * planes)
60 | )
61 |
62 | def forward(self, x):
63 | out = F.relu(self.bn1(self.conv1(x)))
64 | out = F.relu(self.bn2(self.conv2(out)))
65 | out = self.bn3(self.conv3(out))
66 | out += self.shortcut(x)
67 | preact = out
68 | out = F.relu(out)
69 | if self.is_last:
70 | return out, preact
71 | else:
72 | return out
73 |
74 |
75 | class ResNet(nn.Module):
76 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
77 | super(ResNet, self).__init__()
78 | self.in_planes = 64
79 |
80 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
81 | bias=False)
82 | self.bn1 = nn.BatchNorm2d(64)
83 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
84 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
85 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
86 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
87 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
88 |
89 | for m in self.modules():
90 | if isinstance(m, nn.Conv2d):
91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
92 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
93 | nn.init.constant_(m.weight, 1)
94 | nn.init.constant_(m.bias, 0)
95 |
96 | # Zero-initialize the last BN in each residual branch,
97 | # so that the residual branch starts with zeros, and each residual block behaves
98 | # like an identity. This improves the model by 0.2~0.3% according to:
99 | # https://arxiv.org/abs/1706.02677
100 | if zero_init_residual:
101 | for m in self.modules():
102 | if isinstance(m, Bottleneck):
103 | nn.init.constant_(m.bn3.weight, 0)
104 | elif isinstance(m, BasicBlock):
105 | nn.init.constant_(m.bn2.weight, 0)
106 |
107 | def _make_layer(self, block, planes, num_blocks, stride):
108 | strides = [stride] + [1] * (num_blocks - 1)
109 | layers = []
110 | for i in range(num_blocks):
111 | stride = strides[i]
112 | layers.append(block(self.in_planes, planes, stride))
113 | self.in_planes = planes * block.expansion
114 | return nn.Sequential(*layers)
115 |
116 | def forward(self, x, layer=100):
117 | out = F.relu(self.bn1(self.conv1(x)))
118 | out = self.layer1(out)
119 | out = self.layer2(out)
120 | out = self.layer3(out)
121 | out = self.layer4(out)
122 | out = self.avgpool(out)
123 | out = torch.flatten(out, 1)
124 | return out
125 |
126 |
127 | def resnet18(**kwargs):
128 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
129 |
130 |
131 | def resnet34(**kwargs):
132 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
133 |
134 |
135 | def resnet50(**kwargs):
136 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
137 |
138 |
139 | def resnet101(**kwargs):
140 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
141 |
142 |
143 | model_dict = {
144 | 'resnet18': [resnet18, 512],
145 | 'resnet34': [resnet34, 512],
146 | 'resnet50': [resnet50, 2048],
147 | 'resnet101': [resnet101, 2048],
148 | }
149 |
150 |
151 | class LinearBatchNorm(nn.Module):
152 | """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose"""
153 | def __init__(self, dim, affine=True):
154 | super(LinearBatchNorm, self).__init__()
155 | self.dim = dim
156 | self.bn = nn.BatchNorm2d(dim, affine=affine)
157 |
158 | def forward(self, x):
159 | x = x.view(-1, self.dim, 1, 1)
160 | x = self.bn(x)
161 | x = x.view(-1, self.dim)
162 | return x
163 |
164 |
165 | class SupConResNet(nn.Module):
166 | """backbone + projection head"""
167 | def __init__(self, name='resnet50', head='mlp', feat_dim=128):
168 | super(SupConResNet, self).__init__()
169 | model_fun, dim_in = model_dict[name]
170 | self.encoder = model_fun()
171 | if head == 'linear':
172 | self.head = nn.Linear(dim_in, feat_dim)
173 | elif head == 'mlp':
174 | self.head = nn.Sequential(
175 | nn.Linear(dim_in, dim_in),
176 | nn.ReLU(inplace=True),
177 | nn.Linear(dim_in, feat_dim)
178 | )
179 | else:
180 | raise NotImplementedError(
181 | 'head not supported: {}'.format(head))
182 |
183 | def forward(self, x):
184 | feat = self.encoder(x)
185 | feat = F.normalize(self.head(feat), dim=1)
186 | return feat
187 |
188 |
189 | class SupCEResNet(nn.Module):
190 | """encoder + classifier"""
191 | def __init__(self, name='resnet50', num_classes=10):
192 | super(SupCEResNet, self).__init__()
193 | model_fun, dim_in = model_dict[name]
194 | self.encoder = model_fun()
195 | self.fc = nn.Linear(dim_in, num_classes)
196 |
197 | def forward(self, x):
198 | return self.fc(self.encoder(x))
199 |
200 |
201 | class LinearClassifier(nn.Module):
202 | """Linear classifier"""
203 | def __init__(self, name='resnet50', num_classes=10):
204 | super(LinearClassifier, self).__init__()
205 | _, feat_dim = model_dict[name]
206 | self.fc = nn.Linear(feat_dim, num_classes)
207 |
208 | def forward(self, features):
209 | return self.fc(features)
210 |
--------------------------------------------------------------------------------
/SupContrast/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import math
4 | import numpy as np
5 | import torch
6 | import torch.optim as optim
7 |
8 |
9 | class TwoCropTransform:
10 | """Create two crops of the same image"""
11 | def __init__(self, transform):
12 | self.transform = transform
13 |
14 | def __call__(self, x):
15 | return [self.transform(x), self.transform(x)]
16 |
17 |
18 | class AverageMeter(object):
19 | """Computes and stores the average and current value"""
20 | def __init__(self):
21 | self.reset()
22 |
23 | def reset(self):
24 | self.val = 0
25 | self.avg = 0
26 | self.sum = 0
27 | self.count = 0
28 |
29 | def update(self, val, n=1):
30 | self.val = val
31 | self.sum += val * n
32 | self.count += n
33 | self.avg = self.sum / self.count
34 |
35 |
36 | def accuracy(output, target, topk=(1,)):
37 | """Computes the accuracy over the k top predictions for the specified values of k"""
38 | with torch.no_grad():
39 | maxk = max(topk)
40 | batch_size = target.size(0)
41 |
42 | _, pred = output.topk(maxk, 1, True, True)
43 | pred = pred.t()
44 | correct = pred.eq(target.view(1, -1).expand_as(pred))
45 |
46 | res = []
47 | for k in topk:
48 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
49 | res.append(correct_k.mul_(100.0 / batch_size))
50 | return res
51 |
52 |
53 | def adjust_learning_rate(args, optimizer, epoch):
54 | lr = args.learning_rate
55 | if args.cosine:
56 | eta_min = lr * (args.lr_decay_rate ** 3)
57 | lr = eta_min + (lr - eta_min) * (
58 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2
59 | else:
60 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
61 | if steps > 0:
62 | lr = lr * (args.lr_decay_rate ** steps)
63 |
64 | for param_group in optimizer.param_groups:
65 | param_group['lr'] = lr
66 |
67 |
68 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
69 | if args.warm and epoch <= args.warm_epochs:
70 | p = (batch_id + (epoch - 1) * total_batches) / \
71 | (args.warm_epochs * total_batches)
72 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)
73 |
74 | for param_group in optimizer.param_groups:
75 | param_group['lr'] = lr
76 |
77 |
78 | def set_optimizer(opt, model):
79 | optimizer = optim.SGD(model.parameters(),
80 | lr=opt.learning_rate,
81 | momentum=opt.momentum,
82 | weight_decay=opt.weight_decay)
83 | return optimizer
84 |
85 |
86 | def save_model(model, optimizer, opt, epoch, save_file):
87 | print('==> Saving...')
88 | state = {
89 | 'opt': opt,
90 | 'model': model.state_dict(),
91 | 'optimizer': optimizer.state_dict(),
92 | 'epoch': epoch,
93 | }
94 | torch.save(state, save_file)
95 | del state
96 |
--------------------------------------------------------------------------------
/bge_joined_model.py:
--------------------------------------------------------------------------------
1 | from enum import Enum, auto
2 | import random
3 | from typing import List
4 | import numpy as np
5 | import json
6 | import torch
7 | from torch import nn
8 | from torch.utils.data import Dataset, DataLoader
9 | from transformers import AutoTokenizer, AutoConfig
10 | from transformers import RobertaPreTrainedModel, RobertaModel
11 | from transformers import AdamW, get_scheduler
12 | from tqdm.auto import tqdm
13 | from torch.utils.tensorboard import SummaryWriter
14 | import os
15 | import argparse
16 | from sklearn.metrics import f1_score
17 | import torch
18 | from torch import nn
19 | from transformers import AutoModel, AutoTokenizer
20 | from SupContrast.losses import SupConLoss
21 |
22 |
23 | class WeightsCalculator:
24 | def __init__(self, model_device) -> None:
25 | self.device = model_device
26 | self.MAX_N_ITER = 100
27 | self.EPS = 1e-3
28 | self.gamma_model = self.Gamma().to(model_device)
29 |
30 | def reset(self):
31 | self.gamma_model.set_gamma(0, 1, self.device)
32 |
33 | class Gamma(nn.Module):
34 | def __init__(self):
35 | super().__init__()
36 | self.gamma = nn.parameter.Parameter(torch.linspace(0, 1, 100).unsqueeze(-1))
37 |
38 | def forward(self, M, alpha, e_t):
39 | temp = (1 - self.gamma) * alpha.expand(100, -1) + self.gamma * e_t.expand(
40 | 100, -1
41 | )
42 | loss = temp.matmul(M).matmul(temp.T).diag().sum()
43 | return loss
44 |
45 | def set_gamma(self, start, end, device):
46 | self.gamma.data = torch.linspace(start, end, 100).unsqueeze(-1).to(device)
47 |
48 | def search_1d(self, M, alpha, e_t, gamma_model: Gamma):
49 | loss_rec = 0.0
50 | start = 0
51 | end = 1
52 | for i in range(4):
53 | gamma_model.set_gamma(start, end, M.device)
54 | loss = gamma_model(M, alpha, e_t)
55 | if abs(loss.item() - loss_rec) < self.EPS:
56 | break
57 | gamma_model.gamma.grad = None
58 | loss.backward()
59 | loss_rec = loss.item()
60 | gamma = gamma_model.gamma.data
61 | grad = gamma_model.gamma.grad
62 | if grad[0] * grad[-1] >= 0:
63 | return gamma[torch.argmin(abs(grad))].clamp(0, 0.9999)
64 | try:
65 | end = gamma[grad > 0].min()
66 | start = gamma[grad < 0].max()
67 | except:
68 | # grad 为 nan ,可能是loss为nan
69 | raise Exception()
70 | gamma = gamma_model.gamma.data.mean()
71 | return gamma
72 |
73 | def calc_weights(self, submodels: List[nn.Module], losses):
74 | n_loss = len(losses)
75 | alpha = torch.tensor([1 / n_loss] * n_loss)
76 |
77 | if not losses[0].requires_grad:
78 | return alpha.to(losses[0].device)
79 |
80 | params = [
81 | p for model in submodels for p in model.parameters() if p.requires_grad
82 | ]
83 | # 构建M矩阵
84 | g = []
85 | for loss in losses:
86 | temp_g = torch.autograd.grad(loss, params, retain_graph=True)
87 | g.append(torch.cat([i.reshape(-1) for i in temp_g]))
88 | M = []
89 | for i in range(n_loss):
90 | for j in range(n_loss):
91 | M.append(g[i].matmul(g[j]))
92 | M = torch.stack(M)
93 | M = M.reshape(n_loss, n_loss)
94 |
95 | device = M.device
96 | alpha = alpha.to(device)
97 |
98 | for i in range(self.MAX_N_ITER):
99 | t = torch.argmin(torch.sum(alpha.expand(n_loss, n_loss) * M, 1))
100 | e_t = torch.zeros(n_loss).to(device)
101 | e_t[t] = 1.0
102 | gamma = self.search_1d(M, alpha, e_t, self.gamma_model)
103 | alpha = (1 - gamma) * alpha + gamma * e_t
104 | if gamma < self.EPS or abs(1 - alpha.max()) < self.EPS:
105 | break
106 |
107 | return alpha
108 |
109 |
110 | class BgeJoinedModelLoss(Enum):
111 | ClaasificationLoss = auto()
112 | RankLoss = auto()
113 | ContrastiveLoss = auto()
114 |
115 |
116 | class BgeJoinedModel(nn.Module):
117 | def __init__(self, pretrained_model_path, loss_types: List[BgeJoinedModelLoss]):
118 | assert len(loss_types) > 0
119 | super(BgeJoinedModel, self).__init__()
120 | self.loss_types = loss_types
121 | self.bge = AutoModel.from_pretrained(pretrained_model_path)
122 | self.classifier = nn.Linear(1024, 2)
123 | if BgeJoinedModelLoss.ContrastiveLoss in self.loss_types:
124 | self.scl_loss_func = SupConLoss()
125 | self.weights_calculator = None
126 |
127 | def forward(self, cls_tokens):
128 | outputs = self.bge(**cls_tokens).last_hidden_state[:, 0, :]
129 | outputs = self.classifier(outputs)
130 | return outputs
131 |
--------------------------------------------------------------------------------
/data/example_test_dataset.jsonl:
--------------------------------------------------------------------------------
1 | {"question": "who has the most number one single hits", "reference": ["The Beatles"], "ctx": " Title: Shakin' All Over Content: national number one hit in late 1965 for Normie Rowe. Rowe's version of the track (backed by \"\"Que Sera Sera\"\") was one of the biggest-selling Australian singles of the decade. Rowe had recorded his version before The Guess Who's, and based his release on Johnny Chester's 1961 version. The song has been performed many times by The Who, (sometimes in a medley with \"\"Spoonful\"\"), perhaps most famously at Woodstock in 1969 and on \"\"Live at Leeds\"\" in 1970. In Randy Bachman's autobiography, when he met Who bass player John Entwistle, he was told that people constantly got The Who and", "before": {"answer": "The Beatles", "em": 1, "f1": 1.0}, "after": {"answer": "Unknown", "em": 0, "f1": 0}, "category": 2}
2 | {"question": "what is the most famous building in rennes", "reference": ["The Parlement de Bretagne", "Parlement de Bretagne"], "ctx": " Title: Rennes Content: parlement de Bretagne\"\", many \"\"hôtels particuliers\"\" were built in the northern part, the richest in the 18th century. Most of the monuments historiques can be found there. Colourful traditional half-timbered houses are situated primarily along the roads of Saint-Sauveur, Saint-Georges, de Saint-Malo, Saint-Guillaume, des Dames, du Chapitre, Vasselot, Saint-Michel, de la Psallette and around the plazas of Champ-Jacquet, des Lices, Saint-Anne and Rallier-du-Baty. The \"\"Parlement de Bretagne\"\" (Administrative and judicial centre of Brittany, ) is the most famous 17th century building in Rennes. It was rebuilt after a terrible fire in 1994 that may have been caused by a flare", "before": {"answer": "Parliament (Palais du Parlement)", "em": 0, "f1": 0.28571428571428575}, "after": {"answer": "Parlement de Bretagne", "em": 1, "f1": 1.0}, "category": 1}
3 | {"question": "when was the first time lebron went to the finals", "reference": ["2007"], "ctx": " Title: Los Angeles Lakers Content: involves the two most storied basketball franchises in National Basketball Association (NBA) history. It has been called the best rivalry in the NBA. The two teams have met a record 12 times in the NBA Finals, starting with their first Finals meeting in . They would go on to dominate the league in the 1960s and the 1980s, facing each other six times in the 1960s and three times in the 1980s. The rivalry had been less intense since the retirements of Magic Johnson and Larry Bird in the early 1990s, but in 2008 it was renewed as the Celtics", "before": {"answer": "2007", "em": 1, "f1": 1.0}, "after": {"answer": "The first time LeBron James went to the finals was in 2007.", "em": 1, "f1": 0.18181818181818182}, "category": 2}
4 | {"question": "what is the population of the tri-cities washington", "reference": ["over 300,000", "283,846"], "ctx": " Title: Tri-Cities, Washington Content: estimate of MSA population of Tri-Cities is 283,846. This is over a 12% increase from 2010. (2016 US MSA estimates) 2017 estimates now show Tri-Cities as over 300,000. The combined population of the three major cities themselves was 193,567 at the 2010 Census. As of April 1, 2016, the Washington State Office of Financial Management, Forecasting Division estimates the cities as having a combined population of 217,430 The Tri-Cities Airport located in Pasco provides the region with commercial and private air service. Pasco is the seat of Franklin County, while the other two cities are located in Benton County. In", "before": {"answer": "Unknown", "em": 0, "f1": 0}, "after": {"answer": "300,000", "em": 0, "f1": 0.6666666666666666}, "category": 1}
5 | {"question": "panic at the disco la devotee video cast", "reference": ["Noah Schnapp"], "ctx": " Title: LA Devotee Content: to survive and succeed in the city. A video of the audio for \"\"LA Devotee\"\" was uploaded to Fueled by Ramen's official YouTube channel on the day of its release, November 26, 2015. It has been streamed over 27 million times on YouTube and over 80 million times on Spotify as of February 2018. The song's official music video was released September 22, 2016, starring Noah Schnapp, a child actor known for his role as Will Byers in the Netflix series \"\"Stranger Things\"\". The video starts with a young girl nervously trying to cross a field with a dog barking", "before": {"answer": "Unknown", "em": 0, "f1": 0}, "after": {"answer": "Noah Schnapp", "em": 1, "f1": 1.0}, "category": 1}
6 | {"question": "who played dwight's brother on the office", "reference": ["Thomas Middleditch"], "ctx": " Title: Basketball (The Office) Content: fan and cast favorites. \"\"Basketball\"\" was viewed by an estimated 5.0 million viewers and received a 2.4/6% rating share among adults between the ages of 18 and 49. The episode received positive reviews from critics. Michael Scott (Steve Carell) comes in to work prepared to pit the office staff against the warehouse in a game of basketball, with the losers having to work on Saturday. Michael picks Jim Halpert (John Krasinski), Ryan Howard (B. J. Novak), and Stanley Hudson (Leslie David Baker), the last of whom he believes possesses significant basketball skills based on a racial stereotype. In addition, he", "before": {"answer": "Thomas", "em": 0, "f1": 0.6666666666666666}, "after": {"answer": "Thomas Middleitch", "em": 0, "f1": 0.5}, "category": 2}
7 | {"question": "when was the tower of london finished being built", "reference": ["1078"], "ctx": " Title: White Tower (Tower of London) Content: work began on the White Tower, the great stone keep that still dominates the castle today. The precise date of the White Tower's foundation is unknown, and it is also uncertain how long building took. It is traditionally held that construction began in 1078. This is because the \"\"Textus Roffensis\"\" records that Gundulf, Bishop of Rochester, oversaw the building work under instruction from William the Conqueror. Dendrochronological evidence suggests construction of the White Tower began in 1075–1079. The archaeology of the standing building suggests there was a pause in construction between 1080 and 1090–1093, although it is unknown why. Gundulf", "before": {"answer": "1078", "em": 1, "f1": 1.0}, "after": {"answer": "Uncertain", "em": 0, "f1": 0}, "category": 2}
8 | {"question": "who played mrs warboys in one foot in the grave", "reference": ["Doreen Mantle"], "ctx": " Title: Flowers in the Attic (1987 film) Content: and Victoria Tennant were cast as the Grandmother and Mother, respectively, while the four children were played by newcomers Kristy Swanson, Jeb Stuart Adams, Ben Ryan Ganger, and Lindsay Parker. Swanson once claimed when V.C. Andrews met her, she said Swanson was just like she pictured Cathy. Being a fairly low-budget production, Bloom said, big names were not considered for any role in the film. Jeffrey Bloom had a young Sharon Stone audition for the film, but he could not convince the producers to give her the part of Corrine, the mother. Louise Fletcher wanted to get deep inside her", "before": {"answer": "Doreen Mantle", "em": 1, "f1": 1.0}, "after": {"answer": "Annette Crosbie.", "em": 0, "f1": 0}, "category": 2}
9 | {"question": "joined mexico and the united states to form nafta", "reference": ["Canada"], "ctx": " Title: North America Content: is the Canada-Costa Rica Free Trade Agreement among numerous other free trade relations, often between the larger, more developed countries and Central American and Caribbean countries. The North America Free Trade Agreement (NAFTA) forms one of the four largest trade blocs in the world. Its implementation in 1994 was designed for economic homogenization with hopes of eliminating barriers of trade and foreign investment between Canada, the United States and Mexico. While Canada and the United States already conducted the largest bilateral trade relationship – and to present day still do – in the world and Canada–United States trade relations already", "before": {"answer": "Canada", "em": 1, "f1": 1.0}, "after": {"answer": "True", "em": 0, "f1": 0}, "category": 2}
10 | {"question": "what is on a mcchicken sandwich from mcdonalds", "reference": ["mayonnaise", "shredded lettuce", "breaded chicken patty", "toasted wheat bun", "a toasted wheat bun", "a breaded chicken patty"], "ctx": " Title: McChicken Content: McChicken The McChicken is a chicken sandwich sold by the international fast food restaurant chain McDonald's. The sandwich consists of a toasted wheat bun, a breaded chicken patty, shredded lettuce, and mayonnaise. The sandwich, originally introduced in 1980, proved to be a sales disappointment and was later replaced with the highly successful Chicken McNuggets. But following the success of McNuggets, the McChicken was reintroduced in 1988. But again, McDonald's removed the McChicken from its menus in the United States on September 26, 1996, and it was replaced with the Crispy Chicken Deluxe, which was part of McDonald's ill-fated Deluxe line", "before": {"answer": "Chicken patty, lettuce, mayonnaise.", "em": 1, "f1": 0.5714285714285715}, "after": {"answer": "Toasted wheat bun, breaded chicken patty, shredded lettuce, mayonnaise.", "em": 1, "f1": 0.5}, "category": 2}
11 | {"question": "who sings the theme song for degrassi the next generation", "reference": ["The Degrassi Junior Strings", "Jakalope"], "ctx": " Title: Degrassi: The Next Generation Content: during the first three seasons. Dave Ogilvie and Anthony Valcic of Canadian industrial-pop group Jakalope reworked and performed the song with a heavier sound, reflecting the growing maturity of the characters in season four. For seasons six and seven, the theme—still performed by Jakalope—was remixed and stripped of vocals. A fourth version of the theme song, with lyrics sung by Damhnait Doyle, was introduced for the eighth season, and a fifth version of the theme, performed by the in-show band \"\"\"\" is used for the ninth and tenth seasons. For seasons eleven and twelve, a sixth version of the theme", "before": {"answer": "Drake", "em": 0, "f1": 0}, "after": {"answer": "Jakalope", "em": 1, "f1": 1.0}, "category": 1}
12 | {"question": "who has trained the most melbourne cup winners", "reference": ["Bart Cummings"], "ctx": " Title: Lee Freedman Content: winning a BMW, an Australian Cup and a Cox Plate with her. Freedman's greatest training achievement is winning the 2004 and 2005 Melbourne Cup's with 'The Diva', taking her to a total of three wins in the Cup, which has never been done in more than 150 years of the race. Makybe Diva was retired after winning her third Melbourne Cup in November 2005. As of 2006 Freedman is regarded as Victoria's premier horse trainer, and along with Bart Cummings, Gai Waterhouse, David Hayes and John Hawkes, among Australia's best. In the 2006-7 season the Freedman-trained sprinter Miss Andretti won", "before": {"answer": "Bart Cummings", "em": 1, "f1": 1.0}, "after": {"answer": "Lee Freedman", "em": 0, "f1": 0}, "category": 2}
13 | {"question": "where do the signals for apoptosis come from", "reference": ["from other cells"], "ctx": " Title: Apoptosis Content: intrinsic pathway\"\" the cell kills itself because it senses cell stress, while in the \"\"extrinsic pathway\"\" the cell kills itself because of signals from other cells. Weak external signals may also activate the intrinsic pathway of apoptosis. Both pathways induce cell death by activating caspases, which are proteases, or enzymes that degrade proteins. The two pathways both activate initiator caspases, which then activate executioner caspases, which then kill the cell by degrading proteins indiscriminately. Research on apoptosis has increased substantially since the early 1990s. In addition to its importance as a biological phenomenon, defective apoptotic processes have been implicated in", "before": {"answer": "Various sources", "em": 0, "f1": 0}, "after": {"answer": "Other cells", "em": 0, "f1": 0.8}, "category": 1}
14 | {"question": "who plays red on orange is new black", "reference": ["`` Kate '' Mulgrew"], "ctx": " Title: Blue-streaked lory Content: the area to the base of the neck, there is an indigo-violet stripe. The lower part of the nape of the neck, to the top of the back, is a bluish black. The main part of the back, as well as the rump, is red. Wing coverts and flights are red with black tips. The upper part of the tail feathers are a reddish black, and the underside of the tail is a muted red. The iris of the eye is orange-red. The beak is sharply curved, short, and bright red-orange. The feet are medium gray, and the claws are", "before": {"answer": "Kate Mulgrew", "em": 1, "f1": 1.0}, "after": {"answer": "Ruby Rose", "em": 0, "f1": 0}, "category": 2}
15 | {"question": "where does the story the great gatsby take place", "reference": ["Long Island of 1922"], "ctx": " Title: The Great Gatsby Content: a contender for the title \"\"Great American Novel.\"\" In 1998, the Modern Library editorial board voted it the 20th century's best American novel and second best English-language novel of the same time period. Set on the prosperous Long Island of 1922, \"\"The Great Gatsby\"\" provides a critical social history of America during the Roaring Twenties within its fictional narrative. That era, known for widespread economic prosperity, the development of jazz music, flapper culture, new technologies in communication (motion pictures, broadcast radio, recorded music) forging a genuine mass culture, and bootlegging, along with other criminal activity, is plausibly depicted in Fitzgerald's", "before": {"answer": "New York", "em": 0, "f1": 0}, "after": {"answer": "Long Island", "em": 0, "f1": 0.6666666666666666}, "category": 1}
16 | {"question": "when is the opening ceremonies of the olympics 2018", "reference": ["9 February 2018"], "ctx": " Title: 2012 Summer Olympics opening ceremony Content: 2012 Summer Olympics opening ceremony The opening ceremony of the 2012 Summer Olympic Games took place on the evening of Friday 27 July 2012 in the Olympic Stadium, London. As mandated by the Olympic Charter, the proceedings combined the formal ceremonial opening of this international sporting event (including welcoming speeches, hoisting of the flags and the parade of athletes) with an artistic spectacle to showcase the host nation's culture. The 2012 Games were formally opened by Queen Elizabeth II. The spectacle was entitled Isles of Wonder and directed by Academy Award-winning British film director Danny Boyle. Prior to London 2012", "before": {"answer": "February 9", "em": 0, "f1": 0.8}, "after": {"answer": "2020", "em": 0, "f1": 0}, "category": 2}
17 | {"question": "who said i'll gladly pay you tuesday", "reference": ["Wimpy"], "ctx": " Title: USA Tuesday Night Fights Content: César Chávez's eighth-round knockout win over Joey Gamache, which was televised by the show, but as a Pay Per View feature only. In June 2006 CSI Sports, through its FIGHT Sports division, began airing a compilation of some of the best knockouts from the \"\"Tuesday Night Fights\"\" series titled \"\"USA Tuesday Night Fights: KNOCKOUTS!\"\" on pay-per-view. Narrated by current Washington Redskins play-by-play man Larry Michael, the sixty-minute series featured various fights from throughout the years. Common features were \"\"In Case You Missed It...\"\", featuring some memorable moments from \"\"Tuesday Night Fights\"\", \"\"KO Time\"\", which featured very quick knockouts, and \"\"Who", "before": {"answer": "Wimpy", "em": 1, "f1": 1.0}, "after": {"answer": "Wimpy (a character from the Popeye cartoons)", "em": 1, "f1": 0.33333333333333337}, "category": 2}
18 | {"question": "what kind of food did the woolly mammoth eat", "reference": ["grass and sedges"], "ctx": " Title: Woolly mammoth Content: of the woolly mammoth was mainly grass and sedges. Individuals could probably reach the age of 60. Its habitat was the mammoth steppe, which stretched across northern Eurasia and North America. The woolly mammoth coexisted with early humans, who used its bones and tusks for making art, tools, and dwellings, and the species was also hunted for food. It disappeared from its mainland range at the end of the Pleistocene 10,000 years ago, most likely through climate change and consequent shrinkage of its habitat, hunting by humans, or a combination of the two. Isolated populations survived on St. Paul Island", "before": {"answer": "Plants.", "em": 0, "f1": 0}, "after": {"answer": "Grass and sedges", "em": 1, "f1": 1.0}, "category": 1}
19 | {"question": "who has won the most olympic medals in curling", "reference": ["Cathrine Lindahl", "Mirjam Ott", "Kevin Martin", "Anette Norberg", "Eva Lund", "Anna Le Moine", "Torger Nergård"], "ctx": " Title: Curling at the 2010 Winter Olympics Content: With the conclusion of the Vancouver Olympic curling tournament, eight athletes now have two Olympic curling medals. They are in the order in which they received their medals: Mirjam Ott (SUI), Markus Eggler (SUI), Kevin Martin (CAN), Torger Nergård (NOR), Anette Norberg, Eva Lund, Cathrine Lindahl, and Anna Le Moine, all of Sweden. \"\"*Throws third rocks\"\"
\"\"**Throws second rocks\"\" \"\"*Throws second rocks\"\" \"\"**The World Curling Federation had Olga Jarkova listed as the Third. However, a press release by the Vancouver Organizing Committee has Anna Sidorova listed as Third.\"\" \"\"***On Feb 21, 2010, Debbie McCormick switched to throwing third, with Allison", "before": {"answer": "Canada", "em": 0, "f1": 0}, "after": {"answer": "Mirjam Ott", "em": 1, "f1": 1.0}, "category": 1}
20 | {"question": "who built the tower of london in 1066", "reference": ["William the Conqueror"], "ctx": " Title: White Tower (Tower of London) Content: work began on the White Tower, the great stone keep that still dominates the castle today. The precise date of the White Tower's foundation is unknown, and it is also uncertain how long building took. It is traditionally held that construction began in 1078. This is because the \"\"Textus Roffensis\"\" records that Gundulf, Bishop of Rochester, oversaw the building work under instruction from William the Conqueror. Dendrochronological evidence suggests construction of the White Tower began in 1075–1079. The archaeology of the standing building suggests there was a pause in construction between 1080 and 1090–1093, although it is unknown why. Gundulf", "before": {"answer": "William the Conqueror", "em": 1, "f1": 1.0}, "after": {"answer": "unknown", "em": 0, "f1": 0}, "category": 2}
21 |
--------------------------------------------------------------------------------
/data/example_valid_dataset.jsonl:
--------------------------------------------------------------------------------
1 | {"question": "when did the smoking ban in public places start", "reference": ["1995", "August 2 , 1990"], "ctx": " Title: San Luis Obispo, California Content: ordinance to ban smoking in all indoor public areas. Following a failed effort by R. J. Reynolds Tobacco Company to defeat the ordinance, the City Council voted 4–1 in favor on June 19, 1990 with only Mayor Ron Dunin dissenting. As a result, on August 2, 1990, San Luis Obispo became the first municipality in the world to ban smoking in all public buildings, including bars and restaurants. This statute has been a catalyst worldwide in anti-smoking legislation. In April, 2010, San Luis Obispo strengthened its anti-smoking legislation, making smoking in public, excepting for certain conditions, a citable offense beginning", "before": {"answer": "Varies by country", "em": 0, "f1": 0}, "after": {"answer": "August 2, 1990", "em": 1, "f1": 1.0}, "category": 1}
2 | {"question": "who did us fight in world war 1", "reference": ["Austria - Hungary", "Germany"], "ctx": " Title: Home front during World War I Content: failed; likewise, the peace efforts sponsored by industrialist Henry Ford went nowhere. Germany decided to take the risk and try to win by cutting off Britain; the US declared war in April 1917. America had the largest industrial, financial and agricultural base of any of the great powers, but it took 12–18 months to fully reorient it to the war effort. American money, food and munitions flowed freely to Europe from spring 1917, but troops arrived slowly. The US Army in 1917 was small and poorly equipped. The draft began in spring 1917 but volunteers were also accepted. Four million", "before": {"answer": "Central Powers", "em": 0, "f1": 0}, "after": {"answer": "Germany", "em": 1, "f1": 1.0}, "category": 1}
3 | {"question": "who won the workd series of poker 2017", "reference": ["Scott Blumstein"], "ctx": " Title: David Baker (poker player, born 1972) Content: prize of $21,311, while ODB placed 39th for a prize of $17,817. Additionally, both David \"\"ODB\"\" Baker and David \"\"Bakes\"\" Baker won World Series of Poker bracelets three days apart at the 2012 World Series of Poker. ODB won $2,500 477-entrant Event #37: Eight Game Mix, while Bakes won $10,000 178-entrant Event #32: H.O.R.S.E. David Baker (poker player, born 1972) David \"\"ODB\"\" Baker (born July 14, 1972) is an American professional poker player living in Katy, Texas who has appeared at five World Series of Poker (WSOP) final tables. Baker currently has over $2 million in tournament winnings. Baker is", "before": {"answer": "Scott Blumstein", "em": 1, "f1": 1.0}, "after": {"answer": "Unknown. The given information does not mention the World Series of Poker 2017 or the winner.", "em": 0, "f1": 0}, "category": 2}
4 | {"question": "when did the continental congress vote to adopt the declaration of independence", "reference": ["July 2 , 1776"], "ctx": " Title: Continental Congress Content: that the Continental Congress would have to convene once again. The Second Continental Congress convened on May 10, 1775, at Philadelphia's State House, passing the resolution for independence the following year on July 2, 1776, and publicly asserting the decision two days later with the Declaration of Independence. Thomas Jefferson of Virginia drafted the declaration, and John Adams was a leader in the debates in favor of its adoption. John Hancock of Massachusetts was the president during those debates. To govern during the American Revolutionary War, the Second Continental Congress continued, meeting at various locations, until it became the Congress", "before": {"answer": "1776", "em": 0, "f1": 0.5}, "after": {"answer": "July 2, 1776", "em": 1, "f1": 1.0}, "category": 1}
5 | {"question": "when does the boy in the striped pajamas take place", "reference": ["World War II", "during World War II"], "ctx": " Title: The Boy in the Striped Pyjamas Content: more than five million copies around the world. In both 2007 and 2008, it was the best selling book of the year in Spain, and it has also reached number one on the \"\"New York Times\"\" bestseller list, as well as in the UK, Ireland, and Australia. The book was adapted in 2008 as a film of the same name. Bruno is a 9-year-old boy growing up during World War II in Berlin. He lives with his parents, his 12-year-old sister Gretel and maids, one of whom is called Maria. After a visit by Adolf Hitler, Bruno's father is promoted", "before": {"answer": "1940s", "em": 0, "f1": 0}, "after": {"answer": "World War II", "em": 1, "f1": 1.0}, "category": 1}
6 | {"question": "who sang never gonna let you go", "reference": ["Joe Pizzulo", "Joe Pizzulo and Leeza Miller", "Leeza Miller"], "ctx": " Title: Never Gonna Let You Go (Sérgio Mendes song) Content: Never Gonna Let You Go (Sérgio Mendes song) \"\"Never Gonna Let You Go\"\" is a popular song from 1983 credited to Brazilian musician and bandleader Sérgio Mendes and sung by Joe Pizzulo and Leeza Miller. Songwriters Cynthia Weil (lyrics) and Barry Mann (music) composed the song, which appears on Mendes' 1983 self-titled album. Weil and Mann originally submitted \"\"Never Gonna Let You Go\"\" to American funk band Earth, Wind & Fire, but they decided not to record the song. Dionne Warwick recorded the song on her 1982 album \"\"Friends in Love\"\". Stevie Woods had also recorded the song for his", "before": {"answer": "Sergio Mendes.", "em": 0, "f1": 0}, "after": {"answer": "Joe Pizzulo and Leeza Miller", "em": 1, "f1": 1.0}, "category": 1}
7 | {"question": "who are the nbc olympic ice skating commentators", "reference": ["Chris Vosters", "Johnny Weir", "Terry Gannon", "Jonny Moseley", "Charlie White", "Luke Van Valin", "Tara Lipinski", "Tanith White", "Andrea Joyce"], "ctx": " Title: NHL on NBC commentators Content: to became the president of St. Louis Blues, studio analyst Eddie Olczyk was used with Mike Emrick and Pierre McGuire. For 2006–07, \"\"Darren Pang\"\" was used due to Granato's son being born. During 2006–07, the #2 and #3 broadcast teams were mixed up, due to travel constraints. An example of this is color commentators Peter McNab and Brian Hayward often switching roles during the season and playoffs. During the season, due to Pierre McGuire's TSN commitments, a variety of \"\"Inside the Glass\"\" reporters have been used with the #1 team of Mike Emrick and Eddie Olczyk, including Joe Micheletti, Cammi", "before": {"answer": "Tara Lipinski and Johnny Weir", "em": 1, "f1": 0.5714285714285715}, "after": {"answer": "Unknown", "em": 0, "f1": 0}, "category": 2}
8 | {"question": "how many episodes in adventure time season 1", "reference": ["26"], "ctx": " Title: Digimon Adventure Content: was produced by Toei Animation and ran for 54 episodes on Fuji TV in Japan between March 7, 1999 and March 26, 2000. The main opening theme for all episodes aired in Japan is \"\"Butter-Fly\"\" by Kōji Wada, which peaked at #47 on the Oricon Weekly Singles Chart. \"\"I Wish\"\" by AiM is used as the ending theme from episodes 1-26, while \"\"Keep On\"\", AiM's 5th single, served as the ending theme from episodes 27-54. The series also uses three insert songs: \"\"Brave Heart\"\" by Ayumi Miyazaki as the Digivolution theme, \"\"Seven\"\" by Kōji Wada, and by Toshiko Fujita, Tai's", "before": {"answer": "26", "em": 1, "f1": 1.0}, "after": {"answer": "54 episodes", "em": 0, "f1": 0}, "category": 2}
9 | {"question": "when did the movie the post begin filming", "reference": ["May 2017", "May 30 , 2017"], "ctx": " Title: Old Chicago Main Post Office Content: one of 6 sites in Chicago considered for the new Amazon \"\"HQ2\"\" location. The post office was used in the filming of \"\"Batman Begins\"\" in 2004 and for filming \"\"The Dark Knight\"\" in April 2007. It is also seen from above in the opening credits of Candyman. In the summer of 2010, Paramount Pictures and Michael Bay filmed numerous scenes for \"\"\"\" in and around Chicago. The old Chicago Post Office's east facade was decorated as the \"\"Department of Health and Human Services\"\" where the Autobots were kept by the federal government in the movies. The same interior that was", "before": {"answer": "2017", "em": 0, "f1": 0.6666666666666666}, "after": {"answer": "The given information does not mention anything about the movie \"The Post\" and its filming. Therefore, my internal knowledge does not have any information about its filming either.", "em": 0, "f1": 0}, "category": 2}
10 | {"question": "who was suite judy blue eyes written about", "reference": ["Judy Collins", "singer / songwriter Judy Collins"], "ctx": " Title: A Pair of Blue Eyes Content: A Pair of Blue Eyes A Pair of Blue Eyes is a novel by Thomas Hardy, published in 1873, first serialised between September 1872 and July 1873. It was Hardy's third published novel, and the first not published anonymously upon its first publication. The book describes the love triangle of a young woman, Elfride Swancourt, and her two suitors from very different backgrounds. Stephen Smith is a socially inferior but ambitious young man who adores her and with whom she shares a country background. Henry Knight is the respectable, established, older man who represents London society. Although the two are", "before": {"answer": "Judy Collins.", "em": 1, "f1": 1.0}, "after": {"answer": "Crosby, Stills & Nash (and sometimes Young)", "em": 0, "f1": 0}, "category": 2}
11 | {"question": "who is the head of parliament in uk", "reference": ["The Lord Fowler", "John Bercow", "Elizabeth II"], "ctx": " Title: Legislatures of the United Kingdom Content: UK Parliament, the British Parliament, the Westminster Parliament or \"\"Westminster\"\") is the supreme legislative body for the United Kingdom and also for English Law. It alone possesses legislative supremacy and thereby ultimate power over all other political bodies in the UK and its territories. Its head is the Sovereign of the United Kingdom (currently Queen Elizabeth II) and its seat is the Palace of Westminster in Westminster, London. The United Kingdom Legislation may take the form of Acts (passed directly by Parliament) or Statutory Instruments, made under the authority of an Act of Parliament by either a government minister or", "before": {"answer": "Speaker of the House of Commons", "em": 0, "f1": 0}, "after": {"answer": "Queen Elizabeth II", "em": 1, "f1": 0.8}, "category": 1}
12 | {"question": "chief ministers of tamil nadu mentioned on wikipedia", "reference": ["A. Subbarayalu Reddiar", "O. Panneerselvam", "Ramakrishna Ranga Rao", "C. Rajagopalachari", "Tanguturi Prakasam", "P. Subbarayan", "C.N. Annadurai", "Janaki Ramachandran", "M. Karunanidhi", "P.T. Rajan", "J. Jayalalithaa", "V.R. Nedunchezhiyan", "M.G. Ramachandran", "P.S. Kumaraswamy Raja", "K. Palaniswami", "B. Munuswamy Naidu", "Raja of Panagal", "Kurma Venkata Reddy Naidu", "O.P. Ramaswamy Reddiyar"], "ctx": " Title: Tamil Nadu Congress Committee Content: to the Security Laws of the Tamil Nadu Government Departments. Also, Unique Identification Authority of India was established in February 2009, with the help of the Central Government, an agency responsible for implementing the envisioned Multipurpose National Identity Card with the objective of increasing national security and facilitating e-governance. The latest celebrities to join the TNCC where actress Khushboo and actor Karthik Tamil Nadu Congress Committee Tamil Nadu Congress Committee(TNCC) is the wing of Indian National Congress serving in Tamil Nadu. The Current President is Su. Thirunavukkarasar. Social policy of the TNCC is officially based upon the Gandhian principle of", "before": {"answer": "M. K. Stalin, Edappadi K. Palaniswami", "em": 1, "f1": 0.5}, "after": {"answer": "Edappadi K. Palaniswami", "em": 1, "f1": 0.8}, "category": 1}
13 | {"question": "who plays the voice of johnny in sing", "reference": ["Taron Egerton"], "ctx": " Title: Sing (2016 American film) Content: Sing (2016 American film) Sing is a 2016 American computer-animated musical comedy-drama film produced by Illumination Entertainment. It was directed and written by Garth Jennings, co-directed by Christophe Lourdelet and starring the voices of Matthew McConaughey, Reese Witherspoon, Seth MacFarlane, Scarlett Johansson, John C. Reilly, Taron Egerton, and Tori Kelly. The film is about a group of anthropomorphic animals that enter a singing competition, hosted by a koala who is trying to save his theater. The film includes more than 60 songs from famous artists and also has an original song by Stevie Wonder and Ariana Grande called \"\"Faith,\"\" which", "before": {"answer": "Ansel Elgort.", "em": 0, "f1": 0}, "after": {"answer": "Taron Egerton", "em": 1, "f1": 1.0}, "category": 1}
14 | {"question": "how many times has lake placid hosted the winter olympics", "reference": ["twice"], "ctx": " Title: Winter Olympic Games Content: records were set in speed skating largely due to the introduction of the clap skate. The 2002 Winter Olympics were held in Salt Lake City, United States, hosting 77 nations and 2,399 athletes in 78 events in 7 sports. These Games were the first to take place since the September 11 attacks of 2001, which meant a higher degree of security to avoid a terrorist attack. The opening ceremony saw signs of the aftermath of the events of that day, including the flag that flew at Ground Zero, NYPD officer Daniel Rodríguez singing \"\"God Bless America\"\", and honour guards of", "before": {"answer": "Twice", "em": 1, "f1": 1.0}, "after": {"answer": "2 times", "em": 0, "f1": 0}, "category": 2}
15 | {"question": "who is nebula on guardians of the galaxy", "reference": ["Karen Gillan"], "ctx": " Title: Karen Gillan Content: Karen Gillan Karen Gillan (born 28 November 1987) is a Scottish actress, director, screenwriter and model. She played the role of Amy Pond, companion to the Eleventh Doctor, in the BBC One science fiction series \"\"Doctor Who\"\" (2010–2013). She has received wide recognition for portraying Nebula in the Marvel Cinematic Universe films \"\"Guardians of the Galaxy\"\" (2014), \"\"Guardians of the Galaxy Vol. 2\"\" (2017), \"\"\"\" (2018) and \"\"\"\" (2019) and also played Martha in \"\"\"\" (2017). In 2018, she released her first featured film as a writer and director, titled \"\"The Party's Just Beginning\"\". Gillan was born and raised in", "before": {"answer": "Nebula.", "em": 0, "f1": 0}, "after": {"answer": "Karen Gillan", "em": 1, "f1": 1.0}, "category": 1}
16 | {"question": "is a network connection device that can build tables that identify addresses on each network", "reference": ["a router", "routing table"], "ctx": " Title: Zigbee Content: with it. This extended discovery protocol permits external devices to find out about devices in a network and the services that they offer, which endpoints can report when queried by the discovering device (which has previously obtained their addresses). Matching services can also be used. The use of cluster identifiers enforces the binding of complementary entities using the binding tables, which are maintained by Zigbee coordinators, as the table must always be available within a network and coordinators are most likely to have a permanent power supply. Backups, managed by higher-level layers, may be needed by some applications. Binding requires", "before": {"answer": "Router", "em": 1, "f1": 1.0}, "after": {"answer": "Coordinator", "em": 0, "f1": 0}, "category": 2}
17 | {"question": "who played the king in the ten commandments", "reference": ["Yul Brynner", "Woodrow Strode"], "ctx": " Title: Jean Maximilien Lamarque Content: horses shot from under him. In 1810, he was created a baron of the French Empire. Thereafter he was transferred to Spain to support Joseph Bonaparte who had been made king of Spain by his brother, but who was slowly forced to retreat by the rebellious Spanish and the British under Wellington. When Napoleon was exiled to Elba in 1814, Lamarque remained loyal to him, returning to the emperor's service during the Hundred Days. While Napoleon marched to Belgium to deal with the British and Prussian armies, Lamarque commanded a division of ten thousand men against a Royalist uprising in", "before": {"answer": "Yul Brynner", "em": 1, "f1": 1.0}, "after": {"answer": "Charlton Heston.", "em": 0, "f1": 0}, "category": 2}
18 | {"question": "swan lake the sleeping beauty and the nutcracker are three famous ballets by", "reference": ["Pyotr Ilyich Tchaikovsky"], "ctx": " Title: Kristian Fredrikson Content: realised a long-held ambition to design all three of Pyotr Ilyich Tchaikovsky's ballets. He designed \"\"The Nutcracker\"\" for the Royal New Zealand Ballet, \"\"Swan Lake\"\" for the Houston Ballet, and \"\"The Sleeping Beauty\"\" for the Australian Ballet. This last work was the Australian Ballet's largest ever production, requiring 300 costumes and four major sets, and was still touring at the time of Fredrikson's death. In Australia Fredriksen created costumes and undertook production design for a number of film and television productions including costumes for the feature films \"\"Undercover\"\", \"\"Sky Pirate\"\" and \"\"Short Changed\"\" and production design for television series \"\"Vietnam\"\",", "before": {"answer": "Tchaikovsky", "em": 0, "f1": 0.5}, "after": {"answer": "Pyotr Ilyich Tchaikovsky.", "em": 1, "f1": 1.0}, "category": 1}
19 | {"question": "who painted the ceiling of the florence cathedral", "reference": ["Giorgio Vasari", "Federico Zuccari", "Giorgio Vasari and Federico Zuccari"], "ctx": " Title: Raphael Content: more important commission than any he had received before; he had only painted one altarpiece in Florence itself. Several other artists and their teams of assistants were already at work on different rooms, many painting over recently completed paintings commissioned by Julius's loathed predecessor, Alexander VI, whose contributions, and arms, Julius was determined to efface from the palace. Michelangelo, meanwhile, had been commissioned to paint the Sistine Chapel ceiling. This first of the famous \"\"Stanze\"\" or \"\"Raphael Rooms\"\" to be painted, now known as the \"\"Stanza della Segnatura\"\" after its use in Vasari's time, was to make a stunning impact", "before": {"answer": "Giorgio Vasari", "em": 1, "f1": 1.0}, "after": {"answer": "Brunelleschi", "em": 0, "f1": 0}, "category": 2}
20 | {"question": "who was the actor who played crocodile dundee", "reference": ["Paul Hogan", "Paul Hogan , AM"], "ctx": " Title: Gerry Skilton Content: Gerry Skilton Gerry Skilton is an Australian actor of the 20th and early 21st century. Skilton is best known as playing the character Nugget in all three of the Crocodile Dundee series of films: \"\"Crocodile Dundee\"\", \"\"Crocodile Dundee II\"\", and \"\"Crocodile Dundee in Los Angeles\"\". With permission from Paul Hogan, he reprised the character in 2010 for a proposed outback-adventure reality series \"\"Nugget Gets A Life\"\" which he self-financed and had hoped to sell to a network. Among other roles, Skilton played Wayne Churchill in the 1986 Australian mini-series \"\"Cyclone Tracy\"\", Corporal Andy Crilley (the ship's cook) in the 1989", "before": {"answer": "Paul Hogan", "em": 1, "f1": 1.0}, "after": {"answer": "Gerry Skilton", "em": 0, "f1": 0}, "category": 2}
21 |
--------------------------------------------------------------------------------
/joined_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 | import jsonlines
4 | import numpy as np
5 | import json
6 | import torch
7 | from torch import nn
8 | from torch.utils.data import Dataset, DataLoader
9 | from transformers import RobertaPreTrainedModel, RobertaModel
10 | from transformers import AdamW, get_scheduler
11 | from tqdm.auto import tqdm
12 | from torch.utils.tensorboard import SummaryWriter
13 | import os
14 | import argparse
15 | from sklearn.metrics import f1_score
16 | from transformers import (
17 | AutoTokenizer,
18 | AutoModel,
19 | AutoConfig,
20 | AutoModelForSequenceClassification,
21 | GenerationConfig,
22 | LlamaForCausalLM,
23 | LlamaTokenizer,
24 | )
25 |
26 |
27 | class JoinedDataset(Dataset):
28 |
29 | def __init__(self, data_file):
30 | self.data = self.load_data(data_file)
31 |
32 | def load_data(self, data_file):
33 | Data = {}
34 | idx = 0
35 | with jsonlines.open(data_file, "r") as fin:
36 | for line in fin.iter():
37 | Data[idx] = line
38 | idx += 1
39 | return Data
40 |
41 | def __len__(self):
42 | return len(self.data)
43 |
44 | def __getitem__(self, idx):
45 | if isinstance(idx, slice):
46 | start, stop, step = idx.indices(len(self))
47 | return [self.data[i] for i in range(start, stop, step)]
48 | return self.data[idx]
49 |
50 |
51 | class Collater:
52 | def __init__(self, tokenizer, is_train) -> None:
53 | self._tokenizer = tokenizer
54 | self.is_train = is_train
55 |
56 | def tokenize(self, *args):
57 | return self._tokenizer(
58 | *args,
59 | padding=True,
60 | truncation=True,
61 | return_tensors="pt",
62 | )
63 |
64 | def collate_training_data(self, batch_samples):
65 | """
66 | 输入数据形式:
67 | [
68 | {
69 | "id": xx,
70 | "query": xx,
71 | "classification":{"text": xx, "label": xx},
72 | "rank_list":[xx, xx, xx, xx],
73 | "positive": [xx...],
74 | "negative": [xx...],
75 | }...
76 | ]
77 |
78 | 输出数据形式:
79 | {
80 | "classification": [tokenizer([[query, doc]...]), tensor([label...])]
81 | "rank": tokenizer([[query, xx]...])
82 | "positive": tokenizer([[query, positive_sample]...])
83 | "negative": tokenizer([[query, negative_sample]...])
84 | }
85 | """
86 | res = {
87 | "classification": [
88 | self.tokenize(
89 | [
90 | [sample["query"], sample["classification"]["text"]]
91 | for sample in batch_samples
92 | ]
93 | ),
94 | # 1->1, 2->0
95 | torch.tensor(
96 | [2 - sample["classification"]["label"] for sample in batch_samples]
97 | ),
98 | ],
99 | "rank": self.tokenize(
100 | [
101 | [sample["query"], doc]
102 | for sample in batch_samples
103 | for doc in sample["rank_list"]
104 | ]
105 | ),
106 | "positive": self.tokenize(
107 | [
108 | [sample["query"], pos_doc]
109 | for sample in batch_samples
110 | for pos_doc in sample["positive"]
111 | ]
112 | ),
113 | "negative": self.tokenize(
114 | [
115 | [sample["query"], neg_doc]
116 | for sample in batch_samples
117 | for neg_doc in sample["negative"]
118 | ]
119 | ),
120 | }
121 | return res
122 |
123 | def collate_valid_and_test_data(self, batch_samples):
124 | """
125 | 输入数据形式:
126 | [
127 | {"question":"", "ctx":"", "category": xx}
128 | ]
129 | 输出数据形式:
130 | tokenizer([[query, doc]...]), tensor([label...])
131 | """
132 | return (
133 | self.tokenize(
134 | [[sample["question"], sample["ctx"]] for sample in batch_samples]
135 | ),
136 | torch.tensor([2 - sample["category"] for sample in batch_samples]),
137 | )
138 |
139 | def __call__(self, batch_samples):
140 | if self.is_train:
141 | return self.collate_training_data(batch_samples)
142 | else:
143 | return self.collate_valid_and_test_data(batch_samples)
144 |
--------------------------------------------------------------------------------
/train_bge_joined.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 | from typing import Dict
4 | import numpy as np
5 | import json
6 | import torch
7 | from torch import nn
8 | from torch.utils.data import Dataset, DataLoader
9 | from transformers import RobertaPreTrainedModel, RobertaModel
10 | from transformers import AdamW, get_scheduler
11 | from tqdm.auto import tqdm
12 | from torch.utils.tensorboard import SummaryWriter
13 | import os
14 | import argparse
15 | from sklearn.metrics import f1_score
16 | from transformers import (
17 | AutoTokenizer,
18 | AutoModel,
19 | AutoConfig,
20 | AutoModelForSequenceClassification,
21 | GenerationConfig,
22 | LlamaForCausalLM,
23 | LlamaTokenizer,
24 | )
25 |
26 | from joined_dataset import JoinedDataset, Collater
27 | from bge_joined_model import BgeJoinedModel, BgeJoinedModelLoss, WeightsCalculator
28 |
29 |
30 | class Trainer:
31 | def __init__(
32 | self,
33 | model,
34 | train_dataloader,
35 | valid_dataloader,
36 | test_dataloader,
37 | loss_types,
38 | optimizer,
39 | lr_scheduler,
40 | device,
41 | writer,
42 | ) -> None:
43 | self.model = model
44 | self.train_dataloader = train_dataloader
45 | self.valid_dataloader = valid_dataloader
46 | self.test_dataloader = test_dataloader
47 | self.loss_types = loss_types
48 | self.optimizer = optimizer
49 | self.lr_scheduler = lr_scheduler
50 | self.device = device
51 | self.writer = writer
52 | self.weights_calculator = WeightsCalculator(self.device)
53 |
54 | def calc_cls_loss(self, cls_tokens, labels):
55 | outputs = self.model.bge(**cls_tokens).last_hidden_state[:, 0, :]
56 | outputs = self.model.classifier(outputs)
57 | loss_cls = nn.functional.cross_entropy(outputs, labels)
58 | return loss_cls
59 |
60 | def calc_rank_loss(self, rank_tokens, batch_size):
61 | loss_rank = 0
62 | outputs = self.model.bge(**rank_tokens).last_hidden_state[:, 0, :]
63 | outputs = self.model.classifier(outputs)[:, 1]
64 | for group in range(batch_size):
65 | start = group * 4
66 | for i in range(start, start + 3):
67 | for j in range(i + 1, start + 4):
68 | loss_rank += -nn.functional.logsigmoid(outputs[i] - outputs[j])
69 | # C(4,2) * n_groups
70 | loss_rank /= 6 * batch_size
71 | return loss_rank
72 |
73 | def calc_scl_loss(self, pos_tokens, neg_tokens, batch_size):
74 | p_outputs = self.model.bge(**pos_tokens).last_hidden_state[:, 0, :]
75 | n_outputs = self.model.bge(**neg_tokens).last_hidden_state[:, 0, :]
76 | loss_scl = 0
77 | for group in range(batch_size):
78 | features = torch.cat(
79 | [
80 | p_outputs[group * 2 : group * 2 + 2],
81 | n_outputs[group * 2 : group * 2 + 2],
82 | ],
83 | dim=0,
84 | ).unsqueeze(1)
85 | features = nn.functional.normalize(features, dim=-1)
86 | labels = torch.tensor([1, 1, 0, 0]).to(features.device)
87 | loss_scl += self.model.scl_loss_func(features, labels)
88 | loss_scl /= batch_size
89 | return loss_scl
90 |
91 | def cal_loss(self, classification, rank, positive, negative):
92 | """
93 | 输入数据形式:
94 | {
95 | "classification": [tokenizer([[query, doc]...]), tensor([label...])]
96 | "rank": tokenizer([[query, xx]...])
97 | "positive": tokenizer([[query, positive_sample]...])
98 | "negative": tokenizer([[query, negative_sample]...])
99 | }
100 | """
101 | losses = []
102 | if BgeJoinedModelLoss.ClaasificationLoss in self.loss_types:
103 | X, labels = classification
104 | loss_cls = self.calc_cls_loss(X, labels)
105 | losses.append(loss_cls)
106 |
107 | if BgeJoinedModelLoss.RankLoss in self.loss_types:
108 | assert len(rank["input_ids"]) % 4 == 0
109 | loss_rank = self.calc_rank_loss(
110 | rank, batch_size=len(rank["input_ids"]) // 4
111 | )
112 | losses.append(loss_rank)
113 |
114 | if BgeJoinedModelLoss.ContrastiveLoss in self.loss_types:
115 | assert (
116 | len(positive["input_ids"]) == len(negative["input_ids"])
117 | and len(positive["input_ids"]) % 2 == 0
118 | )
119 | loss_scl = self.calc_scl_loss(
120 | positive,
121 | negative,
122 | batch_size=len(positive["input_ids"]) // 2,
123 | )
124 | losses.append(loss_scl)
125 |
126 | self.weights_calculator.reset()
127 | weights = self.weights_calculator.calc_weights(
128 | [self.model.bge.embeddings, self.model.bge.encoder], losses
129 | )
130 | loss = torch.stack(losses).matmul(weights)
131 | return loss
132 |
133 | def train_loop(self, epoch, total_loss):
134 | progress_bar = tqdm(range(len(self.train_dataloader)))
135 | progress_bar.set_description(f"loss: {0:>7f}")
136 | finish_step_num = (epoch - 1) * len(self.train_dataloader)
137 |
138 | self.model.train()
139 | for step, sample in enumerate(self.train_dataloader, start=1):
140 | for k, v in sample.items():
141 | if isinstance(v, list):
142 | sample[k] = [vi.to(self.device) for vi in v]
143 | else:
144 | sample[k] = v.to(self.device)
145 |
146 | loss = self.cal_loss(**sample)
147 | self.writer.add_scalar("loss", loss, step + finish_step_num)
148 |
149 | self.optimizer.zero_grad()
150 | loss.backward()
151 | self.optimizer.step()
152 | if self.lr_scheduler is not None:
153 | self.lr_scheduler.step()
154 |
155 | total_loss += loss.item()
156 | progress_bar.set_description(
157 | f"loss: {total_loss/(finish_step_num + step):>7f}"
158 | )
159 | progress_bar.update(1)
160 | return total_loss
161 |
162 | def test_loop(self, dataloader, dataset_type="Test"):
163 | assert dataset_type in ["Train", "Valid", "Test"]
164 |
165 | self.model.eval()
166 | tp, fp, fn = 0, 0, 0
167 | y0 = y1 = 0
168 | with torch.no_grad():
169 | for sample in dataloader:
170 | X, y = sample["classification"] if dataset_type == "Train" else sample
171 | X, y = X.to(self.device), y.to(self.device)
172 | pred = self.model(X).argmax(1)
173 | tp += torch.sum((pred == 1) & (y == 1)).item()
174 | fp += torch.sum((pred == 1) & (y == 0)).item()
175 | fn += torch.sum((pred == 0) & (y == 1)).item()
176 | y0 += torch.sum(pred == 0).item()
177 | y1 += torch.sum(pred == 1).item()
178 | try:
179 | precision = tp / (tp + fp)
180 | recall = tp / (tp + fn)
181 | except ZeroDivisionError as e:
182 | print(e)
183 | print(f"n_y == 0: {y0}\nn_y == 1: {y1}")
184 | return 0
185 | f1 = 2 * (precision * recall) / (precision + recall)
186 | print(f"{dataset_type} dataset precision: {(100*precision):>0.1f}%")
187 | print(f"{dataset_type} dataset recall: {(100*recall):>0.1f}%")
188 | print(f"{dataset_type} dataset f1: {(100*f1):>0.1f}%")
189 | print(f"n_y == 0: {y0}\nn_y == 1: {y1}")
190 | return f1
191 |
192 | def train(self, epoch_num, outdir):
193 | total_loss = 0.0
194 | best_f1 = 0.0
195 | os.makedirs(outdir, exist_ok=True)
196 | for t in range(epoch_num):
197 | print(f"Epoch {t+1}/{epoch_num}\n-------------------------------")
198 | total_loss = self.train_loop(t + 1, total_loss)
199 | train_f1 = self.test_loop(self.train_dataloader, dataset_type="Train")
200 | self.writer.add_scalar("f1/train_acc", train_f1, t + 1)
201 | valid_f1 = self.test_loop(self.valid_dataloader, dataset_type="Valid")
202 | self.writer.add_scalar("f1/valid_f1", valid_f1, t + 1)
203 | if valid_f1 > best_f1:
204 | best_f1 = valid_f1
205 | print("saving new weights...\n")
206 | torch.save(
207 | self.model.state_dict(),
208 | os.path.join(
209 | outdir,
210 | f"epoch_{t+1}_valid_f1_{(100*valid_f1):0.1f}_model_weights.bin",
211 | ),
212 | )
213 | self.test_loop(self.test_dataloader, dataset_type="Test")
214 |
215 |
216 | def get_args():
217 | parser = argparse.ArgumentParser()
218 | parser.add_argument("--pretrained_model_path", type=str)
219 | parser.add_argument("--train_data_path", type=str)
220 | parser.add_argument("--valid_data_path", type=str)
221 | parser.add_argument("--test_data_path", type=str)
222 | parser.add_argument("--gpu", type=str, choices=["0", "1"], default="0")
223 | parser.add_argument("--outdir", type=str)
224 | parser.add_argument("--tensorboard_log_dir", type=str)
225 | parser.add_argument("--cls_loss", action="store_true")
226 | parser.add_argument("--rank_loss", action="store_true")
227 | parser.add_argument("--scl_loss", action="store_true")
228 |
229 | args = parser.parse_args()
230 | print(f"Using device: gpu:{args.gpu}")
231 |
232 | return args
233 |
234 |
235 | def seed_everything(seed=1029):
236 | random.seed(seed)
237 | os.environ["PYTHONHASHSEED"] = str(seed)
238 | np.random.seed(seed)
239 | torch.manual_seed(seed)
240 | torch.cuda.manual_seed(seed)
241 | torch.cuda.manual_seed_all(seed)
242 | # some cudnn methods can be random even after fixing the seed
243 | # unless you tell it to be deterministic
244 | torch.backends.cudnn.deterministic = True
245 |
246 |
247 | def main():
248 | args = get_args()
249 | seed_everything(42)
250 | learning_rate = 1e-5
251 | batch_size = 4
252 | epoch_num = 10
253 | writer = SummaryWriter(args.tensorboard_log_dir)
254 | device = f"cuda:{args.gpu}"
255 |
256 | checkpoint = args.pretrained_model_path
257 | tokenizer = AutoTokenizer.from_pretrained(checkpoint, model_max_length=512)
258 |
259 | train_data = JoinedDataset(args.train_data_path)
260 | collater = Collater(tokenizer, is_train=True)
261 | train_dataloader = DataLoader(
262 | train_data, batch_size=batch_size, shuffle=True, collate_fn=collater
263 | )
264 |
265 | valid_data = JoinedDataset(args.valid_data_path)
266 | collater = Collater(tokenizer, is_train=False)
267 | valid_dataloader = DataLoader(
268 | valid_data, batch_size=batch_size, shuffle=False, collate_fn=collater
269 | )
270 |
271 | test_data = JoinedDataset(args.valid_data_path)
272 | test_dataloader = DataLoader(
273 | test_data, batch_size=batch_size, shuffle=False, collate_fn=collater
274 | )
275 | loss_types = []
276 | if args.cls_loss:
277 | loss_types.append(BgeJoinedModelLoss.ClaasificationLoss)
278 | if args.rank_loss:
279 | loss_types.append(BgeJoinedModelLoss.RankLoss)
280 | if args.scl_loss:
281 | loss_types.append(BgeJoinedModelLoss.ContrastiveLoss)
282 | model = BgeJoinedModel(checkpoint, loss_types)
283 | model.to(device)
284 | optimizer = AdamW(model.parameters(), lr=learning_rate)
285 | lr_scheduler = None
286 | # lr_scheduler = get_scheduler(
287 | # "linear",
288 | # optimizer=optimizer,
289 | # num_warmup_steps=0,
290 | # num_training_steps=epoch_num * len(train_dataloader),
291 | # )
292 |
293 | trainer = Trainer(
294 | model=model,
295 | train_dataloader=train_dataloader,
296 | valid_dataloader=valid_dataloader,
297 | test_dataloader=test_dataloader,
298 | loss_types=loss_types,
299 | optimizer=optimizer,
300 | lr_scheduler=lr_scheduler,
301 | device=device,
302 | writer=writer,
303 | )
304 |
305 | trainer.train(epoch_num=epoch_num, outdir=args.outdir)
306 | print("Done!")
307 |
308 |
309 | if __name__ == "__main__":
310 | main()
311 |
--------------------------------------------------------------------------------