├── .DS_Store
├── .idea
├── TextFooler.iml
├── deployment.xml
├── encodings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
├── webServers.xml
└── workspace.xml
├── BERT
├── __init__.py
├── extract_features.py
├── file_utils.py
├── modeling.py
├── optimization.py
├── run_classifier.py
├── run_classifier_AG.py
├── run_classifier_Fake.py
├── run_classifier_IMDB.py
├── run_classifier_MR.py
├── run_classifier_Yelp.py
├── run_classifier_mnli.py
├── run_classifier_snli.py
└── tokenization.py
├── ESIM
├── .DS_Store
├── esim
│ ├── __init__.py
│ ├── data.py
│ ├── layers.py
│ ├── model.py
│ └── utils.py
├── scripts
│ ├── .DS_Store
│ ├── fetch_data.py
│ ├── preprocessing
│ │ ├── preprocess_bnli.py
│ │ ├── preprocess_mnli.py
│ │ └── preprocess_snli.py
│ ├── testing
│ │ ├── test_mnli.py
│ │ └── test_snli.py
│ └── training
│ │ ├── train_mnli.py
│ │ ├── train_snli.py
│ │ └── utils.py
└── setup.py
├── InferSent
└── models.py
├── LICENSE
├── README.md
├── attack_classification.py
├── attack_nli.py
├── comp_cos_sim_mat.py
├── criteria.py
├── data
├── ag
├── fake
├── imdb
├── mnli
├── mnli_matched
├── mnli_mismatched
├── mr
├── snli
└── yelp
├── dataloader.py
├── modules.py
├── requirements.txt
├── run_attack_classification.py
├── run_attack_nli.py
└── train_classifier.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/TextFooler/6a56cb10f1055e00a8fbc6882f289e91cfe60f4e/.DS_Store
--------------------------------------------------------------------------------
/.idea/TextFooler.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 | 1567562323149
109 |
110 |
111 | 1567562323149
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
--------------------------------------------------------------------------------
/BERT/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.4.0"
2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
3 | from .modeling import (BertConfig, BertModel, BertForPreTraining,
4 | BertForMaskedLM, BertForNextSentencePrediction,
5 | BertForSequenceClassification, BertForMultipleChoice,
6 | BertForTokenClassification, BertForQuestionAnswering)
7 | from .optimization import BertAdam
8 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
9 |
--------------------------------------------------------------------------------
/BERT/extract_features.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Extract pre-computed feature vectors from a PyTorch BERT model."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import argparse
22 | import collections
23 | import logging
24 | import json
25 | import re
26 |
27 | import torch
28 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
29 | from torch.utils.data.distributed import DistributedSampler
30 |
31 | from pytorch_pretrained_bert.tokenization import BertTokenizer
32 | from pytorch_pretrained_bert.modeling import BertModel
33 |
34 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
35 | datefmt = '%m/%d/%Y %H:%M:%S',
36 | level = logging.INFO)
37 | logger = logging.getLogger(__name__)
38 |
39 |
40 | class InputExample(object):
41 |
42 | def __init__(self, unique_id, text_a, text_b):
43 | self.unique_id = unique_id
44 | self.text_a = text_a
45 | self.text_b = text_b
46 |
47 |
48 | class InputFeatures(object):
49 | """A single set of features of data."""
50 |
51 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
52 | self.unique_id = unique_id
53 | self.tokens = tokens
54 | self.input_ids = input_ids
55 | self.input_mask = input_mask
56 | self.input_type_ids = input_type_ids
57 |
58 |
59 | def convert_examples_to_features(examples, seq_length, tokenizer):
60 | """Loads a data file into a list of `InputBatch`s."""
61 |
62 | features = []
63 | for (ex_index, example) in enumerate(examples):
64 | tokens_a = tokenizer.tokenize(example.text_a)
65 |
66 | tokens_b = None
67 | if example.text_b:
68 | tokens_b = tokenizer.tokenize(example.text_b)
69 |
70 | if tokens_b:
71 | # Modifies `tokens_a` and `tokens_b` in place so that the total
72 | # length is less than the specified length.
73 | # Account for [CLS], [SEP], [SEP] with "- 3"
74 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
75 | else:
76 | # Account for [CLS] and [SEP] with "- 2"
77 | if len(tokens_a) > seq_length - 2:
78 | tokens_a = tokens_a[0:(seq_length - 2)]
79 |
80 | # The convention in BERT is:
81 | # (a) For sequence pairs:
82 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
83 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
84 | # (b) For single sequences:
85 | # tokens: [CLS] the dog is hairy . [SEP]
86 | # type_ids: 0 0 0 0 0 0 0
87 | #
88 | # Where "type_ids" are used to indicate whether this is the first
89 | # sequence or the second sequence. The embedding vectors for `type=0` and
90 | # `type=1` were learned during pre-training and are added to the wordpiece
91 | # embedding vector (and position vector). This is not *strictly* necessary
92 | # since the [SEP] token unambigiously separates the sequences, but it makes
93 | # it easier for the model to learn the concept of sequences.
94 | #
95 | # For classification tasks, the first vector (corresponding to [CLS]) is
96 | # used as as the "sentence vector". Note that this only makes sense because
97 | # the entire model is fine-tuned.
98 | tokens = []
99 | input_type_ids = []
100 | tokens.append("[CLS]")
101 | input_type_ids.append(0)
102 | for token in tokens_a:
103 | tokens.append(token)
104 | input_type_ids.append(0)
105 | tokens.append("[SEP]")
106 | input_type_ids.append(0)
107 |
108 | if tokens_b:
109 | for token in tokens_b:
110 | tokens.append(token)
111 | input_type_ids.append(1)
112 | tokens.append("[SEP]")
113 | input_type_ids.append(1)
114 |
115 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
116 |
117 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
118 | # tokens are attended to.
119 | input_mask = [1] * len(input_ids)
120 |
121 | # Zero-pad up to the sequence length.
122 | while len(input_ids) < seq_length:
123 | input_ids.append(0)
124 | input_mask.append(0)
125 | input_type_ids.append(0)
126 |
127 | assert len(input_ids) == seq_length
128 | assert len(input_mask) == seq_length
129 | assert len(input_type_ids) == seq_length
130 |
131 | if ex_index < 5:
132 | logger.info("*** Example ***")
133 | logger.info("unique_id: %s" % (example.unique_id))
134 | logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
135 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
136 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
137 | logger.info(
138 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
139 |
140 | features.append(
141 | InputFeatures(
142 | unique_id=example.unique_id,
143 | tokens=tokens,
144 | input_ids=input_ids,
145 | input_mask=input_mask,
146 | input_type_ids=input_type_ids))
147 | return features
148 |
149 |
150 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
151 | """Truncates a sequence pair in place to the maximum length."""
152 |
153 | # This is a simple heuristic which will always truncate the longer sequence
154 | # one token at a time. This makes more sense than truncating an equal percent
155 | # of tokens from each, since if one sequence is very short then each token
156 | # that's truncated likely contains more information than a longer sequence.
157 | while True:
158 | total_length = len(tokens_a) + len(tokens_b)
159 | if total_length <= max_length:
160 | break
161 | if len(tokens_a) > len(tokens_b):
162 | tokens_a.pop()
163 | else:
164 | tokens_b.pop()
165 |
166 |
167 | def read_examples(input_file):
168 | """Read a list of `InputExample`s from an input file."""
169 | examples = []
170 | unique_id = 0
171 | with open(input_file, "r", encoding='utf-8') as reader:
172 | while True:
173 | line = reader.readline()
174 | if not line:
175 | break
176 | line = line.strip()
177 | text_a = None
178 | text_b = None
179 | m = re.match(r"^(.*) \|\|\| (.*)$", line)
180 | if m is None:
181 | text_a = line
182 | else:
183 | text_a = m.group(1)
184 | text_b = m.group(2)
185 | examples.append(
186 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
187 | unique_id += 1
188 | return examples
189 |
190 |
191 | def main():
192 | parser = argparse.ArgumentParser()
193 |
194 | ## Required parameters
195 | parser.add_argument("--input_file", default=None, type=str, required=True)
196 | parser.add_argument("--output_file", default=None, type=str, required=True)
197 | parser.add_argument("--bert_model", default=None, type=str, required=True,
198 | help="Bert pre-trained model selected in the list: bert-base-uncased, "
199 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
200 |
201 | ## Other parameters
202 | parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
203 | parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
204 | parser.add_argument("--max_seq_length", default=128, type=int,
205 | help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
206 | "than this will be truncated, and sequences shorter than this will be padded.")
207 | parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
208 | parser.add_argument("--local_rank",
209 | type=int,
210 | default=-1,
211 | help = "local_rank for distributed training on gpus")
212 | parser.add_argument("--no_cuda",
213 | action='store_true',
214 | help="Whether not to use CUDA when available")
215 |
216 | args = parser.parse_args()
217 |
218 | if args.local_rank == -1 or args.no_cuda:
219 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
220 | n_gpu = torch.cuda.device_count()
221 | else:
222 | device = torch.device("cuda", args.local_rank)
223 | n_gpu = 1
224 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
225 | torch.distributed.init_process_group(backend='nccl')
226 | logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1)))
227 |
228 | layer_indexes = [int(x) for x in args.layers.split(",")]
229 |
230 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
231 |
232 | examples = read_examples(args.input_file)
233 |
234 | features = convert_examples_to_features(
235 | examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer)
236 |
237 | unique_id_to_feature = {}
238 | for feature in features:
239 | unique_id_to_feature[feature.unique_id] = feature
240 |
241 | model = BertModel.from_pretrained(args.bert_model)
242 | model.to(device)
243 |
244 | if args.local_rank != -1:
245 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
246 | output_device=args.local_rank)
247 | elif n_gpu > 1:
248 | model = torch.nn.DataParallel(model)
249 |
250 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
251 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
252 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
253 |
254 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index)
255 | if args.local_rank == -1:
256 | eval_sampler = SequentialSampler(eval_data)
257 | else:
258 | eval_sampler = DistributedSampler(eval_data)
259 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)
260 |
261 | model.eval()
262 | with open(args.output_file, "w", encoding='utf-8') as writer:
263 | for input_ids, input_mask, example_indices in eval_dataloader:
264 | input_ids = input_ids.to(device)
265 | input_mask = input_mask.to(device)
266 |
267 | all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask)
268 | all_encoder_layers = all_encoder_layers
269 |
270 | for b, example_index in enumerate(example_indices):
271 | feature = features[example_index.item()]
272 | unique_id = int(feature.unique_id)
273 | # feature = unique_id_to_feature[unique_id]
274 | output_json = collections.OrderedDict()
275 | output_json["linex_index"] = unique_id
276 | all_out_features = []
277 | for (i, token) in enumerate(feature.tokens):
278 | all_layers = []
279 | for (j, layer_index) in enumerate(layer_indexes):
280 | layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy()
281 | layer_output = layer_output[b]
282 | layers = collections.OrderedDict()
283 | layers["index"] = layer_index
284 | layers["values"] = [
285 | round(x.item(), 6) for x in layer_output[i]
286 | ]
287 | all_layers.append(layers)
288 | out_features = collections.OrderedDict()
289 | out_features["token"] = token
290 | out_features["layers"] = all_layers
291 | all_out_features.append(out_features)
292 | output_json["features"] = all_out_features
293 | writer.write(json.dumps(output_json) + "\n")
294 |
295 |
296 | if __name__ == "__main__":
297 | main()
--------------------------------------------------------------------------------
/BERT/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 | from __future__ import (absolute_import, division, print_function, unicode_literals)
7 |
8 | import sys
9 | import json
10 | import logging
11 | import os
12 | import shutil
13 | import tempfile
14 | import fnmatch
15 | from functools import wraps
16 | from hashlib import sha256
17 | import sys
18 | from io import open
19 |
20 | import boto3
21 | import requests
22 | from botocore.exceptions import ClientError
23 | from tqdm import tqdm
24 |
25 | try:
26 | from torch.hub import _get_torch_home
27 | torch_cache_home = _get_torch_home()
28 | except ImportError:
29 | torch_cache_home = os.path.expanduser(
30 | os.getenv('TORCH_HOME', os.path.join(
31 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
32 | default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert')
33 |
34 | try:
35 | from urllib.parse import urlparse
36 | except ImportError:
37 | from urlparse import urlparse
38 |
39 | try:
40 | from pathlib import Path
41 | PYTORCH_PRETRAINED_BERT_CACHE = Path(
42 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))
43 | except (AttributeError, ImportError):
44 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
45 | default_cache_path)
46 |
47 | CONFIG_NAME = "bert_config.json"
48 | WEIGHTS_NAME = "pytorch_model.bin"
49 |
50 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
51 |
52 |
53 | def url_to_filename(url, etag=None):
54 | """
55 | Convert `url` into a hashed filename in a repeatable way.
56 | If `etag` is specified, append its hash to the url's, delimited
57 | by a period.
58 | """
59 | url_bytes = url.encode('utf-8')
60 | url_hash = sha256(url_bytes)
61 | filename = url_hash.hexdigest()
62 |
63 | if etag:
64 | etag_bytes = etag.encode('utf-8')
65 | etag_hash = sha256(etag_bytes)
66 | filename += '.' + etag_hash.hexdigest()
67 |
68 | return filename
69 |
70 |
71 | def filename_to_url(filename, cache_dir=None):
72 | """
73 | Return the url and etag (which may be ``None``) stored for `filename`.
74 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
75 | """
76 | if cache_dir is None:
77 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
78 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
79 | cache_dir = str(cache_dir)
80 |
81 | cache_path = os.path.join(cache_dir, filename)
82 | if not os.path.exists(cache_path):
83 | raise EnvironmentError("file {} not found".format(cache_path))
84 |
85 | meta_path = cache_path + '.json'
86 | if not os.path.exists(meta_path):
87 | raise EnvironmentError("file {} not found".format(meta_path))
88 |
89 | with open(meta_path, encoding="utf-8") as meta_file:
90 | metadata = json.load(meta_file)
91 | url = metadata['url']
92 | etag = metadata['etag']
93 |
94 | return url, etag
95 |
96 |
97 | def cached_path(url_or_filename, cache_dir=None):
98 | """
99 | Given something that might be a URL (or might be a local path),
100 | determine which. If it's a URL, download the file and cache it, and
101 | return the path to the cached file. If it's already a local path,
102 | make sure the file exists and then return the path.
103 | """
104 | if cache_dir is None:
105 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
106 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
107 | url_or_filename = str(url_or_filename)
108 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
109 | cache_dir = str(cache_dir)
110 |
111 | parsed = urlparse(url_or_filename)
112 |
113 | if parsed.scheme in ('http', 'https', 's3'):
114 | # URL, so get it from the cache (downloading if necessary)
115 | return get_from_cache(url_or_filename, cache_dir)
116 | elif os.path.exists(url_or_filename):
117 | # File, and it exists.
118 | return url_or_filename
119 | elif parsed.scheme == '':
120 | # File, but it doesn't exist.
121 | raise EnvironmentError("file {} not found".format(url_or_filename))
122 | else:
123 | # Something unknown
124 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
125 |
126 |
127 | def split_s3_path(url):
128 | """Split a full s3 path into the bucket name and path."""
129 | parsed = urlparse(url)
130 | if not parsed.netloc or not parsed.path:
131 | raise ValueError("bad s3 path {}".format(url))
132 | bucket_name = parsed.netloc
133 | s3_path = parsed.path
134 | # Remove '/' at beginning of path.
135 | if s3_path.startswith("/"):
136 | s3_path = s3_path[1:]
137 | return bucket_name, s3_path
138 |
139 |
140 | def s3_request(func):
141 | """
142 | Wrapper function for s3 requests in order to create more helpful error
143 | messages.
144 | """
145 |
146 | @wraps(func)
147 | def wrapper(url, *args, **kwargs):
148 | try:
149 | return func(url, *args, **kwargs)
150 | except ClientError as exc:
151 | if int(exc.response["Error"]["Code"]) == 404:
152 | raise EnvironmentError("file {} not found".format(url))
153 | else:
154 | raise
155 |
156 | return wrapper
157 |
158 |
159 | @s3_request
160 | def s3_etag(url):
161 | """Check ETag on S3 object."""
162 | s3_resource = boto3.resource("s3")
163 | bucket_name, s3_path = split_s3_path(url)
164 | s3_object = s3_resource.Object(bucket_name, s3_path)
165 | return s3_object.e_tag
166 |
167 |
168 | @s3_request
169 | def s3_get(url, temp_file):
170 | """Pull a file directly from S3."""
171 | s3_resource = boto3.resource("s3")
172 | bucket_name, s3_path = split_s3_path(url)
173 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
174 |
175 |
176 | def http_get(url, temp_file):
177 | req = requests.get(url, stream=True)
178 | content_length = req.headers.get('Content-Length')
179 | total = int(content_length) if content_length is not None else None
180 | progress = tqdm(unit="B", total=total)
181 | for chunk in req.iter_content(chunk_size=1024):
182 | if chunk: # filter out keep-alive new chunks
183 | progress.update(len(chunk))
184 | temp_file.write(chunk)
185 | progress.close()
186 |
187 |
188 | def get_from_cache(url, cache_dir=None):
189 | """
190 | Given a URL, look for the corresponding dataset in the local cache.
191 | If it's not there, download it. Then return the path to the cached file.
192 | """
193 | if cache_dir is None:
194 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
195 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
196 | cache_dir = str(cache_dir)
197 |
198 | if not os.path.exists(cache_dir):
199 | os.makedirs(cache_dir)
200 |
201 | # Get eTag to add to filename, if it exists.
202 | if url.startswith("s3://"):
203 | etag = s3_etag(url)
204 | else:
205 | try:
206 | response = requests.head(url, allow_redirects=True)
207 | if response.status_code != 200:
208 | etag = None
209 | else:
210 | etag = response.headers.get("ETag")
211 | except EnvironmentError:
212 | etag = None
213 |
214 | if sys.version_info[0] == 2 and etag is not None:
215 | etag = etag.decode('utf-8')
216 | filename = url_to_filename(url, etag)
217 |
218 | # get cache path to put the file
219 | cache_path = os.path.join(cache_dir, filename)
220 |
221 | # If we don't have a connection (etag is None) and can't identify the file
222 | # try to get the last downloaded one
223 | if not os.path.exists(cache_path) and etag is None:
224 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
225 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
226 | if matching_files:
227 | cache_path = os.path.join(cache_dir, matching_files[-1])
228 |
229 | if not os.path.exists(cache_path):
230 | # Download to temporary file, then copy to cache dir once finished.
231 | # Otherwise you get corrupt cache entries if the download gets interrupted.
232 | with tempfile.NamedTemporaryFile() as temp_file:
233 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
234 |
235 | # GET file object
236 | if url.startswith("s3://"):
237 | s3_get(url, temp_file)
238 | else:
239 | http_get(url, temp_file)
240 |
241 | # we are copying the file before closing it, so flush to avoid truncation
242 | temp_file.flush()
243 | # shutil.copyfileobj() starts at the current position, so go to the start
244 | temp_file.seek(0)
245 |
246 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
247 | with open(cache_path, 'wb') as cache_file:
248 | shutil.copyfileobj(temp_file, cache_file)
249 |
250 | logger.info("creating metadata file for %s", cache_path)
251 | meta = {'url': url, 'etag': etag}
252 | meta_path = cache_path + '.json'
253 | with open(meta_path, 'w') as meta_file:
254 | output_string = json.dumps(meta)
255 | if sys.version_info[0] == 2 and isinstance(output_string, str):
256 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2
257 | meta_file.write(output_string)
258 |
259 | logger.info("removing temp file %s", temp_file.name)
260 |
261 | return cache_path
262 |
263 |
264 | def read_set_from_file(filename):
265 | '''
266 | Extract a de-duped collection (set) of text from a file.
267 | Expected file format is one item per line.
268 | '''
269 | collection = set()
270 | with open(filename, 'r', encoding='utf-8') as file_:
271 | for line in file_:
272 | collection.add(line.rstrip())
273 | return collection
274 |
275 |
276 | def get_file_extension(path, dot=True, lower=True):
277 | ext = os.path.splitext(path)[1]
278 | ext = ext if dot else ext[1:]
279 | return ext.lower() if lower else ext
--------------------------------------------------------------------------------
/BERT/run_classifier_AG.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | command = 'python run_classifier.py --data_dir /afs/csail.mit.edu/u/z/zhijing/proj/to_di/data/ag ' \
4 | '--bert_model bert-base-uncased ' \
5 | '--task_name ag --output_dir results/ag --cache_dir pytorch_cache --do_train --do_eval --do_lower_case '
6 |
7 | os.system(command)
--------------------------------------------------------------------------------
/BERT/run_classifier_Fake.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | command = 'python run_classifier.py --data_dir /afs/csail.mit.edu/u/z/zhijing/proj/to_di/data/fake ' \
4 | '--bert_model bert-base-uncased --max_seq_length 256 --train_batch_size 16 ' \
5 | '--task_name fake --output_dir results/fake --cache_dir pytorch_cache --do_train --do_eval --do_lower_case '
6 |
7 | os.system(command)
--------------------------------------------------------------------------------
/BERT/run_classifier_IMDB.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | command = 'python run_classifier.py --data_dir /data/medg/misc/jindi/nlp/datasets/imdb ' \
4 | '--bert_model bert-base-uncased --max_seq_length 256 --train_batch_size 32 ' \
5 | '--task_name imdb --output_dir results/imdb --cache_dir pytorch_cache --do_train --do_eval --do_lower_case ' \
6 | '--num_train_epochs 3.'
7 |
8 | os.system(command)
--------------------------------------------------------------------------------
/BERT/run_classifier_MR.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | command = 'python run_classifier.py --data_dir /data/medg/misc/jindi/nlp/datasets/mr ' \
4 | '--bert_model bert-base-uncased ' \
5 | '--task_name mr --output_dir results/mr_retrain --cache_dir pytorch_cache --do_train --do_eval ' \
6 | '--do_lower_case '
7 |
8 | os.system(command)
--------------------------------------------------------------------------------
/BERT/run_classifier_Yelp.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | command = 'python run_classifier.py --data_dir /afs/csail.mit.edu/u/z/zhijing/proj/to_di/data/yelp ' \
4 | '--bert_model bert-base-uncased --max_seq_length 128 --train_batch_size 32 ' \
5 | '--task_name yelp --output_dir results/yelp --cache_dir pytorch_cache --do_train --do_eval --do_lower_case ' \
6 | '--num_train_epochs 2.'
7 |
8 | os.system(command)
--------------------------------------------------------------------------------
/BERT/run_classifier_mnli.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | command = 'python run_classifier.py --data_dir /data/medg/misc/jindi/nlp/datasets/MNLI ' \
4 | '--bert_model bert-base-uncased ' \
5 | '--task_name mnli --output_dir results/MNLI --cache_dir pytorch_cache --do_eval --do_lower_case ' \
6 | '--do_resume'
7 |
8 | os.system(command)
--------------------------------------------------------------------------------
/BERT/run_classifier_snli.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | command = 'python run_classifier.py --data_dir /data/medg/misc/jindi/nlp/datasets/SNLI/snli_1.0 ' \
4 | '--bert_model bert-base-uncased ' \
5 | '--task_name snli --output_dir results/SNLI_retrain --cache_dir pytorch_cache --do_train --do_eval --do_lower_case ' \
6 | # '--do_resume'
7 |
8 | os.system(command)
--------------------------------------------------------------------------------
/BERT/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import, division, print_function, unicode_literals
18 |
19 | import collections
20 | import logging
21 | import os
22 | import unicodedata
23 | from io import open
24 | import six
25 |
26 | from .file_utils import cached_path
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
31 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
32 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
33 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
34 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
35 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
36 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
37 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
38 | 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
39 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
40 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
41 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
42 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
43 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
44 | }
45 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
46 | 'bert-base-uncased': 512,
47 | 'bert-large-uncased': 512,
48 | 'bert-base-cased': 512,
49 | 'bert-large-cased': 512,
50 | 'bert-base-multilingual-uncased': 512,
51 | 'bert-base-multilingual-cased': 512,
52 | 'bert-base-chinese': 512,
53 | 'bert-base-german-cased': 512,
54 | 'bert-large-uncased-whole-word-masking': 512,
55 | 'bert-large-cased-whole-word-masking': 512,
56 | 'bert-large-uncased-whole-word-masking-finetuned-squad': 512,
57 | 'bert-large-cased-whole-word-masking-finetuned-squad': 512,
58 | 'bert-base-cased-finetuned-mrpc': 512,
59 | }
60 | VOCAB_NAME = 'vocab.txt'
61 |
62 |
63 | def load_vocab(vocab_file):
64 | """Loads a vocabulary file into a dictionary."""
65 | vocab = collections.OrderedDict()
66 | index = 0
67 | with open(vocab_file, "r", encoding="utf-8") as reader:
68 | while True:
69 | token = reader.readline()
70 | if not token:
71 | break
72 | token = token.strip()
73 | vocab[token] = index
74 | index += 1
75 | return vocab
76 |
77 |
78 | def whitespace_tokenize(text):
79 | """Runs basic whitespace cleaning and splitting on a piece of text."""
80 | text = text.strip()
81 | if not text:
82 | return []
83 | tokens = text.split()
84 | return tokens
85 |
86 |
87 | class BertTokenizer(object):
88 | """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
89 |
90 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
91 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
92 | """Constructs a BertTokenizer.
93 | Args:
94 | vocab_file: Path to a one-wordpiece-per-line vocabulary file
95 | do_lower_case: Whether to lower case the input
96 | Only has an effect when do_wordpiece_only=False
97 | do_basic_tokenize: Whether to do basic tokenization before wordpiece.
98 | max_len: An artificial maximum length to truncate tokenized sequences to;
99 | Effective maximum length is always the minimum of this
100 | value (if specified) and the underlying BERT model's
101 | sequence length.
102 | never_split: List of tokens which will never be split during tokenization.
103 | Only has an effect when do_wordpiece_only=False
104 | """
105 | if not os.path.isfile(vocab_file):
106 | raise ValueError(
107 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
108 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
109 | self.vocab = load_vocab(vocab_file)
110 | self.ids_to_tokens = collections.OrderedDict(
111 | [(ids, tok) for tok, ids in self.vocab.items()])
112 | self.do_basic_tokenize = do_basic_tokenize
113 | if do_basic_tokenize:
114 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
115 | never_split=never_split)
116 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
117 | self.max_len = max_len if max_len is not None else int(1e12)
118 |
119 | def tokenize(self, text):
120 | split_tokens = []
121 | if self.do_basic_tokenize:
122 | for token in self.basic_tokenizer.tokenize(text):
123 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
124 | split_tokens.append(sub_token)
125 | else:
126 | split_tokens = self.wordpiece_tokenizer.tokenize(text)
127 | return split_tokens
128 |
129 | def convert_tokens_to_ids(self, tokens):
130 | """Converts a sequence of tokens into ids using the vocab."""
131 | ids = []
132 | for token in tokens:
133 | ids.append(self.vocab[token])
134 | if len(ids) > self.max_len:
135 | logger.warning(
136 | "Token indices sequence length is longer than the specified maximum "
137 | " sequence length for this BERT model ({} > {}). Running this"
138 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
139 | )
140 | return ids
141 |
142 | def convert_ids_to_tokens(self, ids):
143 | """Converts a sequence of ids in wordpiece tokens using the vocab."""
144 | tokens = []
145 | for i in ids:
146 | tokens.append(self.ids_to_tokens[i])
147 | return tokens
148 |
149 | def save_vocabulary(self, vocab_path):
150 | """Save the tokenizer vocabulary to a directory or file."""
151 | index = 0
152 | if os.path.isdir(vocab_path):
153 | vocab_file = os.path.join(vocab_path, VOCAB_NAME)
154 | with open(vocab_file, "w", encoding="utf-8") as writer:
155 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
156 | if index != token_index:
157 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
158 | " Please check that the vocabulary is not corrupted!".format(vocab_file))
159 | index = token_index
160 | writer.write(token + u'\n')
161 | index += 1
162 | return vocab_file
163 |
164 | @classmethod
165 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
166 | """
167 | Instantiate a PreTrainedBertModel from a pre-trained model file.
168 | Download and cache the pre-trained model file if needed.
169 | """
170 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
171 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
172 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
173 | logger.warning("The pre-trained model you are loading is a cased model but you have not set "
174 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
175 | "you may want to check this behavior.")
176 | kwargs['do_lower_case'] = False
177 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
178 | logger.warning("The pre-trained model you are loading is an uncased model but you have set "
179 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you "
180 | "but you may want to check this behavior.")
181 | kwargs['do_lower_case'] = True
182 | else:
183 | vocab_file = pretrained_model_name_or_path
184 | if os.path.isdir(vocab_file):
185 | vocab_file = os.path.join(vocab_file, VOCAB_NAME)
186 | # redirect to the cache, if necessary
187 | try:
188 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
189 | except EnvironmentError:
190 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
191 | logger.error(
192 | "Couldn't reach server at '{}' to download vocabulary.".format(
193 | vocab_file))
194 | else:
195 | logger.error(
196 | "Model name '{}' was not found in model name list ({}). "
197 | "We assumed '{}' was a path or url but couldn't find any file "
198 | "associated to this path or url.".format(
199 | pretrained_model_name_or_path,
200 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
201 | vocab_file))
202 | return None
203 | if resolved_vocab_file == vocab_file:
204 | logger.info("loading vocabulary file {}".format(vocab_file))
205 | else:
206 | logger.info("loading vocabulary file {} from cache at {}".format(
207 | vocab_file, resolved_vocab_file))
208 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
209 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
210 | # than the number of positional embeddings
211 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
212 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
213 | # Instantiate tokenizer.
214 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
215 | return tokenizer
216 |
217 |
218 | class BasicTokenizer(object):
219 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
220 |
221 | def __init__(self,
222 | do_lower_case=True,
223 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
224 | """Constructs a BasicTokenizer.
225 | Args:
226 | do_lower_case: Whether to lower case the input.
227 | """
228 | self.do_lower_case = do_lower_case
229 | self.never_split = never_split
230 |
231 | def tokenize(self, text):
232 | """Tokenizes a piece of text."""
233 | text = self._clean_text(text)
234 | # This was added on November 1st, 2018 for the multilingual and Chinese
235 | # models. This is also applied to the English models now, but it doesn't
236 | # matter since the English models were not trained on any Chinese data
237 | # and generally don't have any Chinese data in them (there are Chinese
238 | # characters in the vocabulary because Wikipedia does have some Chinese
239 | # words in the English Wikipedia.).
240 | text = self._tokenize_chinese_chars(text)
241 | orig_tokens = whitespace_tokenize(text)
242 | split_tokens = []
243 | for token in orig_tokens:
244 | if self.do_lower_case and token not in self.never_split:
245 | token = token.lower()
246 | token = self._run_strip_accents(token)
247 | split_tokens.extend(self._run_split_on_punc(token))
248 |
249 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
250 | return output_tokens
251 |
252 | def _run_strip_accents(self, text):
253 | """Strips accents from a piece of text."""
254 | text = unicodedata.normalize("NFD", text)
255 | output = []
256 | for char in text:
257 | cat = unicodedata.category(char)
258 | if cat == "Mn":
259 | continue
260 | output.append(char)
261 | return "".join(output)
262 |
263 | def _run_split_on_punc(self, text):
264 | """Splits punctuation on a piece of text."""
265 | if text in self.never_split:
266 | return [text]
267 | chars = list(text)
268 | i = 0
269 | start_new_word = True
270 | output = []
271 | while i < len(chars):
272 | char = chars[i]
273 | if _is_punctuation(char):
274 | output.append([char])
275 | start_new_word = True
276 | else:
277 | if start_new_word:
278 | output.append([])
279 | start_new_word = False
280 | output[-1].append(char)
281 | i += 1
282 |
283 | return ["".join(x) for x in output]
284 |
285 | def _tokenize_chinese_chars(self, text):
286 | """Adds whitespace around any CJK character."""
287 | output = []
288 | for char in text:
289 | cp = ord(char)
290 | if self._is_chinese_char(cp):
291 | output.append(" ")
292 | output.append(char)
293 | output.append(" ")
294 | else:
295 | output.append(char)
296 | return "".join(output)
297 |
298 | def _is_chinese_char(self, cp):
299 | """Checks whether CP is the codepoint of a CJK character."""
300 | # This defines a "chinese character" as anything in the CJK Unicode block:
301 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
302 | #
303 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
304 | # despite its name. The modern Korean Hangul alphabet is a different block,
305 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
306 | # space-separated words, so they are not treated specially and handled
307 | # like the all of the other languages.
308 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
309 | (cp >= 0x3400 and cp <= 0x4DBF) or #
310 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
311 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
312 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
313 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
314 | (cp >= 0xF900 and cp <= 0xFAFF) or #
315 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
316 | return True
317 |
318 | return False
319 |
320 | def _clean_text(self, text):
321 | """Performs invalid character removal and whitespace cleanup on text."""
322 | output = []
323 | for char in text:
324 | cp = ord(char)
325 | if cp == 0 or cp == 0xfffd or _is_control(char):
326 | continue
327 | if _is_whitespace(char):
328 | output.append(" ")
329 | else:
330 | output.append(char)
331 | return "".join(output)
332 |
333 |
334 | class WordpieceTokenizer(object):
335 | """Runs WordPiece tokenization."""
336 |
337 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
338 | self.vocab = vocab
339 | self.unk_token = unk_token
340 | self.max_input_chars_per_word = max_input_chars_per_word
341 |
342 | def tokenize(self, text):
343 | """Tokenizes a piece of text into its word pieces.
344 | This uses a greedy longest-match-first algorithm to perform tokenization
345 | using the given vocabulary.
346 | For example:
347 | input = "unaffable"
348 | output = ["un", "##aff", "##able"]
349 | Args:
350 | text: A single token or whitespace separated tokens. This should have
351 | already been passed through `BasicTokenizer`.
352 | Returns:
353 | A list of wordpiece tokens.
354 | """
355 |
356 | output_tokens = []
357 | for token in whitespace_tokenize(text):
358 | chars = list(token)
359 | if len(chars) > self.max_input_chars_per_word:
360 | output_tokens.append(self.unk_token)
361 | continue
362 |
363 | is_bad = False
364 | start = 0
365 | sub_tokens = []
366 | while start < len(chars):
367 | end = len(chars)
368 | cur_substr = None
369 | while start < end:
370 | substr = "".join(chars[start:end])
371 | if start > 0:
372 | substr = "##" + substr
373 | if substr in self.vocab:
374 | cur_substr = substr
375 | break
376 | end -= 1
377 | if cur_substr is None:
378 | is_bad = True
379 | break
380 | sub_tokens.append(cur_substr)
381 | start = end
382 |
383 | if is_bad:
384 | output_tokens.append(self.unk_token)
385 | else:
386 | output_tokens.extend(sub_tokens)
387 | return output_tokens
388 |
389 |
390 | def _is_whitespace(char):
391 | """Checks whether `chars` is a whitespace character."""
392 | # \t, \n, and \r are technically contorl characters but we treat them
393 | # as whitespace since they are generally considered as such.
394 | if char == " " or char == "\t" or char == "\n" or char == "\r":
395 | return True
396 | cat = unicodedata.category(char)
397 | if cat == "Zs":
398 | return True
399 | return False
400 |
401 |
402 | def _is_control(char):
403 | """Checks whether `chars` is a control character."""
404 | # These are technically control characters but we count them as whitespace
405 | # characters.
406 | if char == "\t" or char == "\n" or char == "\r":
407 | return False
408 | cat = unicodedata.category(char)
409 | if cat.startswith("C"):
410 | return True
411 | return False
412 |
413 |
414 | def _is_punctuation(char):
415 | """Checks whether `chars` is a punctuation character."""
416 | cp = ord(char)
417 | # We treat all non-letter/number ASCII as punctuation.
418 | # Characters such as "^", "$", and "`" are not in the Unicode
419 | # Punctuation class but we treat them as punctuation anyways, for
420 | # consistency.
421 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
422 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
423 | return True
424 | cat = unicodedata.category(char)
425 | if cat.startswith("P"):
426 | return True
427 | return False
428 |
429 |
430 | def convert_to_unicode(text):
431 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
432 | if six.PY3:
433 | if isinstance(text, str):
434 | return text
435 | elif isinstance(text, bytes):
436 | return text.decode("utf-8", "ignore")
437 | else:
438 | raise ValueError("Unsupported string type: %s" % (type(text)))
439 | elif six.PY2:
440 | if isinstance(text, str):
441 | return text.decode("utf-8", "ignore")
442 | elif isinstance(text, unicode):
443 | return text
444 | else:
445 | raise ValueError("Unsupported string type: %s" % (type(text)))
446 | else:
447 | raise ValueError("Not running on Python2 or Python 3?")
--------------------------------------------------------------------------------
/ESIM/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/TextFooler/6a56cb10f1055e00a8fbc6882f289e91cfe60f4e/ESIM/.DS_Store
--------------------------------------------------------------------------------
/ESIM/esim/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/TextFooler/6a56cb10f1055e00a8fbc6882f289e91cfe60f4e/ESIM/esim/__init__.py
--------------------------------------------------------------------------------
/ESIM/esim/data.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocessor and dataset definition for NLI.
3 | """
4 | # Aurelien Coet, 2018.
5 |
6 | import string
7 | import torch
8 | import numpy as np
9 |
10 | from collections import Counter
11 | from torch.utils.data import Dataset
12 |
13 |
14 | class Preprocessor(object):
15 | """
16 | Preprocessor class for Natural Language Inference datasets.
17 |
18 | The class can be used to read NLI datasets, build worddicts for them
19 | and transform their premises, hypotheses and labels into lists of
20 | integer indices.
21 | """
22 |
23 | def __init__(self,
24 | lowercase=False,
25 | ignore_punctuation=False,
26 | num_words=None,
27 | stopwords=[],
28 | labeldict={},
29 | bos=None,
30 | eos=None):
31 | """
32 | Args:
33 | lowercase: A boolean indicating whether the words in the datasets
34 | being preprocessed must be lowercased or not. Defaults to
35 | False.
36 | ignore_punctuation: A boolean indicating whether punctuation must
37 | be ignored or not in the datasets preprocessed by the object.
38 | num_words: An integer indicating the number of words to use in the
39 | worddict of the object. If set to None, all the words in the
40 | data are kept. Defaults to None.
41 | stopwords: A list of words that must be ignored when building the
42 | worddict for a dataset. Defaults to an empty list.
43 | bos: A string indicating the symbol to use for the 'beginning of
44 | sentence' token in the data. If set to None, the token isn't
45 | used. Defaults to None.
46 | eos: A string indicating the symbol to use for the 'end of
47 | sentence' token in the data. If set to None, the token isn't
48 | used. Defaults to None.
49 | """
50 | self.lowercase = lowercase
51 | self.ignore_punctuation = ignore_punctuation
52 | self.num_words = num_words
53 | self.stopwords = stopwords
54 | self.labeldict = labeldict
55 | self.bos = bos
56 | self.eos = eos
57 |
58 | def read_data(self, filepath):
59 | """
60 | Read the premises, hypotheses and labels from some NLI dataset's
61 | file and return them in a dictionary. The file should be in the same
62 | form as SNLI's .txt files.
63 |
64 | Args:
65 | filepath: The path to a file containing some premises, hypotheses
66 | and labels that must be read. The file should be formatted in
67 | the same way as the SNLI (and MultiNLI) dataset.
68 |
69 | Returns:
70 | A dictionary containing three lists, one for the premises, one for
71 | the hypotheses, and one for the labels in the input data.
72 | """
73 | with open(filepath, 'r', encoding='utf8') as input_data:
74 | ids, premises, hypotheses, labels = [], [], [], []
75 |
76 | # Translation tables to remove parentheses and punctuation from
77 | # strings.
78 | parentheses_table = str.maketrans({'(': None, ')': None})
79 | punct_table = str.maketrans({key: ' '
80 | for key in string.punctuation})
81 |
82 | # Ignore the headers on the first line of the file.
83 | next(input_data)
84 |
85 | for line in input_data:
86 | line = line.strip().split('\t')
87 |
88 | # Ignore sentences that have no gold label.
89 | if line[0] == '-':
90 | continue
91 |
92 | pair_id = line[7]
93 | premise = line[1]
94 | hypothesis = line[2]
95 |
96 | # Remove '(' and ')' from the premises and hypotheses.
97 | premise = premise.translate(parentheses_table)
98 | hypothesis = hypothesis.translate(parentheses_table)
99 |
100 | if self.lowercase:
101 | premise = premise.lower()
102 | hypothesis = hypothesis.lower()
103 |
104 | if self.ignore_punctuation:
105 | premise = premise.translate(punct_table)
106 | hypothesis = hypothesis.translate(punct_table)
107 |
108 | # Each premise and hypothesis is split into a list of words.
109 | premises.append([w for w in premise.rstrip().split()
110 | if w not in self.stopwords])
111 | hypotheses.append([w for w in hypothesis.rstrip().split()
112 | if w not in self.stopwords])
113 | labels.append(line[0])
114 | ids.append(pair_id)
115 |
116 | return {"ids": ids,
117 | "premises": premises,
118 | "hypotheses": hypotheses,
119 | "labels": labels}
120 |
121 | def build_worddict(self, data):
122 | """
123 | Build a dictionary associating words to unique integer indices for
124 | some dataset. The worddict can then be used to transform the words
125 | in datasets to their indices.
126 |
127 | Args:
128 | data: A dictionary containing the premises, hypotheses and
129 | labels of some NLI dataset, in the format returned by the
130 | 'read_data' method of the Preprocessor class.
131 | """
132 | words = []
133 | [words.extend(sentence) for sentence in data['premises']]
134 | [words.extend(sentence) for sentence in data['hypotheses']]
135 |
136 | counts = Counter(words)
137 | num_words = self.num_words
138 | if self.num_words is None:
139 | num_words = len(counts)
140 |
141 | self.worddict = {}
142 |
143 | # Special indices are used for padding, out-of-vocabulary words, and
144 | # beginning and end of sentence tokens.
145 | self.worddict["_PAD_"] = 0
146 | self.worddict["_OOV_"] = 1
147 |
148 | offset = 2
149 | if self.bos:
150 | self.worddict["_BOS_"] = 2
151 | offset += 1
152 | if self.eos:
153 | self.worddict["_EOS_"] = 3
154 | offset += 1
155 |
156 | for i, word in enumerate(counts.most_common(num_words)):
157 | self.worddict[word[0]] = i + offset
158 |
159 | if self.labeldict == {}:
160 | label_names = set(data['labels'])
161 | self.labeldict = {label_name: i
162 | for i, label_name in enumerate(label_names)}
163 |
164 | def words_to_indices(self, sentence):
165 | """
166 | Transform the words in a sentence to their corresponding integer
167 | indices.
168 |
169 | Args:
170 | sentence: A list of words that must be transformed to indices.
171 |
172 | Returns:
173 | A list of indices.
174 | """
175 | indices = []
176 | # Include the beggining of sentence token at the start of the sentence
177 | # if one is defined.
178 | if self.bos:
179 | indices.append(self.worddict["_BOS_"])
180 |
181 | for word in sentence:
182 | if word in self.worddict:
183 | index = self.worddict[word]
184 | else:
185 | # Words absent from 'worddict' are treated as a special
186 | # out-of-vocabulary word (OOV).
187 | index = self.worddict['_OOV_']
188 | indices.append(index)
189 | # Add the end of sentence token at the end of the sentence if one
190 | # is defined.
191 | if self.eos:
192 | indices.append(self.worddict["_EOS_"])
193 |
194 | return indices
195 |
196 | def indices_to_words(self, indices):
197 | """
198 | Transform the indices in a list to their corresponding words in
199 | the object's worddict.
200 |
201 | Args:
202 | indices: A list of integer indices corresponding to words in
203 | the Preprocessor's worddict.
204 |
205 | Returns:
206 | A list of words.
207 | """
208 | return [list(self.worddict.keys())[list(self.worddict.values())
209 | .index(i)]
210 | for i in indices]
211 |
212 | def transform_to_indices(self, data):
213 | """
214 | Transform the words in the premises and hypotheses of a dataset, as
215 | well as their associated labels, to integer indices.
216 |
217 | Args:
218 | data: A dictionary containing lists of premises, hypotheses
219 | and labels, in the format returned by the 'read_data'
220 | method of the Preprocessor class.
221 |
222 | Returns:
223 | A dictionary containing the transformed premises, hypotheses and
224 | labels.
225 | """
226 | transformed_data = {"ids": [],
227 | "premises": [],
228 | "hypotheses": [],
229 | "labels": []}
230 |
231 | for i, premise in enumerate(data['premises']):
232 | # Ignore sentences that have a label for which no index was
233 | # defined in 'labeldict'.
234 | label = data["labels"][i]
235 | if label not in self.labeldict and label != "hidden":
236 | continue
237 |
238 | transformed_data["ids"].append(data["ids"][i])
239 |
240 | if label == "hidden":
241 | transformed_data["labels"].append(-1)
242 | else:
243 | transformed_data["labels"].append(self.labeldict[label])
244 |
245 | indices = self.words_to_indices(premise)
246 | transformed_data["premises"].append(indices)
247 |
248 | indices = self.words_to_indices(data["hypotheses"][i])
249 | transformed_data["hypotheses"].append(indices)
250 |
251 | return transformed_data
252 |
253 | def build_embedding_matrix(self, embeddings_file):
254 | """
255 | Build an embedding matrix with pretrained weights for object's
256 | worddict.
257 |
258 | Args:
259 | embeddings_file: A file containing pretrained word embeddings.
260 |
261 | Returns:
262 | A numpy matrix of size (num_words+n_special_tokens, embedding_dim)
263 | containing pretrained word embeddings (the +n_special_tokens is for
264 | the padding and out-of-vocabulary tokens, as well as BOS and EOS if
265 | they're used).
266 | """
267 | # Load the word embeddings in a dictionnary.
268 | embeddings = {}
269 | with open(embeddings_file, 'r', encoding='utf8') as input_data:
270 | for line in input_data:
271 | line = line.split()
272 |
273 | try:
274 | # Check that the second element on the line is the start
275 | # of the embedding and not another word. Necessary to
276 | # ignore multiple word lines.
277 | float(line[1])
278 | word = line[0]
279 | if word in self.worddict:
280 | embeddings[word] = line[1:]
281 |
282 | # Ignore lines corresponding to multiple words separated
283 | # by spaces.
284 | except ValueError:
285 | continue
286 |
287 | num_words = len(self.worddict)
288 | embedding_dim = len(list(embeddings.values())[0])
289 | embedding_matrix = np.zeros((num_words, embedding_dim))
290 |
291 | # Actual building of the embedding matrix.
292 | missed = 0
293 | for word, i in self.worddict.items():
294 | if word in embeddings:
295 | embedding_matrix[i] = np.array(embeddings[word], dtype=float)
296 | else:
297 | if word == "_PAD_":
298 | continue
299 | missed += 1
300 | # Out of vocabulary words are initialised with random gaussian
301 | # samples.
302 | embedding_matrix[i] = np.random.normal(size=(embedding_dim))
303 | print("Missed words: ", missed)
304 |
305 | return embedding_matrix
306 |
307 |
308 | class NLIDataset(Dataset):
309 | """
310 | Dataset class for Natural Language Inference datasets.
311 |
312 | The class can be used to read preprocessed datasets where the premises,
313 | hypotheses and labels have been transformed to unique integer indices
314 | (this can be done with the 'preprocess_data' script in the 'scripts'
315 | folder of this repository).
316 | """
317 |
318 | def __init__(self,
319 | data,
320 | padding_idx=0,
321 | max_premise_length=None,
322 | max_hypothesis_length=None):
323 | """
324 | Args:
325 | data: A dictionary containing the preprocessed premises,
326 | hypotheses and labels of some dataset.
327 | padding_idx: An integer indicating the index being used for the
328 | padding token in the preprocessed data. Defaults to 0.
329 | max_premise_length: An integer indicating the maximum length
330 | accepted for the sequences in the premises. If set to None,
331 | the length of the longest premise in 'data' is used.
332 | Defaults to None.
333 | max_hypothesis_length: An integer indicating the maximum length
334 | accepted for the sequences in the hypotheses. If set to None,
335 | the length of the longest hypothesis in 'data' is used.
336 | Defaults to None.
337 | """
338 | self.premises_lengths = [len(seq) for seq in data["premises"]]
339 | self.max_premise_length = max_premise_length
340 | if self.max_premise_length is None:
341 | self.max_premise_length = max(self.premises_lengths)
342 |
343 | self.hypotheses_lengths = [len(seq) for seq in data["hypotheses"]]
344 | self.max_hypothesis_length = max_hypothesis_length
345 | if self.max_hypothesis_length is None:
346 | self.max_hypothesis_length = max(self.hypotheses_lengths)
347 |
348 | self.num_sequences = len(data["premises"])
349 |
350 | self.data = {"ids": [],
351 | "premises": torch.ones((self.num_sequences,
352 | self.max_premise_length),
353 | dtype=torch.long) * padding_idx,
354 | "hypotheses": torch.ones((self.num_sequences,
355 | self.max_hypothesis_length),
356 | dtype=torch.long) * padding_idx,
357 | "labels": torch.tensor(data["labels"], dtype=torch.long)}
358 |
359 | for i, premise in enumerate(data["premises"]):
360 | self.data["ids"].append(data["ids"][i])
361 | end = min(len(premise), self.max_premise_length)
362 | self.data["premises"][i][:end] = torch.tensor(premise[:end])
363 |
364 | hypothesis = data["hypotheses"][i]
365 | end = min(len(hypothesis), self.max_hypothesis_length)
366 | self.data["hypotheses"][i][:end] = torch.tensor(hypothesis[:end])
367 |
368 | def __len__(self):
369 | return self.num_sequences
370 |
371 | def __getitem__(self, index):
372 | return {"id": self.data["ids"][index],
373 | "premise": self.data["premises"][index],
374 | "premise_length": min(self.premises_lengths[index],
375 | self.max_premise_length),
376 | "hypothesis": self.data["hypotheses"][index],
377 | "hypothesis_length": min(self.hypotheses_lengths[index],
378 | self.max_hypothesis_length),
379 | "label": self.data["labels"][index]}
380 |
--------------------------------------------------------------------------------
/ESIM/esim/layers.py:
--------------------------------------------------------------------------------
1 | """
2 | Definition of custom layers for the ESIM model.
3 | """
4 | # Aurelien Coet, 2018.
5 |
6 | import torch.nn as nn
7 |
8 | from .utils import sort_by_seq_lens, masked_softmax, weighted_sum
9 |
10 |
11 | # Class widely inspired from:
12 | # https://github.com/allenai/allennlp/blob/master/allennlp/modules/input_variational_dropout.py
13 | class RNNDropout(nn.Dropout):
14 | """
15 | Dropout layer for the inputs of RNNs.
16 |
17 | Apply the same dropout mask to all the elements of the same sequence in
18 | a batch of sequences of size (batch, sequences_length, embedding_dim).
19 | """
20 |
21 | def forward(self, sequences_batch):
22 | """
23 | Apply dropout to the input batch of sequences.
24 |
25 | Args:
26 | sequences_batch: A batch of sequences of vectors that will serve
27 | as input to an RNN.
28 | Tensor of size (batch, sequences_length, emebdding_dim).
29 |
30 | Returns:
31 | A new tensor on which dropout has been applied.
32 | """
33 | ones = sequences_batch.data.new_ones(sequences_batch.shape[0],
34 | sequences_batch.shape[-1])
35 | dropout_mask = nn.functional.dropout(ones, self.p, self.training,
36 | inplace=False)
37 | return dropout_mask.unsqueeze(1) * sequences_batch
38 |
39 |
40 | class Seq2SeqEncoder(nn.Module):
41 | """
42 | RNN taking variable length padded sequences of vectors as input and
43 | encoding them into padded sequences of vectors of the same length.
44 |
45 | This module is useful to handle batches of padded sequences of vectors
46 | that have different lengths and that need to be passed through a RNN.
47 | The sequences are sorted in descending order of their lengths, packed,
48 | passed through the RNN, and the resulting sequences are then padded and
49 | permuted back to the original order of the input sequences.
50 | """
51 |
52 | def __init__(self,
53 | rnn_type,
54 | input_size,
55 | hidden_size,
56 | num_layers=1,
57 | bias=True,
58 | dropout=0.0,
59 | bidirectional=False):
60 | """
61 | Args:
62 | rnn_type: The type of RNN to use as encoder in the module.
63 | Must be a class inheriting from torch.nn.RNNBase
64 | (such as torch.nn.LSTM for example).
65 | input_size: The number of expected features in the input of the
66 | module.
67 | hidden_size: The number of features in the hidden state of the RNN
68 | used as encoder by the module.
69 | num_layers: The number of recurrent layers in the encoder of the
70 | module. Defaults to 1.
71 | bias: If False, the encoder does not use bias weights b_ih and
72 | b_hh. Defaults to True.
73 | dropout: If non-zero, introduces a dropout layer on the outputs
74 | of each layer of the encoder except the last one, with dropout
75 | probability equal to 'dropout'. Defaults to 0.0.
76 | bidirectional: If True, the encoder of the module is bidirectional.
77 | Defaults to False.
78 | """
79 | assert issubclass(rnn_type, nn.RNNBase),\
80 | "rnn_type must be a class inheriting from torch.nn.RNNBase"
81 |
82 | super(Seq2SeqEncoder, self).__init__()
83 |
84 | self.rnn_type = rnn_type
85 | self.input_size = input_size
86 | self.hidden_size = hidden_size
87 | self.num_layers = num_layers
88 | self.bias = bias
89 | self.dropout = dropout
90 | self.bidirectional = bidirectional
91 |
92 | self._encoder = rnn_type(input_size,
93 | hidden_size,
94 | num_layers=num_layers,
95 | bias=bias,
96 | batch_first=True,
97 | dropout=dropout,
98 | bidirectional=bidirectional)
99 |
100 | def forward(self, sequences_batch, sequences_lengths):
101 | """
102 | Args:
103 | sequences_batch: A batch of variable length sequences of vectors.
104 | The batch is assumed to be of size
105 | (batch, sequence, vector_dim).
106 | sequences_lengths: A 1D tensor containing the sizes of the
107 | sequences in the input batch.
108 |
109 | Returns:
110 | reordered_outputs: The outputs (hidden states) of the encoder for
111 | the sequences in the input batch, in the same order.
112 | """
113 | sorted_batch, sorted_lengths, _, restoration_idx =\
114 | sort_by_seq_lens(sequences_batch, sequences_lengths)
115 | packed_batch = nn.utils.rnn.pack_padded_sequence(sorted_batch,
116 | sorted_lengths,
117 | batch_first=True)
118 |
119 | outputs, _ = self._encoder(packed_batch, None)
120 |
121 | outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs,
122 | batch_first=True)
123 | reordered_outputs = outputs.index_select(0, restoration_idx)
124 |
125 | return reordered_outputs
126 |
127 |
128 | class SoftmaxAttention(nn.Module):
129 | """
130 | Attention layer taking premises and hypotheses encoded by an RNN as input
131 | and computing the soft attention between their elements.
132 |
133 | The dot product of the encoded vectors in the premises and hypotheses is
134 | first computed. The softmax of the result is then used in a weighted sum
135 | of the vectors of the premises for each element of the hypotheses, and
136 | conversely for the elements of the premises.
137 | """
138 |
139 | def forward(self,
140 | premise_batch,
141 | premise_mask,
142 | hypothesis_batch,
143 | hypothesis_mask):
144 | """
145 | Args:
146 | premise_batch: A batch of sequences of vectors representing the
147 | premises in some NLI task. The batch is assumed to have the
148 | size (batch, sequences, vector_dim).
149 | premise_mask: A mask for the sequences in the premise batch, to
150 | ignore padding data in the sequences during the computation of
151 | the attention.
152 | hypothesis_batch: A batch of sequences of vectors representing the
153 | hypotheses in some NLI task. The batch is assumed to have the
154 | size (batch, sequences, vector_dim).
155 | hypothesis_mask: A mask for the sequences in the hypotheses batch,
156 | to ignore padding data in the sequences during the computation
157 | of the attention.
158 |
159 | Returns:
160 | attended_premises: The sequences of attention vectors for the
161 | premises in the input batch.
162 | attended_hypotheses: The sequences of attention vectors for the
163 | hypotheses in the input batch.
164 | """
165 | # Dot product between premises and hypotheses in each sequence of
166 | # the batch.
167 | similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1)
168 | .contiguous())
169 |
170 | # Softmax attention weights.
171 | prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask)
172 | hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2)
173 | .contiguous(),
174 | premise_mask)
175 |
176 | # Weighted sums of the hypotheses for the the premises attention,
177 | # and vice-versa for the attention of the hypotheses.
178 | attended_premises = weighted_sum(hypothesis_batch,
179 | prem_hyp_attn,
180 | premise_mask)
181 | attended_hypotheses = weighted_sum(premise_batch,
182 | hyp_prem_attn,
183 | hypothesis_mask)
184 |
185 | return attended_premises, attended_hypotheses
186 |
--------------------------------------------------------------------------------
/ESIM/esim/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Definition of the ESIM model.
3 | """
4 | # Aurelien Coet, 2018.
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | from .layers import RNNDropout, Seq2SeqEncoder, SoftmaxAttention
10 | from .utils import get_mask, replace_masked
11 |
12 |
13 | class ESIM(nn.Module):
14 | """
15 | Implementation of the ESIM model presented in the paper "Enhanced LSTM for
16 | Natural Language Inference" by Chen et al.
17 | """
18 |
19 | def __init__(self,
20 | vocab_size,
21 | embedding_dim,
22 | hidden_size,
23 | embeddings=None,
24 | padding_idx=0,
25 | dropout=0.5,
26 | num_classes=3,
27 | device="cpu"):
28 | """
29 | Args:
30 | vocab_size: The size of the vocabulary of embeddings in the model.
31 | embedding_dim: The dimension of the word embeddings.
32 | hidden_size: The size of all the hidden layers in the network.
33 | embeddings: A tensor of size (vocab_size, embedding_dim) containing
34 | pretrained word embeddings. If None, word embeddings are
35 | initialised randomly. Defaults to None.
36 | padding_idx: The index of the padding token in the premises and
37 | hypotheses passed as input to the model. Defaults to 0.
38 | dropout: The dropout rate to use between the layers of the network.
39 | A dropout rate of 0 corresponds to using no dropout at all.
40 | Defaults to 0.5.
41 | num_classes: The number of classes in the output of the network.
42 | Defaults to 3.
43 | device: The name of the device on which the model is being
44 | executed. Defaults to 'cpu'.
45 | """
46 | super(ESIM, self).__init__()
47 |
48 | self.vocab_size = vocab_size
49 | self.embedding_dim = embedding_dim
50 | self.hidden_size = hidden_size
51 | self.num_classes = num_classes
52 | self.dropout = dropout
53 | self.device = device
54 |
55 | self._word_embedding = nn.Embedding(self.vocab_size,
56 | self.embedding_dim,
57 | padding_idx=padding_idx,
58 | _weight=embeddings)
59 |
60 | if self.dropout:
61 | self._rnn_dropout = RNNDropout(p=self.dropout)
62 | # self._rnn_dropout = nn.Dropout(p=self.dropout)
63 |
64 | self._encoding = Seq2SeqEncoder(nn.LSTM,
65 | self.embedding_dim,
66 | self.hidden_size,
67 | bidirectional=True)
68 |
69 | self._attention = SoftmaxAttention()
70 |
71 | self._projection = nn.Sequential(nn.Linear(4*2*self.hidden_size,
72 | self.hidden_size),
73 | nn.ReLU())
74 |
75 | self._composition = Seq2SeqEncoder(nn.LSTM,
76 | self.hidden_size,
77 | self.hidden_size,
78 | bidirectional=True)
79 |
80 | self._classification = nn.Sequential(nn.Dropout(p=self.dropout),
81 | nn.Linear(2*4*self.hidden_size,
82 | self.hidden_size),
83 | nn.Tanh(),
84 | nn.Dropout(p=self.dropout),
85 | nn.Linear(self.hidden_size,
86 | self.num_classes))
87 |
88 | # Initialize all weights and biases in the model.
89 | self.apply(_init_esim_weights)
90 |
91 | def forward(self,
92 | premises,
93 | premises_lengths,
94 | hypotheses,
95 | hypotheses_lengths):
96 | """
97 | Args:
98 | premises: A batch of varaible length sequences of word indices
99 | representing premises. The batch is assumed to be of size
100 | (batch, premises_length).
101 | premises_lengths: A 1D tensor containing the lengths of the
102 | premises in 'premises'.
103 | hypothesis: A batch of varaible length sequences of word indices
104 | representing hypotheses. The batch is assumed to be of size
105 | (batch, hypotheses_length).
106 | hypotheses_lengths: A 1D tensor containing the lengths of the
107 | hypotheses in 'hypotheses'.
108 |
109 | Returns:
110 | logits: A tensor of size (batch, num_classes) containing the
111 | logits for each output class of the model.
112 | probabilities: A tensor of size (batch, num_classes) containing
113 | the probabilities of each output class in the model.
114 | """
115 | premises_mask = get_mask(premises, premises_lengths).to(self.device)
116 | hypotheses_mask = get_mask(hypotheses, hypotheses_lengths)\
117 | .to(self.device)
118 |
119 | embedded_premises = self._word_embedding(premises)
120 | embedded_hypotheses = self._word_embedding(hypotheses)
121 |
122 | if self.dropout:
123 | embedded_premises = self._rnn_dropout(embedded_premises)
124 | embedded_hypotheses = self._rnn_dropout(embedded_hypotheses)
125 |
126 | encoded_premises = self._encoding(embedded_premises,
127 | premises_lengths)
128 | encoded_hypotheses = self._encoding(embedded_hypotheses,
129 | hypotheses_lengths)
130 |
131 | attended_premises, attended_hypotheses =\
132 | self._attention(encoded_premises, premises_mask,
133 | encoded_hypotheses, hypotheses_mask)
134 |
135 | enhanced_premises = torch.cat([encoded_premises,
136 | attended_premises,
137 | encoded_premises - attended_premises,
138 | encoded_premises * attended_premises],
139 | dim=-1)
140 | enhanced_hypotheses = torch.cat([encoded_hypotheses,
141 | attended_hypotheses,
142 | encoded_hypotheses -
143 | attended_hypotheses,
144 | encoded_hypotheses *
145 | attended_hypotheses],
146 | dim=-1)
147 |
148 | projected_premises = self._projection(enhanced_premises)
149 | projected_hypotheses = self._projection(enhanced_hypotheses)
150 |
151 | if self.dropout:
152 | projected_premises = self._rnn_dropout(projected_premises)
153 | projected_hypotheses = self._rnn_dropout(projected_hypotheses)
154 |
155 | v_ai = self._composition(projected_premises, premises_lengths)
156 | v_bj = self._composition(projected_hypotheses, hypotheses_lengths)
157 |
158 | v_a_avg = torch.sum(v_ai * premises_mask.unsqueeze(1)
159 | .transpose(2, 1), dim=1)\
160 | / torch.sum(premises_mask, dim=1, keepdim=True)
161 | v_b_avg = torch.sum(v_bj * hypotheses_mask.unsqueeze(1)
162 | .transpose(2, 1), dim=1)\
163 | / torch.sum(hypotheses_mask, dim=1, keepdim=True)
164 |
165 | v_a_max, _ = replace_masked(v_ai, premises_mask, -1e7).max(dim=1)
166 | v_b_max, _ = replace_masked(v_bj, hypotheses_mask, -1e7).max(dim=1)
167 |
168 | v = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)
169 |
170 | logits = self._classification(v)
171 | probabilities = nn.functional.softmax(logits, dim=-1)
172 |
173 | return logits, probabilities
174 |
175 |
176 | def _init_esim_weights(module):
177 | """
178 | Initialise the weights of the ESIM model.
179 | """
180 | if isinstance(module, nn.Linear):
181 | nn.init.xavier_uniform_(module.weight.data)
182 | nn.init.constant_(module.bias.data, 0.0)
183 |
184 | elif isinstance(module, nn.LSTM):
185 | nn.init.xavier_uniform_(module.weight_ih_l0.data)
186 | nn.init.orthogonal_(module.weight_hh_l0.data)
187 | nn.init.constant_(module.bias_ih_l0.data, 0.0)
188 | nn.init.constant_(module.bias_hh_l0.data, 0.0)
189 | hidden_size = module.bias_hh_l0.data.shape[0] // 4
190 | module.bias_hh_l0.data[hidden_size:(2*hidden_size)] = 1.0
191 |
192 | if (module.bidirectional):
193 | nn.init.xavier_uniform_(module.weight_ih_l0_reverse.data)
194 | nn.init.orthogonal_(module.weight_hh_l0_reverse.data)
195 | nn.init.constant_(module.bias_ih_l0_reverse.data, 0.0)
196 | nn.init.constant_(module.bias_hh_l0_reverse.data, 0.0)
197 | module.bias_hh_l0_reverse.data[hidden_size:(2*hidden_size)] = 1.0
198 |
--------------------------------------------------------------------------------
/ESIM/esim/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for the ESIM model.
3 | """
4 | # Aurelien Coet, 2018.
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | # Code widely inspired from:
11 | # https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py.
12 | def sort_by_seq_lens(batch, sequences_lengths, descending=True):
13 | """
14 | Sort a batch of padded variable length sequences by length.
15 |
16 | Args:
17 | batch: A batch of padded variable length sequences. The batch should
18 | have the dimensions (batch_size x max_sequence_length x *).
19 | sequences_lengths: A tensor containing the lengths of the sequences in the
20 | input batch. The tensor should be of size (batch_size).
21 | descending: A boolean value indicating whether to sort the sequences
22 | by their lengths in descending order. Defaults to True.
23 |
24 | Returns:
25 | sorted_batch: A tensor containing the input batch reordered by
26 | sequences lengths.
27 | sorted_seq_lens: A tensor containing the sorted lengths of the
28 | sequences in the input batch.
29 | sorting_idx: A tensor containing the indices used to permute the input
30 | batch in order to get 'sorted_batch'.
31 | restoration_idx: A tensor containing the indices that can be used to
32 | restore the order of the sequences in 'sorted_batch' so that it
33 | matches the input batch.
34 | """
35 | sorted_seq_lens, sorting_index =\
36 | sequences_lengths.sort(0, descending=descending)
37 |
38 | sorted_batch = batch.index_select(0, sorting_index)
39 |
40 | idx_range =\
41 | sequences_lengths.new_tensor(torch.arange(0, len(sequences_lengths)))
42 | _, reverse_mapping = sorting_index.sort(0, descending=False)
43 | restoration_index = idx_range.index_select(0, reverse_mapping)
44 |
45 | return sorted_batch, sorted_seq_lens, sorting_index, restoration_index
46 |
47 |
48 | def get_mask(sequences_batch, sequences_lengths):
49 | """
50 | Get the mask for a batch of padded variable length sequences.
51 |
52 | Args:
53 | sequences_batch: A batch of padded variable length sequences
54 | containing word indices. Must be a 2-dimensional tensor of size
55 | (batch, sequence).
56 | sequences_lengths: A tensor containing the lengths of the sequences in
57 | 'sequences_batch'. Must be of size (batch).
58 |
59 | Returns:
60 | A mask of size (batch, max_sequence_length), where max_sequence_length
61 | is the length of the longest sequence in the batch.
62 | """
63 | batch_size = sequences_batch.size()[0]
64 | max_length = torch.max(sequences_lengths)
65 | mask = torch.ones(batch_size, max_length, dtype=torch.float)
66 | mask[sequences_batch[:, :max_length] == 0] = 0.0
67 | return mask
68 |
69 |
70 | # Code widely inspired from:
71 | # https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py.
72 | def masked_softmax(tensor, mask):
73 | """
74 | Apply a masked softmax on the last dimension of a tensor.
75 | The input tensor and mask should be of size (batch, *, sequence_length).
76 |
77 | Args:
78 | tensor: The tensor on which the softmax function must be applied along
79 | the last dimension.
80 | mask: A mask of the same size as the tensor with 0s in the positions of
81 | the values that must be masked and 1s everywhere else.
82 |
83 | Returns:
84 | A tensor of the same size as the inputs containing the result of the
85 | softmax.
86 | """
87 | tensor_shape = tensor.size()
88 | reshaped_tensor = tensor.view(-1, tensor_shape[-1])
89 |
90 | # Reshape the mask so it matches the size of the input tensor.
91 | while mask.dim() < tensor.dim():
92 | mask = mask.unsqueeze(1)
93 | mask = mask.expand_as(tensor).contiguous().float()
94 | reshaped_mask = mask.view(-1, mask.size()[-1])
95 |
96 | result = nn.functional.softmax(reshaped_tensor * reshaped_mask, dim=-1)
97 | result = result * reshaped_mask
98 | # 1e-13 is added to avoid divisions by zero.
99 | result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
100 |
101 | return result.view(*tensor_shape)
102 |
103 |
104 | # Code widely inspired from:
105 | # https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py.
106 | def weighted_sum(tensor, weights, mask):
107 | """
108 | Apply a weighted sum on the vectors along the last dimension of 'tensor',
109 | and mask the vectors in the result with 'mask'.
110 |
111 | Args:
112 | tensor: A tensor of vectors on which a weighted sum must be applied.
113 | weights: The weights to use in the weighted sum.
114 | mask: A mask to apply on the result of the weighted sum.
115 |
116 | Returns:
117 | A new tensor containing the result of the weighted sum after the mask
118 | has been applied on it.
119 | """
120 | weighted_sum = weights.bmm(tensor)
121 |
122 | while mask.dim() < weighted_sum.dim():
123 | mask = mask.unsqueeze(1)
124 | mask = mask.transpose(-1, -2)
125 | mask = mask.expand_as(weighted_sum).contiguous().float()
126 |
127 | return weighted_sum * mask
128 |
129 |
130 | # Code inspired from:
131 | # https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py.
132 | def replace_masked(tensor, mask, value):
133 | """
134 | Replace the all the values of vectors in 'tensor' that are masked in
135 | 'masked' by 'value'.
136 |
137 | Args:
138 | tensor: The tensor in which the masked vectors must have their values
139 | replaced.
140 | mask: A mask indicating the vectors which must have their values
141 | replaced.
142 | value: The value to place in the masked vectors of 'tensor'.
143 |
144 | Returns:
145 | A new tensor of the same size as 'tensor' where the values of the
146 | vectors masked in 'mask' were replaced by 'value'.
147 | """
148 | mask = mask.unsqueeze(1).transpose(2, 1)
149 | reverse_mask = 1.0 - mask
150 | values_to_add = value * reverse_mask
151 | return tensor * mask + values_to_add
152 |
153 |
154 | def correct_predictions(output_probabilities, targets):
155 | """
156 | Compute the number of predictions that match some target classes in the
157 | output of a model.
158 |
159 | Args:
160 | output_probabilities: A tensor of probabilities for different output
161 | classes.
162 | targets: The indices of the actual target classes.
163 |
164 | Returns:
165 | The number of correct predictions in 'output_probabilities'.
166 | """
167 | _, out_classes = output_probabilities.max(dim=1)
168 | correct = (out_classes == targets).sum()
169 | return correct.item()
170 |
--------------------------------------------------------------------------------
/ESIM/scripts/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jind11/TextFooler/6a56cb10f1055e00a8fbc6882f289e91cfe60f4e/ESIM/scripts/.DS_Store
--------------------------------------------------------------------------------
/ESIM/scripts/fetch_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Fetch datasets and pretrained word embeddings for the ESIM model.
3 |
4 | By default, the script downloads the following.
5 | - The SNLI corpus;
6 | - GloVe word embeddings (840B - 300d).
7 | """
8 | # Aurelien Coet, 2018.
9 |
10 | import os
11 | import argparse
12 | import zipfile
13 | import wget
14 |
15 |
16 | def download(url, targetdir):
17 | """
18 | Download a file and save it in some target directory.
19 |
20 | Args:
21 | url: The url from which the file must be downloaded.
22 | targetdir: The path to the directory where the file must be saved.
23 |
24 | Returns:
25 | The path to the downloaded file.
26 | """
27 | print("* Downloading data from {}...".format(url))
28 | filepath = os.path.join(targetdir, url.split('/')[-1])
29 | wget.download(url, filepath)
30 | return filepath
31 |
32 |
33 | def unzip(filepath):
34 | """
35 | Extract the data from a zipped file and delete the archive.
36 |
37 | Args:
38 | filepath: The path to the zipped file.
39 | """
40 | print("\n* Extracting: {}...".format(filepath))
41 | dirpath = os.path.dirname(filepath)
42 | with zipfile.ZipFile(filepath) as zf:
43 | for name in zf.namelist():
44 | # Ignore useless files in archives.
45 | if "__MACOSX" in name or\
46 | ".DS_Store" in name or\
47 | "Icon" in name:
48 | continue
49 | zf.extract(name, dirpath)
50 | # Delete the archive once the data has been extracted.
51 | os.remove(filepath)
52 |
53 |
54 | def download_unzip(url, targetdir):
55 | """
56 | Download and unzip data from some url and save it in a target directory.
57 |
58 | Args:
59 | url: The url to download the data from.
60 | targetdir: The target directory in which to download and unzip the
61 | data.
62 | """
63 | filepath = os.path.join(targetdir, url.split('/')[-1])
64 |
65 | if not os.path.exists(targetdir):
66 | print("* Creating target directory {}...".format(targetdir))
67 | os.makedirs(targetdir)
68 |
69 | # Download and unzip if the target directory is empty.
70 | if not os.listdir(targetdir):
71 | unzip(download(url, targetdir))
72 | # Skip downloading if the zipped data is already available.
73 | elif os.path.exists(filepath):
74 | print("* Found zipped data - skipping download...")
75 | unzip(filepath)
76 | # Skip download and unzipping if the unzipped data is already available.
77 | else:
78 | print("* Found unzipped data for {}, skipping download and unzip..."
79 | .format(targetdir))
80 |
81 |
82 | if __name__ == "__main__":
83 | # Default data.
84 | snli_url = "https://nlp.stanford.edu/projects/snli/snli_1.0.zip"
85 | glove_url = "http://www-nlp.stanford.edu/data/glove.840B.300d.zip"
86 |
87 | parser = argparse.ArgumentParser(description='Download the SNLI dataset')
88 | parser.add_argument('--dataset_url',
89 | default=snli_url,
90 | help='URL of the dataset to download')
91 | parser.add_argument('--embeddings_url',
92 | default=glove_url,
93 | help='URL of the pretrained embeddings to download')
94 | parser.add_argument('--target_dir',
95 | default=os.path.join('..', 'data'),
96 | help='Path to a directory where data must be saved')
97 | args = parser.parse_args()
98 |
99 | if not os.path.exists(args.target_dir):
100 | os.makedirs(args.target_dir)
101 |
102 | print(20*'=', "Fetching the dataset:", 20*'=')
103 | download_unzip(args.dataset_url, os.path.join(args.target_dir, "dataset"))
104 |
105 | print(20*'=', "Fetching the word embeddings:", 20*'=')
106 | download_unzip(args.embeddings_url,
107 | os.path.join(args.target_dir, "embeddings"))
108 |
--------------------------------------------------------------------------------
/ESIM/scripts/preprocessing/preprocess_bnli.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess the Breaking NLI data set.
3 | """
4 |
5 | import os
6 | import json
7 | import pickle
8 |
9 | from nltk import word_tokenize
10 |
11 | from esim.data import Preprocessor
12 |
13 |
14 | def jsonl_to_txt(input_file, output_file):
15 | """
16 | Transform the Breaking NLI data from a jsonl file to .txt for
17 | further processing.
18 |
19 | Args:
20 | input_file: The path to the Breaking NLI data set in jsonl format.
21 | output_file: The path to the .txt file where the tranformed data must
22 | be saved.
23 | """
24 | with open(input_file, 'r') as input_f, open(output_file, 'w') as output_f:
25 | output_f.write("label\tsentence1\tsentence2\t\t\t\t\t\tpairID\n")
26 |
27 | for line in input_f:
28 | data = json.loads(line)
29 |
30 | # Sentences in the Breaking NLI data set aren't distributed in the
31 | # form of binary parses, so we must tokenise them with nltk.
32 | sentence1 = word_tokenize(data['sentence1'])
33 | sentence1 = " ".join(sentence1)
34 | sentence2 = word_tokenize(data['sentence2'])
35 | sentence2 = " ".join(sentence2)
36 |
37 | # The 5 tabs between sentence 2 and the pairID are added to
38 | # follow the same structure as the txt files in SNLI and MNLI.
39 | output_f.write(data['gold_label'] + "\t" + sentence1 + "\t" +
40 | sentence2 + "\t\t\t\t\t" + str(data['pairID']) +
41 | "\n")
42 |
43 |
44 | def preprocess_BNLI_data(input_file,
45 | targetdir,
46 | worddict,
47 | labeldict):
48 | """
49 | Preprocess the BNLI data set so it can be used to test a model trained
50 | on SNLI.
51 |
52 | Args:
53 | inputdir: The path to the file containing the Breaking NLI (BNLI) data.
54 | target_dir: The path to the directory where the preprocessed Breaking
55 | NLI data must be saved.
56 | worddict: The path to the pickled worddict used for preprocessing the
57 | training data on which models were trained before being tested on
58 | BNLI.
59 | labeldict: The dict of labels used for the training data on which
60 | models were trained before being tested on BNLI.
61 | """
62 | if not os.path.exists(targetdir):
63 | os.makedirs(targetdir)
64 |
65 | output_file = os.path.join(targetdir, "bnli.txt")
66 |
67 | print(20*"=", " Preprocessing Breaking NLI data set ", 20*"=")
68 | print("\t* Tranforming jsonl data to txt...")
69 | jsonl_to_txt(input_file, output_file)
70 |
71 | preprocessor = Preprocessor(labeldict=labeldict)
72 |
73 | with open(worddict, 'rb') as pkl:
74 | wdict = pickle.load(pkl)
75 | preprocessor.worddict = wdict
76 |
77 | print("\t* Reading txt data...")
78 | data = preprocessor.read_data(output_file)
79 |
80 | print("\t* Transforming words in premises and hypotheses to indices...")
81 | transformed_data = preprocessor.transform_to_indices(data)
82 |
83 | print("\t* Saving result...")
84 | with open(os.path.join(targetdir, "bnli_data.pkl"), 'wb') as pkl_file:
85 | pickle.dump(transformed_data, pkl_file)
86 |
87 |
88 | if __name__ == "__main__":
89 | import argparse
90 |
91 | parser = argparse.ArgumentParser(description='Preprocess the Breaking\
92 | NLI (BNLI) dataset')
93 | parser.add_argument('--config',
94 | default="../config/preprocessing/bnli_preprocessing.json",
95 | help='Path to a configuration file for preprocessing BNLI')
96 | args = parser.parse_args()
97 |
98 | with open(os.path.normpath(args.config), 'r') as cfg_file:
99 | config = json.load(cfg_file)
100 |
101 | preprocess_BNLI_data(os.path.normpath(config["data_file"]),
102 | os.path.normpath(config["target_dir"]),
103 | os.path.normpath(config["worddict"]),
104 | config["labeldict"])
105 |
--------------------------------------------------------------------------------
/ESIM/scripts/preprocessing/preprocess_mnli.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess the MultiNLI dataset and word embeddings to be used by the
3 | ESIM model.
4 | """
5 | # Aurelien Coet, 2019.
6 |
7 | import os
8 | import pickle
9 | import fnmatch
10 | import json
11 |
12 | from esim.data import Preprocessor
13 |
14 |
15 | def preprocess_MNLI_data(inputdir,
16 | embeddings_file,
17 | targetdir,
18 | lowercase=False,
19 | ignore_punctuation=False,
20 | num_words=None,
21 | stopwords=[],
22 | labeldict={},
23 | bos=None,
24 | eos=None):
25 | """
26 | Preprocess the data from the MultiNLI corpus so it can be used by the
27 | ESIM model.
28 | Compute a worddict from the train set, and transform the words in
29 | the sentences of the corpus to their indices, as well as the labels.
30 | Build an embedding matrix from pretrained word vectors.
31 | The preprocessed data is saved in pickled form in some target directory.
32 |
33 | Args:
34 | inputdir: The path to the directory containing the NLI corpus.
35 | embeddings_file: The path to the file containing the pretrained
36 | word vectors that must be used to build the embedding matrix.
37 | targetdir: The path to the directory where the preprocessed data
38 | must be saved.
39 | lowercase: Boolean value indicating whether to lowercase the premises
40 | and hypotheseses in the input data. Defautls to False.
41 | ignore_punctuation: Boolean value indicating whether to remove
42 | punctuation from the input data. Defaults to False.
43 | num_words: Integer value indicating the size of the vocabulary to use
44 | for the word embeddings. If set to None, all words are kept.
45 | Defaults to None.
46 | stopwords: A list of words that must be ignored when preprocessing
47 | the data. Defaults to an empty list.
48 | bos: A string indicating the symbol to use for beginning of sentence
49 | tokens. If set to None, bos tokens aren't used. Defaults to None.
50 | eos: A string indicating the symbol to use for end of sentence tokens.
51 | If set to None, eos tokens aren't used. Defaults to None.
52 | """
53 | if not os.path.exists(targetdir):
54 | os.makedirs(targetdir)
55 |
56 | # Retrieve the train, dev and test data files from the dataset directory.
57 | train_file = ""
58 | matched_dev_file = ""
59 | mismatched_dev_file = ""
60 | matched_test_file = ""
61 | mismatched_test_file = ""
62 | for file in os.listdir(inputdir):
63 | if fnmatch.fnmatch(file, '*_train.txt'):
64 | train_file = file
65 | elif fnmatch.fnmatch(file, '*_dev_matched.txt'):
66 | matched_dev_file = file
67 | elif fnmatch.fnmatch(file, '*_dev_mismatched.txt'):
68 | mismatched_dev_file = file
69 | elif fnmatch.fnmatch(file, '*_test_matched_unlabeled.txt'):
70 | matched_test_file = file
71 | elif fnmatch.fnmatch(file, '*_test_mismatched_unlabeled.txt'):
72 | mismatched_test_file = file
73 |
74 | # -------------------- Train data preprocessing -------------------- #
75 | preprocessor = Preprocessor(lowercase=lowercase,
76 | ignore_punctuation=ignore_punctuation,
77 | num_words=num_words,
78 | stopwords=stopwords,
79 | labeldict=labeldict,
80 | bos=bos,
81 | eos=eos)
82 |
83 | print(20*"=", " Preprocessing train set ", 20*"=")
84 | print("\t* Reading data...")
85 | data = preprocessor.read_data(os.path.join(inputdir, train_file))
86 |
87 | print("\t* Computing worddict and saving it...")
88 | preprocessor.build_worddict(data)
89 | with open(os.path.join(targetdir, "worddict.pkl"), 'wb') as pkl_file:
90 | pickle.dump(preprocessor.worddict, pkl_file)
91 |
92 | print("\t* Transforming words in premises and hypotheses to indices...")
93 | transformed_data = preprocessor.transform_to_indices(data)
94 | print("\t* Saving result...")
95 | with open(os.path.join(targetdir, "train_data.pkl"), 'wb') as pkl_file:
96 | pickle.dump(transformed_data, pkl_file)
97 |
98 | # -------------------- Validation data preprocessing -------------------- #
99 | print(20*"=", " Preprocessing dev sets ", 20*"=")
100 | print("\t* Reading matched dev data...")
101 | data = preprocessor.read_data(os.path.join(inputdir, matched_dev_file))
102 |
103 | print("\t* Transforming words in premises and hypotheses to indices...")
104 | transformed_data = preprocessor.transform_to_indices(data)
105 | print("\t* Saving result...")
106 | with open(os.path.join(targetdir, "matched_dev_data.pkl"), 'wb') as pkl_file:
107 | pickle.dump(transformed_data, pkl_file)
108 |
109 | print("\t* Reading mismatched dev data...")
110 | data = preprocessor.read_data(os.path.join(inputdir, mismatched_dev_file))
111 |
112 | print("\t* Transforming words in premises and hypotheses to indices...")
113 | transformed_data = preprocessor.transform_to_indices(data)
114 | print("\t* Saving result...")
115 | with open(os.path.join(targetdir, "mismatched_dev_data.pkl"), 'wb') as pkl_file:
116 | pickle.dump(transformed_data, pkl_file)
117 |
118 | # # -------------------- Test data preprocessing -------------------- #
119 | # print(20*"=", " Preprocessing test sets ", 20*"=")
120 | # print("\t* Reading matched test data...")
121 | # data = preprocessor.read_data(os.path.join(inputdir, matched_test_file))
122 | #
123 | # print("\t* Transforming words in premises and hypotheses to indices...")
124 | # transformed_data = preprocessor.transform_to_indices(data)
125 | # print("\t* Saving result...")
126 | # with open(os.path.join(targetdir, "matched_test_data.pkl"), 'wb') as pkl_file:
127 | # pickle.dump(transformed_data, pkl_file)
128 | #
129 | # print("\t* Reading mismatched test data...")
130 | # data = preprocessor.read_data(os.path.join(inputdir, mismatched_test_file))
131 | #
132 | # print("\t* Transforming words in premises and hypotheses to indices...")
133 | # transformed_data = preprocessor.transform_to_indices(data)
134 | # print("\t* Saving result...")
135 | # with open(os.path.join(targetdir, "mismatched_test_data.pkl"), 'wb') as pkl_file:
136 | # pickle.dump(transformed_data, pkl_file)
137 |
138 | # -------------------- Embeddings preprocessing -------------------- #
139 | print(20*"=", " Preprocessing embeddings ", 20*"=")
140 | print("\t* Building embedding matrix and saving it...")
141 | embed_matrix = preprocessor.build_embedding_matrix(embeddings_file)
142 | with open(os.path.join(targetdir, "embeddings.pkl"), 'wb') as pkl_file:
143 | pickle.dump(embed_matrix, pkl_file)
144 |
145 |
146 | if __name__ == "__main__":
147 | import argparse
148 |
149 | parser = argparse.ArgumentParser(description='Preprocess the MultiNLI dataset')
150 | parser.add_argument('--config',
151 | default="../config/preprocessing/mnli_preprocessing.json",
152 | help='Path to a configuration file for preprocessing MultiNLI')
153 | args = parser.parse_args()
154 |
155 | with open(os.path.normpath(args.config), 'r') as cfg_file:
156 | config = json.load(cfg_file)
157 |
158 | preprocess_MNLI_data(os.path.normpath(config["data_dir"]),
159 | os.path.normpath(config["embeddings_file"]),
160 | os.path.normpath(config["target_dir"]),
161 | lowercase=config["lowercase"],
162 | ignore_punctuation=config["ignore_punctuation"],
163 | num_words=config["num_words"],
164 | stopwords=config["stopwords"],
165 | labeldict=config["labeldict"],
166 | bos=config["bos"],
167 | eos=config["eos"])
168 |
--------------------------------------------------------------------------------
/ESIM/scripts/preprocessing/preprocess_snli.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess the SNLI dataset and word embeddings to be used by the ESIM model.
3 | """
4 | # Aurelien Coet, 2018.
5 |
6 | import os
7 | import pickle
8 | import fnmatch
9 | import json
10 |
11 | from esim.data import Preprocessor
12 |
13 |
14 | def preprocess_SNLI_data(inputdir,
15 | embeddings_file,
16 | targetdir,
17 | lowercase=False,
18 | ignore_punctuation=False,
19 | num_words=None,
20 | stopwords=[],
21 | labeldict={},
22 | bos=None,
23 | eos=None):
24 | """
25 | Preprocess the data from the SNLI corpus so it can be used by the
26 | ESIM model.
27 | Compute a worddict from the train set, and transform the words in
28 | the sentences of the corpus to their indices, as well as the labels.
29 | Build an embedding matrix from pretrained word vectors.
30 | The preprocessed data is saved in pickled form in some target directory.
31 |
32 | Args:
33 | inputdir: The path to the directory containing the NLI corpus.
34 | embeddings_file: The path to the file containing the pretrained
35 | word vectors that must be used to build the embedding matrix.
36 | targetdir: The path to the directory where the preprocessed data
37 | must be saved.
38 | lowercase: Boolean value indicating whether to lowercase the premises
39 | and hypotheseses in the input data. Defautls to False.
40 | ignore_punctuation: Boolean value indicating whether to remove
41 | punctuation from the input data. Defaults to False.
42 | num_words: Integer value indicating the size of the vocabulary to use
43 | for the word embeddings. If set to None, all words are kept.
44 | Defaults to None.
45 | stopwords: A list of words that must be ignored when preprocessing
46 | the data. Defaults to an empty list.
47 | bos: A string indicating the symbol to use for beginning of sentence
48 | tokens. If set to None, bos tokens aren't used. Defaults to None.
49 | eos: A string indicating the symbol to use for end of sentence tokens.
50 | If set to None, eos tokens aren't used. Defaults to None.
51 | """
52 | if not os.path.exists(targetdir):
53 | os.makedirs(targetdir)
54 |
55 | # Retrieve the train, dev and test data files from the dataset directory.
56 | train_file = ""
57 | dev_file = ""
58 | test_file = ""
59 | for file in os.listdir(inputdir):
60 | if fnmatch.fnmatch(file, '*_train.txt'):
61 | train_file = file
62 | elif fnmatch.fnmatch(file, '*_dev.txt'):
63 | dev_file = file
64 | elif fnmatch.fnmatch(file, '*_test.txt'):
65 | test_file = file
66 |
67 | # -------------------- Train data preprocessing -------------------- #
68 | preprocessor = Preprocessor(lowercase=lowercase,
69 | ignore_punctuation=ignore_punctuation,
70 | num_words=num_words,
71 | stopwords=stopwords,
72 | labeldict=labeldict,
73 | bos=bos,
74 | eos=eos)
75 |
76 | print(20*"=", " Preprocessing train set ", 20*"=")
77 | print("\t* Reading data...")
78 | data = preprocessor.read_data(os.path.join(inputdir, train_file))
79 |
80 | print("\t* Computing worddict and saving it...")
81 | preprocessor.build_worddict(data)
82 | with open(os.path.join(targetdir, "worddict.pkl"), 'wb') as pkl_file:
83 | pickle.dump(preprocessor.worddict, pkl_file)
84 |
85 | print("\t* Transforming words in premises and hypotheses to indices...")
86 | transformed_data = preprocessor.transform_to_indices(data)
87 | print("\t* Saving result...")
88 | with open(os.path.join(targetdir, "train_data.pkl"), 'wb') as pkl_file:
89 | pickle.dump(transformed_data, pkl_file)
90 |
91 | # -------------------- Validation data preprocessing -------------------- #
92 | print(20*"=", " Preprocessing dev set ", 20*"=")
93 | print("\t* Reading data...")
94 | data = preprocessor.read_data(os.path.join(inputdir, dev_file))
95 |
96 | print("\t* Transforming words in premises and hypotheses to indices...")
97 | transformed_data = preprocessor.transform_to_indices(data)
98 | print("\t* Saving result...")
99 | with open(os.path.join(targetdir, "dev_data.pkl"), 'wb') as pkl_file:
100 | pickle.dump(transformed_data, pkl_file)
101 |
102 | # -------------------- Test data preprocessing -------------------- #
103 | print(20*"=", " Preprocessing test set ", 20*"=")
104 | print("\t* Reading data...")
105 | data = preprocessor.read_data(os.path.join(inputdir, test_file))
106 |
107 | print("\t* Transforming words in premises and hypotheses to indices...")
108 | transformed_data = preprocessor.transform_to_indices(data)
109 | print("\t* Saving result...")
110 | with open(os.path.join(targetdir, "test_data.pkl"), 'wb') as pkl_file:
111 | pickle.dump(transformed_data, pkl_file)
112 |
113 | # -------------------- Embeddings preprocessing -------------------- #
114 | print(20*"=", " Preprocessing embeddings ", 20*"=")
115 | print("\t* Building embedding matrix and saving it...")
116 | embed_matrix = preprocessor.build_embedding_matrix(embeddings_file)
117 | with open(os.path.join(targetdir, "embeddings.pkl"), 'wb') as pkl_file:
118 | pickle.dump(embed_matrix, pkl_file)
119 |
120 |
121 | if __name__ == "__main__":
122 | import argparse
123 |
124 | parser = argparse.ArgumentParser(description='Preprocess the SNLI dataset')
125 | parser.add_argument('--config',
126 | default="../config/preprocessing/snli_preprocessing.json",
127 | help='Path to a configuration file for preprocessing SNLI')
128 | args = parser.parse_args()
129 |
130 | with open(os.path.normpath(args.config), 'r') as cfg_file:
131 | config = json.load(cfg_file)
132 |
133 | preprocess_SNLI_data(os.path.normpath(config["data_dir"]),
134 | os.path.normpath(config["embeddings_file"]),
135 | os.path.normpath(config["target_dir"]),
136 | lowercase=config["lowercase"],
137 | ignore_punctuation=config["ignore_punctuation"],
138 | num_words=config["num_words"],
139 | stopwords=config["stopwords"],
140 | labeldict=config["labeldict"],
141 | bos=config["bos"],
142 | eos=config["eos"])
143 |
--------------------------------------------------------------------------------
/ESIM/scripts/testing/test_mnli.py:
--------------------------------------------------------------------------------
1 | """
2 | Test the ESIM model on the preprocessed MultiNLI dataset.
3 | """
4 | # Aurelien Coet, 2019.
5 |
6 | import os
7 | import pickle
8 | import argparse
9 | import torch
10 | import json
11 |
12 | from torch.utils.data import DataLoader
13 | from esim.data import NLIDataset
14 | from esim.model import ESIM
15 |
16 |
17 | def predict(model, dataloader, labeldict):
18 | """
19 | Predict the labels of an unlabelled test set with a pretrained model.
20 |
21 | Args:
22 | model: The torch module which must be used to make predictions.
23 | dataloader: A DataLoader object to iterate over some dataset.
24 | labeldict: A dictionary associating labels to integer values.
25 |
26 | Returns:
27 | A dictionary associating pair ids to predicted labels.
28 | """
29 | # Switch the model to eval mode.
30 | model.eval()
31 | device = model.device
32 |
33 | # Revert the labeldict to associate integers to labels.
34 | labels = {index: label for label, index in labeldict.items()}
35 | predictions = {}
36 |
37 | # Deactivate autograd for evaluation.
38 | with torch.no_grad():
39 | for batch in dataloader:
40 |
41 | # Move input and output data to the GPU if one is used.
42 | ids = batch["id"]
43 | premises = batch['premise'].to(device)
44 | premises_lengths = batch['premise_length'].to(device)
45 | hypotheses = batch['hypothesis'].to(device)
46 | hypotheses_lengths = batch['hypothesis_length'].to(device)
47 |
48 | _, probs = model(premises,
49 | premises_lengths,
50 | hypotheses,
51 | hypotheses_lengths)
52 |
53 | _, preds = probs.max(dim=1)
54 |
55 | for i, pair_id in enumerate(ids):
56 | predictions[pair_id] = labels[int(preds[i])]
57 |
58 | return predictions
59 |
60 |
61 | def main(test_files, pretrained_file, labeldict, output_dir, batch_size=32):
62 | """
63 | Test the ESIM model with pretrained weights on the MultiNLI dataset.
64 |
65 | Args:
66 | test_files: The paths to the preprocessed matched and mismatched MNLI
67 | test sets.
68 | pretrained_file: The path to a checkpoint produced by the
69 | 'train_mnli' script.
70 | labeldict: A dictionary associating labels (classes) to integer values.
71 | output_dir: The path to a directory where the predictions of the model
72 | must be saved.
73 | batch_size: The size of the batches used for testing. Defaults to 32.
74 | """
75 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
76 |
77 | print(20 * "=", " Preparing for testing ", 20 * "=")
78 |
79 | output_dir = os.path.normpath(output_dir)
80 | if not os.path.exists(output_dir):
81 | os.makedirs(output_dir)
82 |
83 | checkpoint = torch.load(pretrained_file)
84 |
85 | # Retrieve model parameters from the checkpoint.
86 | vocab_size = checkpoint['model']['_word_embedding.weight'].size(0)
87 | embedding_dim = checkpoint['model']['_word_embedding.weight'].size(1)
88 | hidden_size = checkpoint['model']['_projection.0.weight'].size(0)
89 | num_classes = checkpoint['model']['_classification.4.weight'].size(0)
90 |
91 | print("\t* Loading test data...")
92 | with open(os.path.normpath(test_files["matched"]), 'rb') as pkl:
93 | matched_test_data = NLIDataset(pickle.load(pkl))
94 | with open(os.path.normpath(test_files["mismatched"]), 'rb') as pkl:
95 | mismatched_test_data = NLIDataset(pickle.load(pkl))
96 |
97 | matched_test_loader = DataLoader(matched_test_data,
98 | shuffle=False,
99 | batch_size=batch_size)
100 | mismatched_test_loader = DataLoader(mismatched_test_data,
101 | shuffle=False,
102 | batch_size=batch_size)
103 |
104 | print("\t* Building model...")
105 | model = ESIM(vocab_size,
106 | embedding_dim,
107 | hidden_size,
108 | num_classes=num_classes,
109 | device=device).to(device)
110 |
111 | model.load_state_dict(checkpoint['model'])
112 |
113 | print(20 * "=",
114 | " Prediction on MNLI with ESIM model on device: {} ".format(device),
115 | 20 * "=")
116 |
117 | print("\t* Prediction for matched test set...")
118 | predictions = predict(model, matched_test_loader, labeldict)
119 |
120 | with open(os.path.join(output_dir, "matched_predictions.csv"), 'w') as output_f:
121 | output_f.write("pairID,gold_label\n")
122 | for pair_id in predictions:
123 | output_f.write(pair_id+","+predictions[pair_id]+"\n")
124 |
125 | print("\t* Prediction for mismatched test set...")
126 | predictions = predict(model, mismatched_test_loader, labeldict)
127 |
128 | with open(os.path.join(output_dir, "mismatched_predictions.csv"), 'w') as output_f:
129 | output_f.write("pairID,gold_label\n")
130 | for pair_id in predictions:
131 | output_f.write(pair_id+","+predictions[pair_id]+"\n")
132 |
133 |
134 | if __name__ == "__main__":
135 | parser = argparse.ArgumentParser(description='Test the ESIM model on\
136 | the MNLI matched and mismatched test sets')
137 | parser.add_argument('checkpoint',
138 | help="Path to a checkpoint with a pretrained model")
139 | parser.add_argument('--config', default='../config/testing/mnli_testing.json',
140 | help='Path to a configuration file')
141 | args = parser.parse_args()
142 |
143 | with open(os.path.normpath(args.config), 'r') as config_file:
144 | config = json.load(config_file)
145 |
146 | main(config['test_files'],
147 | args.checkpoint,
148 | config['labeldict'],
149 | config['output_dir'],
150 | config['batch_size'])
151 |
--------------------------------------------------------------------------------
/ESIM/scripts/testing/test_snli.py:
--------------------------------------------------------------------------------
1 | """
2 | Test the ESIM model on some preprocessed dataset.
3 | """
4 | # Aurelien Coet, 2018.
5 |
6 | import time
7 | import pickle
8 | import argparse
9 | import torch
10 |
11 | from torch.utils.data import DataLoader
12 | from esim.data import NLIDataset
13 | from esim.model import ESIM
14 | from esim.utils import correct_predictions
15 |
16 |
17 | def test(model, dataloader):
18 | """
19 | Test the accuracy of a model on some labelled test dataset.
20 |
21 | Args:
22 | model: The torch module on which testing must be performed.
23 | dataloader: A DataLoader object to iterate over some dataset.
24 |
25 | Returns:
26 | batch_time: The average time to predict the classes of a batch.
27 | total_time: The total time to process the whole dataset.
28 | accuracy: The accuracy of the model on the input data.
29 | """
30 | # Switch the model to eval mode.
31 | model.eval()
32 | device = model.device
33 |
34 | time_start = time.time()
35 | batch_time = 0.0
36 | accuracy = 0.0
37 |
38 | # Deactivate autograd for evaluation.
39 | with torch.no_grad():
40 | for batch in dataloader:
41 | batch_start = time.time()
42 |
43 | # Move input and output data to the GPU if one is used.
44 | premises = batch['premise'].to(device)
45 | premises_lengths = batch['premise_length'].to(device)
46 | hypotheses = batch['hypothesis'].to(device)
47 | hypotheses_lengths = batch['hypothesis_length'].to(device)
48 | labels = batch['label'].to(device)
49 |
50 | _, probs = model(premises,
51 | premises_lengths,
52 | hypotheses,
53 | hypotheses_lengths)
54 |
55 | accuracy += correct_predictions(probs, labels)
56 | batch_time += time.time() - batch_start
57 |
58 | batch_time /= len(dataloader)
59 | total_time = time.time() - time_start
60 | accuracy /= (len(dataloader.dataset))
61 |
62 | return batch_time, total_time, accuracy
63 |
64 |
65 | def main(test_file, pretrained_file, batch_size=32):
66 | """
67 | Test the ESIM model with pretrained weights on some dataset.
68 |
69 | Args:
70 | test_file: The path to a file containing preprocessed NLI data.
71 | pretrained_file: The path to a checkpoint produced by the
72 | 'train_model' script.
73 | vocab_size: The number of words in the vocabulary of the model
74 | being tested.
75 | embedding_dim: The size of the embeddings in the model.
76 | hidden_size: The size of the hidden layers in the model. Must match
77 | the size used during training. Defaults to 300.
78 | num_classes: The number of classes in the output of the model. Must
79 | match the value used during training. Defaults to 3.
80 | batch_size: The size of the batches used for testing. Defaults to 32.
81 | """
82 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
83 |
84 | print(20 * "=", " Preparing for testing ", 20 * "=")
85 |
86 | checkpoint = torch.load(pretrained_file)
87 |
88 | # Retrieving model parameters from checkpoint.
89 | vocab_size = checkpoint['model']['_word_embedding.weight'].size(0)
90 | embedding_dim = checkpoint['model']['_word_embedding.weight'].size(1)
91 | hidden_size = checkpoint['model']['_projection.0.weight'].size(0)
92 | num_classes = checkpoint['model']['_classification.4.weight'].size(0)
93 |
94 | print("\t* Loading test data...")
95 | with open(test_file, 'rb') as pkl:
96 | test_data = NLIDataset(pickle.load(pkl))
97 |
98 | test_loader = DataLoader(test_data, shuffle=False, batch_size=batch_size)
99 |
100 | print("\t* Building model...")
101 | model = ESIM(vocab_size,
102 | embedding_dim,
103 | hidden_size,
104 | num_classes=num_classes,
105 | device=device).to(device)
106 |
107 | model.load_state_dict(checkpoint['model'])
108 |
109 | print(20 * "=",
110 | " Testing ESIM model on device: {} ".format(device),
111 | 20 * "=")
112 | batch_time, total_time, accuracy = test(model, test_loader)
113 |
114 | print("-> Average batch processing time: {:.4f}s, total test time:\
115 | {:.4f}s, accuracy: {:.4f}%".format(batch_time, total_time, (accuracy*100)))
116 |
117 |
118 | if __name__ == "__main__":
119 | parser = argparse.ArgumentParser(description='Test the ESIM model on\
120 | some dataset')
121 | parser.add_argument('test_data',
122 | help="Path to a file containing preprocessed test data")
123 | parser.add_argument('checkpoint',
124 | help="Path to a checkpoint with a pretrained model")
125 | parser.add_argument('--batch_size', type=int, default=32,
126 | help='Batch size to use during testing')
127 | args = parser.parse_args()
128 |
129 | main(args.test_data,
130 | args.checkpoint,
131 | args.batch_size)
132 |
--------------------------------------------------------------------------------
/ESIM/scripts/training/train_mnli.py:
--------------------------------------------------------------------------------
1 | """
2 | Train the ESIM model on the preprocessed MultiNLI dataset.
3 | """
4 | # Aurelien Coet, 2019.
5 |
6 | import os
7 | import argparse
8 | import pickle
9 | import torch
10 | import json
11 |
12 | import matplotlib.pyplot as plt
13 | import torch.nn as nn
14 |
15 | from torch.utils.data import DataLoader
16 |
17 | from esim.data import NLIDataset
18 | from esim.model import ESIM
19 | from utils import train, validate
20 |
21 |
22 | def main(train_file,
23 | valid_files,
24 | embeddings_file,
25 | target_dir,
26 | hidden_size=300,
27 | dropout=0.5,
28 | num_classes=3,
29 | epochs=64,
30 | batch_size=32,
31 | lr=0.0004,
32 | patience=5,
33 | max_grad_norm=10.0,
34 | checkpoint=None):
35 | """
36 | Train the ESIM model on the SNLI dataset.
37 |
38 | Args:
39 | train_file: A path to some preprocessed data that must be used
40 | to train the model.
41 | valid_files: A dict containing the paths to the preprocessed matched
42 | and mismatched datasets that must be used to validate the model.
43 | embeddings_file: A path to some preprocessed word embeddings that
44 | must be used to initialise the model.
45 | target_dir: The path to a directory where the trained model must
46 | be saved.
47 | hidden_size: The size of the hidden layers in the model. Defaults
48 | to 300.
49 | dropout: The dropout rate to use in the model. Defaults to 0.5.
50 | num_classes: The number of classes in the output of the model.
51 | Defaults to 3.
52 | epochs: The maximum number of epochs for training. Defaults to 64.
53 | batch_size: The size of the batches for training. Defaults to 32.
54 | lr: The learning rate for the optimizer. Defaults to 0.0004.
55 | patience: The patience to use for early stopping. Defaults to 5.
56 | checkpoint: A checkpoint from which to continue training. If None,
57 | training starts from scratch. Defaults to None.
58 | """
59 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
60 |
61 | print(20 * "=", " Preparing for training ", 20 * "=")
62 |
63 | if not os.path.exists(target_dir):
64 | os.makedirs(target_dir)
65 |
66 | # -------------------- Data loading ------------------- #
67 | print("\t* Loading training data...")
68 | with open(train_file, 'rb') as pkl:
69 | train_data = NLIDataset(pickle.load(pkl))
70 |
71 | train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
72 |
73 | print("\t* Loading validation data...")
74 | with open(os.path.normpath(valid_files["matched"]), 'rb') as pkl:
75 | matched_valid_data = NLIDataset(pickle.load(pkl))
76 |
77 | with open(os.path.normpath(valid_files["mismatched"]), 'rb') as pkl:
78 | mismatched_valid_data = NLIDataset(pickle.load(pkl))
79 |
80 | matched_valid_loader = DataLoader(matched_valid_data,
81 | shuffle=False,
82 | batch_size=batch_size)
83 | mismatched_valid_loader = DataLoader(mismatched_valid_data,
84 | shuffle=False,
85 | batch_size=batch_size)
86 |
87 | # -------------------- Model definition ------------------- #
88 | print('\t* Building model...')
89 | with open(embeddings_file, 'rb') as pkl:
90 | embeddings = torch.tensor(pickle.load(pkl), dtype=torch.float)\
91 | .to(device)
92 |
93 | model = ESIM(embeddings.shape[0],
94 | embeddings.shape[1],
95 | hidden_size,
96 | embeddings=embeddings,
97 | dropout=dropout,
98 | num_classes=num_classes,
99 | device=device).to(device)
100 |
101 | # -------------------- Preparation for training ------------------- #
102 | criterion = nn.CrossEntropyLoss()
103 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
104 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
105 | mode='max',
106 | factor=0.5,
107 | patience=0)
108 |
109 | best_score = 0.0
110 | start_epoch = 1
111 |
112 | # Data for loss curves plot.
113 | epochs_count = []
114 | train_losses = []
115 | matched_valid_losses = []
116 | mismatched_valid_losses = []
117 |
118 | # Continuing training from a checkpoint if one was given as argument.
119 | if checkpoint:
120 | checkpoint = torch.load(checkpoint)
121 | start_epoch = checkpoint['epoch'] + 1
122 | best_score = checkpoint['best_score']
123 |
124 | print("\t* Training will continue on existing model from epoch {}..."
125 | .format(start_epoch))
126 |
127 | model.load_state_dict(checkpoint['model'])
128 | optimizer.load_state_dict(checkpoint['optimizer'])
129 | epochs_count = checkpoint['epochs_count']
130 | train_losses = checkpoint['train_losses']
131 | matched_valid_losses = checkpoint['match_valid_losses']
132 | mismatched_valid_losses = checkpoint['mismatch_valid_losses']
133 |
134 | # Compute loss and accuracy before starting (or resuming) training.
135 | _, valid_loss, valid_accuracy = validate(model,
136 | matched_valid_loader,
137 | criterion)
138 | print("\t* Validation loss before training on matched data: {:.4f}, accuracy: {:.4f}%"
139 | .format(valid_loss, (valid_accuracy*100)))
140 |
141 | _, valid_loss, valid_accuracy = validate(model,
142 | mismatched_valid_loader,
143 | criterion)
144 | print("\t* Validation loss before training on mismatched data: {:.4f}, accuracy: {:.4f}%"
145 | .format(valid_loss, (valid_accuracy*100)))
146 |
147 | # -------------------- Training epochs ------------------- #
148 | print("\n",
149 | 20 * "=",
150 | "Training ESIM model on device: {}".format(device),
151 | 20 * "=")
152 |
153 | patience_counter = 0
154 | for epoch in range(start_epoch, epochs+1):
155 | epochs_count.append(epoch)
156 |
157 | print("* Training epoch {}:".format(epoch))
158 | epoch_time, epoch_loss, epoch_accuracy = train(model,
159 | train_loader,
160 | optimizer,
161 | criterion,
162 | epoch,
163 | max_grad_norm)
164 |
165 | train_losses.append(epoch_loss)
166 | print("-> Training time: {:.4f}s, loss = {:.4f}, accuracy: {:.4f}%\n"
167 | .format(epoch_time, epoch_loss, (epoch_accuracy*100)))
168 |
169 | print("* Validation for epoch {} on matched data:".format(epoch))
170 | epoch_time, epoch_loss, epoch_accuracy = validate(model,
171 | matched_valid_loader,
172 | criterion)
173 | matched_valid_losses.append(epoch_loss)
174 | print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%"
175 | .format(epoch_time, epoch_loss, (epoch_accuracy*100)))
176 |
177 | print("* Validation for epoch {} on mismatched data:".format(epoch))
178 | epoch_time, epoch_loss, mis_epoch_accuracy = validate(model,
179 | mismatched_valid_loader,
180 | criterion)
181 | mismatched_valid_losses.append(epoch_loss)
182 | print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n"
183 | .format(epoch_time, epoch_loss, (mis_epoch_accuracy*100)))
184 |
185 | # Update the optimizer's learning rate with the scheduler.
186 | scheduler.step(epoch_accuracy)
187 |
188 | # Early stopping on validation accuracy.
189 | if epoch_accuracy < best_score:
190 | patience_counter += 1
191 | else:
192 | best_score = epoch_accuracy
193 | patience_counter = 0
194 | # Save the best model. The optimizer is not saved to avoid having
195 | # a checkpoint file that is too heavy to be shared. To resume
196 | # training from the best model, use the 'esim_*.pth.tar'
197 | # checkpoints instead.
198 | torch.save({'epoch': epoch,
199 | 'model': model.state_dict(),
200 | 'best_score': best_score,
201 | 'epochs_count': epochs_count,
202 | 'train_losses': train_losses,
203 | 'match_valid_losses': matched_valid_losses,
204 | 'mismatch_valid_losses': mismatched_valid_losses},
205 | os.path.join(target_dir, "best.pth.tar"))
206 |
207 | # # Save the model at each epoch.
208 | # torch.save({'epoch': epoch,
209 | # 'model': model.state_dict(),
210 | # 'best_score': best_score,
211 | # 'optimizer': optimizer.state_dict(),
212 | # 'epochs_count': epochs_count,
213 | # 'train_losses': train_losses,
214 | # 'match_valid_losses': matched_valid_losses,
215 | # 'mismatch_valid_losses': mismatched_valid_losses},
216 | # os.path.join(target_dir, "esim_{}.pth.tar".format(epoch)))
217 |
218 | if patience_counter >= patience:
219 | print("-> Early stopping: patience limit reached, stopping...")
220 | break
221 |
222 | # Plotting of the loss curves for the train and validation sets.
223 | plt.figure()
224 | plt.plot(epochs_count, train_losses, '-r')
225 | plt.plot(epochs_count, matched_valid_losses, '-b')
226 | plt.plot(epochs_count, mismatched_valid_losses, '-g')
227 | plt.xlabel('epoch')
228 | plt.ylabel('loss')
229 | plt.legend(['Training loss',
230 | 'Validation loss (matched set)',
231 | 'Validation loss (mismatched set)'])
232 | plt.title('Cross entropy loss')
233 | plt.show()
234 |
235 |
236 | if __name__ == "__main__":
237 | parser = argparse.ArgumentParser(description='Train the ESIM model on MultiNLI')
238 | parser.add_argument('--config',
239 | default="../config/training/mnli_training.json",
240 | help='Path to a json configuration file')
241 | parser.add_argument('--checkpoint',
242 | default=None,
243 | help='path to a checkpoint file to resume training')
244 | args = parser.parse_args()
245 |
246 | with open(os.path.normpath(args.config), 'r') as config_file:
247 | config = json.load(config_file)
248 |
249 | main(os.path.normpath(config["train_data"]),
250 | config["valid_data"],
251 | os.path.normpath(config["embeddings"]),
252 | os.path.normpath(config["target_dir"]),
253 | config["hidden_size"],
254 | config["dropout"],
255 | config["num_classes"],
256 | config["epochs"],
257 | config["batch_size"],
258 | config["lr"],
259 | config["patience"],
260 | config["max_gradient_norm"],
261 | args.checkpoint)
262 |
--------------------------------------------------------------------------------
/ESIM/scripts/training/train_snli.py:
--------------------------------------------------------------------------------
1 | """
2 | Train the ESIM model on the preprocessed SNLI dataset.
3 | """
4 | # Aurelien Coet, 2018.
5 |
6 | import os
7 | import argparse
8 | import pickle
9 | import torch
10 | import json
11 |
12 | import matplotlib.pyplot as plt
13 | import torch.nn as nn
14 |
15 | from torch.utils.data import DataLoader
16 |
17 | from esim.data import NLIDataset
18 | from esim.model import ESIM
19 | from utils import train, validate
20 |
21 |
22 | def main(train_file,
23 | valid_file,
24 | embeddings_file,
25 | target_dir,
26 | hidden_size=300,
27 | dropout=0.5,
28 | num_classes=3,
29 | epochs=64,
30 | batch_size=32,
31 | lr=0.0004,
32 | patience=5,
33 | max_grad_norm=10.0,
34 | checkpoint=None):
35 | """
36 | Train the ESIM model on the SNLI dataset.
37 |
38 | Args:
39 | train_file: A path to some preprocessed data that must be used
40 | to train the model.
41 | valid_file: A path to some preprocessed data that must be used
42 | to validate the model.
43 | embeddings_file: A path to some preprocessed word embeddings that
44 | must be used to initialise the model.
45 | target_dir: The path to a directory where the trained model must
46 | be saved.
47 | hidden_size: The size of the hidden layers in the model. Defaults
48 | to 300.
49 | dropout: The dropout rate to use in the model. Defaults to 0.5.
50 | num_classes: The number of classes in the output of the model.
51 | Defaults to 3.
52 | epochs: The maximum number of epochs for training. Defaults to 64.
53 | batch_size: The size of the batches for training. Defaults to 32.
54 | lr: The learning rate for the optimizer. Defaults to 0.0004.
55 | patience: The patience to use for early stopping. Defaults to 5.
56 | checkpoint: A checkpoint from which to continue training. If None,
57 | training starts from scratch. Defaults to None.
58 | """
59 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
60 |
61 | print(20 * "=", " Preparing for training ", 20 * "=")
62 |
63 | if not os.path.exists(target_dir):
64 | os.makedirs(target_dir)
65 |
66 | # -------------------- Data loading ------------------- #
67 | print("\t* Loading training data...")
68 | with open(train_file, 'rb') as pkl:
69 | train_data = NLIDataset(pickle.load(pkl))
70 |
71 | train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
72 |
73 | print("\t* Loading validation data...")
74 | with open(valid_file, 'rb') as pkl:
75 | valid_data = NLIDataset(pickle.load(pkl))
76 |
77 | valid_loader = DataLoader(valid_data, shuffle=False, batch_size=batch_size)
78 |
79 | # -------------------- Model definition ------------------- #
80 | print('\t* Building model...')
81 | with open(embeddings_file, 'rb') as pkl:
82 | embeddings = torch.tensor(pickle.load(pkl), dtype=torch.float)\
83 | .to(device)
84 |
85 | model = ESIM(embeddings.shape[0],
86 | embeddings.shape[1],
87 | hidden_size,
88 | embeddings=embeddings,
89 | dropout=dropout,
90 | num_classes=num_classes,
91 | device=device).to(device)
92 |
93 | # -------------------- Preparation for training ------------------- #
94 | criterion = nn.CrossEntropyLoss()
95 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
96 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
97 | mode='max',
98 | factor=0.5,
99 | patience=0)
100 |
101 | best_score = 0.0
102 | start_epoch = 1
103 |
104 | # Data for loss curves plot.
105 | epochs_count = []
106 | train_losses = []
107 | valid_losses = []
108 |
109 | # Continuing training from a checkpoint if one was given as argument.
110 | if checkpoint:
111 | checkpoint = torch.load(checkpoint)
112 | start_epoch = checkpoint['epoch'] + 1
113 | best_score = checkpoint['best_score']
114 |
115 | print("\t* Training will continue on existing model from epoch {}..."
116 | .format(start_epoch))
117 |
118 | model.load_state_dict(checkpoint['model'])
119 | optimizer.load_state_dict(checkpoint['optimizer'])
120 | epochs_count = checkpoint['epochs_count']
121 | train_losses = checkpoint['train_losses']
122 | valid_losses = checkpoint['valid_losses']
123 |
124 | # Compute loss and accuracy before starting (or resuming) training.
125 | _, valid_loss, valid_accuracy = validate(model,
126 | valid_loader,
127 | criterion)
128 | print("\t* Validation loss before training: {:.4f}, accuracy: {:.4f}%"
129 | .format(valid_loss, (valid_accuracy*100)))
130 |
131 | # -------------------- Training epochs ------------------- #
132 | print("\n",
133 | 20 * "=",
134 | "Training ESIM model on device: {}".format(device),
135 | 20 * "=")
136 |
137 | patience_counter = 0
138 | for epoch in range(start_epoch, epochs+1):
139 | epochs_count.append(epoch)
140 |
141 | print("* Training epoch {}:".format(epoch))
142 | epoch_time, epoch_loss, epoch_accuracy = train(model,
143 | train_loader,
144 | optimizer,
145 | criterion,
146 | epoch,
147 | max_grad_norm)
148 |
149 | train_losses.append(epoch_loss)
150 | print("-> Training time: {:.4f}s, loss = {:.4f}, accuracy: {:.4f}%"
151 | .format(epoch_time, epoch_loss, (epoch_accuracy*100)))
152 |
153 | print("* Validation for epoch {}:".format(epoch))
154 | epoch_time, epoch_loss, epoch_accuracy = validate(model,
155 | valid_loader,
156 | criterion)
157 |
158 | valid_losses.append(epoch_loss)
159 | print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n"
160 | .format(epoch_time, epoch_loss, (epoch_accuracy*100)))
161 |
162 | # Update the optimizer's learning rate with the scheduler.
163 | scheduler.step(epoch_accuracy)
164 |
165 | # Early stopping on validation accuracy.
166 | if epoch_accuracy < best_score:
167 | patience_counter += 1
168 | else:
169 | best_score = epoch_accuracy
170 | patience_counter = 0
171 | # Save the best model. The optimizer is not saved to avoid having
172 | # a checkpoint file that is too heavy to be shared. To resume
173 | # training from the best model, use the 'esim_*.pth.tar'
174 | # checkpoints instead.
175 | torch.save({'epoch': epoch,
176 | 'model': model.state_dict(),
177 | 'best_score': best_score,
178 | 'epochs_count': epochs_count,
179 | 'train_losses': train_losses,
180 | 'valid_losses': valid_losses},
181 | os.path.join(target_dir, "best.pth.tar"))
182 |
183 | # Save the model at each epoch.
184 | torch.save({'epoch': epoch,
185 | 'model': model.state_dict(),
186 | 'best_score': best_score,
187 | 'optimizer': optimizer.state_dict(),
188 | 'epochs_count': epochs_count,
189 | 'train_losses': train_losses,
190 | 'valid_losses': valid_losses},
191 | os.path.join(target_dir, "esim_{}.pth.tar".format(epoch)))
192 |
193 | if patience_counter >= patience:
194 | print("-> Early stopping: patience limit reached, stopping...")
195 | break
196 |
197 | # Plotting of the loss curves for the train and validation sets.
198 | plt.figure()
199 | plt.plot(epochs_count, train_losses, '-r')
200 | plt.plot(epochs_count, valid_losses, '-b')
201 | plt.xlabel('epoch')
202 | plt.ylabel('loss')
203 | plt.legend(['Training loss', 'Validation loss'])
204 | plt.title('Cross entropy loss')
205 | plt.show()
206 |
207 |
208 | if __name__ == "__main__":
209 | parser = argparse.ArgumentParser(description='Train the ESIM model on SNLI')
210 | parser.add_argument('--config',
211 | default="../config/training/snli_training.json",
212 | help='Path to a json configuration file')
213 | parser.add_argument('--checkpoint',
214 | default=None,
215 | help='path to a checkpoint file to resume training')
216 | args = parser.parse_args()
217 |
218 | with open(os.path.normpath(args.config), 'r') as config_file:
219 | config = json.load(config_file)
220 |
221 | main(os.path.normpath(config["train_data"]),
222 | os.path.normpath(config["valid_data"]),
223 | os.path.normpath(config["embeddings"]),
224 | os.path.normpath(config["target_dir"]),
225 | config["hidden_size"],
226 | config["dropout"],
227 | config["num_classes"],
228 | config["epochs"],
229 | config["batch_size"],
230 | config["lr"],
231 | config["patience"],
232 | config["max_gradient_norm"],
233 | args.checkpoint)
234 |
--------------------------------------------------------------------------------
/ESIM/scripts/training/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for training and validating models.
3 | """
4 |
5 | import time
6 | import torch
7 |
8 | import torch.nn as nn
9 |
10 | from tqdm import tqdm
11 |
12 | from esim.utils import correct_predictions
13 |
14 |
15 | def train(model,
16 | dataloader,
17 | optimizer,
18 | criterion,
19 | epoch_number,
20 | max_gradient_norm):
21 | """
22 | Train a model for one epoch on some input data with a given optimizer and
23 | criterion.
24 |
25 | Args:
26 | model: A torch module that must be trained on some input data.
27 | dataloader: A DataLoader object to iterate over the training data.
28 | optimizer: A torch optimizer to use for training on the input model.
29 | criterion: A loss criterion to use for training.
30 | epoch_number: The number of the epoch for which training is performed.
31 | max_gradient_norm: Max. norm for gradient norm clipping.
32 |
33 | Returns:
34 | epoch_time: The total time necessary to train the epoch.
35 | epoch_loss: The training loss computed for the epoch.
36 | epoch_accuracy: The accuracy computed for the epoch.
37 | """
38 | # Switch the model to train mode.
39 | model.train()
40 | device = model.device
41 |
42 | epoch_start = time.time()
43 | batch_time_avg = 0.0
44 | running_loss = 0.0
45 | correct_preds = 0
46 |
47 | tqdm_batch_iterator = tqdm(dataloader)
48 | for batch_index, batch in enumerate(tqdm_batch_iterator):
49 | batch_start = time.time()
50 |
51 | # Move input and output data to the GPU if it is used.
52 | premises = batch['premise'].to(device)
53 | premises_lengths = batch['premise_length'].to(device)
54 | hypotheses = batch['hypothesis'].to(device)
55 | hypotheses_lengths = batch['hypothesis_length'].to(device)
56 | labels = batch['label'].to(device)
57 |
58 | optimizer.zero_grad()
59 |
60 | logits, probs = model(premises,
61 | premises_lengths,
62 | hypotheses,
63 | hypotheses_lengths)
64 | loss = criterion(logits, labels)
65 | loss.backward()
66 |
67 | nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm)
68 | optimizer.step()
69 |
70 | batch_time_avg += time.time() - batch_start
71 | running_loss += loss.item()
72 | correct_preds += correct_predictions(probs, labels)
73 |
74 | description = "Avg. batch proc. time: {:.4f}s, loss: {:.4f}"\
75 | .format(batch_time_avg/(batch_index+1),
76 | running_loss/(batch_index+1))
77 | tqdm_batch_iterator.set_description(description)
78 |
79 | epoch_time = time.time() - epoch_start
80 | epoch_loss = running_loss / len(dataloader)
81 | epoch_accuracy = correct_preds / len(dataloader.dataset)
82 |
83 | return epoch_time, epoch_loss, epoch_accuracy
84 |
85 |
86 | def validate(model, dataloader, criterion):
87 | """
88 | Compute the loss and accuracy of a model on some validation dataset.
89 |
90 | Args:
91 | model: A torch module for which the loss and accuracy must be
92 | computed.
93 | dataloader: A DataLoader object to iterate over the validation data.
94 | criterion: A loss criterion to use for computing the loss.
95 | epoch: The number of the epoch for which validation is performed.
96 | device: The device on which the model is located.
97 |
98 | Returns:
99 | epoch_time: The total time to compute the loss and accuracy on the
100 | entire validation set.
101 | epoch_loss: The loss computed on the entire validation set.
102 | epoch_accuracy: The accuracy computed on the entire validation set.
103 | """
104 | # Switch to evaluate mode.
105 | model.eval()
106 | device = model.device
107 |
108 | epoch_start = time.time()
109 | running_loss = 0.0
110 | running_accuracy = 0.0
111 |
112 | # Deactivate autograd for evaluation.
113 | with torch.no_grad():
114 | for batch in dataloader:
115 | # Move input and output data to the GPU if one is used.
116 | premises = batch['premise'].to(device)
117 | premises_lengths = batch['premise_length'].to(device)
118 | hypotheses = batch['hypothesis'].to(device)
119 | hypotheses_lengths = batch['hypothesis_length'].to(device)
120 | labels = batch['label'].to(device)
121 |
122 | logits, probs = model(premises,
123 | premises_lengths,
124 | hypotheses,
125 | hypotheses_lengths)
126 | loss = criterion(logits, labels)
127 |
128 | running_loss += loss.item()
129 | running_accuracy += correct_predictions(probs, labels)
130 |
131 | epoch_time = time.time() - epoch_start
132 | epoch_loss = running_loss / len(dataloader)
133 | epoch_accuracy = running_accuracy / (len(dataloader.dataset))
134 |
135 | return epoch_time, epoch_loss, epoch_accuracy
136 |
--------------------------------------------------------------------------------
/ESIM/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 |
4 | setup(name='ESIM',
5 | version=1.0,
6 | url='https://github.com/coetaur0/ESIM',
7 | license='Apache 2',
8 | author='Aurelien Coet',
9 | author_email='aurelien.coet19@gmail.com',
10 | description='Implementation in Pytorch of the ESIM model for NLI',
11 | packages=[
12 | 'esim'
13 | ],
14 | install_requires=[
15 | 'numpy',
16 | 'nltk',
17 | 'matplotlib',
18 | 'tqdm',
19 | 'torch'
20 | ])
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Di Jin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TextFooler
2 | A Model for Natural Language Attack on Text Classification and Inference
3 |
4 | This is the source code for the paper: [Jin, Di, et al. "Is BERT Really Robust? Natural Language Attack on Text Classification and Entailment." arXiv preprint arXiv:1907.11932 (2019)](https://arxiv.org/pdf/1907.11932.pdf). If you use the code, please cite the paper:
5 |
6 | ```
7 | @article{jin2019bert,
8 | title={Is BERT Really Robust? Natural Language Attack on Text Classification and Entailment},
9 | author={Jin, Di and Jin, Zhijing and Zhou, Joey Tianyi and Szolovits, Peter},
10 | journal={arXiv preprint arXiv:1907.11932},
11 | year={2019}
12 | }
13 | ```
14 |
15 | ## Data
16 | Our 7 datasets are [here](https://bit.ly/nlp_adv_data).
17 |
18 | ## Prerequisites:
19 | Required packages are listed in the requirements.txt file:
20 | ```
21 | pip install -r requirements.txt
22 | ```
23 |
24 | ## How to use
25 |
26 | * Run the following code to install the **esim** package:
27 |
28 | ```
29 | cd ESIM
30 | python setup.py install
31 | cd ..
32 | ```
33 |
34 | * (Optional) Run the following code to pre-compute the cosine similarity scores between word pairs based on the [counter-fitting word embeddings](https://drive.google.com/open?id=1bayGomljWb6HeYDMTDKXrh0HackKtSlx).
35 |
36 | ```
37 | python comp_cos_sim_mat.py [PATH_TO_COUNTER_FITTING_WORD_EMBEDDINGS]
38 | ```
39 |
40 | * Run the following code to generate the adversaries for text classification:
41 |
42 | ```
43 | python attack_classification.py
44 | ```
45 |
46 | For Natural langauge inference:
47 |
48 | ```
49 | python attack_nli.py
50 | ```
51 |
52 | Examples of run code for these two files are in [run_attack_classification.py](https://github.com/jind11/TextFooler/blob/master/run_attack_classification.py) and [run_attack_nli.py](https://github.com/jind11/TextFooler/blob/master/run_attack_nli.py). Here we explain each required argument in details:
53 |
54 | * --dataset_path: The path to the dataset. We put the 1000 examples for each dataset we used in the paper in the folder [data](https://github.com/jind11/TextFooler/tree/master/data).
55 | * --target_model: Name of the target model such as ''bert''.
56 | * --target_model_path: The path to the trained parameters of the target model. For ease of replication, we shared the [trained BERT model parameters](https://drive.google.com/drive/folders/1wKjelHFcqsT3GgA7LzWmoaAHcUkP4c7B?usp=sharing), the [trained LSTM model parameters](https://drive.google.com/drive/folders/108myH_HHtBJX8MvhBQuvTGb-kGOce5M2?usp=sharing), and the [trained CNN model parameters](https://drive.google.com/drive/folders/1Ifowzfers0m1Aw2vE8O7SMifHUhkTEjh?usp=sharing) on each dataset we used.
57 | * --counter_fitting_embeddings_path: The path to the counter-fitting word embeddings.
58 | * --counter_fitting_cos_sim_path: This is optional. If given, then the pre-computed cosine similarity scores based on the counter-fitting word embeddings will be loaded to save time. If not, it will be calculated.
59 | * --USE_cache_path: The path to save the USE model file (Downloading is automatic if this path is empty).
60 |
61 | Two more things to share with you:
62 |
63 | 1. In case someone wants to replicate our experiments for training the target models, we shared the used [seven datasets](https://drive.google.com/open?id=1N-FYUa5XN8qDs4SgttQQnrkeTXXAXjTv) we have processed for you!
64 |
65 | 2. In case someone may want to use our generated adversary results towards the benchmark data directly, [here it is](https://drive.google.com/drive/folders/12yeqcqZiEWuncC5zhSUmKBC3GLFiCEaN?usp=sharing).
66 |
--------------------------------------------------------------------------------
/comp_cos_sim_mat.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sys
3 |
4 | embedding_path = sys.argv[1] # '/data/medg/misc/jindi/nlp/embeddings/counter-fitted-vectors.txt'
5 |
6 | embeddings = []
7 | with open(embedding_path, 'r') as ifile:
8 | for line in ifile:
9 | embedding = [float(num) for num in line.strip().split()[1:]]
10 | embeddings.append(embedding)
11 | embeddings = np.array(embeddings)
12 | print(embeddings.T.shape)
13 | norm = np.linalg.norm(embeddings, axis=1, keepdims=True)
14 | embeddings = np.asarray(embeddings / norm, "float32")
15 | product = np.dot(embeddings, embeddings.T)
16 | np.save(('cos_sim_counter_fitting.npy'), product)
17 |
--------------------------------------------------------------------------------
/criteria.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import nltk
3 | from nltk.stem.wordnet import WordNetLemmatizer
4 | from pattern.en import conjugate, lemma, lexeme, PRESENT, SG, PL, PAST, PROGRESSIVE
5 | import random
6 |
7 |
8 | # Function 0: List of stop words
9 | def get_stopwords():
10 | '''
11 | :return: a set of 266 stop words from nltk. eg. {'someone', 'anyhow', 'almost', 'none', 'mostly', 'around', 'being', 'fifteen', 'moreover', 'whoever', 'further', 'not', 'side', 'keep', 'does', 'regarding', 'until', 'across', 'during', 'nothing', 'of', 'we', 'eleven', 'say', 'between', 'upon', 'whole', 'in', 'nowhere', 'show', 'forty', 'hers', 'may', 'who', 'onto', 'amount', 'you', 'yours', 'his', 'than', 'it', 'last', 'up', 'ca', 'should', 'hereafter', 'others', 'would', 'an', 'all', 'if', 'otherwise', 'somehow', 'due', 'my', 'as', 'since', 'they', 'therein', 'together', 'hereupon', 'go', 'throughout', 'well', 'first', 'thence', 'yet', 'were', 'neither', 'too', 'whether', 'call', 'a', 'without', 'anyway', 'me', 'made', 'the', 'whom', 'but', 'and', 'nor', 'although', 'nine', 'whose', 'becomes', 'everywhere', 'front', 'thereby', 'both', 'will', 'move', 'every', 'whence', 'used', 'therefore', 'anyone', 'into', 'meanwhile', 'perhaps', 'became', 'same', 'something', 'very', 'where', 'besides', 'own', 'whereby', 'whither', 'quite', 'wherever', 'why', 'latter', 'down', 'she', 'sometimes', 'about', 'sometime', 'eight', 'ever', 'towards', 'however', 'noone', 'three', 'top', 'can', 'or', 'did', 'seemed', 'that', 'because', 'please', 'whereafter', 'mine', 'one', 'us', 'within', 'themselves', 'only', 'must', 'whereas', 'namely', 'really', 'yourselves', 'against', 'thus', 'thru', 'over', 'some', 'four', 'her', 'just', 'two', 'whenever', 'seeming', 'five', 'him', 'using', 'while', 'already', 'alone', 'been', 'done', 'is', 'our', 'rather', 'afterwards', 'for', 'back', 'third', 'himself', 'put', 'there', 'under', 'hereby', 'among', 'anywhere', 'at', 'twelve', 'was', 'more', 'doing', 'become', 'name', 'see', 'cannot', 'once', 'thereafter', 'ours', 'part', 'below', 'various', 'next', 'herein', 'also', 'above', 'beside', 'another', 'had', 'has', 'to', 'could', 'least', 'though', 'your', 'ten', 'many', 'other', 'from', 'get', 'which', 'with', 'latterly', 'now', 'never', 'most', 'so', 'yourself', 'amongst', 'whatever', 'whereupon', 'their', 'serious', 'make', 'seem', 'often', 'on', 'seems', 'any', 'hence', 'herself', 'myself', 'be', 'either', 'somewhere', 'before', 'twenty', 'here', 'beyond', 'this', 'else', 'nevertheless', 'its', 'he', 'except', 'when', 'again', 'thereupon', 'after', 'through', 'ourselves', 'along', 'former', 'give', 'enough', 'them', 'behind', 'itself', 'wherein', 'always', 'such', 'several', 'these', 'everyone', 'toward', 'have', 'nobody', 'elsewhere', 'empty', 'few', 'six', 'formerly', 'do', 'no', 'then', 'unless', 'what', 'how', 'even', 'i', 'indeed', 'still', 'might', 'off', 'those', 'via', 'fifty', 'each', 'out', 'less', 're', 'take', 'by', 'hundred', 'much', 'anything', 'becoming', 'am', 'everything', 'per', 'full', 'sixty', 'are', 'bottom', 'beforehand'}
12 | '''
13 | stop_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again', 'against', 'ain', 'all', 'almost', 'alone', 'along', 'already', 'also', 'although', 'am', 'among', 'amongst', 'an', 'and', 'another', 'any', 'anyhow', 'anyone', 'anything', 'anyway', 'anywhere', 'are', 'aren', "aren't", 'around', 'as', 'at', 'back', 'been', 'before', 'beforehand', 'behind', 'being', 'below', 'beside', 'besides', 'between', 'beyond', 'both', 'but', 'by', 'can', 'cannot', 'could', 'couldn', "couldn't", 'd', 'didn', "didn't", 'doesn', "doesn't", 'don', "don't", 'down', 'due', 'during', 'either', 'else', 'elsewhere', 'empty', 'enough', 'even', 'ever', 'everyone', 'everything', 'everywhere', 'except', 'first', 'for', 'former', 'formerly', 'from', 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'he', 'hence', 'her', 'here', 'hereafter', 'hereby', 'herein', 'hereupon', 'hers', 'herself', 'him', 'himself', 'his', 'how', 'however', 'hundred', 'i', 'if', 'in', 'indeed', 'into', 'is', 'isn', "isn't", 'it', "it's", 'its', 'itself', 'just', 'latter', 'latterly', 'least', 'll', 'may', 'me', 'meanwhile', 'mightn', "mightn't", 'mine', 'more', 'moreover', 'most', 'mostly', 'must', 'mustn', "mustn't", 'my', 'myself', 'namely', 'needn', "needn't", 'neither', 'never', 'nevertheless', 'next', 'no', 'nobody', 'none', 'noone', 'nor', 'not', 'nothing', 'now', 'nowhere', 'o', 'of', 'off', 'on', 'once', 'one', 'only', 'onto', 'or', 'other', 'others', 'otherwise', 'our', 'ours', 'ourselves', 'out', 'over', 'per', 'please','s', 'same', 'shan', "shan't", 'she', "she's", "should've", 'shouldn', "shouldn't", 'somehow', 'something', 'sometime', 'somewhere', 'such', 't', 'than', 'that', "that'll", 'the', 'their', 'theirs', 'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 'therefore', 'therein', 'thereupon', 'these', 'they','this', 'those', 'through', 'throughout', 'thru', 'thus', 'to', 'too','toward', 'towards', 'under', 'unless', 'until', 'up', 'upon', 'used', 've', 'was', 'wasn', "wasn't", 'we', 'were', 'weren', "weren't", 'what', 'whatever', 'when', 'whence', 'whenever', 'where', 'whereafter', 'whereas', 'whereby', 'wherein', 'whereupon', 'wherever', 'whether', 'which', 'while', 'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won', "won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've", 'your', 'yours', 'yourself', 'yourselves']
14 | stop_words = set(stop_words)
15 | return stop_words
16 | # StopWords = {}
17 | # StopWords['nltk'] = set(nltk.corpus.stopwords.words('english'))
18 | #
19 | # import spacy
20 | # nlp = spacy.load("en")
21 | # StopWords['spacy'] = nlp.Defaults.stop_words
22 | #
23 | # return StopWords['nltk'] # | StopWords['spacy']
24 |
25 |
26 | UniversalPos = ['NOUN', 'VERB', 'ADJ', 'ADV',
27 | 'PRON', 'DET', 'ADP', 'NUM',
28 | 'CONJ', 'PRT', '.', 'X']
29 |
30 |
31 | # Function 1:
32 | def get_pos(sent, tagset='universal'):
33 | '''
34 | :param sent: list of word strings
35 | tagset: {'universal', 'default'}
36 | :return: list of pos tags.
37 | Universal (Coarse) Pos tags has 12 categories
38 | - NOUN (nouns)
39 | - VERB (verbs)
40 | - ADJ (adjectives)
41 | - ADV (adverbs)
42 | - PRON (pronouns)
43 | - DET (determiners and articles)
44 | - ADP (prepositions and postpositions)
45 | - NUM (numerals)
46 | - CONJ (conjunctions)
47 | - PRT (particles)
48 | - . (punctuation marks)
49 | - X (a catch-all for other categories such as abbreviations or foreign words)
50 | '''
51 | if tagset == 'default':
52 | word_n_pos_list = nltk.pos_tag(sent)
53 | elif tagset == 'universal':
54 | word_n_pos_list = nltk.pos_tag(sent, tagset=tagset)
55 | _, pos_list = zip(*word_n_pos_list)
56 | return pos_list
57 |
58 |
59 | # Function 2: Pos Filter
60 | def pos_filter(ori_pos, new_pos_list):
61 | same = [True if ori_pos == new_pos or (set([ori_pos, new_pos]) <= set(['NOUN', 'VERB']))
62 | else False
63 | for new_pos in new_pos_list]
64 | return same
65 |
66 |
67 | # Function 3:
68 | def get_v_tense(sent):
69 | '''
70 | :param sent: a list of words
71 | :return tenses: a dict {key (word ix): value (tense, e.g. VBD)}
72 | pos of verbs
73 | - VB Verb, base form
74 | - VBD Verb, past tense
75 | - VBG Verb, gerund or present participle
76 | - VBN Verb, past participle
77 | - VBP Verb, non-3rd person singular present
78 | - VBZ Verb, 3rd person singular present
79 | '''
80 | word_n_pos_list = nltk.pos_tag(sent)
81 | _, pos_list = zip(*word_n_pos_list)
82 | tenses = {w_ix: tense for w_ix, tense in enumerate(pos_list) if tense.startswith('V')}
83 | return tenses
84 |
85 |
86 | def change_tense(word, tense, lemmatize=False):
87 | '''
88 | en.verb.tenses():
89 | ['past', '3rd singular present', 'past participle', 'infinitive',
90 | 'present participle', '1st singular present', '1st singular past',
91 | 'past plural', '2nd singular present', '2nd singular past',
92 | '3rd singular past', 'present plural']
93 | :return:
94 | reference link: https://www.clips.uantwerpen.be/pages/pattern-en#conjugation
95 | '''
96 | if lemmatize:
97 | word = WordNetLemmatizer().lemmatize(word, 'v')
98 | # if pos(word) is not verb, return word
99 | lookup = {
100 | 'VB': conjugate(verb=word, tense=PRESENT, number=SG),
101 | 'VBD': conjugate(verb=word, tense=PAST, aspect=PROGRESSIVE, number=SG),
102 | 'VBG': conjugate(verb=word, tense=PRESENT, aspect=PROGRESSIVE, number=SG),
103 | 'VBN': conjugate(verb=word, tense=PAST, aspect=PROGRESSIVE, number=SG),
104 | 'VBP': conjugate(verb=word, tense=PRESENT, number=PL),
105 | 'VBZ': conjugate(verb=word, tense=PRESENT, number=SG),
106 | }
107 | return lookup[tense]
108 |
109 |
110 | def get_sent_list():
111 | file_format = "/afs/csail.mit.edu/u/z/zhijing/proj/to_di/data/{}/test_lm.txt"
112 | content = []
113 | for dataset in ['ag', 'fake', 'mr', 'yelp']:
114 | file = file_format.format(dataset)
115 | with open(file) as f:
116 | content += [line.strip().split() for line in f if line.strip()]
117 | return content
118 |
119 |
120 | def check_pos(sent_list, win_size=10):
121 | '''
122 | :param sent_list:
123 | :param win_size:
124 | :param pad_size:
125 | :return: diff_ix = Counter({0: 606, 1: 180, 2: 42, 3: 15, 4: 5, 5: 1})
126 | len(sent_list) = 60139
127 | '''
128 |
129 | sent_list = sent_list[:]
130 | random.shuffle(sent_list)
131 | sent_list = sent_list[:100]
132 |
133 | center_ix = [random.randint(0 + win_size // 2, len(sent) - 1 - win_size // 2)
134 | if len(sent) > win_size else len(sent) // 2
135 | for sent in sent_list]
136 | word_range = [[max(0, cen_ix - win_size // 2), min(len(sent), cen_ix + win_size // 2)]
137 | for cen_ix, sent in zip(center_ix, sent_list)]
138 |
139 | assert len(center_ix) == len(word_range)
140 | assert len(center_ix) == len(sent_list)
141 |
142 | corr_pos = [get_pos(sent)[word_range[sent_ix][0]: word_range[sent_ix][1]] for sent_ix, sent in enumerate(sent_list)]
143 | part_pos = [get_pos(sent[word_range[sent_ix][0]: word_range[sent_ix][1]]) for sent_ix, sent in enumerate(sent_list)]
144 | # corr_pos = [sent_pos[pad_size: -pad_size] if len(sent_pos) > 2 * pad_size else sent_pos
145 | # for sent_ix, sent_pos in enumerate(corr_pos)]
146 | # part_pos = [sent_pos[pad_size: -pad_size] if len(sent_pos) > 2 * pad_size else sent_pos
147 | # for sent_ix, sent_pos in enumerate(part_pos)]
148 |
149 | diff_ix = []
150 | diff_s_ix = []
151 | for sent_ix, (sent_pos_corr, sent_pos_part) in enumerate(zip(corr_pos, part_pos)):
152 | cen_ix = center_ix[sent_ix] - word_range[sent_ix][0]
153 | if sent_pos_corr[cen_ix] != sent_pos_part[cen_ix]:
154 | diff_s_ix += [sent_ix]
155 | # show_var(["diff_s_ix", "win_size"])
156 |
157 | if diff_s_ix:
158 | import pdb;
159 | pdb.set_trace()
160 | # if sent_pos_corr != sent_pos_part:
161 | # diff_ix += [w_ix for w_ix, (p_corr, p_part) in enumerate(zip(sent_pos_corr, sent_pos_part))
162 | # if p_corr != p_part]
163 | # diff_s_ix += [sent_ix]
164 |
165 |
166 | def main():
167 | # Function 0:
168 | stop_words = get_stopwords()
169 |
170 | # Function 1:
171 | sent = 'i have a dream'.split()
172 | pos_list = get_pos(sent)
173 | sent_list = get_sent_list()
174 | for _ in range(10):
175 | check_pos(sent_list)
176 | import pdb;
177 | pdb.set_trace()
178 |
179 | # Function 2:
180 | ori_pos = 'NOUN'
181 | new_pos_list = ['NOUN', 'VERB', 'ADJ', 'ADV', 'X', '.']
182 | same = pos_filter(ori_pos, new_pos_list)
183 |
184 | # Function 3:
185 | tenses = get_v_tense(sent)
186 |
187 | # this following one does not work, due to the failure to import
188 | # NodeBox English linguistic library (http://nodebox.net/code/index.php/Linguistics)
189 | new_word = change_tense('made', 'VBD')
190 | import pdb;
191 | pdb.set_trace()
192 |
193 |
194 | if __name__ == "__main__":
195 | main()
196 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import os
3 | import sys
4 | import re
5 | import random
6 |
7 | import numpy as np
8 | import torch
9 |
10 | def clean_str(string, TREC=False):
11 | """
12 | Tokenization/string cleaning for all datasets except for SST.
13 | Every dataset is lower cased except for TREC
14 | """
15 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
16 | string = re.sub(r"\'s", " \'s", string)
17 | string = re.sub(r"\'ve", " \'ve", string)
18 | string = re.sub(r"n\'t", " n\'t", string)
19 | string = re.sub(r"\'re", " \'re", string)
20 | string = re.sub(r"\'d", " \'d", string)
21 | string = re.sub(r"\'ll", " \'ll", string)
22 | string = re.sub(r",", " , ", string)
23 | string = re.sub(r"!", " ! ", string)
24 | string = re.sub(r"\(", " \( ", string)
25 | string = re.sub(r"\)", " \) ", string)
26 | string = re.sub(r"\?", " \? ", string)
27 | string = re.sub(r"\s{2,}", " ", string)
28 | return string.strip() if TREC else string.strip().lower()
29 |
30 | def read_corpus(path, clean=True, MR=True, encoding='utf8', shuffle=False, lower=True):
31 | data = []
32 | labels = []
33 | with open(path, encoding=encoding) as fin:
34 | for line in fin:
35 | if MR:
36 | label, sep, text = line.partition(' ')
37 | label = int(label)
38 | else:
39 | label, sep, text = line.partition(',')
40 | label = int(label) - 1
41 | if clean:
42 | text = clean_str(text.strip()) if clean else text.strip()
43 | if lower:
44 | text = text.lower()
45 | labels.append(label)
46 | data.append(text.split())
47 |
48 | if shuffle:
49 | perm = list(range(len(data)))
50 | random.shuffle(perm)
51 | data = [data[i] for i in perm]
52 | labels = [labels[i] for i in perm]
53 |
54 | return data, labels
55 |
56 | def read_MR(path, seed=1234):
57 | file_path = os.path.join(path, "rt-polarity.all")
58 | data, labels = read_corpus(file_path, encoding='latin-1')
59 | random.seed(seed)
60 | perm = list(range(len(data)))
61 | random.shuffle(perm)
62 | data = [ data[i] for i in perm ]
63 | labels = [ labels[i] for i in perm ]
64 | return data, labels
65 |
66 | def read_SUBJ(path, seed=1234):
67 | file_path = os.path.join(path, "subj.all")
68 | data, labels = read_corpus(file_path, encoding='latin-1')
69 | random.seed(seed)
70 | perm = list(range(len(data)))
71 | random.shuffle(perm)
72 | data = [ data[i] for i in perm ]
73 | labels = [ labels[i] for i in perm ]
74 | return data, labels
75 |
76 | def read_CR(path, seed=1234):
77 | file_path = os.path.join(path, "custrev.all")
78 | data, labels = read_corpus(file_path)
79 | random.seed(seed)
80 | perm = list(range(len(data)))
81 | random.shuffle(perm)
82 | data = [ data[i] for i in perm ]
83 | labels = [ labels[i] for i in perm ]
84 | return data, labels
85 |
86 | def read_MPQA(path, seed=1234):
87 | file_path = os.path.join(path, "mpqa.all")
88 | data, labels = read_corpus(file_path)
89 | random.seed(seed)
90 | perm = list(range(len(data)))
91 | random.shuffle(perm)
92 | data = [ data[i] for i in perm ]
93 | labels = [ labels[i] for i in perm ]
94 | return data, labels
95 |
96 | def read_TREC(path, seed=1234):
97 | train_path = os.path.join(path, "TREC.train.all")
98 | test_path = os.path.join(path, "TREC.test.all")
99 | train_x, train_y = read_corpus(train_path, TREC=True, encoding='latin-1')
100 | test_x, test_y = read_corpus(test_path, TREC=True, encoding='latin-1')
101 | random.seed(seed)
102 | perm = list(range(len(train_x)))
103 | random.shuffle(perm)
104 | train_x = [ train_x[i] for i in perm ]
105 | train_y = [ train_y[i] for i in perm ]
106 | return train_x, train_y, test_x, test_y
107 |
108 | def read_SST(path, seed=1234):
109 | train_path = os.path.join(path, "stsa.binary.phrases.train")
110 | valid_path = os.path.join(path, "stsa.binary.dev")
111 | test_path = os.path.join(path, "stsa.binary.test")
112 | train_x, train_y = read_corpus(train_path, False)
113 | valid_x, valid_y = read_corpus(valid_path, False)
114 | test_x, test_y = read_corpus(test_path, False)
115 | random.seed(seed)
116 | perm = list(range(len(train_x)))
117 | random.shuffle(perm)
118 | train_x = [ train_x[i] for i in perm ]
119 | train_y = [ train_y[i] for i in perm ]
120 | return train_x, train_y, valid_x, valid_y, test_x, test_y
121 |
122 | def cv_split(data, labels, nfold, test_id):
123 | assert (nfold > 1) and (test_id >= 0) and (test_id < nfold)
124 | lst_x = [ x for i, x in enumerate(data) if i%nfold != test_id ]
125 | lst_y = [ y for i, y in enumerate(labels) if i%nfold != test_id ]
126 | test_x = [ x for i, x in enumerate(data) if i%nfold == test_id ]
127 | test_y = [ y for i, y in enumerate(labels) if i%nfold == test_id ]
128 | perm = list(range(len(lst_x)))
129 | random.shuffle(perm)
130 | M = int(len(lst_x)*0.9)
131 | train_x = [ lst_x[i] for i in perm[:M] ]
132 | train_y = [ lst_y[i] for i in perm[:M] ]
133 | valid_x = [ lst_x[i] for i in perm[M:] ]
134 | valid_y = [ lst_y[i] for i in perm[M:] ]
135 | return train_x, train_y, valid_x, valid_y, test_x, test_y
136 |
137 | def cv_split2(data, labels, nfold, valid_id):
138 | assert (nfold > 1) and (valid_id >= 0) and (valid_id < nfold)
139 | train_x = [ x for i, x in enumerate(data) if i%nfold != valid_id ]
140 | train_y = [ y for i, y in enumerate(labels) if i%nfold != valid_id ]
141 | valid_x = [ x for i, x in enumerate(data) if i%nfold == valid_id ]
142 | valid_y = [ y for i, y in enumerate(labels) if i%nfold == valid_id ]
143 | return train_x, train_y, valid_x, valid_y
144 |
145 | def pad(sequences, pad_token='', pad_left=True):
146 | ''' input sequences is a list of text sequence [[str]]
147 | pad each text sequence to the length of the longest
148 | '''
149 | max_len = max(5,max(len(seq) for seq in sequences))
150 | if pad_left:
151 | return [ [pad_token]*(max_len-len(seq)) + seq for seq in sequences ]
152 | return [ seq + [pad_token]*(max_len-len(seq)) for seq in sequences ]
153 |
154 |
155 | def create_one_batch(x, y, map2id, oov=''):
156 | oov_id = map2id[oov]
157 | x = pad(x)
158 | length = len(x[0])
159 | batch_size = len(x)
160 | x = [ map2id.get(w, oov_id) for seq in x for w in seq ]
161 | x = torch.LongTensor(x)
162 | assert x.size(0) == length*batch_size
163 | return x.view(batch_size, length).t().contiguous().cuda(), torch.LongTensor(y).cuda()
164 |
165 |
166 | def create_one_batch_x(x, map2id, oov=''):
167 | oov_id = map2id[oov]
168 | x = pad(x)
169 | length = len(x[0])
170 | batch_size = len(x)
171 | x = [ map2id.get(w, oov_id) for seq in x for w in seq ]
172 | x = torch.LongTensor(x)
173 | assert x.size(0) == length*batch_size
174 | return x.view(batch_size, length).t().contiguous().cuda()
175 |
176 |
177 | # shuffle training examples and create mini-batches
178 | def create_batches(x, y, batch_size, map2id, perm=None, sort=False):
179 |
180 | lst = perm or range(len(x))
181 |
182 | # sort sequences based on their length; necessary for SST
183 | if sort:
184 | lst = sorted(lst, key=lambda i: len(x[i]))
185 |
186 | x = [ x[i] for i in lst ]
187 | y = [ y[i] for i in lst ]
188 |
189 | sum_len = 0.
190 | for ii in x:
191 | sum_len += len(ii)
192 | batches_x = [ ]
193 | batches_y = [ ]
194 | size = batch_size
195 | nbatch = (len(x)-1) // size + 1
196 | for i in range(nbatch):
197 | bx, by = create_one_batch(x[i*size:(i+1)*size], y[i*size:(i+1)*size], map2id)
198 | batches_x.append(bx)
199 | batches_y.append(by)
200 |
201 | if sort:
202 | perm = list(range(nbatch))
203 | random.shuffle(perm)
204 | batches_x = [ batches_x[i] for i in perm ]
205 | batches_y = [ batches_y[i] for i in perm ]
206 |
207 | sys.stdout.write("{} batches, avg sent len: {:.1f}\n".format(
208 | nbatch, sum_len/len(x)
209 | ))
210 |
211 | return batches_x, batches_y
212 |
213 |
214 | # shuffle training examples and create mini-batches
215 | def create_batches_x(x, batch_size, map2id, perm=None, sort=False):
216 |
217 | lst = perm or range(len(x))
218 |
219 | # sort sequences based on their length; necessary for SST
220 | if sort:
221 | lst = sorted(lst, key=lambda i: len(x[i]))
222 |
223 | x = [ x[i] for i in lst ]
224 |
225 | sum_len = 0.0
226 | batches_x = [ ]
227 | size = batch_size
228 | nbatch = (len(x)-1) // size + 1
229 | for i in range(nbatch):
230 | bx = create_one_batch_x(x[i*size:(i+1)*size], map2id)
231 | sum_len += len(bx)
232 | batches_x.append(bx)
233 |
234 | if sort:
235 | perm = list(range(nbatch))
236 | random.shuffle(perm)
237 | batches_x = [ batches_x[i] for i in perm ]
238 |
239 | # sys.stdout.write("{} batches, avg len: {:.1f}\n".format(
240 | # nbatch, sum_len/nbatch
241 | # ))
242 |
243 | return batches_x
244 |
245 |
246 | def load_embedding_npz(path):
247 | data = np.load(path)
248 | return [ w.decode('utf8') for w in data['words'] ], data['vals']
249 |
250 | def load_embedding_txt(path):
251 | file_open = gzip.open if path.endswith(".gz") else open
252 | words = [ ]
253 | vals = [ ]
254 | with file_open(path, encoding='utf-8') as fin:
255 | fin.readline()
256 | for line in fin:
257 | line = line.rstrip()
258 | if line:
259 | parts = line.split(' ')
260 | words.append(parts[0])
261 | vals += [ float(x) for x in parts[1:] ]
262 | return words, np.asarray(vals).reshape(len(words),-1)
263 |
264 | def load_embedding(path):
265 | if path.endswith(".npz"):
266 | return load_embedding_npz(path)
267 | else:
268 | return load_embedding_txt(path)
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | def deep_iter(x):
8 | if isinstance(x, list) or isinstance(x, tuple):
9 | for u in x:
10 | for v in deep_iter(u):
11 | yield v
12 | else:
13 | yield x
14 |
15 | class CNN_Text(nn.Module):
16 |
17 | def __init__(self, n_in, widths=[3,4,5], filters=100):
18 | super(CNN_Text,self).__init__()
19 | Ci = 1
20 | Co = filters
21 | h = n_in
22 | self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (w, h)) for w in widths])
23 |
24 | def forward(self, x):
25 | # x is (batch, len, d)
26 | x = x.unsqueeze(1) # (batch, Ci, len, d)
27 | x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] #[(batch, Co, len), ...]
28 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]
29 | x = torch.cat(x, 1)
30 | return x
31 |
32 |
33 | class EmbeddingLayer(nn.Module):
34 | def __init__(self, n_d=100, embs=None, fix_emb=True, oov='', pad='', normalize=True):
35 | super(EmbeddingLayer, self).__init__()
36 | word2id = {}
37 | if embs is not None:
38 | embwords, embvecs = embs
39 | for word in embwords:
40 | assert word not in word2id, "Duplicate words in pre-trained embeddings"
41 | word2id[word] = len(word2id)
42 |
43 | sys.stdout.write("{} pre-trained word embeddings loaded.\n".format(len(word2id)))
44 | # if n_d != len(embvecs[0]):
45 | # sys.stdout.write("[WARNING] n_d ({}) != word vector size ({}). Use {} for embeddings.\n".format(
46 | # n_d, len(embvecs[0]), len(embvecs[0])
47 | # ))
48 | n_d = len(embvecs[0])
49 |
50 | # for w in deep_iter(words):
51 | # if w not in word2id:
52 | # word2id[w] = len(word2id)
53 |
54 | if oov not in word2id:
55 | word2id[oov] = len(word2id)
56 |
57 | if pad not in word2id:
58 | word2id[pad] = len(word2id)
59 |
60 | self.word2id = word2id
61 | self.n_V, self.n_d = len(word2id), n_d
62 | self.oovid = word2id[oov]
63 | self.padid = word2id[pad]
64 | self.embedding = nn.Embedding(self.n_V, n_d)
65 | self.embedding.weight.data.uniform_(-0.25, 0.25)
66 |
67 | if embs is not None:
68 | weight = self.embedding.weight
69 | weight.data[:len(embwords)].copy_(torch.from_numpy(embvecs))
70 | sys.stdout.write("embedding shape: {}\n".format(weight.size()))
71 |
72 | if normalize:
73 | weight = self.embedding.weight
74 | norms = weight.data.norm(2,1)
75 | if norms.dim() == 1:
76 | norms = norms.unsqueeze(1)
77 | weight.data.div_(norms.expand_as(weight.data))
78 |
79 | if fix_emb:
80 | self.embedding.weight.requires_grad = False
81 |
82 | def forward(self, input):
83 | return self.embedding(input)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.9.0
2 | astor==0.8.1
3 | beautifulsoup4==4.9.1
4 | boto3==1.14.7
5 | botocore==1.17.7
6 | certifi==2020.4.5.2
7 | chardet==3.0.4
8 | click==7.1.2
9 | docutils==0.15.2
10 | feedparser==5.2.1
11 | gast==0.3.3
12 | grpcio==1.29.0
13 | h5py==2.10.0
14 | idna==2.9
15 | importlib-metadata==1.6.1
16 | jmespath==0.10.0
17 | joblib==0.15.1
18 | Keras-Applications==1.0.8
19 | Keras-Preprocessing==1.1.2
20 | lxml==4.5.1
21 | Markdown==3.2.2
22 | nltk==3.5
23 | numpy==1.19.0
24 | protobuf==3.12.2
25 | python-dateutil==2.8.1
26 | python-docx==0.8.10
27 | regex==2020.6.8
28 | requests==2.24.0
29 | s3transfer==0.3.3
30 | six==1.15.0
31 | soupsieve==2.0.1
32 | tensorboard==1.12.2
33 | tensorflow-gpu==1.12.0
34 | tensorflow-hub==0.7.0
35 | termcolor==1.1.0
36 | torch==1.2.0
37 | tqdm==4.46.1
38 | urllib3==1.25.9
39 | Werkzeug==1.0.1
40 | zipp==3.1.0
41 | python==3.6
42 |
--------------------------------------------------------------------------------
/run_attack_classification.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # for wordLSTM target
4 | # command = 'python attack_classification.py --dataset_path data/yelp ' \
5 | # '--target_model wordLSTM --batch_size 128 ' \
6 | # '--target_model_path /scratch/jindi/adversary/BERT/results/yelp ' \
7 | # '--word_embeddings_path /data/medg/misc/jindi/nlp/embeddings/glove.6B/glove.6B.200d.txt ' \
8 | # '--counter_fitting_embeddings_path /data/medg/misc/jindi/nlp/embeddings/counter-fitted-vectors.txt ' \
9 | # '--counter_fitting_cos_sim_path ./cos_sim_counter_fitting.npy ' \
10 | # '--USE_cache_path /scratch/jindi/tf_cache'
11 |
12 | # for BERT target
13 | command = 'python attack_classification.py --dataset_path data/yelp ' \
14 | '--target_model bert ' \
15 | '--target_model_path /scratch/jindi/adversary/BERT/results/yelp ' \
16 | '--max_seq_length 256 --batch_size 32 ' \
17 | '--counter_fitting_embeddings_path /data/medg/misc/jindi/nlp/embeddings/counter-fitted-vectors.txt ' \
18 | '--counter_fitting_cos_sim_path /scratch/jindi/adversary/cos_sim_counter_fitting.npy ' \
19 | '--USE_cache_path /scratch/jindi/tf_cache'
20 |
21 | os.system(command)
--------------------------------------------------------------------------------
/run_attack_nli.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # for ESIM target model
4 | # command = 'python attack_nli.py --dataset_path data/snli ' \
5 | # '--target_model esim --target_model_path ESIM/data/checkpoints/SNLI/best.pth.tar ' \
6 | # '--word_embeddings_path ESIM/data/preprocessed/SNLI/worddict.pkl ' \
7 | # '--counter_fitting_embeddings_path /data/medg/misc/jindi/nlp/embeddings/counter-fitted-vectors.txt ' \
8 | # '--counter_fitting_cos_sim_path ./cos_sim_counter_fitting.npy ' \
9 | # '--USE_cache_path /scratch/jindi/tf_cache' \
10 | # '--output_dir results/snli_esim'
11 |
12 | # for InferSent target model
13 | command = 'python attack_nli.py --dataset_path data/snli ' \
14 | '--target_model infersent ' \
15 | '--target_model_path /scratch/jindi/adversary/BERT/results/SNLI ' \
16 | '--word_embeddings_path /data/medg/misc/jindi/nlp/embeddings/glove.840B/glove.840B.300d.txt ' \
17 | '--counter_fitting_embeddings_path /data/medg/misc/jindi/nlp/embeddings/counter-fitted-vectors.txt ' \
18 | '--counter_fitting_cos_sim_path ./cos_sim_counter_fitting.npy ' \
19 | '--USE_cache_path /scratch/jindi/tf_cache ' \
20 | '--output_dir results/snli_infersent'
21 |
22 | # for BERT target model
23 | command = 'python attack_nli.py --dataset_path data/snli ' \
24 | '--target_model bert ' \
25 | '--target_model_path /scratch/jindi/adversary/BERT/results/SNLI ' \
26 | '--counter_fitting_embeddings_path /data/medg/misc/jindi/nlp/embeddings/counter-fitted-vectors.txt ' \
27 | '--counter_fitting_cos_sim_path /scratch/jindi/adversary/cos_sim_counter_fitting.npy ' \
28 | '--USE_cache_path /scratch/jindi/tf_cache ' \
29 | '--output_dir results/snli_bert'
30 |
31 | os.system(command)
--------------------------------------------------------------------------------
/train_classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import time
5 | import random
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.optim as optim
12 | from torch.autograd import Variable
13 |
14 | # from sru import *
15 | import dataloader
16 | import modules
17 |
18 | class Model(nn.Module):
19 | def __init__(self, embedding, hidden_size=150, depth=1, dropout=0.3, cnn=False, nclasses=2):
20 | super(Model, self).__init__()
21 | self.cnn = cnn
22 | self.drop = nn.Dropout(dropout)
23 | self.emb_layer = modules.EmbeddingLayer(
24 | embs = dataloader.load_embedding(embedding)
25 | )
26 | self.word2id = self.emb_layer.word2id
27 |
28 | if cnn:
29 | self.encoder = modules.CNN_Text(
30 | self.emb_layer.n_d,
31 | widths = [3,4,5],
32 | filters=hidden_size
33 | )
34 | d_out = 3*hidden_size
35 | else:
36 | self.encoder = nn.LSTM(
37 | self.emb_layer.n_d,
38 | hidden_size//2,
39 | depth,
40 | dropout = dropout,
41 | # batch_first=True,
42 | bidirectional=True
43 | )
44 | d_out = hidden_size
45 | # else:
46 | # self.encoder = SRU(
47 | # emb_layer.n_d,
48 | # args.d,
49 | # args.depth,
50 | # dropout = args.dropout,
51 | # )
52 | # d_out = args.d
53 | self.out = nn.Linear(d_out, nclasses)
54 |
55 | def forward(self, input):
56 | if self.cnn:
57 | input = input.t()
58 | emb = self.emb_layer(input)
59 | emb = self.drop(emb)
60 |
61 | if self.cnn:
62 | output = self.encoder(emb)
63 | else:
64 | output, hidden = self.encoder(emb)
65 | # output = output[-1]
66 | output = torch.max(output, dim=0)[0].squeeze()
67 |
68 | output = self.drop(output)
69 | return self.out(output)
70 |
71 | def text_pred(self, text, batch_size=32):
72 | batches_x = dataloader.create_batches_x(
73 | text,
74 | batch_size, ##TODO
75 | self.word2id
76 | )
77 | outs = []
78 | with torch.no_grad():
79 | for x in batches_x:
80 | x = Variable(x)
81 | if self.cnn:
82 | x = x.t()
83 | emb = self.emb_layer(x)
84 |
85 | if self.cnn:
86 | output = self.encoder(emb)
87 | else:
88 | output, hidden = self.encoder(emb)
89 | # output = output[-1]
90 | output = torch.max(output, dim=0)[0]
91 |
92 | outs.append(F.softmax(self.out(output), dim=-1))
93 |
94 | return torch.cat(outs, dim=0)
95 |
96 |
97 | def eval_model(niter, model, input_x, input_y):
98 | model.eval()
99 | # N = len(valid_x)
100 | # criterion = nn.CrossEntropyLoss()
101 | correct = 0.0
102 | cnt = 0.
103 | # total_loss = 0.0
104 | with torch.no_grad():
105 | for x, y in zip(input_x, input_y):
106 | x, y = Variable(x, volatile=True), Variable(y)
107 | output = model(x)
108 | # loss = criterion(output, y)
109 | # total_loss += loss.item()*x.size(1)
110 | pred = output.data.max(1)[1]
111 | correct += pred.eq(y.data).cpu().sum()
112 | cnt += y.numel()
113 | model.train()
114 | return correct.item()/cnt
115 |
116 | def train_model(epoch, model, optimizer,
117 | train_x, train_y,
118 | test_x, test_y,
119 | best_test, save_path):
120 |
121 | model.train()
122 | niter = epoch*len(train_x)
123 | criterion = nn.CrossEntropyLoss()
124 |
125 | cnt = 0
126 | for x, y in zip(train_x, train_y):
127 | niter += 1
128 | cnt += 1
129 | model.zero_grad()
130 | x, y = Variable(x), Variable(y)
131 | output = model(x)
132 | loss = criterion(output, y)
133 | loss.backward()
134 | optimizer.step()
135 |
136 | test_acc = eval_model(niter, model, test_x, test_y)
137 |
138 | sys.stdout.write("Epoch={} iter={} lr={:.6f} train_loss={:.6f} test_err={:.6f}\n".format(
139 | epoch, niter,
140 | optimizer.param_groups[0]['lr'],
141 | loss.item(),
142 | test_acc
143 | ))
144 |
145 | if test_acc > best_test:
146 | best_test = test_acc
147 | if save_path:
148 | torch.save(model.state_dict(), save_path)
149 | # test_err = eval_model(niter, model, test_x, test_y)
150 | sys.stdout.write("\n")
151 | return best_test
152 |
153 | def save_data(data, labels, path, type='train'):
154 | with open(os.path.join(path, type+'.txt'), 'w') as ofile:
155 | for text, label in zip(data, labels):
156 | ofile.write('{} {}\n'.format(label, ' '.join(text)))
157 |
158 | def main(args):
159 | if args.dataset == 'mr':
160 | # data, label = dataloader.read_MR(args.path)
161 | # train_x, train_y, test_x, test_y = dataloader.cv_split2(
162 | # data, label,
163 | # nfold=10,
164 | # valid_id=args.cv
165 | # )
166 | #
167 | # if args.save_data_split:
168 | # save_data(train_x, train_y, args.path, 'train')
169 | # save_data(test_x, test_y, args.path, 'test')
170 | train_x, train_y = dataloader.read_corpus('/data/medg/misc/jindi/nlp/datasets/mr/train.txt')
171 | test_x, test_y = dataloader.read_corpus('/data/medg/misc/jindi/nlp/datasets/mr/test.txt')
172 | elif args.dataset == 'imdb':
173 | train_x, train_y = dataloader.read_corpus(os.path.join('/data/medg/misc/jindi/nlp/datasets/imdb',
174 | 'train_tok.csv'),
175 | clean=False, MR=True, shuffle=True)
176 | test_x, test_y = dataloader.read_corpus(os.path.join('/data/medg/misc/jindi/nlp/datasets/imdb',
177 | 'test_tok.csv'),
178 | clean=False, MR=True, shuffle=True)
179 | else:
180 | train_x, train_y = dataloader.read_corpus('/afs/csail.mit.edu/u/z/zhijing/proj/to_di/data/{}/'
181 | 'train_tok.csv'.format(args.dataset),
182 | clean=False, MR=False, shuffle=True)
183 | test_x, test_y = dataloader.read_corpus('/afs/csail.mit.edu/u/z/zhijing/proj/to_di/data/{}/'
184 | 'test_tok.csv'.format(args.dataset),
185 | clean=False, MR=False, shuffle=True)
186 |
187 | nclasses = max(train_y) + 1
188 | # elif args.dataset == 'subj':
189 | # data, label = dataloader.read_SUBJ(args.path)
190 | # elif args.dataset == 'cr':
191 | # data, label = dataloader.read_CR(args.path)
192 | # elif args.dataset == 'mpqa':
193 | # data, label = dataloader.read_MPQA(args.path)
194 | # elif args.dataset == 'trec':
195 | # train_x, train_y, test_x, test_y = dataloader.read_TREC(args.path)
196 | # data = train_x + test_x
197 | # label = None
198 | # elif args.dataset == 'sst':
199 | # train_x, train_y, valid_x, valid_y, test_x, test_y = dataloader.read_SST(args.path)
200 | # data = train_x + valid_x + test_x
201 | # label = None
202 | # else:
203 | # raise Exception("unknown dataset: {}".format(args.dataset))
204 |
205 | # if args.dataset == 'trec':
206 |
207 |
208 | # elif args.dataset != 'sst':
209 | # train_x, train_y, valid_x, valid_y, test_x, test_y = dataloader.cv_split(
210 | # data, label,
211 | # nfold = 10,
212 | # test_id = args.cv
213 | # )
214 |
215 | model = Model(args.embedding, args.d, args.depth, args.dropout, args.cnn, nclasses).cuda()
216 | need_grad = lambda x: x.requires_grad
217 | optimizer = optim.Adam(
218 | filter(need_grad, model.parameters()),
219 | lr = args.lr
220 | )
221 |
222 | train_x, train_y = dataloader.create_batches(
223 | train_x, train_y,
224 | args.batch_size,
225 | model.word2id,
226 | )
227 | # valid_x, valid_y = dataloader.create_batches(
228 | # valid_x, valid_y,
229 | # args.batch_size,
230 | # emb_layer.word2id,
231 | # )
232 | test_x, test_y = dataloader.create_batches(
233 | test_x, test_y,
234 | args.batch_size,
235 | model.word2id,
236 | )
237 |
238 | best_test = 0
239 | # test_err = 1e+8
240 | for epoch in range(args.max_epoch):
241 | best_test = train_model(epoch, model, optimizer,
242 | train_x, train_y,
243 | # valid_x, valid_y,
244 | test_x, test_y,
245 | best_test, args.save_path
246 | )
247 | if args.lr_decay>0:
248 | optimizer.param_groups[0]['lr'] *= args.lr_decay
249 |
250 | # sys.stdout.write("best_valid: {:.6f}\n".format(
251 | # best_valid
252 | # ))
253 | sys.stdout.write("test_err: {:.6f}\n".format(
254 | best_test
255 | ))
256 |
257 | if __name__ == "__main__":
258 | argparser = argparse.ArgumentParser(sys.argv[0], conflict_handler='resolve')
259 | argparser.add_argument("--cnn", action='store_true', help="whether to use cnn")
260 | argparser.add_argument("--lstm", action='store_true', help="whether to use lstm")
261 | argparser.add_argument("--dataset", type=str, default="mr", help="which dataset")
262 | argparser.add_argument("--embedding", type=str, required=True, help="word vectors")
263 | argparser.add_argument("--batch_size", "--batch", type=int, default=32)
264 | argparser.add_argument("--max_epoch", type=int, default=70)
265 | argparser.add_argument("--d", type=int, default=150)
266 | argparser.add_argument("--dropout", type=float, default=0.3)
267 | argparser.add_argument("--depth", type=int, default=1)
268 | argparser.add_argument("--lr", type=float, default=0.001)
269 | argparser.add_argument("--lr_decay", type=float, default=0)
270 | argparser.add_argument("--cv", type=int, default=0)
271 | argparser.add_argument("--save_path", type=str, default='')
272 | argparser.add_argument("--save_data_split", action='store_true', help="whether to save train/test split")
273 | argparser.add_argument("--gpu_id", type=int, default=0)
274 |
275 | args = argparser.parse_args()
276 | # args.save_path = os.path.join(args.save_path, args.dataset)
277 | print (args)
278 | torch.cuda.set_device(args.gpu_id)
279 | main(args)
--------------------------------------------------------------------------------