├── .gitignore ├── NLI_filter.py ├── README.md ├── SupContrast ├── .gitignore ├── LICENSE ├── README.md ├── figures │ ├── .DS_Store │ ├── SimCLR.jpg │ ├── SupCE.jpg │ └── SupContrast.jpg ├── losses.py ├── main_ce.py ├── main_linear.py ├── main_supcon.py ├── networks │ └── resnet_big.py └── util.py ├── bge_joined_model.py ├── data ├── example_test_dataset.jsonl ├── example_train_dataset.jsonl └── example_valid_dataset.jsonl ├── joined_dataset.py └── train_bge_joined.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | __pycache__/ 3 | logs/ 4 | OPENAI_API_KEY.txt 5 | connect_openai.py 6 | temp* 7 | .vscode/ 8 | roberta_weights/ 9 | *.png 10 | test*.sh 11 | bge_weights/ 12 | colbertv2.0/ 13 | rankllama/ 14 | bge/ 15 | bce/ 16 | runs/ 17 | summarize_model/ -------------------------------------------------------------------------------- /NLI_filter.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import jsonlines 4 | import json 5 | import random 6 | import re 7 | import os 8 | import copy 9 | import nltk 10 | import numpy as np 11 | import time 12 | from tenacity import ( 13 | retry, 14 | stop_after_attempt, 15 | wait_random_exponential, 16 | RetryError 17 | ) 18 | import logging 19 | import signal 20 | from tqdm import tqdm 21 | import requests 22 | from concurrent.futures import ThreadPoolExecutor, TimeoutError 23 | import tenacity 24 | 25 | random.seed(0) 26 | 27 | 28 | 29 | filter_results = [] 30 | count = 0 31 | filter_count=0 32 | with open('./sample_data/back_trans.jsonl', 'r') as file: 33 | for i,line in enumerate(tqdm(file)): 34 | line = json.loads(line) 35 | back_instructions = line["back_instruction"] 36 | ori_ins = line['instruction'] 37 | 38 | nli_scores=[] 39 | for back_ins in back_instructions[:3]: 40 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 41 | import torch 42 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 43 | device = "cuda:0" 44 | model_name = "your path to mDeBERTa-v3-base-xnli-multilingual-nli-2mil7" 45 | tokenizer = AutoTokenizer.from_pretrained(model_name) 46 | model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device) 47 | 48 | premise = ori_ins 49 | hypothesis = back_ins 50 | 51 | input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt").to(device) 52 | output = model(input["input_ids"].to(device)) 53 | prediction = torch.softmax(output["logits"][0], -1).tolist() 54 | label_names = ["entailment", "neutral", "contradiction"] 55 | prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)} 56 | max_label = max(prediction, key=prediction.get) 57 | nli_scores.append(max_label) 58 | 59 | 60 | 61 | line["nli_scores"] = nli_scores 62 | print(nli_scores) 63 | 64 | if "contradiction" in nli_scores: 65 | filter_count+=1 66 | continue 67 | else: 68 | filter_results.append(line) 69 | 70 | 71 | count+=1 72 | 73 | 74 | print("filter samples nums:",filter_count) 75 | 76 | with jsonlines.open("./output/back_trans_fliter.jsonl", "w") as f: 77 | for each in filter_results: 78 | f.write(each) 79 | 80 | 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## [Understand What LLM Needs: Dual Preference Alignment for Retrieval-Augmented Generation](https://arxiv.org/pdf/2406.18676) 3 | 4 | 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/understand-what-llm-needs-dual-preference/knowledge-base-question-answering-on-1)](https://paperswithcode.com/sota/knowledge-base-question-answering-on-1?p=understand-what-llm-needs-dual-preference) 6 | 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/understand-what-llm-needs-dual-preference/question-answering-on-natural-questions)](https://paperswithcode.com/sota/question-answering-on-natural-questions?p=understand-what-llm-needs-dual-preference) 8 | 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/understand-what-llm-needs-dual-preference/question-answering-on-triviaqa)](https://paperswithcode.com/sota/question-answering-on-triviaqa?p=understand-what-llm-needs-dual-preference) 10 | 11 | 12 | *Guanting Dong, Yutao Zhu, Chenghao Zhang, Zechen Wang, Zhicheng Dou and Ji-Rong Wen* 13 | 14 | *Gaoling School of Artificial Intelligence, Renmin University of China.* 15 | 16 | This is the repository contains core implementations of the **DPA-RAG**, proposed by [Understand What LLM Needs: Dual Preference Alignment for Retrieval-Augmented Generation](https://arxiv.org/abs/2406.18676). 17 | 18 | 19 | If you find this work helpful for your research, please kindly cite it. 20 | 21 | 22 | ```bibtex 23 | @article{dong2024understand, 24 | author = {Guanting Dong and 25 | Yutao Zhu and 26 | Chenghao Zhang and 27 | Zechen Wang and 28 | Zhicheng Dou and 29 | Ji{-}Rong Wen}, 30 | title = {Understand What {LLM} Needs: Dual Preference Alignment for Retrieval-Augmented 31 | Generation}, 32 | journal = {CoRR}, 33 | volume = {abs/2406.18676}, 34 | year = {2024}, 35 | url = {https://doi.org/10.48550/arXiv.2406.18676}, 36 | doi = {10.48550/ARXIV.2406.18676}, 37 | eprinttype = {arXiv}, 38 | eprint = {2406.18676}, 39 | timestamp = {Mon, 22 Jul 2024 14:28:34 +0200}, 40 | biburl = {https://dblp.org/rec/journals/corr/abs-2406-18676.bib}, 41 | bibsource = {dblp computer science bibliography, https://dblp.org} 42 | } 43 | ``` 44 | 45 | ## 💥 News 46 | - [01/2025] 🔥 Our paper has been accepted by **WWW 2025**. 47 | 48 | - [05/2024] 🔥 We open-source all our code and sampled data of ***DPA-RAG***. Please check out our [reposity](https://github.com/dongguanting/DPA-RAG). 49 | 50 | - [05/2024] 🔥 We introduced ***DPA-RAG***, a universal framework for aligning diverse preference knowledge within RAG systems. Check out the [paper](https://arxiv.org/abs/2406.18676). 51 | 52 | --- 53 | 54 | 55 | ## 🍯 Overall Framework 56 | 57 | image 58 | 59 | 60 | **DPA-RAG** is a universal framework for aligning diverse preference knowledge within RAG systems, consisting of three main components: 61 | 62 | - **Preference Knowledge Construction:** We design three-step process to identify and synthesize high-quality preference knowledge 63 | 64 | - **Reranker-LLM Alignment:** We fine-tune a reranker with multi-grained tasks to achieve alignment between the retriever and LLMs. 65 | 66 | - **LLM Self-Alignment:** Before SFT stage, we introduce a pre-alignment phase to help LLMs capture preference-aligned knowledge. 67 | 68 | 69 | --- 70 | 71 | 72 | 73 | 74 | 75 | 76 | ## 💻 Data preparation 77 | We design a three-step method to gradually mine, augment, and filter out high-quality preference knowledge: 78 | 79 | ### 1. Preference Knowledge Constructio 80 | For each samples in different datasets, you need to use [DPR](https://github.com/facebookresearch/DPR) to retrieve the top 100 passages. 81 | 82 | Then, please follow the process of **Preference Knowledge Construction** section to extract documents labeled “Aligned Knowledge” or “Unaligned Knowledge” from different base model. 83 | 84 | 85 | ### 2. Diverse Query Augmentation 86 | Please use **GPT-3.5-turbo** to perform five augmentations for each query, the template and requirements are as follow: 87 | 88 | ![image](https://github.com/dongguanting/DPA-RAG/assets/60767110/cbefab86-74c8-46ea-afc7-5f94b5ca100c) 89 | 90 | - Rephrasing. Rephrase the original query with the same intention. 91 | - Complexity. Increase the semantic complexity of the original query. 92 | - Decomposition. Decompose the original query into several sub-problems. 93 | - Constraint. Add more conditional and constrained statements to the original query. 94 | - SPARQL. Rewrite the original query based on the SPARQL syntax and generate it directly 95 | 96 | We also provide **1k samples of NQ dataset** for each augmentation, which can be directly downloaded in [here](https://drive.google.com/drive/folders/1fbehvvNzas0VitdBky-pDLDZ_vLSHI81). 97 | 98 | 99 | ### 3. NLI Filtering 100 | Please use use [mDeBERTa](https://huggingface.co/MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7) as NLI model for consistency filtering between original and augmented samples. 101 | 102 | **Data format:** 103 | 104 | ```json 105 | { 106 | "instruction": "what does jamaican people speak?", 107 | "back_instruction": ["What language do Jamaican people speak?", 108 | "Given their diverse cultural background and official language policies", 109 | "what language do Jamaican people predominantly speak in everyday life as well as in formal settings?", 110 | "What do Jamaicans speak?"] 111 | } 112 | ``` 113 | 114 | **Fliter process:** 115 | 116 | ```bash 117 | python NLI_filter.py 118 | ``` 119 | 120 | 121 | --- 122 | 123 | ## :sparkles: Reranker-LLM Alignment 124 | 125 | ### 1. Code clone 126 | Run the following command from the command line to clone our code from github to local: 127 | 128 | ```bash 129 | git clone https://github.com/dongguanting/DPA-RAG.git 130 | ``` 131 | 132 | ### 3. Start training 133 | 134 | We present some training data and test data. If you need to modify training data and test data, follow a similar data format without missing the necessary fields. 135 | 136 | ```bash 137 | python train_bge_joined.py \ 138 | --pretrained_model_path=your/pretrained/bge/path \ 139 | --train_data_path=your/train/data/path \ 140 | --valid_data_path=your/valid/data/path \ 141 | --test_data_path=your/test/data/path \ 142 | --gpu=0 \ 143 | --outdir=your/output/path \ 144 | --tensorboard_log_dir=your/tensorboard/log/path \ 145 | --cls_loss \ 146 | --rank_loss \ 147 | --scl_loss 148 | ``` 149 | 150 | You can choose which GPU to train and test your model on, but right now we don't support multi-GPU training. 151 | 152 | You can choose which losses to incorporate when training. For example, if you only want to use classification loss and contrast learning loss, you can run the following command to start training: 153 | 154 | ```bash 155 | python train_bge_joined.py \ 156 | --pretrained_model_path=your/pretrained/bge/path \ 157 | --train_data_path=your/train/data/path \ 158 | --valid_data_path=your/valid/data/path \ 159 | --test_data_path=your/test/data/path \ 160 | --gpu=0 \ 161 | --outdir=your/output/path \ 162 | --tensorboard_log_dir=your/tensorboard/log/path \ 163 | --cls_loss \ 164 | --scl_loss 165 | ``` 166 | 167 | Tests are automatically performed on the specified data set after the training is complete. 168 | 169 | --- 170 | 171 | 172 | ## 🌈 LLM Training 173 | 174 | 175 | For LLM Training,we use the [LlaMA-Factory v0.6.3](https://github.com/hiyouga/LLaMA-Factory/releases/tag/v0.6.3) for Llama2-7B/13B, Mistral-7B, Qwen1.5-0.5B/4B/7B/14B, Phi2-2.7B. Moreover, we use the [LlaMA-Factory v0.8.1](https://github.com/hiyouga/LLaMA-Factory/releases/tag/v0.8.1) for Qwen2-7B, Llama3-8B. Thanks for their excellent work. 176 | 177 | 178 | 179 | ### 1. Pre-aligned Training: 180 | 181 | We provide the NQ dataset at the prealigned stage [here](https://drive.google.com/drive/folders/1JFCGpnmqfMHGh6X9cMJFtTxVduHkiQXi?usp=sharing). Note that the prealigned data we provide does not include augmented data, please merge augmentation data to unlock more powerful alignment capabilities. You can construct other datasets on your own following our NQ's data format. Please replace the parameters with $ symbols with your own parameters. 182 | 183 | ```bash 184 | deepspeed --num_gpus=8 train_bash.py \ 185 | --deepspeed $deepspeed_zero3_config_path \ 186 | --stage sft \ 187 | --do_train \ 188 | --use_fast_tokenizer \ 189 | --flash_attn \ 190 | --adam_beta1 0.9 \ 191 | --adam_beta2 0.95 \ 192 | --model_name_or_path $MODEL_PATH \ 193 | --dataset $dataset \ 194 | --template $Template \ 195 | --finetuning_type full \ 196 | --output_dir $OUTPUT_PATH \ 197 | --overwrite_cache \ 198 | --overwrite_output_dir \ 199 | --warmup_steps 20 \ 200 | --weight_decay 0.1 \ 201 | --per_device_train_batch_size 4 \ 202 | --gradient_accumulation_steps 4 \ 203 | --ddp_timeout 9000 \ 204 | --learning_rate 7e-6 \ 205 | --lr_scheduler_type "linear" \ 206 | --logging_steps 1 \ 207 | --cutoff_len 8192 \ 208 | --save_steps 200 \ 209 | --num_train_epochs 3.0 \ 210 | --plot_loss \ 211 | --bf16 212 | ``` 213 | 214 | ### 2. SFT Training: 215 | 216 | You can find the original training data with top3 passages (w/o data augmentation) [here](https://drive.google.com/drive/folders/1dCCpAVPiwPgjOhuKGcyonwgfr2kntJHZ?usp=sharing). 217 | Please merge your augmentation data to unlock more powerful alignment capabilities. 218 | 219 | ```bash 220 | deepspeed --num_gpus=8 train_bash.py \ 221 | --deepspeed $deepspeed_zero3_config_path \ 222 | --stage sft \ 223 | --do_train \ 224 | --use_fast_tokenizer \ 225 | --flash_attn \ 226 | --adam_beta1 0.9 \ 227 | --adam_beta2 0.95 \ 228 | --model_name_or_path $MODEL_PATH \ 229 | --dataset $dataset \ 230 | --template $Template \ 231 | --finetuning_type full \ 232 | --output_dir $OUTPUT_PATH \ 233 | --overwrite_cache \ 234 | --overwrite_output_dir \ 235 | --warmup_steps 20 \ 236 | --weight_decay 0.1 \ 237 | --per_device_train_batch_size 4 \ 238 | --gradient_accumulation_steps 4 \ 239 | --ddp_timeout 9000 \ 240 | --learning_rate 7e-6 \ 241 | --lr_scheduler_type "linear" \ 242 | --logging_steps 1 \ 243 | --cutoff_len 8192 \ 244 | --save_steps 200 \ 245 | --num_train_epochs 3.0 \ 246 | --plot_loss \ 247 | --bf16 248 | ``` 249 | 250 | ### 3. Inference 251 | 252 | You can find our reranked test data with top3 passages [here](https://drive.google.com/drive/folders/1HFAEGX5A5aVFNuWMzA1zLoRtfE-FyoTB?usp=sharing). For WebQSP dataset. we only provide train and test data with top2 passages for aligning. 253 | 254 | ```bash 255 | CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ 256 | --stage sft \ 257 | --model_name_or_path $path_to_llama_model \ 258 | --do_predict \ 259 | --flash_attn no \ 260 | --dataset $dataset \ 261 | --template $template \ 262 | --output_dir $OUTPUT_PATH \ 263 | --per_device_eval_batch_size 8 \ 264 | --max_samples 150 \ 265 | --cutoff_len 2048 \ 266 | --predict_with_generate \ 267 | --fp16 \ 268 | --quantization_bit 4 \ 269 | --max_new_tokens 20 270 | ``` 271 | 272 | 273 | 274 | -------------------------------------------------------------------------------- /SupContrast/.gitignore: -------------------------------------------------------------------------------- 1 | tmp*.py 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | save/ -------------------------------------------------------------------------------- /SupContrast/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, Yonglong Tian 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /SupContrast/README.md: -------------------------------------------------------------------------------- 1 | # SupContrast: Supervised Contrastive Learning 2 |

3 | 4 |

5 | 6 | This repo covers an reference implementation for the following papers in PyTorch, using CIFAR as an illustrative example: 7 | (1) Supervised Contrastive Learning. [Paper](https://arxiv.org/abs/2004.11362) 8 | (2) A Simple Framework for Contrastive Learning of Visual Representations. [Paper](https://arxiv.org/abs/2002.05709) 9 | 10 | ## Update 11 | 12 | ${\color{red}Note}$: if you found it not easy to parse the supcon loss implementation in this repo, we got you. Supcon loss essentially is just a cross-entropy loss (see eq 4 in the [StableRep](https://arxiv.org/pdf/2306.00984.pdf) paper). So we got a cleaner and simpler implementation [here](https://github.com/google-research/syn-rep-learn/blob/main/StableRep/models/losses.py#L49). Hope it helps. 13 | 14 | ImageNet model (small batch size with the trick of the momentum encoder) is released [here](https://www.dropbox.com/s/l4a69ececk4spdt/supcon.pth?dl=0). It achieved > 79% top-1 accuracy. 15 | 16 | ## Loss Function 17 | The loss function [`SupConLoss`](https://github.com/HobbitLong/SupContrast/blob/master/losses.py#L11) in `losses.py` takes `features` (L2 normalized) and `labels` as input, and return the loss. If `labels` is `None` or not passed to the it, it degenerates to SimCLR. 18 | 19 | Usage: 20 | ```python 21 | from losses import SupConLoss 22 | 23 | # define loss with a temperature `temp` 24 | criterion = SupConLoss(temperature=temp) 25 | 26 | # features: [bsz, n_views, f_dim] 27 | # `n_views` is the number of crops from each image 28 | # better be L2 normalized in f_dim dimension 29 | features = ... 30 | # labels: [bsz] 31 | labels = ... 32 | 33 | # SupContrast 34 | loss = criterion(features, labels) 35 | # or SimCLR 36 | loss = criterion(features) 37 | ... 38 | ``` 39 | 40 | ## Comparison 41 | Results on CIFAR-10: 42 | | |Arch | Setting | Loss | Accuracy(%) | 43 | |----------|:----:|:---:|:---:|:---:| 44 | | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | 95.0 | 45 | | SupContrast | ResNet50 | Supervised | Contrastive | 96.0 | 46 | | SimCLR | ResNet50 | Unsupervised | Contrastive | 93.6 | 47 | 48 | Results on CIFAR-100: 49 | | |Arch | Setting | Loss | Accuracy(%) | 50 | |----------|:----:|:---:|:---:|:---:| 51 | | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | 75.3 | 52 | | SupContrast | ResNet50 | Supervised | Contrastive | 76.5 | 53 | | SimCLR | ResNet50 | Unsupervised | Contrastive | 70.7 | 54 | 55 | Results on ImageNet (Stay tuned): 56 | | |Arch | Setting | Loss | Accuracy(%) | 57 | |----------|:----:|:---:|:---:|:---:| 58 | | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | - | 59 | | SupContrast | ResNet50 | Supervised | Contrastive | 79.1 (MoCo trick) | 60 | | SimCLR | ResNet50 | Unsupervised | Contrastive | - | 61 | 62 | ## Running 63 | You might use `CUDA_VISIBLE_DEVICES` to set proper number of GPUs, and/or switch to CIFAR100 by `--dataset cifar100`. 64 | **(1) Standard Cross-Entropy** 65 | ``` 66 | python main_ce.py --batch_size 1024 \ 67 | --learning_rate 0.8 \ 68 | --cosine --syncBN \ 69 | ``` 70 | **(2) Supervised Contrastive Learning** 71 | Pretraining stage: 72 | ``` 73 | python main_supcon.py --batch_size 1024 \ 74 | --learning_rate 0.5 \ 75 | --temp 0.1 \ 76 | --cosine 77 | ``` 78 | 79 | You can also specify `--syncBN` but I found it not crucial for SupContrast (`syncBN` 95.9% v.s. `BN` 96.0%). 80 | 81 | WARN: Currently, `--syncBN` has no effect since the code is using `DataParallel` instead of `DistributedDataParaleel` 82 | 83 | Linear evaluation stage: 84 | ``` 85 | python main_linear.py --batch_size 512 \ 86 | --learning_rate 5 \ 87 | --ckpt /path/to/model.pth 88 | ``` 89 | **(3) SimCLR** 90 | Pretraining stage: 91 | ``` 92 | python main_supcon.py --batch_size 1024 \ 93 | --learning_rate 0.5 \ 94 | --temp 0.5 \ 95 | --cosine --syncBN \ 96 | --method SimCLR 97 | ``` 98 | The `--method SimCLR` flag simply stops `labels` from being passed to `SupConLoss` criterion. 99 | Linear evaluation stage: 100 | ``` 101 | python main_linear.py --batch_size 512 \ 102 | --learning_rate 1 \ 103 | --ckpt /path/to/model.pth 104 | ``` 105 | 106 | On custom dataset: 107 | ``` 108 | python main_supcon.py --batch_size 1024 \ 109 | --learning_rate 0.5 \ 110 | --temp 0.1 --cosine \ 111 | --dataset path \ 112 | --data_folder ./path \ 113 | --mean "(0.4914, 0.4822, 0.4465)" \ 114 | --std "(0.2675, 0.2565, 0.2761)" \ 115 | --method SimCLR 116 | ``` 117 | 118 | The `--data_folder` must be of form ./path/label/xxx.png folowing https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.ImageFolder convension. 119 | 120 | and 121 | ## t-SNE Visualization 122 | 123 | **(1) Standard Cross-Entropy** 124 |

125 | 126 |

127 | 128 | **(2) Supervised Contrastive Learning** 129 |

130 | 131 |

132 | 133 | **(3) SimCLR** 134 |

135 | 136 |

137 | 138 | ## Reference 139 | ``` 140 | @Article{khosla2020supervised, 141 | title = {Supervised Contrastive Learning}, 142 | author = {Prannay Khosla and Piotr Teterwak and Chen Wang and Aaron Sarna and Yonglong Tian and Phillip Isola and Aaron Maschinot and Ce Liu and Dilip Krishnan}, 143 | journal = {arXiv preprint arXiv:2004.11362}, 144 | year = {2020}, 145 | } 146 | ``` 147 | -------------------------------------------------------------------------------- /SupContrast/figures/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongguanting/DPA-RAG/2ebaf7e7f60ef2055162c3429fc0fd40ac857a60/SupContrast/figures/.DS_Store -------------------------------------------------------------------------------- /SupContrast/figures/SimCLR.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongguanting/DPA-RAG/2ebaf7e7f60ef2055162c3429fc0fd40ac857a60/SupContrast/figures/SimCLR.jpg -------------------------------------------------------------------------------- /SupContrast/figures/SupCE.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongguanting/DPA-RAG/2ebaf7e7f60ef2055162c3429fc0fd40ac857a60/SupContrast/figures/SupCE.jpg -------------------------------------------------------------------------------- /SupContrast/figures/SupContrast.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongguanting/DPA-RAG/2ebaf7e7f60ef2055162c3429fc0fd40ac857a60/SupContrast/figures/SupContrast.jpg -------------------------------------------------------------------------------- /SupContrast/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Yonglong Tian (yonglong@mit.edu) 3 | Date: May 07, 2020 4 | """ 5 | 6 | from __future__ import print_function 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class SupConLoss(nn.Module): 13 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 14 | It also supports the unsupervised contrastive loss in SimCLR""" 15 | 16 | def __init__(self, temperature=0.07, contrast_mode="all", base_temperature=0.07): 17 | super(SupConLoss, self).__init__() 18 | self.temperature = temperature 19 | self.contrast_mode = contrast_mode 20 | self.base_temperature = base_temperature 21 | 22 | def forward(self, features, labels=None, mask=None): 23 | """Compute loss for model. If both `labels` and `mask` are None, 24 | it degenerates to SimCLR unsupervised loss: 25 | https://arxiv.org/pdf/2002.05709.pdf 26 | 27 | Args: 28 | features: hidden vector of shape [bsz, n_views, ...]. 29 | labels: ground truth of shape [bsz]. 30 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 31 | has the same class as sample i. Can be asymmetric. 32 | Returns: 33 | A loss scalar. 34 | """ 35 | device = features.device 36 | 37 | if len(features.shape) < 3: 38 | raise ValueError( 39 | "`features` needs to be [bsz, n_views, ...]," 40 | "at least 3 dimensions are required" 41 | ) 42 | if len(features.shape) > 3: 43 | features = features.view(features.shape[0], features.shape[1], -1) 44 | 45 | batch_size = features.shape[0] 46 | if labels is not None and mask is not None: 47 | raise ValueError("Cannot define both `labels` and `mask`") 48 | elif labels is None and mask is None: 49 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 50 | elif labels is not None: 51 | labels = labels.contiguous().view(-1, 1) 52 | if labels.shape[0] != batch_size: 53 | raise ValueError("Num of labels does not match num of features") 54 | mask = torch.eq(labels, labels.T).float().to(device) 55 | else: 56 | mask = mask.float().to(device) 57 | 58 | contrast_count = features.shape[1] 59 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 60 | if self.contrast_mode == "one": 61 | anchor_feature = features[:, 0] 62 | anchor_count = 1 63 | elif self.contrast_mode == "all": 64 | anchor_feature = contrast_feature 65 | anchor_count = contrast_count 66 | else: 67 | raise ValueError("Unknown mode: {}".format(self.contrast_mode)) 68 | 69 | # compute logits 70 | anchor_dot_contrast = torch.div( 71 | torch.matmul(anchor_feature, contrast_feature.T), self.temperature 72 | ) 73 | # for numerical stability 74 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 75 | logits = anchor_dot_contrast - logits_max.detach() 76 | 77 | # tile mask 78 | mask = mask.repeat(anchor_count, contrast_count) 79 | # mask-out self-contrast cases 80 | logits_mask = torch.scatter( 81 | torch.ones_like(mask), 82 | 1, 83 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 84 | 0, 85 | ) 86 | mask = mask * logits_mask 87 | 88 | # compute log_prob 89 | exp_logits = torch.exp(logits) * logits_mask 90 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 91 | 92 | # compute mean of log-likelihood over positive 93 | # modified to handle edge cases when there is no positive pair 94 | # for an anchor point. 95 | # Edge case e.g.:- 96 | # features of shape: [4,1,...] 97 | # labels: [0,1,1,2] 98 | # loss before mean: [nan, ..., ..., nan] 99 | mask_pos_pairs = mask.sum(1) 100 | mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) 101 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs 102 | 103 | # loss 104 | loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos 105 | loss = loss.view(anchor_count, batch_size).mean() 106 | 107 | return loss 108 | -------------------------------------------------------------------------------- /SupContrast/main_ce.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import argparse 6 | import time 7 | import math 8 | 9 | import tensorboard_logger as tb_logger 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | from torchvision import transforms, datasets 13 | 14 | from util import AverageMeter 15 | from util import adjust_learning_rate, warmup_learning_rate, accuracy 16 | from util import set_optimizer, save_model 17 | from networks.resnet_big import SupCEResNet 18 | 19 | try: 20 | import apex 21 | from apex import amp, optimizers 22 | except ImportError: 23 | pass 24 | 25 | 26 | def parse_option(): 27 | parser = argparse.ArgumentParser('argument for training') 28 | 29 | parser.add_argument('--print_freq', type=int, default=10, 30 | help='print frequency') 31 | parser.add_argument('--save_freq', type=int, default=50, 32 | help='save frequency') 33 | parser.add_argument('--batch_size', type=int, default=256, 34 | help='batch_size') 35 | parser.add_argument('--num_workers', type=int, default=16, 36 | help='num of workers to use') 37 | parser.add_argument('--epochs', type=int, default=500, 38 | help='number of training epochs') 39 | 40 | # optimization 41 | parser.add_argument('--learning_rate', type=float, default=0.2, 42 | help='learning rate') 43 | parser.add_argument('--lr_decay_epochs', type=str, default='350,400,450', 44 | help='where to decay lr, can be a list') 45 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, 46 | help='decay rate for learning rate') 47 | parser.add_argument('--weight_decay', type=float, default=1e-4, 48 | help='weight decay') 49 | parser.add_argument('--momentum', type=float, default=0.9, 50 | help='momentum') 51 | 52 | # model dataset 53 | parser.add_argument('--model', type=str, default='resnet50') 54 | parser.add_argument('--dataset', type=str, default='cifar10', 55 | choices=['cifar10', 'cifar100'], help='dataset') 56 | 57 | # other setting 58 | parser.add_argument('--cosine', action='store_true', 59 | help='using cosine annealing') 60 | parser.add_argument('--syncBN', action='store_true', 61 | help='using synchronized batch normalization') 62 | parser.add_argument('--warm', action='store_true', 63 | help='warm-up for large batch training') 64 | parser.add_argument('--trial', type=str, default='0', 65 | help='id for recording multiple runs') 66 | 67 | opt = parser.parse_args() 68 | 69 | # set the path according to the environment 70 | opt.data_folder = './datasets/' 71 | opt.model_path = './save/SupCon/{}_models'.format(opt.dataset) 72 | opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset) 73 | 74 | iterations = opt.lr_decay_epochs.split(',') 75 | opt.lr_decay_epochs = list([]) 76 | for it in iterations: 77 | opt.lr_decay_epochs.append(int(it)) 78 | 79 | opt.model_name = 'SupCE_{}_{}_lr_{}_decay_{}_bsz_{}_trial_{}'.\ 80 | format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay, 81 | opt.batch_size, opt.trial) 82 | 83 | if opt.cosine: 84 | opt.model_name = '{}_cosine'.format(opt.model_name) 85 | 86 | # warm-up for large-batch training, 87 | if opt.batch_size > 256: 88 | opt.warm = True 89 | if opt.warm: 90 | opt.model_name = '{}_warm'.format(opt.model_name) 91 | opt.warmup_from = 0.01 92 | opt.warm_epochs = 10 93 | if opt.cosine: 94 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 95 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( 96 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 97 | else: 98 | opt.warmup_to = opt.learning_rate 99 | 100 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) 101 | if not os.path.isdir(opt.tb_folder): 102 | os.makedirs(opt.tb_folder) 103 | 104 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 105 | if not os.path.isdir(opt.save_folder): 106 | os.makedirs(opt.save_folder) 107 | 108 | if opt.dataset == 'cifar10': 109 | opt.n_cls = 10 110 | elif opt.dataset == 'cifar100': 111 | opt.n_cls = 100 112 | else: 113 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 114 | 115 | return opt 116 | 117 | 118 | def set_loader(opt): 119 | # construct data loader 120 | if opt.dataset == 'cifar10': 121 | mean = (0.4914, 0.4822, 0.4465) 122 | std = (0.2023, 0.1994, 0.2010) 123 | elif opt.dataset == 'cifar100': 124 | mean = (0.5071, 0.4867, 0.4408) 125 | std = (0.2675, 0.2565, 0.2761) 126 | else: 127 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 128 | normalize = transforms.Normalize(mean=mean, std=std) 129 | 130 | train_transform = transforms.Compose([ 131 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), 132 | transforms.RandomHorizontalFlip(), 133 | transforms.ToTensor(), 134 | normalize, 135 | ]) 136 | 137 | val_transform = transforms.Compose([ 138 | transforms.ToTensor(), 139 | normalize, 140 | ]) 141 | 142 | if opt.dataset == 'cifar10': 143 | train_dataset = datasets.CIFAR10(root=opt.data_folder, 144 | transform=train_transform, 145 | download=True) 146 | val_dataset = datasets.CIFAR10(root=opt.data_folder, 147 | train=False, 148 | transform=val_transform) 149 | elif opt.dataset == 'cifar100': 150 | train_dataset = datasets.CIFAR100(root=opt.data_folder, 151 | transform=train_transform, 152 | download=True) 153 | val_dataset = datasets.CIFAR100(root=opt.data_folder, 154 | train=False, 155 | transform=val_transform) 156 | else: 157 | raise ValueError(opt.dataset) 158 | 159 | train_sampler = None 160 | train_loader = torch.utils.data.DataLoader( 161 | train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None), 162 | num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler) 163 | val_loader = torch.utils.data.DataLoader( 164 | val_dataset, batch_size=256, shuffle=False, 165 | num_workers=8, pin_memory=True) 166 | 167 | return train_loader, val_loader 168 | 169 | 170 | def set_model(opt): 171 | model = SupCEResNet(name=opt.model, num_classes=opt.n_cls) 172 | criterion = torch.nn.CrossEntropyLoss() 173 | 174 | # enable synchronized Batch Normalization 175 | if opt.syncBN: 176 | model = apex.parallel.convert_syncbn_model(model) 177 | 178 | if torch.cuda.is_available(): 179 | if torch.cuda.device_count() > 1: 180 | model = torch.nn.DataParallel(model) 181 | model = model.cuda() 182 | criterion = criterion.cuda() 183 | cudnn.benchmark = True 184 | 185 | return model, criterion 186 | 187 | 188 | def train(train_loader, model, criterion, optimizer, epoch, opt): 189 | """one epoch training""" 190 | model.train() 191 | 192 | batch_time = AverageMeter() 193 | data_time = AverageMeter() 194 | losses = AverageMeter() 195 | top1 = AverageMeter() 196 | 197 | end = time.time() 198 | for idx, (images, labels) in enumerate(train_loader): 199 | data_time.update(time.time() - end) 200 | 201 | images = images.cuda(non_blocking=True) 202 | labels = labels.cuda(non_blocking=True) 203 | bsz = labels.shape[0] 204 | 205 | # warm-up learning rate 206 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) 207 | 208 | # compute loss 209 | output = model(images) 210 | loss = criterion(output, labels) 211 | 212 | # update metric 213 | losses.update(loss.item(), bsz) 214 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 215 | top1.update(acc1[0], bsz) 216 | 217 | # SGD 218 | optimizer.zero_grad() 219 | loss.backward() 220 | optimizer.step() 221 | 222 | # measure elapsed time 223 | batch_time.update(time.time() - end) 224 | end = time.time() 225 | 226 | # print info 227 | if (idx + 1) % opt.print_freq == 0: 228 | print('Train: [{0}][{1}/{2}]\t' 229 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 230 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 231 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t' 232 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 233 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 234 | data_time=data_time, loss=losses, top1=top1)) 235 | sys.stdout.flush() 236 | 237 | return losses.avg, top1.avg 238 | 239 | 240 | def validate(val_loader, model, criterion, opt): 241 | """validation""" 242 | model.eval() 243 | 244 | batch_time = AverageMeter() 245 | losses = AverageMeter() 246 | top1 = AverageMeter() 247 | 248 | with torch.no_grad(): 249 | end = time.time() 250 | for idx, (images, labels) in enumerate(val_loader): 251 | images = images.float().cuda() 252 | labels = labels.cuda() 253 | bsz = labels.shape[0] 254 | 255 | # forward 256 | output = model(images) 257 | loss = criterion(output, labels) 258 | 259 | # update metric 260 | losses.update(loss.item(), bsz) 261 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 262 | top1.update(acc1[0], bsz) 263 | 264 | # measure elapsed time 265 | batch_time.update(time.time() - end) 266 | end = time.time() 267 | 268 | if idx % opt.print_freq == 0: 269 | print('Test: [{0}/{1}]\t' 270 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 271 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 272 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 273 | idx, len(val_loader), batch_time=batch_time, 274 | loss=losses, top1=top1)) 275 | 276 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) 277 | return losses.avg, top1.avg 278 | 279 | 280 | def main(): 281 | best_acc = 0 282 | opt = parse_option() 283 | 284 | # build data loader 285 | train_loader, val_loader = set_loader(opt) 286 | 287 | # build model and criterion 288 | model, criterion = set_model(opt) 289 | 290 | # build optimizer 291 | optimizer = set_optimizer(opt, model) 292 | 293 | # tensorboard 294 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) 295 | 296 | # training routine 297 | for epoch in range(1, opt.epochs + 1): 298 | adjust_learning_rate(opt, optimizer, epoch) 299 | 300 | # train for one epoch 301 | time1 = time.time() 302 | loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, opt) 303 | time2 = time.time() 304 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 305 | 306 | # tensorboard logger 307 | logger.log_value('train_loss', loss, epoch) 308 | logger.log_value('train_acc', train_acc, epoch) 309 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) 310 | 311 | # evaluation 312 | loss, val_acc = validate(val_loader, model, criterion, opt) 313 | logger.log_value('val_loss', loss, epoch) 314 | logger.log_value('val_acc', val_acc, epoch) 315 | 316 | if val_acc > best_acc: 317 | best_acc = val_acc 318 | 319 | if epoch % opt.save_freq == 0: 320 | save_file = os.path.join( 321 | opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 322 | save_model(model, optimizer, opt, epoch, save_file) 323 | 324 | # save the last model 325 | save_file = os.path.join( 326 | opt.save_folder, 'last.pth') 327 | save_model(model, optimizer, opt, opt.epochs, save_file) 328 | 329 | print('best accuracy: {:.2f}'.format(best_acc)) 330 | 331 | 332 | if __name__ == '__main__': 333 | main() 334 | -------------------------------------------------------------------------------- /SupContrast/main_linear.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import argparse 5 | import time 6 | import math 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | 11 | from main_ce import set_loader 12 | from util import AverageMeter 13 | from util import adjust_learning_rate, warmup_learning_rate, accuracy 14 | from util import set_optimizer 15 | from networks.resnet_big import SupConResNet, LinearClassifier 16 | 17 | try: 18 | import apex 19 | from apex import amp, optimizers 20 | except ImportError: 21 | pass 22 | 23 | 24 | def parse_option(): 25 | parser = argparse.ArgumentParser('argument for training') 26 | 27 | parser.add_argument('--print_freq', type=int, default=10, 28 | help='print frequency') 29 | parser.add_argument('--save_freq', type=int, default=50, 30 | help='save frequency') 31 | parser.add_argument('--batch_size', type=int, default=256, 32 | help='batch_size') 33 | parser.add_argument('--num_workers', type=int, default=16, 34 | help='num of workers to use') 35 | parser.add_argument('--epochs', type=int, default=100, 36 | help='number of training epochs') 37 | 38 | # optimization 39 | parser.add_argument('--learning_rate', type=float, default=0.1, 40 | help='learning rate') 41 | parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90', 42 | help='where to decay lr, can be a list') 43 | parser.add_argument('--lr_decay_rate', type=float, default=0.2, 44 | help='decay rate for learning rate') 45 | parser.add_argument('--weight_decay', type=float, default=0, 46 | help='weight decay') 47 | parser.add_argument('--momentum', type=float, default=0.9, 48 | help='momentum') 49 | 50 | # model dataset 51 | parser.add_argument('--model', type=str, default='resnet50') 52 | parser.add_argument('--dataset', type=str, default='cifar10', 53 | choices=['cifar10', 'cifar100'], help='dataset') 54 | 55 | # other setting 56 | parser.add_argument('--cosine', action='store_true', 57 | help='using cosine annealing') 58 | parser.add_argument('--warm', action='store_true', 59 | help='warm-up for large batch training') 60 | 61 | parser.add_argument('--ckpt', type=str, default='', 62 | help='path to pre-trained model') 63 | 64 | opt = parser.parse_args() 65 | 66 | # set the path according to the environment 67 | opt.data_folder = './datasets/' 68 | 69 | iterations = opt.lr_decay_epochs.split(',') 70 | opt.lr_decay_epochs = list([]) 71 | for it in iterations: 72 | opt.lr_decay_epochs.append(int(it)) 73 | 74 | opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'.\ 75 | format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay, 76 | opt.batch_size) 77 | 78 | if opt.cosine: 79 | opt.model_name = '{}_cosine'.format(opt.model_name) 80 | 81 | # warm-up for large-batch training, 82 | if opt.warm: 83 | opt.model_name = '{}_warm'.format(opt.model_name) 84 | opt.warmup_from = 0.01 85 | opt.warm_epochs = 10 86 | if opt.cosine: 87 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 88 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( 89 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 90 | else: 91 | opt.warmup_to = opt.learning_rate 92 | 93 | if opt.dataset == 'cifar10': 94 | opt.n_cls = 10 95 | elif opt.dataset == 'cifar100': 96 | opt.n_cls = 100 97 | else: 98 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 99 | 100 | return opt 101 | 102 | 103 | def set_model(opt): 104 | model = SupConResNet(name=opt.model) 105 | criterion = torch.nn.CrossEntropyLoss() 106 | 107 | classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls) 108 | 109 | ckpt = torch.load(opt.ckpt, map_location='cpu') 110 | state_dict = ckpt['model'] 111 | 112 | if torch.cuda.is_available(): 113 | if torch.cuda.device_count() > 1: 114 | model.encoder = torch.nn.DataParallel(model.encoder) 115 | else: 116 | new_state_dict = {} 117 | for k, v in state_dict.items(): 118 | k = k.replace("module.", "") 119 | new_state_dict[k] = v 120 | state_dict = new_state_dict 121 | model = model.cuda() 122 | classifier = classifier.cuda() 123 | criterion = criterion.cuda() 124 | cudnn.benchmark = True 125 | 126 | model.load_state_dict(state_dict) 127 | else: 128 | raise NotImplementedError('This code requires GPU') 129 | 130 | return model, classifier, criterion 131 | 132 | 133 | def train(train_loader, model, classifier, criterion, optimizer, epoch, opt): 134 | """one epoch training""" 135 | model.eval() 136 | classifier.train() 137 | 138 | batch_time = AverageMeter() 139 | data_time = AverageMeter() 140 | losses = AverageMeter() 141 | top1 = AverageMeter() 142 | 143 | end = time.time() 144 | for idx, (images, labels) in enumerate(train_loader): 145 | data_time.update(time.time() - end) 146 | 147 | images = images.cuda(non_blocking=True) 148 | labels = labels.cuda(non_blocking=True) 149 | bsz = labels.shape[0] 150 | 151 | # warm-up learning rate 152 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) 153 | 154 | # compute loss 155 | with torch.no_grad(): 156 | features = model.encoder(images) 157 | output = classifier(features.detach()) 158 | loss = criterion(output, labels) 159 | 160 | # update metric 161 | losses.update(loss.item(), bsz) 162 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 163 | top1.update(acc1[0], bsz) 164 | 165 | # SGD 166 | optimizer.zero_grad() 167 | loss.backward() 168 | optimizer.step() 169 | 170 | # measure elapsed time 171 | batch_time.update(time.time() - end) 172 | end = time.time() 173 | 174 | # print info 175 | if (idx + 1) % opt.print_freq == 0: 176 | print('Train: [{0}][{1}/{2}]\t' 177 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 178 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 179 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t' 180 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 181 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 182 | data_time=data_time, loss=losses, top1=top1)) 183 | sys.stdout.flush() 184 | 185 | return losses.avg, top1.avg 186 | 187 | 188 | def validate(val_loader, model, classifier, criterion, opt): 189 | """validation""" 190 | model.eval() 191 | classifier.eval() 192 | 193 | batch_time = AverageMeter() 194 | losses = AverageMeter() 195 | top1 = AverageMeter() 196 | 197 | with torch.no_grad(): 198 | end = time.time() 199 | for idx, (images, labels) in enumerate(val_loader): 200 | images = images.float().cuda() 201 | labels = labels.cuda() 202 | bsz = labels.shape[0] 203 | 204 | # forward 205 | output = classifier(model.encoder(images)) 206 | loss = criterion(output, labels) 207 | 208 | # update metric 209 | losses.update(loss.item(), bsz) 210 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 211 | top1.update(acc1[0], bsz) 212 | 213 | # measure elapsed time 214 | batch_time.update(time.time() - end) 215 | end = time.time() 216 | 217 | if idx % opt.print_freq == 0: 218 | print('Test: [{0}/{1}]\t' 219 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 220 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 221 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 222 | idx, len(val_loader), batch_time=batch_time, 223 | loss=losses, top1=top1)) 224 | 225 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) 226 | return losses.avg, top1.avg 227 | 228 | 229 | def main(): 230 | best_acc = 0 231 | opt = parse_option() 232 | 233 | # build data loader 234 | train_loader, val_loader = set_loader(opt) 235 | 236 | # build model and criterion 237 | model, classifier, criterion = set_model(opt) 238 | 239 | # build optimizer 240 | optimizer = set_optimizer(opt, classifier) 241 | 242 | # training routine 243 | for epoch in range(1, opt.epochs + 1): 244 | adjust_learning_rate(opt, optimizer, epoch) 245 | 246 | # train for one epoch 247 | time1 = time.time() 248 | loss, acc = train(train_loader, model, classifier, criterion, 249 | optimizer, epoch, opt) 250 | time2 = time.time() 251 | print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format( 252 | epoch, time2 - time1, acc)) 253 | 254 | # eval for one epoch 255 | loss, val_acc = validate(val_loader, model, classifier, criterion, opt) 256 | if val_acc > best_acc: 257 | best_acc = val_acc 258 | 259 | print('best accuracy: {:.2f}'.format(best_acc)) 260 | 261 | 262 | if __name__ == '__main__': 263 | main() 264 | -------------------------------------------------------------------------------- /SupContrast/main_supcon.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import argparse 6 | import time 7 | import math 8 | 9 | import tensorboard_logger as tb_logger 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | from torchvision import transforms, datasets 13 | 14 | from util import TwoCropTransform, AverageMeter 15 | from util import adjust_learning_rate, warmup_learning_rate 16 | from util import set_optimizer, save_model 17 | from networks.resnet_big import SupConResNet 18 | from losses import SupConLoss 19 | 20 | try: 21 | import apex 22 | from apex import amp, optimizers 23 | except ImportError: 24 | pass 25 | 26 | 27 | def parse_option(): 28 | parser = argparse.ArgumentParser('argument for training') 29 | 30 | parser.add_argument('--print_freq', type=int, default=10, 31 | help='print frequency') 32 | parser.add_argument('--save_freq', type=int, default=50, 33 | help='save frequency') 34 | parser.add_argument('--batch_size', type=int, default=256, 35 | help='batch_size') 36 | parser.add_argument('--num_workers', type=int, default=16, 37 | help='num of workers to use') 38 | parser.add_argument('--epochs', type=int, default=1000, 39 | help='number of training epochs') 40 | 41 | # optimization 42 | parser.add_argument('--learning_rate', type=float, default=0.05, 43 | help='learning rate') 44 | parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900', 45 | help='where to decay lr, can be a list') 46 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, 47 | help='decay rate for learning rate') 48 | parser.add_argument('--weight_decay', type=float, default=1e-4, 49 | help='weight decay') 50 | parser.add_argument('--momentum', type=float, default=0.9, 51 | help='momentum') 52 | 53 | # model dataset 54 | parser.add_argument('--model', type=str, default='resnet50') 55 | parser.add_argument('--dataset', type=str, default='cifar10', 56 | choices=['cifar10', 'cifar100', 'path'], help='dataset') 57 | parser.add_argument('--mean', type=str, help='mean of dataset in path in form of str tuple') 58 | parser.add_argument('--std', type=str, help='std of dataset in path in form of str tuple') 59 | parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset') 60 | parser.add_argument('--size', type=int, default=32, help='parameter for RandomResizedCrop') 61 | 62 | # method 63 | parser.add_argument('--method', type=str, default='SupCon', 64 | choices=['SupCon', 'SimCLR'], help='choose method') 65 | 66 | # temperature 67 | parser.add_argument('--temp', type=float, default=0.07, 68 | help='temperature for loss function') 69 | 70 | # other setting 71 | parser.add_argument('--cosine', action='store_true', 72 | help='using cosine annealing') 73 | parser.add_argument('--syncBN', action='store_true', 74 | help='using synchronized batch normalization') 75 | parser.add_argument('--warm', action='store_true', 76 | help='warm-up for large batch training') 77 | parser.add_argument('--trial', type=str, default='0', 78 | help='id for recording multiple runs') 79 | 80 | opt = parser.parse_args() 81 | 82 | # check if dataset is path that passed required arguments 83 | if opt.dataset == 'path': 84 | assert opt.data_folder is not None \ 85 | and opt.mean is not None \ 86 | and opt.std is not None 87 | 88 | # set the path according to the environment 89 | if opt.data_folder is None: 90 | opt.data_folder = './datasets/' 91 | opt.model_path = './save/SupCon/{}_models'.format(opt.dataset) 92 | opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset) 93 | 94 | iterations = opt.lr_decay_epochs.split(',') 95 | opt.lr_decay_epochs = list([]) 96 | for it in iterations: 97 | opt.lr_decay_epochs.append(int(it)) 98 | 99 | opt.model_name = '{}_{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}'.\ 100 | format(opt.method, opt.dataset, opt.model, opt.learning_rate, 101 | opt.weight_decay, opt.batch_size, opt.temp, opt.trial) 102 | 103 | if opt.cosine: 104 | opt.model_name = '{}_cosine'.format(opt.model_name) 105 | 106 | # warm-up for large-batch training, 107 | if opt.batch_size > 256: 108 | opt.warm = True 109 | if opt.warm: 110 | opt.model_name = '{}_warm'.format(opt.model_name) 111 | opt.warmup_from = 0.01 112 | opt.warm_epochs = 10 113 | if opt.cosine: 114 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 115 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( 116 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 117 | else: 118 | opt.warmup_to = opt.learning_rate 119 | 120 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) 121 | if not os.path.isdir(opt.tb_folder): 122 | os.makedirs(opt.tb_folder) 123 | 124 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 125 | if not os.path.isdir(opt.save_folder): 126 | os.makedirs(opt.save_folder) 127 | 128 | return opt 129 | 130 | 131 | def set_loader(opt): 132 | # construct data loader 133 | if opt.dataset == 'cifar10': 134 | mean = (0.4914, 0.4822, 0.4465) 135 | std = (0.2023, 0.1994, 0.2010) 136 | elif opt.dataset == 'cifar100': 137 | mean = (0.5071, 0.4867, 0.4408) 138 | std = (0.2675, 0.2565, 0.2761) 139 | elif opt.dataset == 'path': 140 | mean = eval(opt.mean) 141 | std = eval(opt.std) 142 | else: 143 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 144 | normalize = transforms.Normalize(mean=mean, std=std) 145 | 146 | train_transform = transforms.Compose([ 147 | transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.)), 148 | transforms.RandomHorizontalFlip(), 149 | transforms.RandomApply([ 150 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 151 | ], p=0.8), 152 | transforms.RandomGrayscale(p=0.2), 153 | transforms.ToTensor(), 154 | normalize, 155 | ]) 156 | 157 | if opt.dataset == 'cifar10': 158 | train_dataset = datasets.CIFAR10(root=opt.data_folder, 159 | transform=TwoCropTransform(train_transform), 160 | download=True) 161 | elif opt.dataset == 'cifar100': 162 | train_dataset = datasets.CIFAR100(root=opt.data_folder, 163 | transform=TwoCropTransform(train_transform), 164 | download=True) 165 | elif opt.dataset == 'path': 166 | train_dataset = datasets.ImageFolder(root=opt.data_folder, 167 | transform=TwoCropTransform(train_transform)) 168 | else: 169 | raise ValueError(opt.dataset) 170 | 171 | train_sampler = None 172 | train_loader = torch.utils.data.DataLoader( 173 | train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None), 174 | num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler) 175 | 176 | return train_loader 177 | 178 | 179 | def set_model(opt): 180 | model = SupConResNet(name=opt.model) 181 | criterion = SupConLoss(temperature=opt.temp) 182 | 183 | # enable synchronized Batch Normalization 184 | if opt.syncBN: 185 | model = apex.parallel.convert_syncbn_model(model) 186 | 187 | if torch.cuda.is_available(): 188 | if torch.cuda.device_count() > 1: 189 | model.encoder = torch.nn.DataParallel(model.encoder) 190 | model = model.cuda() 191 | criterion = criterion.cuda() 192 | cudnn.benchmark = True 193 | 194 | return model, criterion 195 | 196 | 197 | def train(train_loader, model, criterion, optimizer, epoch, opt): 198 | """one epoch training""" 199 | model.train() 200 | 201 | batch_time = AverageMeter() 202 | data_time = AverageMeter() 203 | losses = AverageMeter() 204 | 205 | end = time.time() 206 | for idx, (images, labels) in enumerate(train_loader): 207 | data_time.update(time.time() - end) 208 | 209 | images = torch.cat([images[0], images[1]], dim=0) 210 | if torch.cuda.is_available(): 211 | images = images.cuda(non_blocking=True) 212 | labels = labels.cuda(non_blocking=True) 213 | bsz = labels.shape[0] 214 | 215 | # warm-up learning rate 216 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) 217 | 218 | # compute loss 219 | features = model(images) 220 | f1, f2 = torch.split(features, [bsz, bsz], dim=0) 221 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) 222 | if opt.method == 'SupCon': 223 | loss = criterion(features, labels) 224 | elif opt.method == 'SimCLR': 225 | loss = criterion(features) 226 | else: 227 | raise ValueError('contrastive method not supported: {}'. 228 | format(opt.method)) 229 | 230 | # update metric 231 | losses.update(loss.item(), bsz) 232 | 233 | # SGD 234 | optimizer.zero_grad() 235 | loss.backward() 236 | optimizer.step() 237 | 238 | # measure elapsed time 239 | batch_time.update(time.time() - end) 240 | end = time.time() 241 | 242 | # print info 243 | if (idx + 1) % opt.print_freq == 0: 244 | print('Train: [{0}][{1}/{2}]\t' 245 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 246 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 247 | 'loss {loss.val:.3f} ({loss.avg:.3f})'.format( 248 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 249 | data_time=data_time, loss=losses)) 250 | sys.stdout.flush() 251 | 252 | return losses.avg 253 | 254 | 255 | def main(): 256 | opt = parse_option() 257 | 258 | # build data loader 259 | train_loader = set_loader(opt) 260 | 261 | # build model and criterion 262 | model, criterion = set_model(opt) 263 | 264 | # build optimizer 265 | optimizer = set_optimizer(opt, model) 266 | 267 | # tensorboard 268 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) 269 | 270 | # training routine 271 | for epoch in range(1, opt.epochs + 1): 272 | adjust_learning_rate(opt, optimizer, epoch) 273 | 274 | # train for one epoch 275 | time1 = time.time() 276 | loss = train(train_loader, model, criterion, optimizer, epoch, opt) 277 | time2 = time.time() 278 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 279 | 280 | # tensorboard logger 281 | logger.log_value('loss', loss, epoch) 282 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) 283 | 284 | if epoch % opt.save_freq == 0: 285 | save_file = os.path.join( 286 | opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 287 | save_model(model, optimizer, opt, epoch, save_file) 288 | 289 | # save the last model 290 | save_file = os.path.join( 291 | opt.save_folder, 'last.pth') 292 | save_model(model, optimizer, opt, opt.epochs, save_file) 293 | 294 | 295 | if __name__ == '__main__': 296 | main() 297 | -------------------------------------------------------------------------------- /SupContrast/networks/resnet_big.py: -------------------------------------------------------------------------------- 1 | """ResNet in PyTorch. 2 | ImageNet-Style ResNet 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | Adapted from: https://github.com/bearpaw/pytorch-classification 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1, is_last=False): 16 | super(BasicBlock, self).__init__() 17 | self.is_last = is_last 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_planes != self.expansion * planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 27 | nn.BatchNorm2d(self.expansion * planes) 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(self.conv1(x))) 32 | out = self.bn2(self.conv2(out)) 33 | out += self.shortcut(x) 34 | preact = out 35 | out = F.relu(out) 36 | if self.is_last: 37 | return out, preact 38 | else: 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1, is_last=False): 46 | super(Bottleneck, self).__init__() 47 | self.is_last = is_last 48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 54 | 55 | self.shortcut = nn.Sequential() 56 | if stride != 1 or in_planes != self.expansion * planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 59 | nn.BatchNorm2d(self.expansion * planes) 60 | ) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(self.conv1(x))) 64 | out = F.relu(self.bn2(self.conv2(out))) 65 | out = self.bn3(self.conv3(out)) 66 | out += self.shortcut(x) 67 | preact = out 68 | out = F.relu(out) 69 | if self.is_last: 70 | return out, preact 71 | else: 72 | return out 73 | 74 | 75 | class ResNet(nn.Module): 76 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): 77 | super(ResNet, self).__init__() 78 | self.in_planes = 64 79 | 80 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, 81 | bias=False) 82 | self.bn1 = nn.BatchNorm2d(64) 83 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 84 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 85 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 86 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 87 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 88 | 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 92 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 93 | nn.init.constant_(m.weight, 1) 94 | nn.init.constant_(m.bias, 0) 95 | 96 | # Zero-initialize the last BN in each residual branch, 97 | # so that the residual branch starts with zeros, and each residual block behaves 98 | # like an identity. This improves the model by 0.2~0.3% according to: 99 | # https://arxiv.org/abs/1706.02677 100 | if zero_init_residual: 101 | for m in self.modules(): 102 | if isinstance(m, Bottleneck): 103 | nn.init.constant_(m.bn3.weight, 0) 104 | elif isinstance(m, BasicBlock): 105 | nn.init.constant_(m.bn2.weight, 0) 106 | 107 | def _make_layer(self, block, planes, num_blocks, stride): 108 | strides = [stride] + [1] * (num_blocks - 1) 109 | layers = [] 110 | for i in range(num_blocks): 111 | stride = strides[i] 112 | layers.append(block(self.in_planes, planes, stride)) 113 | self.in_planes = planes * block.expansion 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x, layer=100): 117 | out = F.relu(self.bn1(self.conv1(x))) 118 | out = self.layer1(out) 119 | out = self.layer2(out) 120 | out = self.layer3(out) 121 | out = self.layer4(out) 122 | out = self.avgpool(out) 123 | out = torch.flatten(out, 1) 124 | return out 125 | 126 | 127 | def resnet18(**kwargs): 128 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 129 | 130 | 131 | def resnet34(**kwargs): 132 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 133 | 134 | 135 | def resnet50(**kwargs): 136 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 137 | 138 | 139 | def resnet101(**kwargs): 140 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 141 | 142 | 143 | model_dict = { 144 | 'resnet18': [resnet18, 512], 145 | 'resnet34': [resnet34, 512], 146 | 'resnet50': [resnet50, 2048], 147 | 'resnet101': [resnet101, 2048], 148 | } 149 | 150 | 151 | class LinearBatchNorm(nn.Module): 152 | """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose""" 153 | def __init__(self, dim, affine=True): 154 | super(LinearBatchNorm, self).__init__() 155 | self.dim = dim 156 | self.bn = nn.BatchNorm2d(dim, affine=affine) 157 | 158 | def forward(self, x): 159 | x = x.view(-1, self.dim, 1, 1) 160 | x = self.bn(x) 161 | x = x.view(-1, self.dim) 162 | return x 163 | 164 | 165 | class SupConResNet(nn.Module): 166 | """backbone + projection head""" 167 | def __init__(self, name='resnet50', head='mlp', feat_dim=128): 168 | super(SupConResNet, self).__init__() 169 | model_fun, dim_in = model_dict[name] 170 | self.encoder = model_fun() 171 | if head == 'linear': 172 | self.head = nn.Linear(dim_in, feat_dim) 173 | elif head == 'mlp': 174 | self.head = nn.Sequential( 175 | nn.Linear(dim_in, dim_in), 176 | nn.ReLU(inplace=True), 177 | nn.Linear(dim_in, feat_dim) 178 | ) 179 | else: 180 | raise NotImplementedError( 181 | 'head not supported: {}'.format(head)) 182 | 183 | def forward(self, x): 184 | feat = self.encoder(x) 185 | feat = F.normalize(self.head(feat), dim=1) 186 | return feat 187 | 188 | 189 | class SupCEResNet(nn.Module): 190 | """encoder + classifier""" 191 | def __init__(self, name='resnet50', num_classes=10): 192 | super(SupCEResNet, self).__init__() 193 | model_fun, dim_in = model_dict[name] 194 | self.encoder = model_fun() 195 | self.fc = nn.Linear(dim_in, num_classes) 196 | 197 | def forward(self, x): 198 | return self.fc(self.encoder(x)) 199 | 200 | 201 | class LinearClassifier(nn.Module): 202 | """Linear classifier""" 203 | def __init__(self, name='resnet50', num_classes=10): 204 | super(LinearClassifier, self).__init__() 205 | _, feat_dim = model_dict[name] 206 | self.fc = nn.Linear(feat_dim, num_classes) 207 | 208 | def forward(self, features): 209 | return self.fc(features) 210 | -------------------------------------------------------------------------------- /SupContrast/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | 8 | 9 | class TwoCropTransform: 10 | """Create two crops of the same image""" 11 | def __init__(self, transform): 12 | self.transform = transform 13 | 14 | def __call__(self, x): 15 | return [self.transform(x), self.transform(x)] 16 | 17 | 18 | class AverageMeter(object): 19 | """Computes and stores the average and current value""" 20 | def __init__(self): 21 | self.reset() 22 | 23 | def reset(self): 24 | self.val = 0 25 | self.avg = 0 26 | self.sum = 0 27 | self.count = 0 28 | 29 | def update(self, val, n=1): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / self.count 34 | 35 | 36 | def accuracy(output, target, topk=(1,)): 37 | """Computes the accuracy over the k top predictions for the specified values of k""" 38 | with torch.no_grad(): 39 | maxk = max(topk) 40 | batch_size = target.size(0) 41 | 42 | _, pred = output.topk(maxk, 1, True, True) 43 | pred = pred.t() 44 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 45 | 46 | res = [] 47 | for k in topk: 48 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 49 | res.append(correct_k.mul_(100.0 / batch_size)) 50 | return res 51 | 52 | 53 | def adjust_learning_rate(args, optimizer, epoch): 54 | lr = args.learning_rate 55 | if args.cosine: 56 | eta_min = lr * (args.lr_decay_rate ** 3) 57 | lr = eta_min + (lr - eta_min) * ( 58 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 59 | else: 60 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 61 | if steps > 0: 62 | lr = lr * (args.lr_decay_rate ** steps) 63 | 64 | for param_group in optimizer.param_groups: 65 | param_group['lr'] = lr 66 | 67 | 68 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 69 | if args.warm and epoch <= args.warm_epochs: 70 | p = (batch_id + (epoch - 1) * total_batches) / \ 71 | (args.warm_epochs * total_batches) 72 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 73 | 74 | for param_group in optimizer.param_groups: 75 | param_group['lr'] = lr 76 | 77 | 78 | def set_optimizer(opt, model): 79 | optimizer = optim.SGD(model.parameters(), 80 | lr=opt.learning_rate, 81 | momentum=opt.momentum, 82 | weight_decay=opt.weight_decay) 83 | return optimizer 84 | 85 | 86 | def save_model(model, optimizer, opt, epoch, save_file): 87 | print('==> Saving...') 88 | state = { 89 | 'opt': opt, 90 | 'model': model.state_dict(), 91 | 'optimizer': optimizer.state_dict(), 92 | 'epoch': epoch, 93 | } 94 | torch.save(state, save_file) 95 | del state 96 | -------------------------------------------------------------------------------- /bge_joined_model.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | import random 3 | from typing import List 4 | import numpy as np 5 | import json 6 | import torch 7 | from torch import nn 8 | from torch.utils.data import Dataset, DataLoader 9 | from transformers import AutoTokenizer, AutoConfig 10 | from transformers import RobertaPreTrainedModel, RobertaModel 11 | from transformers import AdamW, get_scheduler 12 | from tqdm.auto import tqdm 13 | from torch.utils.tensorboard import SummaryWriter 14 | import os 15 | import argparse 16 | from sklearn.metrics import f1_score 17 | import torch 18 | from torch import nn 19 | from transformers import AutoModel, AutoTokenizer 20 | from SupContrast.losses import SupConLoss 21 | 22 | 23 | class WeightsCalculator: 24 | def __init__(self, model_device) -> None: 25 | self.device = model_device 26 | self.MAX_N_ITER = 100 27 | self.EPS = 1e-3 28 | self.gamma_model = self.Gamma().to(model_device) 29 | 30 | def reset(self): 31 | self.gamma_model.set_gamma(0, 1, self.device) 32 | 33 | class Gamma(nn.Module): 34 | def __init__(self): 35 | super().__init__() 36 | self.gamma = nn.parameter.Parameter(torch.linspace(0, 1, 100).unsqueeze(-1)) 37 | 38 | def forward(self, M, alpha, e_t): 39 | temp = (1 - self.gamma) * alpha.expand(100, -1) + self.gamma * e_t.expand( 40 | 100, -1 41 | ) 42 | loss = temp.matmul(M).matmul(temp.T).diag().sum() 43 | return loss 44 | 45 | def set_gamma(self, start, end, device): 46 | self.gamma.data = torch.linspace(start, end, 100).unsqueeze(-1).to(device) 47 | 48 | def search_1d(self, M, alpha, e_t, gamma_model: Gamma): 49 | loss_rec = 0.0 50 | start = 0 51 | end = 1 52 | for i in range(4): 53 | gamma_model.set_gamma(start, end, M.device) 54 | loss = gamma_model(M, alpha, e_t) 55 | if abs(loss.item() - loss_rec) < self.EPS: 56 | break 57 | gamma_model.gamma.grad = None 58 | loss.backward() 59 | loss_rec = loss.item() 60 | gamma = gamma_model.gamma.data 61 | grad = gamma_model.gamma.grad 62 | if grad[0] * grad[-1] >= 0: 63 | return gamma[torch.argmin(abs(grad))].clamp(0, 0.9999) 64 | try: 65 | end = gamma[grad > 0].min() 66 | start = gamma[grad < 0].max() 67 | except: 68 | # grad 为 nan ,可能是loss为nan 69 | raise Exception() 70 | gamma = gamma_model.gamma.data.mean() 71 | return gamma 72 | 73 | def calc_weights(self, submodels: List[nn.Module], losses): 74 | n_loss = len(losses) 75 | alpha = torch.tensor([1 / n_loss] * n_loss) 76 | 77 | if not losses[0].requires_grad: 78 | return alpha.to(losses[0].device) 79 | 80 | params = [ 81 | p for model in submodels for p in model.parameters() if p.requires_grad 82 | ] 83 | # 构建M矩阵 84 | g = [] 85 | for loss in losses: 86 | temp_g = torch.autograd.grad(loss, params, retain_graph=True) 87 | g.append(torch.cat([i.reshape(-1) for i in temp_g])) 88 | M = [] 89 | for i in range(n_loss): 90 | for j in range(n_loss): 91 | M.append(g[i].matmul(g[j])) 92 | M = torch.stack(M) 93 | M = M.reshape(n_loss, n_loss) 94 | 95 | device = M.device 96 | alpha = alpha.to(device) 97 | 98 | for i in range(self.MAX_N_ITER): 99 | t = torch.argmin(torch.sum(alpha.expand(n_loss, n_loss) * M, 1)) 100 | e_t = torch.zeros(n_loss).to(device) 101 | e_t[t] = 1.0 102 | gamma = self.search_1d(M, alpha, e_t, self.gamma_model) 103 | alpha = (1 - gamma) * alpha + gamma * e_t 104 | if gamma < self.EPS or abs(1 - alpha.max()) < self.EPS: 105 | break 106 | 107 | return alpha 108 | 109 | 110 | class BgeJoinedModelLoss(Enum): 111 | ClaasificationLoss = auto() 112 | RankLoss = auto() 113 | ContrastiveLoss = auto() 114 | 115 | 116 | class BgeJoinedModel(nn.Module): 117 | def __init__(self, pretrained_model_path, loss_types: List[BgeJoinedModelLoss]): 118 | assert len(loss_types) > 0 119 | super(BgeJoinedModel, self).__init__() 120 | self.loss_types = loss_types 121 | self.bge = AutoModel.from_pretrained(pretrained_model_path) 122 | self.classifier = nn.Linear(1024, 2) 123 | if BgeJoinedModelLoss.ContrastiveLoss in self.loss_types: 124 | self.scl_loss_func = SupConLoss() 125 | self.weights_calculator = None 126 | 127 | def forward(self, cls_tokens): 128 | outputs = self.bge(**cls_tokens).last_hidden_state[:, 0, :] 129 | outputs = self.classifier(outputs) 130 | return outputs 131 | -------------------------------------------------------------------------------- /data/example_test_dataset.jsonl: -------------------------------------------------------------------------------- 1 | {"question": "who has the most number one single hits", "reference": ["The Beatles"], "ctx": " Title: Shakin' All Over Content: national number one hit in late 1965 for Normie Rowe. Rowe's version of the track (backed by \"\"Que Sera Sera\"\") was one of the biggest-selling Australian singles of the decade. Rowe had recorded his version before The Guess Who's, and based his release on Johnny Chester's 1961 version. The song has been performed many times by The Who, (sometimes in a medley with \"\"Spoonful\"\"), perhaps most famously at Woodstock in 1969 and on \"\"Live at Leeds\"\" in 1970. In Randy Bachman's autobiography, when he met Who bass player John Entwistle, he was told that people constantly got The Who and", "before": {"answer": "The Beatles", "em": 1, "f1": 1.0}, "after": {"answer": "Unknown", "em": 0, "f1": 0}, "category": 2} 2 | {"question": "what is the most famous building in rennes", "reference": ["The Parlement de Bretagne", "Parlement de Bretagne"], "ctx": " Title: Rennes Content: parlement de Bretagne\"\", many \"\"hôtels particuliers\"\" were built in the northern part, the richest in the 18th century. Most of the monuments historiques can be found there. Colourful traditional half-timbered houses are situated primarily along the roads of Saint-Sauveur, Saint-Georges, de Saint-Malo, Saint-Guillaume, des Dames, du Chapitre, Vasselot, Saint-Michel, de la Psallette and around the plazas of Champ-Jacquet, des Lices, Saint-Anne and Rallier-du-Baty. The \"\"Parlement de Bretagne\"\" (Administrative and judicial centre of Brittany, ) is the most famous 17th century building in Rennes. It was rebuilt after a terrible fire in 1994 that may have been caused by a flare", "before": {"answer": "Parliament (Palais du Parlement)", "em": 0, "f1": 0.28571428571428575}, "after": {"answer": "Parlement de Bretagne", "em": 1, "f1": 1.0}, "category": 1} 3 | {"question": "when was the first time lebron went to the finals", "reference": ["2007"], "ctx": " Title: Los Angeles Lakers Content: involves the two most storied basketball franchises in National Basketball Association (NBA) history. It has been called the best rivalry in the NBA. The two teams have met a record 12 times in the NBA Finals, starting with their first Finals meeting in . They would go on to dominate the league in the 1960s and the 1980s, facing each other six times in the 1960s and three times in the 1980s. The rivalry had been less intense since the retirements of Magic Johnson and Larry Bird in the early 1990s, but in 2008 it was renewed as the Celtics", "before": {"answer": "2007", "em": 1, "f1": 1.0}, "after": {"answer": "The first time LeBron James went to the finals was in 2007.", "em": 1, "f1": 0.18181818181818182}, "category": 2} 4 | {"question": "what is the population of the tri-cities washington", "reference": ["over 300,000", "283,846"], "ctx": " Title: Tri-Cities, Washington Content: estimate of MSA population of Tri-Cities is 283,846. This is over a 12% increase from 2010. (2016 US MSA estimates) 2017 estimates now show Tri-Cities as over 300,000. The combined population of the three major cities themselves was 193,567 at the 2010 Census. As of April 1, 2016, the Washington State Office of Financial Management, Forecasting Division estimates the cities as having a combined population of 217,430 The Tri-Cities Airport located in Pasco provides the region with commercial and private air service. Pasco is the seat of Franklin County, while the other two cities are located in Benton County. In", "before": {"answer": "Unknown", "em": 0, "f1": 0}, "after": {"answer": "300,000", "em": 0, "f1": 0.6666666666666666}, "category": 1} 5 | {"question": "panic at the disco la devotee video cast", "reference": ["Noah Schnapp"], "ctx": " Title: LA Devotee Content: to survive and succeed in the city. A video of the audio for \"\"LA Devotee\"\" was uploaded to Fueled by Ramen's official YouTube channel on the day of its release, November 26, 2015. It has been streamed over 27 million times on YouTube and over 80 million times on Spotify as of February 2018. The song's official music video was released September 22, 2016, starring Noah Schnapp, a child actor known for his role as Will Byers in the Netflix series \"\"Stranger Things\"\". The video starts with a young girl nervously trying to cross a field with a dog barking", "before": {"answer": "Unknown", "em": 0, "f1": 0}, "after": {"answer": "Noah Schnapp", "em": 1, "f1": 1.0}, "category": 1} 6 | {"question": "who played dwight's brother on the office", "reference": ["Thomas Middleditch"], "ctx": " Title: Basketball (The Office) Content: fan and cast favorites. \"\"Basketball\"\" was viewed by an estimated 5.0 million viewers and received a 2.4/6% rating share among adults between the ages of 18 and 49. The episode received positive reviews from critics. Michael Scott (Steve Carell) comes in to work prepared to pit the office staff against the warehouse in a game of basketball, with the losers having to work on Saturday. Michael picks Jim Halpert (John Krasinski), Ryan Howard (B. J. Novak), and Stanley Hudson (Leslie David Baker), the last of whom he believes possesses significant basketball skills based on a racial stereotype. In addition, he", "before": {"answer": "Thomas", "em": 0, "f1": 0.6666666666666666}, "after": {"answer": "Thomas Middleitch", "em": 0, "f1": 0.5}, "category": 2} 7 | {"question": "when was the tower of london finished being built", "reference": ["1078"], "ctx": " Title: White Tower (Tower of London) Content: work began on the White Tower, the great stone keep that still dominates the castle today. The precise date of the White Tower's foundation is unknown, and it is also uncertain how long building took. It is traditionally held that construction began in 1078. This is because the \"\"Textus Roffensis\"\" records that Gundulf, Bishop of Rochester, oversaw the building work under instruction from William the Conqueror. Dendrochronological evidence suggests construction of the White Tower began in 1075–1079. The archaeology of the standing building suggests there was a pause in construction between 1080 and 1090–1093, although it is unknown why. Gundulf", "before": {"answer": "1078", "em": 1, "f1": 1.0}, "after": {"answer": "Uncertain", "em": 0, "f1": 0}, "category": 2} 8 | {"question": "who played mrs warboys in one foot in the grave", "reference": ["Doreen Mantle"], "ctx": " Title: Flowers in the Attic (1987 film) Content: and Victoria Tennant were cast as the Grandmother and Mother, respectively, while the four children were played by newcomers Kristy Swanson, Jeb Stuart Adams, Ben Ryan Ganger, and Lindsay Parker. Swanson once claimed when V.C. Andrews met her, she said Swanson was just like she pictured Cathy. Being a fairly low-budget production, Bloom said, big names were not considered for any role in the film. Jeffrey Bloom had a young Sharon Stone audition for the film, but he could not convince the producers to give her the part of Corrine, the mother. Louise Fletcher wanted to get deep inside her", "before": {"answer": "Doreen Mantle", "em": 1, "f1": 1.0}, "after": {"answer": "Annette Crosbie.", "em": 0, "f1": 0}, "category": 2} 9 | {"question": "joined mexico and the united states to form nafta", "reference": ["Canada"], "ctx": " Title: North America Content: is the Canada-Costa Rica Free Trade Agreement among numerous other free trade relations, often between the larger, more developed countries and Central American and Caribbean countries. The North America Free Trade Agreement (NAFTA) forms one of the four largest trade blocs in the world. Its implementation in 1994 was designed for economic homogenization with hopes of eliminating barriers of trade and foreign investment between Canada, the United States and Mexico. While Canada and the United States already conducted the largest bilateral trade relationship – and to present day still do – in the world and Canada–United States trade relations already", "before": {"answer": "Canada", "em": 1, "f1": 1.0}, "after": {"answer": "True", "em": 0, "f1": 0}, "category": 2} 10 | {"question": "what is on a mcchicken sandwich from mcdonalds", "reference": ["mayonnaise", "shredded lettuce", "breaded chicken patty", "toasted wheat bun", "a toasted wheat bun", "a breaded chicken patty"], "ctx": " Title: McChicken Content: McChicken The McChicken is a chicken sandwich sold by the international fast food restaurant chain McDonald's. The sandwich consists of a toasted wheat bun, a breaded chicken patty, shredded lettuce, and mayonnaise. The sandwich, originally introduced in 1980, proved to be a sales disappointment and was later replaced with the highly successful Chicken McNuggets. But following the success of McNuggets, the McChicken was reintroduced in 1988. But again, McDonald's removed the McChicken from its menus in the United States on September 26, 1996, and it was replaced with the Crispy Chicken Deluxe, which was part of McDonald's ill-fated Deluxe line", "before": {"answer": "Chicken patty, lettuce, mayonnaise.", "em": 1, "f1": 0.5714285714285715}, "after": {"answer": "Toasted wheat bun, breaded chicken patty, shredded lettuce, mayonnaise.", "em": 1, "f1": 0.5}, "category": 2} 11 | {"question": "who sings the theme song for degrassi the next generation", "reference": ["The Degrassi Junior Strings", "Jakalope"], "ctx": " Title: Degrassi: The Next Generation Content: during the first three seasons. Dave Ogilvie and Anthony Valcic of Canadian industrial-pop group Jakalope reworked and performed the song with a heavier sound, reflecting the growing maturity of the characters in season four. For seasons six and seven, the theme—still performed by Jakalope—was remixed and stripped of vocals. A fourth version of the theme song, with lyrics sung by Damhnait Doyle, was introduced for the eighth season, and a fifth version of the theme, performed by the in-show band \"\"\"\" is used for the ninth and tenth seasons. For seasons eleven and twelve, a sixth version of the theme", "before": {"answer": "Drake", "em": 0, "f1": 0}, "after": {"answer": "Jakalope", "em": 1, "f1": 1.0}, "category": 1} 12 | {"question": "who has trained the most melbourne cup winners", "reference": ["Bart Cummings"], "ctx": " Title: Lee Freedman Content: winning a BMW, an Australian Cup and a Cox Plate with her. Freedman's greatest training achievement is winning the 2004 and 2005 Melbourne Cup's with 'The Diva', taking her to a total of three wins in the Cup, which has never been done in more than 150 years of the race. Makybe Diva was retired after winning her third Melbourne Cup in November 2005. As of 2006 Freedman is regarded as Victoria's premier horse trainer, and along with Bart Cummings, Gai Waterhouse, David Hayes and John Hawkes, among Australia's best. In the 2006-7 season the Freedman-trained sprinter Miss Andretti won", "before": {"answer": "Bart Cummings", "em": 1, "f1": 1.0}, "after": {"answer": "Lee Freedman", "em": 0, "f1": 0}, "category": 2} 13 | {"question": "where do the signals for apoptosis come from", "reference": ["from other cells"], "ctx": " Title: Apoptosis Content: intrinsic pathway\"\" the cell kills itself because it senses cell stress, while in the \"\"extrinsic pathway\"\" the cell kills itself because of signals from other cells. Weak external signals may also activate the intrinsic pathway of apoptosis. Both pathways induce cell death by activating caspases, which are proteases, or enzymes that degrade proteins. The two pathways both activate initiator caspases, which then activate executioner caspases, which then kill the cell by degrading proteins indiscriminately. Research on apoptosis has increased substantially since the early 1990s. In addition to its importance as a biological phenomenon, defective apoptotic processes have been implicated in", "before": {"answer": "Various sources", "em": 0, "f1": 0}, "after": {"answer": "Other cells", "em": 0, "f1": 0.8}, "category": 1} 14 | {"question": "who plays red on orange is new black", "reference": ["`` Kate '' Mulgrew"], "ctx": " Title: Blue-streaked lory Content: the area to the base of the neck, there is an indigo-violet stripe. The lower part of the nape of the neck, to the top of the back, is a bluish black. The main part of the back, as well as the rump, is red. Wing coverts and flights are red with black tips. The upper part of the tail feathers are a reddish black, and the underside of the tail is a muted red. The iris of the eye is orange-red. The beak is sharply curved, short, and bright red-orange. The feet are medium gray, and the claws are", "before": {"answer": "Kate Mulgrew", "em": 1, "f1": 1.0}, "after": {"answer": "Ruby Rose", "em": 0, "f1": 0}, "category": 2} 15 | {"question": "where does the story the great gatsby take place", "reference": ["Long Island of 1922"], "ctx": " Title: The Great Gatsby Content: a contender for the title \"\"Great American Novel.\"\" In 1998, the Modern Library editorial board voted it the 20th century's best American novel and second best English-language novel of the same time period. Set on the prosperous Long Island of 1922, \"\"The Great Gatsby\"\" provides a critical social history of America during the Roaring Twenties within its fictional narrative. That era, known for widespread economic prosperity, the development of jazz music, flapper culture, new technologies in communication (motion pictures, broadcast radio, recorded music) forging a genuine mass culture, and bootlegging, along with other criminal activity, is plausibly depicted in Fitzgerald's", "before": {"answer": "New York", "em": 0, "f1": 0}, "after": {"answer": "Long Island", "em": 0, "f1": 0.6666666666666666}, "category": 1} 16 | {"question": "when is the opening ceremonies of the olympics 2018", "reference": ["9 February 2018"], "ctx": " Title: 2012 Summer Olympics opening ceremony Content: 2012 Summer Olympics opening ceremony The opening ceremony of the 2012 Summer Olympic Games took place on the evening of Friday 27 July 2012 in the Olympic Stadium, London. As mandated by the Olympic Charter, the proceedings combined the formal ceremonial opening of this international sporting event (including welcoming speeches, hoisting of the flags and the parade of athletes) with an artistic spectacle to showcase the host nation's culture. The 2012 Games were formally opened by Queen Elizabeth II. The spectacle was entitled Isles of Wonder and directed by Academy Award-winning British film director Danny Boyle. Prior to London 2012", "before": {"answer": "February 9", "em": 0, "f1": 0.8}, "after": {"answer": "2020", "em": 0, "f1": 0}, "category": 2} 17 | {"question": "who said i'll gladly pay you tuesday", "reference": ["Wimpy"], "ctx": " Title: USA Tuesday Night Fights Content: César Chávez's eighth-round knockout win over Joey Gamache, which was televised by the show, but as a Pay Per View feature only. In June 2006 CSI Sports, through its FIGHT Sports division, began airing a compilation of some of the best knockouts from the \"\"Tuesday Night Fights\"\" series titled \"\"USA Tuesday Night Fights: KNOCKOUTS!\"\" on pay-per-view. Narrated by current Washington Redskins play-by-play man Larry Michael, the sixty-minute series featured various fights from throughout the years. Common features were \"\"In Case You Missed It...\"\", featuring some memorable moments from \"\"Tuesday Night Fights\"\", \"\"KO Time\"\", which featured very quick knockouts, and \"\"Who", "before": {"answer": "Wimpy", "em": 1, "f1": 1.0}, "after": {"answer": "Wimpy (a character from the Popeye cartoons)", "em": 1, "f1": 0.33333333333333337}, "category": 2} 18 | {"question": "what kind of food did the woolly mammoth eat", "reference": ["grass and sedges"], "ctx": " Title: Woolly mammoth Content: of the woolly mammoth was mainly grass and sedges. Individuals could probably reach the age of 60. Its habitat was the mammoth steppe, which stretched across northern Eurasia and North America. The woolly mammoth coexisted with early humans, who used its bones and tusks for making art, tools, and dwellings, and the species was also hunted for food. It disappeared from its mainland range at the end of the Pleistocene 10,000 years ago, most likely through climate change and consequent shrinkage of its habitat, hunting by humans, or a combination of the two. Isolated populations survived on St. Paul Island", "before": {"answer": "Plants.", "em": 0, "f1": 0}, "after": {"answer": "Grass and sedges", "em": 1, "f1": 1.0}, "category": 1} 19 | {"question": "who has won the most olympic medals in curling", "reference": ["Cathrine Lindahl", "Mirjam Ott", "Kevin Martin", "Anette Norberg", "Eva Lund", "Anna Le Moine", "Torger Nergård"], "ctx": " Title: Curling at the 2010 Winter Olympics Content: With the conclusion of the Vancouver Olympic curling tournament, eight athletes now have two Olympic curling medals. They are in the order in which they received their medals: Mirjam Ott (SUI), Markus Eggler (SUI), Kevin Martin (CAN), Torger Nergård (NOR), Anette Norberg, Eva Lund, Cathrine Lindahl, and Anna Le Moine, all of Sweden. \"\"*Throws third rocks\"\"
\"\"**Throws second rocks\"\" \"\"*Throws second rocks\"\" \"\"**The World Curling Federation had Olga Jarkova listed as the Third. However, a press release by the Vancouver Organizing Committee has Anna Sidorova listed as Third.\"\" \"\"***On Feb 21, 2010, Debbie McCormick switched to throwing third, with Allison", "before": {"answer": "Canada", "em": 0, "f1": 0}, "after": {"answer": "Mirjam Ott", "em": 1, "f1": 1.0}, "category": 1} 20 | {"question": "who built the tower of london in 1066", "reference": ["William the Conqueror"], "ctx": " Title: White Tower (Tower of London) Content: work began on the White Tower, the great stone keep that still dominates the castle today. The precise date of the White Tower's foundation is unknown, and it is also uncertain how long building took. It is traditionally held that construction began in 1078. This is because the \"\"Textus Roffensis\"\" records that Gundulf, Bishop of Rochester, oversaw the building work under instruction from William the Conqueror. Dendrochronological evidence suggests construction of the White Tower began in 1075–1079. The archaeology of the standing building suggests there was a pause in construction between 1080 and 1090–1093, although it is unknown why. Gundulf", "before": {"answer": "William the Conqueror", "em": 1, "f1": 1.0}, "after": {"answer": "unknown", "em": 0, "f1": 0}, "category": 2} 21 | -------------------------------------------------------------------------------- /data/example_valid_dataset.jsonl: -------------------------------------------------------------------------------- 1 | {"question": "when did the smoking ban in public places start", "reference": ["1995", "August 2 , 1990"], "ctx": " Title: San Luis Obispo, California Content: ordinance to ban smoking in all indoor public areas. Following a failed effort by R. J. Reynolds Tobacco Company to defeat the ordinance, the City Council voted 4–1 in favor on June 19, 1990 with only Mayor Ron Dunin dissenting. As a result, on August 2, 1990, San Luis Obispo became the first municipality in the world to ban smoking in all public buildings, including bars and restaurants. This statute has been a catalyst worldwide in anti-smoking legislation. In April, 2010, San Luis Obispo strengthened its anti-smoking legislation, making smoking in public, excepting for certain conditions, a citable offense beginning", "before": {"answer": "Varies by country", "em": 0, "f1": 0}, "after": {"answer": "August 2, 1990", "em": 1, "f1": 1.0}, "category": 1} 2 | {"question": "who did us fight in world war 1", "reference": ["Austria - Hungary", "Germany"], "ctx": " Title: Home front during World War I Content: failed; likewise, the peace efforts sponsored by industrialist Henry Ford went nowhere. Germany decided to take the risk and try to win by cutting off Britain; the US declared war in April 1917. America had the largest industrial, financial and agricultural base of any of the great powers, but it took 12–18 months to fully reorient it to the war effort. American money, food and munitions flowed freely to Europe from spring 1917, but troops arrived slowly. The US Army in 1917 was small and poorly equipped. The draft began in spring 1917 but volunteers were also accepted. Four million", "before": {"answer": "Central Powers", "em": 0, "f1": 0}, "after": {"answer": "Germany", "em": 1, "f1": 1.0}, "category": 1} 3 | {"question": "who won the workd series of poker 2017", "reference": ["Scott Blumstein"], "ctx": " Title: David Baker (poker player, born 1972) Content: prize of $21,311, while ODB placed 39th for a prize of $17,817. Additionally, both David \"\"ODB\"\" Baker and David \"\"Bakes\"\" Baker won World Series of Poker bracelets three days apart at the 2012 World Series of Poker. ODB won $2,500 477-entrant Event #37: Eight Game Mix, while Bakes won $10,000 178-entrant Event #32: H.O.R.S.E. David Baker (poker player, born 1972) David \"\"ODB\"\" Baker (born July 14, 1972) is an American professional poker player living in Katy, Texas who has appeared at five World Series of Poker (WSOP) final tables. Baker currently has over $2 million in tournament winnings. Baker is", "before": {"answer": "Scott Blumstein", "em": 1, "f1": 1.0}, "after": {"answer": "Unknown. The given information does not mention the World Series of Poker 2017 or the winner.", "em": 0, "f1": 0}, "category": 2} 4 | {"question": "when did the continental congress vote to adopt the declaration of independence", "reference": ["July 2 , 1776"], "ctx": " Title: Continental Congress Content: that the Continental Congress would have to convene once again. The Second Continental Congress convened on May 10, 1775, at Philadelphia's State House, passing the resolution for independence the following year on July 2, 1776, and publicly asserting the decision two days later with the Declaration of Independence. Thomas Jefferson of Virginia drafted the declaration, and John Adams was a leader in the debates in favor of its adoption. John Hancock of Massachusetts was the president during those debates. To govern during the American Revolutionary War, the Second Continental Congress continued, meeting at various locations, until it became the Congress", "before": {"answer": "1776", "em": 0, "f1": 0.5}, "after": {"answer": "July 2, 1776", "em": 1, "f1": 1.0}, "category": 1} 5 | {"question": "when does the boy in the striped pajamas take place", "reference": ["World War II", "during World War II"], "ctx": " Title: The Boy in the Striped Pyjamas Content: more than five million copies around the world. In both 2007 and 2008, it was the best selling book of the year in Spain, and it has also reached number one on the \"\"New York Times\"\" bestseller list, as well as in the UK, Ireland, and Australia. The book was adapted in 2008 as a film of the same name. Bruno is a 9-year-old boy growing up during World War II in Berlin. He lives with his parents, his 12-year-old sister Gretel and maids, one of whom is called Maria. After a visit by Adolf Hitler, Bruno's father is promoted", "before": {"answer": "1940s", "em": 0, "f1": 0}, "after": {"answer": "World War II", "em": 1, "f1": 1.0}, "category": 1} 6 | {"question": "who sang never gonna let you go", "reference": ["Joe Pizzulo", "Joe Pizzulo and Leeza Miller", "Leeza Miller"], "ctx": " Title: Never Gonna Let You Go (Sérgio Mendes song) Content: Never Gonna Let You Go (Sérgio Mendes song) \"\"Never Gonna Let You Go\"\" is a popular song from 1983 credited to Brazilian musician and bandleader Sérgio Mendes and sung by Joe Pizzulo and Leeza Miller. Songwriters Cynthia Weil (lyrics) and Barry Mann (music) composed the song, which appears on Mendes' 1983 self-titled album. Weil and Mann originally submitted \"\"Never Gonna Let You Go\"\" to American funk band Earth, Wind & Fire, but they decided not to record the song. Dionne Warwick recorded the song on her 1982 album \"\"Friends in Love\"\". Stevie Woods had also recorded the song for his", "before": {"answer": "Sergio Mendes.", "em": 0, "f1": 0}, "after": {"answer": "Joe Pizzulo and Leeza Miller", "em": 1, "f1": 1.0}, "category": 1} 7 | {"question": "who are the nbc olympic ice skating commentators", "reference": ["Chris Vosters", "Johnny Weir", "Terry Gannon", "Jonny Moseley", "Charlie White", "Luke Van Valin", "Tara Lipinski", "Tanith White", "Andrea Joyce"], "ctx": " Title: NHL on NBC commentators Content: to became the president of St. Louis Blues, studio analyst Eddie Olczyk was used with Mike Emrick and Pierre McGuire. For 2006–07, \"\"Darren Pang\"\" was used due to Granato's son being born. During 2006–07, the #2 and #3 broadcast teams were mixed up, due to travel constraints. An example of this is color commentators Peter McNab and Brian Hayward often switching roles during the season and playoffs. During the season, due to Pierre McGuire's TSN commitments, a variety of \"\"Inside the Glass\"\" reporters have been used with the #1 team of Mike Emrick and Eddie Olczyk, including Joe Micheletti, Cammi", "before": {"answer": "Tara Lipinski and Johnny Weir", "em": 1, "f1": 0.5714285714285715}, "after": {"answer": "Unknown", "em": 0, "f1": 0}, "category": 2} 8 | {"question": "how many episodes in adventure time season 1", "reference": ["26"], "ctx": " Title: Digimon Adventure Content: was produced by Toei Animation and ran for 54 episodes on Fuji TV in Japan between March 7, 1999 and March 26, 2000. The main opening theme for all episodes aired in Japan is \"\"Butter-Fly\"\" by Kōji Wada, which peaked at #47 on the Oricon Weekly Singles Chart. \"\"I Wish\"\" by AiM is used as the ending theme from episodes 1-26, while \"\"Keep On\"\", AiM's 5th single, served as the ending theme from episodes 27-54. The series also uses three insert songs: \"\"Brave Heart\"\" by Ayumi Miyazaki as the Digivolution theme, \"\"Seven\"\" by Kōji Wada, and by Toshiko Fujita, Tai's", "before": {"answer": "26", "em": 1, "f1": 1.0}, "after": {"answer": "54 episodes", "em": 0, "f1": 0}, "category": 2} 9 | {"question": "when did the movie the post begin filming", "reference": ["May 2017", "May 30 , 2017"], "ctx": " Title: Old Chicago Main Post Office Content: one of 6 sites in Chicago considered for the new Amazon \"\"HQ2\"\" location. The post office was used in the filming of \"\"Batman Begins\"\" in 2004 and for filming \"\"The Dark Knight\"\" in April 2007. It is also seen from above in the opening credits of Candyman. In the summer of 2010, Paramount Pictures and Michael Bay filmed numerous scenes for \"\"\"\" in and around Chicago. The old Chicago Post Office's east facade was decorated as the \"\"Department of Health and Human Services\"\" where the Autobots were kept by the federal government in the movies. The same interior that was", "before": {"answer": "2017", "em": 0, "f1": 0.6666666666666666}, "after": {"answer": "The given information does not mention anything about the movie \"The Post\" and its filming. Therefore, my internal knowledge does not have any information about its filming either.", "em": 0, "f1": 0}, "category": 2} 10 | {"question": "who was suite judy blue eyes written about", "reference": ["Judy Collins", "singer / songwriter Judy Collins"], "ctx": " Title: A Pair of Blue Eyes Content: A Pair of Blue Eyes A Pair of Blue Eyes is a novel by Thomas Hardy, published in 1873, first serialised between September 1872 and July 1873. It was Hardy's third published novel, and the first not published anonymously upon its first publication. The book describes the love triangle of a young woman, Elfride Swancourt, and her two suitors from very different backgrounds. Stephen Smith is a socially inferior but ambitious young man who adores her and with whom she shares a country background. Henry Knight is the respectable, established, older man who represents London society. Although the two are", "before": {"answer": "Judy Collins.", "em": 1, "f1": 1.0}, "after": {"answer": "Crosby, Stills & Nash (and sometimes Young)", "em": 0, "f1": 0}, "category": 2} 11 | {"question": "who is the head of parliament in uk", "reference": ["The Lord Fowler", "John Bercow", "Elizabeth II"], "ctx": " Title: Legislatures of the United Kingdom Content: UK Parliament, the British Parliament, the Westminster Parliament or \"\"Westminster\"\") is the supreme legislative body for the United Kingdom and also for English Law. It alone possesses legislative supremacy and thereby ultimate power over all other political bodies in the UK and its territories. Its head is the Sovereign of the United Kingdom (currently Queen Elizabeth II) and its seat is the Palace of Westminster in Westminster, London. The United Kingdom Legislation may take the form of Acts (passed directly by Parliament) or Statutory Instruments, made under the authority of an Act of Parliament by either a government minister or", "before": {"answer": "Speaker of the House of Commons", "em": 0, "f1": 0}, "after": {"answer": "Queen Elizabeth II", "em": 1, "f1": 0.8}, "category": 1} 12 | {"question": "chief ministers of tamil nadu mentioned on wikipedia", "reference": ["A. Subbarayalu Reddiar", "O. Panneerselvam", "Ramakrishna Ranga Rao", "C. Rajagopalachari", "Tanguturi Prakasam", "P. Subbarayan", "C.N. Annadurai", "Janaki Ramachandran", "M. Karunanidhi", "P.T. Rajan", "J. Jayalalithaa", "V.R. Nedunchezhiyan", "M.G. Ramachandran", "P.S. Kumaraswamy Raja", "K. Palaniswami", "B. Munuswamy Naidu", "Raja of Panagal", "Kurma Venkata Reddy Naidu", "O.P. Ramaswamy Reddiyar"], "ctx": " Title: Tamil Nadu Congress Committee Content: to the Security Laws of the Tamil Nadu Government Departments. Also, Unique Identification Authority of India was established in February 2009, with the help of the Central Government, an agency responsible for implementing the envisioned Multipurpose National Identity Card with the objective of increasing national security and facilitating e-governance. The latest celebrities to join the TNCC where actress Khushboo and actor Karthik Tamil Nadu Congress Committee Tamil Nadu Congress Committee(TNCC) is the wing of Indian National Congress serving in Tamil Nadu. The Current President is Su. Thirunavukkarasar. Social policy of the TNCC is officially based upon the Gandhian principle of", "before": {"answer": "M. K. Stalin, Edappadi K. Palaniswami", "em": 1, "f1": 0.5}, "after": {"answer": "Edappadi K. Palaniswami", "em": 1, "f1": 0.8}, "category": 1} 13 | {"question": "who plays the voice of johnny in sing", "reference": ["Taron Egerton"], "ctx": " Title: Sing (2016 American film) Content: Sing (2016 American film) Sing is a 2016 American computer-animated musical comedy-drama film produced by Illumination Entertainment. It was directed and written by Garth Jennings, co-directed by Christophe Lourdelet and starring the voices of Matthew McConaughey, Reese Witherspoon, Seth MacFarlane, Scarlett Johansson, John C. Reilly, Taron Egerton, and Tori Kelly. The film is about a group of anthropomorphic animals that enter a singing competition, hosted by a koala who is trying to save his theater. The film includes more than 60 songs from famous artists and also has an original song by Stevie Wonder and Ariana Grande called \"\"Faith,\"\" which", "before": {"answer": "Ansel Elgort.", "em": 0, "f1": 0}, "after": {"answer": "Taron Egerton", "em": 1, "f1": 1.0}, "category": 1} 14 | {"question": "how many times has lake placid hosted the winter olympics", "reference": ["twice"], "ctx": " Title: Winter Olympic Games Content: records were set in speed skating largely due to the introduction of the clap skate. The 2002 Winter Olympics were held in Salt Lake City, United States, hosting 77 nations and 2,399 athletes in 78 events in 7 sports. These Games were the first to take place since the September 11 attacks of 2001, which meant a higher degree of security to avoid a terrorist attack. The opening ceremony saw signs of the aftermath of the events of that day, including the flag that flew at Ground Zero, NYPD officer Daniel Rodríguez singing \"\"God Bless America\"\", and honour guards of", "before": {"answer": "Twice", "em": 1, "f1": 1.0}, "after": {"answer": "2 times", "em": 0, "f1": 0}, "category": 2} 15 | {"question": "who is nebula on guardians of the galaxy", "reference": ["Karen Gillan"], "ctx": " Title: Karen Gillan Content: Karen Gillan Karen Gillan (born 28 November 1987) is a Scottish actress, director, screenwriter and model. She played the role of Amy Pond, companion to the Eleventh Doctor, in the BBC One science fiction series \"\"Doctor Who\"\" (2010–2013). She has received wide recognition for portraying Nebula in the Marvel Cinematic Universe films \"\"Guardians of the Galaxy\"\" (2014), \"\"Guardians of the Galaxy Vol. 2\"\" (2017), \"\"\"\" (2018) and \"\"\"\" (2019) and also played Martha in \"\"\"\" (2017). In 2018, she released her first featured film as a writer and director, titled \"\"The Party's Just Beginning\"\". Gillan was born and raised in", "before": {"answer": "Nebula.", "em": 0, "f1": 0}, "after": {"answer": "Karen Gillan", "em": 1, "f1": 1.0}, "category": 1} 16 | {"question": "is a network connection device that can build tables that identify addresses on each network", "reference": ["a router", "routing table"], "ctx": " Title: Zigbee Content: with it. This extended discovery protocol permits external devices to find out about devices in a network and the services that they offer, which endpoints can report when queried by the discovering device (which has previously obtained their addresses). Matching services can also be used. The use of cluster identifiers enforces the binding of complementary entities using the binding tables, which are maintained by Zigbee coordinators, as the table must always be available within a network and coordinators are most likely to have a permanent power supply. Backups, managed by higher-level layers, may be needed by some applications. Binding requires", "before": {"answer": "Router", "em": 1, "f1": 1.0}, "after": {"answer": "Coordinator", "em": 0, "f1": 0}, "category": 2} 17 | {"question": "who played the king in the ten commandments", "reference": ["Yul Brynner", "Woodrow Strode"], "ctx": " Title: Jean Maximilien Lamarque Content: horses shot from under him. In 1810, he was created a baron of the French Empire. Thereafter he was transferred to Spain to support Joseph Bonaparte who had been made king of Spain by his brother, but who was slowly forced to retreat by the rebellious Spanish and the British under Wellington. When Napoleon was exiled to Elba in 1814, Lamarque remained loyal to him, returning to the emperor's service during the Hundred Days. While Napoleon marched to Belgium to deal with the British and Prussian armies, Lamarque commanded a division of ten thousand men against a Royalist uprising in", "before": {"answer": "Yul Brynner", "em": 1, "f1": 1.0}, "after": {"answer": "Charlton Heston.", "em": 0, "f1": 0}, "category": 2} 18 | {"question": "swan lake the sleeping beauty and the nutcracker are three famous ballets by", "reference": ["Pyotr Ilyich Tchaikovsky"], "ctx": " Title: Kristian Fredrikson Content: realised a long-held ambition to design all three of Pyotr Ilyich Tchaikovsky's ballets. He designed \"\"The Nutcracker\"\" for the Royal New Zealand Ballet, \"\"Swan Lake\"\" for the Houston Ballet, and \"\"The Sleeping Beauty\"\" for the Australian Ballet. This last work was the Australian Ballet's largest ever production, requiring 300 costumes and four major sets, and was still touring at the time of Fredrikson's death. In Australia Fredriksen created costumes and undertook production design for a number of film and television productions including costumes for the feature films \"\"Undercover\"\", \"\"Sky Pirate\"\" and \"\"Short Changed\"\" and production design for television series \"\"Vietnam\"\",", "before": {"answer": "Tchaikovsky", "em": 0, "f1": 0.5}, "after": {"answer": "Pyotr Ilyich Tchaikovsky.", "em": 1, "f1": 1.0}, "category": 1} 19 | {"question": "who painted the ceiling of the florence cathedral", "reference": ["Giorgio Vasari", "Federico Zuccari", "Giorgio Vasari and Federico Zuccari"], "ctx": " Title: Raphael Content: more important commission than any he had received before; he had only painted one altarpiece in Florence itself. Several other artists and their teams of assistants were already at work on different rooms, many painting over recently completed paintings commissioned by Julius's loathed predecessor, Alexander VI, whose contributions, and arms, Julius was determined to efface from the palace. Michelangelo, meanwhile, had been commissioned to paint the Sistine Chapel ceiling. This first of the famous \"\"Stanze\"\" or \"\"Raphael Rooms\"\" to be painted, now known as the \"\"Stanza della Segnatura\"\" after its use in Vasari's time, was to make a stunning impact", "before": {"answer": "Giorgio Vasari", "em": 1, "f1": 1.0}, "after": {"answer": "Brunelleschi", "em": 0, "f1": 0}, "category": 2} 20 | {"question": "who was the actor who played crocodile dundee", "reference": ["Paul Hogan", "Paul Hogan , AM"], "ctx": " Title: Gerry Skilton Content: Gerry Skilton Gerry Skilton is an Australian actor of the 20th and early 21st century. Skilton is best known as playing the character Nugget in all three of the Crocodile Dundee series of films: \"\"Crocodile Dundee\"\", \"\"Crocodile Dundee II\"\", and \"\"Crocodile Dundee in Los Angeles\"\". With permission from Paul Hogan, he reprised the character in 2010 for a proposed outback-adventure reality series \"\"Nugget Gets A Life\"\" which he self-financed and had hoped to sell to a network. Among other roles, Skilton played Wayne Churchill in the 1986 Australian mini-series \"\"Cyclone Tracy\"\", Corporal Andy Crilley (the ship's cook) in the 1989", "before": {"answer": "Paul Hogan", "em": 1, "f1": 1.0}, "after": {"answer": "Gerry Skilton", "em": 0, "f1": 0}, "category": 2} 21 | -------------------------------------------------------------------------------- /joined_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import jsonlines 4 | import numpy as np 5 | import json 6 | import torch 7 | from torch import nn 8 | from torch.utils.data import Dataset, DataLoader 9 | from transformers import RobertaPreTrainedModel, RobertaModel 10 | from transformers import AdamW, get_scheduler 11 | from tqdm.auto import tqdm 12 | from torch.utils.tensorboard import SummaryWriter 13 | import os 14 | import argparse 15 | from sklearn.metrics import f1_score 16 | from transformers import ( 17 | AutoTokenizer, 18 | AutoModel, 19 | AutoConfig, 20 | AutoModelForSequenceClassification, 21 | GenerationConfig, 22 | LlamaForCausalLM, 23 | LlamaTokenizer, 24 | ) 25 | 26 | 27 | class JoinedDataset(Dataset): 28 | 29 | def __init__(self, data_file): 30 | self.data = self.load_data(data_file) 31 | 32 | def load_data(self, data_file): 33 | Data = {} 34 | idx = 0 35 | with jsonlines.open(data_file, "r") as fin: 36 | for line in fin.iter(): 37 | Data[idx] = line 38 | idx += 1 39 | return Data 40 | 41 | def __len__(self): 42 | return len(self.data) 43 | 44 | def __getitem__(self, idx): 45 | if isinstance(idx, slice): 46 | start, stop, step = idx.indices(len(self)) 47 | return [self.data[i] for i in range(start, stop, step)] 48 | return self.data[idx] 49 | 50 | 51 | class Collater: 52 | def __init__(self, tokenizer, is_train) -> None: 53 | self._tokenizer = tokenizer 54 | self.is_train = is_train 55 | 56 | def tokenize(self, *args): 57 | return self._tokenizer( 58 | *args, 59 | padding=True, 60 | truncation=True, 61 | return_tensors="pt", 62 | ) 63 | 64 | def collate_training_data(self, batch_samples): 65 | """ 66 | 输入数据形式: 67 | [ 68 | { 69 | "id": xx, 70 | "query": xx, 71 | "classification":{"text": xx, "label": xx}, 72 | "rank_list":[xx, xx, xx, xx], 73 | "positive": [xx...], 74 | "negative": [xx...], 75 | }... 76 | ] 77 | 78 | 输出数据形式: 79 | { 80 | "classification": [tokenizer([[query, doc]...]), tensor([label...])] 81 | "rank": tokenizer([[query, xx]...]) 82 | "positive": tokenizer([[query, positive_sample]...]) 83 | "negative": tokenizer([[query, negative_sample]...]) 84 | } 85 | """ 86 | res = { 87 | "classification": [ 88 | self.tokenize( 89 | [ 90 | [sample["query"], sample["classification"]["text"]] 91 | for sample in batch_samples 92 | ] 93 | ), 94 | # 1->1, 2->0 95 | torch.tensor( 96 | [2 - sample["classification"]["label"] for sample in batch_samples] 97 | ), 98 | ], 99 | "rank": self.tokenize( 100 | [ 101 | [sample["query"], doc] 102 | for sample in batch_samples 103 | for doc in sample["rank_list"] 104 | ] 105 | ), 106 | "positive": self.tokenize( 107 | [ 108 | [sample["query"], pos_doc] 109 | for sample in batch_samples 110 | for pos_doc in sample["positive"] 111 | ] 112 | ), 113 | "negative": self.tokenize( 114 | [ 115 | [sample["query"], neg_doc] 116 | for sample in batch_samples 117 | for neg_doc in sample["negative"] 118 | ] 119 | ), 120 | } 121 | return res 122 | 123 | def collate_valid_and_test_data(self, batch_samples): 124 | """ 125 | 输入数据形式: 126 | [ 127 | {"question":"", "ctx":"", "category": xx} 128 | ] 129 | 输出数据形式: 130 | tokenizer([[query, doc]...]), tensor([label...]) 131 | """ 132 | return ( 133 | self.tokenize( 134 | [[sample["question"], sample["ctx"]] for sample in batch_samples] 135 | ), 136 | torch.tensor([2 - sample["category"] for sample in batch_samples]), 137 | ) 138 | 139 | def __call__(self, batch_samples): 140 | if self.is_train: 141 | return self.collate_training_data(batch_samples) 142 | else: 143 | return self.collate_valid_and_test_data(batch_samples) 144 | -------------------------------------------------------------------------------- /train_bge_joined.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | from typing import Dict 4 | import numpy as np 5 | import json 6 | import torch 7 | from torch import nn 8 | from torch.utils.data import Dataset, DataLoader 9 | from transformers import RobertaPreTrainedModel, RobertaModel 10 | from transformers import AdamW, get_scheduler 11 | from tqdm.auto import tqdm 12 | from torch.utils.tensorboard import SummaryWriter 13 | import os 14 | import argparse 15 | from sklearn.metrics import f1_score 16 | from transformers import ( 17 | AutoTokenizer, 18 | AutoModel, 19 | AutoConfig, 20 | AutoModelForSequenceClassification, 21 | GenerationConfig, 22 | LlamaForCausalLM, 23 | LlamaTokenizer, 24 | ) 25 | 26 | from joined_dataset import JoinedDataset, Collater 27 | from bge_joined_model import BgeJoinedModel, BgeJoinedModelLoss, WeightsCalculator 28 | 29 | 30 | class Trainer: 31 | def __init__( 32 | self, 33 | model, 34 | train_dataloader, 35 | valid_dataloader, 36 | test_dataloader, 37 | loss_types, 38 | optimizer, 39 | lr_scheduler, 40 | device, 41 | writer, 42 | ) -> None: 43 | self.model = model 44 | self.train_dataloader = train_dataloader 45 | self.valid_dataloader = valid_dataloader 46 | self.test_dataloader = test_dataloader 47 | self.loss_types = loss_types 48 | self.optimizer = optimizer 49 | self.lr_scheduler = lr_scheduler 50 | self.device = device 51 | self.writer = writer 52 | self.weights_calculator = WeightsCalculator(self.device) 53 | 54 | def calc_cls_loss(self, cls_tokens, labels): 55 | outputs = self.model.bge(**cls_tokens).last_hidden_state[:, 0, :] 56 | outputs = self.model.classifier(outputs) 57 | loss_cls = nn.functional.cross_entropy(outputs, labels) 58 | return loss_cls 59 | 60 | def calc_rank_loss(self, rank_tokens, batch_size): 61 | loss_rank = 0 62 | outputs = self.model.bge(**rank_tokens).last_hidden_state[:, 0, :] 63 | outputs = self.model.classifier(outputs)[:, 1] 64 | for group in range(batch_size): 65 | start = group * 4 66 | for i in range(start, start + 3): 67 | for j in range(i + 1, start + 4): 68 | loss_rank += -nn.functional.logsigmoid(outputs[i] - outputs[j]) 69 | # C(4,2) * n_groups 70 | loss_rank /= 6 * batch_size 71 | return loss_rank 72 | 73 | def calc_scl_loss(self, pos_tokens, neg_tokens, batch_size): 74 | p_outputs = self.model.bge(**pos_tokens).last_hidden_state[:, 0, :] 75 | n_outputs = self.model.bge(**neg_tokens).last_hidden_state[:, 0, :] 76 | loss_scl = 0 77 | for group in range(batch_size): 78 | features = torch.cat( 79 | [ 80 | p_outputs[group * 2 : group * 2 + 2], 81 | n_outputs[group * 2 : group * 2 + 2], 82 | ], 83 | dim=0, 84 | ).unsqueeze(1) 85 | features = nn.functional.normalize(features, dim=-1) 86 | labels = torch.tensor([1, 1, 0, 0]).to(features.device) 87 | loss_scl += self.model.scl_loss_func(features, labels) 88 | loss_scl /= batch_size 89 | return loss_scl 90 | 91 | def cal_loss(self, classification, rank, positive, negative): 92 | """ 93 | 输入数据形式: 94 | { 95 | "classification": [tokenizer([[query, doc]...]), tensor([label...])] 96 | "rank": tokenizer([[query, xx]...]) 97 | "positive": tokenizer([[query, positive_sample]...]) 98 | "negative": tokenizer([[query, negative_sample]...]) 99 | } 100 | """ 101 | losses = [] 102 | if BgeJoinedModelLoss.ClaasificationLoss in self.loss_types: 103 | X, labels = classification 104 | loss_cls = self.calc_cls_loss(X, labels) 105 | losses.append(loss_cls) 106 | 107 | if BgeJoinedModelLoss.RankLoss in self.loss_types: 108 | assert len(rank["input_ids"]) % 4 == 0 109 | loss_rank = self.calc_rank_loss( 110 | rank, batch_size=len(rank["input_ids"]) // 4 111 | ) 112 | losses.append(loss_rank) 113 | 114 | if BgeJoinedModelLoss.ContrastiveLoss in self.loss_types: 115 | assert ( 116 | len(positive["input_ids"]) == len(negative["input_ids"]) 117 | and len(positive["input_ids"]) % 2 == 0 118 | ) 119 | loss_scl = self.calc_scl_loss( 120 | positive, 121 | negative, 122 | batch_size=len(positive["input_ids"]) // 2, 123 | ) 124 | losses.append(loss_scl) 125 | 126 | self.weights_calculator.reset() 127 | weights = self.weights_calculator.calc_weights( 128 | [self.model.bge.embeddings, self.model.bge.encoder], losses 129 | ) 130 | loss = torch.stack(losses).matmul(weights) 131 | return loss 132 | 133 | def train_loop(self, epoch, total_loss): 134 | progress_bar = tqdm(range(len(self.train_dataloader))) 135 | progress_bar.set_description(f"loss: {0:>7f}") 136 | finish_step_num = (epoch - 1) * len(self.train_dataloader) 137 | 138 | self.model.train() 139 | for step, sample in enumerate(self.train_dataloader, start=1): 140 | for k, v in sample.items(): 141 | if isinstance(v, list): 142 | sample[k] = [vi.to(self.device) for vi in v] 143 | else: 144 | sample[k] = v.to(self.device) 145 | 146 | loss = self.cal_loss(**sample) 147 | self.writer.add_scalar("loss", loss, step + finish_step_num) 148 | 149 | self.optimizer.zero_grad() 150 | loss.backward() 151 | self.optimizer.step() 152 | if self.lr_scheduler is not None: 153 | self.lr_scheduler.step() 154 | 155 | total_loss += loss.item() 156 | progress_bar.set_description( 157 | f"loss: {total_loss/(finish_step_num + step):>7f}" 158 | ) 159 | progress_bar.update(1) 160 | return total_loss 161 | 162 | def test_loop(self, dataloader, dataset_type="Test"): 163 | assert dataset_type in ["Train", "Valid", "Test"] 164 | 165 | self.model.eval() 166 | tp, fp, fn = 0, 0, 0 167 | y0 = y1 = 0 168 | with torch.no_grad(): 169 | for sample in dataloader: 170 | X, y = sample["classification"] if dataset_type == "Train" else sample 171 | X, y = X.to(self.device), y.to(self.device) 172 | pred = self.model(X).argmax(1) 173 | tp += torch.sum((pred == 1) & (y == 1)).item() 174 | fp += torch.sum((pred == 1) & (y == 0)).item() 175 | fn += torch.sum((pred == 0) & (y == 1)).item() 176 | y0 += torch.sum(pred == 0).item() 177 | y1 += torch.sum(pred == 1).item() 178 | try: 179 | precision = tp / (tp + fp) 180 | recall = tp / (tp + fn) 181 | except ZeroDivisionError as e: 182 | print(e) 183 | print(f"n_y == 0: {y0}\nn_y == 1: {y1}") 184 | return 0 185 | f1 = 2 * (precision * recall) / (precision + recall) 186 | print(f"{dataset_type} dataset precision: {(100*precision):>0.1f}%") 187 | print(f"{dataset_type} dataset recall: {(100*recall):>0.1f}%") 188 | print(f"{dataset_type} dataset f1: {(100*f1):>0.1f}%") 189 | print(f"n_y == 0: {y0}\nn_y == 1: {y1}") 190 | return f1 191 | 192 | def train(self, epoch_num, outdir): 193 | total_loss = 0.0 194 | best_f1 = 0.0 195 | os.makedirs(outdir, exist_ok=True) 196 | for t in range(epoch_num): 197 | print(f"Epoch {t+1}/{epoch_num}\n-------------------------------") 198 | total_loss = self.train_loop(t + 1, total_loss) 199 | train_f1 = self.test_loop(self.train_dataloader, dataset_type="Train") 200 | self.writer.add_scalar("f1/train_acc", train_f1, t + 1) 201 | valid_f1 = self.test_loop(self.valid_dataloader, dataset_type="Valid") 202 | self.writer.add_scalar("f1/valid_f1", valid_f1, t + 1) 203 | if valid_f1 > best_f1: 204 | best_f1 = valid_f1 205 | print("saving new weights...\n") 206 | torch.save( 207 | self.model.state_dict(), 208 | os.path.join( 209 | outdir, 210 | f"epoch_{t+1}_valid_f1_{(100*valid_f1):0.1f}_model_weights.bin", 211 | ), 212 | ) 213 | self.test_loop(self.test_dataloader, dataset_type="Test") 214 | 215 | 216 | def get_args(): 217 | parser = argparse.ArgumentParser() 218 | parser.add_argument("--pretrained_model_path", type=str) 219 | parser.add_argument("--train_data_path", type=str) 220 | parser.add_argument("--valid_data_path", type=str) 221 | parser.add_argument("--test_data_path", type=str) 222 | parser.add_argument("--gpu", type=str, choices=["0", "1"], default="0") 223 | parser.add_argument("--outdir", type=str) 224 | parser.add_argument("--tensorboard_log_dir", type=str) 225 | parser.add_argument("--cls_loss", action="store_true") 226 | parser.add_argument("--rank_loss", action="store_true") 227 | parser.add_argument("--scl_loss", action="store_true") 228 | 229 | args = parser.parse_args() 230 | print(f"Using device: gpu:{args.gpu}") 231 | 232 | return args 233 | 234 | 235 | def seed_everything(seed=1029): 236 | random.seed(seed) 237 | os.environ["PYTHONHASHSEED"] = str(seed) 238 | np.random.seed(seed) 239 | torch.manual_seed(seed) 240 | torch.cuda.manual_seed(seed) 241 | torch.cuda.manual_seed_all(seed) 242 | # some cudnn methods can be random even after fixing the seed 243 | # unless you tell it to be deterministic 244 | torch.backends.cudnn.deterministic = True 245 | 246 | 247 | def main(): 248 | args = get_args() 249 | seed_everything(42) 250 | learning_rate = 1e-5 251 | batch_size = 4 252 | epoch_num = 10 253 | writer = SummaryWriter(args.tensorboard_log_dir) 254 | device = f"cuda:{args.gpu}" 255 | 256 | checkpoint = args.pretrained_model_path 257 | tokenizer = AutoTokenizer.from_pretrained(checkpoint, model_max_length=512) 258 | 259 | train_data = JoinedDataset(args.train_data_path) 260 | collater = Collater(tokenizer, is_train=True) 261 | train_dataloader = DataLoader( 262 | train_data, batch_size=batch_size, shuffle=True, collate_fn=collater 263 | ) 264 | 265 | valid_data = JoinedDataset(args.valid_data_path) 266 | collater = Collater(tokenizer, is_train=False) 267 | valid_dataloader = DataLoader( 268 | valid_data, batch_size=batch_size, shuffle=False, collate_fn=collater 269 | ) 270 | 271 | test_data = JoinedDataset(args.valid_data_path) 272 | test_dataloader = DataLoader( 273 | test_data, batch_size=batch_size, shuffle=False, collate_fn=collater 274 | ) 275 | loss_types = [] 276 | if args.cls_loss: 277 | loss_types.append(BgeJoinedModelLoss.ClaasificationLoss) 278 | if args.rank_loss: 279 | loss_types.append(BgeJoinedModelLoss.RankLoss) 280 | if args.scl_loss: 281 | loss_types.append(BgeJoinedModelLoss.ContrastiveLoss) 282 | model = BgeJoinedModel(checkpoint, loss_types) 283 | model.to(device) 284 | optimizer = AdamW(model.parameters(), lr=learning_rate) 285 | lr_scheduler = None 286 | # lr_scheduler = get_scheduler( 287 | # "linear", 288 | # optimizer=optimizer, 289 | # num_warmup_steps=0, 290 | # num_training_steps=epoch_num * len(train_dataloader), 291 | # ) 292 | 293 | trainer = Trainer( 294 | model=model, 295 | train_dataloader=train_dataloader, 296 | valid_dataloader=valid_dataloader, 297 | test_dataloader=test_dataloader, 298 | loss_types=loss_types, 299 | optimizer=optimizer, 300 | lr_scheduler=lr_scheduler, 301 | device=device, 302 | writer=writer, 303 | ) 304 | 305 | trainer.train(epoch_num=epoch_num, outdir=args.outdir) 306 | print("Done!") 307 | 308 | 309 | if __name__ == "__main__": 310 | main() 311 | --------------------------------------------------------------------------------