├── README.md
├── datareader.py
├── datareader_cnn.py
├── emnlp_final_experiments
├── claim-detection
│ ├── .DS_Store
│ ├── train_basic.py
│ ├── train_basic_domain_adversarial.py
│ ├── train_multi_view.py
│ ├── train_multi_view_averaging_individuals.py
│ ├── train_multi_view_domain_adversarial.py
│ ├── train_multi_view_domain_classifier.py
│ ├── train_multi_view_domainclassifier_individuals.py
│ └── train_multi_view_selective_weighting.py
└── sentiment-analysis
│ ├── .DS_Store
│ ├── analyze_expert_predictions.py
│ ├── analyze_expert_predictions_cnn.py
│ ├── train_basic.py
│ ├── train_basic_domain_adversarial.py
│ ├── train_multi_view.py
│ ├── train_multi_view_averaging_individuals.py
│ ├── train_multi_view_averaging_individuals_cnn.py
│ ├── train_multi_view_domain_adversarial.py
│ ├── train_multi_view_domain_classifier.py
│ ├── train_multi_view_domainclassifier_individuals.py
│ └── train_multi_view_selective_weighting.py
├── metrics.py
├── model.py
├── multisource-domain-adaptation.png
├── requirements.txt
├── run_claim_experiments.sh
├── run_sentiment_experiments.sh
└── setenv.sh
/README.md:
--------------------------------------------------------------------------------
1 | # Transformer Based Multi-Source Domain Adaptation
2 | Dustin Wright and Isabelle Augenstein
3 |
4 | To appear in EMNLP 2020. Read the preprint: https://arxiv.org/abs/2009.07806
5 |
6 |
7 |
8 |
9 |
10 | In practical machine learning settings, the data on which a model must make predictions often come from a different distribution than the data it was trained on. Here, we investigate the problem of unsupervised multi-source domain adaptation, where a model is trained on labelled data from multiple source domains and must make predictions on a domain for which no labelled data has been seen. Prior work with CNNs and RNNs has demonstrated the benefit of mixture of experts, where the predictions of multiple domain expert classifiers are combined; as well as domain adversarial training, to induce a domain agnostic representation space. Inspired by this, we investigate how such methods can be effectively applied to large pretrained transformer models. We find that domain adversarial training has an effect on the learned representations of these models while having little effect on their performance, suggesting that large transformer-based models are already relatively robust across domains. Additionally, we show that mixture of experts leads to significant performance improvements by comparing several variants of mixing functions, including one novel mixture based on attention. Finally, we demonstrate that the predictions of large pretrained transformer based domain experts are highly homogenous, making it challenging to learn effective functions for mixing their predictions.
11 |
12 | # Citing
13 |
14 | ```bib
15 | @inproceedings{wright2020transformer,
16 | title={{Transformer Based Multi-Source Domain Adaptation}},
17 | author={Dustin Wright and Isabelle Augenstein},
18 | booktitle = {Proceedings of EMNLP},
19 | publisher = {Association for Computational Linguistics},
20 | year = 2020
21 | }
22 | ```
23 |
24 | # Recreating Results
25 |
26 | To recreate our results, first download the [Amazon Product Reviews](https://www.cs.jhu.edu/~mdredze/datasets/sentiment/) and [PHEME Rumour Detection](https://figshare.com/articles/PHEME_dataset_for_Rumour_Detection_and_Veracity_Classification/6392078) datasets and place them in the 'data/' directory. For sentiment data place it in a directory called 'data/sentiment-dataset' and for the PHEME data place it in a directory called 'data/PHEME'
27 |
28 | Create a new conda environment:
29 |
30 | ```bash
31 | $ conda create --name xformer-multisource-domain-adaptation python=3.7
32 | $ conda activate xformer-multisource-domain-adaptation
33 | $ pip install -r requirements.txt
34 | ```
35 |
36 | Note that this project uses wandb; if you do not use wandb, set the following flag to store runs only locally:
37 |
38 | ```bash
39 | export WANDB_MODE=dryrun
40 | ```
41 |
42 | ## Running all experiments
43 |
44 | The files for running all of the experiments are in `run_sentiment_experiments.sh` and `run_claim_experiments.sh`. You can look in these files for the commands to run a particular experiment. Running either of these files will run all 10 variants presented in the paper 5 times. The individual scripts used for each experiment are under `emnlp_final_experiments/claim-detection` and `emnlp_final_experiments/sentiment-analysis`
45 |
--------------------------------------------------------------------------------
/datareader.py:
--------------------------------------------------------------------------------
1 | from xml.dom import minidom
2 | from typing import AnyStr
3 | from typing import List
4 | from typing import Tuple
5 | import unicodedata
6 | import pandas as pd
7 | import json
8 | import glob
9 | import ipdb
10 |
11 | import torch
12 | from torch.utils.data import Dataset
13 | from transformers import PreTrainedTokenizer
14 |
15 |
16 | domain_map = {
17 | 'gourmet_food': 0,
18 | 'jewelry_&_watches': 1,
19 | 'outdoor_living': 2,
20 | 'grocery': 3,
21 | 'computer_&_video_games': 4,
22 | 'beauty': 5,
23 | 'baby': 6,
24 | 'software': 7,
25 | 'magazines': 8,
26 | 'camera_&_photo': 9,
27 | 'music': 10,
28 | 'video': 11,
29 | 'health_&_personal_care': 12,
30 | 'toys_&_games': 13,
31 | 'sports_&_outdoors': 14,
32 | 'apparel': 15,
33 | 'books': 16,
34 | 'kitchen_&_housewares': 17,
35 | 'electronics': 18,
36 | 'dvd': 19
37 | }
38 |
39 | twitter_domain_map = {
40 | 'charliehebdo': 0,
41 | 'ferguson': 1,
42 | 'germanwings-crash': 2,
43 | 'ottawashooting': 3,
44 | 'sydneysiege': 4,
45 | 'health': 5
46 | }
47 |
48 | def text_to_batch_transformer(text: List, tokenizer: PreTrainedTokenizer, text_pair: AnyStr = None) -> Tuple[List, List]:
49 | """Turn a piece of text into a batch for transformer model
50 |
51 | :param text: The text to tokenize and encode
52 | :param tokenizer: The tokenizer to use
53 | :param: text_pair: An optional second string (for multiple sentence sequences)
54 | :return: A list of IDs and a mask
55 | """
56 | if text_pair is None:
57 | input_ids = [tokenizer.encode(t, add_special_tokens=True, max_length=tokenizer.max_len) for t in text]
58 | else:
59 | input_ids = [tokenizer.encode(t, text_pair=p, add_special_tokens=True, max_length=tokenizer.max_len) for t,p in zip(text, text_pair)]
60 |
61 | masks = [[1] * len(i) for i in input_ids]
62 |
63 | return input_ids, masks
64 |
65 |
66 | def collate_batch_transformer(input_data: Tuple) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
67 | input_ids = [i[0][0] for i in input_data]
68 | masks = [i[1][0] for i in input_data]
69 | labels = [i[2] for i in input_data]
70 | domains = [i[3] for i in input_data]
71 |
72 | max_length = max([len(i) for i in input_ids])
73 |
74 | input_ids = [(i + [0] * (max_length - len(i))) for i in input_ids]
75 | masks = [(m + [0] * (max_length - len(m))) for m in masks]
76 |
77 | assert (all(len(i) == max_length for i in input_ids))
78 | assert (all(len(m) == max_length for m in masks))
79 | return torch.tensor(input_ids), torch.tensor(masks), torch.tensor(labels), torch.tensor(domains)
80 |
81 |
82 | def collate_batch_transformer_with_index(input_data: Tuple) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List]:
83 | return collate_batch_transformer(input_data) + ([i[-1] for i in input_data],)
84 |
85 |
86 | def read_xml(dir: AnyStr, domain: AnyStr, split: AnyStr = 'positive'):
87 | """ Convert all of the ratings in amazon product XML file to dicts
88 |
89 | :param xml_file: The XML file to convert to a dict
90 | :return: All of the rows in the xml file as dicts
91 | """
92 | reviews = []
93 | split_map = {'positive': 1, 'negative': 0, 'unlabelled': -1}
94 | in_review_text = False
95 | with open(f'{dir}/{domain}/{split}.review', encoding='utf8', errors='ignore') as f:
96 | for line in f:
97 | if '' in line:
98 | reviews.append({'text': '', 'label': split_map[split], 'domain': domain_map[domain]})
99 | in_review_text = True
100 | continue
101 | if '' in line:
102 | in_review_text = False
103 | reviews[-1]['text'] = reviews[-1]['text'].replace('\n', ' ').strip()
104 | if in_review_text:
105 | reviews[-1]['text'] += line
106 | return reviews
107 |
108 |
109 | class MultiDomainSentimentDataset(Dataset):
110 | """
111 | Implements a dataset for the multidomain sentiment analysis dataset
112 | """
113 | def __init__(
114 | self,
115 | dataset_dir: AnyStr,
116 | domains: List,
117 | tokenizer: PreTrainedTokenizer,
118 | domain_ids: List = None
119 | ):
120 | """
121 |
122 | :param dataset_dir: The base directory for the dataset
123 | :param domains: The set of domains to load data for
124 | :param: tokenizer: The tokenizer to use
125 | :param: domain_ids: A list of ids to override the default domain IDs
126 | """
127 | super(MultiDomainSentimentDataset, self).__init__()
128 | data = []
129 | for domain in domains:
130 | data.extend(read_xml(dataset_dir, domain, 'positive'))
131 | data.extend(read_xml(dataset_dir, domain, 'negative'))
132 |
133 | self.dataset = pd.DataFrame(data)
134 | if domain_ids is not None:
135 | for i in range(len(domain_ids)):
136 | data[data['domain'] == domain_map[domains[i]]][2] = domain_ids[i]
137 | self.tokenizer = tokenizer
138 |
139 | def set_domain_id(self, domain_id):
140 | """
141 | Overrides the domain ID for all data
142 | :param domain_id:
143 | :return:
144 | """
145 | self.dataset['domain'] = domain_id
146 |
147 | def __len__(self):
148 | return self.dataset.shape[0]
149 |
150 | def __getitem__(self, item) -> Tuple:
151 | row = self.dataset.values[item]
152 | input_ids, mask = text_to_batch_transformer([row[0]], self.tokenizer)
153 | label = row[1]
154 | domain = row[2]
155 | return input_ids, mask, label, domain, item
156 |
157 |
158 | class MultiDomainTwitterDataset(Dataset):
159 | """
160 | Implements a dataset for the multidomain sentiment analysis dataset
161 | """
162 | def __init__(
163 | self,
164 | dataset_dir: AnyStr,
165 | domains: List,
166 | tokenizer: PreTrainedTokenizer,
167 | health_data_loc: AnyStr = None,
168 | domain_ids: List = None
169 | ):
170 | """
171 |
172 | :param dataset_dir: The base directory for the dataset
173 | :param domains: The set of domains to load data for
174 | :param: tokenizer: The tokenizer to use
175 | :param: domain_ids: A list of ids to override the default domain IDs
176 | """
177 | super(MultiDomainTwitterDataset, self).__init__()
178 | rumours = []
179 | non_rumours = []
180 | d_ids = []
181 | self.name = "_".join(domains)
182 | for domain in domains:
183 | if domain != 'health':
184 | for source_tweet_file in glob.glob(f'{dataset_dir}/{domain}-all-rnr-threads/non-rumours/**/source-tweets/*.json'):
185 | with open(source_tweet_file) as js:
186 | tweet = json.load(js)
187 | non_rumours.append(tweet['text'])
188 | d_ids.append(twitter_domain_map[domain])
189 | for source_tweet_file in glob.glob(f'{dataset_dir}/{domain}-all-rnr-threads/rumours/**/source-tweets/*.json'):
190 | with open(source_tweet_file) as js:
191 | tweet = json.load(js)
192 | rumours.append(tweet['text'])
193 | d_ids.append(twitter_domain_map[domain])
194 | elif health_data_loc is not None:
195 | health_dataset = pd.read_csv(health_data_loc, sep="\t", header=None)
196 | # Remove unknowns
197 | health_dataset = health_dataset[health_dataset[1] != 0]
198 | # Transform the text
199 | health_dataset[0] = health_dataset[0].apply(lambda x: x[10:] if 'RT @xxxxx ' == x[:10] else x)
200 | # Drop duplicates
201 | health_dataset = health_dataset.drop_duplicates()
202 | statements = [v[0] for v in health_dataset.values]
203 | lblmap = {1: 0, -1: 1}
204 | labels = [lblmap[v[1]] for v in health_dataset.values]
205 | rumours.extend([s for s,l in zip(statements, labels) if l == 1])
206 | non_rumours.extend([s for s,l in zip(statements, labels) if l == 0])
207 | d_ids.extend([twitter_domain_map[domain]] * len(labels))
208 |
209 |
210 | self.dataset = pd.DataFrame(rumours + non_rumours, columns=['statement'])
211 | self.dataset['label'] = [1] * len(rumours) + [0] * len(non_rumours)
212 | self.dataset['statement'] = self.dataset['statement'].str.normalize('NFKD')
213 | self.dataset['domain'] = d_ids
214 |
215 | self.tokenizer = tokenizer
216 |
217 | def set_domain_id(self, domain_id):
218 | """
219 | Overrides the domain ID for all data
220 | :param domain_id:
221 | :return:
222 | """
223 | self.dataset['domain'] = domain_id
224 |
225 | def __len__(self):
226 | return self.dataset.shape[0]
227 |
228 | def __getitem__(self, item) -> Tuple:
229 | row = self.dataset.values[item]
230 | input_ids, mask = text_to_batch_transformer([row[0]], self.tokenizer)
231 | label = row[1]
232 | domain = row[2]
233 | return input_ids, mask, label, domain, item
234 |
235 |
236 |
--------------------------------------------------------------------------------
/datareader_cnn.py:
--------------------------------------------------------------------------------
1 | from xml.dom import minidom
2 | from typing import AnyStr
3 | from typing import List
4 | from typing import Tuple
5 | import unicodedata
6 | import pandas as pd
7 | import json
8 | import glob
9 | import ipdb
10 |
11 | import torch
12 | from torch.utils.data import Dataset
13 | from transformers import PreTrainedTokenizer
14 | from fasttext import tokenize
15 |
16 |
17 | domain_map = {
18 | 'gourmet_food': 0,
19 | 'jewelry_&_watches': 1,
20 | 'outdoor_living': 2,
21 | 'grocery': 3,
22 | 'computer_&_video_games': 4,
23 | 'beauty': 5,
24 | 'baby': 6,
25 | 'software': 7,
26 | 'magazines': 8,
27 | 'camera_&_photo': 9,
28 | 'music': 10,
29 | 'video': 11,
30 | 'health_&_personal_care': 12,
31 | 'toys_&_games': 13,
32 | 'sports_&_outdoors': 14,
33 | 'apparel': 15,
34 | 'books': 16,
35 | 'kitchen_&_housewares': 17,
36 | 'electronics': 18,
37 | 'dvd': 19
38 | }
39 |
40 | twitter_domain_map = {
41 | 'charliehebdo': 0,
42 | 'ferguson': 1,
43 | 'germanwings-crash': 2,
44 | 'ottawashooting': 3,
45 | 'sydneysiege': 4,
46 | 'health': 5
47 | }
48 |
49 | def text_to_batch_cnn(text: List, tokenizer, text_pair: AnyStr = None) -> Tuple[List, List]:
50 | """Turn a piece of text into a batch for transformer model
51 |
52 | :param text: The text to tokenize and encode
53 | :param tokenizer: The tokenizer to use
54 | :param: text_pair: An optional second string (for multiple sentence sequences)
55 | :return: A list of IDs and a mask
56 | """
57 | input_ids = [tokenizer.encode(t) for t in text]
58 |
59 | masks = [[1] * len(i) for i in input_ids]
60 |
61 | return input_ids, masks
62 |
63 |
64 | def collate_batch_cnn(input_data: Tuple) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
65 | input_ids = [i[0][0] for i in input_data]
66 | masks = [i[1][0] for i in input_data]
67 | labels = [i[2] for i in input_data]
68 | domains = [i[3] for i in input_data]
69 |
70 | max_length = max([len(i) for i in input_ids])
71 |
72 | input_ids = [(i + [0] * (max_length - len(i))) for i in input_ids]
73 | masks = [(m + [0] * (max_length - len(m))) for m in masks]
74 |
75 | assert (all(len(i) == max_length for i in input_ids))
76 | assert (all(len(m) == max_length for m in masks))
77 | return torch.tensor(input_ids), torch.tensor(masks), torch.tensor(labels), torch.tensor(domains)
78 |
79 |
80 | class FasttextTokenizer:
81 |
82 | def __init__(self, vocabulary_file):
83 | self.vocab = {}
84 | with open(vocabulary_file) as f:
85 | for j,l in enumerate(f):
86 | self.vocab[l.strip()] = j
87 |
88 | def encode(self, text):
89 | tokens = tokenize(text.lower().replace('\n', ' ') + '\n')
90 | return [self.vocab[t] if t in self.vocab else self.vocab['[UNK]'] for t in tokens]
91 |
92 |
93 | def collate_batch_cnn_with_index(input_data: Tuple) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List]:
94 | return collate_batch_cnn(input_data) + ([i[-1] for i in input_data],)
95 |
96 |
97 | def read_xml(dir: AnyStr, domain: AnyStr, split: AnyStr = 'positive'):
98 | """ Convert all of the ratings in amazon product XML file to dicts
99 |
100 | :param xml_file: The XML file to convert to a dict
101 | :return: All of the rows in the xml file as dicts
102 | """
103 | reviews = []
104 | split_map = {'positive': 1, 'negative': 0, 'unlabelled': -1}
105 | in_review_text = False
106 | with open(f'{dir}/{domain}/{split}.review', encoding='utf8', errors='ignore') as f:
107 | for line in f:
108 | if '' in line:
109 | reviews.append({'text': '', 'label': split_map[split], 'domain': domain_map[domain]})
110 | in_review_text = True
111 | continue
112 | if '' in line:
113 | in_review_text = False
114 | reviews[-1]['text'] = reviews[-1]['text'].replace('\n', ' ').strip()
115 | if in_review_text:
116 | reviews[-1]['text'] += line
117 | return reviews
118 |
119 |
120 | class MultiDomainSentimentDataset(Dataset):
121 | """
122 | Implements a dataset for the multidomain sentiment analysis dataset
123 | """
124 | def __init__(
125 | self,
126 | dataset_dir: AnyStr,
127 | domains: List,
128 | tokenizer,
129 | domain_ids: List = None
130 | ):
131 | """
132 |
133 | :param dataset_dir: The base directory for the dataset
134 | :param domains: The set of domains to load data for
135 | :param: tokenizer: The tokenizer to use
136 | :param: domain_ids: A list of ids to override the default domain IDs
137 | """
138 | super(MultiDomainSentimentDataset, self).__init__()
139 | data = []
140 | for domain in domains:
141 | data.extend(read_xml(dataset_dir, domain, 'positive'))
142 | data.extend(read_xml(dataset_dir, domain, 'negative'))
143 |
144 | self.dataset = pd.DataFrame(data)
145 | if domain_ids is not None:
146 | for i in range(len(domain_ids)):
147 | data[data['domain'] == domain_map[domains[i]]][2] = domain_ids[i]
148 | self.tokenizer = tokenizer
149 |
150 | def set_domain_id(self, domain_id):
151 | """
152 | Overrides the domain ID for all data
153 | :param domain_id:
154 | :return:
155 | """
156 | self.dataset['domain'] = domain_id
157 |
158 | def __len__(self):
159 | return self.dataset.shape[0]
160 |
161 | def __getitem__(self, item) -> Tuple:
162 | row = self.dataset.values[item]
163 | input_ids, mask = text_to_batch_cnn([row[0]], self.tokenizer)
164 | label = row[1]
165 | domain = row[2]
166 | return input_ids, mask, label, domain, item
167 |
168 |
169 | class MultiDomainTwitterDataset(Dataset):
170 | """
171 | Implements a dataset for the multidomain sentiment analysis dataset
172 | """
173 | def __init__(
174 | self,
175 | dataset_dir: AnyStr,
176 | domains: List,
177 | tokenizer: PreTrainedTokenizer,
178 | health_data_loc: AnyStr = None,
179 | domain_ids: List = None
180 | ):
181 | """
182 |
183 | :param dataset_dir: The base directory for the dataset
184 | :param domains: The set of domains to load data for
185 | :param: tokenizer: The tokenizer to use
186 | :param: domain_ids: A list of ids to override the default domain IDs
187 | """
188 | super(MultiDomainTwitterDataset, self).__init__()
189 | rumours = []
190 | non_rumours = []
191 | d_ids = []
192 | self.name = "_".join(domains)
193 | for domain in domains:
194 | if domain != 'health':
195 | for source_tweet_file in glob.glob(f'{dataset_dir}/{domain}-all-rnr-threads/non-rumours/**/source-tweets/*.json'):
196 | with open(source_tweet_file) as js:
197 | tweet = json.load(js)
198 | non_rumours.append(tweet['text'])
199 | d_ids.append(twitter_domain_map[domain])
200 | for source_tweet_file in glob.glob(f'{dataset_dir}/{domain}-all-rnr-threads/rumours/**/source-tweets/*.json'):
201 | with open(source_tweet_file) as js:
202 | tweet = json.load(js)
203 | rumours.append(tweet['text'])
204 | d_ids.append(twitter_domain_map[domain])
205 | elif health_data_loc is not None:
206 | health_dataset = pd.read_csv(health_data_loc, sep="\t", header=None)
207 | # Remove unknowns
208 | health_dataset = health_dataset[health_dataset[1] != 0]
209 | # Transform the text
210 | health_dataset[0] = health_dataset[0].apply(lambda x: x[10:] if 'RT @xxxxx ' == x[:10] else x)
211 | # Drop duplicates
212 | health_dataset = health_dataset.drop_duplicates()
213 | statements = [v[0] for v in health_dataset.values]
214 | lblmap = {1: 0, -1: 1}
215 | labels = [lblmap[v[1]] for v in health_dataset.values]
216 | rumours.extend([s for s,l in zip(statements, labels) if l == 1])
217 | non_rumours.extend([s for s,l in zip(statements, labels) if l == 0])
218 | d_ids.extend([twitter_domain_map[domain]] * len(labels))
219 |
220 |
221 | self.dataset = pd.DataFrame(rumours + non_rumours, columns=['statement'])
222 | self.dataset['label'] = [1] * len(rumours) + [0] * len(non_rumours)
223 | self.dataset['statement'] = self.dataset['statement'].str.normalize('NFKD')
224 | self.dataset['domain'] = d_ids
225 |
226 | self.tokenizer = tokenizer
227 |
228 | def set_domain_id(self, domain_id):
229 | """
230 | Overrides the domain ID for all data
231 | :param domain_id:
232 | :return:
233 | """
234 | self.dataset['domain'] = domain_id
235 |
236 | def __len__(self):
237 | return self.dataset.shape[0]
238 |
239 | def __getitem__(self, item) -> Tuple:
240 | row = self.dataset.values[item]
241 | input_ids, mask = text_to_batch_cnn([row[0]], self.tokenizer)
242 | label = row[1]
243 | domain = row[2]
244 | return input_ids, mask, label, domain, item
245 |
246 |
--------------------------------------------------------------------------------
/emnlp_final_experiments/claim-detection/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/copenlu/xformer-multi-source-domain-adaptation/be2d1a132298131df82fe40dd4f6c08dec8b3404/emnlp_final_experiments/claim-detection/.DS_Store
--------------------------------------------------------------------------------
/emnlp_final_experiments/claim-detection/train_basic.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import os
4 | import random
5 | from typing import AnyStr
6 | from typing import List
7 | import ipdb
8 | from collections import defaultdict
9 | from pathlib import Path
10 |
11 | import numpy as np
12 | import torch
13 | import wandb
14 | from torch.optim.lr_scheduler import LambdaLR
15 | from torch.utils.data import DataLoader
16 | from torch.utils.data import Subset
17 | from torch.utils.data import random_split
18 | from torch.optim import Adam
19 | from tqdm import tqdm
20 | from transformers import AdamW
21 | from transformers import DistilBertConfig
22 | from transformers import DistilBertTokenizer
23 | from transformers import DistilBertModel
24 | from transformers import BertConfig
25 | from transformers import BertTokenizer
26 | from transformers import BertModel
27 | from transformers import get_linear_schedule_with_warmup
28 |
29 | from datareader import MultiDomainTwitterDataset
30 | from datareader import collate_batch_transformer
31 | from metrics import MultiDatasetClassificationEvaluator, acc_f1
32 | from metrics import ClassificationEvaluator
33 |
34 | from metrics import plot_label_distribution
35 | from model import *
36 |
37 |
38 | def train(
39 | model: torch.nn.Module,
40 | train_dls: List[DataLoader],
41 | optimizer: torch.optim.Optimizer,
42 | scheduler: LambdaLR,
43 | validation_evaluator: MultiDatasetClassificationEvaluator,
44 | n_epochs: int,
45 | device: AnyStr,
46 | log_interval: int = 1,
47 | patience: int = 10,
48 | model_dir: str = "wandb_local",
49 | gradient_accumulation: int = 1,
50 | domain_name: str = ''
51 | ):
52 | #best_loss = float('inf')
53 | best_f1 = 0.0
54 | patience_counter = 0
55 |
56 | epoch_counter = 0
57 | total = sum(len(dl) for dl in train_dls)
58 |
59 | # Main loop
60 | while epoch_counter < n_epochs:
61 | dl_iters = [iter(dl) for dl in train_dls]
62 | dl_idx = list(range(len(dl_iters)))
63 | finished = [0] * len(dl_iters)
64 | i = 0
65 | with tqdm(total=total, desc="Training") as pbar:
66 | while sum(finished) < len(dl_iters):
67 | random.shuffle(dl_idx)
68 | for d in dl_idx:
69 | domain_dl = dl_iters[d]
70 | batches = []
71 | try:
72 | for j in range(gradient_accumulation):
73 | batches.append(next(domain_dl))
74 | except StopIteration:
75 | finished[d] = 1
76 | if len(batches) == 0:
77 | continue
78 | optimizer.zero_grad()
79 | for batch in batches:
80 | model.train()
81 | batch = tuple(t.to(device) for t in batch)
82 | input_ids = batch[0]
83 | masks = batch[1]
84 | labels = batch[2]
85 | # Testing with random domains to see if any effect
86 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
87 | domains = batch[3]
88 |
89 | loss, logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels)
90 | loss = loss / gradient_accumulation
91 |
92 | if i % log_interval == 0:
93 | wandb.log({
94 | "Loss": loss.item()
95 | })
96 |
97 | loss.backward()
98 | i += 1
99 | pbar.update(1)
100 |
101 | optimizer.step()
102 | if scheduler is not None:
103 | scheduler.step()
104 |
105 | gc.collect()
106 |
107 | # Inline evaluation
108 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
109 | print(f"Validation F1: {F1}")
110 |
111 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
112 |
113 | # Saving the best model and early stopping
114 | #if val_loss < best_loss:
115 | if F1 > best_f1:
116 | best_model = model.state_dict()
117 | #best_loss = val_loss
118 | best_f1 = F1
119 | #wandb.run.summary['best_validation_loss'] = best_loss
120 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
121 | patience_counter = 0
122 | # Log to wandb
123 | wandb.log({
124 | 'Validation accuracy': acc,
125 | 'Validation Precision': P,
126 | 'Validation Recall': R,
127 | 'Validation F1': F1,
128 | 'Validation loss': val_loss})
129 | else:
130 | patience_counter += 1
131 | # Stop training once we have lost patience
132 | if patience_counter == patience:
133 | break
134 |
135 | gc.collect()
136 | epoch_counter += 1
137 |
138 |
139 | if __name__ == "__main__":
140 | # Define arguments
141 | parser = argparse.ArgumentParser()
142 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str)
143 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8)
144 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0)
145 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1)
146 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200)
147 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2)
148 | parser.add_argument("--pretrained_model", help="Weights to initialize the model with", type=str, default=None)
149 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[])
150 | parser.add_argument("--seed", type=int, help="Random seed", default=1000)
151 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline")
152 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str)
153 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[])
154 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert")
155 | parser.add_argument("--ff_dim", help="The dimensionality of the feedforward network in the sluice", type=int, default=768)
156 | parser.add_argument("--batch_size", help="The batch size", type=int, default=8)
157 | parser.add_argument("--lr", help="Learning rate", type=float, default=3e-5)
158 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01)
159 | parser.add_argument("--lambd", help="l2 reg", type=float, default=10e-3)
160 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int)
161 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None)
162 | parser.add_argument("--full_bert", help="Specify to use full bert model", action="store_true")
163 |
164 | args = parser.parse_args()
165 |
166 | # Set all the seeds
167 | seed = args.seed
168 | random.seed(seed)
169 | np.random.seed(seed)
170 | torch.manual_seed(seed)
171 | torch.cuda.manual_seed_all(seed)
172 | torch.backends.cudnn.deterministic = True
173 | torch.backends.cudnn.benchmark = False
174 |
175 | # See if CUDA available
176 | device = torch.device("cpu")
177 | if args.n_gpu > 0 and torch.cuda.is_available():
178 | print("Training on GPU")
179 | device = torch.device("cuda:0")
180 |
181 | # model configuration
182 | batch_size = args.batch_size
183 | lr = args.lr
184 | weight_decay = args.weight_decay
185 | n_epochs = args.n_epochs
186 | if args.full_bert:
187 | bert_model = 'bert-base-uncased'
188 | bert_config = BertConfig.from_pretrained(bert_model, num_labels=2)
189 | tokenizer = BertTokenizer.from_pretrained(bert_model)
190 | else:
191 | bert_model = 'distilbert-base-uncased'
192 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2)
193 | tokenizer = DistilBertTokenizer.from_pretrained(bert_model)
194 |
195 | # wandb initialization
196 | wandb.init(
197 | project="domain-adaptation-twitter-emnlp",
198 | name=args.run_name,
199 | config={
200 | "epochs": n_epochs,
201 | "learning_rate": lr,
202 | "warmup": args.warmup_steps,
203 | "weight_decay": weight_decay,
204 | "batch_size": batch_size,
205 | "train_split_percentage": args.train_pct,
206 | "bert_model": bert_model,
207 | "seed": seed,
208 | "pretrained_model": args.pretrained_model,
209 | "tags": ",".join(args.tags)
210 | }
211 | )
212 | # Create save directory for model
213 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"):
214 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}")
215 |
216 | # Create the dataset
217 | all_dsets = [MultiDomainTwitterDataset(
218 | args.dataset_loc,
219 | [domain],
220 | tokenizer
221 | ) for domain in args.domains[:-1]]
222 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)]
223 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))]
224 |
225 | accs = []
226 | Ps = []
227 | Rs = []
228 | F1s = []
229 | # Store labels and logits for individual splits for micro F1
230 | labels_all = []
231 | logits_all = []
232 |
233 | for i in range(len(all_dsets)):
234 | domain = args.domains[i]
235 | test_dset = all_dsets[i]
236 | # Override the domain IDs
237 | k = 0
238 | for j in range(len(all_dsets)):
239 | if j != i:
240 | all_dsets[j].set_domain_id(k)
241 | k += 1
242 | test_dset.set_domain_id(k)
243 |
244 | # Split the data
245 | if args.indices_dir is None:
246 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]])
247 | for j in range(len(all_dsets)) if j != i]
248 | # Save the indices
249 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/train_idx_{domain}.txt', 'wt') as f, \
250 | open(f'{args.model_dir}/{Path(wandb.run.dir).name}/val_idx_{domain}.txt', 'wt') as g:
251 | for j, subset in enumerate(subsets):
252 | for idx in subset[0].indices:
253 | f.write(f'{j},{idx}\n')
254 | for idx in subset[1].indices:
255 | g.write(f'{j},{idx}\n')
256 | else:
257 | # load the indices
258 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i]
259 | subset_indices = defaultdict(lambda: [[], []])
260 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \
261 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g:
262 | for l in f:
263 | vals = l.strip().split(',')
264 | subset_indices[int(vals[0])][0].append(int(vals[1]))
265 | for l in g:
266 | vals = l.strip().split(',')
267 | subset_indices[int(vals[0])][1].append(int(vals[1]))
268 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])]
269 | for d in subset_indices]
270 |
271 | train_dls = [DataLoader(
272 | subset[0],
273 | batch_size=batch_size,
274 | shuffle=True,
275 | collate_fn=collate_batch_transformer
276 | ) for subset in subsets]
277 |
278 | val_ds = [subset[1] for subset in subsets]
279 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device)
280 |
281 | # Create the model
282 | if args.full_bert:
283 | bert = BertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
284 | else:
285 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
286 | if args.pretrained_model is not None:
287 | weights = {k: v for k, v in torch.load(args.pretrained_model).items() if "classifier" not in k}
288 | model_dict = bert.state_dict()
289 | model_dict.update(weights)
290 | bert.load_state_dict(model_dict)
291 | model = VanillaBert(bert).to(device)
292 |
293 | # Create the optimizer
294 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
295 | optimizer_grouped_parameters = [
296 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
297 | 'weight_decay': weight_decay},
298 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
299 | ]
300 | # optimizer = Adam(optimizer_grouped_parameters, lr=1e-3)
301 | # scheduler = None
302 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
303 | scheduler = get_linear_schedule_with_warmup(
304 | optimizer,
305 | args.warmup_steps,
306 | n_epochs * sum([len(train_dl) for train_dl in train_dls])
307 | )
308 |
309 | # Train
310 | train(
311 | model,
312 | train_dls,
313 | optimizer,
314 | scheduler,
315 | validation_evaluator,
316 | n_epochs,
317 | device,
318 | args.log_interval,
319 | model_dir=args.model_dir,
320 | gradient_accumulation=args.gradient_accumulation,
321 | domain_name=domain
322 | )
323 |
324 | # Load the best weights
325 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth'))
326 |
327 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False)
328 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate(
329 | model,
330 | plot_callbacks=[plot_label_distribution],
331 | return_labels_logits=True,
332 | return_votes=True
333 | )
334 | print(f"{domain} F1: {F1}")
335 | print(f"{domain} Accuracy: {acc}")
336 | print()
337 |
338 | wandb.run.summary[f"{domain}-P"] = P
339 | wandb.run.summary[f"{domain}-R"] = R
340 | wandb.run.summary[f"{domain}-F1"] = F1
341 | wandb.run.summary[f"{domain}-Acc"] = acc
342 | # macro and micro F1 are only with respect to the 5 main splits
343 | if i < len(all_dsets):
344 | Ps.append(P)
345 | Rs.append(R)
346 | F1s.append(F1)
347 | accs.append(acc)
348 | labels_all.extend(labels)
349 | logits_all.extend(logits)
350 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f:
351 | for p, l in zip(np.argmax(logits, axis=-1), labels):
352 | f.write(f'{domain}\t{p}\t{l}\n')
353 |
354 | acc, P, R, F1 = acc_f1(logits_all, labels_all)
355 | # Add to wandb
356 | wandb.run.summary[f'test-loss'] = loss
357 | wandb.run.summary[f'test-micro-acc'] = acc
358 | wandb.run.summary[f'test-micro-P'] = P
359 | wandb.run.summary[f'test-micro-R'] = R
360 | wandb.run.summary[f'test-micro-F1'] = F1
361 |
362 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs)
363 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps)
364 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs)
365 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s)
366 |
367 | #wandb.log({f"label-distribution-test-{i}": plots[0]})
368 |
--------------------------------------------------------------------------------
/emnlp_final_experiments/claim-detection/train_basic_domain_adversarial.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import os
4 | import random
5 | from typing import AnyStr
6 | from typing import List
7 | import ipdb
8 | from collections import defaultdict
9 | from pathlib import Path
10 |
11 | import numpy as np
12 | import torch
13 | import wandb
14 | from torch.optim.lr_scheduler import LambdaLR
15 | from torch.utils.data import DataLoader
16 | from torch.utils.data import Subset
17 | from torch.utils.data import random_split
18 | from torch.optim import Adam
19 | from tqdm import tqdm
20 | from transformers import AdamW
21 | from transformers import DistilBertConfig
22 | from transformers import DistilBertTokenizer
23 | from transformers import DistilBertModel
24 | from transformers import BertConfig
25 | from transformers import BertTokenizer
26 | from transformers import BertModel
27 | from transformers import get_linear_schedule_with_warmup
28 |
29 | from datareader import MultiDomainTwitterDataset
30 | from datareader import collate_batch_transformer
31 | from metrics import MultiDatasetClassificationEvaluator, acc_f1
32 | from metrics import ClassificationEvaluator
33 |
34 | from metrics import plot_label_distribution
35 | from model import *
36 |
37 |
38 | def train(
39 | model: torch.nn.Module,
40 | train_dls: List[DataLoader],
41 | optimizer: torch.optim.Optimizer,
42 | scheduler: LambdaLR,
43 | validation_evaluator: MultiDatasetClassificationEvaluator,
44 | n_epochs: int,
45 | device: AnyStr,
46 | log_interval: int = 1,
47 | patience: int = 10,
48 | model_dir: str = "wandb_local",
49 | gradient_accumulation: int = 1,
50 | domain_name: str = ''
51 | ):
52 | #best_loss = float('inf')
53 | best_f1 = 0.0
54 | patience_counter = 0
55 |
56 | epoch_counter = 0
57 | total = sum(len(dl) for dl in train_dls)
58 |
59 | # Main loop
60 | while epoch_counter < n_epochs:
61 | dl_iters = [iter(dl) for dl in train_dls]
62 | dl_idx = list(range(len(dl_iters)))
63 | finished = [0] * len(dl_iters)
64 | i = 0
65 | with tqdm(total=total, desc="Training") as pbar:
66 | while sum(finished) < len(dl_iters):
67 | random.shuffle(dl_idx)
68 | for d in dl_idx:
69 | domain_dl = dl_iters[d]
70 | batches = []
71 | try:
72 | for j in range(gradient_accumulation):
73 | batches.append(next(domain_dl))
74 | except StopIteration:
75 | finished[d] = 1
76 | if len(batches) == 0:
77 | continue
78 | optimizer.zero_grad()
79 | for batch in batches:
80 | model.train()
81 | batch = tuple(t.to(device) for t in batch)
82 | input_ids = batch[0]
83 | masks = batch[1]
84 | labels = batch[2]
85 | # Null the labels if its the test data
86 | if d == len(train_dls) - 1:
87 | labels = None
88 | # Testing with random domains to see if any effect
89 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
90 | domains = batch[3]
91 |
92 | loss, logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels)
93 | loss = loss.mean() / gradient_accumulation
94 |
95 | if i % log_interval == 0:
96 | wandb.log({
97 | "Loss": loss.item()
98 | })
99 |
100 | loss.backward()
101 | i += 1
102 | pbar.update(1)
103 |
104 | optimizer.step()
105 | if scheduler is not None:
106 | scheduler.step()
107 |
108 | gc.collect()
109 |
110 | # Inline evaluation
111 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
112 | print(f"Validation F1: {F1}")
113 |
114 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
115 |
116 | # Saving the best model and early stopping
117 | #if val_loss < best_loss:
118 | if F1 > best_f1:
119 | best_model = model.state_dict()
120 | #best_loss = val_loss
121 | best_f1 = F1
122 | #wandb.run.summary['best_validation_loss'] = best_loss
123 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
124 | patience_counter = 0
125 | # Log to wandb
126 | wandb.log({
127 | 'Validation accuracy': acc,
128 | 'Validation Precision': P,
129 | 'Validation Recall': R,
130 | 'Validation F1': F1,
131 | 'Validation loss': val_loss})
132 | else:
133 | patience_counter += 1
134 | # Stop training once we have lost patience
135 | if patience_counter == patience:
136 | break
137 |
138 | gc.collect()
139 | epoch_counter += 1
140 |
141 |
142 | if __name__ == "__main__":
143 | # Define arguments
144 | parser = argparse.ArgumentParser()
145 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str)
146 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8)
147 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0)
148 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1)
149 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200)
150 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2)
151 | parser.add_argument("--pretrained_model", help="Weights to initialize the model with", type=str, default=None)
152 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[])
153 | parser.add_argument("--seed", type=int, help="Random seed", default=1000)
154 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline")
155 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str)
156 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[])
157 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert")
158 | parser.add_argument("--ff_dim", help="The dimensionality of the feedforward network in the sluice", type=int, default=768)
159 | parser.add_argument("--batch_size", help="The batch size", type=int, default=8)
160 | parser.add_argument("--lr", help="Learning rate", type=float, default=3e-5)
161 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01)
162 | parser.add_argument("--lambd", help="l2 reg", type=float, default=10e-3)
163 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int)
164 | parser.add_argument("--supervision_layer", help="The layer at which to use domain adversarial supervision", default=12, type=int)
165 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None)
166 | parser.add_argument("--full_bert", help="Specify to use full bert model", action="store_true")
167 |
168 |
169 | args = parser.parse_args()
170 |
171 | # Set all the seeds
172 | seed = args.seed
173 | random.seed(seed)
174 | np.random.seed(seed)
175 | torch.manual_seed(seed)
176 | torch.cuda.manual_seed_all(seed)
177 | torch.backends.cudnn.deterministic = True
178 | torch.backends.cudnn.benchmark = False
179 |
180 | # See if CUDA available
181 | device = torch.device("cpu")
182 | if args.n_gpu > 0 and torch.cuda.is_available():
183 | print("Training on GPU")
184 | device = torch.device("cuda:0")
185 |
186 | # model configuration
187 | batch_size = args.batch_size
188 | lr = args.lr
189 | weight_decay = args.weight_decay
190 | n_epochs = args.n_epochs
191 | if args.full_bert:
192 | bert_model = 'bert-base-uncased'
193 | bert_config = BertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True)
194 | tokenizer = BertTokenizer.from_pretrained(bert_model)
195 | else:
196 | bert_model = 'distilbert-base-uncased'
197 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True)
198 | tokenizer = DistilBertTokenizer.from_pretrained(bert_model)
199 |
200 | # wandb initialization
201 | wandb.init(
202 | project="domain-adaptation-twitter-emnlp",
203 | name=args.run_name,
204 | config={
205 | "epochs": n_epochs,
206 | "learning_rate": lr,
207 | "warmup": args.warmup_steps,
208 | "weight_decay": weight_decay,
209 | "batch_size": batch_size,
210 | "train_split_percentage": args.train_pct,
211 | "bert_model": bert_model,
212 | "seed": seed,
213 | "pretrained_model": args.pretrained_model,
214 | "tags": ",".join(args.tags)
215 | }
216 | )
217 | # Create save directory for model
218 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"):
219 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}")
220 |
221 | # Create the dataset
222 | all_dsets = [MultiDomainTwitterDataset(
223 | args.dataset_loc,
224 | [domain],
225 | tokenizer
226 | ) for domain in args.domains[:-1]]
227 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)]
228 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))]
229 |
230 | accs = []
231 | Ps = []
232 | Rs = []
233 | F1s = []
234 | # Store labels and logits for individual splits for micro F1
235 | labels_all = []
236 | logits_all = []
237 |
238 | for i in range(len(all_dsets)):
239 | domain = args.domains[i]
240 | test_dset = all_dsets[i]
241 | # Override the domain IDs
242 | k = 0
243 | for j in range(len(all_dsets)):
244 | if j != i:
245 | all_dsets[j].set_domain_id(k)
246 | k += 1
247 | test_dset.set_domain_id(k)
248 |
249 | # Split the data
250 | if args.indices_dir is None:
251 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]])
252 | for j in range(len(all_dsets)) if j != i]
253 | else:
254 | # load the indices
255 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i]
256 | subset_indices = defaultdict(lambda: [[], []])
257 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \
258 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g:
259 | for l in f:
260 | vals = l.strip().split(',')
261 | subset_indices[int(vals[0])][0].append(int(vals[1]))
262 | for l in g:
263 | vals = l.strip().split(',')
264 | subset_indices[int(vals[0])][1].append(int(vals[1]))
265 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])]
266 | for d in subset_indices]
267 | train_dls = [DataLoader(
268 | subset[0],
269 | batch_size=batch_size,
270 | shuffle=True,
271 | collate_fn=collate_batch_transformer
272 | ) for subset in subsets]
273 | # Add test data for domain adversarial training
274 | train_dls += [DataLoader(
275 | test_dset,
276 | batch_size=batch_size,
277 | shuffle=True,
278 | collate_fn=collate_batch_transformer
279 | )]
280 |
281 | val_ds = [subset[1] for subset in subsets]
282 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device)
283 |
284 | # Create the model
285 | if args.full_bert:
286 | bert = BertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
287 | else:
288 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
289 | if args.pretrained_model is not None:
290 | weights = {k: v for k, v in torch.load(args.pretrained_model).items() if "classifier" not in k}
291 | model_dict = bert.state_dict()
292 | model_dict.update(weights)
293 | bert.load_state_dict(model_dict)
294 | model = torch.nn.DataParallel(DomainAdversarialBert(
295 | bert,
296 | n_domains=len(train_dls), supervision_layer=args.supervision_layer
297 | )).to(device)
298 |
299 | # Create the optimizer
300 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
301 | optimizer_grouped_parameters = [
302 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
303 | 'weight_decay': weight_decay},
304 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
305 | ]
306 | # optimizer = Adam(optimizer_grouped_parameters, lr=1e-3)
307 | # scheduler = None
308 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
309 | scheduler = get_linear_schedule_with_warmup(
310 | optimizer,
311 | args.warmup_steps,
312 | n_epochs * sum([len(train_dl) for train_dl in train_dls])
313 | )
314 |
315 | # Train
316 | train(
317 | model,
318 | train_dls,
319 | optimizer,
320 | scheduler,
321 | validation_evaluator,
322 | n_epochs,
323 | device,
324 | args.log_interval,
325 | model_dir=args.model_dir,
326 | gradient_accumulation=args.gradient_accumulation,
327 | domain_name=domain
328 | )
329 |
330 | # Load the best weights
331 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth'))
332 |
333 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False)
334 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate(
335 | model,
336 | plot_callbacks=[plot_label_distribution],
337 | return_labels_logits=True,
338 | return_votes=True
339 | )
340 | print(f"{domain} F1: {F1}")
341 | print(f"{domain} Accuracy: {acc}")
342 | print()
343 |
344 | wandb.run.summary[f"{domain}-P"] = P
345 | wandb.run.summary[f"{domain}-R"] = R
346 | wandb.run.summary[f"{domain}-F1"] = F1
347 | wandb.run.summary[f"{domain}-Acc"] = acc
348 | # macro and micro F1 are only with respect to the 5 main splits
349 | if i < len(all_dsets):
350 | Ps.append(P)
351 | Rs.append(R)
352 | F1s.append(F1)
353 | accs.append(acc)
354 | labels_all.extend(labels)
355 | logits_all.extend(logits)
356 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f:
357 | for p, l in zip(np.argmax(logits, axis=-1), labels):
358 | f.write(f'{domain}\t{p}\t{l}\n')
359 |
360 | acc, P, R, F1 = acc_f1(logits_all, labels_all)
361 | # Add to wandb
362 | wandb.run.summary[f'test-loss'] = loss
363 | wandb.run.summary[f'test-micro-acc'] = acc
364 | wandb.run.summary[f'test-micro-P'] = P
365 | wandb.run.summary[f'test-micro-R'] = R
366 | wandb.run.summary[f'test-micro-F1'] = F1
367 |
368 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs)
369 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps)
370 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs)
371 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s)
372 |
373 | #wandb.log({f"label-distribution-test-{i}": plots[0]})
374 |
--------------------------------------------------------------------------------
/emnlp_final_experiments/claim-detection/train_multi_view.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import os
4 | import random
5 | from typing import AnyStr
6 | from typing import List
7 | import ipdb
8 | import krippendorff
9 | from collections import defaultdict
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | import torch
14 | import wandb
15 | from torch.optim.lr_scheduler import LambdaLR
16 | from torch.utils.data import DataLoader
17 | from torch.utils.data import Subset
18 | from torch.utils.data import random_split
19 | from torch.optim import Adam
20 | from tqdm import tqdm
21 | from transformers import AdamW
22 | from transformers import DistilBertConfig
23 | from transformers import DistilBertTokenizer
24 | from transformers import DistilBertForSequenceClassification
25 | from transformers import get_linear_schedule_with_warmup
26 |
27 | from datareader import MultiDomainTwitterDataset
28 | from datareader import collate_batch_transformer
29 | from metrics import MultiDatasetClassificationEvaluator
30 | from metrics import ClassificationEvaluator
31 | from metrics import acc_f1
32 |
33 | from metrics import plot_label_distribution
34 | from model import MultiTransformerClassifier
35 | from model import VanillaBert
36 | from model import *
37 |
38 |
39 | def train(
40 | model: torch.nn.Module,
41 | train_dls: List[DataLoader],
42 | optimizer: torch.optim.Optimizer,
43 | scheduler: LambdaLR,
44 | validation_evaluator: MultiDatasetClassificationEvaluator,
45 | n_epochs: int,
46 | device: AnyStr,
47 | log_interval: int = 1,
48 | patience: int = 10,
49 | model_dir: str = "wandb_local",
50 | gradient_accumulation: int = 1,
51 | domain_name: str = ''
52 | ):
53 | #best_loss = float('inf')
54 | best_f1 = 0.0
55 | patience_counter = 0
56 |
57 | epoch_counter = 0
58 | total = sum(len(dl) for dl in train_dls)
59 |
60 | # Main loop
61 | while epoch_counter < n_epochs:
62 | dl_iters = [iter(dl) for dl in train_dls]
63 | dl_idx = list(range(len(dl_iters)))
64 | finished = [0] * len(dl_iters)
65 | i = 0
66 | with tqdm(total=total, desc="Training") as pbar:
67 | while sum(finished) < len(dl_iters):
68 | random.shuffle(dl_idx)
69 | for d in dl_idx:
70 | domain_dl = dl_iters[d]
71 | batches = []
72 | try:
73 | for j in range(gradient_accumulation):
74 | batches.append(next(domain_dl))
75 | except StopIteration:
76 | finished[d] = 1
77 | if len(batches) == 0:
78 | continue
79 | optimizer.zero_grad()
80 | for batch in batches:
81 | model.train()
82 | batch = tuple(t.to(device) for t in batch)
83 | input_ids = batch[0]
84 | masks = batch[1]
85 | labels = batch[2]
86 | # Testing with random domains to see if any effect
87 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
88 | domains = batch[3]
89 |
90 | loss, logits, alpha = model(input_ids, attention_mask=masks, domains=domains, labels=labels, ret_alpha = True)
91 | loss = loss.mean() / gradient_accumulation
92 | if i % log_interval == 0:
93 | # wandb.log({
94 | # "Loss": loss.item(),
95 | # "alpha0": alpha[:,0].cpu(),
96 | # "alpha1": alpha[:, 1].cpu(),
97 | # "alpha2": alpha[:, 2].cpu(),
98 | # "alpha_shared": alpha[:, 3].cpu()
99 | # })
100 | wandb.log({
101 | "Loss": loss.item()
102 | })
103 |
104 | loss.backward()
105 | i += 1
106 | pbar.update(1)
107 |
108 | optimizer.step()
109 | if scheduler is not None:
110 | scheduler.step()
111 |
112 | gc.collect()
113 |
114 | # Inline evaluation
115 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
116 | print(f"Validation f1: {F1}")
117 |
118 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
119 |
120 | # Saving the best model and early stopping
121 | #if val_loss < best_loss:
122 | if F1 > best_f1:
123 | best_model = model.state_dict()
124 | #best_loss = val_loss
125 | best_f1 = F1
126 | #wandb.run.summary['best_validation_loss'] = best_loss
127 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
128 | patience_counter = 0
129 | # Log to wandb
130 | wandb.log({
131 | 'Validation accuracy': acc,
132 | 'Validation Precision': P,
133 | 'Validation Recall': R,
134 | 'Validation F1': F1,
135 | 'Validation loss': val_loss})
136 | else:
137 | patience_counter += 1
138 | # Stop training once we have lost patience
139 | if patience_counter == patience:
140 | break
141 |
142 | gc.collect()
143 | epoch_counter += 1
144 |
145 |
146 | if __name__ == "__main__":
147 | # Define arguments
148 | parser = argparse.ArgumentParser()
149 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str)
150 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8)
151 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0)
152 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1)
153 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200)
154 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2)
155 | parser.add_argument("--pretrained_bert", help="Directory with weights to initialize the shared model with", type=str, default=None)
156 | parser.add_argument("--pretrained_multi_xformer", help="Directory with weights to initialize the domain specific models", type=str, default=None)
157 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[])
158 | parser.add_argument("--seed", type=int, help="Random seed", default=1000)
159 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline")
160 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str)
161 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[])
162 | parser.add_argument("--batch_size", help="The batch size", type=int, default=16)
163 | parser.add_argument("--lr", help="Learning rate", type=float, default=1e-5)
164 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01)
165 | parser.add_argument("--n_heads", help="Number of transformer heads", default=6, type=int)
166 | parser.add_argument("--n_layers", help="Number of transformer layers", default=6, type=int)
167 | parser.add_argument("--d_model", help="Transformer model size", default=768, type=int)
168 | parser.add_argument("--ff_dim", help="Intermediate feedforward size", default=2048, type=int)
169 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int)
170 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert")
171 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None)
172 | parser.add_argument("--ensemble_basic", help="Use averaging for the ensembling method", action="store_true")
173 | parser.add_argument("--ensemble_avg_learned", help="Use learned averaging for the ensembling method", action="store_true")
174 |
175 |
176 | args = parser.parse_args()
177 |
178 | # Set all the seeds
179 | seed = args.seed
180 | random.seed(seed)
181 | np.random.seed(seed)
182 | torch.manual_seed(seed)
183 | torch.cuda.manual_seed_all(seed)
184 | torch.backends.cudnn.deterministic = True
185 | torch.backends.cudnn.benchmark = False
186 |
187 | # See if CUDA available
188 | device = torch.device("cpu")
189 | if args.n_gpu > 0 and torch.cuda.is_available():
190 | print("Training on GPU")
191 | device = torch.device("cuda:0")
192 |
193 | # model configuration
194 | bert_model = 'distilbert-base-uncased'
195 | batch_size = args.batch_size
196 | lr = args.lr
197 | weight_decay = args.weight_decay
198 | n_epochs = args.n_epochs
199 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True)
200 |
201 |
202 | # wandb initialization
203 | wandb.init(
204 | project="domain-adaptation-twitter-emnlp",
205 | name=args.run_name,
206 | config={
207 | "epochs": n_epochs,
208 | "learning_rate": lr,
209 | "warmup": args.warmup_steps,
210 | "weight_decay": weight_decay,
211 | "batch_size": batch_size,
212 | "train_split_percentage": args.train_pct,
213 | "bert_model": bert_model,
214 | "seed": seed,
215 | "tags": ",".join(args.tags)
216 | }
217 | )
218 | #wandb.watch(model)
219 | #Create save directory for model
220 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"):
221 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}")
222 |
223 | # Create the dataset
224 | all_dsets = [MultiDomainTwitterDataset(
225 | args.dataset_loc,
226 | [domain],
227 | DistilBertTokenizer.from_pretrained(bert_model)
228 | ) for domain in args.domains[:-1]]
229 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)]
230 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))]
231 |
232 | accs = []
233 | Ps = []
234 | Rs = []
235 | F1s = []
236 | # Store labels and logits for individual splits for micro F1
237 | labels_all = []
238 | logits_all = []
239 |
240 | for i in range(len(all_dsets)+1):
241 | domain = args.domains[i]
242 | test_dset = all_dsets[i]
243 | # Override the domain IDs
244 | k = 0
245 | for j in range(len(all_dsets)):
246 | if j != i:
247 | all_dsets[j].set_domain_id(k)
248 | k += 1
249 | test_dset.set_domain_id(k)
250 | # For test
251 | #all_dsets = [all_dsets[0], all_dsets[2]]
252 |
253 | # Split the data
254 | if args.indices_dir is None:
255 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]])
256 | for j in range(len(all_dsets)) if j != i]
257 | else:
258 | # load the indices
259 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i]
260 | subset_indices = defaultdict(lambda: [[], []])
261 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \
262 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g:
263 | for l in f:
264 | vals = l.strip().split(',')
265 | subset_indices[int(vals[0])][0].append(int(vals[1]))
266 | for l in g:
267 | vals = l.strip().split(',')
268 | subset_indices[int(vals[0])][1].append(int(vals[1]))
269 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] for d in
270 | subset_indices]
271 |
272 | train_dls = [DataLoader(
273 | subset[0],
274 | batch_size=batch_size,
275 | shuffle=True,
276 | collate_fn=collate_batch_transformer
277 | ) for subset in subsets]
278 |
279 | val_ds = [subset[1] for subset in subsets]
280 | # for vds in val_ds:
281 | # print(vds.indices)
282 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device)
283 |
284 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
285 | # Create the model
286 | init_weights = None
287 | if args.pretrained_bert is not None:
288 | init_weights = {k: v for k, v in torch.load(args.pretrained_bert).items() if "classifier" not in k}
289 | model_dict = bert.state_dict()
290 | model_dict.update(init_weights)
291 | bert.load_state_dict(model_dict)
292 | shared_bert = VanillaBert(bert).to(device)
293 |
294 | multi_xformer = MultiDistilBertClassifier(
295 | bert_model,
296 | bert_config,
297 | n_domains=len(train_dls),
298 | init_weights=init_weights
299 | ).to(device)
300 |
301 | if args.ensemble_basic:
302 | model_class = MultiViewTransformerNetworkAveraging
303 | elif args.ensemble_avg_learned:
304 | model_class = MultiViewTransformerNetworkLearnedAveraging
305 | else:
306 | model_class = MultiViewTransformerNetworkProbabilities
307 |
308 | model = torch.nn.DataParallel(model_class(
309 | multi_xformer,
310 | shared_bert
311 | )).to(device)
312 |
313 | # Create the optimizer
314 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
315 | optimizer_grouped_parameters = [
316 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
317 | 'weight_decay': weight_decay},
318 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
319 | ]
320 |
321 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
322 | scheduler = get_linear_schedule_with_warmup(
323 | optimizer,
324 | args.warmup_steps,
325 | n_epochs * sum([len(train_dl) for train_dl in train_dls])
326 | )
327 |
328 | # Train
329 | train(
330 | model,
331 | train_dls,
332 | optimizer,
333 | scheduler,
334 | validation_evaluator,
335 | n_epochs,
336 | device,
337 | args.log_interval,
338 | model_dir=args.model_dir,
339 | gradient_accumulation=args.gradient_accumulation,
340 | domain_name=domain
341 | )
342 | # Load the best weights
343 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth'))
344 | if args.ensemble_avg_learned:
345 | weights = model.module.alpha_params.cpu().detach().numpy()
346 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/weights_{domain}.txt', 'wt') as f:
347 | f.write(str(weights))
348 |
349 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False)
350 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate(
351 | model,
352 | plot_callbacks=[plot_label_distribution],
353 | return_labels_logits=True,
354 | return_votes=True
355 | )
356 | print(f"{domain} F1: {F1}")
357 | print(f"{domain} Accuracy: {acc}")
358 | print()
359 |
360 | wandb.run.summary[f"{domain}-P"] = P
361 | wandb.run.summary[f"{domain}-R"] = R
362 | wandb.run.summary[f"{domain}-F1"] = F1
363 | wandb.run.summary[f"{domain}-Acc"] = acc
364 | # macro and micro F1 are only with respect to the 5 main splits
365 | if i < len(all_dsets):
366 | Ps.append(P)
367 | Rs.append(R)
368 | F1s.append(F1)
369 | accs.append(acc)
370 | labels_all.extend(labels)
371 | logits_all.extend(logits)
372 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f:
373 | for p, l in zip(np.argmax(logits, axis=-1), labels):
374 | f.write(f'{domain}\t{p}\t{l}\n')
375 |
376 | acc, P, R, F1 = acc_f1(logits_all, labels_all)
377 | # Add to wandb
378 | wandb.run.summary[f'test-loss'] = loss
379 | wandb.run.summary[f'test-micro-acc'] = acc
380 | wandb.run.summary[f'test-micro-P'] = P
381 | wandb.run.summary[f'test-micro-R'] = R
382 | wandb.run.summary[f'test-micro-F1'] = F1
383 |
384 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs)
385 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps)
386 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs)
387 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s)
388 |
389 | # wandb.log({f"label-distribution-test-{i}": plots[0]})
390 |
--------------------------------------------------------------------------------
/emnlp_final_experiments/claim-detection/train_multi_view_domain_adversarial.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import os
4 | import random
5 | from typing import AnyStr
6 | from typing import List
7 | import ipdb
8 | import krippendorff
9 | from collections import defaultdict
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | import torch
14 | import wandb
15 | from torch.optim.lr_scheduler import LambdaLR
16 | from torch.utils.data import DataLoader
17 | from torch.utils.data import Subset
18 | from torch.utils.data import random_split
19 | from torch.optim import Adam
20 | from tqdm import tqdm
21 | from transformers import AdamW
22 | from transformers import DistilBertConfig
23 | from transformers import DistilBertTokenizer
24 | from transformers import DistilBertForSequenceClassification
25 | from transformers import get_linear_schedule_with_warmup
26 |
27 | from datareader import MultiDomainTwitterDataset
28 | from datareader import collate_batch_transformer
29 | from metrics import MultiDatasetClassificationEvaluator
30 | from metrics import ClassificationEvaluator
31 | from metrics import acc_f1
32 |
33 | from metrics import plot_label_distribution
34 | from model import MultiTransformerClassifier
35 | from model import VanillaBert
36 | from model import *
37 |
38 |
39 | def train(
40 | model: torch.nn.Module,
41 | train_dls: List[DataLoader],
42 | optimizer: torch.optim.Optimizer,
43 | scheduler: LambdaLR,
44 | validation_evaluator: MultiDatasetClassificationEvaluator,
45 | n_epochs: int,
46 | device: AnyStr,
47 | log_interval: int = 1,
48 | patience: int = 10,
49 | model_dir: str = "wandb_local",
50 | gradient_accumulation: int = 1,
51 | domain_name: str = ''
52 | ):
53 | #best_loss = float('inf')
54 | best_f1 = 0.0
55 | patience_counter = 0
56 |
57 | epoch_counter = 0
58 | total = sum(len(dl) for dl in train_dls)
59 |
60 | # Main loop
61 | while epoch_counter < n_epochs:
62 | dl_iters = [iter(dl) for dl in train_dls]
63 | dl_idx = list(range(len(dl_iters)))
64 | finished = [0] * len(dl_iters)
65 | i = 0
66 | with tqdm(total=total, desc="Training") as pbar:
67 | while sum(finished) < len(dl_iters):
68 | random.shuffle(dl_idx)
69 | for d in dl_idx:
70 | domain_dl = dl_iters[d]
71 | batches = []
72 | try:
73 | for j in range(gradient_accumulation):
74 | batches.append(next(domain_dl))
75 | except StopIteration:
76 | finished[d] = 1
77 | if len(batches) == 0:
78 | continue
79 | optimizer.zero_grad()
80 | for batch in batches:
81 | model.train()
82 | batch = tuple(t.to(device) for t in batch)
83 | input_ids = batch[0]
84 | masks = batch[1]
85 | labels = batch[2]
86 | # Null the labels if its the test data
87 | if d == len(train_dls) - 1:
88 | labels = None
89 | # Testing with random domains to see if any effect
90 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
91 | domains = batch[3]
92 |
93 | loss, logits, alpha = model(input_ids, attention_mask=masks, domains=domains, labels=labels, ret_alpha = True)
94 | loss = loss.mean() / gradient_accumulation
95 | if i % log_interval == 0:
96 | # wandb.log({
97 | # "Loss": loss.item(),
98 | # "alpha0": alpha[:,0].cpu(),
99 | # "alpha1": alpha[:, 1].cpu(),
100 | # "alpha2": alpha[:, 2].cpu(),
101 | # "alpha_shared": alpha[:, 3].cpu()
102 | # })
103 | wandb.log({
104 | "Loss": loss.item()
105 | })
106 |
107 | loss.backward()
108 | i += 1
109 | pbar.update(1)
110 |
111 | optimizer.step()
112 | if scheduler is not None:
113 | scheduler.step()
114 |
115 | gc.collect()
116 |
117 | # Inline evaluation
118 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
119 | print(f"Validation F1: {F1}")
120 |
121 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
122 |
123 | # Saving the best model and early stopping
124 | #if val_loss < best_loss:
125 | if F1 > best_f1:
126 | best_model = model.state_dict()
127 | #best_loss = val_loss
128 | best_f1 = F1
129 | #wandb.run.summary['best_validation_loss'] = best_loss
130 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
131 | patience_counter = 0
132 | # Log to wandb
133 | wandb.log({
134 | 'Validation accuracy': acc,
135 | 'Validation Precision': P,
136 | 'Validation Recall': R,
137 | 'Validation F1': F1,
138 | 'Validation loss': val_loss})
139 | else:
140 | patience_counter += 1
141 | # Stop training once we have lost patience
142 | if patience_counter == patience:
143 | break
144 |
145 | gc.collect()
146 | epoch_counter += 1
147 |
148 |
149 | if __name__ == "__main__":
150 | # Define arguments
151 | parser = argparse.ArgumentParser()
152 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str)
153 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8)
154 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0)
155 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1)
156 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200)
157 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2)
158 | parser.add_argument("--pretrained_bert", help="Directory with weights to initialize the shared model with", type=str, default=None)
159 | parser.add_argument("--pretrained_multi_xformer", help="Directory with weights to initialize the domain specific models", type=str, default=None)
160 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[])
161 | parser.add_argument("--seed", type=int, help="Random seed", default=1000)
162 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline")
163 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str)
164 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[])
165 | parser.add_argument("--batch_size", help="The batch size", type=int, default=16)
166 | parser.add_argument("--lr", help="Learning rate", type=float, default=1e-5)
167 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01)
168 | parser.add_argument("--n_heads", help="Number of transformer heads", default=6, type=int)
169 | parser.add_argument("--n_layers", help="Number of transformer layers", default=6, type=int)
170 | parser.add_argument("--d_model", help="Transformer model size", default=768, type=int)
171 | parser.add_argument("--ff_dim", help="Intermediate feedforward size", default=2048, type=int)
172 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int)
173 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert")
174 | parser.add_argument("--supervision_layer", help="The layer at which to use domain adversarial supervision", default=12, type=int)
175 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None)
176 |
177 | args = parser.parse_args()
178 |
179 | # Set all the seeds
180 | seed = args.seed
181 | random.seed(seed)
182 | np.random.seed(seed)
183 | torch.manual_seed(seed)
184 | torch.cuda.manual_seed_all(seed)
185 | torch.backends.cudnn.deterministic = True
186 | torch.backends.cudnn.benchmark = False
187 |
188 | # See if CUDA available
189 | device = torch.device("cpu")
190 | if args.n_gpu > 0 and torch.cuda.is_available():
191 | print("Training on GPU")
192 | device = torch.device("cuda:0")
193 |
194 | # model configuration
195 | bert_model = 'distilbert-base-uncased'
196 | batch_size = args.batch_size
197 | lr = args.lr
198 | weight_decay = args.weight_decay
199 | n_epochs = args.n_epochs
200 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True)
201 |
202 |
203 | # wandb initialization
204 | wandb.init(
205 | project="domain-adaptation-twitter-emnlp",
206 | name=args.run_name,
207 | config={
208 | "epochs": n_epochs,
209 | "learning_rate": lr,
210 | "warmup": args.warmup_steps,
211 | "weight_decay": weight_decay,
212 | "batch_size": batch_size,
213 | "train_split_percentage": args.train_pct,
214 | "bert_model": bert_model,
215 | "seed": seed,
216 | "tags": ",".join(args.tags)
217 | }
218 | )
219 | #wandb.watch(model)
220 | #Create save directory for model
221 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"):
222 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}")
223 |
224 | # Create the dataset
225 | all_dsets = [MultiDomainTwitterDataset(
226 | args.dataset_loc,
227 | [domain],
228 | DistilBertTokenizer.from_pretrained(bert_model)
229 | ) for domain in args.domains[:-1]]
230 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)]
231 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))]
232 |
233 | accs = []
234 | Ps = []
235 | Rs = []
236 | F1s = []
237 | # Store labels and logits for individual splits for micro F1
238 | labels_all = []
239 | logits_all = []
240 |
241 | for i in range(len(all_dsets)):
242 | domain = args.domains[i]
243 | test_dset = all_dsets[i]
244 | # Override the domain IDs
245 | k = 0
246 | for j in range(len(all_dsets)):
247 | if j != i:
248 | all_dsets[j].set_domain_id(k)
249 | k += 1
250 | test_dset.set_domain_id(k)
251 | # For test
252 | #all_dsets = [all_dsets[0], all_dsets[2]]
253 |
254 | # Split the data
255 | if args.indices_dir is None:
256 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]])
257 | for j in range(len(all_dsets)) if j != i]
258 | else:
259 | # load the indices
260 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i]
261 | subset_indices = defaultdict(lambda: [[], []])
262 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \
263 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g:
264 | for l in f:
265 | vals = l.strip().split(',')
266 | subset_indices[int(vals[0])][0].append(int(vals[1]))
267 | for l in g:
268 | vals = l.strip().split(',')
269 | subset_indices[int(vals[0])][1].append(int(vals[1]))
270 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] for d in
271 | subset_indices]
272 |
273 | train_dls = [DataLoader(
274 | subset[0],
275 | batch_size=batch_size,
276 | shuffle=True,
277 | collate_fn=collate_batch_transformer
278 | ) for subset in subsets]
279 | # Add test data for domain adversarial training
280 | train_dls += [DataLoader(
281 | test_dset,
282 | batch_size=batch_size,
283 | shuffle=True,
284 | collate_fn=collate_batch_transformer
285 | )]
286 |
287 | val_ds = [subset[1] for subset in subsets]
288 | # for vds in val_ds:
289 | # print(vds.indices)
290 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device)
291 |
292 | # Create the model
293 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
294 | init_weights = None
295 | if args.pretrained_bert is not None:
296 | init_weights = {k: v for k, v in torch.load(args.pretrained_bert).items() if "classifier" not in k}
297 | model_dict = bert.state_dict()
298 | model_dict.update(init_weights)
299 | bert.load_state_dict(model_dict)
300 | shared_bert = VanillaBert(bert).to(device)
301 |
302 | multi_xformer = MultiDistilBertClassifier(
303 | bert_model,
304 | bert_config,
305 | n_domains=len(train_dls) - 1,
306 | init_weights=init_weights
307 | ).to(device)
308 |
309 | model = torch.nn.DataParallel(MultiViewTransformerNetworkProbabilitiesAdversarial(
310 | multi_xformer,
311 | shared_bert,
312 | supervision_layer=args.supervision_layer
313 | )).to(device)
314 |
315 | # Create the optimizer
316 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
317 | optimizer_grouped_parameters = [
318 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
319 | 'weight_decay': weight_decay},
320 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
321 | ]
322 |
323 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
324 | scheduler = get_linear_schedule_with_warmup(
325 | optimizer,
326 | args.warmup_steps,
327 | n_epochs * sum([len(train_dl) for train_dl in train_dls])
328 | )
329 |
330 | # Train
331 | train(
332 | model,
333 | train_dls,
334 | optimizer,
335 | scheduler,
336 | validation_evaluator,
337 | n_epochs,
338 | device,
339 | args.log_interval,
340 | model_dir=args.model_dir,
341 | gradient_accumulation=args.gradient_accumulation,
342 | domain_name=domain
343 | )
344 | # Load the best weights
345 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth'))
346 |
347 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False)
348 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate(
349 | model,
350 | plot_callbacks=[plot_label_distribution],
351 | return_labels_logits=True,
352 | return_votes=True
353 | )
354 | print(f"{domain} F1: {F1}")
355 | print(f"{domain} Accuracy: {acc}")
356 | print()
357 |
358 | wandb.run.summary[f"{domain}-P"] = P
359 | wandb.run.summary[f"{domain}-R"] = R
360 | wandb.run.summary[f"{domain}-F1"] = F1
361 | wandb.run.summary[f"{domain}-Acc"] = acc
362 | # macro and micro F1 are only with respect to the 5 main splits
363 | if i < len(all_dsets):
364 | Ps.append(P)
365 | Rs.append(R)
366 | F1s.append(F1)
367 | accs.append(acc)
368 | labels_all.extend(labels)
369 | logits_all.extend(logits)
370 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f:
371 | for p, l in zip(np.argmax(logits, axis=-1), labels):
372 | f.write(f'{domain}\t{p}\t{l}\n')
373 |
374 | acc, P, R, F1 = acc_f1(logits_all, labels_all)
375 | # Add to wandb
376 | wandb.run.summary[f'test-loss'] = loss
377 | wandb.run.summary[f'test-micro-acc'] = acc
378 | wandb.run.summary[f'test-micro-P'] = P
379 | wandb.run.summary[f'test-micro-R'] = R
380 | wandb.run.summary[f'test-micro-F1'] = F1
381 |
382 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs)
383 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps)
384 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs)
385 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s)
386 |
387 | # wandb.log({f"label-distribution-test-{i}": plots[0]})
388 |
--------------------------------------------------------------------------------
/emnlp_final_experiments/claim-detection/train_multi_view_selective_weighting.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import os
4 | import random
5 | from typing import AnyStr
6 | from typing import List
7 | import ipdb
8 | import krippendorff
9 | from collections import defaultdict
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | import torch
14 | import wandb
15 | from torch.optim.lr_scheduler import LambdaLR
16 | from torch.utils.data import DataLoader
17 | from torch.utils.data import Subset
18 | from torch.utils.data import random_split
19 | from torch.optim import Adam
20 | from tqdm import tqdm
21 | from transformers import AdamW
22 | from transformers import DistilBertConfig
23 | from transformers import DistilBertTokenizer
24 | from transformers import DistilBertForSequenceClassification
25 | from transformers import get_linear_schedule_with_warmup
26 |
27 | from datareader import MultiDomainTwitterDataset
28 | from datareader import collate_batch_transformer
29 | from metrics import MultiDatasetClassificationEvaluator
30 | from metrics import ClassificationEvaluator
31 | from metrics import acc_f1
32 |
33 | from metrics import plot_label_distribution
34 | from model import MultiTransformerClassifier
35 | from model import VanillaBert
36 | from model import *
37 | from sklearn.model_selection import ParameterSampler
38 |
39 |
40 | def attention_grid_search(
41 | model: torch.nn.Module,
42 | validation_evaluator: MultiDatasetClassificationEvaluator,
43 | n_epochs: int,
44 | seed: int
45 | ):
46 | best_weights = model.module.weights
47 | # initial
48 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
49 | best_f1 = F1
50 | print(F1)
51 | # Create the grid search
52 | param_dict = {1:list(range(0,11)), 2:list(range(0,11)),3:list(range(0,11)),4:list(range(0,11)), 5:list(range(0,11))}
53 | grid_search_params = ParameterSampler(param_dict, n_iter=n_epochs, random_state=seed)
54 | for d in grid_search_params:
55 | weights = [v for k,v in sorted(d.items(), key=lambda x:x[0])]
56 | weights = np.array(weights) / sum(weights)
57 | model.module.weights = weights
58 | # Inline evaluation
59 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
60 | print(f"Weights: {weights}\tValidation F1: {F1}")
61 |
62 | if F1 > best_f1:
63 | best_weights = weights
64 | best_f1 = F1
65 | # Log to wandb
66 | wandb.log({
67 | 'Validation accuracy': acc,
68 | 'Validation Precision': P,
69 | 'Validation Recall': R,
70 | 'Validation F1': F1,
71 | 'Validation loss': val_loss})
72 |
73 | gc.collect()
74 |
75 | return best_weights
76 |
77 |
78 | if __name__ == "__main__":
79 | # Define arguments
80 | parser = argparse.ArgumentParser()
81 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str)
82 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8)
83 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0)
84 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2)
85 | parser.add_argument("--pretrained_model", help="The pretrained averaging model", type=str, default=None)
86 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[])
87 | parser.add_argument("--seed", type=int, help="Random seed", default=1000)
88 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline")
89 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str)
90 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[])
91 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None)
92 |
93 |
94 | args = parser.parse_args()
95 |
96 | # Set all the seeds
97 | seed = args.seed
98 | random.seed(seed)
99 | np.random.seed(seed)
100 | torch.manual_seed(seed)
101 | torch.cuda.manual_seed_all(seed)
102 | torch.backends.cudnn.deterministic = True
103 | torch.backends.cudnn.benchmark = False
104 |
105 | # See if CUDA available
106 | device = torch.device("cpu")
107 | if args.n_gpu > 0 and torch.cuda.is_available():
108 | print("Training on GPU")
109 | device = torch.device("cuda:0")
110 |
111 | # model configuration
112 | bert_model = 'distilbert-base-uncased'
113 | n_epochs = args.n_epochs
114 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True)
115 |
116 |
117 | # wandb initialization
118 | wandb.init(
119 | project="domain-adaptation-twitter-emnlp",
120 | name=args.run_name,
121 | config={
122 | "epochs": n_epochs,
123 | "train_split_percentage": args.train_pct,
124 | "bert_model": bert_model,
125 | "seed": seed,
126 | "tags": ",".join(args.tags)
127 | }
128 | )
129 | #wandb.watch(model)
130 | #Create save directory for model
131 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"):
132 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}")
133 |
134 | # Create the dataset
135 | all_dsets = [MultiDomainTwitterDataset(
136 | args.dataset_loc,
137 | [domain],
138 | DistilBertTokenizer.from_pretrained(bert_model)
139 | ) for domain in args.domains[:-1]]
140 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)]
141 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))]
142 |
143 | accs = []
144 | Ps = []
145 | Rs = []
146 | F1s = []
147 | # Store labels and logits for individual splits for micro F1
148 | labels_all = []
149 | logits_all = []
150 |
151 | for i in range(len(all_dsets)):
152 | domain = args.domains[i]
153 | test_dset = all_dsets[i]
154 | # Override the domain IDs
155 | k = 0
156 | for j in range(len(all_dsets)):
157 | if j != i:
158 | all_dsets[j].set_domain_id(k)
159 | k += 1
160 | test_dset.set_domain_id(k)
161 | # For test
162 | #all_dsets = [all_dsets[0], all_dsets[2]]
163 |
164 | # Split the data
165 | if args.indices_dir is None:
166 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]])
167 | for j in range(len(all_dsets)) if j != i]
168 | else:
169 | # load the indices
170 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i]
171 | subset_indices = defaultdict(lambda: [[], []])
172 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \
173 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g:
174 | for l in f:
175 | vals = l.strip().split(',')
176 | subset_indices[int(vals[0])][0].append(int(vals[1]))
177 | for l in g:
178 | vals = l.strip().split(',')
179 | subset_indices[int(vals[0])][1].append(int(vals[1]))
180 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] for d in
181 | subset_indices]
182 |
183 | train_dls = [DataLoader(
184 | subset[0],
185 | batch_size=8,
186 | shuffle=True,
187 | collate_fn=collate_batch_transformer
188 | ) for subset in subsets]
189 |
190 | val_ds = [subset[1] for subset in subsets]
191 | # for vds in val_ds:
192 | # print(vds.indices)
193 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device)
194 |
195 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
196 | # Create the model
197 | init_weights = None
198 | shared_bert = VanillaBert(bert).to(device)
199 |
200 | multi_xformer = MultiDistilBertClassifier(
201 | bert_model,
202 | bert_config,
203 | n_domains=len(train_dls),
204 | init_weights=init_weights
205 | ).to(device)
206 |
207 | model = torch.nn.DataParallel(MultiViewTransformerNetworkSelectiveWeight(
208 | multi_xformer,
209 | shared_bert
210 | )).to(device)
211 |
212 | # Load the best weights
213 | model.load_state_dict(torch.load(f'{args.pretrained_model}/model_{domain}.pth'))
214 |
215 | # Calculate the best attention weights with a grid search
216 | weights = attention_grid_search(
217 | model,
218 | validation_evaluator,
219 | n_epochs,
220 | seed
221 | )
222 | model.module.weights = weights
223 | print(f"Best weights: {weights}")
224 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/weights_{domain}.txt', 'wt') as f:
225 | f.write(str(weights))
226 |
227 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False)
228 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate(
229 | model,
230 | plot_callbacks=[plot_label_distribution],
231 | return_labels_logits=True,
232 | return_votes=True
233 | )
234 | print(f"{domain} F1: {F1}")
235 | print(f"{domain} Accuracy: {acc}")
236 | print()
237 |
238 | wandb.run.summary[f"{domain}-P"] = P
239 | wandb.run.summary[f"{domain}-R"] = R
240 | wandb.run.summary[f"{domain}-F1"] = F1
241 | wandb.run.summary[f"{domain}-Acc"] = acc
242 | # macro and micro F1 are only with respect to the 5 main splits
243 | if i < len(all_dsets):
244 | Ps.append(P)
245 | Rs.append(R)
246 | F1s.append(F1)
247 | accs.append(acc)
248 | labels_all.extend(labels)
249 | logits_all.extend(logits)
250 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f:
251 | for p, l in zip(np.argmax(logits, axis=-1), labels):
252 | f.write(f'{domain}\t{p}\t{l}\n')
253 |
254 | acc, P, R, F1 = acc_f1(logits_all, labels_all)
255 | # Add to wandb
256 | wandb.run.summary[f'test-loss'] = loss
257 | wandb.run.summary[f'test-micro-acc'] = acc
258 | wandb.run.summary[f'test-micro-P'] = P
259 | wandb.run.summary[f'test-micro-R'] = R
260 | wandb.run.summary[f'test-micro-F1'] = F1
261 |
262 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs)
263 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps)
264 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs)
265 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s)
266 |
267 | # wandb.log({f"label-distribution-test-{i}": plots[0]})
268 |
--------------------------------------------------------------------------------
/emnlp_final_experiments/sentiment-analysis/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/copenlu/xformer-multi-source-domain-adaptation/be2d1a132298131df82fe40dd4f6c08dec8b3404/emnlp_final_experiments/sentiment-analysis/.DS_Store
--------------------------------------------------------------------------------
/emnlp_final_experiments/sentiment-analysis/analyze_expert_predictions.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import os
4 | import random
5 | from typing import AnyStr
6 | from typing import List
7 | import ipdb
8 | import krippendorff
9 | from collections import defaultdict
10 |
11 | import numpy as np
12 | import torch
13 | import wandb
14 | from torch.optim.lr_scheduler import LambdaLR
15 | from torch.utils.data import DataLoader
16 | from torch.utils.data import Subset
17 | from torch.utils.data import random_split
18 | from torch.optim import Adam
19 | from tqdm import tqdm
20 | from transformers import AdamW
21 | from transformers import DistilBertConfig
22 | from transformers import DistilBertTokenizer
23 | from transformers import DistilBertForSequenceClassification
24 | from transformers import get_linear_schedule_with_warmup
25 |
26 | from datareader import MultiDomainSentimentDataset
27 | from datareader import collate_batch_transformer
28 | from metrics import MultiDatasetClassificationEvaluator
29 | from metrics import ClassificationEvaluator
30 | from metrics import acc_f1
31 |
32 | from metrics import plot_label_distribution
33 | from model import MultiTransformerClassifier
34 | from model import VanillaBert
35 | from model import *
36 | from sklearn.model_selection import ParameterSampler
37 | from scipy.special import softmax
38 | from scipy.stats import wasserstein_distance
39 |
40 |
41 | if __name__ == "__main__":
42 | # Define arguments
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str)
45 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8)
46 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0)
47 | parser.add_argument("--pretrained_model", help="Directory with weights to initialize the shared model with", type=str, default=None)
48 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[])
49 | parser.add_argument("--seed", type=int, help="Random seed", default=1000)
50 |
51 | args = parser.parse_args()
52 |
53 | # Set all the seeds
54 | seed = args.seed
55 | random.seed(seed)
56 | np.random.seed(seed)
57 | torch.manual_seed(seed)
58 | torch.cuda.manual_seed_all(seed)
59 | torch.backends.cudnn.deterministic = True
60 | torch.backends.cudnn.benchmark = False
61 |
62 | # See if CUDA available
63 | device = torch.device("cpu")
64 | if args.n_gpu > 0 and torch.cuda.is_available():
65 | print("Training on GPU")
66 | device = torch.device("cuda:0")
67 |
68 | # model configuration
69 | bert_model = 'distilbert-base-uncased'
70 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True)
71 |
72 | # Create the dataset
73 | all_dsets = [MultiDomainSentimentDataset(
74 | args.dataset_loc,
75 | [domain],
76 | DistilBertTokenizer.from_pretrained(bert_model)
77 | ) for domain in args.domains]
78 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)]
79 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))]
80 |
81 | for i in range(len(all_dsets)):
82 | domain = args.domains[i]
83 |
84 | test_dset = all_dsets[i]
85 |
86 | dataloader = DataLoader(
87 | test_dset,
88 | batch_size=4,
89 | shuffle=True,
90 | collate_fn=collate_batch_transformer
91 | )
92 |
93 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
94 | # Create the model
95 |
96 | model = torch.nn.DataParallel(MultiViewTransformerNetworkAveragingIndividuals(
97 | bert_model,
98 | bert_config,
99 | len(all_dsets) - 1
100 | )).to(device)
101 | model.module.average = True
102 | # load the trained model
103 |
104 | # Load the best weights
105 | for v in range(len(all_dsets)-1):
106 | model.module.domain_experts[v].load_state_dict(torch.load(f'{args.pretrained_model}/model_{domain}_{v}.pth'))
107 | model.module.shared_bert.load_state_dict(torch.load(f'{args.pretrained_model}/model_{domain}_{len(all_dsets)-1}.pth'))
108 |
109 | logits_all = [[] for d in range(len(all_dsets))]
110 | for batch in tqdm(dataloader):
111 | model.train()
112 | batch = tuple(t.to(device) for t in batch)
113 | input_ids = batch[0]
114 | masks = batch[1]
115 | labels = batch[2]
116 | # Testing with random domains to see if any effect
117 | # domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
118 | domains = batch[3]
119 |
120 | logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels, return_logits=True)
121 | for k,l in enumerate(logits):
122 | logits_all[k].append(l.detach().cpu().numpy())
123 |
124 | print(domain)
125 | probs_all = [softmax(np.concatenate(l), axis=-1)[:,1] for l in logits_all]
126 | for i in range(len(probs_all)):
127 | for j in range(i+1, len(probs_all)):
128 | print(wasserstein_distance(probs_all[i], probs_all[j]))
129 | preds_all = np.asarray([np.argmax(np.concatenate(l), axis=-1) for l in logits_all])
130 | print(f"alpha: {krippendorff.alpha(preds_all, level_of_measurement='nominal')}")
131 | print()
--------------------------------------------------------------------------------
/emnlp_final_experiments/sentiment-analysis/analyze_expert_predictions_cnn.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import os
4 | import random
5 | from typing import AnyStr
6 | from typing import List
7 | import ipdb
8 | import krippendorff
9 | from collections import defaultdict
10 |
11 | import numpy as np
12 | import torch
13 | import wandb
14 | from torch.optim.lr_scheduler import LambdaLR
15 | from torch.utils.data import DataLoader
16 | from torch.utils.data import Subset
17 | from torch.utils.data import random_split
18 | from torch.optim import Adam
19 | from tqdm import tqdm
20 | from transformers import AdamW
21 | from transformers import DistilBertConfig
22 | from transformers import DistilBertTokenizer
23 | from transformers import DistilBertForSequenceClassification
24 | from transformers import get_linear_schedule_with_warmup
25 |
26 | from datareader_cnn import MultiDomainSentimentDataset
27 | from datareader_cnn import collate_batch_cnn
28 | from datareader_cnn import FasttextTokenizer
29 | from metrics import MultiDatasetClassificationEvaluator
30 | from metrics import ClassificationEvaluator
31 | from metrics import acc_f1
32 |
33 | from metrics import plot_label_distribution
34 | from model import MultiTransformerClassifier
35 | from model import VanillaBert
36 | from model import *
37 | from sklearn.model_selection import ParameterSampler
38 | from scipy.special import softmax
39 | from scipy.stats import wasserstein_distance
40 |
41 |
42 | if __name__ == "__main__":
43 | # Define arguments
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str)
46 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8)
47 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0)
48 | parser.add_argument("--pretrained_model", help="Directory with weights to initialize the shared model with", type=str, default=None)
49 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[])
50 | parser.add_argument("--seed", type=int, help="Random seed", default=1000)
51 | parser.add_argument("--dropout", help="Path to directory with pretrained embeddings", default=0.3, type=float)
52 |
53 |
54 | #CNN stuff
55 | parser.add_argument("--embedding_dim", help="Dimension of embeddings", choices=[50, 100, 200, 300], default=100,
56 | type=int)
57 | parser.add_argument("--in_channels", type=int, default=1)
58 | parser.add_argument("--out_channels", type=int, default=100)
59 | parser.add_argument("--kernel_heights", help="filter windows", type=int, nargs='+', default=[2, 4, 5])
60 | parser.add_argument("--stride", help="stride", type=int, default=1)
61 | parser.add_argument("--padding", help="padding", type=int, default=0)
62 |
63 | args = parser.parse_args()
64 |
65 | # Set all the seeds
66 | seed = args.seed
67 | random.seed(seed)
68 | np.random.seed(seed)
69 | torch.manual_seed(seed)
70 | torch.cuda.manual_seed_all(seed)
71 | torch.backends.cudnn.deterministic = True
72 | torch.backends.cudnn.benchmark = False
73 |
74 | # See if CUDA available
75 | device = torch.device("cpu")
76 | if args.n_gpu > 0 and torch.cuda.is_available():
77 | print("Training on GPU")
78 | device = torch.device("cuda:0")
79 |
80 | tokenizer = FasttextTokenizer(f"{args.dataset_loc}/vocabulary.txt")
81 |
82 | # Create the dataset
83 | all_dsets = [MultiDomainSentimentDataset(
84 | args.dataset_loc,
85 | [domain],
86 | tokenizer
87 | ) for domain in args.domains]
88 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)]
89 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))]
90 |
91 | for i in range(len(all_dsets)):
92 | domain = args.domains[i]
93 |
94 | test_dset = all_dsets[i]
95 |
96 | dataloader = DataLoader(
97 | test_dset,
98 | batch_size=4,
99 | shuffle=True,
100 | collate_fn=collate_batch_cnn
101 | )
102 |
103 | embeddings = np.load(f"{args.dataset_loc}/fasttext_embeddings.npy")
104 | model = torch.nn.DataParallel(MultiViewCNNAveragingIndividuals(
105 | args,
106 | embeddings,
107 | len(all_dsets) - 1
108 | )).to(device)
109 | model.module.average = True
110 | # load the trained model
111 |
112 | # Load the best weights
113 | for v in range(len(all_dsets)-1):
114 | model.module.domain_experts[v].load_state_dict(torch.load(f'{args.pretrained_model}/model_{domain}_{v}.pth'))
115 | model.module.shared_model.load_state_dict(torch.load(f'{args.pretrained_model}/model_{domain}_{len(all_dsets)-1}.pth'))
116 |
117 | logits_all = [[] for d in range(len(all_dsets))]
118 | for batch in tqdm(dataloader):
119 | model.train()
120 | batch = tuple(t.to(device) for t in batch)
121 | input_ids = batch[0]
122 | masks = batch[1]
123 | labels = batch[2]
124 | # Testing with random domains to see if any effect
125 | # domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
126 | domains = batch[3]
127 |
128 | logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels, return_logits=True)
129 | for k,l in enumerate(logits):
130 | logits_all[k].append(l.detach().cpu().numpy())
131 |
132 | print(domain)
133 | probs_all = [softmax(np.concatenate(l), axis=-1)[:,1] for l in logits_all]
134 | for i in range(len(probs_all)):
135 | for j in range(i+1, len(probs_all)):
136 | print(wasserstein_distance(probs_all[i], probs_all[j]))
137 | preds_all = np.asarray([np.argmax(np.concatenate(l), axis=-1) for l in logits_all])
138 | print(f"alpha: {krippendorff.alpha(preds_all, level_of_measurement='nominal')}")
139 | print()
--------------------------------------------------------------------------------
/emnlp_final_experiments/sentiment-analysis/train_basic.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import os
4 | import random
5 | from typing import AnyStr
6 | from typing import List
7 | import ipdb
8 | from collections import defaultdict
9 | from pathlib import Path
10 |
11 | import numpy as np
12 | import torch
13 | import wandb
14 | from torch.optim.lr_scheduler import LambdaLR
15 | from torch.utils.data import DataLoader
16 | from torch.utils.data import Subset
17 | from torch.utils.data import random_split
18 | from torch.optim import Adam
19 | from tqdm import tqdm
20 | from transformers import AdamW
21 | from transformers import DistilBertConfig
22 | from transformers import DistilBertTokenizer
23 | from transformers import DistilBertModel
24 | from transformers import BertConfig
25 | from transformers import BertTokenizer
26 | from transformers import BertModel
27 | from transformers import get_linear_schedule_with_warmup
28 |
29 | from datareader import MultiDomainSentimentDataset
30 | from datareader import collate_batch_transformer
31 | from metrics import MultiDatasetClassificationEvaluator, acc_f1
32 | from metrics import ClassificationEvaluator
33 |
34 | from metrics import plot_label_distribution
35 | from model import *
36 |
37 |
38 | def train(
39 | model: torch.nn.Module,
40 | train_dls: List[DataLoader],
41 | optimizer: torch.optim.Optimizer,
42 | scheduler: LambdaLR,
43 | validation_evaluator: MultiDatasetClassificationEvaluator,
44 | n_epochs: int,
45 | device: AnyStr,
46 | log_interval: int = 1,
47 | patience: int = 10,
48 | model_dir: str = "wandb_local",
49 | gradient_accumulation: int = 1,
50 | domain_name: str = ''
51 | ):
52 | #best_loss = float('inf')
53 | best_acc = 0.0
54 | patience_counter = 0
55 |
56 | epoch_counter = 0
57 | total = sum(len(dl) for dl in train_dls)
58 |
59 | # Main loop
60 | while epoch_counter < n_epochs:
61 | dl_iters = [iter(dl) for dl in train_dls]
62 | dl_idx = list(range(len(dl_iters)))
63 | finished = [0] * len(dl_iters)
64 | i = 0
65 | with tqdm(total=total, desc="Training") as pbar:
66 | while sum(finished) < len(dl_iters):
67 | random.shuffle(dl_idx)
68 | for d in dl_idx:
69 | domain_dl = dl_iters[d]
70 | batches = []
71 | try:
72 | for j in range(gradient_accumulation):
73 | batches.append(next(domain_dl))
74 | except StopIteration:
75 | finished[d] = 1
76 | if len(batches) == 0:
77 | continue
78 | optimizer.zero_grad()
79 | for batch in batches:
80 | model.train()
81 | batch = tuple(t.to(device) for t in batch)
82 | input_ids = batch[0]
83 | masks = batch[1]
84 | labels = batch[2]
85 | # Testing with random domains to see if any effect
86 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
87 | domains = batch[3]
88 |
89 | loss, logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels)
90 | loss = loss / gradient_accumulation
91 |
92 | if i % log_interval == 0:
93 | wandb.log({
94 | "Loss": loss.item()
95 | })
96 |
97 | loss.backward()
98 | i += 1
99 | pbar.update(1)
100 |
101 | optimizer.step()
102 | if scheduler is not None:
103 | scheduler.step()
104 |
105 | gc.collect()
106 |
107 | # Inline evaluation
108 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
109 | print(f"Validation acc: {acc}")
110 |
111 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
112 |
113 | # Saving the best model and early stopping
114 | #if val_loss < best_loss:
115 | if acc > best_acc:
116 | best_model = model.state_dict()
117 | #best_loss = val_loss
118 | best_acc = acc
119 | #wandb.run.summary['best_validation_loss'] = best_loss
120 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
121 | patience_counter = 0
122 | # Log to wandb
123 | wandb.log({
124 | 'Validation accuracy': acc,
125 | 'Validation Precision': P,
126 | 'Validation Recall': R,
127 | 'Validation F1': F1,
128 | 'Validation loss': val_loss})
129 | else:
130 | patience_counter += 1
131 | # Stop training once we have lost patience
132 | if patience_counter == patience:
133 | break
134 |
135 | gc.collect()
136 | epoch_counter += 1
137 |
138 |
139 | if __name__ == "__main__":
140 | # Define arguments
141 | parser = argparse.ArgumentParser()
142 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str)
143 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8)
144 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0)
145 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1)
146 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200)
147 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2)
148 | parser.add_argument("--pretrained_model", help="Weights to initialize the model with", type=str, default=None)
149 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[])
150 | parser.add_argument("--seed", type=int, help="Random seed", default=1000)
151 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline")
152 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str)
153 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[])
154 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert")
155 | parser.add_argument("--ff_dim", help="The dimensionality of the feedforward network in the sluice", type=int, default=768)
156 | parser.add_argument("--batch_size", help="The batch size", type=int, default=8)
157 | parser.add_argument("--lr", help="Learning rate", type=float, default=3e-5)
158 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01)
159 | parser.add_argument("--lambd", help="l2 reg", type=float, default=10e-3)
160 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int)
161 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None)
162 | parser.add_argument("--full_bert", help="Specify to use full bert model", action="store_true")
163 |
164 | args = parser.parse_args()
165 |
166 | # Set all the seeds
167 | seed = args.seed
168 | random.seed(seed)
169 | np.random.seed(seed)
170 | torch.manual_seed(seed)
171 | torch.cuda.manual_seed_all(seed)
172 | torch.backends.cudnn.deterministic = True
173 | torch.backends.cudnn.benchmark = False
174 |
175 | # See if CUDA available
176 | device = torch.device("cpu")
177 | if args.n_gpu > 0 and torch.cuda.is_available():
178 | print("Training on GPU")
179 | device = torch.device("cuda:0")
180 |
181 | # model configuration
182 | batch_size = args.batch_size
183 | lr = args.lr
184 | weight_decay = args.weight_decay
185 | n_epochs = args.n_epochs
186 | if args.full_bert:
187 | bert_model = 'bert-base-uncased'
188 | bert_config = BertConfig.from_pretrained(bert_model, num_labels=2)
189 | tokenizer = BertTokenizer.from_pretrained(bert_model)
190 | else:
191 | bert_model = 'distilbert-base-uncased'
192 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2)
193 | tokenizer = DistilBertTokenizer.from_pretrained(bert_model)
194 |
195 | # wandb initialization
196 | wandb.init(
197 | project="domain-adaptation-sentiment-emnlp",
198 | name=args.run_name,
199 | config={
200 | "epochs": n_epochs,
201 | "learning_rate": lr,
202 | "warmup": args.warmup_steps,
203 | "weight_decay": weight_decay,
204 | "batch_size": batch_size,
205 | "train_split_percentage": args.train_pct,
206 | "bert_model": bert_model,
207 | "seed": seed,
208 | "pretrained_model": args.pretrained_model,
209 | "tags": ",".join(args.tags)
210 | }
211 | )
212 | # Create save directory for model
213 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"):
214 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}")
215 |
216 | # Create the dataset
217 | all_dsets = [MultiDomainSentimentDataset(
218 | args.dataset_loc,
219 | [domain],
220 | tokenizer
221 | ) for domain in args.domains]
222 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)]
223 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))]
224 |
225 | accs = []
226 | Ps = []
227 | Rs = []
228 | F1s = []
229 | # Store labels and logits for individual splits for micro F1
230 | labels_all = []
231 | logits_all = []
232 |
233 | for i in range(len(all_dsets)):
234 | domain = args.domains[i]
235 | test_dset = all_dsets[i]
236 | # Override the domain IDs
237 | k = 0
238 | for j in range(len(all_dsets)):
239 | if j != i:
240 | all_dsets[j].set_domain_id(k)
241 | k += 1
242 | test_dset.set_domain_id(k)
243 |
244 | # Split the data
245 | if args.indices_dir is None:
246 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]])
247 | for j in range(len(all_dsets)) if j != i]
248 | # Save the indices
249 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/train_idx_{domain}.txt', 'wt') as f, \
250 | open(f'{args.model_dir}/{Path(wandb.run.dir).name}/val_idx_{domain}.txt', 'wt') as g:
251 | for j, subset in enumerate(subsets):
252 | for idx in subset[0].indices:
253 | f.write(f'{j},{idx}\n')
254 | for idx in subset[1].indices:
255 | g.write(f'{j},{idx}\n')
256 | else:
257 | # load the indices
258 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i]
259 | subset_indices = defaultdict(lambda: [[], []])
260 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \
261 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g:
262 | for l in f:
263 | vals = l.strip().split(',')
264 | subset_indices[int(vals[0])][0].append(int(vals[1]))
265 | for l in g:
266 | vals = l.strip().split(',')
267 | subset_indices[int(vals[0])][1].append(int(vals[1]))
268 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])]
269 | for d in subset_indices]
270 |
271 | train_dls = [DataLoader(
272 | subset[0],
273 | batch_size=batch_size,
274 | shuffle=True,
275 | collate_fn=collate_batch_transformer
276 | ) for subset in subsets]
277 |
278 | val_ds = [subset[1] for subset in subsets]
279 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device)
280 |
281 |
282 | # Create the model
283 | if args.full_bert:
284 | bert = BertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
285 | else:
286 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
287 | model = VanillaBert(bert).to(device)
288 | if args.pretrained_model is not None:
289 | model.load_state_dict(torch.load(f"{args.pretrained_model}/model_{domain}.pth"))
290 |
291 | # Create the optimizer
292 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
293 | optimizer_grouped_parameters = [
294 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
295 | 'weight_decay': weight_decay},
296 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
297 | ]
298 | # optimizer = Adam(optimizer_grouped_parameters, lr=1e-3)
299 | # scheduler = None
300 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
301 | scheduler = get_linear_schedule_with_warmup(
302 | optimizer,
303 | args.warmup_steps,
304 | n_epochs * sum([len(train_dl) for train_dl in train_dls])
305 | )
306 |
307 | # Train
308 | train(
309 | model,
310 | train_dls,
311 | optimizer,
312 | scheduler,
313 | validation_evaluator,
314 | n_epochs,
315 | device,
316 | args.log_interval,
317 | model_dir=args.model_dir,
318 | gradient_accumulation=args.gradient_accumulation,
319 | domain_name=domain
320 | )
321 |
322 | # Load the best weights
323 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth'))
324 |
325 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False)
326 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate(
327 | model,
328 | plot_callbacks=[plot_label_distribution],
329 | return_labels_logits=True,
330 | return_votes=True
331 | )
332 | print(f"{domain} F1: {F1}")
333 | print(f"{domain} Accuracy: {acc}")
334 | print()
335 |
336 | wandb.run.summary[f"{domain}-P"] = P
337 | wandb.run.summary[f"{domain}-R"] = R
338 | wandb.run.summary[f"{domain}-F1"] = F1
339 | wandb.run.summary[f"{domain}-Acc"] = acc
340 | Ps.append(P)
341 | Rs.append(R)
342 | F1s.append(F1)
343 | accs.append(acc)
344 | labels_all.extend(labels)
345 | logits_all.extend(logits)
346 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f:
347 | for p, l in zip(np.argmax(logits, axis=-1), labels):
348 | f.write(f'{domain}\t{p}\t{l}\n')
349 |
350 | acc, P, R, F1 = acc_f1(logits_all, labels_all)
351 | # Add to wandb
352 | wandb.run.summary[f'test-loss'] = loss
353 | wandb.run.summary[f'test-micro-acc'] = acc
354 | wandb.run.summary[f'test-micro-P'] = P
355 | wandb.run.summary[f'test-micro-R'] = R
356 | wandb.run.summary[f'test-micro-F1'] = F1
357 |
358 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs)
359 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps)
360 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs)
361 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s)
362 |
363 | #wandb.log({f"label-distribution-test-{i}": plots[0]})
364 |
--------------------------------------------------------------------------------
/emnlp_final_experiments/sentiment-analysis/train_basic_domain_adversarial.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import os
4 | import random
5 | from typing import AnyStr
6 | from typing import List
7 | import ipdb
8 | from collections import defaultdict
9 | from pathlib import Path
10 |
11 | import numpy as np
12 | import torch
13 | import wandb
14 | from torch.optim.lr_scheduler import LambdaLR
15 | from torch.utils.data import DataLoader
16 | from torch.utils.data import Subset
17 | from torch.utils.data import random_split
18 | from torch.optim import Adam
19 | from tqdm import tqdm
20 | from transformers import AdamW
21 | from transformers import DistilBertConfig
22 | from transformers import DistilBertTokenizer
23 | from transformers import DistilBertModel
24 | from transformers import BertConfig
25 | from transformers import BertTokenizer
26 | from transformers import BertModel
27 | from transformers import get_linear_schedule_with_warmup
28 |
29 | from datareader import MultiDomainSentimentDataset
30 | from datareader import collate_batch_transformer
31 | from metrics import MultiDatasetClassificationEvaluator, acc_f1
32 | from metrics import ClassificationEvaluator
33 |
34 | from metrics import plot_label_distribution
35 | from model import *
36 |
37 |
38 | def train(
39 | model: torch.nn.Module,
40 | train_dls: List[DataLoader],
41 | optimizer: torch.optim.Optimizer,
42 | scheduler: LambdaLR,
43 | validation_evaluator: MultiDatasetClassificationEvaluator,
44 | n_epochs: int,
45 | device: AnyStr,
46 | log_interval: int = 1,
47 | patience: int = 10,
48 | model_dir: str = "wandb_local",
49 | gradient_accumulation: int = 1,
50 | domain_name: str = ''
51 | ):
52 | #best_loss = float('inf')
53 | best_acc = 0.0
54 | patience_counter = 0
55 |
56 | epoch_counter = 0
57 | total = sum(len(dl) for dl in train_dls)
58 |
59 | # Main loop
60 | while epoch_counter < n_epochs:
61 | dl_iters = [iter(dl) for dl in train_dls]
62 | dl_idx = list(range(len(dl_iters)))
63 | finished = [0] * len(dl_iters)
64 | i = 0
65 | with tqdm(total=total, desc="Training") as pbar:
66 | while sum(finished) < len(dl_iters):
67 | random.shuffle(dl_idx)
68 | for d in dl_idx:
69 | domain_dl = dl_iters[d]
70 | batches = []
71 | try:
72 | for j in range(gradient_accumulation):
73 | batches.append(next(domain_dl))
74 | except StopIteration:
75 | finished[d] = 1
76 | if len(batches) == 0:
77 | continue
78 | optimizer.zero_grad()
79 | for batch in batches:
80 | model.train()
81 | batch = tuple(t.to(device) for t in batch)
82 | input_ids = batch[0]
83 | masks = batch[1]
84 | labels = batch[2]
85 | # Null the labels if its the test data
86 | if d == len(train_dls) - 1:
87 | labels = None
88 | # Testing with random domains to see if any effect
89 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
90 | domains = batch[3]
91 |
92 | loss, logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels)
93 | loss = loss.mean() / gradient_accumulation
94 |
95 | if i % log_interval == 0:
96 | wandb.log({
97 | "Loss": loss.item()
98 | })
99 |
100 | loss.backward()
101 | i += 1
102 | pbar.update(1)
103 |
104 | optimizer.step()
105 | if scheduler is not None:
106 | scheduler.step()
107 |
108 | gc.collect()
109 |
110 | # Inline evaluation
111 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
112 | print(f"Validation acc: {acc}")
113 |
114 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
115 |
116 | # Saving the best model and early stopping
117 | #if val_loss < best_loss:
118 | if acc > best_acc:
119 | best_model = model.state_dict()
120 | #best_loss = val_loss
121 | best_acc = acc
122 | #wandb.run.summary['best_validation_loss'] = best_loss
123 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
124 | patience_counter = 0
125 | # Log to wandb
126 | wandb.log({
127 | 'Validation accuracy': acc,
128 | 'Validation Precision': P,
129 | 'Validation Recall': R,
130 | 'Validation F1': F1,
131 | 'Validation loss': val_loss})
132 | else:
133 | patience_counter += 1
134 | # Stop training once we have lost patience
135 | if patience_counter == patience:
136 | break
137 |
138 | gc.collect()
139 | epoch_counter += 1
140 |
141 |
142 | if __name__ == "__main__":
143 | # Define arguments
144 | parser = argparse.ArgumentParser()
145 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str)
146 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8)
147 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0)
148 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1)
149 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200)
150 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2)
151 | parser.add_argument("--pretrained_model", help="Weights to initialize the model with", type=str, default=None)
152 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[])
153 | parser.add_argument("--seed", type=int, help="Random seed", default=1000)
154 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline")
155 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str)
156 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[])
157 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert")
158 | parser.add_argument("--ff_dim", help="The dimensionality of the feedforward network in the sluice", type=int, default=768)
159 | parser.add_argument("--batch_size", help="The batch size", type=int, default=8)
160 | parser.add_argument("--lr", help="Learning rate", type=float, default=3e-5)
161 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01)
162 | parser.add_argument("--lambd", help="l2 reg", type=float, default=10e-3)
163 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int)
164 | parser.add_argument("--supervision_layer", help="The layer at which to use domain adversarial supervision", default=12, type=int)
165 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None)
166 | parser.add_argument("--full_bert", help="Specify to use full bert model", action="store_true")
167 |
168 |
169 | args = parser.parse_args()
170 |
171 | # Set all the seeds
172 | seed = args.seed
173 | random.seed(seed)
174 | np.random.seed(seed)
175 | torch.manual_seed(seed)
176 | torch.cuda.manual_seed_all(seed)
177 | torch.backends.cudnn.deterministic = True
178 | torch.backends.cudnn.benchmark = False
179 |
180 | # See if CUDA available
181 | device = torch.device("cpu")
182 | if args.n_gpu > 0 and torch.cuda.is_available():
183 | print("Training on GPU")
184 | device = torch.device("cuda:0")
185 |
186 | # model configuration
187 | batch_size = args.batch_size
188 | lr = args.lr
189 | weight_decay = args.weight_decay
190 | n_epochs = args.n_epochs
191 | if args.full_bert:
192 | bert_model = 'bert-base-uncased'
193 | bert_config = BertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True)
194 | tokenizer = BertTokenizer.from_pretrained(bert_model)
195 | else:
196 | bert_model = 'distilbert-base-uncased'
197 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True)
198 | tokenizer = DistilBertTokenizer.from_pretrained(bert_model)
199 |
200 | # wandb initialization
201 | wandb.init(
202 | project="domain-adaptation-sentiment-emnlp",
203 | name=args.run_name,
204 | config={
205 | "epochs": n_epochs,
206 | "learning_rate": lr,
207 | "warmup": args.warmup_steps,
208 | "weight_decay": weight_decay,
209 | "batch_size": batch_size,
210 | "train_split_percentage": args.train_pct,
211 | "bert_model": bert_model,
212 | "seed": seed,
213 | "pretrained_model": args.pretrained_model,
214 | "tags": ",".join(args.tags)
215 | }
216 | )
217 | # Create save directory for model
218 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"):
219 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}")
220 |
221 | # Create the dataset
222 | all_dsets = [MultiDomainSentimentDataset(
223 | args.dataset_loc,
224 | [domain],
225 | tokenizer
226 | ) for domain in args.domains]
227 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)]
228 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))]
229 |
230 | accs = []
231 | Ps = []
232 | Rs = []
233 | F1s = []
234 | # Store labels and logits for individual splits for micro F1
235 | labels_all = []
236 | logits_all = []
237 |
238 | for i in range(len(all_dsets)):
239 | domain = args.domains[i]
240 | test_dset = all_dsets[i]
241 | # Override the domain IDs
242 | k = 0
243 | for j in range(len(all_dsets)):
244 | if j != i:
245 | all_dsets[j].set_domain_id(k)
246 | k += 1
247 | test_dset.set_domain_id(k)
248 |
249 | # Split the data
250 | if args.indices_dir is None:
251 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]])
252 | for j in range(len(all_dsets)) if j != i]
253 | else:
254 | # load the indices
255 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i]
256 | subset_indices = defaultdict(lambda: [[], []])
257 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \
258 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g:
259 | for l in f:
260 | vals = l.strip().split(',')
261 | subset_indices[int(vals[0])][0].append(int(vals[1]))
262 | for l in g:
263 | vals = l.strip().split(',')
264 | subset_indices[int(vals[0])][1].append(int(vals[1]))
265 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])]
266 | for d in subset_indices]
267 | train_dls = [DataLoader(
268 | subset[0],
269 | batch_size=batch_size,
270 | shuffle=True,
271 | collate_fn=collate_batch_transformer
272 | ) for subset in subsets]
273 | # Add test data for domain adversarial training
274 | train_dls += [DataLoader(
275 | test_dset,
276 | batch_size=batch_size,
277 | shuffle=True,
278 | collate_fn=collate_batch_transformer
279 | )]
280 |
281 | val_ds = [subset[1] for subset in subsets]
282 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device)
283 |
284 | # Create the model
285 | if args.full_bert:
286 | bert = BertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
287 | else:
288 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
289 | model = torch.nn.DataParallel(DomainAdversarialBert(
290 | bert,
291 | n_domains=len(train_dls),
292 | supervision_layer=args.supervision_layer
293 | )).to(device)
294 | if args.pretrained_model is not None:
295 | model.load_state_dict(torch.load(f"{args.pretrained_model}/model_{domain}.pth"))
296 |
297 | # Create the optimizer
298 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
299 | optimizer_grouped_parameters = [
300 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
301 | 'weight_decay': weight_decay},
302 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
303 | ]
304 | # optimizer = Adam(optimizer_grouped_parameters, lr=1e-3)
305 | # scheduler = None
306 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
307 | scheduler = get_linear_schedule_with_warmup(
308 | optimizer,
309 | args.warmup_steps,
310 | n_epochs * sum([len(train_dl) for train_dl in train_dls])
311 | )
312 |
313 | # Train
314 | train(
315 | model,
316 | train_dls,
317 | optimizer,
318 | scheduler,
319 | validation_evaluator,
320 | n_epochs,
321 | device,
322 | args.log_interval,
323 | model_dir=args.model_dir,
324 | gradient_accumulation=args.gradient_accumulation,
325 | domain_name=domain
326 | )
327 |
328 | # Load the best weights
329 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth'))
330 |
331 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False)
332 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate(
333 | model,
334 | plot_callbacks=[plot_label_distribution],
335 | return_labels_logits=True,
336 | return_votes=True
337 | )
338 | print(f"{domain} F1: {F1}")
339 | print(f"{domain} Accuracy: {acc}")
340 | print()
341 |
342 | wandb.run.summary[f"{domain}-P"] = P
343 | wandb.run.summary[f"{domain}-R"] = R
344 | wandb.run.summary[f"{domain}-F1"] = F1
345 | wandb.run.summary[f"{domain}-Acc"] = acc
346 | Ps.append(P)
347 | Rs.append(R)
348 | F1s.append(F1)
349 | accs.append(acc)
350 | labels_all.extend(labels)
351 | logits_all.extend(logits)
352 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f:
353 | for p, l in zip(np.argmax(logits, axis=-1), labels):
354 | f.write(f'{domain}\t{p}\t{l}\n')
355 |
356 | acc, P, R, F1 = acc_f1(logits_all, labels_all)
357 | # Add to wandb
358 | wandb.run.summary[f'test-loss'] = loss
359 | wandb.run.summary[f'test-micro-acc'] = acc
360 | wandb.run.summary[f'test-micro-P'] = P
361 | wandb.run.summary[f'test-micro-R'] = R
362 | wandb.run.summary[f'test-micro-F1'] = F1
363 |
364 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs)
365 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps)
366 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs)
367 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s)
368 |
369 | #wandb.log({f"label-distribution-test-{i}": plots[0]})
370 |
--------------------------------------------------------------------------------
/emnlp_final_experiments/sentiment-analysis/train_multi_view.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import os
4 | import random
5 | from typing import AnyStr
6 | from typing import List
7 | import ipdb
8 | import krippendorff
9 | from collections import defaultdict
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | import torch
14 | import wandb
15 | from torch.optim.lr_scheduler import LambdaLR
16 | from torch.utils.data import DataLoader
17 | from torch.utils.data import Subset
18 | from torch.utils.data import random_split
19 | from torch.optim import Adam
20 | from tqdm import tqdm
21 | from transformers import AdamW
22 | from transformers import DistilBertConfig
23 | from transformers import DistilBertTokenizer
24 | from transformers import DistilBertForSequenceClassification
25 | from transformers import get_linear_schedule_with_warmup
26 |
27 | from datareader import MultiDomainSentimentDataset
28 | from datareader import collate_batch_transformer
29 | from metrics import MultiDatasetClassificationEvaluator
30 | from metrics import ClassificationEvaluator
31 | from metrics import acc_f1
32 |
33 | from metrics import plot_label_distribution
34 | from model import MultiTransformerClassifier
35 | from model import VanillaBert
36 | from model import *
37 |
38 |
39 | def train(
40 | model: torch.nn.Module,
41 | train_dls: List[DataLoader],
42 | optimizer: torch.optim.Optimizer,
43 | scheduler: LambdaLR,
44 | validation_evaluator: MultiDatasetClassificationEvaluator,
45 | n_epochs: int,
46 | device: AnyStr,
47 | log_interval: int = 1,
48 | patience: int = 10,
49 | model_dir: str = "wandb_local",
50 | gradient_accumulation: int = 1,
51 | domain_name: str = ''
52 | ):
53 | #best_loss = float('inf')
54 | best_acc = 0.0
55 | patience_counter = 0
56 |
57 | epoch_counter = 0
58 | total = sum(len(dl) for dl in train_dls)
59 |
60 | # Main loop
61 | while epoch_counter < n_epochs:
62 | dl_iters = [iter(dl) for dl in train_dls]
63 | dl_idx = list(range(len(dl_iters)))
64 | finished = [0] * len(dl_iters)
65 | i = 0
66 | with tqdm(total=total, desc="Training") as pbar:
67 | while sum(finished) < len(dl_iters):
68 | random.shuffle(dl_idx)
69 | for d in dl_idx:
70 | domain_dl = dl_iters[d]
71 | batches = []
72 | try:
73 | for j in range(gradient_accumulation):
74 | batches.append(next(domain_dl))
75 | except StopIteration:
76 | finished[d] = 1
77 | if len(batches) == 0:
78 | continue
79 | optimizer.zero_grad()
80 | for batch in batches:
81 | model.train()
82 | batch = tuple(t.to(device) for t in batch)
83 | input_ids = batch[0]
84 | masks = batch[1]
85 | labels = batch[2]
86 | # Testing with random domains to see if any effect
87 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
88 | domains = batch[3]
89 |
90 | loss, logits, alpha = model(input_ids, attention_mask=masks, domains=domains, labels=labels, ret_alpha = True)
91 | loss = loss.mean() / gradient_accumulation
92 | if i % log_interval == 0:
93 | # wandb.log({
94 | # "Loss": loss.item(),
95 | # "alpha0": alpha[:,0].cpu(),
96 | # "alpha1": alpha[:, 1].cpu(),
97 | # "alpha2": alpha[:, 2].cpu(),
98 | # "alpha_shared": alpha[:, 3].cpu()
99 | # })
100 | wandb.log({
101 | "Loss": loss.item()
102 | })
103 |
104 | loss.backward()
105 | i += 1
106 | pbar.update(1)
107 |
108 | optimizer.step()
109 | if scheduler is not None:
110 | scheduler.step()
111 |
112 | gc.collect()
113 |
114 | # Inline evaluation
115 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
116 | print(f"Validation acc: {acc}")
117 |
118 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
119 |
120 | # Saving the best model and early stopping
121 | #if val_loss < best_loss:
122 | if acc > best_acc:
123 | best_model = model.state_dict()
124 | #best_loss = val_loss
125 | best_acc = acc
126 | #wandb.run.summary['best_validation_loss'] = best_loss
127 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
128 | patience_counter = 0
129 | # Log to wandb
130 | wandb.log({
131 | 'Validation accuracy': acc,
132 | 'Validation Precision': P,
133 | 'Validation Recall': R,
134 | 'Validation F1': F1,
135 | 'Validation loss': val_loss})
136 | else:
137 | patience_counter += 1
138 | # Stop training once we have lost patience
139 | if patience_counter == patience:
140 | break
141 |
142 | gc.collect()
143 | epoch_counter += 1
144 |
145 |
146 | if __name__ == "__main__":
147 | # Define arguments
148 | parser = argparse.ArgumentParser()
149 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str)
150 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8)
151 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0)
152 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1)
153 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200)
154 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2)
155 | parser.add_argument("--pretrained_bert", help="Directory with weights to initialize the shared model with", type=str, default=None)
156 | parser.add_argument("--pretrained_multi_xformer", help="Directory with weights to initialize the domain specific models", type=str, default=None)
157 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[])
158 | parser.add_argument("--seed", type=int, help="Random seed", default=1000)
159 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline")
160 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str)
161 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[])
162 | parser.add_argument("--batch_size", help="The batch size", type=int, default=16)
163 | parser.add_argument("--lr", help="Learning rate", type=float, default=1e-5)
164 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01)
165 | parser.add_argument("--n_heads", help="Number of transformer heads", default=6, type=int)
166 | parser.add_argument("--n_layers", help="Number of transformer layers", default=6, type=int)
167 | parser.add_argument("--d_model", help="Transformer model size", default=768, type=int)
168 | parser.add_argument("--ff_dim", help="Intermediate feedforward size", default=2048, type=int)
169 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int)
170 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert")
171 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None)
172 | parser.add_argument("--ensemble_basic", help="Use averaging for the ensembling method", action="store_true")
173 | parser.add_argument("--ensemble_avg_learned", help="Use learned averaging for the ensembling method", action="store_true")
174 |
175 |
176 | args = parser.parse_args()
177 |
178 | # Set all the seeds
179 | seed = args.seed
180 | random.seed(seed)
181 | np.random.seed(seed)
182 | torch.manual_seed(seed)
183 | torch.cuda.manual_seed_all(seed)
184 | torch.backends.cudnn.deterministic = True
185 | torch.backends.cudnn.benchmark = False
186 |
187 | # See if CUDA available
188 | device = torch.device("cpu")
189 | if args.n_gpu > 0 and torch.cuda.is_available():
190 | print("Training on GPU")
191 | device = torch.device("cuda:0")
192 |
193 | # model configuration
194 | bert_model = 'distilbert-base-uncased'
195 | batch_size = args.batch_size
196 | lr = args.lr
197 | weight_decay = args.weight_decay
198 | n_epochs = args.n_epochs
199 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True)
200 |
201 |
202 | # wandb initialization
203 | wandb.init(
204 | project="domain-adaptation-sentiment-emnlp",
205 | name=args.run_name,
206 | config={
207 | "epochs": n_epochs,
208 | "learning_rate": lr,
209 | "warmup": args.warmup_steps,
210 | "weight_decay": weight_decay,
211 | "batch_size": batch_size,
212 | "train_split_percentage": args.train_pct,
213 | "bert_model": bert_model,
214 | "seed": seed,
215 | "tags": ",".join(args.tags)
216 | }
217 | )
218 | #wandb.watch(model)
219 | #Create save directory for model
220 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"):
221 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}")
222 |
223 | # Create the dataset
224 | all_dsets = [MultiDomainSentimentDataset(
225 | args.dataset_loc,
226 | [domain],
227 | DistilBertTokenizer.from_pretrained(bert_model)
228 | ) for domain in args.domains]
229 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)]
230 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))]
231 |
232 | accs = []
233 | Ps = []
234 | Rs = []
235 | F1s = []
236 | # Store labels and logits for individual splits for micro F1
237 | labels_all = []
238 | logits_all = []
239 |
240 | for i in range(len(all_dsets)):
241 | domain = args.domains[i]
242 | test_dset = all_dsets[i]
243 | # Override the domain IDs
244 | k = 0
245 | for j in range(len(all_dsets)):
246 | if j != i:
247 | all_dsets[j].set_domain_id(k)
248 | k += 1
249 | test_dset.set_domain_id(k)
250 | # For test
251 | #all_dsets = [all_dsets[0], all_dsets[2]]
252 |
253 | # Split the data
254 | if args.indices_dir is None:
255 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]])
256 | for j in range(len(all_dsets)) if j != i]
257 | else:
258 | # load the indices
259 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i]
260 | subset_indices = defaultdict(lambda: [[], []])
261 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \
262 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g:
263 | for l in f:
264 | vals = l.strip().split(',')
265 | subset_indices[int(vals[0])][0].append(int(vals[1]))
266 | for l in g:
267 | vals = l.strip().split(',')
268 | subset_indices[int(vals[0])][1].append(int(vals[1]))
269 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] for d in
270 | subset_indices]
271 |
272 | train_dls = [DataLoader(
273 | subset[0],
274 | batch_size=batch_size,
275 | shuffle=True,
276 | collate_fn=collate_batch_transformer
277 | ) for subset in subsets]
278 |
279 | val_ds = [subset[1] for subset in subsets]
280 | # for vds in val_ds:
281 | # print(vds.indices)
282 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device)
283 |
284 | # Create the model
285 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
286 | multi_xformer = MultiDistilBertClassifier(
287 | bert_model,
288 | bert_config,
289 | n_domains=len(train_dls)
290 | ).to(device)
291 | if args.pretrained_multi_xformer is not None:
292 | multi_xformer.load_state_dict(torch.load(f"{args.pretrained_multi_xformer}/model_{domain}.pth"))
293 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(multi_xformer)
294 | print(f"Validation acc multi-xformer: {acc}")
295 |
296 |
297 | shared_bert = VanillaBert(bert).to(device)
298 | if args.pretrained_bert is not None:
299 | shared_bert.load_state_dict(torch.load(f"{args.pretrained_bert}/model_{domain}.pth"))
300 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(shared_bert)
301 | print(f"Validation acc shared bert: {acc}")
302 |
303 | if args.ensemble_basic:
304 | model_class = MultiViewTransformerNetworkAveraging
305 | elif args.ensemble_avg_learned:
306 | model_class = MultiViewTransformerNetworkLearnedAveraging
307 | else:
308 | model_class = MultiViewTransformerNetworkProbabilities
309 |
310 | model = torch.nn.DataParallel(model_class(
311 | multi_xformer,
312 | shared_bert
313 | )).to(device)
314 |
315 | # (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
316 | # print(f"Validation acc starting: {acc}")
317 |
318 | # Create the optimizer
319 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
320 | optimizer_grouped_parameters = [
321 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
322 | 'weight_decay': weight_decay},
323 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
324 | ]
325 |
326 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
327 | scheduler = get_linear_schedule_with_warmup(
328 | optimizer,
329 | args.warmup_steps,
330 | n_epochs * sum([len(train_dl) for train_dl in train_dls])
331 | )
332 |
333 | # Train
334 | train(
335 | model,
336 | train_dls,
337 | optimizer,
338 | scheduler,
339 | validation_evaluator,
340 | n_epochs,
341 | device,
342 | args.log_interval,
343 | model_dir=args.model_dir,
344 | gradient_accumulation=args.gradient_accumulation,
345 | domain_name=domain
346 | )
347 | # Load the best weights
348 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth'))
349 | if args.ensemble_avg_learned:
350 | weights = model.module.alpha_params.cpu().detach().numpy()
351 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/weights_{domain}.txt', 'wt') as f:
352 | f.write(str(weights))
353 |
354 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False)
355 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate(
356 | model,
357 | plot_callbacks=[plot_label_distribution],
358 | return_labels_logits=True,
359 | return_votes=True
360 | )
361 | print(f"{domain} F1: {F1}")
362 | print(f"{domain} Accuracy: {acc}")
363 | print()
364 |
365 | wandb.run.summary[f"{domain}-P"] = P
366 | wandb.run.summary[f"{domain}-R"] = R
367 | wandb.run.summary[f"{domain}-F1"] = F1
368 | wandb.run.summary[f"{domain}-Acc"] = acc
369 | Ps.append(P)
370 | Rs.append(R)
371 | F1s.append(F1)
372 | accs.append(acc)
373 | labels_all.extend(labels)
374 | logits_all.extend(logits)
375 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f:
376 | for p, l in zip(np.argmax(logits, axis=-1), labels):
377 | f.write(f'{domain}\t{p}\t{l}\n')
378 |
379 | acc, P, R, F1 = acc_f1(logits_all, labels_all)
380 | # Add to wandb
381 | wandb.run.summary[f'test-loss'] = loss
382 | wandb.run.summary[f'test-micro-acc'] = acc
383 | wandb.run.summary[f'test-micro-P'] = P
384 | wandb.run.summary[f'test-micro-R'] = R
385 | wandb.run.summary[f'test-micro-F1'] = F1
386 |
387 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs)
388 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps)
389 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs)
390 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s)
391 |
392 | # wandb.log({f"label-distribution-test-{i}": plots[0]})
393 |
--------------------------------------------------------------------------------
/emnlp_final_experiments/sentiment-analysis/train_multi_view_domain_adversarial.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import os
4 | import random
5 | from typing import AnyStr
6 | from typing import List
7 | import ipdb
8 | import krippendorff
9 | from collections import defaultdict
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | import torch
14 | import wandb
15 | from torch.optim.lr_scheduler import LambdaLR
16 | from torch.utils.data import DataLoader
17 | from torch.utils.data import Subset
18 | from torch.utils.data import random_split
19 | from torch.optim import Adam
20 | from tqdm import tqdm
21 | from transformers import AdamW
22 | from transformers import DistilBertConfig
23 | from transformers import DistilBertTokenizer
24 | from transformers import DistilBertForSequenceClassification
25 | from transformers import get_linear_schedule_with_warmup
26 |
27 | from datareader import MultiDomainSentimentDataset
28 | from datareader import collate_batch_transformer
29 | from metrics import MultiDatasetClassificationEvaluator
30 | from metrics import ClassificationEvaluator
31 | from metrics import acc_f1
32 |
33 | from metrics import plot_label_distribution
34 | from model import MultiTransformerClassifier
35 | from model import VanillaBert
36 | from model import *
37 |
38 |
39 | def train(
40 | model: torch.nn.Module,
41 | train_dls: List[DataLoader],
42 | optimizer: torch.optim.Optimizer,
43 | scheduler: LambdaLR,
44 | validation_evaluator: MultiDatasetClassificationEvaluator,
45 | n_epochs: int,
46 | device: AnyStr,
47 | log_interval: int = 1,
48 | patience: int = 10,
49 | model_dir: str = "wandb_local",
50 | gradient_accumulation: int = 1,
51 | domain_name: str = ''
52 | ):
53 | #best_loss = float('inf')
54 | best_acc = 0.0
55 | patience_counter = 0
56 |
57 | epoch_counter = 0
58 | total = sum(len(dl) for dl in train_dls)
59 |
60 | # Main loop
61 | while epoch_counter < n_epochs:
62 | dl_iters = [iter(dl) for dl in train_dls]
63 | dl_idx = list(range(len(dl_iters)))
64 | finished = [0] * len(dl_iters)
65 | i = 0
66 | with tqdm(total=total, desc="Training") as pbar:
67 | while sum(finished) < len(dl_iters):
68 | random.shuffle(dl_idx)
69 | for d in dl_idx:
70 | domain_dl = dl_iters[d]
71 | batches = []
72 | try:
73 | for j in range(gradient_accumulation):
74 | batches.append(next(domain_dl))
75 | except StopIteration:
76 | finished[d] = 1
77 | if len(batches) == 0:
78 | continue
79 | optimizer.zero_grad()
80 | for batch in batches:
81 | model.train()
82 | batch = tuple(t.to(device) for t in batch)
83 | input_ids = batch[0]
84 | masks = batch[1]
85 | labels = batch[2]
86 | # Null the labels if its the test data
87 | if d == len(train_dls) - 1:
88 | labels = None
89 | # Testing with random domains to see if any effect
90 | #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
91 | domains = batch[3]
92 |
93 | loss, logits, alpha = model(input_ids, attention_mask=masks, domains=domains, labels=labels, ret_alpha = True)
94 | loss = loss.mean() / gradient_accumulation
95 | if i % log_interval == 0:
96 | # wandb.log({
97 | # "Loss": loss.item(),
98 | # "alpha0": alpha[:,0].cpu(),
99 | # "alpha1": alpha[:, 1].cpu(),
100 | # "alpha2": alpha[:, 2].cpu(),
101 | # "alpha_shared": alpha[:, 3].cpu()
102 | # })
103 | wandb.log({
104 | "Loss": loss.item()
105 | })
106 |
107 | loss.backward()
108 | i += 1
109 | pbar.update(1)
110 |
111 | optimizer.step()
112 | if scheduler is not None:
113 | scheduler.step()
114 |
115 | gc.collect()
116 |
117 | # Inline evaluation
118 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
119 | print(f"Validation acc: {acc}")
120 |
121 | #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
122 |
123 | # Saving the best model and early stopping
124 | #if val_loss < best_loss:
125 | if acc > best_acc:
126 | best_model = model.state_dict()
127 | #best_loss = val_loss
128 | best_acc = acc
129 | #wandb.run.summary['best_validation_loss'] = best_loss
130 | torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
131 | patience_counter = 0
132 | # Log to wandb
133 | wandb.log({
134 | 'Validation accuracy': acc,
135 | 'Validation Precision': P,
136 | 'Validation Recall': R,
137 | 'Validation F1': F1,
138 | 'Validation loss': val_loss})
139 | else:
140 | patience_counter += 1
141 | # Stop training once we have lost patience
142 | if patience_counter == patience:
143 | break
144 |
145 | gc.collect()
146 | epoch_counter += 1
147 |
148 |
149 | if __name__ == "__main__":
150 | # Define arguments
151 | parser = argparse.ArgumentParser()
152 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str)
153 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8)
154 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0)
155 | parser.add_argument("--log_interval", help="Number of steps to take between logging steps", type=int, default=1)
156 | parser.add_argument("--warmup_steps", help="Number of steps to warm up Adam", type=int, default=200)
157 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2)
158 | parser.add_argument("--pretrained_bert", help="Directory with weights to initialize the shared model with", type=str, default=None)
159 | parser.add_argument("--pretrained_multi_xformer", help="Directory with weights to initialize the domain specific models", type=str, default=None)
160 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[])
161 | parser.add_argument("--seed", type=int, help="Random seed", default=1000)
162 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline")
163 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str)
164 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[])
165 | parser.add_argument("--batch_size", help="The batch size", type=int, default=16)
166 | parser.add_argument("--lr", help="Learning rate", type=float, default=1e-5)
167 | parser.add_argument("--weight_decay", help="l2 reg", type=float, default=0.01)
168 | parser.add_argument("--n_heads", help="Number of transformer heads", default=6, type=int)
169 | parser.add_argument("--n_layers", help="Number of transformer layers", default=6, type=int)
170 | parser.add_argument("--d_model", help="Transformer model size", default=768, type=int)
171 | parser.add_argument("--ff_dim", help="Intermediate feedforward size", default=2048, type=int)
172 | parser.add_argument("--gradient_accumulation", help="Number of gradient accumulation steps", default=1, type=int)
173 | parser.add_argument("--model", help="Name of the model to run", default="VanillaBert")
174 | parser.add_argument("--supervision_layer", help="The layer at which to use domain adversarial supervision", default=12, type=int)
175 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None)
176 |
177 | args = parser.parse_args()
178 |
179 | # Set all the seeds
180 | seed = args.seed
181 | random.seed(seed)
182 | np.random.seed(seed)
183 | torch.manual_seed(seed)
184 | torch.cuda.manual_seed_all(seed)
185 | torch.backends.cudnn.deterministic = True
186 | torch.backends.cudnn.benchmark = False
187 |
188 | # See if CUDA available
189 | device = torch.device("cpu")
190 | if args.n_gpu > 0 and torch.cuda.is_available():
191 | print("Training on GPU")
192 | device = torch.device("cuda:0")
193 |
194 | # model configuration
195 | bert_model = 'distilbert-base-uncased'
196 | ##################
197 | # override for now
198 | batch_size = 4#args.batch_size
199 | args.gradient_accumulation = 2
200 | ###############
201 | lr = args.lr
202 | weight_decay = args.weight_decay
203 | n_epochs = args.n_epochs
204 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True)
205 |
206 | # wandb initialization
207 | wandb.init(
208 | project="domain-adaptation-sentiment-emnlp",
209 | name=args.run_name,
210 | config={
211 | "epochs": n_epochs,
212 | "learning_rate": lr,
213 | "warmup": args.warmup_steps,
214 | "weight_decay": weight_decay,
215 | "batch_size": batch_size,
216 | "train_split_percentage": args.train_pct,
217 | "bert_model": bert_model,
218 | "seed": seed,
219 | "tags": ",".join(args.tags)
220 | }
221 | )
222 | #wandb.watch(model)
223 | #Create save directory for model
224 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"):
225 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}")
226 |
227 | # Create the dataset
228 | all_dsets = [MultiDomainSentimentDataset(
229 | args.dataset_loc,
230 | [domain],
231 | DistilBertTokenizer.from_pretrained(bert_model)
232 | ) for domain in args.domains]
233 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)]
234 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))]
235 |
236 | accs = []
237 | Ps = []
238 | Rs = []
239 | F1s = []
240 | # Store labels and logits for individual splits for micro F1
241 | labels_all = []
242 | logits_all = []
243 |
244 | for i in range(len(all_dsets)):
245 | domain = args.domains[i]
246 | test_dset = all_dsets[i]
247 | # Override the domain IDs
248 | k = 0
249 | for j in range(len(all_dsets)):
250 | if j != i:
251 | all_dsets[j].set_domain_id(k)
252 | k += 1
253 | test_dset.set_domain_id(k)
254 | # For test
255 | #all_dsets = [all_dsets[0], all_dsets[2]]
256 |
257 | # Split the data
258 | if args.indices_dir is None:
259 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]])
260 | for j in range(len(all_dsets)) if j != i]
261 | else:
262 | # load the indices
263 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i]
264 | subset_indices = defaultdict(lambda: [[], []])
265 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \
266 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g:
267 | for l in f:
268 | vals = l.strip().split(',')
269 | subset_indices[int(vals[0])][0].append(int(vals[1]))
270 | for l in g:
271 | vals = l.strip().split(',')
272 | subset_indices[int(vals[0])][1].append(int(vals[1]))
273 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] for d in
274 | subset_indices]
275 |
276 | train_dls = [DataLoader(
277 | subset[0],
278 | batch_size=batch_size,
279 | shuffle=True,
280 | collate_fn=collate_batch_transformer
281 | ) for subset in subsets]
282 | # Add test data for domain adversarial training
283 | train_dls += [DataLoader(
284 | test_dset,
285 | batch_size=batch_size,
286 | shuffle=True,
287 | collate_fn=collate_batch_transformer
288 | )]
289 |
290 | val_ds = [subset[1] for subset in subsets]
291 | # for vds in val_ds:
292 | # print(vds.indices)
293 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device)
294 |
295 | # Create the model
296 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
297 | multi_xformer = MultiDistilBertClassifier(
298 | bert_model,
299 | bert_config,
300 | n_domains=len(train_dls) - 1
301 | ).to(device)
302 | if args.pretrained_multi_xformer is not None:
303 | multi_xformer.load_state_dict(torch.load(f"{args.pretrained_multi_xformer}/model_{domain}.pth"))
304 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(multi_xformer)
305 | print(f"Validation acc multi-xformer: {acc}")
306 |
307 | shared_bert = VanillaBert(bert).to(device)
308 | if args.pretrained_bert is not None:
309 | shared_bert.load_state_dict(torch.load(f"{args.pretrained_bert}/model_{domain}.pth"))
310 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(shared_bert)
311 | print(f"Validation acc shared bert: {acc}")
312 |
313 | model = torch.nn.DataParallel(MultiViewTransformerNetworkProbabilitiesAdversarial(
314 | multi_xformer,
315 | shared_bert,
316 | supervision_layer=args.supervision_layer
317 | )).to(device)
318 | # (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
319 | # print(f"Validation acc starting: {acc}")
320 |
321 | # Create the optimizer
322 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
323 | optimizer_grouped_parameters = [
324 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
325 | 'weight_decay': weight_decay},
326 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
327 | ]
328 |
329 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
330 | scheduler = get_linear_schedule_with_warmup(
331 | optimizer,
332 | args.warmup_steps,
333 | n_epochs * sum([len(train_dl) for train_dl in train_dls])
334 | )
335 |
336 | # Train
337 | train(
338 | model,
339 | train_dls,
340 | optimizer,
341 | scheduler,
342 | validation_evaluator,
343 | n_epochs,
344 | device,
345 | args.log_interval,
346 | model_dir=args.model_dir,
347 | gradient_accumulation=args.gradient_accumulation,
348 | domain_name=domain
349 | )
350 | # Load the best weights
351 | model.load_state_dict(torch.load(f'{args.model_dir}/{Path(wandb.run.dir).name}/model_{domain}.pth'))
352 |
353 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False)
354 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate(
355 | model,
356 | plot_callbacks=[plot_label_distribution],
357 | return_labels_logits=True,
358 | return_votes=True
359 | )
360 | print(f"{domain} F1: {F1}")
361 | print(f"{domain} Accuracy: {acc}")
362 | print()
363 |
364 | wandb.run.summary[f"{domain}-P"] = P
365 | wandb.run.summary[f"{domain}-R"] = R
366 | wandb.run.summary[f"{domain}-F1"] = F1
367 | wandb.run.summary[f"{domain}-Acc"] = acc
368 | Ps.append(P)
369 | Rs.append(R)
370 | F1s.append(F1)
371 | accs.append(acc)
372 | labels_all.extend(labels)
373 | logits_all.extend(logits)
374 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f:
375 | for p, l in zip(np.argmax(logits, axis=-1), labels):
376 | f.write(f'{domain}\t{p}\t{l}\n')
377 |
378 | acc, P, R, F1 = acc_f1(logits_all, labels_all)
379 | # Add to wandb
380 | wandb.run.summary[f'test-loss'] = loss
381 | wandb.run.summary[f'test-micro-acc'] = acc
382 | wandb.run.summary[f'test-micro-P'] = P
383 | wandb.run.summary[f'test-micro-R'] = R
384 | wandb.run.summary[f'test-micro-F1'] = F1
385 |
386 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs)
387 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps)
388 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs)
389 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s)
390 |
391 | # wandb.log({f"label-distribution-test-{i}": plots[0]})
392 |
--------------------------------------------------------------------------------
/emnlp_final_experiments/sentiment-analysis/train_multi_view_selective_weighting.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import os
4 | import random
5 | from typing import AnyStr
6 | from typing import List
7 | import ipdb
8 | import krippendorff
9 | from collections import defaultdict
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | import torch
14 | import wandb
15 | from torch.optim.lr_scheduler import LambdaLR
16 | from torch.utils.data import DataLoader
17 | from torch.utils.data import Subset
18 | from torch.utils.data import random_split
19 | from torch.optim import Adam
20 | from tqdm import tqdm
21 | from transformers import AdamW
22 | from transformers import DistilBertConfig
23 | from transformers import DistilBertTokenizer
24 | from transformers import DistilBertForSequenceClassification
25 | from transformers import get_linear_schedule_with_warmup
26 |
27 | from datareader import MultiDomainSentimentDataset
28 | from datareader import collate_batch_transformer
29 | from metrics import MultiDatasetClassificationEvaluator
30 | from metrics import ClassificationEvaluator
31 | from metrics import acc_f1
32 |
33 | from metrics import plot_label_distribution
34 | from model import MultiTransformerClassifier
35 | from model import VanillaBert
36 | from model import *
37 | from sklearn.model_selection import ParameterSampler
38 |
39 |
40 | def attention_grid_search(
41 | model: torch.nn.Module,
42 | validation_evaluator: MultiDatasetClassificationEvaluator,
43 | n_epochs: int,
44 | seed: int
45 | ):
46 | best_weights = model.module.weights
47 | # initial
48 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
49 | best_acc = acc
50 | print(acc)
51 | # Create the grid search
52 | param_dict = {1:list(range(0,11)), 2:list(range(0,11)),3:list(range(0,11)),4:list(range(0,11))}
53 | grid_search_params = ParameterSampler(param_dict, n_iter=n_epochs, random_state=seed)
54 | for d in grid_search_params:
55 | weights = [v for k,v in sorted(d.items(), key=lambda x:x[0])]
56 | weights = np.array(weights) / sum(weights)
57 | model.module.weights = weights
58 | # Inline evaluation
59 | (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
60 | print(f"Weights: {weights}\tValidation acc: {acc}")
61 |
62 | if acc > best_acc:
63 | best_weights = weights
64 | best_acc = acc
65 | # Log to wandb
66 | wandb.log({
67 | 'Validation accuracy': acc,
68 | 'Validation Precision': P,
69 | 'Validation Recall': R,
70 | 'Validation F1': F1,
71 | 'Validation loss': val_loss})
72 |
73 | gc.collect()
74 |
75 | return best_weights
76 |
77 |
78 | if __name__ == "__main__":
79 | # Define arguments
80 | parser = argparse.ArgumentParser()
81 | parser.add_argument("--dataset_loc", help="Root directory of the dataset", required=True, type=str)
82 | parser.add_argument("--train_pct", help="Percentage of data to use for training", type=float, default=0.8)
83 | parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0)
84 | parser.add_argument("--n_epochs", help="Number of epochs", type=int, default=2)
85 | parser.add_argument("--pretrained_model", help="The pretrained averaging model", type=str, default=None)
86 | parser.add_argument("--domains", nargs='+', help='A list of domains to use for training', default=[])
87 | parser.add_argument("--seed", type=int, help="Random seed", default=1000)
88 | parser.add_argument("--run_name", type=str, help="A name for the run", default="pheme-baseline")
89 | parser.add_argument("--model_dir", help="Where to store the saved model", default="wandb_local", type=str)
90 | parser.add_argument("--tags", nargs='+', help='A list of tags for this run', default=[])
91 | parser.add_argument("--indices_dir", help="If standard splits are being used", type=str, default=None)
92 |
93 | args = parser.parse_args()
94 |
95 | # Set all the seeds
96 | seed = args.seed
97 | random.seed(seed)
98 | np.random.seed(seed)
99 | torch.manual_seed(seed)
100 | torch.cuda.manual_seed_all(seed)
101 | torch.backends.cudnn.deterministic = True
102 | torch.backends.cudnn.benchmark = False
103 |
104 | # See if CUDA available
105 | device = torch.device("cpu")
106 | if args.n_gpu > 0 and torch.cuda.is_available():
107 | print("Training on GPU")
108 | device = torch.device("cuda:0")
109 |
110 | # model configuration
111 | bert_model = 'distilbert-base-uncased'
112 | n_epochs = args.n_epochs
113 | bert_config = DistilBertConfig.from_pretrained(bert_model, num_labels=2, output_hidden_states=True)
114 |
115 |
116 | # wandb initialization
117 | wandb.init(
118 | project="domain-adaptation-sentiment-emnlp",
119 | name=args.run_name,
120 | config={
121 | "epochs": n_epochs,
122 | "bert_model": bert_model,
123 | "seed": seed,
124 | "tags": ",".join(args.tags)
125 | }
126 | )
127 | #wandb.watch(model)
128 | #Create save directory for model
129 | if not os.path.exists(f"{args.model_dir}/{Path(wandb.run.dir).name}"):
130 | os.makedirs(f"{args.model_dir}/{Path(wandb.run.dir).name}")
131 |
132 | # Create the dataset
133 | all_dsets = [MultiDomainSentimentDataset(
134 | args.dataset_loc,
135 | [domain],
136 | DistilBertTokenizer.from_pretrained(bert_model)
137 | ) for domain in args.domains]
138 | train_sizes = [int(len(dset) * args.train_pct) for j, dset in enumerate(all_dsets)]
139 | val_sizes = [len(all_dsets[j]) - train_sizes[j] for j in range(len(train_sizes))]
140 |
141 | accs = []
142 | Ps = []
143 | Rs = []
144 | F1s = []
145 | # Store labels and logits for individual splits for micro F1
146 | labels_all = []
147 | logits_all = []
148 |
149 | for i in range(len(all_dsets)):
150 | domain = args.domains[i]
151 | test_dset = all_dsets[i]
152 | # Override the domain IDs
153 | k = 0
154 | for j in range(len(all_dsets)):
155 | if j != i:
156 | all_dsets[j].set_domain_id(k)
157 | k += 1
158 | test_dset.set_domain_id(k)
159 | # For test
160 | #all_dsets = [all_dsets[0], all_dsets[2]]
161 |
162 | # Split the data
163 | if args.indices_dir is None:
164 | subsets = [random_split(all_dsets[j], [train_sizes[j], val_sizes[j]])
165 | for j in range(len(all_dsets)) if j != i]
166 | else:
167 | # load the indices
168 | dset_choices = [all_dsets[j] for j in range(len(all_dsets)) if j != i]
169 | subset_indices = defaultdict(lambda: [[], []])
170 | with open(f'{args.indices_dir}/train_idx_{domain}.txt') as f, \
171 | open(f'{args.indices_dir}/val_idx_{domain}.txt') as g:
172 | for l in f:
173 | vals = l.strip().split(',')
174 | subset_indices[int(vals[0])][0].append(int(vals[1]))
175 | for l in g:
176 | vals = l.strip().split(',')
177 | subset_indices[int(vals[0])][1].append(int(vals[1]))
178 | subsets = [[Subset(dset_choices[d], subset_indices[d][0]), Subset(dset_choices[d], subset_indices[d][1])] for d in
179 | subset_indices]
180 |
181 | train_dls = [DataLoader(
182 | subset[0],
183 | batch_size=8,
184 | shuffle=True,
185 | collate_fn=collate_batch_transformer
186 | ) for subset in subsets]
187 |
188 | val_ds = [subset[1] for subset in subsets]
189 | # for vds in val_ds:
190 | # print(vds.indices)
191 | validation_evaluator = MultiDatasetClassificationEvaluator(val_ds, device)
192 |
193 | # Create the model
194 | bert = DistilBertForSequenceClassification.from_pretrained(bert_model, config=bert_config).to(device)
195 | multi_xformer = MultiDistilBertClassifier(
196 | bert_model,
197 | bert_config,
198 | n_domains=len(train_dls)
199 | ).to(device)
200 |
201 | shared_bert = VanillaBert(bert).to(device)
202 |
203 | model = torch.nn.DataParallel(MultiViewTransformerNetworkSelectiveWeight(
204 | multi_xformer,
205 | shared_bert
206 | )).to(device)
207 | # Load the best weights
208 | model.load_state_dict(torch.load(f'{args.pretrained_model}/model_{domain}.pth'))
209 |
210 | # Calculate the best attention weights with a grid search
211 | weights = attention_grid_search(
212 | model,
213 | validation_evaluator,
214 | n_epochs,
215 | seed
216 | )
217 | model.module.weights = weights
218 | print(f"Best weights: {weights}")
219 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/weights_{domain}.txt', 'wt') as f:
220 | f.write(str(weights))
221 |
222 | evaluator = ClassificationEvaluator(test_dset, device, use_domain=False)
223 | (loss, acc, P, R, F1), plots, (labels, logits), votes = evaluator.evaluate(
224 | model,
225 | plot_callbacks=[plot_label_distribution],
226 | return_labels_logits=True,
227 | return_votes=True
228 | )
229 | print(f"{domain} F1: {F1}")
230 | print(f"{domain} Accuracy: {acc}")
231 | print()
232 |
233 | wandb.run.summary[f"{domain}-P"] = P
234 | wandb.run.summary[f"{domain}-R"] = R
235 | wandb.run.summary[f"{domain}-F1"] = F1
236 | wandb.run.summary[f"{domain}-Acc"] = acc
237 | Ps.append(P)
238 | Rs.append(R)
239 | F1s.append(F1)
240 | accs.append(acc)
241 | labels_all.extend(labels)
242 | logits_all.extend(logits)
243 | with open(f'{args.model_dir}/{Path(wandb.run.dir).name}/pred_lab.txt', 'a+') as f:
244 | for p, l in zip(np.argmax(logits, axis=-1), labels):
245 | f.write(f'{domain}\t{p}\t{l}\n')
246 |
247 | acc, P, R, F1 = acc_f1(logits_all, labels_all)
248 | # Add to wandb
249 | wandb.run.summary[f'test-loss'] = loss
250 | wandb.run.summary[f'test-micro-acc'] = acc
251 | wandb.run.summary[f'test-micro-P'] = P
252 | wandb.run.summary[f'test-micro-R'] = R
253 | wandb.run.summary[f'test-micro-F1'] = F1
254 |
255 | wandb.run.summary[f'test-macro-acc'] = sum(accs) / len(accs)
256 | wandb.run.summary[f'test-macro-P'] = sum(Ps) / len(Ps)
257 | wandb.run.summary[f'test-macro-R'] = sum(Rs) / len(Rs)
258 | wandb.run.summary[f'test-macro-F1'] = sum(F1s) / len(F1s)
259 |
260 | # wandb.log({f"label-distribution-test-{i}": plots[0]})
261 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib
4 | import matplotlib.pyplot as plt
5 | from typing import Tuple, List, Callable, AnyStr
6 | from tqdm import tqdm
7 | from torch.utils.data import DataLoader
8 | from torch.utils.data import Dataset
9 | from sklearn.metrics import precision_recall_fscore_support
10 | from collections import Counter
11 | from scipy.stats import entropy
12 | import ipdb
13 |
14 | from datareader import collate_batch_transformer
15 |
16 |
17 | def accuracy(logits: np.ndarray, labels: np.ndarray) -> float:
18 | return np.sum(np.argmax(logits, axis=-1) == labels).astype(np.float32) / float(labels.shape[0])
19 |
20 |
21 | def acc_f1(logits: List, labels: List) -> Tuple[float, float, float, float]:
22 | logits = np.asarray(logits).reshape(-1, len(logits[0]))
23 | labels = np.asarray(labels).reshape(-1)
24 | acc = accuracy(logits, labels)
25 | average = 'binary' if logits.shape[1] == 2 else None
26 | P, R, F1, _ = precision_recall_fscore_support(labels, np.argmax(logits, axis=-1), average=average)
27 | return acc,P,R,F1
28 |
29 |
30 | def plot_label_distribution(labels: np.ndarray, logits: np.ndarray) -> matplotlib.figure.Figure:
31 | """ Plots the distribution of labels in the prediction
32 |
33 | :param labels: Gold labels
34 | :param logits: Logits from the model
35 | :return: None
36 | """
37 | predictions = np.argmax(logits, axis=-1)
38 | labs, counts = zip(*list(sorted(Counter(predictions).items(), key=lambda x: x[0])))
39 |
40 | fig, ax = plt.subplots(figsize=(12, 9))
41 | ax.bar(labs, counts, width=0.2)
42 | ax.set_xticks(labs, [str(l) for l in labs])
43 | ax.set_ylabel('Count')
44 | ax.set_xlabel("Label")
45 | ax.set_title("Prediction distribution")
46 | return fig
47 |
48 |
49 | class ClassificationEvaluator:
50 | """Wrapper to evaluate a model for the task of citation detection
51 |
52 | """
53 |
54 | def __init__(self, dataset: Dataset, device: torch.device, use_domain: bool = True, use_labels: bool = True):
55 | self.dataset = dataset
56 | self.dataloader = DataLoader(
57 | dataset,
58 | batch_size=8,
59 | collate_fn=collate_batch_transformer
60 | )
61 | self.device = device
62 | self.stored_labels = []
63 | self.stored_logits = []
64 | self.use_domain = use_domain
65 | self.use_labels = use_labels
66 |
67 | def micro_f1(self) -> Tuple[float, float, float, float]:
68 | labels_all = self.stored_labels
69 | logits_all = self.stored_logits
70 |
71 | logits_all = np.asarray(logits_all).reshape(-1, len(logits_all[0]))
72 | labels_all = np.asarray(labels_all).reshape(-1)
73 | acc = accuracy(logits_all, labels_all)
74 | P, R, F1, _ = precision_recall_fscore_support(labels_all, np.argmax(logits_all, axis=-1), average='binary')
75 |
76 | return acc, P, R, F1
77 |
78 | def evaluate(
79 | self,
80 | model: torch.nn.Module,
81 | plot_callbacks: List[Callable] = [],
82 | return_labels_logits: bool = False,
83 | return_votes: bool = False
84 | ) -> Tuple:
85 | """Collect evaluation metrics on this dataset
86 |
87 | :param model: The pytorch model to evaluate
88 | :param plot_callbacks: Optional function callbacks for plotting various things
89 | :return: (Loss, Accuracy, Precision, Recall, F1)
90 | """
91 | model.eval()
92 | with torch.no_grad():
93 | labels_all = []
94 | logits_all = []
95 | losses_all = []
96 | votes_all = []
97 | for batch in tqdm(self.dataloader, desc="Evaluation"):
98 | batch = tuple(t.to(self.device) for t in batch)
99 | input_ids = batch[0]
100 | masks = batch[1]
101 | labels = batch[2]
102 | domains = batch[3] if self.use_domain else None
103 | if self.use_labels:
104 | loss, logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels)
105 | if len(loss.size()) > 0:
106 | loss = loss.mean()
107 | else:
108 | (logits,) = model(input_ids, attention_mask=masks, domains=domains)
109 | loss = torch.FloatTensor([-1.])
110 | labels_all.extend(list(labels.detach().cpu().numpy()))
111 | logits_all.extend(list(logits.detach().cpu().numpy()))
112 | losses_all.append(loss.item())
113 | if hasattr(model, 'votes'):
114 | votes_all.extend(model.votes.detach().cpu().numpy())
115 |
116 | if not self.use_labels:
117 | return logits_all
118 |
119 | acc,P,R,F1 = acc_f1(logits_all, labels_all)
120 | loss = sum(losses_all) / len(losses_all)
121 |
122 | # Plotting
123 | plots = []
124 | for f in plot_callbacks:
125 | plots.append(f(labels_all, logits_all))
126 |
127 | ret_vals = (loss, acc, P, R, F1), plots
128 | if return_labels_logits:
129 | ret_vals = ret_vals + ((labels_all, logits_all),)
130 |
131 | if return_votes:
132 | if len(votes_all) > 0:
133 | ret_vals += (votes_all,)
134 | else:
135 | ret_vals += (list(np.argmax(np.asarray(logits_all), axis=1)),)
136 |
137 | return ret_vals
138 |
139 | class MultiDatasetClassificationEvaluator:
140 | """Wrapper to evaluate a model for the task of citation detection
141 |
142 | """
143 |
144 | def __init__(self, datasets: List[Dataset], device: torch.device, use_domain: bool = True):
145 | self.datasets = datasets
146 | self.dataloaders = [DataLoader(
147 | dataset,
148 | batch_size=8,
149 | collate_fn=collate_batch_transformer
150 | ) for dataset in datasets]
151 |
152 | self.device = device
153 | self.stored_labels = []
154 | self.stored_logits = []
155 | self.use_domain = use_domain
156 |
157 | def micro_f1(self) -> Tuple[float, float, float, float]:
158 | labels_all = self.stored_labels
159 | logits_all = self.stored_logits
160 |
161 | logits_all = np.asarray(logits_all).reshape(-1, len(logits_all[0]))
162 | labels_all = np.asarray(labels_all).reshape(-1)
163 | acc = accuracy(logits_all, labels_all)
164 | P, R, F1, _ = precision_recall_fscore_support(labels_all, np.argmax(logits_all, axis=-1), average='binary')
165 |
166 | return acc, P, R, F1
167 |
168 | def evaluate(
169 | self,
170 | model: torch.nn.Module,
171 | plot_callbacks: List[Callable] = [],
172 | return_labels_logits: bool = False,
173 | return_votes: bool = False
174 | ) -> Tuple:
175 | """Collect evaluation metrics on this dataset
176 |
177 | :param model: The pytorch model to evaluate
178 | :param plot_callbacks: Optional function callbacks for plotting various things
179 | :return: (Loss, Accuracy, Precision, Recall, F1)
180 | """
181 | model.eval()
182 | with torch.no_grad():
183 | labels_all = []
184 | logits_all = []
185 | losses_all = []
186 | votes_all = []
187 | for dataloader in self.dataloaders:
188 | for batch in tqdm(dataloader, desc="Evaluation"):
189 | batch = tuple(t.to(self.device) for t in batch)
190 | input_ids = batch[0]
191 | masks = batch[1]
192 | labels = batch[2]
193 | domains = batch[3] if self.use_domain else None
194 | loss, logits = model(input_ids, attention_mask=masks, domains=domains, labels=labels)
195 | if len(loss.size()) > 0:
196 | loss = loss.mean()
197 | labels_all.extend(list(labels.detach().cpu().numpy()))
198 | logits_all.extend(list(logits.detach().cpu().numpy()))
199 | losses_all.append(loss.item())
200 | if hasattr(model, 'votes'):
201 | votes_all.extend(model.votes.detach().cpu().numpy())
202 |
203 | # Use the domain with the lowest entropy
204 | if len(votes_all) > 0:
205 | votes_all = np.asarray(votes_all).transpose(0,1)
206 | entropies = [np.mean([entropy(v) for v in votes]) for votes in votes_all]
207 | domain = np.argmax(entropies)
208 | logits_all = votes_all[domain]
209 |
210 | acc,P,R,F1 = acc_f1(logits_all, labels_all)
211 | loss = sum(losses_all) / len(losses_all)
212 |
213 | # Plotting
214 | plots = []
215 | for f in plot_callbacks:
216 | plots.append(f(labels_all, logits_all))
217 |
218 | ret_vals = (loss, acc, P, R, F1), plots
219 | if return_labels_logits:
220 | ret_vals = ret_vals + ((labels_all, logits_all),)
221 |
222 | if return_votes:
223 | if len(votes_all) > 0:
224 | ret_vals += (votes_all,)
225 | else:
226 | ret_vals += (list(np.argmax(np.asarray(logits_all, axis=1))),)
227 |
228 | return ret_vals
229 |
230 |
231 | class DomainClassifierEvaluator:
232 | """Wrapper to evaluate a model for the task of citation detection
233 |
234 | """
235 |
236 | def __init__(self, dataset: Dataset, device: torch.device):
237 | self.dataset = dataset
238 | self.dataloader = DataLoader(
239 | dataset,
240 | batch_size=8,
241 | collate_fn=collate_batch_transformer
242 | )
243 | self.device = device
244 | self.stored_labels = []
245 | self.stored_logits = []
246 |
247 | def micro_f1(self) -> Tuple[float, float, float, float]:
248 | labels_all = self.stored_labels
249 | logits_all = self.stored_logits
250 |
251 | logits_all = np.asarray(logits_all).reshape(-1, len(logits_all[0]))
252 | labels_all = np.asarray(labels_all).reshape(-1)
253 | acc = accuracy(logits_all, labels_all)
254 | P, R, F1, _ = precision_recall_fscore_support(labels_all, np.argmax(logits_all, axis=-1), average='binary')
255 |
256 | return acc, P, R, F1
257 |
258 | def evaluate(
259 | self,
260 | model: torch.nn.Module,
261 | plot_callbacks: List[Callable] = [],
262 | return_labels_logits: bool = False,
263 | return_votes: bool = False
264 | ) -> Tuple:
265 | """Collect evaluation metrics on this dataset
266 |
267 | :param model: The pytorch model to evaluate
268 | :param plot_callbacks: Optional function callbacks for plotting various things
269 | :return: (Loss, Accuracy, Precision, Recall, F1)
270 | """
271 | model.eval()
272 | with torch.no_grad():
273 | labels_all = []
274 | logits_all = []
275 | losses_all = []
276 | votes_all = []
277 | for batch in tqdm(self.dataloader, desc="Evaluation"):
278 | batch = tuple(t.to(self.device) for t in batch)
279 | input_ids = batch[0]
280 | masks = batch[1]
281 | labels = batch[2]
282 | domains = batch[3]
283 | loss, logits = model(input_ids, attention_mask=masks, labels=domains)
284 |
285 | labels_all.extend(list(domains.detach().cpu().numpy()))
286 | logits_all.extend(list(logits.detach().cpu().numpy()))
287 | losses_all.append(loss.item())
288 | if hasattr(model, 'votes'):
289 | votes_all.extend(model.votes.detach().cpu().numpy())
290 |
291 | acc,P,R,F1 = acc_f1(logits_all, labels_all)
292 | loss = sum(losses_all) / len(losses_all)
293 |
294 | # Plotting
295 | plots = []
296 | for f in plot_callbacks:
297 | plots.append(f(labels_all, logits_all))
298 |
299 | ret_vals = (loss, acc, P, R, F1), plots
300 | if return_labels_logits:
301 | ret_vals = ret_vals + ((labels_all, logits_all),)
302 |
303 | if return_votes:
304 | if len(votes_all) > 0:
305 | ret_vals += (votes_all,)
306 | else:
307 | ret_vals += (list(np.argmax(np.asarray(logits_all), axis=1)),)
308 |
309 | return ret_vals
--------------------------------------------------------------------------------
/multisource-domain-adaptation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/copenlu/xformer-multi-source-domain-adaptation/be2d1a132298131df82fe40dd4f6c08dec8b3404/multisource-domain-adaptation.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fasttext==0.9.1
2 | ipdb==0.12.3
3 | ipython==7.12.0
4 | krippendorff==0.3.2
5 | matplotlib==3.1.3
6 | numpy==1.18.1
7 | pandas==1.0.1
8 | scikit-learn==0.22.1
9 | torch==1.4.0
10 | torchvision==0.5.0
11 | tqdm==4.42.1
12 | transformers==2.8.0
13 | wandb==0.8.27
14 |
--------------------------------------------------------------------------------
/run_claim_experiments.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | . activate xformer-multisource-domain-adaptation
4 |
5 | . setenv.sh
6 |
7 | run_name="(emnlp-claim)"
8 | model_dir="wandb_local/emnlp_claim_experiments"
9 | tags="emnlp claim experiments"
10 | for i in 1000,1 1001,2 666,3 7,4 50,5; do IFS=","; set -- $i;
11 | j=`expr ${2} - 1`
12 |
13 | # 1) Basic
14 | python emnlp_final_experiments/claim-detection/train_basic.py \
15 | --dataset_loc data/PHEME \
16 | --train_pct 0.9 \
17 | --n_gpu 1 \
18 | --n_epochs 5 \
19 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \
20 | --seed ${1} \
21 | --run_name "basic-distilbert-${2}" \
22 | --model_dir ${model_dir}/basic_distilbert \
23 | --tags ${tags} \
24 | --batch_size 8 \
25 | --lr 0.00003 \
26 | indices_dir=`ls -d -t ${model_dir}/basic_distilbert/*/ | head -1`
27 |
28 | # 2) Adv-6
29 | python emnlp_final_experiments/claim-detection/train_basic_domain_adversarial.py \
30 | --dataset_loc data/PHEME \
31 | --train_pct 0.9 \
32 | --n_gpu 1 \
33 | --n_epochs 5 \
34 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \
35 | --seed ${1} \
36 | --run_name "distilbert-adversarial-6-${2}" \
37 | --model_dir ${model_dir}/distilbert_adversarial_6 \
38 | --tags ${tags} \
39 | --batch_size 8 \
40 | --lr 0.00003 \
41 | --supervision_layer 6 \
42 | --indices_dir ${indices_dir}
43 |
44 | # 3) Adv-3
45 | python emnlp_final_experiments/claim-detection/train_basic_domain_adversarial.py \
46 | --dataset_loc data/PHEME \
47 | --train_pct 0.9 \
48 | --n_gpu 1 \
49 | --n_epochs 5 \
50 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \
51 | --seed ${1} \
52 | --run_name "distilbert-adversarial-3-${2}" \
53 | --model_dir ${model_dir}/distilbert_adversarial_3 \
54 | --tags ${tags} \
55 | --batch_size 8 \
56 | --lr 0.00003 \
57 | --supervision_layer 3 \
58 | --indices_dir ${indices_dir}
59 |
60 | # 4) Independent-Avg
61 | python emnlp_final_experiments/claim-detection/train_multi_view_averaging_individuals.py \
62 | --dataset_loc data/PHEME \
63 | --train_pct 0.9 \
64 | --n_gpu 1 \
65 | --n_epochs 5 \
66 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \
67 | --seed ${1} \
68 | --run_name "distilbert-ensemble-averaging-individuals-${2}" \
69 | --model_dir ${model_dir}/distilbert_ensemble_averaging_individuals \
70 | --tags ${tags} \
71 | --batch_size 8 \
72 | --lr 0.00003 \
73 | --indices_dir ${indices_dir}
74 | avg_model=`ls -d -t ${model_dir}/distilbert_ensemble_averaging_individuals/*/ | head -1`
75 |
76 | # 5) Independent-Ft
77 | python emnlp_final_experiments/claim-detection/train_multi_view_selective_weighting.py \
78 | --dataset_loc data/PHEME \
79 | --train_pct 0.9 \
80 | --n_gpu 1 \
81 | --n_epochs 30 \
82 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \
83 | --seed ${1} \
84 | --run_name "distilbert-ensemble-selective-attention-${2}" \
85 | --model_dir ${model_dir}/distilbert_ensemble_selective_attention \
86 | --tags ${tags} \
87 | --pretrained_model ${avg_model} \
88 | --indices_dir ${indices_dir}
89 |
90 | # 6) MoE-DC
91 | python emnlp_final_experiments/claim-detection/train_multi_view_domainclassifier_individuals.py \
92 | --dataset_loc data/PHEME \
93 | --train_pct 0.9 \
94 | --n_gpu 1 \
95 | --n_epochs 5 \
96 | --n_dc_epochs 5 \
97 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \
98 | --seed ${1} \
99 | --run_name "distilbert-ensemble-domainclassifier-individuals-${2}" \
100 | --model_dir ${model_dir}/distilbert_ensemble_domainclassifier_individuals \
101 | --tags ${tags} \
102 | --batch_size 8 \
103 | --lr 0.00003 \
104 | --indices_dir ${indices_dir} \
105 | --pretrained_model ${avg_model}
106 |
107 | # 7) MoE-Avg
108 | python emnlp_final_experiments/claim-detection/train_multi_view.py \
109 | --dataset_loc data/PHEME \
110 | --train_pct 0.9 \
111 | --n_gpu 1 \
112 | --n_epochs 5 \
113 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \
114 | --seed ${1} \
115 | --run_name "distilbert-ensemble-averaging-${2}" \
116 | --model_dir ${model_dir}/distilbert_ensemble_averaging \
117 | --tags ${tags} \
118 | --batch_size 8 \
119 | --lr 0.00003 \
120 | --ensemble_basic \
121 | --indices_dir ${indices_dir}
122 |
123 | # 8) MoE-Att
124 | python emnlp_final_experiments/claim-detection/train_multi_view.py \
125 | --dataset_loc data/PHEME \
126 | --train_pct 0.9 \
127 | --n_gpu 1 \
128 | --n_epochs 5 \
129 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \
130 | --seed ${1} \
131 | --run_name "distilbert-ensemble-attention-${2}" \
132 | --model_dir ${model_dir}/distilbert_ensemble_attention \
133 | --tags ${tags} \
134 | --batch_size 8 \
135 | --lr 0.00003 \
136 | --indices_dir ${indices_dir}
137 |
138 | # 9) MoE-Att-Adv-6
139 | python emnlp_final_experiments/claim-detection/train_multi_view_domain_adversarial.py \
140 | --dataset_loc data/PHEME \
141 | --train_pct 0.9 \
142 | --n_gpu 1 \
143 | --n_epochs 5 \
144 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \
145 | --seed ${1} \
146 | --run_name "distilbert-ensemble-attention-adversarial-6-${2}" \
147 | --model_dir ${model_dir}/distilbert_ensemble_attention_adversarial_6 \
148 | --tags ${tags} \
149 | --batch_size 8 \
150 | --lr 0.00003 \
151 | --supervision_layer 6 \
152 | --indices_dir ${indices_dir}
153 |
154 | # 10) MoE-Att-Adv-3
155 | python emnlp_final_experiments/claim-detection/train_multi_view_domain_adversarial.py \
156 | --dataset_loc data/PHEME \
157 | --train_pct 0.9 \
158 | --n_gpu 1 \
159 | --n_epochs 5 \
160 | --domains charliehebdo ferguson germanwings-crash ottawashooting sydneysiege \
161 | --seed ${1} \
162 | --run_name "distilbert-ensemble-attention-adversarial-3-${2}" \
163 | --model_dir ${model_dir}/distilbert_ensemble_attention_adversarial_3 \
164 | --tags ${tags} \
165 | --batch_size 8 \
166 | --lr 0.00003 \
167 | --supervision_layer 3 \
168 | --indices_dir ${indices_dir}
169 |
170 | done
171 |
--------------------------------------------------------------------------------
/run_sentiment_experiments.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | . activate xformer-multisource-domain-adaptation
4 |
5 | . setenv.sh
6 |
7 | run_name="(emnlp-sentiment)"
8 | model_dir="wandb_local/emnlp_sentiment_experiments"
9 | tags="emnlp sentiment experiments"
10 | for i in 1000,1 1001,2 666,3 7,4 50,5; do IFS=","; set -- $i;
11 | # 1) Basic
12 | python emnlp_final_experiments/sentiment-analysis/train_basic.py \
13 | --dataset_loc data/sentiment-dataset \
14 | --train_pct 0.9 \
15 | --n_gpu 1 \
16 | --n_epochs 5 \
17 | --domains books dvd electronics kitchen_\&_housewares \
18 | --seed ${1} \
19 | --run_name "basic-distilbert-${2}" \
20 | --model_dir ${model_dir}/basic_distilbert \
21 | --tags ${tags} \
22 | --batch_size 8 \
23 | --lr 0.00003
24 | indices_dir=`ls -d -t ${model_dir}/basic_distilbert/*/ | head -1`
25 |
26 | # 2) Adv-6
27 | python emnlp_final_experiments/sentiment-analysis/train_basic_domain_adversarial.py \
28 | --dataset_loc data/sentiment-dataset \
29 | --train_pct 0.9 \
30 | --n_gpu 1 \
31 | --n_epochs 5 \
32 | --domains books dvd electronics kitchen_\&_housewares \
33 | --seed ${1} \
34 | --run_name "distilbert-adversarial-6-${2}" \
35 | --model_dir ${model_dir}/distilbert_adversarial_6 \
36 | --tags ${tags} \
37 | --batch_size 8 \
38 | --lr 0.00003 \
39 | --supervision_layer 6 \
40 | --indices_dir ${indices_dir}
41 |
42 | # 3) Adv-3
43 | python emnlp_final_experiments/sentiment-analysis/train_basic_domain_adversarial.py \
44 | --dataset_loc data/sentiment-dataset \
45 | --train_pct 0.9 \
46 | --n_gpu 1 \
47 | --n_epochs 5 \
48 | --domains books dvd electronics kitchen_\&_housewares \
49 | --seed ${1} \
50 | --run_name "distilbert-adversarial-3-${2}" \
51 | --model_dir ${model_dir}/distilbert_adversarial_3 \
52 | --tags ${tags} \
53 | --batch_size 8 \
54 | --lr 0.00003 \
55 | --supervision_layer 3 \
56 | --indices_dir ${indices_dir}
57 |
58 | # 4) Independent-Avg
59 | python emnlp_final_experiments/sentiment-analysis/train_multi_view_averaging_individuals.py \
60 | --dataset_loc data/sentiment-dataset \
61 | --train_pct 0.9 \
62 | --n_gpu 1 \
63 | --n_epochs 5 \
64 | --domains books dvd electronics kitchen_\&_housewares \
65 | --seed ${1} \
66 | --run_name "distilbert-ensemble-averaging-individuals-${2}" \
67 | --model_dir ${model_dir}/distilbert_ensemble_averaging_individuals \
68 | --tags ${tags} \
69 | --batch_size 8 \
70 | --lr 0.00003 \
71 | --indices_dir ${indices_dir}
72 | avg_model=`ls -d -t ${model_dir}/distilbert_ensemble_averaging_individuals/*/ | head -1`
73 |
74 | # 5) Independent-Ft
75 | python emnlp_final_experiments/sentiment-analysis/train_multi_view_selective_weighting.py \
76 | --dataset_loc data/sentiment-dataset \
77 | --train_pct 0.9 \
78 | --n_gpu 1 \
79 | --n_epochs 30 \
80 | --domains books dvd electronics kitchen_\&_housewares \
81 | --seed ${1} \
82 | --run_name "distilbert-ensemble-selective-attention-${2}" \
83 | --model_dir ${model_dir}/distilbert_ensemble_selective_attention \
84 | --tags ${tags} \
85 | --pretrained_model ${avg_model} \
86 | --indices_dir ${indices_dir}
87 |
88 | # 6) MoE-DC
89 | python emnlp_final_experiments/sentiment-analysis/train_multi_view_domainclassifier_individuals.py \
90 | --dataset_loc data/sentiment-dataset \
91 | --train_pct 0.9 \
92 | --n_gpu 1 \
93 | --n_epochs 5 \
94 | --domains books dvd electronics kitchen_\&_housewares \
95 | --seed ${1} \
96 | --run_name "distilbert-ensemble-domainclassifier-individuals-${2}" \
97 | --model_dir ${model_dir}/distilbert_ensemble_domainclassifier_individuals \
98 | --tags ${tags} \
99 | --batch_size 8 \
100 | --lr 0.00003 \
101 | --indices_dir ${indices_dir} \
102 | --pretrained_model ${avg_model}
103 |
104 | # 7) MoE-Avg
105 | python emnlp_final_experiments/sentiment-analysis/train_multi_view.py \
106 | --dataset_loc data/sentiment-dataset \
107 | --train_pct 0.9 \
108 | --n_gpu 1 \
109 | --n_epochs 5 \
110 | --domains books dvd electronics kitchen_\&_housewares \
111 | --seed ${1} \
112 | --run_name "distilbert-ensemble-averaging-${2}" \
113 | --model_dir ${model_dir}/distilbert_ensemble_averaging \
114 | --tags ${tags} \
115 | --batch_size 8 \
116 | --lr 0.00003 \
117 | --ensemble_basic \
118 | --indices_dir ${indices_dir}
119 |
120 | # 8) MoE-Att
121 | python emnlp_final_experiments/sentiment-analysis/train_multi_view.py \
122 | --dataset_loc data/sentiment-dataset \
123 | --train_pct 0.9 \
124 | --n_gpu 1 \
125 | --n_epochs 5 \
126 | --domains books dvd electronics kitchen_\&_housewares \
127 | --seed ${1} \
128 | --run_name "distilbert-ensemble-attention-${2}" \
129 | --model_dir ${model_dir}/distilbert_ensemble_attention \
130 | --tags ${tags} \
131 | --batch_size 8 \
132 | --lr 0.00003 \
133 | --indices_dir ${indices_dir}
134 |
135 | # 9) MoE-Att-Adv-6
136 | python emnlp_final_experiments/sentiment-analysis/train_multi_view_domain_adversarial.py \
137 | --dataset_loc data/sentiment-dataset \
138 | --train_pct 0.9 \
139 | --n_gpu 1 \
140 | --n_epochs 5 \
141 | --domains books dvd electronics kitchen_\&_housewares \
142 | --seed ${1} \
143 | --run_name "distilbert-ensemble-attention-adversarial-6-${2}" \
144 | --model_dir ${model_dir}/distilbert_ensemble_attention_adversarial_6 \
145 | --tags ${tags} \
146 | --batch_size 8 \
147 | --lr 0.00003 \
148 | --supervision_layer 6 \
149 | --indices_dir ${indices_dir}
150 |
151 | # 10) MoE-Att-Adv-3
152 | python emnlp_final_experiments/sentiment-analysis/train_multi_view_domain_adversarial.py \
153 | --dataset_loc data/sentiment-dataset \
154 | --train_pct 0.9 \
155 | --n_gpu 1 \
156 | --n_epochs 5 \
157 | --domains books dvd electronics kitchen_\&_housewares \
158 | --seed ${1} \
159 | --run_name "distilbert-ensemble-attention-adversarial-3-${2}" \
160 | --model_dir ${model_dir}/distilbert_ensemble_attention_adversarial_4 \
161 | --tags ${tags} \
162 | --batch_size 8 \
163 | --lr 0.00003 \
164 | --supervision_layer 3 \
165 | --indices_dir ${indices_dir}
166 |
167 | done
168 |
--------------------------------------------------------------------------------
/setenv.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=`pwd`:$PYTHONPATH
2 |
--------------------------------------------------------------------------------