├── src ├── hg_api │ ├── __init__.py │ ├── interactive.py │ └── interactive_eval.py ├── nli │ ├── __init__.py │ ├── inspection_tools.py │ ├── evaluation.py │ ├── inference_debug.py │ ├── training_extra.py │ └── training.py ├── flint │ ├── __init__.py │ ├── data_utils │ │ ├── __init__.py │ │ ├── fields.py │ │ └── batchbuilder.py │ └── torch_util.py ├── modeling │ ├── __init__.py │ └── res_encoder.py ├── utils │ ├── __init__.py │ ├── common.py │ ├── save_tool.py │ └── list_dict_data_tool.py ├── dataset_tools │ ├── __init__.py │ ├── format_convert.py │ └── build_data.py └── config.py ├── CODE_OF_CONDUCT.md ├── setup.sh ├── script ├── example_scripts │ ├── train_roberta_small.sh │ ├── train_xlnet.sh │ ├── train_roberta.sh │ └── ANLI_on_Google_Colab.ipynb └── download_data.sh ├── CONTRIBUTING.md ├── mds ├── verifier_labels.md └── start_your_nli_research.md ├── README.md └── LICENSE /src/hg_api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/nli/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /src/flint/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /src/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /src/dataset_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /src/flint/data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import os 8 | from pathlib import Path 9 | 10 | SRC_ROOT = Path(os.path.dirname(os.path.realpath(__file__))) 11 | PRO_ROOT = SRC_ROOT.parent 12 | 13 | if __name__ == '__main__': 14 | pass 15 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Add current pwd to PYTHONPATH 7 | export DIR_TMP="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 8 | 9 | export PYTHONPATH=$PYTHONPATH:$DIR_TMP/src 10 | export PYTHONPATH=$PYTHONPATH:$DIR_TMP/utest 11 | 12 | echo PYTHONPATH=$PYTHONPATH -------------------------------------------------------------------------------- /script/example_scripts/train_roberta_small.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | export MASTER_PORT=88888 8 | 9 | echo $CUDA_VISIBLE_DEVICES 10 | nvidia-smi 11 | # End visible GPUs. 12 | 13 | # setup conda environment 14 | source setup.sh 15 | 16 | which python 17 | 18 | python src/nli/training.py \ 19 | --model_class_name "roberta-base" \ 20 | -n 1 \ 21 | -g 1 \ 22 | --single_gpu \ 23 | -nr 0 \ 24 | --max_length 156 \ 25 | --gradient_accumulation_steps 4 \ 26 | --per_gpu_train_batch_size 4 \ 27 | --per_gpu_eval_batch_size 16 \ 28 | --save_prediction \ 29 | --train_data \ 30 | anli_r1_train:none,anli_r2_train:none,anli_r3_train:none \ 31 | --train_weights \ 32 | 10,20,10 \ 33 | --eval_data \ 34 | anli_r1_dev:none,anli_r2_dev:none,anli_r3_dev:none \ 35 | --eval_frequency 2000 \ 36 | --experiment_name "roberta-base|snli+mnli+fnli+r1*10+r2*20+r3*10|nli" -------------------------------------------------------------------------------- /script/example_scripts/train_xlnet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export MASTER_PORT=88888 4 | 5 | echo $CUDA_VISIBLE_DEVICES 6 | nvidia-smi 7 | # End visible GPUs. 8 | 9 | # setup conda environment 10 | source setup.sh 11 | # Copyright (c) Facebook, Inc. and its affiliates. 12 | # 13 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 14 | # LICENSE file in the root directory of this source tree. 15 | 16 | which python 17 | 18 | python src/nli/training.py \ 19 | --model_class_name "xlnet-large" \ 20 | -n 1 \ 21 | -g 8 \ 22 | -nr 0 \ 23 | --max_length 156 \ 24 | --gradient_accumulation_steps 2 \ 25 | --per_gpu_train_batch_size 8 \ 26 | --per_gpu_eval_batch_size 16 \ 27 | --save_prediction \ 28 | --train_data \ 29 | snli_train:none,mnli_train:none,fever_train:none,anli_r1_train:none,anli_r2_train:none,anli_r3_train:none \ 30 | --train_weights \ 31 | 1,1,1,10,20,10 \ 32 | --eval_data \ 33 | snli_dev:none,mnli_m_dev:none,mnli_mm_dev:none,anli_r1_dev:none,anli_r2_dev:none,anli_r3_dev:none \ 34 | --eval_frequency 2000 \ 35 | --experiment_name "xlnet-large|snli+mnli+fnli+r1*10+r2*20+r3*10|nli" -------------------------------------------------------------------------------- /script/example_scripts/train_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | export MASTER_PORT=88888 8 | 9 | echo $CUDA_VISIBLE_DEVICES 10 | nvidia-smi 11 | # End visible GPUs. 12 | 13 | # setup conda environment 14 | source setup.sh 15 | 16 | which python 17 | 18 | python src/nli/training.py \ 19 | --model_class_name "roberta-large" \ 20 | -n 1 \ 21 | -g 8 \ 22 | -nr 0 \ 23 | --fp16 \ 24 | --fp16_opt_level O2 \ 25 | --max_length 156 \ 26 | --gradient_accumulation_steps 1 \ 27 | --per_gpu_train_batch_size 16 \ 28 | --per_gpu_eval_batch_size 32 \ 29 | --save_prediction \ 30 | --train_data \ 31 | snli_train:none,mnli_train:none,fever_train:none,anli_r1_train:none,anli_r2_train:none,anli_r3_train:none \ 32 | --train_weights \ 33 | 1,1,1,10,20,10 \ 34 | --eval_data \ 35 | snli_dev:none,mnli_m_dev:none,mnli_mm_dev:none,anli_r1_dev:none,anli_r2_dev:none,anli_r3_dev:none \ 36 | --eval_frequency 2000 \ 37 | --experiment_name "roberta-large|snli+mnli+fnli+r1*10+r2*20+r3*10|nli" -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to ANLI 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to ANLI, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /mds/verifier_labels.md: -------------------------------------------------------------------------------- 1 | ## Verifier Labels 2 | We released additional verifier labels for Round 1,2 and 3 on May 11, 2022. 3 | The labels are in [`verifier_labels/verifier_labels_R1-3.jsonl`](https://github.com/facebookresearch/anli/blob/main/verifier_labels/verifier_labels_R1-3.jsonl). 4 | 5 | ## File Format 6 | Each line in the jsonl file records the verifier labels for one example (data point) in the ANLI dataset. 7 | You can use the `uid` field to map the verifier labels back to the original released examples. 8 | 9 | Example: 10 | ```python 11 | {"uid": "385b7051-d09f-4ecb-9224-fcc5f0615408", "verifier labels": ["e", "n", "n"]} 12 | {"uid": "44ee99dc-4179-4160-885d-98e17f203bac", "verifier labels": ["c", "c"]} 13 | ... 14 | ``` 15 | 16 | ## Verifier labels statistics 17 | The table below shows the number of verification labels for each split in ANLI R1 to R3. 18 | Note that for all the examples in dev and test split, there are at least 2 verifiers that agreed with the writer of the example. 19 | 20 | Number of verification labels|0|2|3 21 | ---|---|---|--- 22 | R1-Train|13816|2526|604 23 | R1-Dev|0|702|298 24 | R1-Test|0|740|260 25 | R2-Train | 39895| 3843| 1722 26 | R2-Dev | 0 | 710 | 290 27 | R2-Test | 0 | 672 | 328 28 | R2-Train | 85751 | 10030| 4678 29 | R2-Dev | 0 | 809 | 391 30 | R2-Test | 0 | 820 | 380 31 | 32 | Total number of examples: 169265 33 | 34 | -------------------------------------------------------------------------------- /src/flint/data_utils/fields.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | class FlintField(object): 10 | @classmethod 11 | def batching(cls, batched_data): 12 | raise NotImplemented() 13 | 14 | 15 | class RawFlintField(FlintField): 16 | @classmethod 17 | def batching(cls, batched_data): 18 | return batched_data 19 | 20 | 21 | class LabelFlintField(FlintField): 22 | def batching(self, batched_data): 23 | return torch.tensor(batched_data) 24 | 25 | 26 | class ArrayIndexFlintField(FlintField): 27 | def __init__(self, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False) -> None: 28 | super().__init__() 29 | self.pad_idx = pad_idx 30 | self.eos_idx = eos_idx 31 | self.left_pad = left_pad 32 | self.move_eos_to_beginning = move_eos_to_beginning 33 | 34 | def collate_tokens(self, values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False): 35 | """ 36 | Convert a list of 1d tensors into a padded 2d tensor. 37 | """ 38 | if not torch.is_tensor(values[0]): 39 | values = [torch.tensor(v) for v in values] 40 | 41 | size = max(v.size(0) for v in values) 42 | res = values[0].new(len(values), size).fill_(pad_idx) 43 | 44 | def copy_tensor(src, dst): 45 | assert dst.numel() == src.numel() 46 | if move_eos_to_beginning: 47 | assert src[-1] == eos_idx 48 | dst[0] = eos_idx 49 | dst[1:] = src[:-1] 50 | else: 51 | dst.copy_(src) 52 | 53 | for i, v in enumerate(values): 54 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) 55 | return res 56 | 57 | def batching(self, batched_data): 58 | return self.collate_tokens(batched_data, 59 | self.pad_idx, 60 | self.eos_idx, 61 | self.left_pad, 62 | self.move_eos_to_beginning) 63 | -------------------------------------------------------------------------------- /src/hg_api/interactive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 7 | import torch 8 | 9 | def evaluate(tokenizer, model, premise, hypothesis): 10 | max_length = 256 11 | 12 | tokenized_input_seq_pair = tokenizer.encode_plus(premise, hypothesis, 13 | max_length=max_length, 14 | return_token_type_ids=True, truncation=True) 15 | 16 | input_ids = torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0) 17 | # remember bart doesn't have 'token_type_ids', remove the line below if you are using bart. 18 | token_type_ids = torch.Tensor(tokenized_input_seq_pair['token_type_ids']).long().unsqueeze(0) 19 | attention_mask = torch.Tensor(tokenized_input_seq_pair['attention_mask']).long().unsqueeze(0) 20 | 21 | outputs = model(input_ids, 22 | attention_mask=attention_mask, 23 | token_type_ids=token_type_ids, 24 | labels=None) 25 | # Note: 26 | # "id2label": { 27 | # "0": "entailment", 28 | # "1": "neutral", 29 | # "2": "contradiction" 30 | # }, 31 | 32 | predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist() # batch_size only one 33 | 34 | #print("Premise:", premise) 35 | #print("Hypothesis:", hypothesis) 36 | print("Prediction:") 37 | print("Entailment:", predicted_probability[0]) 38 | print("Neutral:", predicted_probability[1]) 39 | print("Contradiction:", predicted_probability[2]) 40 | 41 | print("="*20) 42 | 43 | if __name__ == '__main__': 44 | print("Loading model...") 45 | 46 | # hg_model_hub_name = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli" 47 | # hg_model_hub_name = "ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli" 48 | # hg_model_hub_name = "ynie/bart-large-snli_mnli_fever_anli_R1_R2_R3-nli" 49 | # hg_model_hub_name = "ynie/electra-large-discriminator-snli_mnli_fever_anli_R1_R2_R3-nli" 50 | hg_model_hub_name = "ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli" 51 | 52 | tokenizer = AutoTokenizer.from_pretrained(hg_model_hub_name) 53 | model = AutoModelForSequenceClassification.from_pretrained(hg_model_hub_name) 54 | print("Model loaded!") 55 | 56 | while True: 57 | premise = input("Premise> ") 58 | hypothesis = input("Hypothesis> ") 59 | 60 | evaluate(tokenizer, model, premise, hypothesis) 61 | -------------------------------------------------------------------------------- /src/flint/data_utils/batchbuilder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from typing import Dict, Type 8 | 9 | from flint.data_utils.fields import FlintField, RawFlintField 10 | 11 | 12 | class BaseBatchBuilder(object): 13 | def __init__(self, batching_schema: Dict[str, FlintField]) -> None: 14 | super().__init__() 15 | self.batching_schema: Dict[str, FlintField] = batching_schema 16 | 17 | def __call__(self, batch): 18 | field_names = batch[0].keys() 19 | batched_data = dict() 20 | 21 | for field_name in field_names: 22 | if field_name not in self.batching_schema: 23 | # default is RawFlintField 24 | batched_data[field_name] = RawFlintField.batching([item[field_name] for item in batch]) 25 | 26 | else: 27 | batched_data[field_name] = self.batching_schema[field_name].batching([item[field_name] for item in batch]) 28 | 29 | return batched_data 30 | 31 | 32 | def has_tensor(obj) -> bool: 33 | """ 34 | Given a possibly complex data structure, 35 | check if it has any torch.Tensors in it. 36 | """ 37 | if isinstance(obj, torch.Tensor): 38 | return True 39 | elif isinstance(obj, dict): 40 | return any(has_tensor(value) for value in obj.values()) 41 | elif isinstance(obj, (list, tuple)): 42 | return any(has_tensor(item) for item in obj) 43 | else: 44 | return False 45 | 46 | 47 | def move_to_device(obj, cuda_device: int): 48 | """ 49 | Given a structure (possibly) containing Tensors on the CPU, 50 | move all the Tensors to the specified GPU (or do nothing, if they should be on the CPU). 51 | """ 52 | 53 | if cuda_device < 0 or not has_tensor(obj): 54 | return obj 55 | elif isinstance(obj, torch.Tensor): 56 | return obj.cuda(cuda_device) 57 | elif isinstance(obj, dict): 58 | return {key: move_to_device(value, cuda_device) for key, value in obj.items()} 59 | elif isinstance(obj, list): 60 | return [move_to_device(item, cuda_device) for item in obj] 61 | elif isinstance(obj, tuple) and hasattr(obj, "_fields"): 62 | # This is the best way to detect a NamedTuple, it turns out. 63 | return obj.__class__(*(move_to_device(item, cuda_device) for item in obj)) 64 | elif isinstance(obj, tuple): 65 | return tuple(move_to_device(item, cuda_device) for item in obj) 66 | else: 67 | return obj 68 | 69 | 70 | if __name__ == '__main__': 71 | print(RawFlintField.batching) -------------------------------------------------------------------------------- /src/utils/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import json 7 | from json import JSONEncoder 8 | from tqdm import tqdm 9 | import config 10 | 11 | 12 | registered_jsonabl_classes = {} 13 | 14 | # Some Jsonable classes, for easy json serialization. 15 | 16 | 17 | def register_class(cls): 18 | global registered_jsonabl_classes 19 | if cls not in registered_jsonabl_classes: 20 | registered_jsonabl_classes.update({cls.__name__: cls}) 21 | 22 | 23 | class JsonableObj(object): 24 | pass 25 | 26 | 27 | class JsonableObjectEncoder(JSONEncoder): 28 | def default(self, o): 29 | if isinstance(o, JsonableObj): 30 | d = {'_jcls_': type(o).__name__} 31 | d.update(vars(o)) 32 | return d 33 | else: 34 | return super().default(o) 35 | 36 | 37 | def unserialize_JsonableObject(d): 38 | global registered_jsonabl_classes 39 | classname = d.pop('_jcls_', None) 40 | if classname: 41 | cls = registered_jsonabl_classes[classname] 42 | obj = cls.__new__(cls) # Make instance without calling __init__ 43 | for key, value in d.items(): 44 | setattr(obj, key, value) 45 | return obj 46 | else: 47 | return d 48 | 49 | 50 | def json_dumps(item): 51 | return json.dumps(item, cls=JsonableObjectEncoder) 52 | 53 | 54 | def json_loads(item_str): 55 | return json.loads(item_str, object_hook=unserialize_JsonableObject) 56 | 57 | # Json Serializable object finished. 58 | 59 | 60 | def save_jsonl(d_list, filename): 61 | print("Save to Jsonl:", filename) 62 | with open(filename, encoding='utf-8', mode='w') as out_f: 63 | for item in d_list: 64 | out_f.write(json.dumps(item, cls=JsonableObjectEncoder) + '\n') 65 | 66 | 67 | def load_jsonl(filename, debug_num=None): 68 | d_list = [] 69 | with open(filename, encoding='utf-8', mode='r') as in_f: 70 | print("Load Jsonl:", filename) 71 | for line in tqdm(in_f): 72 | item = json.loads(line.strip(), object_hook=unserialize_JsonableObject) 73 | d_list.append(item) 74 | if debug_num is not None and 0 < debug_num == len(d_list): 75 | break 76 | 77 | return d_list 78 | 79 | 80 | def load_json(filename, **kwargs): 81 | with open(filename, encoding='utf-8', mode='r') as in_f: 82 | return json.load(in_f, object_hook=unserialize_JsonableObject, **kwargs) 83 | 84 | 85 | def save_json(obj, filename, **kwargs): 86 | with open(filename, encoding='utf-8', mode='w') as out_f: 87 | json.dump(obj, out_f, cls=JsonableObjectEncoder, **kwargs) 88 | out_f.close() -------------------------------------------------------------------------------- /src/dataset_tools/format_convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from utils import common 7 | from typing import List, Dict 8 | from tqdm import tqdm 9 | from collections import defaultdict 10 | import config 11 | from pathlib import Path 12 | 13 | 14 | smnli_label2std_label = defaultdict(lambda: "o") # o stands for all other label that is invalid. 15 | smnli_label2std_label.update({ 16 | "entailment": "e", 17 | "neutral": "n", 18 | "contradiction": "c", 19 | "hidden": "h", 20 | }) 21 | 22 | fever_label2std_label = defaultdict(lambda: "o") 23 | fever_label2std_label.update({ 24 | 'SUPPORTS': "e", 25 | 'NOT ENOUGH INFO': "n", 26 | 'REFUTES': "c", 27 | 'hidden': "h", 28 | }) 29 | 30 | anli_label2std_label = defaultdict(lambda: "o") 31 | anli_label2std_label.update({ 32 | 'e': "e", 33 | 'n': "n", 34 | 'c': "c", 35 | 'hidden': "h", 36 | }) 37 | 38 | # standard output format: {uid, premise, hypothesis, label, extra_dataset_related_field.} 39 | 40 | 41 | def sm_nli2std_format(d_list, filter_invalid=True): 42 | p_list: List[Dict] = [] 43 | for item in d_list: 44 | formatted_item: Dict = dict() 45 | formatted_item['uid']: str = item["pairID"] 46 | formatted_item['premise']: str = item["sentence1"] 47 | formatted_item['hypothesis']: str = item["sentence2"] 48 | formatted_item['label']: str = smnli_label2std_label[item["gold_label"]] 49 | if filter_invalid and formatted_item['label'] == 'o': 50 | continue # Skip example with invalid label. 51 | 52 | p_list.append(formatted_item) 53 | return p_list 54 | 55 | 56 | def fever_nli2std_format(d_list, filter_invalid=True): 57 | p_list: List[Dict] = [] 58 | for item in d_list: 59 | formatted_item: Dict = dict() 60 | formatted_item['uid']: str = item["fid"] 61 | formatted_item['premise']: str = item["context"] 62 | formatted_item['hypothesis']: str = item["query"] 63 | formatted_item['label']: str = fever_label2std_label[item["label"]] 64 | if filter_invalid and formatted_item['label'] == 'o': 65 | continue # Skip example with invalid label. 66 | 67 | p_list.append(formatted_item) 68 | return p_list 69 | 70 | 71 | def a_nli2std_format(d_list, filter_invalid=True): 72 | p_list: List[Dict] = [] 73 | for item in d_list: 74 | formatted_item: Dict = dict() 75 | formatted_item['uid']: str = item["uid"] 76 | formatted_item['premise']: str = item["context"] 77 | formatted_item['hypothesis']: str = item["hypothesis"] 78 | formatted_item['label']: str = anli_label2std_label[item["label"]] 79 | formatted_item['reason']: str = item["reason"] 80 | if filter_invalid and formatted_item['label'] == 'o': 81 | continue # Skip example with invalid label. 82 | 83 | p_list.append(formatted_item) 84 | return p_list 85 | 86 | 87 | if __name__ == '__main__': 88 | pass -------------------------------------------------------------------------------- /src/utils/save_tool.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from pathlib import Path 8 | 9 | import config 10 | from datetime import datetime 11 | 12 | from utils import common 13 | 14 | 15 | class ScoreLogger(object): 16 | def __init__(self, init_tracking_dict) -> None: 17 | super().__init__() 18 | self.logging_item_list = [] 19 | self.score_tracker = dict() 20 | self.score_tracker.update(init_tracking_dict) 21 | 22 | def incorporate_results(self, score_dict, save_key, item=None) -> bool: 23 | assert len(score_dict.keys()) == len(self.score_tracker.keys()) 24 | for fieldname in score_dict.keys(): 25 | assert fieldname in self.score_tracker 26 | 27 | valid_improvement = False 28 | for fieldname, value in score_dict.items(): 29 | if score_dict[fieldname] >= self.score_tracker[fieldname]: 30 | self.score_tracker[fieldname] = score_dict[fieldname] 31 | valid_improvement = True 32 | 33 | self.logging_item_list.append({'k': save_key, 'v': item}) 34 | 35 | return valid_improvement 36 | 37 | def logging_to_file(self, filename): 38 | if Path(filename).is_file(): 39 | old_logging_list = common.load_json(filename) 40 | current_saved_key = set() 41 | 42 | for item in self.logging_item_list: 43 | current_saved_key.add(item['k']) 44 | 45 | for item in old_logging_list: 46 | if item['k'] not in current_saved_key: 47 | raise ValueError("Previous logged item can not be found!") 48 | 49 | common.save_json(self.logging_item_list, filename, indent=2, sort_keys=True) 50 | 51 | 52 | def gen_file_prefix(model_name, directory_name='saved_models', date=None): 53 | date_now = datetime.now().strftime("%m-%d-%H:%M:%S") if not date else date 54 | file_path = os.path.join(config.PRO_ROOT / directory_name / '_'.join((date_now, model_name))) 55 | if not os.path.exists(file_path): 56 | os.makedirs(file_path) 57 | return file_path, date_now 58 | 59 | 60 | def get_cur_time_str(): 61 | date_now = datetime.now().strftime("%m-%d[%H:%M:%S]") 62 | return date_now 63 | 64 | 65 | if __name__ == "__main__": 66 | # print(gen_file_prefix("this_is_my_model.")) 67 | # print(get_cur_time_str()) 68 | score_logger = ScoreLogger({'a_score': -1, 'b_score': -1}) 69 | print(score_logger.incorporate_results({'a_score': 2, 'b_score': -1}, 'key-1', {'a_score': 2, 'b_score': -1})) 70 | print(score_logger.incorporate_results({'a_score': 2, 'b_score': 3}, 'key-2', {'a_score': 2, 'b_score': 3})) 71 | print(score_logger.incorporate_results({'a_score': 2, 'b_score': 4}, 'key-2', {'a_score': 2, 'b_score': 4})) 72 | print(score_logger.incorporate_results({'a_score': 1, 'b_score': 2}, 'key-2', {'a_score': 1, 'b_score': 2})) 73 | print(score_logger.score_tracker) 74 | score_logger.logging_to_file('for_testing.json') -------------------------------------------------------------------------------- /src/hg_api/interactive_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 7 | import torch 8 | import json 9 | 10 | 11 | def get_prediction(tokenizer, model, premise, hypothesis, max_length=256): 12 | tokenized_input_seq_pair = tokenizer.encode_plus(premise, hypothesis, 13 | max_length=max_length, 14 | return_token_type_ids=True, truncation=True) 15 | 16 | input_ids = torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0) 17 | token_type_ids = torch.Tensor(tokenized_input_seq_pair['token_type_ids']).long().unsqueeze(0) 18 | attention_mask = torch.Tensor(tokenized_input_seq_pair['attention_mask']).long().unsqueeze(0) 19 | 20 | outputs = model(input_ids, 21 | attention_mask=attention_mask, 22 | token_type_ids=token_type_ids, 23 | labels=None) 24 | 25 | predicted_probability = torch.softmax(outputs[0], dim=1)[0] # batch_size only one 26 | predicted_index = torch.argmax(predicted_probability) 27 | predicted_probability = predicted_probability.tolist() 28 | 29 | return predicted_probability, predicted_index 30 | 31 | 32 | if __name__ == '__main__': 33 | premise = "Two women are embracing while holding to go packages." 34 | hypothesis = "The men are fighting outside a deli." 35 | 36 | hg_model_hub_name = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli" 37 | # hg_model_hub_name = "ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli" 38 | # hg_model_hub_name = "ynie/bart-large-snli_mnli_fever_anli_R1_R2_R3-nli" 39 | # hg_model_hub_name = "ynie/electra-large-discriminator-snli_mnli_fever_anli_R1_R2_R3-nli" 40 | # hg_model_hub_name = "ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli" 41 | 42 | tokenizer = AutoTokenizer.from_pretrained(hg_model_hub_name) 43 | model = AutoModelForSequenceClassification.from_pretrained(hg_model_hub_name) 44 | 45 | snli_dev = [] 46 | SNLI_DEV_FILE_PATH = "../../data/snli_1.0/snli_1.0_dev.jsonl" # you can change this to other path. 47 | with open(SNLI_DEV_FILE_PATH, mode='r', encoding='utf-8') as in_f: 48 | for line in in_f: 49 | if line: 50 | cur_item = json.loads(line) 51 | if cur_item['gold_label'] != '-': 52 | snli_dev.append(cur_item) 53 | 54 | total = 0 55 | correct = 0 56 | label_mapping = { 57 | 0: 'entailment', 58 | 1: 'neutral', 59 | 2: 'contradiction', 60 | } 61 | 62 | print("Start evaluating...") # this might take a while. 63 | for item in snli_dev: 64 | _, pred_index = get_prediction(tokenizer, model, item['sentence1'], item['sentence2']) 65 | if label_mapping[int(pred_index)] == item['gold_label']: 66 | correct += 1 67 | total += 1 68 | if total % 200 == 0 and total != 0: 69 | print(f"{total} finished.") 70 | 71 | print("Total / Correct / Accuracy:", f"{total} / {correct} / {correct / total}") -------------------------------------------------------------------------------- /src/nli/inspection_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | import logging 9 | from captum.attr import LayerIntegratedGradients 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def summarize_attributions(attributions): 15 | """ 16 | Summarises the attribution across multiple runs 17 | """ 18 | attributions = attributions.sum(dim=-1).squeeze(0) 19 | attributions = attributions / torch.norm(attributions) 20 | return attributions 21 | 22 | 23 | def get_model_prediction(input_ids, attention_mask, token_type_ids, model, model_class_item, with_gradient=False): 24 | model.eval() 25 | 26 | if not with_gradient: 27 | with torch.no_grad(): 28 | if model_class_item['model_class_name'] in ["distilbert", "bart-large"]: 29 | outputs = model(input_ids, 30 | attention_mask=attention_mask, 31 | labels=None) 32 | else: 33 | outputs = model(input_ids, 34 | attention_mask=attention_mask, 35 | token_type_ids=token_type_ids, 36 | labels=None) 37 | else: 38 | if model_class_item['model_class_name'] in ["distilbert", "bart-large"]: 39 | outputs = model(input_ids, 40 | attention_mask=attention_mask, 41 | labels=None) 42 | else: 43 | outputs = model(input_ids, 44 | attention_mask=attention_mask, 45 | token_type_ids=token_type_ids, 46 | labels=None) 47 | 48 | return outputs[0] 49 | 50 | 51 | def get_lig_object(model, model_class_item): 52 | insight_supported = model_class_item['insight_supported'] if 'insight_supported' in model_class_item else False 53 | internal_model_name = model_class_item['internal_model_name'] 54 | lig = None # default is None. 55 | if not insight_supported: 56 | logger.warning(f"Inspection for model '{model_class_item['model_class_name']}' is not supported.") 57 | return lig 58 | 59 | if isinstance(internal_model_name, list): 60 | current_layer = model 61 | for layer_n in internal_model_name: 62 | current_layer = current_layer.__getattr__(layer_n) 63 | # print(current_layer) 64 | lig = LayerIntegratedGradients(get_model_prediction, current_layer) 65 | else: 66 | lig = LayerIntegratedGradients(get_model_prediction, 67 | model.__getattr__(internal_model_name).embeddings.word_embeddings) 68 | return lig 69 | 70 | 71 | def get_tokenized_input_tokens(tokenizer, token_ids): 72 | raw_words_list = tokenizer.convert_ids_to_tokens(token_ids) 73 | string_tokens = [tokenizer.convert_tokens_to_string(word) for word in raw_words_list] 74 | # still need some cleanup, remove space within tokens 75 | output_tokens = [] 76 | for t in string_tokens: 77 | output_tokens.append(t.replace(" ", "")) 78 | return output_tokens 79 | 80 | 81 | def cleanup_tokenization_special_tokens(tokens, importance, tokenizer): 82 | filtered_tokens = [] 83 | filtered_importance = [] 84 | for t, i in zip(tokens, importance): 85 | if t in tokenizer.all_special_tokens: 86 | continue 87 | else: 88 | filtered_tokens.append(t) 89 | filtered_importance.append(i) 90 | return filtered_tokens, filtered_importance 91 | -------------------------------------------------------------------------------- /script/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/env bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | #ANLI_VERSION=0.1 8 | ANLI_VERSION=1.0 9 | #echo ${ANLI_VERSION} 10 | 11 | if [[ -z "$DIR_TMP" ]]; then # If project root not defined. 12 | # get the directory of this file 13 | export CURRENT_FILE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 14 | # setup root directory. 15 | export DIR_TMP=$(cd "${CURRENT_FILE_DIR}/.."; pwd) 16 | fi 17 | 18 | export DIR_TMP=$(cd "${DIR_TMP}"; pwd) 19 | echo "The path of project root: ${DIR_TMP}" 20 | 21 | 22 | # check if data exist. 23 | if [[ ! -d ${DIR_TMP}/data ]]; then 24 | mkdir ${DIR_TMP}/data 25 | fi 26 | 27 | # download the snli data. 28 | cd ${DIR_TMP}/data 29 | if [[ ! -d snli_1.0 ]]; then 30 | wget https://nlp.stanford.edu/projects/snli/snli_1.0.zip 31 | unzip "snli_1.0.zip" 32 | rm -rf "snli_1.0.zip" && rm -rf "__MACOSX" 33 | echo "SNLI Ready" 34 | fi 35 | 36 | # download the mnli data. 37 | cd ${DIR_TMP}/data 38 | if [[ ! -d multinli_1.0 ]]; then 39 | wget "https://cims.nyu.edu/~sbowman/multinli/multinli_1.0.zip" 40 | unzip "multinli_1.0.zip" 41 | rm -rf "multinli_1.0.zip" && rm -rf "__MACOSX" 42 | echo "MNLI Ready" 43 | fi 44 | 45 | # download the fever nli data. 46 | cd ${DIR_TMP}/data 47 | if [[ ! -d nli_fever ]]; then 48 | wget "https://www.dropbox.com/s/hylbuaovqwo2zav/nli_fever.zip" 49 | unzip "nli_fever.zip" 50 | rm -rf "nli_fever.zip" && rm -rf "__MACOSX" 51 | echo "FEVER NLI Ready" 52 | fi 53 | 54 | # download the anli_v0.1 55 | cd ${DIR_TMP}/data 56 | if [[ ! -d anli_v${ANLI_VERSION} ]]; then 57 | wget "https://dl.fbaipublicfiles.com/anli/anli_v${ANLI_VERSION}.zip" 58 | unzip "anli_v${ANLI_VERSION}.zip" 59 | rm -rf "anli_v${ANLI_VERSION}.zip" && rm -rf "__MACOSX" 60 | echo "ANLI Ready" 61 | fi 62 | 63 | ALL_DATA_CHECKED=true 64 | 65 | # Check data SNLI: 66 | cd ${DIR_TMP}/data 67 | if [[ -f snli_1.0/snli_1.0_train.jsonl ]] && [[ -f snli_1.0/snli_1.0_dev.jsonl ]] && [[ -f snli_1.0/snli_1.0_test.jsonl ]]; then 68 | echo "SNLI checked." 69 | else 70 | echo "Some SNLI files are not ready. Please remove the \"snli_1.0\" directory and run download.sh again." 71 | ALL_DATA_CHECKED=false 72 | fi 73 | 74 | # Check data MNLI: 75 | cd ${DIR_TMP}/data 76 | if [[ -f multinli_1.0/multinli_1.0_train.jsonl ]] && [[ -f multinli_1.0/multinli_1.0_dev_mismatched.jsonl ]] && [[ -f multinli_1.0/multinli_1.0_dev_matched.jsonl ]]; then 77 | echo "MNLI checked." 78 | else 79 | echo "Some MNLI files are not ready. Please remove the \"multinli_1.0\" directory and run download.sh again." 80 | ALL_DATA_CHECKED=false 81 | fi 82 | 83 | # Check data FEVER NLI: 84 | cd ${DIR_TMP}/data 85 | if [[ -f nli_fever/train_fitems.jsonl ]] && \ 86 | [[ -f nli_fever/test_fitems.jsonl ]] && \ 87 | [[ -f nli_fever/dev_fitems.jsonl ]]; then 88 | echo "FEVER NLI checked." 89 | else 90 | echo "Some FEVER NLI files are not ready. Please remove the \"nli_fever\" directory and run download.sh again." 91 | ALL_DATA_CHECKED=false 92 | fi 93 | 94 | # Check data ANLI: 95 | cd ${DIR_TMP}/data 96 | if [[ -f anli_v${ANLI_VERSION}/R1/train.jsonl ]] && \ 97 | [[ -f anli_v${ANLI_VERSION}/R1/dev.jsonl ]] && \ 98 | [[ -f anli_v${ANLI_VERSION}/R1/test.jsonl ]] && \ 99 | [[ -f anli_v${ANLI_VERSION}/R2/train.jsonl ]] && \ 100 | [[ -f anli_v${ANLI_VERSION}/R2/dev.jsonl ]] && \ 101 | [[ -f anli_v${ANLI_VERSION}/R2/test.jsonl ]] && \ 102 | [[ -f anli_v${ANLI_VERSION}/R3/train.jsonl ]] && \ 103 | [[ -f anli_v${ANLI_VERSION}/R3/dev.jsonl ]] && \ 104 | [[ -f anli_v${ANLI_VERSION}/R3/test.jsonl ]]; \ 105 | then 106 | echo "ANLI checked." 107 | else 108 | echo "Some ANLI files are not ready. Please remove the \"anli_v${ANLI_VERSION}\" directory and run download.sh again." 109 | ALL_DATA_CHECKED=false 110 | fi 111 | 112 | if [[ ${ALL_DATA_CHECKED} == true ]]; 113 | then 114 | echo "Data download completed and checked." 115 | else 116 | echo "Some data is missing. Please examine again or delete the data directory and re-run download.sh." 117 | fi -------------------------------------------------------------------------------- /src/utils/list_dict_data_tool.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import uuid 7 | 8 | 9 | def list_to_dict(d_list, key_fields): # '_id' or 'pid' 10 | d_dict = dict() 11 | for item in d_list: 12 | assert key_fields in item 13 | d_dict[item[key_fields]] = item 14 | return d_dict 15 | 16 | 17 | def dict_to_list(d_dict): 18 | d_list = [] 19 | for key, value in d_dict.items(): 20 | d_list.append(value) 21 | return d_list 22 | 23 | 24 | def append_item_from_dict_to_list(d_list, d_dict, key_fieldname, append_fieldnames): 25 | if not isinstance(append_fieldnames, list): 26 | append_fieldnames = [append_fieldnames] 27 | for item in d_list: 28 | key = item[key_fieldname] 29 | if key in d_dict: 30 | for append_fieldname in append_fieldnames: 31 | item[append_fieldname] = d_dict[key][append_fieldname] 32 | else: 33 | print(f"Potential Error: {key} not in scored_dict. Maybe bc all forward items are empty.") 34 | for append_fieldname in append_fieldnames: 35 | item[append_fieldname] = [] 36 | return d_list 37 | 38 | 39 | def append_item_from_dict_to_list_hotpot_style(d_list, d_dict, key_fieldname, append_fieldnames): 40 | if not isinstance(append_fieldnames, list): 41 | append_fieldnames = [append_fieldnames] 42 | for item in d_list: 43 | key = item[key_fieldname] 44 | for append_fieldname in append_fieldnames: 45 | if key in d_dict[append_fieldname]: 46 | item[append_fieldname] = d_dict[append_fieldname][key] 47 | else: 48 | print(f"Potential Error: {key} not in scored_dict. Maybe bc all forward items are empty.") 49 | # for append_fieldname in append_fieldnames: 50 | item[append_fieldname] = [] 51 | return d_list 52 | 53 | 54 | def append_subfield_from_list_to_dict(subf_list, d_dict, o_key_field_name, subfield_key_name, 55 | subfield_name='merged_field', check=False): 56 | # Often times, we will need to split the one data point to multiple items to be feeded into neural networks 57 | # and after we obtain the results we will need to map the results back to original data point with some keys. 58 | 59 | # This method is used for this purpose. 60 | # The method can be invoke multiple times, (in practice usually one batch per time.) 61 | """ 62 | :param subf_list: The forward list. 63 | :param d_dict: The dict that contain keys mapping to original data point. 64 | :param o_key_field_name: The fieldname of original data point key. 'pid' 65 | :param subfield_key_name: The fieldname of the sub item. 'fid' 66 | :param subfield_name: The merge field name. 'merged_field' 67 | :param check: 68 | :return: 69 | """ 70 | for key in d_dict.keys(): 71 | d_dict[key][subfield_name] = dict() 72 | 73 | for item in subf_list: 74 | assert o_key_field_name in item 75 | assert subfield_key_name in item 76 | map_id = item[o_key_field_name] 77 | sub_filed_id = item[subfield_key_name] 78 | assert map_id in d_dict 79 | 80 | # if subfield_name not in d_dict[map_id]: 81 | # d_dict[map_id][subfield_name] = dict() 82 | 83 | if sub_filed_id not in d_dict[map_id][subfield_name]: 84 | if check: 85 | assert item[o_key_field_name] == map_id 86 | d_dict[map_id][subfield_name][sub_filed_id] = item 87 | else: 88 | print("Duplicate forward item with key:", sub_filed_id) 89 | 90 | return d_dict 91 | 92 | 93 | if __name__ == '__main__': 94 | oitems = [] 95 | for i in range(3): 96 | oitems.append({'_id': i}) 97 | 98 | fitems = [] 99 | for item in oitems: 100 | oid = item['_id'] 101 | for i in range(int(oid) + 1): 102 | fid = str(uuid.uuid4()) 103 | fitems.append({ 104 | 'oid': oid, 105 | 'fid': fid, 106 | }) 107 | 108 | o_dict = list_to_dict(oitems, '_id') 109 | append_subfield_from_list_to_dict(fitems, o_dict, 'oid', 'fid', check=True) 110 | 111 | print(fitems) 112 | print(o_dict) 113 | -------------------------------------------------------------------------------- /src/dataset_tools/build_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from pathlib import Path 7 | 8 | import config 9 | from dataset_tools.format_convert import sm_nli2std_format, fever_nli2std_format, a_nli2std_format 10 | from utils import common 11 | 12 | # ANLI_VERSION = 1.0 13 | 14 | 15 | def build_snli(path: Path): 16 | snli_data_root_path = (path / "snli") 17 | if not snli_data_root_path.exists(): 18 | snli_data_root_path.mkdir() 19 | o_train = common.load_jsonl(config.PRO_ROOT / "data/snli_1.0/snli_1.0_train.jsonl") 20 | o_dev = common.load_jsonl(config.PRO_ROOT / "data/snli_1.0/snli_1.0_dev.jsonl") 21 | o_test = common.load_jsonl(config.PRO_ROOT / "data/snli_1.0/snli_1.0_test.jsonl") 22 | 23 | d_trian = sm_nli2std_format(o_train) 24 | d_dev = sm_nli2std_format(o_dev) 25 | d_test = sm_nli2std_format(o_test) 26 | 27 | print("SNLI examples without gold label have been filtered.") 28 | print("SNLI Train size:", len(d_trian)) 29 | print("SNLI Dev size:", len(d_dev)) 30 | print("SNLI Test size:", len(d_test)) 31 | 32 | common.save_jsonl(d_trian, snli_data_root_path / 'train.jsonl') 33 | common.save_jsonl(d_dev, snli_data_root_path / 'dev.jsonl') 34 | common.save_jsonl(d_test, snli_data_root_path / 'test.jsonl') 35 | 36 | 37 | def build_mnli(path: Path): 38 | data_root_path = (path / "mnli") 39 | if not data_root_path.exists(): 40 | data_root_path.mkdir() 41 | o_train = common.load_jsonl(config.PRO_ROOT / "data/multinli_1.0/multinli_1.0_train.jsonl") 42 | o_mm_dev = common.load_jsonl(config.PRO_ROOT / "data/multinli_1.0/multinli_1.0_dev_mismatched.jsonl") 43 | o_m_dev = common.load_jsonl(config.PRO_ROOT / "data/multinli_1.0/multinli_1.0_dev_matched.jsonl") 44 | 45 | d_trian = sm_nli2std_format(o_train) 46 | d_mm_dev = sm_nli2std_format(o_mm_dev) 47 | d_m_test = sm_nli2std_format(o_m_dev) 48 | 49 | print("MNLI examples without gold label have been filtered.") 50 | print("MNLI Train size:", len(d_trian)) 51 | print("MNLI MisMatched Dev size:", len(d_mm_dev)) 52 | print("MNLI Matched dev size:", len(d_m_test)) 53 | 54 | common.save_jsonl(d_trian, data_root_path / 'train.jsonl') 55 | common.save_jsonl(d_mm_dev, data_root_path / 'mm_dev.jsonl') 56 | common.save_jsonl(d_m_test, data_root_path / 'm_dev.jsonl') 57 | 58 | 59 | def build_fever_nli(path: Path): 60 | data_root_path = (path / "fever_nli") 61 | if not data_root_path.exists(): 62 | data_root_path.mkdir() 63 | 64 | o_train = common.load_jsonl(config.PRO_ROOT / "data/nli_fever/train_fitems.jsonl") 65 | o_dev = common.load_jsonl(config.PRO_ROOT / "data/nli_fever/dev_fitems.jsonl") 66 | o_test = common.load_jsonl(config.PRO_ROOT / "data/nli_fever/test_fitems.jsonl") 67 | 68 | d_trian = fever_nli2std_format(o_train) 69 | d_dev = fever_nli2std_format(o_dev) 70 | d_test = fever_nli2std_format(o_test) 71 | 72 | print("FEVER-NLI Train size:", len(d_trian)) 73 | print("FEVER-NLI Dev size:", len(d_dev)) 74 | print("FEVER-NLI Test size:", len(d_test)) 75 | 76 | common.save_jsonl(d_trian, data_root_path / 'train.jsonl') 77 | common.save_jsonl(d_dev, data_root_path / 'dev.jsonl') 78 | common.save_jsonl(d_test, data_root_path / 'test.jsonl') 79 | 80 | 81 | def build_anli(path: Path, round=1, version='1.0'): 82 | data_root_path = (path / "anli") 83 | if not data_root_path.exists(): 84 | data_root_path.mkdir() 85 | 86 | round_tag = str(round) 87 | 88 | o_train = common.load_jsonl(config.PRO_ROOT / f"data/anli_v{version}/R{round_tag}/train.jsonl") 89 | o_dev = common.load_jsonl(config.PRO_ROOT / f"data/anli_v{version}/R{round_tag}/dev.jsonl") 90 | o_test = common.load_jsonl(config.PRO_ROOT / f"data/anli_v{version}/R{round_tag}/test.jsonl") 91 | 92 | d_trian = a_nli2std_format(o_train) 93 | d_dev = a_nli2std_format(o_dev) 94 | d_test = a_nli2std_format(o_test) 95 | 96 | print(f"ANLI (R{round_tag}) Train size:", len(d_trian)) 97 | print(f"ANLI (R{round_tag}) Dev size:", len(d_dev)) 98 | print(f"ANLI (R{round_tag}) Test size:", len(d_test)) 99 | 100 | if not (data_root_path / f"r{round_tag}").exists(): 101 | (data_root_path / f"r{round_tag}").mkdir() 102 | 103 | common.save_jsonl(d_trian, data_root_path / f"r{round_tag}" / 'train.jsonl') 104 | common.save_jsonl(d_dev, data_root_path / f"r{round_tag}" / 'dev.jsonl') 105 | common.save_jsonl(d_test, data_root_path / f"r{round_tag}" / 'test.jsonl') 106 | 107 | 108 | def build_data(): 109 | processed_data_root = config.PRO_ROOT / "data" / "build" 110 | if not processed_data_root.exists(): 111 | processed_data_root.mkdir() 112 | build_snli(processed_data_root) 113 | build_mnli(processed_data_root) 114 | build_fever_nli(processed_data_root) 115 | for round in [1, 2, 3]: 116 | build_anli(processed_data_root, round) 117 | 118 | print("NLI data built!") 119 | 120 | 121 | if __name__ == '__main__': 122 | build_data() -------------------------------------------------------------------------------- /src/nli/evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | from pathlib import Path 8 | 9 | import config 10 | from flint.data_utils.fields import RawFlintField, LabelFlintField, ArrayIndexFlintField 11 | from utils import common, list_dict_data_tool, save_tool 12 | from nli.training import MODEL_CLASSES, registered_path, build_eval_dataset_loader_and_sampler, NLITransform, \ 13 | NLIDataset, count_acc, evaluation_dataset, eval_model 14 | 15 | import torch 16 | 17 | import pprint 18 | 19 | pp = pprint.PrettyPrinter(indent=2) 20 | 21 | 22 | def evaluation(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--cpu", action="store_true", help="If set, we only use CPU.") 25 | parser.add_argument( 26 | "--model_class_name", 27 | type=str, 28 | help="Set the model class of the experiment.", 29 | required=True 30 | ) 31 | 32 | parser.add_argument( 33 | "--model_checkpoint_path", 34 | type=str, 35 | help='Set the path to save the prediction.', required=True) 36 | 37 | parser.add_argument( 38 | "--output_prediction_path", 39 | type=str, 40 | default=None, 41 | help='Set the path to save the prediction.') 42 | 43 | parser.add_argument( 44 | "--per_gpu_eval_batch_size", default=16, type=int, help="Batch size per GPU/CPU for evaluation.", 45 | ) 46 | 47 | parser.add_argument("--max_length", default=156, type=int, help="Max length of the sequences.") 48 | 49 | parser.add_argument("--eval_data", 50 | type=str, 51 | help="The training data used in the experiments.") 52 | 53 | args = parser.parse_args() 54 | 55 | if args.cpu: 56 | args.global_rank = -1 57 | else: 58 | args.global_rank = 0 59 | 60 | model_checkpoint_path = args.model_checkpoint_path 61 | num_labels = 3 62 | # we are doing NLI so we set num_labels = 3, for other task we can change this value. 63 | 64 | max_length = args.max_length 65 | 66 | model_class_item = MODEL_CLASSES[args.model_class_name] 67 | model_name = model_class_item['model_name'] 68 | do_lower_case = model_class_item['do_lower_case'] if 'do_lower_case' in model_class_item else False 69 | 70 | tokenizer = model_class_item['tokenizer'].from_pretrained(model_name, 71 | cache_dir=str(config.PRO_ROOT / "trans_cache"), 72 | do_lower_case=do_lower_case) 73 | 74 | model = model_class_item['sequence_classification'].from_pretrained(model_name, 75 | cache_dir=str(config.PRO_ROOT / "trans_cache"), 76 | num_labels=num_labels) 77 | 78 | model.load_state_dict(torch.load(model_checkpoint_path)) 79 | 80 | padding_token_value = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0] 81 | padding_segement_value = model_class_item["padding_segement_value"] 82 | padding_att_value = model_class_item["padding_att_value"] 83 | left_pad = model_class_item['left_pad'] if 'left_pad' in model_class_item else False 84 | 85 | batch_size_per_gpu_eval = args.per_gpu_eval_batch_size 86 | 87 | eval_data_str = args.eval_data 88 | eval_data_name = [] 89 | eval_data_path = [] 90 | eval_data_list = [] 91 | 92 | eval_data_named_path = eval_data_str.split(',') 93 | 94 | for named_path in eval_data_named_path: 95 | ind = named_path.find(':') 96 | name = named_path[:ind] 97 | path = name[ind + 1:] 98 | if name in registered_path: 99 | d_list = common.load_jsonl(registered_path[name]) 100 | else: 101 | d_list = common.load_jsonl(path) 102 | eval_data_name.append(name) 103 | eval_data_path.append(path) 104 | 105 | eval_data_list.append(d_list) 106 | 107 | batching_schema = { 108 | 'uid': RawFlintField(), 109 | 'y': LabelFlintField(), 110 | 'input_ids': ArrayIndexFlintField(pad_idx=padding_token_value, left_pad=left_pad), 111 | 'token_type_ids': ArrayIndexFlintField(pad_idx=padding_segement_value, left_pad=left_pad), 112 | 'attention_mask': ArrayIndexFlintField(pad_idx=padding_att_value, left_pad=left_pad), 113 | } 114 | 115 | data_transformer = NLITransform(model_name, tokenizer, max_length) 116 | eval_data_loaders = [] 117 | for eval_d_list in eval_data_list: 118 | d_dataset, d_sampler, d_dataloader = build_eval_dataset_loader_and_sampler(eval_d_list, data_transformer, 119 | batching_schema, 120 | batch_size_per_gpu_eval) 121 | eval_data_loaders.append(d_dataloader) 122 | 123 | if not args.cpu: 124 | torch.cuda.set_device(0) 125 | model.cuda(0) 126 | 127 | r_dict = dict() 128 | # Eval loop: 129 | for i in range(len(eval_data_name)): 130 | cur_eval_data_name = eval_data_name[i] 131 | cur_eval_data_list = eval_data_list[i] 132 | cur_eval_dataloader = eval_data_loaders[i] 133 | # cur_eval_raw_data_list = eval_raw_data_list[i] 134 | 135 | evaluation_dataset(args, cur_eval_dataloader, cur_eval_data_list, model, r_dict, 136 | eval_name=cur_eval_data_name) 137 | 138 | # save prediction: 139 | if args.output_prediction_path is not None: 140 | cur_results_path = Path(args.output_prediction_path) 141 | if not cur_results_path.exists(): 142 | cur_results_path.mkdir(parents=True) 143 | for key, item in r_dict.items(): 144 | common.save_jsonl(item['predictions'], cur_results_path / f"{key}.jsonl") 145 | 146 | # avoid saving too many things 147 | for key, item in r_dict.items(): 148 | del r_dict[key]['predictions'] 149 | common.save_json(r_dict, cur_results_path / "results_dict.json", indent=2) 150 | 151 | 152 | if __name__ == '__main__': 153 | evaluation() 154 | -------------------------------------------------------------------------------- /src/modeling/res_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import torch 6 | import torch.nn as nn 7 | from torch import optim 8 | from torch.autograd import Variable 9 | from torch.nn import MSELoss, CrossEntropyLoss 10 | 11 | import flint.torch_util as torch_util 12 | from tqdm import tqdm 13 | import os 14 | from datetime import datetime 15 | 16 | 17 | class EmptyScheduler(object): 18 | def __init__(self): 19 | self._state_dict = dict() 20 | 21 | def step(self): 22 | pass 23 | 24 | def state_dict(self): 25 | return self._state_dict 26 | 27 | 28 | class ResEncoder(nn.Module): 29 | def __init__(self, h_size=[1024, 1024, 1024], v_size=10, embd_dim=300, mlp_d=1024, 30 | dropout_r=0.1, k=3, n_layers=1, num_labels=3): 31 | super(ResEncoder, self).__init__() 32 | self.Embd = nn.Embedding(v_size, embd_dim) 33 | self.num_labels = num_labels 34 | 35 | self.lstm = nn.LSTM(input_size=embd_dim, hidden_size=h_size[0], 36 | num_layers=1, bidirectional=True) 37 | 38 | self.lstm_1 = nn.LSTM(input_size=(embd_dim + h_size[0] * 2), hidden_size=h_size[1], 39 | num_layers=1, bidirectional=True) 40 | 41 | self.lstm_2 = nn.LSTM(input_size=(embd_dim + h_size[0] * 2), hidden_size=h_size[2], 42 | num_layers=1, bidirectional=True) 43 | 44 | self.h_size = h_size 45 | self.k = k 46 | 47 | # self.mlp_1 = nn.Linear(h_size[2] * 2 * 4, mlp_d) 48 | self.mlp_1 = nn.Linear(h_size[2] * 2, mlp_d) 49 | self.mlp_2 = nn.Linear(mlp_d, mlp_d) 50 | self.sm = nn.Linear(mlp_d, self.num_labels) 51 | 52 | if n_layers == 1: 53 | self.classifier = nn.Sequential(*[self.mlp_1, nn.ReLU(), nn.Dropout(dropout_r), 54 | self.sm]) 55 | elif n_layers == 2: 56 | self.classifier = nn.Sequential(*[self.mlp_1, nn.ReLU(), nn.Dropout(dropout_r), 57 | self.mlp_2, nn.ReLU(), nn.Dropout(dropout_r), 58 | self.sm]) 59 | else: 60 | print("Error num layers") 61 | 62 | def init_embedding(self, embedding): 63 | self.Embd.weight = embedding.weight 64 | 65 | def forward(self, input_ids, attention_mask, labels=None): 66 | # if self.max_l: 67 | # l1 = l1.clamp(max=self.max_l) 68 | # l2 = l2.clamp(max=self.max_l) 69 | # if s1.size(0) > self.max_l: 70 | # s1 = s1[:self.max_l, :] 71 | # if s2.size(0) > self.max_l: 72 | # s2 = s2[:self.max_l, :] 73 | batch_l_1 = torch.sum(attention_mask, dim=1) 74 | 75 | # p_s1 = self.Embd(s1) 76 | embedding_1 = self.Embd(input_ids) 77 | 78 | s1_layer1_out = torch_util.auto_rnn(self.lstm, embedding_1, batch_l_1) 79 | # s2_layer1_out = torch_util.auto_rnn_bilstm(self.lstm, p_s2, l2) 80 | 81 | # Length truncate 82 | # len1 = s1_layer1_out.size(0) 83 | # len2 = s2_layer1_out.size(0) 84 | # p_s1 = p_s1[:len1, :, :] 85 | # p_s2 = p_s2[:len2, :, :] 86 | 87 | # Using high way 88 | s1_layer2_in = torch.cat([embedding_1, s1_layer1_out], dim=2) 89 | # s2_layer2_in = torch.cat([p_s2, s2_layer1_out], dim=2) 90 | 91 | s1_layer2_out = torch_util.auto_rnn(self.lstm_1, s1_layer2_in, batch_l_1) 92 | # s2_layer2_out = torch_util.auto_rnn_bilstm(self.lstm_1, s2_layer2_in, l2) 93 | 94 | s1_layer3_in = torch.cat([embedding_1, s1_layer1_out + s1_layer2_out], dim=2) 95 | # s2_layer3_in = torch.cat([p_s2, s2_layer1_out + s2_layer2_out], dim=2) 96 | 97 | s1_layer3_out = torch_util.auto_rnn(self.lstm_2, s1_layer3_in, batch_l_1) 98 | # s2_layer3_out = torch_util.auto_rnn_bilstm(self.lstm_2, s2_layer3_in, l2) 99 | 100 | s1_layer3_maxout = torch_util.max_along_time(s1_layer3_out, batch_l_1) 101 | # s2_layer3_maxout = torch_util.max_along_time(s2_layer3_out, l2) 102 | 103 | # Only use the last layer 104 | # features = torch.cat([s1_layer3_maxout, s2_layer3_maxout, 105 | # torch.abs(s1_layer3_maxout - s2_layer3_maxout), 106 | # s1_layer3_maxout * s2_layer3_maxout], 107 | # dim=1) 108 | 109 | features = torch.cat([s1_layer3_maxout], 110 | dim=1) 111 | 112 | logits = self.classifier(features) 113 | 114 | loss = None 115 | if labels is not None: 116 | if self.num_labels == 1: 117 | # We are doing regression 118 | loss_fct = MSELoss() 119 | loss = loss_fct(logits.view(-1), labels.view(-1)) 120 | else: 121 | loss_fct = CrossEntropyLoss() 122 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 123 | 124 | return (loss, logits) 125 | 126 | 127 | class BagOfWords(nn.Module): 128 | def __init__(self, v_size=10, embd_dim=300, mlp_d=1024, 129 | dropout_r=0.1, n_layers=1, num_labels=3): 130 | super(BagOfWords, self).__init__() 131 | self.Embd = nn.Embedding(v_size, embd_dim) 132 | self.num_labels = num_labels 133 | 134 | # self.mlp_1 = nn.Linear(h_size[2] * 2 * 4, mlp_d) 135 | self.mlp_1 = nn.Linear(embd_dim, mlp_d) 136 | self.mlp_2 = nn.Linear(mlp_d, mlp_d) 137 | self.sm = nn.Linear(mlp_d, self.num_labels) 138 | 139 | if n_layers == 1: 140 | self.classifier = nn.Sequential(*[self.mlp_1, nn.ReLU(), nn.Dropout(dropout_r), 141 | self.sm]) 142 | elif n_layers == 2: 143 | self.classifier = nn.Sequential(*[self.mlp_1, nn.ReLU(), nn.Dropout(dropout_r), 144 | self.mlp_2, nn.ReLU(), nn.Dropout(dropout_r), 145 | self.sm]) 146 | else: 147 | print("Error num layers") 148 | 149 | def init_embedding(self, embedding): 150 | self.Embd.weight = embedding.weight 151 | 152 | def forward(self, input_ids, attention_mask, labels=None): 153 | # if self.max_l: 154 | # l1 = l1.clamp(max=self.max_l) 155 | # l2 = l2.clamp(max=self.max_l) 156 | # if s1.size(0) > self.max_l: 157 | # s1 = s1[:self.max_l, :] 158 | # if s2.size(0) > self.max_l: 159 | # s2 = s2[:self.max_l, :] 160 | batch_l_1 = torch.sum(attention_mask, dim=1) 161 | 162 | # p_s1 = self.Embd(s1) 163 | embedding_1 = self.Embd(input_ids) 164 | 165 | s1_layer3_maxout = torch_util.avg_along_time(embedding_1, batch_l_1) 166 | # s2_layer3_maxout = torch_util.max_along_time(s2_layer3_out, l2) 167 | 168 | # Only use the last layer 169 | # features = torch.cat([s1_layer3_maxout, s2_layer3_maxout, 170 | # torch.abs(s1_layer3_maxout - s2_layer3_maxout), 171 | # s1_layer3_maxout * s2_layer3_maxout], 172 | # dim=1) 173 | 174 | features = torch.cat([s1_layer3_maxout], 175 | dim=1) 176 | 177 | logits = self.classifier(features) 178 | 179 | loss = None 180 | if labels is not None: 181 | if self.num_labels == 1: 182 | # We are doing regression 183 | loss_fct = MSELoss() 184 | loss = loss_fct(logits.view(-1), labels.view(-1)) 185 | else: 186 | loss_fct = CrossEntropyLoss() 187 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 188 | 189 | return (loss, logits) -------------------------------------------------------------------------------- /src/nli/inference_debug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | from pathlib import Path 8 | import uuid 9 | import numpy as np 10 | 11 | import config 12 | from flint.data_utils.batchbuilder import move_to_device 13 | from flint.data_utils.fields import RawFlintField, LabelFlintField, ArrayIndexFlintField 14 | from utils import common, list_dict_data_tool, save_tool 15 | from nli.training import MODEL_CLASSES, registered_path, build_eval_dataset_loader_and_sampler, NLITransform, \ 16 | NLIDataset, count_acc, evaluation_dataset, eval_model 17 | 18 | import torch 19 | 20 | import pprint 21 | 22 | pp = pprint.PrettyPrinter(indent=2) 23 | 24 | 25 | id2label = { 26 | 0: 'e', 27 | 1: 'n', 28 | 2: 'c', 29 | -1: '-', 30 | } 31 | 32 | 33 | def softmax(x): 34 | """Compute softmax values for each sets of scores in x.""" 35 | e_x = np.exp(np.asarray(x) - np.max(x)) 36 | return e_x / e_x.sum() 37 | 38 | 39 | def eval_model(model, dev_dataloader, device_num, args): 40 | model.eval() 41 | 42 | uid_list = [] 43 | y_list = [] 44 | pred_list = [] 45 | logits_list = [] 46 | 47 | with torch.no_grad(): 48 | for i, batch in enumerate(dev_dataloader, 0): 49 | batch = move_to_device(batch, device_num) 50 | 51 | if args.model_class_name in ["distilbert", "bart-large"]: 52 | outputs = model(batch['input_ids'], 53 | attention_mask=batch['attention_mask'], 54 | labels=None) 55 | else: 56 | outputs = model(batch['input_ids'], 57 | attention_mask=batch['attention_mask'], 58 | token_type_ids=batch['token_type_ids'], 59 | labels=None) 60 | 61 | # print(outputs) 62 | logits = outputs[0] 63 | 64 | uid_list.extend(list(batch['uid'])) 65 | y_list.extend(batch['y'].tolist()) 66 | pred_list.extend(torch.max(logits, 1)[1].view(logits.size(0)).tolist()) 67 | logits_list.extend(logits.tolist()) 68 | 69 | assert len(pred_list) == len(logits_list) 70 | assert len(pred_list) == len(logits_list) 71 | 72 | result_items_list = [] 73 | for i in range(len(uid_list)): 74 | r_item = dict() 75 | r_item['uid'] = uid_list[i] 76 | r_item['logits'] = logits_list[i] 77 | r_item['probability'] = softmax(r_item['logits']) 78 | r_item['predicted_label'] = id2label[pred_list[i]] 79 | 80 | result_items_list.append(r_item) 81 | 82 | return result_items_list 83 | 84 | 85 | def inference(model_class_name, model_checkpoint_path, max_length, premise, hypothesis, cpu=True): 86 | parser = argparse.ArgumentParser() 87 | args = parser.parse_args() 88 | 89 | # CPU for now 90 | if cpu: 91 | args.global_rank = -1 92 | else: 93 | args.global_rank = 0 94 | 95 | model_checkpoint_path = model_checkpoint_path 96 | args.model_class_name = model_class_name 97 | num_labels = 3 98 | # we are doing NLI so we set num_labels = 3, for other task we can change this value. 99 | 100 | max_length = max_length 101 | 102 | model_class_item = MODEL_CLASSES[model_class_name] 103 | model_name = model_class_item['model_name'] 104 | do_lower_case = model_class_item['do_lower_case'] if 'do_lower_case' in model_class_item else False 105 | 106 | tokenizer = model_class_item['tokenizer'].from_pretrained(model_name, 107 | cache_dir=str(config.PRO_ROOT / "trans_cache"), 108 | do_lower_case=do_lower_case) 109 | 110 | model = model_class_item['sequence_classification'].from_pretrained(model_name, 111 | cache_dir=str(config.PRO_ROOT / "trans_cache"), 112 | num_labels=num_labels) 113 | 114 | model.load_state_dict(torch.load(model_checkpoint_path)) 115 | 116 | padding_token_value = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0] 117 | padding_segement_value = model_class_item["padding_segement_value"] 118 | padding_att_value = model_class_item["padding_att_value"] 119 | left_pad = model_class_item['left_pad'] if 'left_pad' in model_class_item else False 120 | 121 | batch_size_per_gpu_eval = 16 122 | 123 | eval_data_list = [{ 124 | 'uid': str(uuid.uuid4()), 125 | 'premise': premise, 126 | 'hypothesis': hypothesis, 127 | 'label': 'h' # hidden 128 | }] 129 | 130 | batching_schema = { 131 | 'uid': RawFlintField(), 132 | 'y': LabelFlintField(), 133 | 'input_ids': ArrayIndexFlintField(pad_idx=padding_token_value, left_pad=left_pad), 134 | 'token_type_ids': ArrayIndexFlintField(pad_idx=padding_segement_value, left_pad=left_pad), 135 | 'attention_mask': ArrayIndexFlintField(pad_idx=padding_att_value, left_pad=left_pad), 136 | } 137 | 138 | data_transformer = NLITransform(model_name, tokenizer, max_length) 139 | 140 | d_dataset, d_sampler, d_dataloader = build_eval_dataset_loader_and_sampler(eval_data_list, data_transformer, 141 | batching_schema, 142 | batch_size_per_gpu_eval) 143 | 144 | if not cpu: 145 | torch.cuda.set_device(0) 146 | model.cuda(0) 147 | 148 | pred_output_list = eval_model(model, d_dataloader, args.global_rank, args) 149 | # r_dict = dict() 150 | # Eval loop: 151 | # print(pred_output_list) 152 | return pred_output_list[0] 153 | 154 | 155 | if __name__ == '__main__': 156 | # model_class_name = "roberta-large" 157 | # model_checkpoint_path = config.PRO_ROOT / "saved_models/06-29-22:16:24_roberta-large|snli+mnli+fnli+r1*10+r2*20+r3*10|nli/checkpoints/e(0)|i(24000)|snli_dev#(0.9252)|mnli_m_dev#(0.899)|mnli_mm_dev#(0.9002)|anli_r1_dev#(0.74)|anli_r1_test#(0.742)|anli_r2_dev#(0.506)|anli_r2_test#(0.498)|anli_r3_dev#(0.4667)|anli_r3_test#(0.455)/model.pt" 158 | 159 | # model_class_name = "xlnet-large" 160 | # model_checkpoint_path = config.PRO_ROOT / "saved_models/06-29-23:04:33_xlnet-large|snli+mnli+fnli+r1*10+r2*20+r3*10|nli/checkpoints/e(1)|i(30000)|snli_dev#(0.9274)|mnli_m_dev#(0.8981)|mnli_mm_dev#(0.8947)|anli_r1_dev#(0.735)|anli_r1_test#(0.701)|anli_r2_dev#(0.521)|anli_r2_test#(0.514)|anli_r3_dev#(0.5075)|anli_r3_test#(0.4975)/model.pt" 161 | 162 | model_class_name = "albert-xxlarge" 163 | model_checkpoint_path = config.PRO_ROOT / "saved_models/06-29-23:09:03_albert-xxlarge|snli+mnli+fnli+r1*10+r2*20+r3*10|nli/checkpoints/e(0)|i(16000)|snli_dev#(0.9246)|mnli_m_dev#(0.8948)|mnli_mm_dev#(0.8932)|anli_r1_dev#(0.733)|anli_r1_test#(0.711)|anli_r2_dev#(0.571)|anli_r2_test#(0.57)|anli_r3_dev#(0.5817)|anli_r3_test#(0.5375)/model.pt" 164 | # 165 | # model_class_name = "bart-large" 166 | # model_checkpoint_path = config.PRO_ROOT / "saved_models/06-30-08:23:44_bart-large|snli+mnli+fnli+r1*10+r2*20+r3*10|nli/checkpoints/e(1)|i(40000)|snli_dev#(0.9298)|mnli_m_dev#(0.8941)|mnli_mm_dev#(0.8973)|anli_r1_dev#(0.736)|anli_r1_test#(0.72)|anli_r2_dev#(0.533)|anli_r2_test#(0.514)|anli_r3_dev#(0.5058)|anli_r3_test#(0.5042)/model.pt" 167 | # 168 | # model_class_name = "electra-large" 169 | # model_checkpoint_path = config.PRO_ROOT / "saved_models/08-02-08:58:05_electra-large|snli+mnli+fnli+r1*10+r2*20+r3*10|nli/checkpoints/e(0)|i(12000)|snli_dev#(0.9168)|mnli_m_dev#(0.8597)|mnli_mm_dev#(0.8661)|anli_r1_dev#(0.672)|anli_r1_test#(0.678)|anli_r2_dev#(0.536)|anli_r2_test#(0.522)|anli_r3_dev#(0.55)|anli_r3_test#(0.5217)/model.pt" 170 | 171 | max_length = 184 172 | 173 | premise = "Two women are embracing while holding to go packages." 174 | hypothesis = "The men are fighting outside a deli." 175 | 176 | pred_output = inference(model_class_name, model_checkpoint_path, max_length, premise, hypothesis, cpu=True) 177 | print(pred_output) 178 | -------------------------------------------------------------------------------- /mds/start_your_nli_research.md: -------------------------------------------------------------------------------- 1 | # Start your NLI Research 2 | This tutorial gives detailed instructions to help you start research on NLI with pre-trained state-of-the-art models (June 2020). 3 | 4 | ## Requirements 5 | - python 3.6+ 6 | - tqdm 7 | - torch 1.4.0 https://pytorch.org/ 8 | - transformers 3.0.2 https://github.com/huggingface/transformers/ 9 | 10 | ## Initial Setup 11 | ### 1. Setup your python environment and install the requirements. 12 | ### 2. Clone this repo. 13 | ``` 14 | git clone https://github.com/facebookresearch/anli.git 15 | ``` 16 | ### 3. Run the following command in your terminal. 17 | ``` 18 | source setup.sh 19 | ``` 20 | The command will assign the environment variables `$DIR_TMP` and `$PYTHONPATH` 21 | to the project root path and src path, respectively. 22 | These two variables are needed to run any following scripts in this tutorial. 23 | 24 | ## Data Preparation 25 | ### 1. Run the following command to download the data. 26 | ``` 27 | cd $DIR_TMP # Make sure that you run all the scripts in the root of this repo. 28 | bash script/download_data.sh 29 | ``` 30 | All the data (SNLI, MNLI, FEVER-NLI, ANLI) will be downloaded in the `data` directory. 31 | If any data is missing, you can remove the unchecked folders in the `data` directory and re-download. 32 | 33 | ### 2. Run the following python script to build the data. 34 | ```bash 35 | cd $DIR_TMP # Make sure that you run all the scripts in the root of this repo. 36 | python src/dataset_tools/build_data.py # If you encounter import errors, please make sure you have run `source setup.sh` to set up the `$PYTHONPATH` 37 | ``` 38 | The script will convert SNLI, MNLI, FEVER-NLI, ANLI all into the same unified NLI format, and will also remove examples in SNLI and MNLI that do not have a gold label (as in prior work). 39 | 40 | #### Data Directory 41 | Once the `build_data.py` script has completed successfully, the `data` directory (in the project root) should contains a directory called `build` containing the dataset files in the unified data format. Your data directory should have a structure like the one attached below. 42 | ``` 43 | data 44 | ├── build 45 | │   ├── anli 46 | │   │   ├── r1 47 | │   │   │   ├── dev.jsonl 48 | │   │   │   ├── test.jsonl 49 | │   │   │   └── train.jsonl 50 | │   │   ├── r2 51 | │   │   │   ├── dev.jsonl 52 | │   │   │   ├── test.jsonl 53 | │   │   │   └── train.jsonl 54 | │   │   └── r3 55 | │   │   ├── dev.jsonl 56 | │   │   ├── test.jsonl 57 | │   │   └── train.jsonl 58 | │   ├── fever_nli 59 | │   │   ├── dev.jsonl 60 | │   │   ├── test.jsonl 61 | │   │   └── train.jsonl 62 | │   ├── mnli 63 | │   │   ├── m_dev.jsonl 64 | │   │   ├── mm_dev.jsonl 65 | │   │   └── train.jsonl 66 | │   └── snli 67 | │   ├── dev.jsonl 68 | │   ├── test.jsonl 69 | │   └── train.jsonl 70 | ├── anli_v1.0 71 | │   ├── ... # The unzipped ANLI data with original format. 72 | ├── multinli_1.0 73 | │   ├── ... # The unzipped MNLI data with original format. 74 | ├── nli_fever 75 | │   ├── ... # The unzipped NLI-FEVER data with original format. 76 | └── snli_1.0 77 | ├── ... # The unzipped SNLI data with original format. 78 | ``` 79 | 80 | #### Data Format 81 | NLI is basically a 3-way sequence-to-label classification task where the inputs are two textual sequences 82 | called `premise` and `hypothesis`, and the output is a discrete label that is either `entailment`, `contradiction`, or `neutral`. 83 | (Some works consider it to be a 2-way classification tasks. The code here can easily be converted to any sequence-to-label task with some hacking.) 84 | 85 | The training script will load NLI data with a unified `jsonl` format in which each line is a JSON object for one NLI example. 86 | The JSON object should have the following fields: 87 | - "uid": unique id of the example; 88 | - "premise": the premise of the NLI example; 89 | - "hypothesis": the hypothesis of the NLI example; 90 | - "label": the label of the example. The label is from the set {"e", "n", "c"}, denoting the 3 classes "entailment", "neutral", or "contradiction", respectively. 91 | - Additional dataset specific fields... 92 | 93 | Here is one example from SNLI: 94 | ```json 95 | { 96 | "uid": "4705552913.jpg#2r1n", 97 | "premise": "Two women are embracing while holding to go packages.", 98 | "hypothesis": "The sisters are hugging goodbye while holding to go packages after just eating lunch.", 99 | "label": "n" 100 | } 101 | ``` 102 | 103 | Note that some training examples and all the development and test examples in ANLI have a `reason` field showing the reason for the annotation. Please read the paper for details. 104 | 105 | If you want to train or evaluate the model on data other than SNLI, MNLI, FEVER-NLI, or ANLI, we recommend that you refer to `src/dataset_tools/build_data.py` and `src/dataset_tools/format_convert.py` and use the tools in the repo to build your own data to avoid any exceptions. 106 | 107 | ## Model Training 108 | Now, you can use the following script to start training your NLI models. 109 | 110 | ```bash 111 | export MASTER_ADDR=localhost 112 | 113 | python src/nli/training.py \ 114 | --model_class_name "roberta-large" \ 115 | -n 1 \ 116 | -g 2 \ 117 | -nr 0 \ 118 | --max_length 156 \ 119 | --gradient_accumulation_steps 1 \ 120 | --per_gpu_train_batch_size 16 \ 121 | --per_gpu_eval_batch_size 16 \ 122 | --save_prediction \ 123 | --train_data snli_train:none,mnli_train:none \ 124 | --train_weights 1,1 \ 125 | --eval_data snli_dev:none \ 126 | --eval_frequency 2000 \ 127 | --experiment_name "roberta-large|snli|nli" 128 | ``` 129 | 130 | ### Argument Explanation 131 | - "--model_class_name": The argument specify the model class we will be using in the training. We currently support "roberta-large", "roberta-base", "xlnet-large", "xlnet-base", "bert-large", "bert-base", "albert-xxlarge", "bart-large". 132 | Note that the model_class_name should be the same when you load checkpoint for evaluation. 133 | - "-n": The number of nodes (machines) for training. In most cases, we recommend to set it to 1. MultiNode training is not tested. 134 | - "-g": The number of GPUs for training. 135 | - "-nr": Node rank. In most cases, we recommend to set it to 0. MultiNode training is not tested. 136 | - "--train_data": Specify source of training data separated by commas. The string before the colon is the name of the source and the string after the colon is the location of the data. Note that for SNLI, MNLI, FEVER-NLI, and ANLI, you should just give "none" as the location because their location have been manually registered in the script. For customized input data, you will need to specify the actual path of the data. 137 | - "--train_weights": Specify the size of the training data from different source. At every training epoch, the training data is re-sampled from different source and then combined altogether. The weights indicate multiplication of the sampled training size for the correspondent source. "1" means we just add all the data from that source to the training data. "0.5" means we sample half of the data from that source. "3.8" means we sample (with replacement) the training data such that the resulting size is 3.8 * the size of that source. 138 | Notes: The two argument above gives important information about the data used at every training epoch. The number of values for `--train_weights` needs to match the number of items in `--train_data`. 139 | For example, suppose snli_train has 100 examples, and `[name_of_your_data]` has 200 examples. Then, `--train_data snli_train:none,[name_of_your_data]:[path_to_your_data]/[filename_of_your_data].jsonl --train_weights 0.5,2` means training the model with 100 * 0.5 = 50 snli training examples and 200 * 2 = 400 `[name_of_your_data]` examples sampled with replacement at every epoch. 140 | - "--eval_data" Specify source of evaluation data separated by ",". (Same as "--train_data") 141 | - "--eval_frequency": The number of iteration steps between two saved model checkpoints. 142 | - "--experiment_name": The name of the experiment. During training, the checkpoints will be saved in `saved_models/{TRAINING_START_TIME}_[experiment_name]` directory (in the project root). So, the name will be an important identifier for finding the saved checkpoints. 143 | 144 | The other arguments should be self-explanatory. We recommend that you read the code if you are unsure about a specific argument. 145 | 146 | The example scripts `script/example_scripts/train_roberta.sh` and `script/example_scripts/train_xlnet.sh` can be used to reproduce the leaderboard RoBERTa and XLNet results, respectively. 147 | An **important detail** is that the training data used in above experiments are SNLI + MNLI + FEVER-NLI + 10*A1 (10 times upsampled ANLI-R1) + 20*A2 (20 times upsampled ANLI-R2) + 10*A3 (10 times upsampled ANLI-R3), as specified in the following arguments: 148 | ``` 149 | --train_data snli_train:none,mnli_train:none,fever_train:none,anli_r1_train:none,anli_r2_train:none,anli_r3_train:none \ 150 | --train_weights 1,1,1,10,20,10 \ 151 | ``` 152 | The scripts were tested on a machine with 8 Tesla V100 (16GB). 153 | 154 | You can try a smaller RoBERTa-base model with `script/example_scripts/train_roberta_small.sh` which can be run on single GPU with 12GB memory. 155 | 156 | During training, model checkpoints will be automatically saved in `saved_models` directory. 157 | 158 | #### Batch Size 159 | Training batch size might be a factor for performance. The actual training batch size can be calculated as `[number_of_nodes](-n)` * `[number_of_gpus_per_node](-g)` * `[per_gpu_train_batch_size]` * `[gradient_accumulation_steps]`. 160 | If the GPU memory is limited, you can set small forward batch size but with more gradient accumulation steps. E.g. `-n 1 -g 2 -nr 0 --gradient_accumulation_steps 8 --per_gpu_train_batch_size 8` can still give you a 2 * 8 * 8 = 128 training batch size. 161 | 162 | #### Distributed Data Parallel 163 | The code uses pytorch distributed data parallel for multiGPU training which technologically can support any number of GPU usage. 164 | You need to set `$MASTER_ADDR` variable to pass IP address of the master process to the python script. 165 | In most cases, you will just use one machine and you can just set this variable to "localhost". 166 | 167 | ## Evaluating Trained Models 168 | ```bash 169 | python src/nli/evaluation.py \ 170 | --model_class_name "roberta-large" \ 171 | --max_length 156 \ 172 | --per_gpu_eval_batch_size 16 \ 173 | --model_checkpoint_path \ 174 | "[the directory that contains your checkpoint]/model.pt" \ 175 | --eval_data anli_r1_test:none,anli_r2_test:none,anli_r3_test:none \ 176 | --output_prediction_path [the path of the directory you want the output to be saved] 177 | ``` 178 | Notice: 179 | 1. The "model_class_name" need to be the same as the one used in `training.py`. 180 | 2. You need to specify the path to your model parameter (the file named `model.pt`). 181 | 3. Evaluation is done in single GPU. 182 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial NLI 2 | 3 | ## Papers 4 | 5 | ### Dataset 6 | [**Adversarial NLI: A New Benchmark for Natural Language Understanding**](https://arxiv.org/abs/1910.14599) 7 | 8 | ### Annotations of the Dataset for Error Analysis 9 | [**ANLIzing the Adversarial Natural Language Inference Dataset**](https://arxiv.org/abs/2010.12729) 10 | 11 | ## Dataset 12 | Version 1.0 is available here: https://dl.fbaipublicfiles.com/anli/anli_v1.0.zip. 13 | ### Format 14 | The dataset files are all in JSONL format (one JSON per line). Below is one example (in JSON format) with self-explanatory fields. 15 | Note that each example (each line) in the files contains a `uid` field represents **a unique id** across all the examples in all there rounds of ANLI. 16 | ``` 17 | { 18 | "uid": "8a91e1a2-9a32-4fd9-b1b6-bd2ee2287c8f", 19 | "premise": "Javier Torres (born May 14, 1988 in Artesia, California) is an undefeated Mexican American professional boxer in the Heavyweight division. 20 | Torres was the second rated U.S. amateur boxer in the Super Heavyweight division and a member of the Mexican Olympic team.", 21 | "hypothesis": "Javier was born in Mexico", 22 | "label": "c", 23 | "reason": "The paragraph states that Javier was born in the California, US." 24 | } 25 | ``` 26 | 27 | ### Reason 28 | AdversarialNLI dataset contains a reason field for each examples in the `dev` and `test` split and for some examples in the `train` split. The reason is collected by asking annotator "Please write a reason for your statement belonging to the category and why you think it was difficult for the system.". 29 | 30 | 31 | 32 | ### Verifier Labels (Updated on May 11, 2022) 33 | All the examples in our dev and test sets are verified by 2 or 3 (if the first 2 verifiers do not agree with each other) verifiers. We released additional verifier labels in [`verifier_labels/verifier_labels_R1-3.jsonl`](https://github.com/facebookresearch/anli/blob/main/verifier_labels/verifier_labels_R1-3.jsonl). 34 | Please refer to the [verifier_labels_readme](https://github.com/facebookresearch/anli/blob/main/mds/verifier_labels.md) or Sec 2.1, Appendix C and Figure 7 in the [ANLI paper](https://arxiv.org/pdf/1910.14599.pdf) for more details about the verifier labels. 35 | 36 | 37 | ### Annotations for Error Analysis 38 | 39 | An in-depth error analysis of the dataset is available here: https://github.com/facebookresearch/anli/tree/main/anlizinganli 40 | 41 | We use a fine-grained annotation scheme of the different aspects of inference that are responsible for the gold classification labels, and use it to hand-code all three of the ANLI development sets. These annotations can be used to answer a variety of interesting questions: which inference types are most common, which models have the highest performance on each reasoning type, and which types are the most challenging for state of-the-art models? 42 | 43 | 44 | ## Leaderboard 45 | 46 | If you want to have your model added to the leaderboard, please reach out to us or submit a PR. 47 | 48 | Model | Publication | A1 | A2 | A3 49 | ---|---|---|---|--- 50 | InfoBERT (RoBERTa Large) | [Wang et al., 2020](https://openreview.net/forum?id=hpH98mK5Puk) | 75.5 | 51.4 | 49.8 51 | ALUM (RoBERTa Large) | [Liu et al., 2020](https://arxiv.org/abs/2004.08994) | 72.3 | 52.1 | 48.4 52 | GPT-3 | [Brown et al., 2020](https://arxiv.org/abs/2005.14165) | 36.8 | 34.0 | 40.2 53 | ALBERT ( [using the checkpoint in this codebase](#albert) ) | [Lan et al., 2019](https://arxiv.org/abs/1909.11942) | 73.6 | 58.6 | 53.4 54 | XLNet Large | [Yang et al., 2019](https://arxiv.org/abs/1906.08237) | 67.6 | 50.7 | 48.3 55 | RoBERTa Large | [Liu et al., 2019](https://arxiv.org/abs/1907.11692) | 73.8 | 48.9 | 44.4 56 | BERT Large | [Devlin et al., 2018](https://arxiv.org/abs/1810.04805) | 57.4 | 48.3 | 43.5 57 | 58 | (Updated on Jan 21 2021: The three entries at the bottom show the test set numbers from Table 3 in the [ANLI paper](https://arxiv.org/abs/1910.14599). We recommend that you report test set results in your paper. Dev scores, obtained for the models in this code base, are reported [below](#checkpoint_results).) 59 | 60 | ## Implementation 61 | 62 | To facilitate research in the field of NLI, we provide an easy-to-use codebase for NLI data preparation and modeling. 63 | The code is built upon [Transformers](https://huggingface.co/transformers/) with a special focus on NLI. 64 | 65 | We welcome researchers from various fields (linguistics, machine learning, cognitive science, psychology, etc.) to try NLI. 66 | You can use the code to reproduce the results in our paper or even as a starting point for your research. 67 | 68 | Please read more in [**Start your NLI research**](mds/start_your_nli_research.md). 69 | 70 | An important detail in our experiments is that we combine SNLI+MNLI+FEVER-NLI and up-sample different rounds of ANLI to train the models. 71 | **We highly recommend you refer to the above link for reproducing the results and training your models such that the results will be comparable to the ones on the leaderboard.** 72 | 73 | (Updated on May 11, 2022) 74 | Thanks to [Jared Contrascere](https://github.com/contracode). Now, Researchers can use the [notebook](https://github.com/facebookresearch/anli/blob/main/script/example_scripts/ANLI_on_Google_Colab.ipynb) to run experiments quickly via Google Colab. 75 | 76 | ## Pre-trained Models 77 | Pre-trained NLI models can be easily called through huggingface model hub. 78 | 79 | Version information: 80 | ``` 81 | python==3.7 82 | torch==1.7 83 | transformers==3.0.2 or later (tested: 3.0.2, 3.1.0, 4.0.0) 84 | ``` 85 | 86 | Models: `RoBERTa`, `ALBert`, `BART`, `ELECTRA`, `XLNet`. 87 | 88 | The training data is a combination of [`SNLI`](https://nlp.stanford.edu/projects/snli/), [`MNLI`](https://cims.nyu.edu/~sbowman/multinli/), [`FEVER-NLI`](https://github.com/easonnie/combine-FEVER-NSMN/blob/master/other_resources/nli_fever.md), [`ANLI (R1, R2, R3)`](https://github.com/facebookresearch/anli). Please also cite the datasets if you are using the pre-trained model. 89 | 90 | Please try the code snippet below. 91 | ```python 92 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 93 | import torch 94 | 95 | if __name__ == '__main__': 96 | max_length = 256 97 | 98 | premise = "Two women are embracing while holding to go packages." 99 | hypothesis = "The men are fighting outside a deli." 100 | 101 | hg_model_hub_name = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli" 102 | # hg_model_hub_name = "ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli" 103 | # hg_model_hub_name = "ynie/bart-large-snli_mnli_fever_anli_R1_R2_R3-nli" 104 | # hg_model_hub_name = "ynie/electra-large-discriminator-snli_mnli_fever_anli_R1_R2_R3-nli" 105 | # hg_model_hub_name = "ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli" 106 | 107 | tokenizer = AutoTokenizer.from_pretrained(hg_model_hub_name) 108 | model = AutoModelForSequenceClassification.from_pretrained(hg_model_hub_name) 109 | 110 | tokenized_input_seq_pair = tokenizer.encode_plus(premise, hypothesis, 111 | max_length=max_length, 112 | return_token_type_ids=True, truncation=True) 113 | 114 | input_ids = torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0) 115 | 116 | # remember bart doesn't have 'token_type_ids', remove the line below if you are using bart. 117 | token_type_ids = torch.Tensor(tokenized_input_seq_pair['token_type_ids']).long().unsqueeze(0) 118 | attention_mask = torch.Tensor(tokenized_input_seq_pair['attention_mask']).long().unsqueeze(0) 119 | 120 | outputs = model(input_ids, 121 | attention_mask=attention_mask, 122 | token_type_ids=token_type_ids, 123 | labels=None) 124 | # Note: 125 | # "id2label": { 126 | # "0": "entailment", 127 | # "1": "neutral", 128 | # "2": "contradiction" 129 | # }, 130 | 131 | predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist() # batch_size only one 132 | 133 | print("Premise:", premise) 134 | print("Hypothesis:", hypothesis) 135 | print("Entailment:", predicted_probability[0]) 136 | print("Neutral:", predicted_probability[1]) 137 | print("Contradiction:", predicted_probability[2]) 138 | ``` 139 | 140 | If you are using our pre-trained model checkpoints with the above code snippet, you would expect to get the following numbers. 141 | 142 | Huggingface Model Hub Checkpoint | A1 (dev) | A2 (dev) | A3 (dev) | A1 (test) | A2 (test) | A3 (test) 143 | ---|---|---|---|---|---|--- 144 | ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli | 73.8 | 50.8 | 46.1 | 73.6 | 49.3 | 45.5 145 | ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli | 73.4 | 52.3 | 50.8 | 70.0 | 51.4 | 49.8 146 | ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli | 76.0 | 57.0 | 57.0 | 73.6 | 58.6 | 53.4 147 | 148 | 149 | More in [here](https://github.com/facebookresearch/anli/blob/master/src/hg_api/interactive_eval.py). 150 | 151 | ## Rules 152 | 153 | When using this dataset, we ask that you obey some very simple rules: 154 | 155 | 1. We want to make it easy for people to provide ablations on test sets without being rate limited, so we release labeled test sets with this distribution. We trust that you will act in good faith, and will not tune on the test set (this should really go without saying)! We may release unlabeled test sets later. 156 | 157 | 2. **Training data is for training, development data is for development, and test data is for reporting test numbers.** This means that you should not e.g. train on the train+dev data from rounds 1 and 2 and then report an increase in performance on the test set of round 3. 158 | 159 | 3. We will host a leaderboard on this page. If you want to be added to the leaderboard, please contact us and/or submit a PR with a link to your paper, a link to your code in a public repository (e.g. Github), together with the following information: number of parameters in your model, data used for (pre-)training, and your dev and test results for *each* round, as well as the total over *all* rounds. 160 | 161 | ## Other NLI Reference 162 | 163 | We used following NLI resources in training the backend model of the adversarial collection: 164 | - [**SNLI**](https://nlp.stanford.edu/projects/snli/) 165 | - [**MultiNLI**](https://www.nyu.edu/projects/bowman/multinli/) 166 | - [**NLI style FEVER**](https://github.com/easonnie/combine-FEVER-NSMN/blob/master/other_resources/nli_fever.md) 167 | 168 | ## Citations 169 | 170 | ### Dataset 171 | ``` 172 | @inproceedings{nie-etal-2020-adversarial, 173 | title = "Adversarial {NLI}: A New Benchmark for Natural Language Understanding", 174 | author = "Nie, Yixin and 175 | Williams, Adina and 176 | Dinan, Emily and 177 | Bansal, Mohit and 178 | Weston, Jason and 179 | Kiela, Douwe", 180 | booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics", 181 | year = "2020", 182 | publisher = "Association for Computational Linguistics", 183 | } 184 | ``` 185 | 186 | ### Annotations of the Dataset for Error Analysis 187 | ``` 188 | @article{williams-etal-2020-anlizing, 189 | title = "ANLIzing the Adversarial Natural Language Inference Dataset", 190 | author = "Adina Williams and 191 | Tristan Thrush and 192 | Douwe Kiela", 193 | booktitle = "Proceedings of the 5th Annual Meeting of the Society for Computation in Linguistics", 194 | year = "2022", 195 | publisher = "Association for Computational Linguistics", 196 | } 197 | ``` 198 | 199 | ## License 200 | ANLI is licensed under Creative Commons-Non Commercial 4.0. See the LICENSE file for details. 201 | -------------------------------------------------------------------------------- /script/example_scripts/ANLI_on_Google_Colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "MWihWfGuVkRP" 7 | }, 8 | "source": [ 9 | "# ANLI Dataset Filtering\n", 10 | "\n", 11 | "This Google Colab notebook is inspired by the [*Start Your NLI Research*](https://github.com/facebookresearch/anli/blob/main/mds/start_your_nli_research.md) instructions located on the [ANLI](https://github.com/facebookresearch/anli) GitHub repo. This is intended to be run on a [Google Colab Pro](https://colab.research.google.com/signup) or [Pro+](https://colab.research.google.com/signup) account leveraging a GPU-backed runtime." 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": { 17 | "id": "4PAZ13O77C2w" 18 | }, 19 | "source": [ 20 | "## Connect to Google Drive\n", 21 | "\n", 22 | "We will connect to [Google Drive](https://drive.google.com) to store weights and data within the cloud. This is needed because Google Colab has a maximum 24-hour runtime, even with Pro and Pro+ accounts. After running the cell below, you will be prompted to connect with your Google Drive account." 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": { 29 | "id": "nln7suF-69LL" 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "# Mount into drive\n", 34 | "from google.colab import drive\n", 35 | "drive.mount(\"/content/drive\")" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": { 41 | "id": "B7Dy2XhY7_c9" 42 | }, 43 | "source": [ 44 | "This cell created an `ANLI Project Data` folder within your `Colab Notebooks` folder." 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "id": "Cea7z6I869OY" 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "%mkdir -p /content/drive/MyDrive/Colab\\ Notebooks/ANLI\\ Project\\ Data\n", 56 | "%mkdir -p /content/drive/MyDrive/Colab\\ Notebooks/ANLI\\ Project\\ Data/scripts\n", 57 | "%mkdir -p /content/drive/MyDrive/Colab\\ Notebooks/ANLI\\ Project\\ Data/checkpoints" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": { 63 | "id": "BMBDQPTvX2Yt" 64 | }, 65 | "source": [ 66 | "## GPU Allocation\n", 67 | "\n", 68 | "It is a good idea to capture what kind of GPU we have allocated to us. The following commands do this in a summarized and a verbose manner." 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": { 75 | "id": "I_1_x5JOYHP3" 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "!nvidia-smi -L" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": { 86 | "id": "PekoZAVzYHSj" 87 | }, 88 | "outputs": [], 89 | "source": [ 90 | "!nvidia-smi -q" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": { 96 | "id": "ZrB7azJfVsqC" 97 | }, 98 | "source": [ 99 | "## Project Setup\n", 100 | "\n", 101 | "### Code Setup\n", 102 | "\n", 103 | "First, we need to download the [ANLI](https://github.com/facebookresearch/anli) repo and build the dataset." 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": { 110 | "id": "9vvrCFC5Vq8T" 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "!git clone https://github.com/facebookresearch/anli.git 2>/dev/null\n", 115 | "!source anli/setup.sh" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": { 121 | "id": "r0MUm0J5WgjE" 122 | }, 123 | "source": [ 124 | "Then, we'll change directory into the source code directory, `anli`." 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": { 131 | "id": "xo4L-XjwEcQD" 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "import os\n", 136 | "import sys\n", 137 | "\n", 138 | "try:\n", 139 | " os.chdir('anli/')\n", 140 | "except FileNotFoundError as e:\n", 141 | " print(f\"Could not change directory: {str(e)}\")" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": { 147 | "id": "cAFxj96xYVtP" 148 | }, 149 | "source": [ 150 | "Finally, as far as code goes, we'll need the `transformers` module from the popular [Hugging Face](https://huggingface.co/docs/transformers/index) open-source NLP company and `sentencepiece` which is needed to support experiments with the xlnet model." 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": { 157 | "id": "wmVDK6niYUYi" 158 | }, 159 | "outputs": [], 160 | "source": [ 161 | "!pip install transformers sentencepiece" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": { 167 | "id": "QX-Pz088WuXO" 168 | }, 169 | "source": [ 170 | "#### Environment Variables\n", 171 | "\n", 172 | "Before moving onto dataset setup, we'll set some environment variables to prepare for the Bash and Python scripts that follow." 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": { 179 | "id": "39E8dSShPDcS" 180 | }, 181 | "outputs": [], 182 | "source": [ 183 | "%env PYTHONPATH='/env/python:/content/anli/src:/content/anli/utest:/content/anli/src/dataset_tools'\n", 184 | "%env MASTER_ADDR=localhost" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": { 190 | "id": "sXU1zxgOXKT-" 191 | }, 192 | "source": [ 193 | "### Dataset Setup\n", 194 | "\n", 195 | "We can't train a model without data, so this will download the SNLI, MNLI, FEVER, and NLI datasets." 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": { 202 | "id": "UYf3Epg2SMaG" 203 | }, 204 | "outputs": [], 205 | "source": [ 206 | "!bash ./script/download_data.sh" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": { 212 | "id": "SDjY0zN0XtXu" 213 | }, 214 | "source": [ 215 | "Now, we'll transform the dataset into a format that the ANLI project expects." 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": { 222 | "id": "KN6qem_WFxlf" 223 | }, 224 | "outputs": [], 225 | "source": [ 226 | "!python ./src/dataset_tools/build_data.py" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": { 232 | "id": "4Pk1JIFy-PGx" 233 | }, 234 | "source": [ 235 | "## Update Training Script" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": { 241 | "id": "FE-Krd7BAKC1" 242 | }, 243 | "source": [ 244 | "If you've placed a modified `training.py` script in your GDrive `Colab Notebooks/ANLI Project Data/scripts/` directory, uncomment and run the following line so that your updated script will be used in the ***Model Training*** section." 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": { 251 | "id": "b8nAruWu_3dW" 252 | }, 253 | "outputs": [], 254 | "source": [ 255 | "#!cp /content/drive/MyDrive/Colab\\ Notebooks/ANLI\\ Project\\ Data/scripts/training.py ./src/nli/training.py" 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": { 261 | "id": "dI0C8fWcAii0" 262 | }, 263 | "source": [ 264 | "Alternatively, if you are storing an updated `training.py` in a public GitHub/GitLab repo, uncomment and run the following line after updating the URL to point to the *raw* file. In a browser, this will look like a plaintext version of the file." 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "metadata": { 271 | "id": "zUUVh9k8-SA4" 272 | }, 273 | "outputs": [], 274 | "source": [ 275 | "#!curl https://https://raw.githubusercontent.com/username/project/main/src/nli/training.py -o ./src/nli/training.py" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "source": [ 281 | "## Update Data" 282 | ], 283 | "metadata": { 284 | "id": "usAZ5Dr7zgWp" 285 | } 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "source": [ 290 | "If any custom datasets are used for training or evaluation, the following lines bring them from Google Drive into Colab." 291 | ], 292 | "metadata": { 293 | "id": "TXawrKnBB3NY" 294 | } 295 | }, 296 | { 297 | "cell_type": "code", 298 | "source": [ 299 | "#%mkdir -p experiments" 300 | ], 301 | "metadata": { 302 | "id": "RhTsSQRk0FV0" 303 | }, 304 | "execution_count": null, 305 | "outputs": [] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "source": [ 310 | "#!cp -R /content/drive/MyDrive/Colab\\ Notebooks/ANLI\\ Project\\ Data/data/* ./experiments" 311 | ], 312 | "metadata": { 313 | "id": "Gcwqhgx1zgoC" 314 | }, 315 | "execution_count": null, 316 | "outputs": [] 317 | }, 318 | { 319 | "cell_type": "markdown", 320 | "metadata": { 321 | "id": "8DK8PO33Y8At" 322 | }, 323 | "source": [ 324 | "## Model Training" 325 | ] 326 | }, 327 | { 328 | "cell_type": "markdown", 329 | "metadata": { 330 | "id": "5qZUOS-jZL2P" 331 | }, 332 | "source": [ 333 | "Note that a list of supported models and extra, undocumented command line arguments are located in `/content/anli/src/nli/training.py`. Comments are below where changes have been made from [*Start Your NLI Research*](https://github.com/facebookresearch/anli/blob/main/mds/start_your_nli_research.md) instructions.\n", 334 | "\n", 335 | "During training, model checkpoints will be automatically saved in a `saved_models` directory.\n", 336 | "\n", 337 | "***Changelog***\n", 338 | "\n", 339 | "* `-g 1`: This was changed to 1, since we only have one GPU.\n", 340 | "* `--single_gpu`: This was added to suppress PyTorch Multiprocessing logic from kicking in.\n", 341 | "* `--experiment_name`: The name of the experiment. During training, model checkpoints will be saved in `saved_models/{TRAINING_START_TIME}_[experiment_name]`." 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": { 348 | "id": "tU8pKQHcHtC2" 349 | }, 350 | "outputs": [], 351 | "source": [ 352 | "!python ./src/nli/training.py \\\n", 353 | " --model_class_name \"roberta-large\" \\\n", 354 | " -n 1 \\\n", 355 | " -g 1 \\\n", 356 | " --single_gpu \\\n", 357 | " -nr 0 \\\n", 358 | " --max_length 156 \\\n", 359 | " --gradient_accumulation_steps 1 \\\n", 360 | " --per_gpu_train_batch_size 16 \\\n", 361 | " --per_gpu_eval_batch_size 16 \\\n", 362 | " --save_prediction \\\n", 363 | " --train_data snli_train:none,mnli_train:none \\\n", 364 | " --train_weights 1,1 \\\n", 365 | " --eval_data snli_dev:none \\\n", 366 | " --eval_frequency 2000 \\\n", 367 | " --experiment_name \"roberta-large|snli|nli\"" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "metadata": { 373 | "id": "L70Ca7T8FqxL" 374 | }, 375 | "source": [ 376 | "**Make sure to queue this command alongside your training cell.** This will copy saved checkpoints from Colab to your Google Drive." 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "metadata": { 383 | "id": "XbfUqimYa0tY" 384 | }, 385 | "outputs": [], 386 | "source": [ 387 | "!cp -R ./saved_models/* /content/drive/MyDrive/Colab\\ Notebooks/ANLI\\ Project\\ Data/checkpoints/" 388 | ] 389 | } 390 | ], 391 | "metadata": { 392 | "accelerator": "GPU", 393 | "colab": { 394 | "background_execution": "on", 395 | "collapsed_sections": [], 396 | "machine_shape": "hm", 397 | "name": "ANLI on Google Colab.ipynb", 398 | "provenance": [] 399 | }, 400 | "kernelspec": { 401 | "display_name": "Python 3", 402 | "name": "python3" 403 | }, 404 | "language_info": { 405 | "name": "python" 406 | } 407 | }, 408 | "nbformat": 4, 409 | "nbformat_minor": 0 410 | } -------------------------------------------------------------------------------- /src/flint/torch_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | import numpy as np 10 | import functools 11 | 12 | 13 | # def get_length_and_mask(seq): 14 | # len_mask = (seq != 0).long() 15 | # len_t = get_lengths_from_binary_sequence_mask(len_mask) 16 | # return len_mask, len_t 17 | 18 | 19 | def length_truncate(seq, max_l, is_elmo=False): 20 | def _truncate(seq): 21 | if seq.size(1) > max_l: 22 | return seq[:, :max_l, ...] 23 | else: 24 | return seq 25 | 26 | if not is_elmo: 27 | return _truncate(seq) 28 | else: 29 | s1_elmo_embd = dict() 30 | s1_elmo_embd['mask'] = _truncate(seq['mask']) 31 | s1_elmo_embd['elmo_representations'] = [] 32 | for e_rep in seq['elmo_representations']: 33 | s1_elmo_embd['elmo_representations'].append(_truncate(e_rep)) 34 | return s1_elmo_embd 35 | 36 | 37 | def pad_1d(seq, pad_l): 38 | """ 39 | The seq is a sequence having shape [T, ..]. Note: The seq contains only one instance. This is not batched. 40 | 41 | :param seq: Input sequence with shape [T, ...] 42 | :param pad_l: The required pad_length. 43 | :return: Output sequence will have shape [Pad_L, ...] 44 | """ 45 | l = seq.size(0) 46 | if l >= pad_l: 47 | return seq[:pad_l, ] # Truncate the length if the length is bigger than required padded_length. 48 | else: 49 | pad_seq = Variable(seq.data.new(pad_l - l, *seq.size()[1:]).zero_()) # Requires_grad is False 50 | return torch.cat([seq, pad_seq], dim=0) 51 | 52 | 53 | def get_state_shape(rnn: nn.RNN, batch_size, bidirectional=False): 54 | """ 55 | Return the state shape of a given RNN. This is helpful when you want to create a init state for RNN. 56 | 57 | Example: 58 | c0 = h0 = Variable(src_seq_p.data.new(*get_state_shape([your rnn], 3, bidirectional)).zero_()) 59 | 60 | :param rnn: nn.LSTM, nn.GRU or subclass of nn.RNN 61 | :param batch_size: 62 | :param bidirectional: 63 | :return: 64 | """ 65 | if bidirectional: 66 | return rnn.num_layers * 2, batch_size, rnn.hidden_size 67 | else: 68 | return rnn.num_layers, batch_size, rnn.hidden_size 69 | 70 | 71 | def pack_list_sequence(inputs, l, max_l=None, batch_first=True): 72 | """ 73 | Pack a batch of Tensor into one Tensor with max_length. 74 | :param inputs: 75 | :param l: 76 | :param max_l: The max_length of the packed sequence. 77 | :param batch_first: 78 | :return: 79 | """ 80 | batch_list = [] 81 | max_l = max(list(l)) if not max_l else max_l 82 | batch_size = len(inputs) 83 | 84 | for b_i in range(batch_size): 85 | batch_list.append(pad_1d(inputs[b_i], max_l)) 86 | pack_batch_list = torch.stack(batch_list, dim=1) if not batch_first \ 87 | else torch.stack(batch_list, dim=0) 88 | return pack_batch_list 89 | 90 | 91 | def pack_for_rnn_seq(inputs, lengths, batch_first=True, states=None): 92 | """ 93 | :param states: [rnn.num_layers, batch_size, rnn.hidden_size] 94 | :param inputs: Shape of the input should be [B, T, D] if batch_first else [T, B, D]. 95 | :param lengths: [B] 96 | :param batch_first: 97 | :return: 98 | """ 99 | if not batch_first: 100 | _, sorted_indices = lengths.sort() 101 | ''' 102 | Reverse to decreasing order 103 | ''' 104 | r_index = reversed(list(sorted_indices)) 105 | 106 | s_inputs_list = [] 107 | lengths_list = [] 108 | reverse_indices = np.zeros(lengths.size(0), dtype=np.int64) 109 | 110 | for j, i in enumerate(r_index): 111 | s_inputs_list.append(inputs[:, i, :].unsqueeze(1)) 112 | lengths_list.append(lengths[i]) 113 | reverse_indices[i] = j 114 | 115 | reverse_indices = list(reverse_indices) 116 | 117 | s_inputs = torch.cat(s_inputs_list, 1) 118 | packed_seq = nn.utils.rnn.pack_padded_sequence(s_inputs, lengths_list) 119 | 120 | return packed_seq, reverse_indices 121 | 122 | else: 123 | _, sorted_indices = lengths.sort() 124 | ''' 125 | Reverse to decreasing order 126 | ''' 127 | r_index = reversed(list(sorted_indices)) 128 | 129 | s_inputs_list = [] 130 | lengths_list = [] 131 | reverse_indices = np.zeros(lengths.size(0), dtype=np.int64) 132 | 133 | if states is None: 134 | states = () 135 | elif not isinstance(states, tuple): 136 | states = (states,) # rnn.num_layers, batch_size, rnn.hidden_size 137 | 138 | states_lists = tuple([] for _ in states) 139 | 140 | for j, i in enumerate(r_index): 141 | s_inputs_list.append(inputs[i, :, :]) 142 | lengths_list.append(lengths[i]) 143 | reverse_indices[i] = j 144 | 145 | for state_list, state in zip(states_lists, states): 146 | state_list.append(state[:, i, :].unsqueeze(1)) 147 | 148 | reverse_indices = list(reverse_indices) 149 | 150 | s_inputs = torch.stack(s_inputs_list, dim=0) 151 | packed_seq = nn.utils.rnn.pack_padded_sequence(s_inputs, lengths_list, batch_first=batch_first) 152 | 153 | r_states = tuple(torch.cat(state_list, dim=1) for state_list in states_lists) 154 | if len(r_states) == 1: 155 | r_states = r_states[0] 156 | 157 | return packed_seq, reverse_indices, r_states 158 | 159 | 160 | def unpack_from_rnn_seq(packed_seq, reverse_indices, batch_first=True): 161 | unpacked_seq, _ = nn.utils.rnn.pad_packed_sequence(packed_seq, batch_first=batch_first) 162 | s_inputs_list = [] 163 | 164 | if not batch_first: 165 | for i in reverse_indices: 166 | s_inputs_list.append(unpacked_seq[:, i, :].unsqueeze(1)) 167 | return torch.cat(s_inputs_list, 1) 168 | else: 169 | for i in reverse_indices: 170 | s_inputs_list.append(unpacked_seq[i, :, :].unsqueeze(0)) 171 | return torch.cat(s_inputs_list, 0) 172 | 173 | 174 | def reverse_indice_for_state(states, reverse_indices): 175 | """ 176 | :param states: [rnn.num_layers, batch_size, rnn.hidden_size] 177 | :param reverse_indices: [batch_size] 178 | :return: 179 | """ 180 | if states is None: 181 | states = () 182 | elif not isinstance(states, tuple): 183 | states = (states,) # rnn.num_layers, batch_size, rnn.hidden_size 184 | 185 | states_lists = tuple([] for _ in states) 186 | for i in reverse_indices: 187 | for state_list, state in zip(states_lists, states): 188 | state_list.append(state[:, i, :].unsqueeze(1)) 189 | 190 | r_states = tuple(torch.cat(state_list, dim=1) for state_list in states_lists) 191 | if len(r_states) == 1: 192 | r_states = r_states[0] 193 | return r_states 194 | 195 | 196 | def auto_rnn(rnn: nn.RNN, seqs, lengths, batch_first=True, init_state=None, output_last_states=False): 197 | batch_size = seqs.size(0) if batch_first else seqs.size(1) 198 | state_shape = get_state_shape(rnn, batch_size, rnn.bidirectional) 199 | 200 | # if init_state is None: 201 | # h0 = c0 = Variable(seqs.data.new(*state_shape).zero_()) 202 | # else: 203 | # h0 = init_state[0] # rnn.num_layers, batch_size, rnn.hidden_size 204 | # c0 = init_state[1] 205 | 206 | packed_pinputs, r_index, init_state = pack_for_rnn_seq(seqs, lengths, batch_first, init_state) 207 | 208 | if len(init_state) == 0: 209 | h0 = c0 = Variable(seqs.data.new(*state_shape).zero_()) 210 | init_state = (h0, c0) 211 | 212 | output, last_state = rnn(packed_pinputs, init_state) 213 | output = unpack_from_rnn_seq(output, r_index, batch_first) 214 | 215 | if not output_last_states: 216 | return output 217 | else: 218 | last_state = reverse_indice_for_state(last_state, r_index) 219 | return output, last_state 220 | 221 | 222 | def pack_sequence_for_linear(inputs, lengths, batch_first=True): 223 | """ 224 | :param inputs: [B, T, D] if batch_first 225 | :param lengths: [B] 226 | :param batch_first: 227 | :return: 228 | """ 229 | batch_list = [] 230 | if batch_first: 231 | for i, l in enumerate(lengths): 232 | # print(inputs[i, :l].size()) 233 | batch_list.append(inputs[i, :l]) 234 | packed_sequence = torch.cat(batch_list, 0) 235 | # if chuck: 236 | # return list(torch.chunk(packed_sequence, chuck, dim=0)) 237 | # else: 238 | return packed_sequence 239 | else: 240 | raise NotImplemented() 241 | 242 | 243 | def chucked_forward(inputs, net, chuck=None): 244 | if not chuck: 245 | return net(inputs) 246 | else: 247 | output_list = [net(chuck) for chuck in torch.chunk(inputs, chuck, dim=0)] 248 | return torch.cat(output_list, dim=0) 249 | 250 | 251 | def unpack_sequence_for_linear(inputs, lengths, batch_first=True): 252 | batch_list = [] 253 | max_l = max(lengths) 254 | 255 | if not isinstance(inputs, list): 256 | inputs = [inputs] 257 | inputs = torch.cat(inputs) 258 | 259 | if batch_first: 260 | start = 0 261 | for l in lengths: 262 | end = start + l 263 | batch_list.append(pad_1d(inputs[start:end], max_l)) 264 | start = end 265 | return torch.stack(batch_list) 266 | else: 267 | raise NotImplemented() 268 | 269 | 270 | def seq2seq_cross_entropy(logits, label, l, chuck=None, sos_truncate=True): 271 | """ 272 | :param logits: [exB, V] : exB = sum(l) 273 | :param label: [B] : a batch of Label 274 | :param l: [B] : a batch of LongTensor indicating the lengths of each inputs 275 | :param chuck: Number of chuck to process 276 | :return: A loss value 277 | """ 278 | packed_label = pack_sequence_for_linear(label, l) 279 | cross_entropy_loss = functools.partial(F.cross_entropy, size_average=False) 280 | total = sum(l) 281 | 282 | assert total == logits.size(0) or packed_label.size(0) == logits.size(0), \ 283 | "logits length mismatch with label length." 284 | 285 | if chuck: 286 | logits_losses = 0 287 | for x, y in zip(torch.chunk(logits, chuck, dim=0), torch.chunk(packed_label, chuck, dim=0)): 288 | logits_losses += cross_entropy_loss(x, y) 289 | return logits_losses * (1 / total) 290 | else: 291 | return cross_entropy_loss(logits, packed_label) * (1 / total) 292 | 293 | 294 | def max_along_time(inputs, lengths, list_in=False): 295 | """ 296 | :param inputs: [B, T, D] 297 | :param lengths: [B] 298 | :return: [B * D] max_along_time 299 | :param list_in: 300 | """ 301 | ls = list(lengths) 302 | 303 | if not list_in: 304 | b_seq_max_list = [] 305 | for i, l in enumerate(ls): 306 | seq_i = inputs[i, :l, :] 307 | seq_i_max, _ = seq_i.max(dim=0) 308 | seq_i_max = seq_i_max.squeeze() 309 | b_seq_max_list.append(seq_i_max) 310 | 311 | return torch.stack(b_seq_max_list) 312 | else: 313 | b_seq_max_list = [] 314 | for i, l in enumerate(ls): 315 | seq_i = inputs[i] 316 | seq_i_max, _ = seq_i.max(dim=0) 317 | seq_i_max = seq_i_max.squeeze() 318 | b_seq_max_list.append(seq_i_max) 319 | 320 | return torch.stack(b_seq_max_list) 321 | 322 | 323 | def avg_along_time(inputs, lengths, list_in=False): 324 | """ 325 | :param inputs: [B, T, D] 326 | :param lengths: [B] 327 | :return: [B * D] max_along_time 328 | :param list_in: 329 | """ 330 | ls = list(lengths) 331 | 332 | if not list_in: 333 | b_seq_avg_list = [] 334 | for i, l in enumerate(ls): 335 | seq_i = inputs[i, :l, :] 336 | seq_i_avg = seq_i.mean(dim=0) 337 | seq_i_avg = seq_i_avg.squeeze() 338 | b_seq_avg_list.append(seq_i_avg) 339 | 340 | return torch.stack(b_seq_avg_list) 341 | else: 342 | b_seq_avg_list = [] 343 | for i, l in enumerate(ls): 344 | seq_i = inputs[i] 345 | seq_i_avg, _ = seq_i.mean(dim=0) 346 | seq_i_avg = seq_i_avg.squeeze() 347 | b_seq_avg_list.append(seq_i_avg) 348 | 349 | return torch.stack(b_seq_avg_list) 350 | 351 | 352 | # def length_truncate(inputs, lengths, max_len): 353 | # """ 354 | # :param inputs: [B, T] 355 | # :param lengths: [B] 356 | # :param max_len: int 357 | # :return: [B, T] 358 | # """ 359 | # max_l = max(1, max_len) 360 | # max_s1_l = min(max(lengths), max_l) 361 | # lengths = lengths.clamp(min=1, max=max_s1_l) 362 | # if inputs.size(1) > max_s1_l: 363 | # inputs = inputs[:, :max_s1_l] 364 | # 365 | # return inputs, lengths, max_s1_l 366 | 367 | 368 | def get_reverse_indices(indices, lengths): 369 | r_indices = indices.data.new(indices.size()).fill_(0) 370 | batch_size = indices.size(0) 371 | for i in range(int(batch_size)): 372 | b_ind = indices[i] 373 | b_l = lengths[i] 374 | for k, ind in enumerate(b_ind): 375 | if k >= b_l: 376 | break 377 | r_indices[i, int(ind)] = k 378 | return r_indices 379 | 380 | 381 | def index_ordering(inputs, lengths, indices, pad_value=0): 382 | """ 383 | :param inputs: [B, T, ~] 384 | :param lengths: [B] 385 | :param indices: [B, T] 386 | :return: 387 | """ 388 | batch_size = inputs.size(0) 389 | ordered_out_list = [] 390 | for i in range(int(batch_size)): 391 | b_input = inputs[i] 392 | b_l = lengths[i] 393 | b_ind = indices[i] 394 | b_out = b_input[b_ind] 395 | if b_out.size(0) > b_l: 396 | b_out[b_l:] = pad_value 397 | ordered_out_list.append(b_out) 398 | 399 | outs = torch.stack(ordered_out_list, dim=0) 400 | return outs 401 | 402 | 403 | def start_and_end_token_handling(inputs, lengths, sos_index=1, eos_index=2, pad_index=0, 404 | op=None): 405 | """ 406 | :param inputs: [B, T] 407 | :param lengths: [B] 408 | :param sos_index: 409 | :param eos_index: 410 | :param pad_index: 411 | :return: 412 | """ 413 | batch_size = inputs.size(0) 414 | 415 | if not op: 416 | return inputs, lengths 417 | elif op == 'rm_start': 418 | inputs = torch.cat([inputs[:, 1:], Variable(inputs.data.new(batch_size, 1).zero_())], dim=1) 419 | return inputs, lengths - 1 420 | elif op == 'rm_end': 421 | for i in range(batch_size): 422 | pass 423 | # Potential problems!? 424 | # inputs[i, lengths[i] - 1] = pad_index 425 | return inputs, lengths - 1 426 | elif op == 'rm_both': 427 | for i in range(batch_size): 428 | pass 429 | # Potential problems!? 430 | # inputs[i, lengths[i] - 1] = pad_index 431 | inputs = torch.cat([inputs[:, 1:], Variable(inputs.data.new(batch_size, 1).zero_())], dim=1) 432 | return inputs, lengths - 2 433 | 434 | 435 | def seq2seq_att(mems, lengths, state, att_net=None): 436 | """ 437 | :param mems: [B, T, D_mem] This are the memories. 438 | I call memory for this variable because I think attention is just like read something and then 439 | make alignments with your memories. 440 | This memory here is usually the input hidden state of the encoder. 441 | 442 | :param lengths: [B] 443 | :param state: [B, D_state] 444 | I call state for this variable because it's the state I percepts at this time step. 445 | 446 | :param att_net: This is the attention network that will be used to calculate the alignment score between 447 | state and memories. 448 | input of the att_net is mems and state with shape: 449 | mems: [exB, D_mem] 450 | state: [exB, D_state] 451 | return of the att_net is [exB, 1] 452 | 453 | So any function that map a vector to a scalar could work. 454 | 455 | :return: [B, D_result] 456 | """ 457 | 458 | d_state = state.size(1) 459 | 460 | if not att_net: 461 | return state 462 | else: 463 | batch_list_mems = [] 464 | batch_list_state = [] 465 | for i, l in enumerate(lengths): 466 | b_mems = mems[i, :l] # [T, D_mem] 467 | batch_list_mems.append(b_mems) 468 | 469 | b_state = state[i].expand(b_mems.size(0), d_state) # [T, D_state] 470 | batch_list_state.append(b_state) 471 | 472 | packed_sequence_mems = torch.cat(batch_list_mems, 0) # [sum(l), D_mem] 473 | packed_sequence_state = torch.cat(batch_list_state, 0) # [sum(l), D_state] 474 | 475 | align_score = att_net(packed_sequence_mems, packed_sequence_state) # [sum(l), 1] 476 | 477 | # The score grouped as [(a1, a2, a3), (a1, a2), (a1, a2, a3, a4)]. 478 | # aligned_seq = packed_sequence_mems * align_score 479 | 480 | start = 0 481 | result_list = [] 482 | for i, l in enumerate(lengths): 483 | end = start + l 484 | 485 | b_mems = packed_sequence_mems[start:end, :] # [l, D_mems] 486 | b_score = align_score[start:end, :] # [l, 1] 487 | 488 | softed_b_score = F.softmax(b_score.transpose(0, 1)).transpose(0, 1) # [l, 1] 489 | 490 | weighted_sum = torch.sum(b_mems * softed_b_score, dim=0, keepdim=False) # [D_mems] 491 | 492 | result_list.append(weighted_sum) 493 | 494 | start = end 495 | 496 | result = torch.stack(result_list, dim=0) 497 | 498 | return result 499 | 500 | # Test something -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | -------------------------------------------------------------------------------- /src/nli/training_extra.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | from pathlib import Path 8 | 9 | from torch.optim import Adam 10 | from transformers import RobertaTokenizer, RobertaForSequenceClassification 11 | from transformers import XLNetTokenizer, XLNetForSequenceClassification 12 | # from transformers import XLNetTokenizer 13 | # from modeling.dummy_modeling_xlnet import XLNetForSequenceClassification 14 | from transformers import BertTokenizer, BertForSequenceClassification 15 | from transformers import AlbertTokenizer, AlbertForSequenceClassification 16 | from transformers import DistilBertTokenizer, DistilBertForSequenceClassification 17 | from transformers import BartTokenizer, BartForSequenceClassification 18 | from transformers import ElectraTokenizer, ElectraForSequenceClassification 19 | 20 | from torch.utils.data import Dataset, DataLoader, DistributedSampler, RandomSampler, SequentialSampler 21 | import config 22 | from transformers import AdamW 23 | from transformers import get_linear_schedule_with_warmup 24 | from flint.data_utils.batchbuilder import BaseBatchBuilder, move_to_device 25 | from flint.data_utils.fields import RawFlintField, LabelFlintField, ArrayIndexFlintField 26 | from modeling.res_encoder import ResEncoder, EmptyScheduler, BagOfWords 27 | from utils import common, list_dict_data_tool, save_tool 28 | import os 29 | import torch.multiprocessing as mp 30 | import torch.distributed as dist 31 | import torch.nn as nn 32 | 33 | import numpy as np 34 | import random 35 | import torch 36 | from tqdm import tqdm 37 | import math 38 | import copy 39 | 40 | import pprint 41 | 42 | pp = pprint.PrettyPrinter(indent=2) 43 | 44 | # from fairseq.data.data_utils import collate_tokens 45 | 46 | MODEL_CLASSES = { 47 | "lstm-resencoder": { 48 | "model_name": "bert-large-uncased", 49 | "tokenizer": BertTokenizer, 50 | "sequence_classification": BertForSequenceClassification, 51 | # "padding_token_value": 0, 52 | "padding_segement_value": 0, 53 | "padding_att_value": 0, 54 | "do_lower_case": True, 55 | }, 56 | 57 | "bag-of-words": { 58 | "model_name": "bert-large-uncased", 59 | "tokenizer": BertTokenizer, 60 | "sequence_classification": BertForSequenceClassification, 61 | # "padding_token_value": 0, 62 | "padding_segement_value": 0, 63 | "padding_att_value": 0, 64 | "do_lower_case": True, 65 | }, 66 | 67 | "bert-base": { 68 | "model_name": "bert-base-uncased", 69 | "tokenizer": BertTokenizer, 70 | "sequence_classification": BertForSequenceClassification, 71 | # "padding_token_value": 0, 72 | "padding_segement_value": 0, 73 | "padding_att_value": 0, 74 | "do_lower_case": True, 75 | }, 76 | 77 | "bert-large": { 78 | "model_name": "bert-large-uncased", 79 | "tokenizer": BertTokenizer, 80 | "sequence_classification": BertForSequenceClassification, 81 | # "padding_token_value": 0, 82 | "padding_segement_value": 0, 83 | "padding_att_value": 0, 84 | "do_lower_case": True, 85 | "internal_model_name": "bert", 86 | 'insight_supported': True, 87 | }, 88 | 89 | "xlnet-base": { 90 | "model_name": "xlnet-base-cased", 91 | "tokenizer": XLNetTokenizer, 92 | "sequence_classification": XLNetForSequenceClassification, 93 | # "padding_token_value": 0, 94 | "padding_segement_value": 4, 95 | "padding_att_value": 0, 96 | "left_pad": True, 97 | "internal_model_name": ["transformer", "word_embedding"], 98 | }, 99 | "xlnet-large": { 100 | "model_name": "xlnet-large-cased", 101 | "tokenizer": XLNetTokenizer, 102 | "sequence_classification": XLNetForSequenceClassification, 103 | "padding_segement_value": 4, 104 | "padding_att_value": 0, 105 | "left_pad": True, 106 | "internal_model_name": ["transformer", "word_embedding"], 107 | 'insight_supported': True, 108 | }, 109 | 110 | "roberta-base": { 111 | "model_name": "roberta-base", 112 | "tokenizer": RobertaTokenizer, 113 | "sequence_classification": RobertaForSequenceClassification, 114 | "padding_segement_value": 0, 115 | "padding_att_value": 0, 116 | "internal_model_name": "roberta", 117 | 'insight_supported': True, 118 | }, 119 | "roberta-large": { 120 | "model_name": "roberta-large", 121 | "tokenizer": RobertaTokenizer, 122 | "sequence_classification": RobertaForSequenceClassification, 123 | "padding_segement_value": 0, 124 | "padding_att_value": 0, 125 | "internal_model_name": "roberta", 126 | 'insight_supported': True, 127 | }, 128 | 129 | "albert-xxlarge": { 130 | "model_name": "albert-xxlarge-v2", 131 | "tokenizer": AlbertTokenizer, 132 | "sequence_classification": AlbertForSequenceClassification, 133 | "padding_segement_value": 0, 134 | "padding_att_value": 0, 135 | "do_lower_case": True, 136 | "internal_model_name": "albert", 137 | 'insight_supported': True, 138 | }, 139 | 140 | "distilbert": { 141 | "model_name": "distilbert-base-cased", 142 | "tokenizer": DistilBertTokenizer, 143 | "sequence_classification": DistilBertForSequenceClassification, 144 | "padding_segement_value": 0, 145 | "padding_att_value": 0, 146 | }, 147 | 148 | "bart-large": { 149 | "model_name": "facebook/bart-large", 150 | "tokenizer": BartTokenizer, 151 | "sequence_classification": BartForSequenceClassification, 152 | "padding_segement_value": 0, 153 | "padding_att_value": 0, 154 | "internal_model_name": ["model", "encoder", "embed_tokens"], 155 | 'insight_supported': True, 156 | }, 157 | 158 | "electra-base": { 159 | "model_name": "google/electra-base-discriminator", 160 | "tokenizer": ElectraTokenizer, 161 | "sequence_classification": ElectraForSequenceClassification, 162 | "padding_segement_value": 0, 163 | "padding_att_value": 0, 164 | "internal_model_name": "electra", 165 | 'insight_supported': True, 166 | }, 167 | 168 | "electra-large": { 169 | "model_name": "google/electra-large-discriminator", 170 | "tokenizer": ElectraTokenizer, 171 | "sequence_classification": ElectraForSequenceClassification, 172 | "padding_segement_value": 0, 173 | "padding_att_value": 0, 174 | "internal_model_name": "electra", 175 | 'insight_supported': True, 176 | } 177 | } 178 | 179 | registered_path = { 180 | 'snli_train': config.PRO_ROOT / "data/build/snli/train.jsonl", 181 | 'snli_dev': config.PRO_ROOT / "data/build/snli/dev.jsonl", 182 | 'snli_test': config.PRO_ROOT / "data/build/snli/test.jsonl", 183 | 184 | 'mnli_train': config.PRO_ROOT / "data/build/mnli/train.jsonl", 185 | 'mnli_m_dev': config.PRO_ROOT / "data/build/mnli/m_dev.jsonl", 186 | 'mnli_mm_dev': config.PRO_ROOT / "data/build/mnli/mm_dev.jsonl", 187 | 188 | 'fever_train': config.PRO_ROOT / "data/build/fever_nli/train.jsonl", 189 | 'fever_dev': config.PRO_ROOT / "data/build/fever_nli/dev.jsonl", 190 | 'fever_test': config.PRO_ROOT / "data/build/fever_nli/test.jsonl", 191 | 192 | 'anli_r1_train': config.PRO_ROOT / "data/build/anli/r1/train.jsonl", 193 | 'anli_r1_dev': config.PRO_ROOT / "data/build/anli/r1/dev.jsonl", 194 | 'anli_r1_test': config.PRO_ROOT / "data/build/anli/r1/test.jsonl", 195 | 196 | 'anli_r2_train': config.PRO_ROOT / "data/build/anli/r2/train.jsonl", 197 | 'anli_r2_dev': config.PRO_ROOT / "data/build/anli/r2/dev.jsonl", 198 | 'anli_r2_test': config.PRO_ROOT / "data/build/anli/r2/test.jsonl", 199 | 200 | 'anli_r3_train': config.PRO_ROOT / "data/build/anli/r3/train.jsonl", 201 | 'anli_r3_dev': config.PRO_ROOT / "data/build/anli/r3/dev.jsonl", 202 | 'anli_r3_test': config.PRO_ROOT / "data/build/anli/r3/test.jsonl", 203 | } 204 | 205 | nli_label2index = { 206 | 'e': 0, 207 | 'n': 1, 208 | 'c': 2, 209 | 'h': -1, 210 | } 211 | 212 | 213 | def set_seed(seed): 214 | random.seed(seed) 215 | np.random.seed(seed) 216 | torch.manual_seed(seed) 217 | 218 | 219 | class NLIDataset(Dataset): 220 | def __init__(self, data_list, transform) -> None: 221 | super().__init__() 222 | self.d_list = data_list 223 | self.len = len(self.d_list) 224 | self.transform = transform 225 | 226 | def __getitem__(self, index: int): 227 | return self.transform(self.d_list[index]) 228 | 229 | # you should write schema for each of the input elements 230 | 231 | def __len__(self) -> int: 232 | return self.len 233 | 234 | 235 | class NLITransform(object): 236 | def __init__(self, model_name, tokenizer, max_length=None): 237 | self.model_name = model_name 238 | self.tokenizer = tokenizer 239 | self.max_length = max_length 240 | 241 | def __call__(self, sample): 242 | processed_sample = dict() 243 | processed_sample['uid'] = sample['uid'] 244 | processed_sample['gold_label'] = sample['label'] 245 | processed_sample['y'] = nli_label2index[sample['label']] 246 | 247 | # premise: str = sample['premise'] 248 | premise: str = sample['context'] if 'context' in sample else sample['premise'] 249 | hypothesis: str = sample['hypothesis'] 250 | 251 | if premise.strip() == '': 252 | premise = 'empty' 253 | 254 | if hypothesis.strip() == '': 255 | hypothesis = 'empty' 256 | 257 | tokenized_input_seq_pair = self.tokenizer.encode_plus(premise, hypothesis, 258 | max_length=self.max_length, 259 | return_token_type_ids=True, truncation=True) 260 | 261 | processed_sample.update(tokenized_input_seq_pair) 262 | 263 | return processed_sample 264 | 265 | 266 | def build_eval_dataset_loader_and_sampler(d_list, data_transformer, batching_schema, batch_size_per_gpu_eval): 267 | d_dataset = NLIDataset(d_list, data_transformer) 268 | d_sampler = SequentialSampler(d_dataset) 269 | d_dataloader = DataLoader(dataset=d_dataset, 270 | batch_size=batch_size_per_gpu_eval, 271 | shuffle=False, # 272 | num_workers=0, 273 | pin_memory=True, 274 | sampler=d_sampler, 275 | collate_fn=BaseBatchBuilder(batching_schema)) # 276 | return d_dataset, d_sampler, d_dataloader 277 | 278 | 279 | def sample_data_list(d_list, ratio): 280 | if ratio <= 0: 281 | raise ValueError("Invalid training weight ratio. Please change --train_weights.") 282 | upper_int = int(math.ceil(ratio)) 283 | if upper_int == 1: 284 | return d_list # if ratio is 1 then we just return the data list 285 | else: 286 | sampled_d_list = [] 287 | for _ in range(upper_int): 288 | sampled_d_list.extend(copy.deepcopy(d_list)) 289 | if np.isclose(ratio, upper_int): 290 | return sampled_d_list 291 | else: 292 | sampled_length = int(ratio * len(d_list)) 293 | random.shuffle(sampled_d_list) 294 | return sampled_d_list[:sampled_length] 295 | 296 | 297 | def main(): 298 | parser = argparse.ArgumentParser() 299 | parser.add_argument("--cpu", action="store_true", help="If set, we only use CPU.") 300 | parser.add_argument("--single_gpu", action="store_true", help="If set, we only use single GPU.") 301 | parser.add_argument("--fp16", action="store_true", help="If set, we will use fp16.") 302 | 303 | parser.add_argument( 304 | "--fp16_opt_level", 305 | type=str, 306 | default="O1", 307 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 308 | "See details at https://nvidia.github.io/apex/amp.html", 309 | ) 310 | 311 | # environment arguments 312 | parser.add_argument('-s', '--seed', default=1, type=int, metavar='N', 313 | help='manual random seed') 314 | parser.add_argument('-n', '--num_nodes', default=1, type=int, metavar='N', 315 | help='number of nodes') 316 | parser.add_argument('-g', '--gpus_per_node', default=1, type=int, 317 | help='number of gpus per node') 318 | parser.add_argument('-nr', '--node_rank', default=0, type=int, 319 | help='ranking within the nodes') 320 | 321 | # experiments specific arguments 322 | parser.add_argument('--debug_mode', 323 | action='store_true', 324 | dest='debug_mode', 325 | help='weather this is debug mode or normal') 326 | 327 | parser.add_argument( 328 | "--model_class_name", 329 | type=str, 330 | help="Set the model class of the experiment.", 331 | ) 332 | 333 | parser.add_argument( 334 | "--experiment_name", 335 | type=str, 336 | help="Set the name of the experiment. [model_name]/[data]/[task]/[other]", 337 | ) 338 | 339 | parser.add_argument( 340 | "--save_prediction", 341 | action='store_true', 342 | dest='save_prediction', 343 | help='Do we want to save prediction') 344 | 345 | parser.add_argument('--epochs', default=2, type=int, metavar='N', 346 | help='number of total epochs to run') 347 | parser.add_argument( 348 | "--per_gpu_train_batch_size", default=16, type=int, help="Batch size per GPU/CPU for training.", 349 | ) 350 | parser.add_argument( 351 | "--gradient_accumulation_steps", 352 | type=int, 353 | default=1, 354 | help="Number of updates steps to accumulate before performing a backward/update pass.", 355 | ) 356 | parser.add_argument( 357 | "--per_gpu_eval_batch_size", default=64, type=int, help="Batch size per GPU/CPU for evaluation.", 358 | ) 359 | 360 | parser.add_argument("--max_length", default=160, type=int, help="Max length of the sequences.") 361 | 362 | parser.add_argument("--warmup_steps", default=-1, type=int, help="Linear warmup over warmup_steps.") 363 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 364 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") 365 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 366 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 367 | 368 | parser.add_argument( 369 | "--eval_frequency", default=1000, type=int, help="set the evaluation frequency, evaluate every X global step.", 370 | ) 371 | 372 | parser.add_argument("--train_data", 373 | type=str, 374 | help="The training data used in the experiments.") 375 | 376 | parser.add_argument("--train_weights", 377 | type=str, 378 | help="The training data weights used in the experiments.") 379 | 380 | parser.add_argument("--eval_data", 381 | type=str, 382 | help="The training data used in the experiments.") 383 | 384 | args = parser.parse_args() 385 | 386 | if args.cpu: 387 | args.world_size = 1 388 | train(-1, args) 389 | elif args.single_gpu: 390 | args.world_size = 1 391 | train(0, args) 392 | else: # distributed multiGPU training 393 | ######################################################### 394 | args.world_size = args.gpus_per_node * args.num_nodes # 395 | # os.environ['MASTER_ADDR'] = '152.2.142.184' # This is the IP address for nlp5 396 | # maybe we will automatically retrieve the IP later. 397 | os.environ['MASTER_PORT'] = '88888' # 398 | mp.spawn(train, nprocs=args.gpus_per_node, args=(args,)) # spawn how many process in this node 399 | # remember train is called as train(i, args). 400 | ######################################################### 401 | 402 | 403 | def train(local_rank, args): 404 | # debug = False 405 | # print("GPU:", gpu) 406 | # world_size = args.world_size 407 | args.global_rank = args.node_rank * args.gpus_per_node + local_rank 408 | args.local_rank = local_rank 409 | # args.warmup_steps = 20 410 | debug_count = 1000 411 | num_epoch = args.epochs 412 | 413 | actual_train_batch_size = args.world_size * args.per_gpu_train_batch_size * args.gradient_accumulation_steps 414 | args.actual_train_batch_size = actual_train_batch_size 415 | 416 | set_seed(args.seed) 417 | num_labels = 3 # we are doing NLI so we set num_labels = 3, for other task we can change this value. 418 | 419 | max_length = args.max_length 420 | 421 | model_class_item = MODEL_CLASSES[args.model_class_name] 422 | model_class_name = args.model_class_name 423 | model_name = model_class_item['model_name'] 424 | do_lower_case = model_class_item['do_lower_case'] if 'do_lower_case' in model_class_item else False 425 | 426 | tokenizer = model_class_item['tokenizer'].from_pretrained(model_name, 427 | cache_dir=str(config.PRO_ROOT / "trans_cache"), 428 | do_lower_case=do_lower_case) 429 | 430 | if model_class_name in ['lstm-resencoder']: 431 | hg_model = model_class_item['sequence_classification'].from_pretrained(model_name, 432 | cache_dir=str( 433 | config.PRO_ROOT / "trans_cache"), 434 | num_labels=num_labels) 435 | embedding = hg_model.bert.embeddings.word_embeddings 436 | model = ResEncoder(v_size=embedding.weight.size(0), embd_dim=embedding.weight.size(1)) 437 | model.Embd.weight = embedding.weight 438 | 439 | elif model_class_name in ['bag-of-words']: 440 | hg_model = model_class_item['sequence_classification'].from_pretrained(model_name, 441 | cache_dir=str( 442 | config.PRO_ROOT / "trans_cache"), 443 | num_labels=num_labels) 444 | embedding = hg_model.bert.embeddings.word_embeddings 445 | model = BagOfWords(v_size=embedding.weight.size(0), embd_dim=embedding.weight.size(1)) 446 | model.Embd.weight = embedding.weight 447 | 448 | else: 449 | model = model_class_item['sequence_classification'].from_pretrained(model_name, 450 | cache_dir=str(config.PRO_ROOT / "trans_cache"), 451 | num_labels=num_labels) 452 | 453 | padding_token_value = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0] 454 | padding_segement_value = model_class_item["padding_segement_value"] 455 | padding_att_value = model_class_item["padding_att_value"] 456 | left_pad = model_class_item['left_pad'] if 'left_pad' in model_class_item else False 457 | 458 | batch_size_per_gpu_train = args.per_gpu_train_batch_size 459 | batch_size_per_gpu_eval = args.per_gpu_eval_batch_size 460 | 461 | if not args.cpu and not args.single_gpu: 462 | dist.init_process_group( 463 | backend='nccl', 464 | init_method='env://', 465 | world_size=args.world_size, 466 | rank=args.global_rank 467 | ) 468 | 469 | train_data_str = args.train_data 470 | train_data_weights_str = args.train_weights 471 | eval_data_str = args.eval_data 472 | 473 | train_data_name = [] 474 | train_data_path = [] 475 | train_data_list = [] 476 | train_data_weights = [] 477 | 478 | eval_data_name = [] 479 | eval_data_path = [] 480 | eval_data_list = [] 481 | 482 | train_data_named_path = train_data_str.split(',') 483 | weights_str = train_data_weights_str.split(',') if train_data_weights_str is not None else None 484 | 485 | eval_data_named_path = eval_data_str.split(',') 486 | 487 | for named_path in train_data_named_path: 488 | ind = named_path.find(':') 489 | name = named_path[:ind] 490 | path = name[ind + 1:] 491 | if name in registered_path: 492 | d_list = common.load_jsonl(registered_path[name]) 493 | else: 494 | d_list = common.load_jsonl(path) 495 | 496 | train_data_name.append(name) 497 | train_data_path.append(path) 498 | 499 | train_data_list.append(d_list) 500 | 501 | if weights_str is not None: 502 | for weights in weights_str: 503 | train_data_weights.append(float(weights)) 504 | else: 505 | for i in range(len(train_data_list)): 506 | train_data_weights.append(1) 507 | 508 | for named_path in eval_data_named_path: 509 | ind = named_path.find(':') 510 | name = named_path[:ind] 511 | path = name[ind + 1:] 512 | if name in registered_path: 513 | d_list = common.load_jsonl(registered_path[name]) 514 | else: 515 | d_list = common.load_jsonl(path) 516 | eval_data_name.append(name) 517 | eval_data_path.append(path) 518 | 519 | eval_data_list.append(d_list) 520 | 521 | assert len(train_data_weights) == len(train_data_list) 522 | 523 | batching_schema = { 524 | 'uid': RawFlintField(), 525 | 'y': LabelFlintField(), 526 | 'input_ids': ArrayIndexFlintField(pad_idx=padding_token_value, left_pad=left_pad), 527 | 'token_type_ids': ArrayIndexFlintField(pad_idx=padding_segement_value, left_pad=left_pad), 528 | 'attention_mask': ArrayIndexFlintField(pad_idx=padding_att_value, left_pad=left_pad), 529 | } 530 | 531 | data_transformer = NLITransform(model_name, tokenizer, max_length) 532 | # data_transformer = NLITransform(model_name, tokenizer, max_length, with_element=True) 533 | 534 | eval_data_loaders = [] 535 | for eval_d_list in eval_data_list: 536 | d_dataset, d_sampler, d_dataloader = build_eval_dataset_loader_and_sampler(eval_d_list, data_transformer, 537 | batching_schema, 538 | batch_size_per_gpu_eval) 539 | eval_data_loaders.append(d_dataloader) 540 | 541 | # Estimate the training size: 542 | training_list = [] 543 | for i in range(len(train_data_list)): 544 | print("Build Training Data ...") 545 | train_d_list = train_data_list[i] 546 | train_d_name = train_data_name[i] 547 | train_d_weight = train_data_weights[i] 548 | cur_train_list = sample_data_list(train_d_list, train_d_weight) # change later # we can apply different sample strategy here. 549 | print(f"Data Name:{train_d_name}; Weight: {train_d_weight}; " 550 | f"Original Size: {len(train_d_list)}; Sampled Size: {len(cur_train_list)}") 551 | training_list.extend(cur_train_list) 552 | estimated_training_size = len(training_list) 553 | print("Estimated training size:", estimated_training_size) 554 | # Estimate the training size ends: 555 | 556 | # t_total = estimated_training_size // args.gradient_accumulation_steps * num_epoch 557 | t_total = estimated_training_size * num_epoch // args.actual_train_batch_size 558 | if args.warmup_steps <= 0: # set the warmup steps to 0.1 * total step if the given warmup step is -1. 559 | args.warmup_steps = int(t_total * 0.1) 560 | 561 | if not args.cpu: 562 | torch.cuda.set_device(args.local_rank) 563 | model.cuda(args.local_rank) 564 | 565 | no_decay = ["bias", "LayerNorm.weight"] 566 | optimizer_grouped_parameters = [ 567 | { 568 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 569 | "weight_decay": args.weight_decay, 570 | }, 571 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 572 | ] 573 | 574 | if model_class_name not in ['lstm-resencoder']: 575 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 576 | scheduler = get_linear_schedule_with_warmup( 577 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 578 | ) 579 | else: 580 | optimizer = Adam(optimizer_grouped_parameters) 581 | scheduler = EmptyScheduler() 582 | 583 | if args.fp16: 584 | try: 585 | from apex import amp 586 | except ImportError: 587 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 588 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 589 | 590 | if not args.cpu and not args.single_gpu: 591 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], 592 | output_device=local_rank, find_unused_parameters=True) 593 | 594 | args_dict = dict(vars(args)) 595 | file_path_prefix = '.' 596 | if args.global_rank in [-1, 0]: 597 | print("Total Steps:", t_total) 598 | args.total_step = t_total 599 | print("Warmup Steps:", args.warmup_steps) 600 | print("Actual Training Batch Size:", actual_train_batch_size) 601 | print("Arguments", pp.pprint(args)) 602 | 603 | # Let build the logger and log everything before the start of the first training epoch. 604 | if args.global_rank in [-1, 0]: # only do logging if we use cpu or global_rank=0 605 | if not args.debug_mode: 606 | file_path_prefix, date = save_tool.gen_file_prefix(f"{args.experiment_name}") 607 | # # # Create Log File 608 | # Save the source code. 609 | script_name = os.path.basename(__file__) 610 | with open(os.path.join(file_path_prefix, script_name), 'w') as out_f, open(__file__, 'r') as it: 611 | out_f.write(it.read()) 612 | out_f.flush() 613 | 614 | # Save option file 615 | common.save_json(args_dict, os.path.join(file_path_prefix, "args.json")) 616 | checkpoints_path = Path(file_path_prefix) / "checkpoints" 617 | if not checkpoints_path.exists(): 618 | checkpoints_path.mkdir() 619 | prediction_path = Path(file_path_prefix) / "predictions" 620 | if not prediction_path.exists(): 621 | prediction_path.mkdir() 622 | 623 | global_step = 0 624 | 625 | # print(f"Global Rank:{args.global_rank} ### ", 'Init!') 626 | 627 | for epoch in tqdm(range(num_epoch), desc="Epoch", disable=args.global_rank not in [-1, 0]): 628 | # Let's build up training dataset for this epoch 629 | training_list = [] 630 | for i in range(len(train_data_list)): 631 | print("Build Training Data ...") 632 | train_d_list = train_data_list[i] 633 | train_d_name = train_data_name[i] 634 | train_d_weight = train_data_weights[i] 635 | cur_train_list = sample_data_list(train_d_list, train_d_weight) # change later # we can apply different sample strategy here. 636 | print(f"Data Name:{train_d_name}; Weight: {train_d_weight}; " 637 | f"Original Size: {len(train_d_list)}; Sampled Size: {len(cur_train_list)}") 638 | training_list.extend(cur_train_list) 639 | 640 | random.shuffle(training_list) 641 | train_dataset = NLIDataset(training_list, data_transformer) 642 | 643 | train_sampler = SequentialSampler(train_dataset) 644 | if not args.cpu and not args.single_gpu: 645 | print("Use distributed sampler.") 646 | train_sampler = DistributedSampler(train_dataset, args.world_size, args.global_rank, 647 | shuffle=True) 648 | 649 | train_dataloader = DataLoader(dataset=train_dataset, 650 | batch_size=batch_size_per_gpu_train, 651 | shuffle=False, # 652 | num_workers=0, 653 | pin_memory=True, 654 | sampler=train_sampler, 655 | collate_fn=BaseBatchBuilder(batching_schema)) # 656 | # training build finished. 657 | 658 | print(debug_node_info(args), "epoch: ", epoch) 659 | 660 | if not args.cpu and not args.single_gpu: 661 | train_sampler.set_epoch(epoch) # setup the epoch to ensure random sampling at each epoch 662 | 663 | for forward_step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", 664 | disable=args.global_rank not in [-1, 0]), 0): 665 | model.train() 666 | 667 | batch = move_to_device(batch, local_rank) 668 | # print(batch['input_ids'], batch['y']) 669 | if args.model_class_name in ["distilbert", "bart-large", "lstm-resencoder", "bag-of-words"]: 670 | outputs = model(batch['input_ids'], 671 | attention_mask=batch['attention_mask'], 672 | labels=batch['y']) 673 | else: 674 | outputs = model(batch['input_ids'], 675 | attention_mask=batch['attention_mask'], 676 | token_type_ids=batch['token_type_ids'], 677 | labels=batch['y']) 678 | loss, logits = outputs[:2] 679 | # print(debug_node_info(args), loss, logits, batch['uid']) 680 | # print(debug_node_info(args), loss, batch['uid']) 681 | 682 | # Accumulated loss 683 | if args.gradient_accumulation_steps > 1: 684 | loss = loss / args.gradient_accumulation_steps 685 | 686 | # if this forward step need model updates 687 | # handle fp16 688 | if args.fp16: 689 | with amp.scale_loss(loss, optimizer) as scaled_loss: 690 | scaled_loss.backward() 691 | else: 692 | loss.backward() 693 | 694 | # Gradient clip: if max_grad_norm < 0 695 | if (forward_step + 1) % args.gradient_accumulation_steps == 0: 696 | if args.max_grad_norm > 0: 697 | if args.fp16: 698 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 699 | else: 700 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 701 | 702 | optimizer.step() 703 | scheduler.step() # Update learning rate schedule 704 | model.zero_grad() 705 | 706 | global_step += 1 707 | 708 | if args.global_rank in [-1, 0] and args.eval_frequency > 0 and global_step % args.eval_frequency == 0: 709 | r_dict = dict() 710 | # Eval loop: 711 | for i in range(len(eval_data_name)): 712 | cur_eval_data_name = eval_data_name[i] 713 | cur_eval_data_list = eval_data_list[i] 714 | cur_eval_dataloader = eval_data_loaders[i] 715 | # cur_eval_raw_data_list = eval_raw_data_list[i] 716 | 717 | evaluation_dataset(args, cur_eval_dataloader, cur_eval_data_list, model, r_dict, 718 | eval_name=cur_eval_data_name) 719 | 720 | # saving checkpoints 721 | current_checkpoint_filename = \ 722 | f'e({epoch})|i({global_step})' 723 | 724 | for i in range(len(eval_data_name)): 725 | cur_eval_data_name = eval_data_name[i] 726 | current_checkpoint_filename += \ 727 | f'|{cur_eval_data_name}#({round(r_dict[cur_eval_data_name]["acc"], 4)})' 728 | 729 | if not args.debug_mode: 730 | # save model: 731 | model_output_dir = checkpoints_path / current_checkpoint_filename 732 | if not model_output_dir.exists(): 733 | model_output_dir.mkdir() 734 | model_to_save = ( 735 | model.module if hasattr(model, "module") else model 736 | ) # Take care of distributed/parallel training 737 | 738 | torch.save(model_to_save.state_dict(), str(model_output_dir / "model.pt")) 739 | torch.save(optimizer.state_dict(), str(model_output_dir / "optimizer.pt")) 740 | torch.save(scheduler.state_dict(), str(model_output_dir / "scheduler.pt")) 741 | 742 | # save prediction: 743 | if not args.debug_mode and args.save_prediction: 744 | cur_results_path = prediction_path / current_checkpoint_filename 745 | if not cur_results_path.exists(): 746 | cur_results_path.mkdir(parents=True) 747 | for key, item in r_dict.items(): 748 | common.save_jsonl(item['predictions'], cur_results_path / f"{key}.jsonl") 749 | 750 | # avoid saving too many things 751 | for key, item in r_dict.items(): 752 | del r_dict[key]['predictions'] 753 | common.save_json(r_dict, cur_results_path / "results_dict.json", indent=2) 754 | 755 | # End of epoch evaluation. 756 | if args.global_rank in [-1, 0]: 757 | r_dict = dict() 758 | # Eval loop: 759 | for i in range(len(eval_data_name)): 760 | cur_eval_data_name = eval_data_name[i] 761 | cur_eval_data_list = eval_data_list[i] 762 | cur_eval_dataloader = eval_data_loaders[i] 763 | # cur_eval_raw_data_list = eval_raw_data_list[i] 764 | 765 | evaluation_dataset(args, cur_eval_dataloader, cur_eval_data_list, model, r_dict, 766 | eval_name=cur_eval_data_name) 767 | 768 | # saving checkpoints 769 | current_checkpoint_filename = \ 770 | f'e({epoch})|i({global_step})' 771 | 772 | for i in range(len(eval_data_name)): 773 | cur_eval_data_name = eval_data_name[i] 774 | current_checkpoint_filename += \ 775 | f'|{cur_eval_data_name}#({round(r_dict[cur_eval_data_name]["acc"], 4)})' 776 | 777 | if not args.debug_mode: 778 | # save model: 779 | model_output_dir = checkpoints_path / current_checkpoint_filename 780 | if not model_output_dir.exists(): 781 | model_output_dir.mkdir() 782 | model_to_save = ( 783 | model.module if hasattr(model, "module") else model 784 | ) # Take care of distributed/parallel training 785 | 786 | torch.save(model_to_save.state_dict(), str(model_output_dir / "model.pt")) 787 | torch.save(optimizer.state_dict(), str(model_output_dir / "optimizer.pt")) 788 | torch.save(scheduler.state_dict(), str(model_output_dir / "scheduler.pt")) 789 | 790 | # save prediction: 791 | if not args.debug_mode and args.save_prediction: 792 | cur_results_path = prediction_path / current_checkpoint_filename 793 | if not cur_results_path.exists(): 794 | cur_results_path.mkdir(parents=True) 795 | for key, item in r_dict.items(): 796 | common.save_jsonl(item['predictions'], cur_results_path / f"{key}.jsonl") 797 | 798 | # avoid saving too many things 799 | for key, item in r_dict.items(): 800 | del r_dict[key]['predictions'] 801 | common.save_json(r_dict, cur_results_path / "results_dict.json", indent=2) 802 | 803 | 804 | id2label = { 805 | 0: 'e', 806 | 1: 'n', 807 | 2: 'c', 808 | -1: '-', 809 | } 810 | 811 | 812 | def count_acc(gt_list, pred_list): 813 | assert len(gt_list) == len(pred_list) 814 | gt_dict = list_dict_data_tool.list_to_dict(gt_list, 'uid') 815 | pred_list = list_dict_data_tool.list_to_dict(pred_list, 'uid') 816 | total_count = 0 817 | hit = 0 818 | for key, value in pred_list.items(): 819 | if gt_dict[key]['label'] == value['predicted_label']: 820 | hit += 1 821 | total_count += 1 822 | return hit, total_count 823 | 824 | 825 | def evaluation_dataset(args, eval_dataloader, eval_list, model, r_dict, eval_name): 826 | # r_dict = dict() 827 | pred_output_list = eval_model(model, eval_dataloader, args.global_rank, args) 828 | predictions = pred_output_list 829 | hit, total = count_acc(eval_list, pred_output_list) 830 | 831 | print(debug_node_info(args), f"{eval_name} Acc:", hit, total, hit / total) 832 | 833 | r_dict[f'{eval_name}'] = { 834 | 'acc': hit / total, 835 | 'correct_count': hit, 836 | 'total_count': total, 837 | 'predictions': predictions, 838 | } 839 | 840 | 841 | def eval_model(model, dev_dataloader, device_num, args): 842 | model.eval() 843 | 844 | uid_list = [] 845 | y_list = [] 846 | pred_list = [] 847 | logits_list = [] 848 | 849 | with torch.no_grad(): 850 | for i, batch in enumerate(dev_dataloader, 0): 851 | batch = move_to_device(batch, device_num) 852 | 853 | if args.model_class_name in ["distilbert", "bart-large", 'lstm-resencoder', "bag-of-words"]: 854 | outputs = model(batch['input_ids'], 855 | attention_mask=batch['attention_mask'], 856 | labels=batch['y']) 857 | else: 858 | outputs = model(batch['input_ids'], 859 | attention_mask=batch['attention_mask'], 860 | token_type_ids=batch['token_type_ids'], 861 | labels=batch['y']) 862 | 863 | loss, logits = outputs[:2] 864 | 865 | uid_list.extend(list(batch['uid'])) 866 | y_list.extend(batch['y'].tolist()) 867 | pred_list.extend(torch.max(logits, 1)[1].view(logits.size(0)).tolist()) 868 | logits_list.extend(logits.tolist()) 869 | 870 | assert len(pred_list) == len(logits_list) 871 | assert len(pred_list) == len(logits_list) 872 | 873 | result_items_list = [] 874 | for i in range(len(uid_list)): 875 | r_item = dict() 876 | r_item['uid'] = uid_list[i] 877 | r_item['logits'] = logits_list[i] 878 | r_item['predicted_label'] = id2label[pred_list[i]] 879 | 880 | result_items_list.append(r_item) 881 | 882 | return result_items_list 883 | 884 | 885 | def debug_node_info(args): 886 | names = ['global_rank', 'local_rank', 'node_rank'] 887 | values = [] 888 | 889 | for name in names: 890 | if name in args: 891 | values.append(getattr(args, name)) 892 | else: 893 | return "Pro:No node info " 894 | 895 | return "Pro:" + '|'.join([f"{name}:{value}" for name, value in zip(names, values)]) + "||Print:" 896 | 897 | 898 | if __name__ == '__main__': 899 | main() 900 | -------------------------------------------------------------------------------- /src/nli/training.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under Creative Commons-Non Commercial 4.0 found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | from pathlib import Path 8 | 9 | from transformers import RobertaTokenizer, RobertaForSequenceClassification 10 | from transformers import XLNetTokenizer, XLNetForSequenceClassification 11 | # from transformers import XLNetTokenizer 12 | # from modeling.dummy_modeling_xlnet import XLNetForSequenceClassification 13 | from transformers import BertTokenizer, BertForSequenceClassification 14 | from transformers import AlbertTokenizer, AlbertForSequenceClassification 15 | from transformers import DistilBertTokenizer, DistilBertForSequenceClassification 16 | from transformers import BartTokenizer, BartForSequenceClassification 17 | from transformers import ElectraTokenizer, ElectraForSequenceClassification 18 | 19 | from torch.utils.data import Dataset, DataLoader, DistributedSampler, RandomSampler, SequentialSampler 20 | import config 21 | from transformers import AdamW 22 | from transformers import get_linear_schedule_with_warmup 23 | from flint.data_utils.batchbuilder import BaseBatchBuilder, move_to_device 24 | from flint.data_utils.fields import RawFlintField, LabelFlintField, ArrayIndexFlintField 25 | from utils import common, list_dict_data_tool, save_tool 26 | import os 27 | import torch.multiprocessing as mp 28 | import torch.distributed as dist 29 | import torch.nn as nn 30 | 31 | import numpy as np 32 | import random 33 | import torch 34 | from tqdm import tqdm 35 | import math 36 | import copy 37 | 38 | import pprint 39 | 40 | pp = pprint.PrettyPrinter(indent=2) 41 | 42 | # from fairseq.data.data_utils import collate_tokens 43 | 44 | MODEL_CLASSES = { 45 | "bert-base": { 46 | "model_name": "bert-base-uncased", 47 | "tokenizer": BertTokenizer, 48 | "sequence_classification": BertForSequenceClassification, 49 | # "padding_token_value": 0, 50 | "padding_segement_value": 0, 51 | "padding_att_value": 0, 52 | "do_lower_case": True, 53 | }, 54 | "bert-large": { 55 | "model_name": "bert-large-uncased", 56 | "tokenizer": BertTokenizer, 57 | "sequence_classification": BertForSequenceClassification, 58 | # "padding_token_value": 0, 59 | "padding_segement_value": 0, 60 | "padding_att_value": 0, 61 | "do_lower_case": True, 62 | "internal_model_name": "bert", 63 | 'insight_supported': True, 64 | }, 65 | 66 | "xlnet-base": { 67 | "model_name": "xlnet-base-cased", 68 | "tokenizer": XLNetTokenizer, 69 | "sequence_classification": XLNetForSequenceClassification, 70 | # "padding_token_value": 0, 71 | "padding_segement_value": 4, 72 | "padding_att_value": 0, 73 | "left_pad": True, 74 | "internal_model_name": ["transformer", "word_embedding"], 75 | }, 76 | "xlnet-large": { 77 | "model_name": "xlnet-large-cased", 78 | "tokenizer": XLNetTokenizer, 79 | "sequence_classification": XLNetForSequenceClassification, 80 | "padding_segement_value": 4, 81 | "padding_att_value": 0, 82 | "left_pad": True, 83 | "internal_model_name": ["transformer", "word_embedding"], 84 | 'insight_supported': True, 85 | }, 86 | 87 | "roberta-base": { 88 | "model_name": "roberta-base", 89 | "tokenizer": RobertaTokenizer, 90 | "sequence_classification": RobertaForSequenceClassification, 91 | "padding_segement_value": 0, 92 | "padding_att_value": 0, 93 | "internal_model_name": "roberta", 94 | 'insight_supported': True, 95 | }, 96 | "roberta-large": { 97 | "model_name": "roberta-large", 98 | "tokenizer": RobertaTokenizer, 99 | "sequence_classification": RobertaForSequenceClassification, 100 | "padding_segement_value": 0, 101 | "padding_att_value": 0, 102 | "internal_model_name": "roberta", 103 | 'insight_supported': True, 104 | }, 105 | 106 | "albert-xxlarge": { 107 | "model_name": "albert-xxlarge-v2", 108 | "tokenizer": AlbertTokenizer, 109 | "sequence_classification": AlbertForSequenceClassification, 110 | "padding_segement_value": 0, 111 | "padding_att_value": 0, 112 | "do_lower_case": True, 113 | "internal_model_name": "albert", 114 | 'insight_supported': True, 115 | }, 116 | 117 | "distilbert": { 118 | "model_name": "distilbert-base-cased", 119 | "tokenizer": DistilBertTokenizer, 120 | "sequence_classification": DistilBertForSequenceClassification, 121 | "padding_segement_value": 0, 122 | "padding_att_value": 0, 123 | }, 124 | 125 | "bart-large": { 126 | "model_name": "facebook/bart-large", 127 | "tokenizer": BartTokenizer, 128 | "sequence_classification": BartForSequenceClassification, 129 | "padding_segement_value": 0, 130 | "padding_att_value": 0, 131 | "internal_model_name": ["model", "encoder", "embed_tokens"], 132 | 'insight_supported': True, 133 | }, 134 | 135 | "electra-base": { 136 | "model_name": "google/electra-base-discriminator", 137 | "tokenizer": ElectraTokenizer, 138 | "sequence_classification": ElectraForSequenceClassification, 139 | "padding_segement_value": 0, 140 | "padding_att_value": 0, 141 | "internal_model_name": "electra", 142 | 'insight_supported': True, 143 | }, 144 | 145 | "electra-large": { 146 | "model_name": "google/electra-large-discriminator", 147 | "tokenizer": ElectraTokenizer, 148 | "sequence_classification": ElectraForSequenceClassification, 149 | "padding_segement_value": 0, 150 | "padding_att_value": 0, 151 | "internal_model_name": "electra", 152 | 'insight_supported': True, 153 | } 154 | } 155 | 156 | registered_path = { 157 | 'snli_train': config.PRO_ROOT / "data/build/snli/train.jsonl", 158 | 'snli_dev': config.PRO_ROOT / "data/build/snli/dev.jsonl", 159 | 'snli_test': config.PRO_ROOT / "data/build/snli/test.jsonl", 160 | 161 | 'mnli_train': config.PRO_ROOT / "data/build/mnli/train.jsonl", 162 | 'mnli_m_dev': config.PRO_ROOT / "data/build/mnli/m_dev.jsonl", 163 | 'mnli_mm_dev': config.PRO_ROOT / "data/build/mnli/mm_dev.jsonl", 164 | 165 | 'fever_train': config.PRO_ROOT / "data/build/fever_nli/train.jsonl", 166 | 'fever_dev': config.PRO_ROOT / "data/build/fever_nli/dev.jsonl", 167 | 'fever_test': config.PRO_ROOT / "data/build/fever_nli/test.jsonl", 168 | 169 | 'anli_r1_train': config.PRO_ROOT / "data/build/anli/r1/train.jsonl", 170 | 'anli_r1_dev': config.PRO_ROOT / "data/build/anli/r1/dev.jsonl", 171 | 'anli_r1_test': config.PRO_ROOT / "data/build/anli/r1/test.jsonl", 172 | 173 | 'anli_r2_train': config.PRO_ROOT / "data/build/anli/r2/train.jsonl", 174 | 'anli_r2_dev': config.PRO_ROOT / "data/build/anli/r2/dev.jsonl", 175 | 'anli_r2_test': config.PRO_ROOT / "data/build/anli/r2/test.jsonl", 176 | 177 | 'anli_r3_train': config.PRO_ROOT / "data/build/anli/r3/train.jsonl", 178 | 'anli_r3_dev': config.PRO_ROOT / "data/build/anli/r3/dev.jsonl", 179 | 'anli_r3_test': config.PRO_ROOT / "data/build/anli/r3/test.jsonl", 180 | } 181 | 182 | nli_label2index = { 183 | 'e': 0, 184 | 'n': 1, 185 | 'c': 2, 186 | 'h': -1, 187 | } 188 | 189 | 190 | def set_seed(seed): 191 | random.seed(seed) 192 | np.random.seed(seed) 193 | torch.manual_seed(seed) 194 | 195 | 196 | class NLIDataset(Dataset): 197 | def __init__(self, data_list, transform) -> None: 198 | super().__init__() 199 | self.d_list = data_list 200 | self.len = len(self.d_list) 201 | self.transform = transform 202 | 203 | def __getitem__(self, index: int): 204 | return self.transform(self.d_list[index]) 205 | 206 | # you should write schema for each of the input elements 207 | 208 | def __len__(self) -> int: 209 | return self.len 210 | 211 | 212 | class NLITransform(object): 213 | def __init__(self, model_name, tokenizer, max_length=None): 214 | self.model_name = model_name 215 | self.tokenizer = tokenizer 216 | self.max_length = max_length 217 | 218 | def __call__(self, sample): 219 | processed_sample = dict() 220 | processed_sample['uid'] = sample['uid'] 221 | processed_sample['gold_label'] = sample['label'] 222 | processed_sample['y'] = nli_label2index[sample['label']] 223 | 224 | # premise: str = sample['premise'] 225 | premise: str = sample['context'] if 'context' in sample else sample['premise'] 226 | hypothesis: str = sample['hypothesis'] 227 | 228 | if premise.strip() == '': 229 | premise = 'empty' 230 | 231 | if hypothesis.strip() == '': 232 | hypothesis = 'empty' 233 | 234 | tokenized_input_seq_pair = self.tokenizer.encode_plus(premise, hypothesis, 235 | max_length=self.max_length, 236 | return_token_type_ids=True, truncation=True) 237 | 238 | processed_sample.update(tokenized_input_seq_pair) 239 | 240 | return processed_sample 241 | 242 | 243 | def build_eval_dataset_loader_and_sampler(d_list, data_transformer, batching_schema, batch_size_per_gpu_eval): 244 | d_dataset = NLIDataset(d_list, data_transformer) 245 | d_sampler = SequentialSampler(d_dataset) 246 | d_dataloader = DataLoader(dataset=d_dataset, 247 | batch_size=batch_size_per_gpu_eval, 248 | shuffle=False, # 249 | num_workers=0, 250 | pin_memory=True, 251 | sampler=d_sampler, 252 | collate_fn=BaseBatchBuilder(batching_schema)) # 253 | return d_dataset, d_sampler, d_dataloader 254 | 255 | 256 | def sample_data_list(d_list, ratio): 257 | if ratio <= 0: 258 | raise ValueError("Invalid training weight ratio. Please change --train_weights.") 259 | upper_int = int(math.ceil(ratio)) 260 | if upper_int == 1: 261 | return d_list # if ratio is 1 then we just return the data list 262 | else: 263 | sampled_d_list = [] 264 | for _ in range(upper_int): 265 | sampled_d_list.extend(copy.deepcopy(d_list)) 266 | if np.isclose(ratio, upper_int): 267 | return sampled_d_list 268 | else: 269 | sampled_length = int(ratio * len(d_list)) 270 | random.shuffle(sampled_d_list) 271 | return sampled_d_list[:sampled_length] 272 | 273 | 274 | def main(): 275 | parser = argparse.ArgumentParser() 276 | parser.add_argument("--cpu", action="store_true", help="If set, we only use CPU.") 277 | parser.add_argument("--single_gpu", action="store_true", help="If set, we only use single GPU.") 278 | parser.add_argument("--fp16", action="store_true", help="If set, we will use fp16.") 279 | 280 | parser.add_argument( 281 | "--fp16_opt_level", 282 | type=str, 283 | default="O1", 284 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 285 | "See details at https://nvidia.github.io/apex/amp.html", 286 | ) 287 | 288 | # environment arguments 289 | parser.add_argument('-s', '--seed', default=1, type=int, metavar='N', 290 | help='manual random seed') 291 | parser.add_argument('-n', '--num_nodes', default=1, type=int, metavar='N', 292 | help='number of nodes') 293 | parser.add_argument('-g', '--gpus_per_node', default=1, type=int, 294 | help='number of gpus per node') 295 | parser.add_argument('-nr', '--node_rank', default=0, type=int, 296 | help='ranking within the nodes') 297 | 298 | # experiments specific arguments 299 | parser.add_argument('--debug_mode', 300 | action='store_true', 301 | dest='debug_mode', 302 | help='weather this is debug mode or normal') 303 | 304 | parser.add_argument( 305 | "--model_class_name", 306 | type=str, 307 | help="Set the model class of the experiment.", 308 | ) 309 | 310 | parser.add_argument( 311 | "--experiment_name", 312 | type=str, 313 | help="Set the name of the experiment. [model_name]/[data]/[task]/[other]", 314 | ) 315 | 316 | parser.add_argument( 317 | "--save_prediction", 318 | action='store_true', 319 | dest='save_prediction', 320 | help='Do we want to save prediction') 321 | 322 | parser.add_argument( 323 | "--resume_path", 324 | type=str, 325 | default=None, 326 | help="If we want to resume model training, we need to set the resume path to restore state dicts.", 327 | ) 328 | parser.add_argument( 329 | "--global_iteration", 330 | type=int, 331 | default=0, 332 | help="This argument is only used if we resume model training.", 333 | ) 334 | 335 | parser.add_argument('--epochs', default=2, type=int, metavar='N', 336 | help='number of total epochs to run') 337 | parser.add_argument('--total_step', default=-1, type=int, metavar='N', 338 | help='number of step to update, default calculate with total data size.' 339 | 'if we set this step, then epochs will be 100 to run forever.') 340 | 341 | parser.add_argument('--sampler_seed', default=-1, type=int, metavar='N', 342 | help='The seed the controls the data sampling order.') 343 | 344 | parser.add_argument( 345 | "--per_gpu_train_batch_size", default=16, type=int, help="Batch size per GPU/CPU for training.", 346 | ) 347 | parser.add_argument( 348 | "--gradient_accumulation_steps", 349 | type=int, 350 | default=1, 351 | help="Number of updates steps to accumulate before performing a backward/update pass.", 352 | ) 353 | parser.add_argument( 354 | "--per_gpu_eval_batch_size", default=64, type=int, help="Batch size per GPU/CPU for evaluation.", 355 | ) 356 | 357 | parser.add_argument("--max_length", default=160, type=int, help="Max length of the sequences.") 358 | 359 | parser.add_argument("--warmup_steps", default=-1, type=int, help="Linear warmup over warmup_steps.") 360 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 361 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") 362 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 363 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 364 | 365 | parser.add_argument( 366 | "--eval_frequency", default=1000, type=int, help="set the evaluation frequency, evaluate every X global step.", 367 | ) 368 | 369 | parser.add_argument("--train_data", 370 | type=str, 371 | help="The training data used in the experiments.") 372 | 373 | parser.add_argument("--train_weights", 374 | type=str, 375 | help="The training data weights used in the experiments.") 376 | 377 | parser.add_argument("--eval_data", 378 | type=str, 379 | help="The training data used in the experiments.") 380 | 381 | args = parser.parse_args() 382 | 383 | if args.cpu: 384 | args.world_size = 1 385 | train(-1, args) 386 | elif args.single_gpu: 387 | args.world_size = 1 388 | train(0, args) 389 | else: # distributed multiGPU training 390 | ######################################################### 391 | args.world_size = args.gpus_per_node * args.num_nodes # 392 | # os.environ['MASTER_ADDR'] = '152.2.142.184' # This is the IP address for nlp5 393 | # maybe we will automatically retrieve the IP later. 394 | os.environ['MASTER_PORT'] = '88888' # 395 | mp.spawn(train, nprocs=args.gpus_per_node, args=(args,)) # spawn how many process in this node 396 | # remember train is called as train(i, args). 397 | ######################################################### 398 | 399 | 400 | def train(local_rank, args): 401 | # debug = False 402 | # print("GPU:", gpu) 403 | # world_size = args.world_size 404 | args.global_rank = args.node_rank * args.gpus_per_node + local_rank 405 | args.local_rank = local_rank 406 | # args.warmup_steps = 20 407 | debug_count = 1000 408 | 409 | if args.total_step > 0: 410 | num_epoch = 10000 # if we set total step, num_epoch will be forever. 411 | else: 412 | num_epoch = args.epochs 413 | 414 | actual_train_batch_size = args.world_size * args.per_gpu_train_batch_size * args.gradient_accumulation_steps 415 | args.actual_train_batch_size = actual_train_batch_size 416 | 417 | set_seed(args.seed) 418 | num_labels = 3 # we are doing NLI so we set num_labels = 3, for other task we can change this value. 419 | 420 | max_length = args.max_length 421 | 422 | model_class_item = MODEL_CLASSES[args.model_class_name] 423 | model_name = model_class_item['model_name'] 424 | do_lower_case = model_class_item['do_lower_case'] if 'do_lower_case' in model_class_item else False 425 | 426 | tokenizer = model_class_item['tokenizer'].from_pretrained(model_name, 427 | cache_dir=str(config.PRO_ROOT / "trans_cache"), 428 | do_lower_case=do_lower_case) 429 | 430 | model = model_class_item['sequence_classification'].from_pretrained(model_name, 431 | cache_dir=str(config.PRO_ROOT / "trans_cache"), 432 | num_labels=num_labels) 433 | 434 | padding_token_value = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0] 435 | padding_segement_value = model_class_item["padding_segement_value"] 436 | padding_att_value = model_class_item["padding_att_value"] 437 | left_pad = model_class_item['left_pad'] if 'left_pad' in model_class_item else False 438 | 439 | batch_size_per_gpu_train = args.per_gpu_train_batch_size 440 | batch_size_per_gpu_eval = args.per_gpu_eval_batch_size 441 | 442 | if not args.cpu and not args.single_gpu: 443 | dist.init_process_group( 444 | backend='nccl', 445 | init_method='env://', 446 | world_size=args.world_size, 447 | rank=args.global_rank 448 | ) 449 | 450 | train_data_str = args.train_data 451 | train_data_weights_str = args.train_weights 452 | eval_data_str = args.eval_data 453 | 454 | train_data_name = [] 455 | train_data_path = [] 456 | train_data_list = [] 457 | train_data_weights = [] 458 | 459 | eval_data_name = [] 460 | eval_data_path = [] 461 | eval_data_list = [] 462 | 463 | train_data_named_path = train_data_str.split(',') 464 | weights_str = train_data_weights_str.split(',') if train_data_weights_str is not None else None 465 | 466 | eval_data_named_path = eval_data_str.split(',') 467 | 468 | for named_path in train_data_named_path: 469 | ind = named_path.find(':') 470 | name = named_path[:ind] 471 | path = named_path[ind + 1:] 472 | if name in registered_path: 473 | d_list = common.load_jsonl(registered_path[name]) 474 | else: 475 | d_list = common.load_jsonl(path) 476 | 477 | train_data_name.append(name) 478 | train_data_path.append(path) 479 | 480 | train_data_list.append(d_list) 481 | 482 | if weights_str is not None: 483 | for weights in weights_str: 484 | train_data_weights.append(float(weights)) 485 | else: 486 | for i in range(len(train_data_list)): 487 | train_data_weights.append(1) 488 | 489 | for named_path in eval_data_named_path: 490 | ind = named_path.find(':') 491 | name = named_path[:ind] 492 | path = named_path[ind + 1:] 493 | if name in registered_path: 494 | d_list = common.load_jsonl(registered_path[name]) 495 | else: 496 | d_list = common.load_jsonl(path) 497 | eval_data_name.append(name) 498 | eval_data_path.append(path) 499 | 500 | eval_data_list.append(d_list) 501 | 502 | assert len(train_data_weights) == len(train_data_list) 503 | 504 | batching_schema = { 505 | 'uid': RawFlintField(), 506 | 'y': LabelFlintField(), 507 | 'input_ids': ArrayIndexFlintField(pad_idx=padding_token_value, left_pad=left_pad), 508 | 'token_type_ids': ArrayIndexFlintField(pad_idx=padding_segement_value, left_pad=left_pad), 509 | 'attention_mask': ArrayIndexFlintField(pad_idx=padding_att_value, left_pad=left_pad), 510 | } 511 | 512 | data_transformer = NLITransform(model_name, tokenizer, max_length) 513 | # data_transformer = NLITransform(model_name, tokenizer, max_length, with_element=True) 514 | 515 | eval_data_loaders = [] 516 | for eval_d_list in eval_data_list: 517 | d_dataset, d_sampler, d_dataloader = build_eval_dataset_loader_and_sampler(eval_d_list, data_transformer, 518 | batching_schema, 519 | batch_size_per_gpu_eval) 520 | eval_data_loaders.append(d_dataloader) 521 | 522 | # Estimate the training size: 523 | training_list = [] 524 | for i in range(len(train_data_list)): 525 | print("Build Training Data ...") 526 | train_d_list = train_data_list[i] 527 | train_d_name = train_data_name[i] 528 | train_d_weight = train_data_weights[i] 529 | cur_train_list = sample_data_list(train_d_list, train_d_weight) # change later # we can apply different sample strategy here. 530 | print(f"Data Name:{train_d_name}; Weight: {train_d_weight}; " 531 | f"Original Size: {len(train_d_list)}; Sampled Size: {len(cur_train_list)}") 532 | training_list.extend(cur_train_list) 533 | estimated_training_size = len(training_list) 534 | print("Estimated training size:", estimated_training_size) 535 | # Estimate the training size ends: 536 | 537 | # t_total = estimated_training_size // args.gradient_accumulation_steps * num_epoch 538 | # t_total = estimated_training_size * num_epoch // args.actual_train_batch_size 539 | if args.total_step <= 0: 540 | t_total = estimated_training_size * num_epoch // args.actual_train_batch_size 541 | else: 542 | t_total = args.total_step 543 | 544 | if args.warmup_steps <= 0: # set the warmup steps to 0.1 * total step if the given warmup step is -1. 545 | args.warmup_steps = int(t_total * 0.1) 546 | 547 | if not args.cpu: 548 | torch.cuda.set_device(args.local_rank) 549 | model.cuda(args.local_rank) 550 | 551 | no_decay = ["bias", "LayerNorm.weight"] 552 | optimizer_grouped_parameters = [ 553 | { 554 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 555 | "weight_decay": args.weight_decay, 556 | }, 557 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 558 | ] 559 | 560 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 561 | scheduler = get_linear_schedule_with_warmup( 562 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 563 | ) 564 | 565 | global_step = 0 566 | 567 | if args.resume_path: 568 | print("Resume Training") 569 | global_step = args.global_iteration 570 | print("Resume Global Step: ", global_step) 571 | model.load_state_dict(torch.load(str(Path(args.resume_path) / "model.pt"), map_location=torch.device('cpu'))) 572 | optimizer.load_state_dict(torch.load(str(Path(args.resume_path) / "optimizer.pt"), map_location=torch.device('cpu'))) 573 | scheduler.load_state_dict(torch.load(str(Path(args.resume_path) / "scheduler.pt"), map_location=torch.device('cpu'))) 574 | print("State Resumed") 575 | 576 | if args.fp16: 577 | try: 578 | from apex import amp 579 | except ImportError: 580 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 581 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 582 | 583 | if not args.cpu and not args.single_gpu: 584 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], 585 | output_device=local_rank, find_unused_parameters=True) 586 | 587 | args_dict = dict(vars(args)) 588 | file_path_prefix = '.' 589 | if args.global_rank in [-1, 0]: 590 | print("Total Steps:", t_total) 591 | args.total_step = t_total 592 | print("Warmup Steps:", args.warmup_steps) 593 | print("Actual Training Batch Size:", actual_train_batch_size) 594 | print("Arguments", pp.pprint(args)) 595 | 596 | is_finished = False 597 | 598 | # Let build the logger and log everything before the start of the first training epoch. 599 | if args.global_rank in [-1, 0]: # only do logging if we use cpu or global_rank=0 600 | resume_prefix = "" 601 | # if args.resume_path: 602 | # resume_prefix = "resumed_" 603 | 604 | if not args.debug_mode: 605 | file_path_prefix, date = save_tool.gen_file_prefix(f"{args.experiment_name}") 606 | # # # Create Log File 607 | # Save the source code. 608 | script_name = os.path.basename(__file__) 609 | with open(os.path.join(file_path_prefix, script_name), 'w') as out_f, open(__file__, 'r') as it: 610 | out_f.write(it.read()) 611 | out_f.flush() 612 | 613 | # Save option file 614 | common.save_json(args_dict, os.path.join(file_path_prefix, "args.json")) 615 | checkpoints_path = Path(file_path_prefix) / "checkpoints" 616 | if not checkpoints_path.exists(): 617 | checkpoints_path.mkdir() 618 | prediction_path = Path(file_path_prefix) / "predictions" 619 | if not prediction_path.exists(): 620 | prediction_path.mkdir() 621 | 622 | # if this is a resumed, then we save the resumed path. 623 | if args.resume_path: 624 | with open(os.path.join(file_path_prefix, "resume_log.txt"), 'w') as out_f: 625 | out_f.write(str(args.resume_path)) 626 | out_f.flush() 627 | 628 | # print(f"Global Rank:{args.global_rank} ### ", 'Init!') 629 | 630 | for epoch in tqdm(range(num_epoch), desc="Epoch", disable=args.global_rank not in [-1, 0]): 631 | # Let's build up training dataset for this epoch 632 | training_list = [] 633 | for i in range(len(train_data_list)): 634 | print("Build Training Data ...") 635 | train_d_list = train_data_list[i] 636 | train_d_name = train_data_name[i] 637 | train_d_weight = train_data_weights[i] 638 | cur_train_list = sample_data_list(train_d_list, train_d_weight) # change later # we can apply different sample strategy here. 639 | print(f"Data Name:{train_d_name}; Weight: {train_d_weight}; " 640 | f"Original Size: {len(train_d_list)}; Sampled Size: {len(cur_train_list)}") 641 | training_list.extend(cur_train_list) 642 | 643 | random.shuffle(training_list) 644 | train_dataset = NLIDataset(training_list, data_transformer) 645 | 646 | train_sampler = SequentialSampler(train_dataset) 647 | if not args.cpu and not args.single_gpu: 648 | print("Use distributed sampler.") 649 | train_sampler = DistributedSampler(train_dataset, args.world_size, args.global_rank, 650 | shuffle=True) 651 | 652 | train_dataloader = DataLoader(dataset=train_dataset, 653 | batch_size=batch_size_per_gpu_train, 654 | shuffle=False, # 655 | num_workers=0, 656 | pin_memory=True, 657 | sampler=train_sampler, 658 | collate_fn=BaseBatchBuilder(batching_schema)) # 659 | # training build finished. 660 | 661 | print(debug_node_info(args), "epoch: ", epoch) 662 | 663 | if not args.cpu and not args.single_gpu: 664 | if args.sampler_seed == -1: 665 | train_sampler.set_epoch(epoch) # setup the epoch to ensure random sampling at each epoch 666 | else: 667 | train_sampler.set_epoch(epoch + args.sampler_seed) 668 | 669 | for forward_step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", 670 | disable=args.global_rank not in [-1, 0]), 0): 671 | model.train() 672 | 673 | batch = move_to_device(batch, local_rank) 674 | # print(batch['input_ids'], batch['y']) 675 | if args.model_class_name in ["distilbert", "bart-large"]: 676 | outputs = model(batch['input_ids'], 677 | attention_mask=batch['attention_mask'], 678 | labels=batch['y']) 679 | else: 680 | outputs = model(batch['input_ids'], 681 | attention_mask=batch['attention_mask'], 682 | token_type_ids=batch['token_type_ids'], 683 | labels=batch['y']) 684 | loss, logits = outputs[:2] 685 | # print(debug_node_info(args), loss, logits, batch['uid']) 686 | # print(debug_node_info(args), loss, batch['uid']) 687 | 688 | # Accumulated loss 689 | if args.gradient_accumulation_steps > 1: 690 | loss = loss / args.gradient_accumulation_steps 691 | 692 | # if this forward step need model updates 693 | # handle fp16 694 | if args.fp16: 695 | with amp.scale_loss(loss, optimizer) as scaled_loss: 696 | scaled_loss.backward() 697 | else: 698 | loss.backward() 699 | 700 | # Gradient clip: if max_grad_norm < 0 701 | if (forward_step + 1) % args.gradient_accumulation_steps == 0: 702 | if args.max_grad_norm > 0: 703 | if args.fp16: 704 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 705 | else: 706 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 707 | 708 | optimizer.step() 709 | scheduler.step() # Update learning rate schedule 710 | model.zero_grad() 711 | 712 | global_step += 1 713 | 714 | if args.global_rank in [-1, 0] and args.eval_frequency > 0 and global_step % args.eval_frequency == 0: 715 | r_dict = dict() 716 | # Eval loop: 717 | for i in range(len(eval_data_name)): 718 | cur_eval_data_name = eval_data_name[i] 719 | cur_eval_data_list = eval_data_list[i] 720 | cur_eval_dataloader = eval_data_loaders[i] 721 | # cur_eval_raw_data_list = eval_raw_data_list[i] 722 | 723 | evaluation_dataset(args, cur_eval_dataloader, cur_eval_data_list, model, r_dict, 724 | eval_name=cur_eval_data_name) 725 | 726 | # saving checkpoints 727 | current_checkpoint_filename = \ 728 | f'e({epoch})|i({global_step})' 729 | 730 | for i in range(len(eval_data_name)): 731 | cur_eval_data_name = eval_data_name[i] 732 | current_checkpoint_filename += \ 733 | f'|{cur_eval_data_name}#({round(r_dict[cur_eval_data_name]["acc"], 4)})' 734 | 735 | if not args.debug_mode: 736 | # save model: 737 | model_output_dir = checkpoints_path / current_checkpoint_filename 738 | if not model_output_dir.exists(): 739 | model_output_dir.mkdir() 740 | model_to_save = ( 741 | model.module if hasattr(model, "module") else model 742 | ) # Take care of distributed/parallel training 743 | 744 | torch.save(model_to_save.state_dict(), str(model_output_dir / "model.pt")) 745 | torch.save(optimizer.state_dict(), str(model_output_dir / "optimizer.pt")) 746 | torch.save(scheduler.state_dict(), str(model_output_dir / "scheduler.pt")) 747 | 748 | # save prediction: 749 | if not args.debug_mode and args.save_prediction: 750 | cur_results_path = prediction_path / current_checkpoint_filename 751 | if not cur_results_path.exists(): 752 | cur_results_path.mkdir(parents=True) 753 | for key, item in r_dict.items(): 754 | common.save_jsonl(item['predictions'], cur_results_path / f"{key}.jsonl") 755 | 756 | # avoid saving too many things 757 | for key, item in r_dict.items(): 758 | del r_dict[key]['predictions'] 759 | common.save_json(r_dict, cur_results_path / "results_dict.json", indent=2) 760 | 761 | if args.total_step > 0 and global_step == t_total: 762 | # if we set total step and global step s t_total. 763 | is_finished = True 764 | break 765 | 766 | # End of epoch evaluation. 767 | if args.global_rank in [-1, 0] and args.total_step <= 0: 768 | r_dict = dict() 769 | # Eval loop: 770 | for i in range(len(eval_data_name)): 771 | cur_eval_data_name = eval_data_name[i] 772 | cur_eval_data_list = eval_data_list[i] 773 | cur_eval_dataloader = eval_data_loaders[i] 774 | # cur_eval_raw_data_list = eval_raw_data_list[i] 775 | 776 | evaluation_dataset(args, cur_eval_dataloader, cur_eval_data_list, model, r_dict, 777 | eval_name=cur_eval_data_name) 778 | 779 | # saving checkpoints 780 | current_checkpoint_filename = \ 781 | f'e({epoch})|i({global_step})' 782 | 783 | for i in range(len(eval_data_name)): 784 | cur_eval_data_name = eval_data_name[i] 785 | current_checkpoint_filename += \ 786 | f'|{cur_eval_data_name}#({round(r_dict[cur_eval_data_name]["acc"], 4)})' 787 | 788 | if not args.debug_mode: 789 | # save model: 790 | model_output_dir = checkpoints_path / current_checkpoint_filename 791 | if not model_output_dir.exists(): 792 | model_output_dir.mkdir() 793 | model_to_save = ( 794 | model.module if hasattr(model, "module") else model 795 | ) # Take care of distributed/parallel training 796 | 797 | torch.save(model_to_save.state_dict(), str(model_output_dir / "model.pt")) 798 | torch.save(optimizer.state_dict(), str(model_output_dir / "optimizer.pt")) 799 | torch.save(scheduler.state_dict(), str(model_output_dir / "scheduler.pt")) 800 | 801 | # save prediction: 802 | if not args.debug_mode and args.save_prediction: 803 | cur_results_path = prediction_path / current_checkpoint_filename 804 | if not cur_results_path.exists(): 805 | cur_results_path.mkdir(parents=True) 806 | for key, item in r_dict.items(): 807 | common.save_jsonl(item['predictions'], cur_results_path / f"{key}.jsonl") 808 | 809 | # avoid saving too many things 810 | for key, item in r_dict.items(): 811 | del r_dict[key]['predictions'] 812 | common.save_json(r_dict, cur_results_path / "results_dict.json", indent=2) 813 | 814 | if is_finished: 815 | break 816 | 817 | 818 | id2label = { 819 | 0: 'e', 820 | 1: 'n', 821 | 2: 'c', 822 | -1: '-', 823 | } 824 | 825 | 826 | def count_acc(gt_list, pred_list): 827 | assert len(gt_list) == len(pred_list) 828 | gt_dict = list_dict_data_tool.list_to_dict(gt_list, 'uid') 829 | pred_list = list_dict_data_tool.list_to_dict(pred_list, 'uid') 830 | total_count = 0 831 | hit = 0 832 | for key, value in pred_list.items(): 833 | if gt_dict[key]['label'] == value['predicted_label']: 834 | hit += 1 835 | total_count += 1 836 | return hit, total_count 837 | 838 | 839 | def evaluation_dataset(args, eval_dataloader, eval_list, model, r_dict, eval_name): 840 | # r_dict = dict() 841 | pred_output_list = eval_model(model, eval_dataloader, args.global_rank, args) 842 | predictions = pred_output_list 843 | hit, total = count_acc(eval_list, pred_output_list) 844 | 845 | print(debug_node_info(args), f"{eval_name} Acc:", hit, total, hit / total) 846 | 847 | r_dict[f'{eval_name}'] = { 848 | 'acc': hit / total, 849 | 'correct_count': hit, 850 | 'total_count': total, 851 | 'predictions': predictions, 852 | } 853 | 854 | 855 | def eval_model(model, dev_dataloader, device_num, args): 856 | model.eval() 857 | 858 | uid_list = [] 859 | y_list = [] 860 | pred_list = [] 861 | logits_list = [] 862 | 863 | with torch.no_grad(): 864 | for i, batch in enumerate(dev_dataloader, 0): 865 | batch = move_to_device(batch, device_num) 866 | 867 | if args.model_class_name in ["distilbert", "bart-large"]: 868 | outputs = model(batch['input_ids'], 869 | attention_mask=batch['attention_mask'], 870 | labels=batch['y']) 871 | else: 872 | outputs = model(batch['input_ids'], 873 | attention_mask=batch['attention_mask'], 874 | token_type_ids=batch['token_type_ids'], 875 | labels=batch['y']) 876 | 877 | loss, logits = outputs[:2] 878 | 879 | uid_list.extend(list(batch['uid'])) 880 | y_list.extend(batch['y'].tolist()) 881 | pred_list.extend(torch.max(logits, 1)[1].view(logits.size(0)).tolist()) 882 | logits_list.extend(logits.tolist()) 883 | 884 | assert len(pred_list) == len(logits_list) 885 | assert len(pred_list) == len(logits_list) 886 | 887 | result_items_list = [] 888 | for i in range(len(uid_list)): 889 | r_item = dict() 890 | r_item['uid'] = uid_list[i] 891 | r_item['logits'] = logits_list[i] 892 | r_item['predicted_label'] = id2label[pred_list[i]] 893 | 894 | result_items_list.append(r_item) 895 | 896 | return result_items_list 897 | 898 | 899 | def debug_node_info(args): 900 | names = ['global_rank', 'local_rank', 'node_rank'] 901 | values = [] 902 | 903 | for name in names: 904 | if name in args: 905 | values.append(getattr(args, name)) 906 | else: 907 | return "Pro:No node info " 908 | 909 | return "Pro:" + '|'.join([f"{name}:{value}" for name, value in zip(names, values)]) + "||Print:" 910 | 911 | 912 | if __name__ == '__main__': 913 | main() --------------------------------------------------------------------------------