├── README.md ├── datareader.py ├── datareader_cnn.py ├── emnlp_final_experiments ├── claim-detection │ ├── .DS_Store │ ├── train_basic.py │ ├── train_basic_domain_adversarial.py │ ├── train_multi_view.py │ ├── train_multi_view_averaging_individuals.py │ ├── train_multi_view_domain_adversarial.py │ ├── train_multi_view_domain_classifier.py │ ├── train_multi_view_domainclassifier_individuals.py │ └── train_multi_view_selective_weighting.py └── sentiment-analysis │ ├── .DS_Store │ ├── analyze_expert_predictions.py │ ├── analyze_expert_predictions_cnn.py │ ├── train_basic.py │ ├── train_basic_domain_adversarial.py │ ├── train_multi_view.py │ ├── train_multi_view_averaging_individuals.py │ ├── train_multi_view_averaging_individuals_cnn.py │ ├── train_multi_view_domain_adversarial.py │ ├── train_multi_view_domain_classifier.py │ ├── train_multi_view_domainclassifier_individuals.py │ └── train_multi_view_selective_weighting.py ├── metrics.py ├── model.py ├── multisource-domain-adaptation.png ├── requirements.txt ├── run_claim_experiments.sh ├── run_sentiment_experiments.sh └── setenv.sh /README.md: -------------------------------------------------------------------------------- 1 | # Transformer Based Multi-Source Domain Adaptation 2 | Dustin Wright and Isabelle Augenstein 3 | 4 | To appear in EMNLP 2020. Read the preprint: https://arxiv.org/abs/2009.07806 5 | 6 |

7 | PUC 8 |

9 | 10 | In practical machine learning settings, the data on which a model must make predictions often come from a different distribution than the data it was trained on. Here, we investigate the problem of unsupervised multi-source domain adaptation, where a model is trained on labelled data from multiple source domains and must make predictions on a domain for which no labelled data has been seen. Prior work with CNNs and RNNs has demonstrated the benefit of mixture of experts, where the predictions of multiple domain expert classifiers are combined; as well as domain adversarial training, to induce a domain agnostic representation space. Inspired by this, we investigate how such methods can be effectively applied to large pretrained transformer models. We find that domain adversarial training has an effect on the learned representations of these models while having little effect on their performance, suggesting that large transformer-based models are already relatively robust across domains. Additionally, we show that mixture of experts leads to significant performance improvements by comparing several variants of mixing functions, including one novel mixture based on attention. Finally, we demonstrate that the predictions of large pretrained transformer based domain experts are highly homogenous, making it challenging to learn effective functions for mixing their predictions. 11 | 12 | # Citing 13 | 14 | ```bib 15 | @inproceedings{wright2020transformer, 16 | title={{Transformer Based Multi-Source Domain Adaptation}}, 17 | author={Dustin Wright and Isabelle Augenstein}, 18 | booktitle = {Proceedings of EMNLP}, 19 | publisher = {Association for Computational Linguistics}, 20 | year = 2020 21 | } 22 | ``` 23 | 24 | # Recreating Results 25 | 26 | To recreate our results, first download the [Amazon Product Reviews](https://www.cs.jhu.edu/~mdredze/datasets/sentiment/) and [PHEME Rumour Detection](https://figshare.com/articles/PHEME_dataset_for_Rumour_Detection_and_Veracity_Classification/6392078) datasets and place them in the 'data/' directory. For sentiment data place it in a directory called 'data/sentiment-dataset' and for the PHEME data place it in a directory called 'data/PHEME' 27 | 28 | Create a new conda environment: 29 | 30 | ```bash 31 | $ conda create --name xformer-multisource-domain-adaptation python=3.7 32 | $ conda activate xformer-multisource-domain-adaptation 33 | $ pip install -r requirements.txt 34 | ``` 35 | 36 | Note that this project uses wandb; if you do not use wandb, set the following flag to store runs only locally: 37 | 38 | ```bash 39 | export WANDB_MODE=dryrun 40 | ``` 41 | 42 | ## Running all experiments 43 | 44 | The files for running all of the experiments are in `run_sentiment_experiments.sh` and `run_claim_experiments.sh`. You can look in these files for the commands to run a particular experiment. Running either of these files will run all 10 variants presented in the paper 5 times. The individual scripts used for each experiment are under `emnlp_final_experiments/claim-detection` and `emnlp_final_experiments/sentiment-analysis` 45 | -------------------------------------------------------------------------------- /datareader.py: -------------------------------------------------------------------------------- 1 | from xml.dom import minidom 2 | from typing import AnyStr 3 | from typing import List 4 | from typing import Tuple 5 | import unicodedata 6 | import pandas as pd 7 | import json 8 | import glob 9 | import ipdb 10 | 11 | import torch 12 | from torch.utils.data import Dataset 13 | from transformers import PreTrainedTokenizer 14 | 15 | 16 | domain_map = { 17 | 'gourmet_food': 0, 18 | 'jewelry_&_watches': 1, 19 | 'outdoor_living': 2, 20 | 'grocery': 3, 21 | 'computer_&_video_games': 4, 22 | 'beauty': 5, 23 | 'baby': 6, 24 | 'software': 7, 25 | 'magazines': 8, 26 | 'camera_&_photo': 9, 27 | 'music': 10, 28 | 'video': 11, 29 | 'health_&_personal_care': 12, 30 | 'toys_&_games': 13, 31 | 'sports_&_outdoors': 14, 32 | 'apparel': 15, 33 | 'books': 16, 34 | 'kitchen_&_housewares': 17, 35 | 'electronics': 18, 36 | 'dvd': 19 37 | } 38 | 39 | twitter_domain_map = { 40 | 'charliehebdo': 0, 41 | 'ferguson': 1, 42 | 'germanwings-crash': 2, 43 | 'ottawashooting': 3, 44 | 'sydneysiege': 4, 45 | 'health': 5 46 | } 47 | 48 | def text_to_batch_transformer(text: List, tokenizer: PreTrainedTokenizer, text_pair: AnyStr = None) -> Tuple[List, List]: 49 | """Turn a piece of text into a batch for transformer model 50 | 51 | :param text: The text to tokenize and encode 52 | :param tokenizer: The tokenizer to use 53 | :param: text_pair: An optional second string (for multiple sentence sequences) 54 | :return: A list of IDs and a mask 55 | """ 56 | if text_pair is None: 57 | input_ids = [tokenizer.encode(t, add_special_tokens=True, max_length=tokenizer.max_len) for t in text] 58 | else: 59 | input_ids = [tokenizer.encode(t, text_pair=p, add_special_tokens=True, max_length=tokenizer.max_len) for t,p in zip(text, text_pair)] 60 | 61 | masks = [[1] * len(i) for i in input_ids] 62 | 63 | return input_ids, masks 64 | 65 | 66 | def collate_batch_transformer(input_data: Tuple) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 67 | input_ids = [i[0][0] for i in input_data] 68 | masks = [i[1][0] for i in input_data] 69 | labels = [i[2] for i in input_data] 70 | domains = [i[3] for i in input_data] 71 | 72 | max_length = max([len(i) for i in input_ids]) 73 | 74 | input_ids = [(i + [0] * (max_length - len(i))) for i in input_ids] 75 | masks = [(m + [0] * (max_length - len(m))) for m in masks] 76 | 77 | assert (all(len(i) == max_length for i in input_ids)) 78 | assert (all(len(m) == max_length for m in masks)) 79 | return torch.tensor(input_ids), torch.tensor(masks), torch.tensor(labels), torch.tensor(domains) 80 | 81 | 82 | def collate_batch_transformer_with_index(input_data: Tuple) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List]: 83 | return collate_batch_transformer(input_data) + ([i[-1] for i in input_data],) 84 | 85 | 86 | def read_xml(dir: AnyStr, domain: AnyStr, split: AnyStr = 'positive'): 87 | """ Convert all of the ratings in amazon product XML file to dicts 88 | 89 | :param xml_file: The XML file to convert to a dict 90 | :return: All of the rows in the xml file as dicts 91 | """ 92 | reviews = [] 93 | split_map = {'positive': 1, 'negative': 0, 'unlabelled': -1} 94 | in_review_text = False 95 | with open(f'{dir}/{domain}/{split}.review', encoding='utf8', errors='ignore') as f: 96 | for line in f: 97 | if '' in line: 98 | reviews.append({'text': '', 'label': split_map[split], 'domain': domain_map[domain]}) 99 | in_review_text = True 100 | continue 101 | if '' in line: 102 | in_review_text = False 103 | reviews[-1]['text'] = reviews[-1]['text'].replace('\n', ' ').strip() 104 | if in_review_text: 105 | reviews[-1]['text'] += line 106 | return reviews 107 | 108 | 109 | class MultiDomainSentimentDataset(Dataset): 110 | """ 111 | Implements a dataset for the multidomain sentiment analysis dataset 112 | """ 113 | def __init__( 114 | self, 115 | dataset_dir: AnyStr, 116 | domains: List, 117 | tokenizer: PreTrainedTokenizer, 118 | domain_ids: List = None 119 | ): 120 | """ 121 | 122 | :param dataset_dir: The base directory for the dataset 123 | :param domains: The set of domains to load data for 124 | :param: tokenizer: The tokenizer to use 125 | :param: domain_ids: A list of ids to override the default domain IDs 126 | """ 127 | super(MultiDomainSentimentDataset, self).__init__() 128 | data = [] 129 | for domain in domains: 130 | data.extend(read_xml(dataset_dir, domain, 'positive')) 131 | data.extend(read_xml(dataset_dir, domain, 'negative')) 132 | 133 | self.dataset = pd.DataFrame(data) 134 | if domain_ids is not None: 135 | for i in range(len(domain_ids)): 136 | data[data['domain'] == domain_map[domains[i]]][2] = domain_ids[i] 137 | self.tokenizer = tokenizer 138 | 139 | def set_domain_id(self, domain_id): 140 | """ 141 | Overrides the domain ID for all data 142 | :param domain_id: 143 | :return: 144 | """ 145 | self.dataset['domain'] = domain_id 146 | 147 | def __len__(self): 148 | return self.dataset.shape[0] 149 | 150 | def __getitem__(self, item) -> Tuple: 151 | row = self.dataset.values[item] 152 | input_ids, mask = text_to_batch_transformer([row[0]], self.tokenizer) 153 | label = row[1] 154 | domain = row[2] 155 | return input_ids, mask, label, domain, item 156 | 157 | 158 | class MultiDomainTwitterDataset(Dataset): 159 | """ 160 | Implements a dataset for the multidomain sentiment analysis dataset 161 | """ 162 | def __init__( 163 | self, 164 | dataset_dir: AnyStr, 165 | domains: List, 166 | tokenizer: PreTrainedTokenizer, 167 | health_data_loc: AnyStr = None, 168 | domain_ids: List = None 169 | ): 170 | """ 171 | 172 | :param dataset_dir: The base directory for the dataset 173 | :param domains: The set of domains to load data for 174 | :param: tokenizer: The tokenizer to use 175 | :param: domain_ids: A list of ids to override the default domain IDs 176 | """ 177 | super(MultiDomainTwitterDataset, self).__init__() 178 | rumours = [] 179 | non_rumours = [] 180 | d_ids = [] 181 | self.name = "_".join(domains) 182 | for domain in domains: 183 | if domain != 'health': 184 | for source_tweet_file in glob.glob(f'{dataset_dir}/{domain}-all-rnr-threads/non-rumours/**/source-tweets/*.json'): 185 | with open(source_tweet_file) as js: 186 | tweet = json.load(js) 187 | non_rumours.append(tweet['text']) 188 | d_ids.append(twitter_domain_map[domain]) 189 | for source_tweet_file in glob.glob(f'{dataset_dir}/{domain}-all-rnr-threads/rumours/**/source-tweets/*.json'): 190 | with open(source_tweet_file) as js: 191 | tweet = json.load(js) 192 | rumours.append(tweet['text']) 193 | d_ids.append(twitter_domain_map[domain]) 194 | elif health_data_loc is not None: 195 | health_dataset = pd.read_csv(health_data_loc, sep="\t", header=None) 196 | # Remove unknowns 197 | health_dataset = health_dataset[health_dataset[1] != 0] 198 | # Transform the text 199 | health_dataset[0] = health_dataset[0].apply(lambda x: x[10:] if 'RT @xxxxx ' == x[:10] else x) 200 | # Drop duplicates 201 | health_dataset = health_dataset.drop_duplicates() 202 | statements = [v[0] for v in health_dataset.values] 203 | lblmap = {1: 0, -1: 1} 204 | labels = [lblmap[v[1]] for v in health_dataset.values] 205 | rumours.extend([s for s,l in zip(statements, labels) if l == 1]) 206 | non_rumours.extend([s for s,l in zip(statements, labels) if l == 0]) 207 | d_ids.extend([twitter_domain_map[domain]] * len(labels)) 208 | 209 | 210 | self.dataset = pd.DataFrame(rumours + non_rumours, columns=['statement']) 211 | self.dataset['label'] = [1] * len(rumours) + [0] * len(non_rumours) 212 | self.dataset['statement'] = self.dataset['statement'].str.normalize('NFKD') 213 | self.dataset['domain'] = d_ids 214 | 215 | self.tokenizer = tokenizer 216 | 217 | def set_domain_id(self, domain_id): 218 | """ 219 | Overrides the domain ID for all data 220 | :param domain_id: 221 | :return: 222 | """ 223 | self.dataset['domain'] = domain_id 224 | 225 | def __len__(self): 226 | return self.dataset.shape[0] 227 | 228 | def __getitem__(self, item) -> Tuple: 229 | row = self.dataset.values[item] 230 | input_ids, mask = text_to_batch_transformer([row[0]], self.tokenizer) 231 | label = row[1] 232 | domain = row[2] 233 | return input_ids, mask, label, domain, item 234 | 235 | 236 | -------------------------------------------------------------------------------- /datareader_cnn.py: -------------------------------------------------------------------------------- 1 | from xml.dom import minidom 2 | from typing import AnyStr 3 | from typing import List 4 | from typing import Tuple 5 | import unicodedata 6 | import pandas as pd 7 | import json 8 | import glob 9 | import ipdb 10 | 11 | import torch 12 | from torch.utils.data import Dataset 13 | from transformers import PreTrainedTokenizer 14 | from fasttext import tokenize 15 | 16 | 17 | domain_map = { 18 | 'gourmet_food': 0, 19 | 'jewelry_&_watches': 1, 20 | 'outdoor_living': 2, 21 | 'grocery': 3, 22 | 'computer_&_video_games': 4, 23 | 'beauty': 5, 24 | 'baby': 6, 25 | 'software': 7, 26 | 'magazines': 8, 27 | 'camera_&_photo': 9, 28 | 'music': 10, 29 | 'video': 11, 30 | 'health_&_personal_care': 12, 31 | 'toys_&_games': 13, 32 | 'sports_&_outdoors': 14, 33 | 'apparel': 15, 34 | 'books': 16, 35 | 'kitchen_&_housewares': 17, 36 | 'electronics': 18, 37 | 'dvd': 19 38 | } 39 | 40 | twitter_domain_map = { 41 | 'charliehebdo': 0, 42 | 'ferguson': 1, 43 | 'germanwings-crash': 2, 44 | 'ottawashooting': 3, 45 | 'sydneysiege': 4, 46 | 'health': 5 47 | } 48 | 49 | def text_to_batch_cnn(text: List, tokenizer, text_pair: AnyStr = None) -> Tuple[List, List]: 50 | """Turn a piece of text into a batch for transformer model 51 | 52 | :param text: The text to tokenize and encode 53 | :param tokenizer: The tokenizer to use 54 | :param: text_pair: An optional second string (for multiple sentence sequences) 55 | :return: A list of IDs and a mask 56 | """ 57 | input_ids = [tokenizer.encode(t) for t in text] 58 | 59 | masks = [[1] * len(i) for i in input_ids] 60 | 61 | return input_ids, masks 62 | 63 | 64 | def collate_batch_cnn(input_data: Tuple) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 65 | input_ids = [i[0][0] for i in input_data] 66 | masks = [i[1][0] for i in input_data] 67 | labels = [i[2] for i in input_data] 68 | domains = [i[3] for i in input_data] 69 | 70 | max_length = max([len(i) for i in input_ids]) 71 | 72 | input_ids = [(i + [0] * (max_length - len(i))) for i in input_ids] 73 | masks = [(m + [0] * (max_length - len(m))) for m in masks] 74 | 75 | assert (all(len(i) == max_length for i in input_ids)) 76 | assert (all(len(m) == max_length for m in masks)) 77 | return torch.tensor(input_ids), torch.tensor(masks), torch.tensor(labels), torch.tensor(domains) 78 | 79 | 80 | class FasttextTokenizer: 81 | 82 | def __init__(self, vocabulary_file): 83 | self.vocab = {} 84 | with open(vocabulary_file) as f: 85 | for j,l in enumerate(f): 86 | self.vocab[l.strip()] = j 87 | 88 | def encode(self, text): 89 | tokens = tokenize(text.lower().replace('\n', ' ') + '\n') 90 | return [self.vocab[t] if t in self.vocab else self.vocab['[UNK]'] for t in tokens] 91 | 92 | 93 | def collate_batch_cnn_with_index(input_data: Tuple) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List]: 94 | return collate_batch_cnn(input_data) + ([i[-1] for i in input_data],) 95 | 96 | 97 | def read_xml(dir: AnyStr, domain: AnyStr, split: AnyStr = 'positive'): 98 | """ Convert all of the ratings in amazon product XML file to dicts 99 | 100 | :param xml_file: The XML file to convert to a dict 101 | :return: All of the rows in the xml file as dicts 102 | """ 103 | reviews = [] 104 | split_map = {'positive': 1, 'negative': 0, 'unlabelled': -1} 105 | in_review_text = False 106 | with open(f'{dir}/{domain}/{split}.review', encoding='utf8', errors='ignore') as f: 107 | for line in f: 108 | if '' in line: 109 | reviews.append({'text': '', 'label': split_map[split], 'domain': domain_map[domain]}) 110 | in_review_text = True 111 | continue 112 | if '' in line: 113 | in_review_text = False 114 | reviews[-1]['text'] = reviews[-1]['text'].replace('\n', ' ').strip() 115 | if in_review_text: 116 | reviews[-1]['text'] += line 117 | return reviews 118 | 119 | 120 | class MultiDomainSentimentDataset(Dataset): 121 | """ 122 | Implements a dataset for the multidomain sentiment analysis dataset 123 | """ 124 | def __init__( 125 | self, 126 | dataset_dir: AnyStr, 127 | domains: List, 128 | tokenizer, 129 | domain_ids: List = None 130 | ): 131 | """ 132 | 133 | :param dataset_dir: The base directory for the dataset 134 | :param domains: The set of domains to load data for 135 | :param: tokenizer: The tokenizer to use 136 | :param: domain_ids: A list of ids to override the default domain IDs 137 | """ 138 | super(MultiDomainSentimentDataset, self).__init__() 139 | data = [] 140 | for domain in domains: 141 | data.extend(read_xml(dataset_dir, domain, 'positive')) 142 | data.extend(read_xml(dataset_dir, domain, 'negative')) 143 | 144 | self.dataset = pd.DataFrame(data) 145 | if domain_ids is not None: 146 | for i in range(len(domain_ids)): 147 | data[data['domain'] == domain_map[domains[i]]][2] = domain_ids[i] 148 | self.tokenizer = tokenizer 149 | 150 | def set_domain_id(self, domain_id): 151 | """ 152 | Overrides the domain ID for all data 153 | :param domain_id: 154 | :return: 155 | """ 156 | self.dataset['domain'] = domain_id 157 | 158 | def __len__(self): 159 | return self.dataset.shape[0] 160 | 161 | def __getitem__(self, item) -> Tuple: 162 | row = self.dataset.values[item] 163 | input_ids, mask = text_to_batch_cnn([row[0]], self.tokenizer) 164 | label = row[1] 165 | domain = row[2] 166 | return input_ids, mask, label, domain, item 167 | 168 | 169 | class MultiDomainTwitterDataset(Dataset): 170 | """ 171 | Implements a dataset for the multidomain sentiment analysis dataset 172 | """ 173 | def __init__( 174 | self, 175 | dataset_dir: AnyStr, 176 | domains: List, 177 | tokenizer: PreTrainedTokenizer, 178 | health_data_loc: AnyStr = None, 179 | domain_ids: List = None 180 | ): 181 | """ 182 | 183 | :param dataset_dir: The base directory for the dataset 184 | :param domains: The set of domains to load data for 185 | :param: tokenizer: The tokenizer to use 186 | :param: domain_ids: A list of ids to override the default domain IDs 187 | """ 188 | super(MultiDomainTwitterDataset, self).__init__() 189 | rumours = [] 190 | non_rumours = [] 191 | d_ids = [] 192 | self.name = "_".join(domains) 193 | for domain in domains: 194 | if domain != 'health': 195 | for source_tweet_file in glob.glob(f'{dataset_dir}/{domain}-all-rnr-threads/non-rumours/**/source-tweets/*.json'): 196 | with open(source_tweet_file) as js: 197 | tweet = json.load(js) 198 | non_rumours.append(tweet['text']) 199 | d_ids.append(twitter_domain_map[domain]) 200 | for source_tweet_file in glob.glob(f'{dataset_dir}/{domain}-all-rnr-threads/rumours/**/source-tweets/*.json'): 201 | with open(source_tweet_file) as js: 202 | tweet = json.load(js) 203 | rumours.append(tweet['text']) 204 | d_ids.append(twitter_domain_map[domain]) 205 | elif health_data_loc is not None: 206 | health_dataset = pd.read_csv(health_data_loc, sep="\t", header=None) 207 | # Remove unknowns 208 | health_dataset = health_dataset[health_dataset[1] != 0] 209 | # Transform the text 210 | health_dataset[0] = health_dataset[0].apply(lambda x: x[10:] if 'RT @xxxxx ' == x[:10] else x) 211 | # Drop duplicates 212 | health_dataset = health_dataset.drop_duplicates() 213 | statements = [v[0] for v in health_dataset.values] 214 | lblmap = {1: 0, -1: 1} 215 | labels = [lblmap[v[1]] for v in health_dataset.values] 216 | rumours.extend([s for s,l in zip(statements, labels) if l == 1]) 217 | non_rumours.extend([s for s,l in zip(statements, labels) if l == 0]) 218 | d_ids.extend([twitter_domain_map[domain]] * len(labels)) 219 | 220 | 221 | self.dataset = pd.DataFrame(rumours + non_rumours, columns=['statement']) 222 | self.dataset['label'] = [1] * len(rumours) + [0] * len(non_rumours) 223 | self.dataset['statement'] = self.dataset['statement'].str.normalize('NFKD') 224 | self.dataset['domain'] = d_ids 225 | 226 | self.tokenizer = tokenizer 227 | 228 | def set_domain_id(self, domain_id): 229 | """ 230 | Overrides the domain ID for all data 231 | :param domain_id: 232 | :return: 233 | """ 234 | self.dataset['domain'] = domain_id 235 | 236 | def __len__(self): 237 | return self.dataset.shape[0] 238 | 239 | def __getitem__(self, item) -> Tuple: 240 | row = self.dataset.values[item] 241 | input_ids, mask = text_to_batch_cnn([row[0]], self.tokenizer) 242 | label = row[1] 243 | domain = row[2] 244 | return input_ids, mask, label, domain, item 245 | 246 | -------------------------------------------------------------------------------- /emnlp_final_experiments/claim-detection/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/copenlu/xformer-multi-source-domain-adaptation/be2d1a132298131df82fe40dd4f6c08dec8b3404/emnlp_final_experiments/claim-detection/.DS_Store -------------------------------------------------------------------------------- /emnlp_final_experiments/claim-detection/train_basic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import random 5 | from typing import AnyStr 6 | from typing import List 7 | import ipdb 8 | from collections import defaultdict 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import torch 13 | import wandb 14 | from torch.optim.lr_scheduler import LambdaLR 15 | from torch.utils.data import DataLoader 16 | from torch.utils.data import Subset 17 | from torch.utils.data import random_split 18 | from torch.optim import Adam 19 | from tqdm import tqdm 20 | from transformers import AdamW 21 | from transformers import DistilBertConfig 22 | from transformers import DistilBertTokenizer 23 | from transformers import DistilBertModel 24 | from transformers import BertConfig 25 | from transformers import BertTokenizer 26 | from transformers import BertModel 27 | from transformers import get_linear_schedule_with_warmup 28 | 29 | from datareader import MultiDomainTwitterDataset 30 | from datareader import collate_batch_transformer 31 | from metrics import MultiDatasetClassificationEvaluator, acc_f1 32 | from metrics import ClassificationEvaluator 33 | 34 | from metrics import plot_label_distribution 35 | from model import * 36 | 37 | 38 | def train( 39 | model: torch.nn.Module, 40 | train_dls: List[DataLoader], 41 | optimizer: torch.optim.Optimizer, 42 | scheduler: LambdaLR, 43 | validation_evaluator: MultiDatasetClassificationEvaluator, 44 | n_epochs: int, 45 | device: AnyStr, 46 | log_interval: int = 1, 47 | patience: int = 10, 48 | model_dir: str = "wandb_local", 49 | gradient_accumulation: int = 1, 50 | domain_name: str = '' 51 | ): 52 | #best_loss = float('inf') 53 | best_f1 = 0.0 54 | patience_counter = 0 55 | 56 | epoch_counter = 0 57 | total = sum(len(dl) for dl in train_dls) 58 | 59 | # Main loop 60 | while epoch_counter < n_epochs: 61 | dl_iters = [iter(dl) for dl in train_dls] 62 | dl_idx = list(range(len(dl_iters))) 63 | finished = [0] * len(dl_iters) 64 | i = 0 65 | with tqdm(total=total, desc="Training") as pbar: 66 | while sum(finished) < len(dl_iters): 67 | random.shuffle(dl_idx) 68 | for d in dl_idx: 69 | domain_dl = dl_iters[d] 70 | batches = [] 71 | try: 72 | for j in range(gradient_accumulation): 73 | batches.append(next(domain_dl)) 74 | except StopIteration: 75 | finished[d] = 1 76 | if len(batches) == 0: 77 | continue 78 | optimizer.zero_grad() 79 | for batch in batches: 80 | model.train() 81 | batch = tuple(t.to(device) for t in batch) 82 | input_ids = batch[0] 83 | masks = batch[1] 84 | labels = batch[2] 85 | # Testing with random domains to see if any effect 86 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) 87 | domains = batch[3] 88 | 89 | loss, logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels) 90 | loss = loss / gradient_accumulation 91 | 92 | if i % log_interval == 0: 93 | wandb.log({ 94 | "Loss": loss.item() 95 | }) 96 | 97 | loss.backward() 98 | i += 1 99 | pbar.update(1) 100 | 101 | optimizer.step() 102 | if scheduler is not None: 103 | scheduler.step() 104 | 105 | gc.collect() 106 | 107 | # Inline evaluation 108 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) 109 | print(f"Validation F1: {F1}") 110 | 111 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 112 | 113 | # Saving the best model and early stopping 114 | #if val_loss < best_loss: 115 | if F1 > best_f1: 116 | best_model = model.state_dict() 117 | #best_loss = val_loss 118 | best_f1 = F1 119 | #wandb.run.summary['best_validation_loss'] = best_loss 120 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 121 | patience_counter = 0 122 | # Log to wandb 123 | wandb.log({ 124 | 'Validation accuracy': acc, 125 | 'Validation Precision': P, 126 | 'Validation Recall': R, 127 | 'Validation F1': F1, 128 | 'Validation loss': val_loss}) 129 | else: 130 | patience_counter += 1 131 | # Stop training once we have lost patience 132 | if patience_counter == patience: 133 | break 134 | 135 | gc.collect() 136 | epoch_counter += 1 137 | 138 | 139 | if __name__ == "__main__": 140 | # Define arguments 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str) 143 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8) 144 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0) 145 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1) 146 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200) 147 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2) 148 | parser.add_argument("--pretrained_model", help="Weights to initialize the model with", type=str, default=None) 149 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[]) 150 | parser.add_argument("--seed", type=int, help="Random seed", default=1000) 151 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline") 152 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str) 153 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[]) 154 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert") 155 | parser.add_argument("--ff_dim", help="The dimensionality of the feedforward network in the sluice", type=int, default=768) 156 | parser.add_argument("--batch_size", help="The batch size", type=int, default=8) 157 | parser.add_argument("--lr", help="Learning rate", type=float, default=3e-5) 158 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01) 159 | parser.add_argument("--lambd", help="l2 reg", type=float, default=10e-3) 160 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int) 161 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None) 162 | parser.add_argument("--full_bert", help="Specify to use full bert model", action="store_true") 163 | 164 | args = parser.parse_args() 165 | 166 | # Set all the seeds 167 | seed = args.seed 168 | random.seed(seed) 169 | np.random.seed(seed) 170 | torch.manual_seed(seed) 171 | torch.cuda.manual_seed_all(seed) 172 | torch.backends.cudnn.deterministic = True 173 | torch.backends.cudnn.benchmark = False 174 | 175 | # See if CUDA available 176 | device = torch.device("cpu") 177 | if args.n_gpu > 0 and torch.cuda.is_available(): 178 | print("Training on GPU") 179 | device = torch.device("cuda:0") 180 | 181 | # model configuration 182 | batch_size = args.batch_size 183 | lr = args.lr 184 | weight_decay = args.weight_decay 185 | n_epochs = args.n_epochs 186 | if args.full_bert: 187 | bert_model = 'bert-base-uncased' 188 | bert_config = BertConfig.from_pretrained(bert_model, num_labels=2) 189 | tokenizer = BertTokenizer.from_pretrained(bert_model) 190 | else: 191 | bert_model = 'distilbert-base-uncased' 192 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2) 193 | tokenizer = DistilBertTokenizer.from_pretrained(bert_model) 194 | 195 | # wandb initialization 196 | wandb.init( 197 | project="domain-adaptation-twitter-emnlp", 198 | name=args.run_name, 199 | config={ 200 | "epochs": n_epochs, 201 | "learning_rate": lr, 202 | "warmup": args.warmup_steps, 203 | "weight_decay": weight_decay, 204 | "batch_size": batch_size, 205 | "train_split_percentage": args.train_pct, 206 | "bert_model": bert_model, 207 | "seed": seed, 208 | "pretrained_model": args.pretrained_model, 209 | "tags": ",".join(args.tags) 210 | } 211 | ) 212 | # Create save directory for model 213 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"): 214 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}") 215 | 216 | # Create the dataset 217 | all_dsets = [MultiDomainTwitterDataset( 218 | args.dataset_loc, 219 | [domain], 220 | tokenizer 221 | ) for domain in args.domains[:-1]] 222 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)] 223 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))] 224 | 225 | accs = [] 226 | Ps = [] 227 | Rs = [] 228 | F1s = [] 229 | # Store labels and logits for individual splits for micro F1 230 | labels_all = [] 231 | logits_all = [] 232 | 233 | for i in range(len(all_dsets)): 234 | domain = args.domains[i] 235 | test_dset = all_dsets[i] 236 | # Override the domain IDs 237 | k = 0 238 | for j in range(len(all_dsets)): 239 | if j != i: 240 | all_dsets[j].set_domain_id(k) 241 | k += 1 242 | test_dset.set_domain_id(k) 243 | 244 | # Split the data 245 | if args.indices_dir is None: 246 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]]) 247 | for j in range(len(all_dsets)) if j != i] 248 | # Save the indices 249 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/train_idx_{domain}.txt', 'wt') as f, \ 250 | open(f'{args.model_dir}/{Path(wandb.run.dir).name}/val_idx_{domain}.txt', 'wt') as g: 251 | for j, subset in enumerate(subsets): 252 | for idx in subset[0].indices: 253 | f.write(f'{j},{idx}\n') 254 | for idx in subset[1].indices: 255 | g.write(f'{j},{idx}\n') 256 | else: 257 | # load the indices 258 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i] 259 | subset_indices = defaultdict(lambda: [[], []]) 260 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \ 261 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g: 262 | for l in f: 263 | vals = l.strip().split(',') 264 | subset_indices[int(vals[0])][0].append(int(vals[1])) 265 | for l in g: 266 | vals = l.strip().split(',') 267 | subset_indices[int(vals[0])][1].append(int(vals[1])) 268 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] 269 | for d in subset_indices] 270 | 271 | train_dls = [DataLoader( 272 | subset[0], 273 | batch_size=batch_size, 274 | shuffle=True, 275 | collate_fn=collate_batch_transformer 276 | ) for subset in subsets] 277 | 278 | val_ds = [subset[1] for subset in subsets] 279 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device) 280 | 281 | # Create the model 282 | if args.full_bert: 283 | bert = BertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 284 | else: 285 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 286 | if args.pretrained_model is not None: 287 | weights = {k: v for k, v in torch.load(args.pretrained_model).items() if "classifier" not in k} 288 | model_dict = bert.state_dict() 289 | model_dict.update(weights) 290 | bert.load_state_dict(model_dict) 291 | model = VanillaBert(bert).to(device) 292 | 293 | # Create the optimizer 294 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 295 | optimizer_grouped_parameters = [ 296 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 297 | 'weight_decay': weight_decay}, 298 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 299 | ] 300 | # optimizer = Adam(optimizer_grouped_parameters, lr=1e-3) 301 | # scheduler = None 302 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr) 303 | scheduler = get_linear_schedule_with_warmup( 304 | optimizer, 305 | args.warmup_steps, 306 | n_epochs * sum([len(train_dl) for train_dl in train_dls]) 307 | ) 308 | 309 | # Train 310 | train( 311 | model, 312 | train_dls, 313 | optimizer, 314 | scheduler, 315 | validation_evaluator, 316 | n_epochs, 317 | device, 318 | args.log_interval, 319 | model_dir=args.model_dir, 320 | gradient_accumulation=args.gradient_accumulation, 321 | domain_name=domain 322 | ) 323 | 324 | # Load the best weights 325 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth')) 326 | 327 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False) 328 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate( 329 | model, 330 | plot_callbacks=[plot_label_distribution], 331 | return_labels_logits=True, 332 | return_votes=True 333 | ) 334 | print(f"{domain} F1: {F1}") 335 | print(f"{domain} Accuracy: {acc}") 336 | print() 337 | 338 | wandb.run.summary[f"{domain}-P"] = P 339 | wandb.run.summary[f"{domain}-R"] = R 340 | wandb.run.summary[f"{domain}-F1"] = F1 341 | wandb.run.summary[f"{domain}-Acc"] = acc 342 | # macro and micro F1 are only with respect to the 5 main splits 343 | if i < len(all_dsets): 344 | Ps.append(P) 345 | Rs.append(R) 346 | F1s.append(F1) 347 | accs.append(acc) 348 | labels_all.extend(labels) 349 | logits_all.extend(logits) 350 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f: 351 | for p, l in zip(np.argmax(logits, axis=-1), labels): 352 | f.write(f'{domain}\t{p}\t{l}\n') 353 | 354 | acc, P, R, F1 = acc_f1(logits_all, labels_all) 355 | # Add to wandb 356 | wandb.run.summary[f'test-loss'] = loss 357 | wandb.run.summary[f'test-micro-acc'] = acc 358 | wandb.run.summary[f'test-micro-P'] = P 359 | wandb.run.summary[f'test-micro-R'] = R 360 | wandb.run.summary[f'test-micro-F1'] = F1 361 | 362 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs) 363 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps) 364 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs) 365 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s) 366 | 367 | #wandb.log({f"label-distribution-test-{i}": plots[0]}) 368 | -------------------------------------------------------------------------------- /emnlp_final_experiments/claim-detection/train_basic_domain_adversarial.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import random 5 | from typing import AnyStr 6 | from typing import List 7 | import ipdb 8 | from collections import defaultdict 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import torch 13 | import wandb 14 | from torch.optim.lr_scheduler import LambdaLR 15 | from torch.utils.data import DataLoader 16 | from torch.utils.data import Subset 17 | from torch.utils.data import random_split 18 | from torch.optim import Adam 19 | from tqdm import tqdm 20 | from transformers import AdamW 21 | from transformers import DistilBertConfig 22 | from transformers import DistilBertTokenizer 23 | from transformers import DistilBertModel 24 | from transformers import BertConfig 25 | from transformers import BertTokenizer 26 | from transformers import BertModel 27 | from transformers import get_linear_schedule_with_warmup 28 | 29 | from datareader import MultiDomainTwitterDataset 30 | from datareader import collate_batch_transformer 31 | from metrics import MultiDatasetClassificationEvaluator, acc_f1 32 | from metrics import ClassificationEvaluator 33 | 34 | from metrics import plot_label_distribution 35 | from model import * 36 | 37 | 38 | def train( 39 | model: torch.nn.Module, 40 | train_dls: List[DataLoader], 41 | optimizer: torch.optim.Optimizer, 42 | scheduler: LambdaLR, 43 | validation_evaluator: MultiDatasetClassificationEvaluator, 44 | n_epochs: int, 45 | device: AnyStr, 46 | log_interval: int = 1, 47 | patience: int = 10, 48 | model_dir: str = "wandb_local", 49 | gradient_accumulation: int = 1, 50 | domain_name: str = '' 51 | ): 52 | #best_loss = float('inf') 53 | best_f1 = 0.0 54 | patience_counter = 0 55 | 56 | epoch_counter = 0 57 | total = sum(len(dl) for dl in train_dls) 58 | 59 | # Main loop 60 | while epoch_counter < n_epochs: 61 | dl_iters = [iter(dl) for dl in train_dls] 62 | dl_idx = list(range(len(dl_iters))) 63 | finished = [0] * len(dl_iters) 64 | i = 0 65 | with tqdm(total=total, desc="Training") as pbar: 66 | while sum(finished) < len(dl_iters): 67 | random.shuffle(dl_idx) 68 | for d in dl_idx: 69 | domain_dl = dl_iters[d] 70 | batches = [] 71 | try: 72 | for j in range(gradient_accumulation): 73 | batches.append(next(domain_dl)) 74 | except StopIteration: 75 | finished[d] = 1 76 | if len(batches) == 0: 77 | continue 78 | optimizer.zero_grad() 79 | for batch in batches: 80 | model.train() 81 | batch = tuple(t.to(device) for t in batch) 82 | input_ids = batch[0] 83 | masks = batch[1] 84 | labels = batch[2] 85 | # Null the labels if its the test data 86 | if d == len(train_dls) - 1: 87 | labels = None 88 | # Testing with random domains to see if any effect 89 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) 90 | domains = batch[3] 91 | 92 | loss, logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels) 93 | loss = loss.mean() / gradient_accumulation 94 | 95 | if i % log_interval == 0: 96 | wandb.log({ 97 | "Loss": loss.item() 98 | }) 99 | 100 | loss.backward() 101 | i += 1 102 | pbar.update(1) 103 | 104 | optimizer.step() 105 | if scheduler is not None: 106 | scheduler.step() 107 | 108 | gc.collect() 109 | 110 | # Inline evaluation 111 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) 112 | print(f"Validation F1: {F1}") 113 | 114 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 115 | 116 | # Saving the best model and early stopping 117 | #if val_loss < best_loss: 118 | if F1 > best_f1: 119 | best_model = model.state_dict() 120 | #best_loss = val_loss 121 | best_f1 = F1 122 | #wandb.run.summary['best_validation_loss'] = best_loss 123 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 124 | patience_counter = 0 125 | # Log to wandb 126 | wandb.log({ 127 | 'Validation accuracy': acc, 128 | 'Validation Precision': P, 129 | 'Validation Recall': R, 130 | 'Validation F1': F1, 131 | 'Validation loss': val_loss}) 132 | else: 133 | patience_counter += 1 134 | # Stop training once we have lost patience 135 | if patience_counter == patience: 136 | break 137 | 138 | gc.collect() 139 | epoch_counter += 1 140 | 141 | 142 | if __name__ == "__main__": 143 | # Define arguments 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str) 146 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8) 147 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0) 148 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1) 149 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200) 150 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2) 151 | parser.add_argument("--pretrained_model", help="Weights to initialize the model with", type=str, default=None) 152 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[]) 153 | parser.add_argument("--seed", type=int, help="Random seed", default=1000) 154 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline") 155 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str) 156 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[]) 157 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert") 158 | parser.add_argument("--ff_dim", help="The dimensionality of the feedforward network in the sluice", type=int, default=768) 159 | parser.add_argument("--batch_size", help="The batch size", type=int, default=8) 160 | parser.add_argument("--lr", help="Learning rate", type=float, default=3e-5) 161 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01) 162 | parser.add_argument("--lambd", help="l2 reg", type=float, default=10e-3) 163 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int) 164 | parser.add_argument("--supervision_layer", help="The layer at which to use domain adversarial supervision", default=12, type=int) 165 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None) 166 | parser.add_argument("--full_bert", help="Specify to use full bert model", action="store_true") 167 | 168 | 169 | args = parser.parse_args() 170 | 171 | # Set all the seeds 172 | seed = args.seed 173 | random.seed(seed) 174 | np.random.seed(seed) 175 | torch.manual_seed(seed) 176 | torch.cuda.manual_seed_all(seed) 177 | torch.backends.cudnn.deterministic = True 178 | torch.backends.cudnn.benchmark = False 179 | 180 | # See if CUDA available 181 | device = torch.device("cpu") 182 | if args.n_gpu > 0 and torch.cuda.is_available(): 183 | print("Training on GPU") 184 | device = torch.device("cuda:0") 185 | 186 | # model configuration 187 | batch_size = args.batch_size 188 | lr = args.lr 189 | weight_decay = args.weight_decay 190 | n_epochs = args.n_epochs 191 | if args.full_bert: 192 | bert_model = 'bert-base-uncased' 193 | bert_config = BertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True) 194 | tokenizer = BertTokenizer.from_pretrained(bert_model) 195 | else: 196 | bert_model = 'distilbert-base-uncased' 197 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True) 198 | tokenizer = DistilBertTokenizer.from_pretrained(bert_model) 199 | 200 | # wandb initialization 201 | wandb.init( 202 | project="domain-adaptation-twitter-emnlp", 203 | name=args.run_name, 204 | config={ 205 | "epochs": n_epochs, 206 | "learning_rate": lr, 207 | "warmup": args.warmup_steps, 208 | "weight_decay": weight_decay, 209 | "batch_size": batch_size, 210 | "train_split_percentage": args.train_pct, 211 | "bert_model": bert_model, 212 | "seed": seed, 213 | "pretrained_model": args.pretrained_model, 214 | "tags": ",".join(args.tags) 215 | } 216 | ) 217 | # Create save directory for model 218 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"): 219 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}") 220 | 221 | # Create the dataset 222 | all_dsets = [MultiDomainTwitterDataset( 223 | args.dataset_loc, 224 | [domain], 225 | tokenizer 226 | ) for domain in args.domains[:-1]] 227 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)] 228 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))] 229 | 230 | accs = [] 231 | Ps = [] 232 | Rs = [] 233 | F1s = [] 234 | # Store labels and logits for individual splits for micro F1 235 | labels_all = [] 236 | logits_all = [] 237 | 238 | for i in range(len(all_dsets)): 239 | domain = args.domains[i] 240 | test_dset = all_dsets[i] 241 | # Override the domain IDs 242 | k = 0 243 | for j in range(len(all_dsets)): 244 | if j != i: 245 | all_dsets[j].set_domain_id(k) 246 | k += 1 247 | test_dset.set_domain_id(k) 248 | 249 | # Split the data 250 | if args.indices_dir is None: 251 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]]) 252 | for j in range(len(all_dsets)) if j != i] 253 | else: 254 | # load the indices 255 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i] 256 | subset_indices = defaultdict(lambda: [[], []]) 257 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \ 258 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g: 259 | for l in f: 260 | vals = l.strip().split(',') 261 | subset_indices[int(vals[0])][0].append(int(vals[1])) 262 | for l in g: 263 | vals = l.strip().split(',') 264 | subset_indices[int(vals[0])][1].append(int(vals[1])) 265 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] 266 | for d in subset_indices] 267 | train_dls = [DataLoader( 268 | subset[0], 269 | batch_size=batch_size, 270 | shuffle=True, 271 | collate_fn=collate_batch_transformer 272 | ) for subset in subsets] 273 | # Add test data for domain adversarial training 274 | train_dls += [DataLoader( 275 | test_dset, 276 | batch_size=batch_size, 277 | shuffle=True, 278 | collate_fn=collate_batch_transformer 279 | )] 280 | 281 | val_ds = [subset[1] for subset in subsets] 282 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device) 283 | 284 | # Create the model 285 | if args.full_bert: 286 | bert = BertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 287 | else: 288 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 289 | if args.pretrained_model is not None: 290 | weights = {k: v for k, v in torch.load(args.pretrained_model).items() if "classifier" not in k} 291 | model_dict = bert.state_dict() 292 | model_dict.update(weights) 293 | bert.load_state_dict(model_dict) 294 | model = torch.nn.DataParallel(DomainAdversarialBert( 295 | bert, 296 | n_domains=len(train_dls), supervision_layer=args.supervision_layer 297 | )).to(device) 298 | 299 | # Create the optimizer 300 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 301 | optimizer_grouped_parameters = [ 302 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 303 | 'weight_decay': weight_decay}, 304 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 305 | ] 306 | # optimizer = Adam(optimizer_grouped_parameters, lr=1e-3) 307 | # scheduler = None 308 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr) 309 | scheduler = get_linear_schedule_with_warmup( 310 | optimizer, 311 | args.warmup_steps, 312 | n_epochs * sum([len(train_dl) for train_dl in train_dls]) 313 | ) 314 | 315 | # Train 316 | train( 317 | model, 318 | train_dls, 319 | optimizer, 320 | scheduler, 321 | validation_evaluator, 322 | n_epochs, 323 | device, 324 | args.log_interval, 325 | model_dir=args.model_dir, 326 | gradient_accumulation=args.gradient_accumulation, 327 | domain_name=domain 328 | ) 329 | 330 | # Load the best weights 331 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth')) 332 | 333 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False) 334 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate( 335 | model, 336 | plot_callbacks=[plot_label_distribution], 337 | return_labels_logits=True, 338 | return_votes=True 339 | ) 340 | print(f"{domain} F1: {F1}") 341 | print(f"{domain} Accuracy: {acc}") 342 | print() 343 | 344 | wandb.run.summary[f"{domain}-P"] = P 345 | wandb.run.summary[f"{domain}-R"] = R 346 | wandb.run.summary[f"{domain}-F1"] = F1 347 | wandb.run.summary[f"{domain}-Acc"] = acc 348 | # macro and micro F1 are only with respect to the 5 main splits 349 | if i < len(all_dsets): 350 | Ps.append(P) 351 | Rs.append(R) 352 | F1s.append(F1) 353 | accs.append(acc) 354 | labels_all.extend(labels) 355 | logits_all.extend(logits) 356 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f: 357 | for p, l in zip(np.argmax(logits, axis=-1), labels): 358 | f.write(f'{domain}\t{p}\t{l}\n') 359 | 360 | acc, P, R, F1 = acc_f1(logits_all, labels_all) 361 | # Add to wandb 362 | wandb.run.summary[f'test-loss'] = loss 363 | wandb.run.summary[f'test-micro-acc'] = acc 364 | wandb.run.summary[f'test-micro-P'] = P 365 | wandb.run.summary[f'test-micro-R'] = R 366 | wandb.run.summary[f'test-micro-F1'] = F1 367 | 368 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs) 369 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps) 370 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs) 371 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s) 372 | 373 | #wandb.log({f"label-distribution-test-{i}": plots[0]}) 374 | -------------------------------------------------------------------------------- /emnlp_final_experiments/claim-detection/train_multi_view.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import random 5 | from typing import AnyStr 6 | from typing import List 7 | import ipdb 8 | import krippendorff 9 | from collections import defaultdict 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | import wandb 15 | from torch.optim.lr_scheduler import LambdaLR 16 | from torch.utils.data import DataLoader 17 | from torch.utils.data import Subset 18 | from torch.utils.data import random_split 19 | from torch.optim import Adam 20 | from tqdm import tqdm 21 | from transformers import AdamW 22 | from transformers import DistilBertConfig 23 | from transformers import DistilBertTokenizer 24 | from transformers import DistilBertForSequenceClassification 25 | from transformers import get_linear_schedule_with_warmup 26 | 27 | from datareader import MultiDomainTwitterDataset 28 | from datareader import collate_batch_transformer 29 | from metrics import MultiDatasetClassificationEvaluator 30 | from metrics import ClassificationEvaluator 31 | from metrics import acc_f1 32 | 33 | from metrics import plot_label_distribution 34 | from model import MultiTransformerClassifier 35 | from model import VanillaBert 36 | from model import * 37 | 38 | 39 | def train( 40 | model: torch.nn.Module, 41 | train_dls: List[DataLoader], 42 | optimizer: torch.optim.Optimizer, 43 | scheduler: LambdaLR, 44 | validation_evaluator: MultiDatasetClassificationEvaluator, 45 | n_epochs: int, 46 | device: AnyStr, 47 | log_interval: int = 1, 48 | patience: int = 10, 49 | model_dir: str = "wandb_local", 50 | gradient_accumulation: int = 1, 51 | domain_name: str = '' 52 | ): 53 | #best_loss = float('inf') 54 | best_f1 = 0.0 55 | patience_counter = 0 56 | 57 | epoch_counter = 0 58 | total = sum(len(dl) for dl in train_dls) 59 | 60 | # Main loop 61 | while epoch_counter < n_epochs: 62 | dl_iters = [iter(dl) for dl in train_dls] 63 | dl_idx = list(range(len(dl_iters))) 64 | finished = [0] * len(dl_iters) 65 | i = 0 66 | with tqdm(total=total, desc="Training") as pbar: 67 | while sum(finished) < len(dl_iters): 68 | random.shuffle(dl_idx) 69 | for d in dl_idx: 70 | domain_dl = dl_iters[d] 71 | batches = [] 72 | try: 73 | for j in range(gradient_accumulation): 74 | batches.append(next(domain_dl)) 75 | except StopIteration: 76 | finished[d] = 1 77 | if len(batches) == 0: 78 | continue 79 | optimizer.zero_grad() 80 | for batch in batches: 81 | model.train() 82 | batch = tuple(t.to(device) for t in batch) 83 | input_ids = batch[0] 84 | masks = batch[1] 85 | labels = batch[2] 86 | # Testing with random domains to see if any effect 87 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) 88 | domains = batch[3] 89 | 90 | loss, logits, alpha = model(input_ids, attention_mask=masks, domains=domains, labels=labels, ret_alpha = True) 91 | loss = loss.mean() / gradient_accumulation 92 | if i % log_interval == 0: 93 | # wandb.log({ 94 | # "Loss": loss.item(), 95 | # "alpha0": alpha[:,0].cpu(), 96 | # "alpha1": alpha[:, 1].cpu(), 97 | # "alpha2": alpha[:, 2].cpu(), 98 | # "alpha_shared": alpha[:, 3].cpu() 99 | # }) 100 | wandb.log({ 101 | "Loss": loss.item() 102 | }) 103 | 104 | loss.backward() 105 | i += 1 106 | pbar.update(1) 107 | 108 | optimizer.step() 109 | if scheduler is not None: 110 | scheduler.step() 111 | 112 | gc.collect() 113 | 114 | # Inline evaluation 115 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) 116 | print(f"Validation f1: {F1}") 117 | 118 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 119 | 120 | # Saving the best model and early stopping 121 | #if val_loss < best_loss: 122 | if F1 > best_f1: 123 | best_model = model.state_dict() 124 | #best_loss = val_loss 125 | best_f1 = F1 126 | #wandb.run.summary['best_validation_loss'] = best_loss 127 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 128 | patience_counter = 0 129 | # Log to wandb 130 | wandb.log({ 131 | 'Validation accuracy': acc, 132 | 'Validation Precision': P, 133 | 'Validation Recall': R, 134 | 'Validation F1': F1, 135 | 'Validation loss': val_loss}) 136 | else: 137 | patience_counter += 1 138 | # Stop training once we have lost patience 139 | if patience_counter == patience: 140 | break 141 | 142 | gc.collect() 143 | epoch_counter += 1 144 | 145 | 146 | if __name__ == "__main__": 147 | # Define arguments 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str) 150 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8) 151 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0) 152 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1) 153 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200) 154 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2) 155 | parser.add_argument("--pretrained_bert", help="Directory with weights to initialize the shared model with", type=str, default=None) 156 | parser.add_argument("--pretrained_multi_xformer", help="Directory with weights to initialize the domain specific models", type=str, default=None) 157 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[]) 158 | parser.add_argument("--seed", type=int, help="Random seed", default=1000) 159 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline") 160 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str) 161 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[]) 162 | parser.add_argument("--batch_size", help="The batch size", type=int, default=16) 163 | parser.add_argument("--lr", help="Learning rate", type=float, default=1e-5) 164 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01) 165 | parser.add_argument("--n_heads", help="Number of transformer heads", default=6, type=int) 166 | parser.add_argument("--n_layers", help="Number of transformer layers", default=6, type=int) 167 | parser.add_argument("--d_model", help="Transformer model size", default=768, type=int) 168 | parser.add_argument("--ff_dim", help="Intermediate feedforward size", default=2048, type=int) 169 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int) 170 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert") 171 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None) 172 | parser.add_argument("--ensemble_basic", help="Use averaging for the ensembling method", action="store_true") 173 | parser.add_argument("--ensemble_avg_learned", help="Use learned averaging for the ensembling method", action="store_true") 174 | 175 | 176 | args = parser.parse_args() 177 | 178 | # Set all the seeds 179 | seed = args.seed 180 | random.seed(seed) 181 | np.random.seed(seed) 182 | torch.manual_seed(seed) 183 | torch.cuda.manual_seed_all(seed) 184 | torch.backends.cudnn.deterministic = True 185 | torch.backends.cudnn.benchmark = False 186 | 187 | # See if CUDA available 188 | device = torch.device("cpu") 189 | if args.n_gpu > 0 and torch.cuda.is_available(): 190 | print("Training on GPU") 191 | device = torch.device("cuda:0") 192 | 193 | # model configuration 194 | bert_model = 'distilbert-base-uncased' 195 | batch_size = args.batch_size 196 | lr = args.lr 197 | weight_decay = args.weight_decay 198 | n_epochs = args.n_epochs 199 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True) 200 | 201 | 202 | # wandb initialization 203 | wandb.init( 204 | project="domain-adaptation-twitter-emnlp", 205 | name=args.run_name, 206 | config={ 207 | "epochs": n_epochs, 208 | "learning_rate": lr, 209 | "warmup": args.warmup_steps, 210 | "weight_decay": weight_decay, 211 | "batch_size": batch_size, 212 | "train_split_percentage": args.train_pct, 213 | "bert_model": bert_model, 214 | "seed": seed, 215 | "tags": ",".join(args.tags) 216 | } 217 | ) 218 | #wandb.watch(model) 219 | #Create save directory for model 220 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"): 221 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}") 222 | 223 | # Create the dataset 224 | all_dsets = [MultiDomainTwitterDataset( 225 | args.dataset_loc, 226 | [domain], 227 | DistilBertTokenizer.from_pretrained(bert_model) 228 | ) for domain in args.domains[:-1]] 229 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)] 230 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))] 231 | 232 | accs = [] 233 | Ps = [] 234 | Rs = [] 235 | F1s = [] 236 | # Store labels and logits for individual splits for micro F1 237 | labels_all = [] 238 | logits_all = [] 239 | 240 | for i in range(len(all_dsets)+1): 241 | domain = args.domains[i] 242 | test_dset = all_dsets[i] 243 | # Override the domain IDs 244 | k = 0 245 | for j in range(len(all_dsets)): 246 | if j != i: 247 | all_dsets[j].set_domain_id(k) 248 | k += 1 249 | test_dset.set_domain_id(k) 250 | # For test 251 | #all_dsets = [all_dsets[0], all_dsets[2]] 252 | 253 | # Split the data 254 | if args.indices_dir is None: 255 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]]) 256 | for j in range(len(all_dsets)) if j != i] 257 | else: 258 | # load the indices 259 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i] 260 | subset_indices = defaultdict(lambda: [[], []]) 261 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \ 262 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g: 263 | for l in f: 264 | vals = l.strip().split(',') 265 | subset_indices[int(vals[0])][0].append(int(vals[1])) 266 | for l in g: 267 | vals = l.strip().split(',') 268 | subset_indices[int(vals[0])][1].append(int(vals[1])) 269 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] for d in 270 | subset_indices] 271 | 272 | train_dls = [DataLoader( 273 | subset[0], 274 | batch_size=batch_size, 275 | shuffle=True, 276 | collate_fn=collate_batch_transformer 277 | ) for subset in subsets] 278 | 279 | val_ds = [subset[1] for subset in subsets] 280 | # for vds in val_ds: 281 | # print(vds.indices) 282 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device) 283 | 284 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 285 | # Create the model 286 | init_weights = None 287 | if args.pretrained_bert is not None: 288 | init_weights = {k: v for k, v in torch.load(args.pretrained_bert).items() if "classifier" not in k} 289 | model_dict = bert.state_dict() 290 | model_dict.update(init_weights) 291 | bert.load_state_dict(model_dict) 292 | shared_bert = VanillaBert(bert).to(device) 293 | 294 | multi_xformer = MultiDistilBertClassifier( 295 | bert_model, 296 | bert_config, 297 | n_domains=len(train_dls), 298 | init_weights=init_weights 299 | ).to(device) 300 | 301 | if args.ensemble_basic: 302 | model_class = MultiViewTransformerNetworkAveraging 303 | elif args.ensemble_avg_learned: 304 | model_class = MultiViewTransformerNetworkLearnedAveraging 305 | else: 306 | model_class = MultiViewTransformerNetworkProbabilities 307 | 308 | model = torch.nn.DataParallel(model_class( 309 | multi_xformer, 310 | shared_bert 311 | )).to(device) 312 | 313 | # Create the optimizer 314 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 315 | optimizer_grouped_parameters = [ 316 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 317 | 'weight_decay': weight_decay}, 318 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 319 | ] 320 | 321 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr) 322 | scheduler = get_linear_schedule_with_warmup( 323 | optimizer, 324 | args.warmup_steps, 325 | n_epochs * sum([len(train_dl) for train_dl in train_dls]) 326 | ) 327 | 328 | # Train 329 | train( 330 | model, 331 | train_dls, 332 | optimizer, 333 | scheduler, 334 | validation_evaluator, 335 | n_epochs, 336 | device, 337 | args.log_interval, 338 | model_dir=args.model_dir, 339 | gradient_accumulation=args.gradient_accumulation, 340 | domain_name=domain 341 | ) 342 | # Load the best weights 343 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth')) 344 | if args.ensemble_avg_learned: 345 | weights = model.module.alpha_params.cpu().detach().numpy() 346 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/weights_{domain}.txt', 'wt') as f: 347 | f.write(str(weights)) 348 | 349 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False) 350 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate( 351 | model, 352 | plot_callbacks=[plot_label_distribution], 353 | return_labels_logits=True, 354 | return_votes=True 355 | ) 356 | print(f"{domain} F1: {F1}") 357 | print(f"{domain} Accuracy: {acc}") 358 | print() 359 | 360 | wandb.run.summary[f"{domain}-P"] = P 361 | wandb.run.summary[f"{domain}-R"] = R 362 | wandb.run.summary[f"{domain}-F1"] = F1 363 | wandb.run.summary[f"{domain}-Acc"] = acc 364 | # macro and micro F1 are only with respect to the 5 main splits 365 | if i < len(all_dsets): 366 | Ps.append(P) 367 | Rs.append(R) 368 | F1s.append(F1) 369 | accs.append(acc) 370 | labels_all.extend(labels) 371 | logits_all.extend(logits) 372 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f: 373 | for p, l in zip(np.argmax(logits, axis=-1), labels): 374 | f.write(f'{domain}\t{p}\t{l}\n') 375 | 376 | acc, P, R, F1 = acc_f1(logits_all, labels_all) 377 | # Add to wandb 378 | wandb.run.summary[f'test-loss'] = loss 379 | wandb.run.summary[f'test-micro-acc'] = acc 380 | wandb.run.summary[f'test-micro-P'] = P 381 | wandb.run.summary[f'test-micro-R'] = R 382 | wandb.run.summary[f'test-micro-F1'] = F1 383 | 384 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs) 385 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps) 386 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs) 387 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s) 388 | 389 | # wandb.log({f"label-distribution-test-{i}": plots[0]}) 390 | -------------------------------------------------------------------------------- /emnlp_final_experiments/claim-detection/train_multi_view_domain_adversarial.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import random 5 | from typing import AnyStr 6 | from typing import List 7 | import ipdb 8 | import krippendorff 9 | from collections import defaultdict 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | import wandb 15 | from torch.optim.lr_scheduler import LambdaLR 16 | from torch.utils.data import DataLoader 17 | from torch.utils.data import Subset 18 | from torch.utils.data import random_split 19 | from torch.optim import Adam 20 | from tqdm import tqdm 21 | from transformers import AdamW 22 | from transformers import DistilBertConfig 23 | from transformers import DistilBertTokenizer 24 | from transformers import DistilBertForSequenceClassification 25 | from transformers import get_linear_schedule_with_warmup 26 | 27 | from datareader import MultiDomainTwitterDataset 28 | from datareader import collate_batch_transformer 29 | from metrics import MultiDatasetClassificationEvaluator 30 | from metrics import ClassificationEvaluator 31 | from metrics import acc_f1 32 | 33 | from metrics import plot_label_distribution 34 | from model import MultiTransformerClassifier 35 | from model import VanillaBert 36 | from model import * 37 | 38 | 39 | def train( 40 | model: torch.nn.Module, 41 | train_dls: List[DataLoader], 42 | optimizer: torch.optim.Optimizer, 43 | scheduler: LambdaLR, 44 | validation_evaluator: MultiDatasetClassificationEvaluator, 45 | n_epochs: int, 46 | device: AnyStr, 47 | log_interval: int = 1, 48 | patience: int = 10, 49 | model_dir: str = "wandb_local", 50 | gradient_accumulation: int = 1, 51 | domain_name: str = '' 52 | ): 53 | #best_loss = float('inf') 54 | best_f1 = 0.0 55 | patience_counter = 0 56 | 57 | epoch_counter = 0 58 | total = sum(len(dl) for dl in train_dls) 59 | 60 | # Main loop 61 | while epoch_counter < n_epochs: 62 | dl_iters = [iter(dl) for dl in train_dls] 63 | dl_idx = list(range(len(dl_iters))) 64 | finished = [0] * len(dl_iters) 65 | i = 0 66 | with tqdm(total=total, desc="Training") as pbar: 67 | while sum(finished) < len(dl_iters): 68 | random.shuffle(dl_idx) 69 | for d in dl_idx: 70 | domain_dl = dl_iters[d] 71 | batches = [] 72 | try: 73 | for j in range(gradient_accumulation): 74 | batches.append(next(domain_dl)) 75 | except StopIteration: 76 | finished[d] = 1 77 | if len(batches) == 0: 78 | continue 79 | optimizer.zero_grad() 80 | for batch in batches: 81 | model.train() 82 | batch = tuple(t.to(device) for t in batch) 83 | input_ids = batch[0] 84 | masks = batch[1] 85 | labels = batch[2] 86 | # Null the labels if its the test data 87 | if d == len(train_dls) - 1: 88 | labels = None 89 | # Testing with random domains to see if any effect 90 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) 91 | domains = batch[3] 92 | 93 | loss, logits, alpha = model(input_ids, attention_mask=masks, domains=domains, labels=labels, ret_alpha = True) 94 | loss = loss.mean() / gradient_accumulation 95 | if i % log_interval == 0: 96 | # wandb.log({ 97 | # "Loss": loss.item(), 98 | # "alpha0": alpha[:,0].cpu(), 99 | # "alpha1": alpha[:, 1].cpu(), 100 | # "alpha2": alpha[:, 2].cpu(), 101 | # "alpha_shared": alpha[:, 3].cpu() 102 | # }) 103 | wandb.log({ 104 | "Loss": loss.item() 105 | }) 106 | 107 | loss.backward() 108 | i += 1 109 | pbar.update(1) 110 | 111 | optimizer.step() 112 | if scheduler is not None: 113 | scheduler.step() 114 | 115 | gc.collect() 116 | 117 | # Inline evaluation 118 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) 119 | print(f"Validation F1: {F1}") 120 | 121 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 122 | 123 | # Saving the best model and early stopping 124 | #if val_loss < best_loss: 125 | if F1 > best_f1: 126 | best_model = model.state_dict() 127 | #best_loss = val_loss 128 | best_f1 = F1 129 | #wandb.run.summary['best_validation_loss'] = best_loss 130 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 131 | patience_counter = 0 132 | # Log to wandb 133 | wandb.log({ 134 | 'Validation accuracy': acc, 135 | 'Validation Precision': P, 136 | 'Validation Recall': R, 137 | 'Validation F1': F1, 138 | 'Validation loss': val_loss}) 139 | else: 140 | patience_counter += 1 141 | # Stop training once we have lost patience 142 | if patience_counter == patience: 143 | break 144 | 145 | gc.collect() 146 | epoch_counter += 1 147 | 148 | 149 | if __name__ == "__main__": 150 | # Define arguments 151 | parser = argparse.ArgumentParser() 152 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str) 153 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8) 154 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0) 155 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1) 156 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200) 157 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2) 158 | parser.add_argument("--pretrained_bert", help="Directory with weights to initialize the shared model with", type=str, default=None) 159 | parser.add_argument("--pretrained_multi_xformer", help="Directory with weights to initialize the domain specific models", type=str, default=None) 160 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[]) 161 | parser.add_argument("--seed", type=int, help="Random seed", default=1000) 162 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline") 163 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str) 164 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[]) 165 | parser.add_argument("--batch_size", help="The batch size", type=int, default=16) 166 | parser.add_argument("--lr", help="Learning rate", type=float, default=1e-5) 167 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01) 168 | parser.add_argument("--n_heads", help="Number of transformer heads", default=6, type=int) 169 | parser.add_argument("--n_layers", help="Number of transformer layers", default=6, type=int) 170 | parser.add_argument("--d_model", help="Transformer model size", default=768, type=int) 171 | parser.add_argument("--ff_dim", help="Intermediate feedforward size", default=2048, type=int) 172 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int) 173 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert") 174 | parser.add_argument("--supervision_layer", help="The layer at which to use domain adversarial supervision", default=12, type=int) 175 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None) 176 | 177 | args = parser.parse_args() 178 | 179 | # Set all the seeds 180 | seed = args.seed 181 | random.seed(seed) 182 | np.random.seed(seed) 183 | torch.manual_seed(seed) 184 | torch.cuda.manual_seed_all(seed) 185 | torch.backends.cudnn.deterministic = True 186 | torch.backends.cudnn.benchmark = False 187 | 188 | # See if CUDA available 189 | device = torch.device("cpu") 190 | if args.n_gpu > 0 and torch.cuda.is_available(): 191 | print("Training on GPU") 192 | device = torch.device("cuda:0") 193 | 194 | # model configuration 195 | bert_model = 'distilbert-base-uncased' 196 | batch_size = args.batch_size 197 | lr = args.lr 198 | weight_decay = args.weight_decay 199 | n_epochs = args.n_epochs 200 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True) 201 | 202 | 203 | # wandb initialization 204 | wandb.init( 205 | project="domain-adaptation-twitter-emnlp", 206 | name=args.run_name, 207 | config={ 208 | "epochs": n_epochs, 209 | "learning_rate": lr, 210 | "warmup": args.warmup_steps, 211 | "weight_decay": weight_decay, 212 | "batch_size": batch_size, 213 | "train_split_percentage": args.train_pct, 214 | "bert_model": bert_model, 215 | "seed": seed, 216 | "tags": ",".join(args.tags) 217 | } 218 | ) 219 | #wandb.watch(model) 220 | #Create save directory for model 221 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"): 222 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}") 223 | 224 | # Create the dataset 225 | all_dsets = [MultiDomainTwitterDataset( 226 | args.dataset_loc, 227 | [domain], 228 | DistilBertTokenizer.from_pretrained(bert_model) 229 | ) for domain in args.domains[:-1]] 230 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)] 231 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))] 232 | 233 | accs = [] 234 | Ps = [] 235 | Rs = [] 236 | F1s = [] 237 | # Store labels and logits for individual splits for micro F1 238 | labels_all = [] 239 | logits_all = [] 240 | 241 | for i in range(len(all_dsets)): 242 | domain = args.domains[i] 243 | test_dset = all_dsets[i] 244 | # Override the domain IDs 245 | k = 0 246 | for j in range(len(all_dsets)): 247 | if j != i: 248 | all_dsets[j].set_domain_id(k) 249 | k += 1 250 | test_dset.set_domain_id(k) 251 | # For test 252 | #all_dsets = [all_dsets[0], all_dsets[2]] 253 | 254 | # Split the data 255 | if args.indices_dir is None: 256 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]]) 257 | for j in range(len(all_dsets)) if j != i] 258 | else: 259 | # load the indices 260 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i] 261 | subset_indices = defaultdict(lambda: [[], []]) 262 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \ 263 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g: 264 | for l in f: 265 | vals = l.strip().split(',') 266 | subset_indices[int(vals[0])][0].append(int(vals[1])) 267 | for l in g: 268 | vals = l.strip().split(',') 269 | subset_indices[int(vals[0])][1].append(int(vals[1])) 270 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] for d in 271 | subset_indices] 272 | 273 | train_dls = [DataLoader( 274 | subset[0], 275 | batch_size=batch_size, 276 | shuffle=True, 277 | collate_fn=collate_batch_transformer 278 | ) for subset in subsets] 279 | # Add test data for domain adversarial training 280 | train_dls += [DataLoader( 281 | test_dset, 282 | batch_size=batch_size, 283 | shuffle=True, 284 | collate_fn=collate_batch_transformer 285 | )] 286 | 287 | val_ds = [subset[1] for subset in subsets] 288 | # for vds in val_ds: 289 | # print(vds.indices) 290 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device) 291 | 292 | # Create the model 293 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 294 | init_weights = None 295 | if args.pretrained_bert is not None: 296 | init_weights = {k: v for k, v in torch.load(args.pretrained_bert).items() if "classifier" not in k} 297 | model_dict = bert.state_dict() 298 | model_dict.update(init_weights) 299 | bert.load_state_dict(model_dict) 300 | shared_bert = VanillaBert(bert).to(device) 301 | 302 | multi_xformer = MultiDistilBertClassifier( 303 | bert_model, 304 | bert_config, 305 | n_domains=len(train_dls) - 1, 306 | init_weights=init_weights 307 | ).to(device) 308 | 309 | model = torch.nn.DataParallel(MultiViewTransformerNetworkProbabilitiesAdversarial( 310 | multi_xformer, 311 | shared_bert, 312 | supervision_layer=args.supervision_layer 313 | )).to(device) 314 | 315 | # Create the optimizer 316 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 317 | optimizer_grouped_parameters = [ 318 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 319 | 'weight_decay': weight_decay}, 320 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 321 | ] 322 | 323 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr) 324 | scheduler = get_linear_schedule_with_warmup( 325 | optimizer, 326 | args.warmup_steps, 327 | n_epochs * sum([len(train_dl) for train_dl in train_dls]) 328 | ) 329 | 330 | # Train 331 | train( 332 | model, 333 | train_dls, 334 | optimizer, 335 | scheduler, 336 | validation_evaluator, 337 | n_epochs, 338 | device, 339 | args.log_interval, 340 | model_dir=args.model_dir, 341 | gradient_accumulation=args.gradient_accumulation, 342 | domain_name=domain 343 | ) 344 | # Load the best weights 345 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth')) 346 | 347 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False) 348 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate( 349 | model, 350 | plot_callbacks=[plot_label_distribution], 351 | return_labels_logits=True, 352 | return_votes=True 353 | ) 354 | print(f"{domain} F1: {F1}") 355 | print(f"{domain} Accuracy: {acc}") 356 | print() 357 | 358 | wandb.run.summary[f"{domain}-P"] = P 359 | wandb.run.summary[f"{domain}-R"] = R 360 | wandb.run.summary[f"{domain}-F1"] = F1 361 | wandb.run.summary[f"{domain}-Acc"] = acc 362 | # macro and micro F1 are only with respect to the 5 main splits 363 | if i < len(all_dsets): 364 | Ps.append(P) 365 | Rs.append(R) 366 | F1s.append(F1) 367 | accs.append(acc) 368 | labels_all.extend(labels) 369 | logits_all.extend(logits) 370 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f: 371 | for p, l in zip(np.argmax(logits, axis=-1), labels): 372 | f.write(f'{domain}\t{p}\t{l}\n') 373 | 374 | acc, P, R, F1 = acc_f1(logits_all, labels_all) 375 | # Add to wandb 376 | wandb.run.summary[f'test-loss'] = loss 377 | wandb.run.summary[f'test-micro-acc'] = acc 378 | wandb.run.summary[f'test-micro-P'] = P 379 | wandb.run.summary[f'test-micro-R'] = R 380 | wandb.run.summary[f'test-micro-F1'] = F1 381 | 382 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs) 383 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps) 384 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs) 385 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s) 386 | 387 | # wandb.log({f"label-distribution-test-{i}": plots[0]}) 388 | -------------------------------------------------------------------------------- /emnlp_final_experiments/claim-detection/train_multi_view_selective_weighting.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import random 5 | from typing import AnyStr 6 | from typing import List 7 | import ipdb 8 | import krippendorff 9 | from collections import defaultdict 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | import wandb 15 | from torch.optim.lr_scheduler import LambdaLR 16 | from torch.utils.data import DataLoader 17 | from torch.utils.data import Subset 18 | from torch.utils.data import random_split 19 | from torch.optim import Adam 20 | from tqdm import tqdm 21 | from transformers import AdamW 22 | from transformers import DistilBertConfig 23 | from transformers import DistilBertTokenizer 24 | from transformers import DistilBertForSequenceClassification 25 | from transformers import get_linear_schedule_with_warmup 26 | 27 | from datareader import MultiDomainTwitterDataset 28 | from datareader import collate_batch_transformer 29 | from metrics import MultiDatasetClassificationEvaluator 30 | from metrics import ClassificationEvaluator 31 | from metrics import acc_f1 32 | 33 | from metrics import plot_label_distribution 34 | from model import MultiTransformerClassifier 35 | from model import VanillaBert 36 | from model import * 37 | from sklearn.model_selection import ParameterSampler 38 | 39 | 40 | def attention_grid_search( 41 | model: torch.nn.Module, 42 | validation_evaluator: MultiDatasetClassificationEvaluator, 43 | n_epochs: int, 44 | seed: int 45 | ): 46 | best_weights = model.module.weights 47 | # initial 48 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) 49 | best_f1 = F1 50 | print(F1) 51 | # Create the grid search 52 | param_dict = {1:list(range(0,11)), 2:list(range(0,11)),3:list(range(0,11)),4:list(range(0,11)), 5:list(range(0,11))} 53 | grid_search_params = ParameterSampler(param_dict, n_iter=n_epochs, random_state=seed) 54 | for d in grid_search_params: 55 | weights = [v for k,v in sorted(d.items(), key=lambda x:x[0])] 56 | weights = np.array(weights) / sum(weights) 57 | model.module.weights = weights 58 | # Inline evaluation 59 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) 60 | print(f"Weights: {weights}\tValidation F1: {F1}") 61 | 62 | if F1 > best_f1: 63 | best_weights = weights 64 | best_f1 = F1 65 | # Log to wandb 66 | wandb.log({ 67 | 'Validation accuracy': acc, 68 | 'Validation Precision': P, 69 | 'Validation Recall': R, 70 | 'Validation F1': F1, 71 | 'Validation loss': val_loss}) 72 | 73 | gc.collect() 74 | 75 | return best_weights 76 | 77 | 78 | if __name__ == "__main__": 79 | # Define arguments 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str) 82 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8) 83 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0) 84 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2) 85 | parser.add_argument("--pretrained_model", help="The pretrained averaging model", type=str, default=None) 86 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[]) 87 | parser.add_argument("--seed", type=int, help="Random seed", default=1000) 88 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline") 89 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str) 90 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[]) 91 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None) 92 | 93 | 94 | args = parser.parse_args() 95 | 96 | # Set all the seeds 97 | seed = args.seed 98 | random.seed(seed) 99 | np.random.seed(seed) 100 | torch.manual_seed(seed) 101 | torch.cuda.manual_seed_all(seed) 102 | torch.backends.cudnn.deterministic = True 103 | torch.backends.cudnn.benchmark = False 104 | 105 | # See if CUDA available 106 | device = torch.device("cpu") 107 | if args.n_gpu > 0 and torch.cuda.is_available(): 108 | print("Training on GPU") 109 | device = torch.device("cuda:0") 110 | 111 | # model configuration 112 | bert_model = 'distilbert-base-uncased' 113 | n_epochs = args.n_epochs 114 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True) 115 | 116 | 117 | # wandb initialization 118 | wandb.init( 119 | project="domain-adaptation-twitter-emnlp", 120 | name=args.run_name, 121 | config={ 122 | "epochs": n_epochs, 123 | "train_split_percentage": args.train_pct, 124 | "bert_model": bert_model, 125 | "seed": seed, 126 | "tags": ",".join(args.tags) 127 | } 128 | ) 129 | #wandb.watch(model) 130 | #Create save directory for model 131 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"): 132 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}") 133 | 134 | # Create the dataset 135 | all_dsets = [MultiDomainTwitterDataset( 136 | args.dataset_loc, 137 | [domain], 138 | DistilBertTokenizer.from_pretrained(bert_model) 139 | ) for domain in args.domains[:-1]] 140 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)] 141 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))] 142 | 143 | accs = [] 144 | Ps = [] 145 | Rs = [] 146 | F1s = [] 147 | # Store labels and logits for individual splits for micro F1 148 | labels_all = [] 149 | logits_all = [] 150 | 151 | for i in range(len(all_dsets)): 152 | domain = args.domains[i] 153 | test_dset = all_dsets[i] 154 | # Override the domain IDs 155 | k = 0 156 | for j in range(len(all_dsets)): 157 | if j != i: 158 | all_dsets[j].set_domain_id(k) 159 | k += 1 160 | test_dset.set_domain_id(k) 161 | # For test 162 | #all_dsets = [all_dsets[0], all_dsets[2]] 163 | 164 | # Split the data 165 | if args.indices_dir is None: 166 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]]) 167 | for j in range(len(all_dsets)) if j != i] 168 | else: 169 | # load the indices 170 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i] 171 | subset_indices = defaultdict(lambda: [[], []]) 172 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \ 173 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g: 174 | for l in f: 175 | vals = l.strip().split(',') 176 | subset_indices[int(vals[0])][0].append(int(vals[1])) 177 | for l in g: 178 | vals = l.strip().split(',') 179 | subset_indices[int(vals[0])][1].append(int(vals[1])) 180 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] for d in 181 | subset_indices] 182 | 183 | train_dls = [DataLoader( 184 | subset[0], 185 | batch_size=8, 186 | shuffle=True, 187 | collate_fn=collate_batch_transformer 188 | ) for subset in subsets] 189 | 190 | val_ds = [subset[1] for subset in subsets] 191 | # for vds in val_ds: 192 | # print(vds.indices) 193 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device) 194 | 195 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 196 | # Create the model 197 | init_weights = None 198 | shared_bert = VanillaBert(bert).to(device) 199 | 200 | multi_xformer = MultiDistilBertClassifier( 201 | bert_model, 202 | bert_config, 203 | n_domains=len(train_dls), 204 | init_weights=init_weights 205 | ).to(device) 206 | 207 | model = torch.nn.DataParallel(MultiViewTransformerNetworkSelectiveWeight( 208 | multi_xformer, 209 | shared_bert 210 | )).to(device) 211 | 212 | # Load the best weights 213 | model.load_state_dict(torch.load(f'{args.pretrained_model}/model_{domain}.pth')) 214 | 215 | # Calculate the best attention weights with a grid search 216 | weights = attention_grid_search( 217 | model, 218 | validation_evaluator, 219 | n_epochs, 220 | seed 221 | ) 222 | model.module.weights = weights 223 | print(f"Best weights: {weights}") 224 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/weights_{domain}.txt', 'wt') as f: 225 | f.write(str(weights)) 226 | 227 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False) 228 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate( 229 | model, 230 | plot_callbacks=[plot_label_distribution], 231 | return_labels_logits=True, 232 | return_votes=True 233 | ) 234 | print(f"{domain} F1: {F1}") 235 | print(f"{domain} Accuracy: {acc}") 236 | print() 237 | 238 | wandb.run.summary[f"{domain}-P"] = P 239 | wandb.run.summary[f"{domain}-R"] = R 240 | wandb.run.summary[f"{domain}-F1"] = F1 241 | wandb.run.summary[f"{domain}-Acc"] = acc 242 | # macro and micro F1 are only with respect to the 5 main splits 243 | if i < len(all_dsets): 244 | Ps.append(P) 245 | Rs.append(R) 246 | F1s.append(F1) 247 | accs.append(acc) 248 | labels_all.extend(labels) 249 | logits_all.extend(logits) 250 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f: 251 | for p, l in zip(np.argmax(logits, axis=-1), labels): 252 | f.write(f'{domain}\t{p}\t{l}\n') 253 | 254 | acc, P, R, F1 = acc_f1(logits_all, labels_all) 255 | # Add to wandb 256 | wandb.run.summary[f'test-loss'] = loss 257 | wandb.run.summary[f'test-micro-acc'] = acc 258 | wandb.run.summary[f'test-micro-P'] = P 259 | wandb.run.summary[f'test-micro-R'] = R 260 | wandb.run.summary[f'test-micro-F1'] = F1 261 | 262 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs) 263 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps) 264 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs) 265 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s) 266 | 267 | # wandb.log({f"label-distribution-test-{i}": plots[0]}) 268 | -------------------------------------------------------------------------------- /emnlp_final_experiments/sentiment-analysis/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/copenlu/xformer-multi-source-domain-adaptation/be2d1a132298131df82fe40dd4f6c08dec8b3404/emnlp_final_experiments/sentiment-analysis/.DS_Store -------------------------------------------------------------------------------- /emnlp_final_experiments/sentiment-analysis/analyze_expert_predictions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import random 5 | from typing import AnyStr 6 | from typing import List 7 | import ipdb 8 | import krippendorff 9 | from collections import defaultdict 10 | 11 | import numpy as np 12 | import torch 13 | import wandb 14 | from torch.optim.lr_scheduler import LambdaLR 15 | from torch.utils.data import DataLoader 16 | from torch.utils.data import Subset 17 | from torch.utils.data import random_split 18 | from torch.optim import Adam 19 | from tqdm import tqdm 20 | from transformers import AdamW 21 | from transformers import DistilBertConfig 22 | from transformers import DistilBertTokenizer 23 | from transformers import DistilBertForSequenceClassification 24 | from transformers import get_linear_schedule_with_warmup 25 | 26 | from datareader import MultiDomainSentimentDataset 27 | from datareader import collate_batch_transformer 28 | from metrics import MultiDatasetClassificationEvaluator 29 | from metrics import ClassificationEvaluator 30 | from metrics import acc_f1 31 | 32 | from metrics import plot_label_distribution 33 | from model import MultiTransformerClassifier 34 | from model import VanillaBert 35 | from model import * 36 | from sklearn.model_selection import ParameterSampler 37 | from scipy.special import softmax 38 | from scipy.stats import wasserstein_distance 39 | 40 | 41 | if __name__ == "__main__": 42 | # Define arguments 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str) 45 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8) 46 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0) 47 | parser.add_argument("--pretrained_model", help="Directory with weights to initialize the shared model with", type=str, default=None) 48 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[]) 49 | parser.add_argument("--seed", type=int, help="Random seed", default=1000) 50 | 51 | args = parser.parse_args() 52 | 53 | # Set all the seeds 54 | seed = args.seed 55 | random.seed(seed) 56 | np.random.seed(seed) 57 | torch.manual_seed(seed) 58 | torch.cuda.manual_seed_all(seed) 59 | torch.backends.cudnn.deterministic = True 60 | torch.backends.cudnn.benchmark = False 61 | 62 | # See if CUDA available 63 | device = torch.device("cpu") 64 | if args.n_gpu > 0 and torch.cuda.is_available(): 65 | print("Training on GPU") 66 | device = torch.device("cuda:0") 67 | 68 | # model configuration 69 | bert_model = 'distilbert-base-uncased' 70 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True) 71 | 72 | # Create the dataset 73 | all_dsets = [MultiDomainSentimentDataset( 74 | args.dataset_loc, 75 | [domain], 76 | DistilBertTokenizer.from_pretrained(bert_model) 77 | ) for domain in args.domains] 78 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)] 79 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))] 80 | 81 | for i in range(len(all_dsets)): 82 | domain = args.domains[i] 83 | 84 | test_dset = all_dsets[i] 85 | 86 | dataloader = DataLoader( 87 | test_dset, 88 | batch_size=4, 89 | shuffle=True, 90 | collate_fn=collate_batch_transformer 91 | ) 92 | 93 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 94 | # Create the model 95 | 96 | model = torch.nn.DataParallel(MultiViewTransformerNetworkAveragingIndividuals( 97 | bert_model, 98 | bert_config, 99 | len(all_dsets) - 1 100 | )).to(device) 101 | model.module.average = True 102 | # load the trained model 103 | 104 | # Load the best weights 105 | for v in range(len(all_dsets)-1): 106 | model.module.domain_experts[v].load_state_dict(torch.load(f'{args.pretrained_model}/model_{domain}_{v}.pth')) 107 | model.module.shared_bert.load_state_dict(torch.load(f'{args.pretrained_model}/model_{domain}_{len(all_dsets)-1}.pth')) 108 | 109 | logits_all = [[] for d in range(len(all_dsets))] 110 | for batch in tqdm(dataloader): 111 | model.train() 112 | batch = tuple(t.to(device) for t in batch) 113 | input_ids = batch[0] 114 | masks = batch[1] 115 | labels = batch[2] 116 | # Testing with random domains to see if any effect 117 | # domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) 118 | domains = batch[3] 119 | 120 | logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels, return_logits=True) 121 | for k,l in enumerate(logits): 122 | logits_all[k].append(l.detach().cpu().numpy()) 123 | 124 | print(domain) 125 | probs_all = [softmax(np.concatenate(l), axis=-1)[:,1] for l in logits_all] 126 | for i in range(len(probs_all)): 127 | for j in range(i+1, len(probs_all)): 128 | print(wasserstein_distance(probs_all[i], probs_all[j])) 129 | preds_all = np.asarray([np.argmax(np.concatenate(l), axis=-1) for l in logits_all]) 130 | print(f"alpha: {krippendorff.alpha(preds_all, level_of_measurement='nominal')}") 131 | print() -------------------------------------------------------------------------------- /emnlp_final_experiments/sentiment-analysis/analyze_expert_predictions_cnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import random 5 | from typing import AnyStr 6 | from typing import List 7 | import ipdb 8 | import krippendorff 9 | from collections import defaultdict 10 | 11 | import numpy as np 12 | import torch 13 | import wandb 14 | from torch.optim.lr_scheduler import LambdaLR 15 | from torch.utils.data import DataLoader 16 | from torch.utils.data import Subset 17 | from torch.utils.data import random_split 18 | from torch.optim import Adam 19 | from tqdm import tqdm 20 | from transformers import AdamW 21 | from transformers import DistilBertConfig 22 | from transformers import DistilBertTokenizer 23 | from transformers import DistilBertForSequenceClassification 24 | from transformers import get_linear_schedule_with_warmup 25 | 26 | from datareader_cnn import MultiDomainSentimentDataset 27 | from datareader_cnn import collate_batch_cnn 28 | from datareader_cnn import FasttextTokenizer 29 | from metrics import MultiDatasetClassificationEvaluator 30 | from metrics import ClassificationEvaluator 31 | from metrics import acc_f1 32 | 33 | from metrics import plot_label_distribution 34 | from model import MultiTransformerClassifier 35 | from model import VanillaBert 36 | from model import * 37 | from sklearn.model_selection import ParameterSampler 38 | from scipy.special import softmax 39 | from scipy.stats import wasserstein_distance 40 | 41 | 42 | if __name__ == "__main__": 43 | # Define arguments 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str) 46 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8) 47 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0) 48 | parser.add_argument("--pretrained_model", help="Directory with weights to initialize the shared model with", type=str, default=None) 49 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[]) 50 | parser.add_argument("--seed", type=int, help="Random seed", default=1000) 51 | parser.add_argument("--dropout", help="Path to directory with pretrained embeddings", default=0.3, type=float) 52 | 53 | 54 | #CNN stuff 55 | parser.add_argument("--embedding_dim", help="Dimension of embeddings", choices=[50, 100, 200, 300], default=100, 56 | type=int) 57 | parser.add_argument("--in_channels", type=int, default=1) 58 | parser.add_argument("--out_channels", type=int, default=100) 59 | parser.add_argument("--kernel_heights", help="filter windows", type=int, nargs='+', default=[2, 4, 5]) 60 | parser.add_argument("--stride", help="stride", type=int, default=1) 61 | parser.add_argument("--padding", help="padding", type=int, default=0) 62 | 63 | args = parser.parse_args() 64 | 65 | # Set all the seeds 66 | seed = args.seed 67 | random.seed(seed) 68 | np.random.seed(seed) 69 | torch.manual_seed(seed) 70 | torch.cuda.manual_seed_all(seed) 71 | torch.backends.cudnn.deterministic = True 72 | torch.backends.cudnn.benchmark = False 73 | 74 | # See if CUDA available 75 | device = torch.device("cpu") 76 | if args.n_gpu > 0 and torch.cuda.is_available(): 77 | print("Training on GPU") 78 | device = torch.device("cuda:0") 79 | 80 | tokenizer = FasttextTokenizer(f"{args.dataset_loc}/vocabulary.txt") 81 | 82 | # Create the dataset 83 | all_dsets = [MultiDomainSentimentDataset( 84 | args.dataset_loc, 85 | [domain], 86 | tokenizer 87 | ) for domain in args.domains] 88 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)] 89 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))] 90 | 91 | for i in range(len(all_dsets)): 92 | domain = args.domains[i] 93 | 94 | test_dset = all_dsets[i] 95 | 96 | dataloader = DataLoader( 97 | test_dset, 98 | batch_size=4, 99 | shuffle=True, 100 | collate_fn=collate_batch_cnn 101 | ) 102 | 103 | embeddings = np.load(f"{args.dataset_loc}/fasttext_embeddings.npy") 104 | model = torch.nn.DataParallel(MultiViewCNNAveragingIndividuals( 105 | args, 106 | embeddings, 107 | len(all_dsets) - 1 108 | )).to(device) 109 | model.module.average = True 110 | # load the trained model 111 | 112 | # Load the best weights 113 | for v in range(len(all_dsets)-1): 114 | model.module.domain_experts[v].load_state_dict(torch.load(f'{args.pretrained_model}/model_{domain}_{v}.pth')) 115 | model.module.shared_model.load_state_dict(torch.load(f'{args.pretrained_model}/model_{domain}_{len(all_dsets)-1}.pth')) 116 | 117 | logits_all = [[] for d in range(len(all_dsets))] 118 | for batch in tqdm(dataloader): 119 | model.train() 120 | batch = tuple(t.to(device) for t in batch) 121 | input_ids = batch[0] 122 | masks = batch[1] 123 | labels = batch[2] 124 | # Testing with random domains to see if any effect 125 | # domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) 126 | domains = batch[3] 127 | 128 | logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels, return_logits=True) 129 | for k,l in enumerate(logits): 130 | logits_all[k].append(l.detach().cpu().numpy()) 131 | 132 | print(domain) 133 | probs_all = [softmax(np.concatenate(l), axis=-1)[:,1] for l in logits_all] 134 | for i in range(len(probs_all)): 135 | for j in range(i+1, len(probs_all)): 136 | print(wasserstein_distance(probs_all[i], probs_all[j])) 137 | preds_all = np.asarray([np.argmax(np.concatenate(l), axis=-1) for l in logits_all]) 138 | print(f"alpha: {krippendorff.alpha(preds_all, level_of_measurement='nominal')}") 139 | print() -------------------------------------------------------------------------------- /emnlp_final_experiments/sentiment-analysis/train_basic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import random 5 | from typing import AnyStr 6 | from typing import List 7 | import ipdb 8 | from collections import defaultdict 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import torch 13 | import wandb 14 | from torch.optim.lr_scheduler import LambdaLR 15 | from torch.utils.data import DataLoader 16 | from torch.utils.data import Subset 17 | from torch.utils.data import random_split 18 | from torch.optim import Adam 19 | from tqdm import tqdm 20 | from transformers import AdamW 21 | from transformers import DistilBertConfig 22 | from transformers import DistilBertTokenizer 23 | from transformers import DistilBertModel 24 | from transformers import BertConfig 25 | from transformers import BertTokenizer 26 | from transformers import BertModel 27 | from transformers import get_linear_schedule_with_warmup 28 | 29 | from datareader import MultiDomainSentimentDataset 30 | from datareader import collate_batch_transformer 31 | from metrics import MultiDatasetClassificationEvaluator, acc_f1 32 | from metrics import ClassificationEvaluator 33 | 34 | from metrics import plot_label_distribution 35 | from model import * 36 | 37 | 38 | def train( 39 | model: torch.nn.Module, 40 | train_dls: List[DataLoader], 41 | optimizer: torch.optim.Optimizer, 42 | scheduler: LambdaLR, 43 | validation_evaluator: MultiDatasetClassificationEvaluator, 44 | n_epochs: int, 45 | device: AnyStr, 46 | log_interval: int = 1, 47 | patience: int = 10, 48 | model_dir: str = "wandb_local", 49 | gradient_accumulation: int = 1, 50 | domain_name: str = '' 51 | ): 52 | #best_loss = float('inf') 53 | best_acc = 0.0 54 | patience_counter = 0 55 | 56 | epoch_counter = 0 57 | total = sum(len(dl) for dl in train_dls) 58 | 59 | # Main loop 60 | while epoch_counter < n_epochs: 61 | dl_iters = [iter(dl) for dl in train_dls] 62 | dl_idx = list(range(len(dl_iters))) 63 | finished = [0] * len(dl_iters) 64 | i = 0 65 | with tqdm(total=total, desc="Training") as pbar: 66 | while sum(finished) < len(dl_iters): 67 | random.shuffle(dl_idx) 68 | for d in dl_idx: 69 | domain_dl = dl_iters[d] 70 | batches = [] 71 | try: 72 | for j in range(gradient_accumulation): 73 | batches.append(next(domain_dl)) 74 | except StopIteration: 75 | finished[d] = 1 76 | if len(batches) == 0: 77 | continue 78 | optimizer.zero_grad() 79 | for batch in batches: 80 | model.train() 81 | batch = tuple(t.to(device) for t in batch) 82 | input_ids = batch[0] 83 | masks = batch[1] 84 | labels = batch[2] 85 | # Testing with random domains to see if any effect 86 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) 87 | domains = batch[3] 88 | 89 | loss, logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels) 90 | loss = loss / gradient_accumulation 91 | 92 | if i % log_interval == 0: 93 | wandb.log({ 94 | "Loss": loss.item() 95 | }) 96 | 97 | loss.backward() 98 | i += 1 99 | pbar.update(1) 100 | 101 | optimizer.step() 102 | if scheduler is not None: 103 | scheduler.step() 104 | 105 | gc.collect() 106 | 107 | # Inline evaluation 108 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) 109 | print(f"Validation acc: {acc}") 110 | 111 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 112 | 113 | # Saving the best model and early stopping 114 | #if val_loss < best_loss: 115 | if acc > best_acc: 116 | best_model = model.state_dict() 117 | #best_loss = val_loss 118 | best_acc = acc 119 | #wandb.run.summary['best_validation_loss'] = best_loss 120 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 121 | patience_counter = 0 122 | # Log to wandb 123 | wandb.log({ 124 | 'Validation accuracy': acc, 125 | 'Validation Precision': P, 126 | 'Validation Recall': R, 127 | 'Validation F1': F1, 128 | 'Validation loss': val_loss}) 129 | else: 130 | patience_counter += 1 131 | # Stop training once we have lost patience 132 | if patience_counter == patience: 133 | break 134 | 135 | gc.collect() 136 | epoch_counter += 1 137 | 138 | 139 | if __name__ == "__main__": 140 | # Define arguments 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str) 143 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8) 144 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0) 145 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1) 146 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200) 147 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2) 148 | parser.add_argument("--pretrained_model", help="Weights to initialize the model with", type=str, default=None) 149 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[]) 150 | parser.add_argument("--seed", type=int, help="Random seed", default=1000) 151 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline") 152 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str) 153 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[]) 154 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert") 155 | parser.add_argument("--ff_dim", help="The dimensionality of the feedforward network in the sluice", type=int, default=768) 156 | parser.add_argument("--batch_size", help="The batch size", type=int, default=8) 157 | parser.add_argument("--lr", help="Learning rate", type=float, default=3e-5) 158 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01) 159 | parser.add_argument("--lambd", help="l2 reg", type=float, default=10e-3) 160 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int) 161 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None) 162 | parser.add_argument("--full_bert", help="Specify to use full bert model", action="store_true") 163 | 164 | args = parser.parse_args() 165 | 166 | # Set all the seeds 167 | seed = args.seed 168 | random.seed(seed) 169 | np.random.seed(seed) 170 | torch.manual_seed(seed) 171 | torch.cuda.manual_seed_all(seed) 172 | torch.backends.cudnn.deterministic = True 173 | torch.backends.cudnn.benchmark = False 174 | 175 | # See if CUDA available 176 | device = torch.device("cpu") 177 | if args.n_gpu > 0 and torch.cuda.is_available(): 178 | print("Training on GPU") 179 | device = torch.device("cuda:0") 180 | 181 | # model configuration 182 | batch_size = args.batch_size 183 | lr = args.lr 184 | weight_decay = args.weight_decay 185 | n_epochs = args.n_epochs 186 | if args.full_bert: 187 | bert_model = 'bert-base-uncased' 188 | bert_config = BertConfig.from_pretrained(bert_model, num_labels=2) 189 | tokenizer = BertTokenizer.from_pretrained(bert_model) 190 | else: 191 | bert_model = 'distilbert-base-uncased' 192 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2) 193 | tokenizer = DistilBertTokenizer.from_pretrained(bert_model) 194 | 195 | # wandb initialization 196 | wandb.init( 197 | project="domain-adaptation-sentiment-emnlp", 198 | name=args.run_name, 199 | config={ 200 | "epochs": n_epochs, 201 | "learning_rate": lr, 202 | "warmup": args.warmup_steps, 203 | "weight_decay": weight_decay, 204 | "batch_size": batch_size, 205 | "train_split_percentage": args.train_pct, 206 | "bert_model": bert_model, 207 | "seed": seed, 208 | "pretrained_model": args.pretrained_model, 209 | "tags": ",".join(args.tags) 210 | } 211 | ) 212 | # Create save directory for model 213 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"): 214 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}") 215 | 216 | # Create the dataset 217 | all_dsets = [MultiDomainSentimentDataset( 218 | args.dataset_loc, 219 | [domain], 220 | tokenizer 221 | ) for domain in args.domains] 222 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)] 223 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))] 224 | 225 | accs = [] 226 | Ps = [] 227 | Rs = [] 228 | F1s = [] 229 | # Store labels and logits for individual splits for micro F1 230 | labels_all = [] 231 | logits_all = [] 232 | 233 | for i in range(len(all_dsets)): 234 | domain = args.domains[i] 235 | test_dset = all_dsets[i] 236 | # Override the domain IDs 237 | k = 0 238 | for j in range(len(all_dsets)): 239 | if j != i: 240 | all_dsets[j].set_domain_id(k) 241 | k += 1 242 | test_dset.set_domain_id(k) 243 | 244 | # Split the data 245 | if args.indices_dir is None: 246 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]]) 247 | for j in range(len(all_dsets)) if j != i] 248 | # Save the indices 249 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/train_idx_{domain}.txt', 'wt') as f, \ 250 | open(f'{args.model_dir}/{Path(wandb.run.dir).name}/val_idx_{domain}.txt', 'wt') as g: 251 | for j, subset in enumerate(subsets): 252 | for idx in subset[0].indices: 253 | f.write(f'{j},{idx}\n') 254 | for idx in subset[1].indices: 255 | g.write(f'{j},{idx}\n') 256 | else: 257 | # load the indices 258 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i] 259 | subset_indices = defaultdict(lambda: [[], []]) 260 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \ 261 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g: 262 | for l in f: 263 | vals = l.strip().split(',') 264 | subset_indices[int(vals[0])][0].append(int(vals[1])) 265 | for l in g: 266 | vals = l.strip().split(',') 267 | subset_indices[int(vals[0])][1].append(int(vals[1])) 268 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] 269 | for d in subset_indices] 270 | 271 | train_dls = [DataLoader( 272 | subset[0], 273 | batch_size=batch_size, 274 | shuffle=True, 275 | collate_fn=collate_batch_transformer 276 | ) for subset in subsets] 277 | 278 | val_ds = [subset[1] for subset in subsets] 279 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device) 280 | 281 | 282 | # Create the model 283 | if args.full_bert: 284 | bert = BertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 285 | else: 286 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 287 | model = VanillaBert(bert).to(device) 288 | if args.pretrained_model is not None: 289 | model.load_state_dict(torch.load(f"{args.pretrained_model}/model_{domain}.pth")) 290 | 291 | # Create the optimizer 292 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 293 | optimizer_grouped_parameters = [ 294 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 295 | 'weight_decay': weight_decay}, 296 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 297 | ] 298 | # optimizer = Adam(optimizer_grouped_parameters, lr=1e-3) 299 | # scheduler = None 300 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr) 301 | scheduler = get_linear_schedule_with_warmup( 302 | optimizer, 303 | args.warmup_steps, 304 | n_epochs * sum([len(train_dl) for train_dl in train_dls]) 305 | ) 306 | 307 | # Train 308 | train( 309 | model, 310 | train_dls, 311 | optimizer, 312 | scheduler, 313 | validation_evaluator, 314 | n_epochs, 315 | device, 316 | args.log_interval, 317 | model_dir=args.model_dir, 318 | gradient_accumulation=args.gradient_accumulation, 319 | domain_name=domain 320 | ) 321 | 322 | # Load the best weights 323 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth')) 324 | 325 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False) 326 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate( 327 | model, 328 | plot_callbacks=[plot_label_distribution], 329 | return_labels_logits=True, 330 | return_votes=True 331 | ) 332 | print(f"{domain} F1: {F1}") 333 | print(f"{domain} Accuracy: {acc}") 334 | print() 335 | 336 | wandb.run.summary[f"{domain}-P"] = P 337 | wandb.run.summary[f"{domain}-R"] = R 338 | wandb.run.summary[f"{domain}-F1"] = F1 339 | wandb.run.summary[f"{domain}-Acc"] = acc 340 | Ps.append(P) 341 | Rs.append(R) 342 | F1s.append(F1) 343 | accs.append(acc) 344 | labels_all.extend(labels) 345 | logits_all.extend(logits) 346 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f: 347 | for p, l in zip(np.argmax(logits, axis=-1), labels): 348 | f.write(f'{domain}\t{p}\t{l}\n') 349 | 350 | acc, P, R, F1 = acc_f1(logits_all, labels_all) 351 | # Add to wandb 352 | wandb.run.summary[f'test-loss'] = loss 353 | wandb.run.summary[f'test-micro-acc'] = acc 354 | wandb.run.summary[f'test-micro-P'] = P 355 | wandb.run.summary[f'test-micro-R'] = R 356 | wandb.run.summary[f'test-micro-F1'] = F1 357 | 358 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs) 359 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps) 360 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs) 361 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s) 362 | 363 | #wandb.log({f"label-distribution-test-{i}": plots[0]}) 364 | -------------------------------------------------------------------------------- /emnlp_final_experiments/sentiment-analysis/train_basic_domain_adversarial.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import random 5 | from typing import AnyStr 6 | from typing import List 7 | import ipdb 8 | from collections import defaultdict 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import torch 13 | import wandb 14 | from torch.optim.lr_scheduler import LambdaLR 15 | from torch.utils.data import DataLoader 16 | from torch.utils.data import Subset 17 | from torch.utils.data import random_split 18 | from torch.optim import Adam 19 | from tqdm import tqdm 20 | from transformers import AdamW 21 | from transformers import DistilBertConfig 22 | from transformers import DistilBertTokenizer 23 | from transformers import DistilBertModel 24 | from transformers import BertConfig 25 | from transformers import BertTokenizer 26 | from transformers import BertModel 27 | from transformers import get_linear_schedule_with_warmup 28 | 29 | from datareader import MultiDomainSentimentDataset 30 | from datareader import collate_batch_transformer 31 | from metrics import MultiDatasetClassificationEvaluator, acc_f1 32 | from metrics import ClassificationEvaluator 33 | 34 | from metrics import plot_label_distribution 35 | from model import * 36 | 37 | 38 | def train( 39 | model: torch.nn.Module, 40 | train_dls: List[DataLoader], 41 | optimizer: torch.optim.Optimizer, 42 | scheduler: LambdaLR, 43 | validation_evaluator: MultiDatasetClassificationEvaluator, 44 | n_epochs: int, 45 | device: AnyStr, 46 | log_interval: int = 1, 47 | patience: int = 10, 48 | model_dir: str = "wandb_local", 49 | gradient_accumulation: int = 1, 50 | domain_name: str = '' 51 | ): 52 | #best_loss = float('inf') 53 | best_acc = 0.0 54 | patience_counter = 0 55 | 56 | epoch_counter = 0 57 | total = sum(len(dl) for dl in train_dls) 58 | 59 | # Main loop 60 | while epoch_counter < n_epochs: 61 | dl_iters = [iter(dl) for dl in train_dls] 62 | dl_idx = list(range(len(dl_iters))) 63 | finished = [0] * len(dl_iters) 64 | i = 0 65 | with tqdm(total=total, desc="Training") as pbar: 66 | while sum(finished) < len(dl_iters): 67 | random.shuffle(dl_idx) 68 | for d in dl_idx: 69 | domain_dl = dl_iters[d] 70 | batches = [] 71 | try: 72 | for j in range(gradient_accumulation): 73 | batches.append(next(domain_dl)) 74 | except StopIteration: 75 | finished[d] = 1 76 | if len(batches) == 0: 77 | continue 78 | optimizer.zero_grad() 79 | for batch in batches: 80 | model.train() 81 | batch = tuple(t.to(device) for t in batch) 82 | input_ids = batch[0] 83 | masks = batch[1] 84 | labels = batch[2] 85 | # Null the labels if its the test data 86 | if d == len(train_dls) - 1: 87 | labels = None 88 | # Testing with random domains to see if any effect 89 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) 90 | domains = batch[3] 91 | 92 | loss, logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels) 93 | loss = loss.mean() / gradient_accumulation 94 | 95 | if i % log_interval == 0: 96 | wandb.log({ 97 | "Loss": loss.item() 98 | }) 99 | 100 | loss.backward() 101 | i += 1 102 | pbar.update(1) 103 | 104 | optimizer.step() 105 | if scheduler is not None: 106 | scheduler.step() 107 | 108 | gc.collect() 109 | 110 | # Inline evaluation 111 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) 112 | print(f"Validation acc: {acc}") 113 | 114 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 115 | 116 | # Saving the best model and early stopping 117 | #if val_loss < best_loss: 118 | if acc > best_acc: 119 | best_model = model.state_dict() 120 | #best_loss = val_loss 121 | best_acc = acc 122 | #wandb.run.summary['best_validation_loss'] = best_loss 123 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 124 | patience_counter = 0 125 | # Log to wandb 126 | wandb.log({ 127 | 'Validation accuracy': acc, 128 | 'Validation Precision': P, 129 | 'Validation Recall': R, 130 | 'Validation F1': F1, 131 | 'Validation loss': val_loss}) 132 | else: 133 | patience_counter += 1 134 | # Stop training once we have lost patience 135 | if patience_counter == patience: 136 | break 137 | 138 | gc.collect() 139 | epoch_counter += 1 140 | 141 | 142 | if __name__ == "__main__": 143 | # Define arguments 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str) 146 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8) 147 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0) 148 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1) 149 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200) 150 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2) 151 | parser.add_argument("--pretrained_model", help="Weights to initialize the model with", type=str, default=None) 152 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[]) 153 | parser.add_argument("--seed", type=int, help="Random seed", default=1000) 154 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline") 155 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str) 156 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[]) 157 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert") 158 | parser.add_argument("--ff_dim", help="The dimensionality of the feedforward network in the sluice", type=int, default=768) 159 | parser.add_argument("--batch_size", help="The batch size", type=int, default=8) 160 | parser.add_argument("--lr", help="Learning rate", type=float, default=3e-5) 161 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01) 162 | parser.add_argument("--lambd", help="l2 reg", type=float, default=10e-3) 163 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int) 164 | parser.add_argument("--supervision_layer", help="The layer at which to use domain adversarial supervision", default=12, type=int) 165 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None) 166 | parser.add_argument("--full_bert", help="Specify to use full bert model", action="store_true") 167 | 168 | 169 | args = parser.parse_args() 170 | 171 | # Set all the seeds 172 | seed = args.seed 173 | random.seed(seed) 174 | np.random.seed(seed) 175 | torch.manual_seed(seed) 176 | torch.cuda.manual_seed_all(seed) 177 | torch.backends.cudnn.deterministic = True 178 | torch.backends.cudnn.benchmark = False 179 | 180 | # See if CUDA available 181 | device = torch.device("cpu") 182 | if args.n_gpu > 0 and torch.cuda.is_available(): 183 | print("Training on GPU") 184 | device = torch.device("cuda:0") 185 | 186 | # model configuration 187 | batch_size = args.batch_size 188 | lr = args.lr 189 | weight_decay = args.weight_decay 190 | n_epochs = args.n_epochs 191 | if args.full_bert: 192 | bert_model = 'bert-base-uncased' 193 | bert_config = BertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True) 194 | tokenizer = BertTokenizer.from_pretrained(bert_model) 195 | else: 196 | bert_model = 'distilbert-base-uncased' 197 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True) 198 | tokenizer = DistilBertTokenizer.from_pretrained(bert_model) 199 | 200 | # wandb initialization 201 | wandb.init( 202 | project="domain-adaptation-sentiment-emnlp", 203 | name=args.run_name, 204 | config={ 205 | "epochs": n_epochs, 206 | "learning_rate": lr, 207 | "warmup": args.warmup_steps, 208 | "weight_decay": weight_decay, 209 | "batch_size": batch_size, 210 | "train_split_percentage": args.train_pct, 211 | "bert_model": bert_model, 212 | "seed": seed, 213 | "pretrained_model": args.pretrained_model, 214 | "tags": ",".join(args.tags) 215 | } 216 | ) 217 | # Create save directory for model 218 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"): 219 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}") 220 | 221 | # Create the dataset 222 | all_dsets = [MultiDomainSentimentDataset( 223 | args.dataset_loc, 224 | [domain], 225 | tokenizer 226 | ) for domain in args.domains] 227 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)] 228 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))] 229 | 230 | accs = [] 231 | Ps = [] 232 | Rs = [] 233 | F1s = [] 234 | # Store labels and logits for individual splits for micro F1 235 | labels_all = [] 236 | logits_all = [] 237 | 238 | for i in range(len(all_dsets)): 239 | domain = args.domains[i] 240 | test_dset = all_dsets[i] 241 | # Override the domain IDs 242 | k = 0 243 | for j in range(len(all_dsets)): 244 | if j != i: 245 | all_dsets[j].set_domain_id(k) 246 | k += 1 247 | test_dset.set_domain_id(k) 248 | 249 | # Split the data 250 | if args.indices_dir is None: 251 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]]) 252 | for j in range(len(all_dsets)) if j != i] 253 | else: 254 | # load the indices 255 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i] 256 | subset_indices = defaultdict(lambda: [[], []]) 257 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \ 258 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g: 259 | for l in f: 260 | vals = l.strip().split(',') 261 | subset_indices[int(vals[0])][0].append(int(vals[1])) 262 | for l in g: 263 | vals = l.strip().split(',') 264 | subset_indices[int(vals[0])][1].append(int(vals[1])) 265 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] 266 | for d in subset_indices] 267 | train_dls = [DataLoader( 268 | subset[0], 269 | batch_size=batch_size, 270 | shuffle=True, 271 | collate_fn=collate_batch_transformer 272 | ) for subset in subsets] 273 | # Add test data for domain adversarial training 274 | train_dls += [DataLoader( 275 | test_dset, 276 | batch_size=batch_size, 277 | shuffle=True, 278 | collate_fn=collate_batch_transformer 279 | )] 280 | 281 | val_ds = [subset[1] for subset in subsets] 282 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device) 283 | 284 | # Create the model 285 | if args.full_bert: 286 | bert = BertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 287 | else: 288 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 289 | model = torch.nn.DataParallel(DomainAdversarialBert( 290 | bert, 291 | n_domains=len(train_dls), 292 | supervision_layer=args.supervision_layer 293 | )).to(device) 294 | if args.pretrained_model is not None: 295 | model.load_state_dict(torch.load(f"{args.pretrained_model}/model_{domain}.pth")) 296 | 297 | # Create the optimizer 298 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 299 | optimizer_grouped_parameters = [ 300 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 301 | 'weight_decay': weight_decay}, 302 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 303 | ] 304 | # optimizer = Adam(optimizer_grouped_parameters, lr=1e-3) 305 | # scheduler = None 306 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr) 307 | scheduler = get_linear_schedule_with_warmup( 308 | optimizer, 309 | args.warmup_steps, 310 | n_epochs * sum([len(train_dl) for train_dl in train_dls]) 311 | ) 312 | 313 | # Train 314 | train( 315 | model, 316 | train_dls, 317 | optimizer, 318 | scheduler, 319 | validation_evaluator, 320 | n_epochs, 321 | device, 322 | args.log_interval, 323 | model_dir=args.model_dir, 324 | gradient_accumulation=args.gradient_accumulation, 325 | domain_name=domain 326 | ) 327 | 328 | # Load the best weights 329 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth')) 330 | 331 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False) 332 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate( 333 | model, 334 | plot_callbacks=[plot_label_distribution], 335 | return_labels_logits=True, 336 | return_votes=True 337 | ) 338 | print(f"{domain} F1: {F1}") 339 | print(f"{domain} Accuracy: {acc}") 340 | print() 341 | 342 | wandb.run.summary[f"{domain}-P"] = P 343 | wandb.run.summary[f"{domain}-R"] = R 344 | wandb.run.summary[f"{domain}-F1"] = F1 345 | wandb.run.summary[f"{domain}-Acc"] = acc 346 | Ps.append(P) 347 | Rs.append(R) 348 | F1s.append(F1) 349 | accs.append(acc) 350 | labels_all.extend(labels) 351 | logits_all.extend(logits) 352 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f: 353 | for p, l in zip(np.argmax(logits, axis=-1), labels): 354 | f.write(f'{domain}\t{p}\t{l}\n') 355 | 356 | acc, P, R, F1 = acc_f1(logits_all, labels_all) 357 | # Add to wandb 358 | wandb.run.summary[f'test-loss'] = loss 359 | wandb.run.summary[f'test-micro-acc'] = acc 360 | wandb.run.summary[f'test-micro-P'] = P 361 | wandb.run.summary[f'test-micro-R'] = R 362 | wandb.run.summary[f'test-micro-F1'] = F1 363 | 364 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs) 365 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps) 366 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs) 367 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s) 368 | 369 | #wandb.log({f"label-distribution-test-{i}": plots[0]}) 370 | -------------------------------------------------------------------------------- /emnlp_final_experiments/sentiment-analysis/train_multi_view.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import random 5 | from typing import AnyStr 6 | from typing import List 7 | import ipdb 8 | import krippendorff 9 | from collections import defaultdict 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | import wandb 15 | from torch.optim.lr_scheduler import LambdaLR 16 | from torch.utils.data import DataLoader 17 | from torch.utils.data import Subset 18 | from torch.utils.data import random_split 19 | from torch.optim import Adam 20 | from tqdm import tqdm 21 | from transformers import AdamW 22 | from transformers import DistilBertConfig 23 | from transformers import DistilBertTokenizer 24 | from transformers import DistilBertForSequenceClassification 25 | from transformers import get_linear_schedule_with_warmup 26 | 27 | from datareader import MultiDomainSentimentDataset 28 | from datareader import collate_batch_transformer 29 | from metrics import MultiDatasetClassificationEvaluator 30 | from metrics import ClassificationEvaluator 31 | from metrics import acc_f1 32 | 33 | from metrics import plot_label_distribution 34 | from model import MultiTransformerClassifier 35 | from model import VanillaBert 36 | from model import * 37 | 38 | 39 | def train( 40 | model: torch.nn.Module, 41 | train_dls: List[DataLoader], 42 | optimizer: torch.optim.Optimizer, 43 | scheduler: LambdaLR, 44 | validation_evaluator: MultiDatasetClassificationEvaluator, 45 | n_epochs: int, 46 | device: AnyStr, 47 | log_interval: int = 1, 48 | patience: int = 10, 49 | model_dir: str = "wandb_local", 50 | gradient_accumulation: int = 1, 51 | domain_name: str = '' 52 | ): 53 | #best_loss = float('inf') 54 | best_acc = 0.0 55 | patience_counter = 0 56 | 57 | epoch_counter = 0 58 | total = sum(len(dl) for dl in train_dls) 59 | 60 | # Main loop 61 | while epoch_counter < n_epochs: 62 | dl_iters = [iter(dl) for dl in train_dls] 63 | dl_idx = list(range(len(dl_iters))) 64 | finished = [0] * len(dl_iters) 65 | i = 0 66 | with tqdm(total=total, desc="Training") as pbar: 67 | while sum(finished) < len(dl_iters): 68 | random.shuffle(dl_idx) 69 | for d in dl_idx: 70 | domain_dl = dl_iters[d] 71 | batches = [] 72 | try: 73 | for j in range(gradient_accumulation): 74 | batches.append(next(domain_dl)) 75 | except StopIteration: 76 | finished[d] = 1 77 | if len(batches) == 0: 78 | continue 79 | optimizer.zero_grad() 80 | for batch in batches: 81 | model.train() 82 | batch = tuple(t.to(device) for t in batch) 83 | input_ids = batch[0] 84 | masks = batch[1] 85 | labels = batch[2] 86 | # Testing with random domains to see if any effect 87 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) 88 | domains = batch[3] 89 | 90 | loss, logits, alpha = model(input_ids, attention_mask=masks, domains=domains, labels=labels, ret_alpha = True) 91 | loss = loss.mean() / gradient_accumulation 92 | if i % log_interval == 0: 93 | # wandb.log({ 94 | # "Loss": loss.item(), 95 | # "alpha0": alpha[:,0].cpu(), 96 | # "alpha1": alpha[:, 1].cpu(), 97 | # "alpha2": alpha[:, 2].cpu(), 98 | # "alpha_shared": alpha[:, 3].cpu() 99 | # }) 100 | wandb.log({ 101 | "Loss": loss.item() 102 | }) 103 | 104 | loss.backward() 105 | i += 1 106 | pbar.update(1) 107 | 108 | optimizer.step() 109 | if scheduler is not None: 110 | scheduler.step() 111 | 112 | gc.collect() 113 | 114 | # Inline evaluation 115 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) 116 | print(f"Validation acc: {acc}") 117 | 118 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 119 | 120 | # Saving the best model and early stopping 121 | #if val_loss < best_loss: 122 | if acc > best_acc: 123 | best_model = model.state_dict() 124 | #best_loss = val_loss 125 | best_acc = acc 126 | #wandb.run.summary['best_validation_loss'] = best_loss 127 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 128 | patience_counter = 0 129 | # Log to wandb 130 | wandb.log({ 131 | 'Validation accuracy': acc, 132 | 'Validation Precision': P, 133 | 'Validation Recall': R, 134 | 'Validation F1': F1, 135 | 'Validation loss': val_loss}) 136 | else: 137 | patience_counter += 1 138 | # Stop training once we have lost patience 139 | if patience_counter == patience: 140 | break 141 | 142 | gc.collect() 143 | epoch_counter += 1 144 | 145 | 146 | if __name__ == "__main__": 147 | # Define arguments 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str) 150 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8) 151 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0) 152 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1) 153 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200) 154 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2) 155 | parser.add_argument("--pretrained_bert", help="Directory with weights to initialize the shared model with", type=str, default=None) 156 | parser.add_argument("--pretrained_multi_xformer", help="Directory with weights to initialize the domain specific models", type=str, default=None) 157 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[]) 158 | parser.add_argument("--seed", type=int, help="Random seed", default=1000) 159 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline") 160 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str) 161 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[]) 162 | parser.add_argument("--batch_size", help="The batch size", type=int, default=16) 163 | parser.add_argument("--lr", help="Learning rate", type=float, default=1e-5) 164 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01) 165 | parser.add_argument("--n_heads", help="Number of transformer heads", default=6, type=int) 166 | parser.add_argument("--n_layers", help="Number of transformer layers", default=6, type=int) 167 | parser.add_argument("--d_model", help="Transformer model size", default=768, type=int) 168 | parser.add_argument("--ff_dim", help="Intermediate feedforward size", default=2048, type=int) 169 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int) 170 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert") 171 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None) 172 | parser.add_argument("--ensemble_basic", help="Use averaging for the ensembling method", action="store_true") 173 | parser.add_argument("--ensemble_avg_learned", help="Use learned averaging for the ensembling method", action="store_true") 174 | 175 | 176 | args = parser.parse_args() 177 | 178 | # Set all the seeds 179 | seed = args.seed 180 | random.seed(seed) 181 | np.random.seed(seed) 182 | torch.manual_seed(seed) 183 | torch.cuda.manual_seed_all(seed) 184 | torch.backends.cudnn.deterministic = True 185 | torch.backends.cudnn.benchmark = False 186 | 187 | # See if CUDA available 188 | device = torch.device("cpu") 189 | if args.n_gpu > 0 and torch.cuda.is_available(): 190 | print("Training on GPU") 191 | device = torch.device("cuda:0") 192 | 193 | # model configuration 194 | bert_model = 'distilbert-base-uncased' 195 | batch_size = args.batch_size 196 | lr = args.lr 197 | weight_decay = args.weight_decay 198 | n_epochs = args.n_epochs 199 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True) 200 | 201 | 202 | # wandb initialization 203 | wandb.init( 204 | project="domain-adaptation-sentiment-emnlp", 205 | name=args.run_name, 206 | config={ 207 | "epochs": n_epochs, 208 | "learning_rate": lr, 209 | "warmup": args.warmup_steps, 210 | "weight_decay": weight_decay, 211 | "batch_size": batch_size, 212 | "train_split_percentage": args.train_pct, 213 | "bert_model": bert_model, 214 | "seed": seed, 215 | "tags": ",".join(args.tags) 216 | } 217 | ) 218 | #wandb.watch(model) 219 | #Create save directory for model 220 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"): 221 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}") 222 | 223 | # Create the dataset 224 | all_dsets = [MultiDomainSentimentDataset( 225 | args.dataset_loc, 226 | [domain], 227 | DistilBertTokenizer.from_pretrained(bert_model) 228 | ) for domain in args.domains] 229 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)] 230 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))] 231 | 232 | accs = [] 233 | Ps = [] 234 | Rs = [] 235 | F1s = [] 236 | # Store labels and logits for individual splits for micro F1 237 | labels_all = [] 238 | logits_all = [] 239 | 240 | for i in range(len(all_dsets)): 241 | domain = args.domains[i] 242 | test_dset = all_dsets[i] 243 | # Override the domain IDs 244 | k = 0 245 | for j in range(len(all_dsets)): 246 | if j != i: 247 | all_dsets[j].set_domain_id(k) 248 | k += 1 249 | test_dset.set_domain_id(k) 250 | # For test 251 | #all_dsets = [all_dsets[0], all_dsets[2]] 252 | 253 | # Split the data 254 | if args.indices_dir is None: 255 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]]) 256 | for j in range(len(all_dsets)) if j != i] 257 | else: 258 | # load the indices 259 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i] 260 | subset_indices = defaultdict(lambda: [[], []]) 261 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \ 262 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g: 263 | for l in f: 264 | vals = l.strip().split(',') 265 | subset_indices[int(vals[0])][0].append(int(vals[1])) 266 | for l in g: 267 | vals = l.strip().split(',') 268 | subset_indices[int(vals[0])][1].append(int(vals[1])) 269 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] for d in 270 | subset_indices] 271 | 272 | train_dls = [DataLoader( 273 | subset[0], 274 | batch_size=batch_size, 275 | shuffle=True, 276 | collate_fn=collate_batch_transformer 277 | ) for subset in subsets] 278 | 279 | val_ds = [subset[1] for subset in subsets] 280 | # for vds in val_ds: 281 | # print(vds.indices) 282 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device) 283 | 284 | # Create the model 285 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 286 | multi_xformer = MultiDistilBertClassifier( 287 | bert_model, 288 | bert_config, 289 | n_domains=len(train_dls) 290 | ).to(device) 291 | if args.pretrained_multi_xformer is not None: 292 | multi_xformer.load_state_dict(torch.load(f"{args.pretrained_multi_xformer}/model_{domain}.pth")) 293 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(multi_xformer) 294 | print(f"Validation acc multi-xformer: {acc}") 295 | 296 | 297 | shared_bert = VanillaBert(bert).to(device) 298 | if args.pretrained_bert is not None: 299 | shared_bert.load_state_dict(torch.load(f"{args.pretrained_bert}/model_{domain}.pth")) 300 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(shared_bert) 301 | print(f"Validation acc shared bert: {acc}") 302 | 303 | if args.ensemble_basic: 304 | model_class = MultiViewTransformerNetworkAveraging 305 | elif args.ensemble_avg_learned: 306 | model_class = MultiViewTransformerNetworkLearnedAveraging 307 | else: 308 | model_class = MultiViewTransformerNetworkProbabilities 309 | 310 | model = torch.nn.DataParallel(model_class( 311 | multi_xformer, 312 | shared_bert 313 | )).to(device) 314 | 315 | # (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) 316 | # print(f"Validation acc starting: {acc}") 317 | 318 | # Create the optimizer 319 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 320 | optimizer_grouped_parameters = [ 321 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 322 | 'weight_decay': weight_decay}, 323 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 324 | ] 325 | 326 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr) 327 | scheduler = get_linear_schedule_with_warmup( 328 | optimizer, 329 | args.warmup_steps, 330 | n_epochs * sum([len(train_dl) for train_dl in train_dls]) 331 | ) 332 | 333 | # Train 334 | train( 335 | model, 336 | train_dls, 337 | optimizer, 338 | scheduler, 339 | validation_evaluator, 340 | n_epochs, 341 | device, 342 | args.log_interval, 343 | model_dir=args.model_dir, 344 | gradient_accumulation=args.gradient_accumulation, 345 | domain_name=domain 346 | ) 347 | # Load the best weights 348 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth')) 349 | if args.ensemble_avg_learned: 350 | weights = model.module.alpha_params.cpu().detach().numpy() 351 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/weights_{domain}.txt', 'wt') as f: 352 | f.write(str(weights)) 353 | 354 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False) 355 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate( 356 | model, 357 | plot_callbacks=[plot_label_distribution], 358 | return_labels_logits=True, 359 | return_votes=True 360 | ) 361 | print(f"{domain} F1: {F1}") 362 | print(f"{domain} Accuracy: {acc}") 363 | print() 364 | 365 | wandb.run.summary[f"{domain}-P"] = P 366 | wandb.run.summary[f"{domain}-R"] = R 367 | wandb.run.summary[f"{domain}-F1"] = F1 368 | wandb.run.summary[f"{domain}-Acc"] = acc 369 | Ps.append(P) 370 | Rs.append(R) 371 | F1s.append(F1) 372 | accs.append(acc) 373 | labels_all.extend(labels) 374 | logits_all.extend(logits) 375 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f: 376 | for p, l in zip(np.argmax(logits, axis=-1), labels): 377 | f.write(f'{domain}\t{p}\t{l}\n') 378 | 379 | acc, P, R, F1 = acc_f1(logits_all, labels_all) 380 | # Add to wandb 381 | wandb.run.summary[f'test-loss'] = loss 382 | wandb.run.summary[f'test-micro-acc'] = acc 383 | wandb.run.summary[f'test-micro-P'] = P 384 | wandb.run.summary[f'test-micro-R'] = R 385 | wandb.run.summary[f'test-micro-F1'] = F1 386 | 387 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs) 388 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps) 389 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs) 390 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s) 391 | 392 | # wandb.log({f"label-distribution-test-{i}": plots[0]}) 393 | -------------------------------------------------------------------------------- /emnlp_final_experiments/sentiment-analysis/train_multi_view_domain_adversarial.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import random 5 | from typing import AnyStr 6 | from typing import List 7 | import ipdb 8 | import krippendorff 9 | from collections import defaultdict 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | import wandb 15 | from torch.optim.lr_scheduler import LambdaLR 16 | from torch.utils.data import DataLoader 17 | from torch.utils.data import Subset 18 | from torch.utils.data import random_split 19 | from torch.optim import Adam 20 | from tqdm import tqdm 21 | from transformers import AdamW 22 | from transformers import DistilBertConfig 23 | from transformers import DistilBertTokenizer 24 | from transformers import DistilBertForSequenceClassification 25 | from transformers import get_linear_schedule_with_warmup 26 | 27 | from datareader import MultiDomainSentimentDataset 28 | from datareader import collate_batch_transformer 29 | from metrics import MultiDatasetClassificationEvaluator 30 | from metrics import ClassificationEvaluator 31 | from metrics import acc_f1 32 | 33 | from metrics import plot_label_distribution 34 | from model import MultiTransformerClassifier 35 | from model import VanillaBert 36 | from model import * 37 | 38 | 39 | def train( 40 | model: torch.nn.Module, 41 | train_dls: List[DataLoader], 42 | optimizer: torch.optim.Optimizer, 43 | scheduler: LambdaLR, 44 | validation_evaluator: MultiDatasetClassificationEvaluator, 45 | n_epochs: int, 46 | device: AnyStr, 47 | log_interval: int = 1, 48 | patience: int = 10, 49 | model_dir: str = "wandb_local", 50 | gradient_accumulation: int = 1, 51 | domain_name: str = '' 52 | ): 53 | #best_loss = float('inf') 54 | best_acc = 0.0 55 | patience_counter = 0 56 | 57 | epoch_counter = 0 58 | total = sum(len(dl) for dl in train_dls) 59 | 60 | # Main loop 61 | while epoch_counter < n_epochs: 62 | dl_iters = [iter(dl) for dl in train_dls] 63 | dl_idx = list(range(len(dl_iters))) 64 | finished = [0] * len(dl_iters) 65 | i = 0 66 | with tqdm(total=total, desc="Training") as pbar: 67 | while sum(finished) < len(dl_iters): 68 | random.shuffle(dl_idx) 69 | for d in dl_idx: 70 | domain_dl = dl_iters[d] 71 | batches = [] 72 | try: 73 | for j in range(gradient_accumulation): 74 | batches.append(next(domain_dl)) 75 | except StopIteration: 76 | finished[d] = 1 77 | if len(batches) == 0: 78 | continue 79 | optimizer.zero_grad() 80 | for batch in batches: 81 | model.train() 82 | batch = tuple(t.to(device) for t in batch) 83 | input_ids = batch[0] 84 | masks = batch[1] 85 | labels = batch[2] 86 | # Null the labels if its the test data 87 | if d == len(train_dls) - 1: 88 | labels = None 89 | # Testing with random domains to see if any effect 90 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) 91 | domains = batch[3] 92 | 93 | loss, logits, alpha = model(input_ids, attention_mask=masks, domains=domains, labels=labels, ret_alpha = True) 94 | loss = loss.mean() / gradient_accumulation 95 | if i % log_interval == 0: 96 | # wandb.log({ 97 | # "Loss": loss.item(), 98 | # "alpha0": alpha[:,0].cpu(), 99 | # "alpha1": alpha[:, 1].cpu(), 100 | # "alpha2": alpha[:, 2].cpu(), 101 | # "alpha_shared": alpha[:, 3].cpu() 102 | # }) 103 | wandb.log({ 104 | "Loss": loss.item() 105 | }) 106 | 107 | loss.backward() 108 | i += 1 109 | pbar.update(1) 110 | 111 | optimizer.step() 112 | if scheduler is not None: 113 | scheduler.step() 114 | 115 | gc.collect() 116 | 117 | # Inline evaluation 118 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) 119 | print(f"Validation acc: {acc}") 120 | 121 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 122 | 123 | # Saving the best model and early stopping 124 | #if val_loss < best_loss: 125 | if acc > best_acc: 126 | best_model = model.state_dict() 127 | #best_loss = val_loss 128 | best_acc = acc 129 | #wandb.run.summary['best_validation_loss'] = best_loss 130 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') 131 | patience_counter = 0 132 | # Log to wandb 133 | wandb.log({ 134 | 'Validation accuracy': acc, 135 | 'Validation Precision': P, 136 | 'Validation Recall': R, 137 | 'Validation F1': F1, 138 | 'Validation loss': val_loss}) 139 | else: 140 | patience_counter += 1 141 | # Stop training once we have lost patience 142 | if patience_counter == patience: 143 | break 144 | 145 | gc.collect() 146 | epoch_counter += 1 147 | 148 | 149 | if __name__ == "__main__": 150 | # Define arguments 151 | parser = argparse.ArgumentParser() 152 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str) 153 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8) 154 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0) 155 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1) 156 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200) 157 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2) 158 | parser.add_argument("--pretrained_bert", help="Directory with weights to initialize the shared model with", type=str, default=None) 159 | parser.add_argument("--pretrained_multi_xformer", help="Directory with weights to initialize the domain specific models", type=str, default=None) 160 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[]) 161 | parser.add_argument("--seed", type=int, help="Random seed", default=1000) 162 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline") 163 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str) 164 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[]) 165 | parser.add_argument("--batch_size", help="The batch size", type=int, default=16) 166 | parser.add_argument("--lr", help="Learning rate", type=float, default=1e-5) 167 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01) 168 | parser.add_argument("--n_heads", help="Number of transformer heads", default=6, type=int) 169 | parser.add_argument("--n_layers", help="Number of transformer layers", default=6, type=int) 170 | parser.add_argument("--d_model", help="Transformer model size", default=768, type=int) 171 | parser.add_argument("--ff_dim", help="Intermediate feedforward size", default=2048, type=int) 172 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int) 173 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert") 174 | parser.add_argument("--supervision_layer", help="The layer at which to use domain adversarial supervision", default=12, type=int) 175 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None) 176 | 177 | args = parser.parse_args() 178 | 179 | # Set all the seeds 180 | seed = args.seed 181 | random.seed(seed) 182 | np.random.seed(seed) 183 | torch.manual_seed(seed) 184 | torch.cuda.manual_seed_all(seed) 185 | torch.backends.cudnn.deterministic = True 186 | torch.backends.cudnn.benchmark = False 187 | 188 | # See if CUDA available 189 | device = torch.device("cpu") 190 | if args.n_gpu > 0 and torch.cuda.is_available(): 191 | print("Training on GPU") 192 | device = torch.device("cuda:0") 193 | 194 | # model configuration 195 | bert_model = 'distilbert-base-uncased' 196 | ################## 197 | # override for now 198 | batch_size = 4#args.batch_size 199 | args.gradient_accumulation = 2 200 | ############### 201 | lr = args.lr 202 | weight_decay = args.weight_decay 203 | n_epochs = args.n_epochs 204 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True) 205 | 206 | # wandb initialization 207 | wandb.init( 208 | project="domain-adaptation-sentiment-emnlp", 209 | name=args.run_name, 210 | config={ 211 | "epochs": n_epochs, 212 | "learning_rate": lr, 213 | "warmup": args.warmup_steps, 214 | "weight_decay": weight_decay, 215 | "batch_size": batch_size, 216 | "train_split_percentage": args.train_pct, 217 | "bert_model": bert_model, 218 | "seed": seed, 219 | "tags": ",".join(args.tags) 220 | } 221 | ) 222 | #wandb.watch(model) 223 | #Create save directory for model 224 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"): 225 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}") 226 | 227 | # Create the dataset 228 | all_dsets = [MultiDomainSentimentDataset( 229 | args.dataset_loc, 230 | [domain], 231 | DistilBertTokenizer.from_pretrained(bert_model) 232 | ) for domain in args.domains] 233 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)] 234 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))] 235 | 236 | accs = [] 237 | Ps = [] 238 | Rs = [] 239 | F1s = [] 240 | # Store labels and logits for individual splits for micro F1 241 | labels_all = [] 242 | logits_all = [] 243 | 244 | for i in range(len(all_dsets)): 245 | domain = args.domains[i] 246 | test_dset = all_dsets[i] 247 | # Override the domain IDs 248 | k = 0 249 | for j in range(len(all_dsets)): 250 | if j != i: 251 | all_dsets[j].set_domain_id(k) 252 | k += 1 253 | test_dset.set_domain_id(k) 254 | # For test 255 | #all_dsets = [all_dsets[0], all_dsets[2]] 256 | 257 | # Split the data 258 | if args.indices_dir is None: 259 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]]) 260 | for j in range(len(all_dsets)) if j != i] 261 | else: 262 | # load the indices 263 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i] 264 | subset_indices = defaultdict(lambda: [[], []]) 265 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \ 266 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g: 267 | for l in f: 268 | vals = l.strip().split(',') 269 | subset_indices[int(vals[0])][0].append(int(vals[1])) 270 | for l in g: 271 | vals = l.strip().split(',') 272 | subset_indices[int(vals[0])][1].append(int(vals[1])) 273 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] for d in 274 | subset_indices] 275 | 276 | train_dls = [DataLoader( 277 | subset[0], 278 | batch_size=batch_size, 279 | shuffle=True, 280 | collate_fn=collate_batch_transformer 281 | ) for subset in subsets] 282 | # Add test data for domain adversarial training 283 | train_dls += [DataLoader( 284 | test_dset, 285 | batch_size=batch_size, 286 | shuffle=True, 287 | collate_fn=collate_batch_transformer 288 | )] 289 | 290 | val_ds = [subset[1] for subset in subsets] 291 | # for vds in val_ds: 292 | # print(vds.indices) 293 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device) 294 | 295 | # Create the model 296 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 297 | multi_xformer = MultiDistilBertClassifier( 298 | bert_model, 299 | bert_config, 300 | n_domains=len(train_dls) - 1 301 | ).to(device) 302 | if args.pretrained_multi_xformer is not None: 303 | multi_xformer.load_state_dict(torch.load(f"{args.pretrained_multi_xformer}/model_{domain}.pth")) 304 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(multi_xformer) 305 | print(f"Validation acc multi-xformer: {acc}") 306 | 307 | shared_bert = VanillaBert(bert).to(device) 308 | if args.pretrained_bert is not None: 309 | shared_bert.load_state_dict(torch.load(f"{args.pretrained_bert}/model_{domain}.pth")) 310 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(shared_bert) 311 | print(f"Validation acc shared bert: {acc}") 312 | 313 | model = torch.nn.DataParallel(MultiViewTransformerNetworkProbabilitiesAdversarial( 314 | multi_xformer, 315 | shared_bert, 316 | supervision_layer=args.supervision_layer 317 | )).to(device) 318 | # (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) 319 | # print(f"Validation acc starting: {acc}") 320 | 321 | # Create the optimizer 322 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 323 | optimizer_grouped_parameters = [ 324 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 325 | 'weight_decay': weight_decay}, 326 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 327 | ] 328 | 329 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr) 330 | scheduler = get_linear_schedule_with_warmup( 331 | optimizer, 332 | args.warmup_steps, 333 | n_epochs * sum([len(train_dl) for train_dl in train_dls]) 334 | ) 335 | 336 | # Train 337 | train( 338 | model, 339 | train_dls, 340 | optimizer, 341 | scheduler, 342 | validation_evaluator, 343 | n_epochs, 344 | device, 345 | args.log_interval, 346 | model_dir=args.model_dir, 347 | gradient_accumulation=args.gradient_accumulation, 348 | domain_name=domain 349 | ) 350 | # Load the best weights 351 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth')) 352 | 353 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False) 354 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate( 355 | model, 356 | plot_callbacks=[plot_label_distribution], 357 | return_labels_logits=True, 358 | return_votes=True 359 | ) 360 | print(f"{domain} F1: {F1}") 361 | print(f"{domain} Accuracy: {acc}") 362 | print() 363 | 364 | wandb.run.summary[f"{domain}-P"] = P 365 | wandb.run.summary[f"{domain}-R"] = R 366 | wandb.run.summary[f"{domain}-F1"] = F1 367 | wandb.run.summary[f"{domain}-Acc"] = acc 368 | Ps.append(P) 369 | Rs.append(R) 370 | F1s.append(F1) 371 | accs.append(acc) 372 | labels_all.extend(labels) 373 | logits_all.extend(logits) 374 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f: 375 | for p, l in zip(np.argmax(logits, axis=-1), labels): 376 | f.write(f'{domain}\t{p}\t{l}\n') 377 | 378 | acc, P, R, F1 = acc_f1(logits_all, labels_all) 379 | # Add to wandb 380 | wandb.run.summary[f'test-loss'] = loss 381 | wandb.run.summary[f'test-micro-acc'] = acc 382 | wandb.run.summary[f'test-micro-P'] = P 383 | wandb.run.summary[f'test-micro-R'] = R 384 | wandb.run.summary[f'test-micro-F1'] = F1 385 | 386 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs) 387 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps) 388 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs) 389 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s) 390 | 391 | # wandb.log({f"label-distribution-test-{i}": plots[0]}) 392 | -------------------------------------------------------------------------------- /emnlp_final_experiments/sentiment-analysis/train_multi_view_selective_weighting.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import random 5 | from typing import AnyStr 6 | from typing import List 7 | import ipdb 8 | import krippendorff 9 | from collections import defaultdict 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | import wandb 15 | from torch.optim.lr_scheduler import LambdaLR 16 | from torch.utils.data import DataLoader 17 | from torch.utils.data import Subset 18 | from torch.utils.data import random_split 19 | from torch.optim import Adam 20 | from tqdm import tqdm 21 | from transformers import AdamW 22 | from transformers import DistilBertConfig 23 | from transformers import DistilBertTokenizer 24 | from transformers import DistilBertForSequenceClassification 25 | from transformers import get_linear_schedule_with_warmup 26 | 27 | from datareader import MultiDomainSentimentDataset 28 | from datareader import collate_batch_transformer 29 | from metrics import MultiDatasetClassificationEvaluator 30 | from metrics import ClassificationEvaluator 31 | from metrics import acc_f1 32 | 33 | from metrics import plot_label_distribution 34 | from model import MultiTransformerClassifier 35 | from model import VanillaBert 36 | from model import * 37 | from sklearn.model_selection import ParameterSampler 38 | 39 | 40 | def attention_grid_search( 41 | model: torch.nn.Module, 42 | validation_evaluator: MultiDatasetClassificationEvaluator, 43 | n_epochs: int, 44 | seed: int 45 | ): 46 | best_weights = model.module.weights 47 | # initial 48 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) 49 | best_acc = acc 50 | print(acc) 51 | # Create the grid search 52 | param_dict = {1:list(range(0,11)), 2:list(range(0,11)),3:list(range(0,11)),4:list(range(0,11))} 53 | grid_search_params = ParameterSampler(param_dict, n_iter=n_epochs, random_state=seed) 54 | for d in grid_search_params: 55 | weights = [v for k,v in sorted(d.items(), key=lambda x:x[0])] 56 | weights = np.array(weights) / sum(weights) 57 | model.module.weights = weights 58 | # Inline evaluation 59 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) 60 | print(f"Weights: {weights}\tValidation acc: {acc}") 61 | 62 | if acc > best_acc: 63 | best_weights = weights 64 | best_acc = acc 65 | # Log to wandb 66 | wandb.log({ 67 | 'Validation accuracy': acc, 68 | 'Validation Precision': P, 69 | 'Validation Recall': R, 70 | 'Validation F1': F1, 71 | 'Validation loss': val_loss}) 72 | 73 | gc.collect() 74 | 75 | return best_weights 76 | 77 | 78 | if __name__ == "__main__": 79 | # Define arguments 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str) 82 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8) 83 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0) 84 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2) 85 | parser.add_argument("--pretrained_model", help="The pretrained averaging model", type=str, default=None) 86 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[]) 87 | parser.add_argument("--seed", type=int, help="Random seed", default=1000) 88 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline") 89 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str) 90 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[]) 91 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None) 92 | 93 | args = parser.parse_args() 94 | 95 | # Set all the seeds 96 | seed = args.seed 97 | random.seed(seed) 98 | np.random.seed(seed) 99 | torch.manual_seed(seed) 100 | torch.cuda.manual_seed_all(seed) 101 | torch.backends.cudnn.deterministic = True 102 | torch.backends.cudnn.benchmark = False 103 | 104 | # See if CUDA available 105 | device = torch.device("cpu") 106 | if args.n_gpu > 0 and torch.cuda.is_available(): 107 | print("Training on GPU") 108 | device = torch.device("cuda:0") 109 | 110 | # model configuration 111 | bert_model = 'distilbert-base-uncased' 112 | n_epochs = args.n_epochs 113 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True) 114 | 115 | 116 | # wandb initialization 117 | wandb.init( 118 | project="domain-adaptation-sentiment-emnlp", 119 | name=args.run_name, 120 | config={ 121 | "epochs": n_epochs, 122 | "bert_model": bert_model, 123 | "seed": seed, 124 | "tags": ",".join(args.tags) 125 | } 126 | ) 127 | #wandb.watch(model) 128 | #Create save directory for model 129 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"): 130 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}") 131 | 132 | # Create the dataset 133 | all_dsets = [MultiDomainSentimentDataset( 134 | args.dataset_loc, 135 | [domain], 136 | DistilBertTokenizer.from_pretrained(bert_model) 137 | ) for domain in args.domains] 138 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)] 139 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))] 140 | 141 | accs = [] 142 | Ps = [] 143 | Rs = [] 144 | F1s = [] 145 | # Store labels and logits for individual splits for micro F1 146 | labels_all = [] 147 | logits_all = [] 148 | 149 | for i in range(len(all_dsets)): 150 | domain = args.domains[i] 151 | test_dset = all_dsets[i] 152 | # Override the domain IDs 153 | k = 0 154 | for j in range(len(all_dsets)): 155 | if j != i: 156 | all_dsets[j].set_domain_id(k) 157 | k += 1 158 | test_dset.set_domain_id(k) 159 | # For test 160 | #all_dsets = [all_dsets[0], all_dsets[2]] 161 | 162 | # Split the data 163 | if args.indices_dir is None: 164 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]]) 165 | for j in range(len(all_dsets)) if j != i] 166 | else: 167 | # load the indices 168 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i] 169 | subset_indices = defaultdict(lambda: [[], []]) 170 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \ 171 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g: 172 | for l in f: 173 | vals = l.strip().split(',') 174 | subset_indices[int(vals[0])][0].append(int(vals[1])) 175 | for l in g: 176 | vals = l.strip().split(',') 177 | subset_indices[int(vals[0])][1].append(int(vals[1])) 178 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] for d in 179 | subset_indices] 180 | 181 | train_dls = [DataLoader( 182 | subset[0], 183 | batch_size=8, 184 | shuffle=True, 185 | collate_fn=collate_batch_transformer 186 | ) for subset in subsets] 187 | 188 | val_ds = [subset[1] for subset in subsets] 189 | # for vds in val_ds: 190 | # print(vds.indices) 191 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device) 192 | 193 | # Create the model 194 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device) 195 | multi_xformer = MultiDistilBertClassifier( 196 | bert_model, 197 | bert_config, 198 | n_domains=len(train_dls) 199 | ).to(device) 200 | 201 | shared_bert = VanillaBert(bert).to(device) 202 | 203 | model = torch.nn.DataParallel(MultiViewTransformerNetworkSelectiveWeight( 204 | multi_xformer, 205 | shared_bert 206 | )).to(device) 207 | # Load the best weights 208 | model.load_state_dict(torch.load(f'{args.pretrained_model}/model_{domain}.pth')) 209 | 210 | # Calculate the best attention weights with a grid search 211 | weights = attention_grid_search( 212 | model, 213 | validation_evaluator, 214 | n_epochs, 215 | seed 216 | ) 217 | model.module.weights = weights 218 | print(f"Best weights: {weights}") 219 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/weights_{domain}.txt', 'wt') as f: 220 | f.write(str(weights)) 221 | 222 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False) 223 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate( 224 | model, 225 | plot_callbacks=[plot_label_distribution], 226 | return_labels_logits=True, 227 | return_votes=True 228 | ) 229 | print(f"{domain} F1: {F1}") 230 | print(f"{domain} Accuracy: {acc}") 231 | print() 232 | 233 | wandb.run.summary[f"{domain}-P"] = P 234 | wandb.run.summary[f"{domain}-R"] = R 235 | wandb.run.summary[f"{domain}-F1"] = F1 236 | wandb.run.summary[f"{domain}-Acc"] = acc 237 | Ps.append(P) 238 | Rs.append(R) 239 | F1s.append(F1) 240 | accs.append(acc) 241 | labels_all.extend(labels) 242 | logits_all.extend(logits) 243 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f: 244 | for p, l in zip(np.argmax(logits, axis=-1), labels): 245 | f.write(f'{domain}\t{p}\t{l}\n') 246 | 247 | acc, P, R, F1 = acc_f1(logits_all, labels_all) 248 | # Add to wandb 249 | wandb.run.summary[f'test-loss'] = loss 250 | wandb.run.summary[f'test-micro-acc'] = acc 251 | wandb.run.summary[f'test-micro-P'] = P 252 | wandb.run.summary[f'test-micro-R'] = R 253 | wandb.run.summary[f'test-micro-F1'] = F1 254 | 255 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs) 256 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps) 257 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs) 258 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s) 259 | 260 | # wandb.log({f"label-distribution-test-{i}": plots[0]}) 261 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | from typing import Tuple, List, Callable, AnyStr 6 | from tqdm import tqdm 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data import Dataset 9 | from sklearn.metrics import precision_recall_fscore_support 10 | from collections import Counter 11 | from scipy.stats import entropy 12 | import ipdb 13 | 14 | from datareader import collate_batch_transformer 15 | 16 | 17 | def accuracy(logits: np.ndarray, labels: np.ndarray) -> float: 18 | return np.sum(np.argmax(logits, axis=-1) == labels).astype(np.float32) / float(labels.shape[0]) 19 | 20 | 21 | def acc_f1(logits: List, labels: List) -> Tuple[float, float, float, float]: 22 | logits = np.asarray(logits).reshape(-1, len(logits[0])) 23 | labels = np.asarray(labels).reshape(-1) 24 | acc = accuracy(logits, labels) 25 | average = 'binary' if logits.shape[1] == 2 else None 26 | P, R, F1, _ = precision_recall_fscore_support(labels, np.argmax(logits, axis=-1), average=average) 27 | return acc,P,R,F1 28 | 29 | 30 | def plot_label_distribution(labels: np.ndarray, logits: np.ndarray) -> matplotlib.figure.Figure: 31 | """ Plots the distribution of labels in the prediction 32 | 33 | :param labels: Gold labels 34 | :param logits: Logits from the model 35 | :return: None 36 | """ 37 | predictions = np.argmax(logits, axis=-1) 38 | labs, counts = zip(*list(sorted(Counter(predictions).items(), key=lambda x: x[0]))) 39 | 40 | fig, ax = plt.subplots(figsize=(12, 9)) 41 | ax.bar(labs, counts, width=0.2) 42 | ax.set_xticks(labs, [str(l) for l in labs]) 43 | ax.set_ylabel('Count') 44 | ax.set_xlabel("Label") 45 | ax.set_title("Prediction distribution") 46 | return fig 47 | 48 | 49 | class ClassificationEvaluator: 50 | """Wrapper to evaluate a model for the task of citation detection 51 | 52 | """ 53 | 54 | def __init__(self, dataset: Dataset, device: torch.device, use_domain: bool = True, use_labels: bool = True): 55 | self.dataset = dataset 56 | self.dataloader = DataLoader( 57 | dataset, 58 | batch_size=8, 59 | collate_fn=collate_batch_transformer 60 | ) 61 | self.device = device 62 | self.stored_labels = [] 63 | self.stored_logits = [] 64 | self.use_domain = use_domain 65 | self.use_labels = use_labels 66 | 67 | def micro_f1(self) -> Tuple[float, float, float, float]: 68 | labels_all = self.stored_labels 69 | logits_all = self.stored_logits 70 | 71 | logits_all = np.asarray(logits_all).reshape(-1, len(logits_all[0])) 72 | labels_all = np.asarray(labels_all).reshape(-1) 73 | acc = accuracy(logits_all, labels_all) 74 | P, R, F1, _ = precision_recall_fscore_support(labels_all, np.argmax(logits_all, axis=-1), average='binary') 75 | 76 | return acc, P, R, F1 77 | 78 | def evaluate( 79 | self, 80 | model: torch.nn.Module, 81 | plot_callbacks: List[Callable] = [], 82 | return_labels_logits: bool = False, 83 | return_votes: bool = False 84 | ) -> Tuple: 85 | """Collect evaluation metrics on this dataset 86 | 87 | :param model: The pytorch model to evaluate 88 | :param plot_callbacks: Optional function callbacks for plotting various things 89 | :return: (Loss, Accuracy, Precision, Recall, F1) 90 | """ 91 | model.eval() 92 | with torch.no_grad(): 93 | labels_all = [] 94 | logits_all = [] 95 | losses_all = [] 96 | votes_all = [] 97 | for batch in tqdm(self.dataloader, desc="Evaluation"): 98 | batch = tuple(t.to(self.device) for t in batch) 99 | input_ids = batch[0] 100 | masks = batch[1] 101 | labels = batch[2] 102 | domains = batch[3] if self.use_domain else None 103 | if self.use_labels: 104 | loss, logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels) 105 | if len(loss.size()) > 0: 106 | loss = loss.mean() 107 | else: 108 | (logits,) = model(input_ids, attention_mask=masks, domains=domains) 109 | loss = torch.FloatTensor([-1.]) 110 | labels_all.extend(list(labels.detach().cpu().numpy())) 111 | logits_all.extend(list(logits.detach().cpu().numpy())) 112 | losses_all.append(loss.item()) 113 | if hasattr(model, 'votes'): 114 | votes_all.extend(model.votes.detach().cpu().numpy()) 115 | 116 | if not self.use_labels: 117 | return logits_all 118 | 119 | acc,P,R,F1 = acc_f1(logits_all, labels_all) 120 | loss = sum(losses_all) / len(losses_all) 121 | 122 | # Plotting 123 | plots = [] 124 | for f in plot_callbacks: 125 | plots.append(f(labels_all, logits_all)) 126 | 127 | ret_vals = (loss, acc, P, R, F1), plots 128 | if return_labels_logits: 129 | ret_vals = ret_vals + ((labels_all, logits_all),) 130 | 131 | if return_votes: 132 | if len(votes_all) > 0: 133 | ret_vals += (votes_all,) 134 | else: 135 | ret_vals += (list(np.argmax(np.asarray(logits_all), axis=1)),) 136 | 137 | return ret_vals 138 | 139 | class MultiDatasetClassificationEvaluator: 140 | """Wrapper to evaluate a model for the task of citation detection 141 | 142 | """ 143 | 144 | def __init__(self, datasets: List[Dataset], device: torch.device, use_domain: bool = True): 145 | self.datasets = datasets 146 | self.dataloaders = [DataLoader( 147 | dataset, 148 | batch_size=8, 149 | collate_fn=collate_batch_transformer 150 | ) for dataset in datasets] 151 | 152 | self.device = device 153 | self.stored_labels = [] 154 | self.stored_logits = [] 155 | self.use_domain = use_domain 156 | 157 | def micro_f1(self) -> Tuple[float, float, float, float]: 158 | labels_all = self.stored_labels 159 | logits_all = self.stored_logits 160 | 161 | logits_all = np.asarray(logits_all).reshape(-1, len(logits_all[0])) 162 | labels_all = np.asarray(labels_all).reshape(-1) 163 | acc = accuracy(logits_all, labels_all) 164 | P, R, F1, _ = precision_recall_fscore_support(labels_all, np.argmax(logits_all, axis=-1), average='binary') 165 | 166 | return acc, P, R, F1 167 | 168 | def evaluate( 169 | self, 170 | model: torch.nn.Module, 171 | plot_callbacks: List[Callable] = [], 172 | return_labels_logits: bool = False, 173 | return_votes: bool = False 174 | ) -> Tuple: 175 | """Collect evaluation metrics on this dataset 176 | 177 | :param model: The pytorch model to evaluate 178 | :param plot_callbacks: Optional function callbacks for plotting various things 179 | :return: (Loss, Accuracy, Precision, Recall, F1) 180 | """ 181 | model.eval() 182 | with torch.no_grad(): 183 | labels_all = [] 184 | logits_all = [] 185 | losses_all = [] 186 | votes_all = [] 187 | for dataloader in self.dataloaders: 188 | for batch in tqdm(dataloader, desc="Evaluation"): 189 | batch = tuple(t.to(self.device) for t in batch) 190 | input_ids = batch[0] 191 | masks = batch[1] 192 | labels = batch[2] 193 | domains = batch[3] if self.use_domain else None 194 | loss, logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels) 195 | if len(loss.size()) > 0: 196 | loss = loss.mean() 197 | labels_all.extend(list(labels.detach().cpu().numpy())) 198 | logits_all.extend(list(logits.detach().cpu().numpy())) 199 | losses_all.append(loss.item()) 200 | if hasattr(model, 'votes'): 201 | votes_all.extend(model.votes.detach().cpu().numpy()) 202 | 203 | # Use the domain with the lowest entropy 204 | if len(votes_all) > 0: 205 | votes_all = np.asarray(votes_all).transpose(0,1) 206 | entropies = [np.mean([entropy(v) for v in votes]) for votes in votes_all] 207 | domain = np.argmax(entropies) 208 | logits_all = votes_all[domain] 209 | 210 | acc,P,R,F1 = acc_f1(logits_all, labels_all) 211 | loss = sum(losses_all) / len(losses_all) 212 | 213 | # Plotting 214 | plots = [] 215 | for f in plot_callbacks: 216 | plots.append(f(labels_all, logits_all)) 217 | 218 | ret_vals = (loss, acc, P, R, F1), plots 219 | if return_labels_logits: 220 | ret_vals = ret_vals + ((labels_all, logits_all),) 221 | 222 | if return_votes: 223 | if len(votes_all) > 0: 224 | ret_vals += (votes_all,) 225 | else: 226 | ret_vals += (list(np.argmax(np.asarray(logits_all, axis=1))),) 227 | 228 | return ret_vals 229 | 230 | 231 | class DomainClassifierEvaluator: 232 | """Wrapper to evaluate a model for the task of citation detection 233 | 234 | """ 235 | 236 | def __init__(self, dataset: Dataset, device: torch.device): 237 | self.dataset = dataset 238 | self.dataloader = DataLoader( 239 | dataset, 240 | batch_size=8, 241 | collate_fn=collate_batch_transformer 242 | ) 243 | self.device = device 244 | self.stored_labels = [] 245 | self.stored_logits = [] 246 | 247 | def micro_f1(self) -> Tuple[float, float, float, float]: 248 | labels_all = self.stored_labels 249 | logits_all = self.stored_logits 250 | 251 | logits_all = np.asarray(logits_all).reshape(-1, len(logits_all[0])) 252 | labels_all = np.asarray(labels_all).reshape(-1) 253 | acc = accuracy(logits_all, labels_all) 254 | P, R, F1, _ = precision_recall_fscore_support(labels_all, np.argmax(logits_all, axis=-1), average='binary') 255 | 256 | return acc, P, R, F1 257 | 258 | def evaluate( 259 | self, 260 | model: torch.nn.Module, 261 | plot_callbacks: List[Callable] = [], 262 | return_labels_logits: bool = False, 263 | return_votes: bool = False 264 | ) -> Tuple: 265 | """Collect evaluation metrics on this dataset 266 | 267 | :param model: The pytorch model to evaluate 268 | :param plot_callbacks: Optional function callbacks for plotting various things 269 | :return: (Loss, Accuracy, Precision, Recall, F1) 270 | """ 271 | model.eval() 272 | with torch.no_grad(): 273 | labels_all = [] 274 | logits_all = [] 275 | losses_all = [] 276 | votes_all = [] 277 | for batch in tqdm(self.dataloader, desc="Evaluation"): 278 | batch = tuple(t.to(self.device) for t in batch) 279 | input_ids = batch[0] 280 | masks = batch[1] 281 | labels = batch[2] 282 | domains = batch[3] 283 | loss, logits = model(input_ids, attention_mask=masks, labels=domains) 284 | 285 | labels_all.extend(list(domains.detach().cpu().numpy())) 286 | logits_all.extend(list(logits.detach().cpu().numpy())) 287 | losses_all.append(loss.item()) 288 | if hasattr(model, 'votes'): 289 | votes_all.extend(model.votes.detach().cpu().numpy()) 290 | 291 | acc,P,R,F1 = acc_f1(logits_all, labels_all) 292 | loss = sum(losses_all) / len(losses_all) 293 | 294 | # Plotting 295 | plots = [] 296 | for f in plot_callbacks: 297 | plots.append(f(labels_all, logits_all)) 298 | 299 | ret_vals = (loss, acc, P, R, F1), plots 300 | if return_labels_logits: 301 | ret_vals = ret_vals + ((labels_all, logits_all),) 302 | 303 | if return_votes: 304 | if len(votes_all) > 0: 305 | ret_vals += (votes_all,) 306 | else: 307 | ret_vals += (list(np.argmax(np.asarray(logits_all), axis=1)),) 308 | 309 | return ret_vals -------------------------------------------------------------------------------- /multisource-domain-adaptation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/copenlu/xformer-multi-source-domain-adaptation/be2d1a132298131df82fe40dd4f6c08dec8b3404/multisource-domain-adaptation.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fasttext==0.9.1 2 | ipdb==0.12.3 3 | ipython==7.12.0 4 | krippendorff==0.3.2 5 | matplotlib==3.1.3 6 | numpy==1.18.1 7 | pandas==1.0.1 8 | scikit-learn==0.22.1 9 | torch==1.4.0 10 | torchvision==0.5.0 11 | tqdm==4.42.1 12 | transformers==2.8.0 13 | wandb==0.8.27 14 | -------------------------------------------------------------------------------- /run_claim_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | . activate xformer-multisource-domain-adaptation 4 | 5 | . setenv.sh 6 | 7 | run_name="(emnlp-claim)" 8 | model_dir="wandb_local/emnlp_claim_experiments" 9 | tags="emnlp claim experiments" 10 | for i in 1000,1 1001,2 666,3 7,4 50,5; do IFS=","; set -- $i; 11 | j=`expr ${2} - 1` 12 | 13 | # 1) Basic 14 | python emnlp_final_experiments/claim-detection/train_basic.py \ 15 | --dataset_loc data/PHEME \ 16 | --train_pct 0.9 \ 17 | --n_gpu 1 \ 18 | --n_epochs 5 \ 19 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \ 20 | --seed ${1} \ 21 | --run_name "basic-distilbert-${2}" \ 22 | --model_dir ${model_dir}/basic_distilbert \ 23 | --tags ${tags} \ 24 | --batch_size 8 \ 25 | --lr 0.00003 \ 26 | indices_dir=`ls -d -t ${model_dir}/basic_distilbert/*/ | head -1` 27 | 28 | # 2) Adv-6 29 | python emnlp_final_experiments/claim-detection/train_basic_domain_adversarial.py \ 30 | --dataset_loc data/PHEME \ 31 | --train_pct 0.9 \ 32 | --n_gpu 1 \ 33 | --n_epochs 5 \ 34 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \ 35 | --seed ${1} \ 36 | --run_name "distilbert-adversarial-6-${2}" \ 37 | --model_dir ${model_dir}/distilbert_adversarial_6 \ 38 | --tags ${tags} \ 39 | --batch_size 8 \ 40 | --lr 0.00003 \ 41 | --supervision_layer 6 \ 42 | --indices_dir ${indices_dir} 43 | 44 | # 3) Adv-3 45 | python emnlp_final_experiments/claim-detection/train_basic_domain_adversarial.py \ 46 | --dataset_loc data/PHEME \ 47 | --train_pct 0.9 \ 48 | --n_gpu 1 \ 49 | --n_epochs 5 \ 50 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \ 51 | --seed ${1} \ 52 | --run_name "distilbert-adversarial-3-${2}" \ 53 | --model_dir ${model_dir}/distilbert_adversarial_3 \ 54 | --tags ${tags} \ 55 | --batch_size 8 \ 56 | --lr 0.00003 \ 57 | --supervision_layer 3 \ 58 | --indices_dir ${indices_dir} 59 | 60 | # 4) Independent-Avg 61 | python emnlp_final_experiments/claim-detection/train_multi_view_averaging_individuals.py \ 62 | --dataset_loc data/PHEME \ 63 | --train_pct 0.9 \ 64 | --n_gpu 1 \ 65 | --n_epochs 5 \ 66 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \ 67 | --seed ${1} \ 68 | --run_name "distilbert-ensemble-averaging-individuals-${2}" \ 69 | --model_dir ${model_dir}/distilbert_ensemble_averaging_individuals \ 70 | --tags ${tags} \ 71 | --batch_size 8 \ 72 | --lr 0.00003 \ 73 | --indices_dir ${indices_dir} 74 | avg_model=`ls -d -t ${model_dir}/distilbert_ensemble_averaging_individuals/*/ | head -1` 75 | 76 | # 5) Independent-Ft 77 | python emnlp_final_experiments/claim-detection/train_multi_view_selective_weighting.py \ 78 | --dataset_loc data/PHEME \ 79 | --train_pct 0.9 \ 80 | --n_gpu 1 \ 81 | --n_epochs 30 \ 82 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \ 83 | --seed ${1} \ 84 | --run_name "distilbert-ensemble-selective-attention-${2}" \ 85 | --model_dir ${model_dir}/distilbert_ensemble_selective_attention \ 86 | --tags ${tags} \ 87 | --pretrained_model ${avg_model} \ 88 | --indices_dir ${indices_dir} 89 | 90 | # 6) MoE-DC 91 | python emnlp_final_experiments/claim-detection/train_multi_view_domainclassifier_individuals.py \ 92 | --dataset_loc data/PHEME \ 93 | --train_pct 0.9 \ 94 | --n_gpu 1 \ 95 | --n_epochs 5 \ 96 | --n_dc_epochs 5 \ 97 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \ 98 | --seed ${1} \ 99 | --run_name "distilbert-ensemble-domainclassifier-individuals-${2}" \ 100 | --model_dir ${model_dir}/distilbert_ensemble_domainclassifier_individuals \ 101 | --tags ${tags} \ 102 | --batch_size 8 \ 103 | --lr 0.00003 \ 104 | --indices_dir ${indices_dir} \ 105 | --pretrained_model ${avg_model} 106 | 107 | # 7) MoE-Avg 108 | python emnlp_final_experiments/claim-detection/train_multi_view.py \ 109 | --dataset_loc data/PHEME \ 110 | --train_pct 0.9 \ 111 | --n_gpu 1 \ 112 | --n_epochs 5 \ 113 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \ 114 | --seed ${1} \ 115 | --run_name "distilbert-ensemble-averaging-${2}" \ 116 | --model_dir ${model_dir}/distilbert_ensemble_averaging \ 117 | --tags ${tags} \ 118 | --batch_size 8 \ 119 | --lr 0.00003 \ 120 | --ensemble_basic \ 121 | --indices_dir ${indices_dir} 122 | 123 | # 8) MoE-Att 124 | python emnlp_final_experiments/claim-detection/train_multi_view.py \ 125 | --dataset_loc data/PHEME \ 126 | --train_pct 0.9 \ 127 | --n_gpu 1 \ 128 | --n_epochs 5 \ 129 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \ 130 | --seed ${1} \ 131 | --run_name "distilbert-ensemble-attention-${2}" \ 132 | --model_dir ${model_dir}/distilbert_ensemble_attention \ 133 | --tags ${tags} \ 134 | --batch_size 8 \ 135 | --lr 0.00003 \ 136 | --indices_dir ${indices_dir} 137 | 138 | # 9) MoE-Att-Adv-6 139 | python emnlp_final_experiments/claim-detection/train_multi_view_domain_adversarial.py \ 140 | --dataset_loc data/PHEME \ 141 | --train_pct 0.9 \ 142 | --n_gpu 1 \ 143 | --n_epochs 5 \ 144 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \ 145 | --seed ${1} \ 146 | --run_name "distilbert-ensemble-attention-adversarial-6-${2}" \ 147 | --model_dir ${model_dir}/distilbert_ensemble_attention_adversarial_6 \ 148 | --tags ${tags} \ 149 | --batch_size 8 \ 150 | --lr 0.00003 \ 151 | --supervision_layer 6 \ 152 | --indices_dir ${indices_dir} 153 | 154 | # 10) MoE-Att-Adv-3 155 | python emnlp_final_experiments/claim-detection/train_multi_view_domain_adversarial.py \ 156 | --dataset_loc data/PHEME \ 157 | --train_pct 0.9 \ 158 | --n_gpu 1 \ 159 | --n_epochs 5 \ 160 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \ 161 | --seed ${1} \ 162 | --run_name "distilbert-ensemble-attention-adversarial-3-${2}" \ 163 | --model_dir ${model_dir}/distilbert_ensemble_attention_adversarial_3 \ 164 | --tags ${tags} \ 165 | --batch_size 8 \ 166 | --lr 0.00003 \ 167 | --supervision_layer 3 \ 168 | --indices_dir ${indices_dir} 169 | 170 | done 171 | -------------------------------------------------------------------------------- /run_sentiment_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | . activate xformer-multisource-domain-adaptation 4 | 5 | . setenv.sh 6 | 7 | run_name="(emnlp-sentiment)" 8 | model_dir="wandb_local/emnlp_sentiment_experiments" 9 | tags="emnlp sentiment experiments" 10 | for i in 1000,1 1001,2 666,3 7,4 50,5; do IFS=","; set -- $i; 11 | # 1) Basic 12 | python emnlp_final_experiments/sentiment-analysis/train_basic.py \ 13 | --dataset_loc data/sentiment-dataset \ 14 | --train_pct 0.9 \ 15 | --n_gpu 1 \ 16 | --n_epochs 5 \ 17 | --domains books dvd electronics kitchen_\&_housewares \ 18 | --seed ${1} \ 19 | --run_name "basic-distilbert-${2}" \ 20 | --model_dir ${model_dir}/basic_distilbert \ 21 | --tags ${tags} \ 22 | --batch_size 8 \ 23 | --lr 0.00003 24 | indices_dir=`ls -d -t ${model_dir}/basic_distilbert/*/ | head -1` 25 | 26 | # 2) Adv-6 27 | python emnlp_final_experiments/sentiment-analysis/train_basic_domain_adversarial.py \ 28 | --dataset_loc data/sentiment-dataset \ 29 | --train_pct 0.9 \ 30 | --n_gpu 1 \ 31 | --n_epochs 5 \ 32 | --domains books dvd electronics kitchen_\&_housewares \ 33 | --seed ${1} \ 34 | --run_name "distilbert-adversarial-6-${2}" \ 35 | --model_dir ${model_dir}/distilbert_adversarial_6 \ 36 | --tags ${tags} \ 37 | --batch_size 8 \ 38 | --lr 0.00003 \ 39 | --supervision_layer 6 \ 40 | --indices_dir ${indices_dir} 41 | 42 | # 3) Adv-3 43 | python emnlp_final_experiments/sentiment-analysis/train_basic_domain_adversarial.py \ 44 | --dataset_loc data/sentiment-dataset \ 45 | --train_pct 0.9 \ 46 | --n_gpu 1 \ 47 | --n_epochs 5 \ 48 | --domains books dvd electronics kitchen_\&_housewares \ 49 | --seed ${1} \ 50 | --run_name "distilbert-adversarial-3-${2}" \ 51 | --model_dir ${model_dir}/distilbert_adversarial_3 \ 52 | --tags ${tags} \ 53 | --batch_size 8 \ 54 | --lr 0.00003 \ 55 | --supervision_layer 3 \ 56 | --indices_dir ${indices_dir} 57 | 58 | # 4) Independent-Avg 59 | python emnlp_final_experiments/sentiment-analysis/train_multi_view_averaging_individuals.py \ 60 | --dataset_loc data/sentiment-dataset \ 61 | --train_pct 0.9 \ 62 | --n_gpu 1 \ 63 | --n_epochs 5 \ 64 | --domains books dvd electronics kitchen_\&_housewares \ 65 | --seed ${1} \ 66 | --run_name "distilbert-ensemble-averaging-individuals-${2}" \ 67 | --model_dir ${model_dir}/distilbert_ensemble_averaging_individuals \ 68 | --tags ${tags} \ 69 | --batch_size 8 \ 70 | --lr 0.00003 \ 71 | --indices_dir ${indices_dir} 72 | avg_model=`ls -d -t ${model_dir}/distilbert_ensemble_averaging_individuals/*/ | head -1` 73 | 74 | # 5) Independent-Ft 75 | python emnlp_final_experiments/sentiment-analysis/train_multi_view_selective_weighting.py \ 76 | --dataset_loc data/sentiment-dataset \ 77 | --train_pct 0.9 \ 78 | --n_gpu 1 \ 79 | --n_epochs 30 \ 80 | --domains books dvd electronics kitchen_\&_housewares \ 81 | --seed ${1} \ 82 | --run_name "distilbert-ensemble-selective-attention-${2}" \ 83 | --model_dir ${model_dir}/distilbert_ensemble_selective_attention \ 84 | --tags ${tags} \ 85 | --pretrained_model ${avg_model} \ 86 | --indices_dir ${indices_dir} 87 | 88 | # 6) MoE-DC 89 | python emnlp_final_experiments/sentiment-analysis/train_multi_view_domainclassifier_individuals.py \ 90 | --dataset_loc data/sentiment-dataset \ 91 | --train_pct 0.9 \ 92 | --n_gpu 1 \ 93 | --n_epochs 5 \ 94 | --domains books dvd electronics kitchen_\&_housewares \ 95 | --seed ${1} \ 96 | --run_name "distilbert-ensemble-domainclassifier-individuals-${2}" \ 97 | --model_dir ${model_dir}/distilbert_ensemble_domainclassifier_individuals \ 98 | --tags ${tags} \ 99 | --batch_size 8 \ 100 | --lr 0.00003 \ 101 | --indices_dir ${indices_dir} \ 102 | --pretrained_model ${avg_model} 103 | 104 | # 7) MoE-Avg 105 | python emnlp_final_experiments/sentiment-analysis/train_multi_view.py \ 106 | --dataset_loc data/sentiment-dataset \ 107 | --train_pct 0.9 \ 108 | --n_gpu 1 \ 109 | --n_epochs 5 \ 110 | --domains books dvd electronics kitchen_\&_housewares \ 111 | --seed ${1} \ 112 | --run_name "distilbert-ensemble-averaging-${2}" \ 113 | --model_dir ${model_dir}/distilbert_ensemble_averaging \ 114 | --tags ${tags} \ 115 | --batch_size 8 \ 116 | --lr 0.00003 \ 117 | --ensemble_basic \ 118 | --indices_dir ${indices_dir} 119 | 120 | # 8) MoE-Att 121 | python emnlp_final_experiments/sentiment-analysis/train_multi_view.py \ 122 | --dataset_loc data/sentiment-dataset \ 123 | --train_pct 0.9 \ 124 | --n_gpu 1 \ 125 | --n_epochs 5 \ 126 | --domains books dvd electronics kitchen_\&_housewares \ 127 | --seed ${1} \ 128 | --run_name "distilbert-ensemble-attention-${2}" \ 129 | --model_dir ${model_dir}/distilbert_ensemble_attention \ 130 | --tags ${tags} \ 131 | --batch_size 8 \ 132 | --lr 0.00003 \ 133 | --indices_dir ${indices_dir} 134 | 135 | # 9) MoE-Att-Adv-6 136 | python emnlp_final_experiments/sentiment-analysis/train_multi_view_domain_adversarial.py \ 137 | --dataset_loc data/sentiment-dataset \ 138 | --train_pct 0.9 \ 139 | --n_gpu 1 \ 140 | --n_epochs 5 \ 141 | --domains books dvd electronics kitchen_\&_housewares \ 142 | --seed ${1} \ 143 | --run_name "distilbert-ensemble-attention-adversarial-6-${2}" \ 144 | --model_dir ${model_dir}/distilbert_ensemble_attention_adversarial_6 \ 145 | --tags ${tags} \ 146 | --batch_size 8 \ 147 | --lr 0.00003 \ 148 | --supervision_layer 6 \ 149 | --indices_dir ${indices_dir} 150 | 151 | # 10) MoE-Att-Adv-3 152 | python emnlp_final_experiments/sentiment-analysis/train_multi_view_domain_adversarial.py \ 153 | --dataset_loc data/sentiment-dataset \ 154 | --train_pct 0.9 \ 155 | --n_gpu 1 \ 156 | --n_epochs 5 \ 157 | --domains books dvd electronics kitchen_\&_housewares \ 158 | --seed ${1} \ 159 | --run_name "distilbert-ensemble-attention-adversarial-3-${2}" \ 160 | --model_dir ${model_dir}/distilbert_ensemble_attention_adversarial_4 \ 161 | --tags ${tags} \ 162 | --batch_size 8 \ 163 | --lr 0.00003 \ 164 | --supervision_layer 3 \ 165 | --indices_dir ${indices_dir} 166 | 167 | done 168 | -------------------------------------------------------------------------------- /setenv.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=`pwd`:$PYTHONPATH 2 | --------------------------------------------------------------------------------