├── .DS_Store ├── AugData └── nlpaug_explore.py ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Config ├── LICENSE ├── Notebook ├── .DS_Store ├── .ipynb_checkpoints │ ├── SCCL-ExplicitSubmit-checkpoint.ipynb │ ├── SCCL-VirtualSubmit-checkpoint.ipynb │ ├── run_explicit-checkpoint.sh │ ├── run_virtual-checkpoint.sh │ └── utils-checkpoint.py ├── run_explicit.sh └── run_virtual.sh ├── README.md ├── dataloader ├── .DS_Store ├── .ipynb_checkpoints │ └── dataloader-checkpoint.py ├── __pycache__ │ └── dataloader.cpython-36.pyc └── dataloader.py ├── learner ├── .DS_Store ├── .ipynb_checkpoints │ ├── cluster_utils-checkpoint.py │ └── contrastive_utils-checkpoint.py ├── __pycache__ │ ├── cluster_utils.cpython-36.pyc │ └── contrastive_utils.cpython-36.pyc ├── cluster_utils.py └── contrastive_utils.py ├── logs └── .DS_Store ├── main.py ├── models ├── .DS_Store ├── .ipynb_checkpoints │ └── Transformers-checkpoint.py ├── Transformers.py └── __pycache__ │ └── Transformers.cpython-36.pyc ├── training.py └── utils ├── .DS_Store ├── .ipynb_checkpoints ├── kmeans-checkpoint.py ├── logger-checkpoint.py └── optimizer-checkpoint.py ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── kmeans.cpython-36.pyc ├── logger.cpython-36.pyc ├── metric.cpython-36.pyc └── optimizer.cpython-36.pyc ├── kmeans.py ├── logger.py ├── metric.py └── optimizer.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/.DS_Store -------------------------------------------------------------------------------- /AugData/nlpaug_explore.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import argparse 5 | import numpy as np 6 | import pandas as pd 7 | 8 | import torch 9 | import nlpaug.augmenter.char as nac 10 | import nlpaug.augmenter.word as naw 11 | from nlpaug.util import Action 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--dataset', type=str, default='stackoverflow') 15 | parser.add_argument('--augtype', type=str, default='ctxt_insertbertroberta') 16 | parser.add_argument('--aug_min', type=int, default=1) 17 | parser.add_argument('--aug_p', type=float, default=0.2) 18 | parser.add_argument('--aug_max', type=int, default=10) 19 | parser.add_argument('--gpuid', type=int, default=0) 20 | args = parser.parse_args() 21 | 22 | 23 | def set_global_random_seed(seed): 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | np.random.seed(seed) 27 | random.seed(seed) 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = True 30 | 31 | 32 | def contextual_augment(data_source, data_target, textcol="text", aug_p=0.2, device1="cuda", device2="cuda"): 33 | ### contextual augmentation 34 | print(f"\n-----transformer_augment-----\n") 35 | augmenter1 = naw.ContextualWordEmbsAug( 36 | model_path='roberta-base', action="substitute", aug_min=1, aug_p=aug_p, device=device1) 37 | 38 | augmenter2 = naw.ContextualWordEmbsAug( 39 | model_path='bert-base-uncased', action="substitute", aug_min=1, aug_p=aug_p, device=device2) 40 | 41 | train_data = pd.read_csv(data_source) 42 | train_text = train_data[textcol].fillna('.').astype(str).values 43 | print("train_text:", len(train_text), type(train_text[0])) 44 | 45 | auglist1, auglist2 = [], [] 46 | for txt in train_text: 47 | atxt1 = augmenter1.augment(txt) 48 | atxt2 = augmenter2.augment(txt) 49 | auglist1.append(str(atxt1)) 50 | auglist2.append(str(atxt2)) 51 | 52 | train_data[textcol+"1"] = pd.Series(auglist1) 53 | train_data[textcol+"2"] = pd.Series(auglist2) 54 | train_data.to_csv(data_target, index=False) 55 | 56 | for o, a1, a2 in zip(train_text[:5], auglist1[:5], auglist2[:5]): 57 | print("-----Original Text: \n", o) 58 | print("-----Augmented Text1: \n", a1) 59 | print("-----Augmented Text2: \n", a2) 60 | 61 | 62 | def word_deletion(data_source, data_target, textcol="text", aug_p=0.2): 63 | ### wordnet based data augmentation 64 | print(f"\n-----word_deletion-----\n") 65 | aug = naw.RandomWordAug(aug_min=1, aug_p=aug_p) 66 | 67 | train_data = pd.read_csv(data_source) 68 | train_text = train_data[textcol].fillna('.').astype(str).values 69 | print("train_text:", len(train_text), type(train_text[0])) 70 | 71 | augtxts1, augtxts2 = [], [] 72 | for txt in train_text: 73 | atxt = aug.augment(txt, n=2, num_thread=1) 74 | augtxts1.append(str(atxt[0])) 75 | augtxts2.append(str(atxt[1])) 76 | 77 | train_data[textcol+"1"] = pd.Series(augtxts1) 78 | train_data[textcol+"2"] = pd.Series(augtxts2) 79 | train_data.to_csv(data_target, index=False) 80 | 81 | for o, a1, a2 in zip(train_text[:5], augtxts1[:5], augtxts2[:5]): 82 | print("-----Original Text: \n", o) 83 | print("-----Augmentation1: \n", a1) 84 | print("-----Augmentation2: \n", a2) 85 | 86 | 87 | def randomchar_augment(data_source, data_target, textcol="text", aug_p=0.2, augstage="post"): 88 | ### wordnet based data augmentation 89 | print(f"\n*****random char aug: rate--{aug_p}, stage: {augstage}*****\n") 90 | aug = nac.RandomCharAug(action="swap", aug_char_p=aug_p, aug_word_p=aug_p) 91 | 92 | train_data = pd.read_csv(data_source) 93 | if augstage == "init": 94 | train_text = train_data[textcol].fillna('.').astype(str).values 95 | print("train_text:", len(train_text), type(train_text[0])) 96 | 97 | augtxts1, augtxts2 = [], [] 98 | for txt in train_text: 99 | atxt = aug.augment(txt, n=2, num_thread=1) 100 | augtxts1.append(str(atxt[0])) 101 | augtxts2.append(str(atxt[1])) 102 | 103 | train_data[textcol+"1"] = pd.Series(augtxts1) 104 | train_data[textcol+"2"] = pd.Series(augtxts2) 105 | train_data.to_csv(data_target, index=False) 106 | 107 | for o, a1, a2 in zip(train_text[:5], augtxts1[:5], augtxts2[:5]): 108 | print("-----Original Text: \n", o) 109 | print("-----Augmentation1: \n", a1) 110 | print("-----Augmentation2: \n", a2) 111 | else: 112 | train_text1 = train_data[textcol+"1"].fillna('.').astype(str).values 113 | train_text2 = train_data[textcol+"2"].fillna('.').astype(str).values 114 | 115 | augtxts1, augtxts2 = [], [] 116 | for txt1, txt2 in zip(train_text1, train_text2): 117 | atxt1 = aug.augment(txt1, n=1, num_thread=1) 118 | atxt2 = aug.augment(txt2, n=1, num_thread=1) 119 | augtxts1.append(str(atxt1)) 120 | augtxts2.append(str(atxt2)) 121 | 122 | train_data[textcol+"1"] = pd.Series(augtxts1) 123 | train_data[textcol+"2"] = pd.Series(augtxts2) 124 | train_data.to_csv(data_target, index=False) 125 | 126 | for o1, a1, o2, a2 in zip(train_text1[:2], augtxts1[:2], train_text2[:2], augtxts2[:2]): 127 | print("-----Original Text1: \n", o1) 128 | print("-----Augmentation1: \n", a1) 129 | print("-----Original Text2: \n", o2) 130 | print("-----Augmentation2: \n", a2) 131 | 132 | 133 | 134 | def augment_files(datadir="./", targetdir="./", dataset="wiki1m_unique", aug_p=0.1, augtype="trans_subst"): 135 | set_global_random_seed(0) 136 | device1=torch.cuda.set_device(0) 137 | device2=torch.cuda.set_device(1) 138 | 139 | DataSource = os.path.join(datadir, dataset + ".csv") 140 | DataTarget = os.path.join(targetdir, '{}_{}_{}.csv'.format(dataset, augtype, int(aug_p*100))) 141 | 142 | if augtype == "word_deletion": 143 | augseq = word_deletion(DataSource, DataTarget, textcol="text", aug_p=aug_p) 144 | elif augtype == "trans_subst": 145 | augseq = contextual_augment(DataSource, DataTarget, textcol="text", aug_p=aug_p, device1=device1, device2=device2) 146 | elif augtype == "charswap": 147 | augseq = randomchar_augment(DataSource, DataTarget, textcol="text", aug_p=aug_p, augstage="post") 148 | else: 149 | print("Please specify AugType!!") 150 | 151 | 152 | if __name__ == '__main__': 153 | 154 | datadir = "path-to-the-original-datasets" 155 | targetdir = "path-to-store-the-augmented-datasets" 156 | 157 | # datasets = ["agnews", "searchsnippets", "stackoverflow", "biomedical", "googlenews_TS", "googlenews_T", "googlenews_S"] 158 | 159 | # for dataset in datasets: 160 | # augment_files(datadir=datadir, targetdir=targetdir, dataset=dataset, aug_p=0.1, augtype="trans_subst") 161 | # augment_files(datadir=datadir, targetdir=targetdir, dataset=dataset, aug_p=0.2, augtype="trans_subst") 162 | 163 | # for dataset in datasets: 164 | # augment_files(datadir=datadir, targetdir=targetdir, dataset=dataset, aug_p=0.1, augtype="word_deletion") 165 | # augment_files(datadir=datadir, targetdir=targetdir, dataset=dataset, aug_p=0.2, augtype="word_deletion") 166 | 167 | datasets = ["agnews_trans_subst_20", "searchsnippets_trans_subst_20", "stackoverflow_trans_subst_20", "biomedical_trans_subst_20", "googlenews_TS_trans_subst_20", "googlenews_T_trans_subst_20", "googlenews_S_trans_subst_20"] 168 | 169 | for dataset in datasets: 170 | augment_files(datadir=datadir, targetdir=targetdir, dataset=dataset, aug_p=0.2, augtype="charswap") 171 | 172 | 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /Config: -------------------------------------------------------------------------------- 1 | package.SCCLBert = { 2 | interfaces = (1.0); 3 | 4 | # Use NoOpBuild. See https://w.amazon.com/index.php/BrazilBuildSystem/NoOpBuild 5 | build-system = no-op; 6 | build-tools = { 7 | 1.0 = { 8 | NoOpBuild = 1.0; 9 | }; 10 | }; 11 | 12 | # Use runtime-dependencies for when you want to bring in additional 13 | # packages when deploying. 14 | # Use dependencies instead if you intend for these dependencies to 15 | # be exported to other packages that build against you. 16 | dependencies = { 17 | 1.0 = { 18 | }; 19 | }; 20 | 21 | runtime-dependencies = { 22 | 1.0 = { 23 | }; 24 | }; 25 | 26 | }; 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 15 | 16 | -------------------------------------------------------------------------------- /Notebook/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/Notebook/.DS_Store -------------------------------------------------------------------------------- /Notebook/.ipynb_checkpoints/SCCL-ExplicitSubmit-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "inclusive-adjustment", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import time\n", 11 | "from sagemaker.pytorch import PyTorch\n", 12 | "from utils import (wait_till_all_done, CLUSTER_Augmented_DATASETS_CTXT_20, CLUSTER_Augmented_DATASETS_CTXT_10, CLUSTER_Augmented_DATASETS_CTXT_CHAR_10,\n", 13 | " CLUSTER_Augmented_DATASETS_CTXT_CHAR_20, CLUSTER_Augmented_DATASETS_WDEL_20, CLUSTER_Augmented_DATASETS_WDEL_10)\n", 14 | "\n", 15 | "role = 'arn:aws:iam::157264205850:role/dejiao-sagemaker-run'" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 3, 21 | "id": "guilty-fifty", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "bert_models = [\"distilbert\"]\n", 26 | "lr_params = [(5e-06, 100), (1e-05, 100)]\n", 27 | "contrast_types = [\"Orig\"]\n", 28 | "temps = [0.5]\n", 29 | "objectives = [\"contrastive\", \"SCCL\"]\n", 30 | "datasets = [\"agnews\", \"searchsnippets\", \"stackoverflow\", \"biomedical\", \"tweet\", \"googleT\", \"googleS\", \"googleTS\"]\n", 31 | "\n", 32 | "use_pretrain=\"SBERT\"\n", 33 | "augtype=\"explicit\"\n", 34 | "batch_size = 400\n", 35 | "maxlen = 32\n", 36 | "maxiter = 3000\n", 37 | "eta = 10\n", 38 | "alpha = 1.0\n", 39 | "base_job_name = \"SCCLv2-distil-exp-strategy-hpo-long\"\n", 40 | "s3_dataroot = \"s3://dejiao-experiment-east1/datasets/psc_shorttext/\"\n", 41 | "s3_resdir = \"s3://dejiao-experiment-east1/train/SCCL-SBERT-EXP-ALL-LONG/\"\n", 42 | "\n", 43 | "# augmentation_stratgies = [\n", 44 | "# CLUSTER_Augmented_DATASETS_CTXT_20, \n", 45 | "# CLUSTER_Augmented_DATASETS_CTXT_CHAR_20,\n", 46 | "# CLUSTER_Augmented_DATASETS_WDEL_20, \n", 47 | "# CLUSTER_Augmented_DATASETS_WDEL_10, \n", 48 | "# CLUSTER_Augmented_DATASETS_CTXT_10, \n", 49 | "# ]\n", 50 | "\n", 51 | "augmentation_stratgies = [\n", 52 | " CLUSTER_Augmented_DATASETS_CTXT_CHAR_10,\n", 53 | "]" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "id": "afraid-ranch", 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "agnews_trans_subst_20 \t 4 \t text \t alpha:1.0 \n", 67 | "submit: 1\n", 68 | "distilbert \t lr: 5e-06\n", 69 | "agnews_trans_subst_20 \t 4 \t text \t alpha:1.0 \n", 70 | "submit: 2\n", 71 | "distilbert \t lr: 5e-06\n", 72 | "agnews_trans_subst_20 \t 4 \t text \t alpha:1.0 \n", 73 | "submit: 3\n", 74 | "distilbert \t lr: 1e-05\n", 75 | "agnews_trans_subst_20 \t 4 \t text \t alpha:1.0 \n", 76 | "submit: 4\n", 77 | "distilbert \t lr: 1e-05\n", 78 | "searchsnippets_trans_subst_20 \t 8 \t text \t alpha:1.0 \n", 79 | "submit: 5\n", 80 | "distilbert \t lr: 5e-06\n", 81 | "searchsnippets_trans_subst_20 \t 8 \t text \t alpha:1.0 \n", 82 | "submit: 6\n", 83 | "distilbert \t lr: 5e-06\n", 84 | "searchsnippets_trans_subst_20 \t 8 \t text \t alpha:1.0 \n", 85 | "submit: 7\n", 86 | "distilbert \t lr: 1e-05\n", 87 | "searchsnippets_trans_subst_20 \t 8 \t text \t alpha:1.0 \n", 88 | "submit: 8\n", 89 | "distilbert \t lr: 1e-05\n", 90 | "stackoverflow_trans_subst_20 \t 20 \t text \t alpha:10.0 \n", 91 | "submit: 9\n", 92 | "distilbert \t lr: 5e-06\n", 93 | "stackoverflow_trans_subst_20 \t 20 \t text \t alpha:10.0 \n", 94 | "submit: 10\n", 95 | "distilbert \t lr: 5e-06\n", 96 | "stackoverflow_trans_subst_20 \t 20 \t text \t alpha:10.0 \n", 97 | "submit: 11\n", 98 | "distilbert \t lr: 1e-05\n", 99 | "stackoverflow_trans_subst_20 \t 20 \t text \t alpha:10.0 \n", 100 | "submit: 12\n", 101 | "distilbert \t lr: 1e-05\n", 102 | "biomedical_trans_subst_20 \t 20 \t text \t alpha:10.0 \n", 103 | "submit: 13\n", 104 | "distilbert \t lr: 5e-06\n", 105 | "biomedical_trans_subst_20 \t 20 \t text \t alpha:10.0 \n", 106 | "submit: 14\n", 107 | "distilbert \t lr: 5e-06\n", 108 | "biomedical_trans_subst_20 \t 20 \t text \t alpha:10.0 \n", 109 | "submit: 15\n", 110 | "distilbert \t lr: 1e-05\n", 111 | "biomedical_trans_subst_20 \t 20 \t text \t alpha:10.0 \n", 112 | "submit: 16\n", 113 | "distilbert \t lr: 1e-05\n", 114 | "tweet89_trans_subst_20 \t 89 \t text \t alpha:1.0 \n", 115 | "submit: 17\n", 116 | "distilbert \t lr: 5e-06\n", 117 | "tweet89_trans_subst_20 \t 89 \t text \t alpha:1.0 \n", 118 | "submit: 18\n", 119 | "distilbert \t lr: 5e-06\n", 120 | "tweet89_trans_subst_20 \t 89 \t text \t alpha:1.0 \n", 121 | "submit: 19\n", 122 | "distilbert \t lr: 1e-05\n", 123 | "tweet89_trans_subst_20 \t 89 \t text \t alpha:1.0 \n", 124 | "submit: 20\n", 125 | "distilbert \t lr: 1e-05\n", 126 | "googlenews_T_trans_subst_20 \t 152 \t text \t alpha:1.0 \n", 127 | "submit: 21\n", 128 | "distilbert \t lr: 5e-06\n", 129 | "googlenews_T_trans_subst_20 \t 152 \t text \t alpha:1.0 \n", 130 | "submit: 22\n", 131 | "distilbert \t lr: 5e-06\n", 132 | "googlenews_T_trans_subst_20 \t 152 \t text \t alpha:1.0 \n", 133 | "submit: 23\n", 134 | "distilbert \t lr: 1e-05\n", 135 | "googlenews_T_trans_subst_20 \t 152 \t text \t alpha:1.0 \n", 136 | "submit: 24\n", 137 | "distilbert \t lr: 1e-05\n", 138 | "googlenews_S_trans_subst_20 \t 152 \t text \t alpha:1.0 \n", 139 | "submit: 25\n", 140 | "distilbert \t lr: 5e-06\n", 141 | "googlenews_S_trans_subst_20 \t 152 \t text \t alpha:1.0 \n", 142 | "submit: 26\n", 143 | "distilbert \t lr: 5e-06\n", 144 | "googlenews_S_trans_subst_20 \t 152 \t text \t alpha:1.0 \n", 145 | "submit: 27\n", 146 | "distilbert \t lr: 1e-05\n", 147 | "googlenews_S_trans_subst_20 \t 152 \t text \t alpha:1.0 \n", 148 | "submit: 28\n", 149 | "distilbert \t lr: 1e-05\n", 150 | "googlenews_TS_trans_subst_20 \t 152 \t text \t alpha:1.0 \n", 151 | "submit: 29\n", 152 | "distilbert \t lr: 5e-06\n", 153 | "googlenews_TS_trans_subst_20 \t 152 \t text \t alpha:1.0 \n", 154 | "submit: 30\n", 155 | "distilbert \t lr: 5e-06\n", 156 | "googlenews_TS_trans_subst_20 \t 152 \t text \t alpha:1.0 \n", 157 | "submit: 31\n", 158 | "distilbert \t lr: 1e-05\n", 159 | "googlenews_TS_trans_subst_20 \t 152 \t text \t alpha:1.0 \n", 160 | "submit: 32\n", 161 | "distilbert \t lr: 1e-05\n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "idx = 1 \n", 167 | "\n", 168 | "for CLUSTER_Augmented_DATASETS in augmentation_stratgies:\n", 169 | " wait_till_all_done(base_job_name) \n", 170 | " for datakey in datasets:\n", 171 | " \n", 172 | " for lr, lr_scale in lr_params:\n", 173 | " for temperature in temps:\n", 174 | " for objective in objectives:\n", 175 | " for ctype in contrast_types:\n", 176 | " for bert in bert_models:\n", 177 | " \n", 178 | " dataname, num_classes, text, label = CLUSTER_Augmented_DATASETS[datakey]\n", 179 | " \n", 180 | " if datakey in [\"stackoverflow\", \"biomedical\"]:\n", 181 | " alpha = 10.0\n", 182 | " else:\n", 183 | " alpha = 1.0\n", 184 | " \n", 185 | " print(f\"{dataname} \\t {num_classes} \\t {text} \\t alpha:{alpha} \")\n", 186 | "\n", 187 | " hyperparameters = {\n", 188 | " 'train_instance': \"sagemaker\",\n", 189 | " 'use_pretrain': use_pretrain,\n", 190 | " 'datapath': s3_dataroot,\n", 191 | " 'dataname': dataname, \n", 192 | " 'text': text,\n", 193 | " 'label': label,\n", 194 | " 'num_classes': num_classes,\n", 195 | " 'bert': bert,\n", 196 | " 'objective': objective,\n", 197 | " 'alpha': alpha,\n", 198 | " 'eta': eta, \n", 199 | " 'augtype': augtype,\n", 200 | " 'contrast_type': ctype,\n", 201 | " 'lr': lr,\n", 202 | " 'lr_scale': lr_scale,\n", 203 | " 'lr_scale_contrast': '100',\n", 204 | " 'batch_size': batch_size,\n", 205 | " 'max_length': maxlen,\n", 206 | " 'temperature': temperature,\n", 207 | " 'max_iter': maxiter,\n", 208 | " 'print_freq': '100',\n", 209 | " 'seed': '0',\n", 210 | " 'gpuid': '0',\n", 211 | " 'resdir': '/tmp/resnli/PaperTempRes/',\n", 212 | " 's3_resdir': s3_resdir,\n", 213 | " }\n", 214 | "\n", 215 | " try:\n", 216 | " estimator = PyTorch(entry_point='main.py',\n", 217 | " source_dir='/home/ec2-user/efs/dejiao-explore/code/SCCL/',\n", 218 | " role=role,\n", 219 | " instance_count=1,\n", 220 | " instance_type='ml.p3.2xlarge',\n", 221 | " image_uri='157264205850.dkr.ecr.us-east-1.amazonaws.com/vncl-transformers-p17',\n", 222 | " base_job_name = base_job_name,\n", 223 | " hyperparameters=hyperparameters,\n", 224 | " output_path='s3://dejiao-sagemaker-east1/SCCL/',\n", 225 | " framework_version='1.8.1',\n", 226 | " py_version = 'py3',\n", 227 | " debugger_hook_config=False,\n", 228 | " max_run=3 * 24 * 60 * 60,\n", 229 | " volume_size = 500,\n", 230 | " )\n", 231 | "\n", 232 | " estimator.fit(wait=False)\n", 233 | " print(\"submit: {}\".format(idx))\n", 234 | " except:\n", 235 | " print(\"submit: {} failed\".format(idx))\n", 236 | "\n", 237 | " time.sleep(2)\n", 238 | " idx += 1\n", 239 | "\n", 240 | " print(bert, \"\\t lr:\", lr)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "id": "accepting-insight", 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [] 250 | } 251 | ], 252 | "metadata": { 253 | "kernelspec": { 254 | "display_name": "Environment (conda_pytorch_latest_p37)", 255 | "language": "python", 256 | "name": "conda_pytorch_latest_p37" 257 | }, 258 | "language_info": { 259 | "codemirror_mode": { 260 | "name": "ipython", 261 | "version": 3 262 | }, 263 | "file_extension": ".py", 264 | "mimetype": "text/x-python", 265 | "name": "python", 266 | "nbconvert_exporter": "python", 267 | "pygments_lexer": "ipython3", 268 | "version": "3.7.12" 269 | } 270 | }, 271 | "nbformat": 4, 272 | "nbformat_minor": 5 273 | } 274 | -------------------------------------------------------------------------------- /Notebook/.ipynb_checkpoints/SCCL-VirtualSubmit-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "inclusive-adjustment", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import time\n", 11 | "from sagemaker.pytorch import PyTorch\n", 12 | "from utils import wait_till_all_done, CLUSTER_DATASETS\n", 13 | "\n", 14 | "role = 'arn:aws:iam::157264205850:role/dejiao-sagemaker-run'" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "id": "guilty-fifty", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "bert_models = [\"distilroberta\", \"distilbert\"]\n", 25 | "lr_params = [(1e-05, 100), (3e-05, 100)]\n", 26 | "contrast_types = [\"HardNeg\", \"Orig\"]\n", 27 | "temps = [0.5, 0.1]\n", 28 | "objectives = [\"contrastive\", \"SCCL\"]\n", 29 | "datasets = [\"agnews\", \"searchsnippets\", \"stackoverflow\", \"biomedical\", \"tweet\", \"googleT\", \"googleS\", \"googleTS\"]\n", 30 | "\n", 31 | "use_pretrain=\"SBERT\"\n", 32 | "augtype=\"virtual\"\n", 33 | "batch_size = 512\n", 34 | "maxlen = 32\n", 35 | "maxiter = 1000\n", 36 | "eta = 10\n", 37 | "base_job_name = \"SCCLv2-distil-hpo\"\n", 38 | "s3_dataroot = \"s3://dejiao-experiment-east1/datasets/psc_shorttext/\"\n", 39 | "s3_resdir = \"s3://dejiao-experiment-east1/train/SCCL-SBERT/\"" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "afraid-ranch", 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "name": "stdout", 50 | "output_type": "stream", 51 | "text": [ 52 | "submit: 1\n", 53 | "distilroberta \t lr: 1e-05\n", 54 | "submit: 2\n", 55 | "distilbert \t lr: 1e-05\n", 56 | "submit: 3\n", 57 | "distilroberta \t lr: 1e-05\n", 58 | "submit: 4\n", 59 | "distilbert \t lr: 1e-05\n", 60 | "submit: 5\n", 61 | "distilroberta \t lr: 1e-05\n", 62 | "submit: 6\n", 63 | "distilbert \t lr: 1e-05\n", 64 | "submit: 7\n", 65 | "distilroberta \t lr: 1e-05\n", 66 | "submit: 8\n", 67 | "distilbert \t lr: 1e-05\n", 68 | "submit: 9\n", 69 | "distilroberta \t lr: 1e-05\n", 70 | "submit: 10\n", 71 | "distilbert \t lr: 1e-05\n", 72 | "submit: 11\n", 73 | "distilroberta \t lr: 1e-05\n", 74 | "submit: 12\n", 75 | "distilbert \t lr: 1e-05\n", 76 | "submit: 13\n", 77 | "distilroberta \t lr: 1e-05\n", 78 | "submit: 14\n", 79 | "distilbert \t lr: 1e-05\n", 80 | "submit: 15\n", 81 | "distilroberta \t lr: 1e-05\n", 82 | "submit: 16\n", 83 | "distilbert \t lr: 1e-05\n", 84 | "submit: 17\n", 85 | "distilroberta \t lr: 1e-05\n", 86 | "submit: 18\n", 87 | "distilbert \t lr: 1e-05\n", 88 | "submit: 19\n", 89 | "distilroberta \t lr: 1e-05\n", 90 | "submit: 20\n", 91 | "distilbert \t lr: 1e-05\n", 92 | "submit: 21\n", 93 | "distilroberta \t lr: 1e-05\n", 94 | "submit: 22\n", 95 | "distilbert \t lr: 1e-05\n", 96 | "submit: 23\n", 97 | "distilroberta \t lr: 1e-05\n", 98 | "submit: 24\n", 99 | "distilbert \t lr: 1e-05\n", 100 | "submit: 25\n", 101 | "distilroberta \t lr: 1e-05\n", 102 | "submit: 26\n", 103 | "distilbert \t lr: 1e-05\n", 104 | "submit: 27\n", 105 | "distilroberta \t lr: 1e-05\n", 106 | "submit: 28\n", 107 | "distilbert \t lr: 1e-05\n", 108 | "submit: 29\n", 109 | "distilroberta \t lr: 1e-05\n", 110 | "submit: 30\n", 111 | "distilbert \t lr: 1e-05\n", 112 | "submit: 31\n", 113 | "distilroberta \t lr: 1e-05\n", 114 | "submit: 32\n", 115 | "distilbert \t lr: 1e-05\n", 116 | "submit: 33\n", 117 | "distilroberta \t lr: 1e-05\n", 118 | "submit: 34\n", 119 | "distilbert \t lr: 1e-05\n", 120 | "submit: 35\n", 121 | "distilroberta \t lr: 1e-05\n", 122 | "submit: 36\n", 123 | "distilbert \t lr: 1e-05\n", 124 | "submit: 37\n", 125 | "distilroberta \t lr: 1e-05\n", 126 | "submit: 38\n", 127 | "distilbert \t lr: 1e-05\n", 128 | "submit: 39\n", 129 | "distilroberta \t lr: 1e-05\n", 130 | "submit: 40\n", 131 | "distilbert \t lr: 1e-05\n", 132 | "submit: 41\n", 133 | "distilroberta \t lr: 1e-05\n", 134 | "submit: 42\n", 135 | "distilbert \t lr: 1e-05\n", 136 | "submit: 43\n", 137 | "distilroberta \t lr: 1e-05\n", 138 | "submit: 44\n", 139 | "distilbert \t lr: 1e-05\n", 140 | "submit: 45\n", 141 | "distilroberta \t lr: 1e-05\n", 142 | "submit: 46\n", 143 | "distilbert \t lr: 1e-05\n", 144 | "submit: 47\n", 145 | "distilroberta \t lr: 1e-05\n", 146 | "submit: 48\n", 147 | "distilbert \t lr: 1e-05\n", 148 | "submit: 49\n", 149 | "distilroberta \t lr: 1e-05\n", 150 | "submit: 50\n", 151 | "distilbert \t lr: 1e-05\n", 152 | "submit: 51\n", 153 | "distilroberta \t lr: 1e-05\n", 154 | "submit: 52\n", 155 | "distilbert \t lr: 1e-05\n", 156 | "submit: 53\n", 157 | "distilroberta \t lr: 1e-05\n", 158 | "submit: 54\n", 159 | "distilbert \t lr: 1e-05\n" 160 | ] 161 | } 162 | ], 163 | "source": [ 164 | "idx = 1 \n", 165 | "for lr, lr_scale in lr_params:\n", 166 | " for temperature in temps:\n", 167 | " wait_till_all_done(base_job_name)\n", 168 | " \n", 169 | " for datakey in datasets: \n", 170 | " for objective in objectives:\n", 171 | " for ctype in contrast_types:\n", 172 | " for bert in bert_models:\n", 173 | "\n", 174 | " dataname, num_classes, text, label = CLUSTER_DATASETS[datakey]\n", 175 | "\n", 176 | " hyperparameters = {\n", 177 | " 'train_instance': \"sagemaker\",\n", 178 | " 'use_pretrain': use_pretrain,\n", 179 | " 'datapath': s3_dataroot,\n", 180 | " 'dataname': dataname, \n", 181 | " 'text': text,\n", 182 | " 'label': label,\n", 183 | " 'num_classes': num_classes,\n", 184 | " 'bert': bert,\n", 185 | " 'objective': objective,\n", 186 | " 'eta': eta, \n", 187 | " 'augtype': 'virtual',\n", 188 | " 'contrast_type': ctype,\n", 189 | " 'lr': lr,\n", 190 | " 'lr_scale': lr_scale,\n", 191 | " 'lr_scale_contrast': '100',\n", 192 | " 'batch_size': batch_size,\n", 193 | " 'max_length': maxlen,\n", 194 | " 'temperature': temperature,\n", 195 | " 'max_iter': maxiter,\n", 196 | " 'print_freq': '50',\n", 197 | " 'seed': '0',\n", 198 | " 'gpuid': '0',\n", 199 | " 'resdir': '/tmp/resnli/PaperTempRes/',\n", 200 | " 's3_resdir': s3_resdir,\n", 201 | " }\n", 202 | "\n", 203 | " try:\n", 204 | " estimator = PyTorch(entry_point='main.py',\n", 205 | " source_dir='/home/ec2-user/efs/dejiao-explore/code/SCCL/',\n", 206 | " role=role,\n", 207 | " instance_count=1,\n", 208 | " instance_type='ml.p3.2xlarge',\n", 209 | " image_uri='157264205850.dkr.ecr.us-east-1.amazonaws.com/vncl-transformers-p17',\n", 210 | " base_job_name = base_job_name,\n", 211 | " hyperparameters=hyperparameters,\n", 212 | " output_path='s3://dejiao-sagemaker-east1/SCCL/',\n", 213 | " framework_version='1.8.1',\n", 214 | " py_version = 'py3',\n", 215 | " debugger_hook_config=False,\n", 216 | " max_run=3 * 24 * 60 * 60,\n", 217 | " volume_size = 500,\n", 218 | " )\n", 219 | "\n", 220 | " estimator.fit(wait=False)\n", 221 | " print(\"submit: {}\".format(idx))\n", 222 | " except:\n", 223 | " print(\"submit: {} failed\".format(idx))\n", 224 | "\n", 225 | " time.sleep(2)\n", 226 | " idx += 1\n", 227 | "\n", 228 | " print(bert, \"\\t lr:\", lr)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "id": "accepting-insight", 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [] 238 | } 239 | ], 240 | "metadata": { 241 | "kernelspec": { 242 | "display_name": "Environment (conda_pytorch_latest_p37)", 243 | "language": "python", 244 | "name": "conda_pytorch_latest_p37" 245 | }, 246 | "language_info": { 247 | "codemirror_mode": { 248 | "name": "ipython", 249 | "version": 3 250 | }, 251 | "file_extension": ".py", 252 | "mimetype": "text/x-python", 253 | "name": "python", 254 | "nbconvert_exporter": "python", 255 | "pygments_lexer": "ipython3", 256 | "version": "3.7.12" 257 | } 258 | }, 259 | "nbformat": 4, 260 | "nbformat_minor": 5 261 | } 262 | -------------------------------------------------------------------------------- /Notebook/.ipynb_checkpoints/run_explicit-checkpoint.sh: -------------------------------------------------------------------------------- 1 | resdir="/home/ec2-user/efs/dejiao-explore/experiments/SCCL-ec2/" 2 | datapath="s3://dejiao-experiment-east1/datasets/psc_shorttext/" 3 | 4 | bsize=400 5 | maxiter=1000 6 | 7 | 8 | python3 main.py \ 9 | --resdir $resdir/ \ 10 | --use_pretrain SBERT \ 11 | --bert distilbert \ 12 | --datapath $datapath \ 13 | --dataname searchsnippets_trans_subst_20 \ 14 | --num_classes 8 \ 15 | --text text \ 16 | --label label \ 17 | --objective SCCL \ 18 | --augtype explicit \ 19 | --temperature 0.5 \ 20 | --eta 10 \ 21 | --lr 1e-05 \ 22 | --lr_scale 100 \ 23 | --max_length 32 \ 24 | --batch_size $bsize \ 25 | --max_iter $maxiter \ 26 | --print_freq 100 \ 27 | --gpuid 7 & 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /Notebook/.ipynb_checkpoints/run_virtual-checkpoint.sh: -------------------------------------------------------------------------------- 1 | resdir="/home/ec2-user/efs/dejiao-explore/experiments/SCCL-ec2/" 2 | datapath="s3://dejiao-experiment-east1/datasets/psc_shorttext/" 3 | 4 | maxiter=100 5 | bsize=400 6 | 7 | 8 | 9 | python3 main.py \ 10 | --resdir $resdir/ \ 11 | --use_pretrain SBERT \ 12 | --bert distilbert \ 13 | --datapath $datapath \ 14 | --dataname searchsnippets \ 15 | --num_classes 8 \ 16 | --text text \ 17 | --label label \ 18 | --objective SCCL \ 19 | --augtype virtual \ 20 | --temperature 0.5 \ 21 | --eta 10 \ 22 | --lr 1e-05 \ 23 | --lr_scale 100 \ 24 | --max_length 32 \ 25 | --batch_size 400 \ 26 | --max_iter $maxiter \ 27 | --print_freq 100 \ 28 | --gpuid 1 & 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /Notebook/.ipynb_checkpoints/utils-checkpoint.py: -------------------------------------------------------------------------------- 1 | import time 2 | import boto3 3 | client = boto3.client('sagemaker') 4 | 5 | CLUSTER_DATASETS = { 6 | "appen_human":('appen_human_asr', 30, 'text', 'label'), 7 | "appen_asr":('appen_human_asr', 30, 'transcribed', 'label'), 8 | "agnews":('agnews', 4, 'text', 'label'), 9 | "searchsnippets":('searchsnippets', 8, 'text', 'label'), 10 | "stackoverflow":('stackoverflow', 20, 'text', 'label'), 11 | "biomedical":('biomedical', 20, 'text', 'label'), 12 | "tweet":('tweet89', 89, 'text', 'label'), 13 | "googleT":('googlenews_T', 152, 'text', 'label'), 14 | "googleS":('googlenews_S', 152, 'text', 'label'), 15 | "googleTS":('googlenews_TS', 152, 'text', 'label'), 16 | "googlenews_TS_ctxt_substbertroberta_char_02":('googlenews_TS_ctxt_substbertroberta_char_02', 152, 'text', 'label'), 17 | "searchsnippets_ctxt_substbertroberta_char_02":('searchsnippets_ctxt_substbertroberta_char_02', 8, 'text', 'label'), 18 | "stackoverflow_ctxt_substbertroberta_01":('stackoverflow_ctxt_substbertroberta_01', 20, 'text', 'label'), 19 | } 20 | 21 | 22 | CLUSTER_Augmented_DATASETS_CTXT_20 = { 23 | "agnews":('agnews_trans_subst_20', 4, 'text', 'label'), 24 | "searchsnippets":('searchsnippets_trans_subst_20', 8, 'text', 'label'), 25 | "stackoverflow":('stackoverflow_trans_subst_20', 20, 'text', 'label'), 26 | "biomedical":('biomedical_trans_subst_20', 20, 'text', 'label'), 27 | "tweet":('tweet89_trans_subst_20', 89, 'text', 'label'), 28 | "googleT":('googlenews_T_trans_subst_20', 152, 'text', 'label'), 29 | "googleS":('googlenews_S_trans_subst_20', 152, 'text', 'label'), 30 | "googleTS":('googlenews_TS_trans_subst_20', 152, 'text', 'label'), 31 | } 32 | 33 | CLUSTER_Augmented_DATASETS_CTXT_CHAR_20 = { 34 | "agnews":('agnews_trans_subst_20_charswap_20', 4, 'text', 'label'), 35 | "searchsnippets":('searchsnippets_trans_subst_20_charswap_20', 8, 'text', 'label'), 36 | "stackoverflow":('stackoverflow_trans_subst_20_charswap_20', 20, 'text', 'label'), 37 | "biomedical":('biomedical_trans_subst_20_charswap_20', 20, 'text', 'label'), 38 | "tweet":('tweet89_trans_subst_20_charswap_20', 89, 'text', 'label'), 39 | "googleT":('googlenews_T_trans_subst_20_charswap_20', 152, 'text', 'label'), 40 | "googleS":('googlenews_S_trans_subst_20_charswap_20', 152, 'text', 'label'), 41 | "googleTS":('googlenews_TS_trans_subst_20_charswap_20', 152, 'text', 'label'), 42 | } 43 | 44 | 45 | CLUSTER_Augmented_DATASETS_CTXT_CHAR_10 = { 46 | "agnews":('agnews_trans_subst_10_charswap_20', 4, 'text', 'label'), 47 | "searchsnippets":('searchsnippets_trans_subst_10_charswap_20', 8, 'text', 'label'), 48 | "stackoverflow":('stackoverflow_trans_subst_10_charswap_20', 20, 'text', 'label'), 49 | "biomedical":('biomedical_trans_subst_10_charswap_20', 20, 'text', 'label'), 50 | "tweet":('tweet89_trans_subst_10_charswap_20', 89, 'text', 'label'), 51 | "googleT":('googlenews_T_trans_subst_10_charswap_20', 152, 'text', 'label'), 52 | "googleS":('googlenews_S_trans_subst_10_charswap_20', 152, 'text', 'label'), 53 | "googleTS":('googlenews_TS_trans_subst_10_charswap_20', 152, 'text', 'label'), 54 | } 55 | 56 | 57 | CLUSTER_Augmented_DATASETS_CTXT_10 = { 58 | "agnews":('agnews_trans_subst_10', 4, 'text', 'label'), 59 | "searchsnippets":('searchsnippets_trans_subst_10', 8, 'text', 'label'), 60 | "stackoverflow":('stackoverflow_trans_subst_10', 20, 'text', 'label'), 61 | "biomedical":('biomedical_trans_subst_10', 20, 'text', 'label'), 62 | "tweet":('tweet89_trans_subst_10', 89, 'text', 'label'), 63 | "googleT":('googlenews_T_trans_subst_10', 152, 'text', 'label'), 64 | "googleS":('googlenews_S_trans_subst_10', 152, 'text', 'label'), 65 | "googleTS":('googlenews_TS_trans_subst_10', 152, 'text', 'label'), 66 | } 67 | 68 | 69 | CLUSTER_Augmented_DATASETS_WDEL_20 = { 70 | "agnews":('agnews_word_deletion_20', 4, 'text', 'label'), 71 | "searchsnippets":('searchsnippets_word_deletion_20', 8, 'text', 'label'), 72 | "stackoverflow":('stackoverflow_word_deletion_20', 20, 'text', 'label'), 73 | "biomedical":('biomedical_word_deletion_20', 20, 'text', 'label'), 74 | "tweet":('tweet89_word_deletion_20', 89, 'text', 'label'), 75 | "googleT":('googlenews_T_word_deletion_20', 152, 'text', 'label'), 76 | "googleS":('googlenews_S_word_deletion_20', 152, 'text', 'label'), 77 | "googleTS":('googlenews_TS_word_deletion_20', 152, 'text', 'label'), 78 | } 79 | 80 | 81 | CLUSTER_Augmented_DATASETS_WDEL_10 = { 82 | "agnews":('agnews_word_deletion_10', 4, 'text', 'label'), 83 | "searchsnippets":('searchsnippets_word_deletion_10', 8, 'text', 'label'), 84 | "stackoverflow":('stackoverflow_word_deletion_10', 20, 'text', 'label'), 85 | "biomedical":('biomedical_word_deletion_10', 20, 'text', 'label'), 86 | "tweet":('tweet89_word_deletion_10', 89, 'text', 'label'), 87 | "googleT":('googlenews_T_word_deletion_10', 152, 'text', 'label'), 88 | "googleS":('googlenews_S_word_deletion_10', 152, 'text', 'label'), 89 | "googleTS":('googlenews_TS_word_deletion_10', 152, 'text', 'label'), 90 | } 91 | 92 | 93 | 94 | def wait_till_all_done(base_job_name): 95 | while True: 96 | response = client.list_training_jobs(MaxResults=100) 97 | all_status = [ 98 | item['TrainingJobStatus'] for item in response['TrainingJobSummaries'] if item["TrainingJobName"].startswith(base_job_name) 99 | ] 100 | go = all([item != 'InProgress' for item in all_status]) 101 | if go: 102 | break 103 | else: 104 | time.sleep(600) # wait 10 mins 105 | 106 | 107 | -------------------------------------------------------------------------------- /Notebook/run_explicit.sh: -------------------------------------------------------------------------------- 1 | resdir="path-to-store-your-results" 2 | datapath="path-to-your-data" 3 | 4 | bsize=400 5 | maxiter=1000 6 | 7 | 8 | python3 main.py \ 9 | --resdir $resdir/ \ 10 | --use_pretrain SBERT \ 11 | --bert distilbert \ 12 | --datapath $datapath \ 13 | --dataname searchsnippets_trans_subst_20 \ 14 | --num_classes 8 \ 15 | --text text \ 16 | --label label \ 17 | --objective SCCL \ 18 | --augtype explicit \ 19 | --temperature 0.5 \ 20 | --eta 10 \ 21 | --lr 1e-05 \ 22 | --lr_scale 100 \ 23 | --max_length 32 \ 24 | --batch_size $bsize \ 25 | --max_iter $maxiter \ 26 | --print_freq 100 \ 27 | --gpuid 7 & 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /Notebook/run_virtual.sh: -------------------------------------------------------------------------------- 1 | resdir="path-to-store-your-results" 2 | datapath="path-to-your-data" 3 | 4 | maxiter=100 5 | bsize=400 6 | 7 | 8 | 9 | python3 main.py \ 10 | --resdir $resdir/ \ 11 | --use_pretrain SBERT \ 12 | --bert distilbert \ 13 | --datapath $datapath \ 14 | --dataname searchsnippets \ 15 | --num_classes 8 \ 16 | --text text \ 17 | --label label \ 18 | --objective SCCL \ 19 | --augtype virtual \ 20 | --temperature 0.5 \ 21 | --eta 10 \ 22 | --lr 1e-05 \ 23 | --lr_scale 100 \ 24 | --max_length 32 \ 25 | --batch_size 400 \ 26 | --max_iter $maxiter \ 27 | --print_freq 100 \ 28 | --gpuid 1 & 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCCL: Supporting Clustering with Contrastive Learning 2 | 3 | This repository contains the code for our paper [Supporting Clustering with Contrastive Learning (NAACL 2021)](https://aclanthology.org/2021.naacl-main.427.pdf) Dejiao Zhang, Feng Nan, Xiaokai Wei, Shangwen Li, Henghui Zhu, Kathleen McKeown, Ramesh Nallapati, Andrew Arnold, and Bing Xiang. 4 | 5 | **************************** **Updates** **************************** 6 | * 12/11/2021: We updated our code. Now you can run SCCL with virtual augmentations only. 7 | * 05/28/2021: We released our initial code for SCCL, which requires explicit data augmentations. 8 | 9 | 10 | ## Getting Started 11 | 12 | ### Dependencies: 13 | python==3.6.13 14 | pytorch==1.6.0. 15 | sentence-transformers==2.0.0. 16 | transformers==4.8.1. 17 | tensorboardX==2.4.1 18 | pandas==1.1.5 19 | sklearn==0.24.1 20 | numpy==1.19.5 21 | 22 | 23 | ### SCCL with explicit augmentations 24 | 25 | In additional to the original data, SCCL requires a pair of augmented data for each instance. 26 | 27 | The data format is (text, text1, text2) where text1 and text2 are the column names of augmented pairs. 28 | See our NAACL paper for details about the learning objective. 29 | 30 | Step-1. download the original datastes from https://github.com/rashadulrakib/short-text-clustering-enhancement/tree/master/data 31 | 32 | step-2. then obtain the augmented data using the code in ./AugData/ 33 | 34 | step-3 run the code via the following: 35 | 36 | ```python 37 | python3 main.py \ 38 | --resdir $path-to-store-your-results \ 39 | --use_pretrain SBERT \ 40 | --bert distilbert \ 41 | --datapath $path-to-your-data \ 42 | --dataname searchsnippets_trans_subst_20 \ 43 | --num_classes 8 \ 44 | --text text \ 45 | --label label \ 46 | --objective SCCL \ 47 | --augtype explicit \ 48 | --temperature 0.5 \ 49 | --eta 10 \ 50 | --lr 1e-05 \ 51 | --lr_scale 100 \ 52 | --max_length 32 \ 53 | --batch_size 400 \ 54 | --max_iter 3000 \ 55 | --print_freq 100 \ 56 | --gpuid 0 & 57 | 58 | ``` 59 | 60 | 61 | ### SCCL with virtual augmentation 62 | 63 | Download the original datastes from 64 | https://github.com/rashadulrakib/short-text-clustering-enhancement/tree/master/data 65 | 66 | ```python 67 | python3 main.py \ 68 | --resdir $path-to-store-your-results \ 69 | --use_pretrain SBERT \ 70 | --bert distilbert \ 71 | --datapath $path-to-your-data \ 72 | --dataname searchsnippets \ 73 | --num_classes 8 \ 74 | --text text \ 75 | --label label \ 76 | --objective SCCL \ 77 | --augtype virtual \ 78 | --temperature 0.5 \ 79 | --eta 10 \ 80 | --lr 1e-05 \ 81 | --lr_scale 100 \ 82 | --max_length 32 \ 83 | --batch_size 400 \ 84 | --max_iter 1000 \ 85 | --print_freq 100 \ 86 | --gpuid 1 & 87 | 88 | ``` 89 | 90 | 91 | 92 | ## Citation: 93 | 94 | ```bibtex 95 | @inproceedings{zhang-etal-2021-supporting, 96 | title = "Supporting Clustering with Contrastive Learning", 97 | author = "Zhang, Dejiao and 98 | Nan, Feng and 99 | Wei, Xiaokai and 100 | Li, Shang-Wen and 101 | Zhu, Henghui and 102 | McKeown, Kathleen and 103 | Nallapati, Ramesh and 104 | Arnold, Andrew O. and 105 | Xiang, Bing", 106 | booktitle = "Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies", 107 | month = jun, 108 | year = "2021", 109 | address = "Online", 110 | publisher = "Association for Computational Linguistics", 111 | url = "https://aclanthology.org/2021.naacl-main.427", 112 | doi = "10.18653/v1/2021.naacl-main.427", 113 | pages = "5419--5430", 114 | abstract = "Unsupervised clustering aims at discovering the semantic categories of data according to some distance measured in the representation space. However, different categories often overlap with each other in the representation space at the beginning of the learning process, which poses a significant challenge for distance-based clustering in achieving good separation between different categories. To this end, we propose Supporting Clustering with Contrastive Learning (SCCL) {--} a novel framework to leverage contrastive learning to promote better separation. We assess the performance of SCCL on short text clustering and show that SCCL significantly advances the state-of-the-art results on most benchmark datasets with 3{\%}-11{\%} improvement on Accuracy and 4{\%}-15{\%} improvement on Normalized Mutual Information. Furthermore, our quantitative analysis demonstrates the effectiveness of SCCL in leveraging the strengths of both bottom-up instance discrimination and top-down clustering to achieve better intra-cluster and inter-cluster distances when evaluated with the ground truth cluster labels.",} 115 | 116 | ``` 117 | 118 | -------------------------------------------------------------------------------- /dataloader/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/dataloader/.DS_Store -------------------------------------------------------------------------------- /dataloader/.ipynb_checkpoints/dataloader-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | import os 9 | import pandas as pd 10 | import torch.utils.data as util_data 11 | from torch.utils.data import Dataset 12 | 13 | class VirtualAugSamples(Dataset): 14 | def __init__(self, train_x, train_y): 15 | assert len(train_x) == len(train_y) 16 | self.train_x = train_x 17 | self.train_y = train_y 18 | 19 | def __len__(self): 20 | return len(self.train_x) 21 | 22 | def __getitem__(self, idx): 23 | return {'text': self.train_x[idx], 'label': self.train_y[idx]} 24 | 25 | 26 | class ExplitAugSamples(Dataset): 27 | def __init__(self, train_x, train_x1, train_x2, train_y): 28 | assert len(train_y) == len(train_x) == len(train_x1) == len(train_x2) 29 | self.train_x = train_x 30 | self.train_x1 = train_x1 31 | self.train_x2 = train_x2 32 | self.train_y = train_y 33 | 34 | def __len__(self): 35 | return len(self.train_y) 36 | 37 | def __getitem__(self, idx): 38 | return {'text': self.train_x[idx], 'augmentation_1': self.train_x1[idx], 'augmentation_2': self.train_x2[idx], 'label': self.train_y[idx]} 39 | 40 | 41 | def explict_augmentation_loader(args): 42 | train_data = pd.read_csv(os.path.join(args.datapath, args.dataname+".csv")) 43 | train_text = train_data[args.text].fillna('.').values 44 | train_text1 = train_data[args.augmentation_1].fillna('.').values 45 | train_text2 = train_data[args.augmentation_2].fillna('.').values 46 | train_label = train_data[args.label].astype(int).values 47 | 48 | train_dataset = ExplitAugSamples(train_text, train_text1, train_text2, train_label) 49 | train_loader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) 50 | return train_loader 51 | 52 | 53 | def virtual_augmentation_loader(args): 54 | train_data = pd.read_csv(os.path.join(args.datapath, args.dataname+".csv")) 55 | train_text = train_data[args.text].fillna('.').values 56 | train_label = train_data[args.label].astype(int).values 57 | 58 | train_dataset = VirtualAugSamples(train_text, train_label) 59 | train_loader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) 60 | return train_loader 61 | 62 | 63 | def unshuffle_loader(args): 64 | train_data = pd.read_csv(os.path.join(args.datapath, args.dataname+".csv")) 65 | train_text = train_data[args.text].fillna('.').values 66 | train_label = train_data[args.label].astype(int).values 67 | 68 | train_dataset = VirtualAugSamples(train_text, train_label) 69 | train_loader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=1) 70 | return train_loader 71 | 72 | -------------------------------------------------------------------------------- /dataloader/__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/dataloader/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | import os 9 | import pandas as pd 10 | import torch.utils.data as util_data 11 | from torch.utils.data import Dataset 12 | 13 | class VirtualAugSamples(Dataset): 14 | def __init__(self, train_x, train_y): 15 | assert len(train_x) == len(train_y) 16 | self.train_x = train_x 17 | self.train_y = train_y 18 | 19 | def __len__(self): 20 | return len(self.train_x) 21 | 22 | def __getitem__(self, idx): 23 | return {'text': self.train_x[idx], 'label': self.train_y[idx]} 24 | 25 | 26 | class ExplitAugSamples(Dataset): 27 | def __init__(self, train_x, train_x1, train_x2, train_y): 28 | assert len(train_y) == len(train_x) == len(train_x1) == len(train_x2) 29 | self.train_x = train_x 30 | self.train_x1 = train_x1 31 | self.train_x2 = train_x2 32 | self.train_y = train_y 33 | 34 | def __len__(self): 35 | return len(self.train_y) 36 | 37 | def __getitem__(self, idx): 38 | return {'text': self.train_x[idx], 'augmentation_1': self.train_x1[idx], 'augmentation_2': self.train_x2[idx], 'label': self.train_y[idx]} 39 | 40 | 41 | def explict_augmentation_loader(args): 42 | train_data = pd.read_csv(os.path.join(args.datapath, args.dataname+".csv")) 43 | train_text = train_data[args.text].fillna('.').values 44 | train_text1 = train_data[args.augmentation_1].fillna('.').values 45 | train_text2 = train_data[args.augmentation_2].fillna('.').values 46 | train_label = train_data[args.label].astype(int).values 47 | 48 | train_dataset = ExplitAugSamples(train_text, train_text1, train_text2, train_label) 49 | train_loader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) 50 | return train_loader 51 | 52 | 53 | def virtual_augmentation_loader(args): 54 | train_data = pd.read_csv(os.path.join(args.datapath, args.dataname+".csv")) 55 | train_text = train_data[args.text].fillna('.').values 56 | train_label = train_data[args.label].astype(int).values 57 | 58 | train_dataset = VirtualAugSamples(train_text, train_label) 59 | train_loader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) 60 | return train_loader 61 | 62 | 63 | def unshuffle_loader(args): 64 | train_data = pd.read_csv(os.path.join(args.datapath, args.dataname+".csv")) 65 | train_text = train_data[args.text].fillna('.').values 66 | train_label = train_data[args.label].astype(int).values 67 | 68 | train_dataset = VirtualAugSamples(train_text, train_label) 69 | train_loader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=1) 70 | return train_loader 71 | 72 | -------------------------------------------------------------------------------- /learner/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/learner/.DS_Store -------------------------------------------------------------------------------- /learner/.ipynb_checkpoints/cluster_utils-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | eps = 1e-8 12 | 13 | class KLDiv(nn.Module): 14 | def forward(self, predict, target): 15 | assert predict.ndimension()==2,'Input dimension must be 2' 16 | target = target.detach() 17 | p1 = predict + eps 18 | t1 = target + eps 19 | logI = p1.log() 20 | logT = t1.log() 21 | TlogTdI = target * (logT - logI) 22 | kld = TlogTdI.sum(1) 23 | return kld 24 | 25 | class KCL(nn.Module): 26 | def __init__(self): 27 | super(KCL,self).__init__() 28 | self.kld = KLDiv() 29 | 30 | def forward(self, prob1, prob2): 31 | kld = self.kld(prob1, prob2) 32 | return kld.mean() 33 | 34 | def target_distribution(batch: torch.Tensor) -> torch.Tensor: 35 | weight = (batch ** 2) / (torch.sum(batch, 0) + 1e-9) 36 | return (weight.t() / torch.sum(weight, 1)).t() -------------------------------------------------------------------------------- /learner/.ipynb_checkpoints/contrastive_utils-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 12/12/2021 6 | """ 7 | 8 | 9 | from __future__ import print_function 10 | import torch 11 | import torch.nn as nn 12 | import numpy as np 13 | 14 | 15 | class PairConLoss(nn.Module): 16 | def __init__(self, temperature=0.05): 17 | super(PairConLoss, self).__init__() 18 | self.temperature = temperature 19 | self.eps = 1e-08 20 | print(f"\n Initializing PairConLoss \n") 21 | 22 | def forward(self, features_1, features_2): 23 | device = features_1.device 24 | batch_size = features_1.shape[0] 25 | features= torch.cat([features_1, features_2], dim=0) 26 | mask = torch.eye(batch_size, dtype=torch.bool).to(device) 27 | mask = mask.repeat(2, 2) 28 | mask = ~mask 29 | 30 | pos = torch.exp(torch.sum(features_1*features_2, dim=-1) / self.temperature) 31 | pos = torch.cat([pos, pos], dim=0) 32 | neg = torch.exp(torch.mm(features, features.t().contiguous()) / self.temperature) 33 | neg = neg.masked_select(mask).view(2*batch_size, -1) 34 | 35 | neg_mean = torch.mean(neg) 36 | pos_n = torch.mean(pos) 37 | Ng = neg.sum(dim=-1) 38 | 39 | loss_pos = (- torch.log(pos / (Ng+pos))).mean() 40 | 41 | return {"loss":loss_pos, "pos_mean":pos_n.detach().cpu().numpy(), "neg_mean":neg_mean.detach().cpu().numpy(), "pos":pos.detach().cpu().numpy(), "neg":neg.detach().cpu().numpy()} 42 | 43 | 44 | -------------------------------------------------------------------------------- /learner/__pycache__/cluster_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/learner/__pycache__/cluster_utils.cpython-36.pyc -------------------------------------------------------------------------------- /learner/__pycache__/contrastive_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/learner/__pycache__/contrastive_utils.cpython-36.pyc -------------------------------------------------------------------------------- /learner/cluster_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | eps = 1e-8 12 | 13 | class KLDiv(nn.Module): 14 | def forward(self, predict, target): 15 | assert predict.ndimension()==2,'Input dimension must be 2' 16 | target = target.detach() 17 | p1 = predict + eps 18 | t1 = target + eps 19 | logI = p1.log() 20 | logT = t1.log() 21 | TlogTdI = target * (logT - logI) 22 | kld = TlogTdI.sum(1) 23 | return kld 24 | 25 | class KCL(nn.Module): 26 | def __init__(self): 27 | super(KCL,self).__init__() 28 | self.kld = KLDiv() 29 | 30 | def forward(self, prob1, prob2): 31 | kld = self.kld(prob1, prob2) 32 | return kld.mean() 33 | 34 | def target_distribution(batch: torch.Tensor) -> torch.Tensor: 35 | weight = (batch ** 2) / (torch.sum(batch, 0) + 1e-9) 36 | return (weight.t() / torch.sum(weight, 1)).t() -------------------------------------------------------------------------------- /learner/contrastive_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 12/12/2021 6 | """ 7 | 8 | 9 | from __future__ import print_function 10 | import torch 11 | import torch.nn as nn 12 | import numpy as np 13 | 14 | 15 | class PairConLoss(nn.Module): 16 | def __init__(self, temperature=0.05): 17 | super(PairConLoss, self).__init__() 18 | self.temperature = temperature 19 | self.eps = 1e-08 20 | print(f"\n Initializing PairConLoss \n") 21 | 22 | def forward(self, features_1, features_2): 23 | device = features_1.device 24 | batch_size = features_1.shape[0] 25 | features= torch.cat([features_1, features_2], dim=0) 26 | mask = torch.eye(batch_size, dtype=torch.bool).to(device) 27 | mask = mask.repeat(2, 2) 28 | mask = ~mask 29 | 30 | pos = torch.exp(torch.sum(features_1*features_2, dim=-1) / self.temperature) 31 | pos = torch.cat([pos, pos], dim=0) 32 | neg = torch.exp(torch.mm(features, features.t().contiguous()) / self.temperature) 33 | neg = neg.masked_select(mask).view(2*batch_size, -1) 34 | 35 | neg_mean = torch.mean(neg) 36 | pos_n = torch.mean(pos) 37 | Ng = neg.sum(dim=-1) 38 | 39 | loss_pos = (- torch.log(pos / (Ng+pos))).mean() 40 | 41 | return {"loss":loss_pos, "pos_mean":pos_n.detach().cpu().numpy(), "neg_mean":neg_mean.detach().cpu().numpy(), "pos":pos.detach().cpu().numpy(), "neg":neg.detach().cpu().numpy()} 42 | 43 | 44 | -------------------------------------------------------------------------------- /logs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/logs/.DS_Store -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | import os 9 | import sys 10 | sys.path.append( './' ) 11 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 12 | 13 | import torch 14 | import argparse 15 | from models.Transformers import SCCLBert 16 | import dataloader.dataloader as dataloader 17 | from training import SCCLvTrainer 18 | from utils.kmeans import get_kmeans_centers 19 | from utils.logger import setup_path, set_global_random_seed 20 | from utils.optimizer import get_optimizer, get_bert 21 | import numpy as np 22 | 23 | 24 | def run(args): 25 | args.resPath, args.tensorboard = setup_path(args) 26 | set_global_random_seed(args.seed) 27 | 28 | # dataset loader 29 | train_loader = dataloader.explict_augmentation_loader(args) if args.augtype == "explicit" else dataloader.virtual_augmentation_loader(args) 30 | 31 | # model 32 | torch.cuda.set_device(args.gpuid[0]) 33 | bert, tokenizer = get_bert(args) 34 | 35 | # initialize cluster centers 36 | cluster_centers = get_kmeans_centers(bert, tokenizer, train_loader, args.num_classes, args.max_length) 37 | 38 | model = SCCLBert(bert, tokenizer, cluster_centers=cluster_centers, alpha=args.alpha) 39 | model = model.cuda() 40 | 41 | # optimizer 42 | optimizer = get_optimizer(model, args) 43 | 44 | trainer = SCCLvTrainer(model, tokenizer, optimizer, train_loader, args) 45 | trainer.train() 46 | 47 | return None 48 | 49 | def get_args(argv): 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('--train_instance', type=str, default='local') 52 | parser.add_argument('--gpuid', nargs="+", type=int, default=[0], help="The list of gpuid, ex:--gpuid 3 1. Negative value means cpu-only") 53 | parser.add_argument('--seed', type=int, default=0, help="") 54 | parser.add_argument('--print_freq', type=float, default=100, help="") 55 | parser.add_argument('--resdir', type=str, default='./results/') 56 | parser.add_argument('--s3_resdir', type=str, default='./results') 57 | 58 | parser.add_argument('--bert', type=str, default='distilroberta', help="") 59 | parser.add_argument('--use_pretrain', type=str, default='BERT', choices=["BERT", "SBERT", "PAIRSUPCON"]) 60 | 61 | # Dataset 62 | parser.add_argument('--datapath', type=str, default='../datasets/') 63 | parser.add_argument('--dataname', type=str, default='searchsnippets', help="") 64 | parser.add_argument('--num_classes', type=int, default=8, help="") 65 | parser.add_argument('--max_length', type=int, default=32) 66 | parser.add_argument('--label', type=str, default='label') 67 | parser.add_argument('--text', type=str, default='text') 68 | parser.add_argument('--augmentation_1', type=str, default='text1') 69 | parser.add_argument('--augmentation_2', type=str, default='text2') 70 | # Learning parameters 71 | parser.add_argument('--lr', type=float, default=1e-5, help="") 72 | parser.add_argument('--lr_scale', type=int, default=100, help="") 73 | parser.add_argument('--max_iter', type=int, default=1000) 74 | # contrastive learning 75 | parser.add_argument('--objective', type=str, default='contrastive') 76 | parser.add_argument('--augtype', type=str, default='virtual', choices=['virtual', 'explicit']) 77 | parser.add_argument('--batch_size', type=int, default=400) 78 | parser.add_argument('--temperature', type=float, default=0.5, help="temperature required by contrastive loss") 79 | parser.add_argument('--eta', type=float, default=1, help="") 80 | 81 | # Clustering 82 | parser.add_argument('--alpha', type=float, default=1.0) 83 | 84 | args = parser.parse_args(argv) 85 | args.use_gpu = args.gpuid[0] >= 0 86 | args.resPath = None 87 | args.tensorboard = None 88 | 89 | return args 90 | 91 | 92 | 93 | if __name__ == '__main__': 94 | import subprocess 95 | 96 | args = get_args(sys.argv[1:]) 97 | 98 | if args.train_instance == "sagemaker": 99 | run(args) 100 | subprocess.run(["aws", "s3", "cp", "--recursive", args.resdir, args.s3_resdir]) 101 | else: 102 | run(args) 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/models/.DS_Store -------------------------------------------------------------------------------- /models/.ipynb_checkpoints/Transformers-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn import Parameter 12 | from transformers import BertPreTrainedModel 13 | # from transformers import AutoModel, AutoTokenizer 14 | 15 | class SCCLBert(nn.Module): 16 | def __init__(self, bert_model, tokenizer, cluster_centers=None, alpha=1.0): 17 | super(SCCLBert, self).__init__() 18 | 19 | self.tokenizer = tokenizer 20 | self.bert = bert_model 21 | self.emb_size = self.bert.config.hidden_size 22 | self.alpha = alpha 23 | 24 | # Instance-CL head 25 | self.contrast_head = nn.Sequential( 26 | nn.Linear(self.emb_size, self.emb_size), 27 | nn.ReLU(inplace=True), 28 | nn.Linear(self.emb_size, 128)) 29 | 30 | # Clustering head 31 | initial_cluster_centers = torch.tensor( 32 | cluster_centers, dtype=torch.float, requires_grad=True) 33 | self.cluster_centers = Parameter(initial_cluster_centers) 34 | 35 | 36 | def forward(self, input_ids, attention_mask, task_type="virtual"): 37 | if task_type == "evaluate": 38 | return self.get_mean_embeddings(input_ids, attention_mask) 39 | 40 | elif task_type == "virtual": 41 | input_ids_1, input_ids_2 = torch.unbind(input_ids, dim=1) 42 | attention_mask_1, attention_mask_2 = torch.unbind(attention_mask, dim=1) 43 | 44 | mean_output_1 = self.get_mean_embeddings(input_ids_1, attention_mask_1) 45 | mean_output_2 = self.get_mean_embeddings(input_ids_2, attention_mask_2) 46 | return mean_output_1, mean_output_2 47 | 48 | elif task_type == "explicit": 49 | input_ids_1, input_ids_2, input_ids_3 = torch.unbind(input_ids, dim=1) 50 | attention_mask_1, attention_mask_2, attention_mask_3 = torch.unbind(attention_mask, dim=1) 51 | 52 | mean_output_1 = self.get_mean_embeddings(input_ids_1, attention_mask_1) 53 | mean_output_2 = self.get_mean_embeddings(input_ids_2, attention_mask_2) 54 | mean_output_3 = self.get_mean_embeddings(input_ids_3, attention_mask_3) 55 | return mean_output_1, mean_output_2, mean_output_3 56 | 57 | else: 58 | raise Exception("TRANSFORMER ENCODING TYPE ERROR! OPTIONS: [EVALUATE, VIRTUAL, EXPLICIT]") 59 | 60 | 61 | def get_mean_embeddings(self, input_ids, attention_mask): 62 | bert_output = self.bert.forward(input_ids=input_ids, attention_mask=attention_mask) 63 | attention_mask = attention_mask.unsqueeze(-1) 64 | mean_output = torch.sum(bert_output[0]*attention_mask, dim=1) / torch.sum(attention_mask, dim=1) 65 | return mean_output 66 | 67 | 68 | def get_cluster_prob(self, embeddings): 69 | norm_squared = torch.sum((embeddings.unsqueeze(1) - self.cluster_centers) ** 2, 2) 70 | numerator = 1.0 / (1.0 + (norm_squared / self.alpha)) 71 | power = float(self.alpha + 1) / 2 72 | numerator = numerator ** power 73 | return numerator / torch.sum(numerator, dim=1, keepdim=True) 74 | 75 | def local_consistency(self, embd0, embd1, embd2, criterion): 76 | p0 = self.get_cluster_prob(embd0) 77 | p1 = self.get_cluster_prob(embd1) 78 | p2 = self.get_cluster_prob(embd2) 79 | 80 | lds1 = criterion(p1, p0) 81 | lds2 = criterion(p2, p0) 82 | return lds1+lds2 83 | 84 | def contrast_logits(self, embd1, embd2=None): 85 | feat1 = F.normalize(self.contrast_head(embd1), dim=1) 86 | if embd2 != None: 87 | feat2 = F.normalize(self.contrast_head(embd2), dim=1) 88 | return feat1, feat2 89 | else: 90 | return feat1 91 | 92 | 93 | -------------------------------------------------------------------------------- /models/Transformers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn import Parameter 12 | from transformers import BertPreTrainedModel 13 | # from transformers import AutoModel, AutoTokenizer 14 | 15 | class SCCLBert(nn.Module): 16 | def __init__(self, bert_model, tokenizer, cluster_centers=None, alpha=1.0): 17 | super(SCCLBert, self).__init__() 18 | 19 | self.tokenizer = tokenizer 20 | self.bert = bert_model 21 | self.emb_size = self.bert.config.hidden_size 22 | self.alpha = alpha 23 | 24 | # Instance-CL head 25 | self.contrast_head = nn.Sequential( 26 | nn.Linear(self.emb_size, self.emb_size), 27 | nn.ReLU(inplace=True), 28 | nn.Linear(self.emb_size, 128)) 29 | 30 | # Clustering head 31 | initial_cluster_centers = torch.tensor( 32 | cluster_centers, dtype=torch.float, requires_grad=True) 33 | self.cluster_centers = Parameter(initial_cluster_centers) 34 | 35 | 36 | def forward(self, input_ids, attention_mask, task_type="virtual"): 37 | if task_type == "evaluate": 38 | return self.get_mean_embeddings(input_ids, attention_mask) 39 | 40 | elif task_type == "virtual": 41 | input_ids_1, input_ids_2 = torch.unbind(input_ids, dim=1) 42 | attention_mask_1, attention_mask_2 = torch.unbind(attention_mask, dim=1) 43 | 44 | mean_output_1 = self.get_mean_embeddings(input_ids_1, attention_mask_1) 45 | mean_output_2 = self.get_mean_embeddings(input_ids_2, attention_mask_2) 46 | return mean_output_1, mean_output_2 47 | 48 | elif task_type == "explicit": 49 | input_ids_1, input_ids_2, input_ids_3 = torch.unbind(input_ids, dim=1) 50 | attention_mask_1, attention_mask_2, attention_mask_3 = torch.unbind(attention_mask, dim=1) 51 | 52 | mean_output_1 = self.get_mean_embeddings(input_ids_1, attention_mask_1) 53 | mean_output_2 = self.get_mean_embeddings(input_ids_2, attention_mask_2) 54 | mean_output_3 = self.get_mean_embeddings(input_ids_3, attention_mask_3) 55 | return mean_output_1, mean_output_2, mean_output_3 56 | 57 | else: 58 | raise Exception("TRANSFORMER ENCODING TYPE ERROR! OPTIONS: [EVALUATE, VIRTUAL, EXPLICIT]") 59 | 60 | 61 | def get_mean_embeddings(self, input_ids, attention_mask): 62 | bert_output = self.bert.forward(input_ids=input_ids, attention_mask=attention_mask) 63 | attention_mask = attention_mask.unsqueeze(-1) 64 | mean_output = torch.sum(bert_output[0]*attention_mask, dim=1) / torch.sum(attention_mask, dim=1) 65 | return mean_output 66 | 67 | 68 | def get_cluster_prob(self, embeddings): 69 | norm_squared = torch.sum((embeddings.unsqueeze(1) - self.cluster_centers) ** 2, 2) 70 | numerator = 1.0 / (1.0 + (norm_squared / self.alpha)) 71 | power = float(self.alpha + 1) / 2 72 | numerator = numerator ** power 73 | return numerator / torch.sum(numerator, dim=1, keepdim=True) 74 | 75 | def local_consistency(self, embd0, embd1, embd2, criterion): 76 | p0 = self.get_cluster_prob(embd0) 77 | p1 = self.get_cluster_prob(embd1) 78 | p2 = self.get_cluster_prob(embd2) 79 | 80 | lds1 = criterion(p1, p0) 81 | lds2 = criterion(p2, p0) 82 | return lds1+lds2 83 | 84 | def contrast_logits(self, embd1, embd2=None): 85 | feat1 = F.normalize(self.contrast_head(embd1), dim=1) 86 | if embd2 != None: 87 | feat2 = F.normalize(self.contrast_head(embd2), dim=1) 88 | return feat1, feat2 89 | else: 90 | return feat1 91 | 92 | 93 | -------------------------------------------------------------------------------- /models/__pycache__/Transformers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/models/__pycache__/Transformers.cpython-36.pyc -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | import os 9 | import time 10 | import numpy as np 11 | from sklearn import cluster 12 | 13 | from utils.logger import statistics_log 14 | from utils.metric import Confusion 15 | from dataloader.dataloader import unshuffle_loader 16 | 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn import functional as F 20 | from learner.cluster_utils import target_distribution 21 | from learner.contrastive_utils import PairConLoss 22 | 23 | class SCCLvTrainer(nn.Module): 24 | def __init__(self, model, tokenizer, optimizer, train_loader, args): 25 | super(SCCLvTrainer, self).__init__() 26 | self.model = model 27 | self.tokenizer = tokenizer 28 | self.optimizer = optimizer 29 | self.train_loader = train_loader 30 | self.args = args 31 | self.eta = self.args.eta 32 | 33 | self.cluster_loss = nn.KLDivLoss(size_average=False) 34 | self.contrast_loss = PairConLoss(temperature=self.args.temperature) 35 | 36 | self.gstep = 0 37 | print(f"*****Intialize SCCLv, temp:{self.args.temperature}, eta:{self.args.eta}\n") 38 | 39 | def get_batch_token(self, text): 40 | token_feat = self.tokenizer.batch_encode_plus( 41 | text, 42 | max_length=self.args.max_length, 43 | return_tensors='pt', 44 | padding='max_length', 45 | truncation=True 46 | ) 47 | return token_feat 48 | 49 | 50 | def prepare_transformer_input(self, batch): 51 | if len(batch) == 4: 52 | text1, text2, text3 = batch['text'], batch['augmentation_1'], batch['augmentation_2'] 53 | feat1 = self.get_batch_token(text1) 54 | feat2 = self.get_batch_token(text2) 55 | feat3 = self.get_batch_token(text3) 56 | 57 | input_ids = torch.cat([feat1['input_ids'].unsqueeze(1), feat2['input_ids'].unsqueeze(1), feat3['input_ids'].unsqueeze(1)], dim=1) 58 | attention_mask = torch.cat([feat1['attention_mask'].unsqueeze(1), feat2['attention_mask'].unsqueeze(1), feat3['attention_mask'].unsqueeze(1)], dim=1) 59 | 60 | elif len(batch) == 2: 61 | text = batch['text'] 62 | feat1 = self.get_batch_token(text) 63 | feat2 = self.get_batch_token(text) 64 | 65 | input_ids = torch.cat([feat1['input_ids'].unsqueeze(1), feat2['input_ids'].unsqueeze(1)], dim=1) 66 | attention_mask = torch.cat([feat1['attention_mask'].unsqueeze(1), feat2['attention_mask'].unsqueeze(1)], dim=1) 67 | 68 | return input_ids.cuda(), attention_mask.cuda() 69 | 70 | 71 | def train_step_virtual(self, input_ids, attention_mask): 72 | 73 | embd1, embd2 = self.model(input_ids, attention_mask, task_type="virtual") 74 | 75 | # Instance-CL loss 76 | feat1, feat2 = self.model.contrast_logits(embd1, embd2) 77 | losses = self.contrast_loss(feat1, feat2) 78 | loss = self.eta * losses["loss"] 79 | 80 | # Clustering loss 81 | if self.args.objective == "SCCL": 82 | output = self.model.get_cluster_prob(embd1) 83 | target = target_distribution(output).detach() 84 | 85 | cluster_loss = self.cluster_loss((output+1e-08).log(), target)/output.shape[0] 86 | loss += 0.5*cluster_loss 87 | losses["cluster_loss"] = cluster_loss.item() 88 | 89 | loss.backward() 90 | self.optimizer.step() 91 | self.optimizer.zero_grad() 92 | return losses 93 | 94 | 95 | def train_step_explicit(self, input_ids, attention_mask): 96 | 97 | embd1, embd2, embd3 = self.model(input_ids, attention_mask, task_type="explicit") 98 | 99 | # Instance-CL loss 100 | feat1, feat2 = self.model.contrast_logits(embd2, embd3) 101 | losses = self.contrast_loss(feat1, feat2) 102 | loss = self.eta * losses["loss"] 103 | 104 | # Clustering loss 105 | if self.args.objective == "SCCL": 106 | output = self.model.get_cluster_prob(embd1) 107 | target = target_distribution(output).detach() 108 | 109 | cluster_loss = self.cluster_loss((output+1e-08).log(), target)/output.shape[0] 110 | loss += cluster_loss 111 | losses["cluster_loss"] = cluster_loss.item() 112 | 113 | loss.backward() 114 | self.optimizer.step() 115 | self.optimizer.zero_grad() 116 | return losses 117 | 118 | 119 | def train(self): 120 | print('\n={}/{}=Iterations/Batches'.format(self.args.max_iter, len(self.train_loader))) 121 | 122 | self.model.train() 123 | for i in np.arange(self.args.max_iter+1): 124 | try: 125 | batch = next(train_loader_iter) 126 | except: 127 | train_loader_iter = iter(self.train_loader) 128 | batch = next(train_loader_iter) 129 | 130 | input_ids, attention_mask = self.prepare_transformer_input(batch) 131 | 132 | losses = self.train_step_virtual(input_ids, attention_mask) if self.args.augtype == "virtual" else self.train_step_explicit(input_ids, attention_mask) 133 | 134 | if (self.args.print_freq>0) and ((i%self.args.print_freq==0) or (i==self.args.max_iter)): 135 | statistics_log(self.args.tensorboard, losses=losses, global_step=i) 136 | self.evaluate_embedding(i) 137 | self.model.train() 138 | 139 | return None 140 | 141 | 142 | def evaluate_embedding(self, step): 143 | dataloader = unshuffle_loader(self.args) 144 | print('---- {} evaluation batches ----'.format(len(dataloader))) 145 | 146 | self.model.eval() 147 | for i, batch in enumerate(dataloader): 148 | with torch.no_grad(): 149 | text, label = batch['text'], batch['label'] 150 | feat = self.get_batch_token(text) 151 | embeddings = self.model(feat['input_ids'].cuda(), feat['attention_mask'].cuda(), task_type="evaluate") 152 | 153 | model_prob = self.model.get_cluster_prob(embeddings) 154 | if i == 0: 155 | all_labels = label 156 | all_embeddings = embeddings.detach() 157 | all_prob = model_prob 158 | else: 159 | all_labels = torch.cat((all_labels, label), dim=0) 160 | all_embeddings = torch.cat((all_embeddings, embeddings.detach()), dim=0) 161 | all_prob = torch.cat((all_prob, model_prob), dim=0) 162 | 163 | # Initialize confusion matrices 164 | confusion, confusion_model = Confusion(self.args.num_classes), Confusion(self.args.num_classes) 165 | 166 | all_pred = all_prob.max(1)[1] 167 | confusion_model.add(all_pred, all_labels) 168 | confusion_model.optimal_assignment(self.args.num_classes) 169 | acc_model = confusion_model.acc() 170 | 171 | kmeans = cluster.KMeans(n_clusters=self.args.num_classes, random_state=self.args.seed) 172 | embeddings = all_embeddings.cpu().numpy() 173 | kmeans.fit(embeddings) 174 | pred_labels = torch.tensor(kmeans.labels_.astype(np.int)) 175 | 176 | # clustering accuracy 177 | confusion.add(pred_labels, all_labels) 178 | confusion.optimal_assignment(self.args.num_classes) 179 | acc = confusion.acc() 180 | 181 | ressave = {"acc":acc, "acc_model":acc_model} 182 | ressave.update(confusion.clusterscores()) 183 | for key, val in ressave.items(): 184 | self.args.tensorboard.add_scalar('Test/{}'.format(key), val, step) 185 | 186 | np.save(self.args.resPath + 'acc_{}.npy'.format(step), ressave) 187 | np.save(self.args.resPath + 'scores_{}.npy'.format(step), confusion.clusterscores()) 188 | np.save(self.args.resPath + 'mscores_{}.npy'.format(step), confusion_model.clusterscores()) 189 | # np.save(self.args.resPath + 'mpredlabels_{}.npy'.format(step), all_pred.cpu().numpy()) 190 | # np.save(self.args.resPath + 'predlabels_{}.npy'.format(step), pred_labels.cpu().numpy()) 191 | # np.save(self.args.resPath + 'embeddings_{}.npy'.format(step), embeddings) 192 | # np.save(self.args.resPath + 'labels_{}.npy'.format(step), all_labels.cpu()) 193 | 194 | print('[Representation] Clustering scores:',confusion.clusterscores()) 195 | print('[Representation] ACC: {:.3f}'.format(acc)) 196 | print('[Model] Clustering scores:',confusion_model.clusterscores()) 197 | print('[Model] ACC: {:.3f}'.format(acc_model)) 198 | return None 199 | 200 | 201 | 202 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/utils/.DS_Store -------------------------------------------------------------------------------- /utils/.ipynb_checkpoints/kmeans-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | from utils.metric import Confusion 11 | from sklearn.cluster import KMeans 12 | 13 | 14 | def get_mean_embeddings(bert, input_ids, attention_mask): 15 | bert_output = bert.forward(input_ids=input_ids, attention_mask=attention_mask) 16 | attention_mask = attention_mask.unsqueeze(-1) 17 | mean_output = torch.sum(bert_output[0]*attention_mask, dim=1) / torch.sum(attention_mask, dim=1) 18 | return mean_output 19 | 20 | 21 | def get_batch_token(tokenizer, text, max_length): 22 | token_feat = tokenizer.batch_encode_plus( 23 | text, 24 | max_length=max_length, 25 | return_tensors='pt', 26 | padding='max_length', 27 | truncation=True 28 | ) 29 | return token_feat 30 | 31 | 32 | def get_kmeans_centers(bert, tokenizer, train_loader, num_classes, max_length): 33 | for i, batch in enumerate(train_loader): 34 | 35 | text, label = batch['text'], batch['label'] 36 | tokenized_features = get_batch_token(tokenizer, text, max_length) 37 | corpus_embeddings = get_mean_embeddings(bert, **tokenized_features) 38 | 39 | if i == 0: 40 | all_labels = label 41 | all_embeddings = corpus_embeddings.detach().numpy() 42 | else: 43 | all_labels = torch.cat((all_labels, label), dim=0) 44 | all_embeddings = np.concatenate((all_embeddings, corpus_embeddings.detach().numpy()), axis=0) 45 | 46 | # Perform KMeans clustering 47 | confusion = Confusion(num_classes) 48 | clustering_model = KMeans(n_clusters=num_classes) 49 | clustering_model.fit(all_embeddings) 50 | cluster_assignment = clustering_model.labels_ 51 | 52 | true_labels = all_labels 53 | pred_labels = torch.tensor(cluster_assignment) 54 | print("all_embeddings:{}, true_labels:{}, pred_labels:{}".format(all_embeddings.shape, len(true_labels), len(pred_labels))) 55 | 56 | confusion.add(pred_labels, true_labels) 57 | confusion.optimal_assignment(num_classes) 58 | print("Iterations:{}, Clustering ACC:{:.3f}, centers:{}".format(clustering_model.n_iter_, confusion.acc(), clustering_model.cluster_centers_.shape)) 59 | 60 | return clustering_model.cluster_centers_ 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /utils/.ipynb_checkpoints/logger-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | import os 9 | from tensorboardX import SummaryWriter 10 | import random 11 | import torch 12 | import numpy as np 13 | 14 | 15 | def set_global_random_seed(seed): 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | np.random.seed(seed) 19 | random.seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = True 22 | 23 | 24 | def setup_path(args): 25 | resPath = "SCCL" 26 | resPath += f'.{args.bert}' 27 | resPath += f'.{args.use_pretrain}' 28 | resPath += f'.{args.augtype}' 29 | resPath += f'.{args.dataname}' 30 | resPath += f".{args.text}" 31 | resPath += f'.lr{args.lr}' 32 | resPath += f'.lrscale{args.lr_scale}' 33 | resPath += f'.{args.objective}' 34 | resPath += f'.eta{args.eta}' 35 | resPath += f'.tmp{args.temperature}' 36 | resPath += f'.alpha{args.alpha}' 37 | resPath += f'.seed{args.seed}/' 38 | resPath = args.resdir + resPath 39 | print(f'results path: {resPath}') 40 | 41 | tensorboard = SummaryWriter(resPath) 42 | return resPath, tensorboard 43 | 44 | 45 | def statistics_log(tensorboard, losses=None, global_step=0): 46 | print("[{}]-----".format(global_step)) 47 | if losses is not None: 48 | for key, val in losses.items(): 49 | if key in ["pos", "neg", "pos_diag", "pos_rand", "neg_offdiag"]: 50 | tensorboard.add_histogram('train/'+key, val, global_step) 51 | else: 52 | try: 53 | tensorboard.add_scalar('train/'+key, val.item(), global_step) 54 | except: 55 | tensorboard.add_scalar('train/'+key, val, global_step) 56 | print("{}:\t {:.3f}".format(key, val)) 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /utils/.ipynb_checkpoints/optimizer-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | import os 9 | import torch 10 | from transformers import get_linear_schedule_with_warmup 11 | from transformers import AutoModel, AutoTokenizer, AutoConfig 12 | from sentence_transformers import SentenceTransformer 13 | 14 | BERT_CLASS = { 15 | "distilbert": 'distilbert-base-uncased', 16 | } 17 | 18 | SBERT_CLASS = { 19 | "distilbert": 'distilbert-base-nli-stsb-mean-tokens', 20 | } 21 | 22 | 23 | def get_optimizer(model, args): 24 | 25 | optimizer = torch.optim.Adam([ 26 | {'params':model.bert.parameters()}, 27 | {'params':model.contrast_head.parameters(), 'lr': args.lr*args.lr_scale}, 28 | {'params':model.cluster_centers, 'lr': args.lr*args.lr_scale} 29 | ], lr=args.lr) 30 | 31 | print(optimizer) 32 | return optimizer 33 | 34 | 35 | def get_bert(args): 36 | 37 | if args.use_pretrain == "SBERT": 38 | bert_model = get_sbert(args) 39 | tokenizer = bert_model[0].tokenizer 40 | model = bert_model[0].auto_model 41 | print("..... loading Sentence-BERT !!!") 42 | else: 43 | config = AutoConfig.from_pretrained(BERT_CLASS[args.bert]) 44 | model = AutoModel.from_pretrained(BERT_CLASS[args.bert], config=config) 45 | tokenizer = AutoTokenizer.from_pretrained(BERT_CLASS[args.bert]) 46 | print("..... loading plain BERT !!!") 47 | 48 | return model, tokenizer 49 | 50 | 51 | def get_sbert(args): 52 | sbert = SentenceTransformer(SBERT_CLASS[args.bert]) 53 | return sbert 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/kmeans.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/utils/__pycache__/kmeans.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/utils/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/optimizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sccl/46e700fe170625227dca1ba08ddbbe4665eae0b6/utils/__pycache__/optimizer.cpython-36.pyc -------------------------------------------------------------------------------- /utils/kmeans.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | from utils.metric import Confusion 11 | from sklearn.cluster import KMeans 12 | 13 | 14 | def get_mean_embeddings(bert, input_ids, attention_mask): 15 | bert_output = bert.forward(input_ids=input_ids, attention_mask=attention_mask) 16 | attention_mask = attention_mask.unsqueeze(-1) 17 | mean_output = torch.sum(bert_output[0]*attention_mask, dim=1) / torch.sum(attention_mask, dim=1) 18 | return mean_output 19 | 20 | 21 | def get_batch_token(tokenizer, text, max_length): 22 | token_feat = tokenizer.batch_encode_plus( 23 | text, 24 | max_length=max_length, 25 | return_tensors='pt', 26 | padding='max_length', 27 | truncation=True 28 | ) 29 | return token_feat 30 | 31 | 32 | def get_kmeans_centers(bert, tokenizer, train_loader, num_classes, max_length): 33 | for i, batch in enumerate(train_loader): 34 | 35 | text, label = batch['text'], batch['label'] 36 | tokenized_features = get_batch_token(tokenizer, text, max_length) 37 | corpus_embeddings = get_mean_embeddings(bert, **tokenized_features) 38 | 39 | if i == 0: 40 | all_labels = label 41 | all_embeddings = corpus_embeddings.detach().numpy() 42 | else: 43 | all_labels = torch.cat((all_labels, label), dim=0) 44 | all_embeddings = np.concatenate((all_embeddings, corpus_embeddings.detach().numpy()), axis=0) 45 | 46 | # Perform KMeans clustering 47 | confusion = Confusion(num_classes) 48 | clustering_model = KMeans(n_clusters=num_classes) 49 | clustering_model.fit(all_embeddings) 50 | cluster_assignment = clustering_model.labels_ 51 | 52 | true_labels = all_labels 53 | pred_labels = torch.tensor(cluster_assignment) 54 | print("all_embeddings:{}, true_labels:{}, pred_labels:{}".format(all_embeddings.shape, len(true_labels), len(pred_labels))) 55 | 56 | confusion.add(pred_labels, true_labels) 57 | confusion.optimal_assignment(num_classes) 58 | print("Iterations:{}, Clustering ACC:{:.3f}, centers:{}".format(clustering_model.n_iter_, confusion.acc(), clustering_model.cluster_centers_.shape)) 59 | 60 | return clustering_model.cluster_centers_ 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | import os 9 | from tensorboardX import SummaryWriter 10 | import random 11 | import torch 12 | import numpy as np 13 | 14 | 15 | def set_global_random_seed(seed): 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | np.random.seed(seed) 19 | random.seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = True 22 | 23 | 24 | def setup_path(args): 25 | resPath = "SCCL" 26 | resPath += f'.{args.bert}' 27 | resPath += f'.{args.use_pretrain}' 28 | resPath += f'.{args.augtype}' 29 | resPath += f'.{args.dataname}' 30 | resPath += f".{args.text}" 31 | resPath += f'.lr{args.lr}' 32 | resPath += f'.lrscale{args.lr_scale}' 33 | resPath += f'.{args.objective}' 34 | resPath += f'.eta{args.eta}' 35 | resPath += f'.tmp{args.temperature}' 36 | resPath += f'.alpha{args.alpha}' 37 | resPath += f'.seed{args.seed}/' 38 | resPath = args.resdir + resPath 39 | print(f'results path: {resPath}') 40 | 41 | tensorboard = SummaryWriter(resPath) 42 | return resPath, tensorboard 43 | 44 | 45 | def statistics_log(tensorboard, losses=None, global_step=0): 46 | print("[{}]-----".format(global_step)) 47 | if losses is not None: 48 | for key, val in losses.items(): 49 | if key in ["pos", "neg", "pos_diag", "pos_rand", "neg_offdiag"]: 50 | tensorboard.add_histogram('train/'+key, val, global_step) 51 | else: 52 | try: 53 | tensorboard.add_scalar('train/'+key, val.item(), global_step) 54 | except: 55 | tensorboard.add_scalar('train/'+key, val, global_step) 56 | print("{}:\t {:.3f}".format(key, val)) 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | from __future__ import print_function 9 | import time 10 | import torch 11 | import numpy as np 12 | from scipy.optimize import linear_sum_assignment as hungarian 13 | from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score, adjusted_mutual_info_score 14 | 15 | cluster_nmi = normalized_mutual_info_score 16 | def cluster_acc(y_true, y_pred): 17 | y_true = y_true.astype(np.int64) 18 | assert y_pred.size == y_true.size 19 | D = max(y_pred.max(), y_true.max()) + 1 20 | w = np.zeros((D, D), dtype=np.int64) 21 | for i in range(y_pred.size): 22 | w[y_pred[i], y_true[i]] += 1 23 | 24 | # ind = sklearn.utils.linear_assignment_.linear_assignment(w.max() - w) 25 | # row_ind, col_ind = linear_assignment(w.max() - w) 26 | row_ind, col_ind = hungarian(w.max() - w) 27 | return sum([w[i, j] for i, j in zip(row_ind, col_ind)]) * 1.0 / y_pred.size 28 | 29 | class AverageMeter(object): 30 | """Computes and stores the average and current value""" 31 | def __init__(self): 32 | self.reset() 33 | 34 | def reset(self): 35 | self.val = 0 36 | self.avg = 0 37 | self.sum = 0 38 | self.count = 0 39 | 40 | def update(self, val, n=1): 41 | self.val = val 42 | self.sum += val * n 43 | self.count += n 44 | self.avg = float(self.sum) / self.count 45 | 46 | class Timer(object): 47 | """ 48 | """ 49 | def __init__(self): 50 | self.reset() 51 | 52 | def reset(self): 53 | self.interval = 0 54 | self.time = time.time() 55 | 56 | def value(self): 57 | return time.time() - self.time 58 | 59 | def tic(self): 60 | self.time = time.time() 61 | 62 | def toc(self): 63 | self.interval = time.time() - self.time 64 | self.time = time.time() 65 | return self.interval 66 | 67 | class Confusion(object): 68 | """ 69 | column of confusion matrix: predicted index 70 | row of confusion matrix: target index 71 | """ 72 | def __init__(self, k, normalized = False): 73 | super(Confusion, self).__init__() 74 | self.k = k 75 | self.conf = torch.LongTensor(k,k) 76 | self.normalized = normalized 77 | self.reset() 78 | 79 | def reset(self): 80 | self.conf.fill_(0) 81 | self.gt_n_cluster = None 82 | 83 | def cuda(self): 84 | self.conf = self.conf.cuda() 85 | 86 | def add(self, output, target): 87 | output = output.squeeze() 88 | target = target.squeeze() 89 | assert output.size(0) == target.size(0), \ 90 | 'number of targets and outputs do not match' 91 | if output.ndimension()>1: #it is the raw probabilities over classes 92 | assert output.size(1) == self.conf.size(0), \ 93 | 'number of outputs does not match size of confusion matrix' 94 | 95 | _,pred = output.max(1) #find the predicted class 96 | else: #it is already the predicted class 97 | pred = output 98 | indices = (target*self.conf.stride(0) + pred.squeeze_().type_as(target)).type_as(self.conf) 99 | ones = torch.ones(1).type_as(self.conf).expand(indices.size(0)) 100 | self._conf_flat = self.conf.view(-1) 101 | self._conf_flat.index_add_(0, indices, ones) 102 | 103 | def classIoU(self,ignore_last=False): 104 | confusion_tensor = self.conf 105 | if ignore_last: 106 | confusion_tensor = self.conf.narrow(0,0,self.k-1).narrow(1,0,self.k-1) 107 | union = confusion_tensor.sum(0).view(-1) + confusion_tensor.sum(1).view(-1) - confusion_tensor.diag().view(-1) 108 | acc = confusion_tensor.diag().float().view(-1).div(union.float()+1) 109 | return acc 110 | 111 | def recall(self,clsId): 112 | i = clsId 113 | TP = self.conf[i,i].sum().item() 114 | TPuFN = self.conf[i,:].sum().item() 115 | if TPuFN==0: 116 | return 0 117 | return float(TP)/TPuFN 118 | 119 | def precision(self,clsId): 120 | i = clsId 121 | TP = self.conf[i,i].sum().item() 122 | TPuFP = self.conf[:,i].sum().item() 123 | if TPuFP==0: 124 | return 0 125 | return float(TP)/TPuFP 126 | 127 | def f1score(self,clsId): 128 | r = self.recall(clsId) 129 | p = self.precision(clsId) 130 | print("classID:{}, precision:{:.4f}, recall:{:.4f}".format(clsId, p, r)) 131 | if (p+r)==0: 132 | return 0 133 | return 2*float(p*r)/(p+r) 134 | 135 | def acc(self): 136 | TP = self.conf.diag().sum().item() 137 | total = self.conf.sum().item() 138 | if total==0: 139 | return 0 140 | return float(TP)/total 141 | 142 | def optimal_assignment(self,gt_n_cluster=None,assign=None): 143 | if assign is None: 144 | mat = -self.conf.cpu().numpy() #hungaian finds the minimum cost 145 | r,assign = hungarian(mat) 146 | self.conf = self.conf[:,assign] 147 | self.gt_n_cluster = gt_n_cluster 148 | return assign 149 | 150 | def show(self,width=6,row_labels=None,column_labels=None): 151 | print("Confusion Matrix:") 152 | conf = self.conf 153 | rows = self.gt_n_cluster or conf.size(0) 154 | cols = conf.size(1) 155 | if column_labels is not None: 156 | print(("%" + str(width) + "s") % '', end='') 157 | for c in column_labels: 158 | print(("%" + str(width) + "s") % c, end='') 159 | print('') 160 | for i in range(0,rows): 161 | if row_labels is not None: 162 | print(("%" + str(width) + "s|") % row_labels[i], end='') 163 | for j in range(0,cols): 164 | print(("%"+str(width)+".d")%conf[i,j],end='') 165 | print('') 166 | 167 | def conf2label(self): 168 | conf=self.conf 169 | gt_classes_count=conf.sum(1).squeeze() 170 | n_sample = gt_classes_count.sum().item() 171 | gt_label = torch.zeros(n_sample) 172 | pred_label = torch.zeros(n_sample) 173 | cur_idx = 0 174 | for c in range(conf.size(0)): 175 | if gt_classes_count[c]>0: 176 | gt_label[cur_idx:cur_idx+gt_classes_count[c]].fill_(c) 177 | for p in range(conf.size(1)): 178 | if conf[c][p]>0: 179 | pred_label[cur_idx:cur_idx+conf[c][p]].fill_(p) 180 | cur_idx = cur_idx + conf[c][p]; 181 | return gt_label,pred_label 182 | 183 | def clusterscores(self): 184 | target,pred = self.conf2label() 185 | NMI = normalized_mutual_info_score(target,pred) 186 | ARI = adjusted_rand_score(target,pred) 187 | AMI = adjusted_mutual_info_score(target,pred) 188 | return {'NMI':NMI,'ARI':ARI,'AMI':AMI} 189 | -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved 3 | 4 | Author: Dejiao Zhang (dejiaoz@amazon.com) 5 | Date: 02/26/2021 6 | """ 7 | 8 | import os 9 | import torch 10 | from transformers import get_linear_schedule_with_warmup 11 | from transformers import AutoModel, AutoTokenizer, AutoConfig 12 | from sentence_transformers import SentenceTransformer 13 | 14 | BERT_CLASS = { 15 | "distilbert": 'distilbert-base-uncased', 16 | } 17 | 18 | SBERT_CLASS = { 19 | "distilbert": 'distilbert-base-nli-stsb-mean-tokens', 20 | } 21 | 22 | 23 | def get_optimizer(model, args): 24 | 25 | optimizer = torch.optim.Adam([ 26 | {'params':model.bert.parameters()}, 27 | {'params':model.contrast_head.parameters(), 'lr': args.lr*args.lr_scale}, 28 | {'params':model.cluster_centers, 'lr': args.lr*args.lr_scale} 29 | ], lr=args.lr) 30 | 31 | print(optimizer) 32 | return optimizer 33 | 34 | 35 | def get_bert(args): 36 | 37 | if args.use_pretrain == "SBERT": 38 | bert_model = get_sbert(args) 39 | tokenizer = bert_model[0].tokenizer 40 | model = bert_model[0].auto_model 41 | print("..... loading Sentence-BERT !!!") 42 | else: 43 | config = AutoConfig.from_pretrained(BERT_CLASS[args.bert]) 44 | model = AutoModel.from_pretrained(BERT_CLASS[args.bert], config=config) 45 | tokenizer = AutoTokenizer.from_pretrained(BERT_CLASS[args.bert]) 46 | print("..... loading plain BERT !!!") 47 | 48 | return model, tokenizer 49 | 50 | 51 | def get_sbert(args): 52 | sbert = SentenceTransformer(SBERT_CLASS[args.bert]) 53 | return sbert 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | --------------------------------------------------------------------------------