├── models ├── __init__.py ├── proto │ ├── protonet.sh │ ├── protaugment.py │ ├── protaugment-tmp.py │ └── protonet.py ├── matching │ ├── matchingnet.sh │ └── matchingnet.py ├── relation │ ├── relationnet.sh │ └── relationnet.py ├── bert_baseline │ ├── baseline.sh │ └── baseline.py ├── induction │ ├── inductionnet.sh │ └── inductionnet.py └── encoders │ └── bert_encoder.py ├── .gitignore ├── utils ├── python.py ├── math.py ├── scripts │ ├── prepate-intent-dataset.py │ └── runner.sh ├── data.py └── few_shot.py ├── README.md ├── requirements.txt └── language_modeling └── run_language_modeling.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders.bert_encoder import BERTEncoder 2 | from .bert_baseline.baseline import BaselineNet 3 | from .induction.inductionnet import InductionNet 4 | from .matching.matchingnet import MatchingNet 5 | from .proto.protonet import ProtoNet 6 | from .relation.relationnet import RelationNet 7 | -------------------------------------------------------------------------------- /models/proto/protonet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd $HOME/Projects/FewShotText 3 | source .venv/bin/activate 4 | source .envrc 5 | echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" 6 | command -v nvidia-smi >/dev/null && { 7 | echo "GPU Devices:" 8 | nvidia-smi 9 | } || { 10 | : 11 | } 12 | 13 | PYTHONPATH=. python models/proto/protonet.py $@ 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .envrc 2 | .venv 3 | .idea 4 | **/__pycache__ 5 | # Runs folder 6 | runs 7 | runs* 8 | runs.bkp 9 | tmp 10 | 11 | # Shell scripts are ignored by default 12 | **/*.sh 13 | 14 | # Jupyter stuff 15 | notebooks 16 | **/.ipynb_checkpoints 17 | 18 | # Spreadsheet stuff 19 | **/*.ods 20 | **/*.xlsx 21 | 22 | # Ignore data by default 23 | data 24 | **/*.zip 25 | -------------------------------------------------------------------------------- /models/matching/matchingnet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd $HOME/Projects/FewShotText 3 | source .venv/bin/activate 4 | source .envrc 5 | echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" 6 | command -v nvidia-smi >/dev/null && { 7 | echo "GPU Devices:" 8 | nvidia-smi 9 | } || { 10 | : 11 | } 12 | 13 | PYTHONPATH=. python models/matching/matchingnet.py $@ 14 | -------------------------------------------------------------------------------- /models/relation/relationnet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd $HOME/Projects/FewShotText 3 | source .venv/bin/activate 4 | source .envrc 5 | echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" 6 | command -v nvidia-smi >/dev/null && { 7 | echo "GPU Devices:" 8 | nvidia-smi 9 | } || { 10 | : 11 | } 12 | 13 | PYTHONPATH=. python models/relation/relationnet.py $@ 14 | -------------------------------------------------------------------------------- /models/bert_baseline/baseline.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd $HOME/Projects/FewShotText 3 | source .venv/bin/activate 4 | source .envrc 5 | echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" 6 | command -v nvidia-smi >/dev/null && { 7 | echo "GPU Devices:" 8 | nvidia-smi 9 | } || { 10 | : 11 | } 12 | 13 | PYTHONPATH=. python models/bert_baseline/baseline.py $@ 14 | -------------------------------------------------------------------------------- /models/induction/inductionnet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd $HOME/Projects/FewShotText 3 | source .venv/bin/activate 4 | source .envrc 5 | echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" 6 | command -v nvidia-smi >/dev/null && { 7 | echo "GPU Devices:" 8 | nvidia-smi 9 | } || { 10 | : 11 | } 12 | 13 | PYTHONPATH=. python models/induction/inductionnet.py $@ 14 | -------------------------------------------------------------------------------- /utils/python.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import random 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def now(): 8 | return datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S.%f") 9 | 10 | 11 | def set_seeds(seed: int) -> None: 12 | """ 13 | set random seeds 14 | :param seed: int 15 | :return: None 16 | """ 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | -------------------------------------------------------------------------------- /utils/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def euclidean_dist(x, y): 5 | # x: N x D 6 | # y: M x D 7 | n = x.size(0) 8 | m = y.size(0) 9 | d = x.size(1) 10 | assert d == y.size(1) 11 | 12 | x = x.unsqueeze(1).expand(n, m, d) 13 | y = y.unsqueeze(0).expand(n, m, d) 14 | 15 | return torch.pow(x - y, 2).sum(2) 16 | 17 | 18 | def cosine_similarity(x, y): 19 | x = (x / x.norm(dim=1).view(-1, 1)) 20 | y = (y / y.norm(dim=1).view(-1, 1)) 21 | 22 | return x @ y.T 23 | -------------------------------------------------------------------------------- /utils/scripts/prepate-intent-dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from utils.data import get_jsonl_data, write_jsonl_data, write_txt_data 3 | 4 | 5 | def process_dataset(name): 6 | full = get_jsonl_data(f"data/{name}/full.jsonl") 7 | random.shuffle(full) 8 | 9 | labels = sorted(set([d['label'] for d in full])) 10 | n_labels = len(labels) 11 | random.seed(42) 12 | random.shuffle(labels) 13 | 14 | train_labels, valid_labels, test_labels = ( 15 | labels[:int(n_labels / 3)], 16 | labels[int(n_labels / 3):int(2 * n_labels / 3)], 17 | labels[int(2 * n_labels / 3):] 18 | ) 19 | 20 | write_jsonl_data([ 21 | d for d in full if d['label'] in train_labels 22 | ], f"data/{name}/train.jsonl", force=True) 23 | write_txt_data(train_labels, f"data/{name}/labels.train.txt") 24 | 25 | write_jsonl_data([ 26 | d for d in full if d['label'] in valid_labels 27 | ], f"data/{name}/valid.jsonl", force=True) 28 | write_txt_data(valid_labels, f"data/{name}/labels.valid.txt") 29 | 30 | write_jsonl_data([ 31 | d for d in full if d['label'] in test_labels 32 | ], f"data/{name}/test.jsonl", force=True) 33 | write_txt_data(test_labels, f"data/{name}/labels.test.txt") 34 | 35 | 36 | if __name__ == "__main__": 37 | for name in ('OOS', 'TREC28', 'Liu'): 38 | process_dataset(name) 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FewShotText 2 | This repository contains code for the paper [A Neural Few-Shot Text Classification Reality Check](https://arxiv.org/abs/2101.12073) 3 | 4 | ## Environment setup 5 | ```bash 6 | # Create environment 7 | python3 -m virtualenv .venv --python=python3.6 8 | 9 | # Install environment 10 | .venv/bin/pip install -r requirements.txt 11 | 12 | # Activate environment 13 | source .venv/bin/activate 14 | ``` 15 | 16 | ## Fine-tuning BERT on the MLM task 17 | ```bash 18 | model_name=bert-base-cased 19 | block_size=256 20 | dataset=OOS 21 | output_dir=transformer_models/${dataset}/fine-tuned 22 | 23 | python language_modeling/run_language_modeling.py \ 24 | --model_name_or_path ${model_name} \ 25 | --output_dir ${output_dir} \ 26 | --mlm \ 27 | --do_train \ 28 | --train_data_file data/${dataset}/full/full-train.txt \ 29 | --do_eval \ 30 | --eval_data_file data/${dataset}/full/full-test.txt \ 31 | --overwrite_output_dir \ 32 | --logging_steps=1000 \ 33 | --line_by_line \ 34 | --logging_dir ${output_dir} \ 35 | --block_size ${block_size} \ 36 | --save_steps=1000 \ 37 | --num_train_epochs 20 \ 38 | --save_total_limit 20 \ 39 | --seed 42 \ 40 | --evaluation_strategy epoch 41 | ``` 42 | 43 | ## Training a few-shot model 44 | To run the paper's experiments, simply use the ```utils/scripts/runner.sh``` file. 45 | 46 | ## Reference 47 | If you use the data or codes in this repository, please cite our paper: 48 | ```bash 49 | @article{dopierre2021neural, 50 | title={A Neural Few-Shot Text Classification Reality Check}, 51 | author={Dopierre, Thomas and Gravier, Christophe and Logerais, Wilfried}, 52 | journal={arXiv preprint arXiv:2101.12073}, 53 | year={2021} 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /models/encoders/bert_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch.nn as nn 4 | import logging 5 | import warnings 6 | import torch 7 | from transformers import AutoModel, AutoTokenizer 8 | 9 | logging.basicConfig() 10 | logger = logging.getLogger(__name__) 11 | logger.setLevel(logging.DEBUG) 12 | 13 | warnings.simplefilter('ignore') 14 | 15 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 16 | 17 | 18 | class BERTEncoder(nn.Module): 19 | def __init__(self, config_name_or_path): 20 | super(BERTEncoder, self).__init__() 21 | logger.info(f"Loading Encoder @ {config_name_or_path}") 22 | self.tokenizer = AutoTokenizer.from_pretrained(config_name_or_path) 23 | self.bert = AutoModel.from_pretrained(config_name_or_path).to(device) 24 | logger.info(f"Encoder loaded.") 25 | self.warmed: bool = False 26 | 27 | def embed_sentences(self, sentences: List[str]): 28 | if self.warmed: 29 | padding = True 30 | else: 31 | padding = "max_length" 32 | self.warmed = True 33 | batch = self.tokenizer( 34 | sentences, 35 | return_tensors="pt", 36 | max_length=64, 37 | truncation=True, 38 | padding=padding 39 | ) 40 | batch = {k: v.to(device) for k, v in batch.items()} 41 | 42 | fw = self.bert.forward(**batch) 43 | return fw.pooler_output 44 | 45 | def forward(self, sentences: List[str]): 46 | try: 47 | return self.embed_sentences(sentences) 48 | except Exception as e: 49 | logger.error(f"could not embed sentence {sentences} (err: {type(e)}, {e}, {str(e)}") 50 | raise e 51 | 52 | 53 | def test(): 54 | encoder = BERTEncoder("bert-base-cased") 55 | sentences = ["test sentence #1", "test sentence #2🍇"] 56 | encoder.embed_sentences(sentences) 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | altair==4.1.0 3 | antlr4-python3-runtime==4.8 4 | argon2-cffi==20.1.0 5 | astor==0.8.1 6 | astunparse==1.6.3 7 | async-generator==1.10 8 | attrs==20.3.0 9 | backcall==0.2.0 10 | base58==2.0.1 11 | bleach==3.3.0 12 | blessings==1.7 13 | blinker==1.4 14 | cachetools==4.1.0 15 | certifi==2020.4.5.2 16 | cffi==1.14.4 17 | chardet==3.0.4 18 | charset-normalizer==2.0.3 19 | click==7.1.2 20 | configparser==5.0.1 21 | Cython==0.29.21 22 | dataclasses==0.7 23 | decorator==4.4.2 24 | defusedxml==0.6.0 25 | docker-pycreds==0.4.0 26 | entrypoints==0.3 27 | fairseq==0.10.1 28 | filelock==3.0.12 29 | flatbuffers==1.12 30 | future==0.18.2 31 | gast==0.3.3 32 | gitdb==4.0.5 33 | GitPython==3.1.11 34 | google-auth==1.17.2 35 | google-auth-oauthlib==0.4.1 36 | google-pasta==0.2.0 37 | grpcio==1.32.0 38 | h5py==2.10.0 39 | huggingface-hub==0.0.12 40 | hydra-core==1.0.4 41 | idna==2.9 42 | importlib-metadata==1.6.1 43 | importlib-resources==3.3.0 44 | ipykernel==5.4.2 45 | ipython==7.16.1 46 | ipython-genutils==0.2.0 47 | ipywidgets==7.5.1 48 | jedi==0.17.2 49 | Jinja2==2.11.3 50 | joblib==0.15.1 51 | jsonpatch==1.25 52 | jsonpointer==2.0 53 | jsonschema==3.2.0 54 | jupyter-client==6.1.7 55 | jupyter-core==4.7.0 56 | jupyterlab-pygments==0.1.2 57 | Keras-Preprocessing==1.1.2 58 | Markdown==3.2.2 59 | MarkupSafe==1.1.1 60 | mistune==0.8.4 61 | nbclient==0.5.1 62 | nbconvert==6.0.7 63 | nbformat==5.0.8 64 | nest-asyncio==1.4.3 65 | nltk==3.5 66 | notebook==6.1.5 67 | numpy==1.19.4 68 | oauthlib==3.1.0 69 | omegaconf==2.0.5 70 | opt-einsum==3.3.0 71 | packaging==21.0 72 | pandas==1.1.5 73 | pandocfilters==1.4.3 74 | parso==0.7.1 75 | pexpect==4.8.0 76 | pickleshare==0.7.5 77 | Pillow==8.2.0 78 | portalocker==2.0.0 79 | prometheus-client==0.9.0 80 | promise==2.3 81 | prompt-toolkit==3.0.8 82 | protobuf==3.12.2 83 | psutil==5.7.3 84 | ptyprocess==0.6.0 85 | pyarrow==2.0.0 86 | pyasn1==0.4.8 87 | pyasn1-modules==0.2.8 88 | pybind11==2.5.0 89 | pycparser==2.20 90 | pydeck==0.5.0 91 | Pygments==2.7.4 92 | pyparsing==3.0.0a1 93 | pyrsistent==0.17.3 94 | python-dateutil==2.8.1 95 | pytz==2020.4 96 | PyYAML==5.4 97 | pyzmq==19.0.1 98 | regex==2020.6.8 99 | requests==2.26.0 100 | requests-oauthlib==1.3.0 101 | rouge-score==0.0.4 102 | rsa==4.7 103 | sacrebleu==1.4.14 104 | sacremoses==0.0.43 105 | scikit-learn==0.23.1 106 | scipy==1.5.0rc1 107 | Send2Trash==1.5.0 108 | sentencepiece==0.1.92 109 | sentry-sdk==0.19.5 110 | seqeval==1.2.2 111 | shortuuid==1.0.1 112 | simpletransformers==0.51.3 113 | six==1.15.0 114 | smmap==3.0.4 115 | streamlit==0.72.0 116 | subprocess32==3.5.4 117 | tensorboard==2.4.0 118 | tensorboard-plugin-wit==1.6.0.post3 119 | tensorboardX==2.0 120 | tensorflow==2.4.2 121 | tensorflow-estimator==2.4.0 122 | tensorflow-hub==0.10.0 123 | termcolor==1.1.0 124 | terminado==0.9.1 125 | testpath==0.4.4 126 | threadpoolctl==2.1.0 127 | tokenizers==0.10.3 128 | toml==0.10.2 129 | toolz==0.11.1 130 | torch==1.5.0 131 | torchfile==0.1.0 132 | torchnet==0.0.4 133 | torchvision==0.6.0 134 | tornado==6.0.4 135 | tqdm==4.54.1 136 | traitlets==4.3.3 137 | transformers==4.8.2 138 | typing-extensions==3.7.4.3 139 | tzlocal==2.1 140 | urllib3==1.26.5 141 | validators==0.18.1 142 | visdom==0.1.8.9 143 | wandb==0.10.12 144 | watchdog==1.0.1 145 | wcwidth==0.2.5 146 | webencodings==0.5.1 147 | websocket-client==0.57.0 148 | Werkzeug==1.0.1 149 | widgetsnbextension==3.5.1 150 | wrapt==1.12.1 151 | zipp==3.1.0 152 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import collections 4 | import os 5 | import json 6 | from typing import List, Dict 7 | 8 | 9 | def get_jsonl_data(jsonl_path: str): 10 | assert jsonl_path.endswith(".jsonl") 11 | out = list() 12 | with open(jsonl_path, 'r', encoding="utf-8") as file: 13 | for line in file: 14 | j = json.loads(line.strip()) 15 | out.append(j) 16 | return out 17 | 18 | 19 | def write_jsonl_data(jsonl_data: List[Dict], jsonl_path: str, force=False): 20 | if os.path.exists(jsonl_path) and not force: 21 | raise FileExistsError 22 | with open(jsonl_path, 'w') as file: 23 | for line in jsonl_data: 24 | file.write(json.dumps(line, ensure_ascii=False) + '\n') 25 | 26 | 27 | def get_txt_data(txt_path: str): 28 | assert txt_path.endswith(".txt") 29 | with open(txt_path, "r") as file: 30 | return [line.strip() for line in file.readlines()] 31 | 32 | 33 | def write_txt_data(data: List[str], path: str, force: bool = False): 34 | if os.path.exists(path) and not force: 35 | raise FileExistsError 36 | with open(path, "w") as file: 37 | for line in data: 38 | file.write(line + "\n") 39 | 40 | 41 | def get_tsv_data(tsv_path: str, label: str = None): 42 | out = list() 43 | with open(tsv_path, "r") as file: 44 | for line in file: 45 | line = line.strip().split('\t') 46 | if not label: 47 | label = tsv_path.split('/')[-1] 48 | 49 | out.append({ 50 | "sentence": line[0], 51 | "label": label + str(line[1]) 52 | }) 53 | return out 54 | 55 | 56 | def raw_data_to_dict(data, shuffle=True): 57 | labels_dict = collections.defaultdict(list) 58 | for item in data: 59 | labels_dict[item['label']].append(item) 60 | labels_dict = dict(labels_dict) 61 | if shuffle: 62 | for key, val in labels_dict.items(): 63 | random.shuffle(val) 64 | return labels_dict 65 | 66 | 67 | class UnlabeledDataLoader: 68 | def __init__(self, file_path: str): 69 | self.file_path = file_path 70 | self.raw_data = get_jsonl_data(self.file_path) 71 | self.data_dict = raw_data_to_dict(self.raw_data, shuffle=True) 72 | 73 | def create_episode(self, n_augment: int = 0): 74 | episode = dict() 75 | augmentations = list() 76 | if n_augment: 77 | already_done = list() 78 | for i in range(n_augment): 79 | # Draw a random label 80 | key = random.choice(list(self.data_dict.keys())) 81 | # Draw a random data index 82 | ix = random.choice(range(len(self.data_dict[key]))) 83 | # If already used, re-sample 84 | while (key, ix) in already_done: 85 | key = random.choice(list(self.data_dict.keys())) 86 | ix = random.choice(range(len(self.data_dict[key]))) 87 | already_done.append((key, ix)) 88 | if "augmentations" not in self.data_dict[key][ix]: 89 | raise KeyError(f"Input data {self.data_dict[key][ix]} does not contain any augmentations / is not properly formatted.") 90 | augmentations.append((self.data_dict[key][ix])) 91 | 92 | episode["x_augment"] = augmentations 93 | 94 | return episode 95 | 96 | 97 | class FewShotDataLoader: 98 | def __init__(self, file_path, unlabeled_file_path: str = None): 99 | self.raw_data = get_jsonl_data(file_path) 100 | self.data_dict = raw_data_to_dict(self.raw_data, shuffle=True) 101 | self.unlabeled_file_path = unlabeled_file_path 102 | if self.unlabeled_file_path: 103 | self.unlabeled_data_loader = UnlabeledDataLoader(file_path=self.unlabeled_file_path) 104 | 105 | def create_episode(self, n_support: int = 0, n_classes: int = 0, n_query: int = 0, n_unlabeled: int = 0, n_augment: int = 0): 106 | episode = dict() 107 | if n_classes: 108 | n_classes = min(n_classes, len(self.data_dict.keys())) 109 | rand_keys = np.random.choice(list(self.data_dict.keys()), n_classes, replace=False) 110 | 111 | assert min([len(val) for val in self.data_dict.values()]) >= n_support + n_query + n_unlabeled 112 | 113 | for key, val in self.data_dict.items(): 114 | random.shuffle(val) 115 | 116 | if n_support: 117 | episode["xs"] = [[self.data_dict[k][i] for i in range(n_support)] for k in rand_keys] 118 | if n_query: 119 | episode["xq"] = [[self.data_dict[k][n_support + i] for i in range(n_query)] for k in rand_keys] 120 | 121 | if n_unlabeled: 122 | episode['xu'] = [item for k in rand_keys for item in self.data_dict[k][n_support + n_query:n_support + n_query + n_unlabeled]] 123 | 124 | if n_augment: 125 | episode = dict(**episode, **self.unlabeled_data_loader.create_episode(n_augment=n_augment)) 126 | 127 | return episode 128 | -------------------------------------------------------------------------------- /utils/scripts/runner.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | for n_class in 5 2 3 4; do 4 | for seed in 1 2 3 4 5; do 5 | for shots in 5; do 6 | for dataset in OOS TREC28 Liu; do 7 | 8 | OUTPUT_ROOT="runs/${dataset}/${n_class}C_${shots}K/seed${seed}" 9 | data_params=" 10 | --train-path data/${dataset}/train.jsonl 11 | --valid-path data/${dataset}/valid.jsonl 12 | --test-path data/${dataset}/test.jsonl" 13 | 14 | baseline_params=" 15 | --n-classes ${n_class} 16 | --n-support ${shots} 17 | --n-test-episodes 600" 18 | 19 | few_shot_params=" 20 | --n-classes ${n_class} 21 | --n-support ${shots} 22 | --n-query 5 23 | --n-test-episodes 600" 24 | 25 | model_params=" 26 | --model-name-or-path transformer_models/${dataset}/fine-tuned" 27 | 28 | 29 | baseline_training_params=" 30 | --n-train-epoch 10 31 | --seed ${seed}" 32 | 33 | few_shot_training_params=" 34 | --max-iter 10000 35 | --evaluate-every 100 36 | --early-stop 25 37 | --seed ${seed}" 38 | 39 | # Baseline 40 | OUT_PATH="${OUTPUT_ROOT}/baseline" 41 | if [[ -d "${OUT_PATH}" ]]; then 42 | echo "${OUT_PATH} already exists. Skipping." 43 | else 44 | mkdir -p ${OUT_PATH} 45 | LOGS_PATH="${OUT_PATH}/training.log" 46 | 47 | ./models/bert_baseline/baseline.sh \ 48 | $(echo ${data_params}) \ 49 | $(echo ${baseline_params}) \ 50 | $(echo ${model_params}) \ 51 | $(echo ${baseline_training_params}) \ 52 | --output-path "${OUT_PATH}/output" > ${LOGS_PATH} 53 | fi 54 | 55 | # Induction Network 56 | OUT_PATH="${OUTPUT_ROOT}/induction" 57 | if [[ -d "${OUT_PATH}" ]]; then 58 | echo "${OUT_PATH} already exists. Skipping." 59 | else 60 | mkdir -p ${OUT_PATH} 61 | LOGS_PATH="${OUT_PATH}/training.log" 62 | 63 | ./models/induction/inductionnet.sh \ 64 | $(echo ${data_params}) \ 65 | $(echo ${few_shot_params}) \ 66 | $(echo ${model_params}) \ 67 | $(echo ${few_shot_training_params}) \ 68 | --ntl-n-slices 100 \ 69 | --n-routing-iter 3 \ 70 | --output-path "${OUT_PATH}/output" > ${LOGS_PATH} 71 | fi 72 | 73 | # Relation Network 74 | for relation_module_type in base ntl; do 75 | 76 | OUT_PATH="${OUTPUT_ROOT}/relation-${relation_module_type}" 77 | if [[ -d "${OUT_PATH}" ]]; then 78 | echo "${OUT_PATH} already exists. Skipping." 79 | else 80 | mkdir -p ${OUT_PATH} 81 | LOGS_PATH="${OUT_PATH}/training.log" 82 | 83 | ./models/relation/relationnet.sh \ 84 | $(echo ${data_params}) \ 85 | $(echo ${few_shot_params}) \ 86 | $(echo ${model_params}) \ 87 | $(echo ${few_shot_training_params}) \ 88 | --relation-module-type "${relation_module_type}" \ 89 | --output-path "${OUT_PATH}/output" > ${LOGS_PATH} 90 | 91 | fi 92 | done 93 | 94 | 95 | for metric in euclidean cosine; do 96 | # Baseline++ 97 | OUT_PATH="${OUTPUT_ROOT}/baseline++_${metric}" 98 | if [[ -d "${OUT_PATH}" ]]; then 99 | echo "${OUT_PATH} already exists. Skipping." 100 | else 101 | mkdir -p ${OUT_PATH} 102 | LOGS_PATH="${OUT_PATH}/training.log" 103 | 104 | ./models/bert_baseline/baseline.sh \ 105 | $(echo ${data_params}) \ 106 | $(echo ${baseline_params}) \ 107 | $(echo ${model_params}) \ 108 | $(echo ${baseline_training_params}) \ 109 | --pp --metric "${metric}" \ 110 | --output-path "${OUT_PATH}/output" > ${LOGS_PATH} 111 | fi 112 | 113 | # Matching Network 114 | OUT_PATH="${OUTPUT_ROOT}/matching-${metric}" 115 | if [[ -d "${OUT_PATH}" ]]; then 116 | echo "${OUT_PATH} already exists. Skipping." 117 | else 118 | mkdir -p ${OUT_PATH} 119 | LOGS_PATH="${OUT_PATH}/training.log" 120 | 121 | ./models/matching/matchingnet.sh \ 122 | $(echo ${data_params}) \ 123 | $(echo ${model_params}) \ 124 | $(echo ${few_shot_params}) \ 125 | $(echo ${few_shot_training_params}) \ 126 | --metric ${metric} \ 127 | --output-path "${OUT_PATH}/output" > ${LOGS_PATH} 128 | fi 129 | 130 | # Prototypical Network 131 | OUT_PATH="${OUTPUT_ROOT}/proto-${metric}" 132 | if [[ -d "${OUT_PATH}" ]]; then 133 | echo "${OUT_PATH} already exists. Skipping." 134 | else 135 | mkdir -p ${OUT_PATH} 136 | LOGS_PATH="${OUT_PATH}/training.log" 137 | 138 | ./models/proto/protonet.sh \ 139 | $(echo ${data_params}) \ 140 | $(echo ${model_params}) \ 141 | $(echo ${few_shot_params}) \ 142 | $(echo ${few_shot_training_params}) \ 143 | --metric ${metric} \ 144 | --output-path "${OUT_PATH}/output" > ${LOGS_PATH} 145 | fi 146 | 147 | # Proto++ 148 | OUT_PATH="${OUTPUT_ROOT}/proto++-${metric}" 149 | if [[ -d "${OUT_PATH}" ]]; then 150 | echo "${OUT_PATH} already exists. Skipping." 151 | else 152 | mkdir -p ${OUT_PATH} 153 | LOGS_PATH="${OUT_PATH}/training.log" 154 | 155 | if [[ "${dataset}" == "Liu" ]]; then 156 | n_unlabeled=10 157 | else 158 | n_unlabeled=20 159 | fi 160 | ./models/proto/protonet.sh \ 161 | $(echo ${data_params}) \ 162 | $(echo ${model_params}) \ 163 | $(echo ${few_shot_params}) \ 164 | $(echo ${few_shot_training_params}) \ 165 | --metric ${metric} \ 166 | --n-unlabeled ${n_unlabeled} \ 167 | --output-path "${OUT_PATH}/output" > ${LOGS_PATH} 168 | fi 169 | done 170 | done 171 | done 172 | done 173 | done 174 | -------------------------------------------------------------------------------- /utils/few_shot.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | import random 5 | from typing import List 6 | from utils.data import get_tsv_data 7 | import torch 8 | 9 | 10 | def random_sample_cls(sentences: List[str], labels: List[str], n_support: int, n_query: int, label: str): 11 | """ 12 | Randomly samples Ns examples as support set and Nq as Query set 13 | """ 14 | data = [sentences[i] for i, lab in enumerate(labels) if lab == label] 15 | perm = torch.randperm(len(data)) 16 | idx = perm[:n_support] 17 | support = [data[i] for i in idx] 18 | idx = perm[n_support: n_support + n_query] 19 | query = [data[i] for i in idx] 20 | 21 | return support, query 22 | 23 | 24 | def create_episode(data_dict, n_support, n_classes, n_query, n_unlabeled=0, n_augment=0): 25 | n_classes = min(n_classes, len(data_dict.keys())) 26 | rand_keys = np.random.choice(list(data_dict.keys()), n_classes, replace=False) 27 | 28 | assert min([len(val) for val in data_dict.values()]) >= n_support + n_query + n_unlabeled 29 | 30 | for key, val in data_dict.items(): 31 | random.shuffle(val) 32 | 33 | episode = { 34 | "xs": [ 35 | [data_dict[k][i] for i in range(n_support)] for k in rand_keys 36 | ], 37 | "xq": [ 38 | [data_dict[k][n_support + i] for i in range(n_query)] for k in rand_keys 39 | ] 40 | } 41 | 42 | if n_unlabeled: 43 | episode['xu'] = [ 44 | item for k in rand_keys for item in data_dict[k][n_support + n_query:n_support + n_query + 10] 45 | ] 46 | 47 | if n_augment: 48 | augmentations = list() 49 | already_done = list() 50 | for i in range(n_augment): 51 | # Draw a random label 52 | key = random.choice(list(data_dict.keys())) 53 | # Draw a random data index 54 | ix = random.choice(range(len(data_dict[key]))) 55 | # If already used, re-sample 56 | while (key, ix) in already_done: 57 | key = random.choice(list(data_dict.keys())) 58 | ix = random.choice(range(len(data_dict[key]))) 59 | already_done.append((key, ix)) 60 | if "augmentations" not in data_dict[key][ix]: 61 | raise KeyError(f"Input data {data_dict[key][ix]} does not contain any augmentations / is not properly formatted.") 62 | augmentations.append(( 63 | data_dict[key][ix]["sentence"], 64 | [item["text"] for item in data_dict[key][ix]["augmentations"]] 65 | )) 66 | episode["x_augment"] = augmentations 67 | 68 | return episode 69 | 70 | 71 | def create_ARSC_train_episode(prefix: str = "data/ARSC-Yu/raw", n_support: int = 5, n_query: int = 5, n_unlabeled=0): 72 | labels = sorted( 73 | set([line.strip() for line in open(f"{prefix}/workspace.filtered.list", "r").readlines()]) 74 | - set([line.strip() for line in open(f"{prefix}/workspace.target.list", "r").readlines()])) 75 | 76 | # Pick a random label 77 | label = random.choice(labels) 78 | 79 | # Pick a random binary task (2, 4, 5) 80 | binary_task = random.choice([2, 4, 5]) 81 | 82 | # Fix: this label/binary task sucks 83 | while label == "office_products" and binary_task == 2: 84 | # Pick a random label 85 | label = random.choice(labels) 86 | 87 | # Pick a random binary task (2, 4, 5) 88 | binary_task = random.choice([2, 4, 5]) 89 | 90 | data = ( 91 | get_tsv_data(f"{prefix}/{label}.t{binary_task}.train", label=label) + 92 | get_tsv_data(f"{prefix}/{label}.t{binary_task}.dev", label=label) + 93 | get_tsv_data(f"{prefix}/{label}.t{binary_task}.test", label=label) 94 | ) 95 | 96 | random.shuffle(data) 97 | task = collections.defaultdict(list) 98 | for d in data: 99 | task[d['label']].append(d['sentence']) 100 | task = dict(task) 101 | 102 | assert min([len(val) for val in task.values()]) >= n_support + n_query + n_unlabeled, \ 103 | f"Label {label}_{binary_task}: min samples is {min([len(val) for val in task.values()])} while K+Q+U={n_support + n_query + n_unlabeled}" 104 | 105 | for key, val in task.items(): 106 | random.shuffle(val) 107 | 108 | episode = { 109 | "xs": [ 110 | [task[k][i] for i in range(n_support)] for k in task.keys() 111 | ], 112 | "xq": [ 113 | [task[k][n_support + i] for i in range(n_query)] for k in task.keys() 114 | ] 115 | } 116 | 117 | if n_unlabeled: 118 | episode['xu'] = [ 119 | item for k in task.keys() for item in task[k][n_support + n_query:n_support + n_query + n_unlabeled] 120 | ] 121 | return episode 122 | 123 | 124 | def create_ARSC_test_episode(prefix: str = "data/ARSC-Yu/raw", n_query: int = 5, n_unlabeled=0, set_type: str = "test"): 125 | assert set_type in ("test", "dev") 126 | labels = [line.strip() for line in open(f"{prefix}/workspace.target.list", "r").readlines()] 127 | 128 | # Pick a random label 129 | label = random.choice(labels) 130 | 131 | # Pick a random binary task (2, 4, 5) 132 | binary_task = random.choice([2, 4, 5]) 133 | 134 | support_data = get_tsv_data(f"{prefix}/{label}.t{binary_task}.train", label=label) 135 | assert len(support_data) == 10 # 2 * 5 shots 136 | support_dict = collections.defaultdict(list) 137 | for d in support_data: 138 | support_dict[d['label']].append(d['sentence']) 139 | 140 | query_data = get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.{set_type}", label=label) 141 | query_dict = collections.defaultdict(list) 142 | for d in query_data: 143 | query_dict[d['label']].append(d['sentence']) 144 | 145 | assert min([len(val) for val in query_dict.values()]) >= n_query + n_unlabeled 146 | 147 | for key, val in query_dict.items(): 148 | random.shuffle(val) 149 | 150 | episode = { 151 | "xs": [ 152 | [sentence for sentence in support_dict[k]] for k in sorted(query_dict.keys()) 153 | ], 154 | "xq": [ 155 | [query_dict[k][i] for i in range(n_query)] for k in sorted(query_dict.keys()) 156 | ] 157 | } 158 | 159 | if n_unlabeled: 160 | episode['xu'] = [ 161 | item for k in sorted(query_dict.keys()) for item in query_dict[k][n_query:n_query + n_unlabeled] 162 | ] 163 | return episode 164 | 165 | 166 | def create_ARSC_train_baseline_episode(): 167 | labels = sorted( 168 | set([line.strip() for line in open("data/ARSC-Yu/raw/workspace.filtered.list", "r").readlines()]) 169 | - set([line.strip() for line in open("data/ARSC-Yu/raw/workspace.target.list", "r").readlines()])) 170 | 171 | # Pick a random label 172 | label = random.choice(labels) 173 | 174 | # Pick a random binary task (2, 4, 5) 175 | binary_task = random.choice([2, 4, 5]) 176 | 177 | data = ( 178 | get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.train", label=label) + 179 | get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.dev", label=label) + 180 | get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.test", label=label) 181 | ) 182 | 183 | random.shuffle(data) 184 | task = collections.defaultdict(list) 185 | for d in data: 186 | task[d['label']].append(d['sentence']) 187 | task = dict(task) 188 | 189 | for key, val in task.items(): 190 | random.shuffle(val) 191 | 192 | episode = { 193 | "xs": [ 194 | list(task[k]) for k in task.keys() 195 | ] 196 | } 197 | 198 | return episode 199 | 200 | 201 | def get_ARSC_test_tasks(): 202 | labels = sorted(set([line.strip() for line in open("data/ARSC-Yu/raw/workspace.target.list", "r").readlines()])) 203 | 204 | tasks = list() 205 | for label in labels: 206 | for binary_task in (2, 4, 5): 207 | train_data = get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.train", label=label) 208 | valid_data = get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.dev", label=label) 209 | test_data = get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.test", label=label) 210 | tasks.append({ 211 | "xs": [ 212 | [d['sentence'] for d in train_data if d['label'] == f"{label}-1"], 213 | [d['sentence'] for d in train_data if d['label'] == f"{label}1"], 214 | ], 215 | "x_valid": [ 216 | [d['sentence'] for d in valid_data if d['label'] == f"{label}-1"], 217 | [d['sentence'] for d in valid_data if d['label'] == f"{label}1"], 218 | ], 219 | "x_test": [ 220 | [d['sentence'] for d in test_data if d['label'] == f"{label}-1"], 221 | [d['sentence'] for d in test_data if d['label'] == f"{label}1"], 222 | ], 223 | }) 224 | 225 | assert all([len(task['xs'][0]) == len(task['xs'][1]) for task in tasks]) 226 | return tasks 227 | -------------------------------------------------------------------------------- /language_modeling/run_language_modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). 18 | GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned 19 | using a masked language modeling (MLM) loss. 20 | """ 21 | 22 | import logging 23 | import math 24 | import os 25 | from dataclasses import dataclass, field 26 | from typing import Optional 27 | 28 | from transformers import ( 29 | CONFIG_MAPPING, 30 | MODEL_WITH_LM_HEAD_MAPPING, 31 | AutoConfig, 32 | AutoModelForMaskedLM, 33 | AutoTokenizer, 34 | DataCollatorForLanguageModeling, 35 | HfArgumentParser, 36 | LineByLineTextDataset, 37 | PreTrainedTokenizer, 38 | TextDataset, 39 | Trainer, 40 | TrainingArguments, 41 | set_seed, 42 | ) 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys()) 47 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 48 | 49 | 50 | @dataclass 51 | class ModelArguments: 52 | """ 53 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 54 | """ 55 | 56 | model_name_or_path: Optional[str] = field( 57 | default=None, 58 | metadata={ 59 | "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch." 60 | }, 61 | ) 62 | model_type: Optional[str] = field( 63 | default=None, 64 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 65 | ) 66 | config_name: Optional[str] = field( 67 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 68 | ) 69 | tokenizer_name: Optional[str] = field( 70 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 71 | ) 72 | cache_dir: Optional[str] = field( 73 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 74 | ) 75 | 76 | 77 | @dataclass 78 | class DataTrainingArguments: 79 | """ 80 | Arguments pertaining to what data we are going to input our model for training and eval. 81 | """ 82 | 83 | train_data_file: Optional[str] = field( 84 | default=None, metadata={"help": "The input training data file (a text file)."} 85 | ) 86 | eval_data_file: Optional[str] = field( 87 | default=None, 88 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 89 | ) 90 | line_by_line: bool = field( 91 | default=False, 92 | metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, 93 | ) 94 | 95 | mlm: bool = field( 96 | default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."} 97 | ) 98 | mlm_probability: float = field( 99 | default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} 100 | ) 101 | 102 | block_size: int = field( 103 | default=-1, 104 | metadata={ 105 | "help": "Optional input sequence length after tokenization." 106 | "The training dataset will be truncated in block of this size for training." 107 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 108 | }, 109 | ) 110 | overwrite_cache: bool = field( 111 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 112 | ) 113 | 114 | 115 | def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False, local_rank=-1): 116 | file_path = args.eval_data_file if evaluate else args.train_data_file 117 | if args.line_by_line: 118 | return LineByLineTextDataset( 119 | tokenizer=tokenizer, file_path=file_path, block_size=args.block_size 120 | ) 121 | else: 122 | return TextDataset( 123 | tokenizer=tokenizer, file_path=file_path, block_size=args.block_size 124 | ) 125 | 126 | 127 | def main(): 128 | # See all possible arguments in src/transformers/training_args.py 129 | # or by passing the --help flag to this script. 130 | # We now keep distinct sets of args, for a cleaner separation of concerns. 131 | 132 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 133 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 134 | 135 | if data_args.eval_data_file is None and training_args.do_eval: 136 | raise ValueError( 137 | "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " 138 | "or remove the --do_eval argument." 139 | ) 140 | 141 | if ( 142 | os.path.exists(training_args.output_dir) 143 | and os.listdir(training_args.output_dir) 144 | and training_args.do_train 145 | and not training_args.overwrite_output_dir 146 | ): 147 | raise ValueError( 148 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 149 | ) 150 | 151 | # Setup logging 152 | logging.basicConfig( 153 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 154 | datefmt="%m/%d/%Y %H:%M:%S", 155 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 156 | ) 157 | logger.warning( 158 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 159 | training_args.local_rank, 160 | training_args.device, 161 | training_args.n_gpu, 162 | bool(training_args.local_rank != -1), 163 | training_args.fp16, 164 | ) 165 | logger.info("Training/evaluation parameters %s", training_args) 166 | 167 | # Set seed 168 | set_seed(training_args.seed) 169 | 170 | # Load pretrained model and tokenizer 171 | # 172 | # Distributed training: 173 | # The .from_pretrained methods guarantee that only one local process can concurrently 174 | # download model & vocab. 175 | 176 | if model_args.config_name: 177 | config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) 178 | elif model_args.model_name_or_path: 179 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) 180 | else: 181 | config = CONFIG_MAPPING[model_args.model_type]() 182 | logger.warning("You are instantiating a new config instance from scratch.") 183 | 184 | if model_args.tokenizer_name: 185 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir) 186 | elif model_args.model_name_or_path: 187 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) 188 | else: 189 | raise ValueError( 190 | "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it," 191 | "and load it from here, using --tokenizer_name" 192 | ) 193 | 194 | if model_args.model_name_or_path: 195 | model = AutoModelForMaskedLM.from_pretrained( 196 | model_args.model_name_or_path, 197 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 198 | config=config, 199 | cache_dir=model_args.cache_dir, 200 | ) 201 | else: 202 | logger.info("Training new model from scratch") 203 | model = AutoModelForMaskedLM.from_config(config) 204 | 205 | model.resize_token_embeddings(len(tokenizer)) 206 | 207 | if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm: 208 | raise ValueError( 209 | "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm " 210 | "flag (masked language modeling)." 211 | ) 212 | 213 | if data_args.block_size <= 0: 214 | data_args.block_size = tokenizer.max_len 215 | # Our input block size will be the max possible for the model 216 | else: 217 | data_args.block_size = min(data_args.block_size, tokenizer.model_max_length) 218 | 219 | # Get datasets 220 | train_dataset = ( 221 | get_dataset(data_args, tokenizer=tokenizer) 222 | if training_args.do_train 223 | else None 224 | ) 225 | eval_dataset = ( 226 | get_dataset(data_args, tokenizer=tokenizer, evaluate=True) 227 | if training_args.do_eval 228 | else None 229 | ) 230 | data_collator = DataCollatorForLanguageModeling( 231 | tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability 232 | ) 233 | 234 | # Initialize our Trainer 235 | trainer = Trainer( 236 | model=model, 237 | args=training_args, 238 | data_collator=data_collator, 239 | train_dataset=train_dataset, 240 | eval_dataset=eval_dataset, 241 | ) 242 | 243 | # Training 244 | if training_args.do_train: 245 | model_path = ( 246 | model_args.model_name_or_path 247 | if model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path) 248 | else None 249 | ) 250 | trainer.train(model_path=model_path) 251 | trainer.save_model() 252 | # For convenience, we also re-save the tokenizer to the same directory, 253 | # so that you can share your model easily on huggingface.co/models =) 254 | if trainer.is_world_process_zero(): 255 | tokenizer.save_pretrained(training_args.output_dir) 256 | 257 | # Evaluation 258 | results = {} 259 | if training_args.do_eval and training_args.local_rank in [-1, 0]: 260 | logger.info("*** Evaluate ***") 261 | 262 | eval_output = trainer.evaluate() 263 | perplexity = math.exp(eval_output["eval_loss"]) 264 | result = {"perplexity": perplexity} 265 | 266 | output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt") 267 | with open(output_eval_file, "w") as writer: 268 | logger.info("***** Eval results *****") 269 | for key in sorted(result.keys()): 270 | logger.info(" %s = %s", key, str(result[key])) 271 | writer.write("%s = %s\n" % (key, str(result[key]))) 272 | 273 | results.update(result) 274 | 275 | return results 276 | 277 | 278 | if __name__ == "__main__": 279 | main() 280 | -------------------------------------------------------------------------------- /models/matching/matchingnet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from utils.data import get_jsonl_data 4 | from utils.python import now, set_seeds 5 | import random 6 | import collections 7 | import os 8 | from typing import List, Dict 9 | from tensorboardX import SummaryWriter 10 | import numpy as np 11 | from models.encoders.bert_encoder import BERTEncoder 12 | import torch 13 | import torch.nn as nn 14 | import warnings 15 | import logging 16 | from utils.few_shot import create_episode, create_ARSC_train_episode, create_ARSC_test_episode 17 | from utils.math import euclidean_dist, cosine_similarity 18 | from sklearn.metrics import f1_score 19 | 20 | logging.basicConfig() 21 | logger = logging.getLogger(__name__) 22 | logger.setLevel(logging.DEBUG) 23 | 24 | warnings.simplefilter('ignore') 25 | 26 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 27 | 28 | 29 | class MatchingNet(nn.Module): 30 | def __init__(self, encoder, metric: str = "cosine"): 31 | super(MatchingNet, self).__init__() 32 | 33 | self.encoder = encoder 34 | self.metric = metric 35 | assert self.metric in ("cosine", "euclidean") 36 | 37 | def loss(self, sample): 38 | """ 39 | :param sample: { 40 | "xs": [ 41 | [support_A_1, support_A_2, ...], 42 | [support_B_1, support_B_2, ...], 43 | [support_C_1, support_C_2, ...], 44 | ... 45 | ], 46 | "xq": [ 47 | [query_A_1, query_A_2, ...], 48 | [query_B_1, query_B_2, ...], 49 | [query_C_1, query_C_2, ...], 50 | ... 51 | ] 52 | } 53 | :return: 54 | """ 55 | xs = sample["xs"] # support 56 | xq = sample["xq"] # query 57 | 58 | n_class = len(xs) 59 | assert len(xq) == n_class 60 | n_support = len(xs[0]) 61 | n_query = len(xq[0]) 62 | 63 | x = [item for xs_ in xs for item in xs_] + [item for xq_ in xq for item in xq_] 64 | z = self.encoder.forward(x) 65 | z_support = z[:n_class * n_support] 66 | z_query = z[n_class * n_support:] 67 | 68 | if self.metric == "euclidean": 69 | similarities = -euclidean_dist(z_query, z_support) 70 | elif self.metric == "cosine": 71 | similarities = cosine_similarity(z_query, z_support) * 5 72 | else: 73 | raise NotImplementedError 74 | 75 | # Average over support samples 76 | distances_from_query_to_classes = torch.cat([similarities[:, c * n_support: (c + 1) * n_support].mean(1).view(1, -1) for c in range(n_class)]).T 77 | true_labels = torch.zeros_like(distances_from_query_to_classes) 78 | 79 | for ix_class, class_query_sentences in enumerate(xq): 80 | for ix_sentence, sentence in enumerate(class_query_sentences): 81 | true_labels[ix_class * n_query + ix_sentence, ix_class] = 1 82 | 83 | loss_fn = nn.CrossEntropyLoss() 84 | loss_val = loss_fn(distances_from_query_to_classes, true_labels.argmax(1)) 85 | acc_val = (true_labels.argmax(1) == distances_from_query_to_classes.argmax(1)).float().mean() 86 | f1_val = f1_score( 87 | true_labels.argmax(1).cpu().numpy(), 88 | distances_from_query_to_classes.argmax(1).cpu().numpy(), 89 | average="weighted" 90 | ) 91 | return loss_val, { 92 | "loss": loss_val.item(), 93 | "metrics": { 94 | "acc": acc_val.item(), 95 | "loss": loss_val.item(), 96 | "f1": f1_val, 97 | }, 98 | "y_hat": distances_from_query_to_classes.argmax(1).cpu().detach().numpy() 99 | } 100 | 101 | def train_step(self, optimizer, data_dict: Dict[str, List[str]], n_support, n_classes, n_query): 102 | 103 | episode = create_episode( 104 | data_dict=data_dict, 105 | n_support=n_support, 106 | n_classes=n_classes, 107 | n_query=n_query 108 | ) 109 | 110 | self.train() 111 | optimizer.zero_grad() 112 | torch.cuda.empty_cache() 113 | loss, loss_dict = self.loss(episode) 114 | loss.backward() 115 | optimizer.step() 116 | 117 | return loss, loss_dict 118 | 119 | def test_step(self, data_dict, n_support, n_classes, n_query, n_episodes=1000): 120 | metrics = collections.defaultdict(list) 121 | self.eval() 122 | for i in range(n_episodes): 123 | episode = create_episode( 124 | data_dict=data_dict, 125 | n_support=n_support, 126 | n_classes=n_classes, 127 | n_query=n_query 128 | ) 129 | 130 | with torch.no_grad(): 131 | loss, loss_dict = self.loss(episode) 132 | 133 | for k, v in loss_dict["metrics"].items(): 134 | metrics[k].append(v) 135 | 136 | return { 137 | key: np.mean(value) for key, value in metrics.items() 138 | } 139 | 140 | def train_step_ARSC(self, data_path: str, optimizer): 141 | episode = create_ARSC_train_episode(prefix=data_path, n_support=5, n_query=5) 142 | 143 | self.train() 144 | optimizer.zero_grad() 145 | torch.cuda.empty_cache() 146 | loss, loss_dict = self.loss(episode) 147 | loss.backward() 148 | optimizer.step() 149 | 150 | return loss, loss_dict 151 | 152 | def test_step_ARSC(self, data_path: str, n_episodes=1000, set_type="test"): 153 | assert set_type in ("dev", "test") 154 | metrics = collections.defaultdict(list) 155 | self.eval() 156 | for i in range(n_episodes): 157 | episode = create_ARSC_test_episode(prefix=data_path, n_query=5, set_type=set_type) 158 | 159 | with torch.no_grad(): 160 | loss, loss_dict = self.loss(episode) 161 | 162 | for k, v in loss_dict["metrics"].items(): 163 | metrics[k].append(v) 164 | 165 | return { 166 | key: np.mean(value) for key, value in metrics.items() 167 | } 168 | 169 | 170 | def run_matching( 171 | train_path: str, 172 | model_name_or_path: str, 173 | n_support: int, 174 | n_query: int, 175 | n_classes: int, 176 | valid_path: str = None, 177 | test_path: str = None, 178 | output_path: str = f"runs/{now()}", 179 | max_iter: int = 10000, 180 | evaluate_every: int = 100, 181 | early_stop: int = None, 182 | n_test_episodes: int = 1000, 183 | log_every: int = 10, 184 | metric: str = "cosine", 185 | arsc_format: bool = False, 186 | data_path: str = None 187 | ): 188 | if output_path: 189 | if os.path.exists(output_path) and len(os.listdir(output_path)): 190 | raise FileExistsError(f"Output path {output_path} already exists. Exiting.") 191 | 192 | # -------------------- 193 | # Creating Log Writers 194 | # -------------------- 195 | os.makedirs(output_path) 196 | os.makedirs(os.path.join(output_path, "logs/train")) 197 | train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(output_path, "logs/train"), flush_secs=1, max_queue=1) 198 | valid_writer: SummaryWriter = None 199 | test_writer: SummaryWriter = None 200 | log_dict = dict(train=list()) 201 | 202 | if valid_path: 203 | os.makedirs(os.path.join(output_path, "logs/valid")) 204 | valid_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/valid"), flush_secs=1, max_queue=1) 205 | log_dict["valid"] = list() 206 | if test_path: 207 | os.makedirs(os.path.join(output_path, "logs/test")) 208 | test_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/test"), flush_secs=1, max_queue=1) 209 | log_dict["test"] = list() 210 | 211 | def raw_data_to_labels_dict(data, shuffle=True): 212 | labels_dict = collections.defaultdict(list) 213 | for item in data: 214 | labels_dict[item["label"]].append(item["sentence"]) 215 | labels_dict = dict(labels_dict) 216 | if shuffle: 217 | for key, val in labels_dict.items(): 218 | random.shuffle(val) 219 | return labels_dict 220 | 221 | # Load model 222 | bert = BERTEncoder(model_name_or_path).to(device) 223 | matching_net = MatchingNet(encoder=bert, metric=metric) 224 | optimizer = torch.optim.Adam(matching_net.parameters(), lr=2e-5) 225 | 226 | # Load data 227 | if not arsc_format: 228 | train_data = get_jsonl_data(train_path) 229 | train_data_dict = raw_data_to_labels_dict(train_data, shuffle=True) 230 | logger.info(f"train labels: {train_data_dict.keys()}") 231 | 232 | if valid_path: 233 | valid_data = get_jsonl_data(valid_path) 234 | valid_data_dict = raw_data_to_labels_dict(valid_data, shuffle=True) 235 | logger.info(f"valid labels: {valid_data_dict.keys()}") 236 | else: 237 | valid_data_dict = None 238 | 239 | if test_path: 240 | test_data = get_jsonl_data(test_path) 241 | test_data_dict = raw_data_to_labels_dict(test_data, shuffle=True) 242 | logger.info(f"test labels: {test_data_dict.keys()}") 243 | else: 244 | test_data_dict = None 245 | else: 246 | train_data_dict = None 247 | valid_data_dict = None 248 | test_data_dict = None 249 | 250 | train_metrics = collections.defaultdict(list) 251 | n_eval_since_last_best = 0 252 | best_valid_acc = 0.0 253 | 254 | for step in range(max_iter): 255 | if not arsc_format: 256 | loss, loss_dict = matching_net.train_step( 257 | optimizer=optimizer, 258 | data_dict=train_data_dict, 259 | n_support=n_support, 260 | n_query=n_query, 261 | n_classes=n_classes 262 | ) 263 | else: 264 | loss, loss_dict = matching_net.train_step_ARSC( 265 | data_path=data_path, 266 | optimizer=optimizer 267 | ) 268 | for key, value in loss_dict["metrics"].items(): 269 | train_metrics[key].append(value) 270 | 271 | # Logging 272 | if (step + 1) % log_every == 0: 273 | for key, value in train_metrics.items(): 274 | train_writer.add_scalar(tag=key, scalar_value=np.mean(value), global_step=step) 275 | logger.info(f"train | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in train_metrics.items()])) 276 | log_dict["train"].append({ 277 | "metrics": [ 278 | { 279 | "tag": key, 280 | "value": np.mean(value) 281 | } 282 | for key, value in train_metrics.items() 283 | ], 284 | "global_step": step 285 | }) 286 | 287 | train_metrics = collections.defaultdict(list) 288 | 289 | if valid_path or test_path: 290 | if (step + 1) % evaluate_every == 0: 291 | for path, writer, set_type, set_data in zip( 292 | [valid_path, test_path], 293 | [valid_writer, test_writer], 294 | ["valid", "test"], 295 | [valid_data_dict, test_data_dict] 296 | ): 297 | if path: 298 | if not arsc_format: 299 | set_results = matching_net.test_step( 300 | data_dict=set_data, 301 | n_support=n_support, 302 | n_query=n_query, 303 | n_classes=n_classes, 304 | n_episodes=n_test_episodes 305 | ) 306 | else: 307 | set_results = matching_net.test_step_ARSC( 308 | data_path=data_path, 309 | n_episodes=n_test_episodes, 310 | set_type={"valid": "dev", "test": "test"}[set_type] 311 | ) 312 | 313 | for key, val in set_results.items(): 314 | writer.add_scalar(tag=key, scalar_value=val, global_step=step) 315 | log_dict[set_type].append({ 316 | "metrics": [ 317 | { 318 | "tag": key, 319 | "value": val 320 | } 321 | for key, val in set_results.items() 322 | ], 323 | "global_step": step 324 | }) 325 | 326 | logger.info(f"{set_type} | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in set_results.items()])) 327 | if set_type == "valid": 328 | if set_results["acc"] > best_valid_acc: 329 | best_valid_acc = set_results["acc"] 330 | n_eval_since_last_best = 0 331 | logger.info(f"Better eval results!") 332 | else: 333 | n_eval_since_last_best += 1 334 | logger.info(f"Worse eval results ({n_eval_since_last_best}/{early_stop})") 335 | 336 | if early_stop and n_eval_since_last_best >= early_stop: 337 | logger.warning(f"Early-stopping.") 338 | break 339 | with open(os.path.join(output_path, "metrics.json"), "w") as file: 340 | json.dump(log_dict, file, ensure_ascii=False) 341 | 342 | 343 | def main(): 344 | parser = argparse.ArgumentParser() 345 | parser.add_argument("--train-path", type=str, required=True, help="Path to training data") 346 | parser.add_argument("--valid-path", type=str, default=None, help="Path to validation data") 347 | parser.add_argument("--test-path", type=str, default=None, help="Path to testing data") 348 | parser.add_argument("--data-path", type=str, default=None, help="Path to data (ARSC only)") 349 | 350 | parser.add_argument("--output-path", type=str, default=f"runs/{now()}") 351 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use") 352 | parser.add_argument("--max-iter", type=int, default=10000, help="Max number of training episodes") 353 | parser.add_argument("--evaluate-every", type=int, default=100, help="Number of training episodes between each evaluation (on both valid, test)") 354 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging") 355 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set") 356 | parser.add_argument("--early-stop", type=int, default=0, help="Number of worse evaluation steps before stopping. 0=disabled") 357 | 358 | # Few-Shot related stuff 359 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class") 360 | parser.add_argument("--n-query", type=int, default=5, help="Number of query points for each class") 361 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode") 362 | parser.add_argument("--n-test-episodes", type=int, default=1000, help="Number of episodes during evaluation (valid, test)") 363 | 364 | # Which metric to use 365 | parser.add_argument("--metric", type=str, default="cosine", help="Metric to use", choices=("euclidean", "cosine")) 366 | 367 | # ARSC data 368 | parser.add_argument("--arsc-format", default=False, action="store_true", help="Using ARSC few-shot format") 369 | args = parser.parse_args() 370 | 371 | # Set random seed 372 | set_seeds(args.seed) 373 | 374 | # Check if data path(s) exist 375 | for arg in [args.train_path, args.valid_path, args.test_path]: 376 | if arg and not os.path.exists(arg): 377 | raise FileNotFoundError(f"Data @ {arg} not found.") 378 | 379 | # Run 380 | run_matching( 381 | train_path=args.train_path, 382 | valid_path=args.valid_path, 383 | test_path=args.test_path, 384 | output_path=args.output_path, 385 | 386 | model_name_or_path=args.model_name_or_path, 387 | 388 | n_support=args.n_support, 389 | n_query=args.n_query, 390 | n_classes=args.n_classes, 391 | n_test_episodes=args.n_test_episodes, 392 | 393 | max_iter=args.max_iter, 394 | evaluate_every=args.evaluate_every, 395 | metric=args.metric, 396 | early_stop=args.early_stop, 397 | arsc_format=args.arsc_format, 398 | data_path=args.data_path, 399 | log_every=args.log_every 400 | ) 401 | 402 | # Save config 403 | with open(os.path.join(args.output_path, "config.json"), "w") as file: 404 | json.dump(vars(args), file, ensure_ascii=False) 405 | 406 | 407 | if __name__ == "__main__": 408 | main() 409 | -------------------------------------------------------------------------------- /models/relation/relationnet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from utils.data import get_jsonl_data 4 | from utils.python import now, set_seeds 5 | import random 6 | import collections 7 | import os 8 | from typing import List, Dict 9 | from tensorboardX import SummaryWriter 10 | import numpy as np 11 | from models.encoders.bert_encoder import BERTEncoder 12 | import torch 13 | import torch.nn as nn 14 | import warnings 15 | import logging 16 | from utils.few_shot import create_episode, create_ARSC_train_episode, create_ARSC_test_episode 17 | 18 | logging.basicConfig() 19 | logger = logging.getLogger(__name__) 20 | logger.setLevel(logging.DEBUG) 21 | 22 | warnings.simplefilter('ignore') 23 | 24 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 25 | 26 | 27 | class RelationNet(nn.Module): 28 | def __init__(self, encoder, hidden_dim: int = 768, relation_module_type: str = "base", ntl_n_slices: int = 100): 29 | super(RelationNet, self).__init__() 30 | 31 | self.encoder = encoder 32 | self.relation_module_type = relation_module_type 33 | self.ntl_n_slices = ntl_n_slices 34 | self.hidden_dim = hidden_dim 35 | 36 | # Declare relation module 37 | if self.relation_module_type == "base": 38 | self.relation_module = RelationModule(input_dim=hidden_dim).to(device) 39 | elif self.relation_module_type == "ntl": 40 | self.relation_module = NTLRelationModule(input_dim=hidden_dim, n_slice=self.ntl_n_slices).to(device) 41 | else: 42 | raise NotImplementedError(f"relation module type {self.relation_module_type} not implemented.") 43 | 44 | def loss(self, sample): 45 | """ 46 | :param sample: { 47 | "xs": [ 48 | [support_A_1, support_A_2, ...], 49 | [support_B_1, support_B_2, ...], 50 | [support_C_1, support_C_2, ...], 51 | ... 52 | ], 53 | "xq": [ 54 | [query_A_1, query_A_2, ...], 55 | [query_B_1, query_B_2, ...], 56 | [query_C_1, query_C_2, ...], 57 | ... 58 | ] 59 | } 60 | :return: 61 | """ 62 | xs = sample["xs"] # support 63 | xq = sample["xq"] # query 64 | 65 | n_class = len(xs) 66 | assert len(xq) == n_class 67 | n_support = len(xs[0]) 68 | n_query = len(xq[0]) 69 | 70 | x = [item for xs_ in xs for item in xs_] + [item for xq_ in xq for item in xq_] 71 | z = self.encoder.forward(x) 72 | z_dim = z.size(-1) 73 | 74 | z_query = z[n_class * n_support:] 75 | z_proto = z[:n_class * n_support].view(n_class, n_support, z_dim).mean(1) 76 | 77 | relation_module_scores = self.relation_module.forward(z_q=z_query, z_c=z_proto) 78 | true_labels = torch.zeros_like(relation_module_scores).to(device) 79 | 80 | for ix_class, class_query_sentences in enumerate(xq): 81 | for ix_sentence, sentence in enumerate(class_query_sentences): 82 | true_labels[ix_class * n_query + ix_sentence, ix_class] = 1 83 | 84 | # MSE LOSS 85 | # relation_module_scores = torch.sigmoid(relation_module_scores) 86 | # loss_fn = nn.MSELoss() 87 | # loss_val = loss_fn(relation_module_scores, true_labels) 88 | # acc_full = ((relation_module_scores > 0.5).float() == true_labels.float()).float().mean() 89 | # acc_exact = (((relation_module_scores > 0.5).float() - true_labels.float()).abs().max(dim=1)[0] == 0).float().mean() 90 | # acc_max = (relation_module_scores.argmax(1) == true_labels.argmax(1)).float().mean() 91 | # 92 | # return loss_val, { 93 | # "loss": loss_val.item(), 94 | # "metrics": { 95 | # "loss": loss_val.item(), 96 | # "acc_full": acc_full.item(), 97 | # "acc_exact": acc_exact.item(), 98 | # "acc_max": acc_max.item(), 99 | # "acc": acc_max.item() 100 | # }, 101 | # "y_hat": relation_module_scores.argmax(1).cpu().detach().numpy() 102 | # } 103 | 104 | # CE LOSS 105 | loss_fn = nn.CrossEntropyLoss() 106 | loss_val = loss_fn(relation_module_scores, true_labels.argmax(1)) 107 | acc_val = (true_labels.argmax(1) == relation_module_scores.argmax(1)).float().mean() 108 | return loss_val, { 109 | "loss": loss_val.item(), 110 | "metrics": { 111 | "loss": loss_val.item(), 112 | "acc": acc_val.item() 113 | }, 114 | "y_hat": relation_module_scores.argmax(1).cpu().detach().numpy() 115 | } 116 | 117 | def train_step(self, optimizer, data_dict: Dict[str, List[str]], n_support, n_classes, n_query): 118 | 119 | episode = create_episode( 120 | data_dict=data_dict, 121 | n_support=n_support, 122 | n_classes=n_classes, 123 | n_query=n_query 124 | ) 125 | 126 | self.train() 127 | optimizer.zero_grad() 128 | torch.cuda.empty_cache() 129 | loss, loss_dict = self.loss(episode) 130 | loss.backward() 131 | optimizer.step() 132 | 133 | return loss, loss_dict 134 | 135 | def test_step(self, data_dict, n_support, n_classes, n_query, n_episodes=1000): 136 | metrics = collections.defaultdict(list) 137 | self.eval() 138 | for i in range(n_episodes): 139 | episode = create_episode( 140 | data_dict=data_dict, 141 | n_support=n_support, 142 | n_classes=n_classes, 143 | n_query=n_query 144 | ) 145 | 146 | with torch.no_grad(): 147 | loss, loss_dict = self.loss(episode) 148 | 149 | for key, value in loss_dict["metrics"].items(): 150 | metrics[key].append(value) 151 | 152 | return { 153 | key: np.mean(value) for key, value in metrics.items() 154 | } 155 | 156 | def train_step_ARSC(self, data_path: str, optimizer): 157 | episode = create_ARSC_train_episode(prefix=data_path, n_support=5, n_query=5) 158 | 159 | self.train() 160 | optimizer.zero_grad() 161 | torch.cuda.empty_cache() 162 | loss, loss_dict = self.loss(episode) 163 | loss.backward() 164 | optimizer.step() 165 | 166 | return loss, loss_dict 167 | 168 | def test_step_ARSC(self, data_path: str, n_episodes=1000, set_type="test"): 169 | assert set_type in ("dev", "test") 170 | metrics = collections.defaultdict(list) 171 | self.eval() 172 | for i in range(n_episodes): 173 | episode = create_ARSC_test_episode(prefix=data_path, n_query=5, set_type=set_type) 174 | 175 | with torch.no_grad(): 176 | loss, loss_dict = self.loss(episode) 177 | 178 | for key, value in loss_dict["metrics"].items(): 179 | metrics[key].append(value) 180 | 181 | return { 182 | key: np.mean(value) for key, value in metrics.items() 183 | } 184 | 185 | 186 | class RelationModule(nn.Module): 187 | def __init__(self, input_dim): 188 | super(RelationModule, self).__init__() 189 | self.fc1 = nn.Sequential( 190 | nn.Linear(in_features=input_dim * 2, out_features=input_dim), 191 | nn.ReLU(), 192 | nn.Dropout(p=0.25) 193 | ) 194 | self.fc2 = nn.Sequential( 195 | nn.Linear(in_features=input_dim, out_features=1) 196 | ) 197 | 198 | def forward(self, z_q, z_c): 199 | n_class = z_c.size(0) 200 | n_query = z_q.size(0) 201 | concatenated = torch.cat(( 202 | z_q.repeat(1, n_class).view(-1, z_q.size(-1)), 203 | z_c.repeat(n_query, 1) 204 | ), dim=1) 205 | 206 | return self.fc2(self.fc1(concatenated)).view(n_query, n_class) 207 | 208 | 209 | class NTLRelationModule(nn.Module): 210 | def __init__(self, input_dim, n_slice=100): 211 | super(NTLRelationModule, self).__init__() 212 | self.n_slice = n_slice 213 | M = np.random.randn(n_slice, input_dim, input_dim) 214 | M = M / np.linalg.norm(M, axis=(1, 2))[:, None, None] 215 | self.register_parameter("M", nn.Parameter(torch.Tensor(M))) 216 | self.dropout = nn.Dropout(p=0.25) 217 | self.fc = nn.Linear(n_slice, 1) 218 | 219 | def forward(self, z_q, z_c): 220 | n_query = z_q.size(0) 221 | n_class = z_c.size(0) 222 | 223 | v = self.dropout(nn.ReLU()(torch.cat([(z_q @ m @ z_c.T).unsqueeze(-1) for m in self.M], dim=-1).view(-1, self.n_slice))) 224 | r_logit = self.fc(v).view(n_query, n_class) 225 | return r_logit 226 | 227 | 228 | def run_relation( 229 | train_path: str, 230 | model_name_or_path: str, 231 | n_support: int, 232 | n_query: int, 233 | n_classes: int, 234 | valid_path: str = None, 235 | test_path: str = None, 236 | output_path: str = f"runs/{now()}", 237 | max_iter: int = 10000, 238 | evaluate_every: int = 100, 239 | early_stop: int = None, 240 | n_test_episodes: int = 1000, 241 | log_every: int = 10, 242 | relation_module_type: str = "base", 243 | ntl_n_slices: int = 100, 244 | arsc_format: bool = False, 245 | data_path: str = None 246 | ): 247 | if output_path: 248 | if os.path.exists(output_path) and len(os.listdir(output_path)): 249 | raise FileExistsError(f"Output path {output_path} already exists. Exiting.") 250 | 251 | # -------------------- 252 | # Creating Log Writers 253 | # -------------------- 254 | os.makedirs(output_path) 255 | os.makedirs(os.path.join(output_path, "logs/train")) 256 | train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(output_path, "logs/train"), flush_secs=1, max_queue=1) 257 | valid_writer: SummaryWriter = None 258 | test_writer: SummaryWriter = None 259 | log_dict = dict(train=list()) 260 | 261 | if valid_path: 262 | os.makedirs(os.path.join(output_path, "logs/valid")) 263 | valid_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/valid"), flush_secs=1, max_queue=1) 264 | log_dict["valid"] = list() 265 | if test_path: 266 | os.makedirs(os.path.join(output_path, "logs/test")) 267 | test_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/test"), flush_secs=1, max_queue=1) 268 | log_dict["test"] = list() 269 | 270 | def raw_data_to_labels_dict(data, shuffle=True): 271 | labels_dict = collections.defaultdict(list) 272 | for item in data: 273 | labels_dict[item["label"]].append(item["sentence"]) 274 | labels_dict = dict(labels_dict) 275 | if shuffle: 276 | for key, val in labels_dict.items(): 277 | random.shuffle(val) 278 | return labels_dict 279 | 280 | # Load model 281 | bert = BERTEncoder(model_name_or_path).to(device) 282 | matching_net = RelationNet(encoder=bert, relation_module_type=relation_module_type, ntl_n_slices=ntl_n_slices) 283 | optimizer = torch.optim.Adam(matching_net.parameters(), lr=2e-5) 284 | 285 | # Load data 286 | if not arsc_format: 287 | train_data = get_jsonl_data(train_path) 288 | train_data_dict = raw_data_to_labels_dict(train_data, shuffle=True) 289 | logger.info(f"train labels: {train_data_dict.keys()}") 290 | 291 | if valid_path: 292 | valid_data = get_jsonl_data(valid_path) 293 | valid_data_dict = raw_data_to_labels_dict(valid_data, shuffle=True) 294 | logger.info(f"valid labels: {valid_data_dict.keys()}") 295 | else: 296 | valid_data_dict = None 297 | 298 | if test_path: 299 | test_data = get_jsonl_data(test_path) 300 | test_data_dict = raw_data_to_labels_dict(test_data, shuffle=True) 301 | logger.info(f"test labels: {test_data_dict.keys()}") 302 | else: 303 | test_data_dict = None 304 | else: 305 | train_data_dict = None 306 | test_data_dict = None 307 | valid_data_dict = None 308 | 309 | train_metrics = collections.defaultdict(list) 310 | n_eval_since_last_best = 0 311 | best_valid_acc = 0.0 312 | 313 | for step in range(max_iter): 314 | if not arsc_format: 315 | loss, loss_dict = matching_net.train_step( 316 | optimizer=optimizer, 317 | data_dict=train_data_dict, 318 | n_support=n_support, 319 | n_query=n_query, 320 | n_classes=n_classes 321 | ) 322 | else: 323 | loss, loss_dict = matching_net.train_step_ARSC(optimizer=optimizer, data_path=data_path) 324 | 325 | for key, value in loss_dict["metrics"].items(): 326 | train_metrics[key].append(value) 327 | 328 | # Logging 329 | if (step + 1) % log_every == 0: 330 | for key, value in train_metrics.items(): 331 | train_writer.add_scalar(tag=key, scalar_value=np.mean(value), global_step=step) 332 | 333 | logger.info(f"train | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in train_metrics.items()])) 334 | log_dict["train"].append({ 335 | "metrics": [ 336 | { 337 | "tag": key, 338 | "value": np.mean(value) 339 | } 340 | for key, value in train_metrics.items() 341 | ], 342 | "global_step": step 343 | }) 344 | 345 | train_metrics = collections.defaultdict(list) 346 | 347 | if valid_path or test_path or data_path: 348 | if (step + 1) % evaluate_every == 0: 349 | for path, writer, set_type, set_data in zip( 350 | [valid_path, test_path], 351 | [valid_writer, test_writer], 352 | ["valid", "test"], 353 | [valid_data_dict, test_data_dict] 354 | ): 355 | if path: 356 | if not arsc_format: 357 | set_results = matching_net.test_step( 358 | data_dict=set_data, 359 | n_support=n_support, 360 | n_query=n_query, 361 | n_classes=n_classes, 362 | n_episodes=n_test_episodes 363 | ) 364 | else: 365 | set_results = matching_net.test_step_ARSC( 366 | data_path=data_path, 367 | n_episodes=n_test_episodes, 368 | set_type={"valid": "dev", "test": "test"}[set_type] 369 | ) 370 | for key, val in set_results.items(): 371 | writer.add_scalar(tag=key, scalar_value=val, global_step=step) 372 | log_dict[set_type].append({ 373 | "metrics": [ 374 | { 375 | "tag": key, 376 | "value": val 377 | } 378 | for key, val in set_results.items() 379 | ], 380 | "global_step": step 381 | }) 382 | 383 | logger.info(f"{set_type} | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in set_results.items()])) 384 | 385 | if set_type == "valid": 386 | if set_results["acc"] > best_valid_acc: 387 | best_valid_acc = set_results["acc"] 388 | n_eval_since_last_best = 0 389 | logger.info(f"Better eval results!") 390 | else: 391 | n_eval_since_last_best += 1 392 | logger.info(f"Worse eval results ({n_eval_since_last_best}/{early_stop})") 393 | 394 | if early_stop and n_eval_since_last_best >= early_stop: 395 | logger.warning(f"Early-stopping.") 396 | break 397 | with open(os.path.join(output_path, "metrics.json"), "w") as file: 398 | json.dump(log_dict, file, ensure_ascii=False) 399 | 400 | 401 | def main(): 402 | parser = argparse.ArgumentParser() 403 | parser.add_argument("--train-path", type=str, default=None, help="Path to training data") 404 | parser.add_argument("--valid-path", type=str, default=None, help="Path to validation data") 405 | parser.add_argument("--test-path", type=str, default=None, help="Path to testing data") 406 | parser.add_argument("--data-path", type=str, default=None, help="Path to data (ARSC only)") 407 | 408 | parser.add_argument("--output-path", type=str, default=f"runs/{now()}") 409 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use") 410 | parser.add_argument("--max-iter", type=int, default=10000, help="Max number of training episodes") 411 | parser.add_argument("--evaluate-every", type=int, default=100, help="Number of training episodes between each evaluation (on both valid, test)") 412 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging") 413 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set") 414 | parser.add_argument("--early-stop", type=int, default=0, help="Number of worse evaluation steps before stopping. 0=disabled") 415 | 416 | # Few-Shot related stuff 417 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class") 418 | parser.add_argument("--n-query", type=int, default=5, help="Number of query points for each class") 419 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode") 420 | parser.add_argument("--n-test-episodes", type=int, default=1000, help="Number of episodes during evaluation (valid, test)") 421 | 422 | # Relation Network-specific 423 | parser.add_argument("--relation-module-type", type=str, required=True, help="Which relation module to use") 424 | parser.add_argument("--ntl-n-slices", type=int, default=100, help="Number of matrices to use in NTL") 425 | # ARSC data 426 | parser.add_argument("--arsc-format", default=False, action="store_true", help="Using ARSC few-shot format") 427 | 428 | args = parser.parse_args() 429 | logger.debug(f"Received args {args}") 430 | 431 | # Set random seed 432 | set_seeds(args.seed) 433 | 434 | # Check if data path(s) exist 435 | for arg in [args.train_path, args.valid_path, args.test_path]: 436 | if arg and not os.path.exists(arg): 437 | raise FileNotFoundError(f"Data @ {arg} not found.") 438 | 439 | # Run 440 | run_relation( 441 | train_path=args.train_path, 442 | valid_path=args.valid_path, 443 | test_path=args.test_path, 444 | output_path=args.output_path, 445 | 446 | model_name_or_path=args.model_name_or_path, 447 | 448 | n_support=args.n_support, 449 | n_query=args.n_query, 450 | n_classes=args.n_classes, 451 | n_test_episodes=args.n_test_episodes, 452 | 453 | max_iter=args.max_iter, 454 | evaluate_every=args.evaluate_every, 455 | 456 | relation_module_type=args.relation_module_type, 457 | ntl_n_slices=args.ntl_n_slices, 458 | 459 | early_stop=args.early_stop, 460 | arsc_format=args.arsc_format, 461 | data_path=args.data_path 462 | ) 463 | 464 | # Save config 465 | with open(os.path.join(args.output_path, "config.json"), "w") as file: 466 | json.dump(vars(args), file, ensure_ascii=False) 467 | 468 | 469 | if __name__ == "__main__": 470 | main() 471 | -------------------------------------------------------------------------------- /models/induction/inductionnet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from utils.data import get_jsonl_data 4 | from utils.python import now, set_seeds 5 | import random 6 | import collections 7 | import os 8 | from typing import List, Dict 9 | from tensorboardX import SummaryWriter 10 | import numpy as np 11 | from models.encoders.bert_encoder import BERTEncoder 12 | import torch 13 | import torch.nn as nn 14 | import warnings 15 | import logging 16 | from utils.few_shot import create_episode, create_ARSC_train_episode, create_ARSC_test_episode 17 | 18 | logging.basicConfig() 19 | logger = logging.getLogger(__name__) 20 | logger.setLevel(logging.DEBUG) 21 | 22 | warnings.simplefilter('ignore') 23 | torch.autograd.set_detect_anomaly(True) 24 | 25 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 26 | 27 | 28 | class InductionNet(nn.Module): 29 | def __init__(self, encoder, hidden_dim: int = 768, ntl_n_slices: int = 100, n_routing_iter: int = 3): 30 | super(InductionNet, self).__init__() 31 | 32 | self.encoder = encoder 33 | 34 | self.ntl_n_slices: int = ntl_n_slices 35 | self.n_routing_iter: int = n_routing_iter 36 | self.hidden_dim = hidden_dim 37 | self.relation_module = NTLRelationModule(input_dim=self.hidden_dim, n_slice=self.ntl_n_slices).to(device) 38 | self.induction_module = InductionModule(input_dim=hidden_dim, n_routing_iter=self.n_routing_iter).to(device) 39 | 40 | def loss(self, sample): 41 | """ 42 | :param sample: { 43 | "xs": [ 44 | [support_A_1, support_A_2, ...], 45 | [support_B_1, support_B_2, ...], 46 | [support_C_1, support_C_2, ...], 47 | ... 48 | ], 49 | "xq": [ 50 | [query_A_1, query_A_2, ...], 51 | [query_B_1, query_B_2, ...], 52 | [query_C_1, query_C_2, ...], 53 | ... 54 | ] 55 | } 56 | :return: 57 | """ 58 | xs = sample["xs"] # support 59 | xq = sample["xq"] # query 60 | 61 | n_class = len(xs) 62 | assert len(xq) == n_class 63 | n_support = len(xs[0]) 64 | n_query = len(xq[0]) 65 | 66 | x = [item for xs_ in xs for item in xs_] + [item for xq_ in xq for item in xq_] 67 | z = self.encoder.forward(x) 68 | z_dim = z.size(-1) 69 | 70 | z_query = z[n_class * n_support:] 71 | z_support = z[:n_class * n_support].view(n_class, n_support, z_dim) 72 | 73 | class_representatives = self.induction_module.forward(z_s=z_support) 74 | relation_module_scores = self.relation_module.forward(z_q=z_query, z_c=class_representatives) 75 | true_labels = torch.zeros_like(relation_module_scores).to(device) 76 | 77 | for ix_class, class_query_sentences in enumerate(xq): 78 | for ix_sentence, sentence in enumerate(class_query_sentences): 79 | true_labels[ix_class * n_query + ix_sentence, ix_class] = 1 80 | 81 | # MSE LOSS 82 | # relation_module_scores = torch.sigmoid(relation_module_scores) 83 | # loss_fn = nn.MSELoss() 84 | # loss_val = loss_fn(relation_module_scores, true_labels) 85 | # acc_full = ((relation_module_scores > 0.5).float() == true_labels.float()).float().mean() 86 | # acc_exact = (((relation_module_scores > 0.5).float() - true_labels.float()).abs().max(dim=1)[0] == 0).float().mean() 87 | # acc_max = (relation_module_scores.argmax(1) == true_labels.argmax(1)).float().mean() 88 | # 89 | # return loss_val, { 90 | # "loss": loss_val.item(), 91 | # "metrics": { 92 | # "loss": loss_val.item(), 93 | # "acc_full": acc_full.item(), 94 | # "acc_exact": acc_exact.item(), 95 | # "acc_max": acc_max.item(), 96 | # "acc": acc_max.item() 97 | # }, 98 | # "y_hat": relation_module_scores.argmax(1).cpu().detach().numpy() 99 | # } 100 | 101 | # CE LOSS 102 | loss_fn = nn.CrossEntropyLoss() 103 | loss_val = loss_fn(relation_module_scores, true_labels.argmax(1)) 104 | acc_val = (true_labels.argmax(1) == relation_module_scores.argmax(1)).float().mean() 105 | return loss_val, { 106 | "loss": loss_val.item(), 107 | "metrics": { 108 | "loss": loss_val.item(), 109 | "acc": acc_val.item() 110 | }, 111 | "y_hat": relation_module_scores.argmax(1).cpu().detach().numpy() 112 | } 113 | 114 | def train_step(self, optimizer, data_dict: Dict[str, List[str]], n_support, n_classes, n_query): 115 | 116 | episode = create_episode( 117 | data_dict=data_dict, 118 | n_support=n_support, 119 | n_classes=n_classes, 120 | n_query=n_query 121 | ) 122 | 123 | self.train() 124 | optimizer.zero_grad() 125 | torch.cuda.empty_cache() 126 | loss, loss_dict = self.loss(episode) 127 | loss.backward() 128 | optimizer.step() 129 | 130 | return loss, loss_dict 131 | 132 | def test_step(self, data_dict, n_support, n_classes, n_query, n_episodes=1000): 133 | metrics = collections.defaultdict(list) 134 | self.eval() 135 | for i in range(n_episodes): 136 | episode = create_episode( 137 | data_dict=data_dict, 138 | n_support=n_support, 139 | n_classes=n_classes, 140 | n_query=n_query 141 | ) 142 | 143 | with torch.no_grad(): 144 | loss, loss_dict = self.loss(episode) 145 | 146 | for key, value in loss_dict["metrics"].items(): 147 | metrics[key].append(value) 148 | 149 | return { 150 | key: np.mean(value) for key, value in metrics.items() 151 | } 152 | 153 | def train_step_ARSC(self, data_path: str, optimizer): 154 | episode = create_ARSC_train_episode(prefix=data_path, n_support=5, n_query=5) 155 | 156 | self.train() 157 | optimizer.zero_grad() 158 | torch.cuda.empty_cache() 159 | loss, loss_dict = self.loss(episode) 160 | loss.backward() 161 | optimizer.step() 162 | 163 | return loss, loss_dict 164 | 165 | def test_step_ARSC(self, data_path: str, n_episodes=1000, set_type="test"): 166 | assert set_type in ("dev", "test") 167 | metrics = collections.defaultdict(list) 168 | self.eval() 169 | for i in range(n_episodes): 170 | episode = create_ARSC_test_episode(prefix=data_path, n_query=5, set_type=set_type) 171 | 172 | with torch.no_grad(): 173 | loss, loss_dict = self.loss(episode) 174 | 175 | for key, value in loss_dict["metrics"].items(): 176 | metrics[key].append(value) 177 | 178 | return { 179 | key: np.mean(value) for key, value in metrics.items() 180 | } 181 | 182 | 183 | class InductionModule(nn.Module): 184 | def __init__(self, input_dim: int, n_routing_iter: int = 3): 185 | super(InductionModule, self).__init__() 186 | self.input_dim: int = input_dim 187 | self.n_routing_iter: int = n_routing_iter 188 | 189 | # Init Ws, bs 190 | self.Ws_bs = nn.Linear(input_dim, input_dim) 191 | 192 | @staticmethod 193 | def squash(x): 194 | return (x / x.norm(dim=1)[:, None]) * ((x.norm(dim=1) ** 2) / (1 + (x.norm(dim=1) ** 2)))[:, None] 195 | 196 | def forward(self, z_s): 197 | """ 198 | :param z_s: embedding of support samples, shape=(C, K, hidden_dim) 199 | :return: 200 | """ 201 | C, K, hidden_dim = z_s.size() 202 | class_representatives: List[torch.Tensor] = list() 203 | 204 | for i in range(C): 205 | z_squashed = self.squash(self.Ws_bs(z_s[i])) 206 | b_i = torch.autograd.Variable(torch.zeros(K)).to(device) 207 | for iteration in range(self.n_routing_iter): 208 | d_i = b_i.clone().softmax(dim=-1) 209 | c_i = torch.matmul(d_i, z_squashed) 210 | c_i = self.squash(c_i.view(1, -1)) 211 | b_i += (z_squashed @ c_i.view(-1, 1)).view(-1) 212 | 213 | class_representatives.append(c_i) 214 | class_representatives = torch.cat(class_representatives).to(device) 215 | return class_representatives 216 | 217 | 218 | class NTLRelationModule(nn.Module): 219 | def __init__(self, input_dim, n_slice=100): 220 | super(NTLRelationModule, self).__init__() 221 | self.n_slice = n_slice 222 | M = np.random.randn(n_slice, input_dim, input_dim) 223 | M = M / np.linalg.norm(M, axis=(1, 2))[:, None, None] 224 | self.register_parameter("M", nn.Parameter(torch.Tensor(M))) 225 | self.dropout = nn.Dropout(p=0.25) 226 | self.fc = nn.Linear(n_slice, 1) 227 | 228 | def forward(self, z_q, z_c): 229 | n_query = z_q.size(0) 230 | n_class = z_c.size(0) 231 | 232 | v = self.dropout(nn.ReLU()(torch.cat([(z_q @ m @ z_c.T).unsqueeze(-1) for m in self.M], dim=-1).view(-1, self.n_slice))) 233 | r_logit = self.fc(v).view(n_query, n_class) 234 | return r_logit 235 | 236 | 237 | def run_induction( 238 | train_path: str, 239 | model_name_or_path: str, 240 | n_support: int, 241 | n_query: int, 242 | n_classes: int, 243 | valid_path: str = None, 244 | test_path: str = None, 245 | output_path: str = f"runs/{now()}", 246 | max_iter: int = 10000, 247 | evaluate_every: int = 100, 248 | early_stop: int = None, 249 | n_test_episodes: int = 1000, 250 | log_every: int = 10, 251 | ntl_n_slices: int = 100, 252 | n_routing_iter: int = 3, 253 | arsc_format: bool = False, 254 | data_path: str = None 255 | ): 256 | if output_path: 257 | if os.path.exists(output_path) and len(os.listdir(output_path)): 258 | raise FileExistsError(f"Output path {output_path} already exists. Exiting.") 259 | 260 | # -------------------- 261 | # Creating Log Writers 262 | # -------------------- 263 | os.makedirs(output_path) 264 | os.makedirs(os.path.join(output_path, "logs/train")) 265 | train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(output_path, "logs/train"), flush_secs=1, max_queue=1) 266 | valid_writer: SummaryWriter = None 267 | test_writer: SummaryWriter = None 268 | log_dict = dict(train=list()) 269 | 270 | if valid_path: 271 | os.makedirs(os.path.join(output_path, "logs/valid")) 272 | valid_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/valid"), flush_secs=1, max_queue=1) 273 | log_dict["valid"] = list() 274 | if test_path: 275 | os.makedirs(os.path.join(output_path, "logs/test")) 276 | test_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/test"), flush_secs=1, max_queue=1) 277 | log_dict["test"] = list() 278 | 279 | def raw_data_to_labels_dict(data, shuffle=True): 280 | labels_dict = collections.defaultdict(list) 281 | for item in data: 282 | labels_dict[item["label"]].append(item["sentence"]) 283 | labels_dict = dict(labels_dict) 284 | if shuffle: 285 | for key, val in labels_dict.items(): 286 | random.shuffle(val) 287 | return labels_dict 288 | 289 | # Load model 290 | bert = BERTEncoder(model_name_or_path).to(device) 291 | induction_net = InductionNet( 292 | encoder=bert, 293 | ntl_n_slices=ntl_n_slices, 294 | n_routing_iter=n_routing_iter 295 | ) 296 | optimizer = torch.optim.Adam(induction_net.parameters(), lr=2e-5) 297 | 298 | # Load data 299 | if not arsc_format: 300 | train_data = get_jsonl_data(train_path) 301 | train_data_dict = raw_data_to_labels_dict(train_data, shuffle=True) 302 | logger.info(f"train labels: {train_data_dict.keys()}") 303 | 304 | if valid_path: 305 | valid_data = get_jsonl_data(valid_path) 306 | valid_data_dict = raw_data_to_labels_dict(valid_data, shuffle=True) 307 | logger.info(f"valid labels: {valid_data_dict.keys()}") 308 | else: 309 | valid_data_dict = None 310 | 311 | if test_path: 312 | test_data = get_jsonl_data(test_path) 313 | test_data_dict = raw_data_to_labels_dict(test_data, shuffle=True) 314 | logger.info(f"test labels: {test_data_dict.keys()}") 315 | else: 316 | test_data_dict = None 317 | else: 318 | train_data_dict = None 319 | test_data_dict = None 320 | valid_data_dict = None 321 | 322 | train_metrics = collections.defaultdict(list) 323 | n_eval_since_last_best = 0 324 | best_valid_acc = 0.0 325 | 326 | for step in range(max_iter): 327 | if not arsc_format: 328 | loss, loss_dict = induction_net.train_step( 329 | optimizer=optimizer, 330 | data_dict=train_data_dict, 331 | n_support=n_support, 332 | n_query=n_query, 333 | n_classes=n_classes 334 | ) 335 | else: 336 | loss, loss_dict = induction_net.train_step_ARSC( 337 | optimizer=optimizer, 338 | data_path=data_path 339 | ) 340 | 341 | for key, value in loss_dict["metrics"].items(): 342 | train_metrics[key].append(value) 343 | 344 | # Logging 345 | if (step + 1) % log_every == 0: 346 | for key, value in train_metrics.items(): 347 | train_writer.add_scalar(tag=key, scalar_value=np.mean(value), global_step=step) 348 | 349 | logger.info(f"train | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in train_metrics.items()])) 350 | log_dict["train"].append({ 351 | "metrics": [ 352 | { 353 | "tag": key, 354 | "value": np.mean(value) 355 | } 356 | for key, value in train_metrics.items() 357 | ], 358 | "global_step": step 359 | }) 360 | 361 | train_metrics = collections.defaultdict(list) 362 | 363 | if valid_path or test_path or data_path: 364 | if (step + 1) % evaluate_every == 0: 365 | for path, writer, set_type, set_data in zip( 366 | [valid_path, test_path], 367 | [valid_writer, test_writer], 368 | ["valid", "test"], 369 | [valid_data_dict, test_data_dict] 370 | ): 371 | if path: 372 | if not arsc_format: 373 | set_results = induction_net.test_step( 374 | data_dict=set_data, 375 | n_support=n_support, 376 | n_query=n_query, 377 | n_classes=n_classes, 378 | n_episodes=n_test_episodes 379 | ) 380 | else: 381 | set_results = induction_net.test_step_ARSC( 382 | data_path=data_path, 383 | n_episodes=n_test_episodes, 384 | set_type={"valid": "dev", "test": "test"}[set_type] 385 | ) 386 | for key, val in set_results.items(): 387 | writer.add_scalar(tag=key, scalar_value=val, global_step=step) 388 | log_dict[set_type].append({ 389 | "metrics": [ 390 | { 391 | "tag": key, 392 | "value": val 393 | } 394 | for key, val in set_results.items() 395 | ], 396 | "global_step": step 397 | }) 398 | 399 | logger.info(f"{set_type} | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in set_results.items()])) 400 | 401 | if set_type == "valid": 402 | if set_results["acc"] > best_valid_acc: 403 | best_valid_acc = set_results["acc"] 404 | n_eval_since_last_best = 0 405 | logger.info(f"Better eval results!") 406 | else: 407 | n_eval_since_last_best += 1 408 | logger.info(f"Worse eval results ({n_eval_since_last_best}/{early_stop})") 409 | 410 | if early_stop and n_eval_since_last_best >= early_stop: 411 | print(f"Early-stopping.") 412 | break 413 | with open(os.path.join(output_path, "metrics.json"), "w") as file: 414 | json.dump(log_dict, file, ensure_ascii=False) 415 | 416 | 417 | def main(): 418 | parser = argparse.ArgumentParser() 419 | parser.add_argument("--train-path", type=str, default=None, help="Path to training data") 420 | parser.add_argument("--valid-path", type=str, default=None, help="Path to validation data") 421 | parser.add_argument("--test-path", type=str, default=None, help="Path to testing data") 422 | parser.add_argument("--data-path", type=str, default=None, help="Path to data (ARSC only)") 423 | 424 | parser.add_argument("--output-path", type=str, default=f"runs/{now()}") 425 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use") 426 | parser.add_argument("--max-iter", type=int, default=10000, help="Max number of training episodes") 427 | parser.add_argument("--evaluate-every", type=int, default=100, help="Number of training episodes between each evaluation (on both valid, test)") 428 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging") 429 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set") 430 | parser.add_argument("--early-stop", type=int, default=0, help="Number of worse evaluation steps before stopping. 0=disabled") 431 | 432 | # Few-Shot related stuff 433 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class") 434 | parser.add_argument("--n-query", type=int, default=5, help="Number of query points for each class") 435 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode") 436 | parser.add_argument("--n-test-episodes", type=int, default=1000, help="Number of episodes during evaluation (valid, test)") 437 | 438 | # Relation Network-specific 439 | parser.add_argument("--ntl-n-slices", type=int, default=100, help="Number of matrices to use in NTL") 440 | parser.add_argument("--n-routing-iter", type=int, default=3, help="Number of routing iterations in the induction module") 441 | 442 | # ARSC data 443 | parser.add_argument("--arsc-format", default=False, action="store_true", help="Using ARSC few-shot format") 444 | args = parser.parse_args() 445 | 446 | # Set random seed 447 | set_seeds(args.seed) 448 | 449 | # Check if data path(s) exist 450 | for arg in [args.train_path, args.valid_path, args.test_path]: 451 | if arg and not os.path.exists(arg): 452 | raise FileNotFoundError(f"Data @ {arg} not found.") 453 | 454 | # Run 455 | run_induction( 456 | train_path=args.train_path, 457 | valid_path=args.valid_path, 458 | test_path=args.test_path, 459 | output_path=args.output_path, 460 | 461 | model_name_or_path=args.model_name_or_path, 462 | 463 | n_support=args.n_support, 464 | n_query=args.n_query, 465 | n_classes=args.n_classes, 466 | n_test_episodes=args.n_test_episodes, 467 | 468 | max_iter=args.max_iter, 469 | evaluate_every=args.evaluate_every, 470 | n_routing_iter=args.n_routing_iter, 471 | ntl_n_slices=args.ntl_n_slices, 472 | early_stop=args.early_stop, 473 | arsc_format=args.arsc_format, 474 | data_path=args.data_path 475 | ) 476 | 477 | # Save config 478 | with open(os.path.join(args.output_path, "config.json"), "w") as file: 479 | json.dump(vars(args), file, ensure_ascii=False) 480 | 481 | 482 | if __name__ == "__main__": 483 | main() 484 | -------------------------------------------------------------------------------- /models/proto/protaugment.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 4 | 5 | import json 6 | import argparse 7 | 8 | from transformers import AutoTokenizer 9 | 10 | from models.encoders.bert_encoder import BERTEncoder 11 | from paraphrase.modeling import UnigramRandomDropParaphraseBatchPreparer, DBSParaphraseModel, BigramDropParaphraseBatchPreparer, BaseParaphraseBatchPreparer 12 | from paraphrase.utils.data import FewShotDataset, FewShotSSLParaphraseDataset, FewShotSSLFileDataset 13 | from utils.data import get_jsonl_data, FewShotDataLoader 14 | from utils.python import now, set_seeds 15 | import random 16 | import collections 17 | import os 18 | from typing import List, Dict, Callable, Union 19 | from tensorboardX import SummaryWriter 20 | import numpy as np 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as torch_functional 24 | from torch.autograd import Variable 25 | import warnings 26 | import logging 27 | from utils.few_shot import create_episode, create_ARSC_test_episode, create_ARSC_train_episode 28 | from utils.math import euclidean_dist, cosine_similarity 29 | 30 | logger = logging.getLogger(__name__) 31 | logger.setLevel(logging.DEBUG) 32 | 33 | warnings.simplefilter('ignore') 34 | 35 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 36 | 37 | 38 | class ProtAugmentNet(nn.Module): 39 | def __init__(self, encoder, metric="euclidean"): 40 | super(ProtAugmentNet, self).__init__() 41 | 42 | self.encoder: BERTEncoder = encoder 43 | self.metric = metric 44 | assert self.metric in ('euclidean', 'cosine') 45 | 46 | def loss(self, sample, supervised_loss_share: float = 0): 47 | """ 48 | :param supervised_loss_share: share of supervised loss in total loss 49 | :param sample: { 50 | "xs": [ 51 | [support_A_1, support_A_2, ...], 52 | [support_B_1, support_B_2, ...], 53 | [support_C_1, support_C_2, ...], 54 | ... 55 | ], 56 | "xq": [ 57 | [query_A_1, query_A_2, ...], 58 | [query_B_1, query_B_2, ...], 59 | [query_C_1, query_C_2, ...], 60 | ... 61 | ] 62 | } 63 | :return: 64 | """ 65 | xs = sample['xs'] # support 66 | xq = sample['xq'] # query 67 | 68 | n_class = len(xs) 69 | assert len(xq) == n_class 70 | n_support = len(xs[0]) 71 | n_query = len(xq[0]) 72 | 73 | target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long() 74 | target_inds = Variable(target_inds, requires_grad=False).to(device) 75 | 76 | has_augment = "x_augment" in sample 77 | if has_augment: 78 | augmentations = sample["x_augment"] 79 | 80 | n_augmentations_samples = len(sample["x_augment"]) 81 | n_augmentations_per_sample = [len(item['tgt_texts']) for item in augmentations] 82 | assert len(set(n_augmentations_per_sample)) == 1 83 | n_augmentations_per_sample = n_augmentations_per_sample[0] 84 | 85 | supports = [item["sentence"] for xs_ in xs for item in xs_] 86 | queries = [item["sentence"] for xq_ in xq for item in xq_] 87 | augmentations_supports = [[item2 for item2 in item["tgt_texts"]] for item in sample["x_augment"]] 88 | augmentation_queries = [item["src_text"] for item in sample["x_augment"]] 89 | 90 | # Encode 91 | x = supports + queries + [item2 for item1 in augmentations_supports for item2 in item1] + augmentation_queries 92 | z = self.encoder.embed_sentences(x) 93 | z_dim = z.size(-1) 94 | 95 | # Dispatch 96 | z_support = z[:len(supports)].view(n_class, n_support, z_dim).mean(dim=[1]) 97 | z_query = z[len(supports):len(supports) + len(queries)] 98 | z_aug_support = (z[len(supports) + len(queries):len(supports) + len(queries) + n_augmentations_per_sample * n_augmentations_samples] 99 | .view(n_augmentations_samples, n_augmentations_per_sample, z_dim).mean(dim=[1])) 100 | z_aug_query = z[-len(augmentation_queries):] 101 | else: 102 | # When not using augmentations 103 | supports = [item["sentence"] for xs_ in xs for item in xs_] 104 | queries = [item["sentence"] for xq_ in xq for item in xq_] 105 | 106 | # Encode 107 | x = supports + queries 108 | z = self.encoder.embed_sentences(x) 109 | z_dim = z.size(-1) 110 | 111 | # Dispatch 112 | z_support = z[:len(supports)].view(n_class, n_support, z_dim).mean(dim=[1]) 113 | z_query = z[len(supports):len(supports) + len(queries)] 114 | 115 | if self.metric == "euclidean": 116 | supervised_dists = euclidean_dist(z_query, z_support) 117 | if has_augment: 118 | unsupervised_dists = euclidean_dist(z_aug_query, z_aug_support) 119 | elif self.metric == "cosine": 120 | supervised_dists = (-cosine_similarity(z_query, z_support) + 1) * 5 121 | if has_augment: 122 | unsupervised_dists = (-cosine_similarity(z_aug_query, z_aug_support) + 1) * 5 123 | else: 124 | raise NotImplementedError 125 | 126 | from torch.nn import CrossEntropyLoss 127 | supervised_loss = CrossEntropyLoss()(-supervised_dists, target_inds.reshape(-1)) 128 | _, y_hat_supervised = (-supervised_dists).max(1) 129 | acc_val_supervised = torch.eq(y_hat_supervised, target_inds.reshape(-1)).float().mean() 130 | 131 | if has_augment: 132 | # Unsupervised loss 133 | unsupervised_target_inds = torch.range(0, n_augmentations_samples - 1).to(device).long() 134 | unsupervised_loss = CrossEntropyLoss()(-unsupervised_dists, unsupervised_target_inds) 135 | _, y_hat_unsupervised = (-unsupervised_dists).max(1) 136 | acc_val_unsupervised = torch.eq(y_hat_unsupervised, unsupervised_target_inds.reshape(-1)).float().mean() 137 | 138 | # Final loss 139 | assert 0 <= supervised_loss_share <= 1 140 | final_loss = (supervised_loss_share) * supervised_loss + (1 - supervised_loss_share) * unsupervised_loss 141 | 142 | return final_loss, { 143 | "metrics": { 144 | "supervised_acc": acc_val_supervised.item(), 145 | "unsupervised_acc": acc_val_unsupervised.item(), 146 | "supervised_loss": supervised_loss.item(), 147 | "unsupervised_loss": unsupervised_loss.item(), 148 | "supervised_loss_share": supervised_loss_share, 149 | "final_loss": final_loss.item(), 150 | }, 151 | "supervised_dists": supervised_dists, 152 | "unsupervised_dists": unsupervised_dists, 153 | "target": target_inds 154 | } 155 | 156 | return supervised_loss, { 157 | "metrics": { 158 | "acc": acc_val_supervised.item(), 159 | "loss": supervised_loss.item(), 160 | }, 161 | "dists": supervised_dists, 162 | "target": target_inds 163 | } 164 | 165 | def train_step(self, optimizer, episode, supervised_loss_share: float): 166 | self.train() 167 | optimizer.zero_grad() 168 | torch.cuda.empty_cache() 169 | loss, loss_dict = self.loss(episode, supervised_loss_share=supervised_loss_share) 170 | loss.backward() 171 | optimizer.step() 172 | 173 | return loss, loss_dict 174 | 175 | def test_step(self, dataset: FewShotDataset, n_episodes: int = 1000): 176 | metrics = collections.defaultdict(list) 177 | 178 | self.eval() 179 | for i in range(n_episodes): 180 | episode = dataset.get_episode() 181 | 182 | with torch.no_grad(): 183 | loss, loss_dict = self.loss(episode, supervised_loss_share=1) 184 | 185 | for k, v in loss_dict["metrics"].items(): 186 | metrics[k].append(v) 187 | 188 | return { 189 | key: np.mean(value) for key, value in metrics.items() 190 | } 191 | 192 | 193 | def run_proto( 194 | # Compulsory! 195 | data_path: str, 196 | train_labels_path: str, 197 | model_name_or_path: str, 198 | 199 | # Few-shot Stuff 200 | n_support: int, 201 | n_query: int, 202 | n_classes: int, 203 | metric: str = "euclidean", 204 | 205 | # Optional path to augmented data 206 | unlabeled_path: str = None, 207 | 208 | # Path training data ONLY (optional) 209 | train_path: str = None, 210 | 211 | # Validation & test 212 | valid_labels_path: str = None, 213 | test_labels_path: str = None, 214 | evaluate_every: int = 100, 215 | n_test_episodes: int = 1000, 216 | 217 | # Logging & Saving 218 | output_path: str = f'runs/{now()}', 219 | log_every: int = 10, 220 | 221 | # Training stuff 222 | max_iter: int = 10000, 223 | early_stop: int = None, 224 | 225 | # Augmentation & paraphrase 226 | n_augmentation: int = 5, 227 | paraphrase_model_name_or_path: str = None, 228 | paraphrase_tokenizer_name_or_path: str = None, 229 | paraphrase_num_beams: int = None, 230 | paraphrase_beam_group_size: int = None, 231 | paraphrase_diversity_penalty: float = None, 232 | paraphrase_filtering_strategy: str = None, 233 | paraphrase_drop_strategy: str = None, 234 | paraphrase_drop_chance_speed: str = None, 235 | paraphrase_drop_chance_auc: float = None, 236 | supervised_loss_share_fn: Callable[[int, int], float] = lambda x, y: 1 - (x / y), 237 | 238 | augmentation_data_path: str = None 239 | ): 240 | if output_path: 241 | if os.path.exists(output_path) and len(os.listdir(output_path)): 242 | raise FileExistsError(f"Output path {output_path} already exists. Exiting.") 243 | 244 | # -------------------- 245 | # Creating Log Writers 246 | # -------------------- 247 | os.makedirs(output_path) 248 | os.makedirs(os.path.join(output_path, "logs/train")) 249 | train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(output_path, "logs/train"), flush_secs=1, max_queue=1) 250 | valid_writer: SummaryWriter = None 251 | test_writer: SummaryWriter = None 252 | log_dict = dict(train=list()) 253 | 254 | # ---------- 255 | # Load model 256 | # ---------- 257 | bert = BERTEncoder(model_name_or_path).to(device) 258 | protonet: ProtAugmentNet = ProtAugmentNet(encoder=bert, metric=metric) 259 | optimizer = torch.optim.Adam(protonet.parameters(), lr=2e-5) 260 | 261 | # ------------------ 262 | # Load Train Dataset 263 | # ------------------ 264 | if augmentation_data_path: 265 | # If an augmentation data path is provided, uses those pre-generated augmentations 266 | train_dataset = FewShotSSLFileDataset( 267 | data_path=train_path if train_path else data_path, 268 | labels_path=train_labels_path, 269 | n_classes=n_classes, 270 | n_support=n_support, 271 | n_query=n_query, 272 | n_unlabeled=n_augmentation, 273 | unlabeled_file_path=augmentation_data_path, 274 | ) 275 | else: 276 | # --------------------- 277 | # Load paraphrase model 278 | # --------------------- 279 | paraphrase_model_device = torch.device("cpu") if "20newsgroup" in data_path else torch.device("cuda") 280 | logger.info(f"Paraphrase model device: {paraphrase_model_device}") 281 | paraphrase_tokenizer = AutoTokenizer.from_pretrained(paraphrase_tokenizer_name_or_path) 282 | if paraphrase_drop_strategy == "unigram": 283 | paraphrase_batch_preparer = UnigramRandomDropParaphraseBatchPreparer( 284 | tokenizer=paraphrase_tokenizer, 285 | auc=paraphrase_drop_chance_auc, 286 | drop_chance_speed=paraphrase_drop_chance_speed, 287 | device=paraphrase_model_device 288 | ) 289 | elif paraphrase_drop_strategy == "bigram": 290 | paraphrase_batch_preparer = BigramDropParaphraseBatchPreparer(tokenizer=paraphrase_tokenizer, device=paraphrase_model_device) 291 | else: 292 | paraphrase_batch_preparer = BaseParaphraseBatchPreparer(tokenizer=paraphrase_tokenizer, device=paraphrase_model_device) 293 | 294 | paraphrase_model = DBSParaphraseModel( 295 | model_name_or_path=paraphrase_model_name_or_path, 296 | tok_name_or_path=paraphrase_tokenizer_name_or_path, 297 | num_beams=paraphrase_num_beams, 298 | beam_group_size=paraphrase_beam_group_size, 299 | diversity_penalty=paraphrase_diversity_penalty, 300 | filtering_strategy=paraphrase_filtering_strategy, 301 | paraphrase_batch_preparer=paraphrase_batch_preparer, 302 | device=paraphrase_model_device 303 | ) 304 | 305 | train_dataset = FewShotSSLParaphraseDataset( 306 | data_path=train_path if train_path else data_path, 307 | labels_path=train_labels_path, 308 | n_classes=n_classes, 309 | n_support=n_support, 310 | n_query=n_query, 311 | n_unlabeled=n_augmentation, 312 | unlabeled_file_path=unlabeled_path, 313 | paraphrase_model=paraphrase_model 314 | ) 315 | logger.info(f"Train dataset has {len(train_dataset)} items") 316 | 317 | # --------- 318 | # Load data 319 | # --------- 320 | logger.info(f"train labels: {train_dataset.data.keys()}") 321 | valid_dataset: FewShotDataset = None 322 | if valid_labels_path: 323 | os.makedirs(os.path.join(output_path, "logs/valid")) 324 | valid_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/valid"), flush_secs=1, max_queue=1) 325 | log_dict["valid"] = list() 326 | valid_dataset = FewShotDataset(data_path=data_path, labels_path=valid_labels_path, n_classes=n_classes, n_support=n_support, n_query=n_query) 327 | logger.info(f"valid labels: {valid_dataset.data.keys()}") 328 | assert len(set(valid_dataset.data.keys()) & set(train_dataset.data.keys())) == 0 329 | 330 | test_dataset: FewShotDataset = None 331 | if test_labels_path: 332 | os.makedirs(os.path.join(output_path, "logs/test")) 333 | test_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/test"), flush_secs=1, max_queue=1) 334 | log_dict["test"] = list() 335 | test_dataset = FewShotDataset(data_path=data_path, labels_path=test_labels_path, n_classes=n_classes, n_support=n_support, n_query=n_query) 336 | logger.info(f"test labels: {test_dataset.data.keys()}") 337 | assert len(set(test_dataset.data.keys()) & set(train_dataset.data.keys())) == 0 338 | 339 | train_metrics = collections.defaultdict(list) 340 | n_eval_since_last_best = 0 341 | best_valid_acc = 0.0 342 | 343 | for step in range(max_iter): 344 | episode = train_dataset.get_episode() 345 | 346 | supervised_loss_share = supervised_loss_share_fn(step, max_iter) 347 | loss, loss_dict = protonet.train_step(optimizer=optimizer, episode=episode, supervised_loss_share=supervised_loss_share) 348 | 349 | for key, value in loss_dict["metrics"].items(): 350 | train_metrics[key].append(value) 351 | 352 | # Logging 353 | if (step + 1) % log_every == 0: 354 | for key, value in train_metrics.items(): 355 | train_writer.add_scalar(tag=key, scalar_value=np.mean(value), global_step=step) 356 | logger.info(f"train | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in train_metrics.items()])) 357 | log_dict["train"].append({ 358 | "metrics": [ 359 | { 360 | "tag": key, 361 | "value": np.mean(value) 362 | } 363 | for key, value in train_metrics.items() 364 | ], 365 | "global_step": step 366 | }) 367 | 368 | train_metrics = collections.defaultdict(list) 369 | 370 | if valid_labels_path or test_labels_path: 371 | if (step + 1) % evaluate_every == 0: 372 | for labels_path, writer, set_type, set_dataset in zip( 373 | [valid_labels_path, test_labels_path], 374 | [valid_writer, test_writer], 375 | ["valid", "test"], 376 | [valid_dataset, test_dataset] 377 | ): 378 | if set_dataset: 379 | 380 | set_results = protonet.test_step( 381 | dataset=set_dataset, 382 | n_episodes=n_test_episodes 383 | ) 384 | 385 | for key, val in set_results.items(): 386 | writer.add_scalar(tag=key, scalar_value=val, global_step=step) 387 | log_dict[set_type].append({ 388 | "metrics": [ 389 | { 390 | "tag": key, 391 | "value": val 392 | } 393 | for key, val in set_results.items() 394 | ], 395 | "global_step": step 396 | }) 397 | 398 | logger.info(f"{set_type} | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in set_results.items()])) 399 | if set_type == "valid": 400 | if set_results["acc"] > best_valid_acc: 401 | best_valid_acc = set_results["acc"] 402 | n_eval_since_last_best = 0 403 | logger.info(f"Better eval results!") 404 | else: 405 | n_eval_since_last_best += 1 406 | logger.info(f"Worse eval results ({n_eval_since_last_best}/{early_stop})") 407 | 408 | if early_stop and n_eval_since_last_best >= early_stop: 409 | logger.warning(f"Early-stopping.") 410 | break 411 | 412 | with open(os.path.join(output_path, 'metrics.json'), "w") as file: 413 | json.dump(log_dict, file, ensure_ascii=False) 414 | 415 | 416 | def main(): 417 | parser = argparse.ArgumentParser() 418 | parser.add_argument("--data-path", type=str, required=True, help="Path to data") 419 | parser.add_argument("--train-labels-path", type=str, required=True, help="Path to train labels") 420 | parser.add_argument("--train-path", type=str, help="Path to training data (if provided, picks training data from this path instead of --data-path") 421 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use") 422 | 423 | # Few-Shot related stuff 424 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class") 425 | parser.add_argument("--n-query", type=int, default=5, help="Number of query points for each class") 426 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode") 427 | parser.add_argument("--metric", type=str, default="euclidean", help="Metric to use", choices=("euclidean", "cosine")) 428 | 429 | # Path to augmented data 430 | parser.add_argument("--unlabeled-path", type=str, help="Path to data containing augmentations used for consistency") 431 | 432 | # Validation & test 433 | parser.add_argument("--valid-labels-path", type=str, required=True, help="Path to valid labels") 434 | parser.add_argument("--test-labels-path", type=str, required=True, help="Path to test labels") 435 | parser.add_argument("--evaluate-every", type=int, default=100, help="Number of training episodes between each evaluation (on both valid, test)") 436 | parser.add_argument("--n-test-episodes", type=int, default=1000, help="Number of episodes during evaluation (valid, test)") 437 | 438 | # Logging & Saving 439 | parser.add_argument("--output-path", type=str, default=f'runs/{now()}') 440 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging") 441 | 442 | # Training stuff 443 | parser.add_argument("--max-iter", type=int, default=10000, help="Max number of training episodes") 444 | parser.add_argument("--early-stop", type=int, default=0, help="Number of worse evaluation steps before stopping. 0=disabled") 445 | 446 | # Augmentation & Paraphrase 447 | parser.add_argument("--n-augmentation", type=int, help="Number of unlabeled data points per class (proto++)", default=0) 448 | parser.add_argument("--paraphrase-model-name-or-path", type=str) 449 | parser.add_argument("--paraphrase-tokenizer-name-or-path", type=str) 450 | parser.add_argument("--paraphrase-num-beams", type=int) 451 | parser.add_argument("--paraphrase-beam-group-size", type=int) 452 | parser.add_argument("--paraphrase-diversity-penalty", type=float) 453 | parser.add_argument("--paraphrase-filtering-strategy", type=str) 454 | parser.add_argument("--paraphrase-drop-strategy", type=str) 455 | parser.add_argument("--paraphrase-drop-chance-speed", type=str) 456 | parser.add_argument("--paraphrase-drop-chance-auc", type=float) 457 | 458 | # Augmentation file path (optional, but if provided it will be used) 459 | parser.add_argument("--augmentation-data-path", type=str) 460 | 461 | # Seed 462 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set") 463 | 464 | # Supervised loss share 465 | parser.add_argument("--supervised-loss-share-power", default=1.0, type=float, help="supervised_loss_share = 1 - (x/y) ** ") 466 | 467 | args = parser.parse_args() 468 | logger.debug(f"Received args: {json.dumps(args.__dict__, sort_keys=True, ensure_ascii=False, indent=1)}") 469 | # Set random seed 470 | set_seeds(args.seed) 471 | 472 | # Check if data path(s) exist 473 | for arg in [args.data_path, args.train_labels_path, args.valid_labels_path, args.test_labels_path]: 474 | if arg and not os.path.exists(arg): 475 | raise FileNotFoundError(f"Data @ {arg} not found.") 476 | 477 | # Create supervised_loss_share_fn 478 | def get_supervised_loss_share_fn(supervised_loss_share_power: Union[int, float]) -> Callable[[int, int], float]: 479 | def _supervised_loss_share_fn(current_step: int, max_steps: int) -> float: 480 | assert current_step <= max_steps 481 | return 1 - (current_step / max_steps) ** supervised_loss_share_power 482 | 483 | return _supervised_loss_share_fn 484 | 485 | supervised_loss_share_fn = get_supervised_loss_share_fn(args.supervised_loss_share_power) 486 | 487 | # Run 488 | run_proto( 489 | data_path=args.data_path, 490 | train_labels_path=args.train_labels_path, 491 | train_path=args.train_path, 492 | model_name_or_path=args.model_name_or_path, 493 | n_support=args.n_support, 494 | n_query=args.n_query, 495 | n_classes=args.n_classes, 496 | metric=args.metric, 497 | unlabeled_path=args.unlabeled_path, 498 | 499 | valid_labels_path=args.valid_labels_path, 500 | test_labels_path=args.test_labels_path, 501 | evaluate_every=args.evaluate_every, 502 | n_test_episodes=args.n_test_episodes, 503 | 504 | output_path=args.output_path, 505 | log_every=args.log_every, 506 | max_iter=args.max_iter, 507 | early_stop=args.early_stop, 508 | 509 | n_augmentation=args.n_augmentation, 510 | paraphrase_model_name_or_path=args.paraphrase_model_name_or_path, 511 | paraphrase_tokenizer_name_or_path=args.paraphrase_tokenizer_name_or_path, 512 | paraphrase_num_beams=args.paraphrase_num_beams, 513 | paraphrase_beam_group_size=args.paraphrase_beam_group_size, 514 | paraphrase_filtering_strategy=args.paraphrase_filtering_strategy, 515 | paraphrase_drop_strategy=args.paraphrase_drop_strategy, 516 | paraphrase_drop_chance_speed=args.paraphrase_drop_chance_speed, 517 | paraphrase_drop_chance_auc=args.paraphrase_drop_chance_auc, 518 | supervised_loss_share_fn=supervised_loss_share_fn, 519 | 520 | augmentation_data_path=args.augmentation_data_path 521 | ) 522 | 523 | # Save config 524 | with open(os.path.join(args.output_path, "config.json"), "w") as file: 525 | json.dump(vars(args), file, ensure_ascii=False, indent=1) 526 | 527 | 528 | if __name__ == '__main__': 529 | main() 530 | -------------------------------------------------------------------------------- /models/proto/protaugment-tmp.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 4 | 5 | import json 6 | import argparse 7 | 8 | from transformers import AutoTokenizer 9 | 10 | from models.encoders.bert_encoder import BERTEncoder 11 | from paraphrase.modeling import UnigramRandomDropParaphraseBatchPreparer, DBSParaphraseModel, BigramDropParaphraseBatchPreparer, BaseParaphraseBatchPreparer 12 | from paraphrase.utils.data import FewShotDataset, FewShotSSLParaphraseDataset, FewShotSSLFileDataset 13 | from utils.data import get_jsonl_data, FewShotDataLoader 14 | from utils.python import now, set_seeds 15 | import random 16 | import collections 17 | import os 18 | from typing import List, Dict, Callable, Union 19 | from tensorboardX import SummaryWriter 20 | import numpy as np 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as torch_functional 24 | from torch.autograd import Variable 25 | import warnings 26 | import logging 27 | from utils.few_shot import create_episode, create_ARSC_test_episode, create_ARSC_train_episode 28 | from utils.math import euclidean_dist, cosine_similarity 29 | 30 | logger = logging.getLogger(__name__) 31 | logger.setLevel(logging.DEBUG) 32 | 33 | warnings.simplefilter('ignore') 34 | 35 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 36 | 37 | 38 | class ProtAugmentNet(nn.Module): 39 | def __init__(self, encoder, metric="euclidean"): 40 | super(ProtAugmentNet, self).__init__() 41 | 42 | self.encoder: BERTEncoder = encoder 43 | self.metric = metric 44 | assert self.metric in ('euclidean', 'cosine') 45 | 46 | def loss(self, sample, supervised_loss_share: float = 0): 47 | """ 48 | :param supervised_loss_share: share of supervised loss in total loss 49 | :param sample: { 50 | "xs": [ 51 | [support_A_1, support_A_2, ...], 52 | [support_B_1, support_B_2, ...], 53 | [support_C_1, support_C_2, ...], 54 | ... 55 | ], 56 | "xq": [ 57 | [query_A_1, query_A_2, ...], 58 | [query_B_1, query_B_2, ...], 59 | [query_C_1, query_C_2, ...], 60 | ... 61 | ] 62 | } 63 | :return: 64 | """ 65 | xs = sample['xs'] # support 66 | xq = sample['xq'] # query 67 | 68 | n_class = len(xs) 69 | assert len(xq) == n_class 70 | n_support = len(xs[0]) 71 | n_query = len(xq[0]) 72 | 73 | target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long() 74 | target_inds = Variable(target_inds, requires_grad=False).to(device) 75 | 76 | has_augment = "x_augment" in sample 77 | if has_augment: 78 | augmentations = sample["x_augment"] 79 | 80 | n_augmentations_samples = len(sample["x_augment"]) 81 | n_augmentations_per_sample = [len(item['tgt_texts']) for item in augmentations] 82 | assert len(set(n_augmentations_per_sample)) == 1 83 | n_augmentations_per_sample = n_augmentations_per_sample[0] 84 | 85 | supports = [item["sentence"] for xs_ in xs for item in xs_] 86 | queries = [item["sentence"] for xq_ in xq for item in xq_] 87 | augmentations_supports = [[item2 for item2 in item["tgt_texts"]] for item in sample["x_augment"]] 88 | augmentation_queries = [item["src_text"] for item in sample["x_augment"]] 89 | 90 | # Encode 91 | x = supports + queries + [item2 for item1 in augmentations_supports for item2 in item1] + augmentation_queries 92 | z = self.encoder.embed_sentences(x) 93 | z_dim = z.size(-1) 94 | 95 | # Dispatch 96 | z_support = z[:len(supports)].view(n_class, n_support, z_dim).mean(dim=[1]) 97 | z_query = z[len(supports):len(supports) + len(queries)] 98 | z_aug_support = (z[len(supports) + len(queries):len(supports) + len(queries) + n_augmentations_per_sample * n_augmentations_samples] 99 | .view(n_augmentations_samples, n_augmentations_per_sample, z_dim).mean(dim=[1])) 100 | z_aug_query = z[-len(augmentation_queries):] 101 | else: 102 | # When not using augmentations 103 | supports = [item["sentence"] for xs_ in xs for item in xs_] 104 | queries = [item["sentence"] for xq_ in xq for item in xq_] 105 | 106 | # Encode 107 | x = supports + queries 108 | z = self.encoder.embed_sentences(x) 109 | z_dim = z.size(-1) 110 | 111 | # Dispatch 112 | z_support = z[:len(supports)].view(n_class, n_support, z_dim).mean(dim=[1]) 113 | z_query = z[len(supports):len(supports) + len(queries)] 114 | 115 | if self.metric == "euclidean": 116 | supervised_dists = euclidean_dist(z_query, z_support) 117 | if has_augment: 118 | unsupervised_dists = euclidean_dist(z_aug_query, z_aug_support) 119 | elif self.metric == "cosine": 120 | supervised_dists = (-cosine_similarity(z_query, z_support) + 1) * 5 121 | if has_augment: 122 | unsupervised_dists = (-cosine_similarity(z_aug_query, z_aug_support) + 1) * 5 123 | else: 124 | raise NotImplementedError 125 | 126 | from torch.nn import CrossEntropyLoss 127 | supervised_loss = CrossEntropyLoss()(-supervised_dists, target_inds.reshape(-1)) 128 | _, y_hat_supervised = (-supervised_dists).max(1) 129 | acc_val_supervised = torch.eq(y_hat_supervised, target_inds.reshape(-1)).float().mean() 130 | 131 | if has_augment: 132 | # Unsupervised loss 133 | unsupervised_target_inds = torch.range(0, n_augmentations_samples - 1).to(device).long() 134 | unsupervised_loss = CrossEntropyLoss()(-unsupervised_dists, unsupervised_target_inds) 135 | _, y_hat_unsupervised = (-unsupervised_dists).max(1) 136 | acc_val_unsupervised = torch.eq(y_hat_unsupervised, unsupervised_target_inds.reshape(-1)).float().mean() 137 | 138 | # Final loss 139 | assert 0 <= supervised_loss_share <= 1 140 | final_loss = (supervised_loss_share) * supervised_loss + (1 - supervised_loss_share) * unsupervised_loss 141 | 142 | return final_loss, { 143 | "metrics": { 144 | "supervised_acc": acc_val_supervised.item(), 145 | "unsupervised_acc": acc_val_unsupervised.item(), 146 | "supervised_loss": supervised_loss.item(), 147 | "unsupervised_loss": unsupervised_loss.item(), 148 | "supervised_loss_share": supervised_loss_share, 149 | "final_loss": final_loss.item(), 150 | }, 151 | "supervised_dists": supervised_dists, 152 | "unsupervised_dists": unsupervised_dists, 153 | "target": target_inds 154 | } 155 | 156 | return supervised_loss, { 157 | "metrics": { 158 | "acc": acc_val_supervised.item(), 159 | "loss": supervised_loss.item(), 160 | }, 161 | "dists": supervised_dists, 162 | "target": target_inds 163 | } 164 | 165 | def train_step(self, optimizer, episode, supervised_loss_share: float): 166 | self.train() 167 | optimizer.zero_grad() 168 | torch.cuda.empty_cache() 169 | loss, loss_dict = self.loss(episode, supervised_loss_share=supervised_loss_share) 170 | loss.backward() 171 | optimizer.step() 172 | 173 | return loss, loss_dict 174 | 175 | def test_step(self, dataset: FewShotDataset, n_episodes: int = 1000): 176 | metrics = collections.defaultdict(list) 177 | 178 | self.eval() 179 | for i in range(n_episodes): 180 | episode = dataset.get_episode() 181 | 182 | with torch.no_grad(): 183 | loss, loss_dict = self.loss(episode, supervised_loss_share=1) 184 | 185 | for k, v in loss_dict["metrics"].items(): 186 | metrics[k].append(v) 187 | 188 | return { 189 | key: np.mean(value) for key, value in metrics.items() 190 | } 191 | 192 | 193 | def run_proto( 194 | # Compulsory! 195 | data_path: str, 196 | train_labels_path: str, 197 | model_name_or_path: str, 198 | 199 | # Few-shot Stuff 200 | n_support: int, 201 | n_query: int, 202 | n_classes: int, 203 | metric: str = "euclidean", 204 | 205 | # Optional path to augmented data 206 | unlabeled_path: str = None, 207 | 208 | # Path training data ONLY (optional) 209 | train_path: str = None, 210 | 211 | # Validation & test 212 | valid_labels_path: str = None, 213 | test_labels_path: str = None, 214 | evaluate_every: int = 100, 215 | n_test_episodes: int = 1000, 216 | 217 | # Logging & Saving 218 | output_path: str = f'runs/{now()}', 219 | log_every: int = 10, 220 | 221 | # Training stuff 222 | max_iter: int = 10000, 223 | early_stop: int = None, 224 | 225 | # Augmentation & paraphrase 226 | n_augmentation: int = 5, 227 | paraphrase_model_name_or_path: str = None, 228 | paraphrase_tokenizer_name_or_path: str = None, 229 | paraphrase_num_beams: int = None, 230 | paraphrase_beam_group_size: int = None, 231 | paraphrase_diversity_penalty: float = None, 232 | paraphrase_filtering_strategy: str = None, 233 | paraphrase_drop_strategy: str = None, 234 | paraphrase_drop_chance_speed: str = None, 235 | paraphrase_drop_chance_auc: float = None, 236 | supervised_loss_share_fn: Callable[[int, int], float] = lambda x, y: 1 - (x / y), 237 | 238 | augmentation_data_path: str = None 239 | ): 240 | if output_path: 241 | if os.path.exists(output_path) and len(os.listdir(output_path)): 242 | raise FileExistsError(f"Output path {output_path} already exists. Exiting.") 243 | 244 | # -------------------- 245 | # Creating Log Writers 246 | # -------------------- 247 | os.makedirs(output_path) 248 | os.makedirs(os.path.join(output_path, "logs/train")) 249 | train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(output_path, "logs/train"), flush_secs=1, max_queue=1) 250 | valid_writer: SummaryWriter = None 251 | test_writer: SummaryWriter = None 252 | log_dict = dict(train=list()) 253 | 254 | # ---------- 255 | # Load model 256 | # ---------- 257 | bert = BERTEncoder(model_name_or_path).to(device) 258 | protonet: ProtAugmentNet = ProtAugmentNet(encoder=bert, metric=metric) 259 | optimizer = torch.optim.Adam(protonet.parameters(), lr=2e-5) 260 | 261 | # ------------------ 262 | # Load Train Dataset 263 | # ------------------ 264 | if augmentation_data_path: 265 | # If an augmentation data path is provided, uses those pre-generated augmentations 266 | train_dataset = FewShotSSLFileDataset( 267 | data_path=train_path if train_path else data_path, 268 | labels_path=train_labels_path, 269 | n_classes=n_classes, 270 | n_support=n_support, 271 | n_query=n_query, 272 | n_unlabeled=n_augmentation, 273 | unlabeled_file_path=augmentation_data_path, 274 | ) 275 | else: 276 | # --------------------- 277 | # Load paraphrase model 278 | # --------------------- 279 | paraphrase_model_device = torch.device("cpu") if "20newsgroup" in data_path else torch.device("cuda") 280 | logger.info(f"Paraphrase model device: {paraphrase_model_device}") 281 | paraphrase_tokenizer = AutoTokenizer.from_pretrained(paraphrase_tokenizer_name_or_path) 282 | if paraphrase_drop_strategy == "unigram": 283 | paraphrase_batch_preparer = UnigramRandomDropParaphraseBatchPreparer( 284 | tokenizer=paraphrase_tokenizer, 285 | auc=paraphrase_drop_chance_auc, 286 | drop_chance_speed=paraphrase_drop_chance_speed, 287 | device=paraphrase_model_device 288 | ) 289 | elif paraphrase_drop_strategy == "bigram": 290 | paraphrase_batch_preparer = BigramDropParaphraseBatchPreparer(tokenizer=paraphrase_tokenizer, device=paraphrase_model_device) 291 | else: 292 | paraphrase_batch_preparer = BaseParaphraseBatchPreparer(tokenizer=paraphrase_tokenizer, device=paraphrase_model_device) 293 | 294 | paraphrase_model = DBSParaphraseModel( 295 | model_name_or_path=paraphrase_model_name_or_path, 296 | tok_name_or_path=paraphrase_tokenizer_name_or_path, 297 | num_beams=paraphrase_num_beams, 298 | beam_group_size=paraphrase_beam_group_size, 299 | diversity_penalty=paraphrase_diversity_penalty, 300 | filtering_strategy=paraphrase_filtering_strategy, 301 | paraphrase_batch_preparer=paraphrase_batch_preparer, 302 | device=paraphrase_model_device 303 | ) 304 | 305 | train_dataset = FewShotSSLParaphraseDataset( 306 | data_path=train_path if train_path else data_path, 307 | labels_path=train_labels_path, 308 | n_classes=n_classes, 309 | n_support=n_support, 310 | n_query=n_query, 311 | n_unlabeled=n_augmentation, 312 | unlabeled_file_path=unlabeled_path, 313 | paraphrase_model=paraphrase_model 314 | ) 315 | logger.info(f"Train dataset has {len(train_dataset)} items") 316 | 317 | # --------- 318 | # Load data 319 | # --------- 320 | logger.info(f"train labels: {train_dataset.data.keys()}") 321 | valid_dataset: FewShotDataset = None 322 | if valid_labels_path: 323 | os.makedirs(os.path.join(output_path, "logs/valid")) 324 | valid_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/valid"), flush_secs=1, max_queue=1) 325 | log_dict["valid"] = list() 326 | valid_dataset = FewShotDataset(data_path=data_path, labels_path=valid_labels_path, n_classes=n_classes, n_support=n_support, n_query=n_query) 327 | logger.info(f"valid labels: {valid_dataset.data.keys()}") 328 | assert len(set(valid_dataset.data.keys()) & set(train_dataset.data.keys())) == 0 329 | 330 | test_dataset: FewShotDataset = None 331 | if test_labels_path: 332 | os.makedirs(os.path.join(output_path, "logs/test")) 333 | test_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/test"), flush_secs=1, max_queue=1) 334 | log_dict["test"] = list() 335 | test_dataset = FewShotDataset(data_path=data_path, labels_path=test_labels_path, n_classes=n_classes, n_support=n_support, n_query=n_query) 336 | logger.info(f"test labels: {test_dataset.data.keys()}") 337 | assert len(set(test_dataset.data.keys()) & set(train_dataset.data.keys())) == 0 338 | 339 | train_metrics = collections.defaultdict(list) 340 | n_eval_since_last_best = 0 341 | best_valid_acc = 0.0 342 | 343 | for step in range(max_iter): 344 | episode = train_dataset.get_episode() 345 | 346 | supervised_loss_share = supervised_loss_share_fn(step, max_iter) 347 | loss, loss_dict = protonet.train_step(optimizer=optimizer, episode=episode, supervised_loss_share=supervised_loss_share) 348 | 349 | for key, value in loss_dict["metrics"].items(): 350 | train_metrics[key].append(value) 351 | 352 | # Logging 353 | if (step + 1) % log_every == 0: 354 | for key, value in train_metrics.items(): 355 | train_writer.add_scalar(tag=key, scalar_value=np.mean(value), global_step=step) 356 | logger.info(f"train | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in train_metrics.items()])) 357 | log_dict["train"].append({ 358 | "metrics": [ 359 | { 360 | "tag": key, 361 | "value": np.mean(value) 362 | } 363 | for key, value in train_metrics.items() 364 | ], 365 | "global_step": step 366 | }) 367 | 368 | train_metrics = collections.defaultdict(list) 369 | 370 | if valid_labels_path or test_labels_path: 371 | if (step + 1) % evaluate_every == 0: 372 | for labels_path, writer, set_type, set_dataset in zip( 373 | [valid_labels_path, test_labels_path], 374 | [valid_writer, test_writer], 375 | ["valid", "test"], 376 | [valid_dataset, test_dataset] 377 | ): 378 | if set_dataset: 379 | 380 | set_results = protonet.test_step( 381 | dataset=set_dataset, 382 | n_episodes=n_test_episodes 383 | ) 384 | 385 | for key, val in set_results.items(): 386 | writer.add_scalar(tag=key, scalar_value=val, global_step=step) 387 | log_dict[set_type].append({ 388 | "metrics": [ 389 | { 390 | "tag": key, 391 | "value": val 392 | } 393 | for key, val in set_results.items() 394 | ], 395 | "global_step": step 396 | }) 397 | 398 | logger.info(f"{set_type} | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in set_results.items()])) 399 | if set_type == "valid": 400 | if set_results["acc"] > best_valid_acc: 401 | best_valid_acc = set_results["acc"] 402 | n_eval_since_last_best = 0 403 | logger.info(f"Better eval results!") 404 | else: 405 | n_eval_since_last_best += 1 406 | logger.info(f"Worse eval results ({n_eval_since_last_best}/{early_stop})") 407 | 408 | if early_stop and n_eval_since_last_best >= early_stop: 409 | logger.warning(f"Early-stopping.") 410 | break 411 | 412 | with open(os.path.join(output_path, 'metrics.json'), "w") as file: 413 | json.dump(log_dict, file, ensure_ascii=False) 414 | 415 | 416 | def main(): 417 | parser = argparse.ArgumentParser() 418 | parser.add_argument("--data-path", type=str, required=True, help="Path to data") 419 | parser.add_argument("--train-labels-path", type=str, required=True, help="Path to train labels") 420 | parser.add_argument("--train-path", type=str, help="Path to training data (if provided, picks training data from this path instead of --data-path") 421 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use") 422 | 423 | # Few-Shot related stuff 424 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class") 425 | parser.add_argument("--n-query", type=int, default=5, help="Number of query points for each class") 426 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode") 427 | parser.add_argument("--metric", type=str, default="euclidean", help="Metric to use", choices=("euclidean", "cosine")) 428 | 429 | # Path to augmented data 430 | parser.add_argument("--unlabeled-path", type=str, required=True, help="Path to data containing augmentations used for consistency") 431 | 432 | # Validation & test 433 | parser.add_argument("--valid-labels-path", type=str, required=True, help="Path to valid labels") 434 | parser.add_argument("--test-labels-path", type=str, required=True, help="Path to test labels") 435 | parser.add_argument("--evaluate-every", type=int, default=100, help="Number of training episodes between each evaluation (on both valid, test)") 436 | parser.add_argument("--n-test-episodes", type=int, default=1000, help="Number of episodes during evaluation (valid, test)") 437 | 438 | # Logging & Saving 439 | parser.add_argument("--output-path", type=str, default=f'runs/{now()}') 440 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging") 441 | 442 | # Training stuff 443 | parser.add_argument("--max-iter", type=int, default=10000, help="Max number of training episodes") 444 | parser.add_argument("--early-stop", type=int, default=0, help="Number of worse evaluation steps before stopping. 0=disabled") 445 | 446 | # Augmentation & Paraphrase 447 | parser.add_argument("--n-augmentation", type=int, help="Number of unlabeled data points per class (proto++)", default=0) 448 | parser.add_argument("--paraphrase-model-name-or-path", type=str) 449 | parser.add_argument("--paraphrase-tokenizer-name-or-path", type=str) 450 | parser.add_argument("--paraphrase-num-beams", type=int) 451 | parser.add_argument("--paraphrase-beam-group-size", type=int) 452 | parser.add_argument("--paraphrase-diversity-penalty", type=float) 453 | parser.add_argument("--paraphrase-filtering-strategy", type=str) 454 | parser.add_argument("--paraphrase-drop-strategy", type=str) 455 | parser.add_argument("--paraphrase-drop-chance-speed", type=str) 456 | parser.add_argument("--paraphrase-drop-chance-auc", type=float) 457 | 458 | # Augmentation file path (optional, but if provided it will be used) 459 | parser.add_argument("--augmentation-data-path", type=str) 460 | 461 | # Seed 462 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set") 463 | 464 | # Supervised loss share 465 | parser.add_argument("--supervised-loss-share-power", default=1.0, type=float, help="supervised_loss_share = 1 - (x/y) ** ") 466 | 467 | args = parser.parse_args() 468 | logger.debug(f"Received args: {json.dumps(args.__dict__, sort_keys=True, ensure_ascii=False, indent=1)}") 469 | # Set random seed 470 | set_seeds(args.seed) 471 | 472 | # Check if data path(s) exist 473 | for arg in [args.data_path, args.train_labels_path, args.valid_labels_path, args.test_labels_path]: 474 | if arg and not os.path.exists(arg): 475 | raise FileNotFoundError(f"Data @ {arg} not found.") 476 | 477 | # Create supervised_loss_share_fn 478 | def get_supervised_loss_share_fn(supervised_loss_share_power: Union[int, float]) -> Callable[[int, int], float]: 479 | def _supervised_loss_share_fn(current_step: int, max_steps: int) -> float: 480 | assert current_step <= max_steps 481 | return 1 - (current_step / max_steps) ** supervised_loss_share_power 482 | 483 | return _supervised_loss_share_fn 484 | 485 | supervised_loss_share_fn = get_supervised_loss_share_fn(args.supervised_loss_share_power) 486 | 487 | # Run 488 | run_proto( 489 | data_path=args.data_path, 490 | train_labels_path=args.train_labels_path, 491 | train_path=args.train_path, 492 | model_name_or_path=args.model_name_or_path, 493 | n_support=args.n_support, 494 | n_query=args.n_query, 495 | n_classes=args.n_classes, 496 | metric=args.metric, 497 | unlabeled_path=args.unlabeled_path, 498 | 499 | valid_labels_path=args.valid_labels_path, 500 | test_labels_path=args.test_labels_path, 501 | evaluate_every=args.evaluate_every, 502 | n_test_episodes=args.n_test_episodes, 503 | 504 | output_path=args.output_path, 505 | log_every=args.log_every, 506 | max_iter=args.max_iter, 507 | early_stop=args.early_stop, 508 | 509 | n_augmentation=args.n_augmentation, 510 | paraphrase_model_name_or_path=args.paraphrase_model_name_or_path, 511 | paraphrase_tokenizer_name_or_path=args.paraphrase_tokenizer_name_or_path, 512 | paraphrase_num_beams=args.paraphrase_num_beams, 513 | paraphrase_beam_group_size=args.paraphrase_beam_group_size, 514 | paraphrase_filtering_strategy=args.paraphrase_filtering_strategy, 515 | paraphrase_drop_strategy=args.paraphrase_drop_strategy, 516 | paraphrase_drop_chance_speed=args.paraphrase_drop_chance_speed, 517 | paraphrase_drop_chance_auc=args.paraphrase_drop_chance_auc, 518 | supervised_loss_share_fn=supervised_loss_share_fn, 519 | 520 | augmentation_data_path=args.augmentation_data_path 521 | ) 522 | 523 | # Save config 524 | with open(os.path.join(args.output_path, "config.json"), "w") as file: 525 | json.dump(vars(args), file, ensure_ascii=False, indent=1) 526 | 527 | 528 | if __name__ == '__main__': 529 | main() 530 | -------------------------------------------------------------------------------- /models/proto/protonet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from models.encoders.bert_encoder import BERTEncoder 4 | from utils.data import get_jsonl_data, FewShotDataLoader 5 | from utils.python import now, set_seeds 6 | import random 7 | import collections 8 | import os 9 | from typing import List, Dict, Callable, Union 10 | from tensorboardX import SummaryWriter 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as torch_functional 15 | from torch.autograd import Variable 16 | import warnings 17 | import logging 18 | from utils.few_shot import create_episode, create_ARSC_test_episode, create_ARSC_train_episode 19 | from utils.math import euclidean_dist, cosine_similarity 20 | 21 | logging.basicConfig() 22 | logger = logging.getLogger(__name__) 23 | logger.setLevel(logging.DEBUG) 24 | 25 | warnings.simplefilter('ignore') 26 | 27 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 28 | 29 | 30 | class ProtoNet(nn.Module): 31 | def __init__(self, encoder: BERTEncoder, metric="euclidean"): 32 | super(ProtoNet, self).__init__() 33 | 34 | self.encoder: BERTEncoder = encoder 35 | self.metric = metric 36 | assert self.metric in ('euclidean', 'cosine') 37 | 38 | def loss(self, sample, supervised_loss_share: float = 0): 39 | """ 40 | :param supervised_loss_share: share of supervised loss in total loss 41 | :param sample: { 42 | "xs": [ 43 | [support_A_1, support_A_2, ...], 44 | [support_B_1, support_B_2, ...], 45 | [support_C_1, support_C_2, ...], 46 | ... 47 | ], 48 | "xq": [ 49 | [query_A_1, query_A_2, ...], 50 | [query_B_1, query_B_2, ...], 51 | [query_C_1, query_C_2, ...], 52 | ... 53 | ] 54 | } 55 | :return: 56 | """ 57 | xs = sample['xs'] # support 58 | xq = sample['xq'] # query 59 | 60 | n_class = len(xs) 61 | assert len(xq) == n_class 62 | n_support = len(xs[0]) 63 | n_query = len(xq[0]) 64 | 65 | target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long() 66 | target_inds = Variable(target_inds, requires_grad=False).to(device) 67 | 68 | has_augmentations = ("x_augment" in sample) 69 | 70 | if has_augmentations: 71 | # When using augmentations 72 | augmentations = sample["x_augment"] 73 | 74 | n_augmentations_samples = len(sample["x_augment"]) 75 | n_augmentations_per_sample = [len(item['augmentations']) for item in augmentations] 76 | assert set(n_augmentations_per_sample) == {5} 77 | n_augmentations_per_sample = n_augmentations_per_sample[0] 78 | 79 | supports = [item["sentence"] for xs_ in xs for item in xs_] 80 | queries = [item["sentence"] for xq_ in xq for item in xq_] 81 | augmentations_supports = [[item2["text"] for item2 in item["augmentations"]] for item in sample["x_augment"]] 82 | augmentation_queries = [item["sentence"] for item in sample["x_augment"]] 83 | 84 | # Encode 85 | x = supports + queries + [item2 for item1 in augmentations_supports for item2 in item1] + augmentation_queries 86 | z = self.encoder.embed_sentences(x) 87 | z_dim = z.size(-1) 88 | 89 | # Dispatch 90 | z_support = z[:len(supports)].view(n_class, n_support, z_dim).mean(dim=[1]) 91 | z_query = z[len(supports):len(supports) + len(queries)] 92 | z_aug_support = (z[len(supports) + len(queries):len(supports) + len(queries) + n_augmentations_per_sample * n_augmentations_samples] 93 | .view(n_augmentations_samples, n_augmentations_per_sample, z_dim).mean(dim=[1])) 94 | z_aug_query = z[-len(augmentation_queries):] 95 | else: 96 | # When not using augmentations 97 | supports = [item["sentence"] for xs_ in xs for item in xs_] 98 | queries = [item["sentence"] for xq_ in xq for item in xq_] 99 | 100 | # Encode 101 | x = supports + queries 102 | z = self.encoder.embed_sentences(x) 103 | z_dim = z.size(-1) 104 | 105 | # Dispatch 106 | z_support = z[:len(supports)].view(n_class, n_support, z_dim).mean(dim=[1]) 107 | z_query = z[len(supports):len(supports) + len(queries)] 108 | 109 | if self.metric == "euclidean": 110 | supervised_dists = euclidean_dist(z_query, z_support) 111 | if has_augmentations: 112 | unsupervised_dists = euclidean_dist(z_aug_query, z_aug_support) 113 | elif self.metric == "cosine": 114 | supervised_dists = (-cosine_similarity(z_query, z_support) + 1) * 5 115 | if has_augmentations: 116 | unsupervised_dists = (-cosine_similarity(z_aug_query, z_aug_support) + 1) * 5 117 | else: 118 | raise NotImplementedError 119 | 120 | # Supervised loss 121 | # -- legacy 122 | # log_p_y = torch_functional.log_softmax(-supervised_dists, dim=1).view(n_class, n_query, -1) 123 | # dists.view(n_class, n_query, -1) 124 | # loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean() 125 | # -- NEW 126 | from torch.nn import CrossEntropyLoss 127 | supervised_loss = CrossEntropyLoss()(-supervised_dists, target_inds.reshape(-1)) 128 | _, y_hat_supervised = (-supervised_dists).max(1) 129 | acc_val_supervised = torch.eq(y_hat_supervised, target_inds.reshape(-1)).float().mean() 130 | 131 | if has_augmentations: 132 | # Unsupervised loss 133 | unsupervised_target_inds = torch.range(0, n_augmentations_samples - 1).to(device).long() 134 | unsupervised_loss = CrossEntropyLoss()(-unsupervised_dists, unsupervised_target_inds) 135 | _, y_hat_unsupervised = (-unsupervised_dists).max(1) 136 | acc_val_unsupervised = torch.eq(y_hat_unsupervised, unsupervised_target_inds.reshape(-1)).float().mean() 137 | 138 | # Final loss 139 | assert 0 <= supervised_loss_share <= 1 140 | final_loss = (supervised_loss_share) * supervised_loss + (1 - supervised_loss_share) * unsupervised_loss 141 | 142 | return final_loss, { 143 | "metrics": { 144 | "supervised_acc": acc_val_supervised.item(), 145 | "unsupervised_acc": acc_val_unsupervised.item(), 146 | "supervised_loss": supervised_loss.item(), 147 | "unsupervised_loss": unsupervised_loss.item(), 148 | "supervised_loss_share": supervised_loss_share, 149 | "final_loss": final_loss.item(), 150 | }, 151 | "supervised_dists": supervised_dists, 152 | "unsupervised_dists": unsupervised_dists, 153 | "target": target_inds 154 | } 155 | 156 | return supervised_loss, { 157 | "metrics": { 158 | "acc": acc_val_supervised.item(), 159 | "loss": supervised_loss.item(), 160 | }, 161 | "dists": supervised_dists, 162 | "target": target_inds 163 | } 164 | 165 | def loss_softkmeans(self, sample): 166 | xs = sample['xs'] # support 167 | xq = sample['xq'] # query 168 | xu = sample['xu'] # unlabeled 169 | 170 | n_class = len(xs) 171 | assert len(xq) == n_class 172 | n_support = len(xs[0]) 173 | n_query = len(xq[0]) 174 | 175 | target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long() 176 | target_inds = Variable(target_inds, requires_grad=False).to(device) 177 | 178 | x = [item["sentence"] for xs_ in xs for item in xs_] + [item["sentence"] for xq_ in xq for item in xq_] + [item["sentence"] for item in xu] 179 | z = self.encoder.embed_sentences(x) 180 | z_dim = z.size(-1) 181 | 182 | zs = z[:n_class * n_support] 183 | z_proto = z[:n_class * n_support].view(n_class, n_support, z_dim).mean(1) 184 | zq = z[n_class * n_support: (n_class * n_support) + (n_class * n_query)] 185 | zu = z[(n_class * n_support) + (n_class * n_query):] 186 | 187 | distances_to_proto = euclidean_dist( 188 | torch.cat((zs, zu)), 189 | z_proto 190 | ) 191 | 192 | distances_to_proto_normed = torch.nn.Softmax(dim=-1)(-distances_to_proto) 193 | 194 | refined_protos = list() 195 | for class_ix in range(n_class): 196 | z = torch.cat( 197 | (zs[class_ix * n_support: (class_ix + 1) * n_support], zu) 198 | ) 199 | d = torch.cat( 200 | (torch.ones(n_support).to(device), 201 | distances_to_proto_normed[(n_class * n_support):, class_ix]) 202 | ) 203 | refined_proto = ((z.t() * d).sum(1) / d.sum()) 204 | refined_protos.append(refined_proto.view(1, -1)) 205 | refined_protos = torch.cat(refined_protos) 206 | 207 | if self.metric == "euclidean": 208 | dists = euclidean_dist(zq, refined_protos) 209 | elif self.metric == "cosine": 210 | dists = (-cosine_similarity(zq, refined_protos) + 1) * 5 211 | else: 212 | raise NotImplementedError 213 | 214 | log_p_y = torch_functional.log_softmax(-dists, dim=1).view(n_class, n_query, -1) 215 | dists.view(n_class, n_query, -1) 216 | loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean() 217 | _, y_hat = log_p_y.max(2) 218 | acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean() 219 | 220 | return loss_val, { 221 | 'loss': loss_val.item(), 222 | "metrics": { 223 | "acc": acc_val.item(), 224 | "loss": loss_val.item(), 225 | }, 226 | 'dists': dists, 227 | 'target': target_inds 228 | } 229 | 230 | def loss_consistency(self, sample): 231 | x_augment = sample["x_augment"] 232 | n_samples = len(x_augment) 233 | 234 | # x_augment = [(A, [A_1, A_2, ..., A_n]), (B, [B_1, B_2, ..., B_m])] 235 | lengths = [1 + len(augments) for sentence, augments in x_augment] 236 | 237 | x = list() 238 | for sentence, augs in x_augment: 239 | x.append(sentence) 240 | x += augs 241 | 242 | z = self.encoder.embed_sentences(x) 243 | assert len(z) == sum(lengths) 244 | 245 | i = 0 246 | original_embeddings = list() 247 | augmented_embeddings = list() 248 | for length in lengths: 249 | original_embeddings.append(z[i]) 250 | augmented_embeddings.append(z[i + 1:i + length + 1]) 251 | i += length 252 | 253 | augmented_embeddings = [a.mean(0) for a in augmented_embeddings] 254 | if self.metric == "euclidean": 255 | dists = euclidean_dist(original_embeddings, augmented_embeddings) 256 | elif self.metric == "cosine": 257 | dists = (-cosine_similarity(original_embeddings, augmented_embeddings) + 1) * 5 258 | else: 259 | raise NotImplementedError 260 | 261 | log_p_y = torch_functional.log_softmax(-dists, dim=1).view(n_samples, n_samples, -1) 262 | 263 | # loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean() 264 | # _, y_hat = log_p_y.max(2) 265 | # acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean() 266 | # 267 | # return loss_val, { 268 | # 'loss': loss_val.item(), 269 | # 'acc': acc_val.item(), 270 | # 'dists': dists, 271 | # 'target': target_inds 272 | # } 273 | 274 | def train_step(self, optimizer, episode, supervised_loss_share: float, unlabeled: bool = False): 275 | self.train() 276 | optimizer.zero_grad() 277 | torch.cuda.empty_cache() 278 | if unlabeled: 279 | loss, loss_dict = self.loss_softkmeans(episode) 280 | else: 281 | loss, loss_dict = self.loss(episode, supervised_loss_share=supervised_loss_share) 282 | loss.backward() 283 | optimizer.step() 284 | 285 | return loss, loss_dict 286 | 287 | def test_step(self, 288 | data_loader: FewShotDataLoader, 289 | n_support: int, 290 | n_query: int, 291 | n_classes: int, 292 | n_unlabeled: int = 0, 293 | n_episodes: int = 1000): 294 | metrics = collections.defaultdict(list) 295 | 296 | self.eval() 297 | for i in range(n_episodes): 298 | episode = data_loader.create_episode( 299 | n_support=n_support, 300 | n_query=n_query, 301 | n_unlabeled=n_unlabeled, 302 | n_classes=n_classes, 303 | n_augment=0 304 | ) 305 | 306 | with torch.no_grad(): 307 | if n_unlabeled: 308 | loss, loss_dict = self.loss_softkmeans(episode) 309 | else: 310 | loss, loss_dict = self.loss(episode, supervised_loss_share=0) 311 | 312 | for k, v in loss_dict["metrics"].items(): 313 | metrics[k].append(v) 314 | 315 | return { 316 | key: np.mean(value) for key, value in metrics.items() 317 | } 318 | 319 | def train_step_ARSC(self, data_path: str, optimizer, n_unlabeled: int): 320 | episode = create_ARSC_train_episode(prefix=data_path, n_support=5, n_query=5, n_unlabeled=n_unlabeled) 321 | 322 | self.train() 323 | optimizer.zero_grad() 324 | torch.cuda.empty_cache() 325 | if n_unlabeled: 326 | loss, loss_dict = self.loss_softkmeans(episode) 327 | else: 328 | loss, loss_dict = self.loss(episode) 329 | loss.backward() 330 | optimizer.step() 331 | 332 | return loss, loss_dict 333 | 334 | def test_step_ARSC(self, data_path: str, n_unlabeled=0, n_episodes=1000, set_type="test"): 335 | assert set_type in ("dev", "test") 336 | metrics = collections.defaultdict(list) 337 | self.eval() 338 | for i in range(n_episodes): 339 | episode = create_ARSC_test_episode(prefix=data_path, n_query=5, n_unlabeled=n_unlabeled, set_type=set_type) 340 | 341 | with torch.no_grad(): 342 | if n_unlabeled: 343 | loss, loss_dict = self.loss_softkmeans(episode) 344 | else: 345 | loss, loss_dict = self.loss(episode) 346 | 347 | for k, v in loss_dict["metrics"].items(): 348 | metrics[k].append(v) 349 | 350 | return { 351 | key: np.mean(value) for key, value in metrics.items() 352 | } 353 | 354 | 355 | def run_proto( 356 | train_path: str, 357 | model_name_or_path: str, 358 | n_support: int, 359 | n_query: int, 360 | n_classes: int, 361 | unlabeled_path: str = None, 362 | valid_path: str = None, 363 | test_path: str = None, 364 | n_unlabeled: int = 0, 365 | n_augment: int = 0, 366 | output_path: str = f'runs/{now()}', 367 | max_iter: int = 10000, 368 | evaluate_every: int = 100, 369 | early_stop: int = None, 370 | n_test_episodes: int = 1000, 371 | log_every: int = 10, 372 | metric: str = "euclidean", 373 | arsc_format: bool = False, 374 | data_path: str = None, 375 | supervised_loss_share_fn: Callable[[int, int], float] = lambda x, y: 1 - (x / y) 376 | ): 377 | if output_path: 378 | if os.path.exists(output_path) and len(os.listdir(output_path)): 379 | raise FileExistsError(f"Output path {output_path} already exists. Exiting.") 380 | 381 | # -------------------- 382 | # Creating Log Writers 383 | # -------------------- 384 | os.makedirs(output_path) 385 | os.makedirs(os.path.join(output_path, "logs/train")) 386 | train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(output_path, "logs/train"), flush_secs=1, max_queue=1) 387 | valid_writer: SummaryWriter = None 388 | test_writer: SummaryWriter = None 389 | log_dict = dict(train=list()) 390 | 391 | if valid_path: 392 | os.makedirs(os.path.join(output_path, "logs/valid")) 393 | valid_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/valid"), flush_secs=1, max_queue=1) 394 | log_dict["valid"] = list() 395 | if test_path: 396 | os.makedirs(os.path.join(output_path, "logs/test")) 397 | test_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/test"), flush_secs=1, max_queue=1) 398 | log_dict["test"] = list() 399 | 400 | # Load model 401 | bert = BERTEncoder(model_name_or_path).to(device) 402 | protonet = ProtoNet(encoder=bert, metric=metric) 403 | optimizer = torch.optim.Adam(protonet.parameters(), lr=2e-5) 404 | 405 | # Load data 406 | if not arsc_format: 407 | train_data_loader = FewShotDataLoader(train_path, unlabeled_file_path=unlabeled_path) 408 | logger.info(f"train labels: {train_data_loader.data_dict.keys()}") 409 | 410 | if valid_path: 411 | valid_data_loader = FewShotDataLoader(valid_path) 412 | logger.info(f"valid labels: {valid_data_loader.data_dict.keys()}") 413 | else: 414 | valid_data_loader = None 415 | 416 | if test_path: 417 | test_data_loader = FewShotDataLoader(test_path) 418 | logger.info(f"test labels: {test_data_loader.data_dict.keys()}") 419 | else: 420 | test_data_loader = None 421 | else: 422 | train_data_loader = None 423 | valid_data_loader = None 424 | test_data_loader = None 425 | 426 | train_metrics = collections.defaultdict(list) 427 | n_eval_since_last_best = 0 428 | best_valid_acc = 0.0 429 | 430 | for step in range(max_iter): 431 | if not arsc_format: 432 | episode = train_data_loader.create_episode( 433 | n_support=n_support, 434 | n_query=n_query, 435 | n_classes=n_classes, 436 | n_unlabeled=n_unlabeled, 437 | n_augment=n_augment 438 | ) 439 | else: 440 | episode = create_ARSC_train_episode(n_support=5, n_query=5) 441 | 442 | supervised_loss_share = supervised_loss_share_fn(step, max_iter) 443 | loss, loss_dict = protonet.train_step(optimizer=optimizer, episode=episode, unlabeled=(n_unlabeled > 0), supervised_loss_share=supervised_loss_share) 444 | 445 | for key, value in loss_dict["metrics"].items(): 446 | train_metrics[key].append(value) 447 | 448 | # Logging 449 | if (step + 1) % log_every == 0: 450 | for key, value in train_metrics.items(): 451 | train_writer.add_scalar(tag=key, scalar_value=np.mean(value), global_step=step) 452 | logger.info(f"train | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in train_metrics.items()])) 453 | log_dict["train"].append({ 454 | "metrics": [ 455 | { 456 | "tag": key, 457 | "value": np.mean(value) 458 | } 459 | for key, value in train_metrics.items() 460 | ], 461 | "global_step": step 462 | }) 463 | 464 | train_metrics = collections.defaultdict(list) 465 | 466 | if valid_path or test_path: 467 | if (step + 1) % evaluate_every == 0: 468 | for path, writer, set_type, set_data_loader in zip( 469 | [valid_path, test_path], 470 | [valid_writer, test_writer], 471 | ["valid", "test"], 472 | [valid_data_loader, valid_data_loader] 473 | ): 474 | if path: 475 | if not arsc_format: 476 | set_results = protonet.test_step( 477 | data_loader=set_data_loader, 478 | n_unlabeled=n_unlabeled, 479 | n_support=n_support, 480 | n_query=n_query, 481 | n_classes=n_classes, 482 | n_episodes=n_test_episodes 483 | ) 484 | else: 485 | set_results = protonet.test_step_ARSC( 486 | data_path=data_path, 487 | n_unlabeled=n_unlabeled, 488 | n_episodes=n_test_episodes, 489 | set_type={"valid": "dev", "test": "test"}[set_type] 490 | ) 491 | 492 | for key, val in set_results.items(): 493 | writer.add_scalar(tag=key, scalar_value=val, global_step=step) 494 | log_dict[set_type].append({ 495 | "metrics": [ 496 | { 497 | "tag": key, 498 | "value": val 499 | } 500 | for key, val in set_results.items() 501 | ], 502 | "global_step": step 503 | }) 504 | 505 | logger.info(f"{set_type} | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in set_results.items()])) 506 | if set_type == "valid": 507 | if set_results["acc"] > best_valid_acc: 508 | best_valid_acc = set_results["acc"] 509 | n_eval_since_last_best = 0 510 | logger.info(f"Better eval results!") 511 | else: 512 | n_eval_since_last_best += 1 513 | logger.info(f"Worse eval results ({n_eval_since_last_best}/{early_stop})") 514 | 515 | if early_stop and n_eval_since_last_best >= early_stop: 516 | logger.warning(f"Early-stopping.") 517 | break 518 | with open(os.path.join(output_path, 'metrics.json'), "w") as file: 519 | json.dump(log_dict, file, ensure_ascii=False) 520 | 521 | 522 | def main(): 523 | parser = argparse.ArgumentParser() 524 | parser.add_argument("--train-path", type=str, required=True, help="Path to training data") 525 | parser.add_argument("--valid-path", type=str, default=None, help="Path to validation data") 526 | parser.add_argument("--test-path", type=str, default=None, help="Path to testing data") 527 | parser.add_argument("--unlabeled-path", type=str, default=None, help="Path to data containing augmentations used for consistency") 528 | parser.add_argument("--data-path", type=str, default=None, help="Path to data (ARSC only)") 529 | 530 | parser.add_argument("--output-path", type=str, default=f'runs/{now()}') 531 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use") 532 | parser.add_argument("--n-unlabeled", type=int, help="Number of unlabeled data points per class (proto++)", default=0) 533 | parser.add_argument("--max-iter", type=int, default=10000, help="Max number of training episodes") 534 | parser.add_argument("--evaluate-every", type=int, default=100, help="Number of training episodes between each evaluation (on both valid, test)") 535 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging") 536 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set") 537 | parser.add_argument("--early-stop", type=int, default=0, help="Number of worse evaluation steps before stopping. 0=disabled") 538 | 539 | # Few-Shot related stuff 540 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class") 541 | parser.add_argument("--n-query", type=int, default=5, help="Number of query points for each class") 542 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode") 543 | parser.add_argument("--n-augment", type=int, default=0, help="Number of augmented samples to take") 544 | parser.add_argument("--n-test-episodes", type=int, default=1000, help="Number of episodes during evaluation (valid, test)") 545 | 546 | # Metric to use in proto distance calculation 547 | parser.add_argument("--metric", type=str, default="euclidean", help="Metric to use", choices=("euclidean", "cosine")) 548 | 549 | # Supervised loss share 550 | parser.add_argument("--supervised-loss-share-power", default=1.0, type=float, help="supervised_loss_share = 1 - (x/y) ** ") 551 | 552 | # ARSC data 553 | parser.add_argument("--arsc-format", default=False, action="store_true", help="Using ARSC few-shot format") 554 | args = parser.parse_args() 555 | 556 | # Set random seed 557 | set_seeds(args.seed) 558 | 559 | # Check if data path(s) exist 560 | for arg in [args.train_path, args.valid_path, args.test_path]: 561 | if arg and not os.path.exists(arg): 562 | raise FileNotFoundError(f"Data @ {arg} not found.") 563 | 564 | # Create supervised_loss_share_fn 565 | def get_supervised_loss_share_fn(supervised_loss_share_power: Union[int, float]) -> Callable[[int, int], float]: 566 | def _supervised_loss_share_fn(current_step: int, max_steps: int) -> float: 567 | assert current_step <= max_steps 568 | return 1 - (current_step / max_steps) ** supervised_loss_share_power 569 | 570 | return _supervised_loss_share_fn 571 | 572 | supervised_loss_share_fn = get_supervised_loss_share_fn(args.supervised_loss_share_power) 573 | 574 | # Run 575 | run_proto( 576 | train_path=args.train_path, 577 | valid_path=args.valid_path, 578 | test_path=args.test_path, 579 | output_path=args.output_path, 580 | unlabeled_path=args.unlabeled_path, 581 | 582 | model_name_or_path=args.model_name_or_path, 583 | n_unlabeled=args.n_unlabeled, 584 | 585 | n_support=args.n_support, 586 | n_query=args.n_query, 587 | n_classes=args.n_classes, 588 | n_test_episodes=args.n_test_episodes, 589 | n_augment=args.n_augment, 590 | 591 | max_iter=args.max_iter, 592 | evaluate_every=args.evaluate_every, 593 | 594 | metric=args.metric, 595 | early_stop=args.early_stop, 596 | arsc_format=args.arsc_format, 597 | data_path=args.data_path, 598 | log_every=args.log_every, 599 | 600 | supervised_loss_share_fn=supervised_loss_share_fn 601 | 602 | ) 603 | 604 | # Save config 605 | with open(os.path.join(args.output_path, "config.json"), "w") as file: 606 | json.dump(vars(args), file, ensure_ascii=False) 607 | 608 | 609 | if __name__ == '__main__': 610 | main() 611 | -------------------------------------------------------------------------------- /models/bert_baseline/baseline.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | import tqdm 5 | 6 | from models.encoders.bert_encoder import BERTEncoder 7 | from utils.data import get_jsonl_data 8 | from utils.python import now, set_seeds 9 | from utils.few_shot import create_ARSC_train_episode, get_ARSC_test_tasks 10 | import random 11 | import collections 12 | import os 13 | from typing import List, Dict 14 | from tensorboardX import SummaryWriter 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import warnings 19 | import logging 20 | from utils.math import euclidean_dist, cosine_similarity 21 | 22 | logging.basicConfig() 23 | logger = logging.getLogger(__name__) 24 | logger.setLevel(logging.DEBUG) 25 | 26 | warnings.simplefilter('ignore') 27 | 28 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 29 | 30 | 31 | class BaselineNet(nn.Module): 32 | def __init__( 33 | self, 34 | encoder, 35 | is_pp: bool = False, 36 | hidden_dim: int = 768, 37 | metric: str = "cosine" 38 | ): 39 | super(BaselineNet, self).__init__() 40 | self.encoder = encoder 41 | self.dropout = nn.Dropout(p=0.25).to(device) 42 | self.is_pp = is_pp 43 | self.hidden_dim = hidden_dim 44 | self.metric = metric 45 | assert self.metric in ("euclidean", "cosine") 46 | 47 | def train_ARSC_one_episode( 48 | self, 49 | data_path: str, 50 | n_iter: int = 100, 51 | ): 52 | self.train() 53 | episode = create_ARSC_train_episode(prefix=data_path, n_support=5, n_query=0, n_unlabeled=0) 54 | n_episode_classes = len(episode["xs"]) 55 | loss_fn = nn.CrossEntropyLoss() 56 | episode_matrix = None 57 | episode_classifier = None 58 | if self.is_pp: 59 | with torch.no_grad(): 60 | init_matrix = np.array([ 61 | [ 62 | self.encoder.forward([sentence]).squeeze().cpu().detach().numpy() 63 | for sentence in episode["xs"][c] 64 | ] 65 | for c in range(n_episode_classes) 66 | ]).mean(1) 67 | 68 | episode_matrix = torch.Tensor(init_matrix).to(device) 69 | episode_matrix.requires_grad = True 70 | optimizer = torch.optim.Adam(list(self.parameters()) + [episode_matrix], lr=2e-5) 71 | else: 72 | episode_classifier = nn.Linear(in_features=self.hidden_dim, out_features=n_episode_classes).to(device) 73 | optimizer = torch.optim.Adam(list(self.parameters()) + list(episode_classifier.parameters()), lr=2e-5) 74 | 75 | # Train on support 76 | iter_bar = tqdm.tqdm(range(n_iter)) 77 | losses = list() 78 | accuracies = list() 79 | 80 | for _ in iter_bar: 81 | optimizer.zero_grad() 82 | 83 | sentences = [sentence for sentence_list in episode["xs"] for sentence in sentence_list] 84 | labels = torch.Tensor([ix for ix, sl in enumerate(episode["xs"]) for _ in sl]).long().to(device) 85 | z = self.encoder(sentences) 86 | 87 | # z = batch_embeddings 88 | 89 | if self.is_pp: 90 | if self.metric == "cosine": 91 | z = cosine_similarity(z, episode_matrix) * 5 92 | elif self.metric == "euclidean": 93 | z = -euclidean_dist(z, episode_matrix) 94 | else: 95 | raise NotImplementedError 96 | else: 97 | z = self.dropout(z) 98 | z = episode_classifier(z) 99 | 100 | loss = loss_fn(input=z, target=labels) 101 | acc = (z.argmax(1) == labels).float().mean() 102 | loss.backward() 103 | optimizer.step() 104 | iter_bar.set_description(f"{loss.item():.3f} | {acc.item():.3f}") 105 | losses.append(loss.item()) 106 | accuracies.append(acc.item()) 107 | return { 108 | "loss": np.mean(losses), 109 | "acc": np.mean(accuracies) 110 | } 111 | 112 | def run_ARSC( 113 | self, 114 | data_path: str, 115 | train_summary_writer: SummaryWriter = None, 116 | valid_summary_writer: SummaryWriter = None, 117 | test_summary_writer: SummaryWriter = None, 118 | n_episodes: int = 1000, 119 | n_train_iter: int = 100, 120 | train_eval_every: int = 100, 121 | n_test_iter: int = 1000, 122 | test_eval_every: int = 100, 123 | ): 124 | metrics = list() 125 | for episode_ix in range(n_episodes): 126 | output = self.train_ARSC_one_episode(data_path=data_path, n_iter=n_train_iter) 127 | episode_metrics = { 128 | "train": output 129 | } 130 | 131 | if train_summary_writer: 132 | train_summary_writer.add_scalar(tag=f'loss', global_step=episode_ix, scalar_value=output["loss"]) 133 | train_summary_writer.add_scalar(tag=f'acc', global_step=episode_ix, scalar_value=output["acc"]) 134 | 135 | # Running evaluation 136 | if (train_eval_every and (episode_ix + 1) % train_eval_every == 0) or (not train_eval_every and episode_ix + 1 == n_episodes): 137 | test_metrics = self.test_model_ARSC( 138 | data_path=data_path, 139 | valid_summary_writer=valid_summary_writer, 140 | test_summary_writer=test_summary_writer, 141 | n_iter=n_test_iter, 142 | eval_every=test_eval_every 143 | ) 144 | episode_metrics["test"] = test_metrics 145 | 146 | metrics.append(episode_metrics) 147 | return metrics 148 | 149 | def test_model_ARSC( 150 | self, 151 | data_path: str, 152 | n_iter: int = 1000, 153 | valid_summary_writer: SummaryWriter = None, 154 | test_summary_writer: SummaryWriter = None, 155 | eval_every: int = 100 156 | ): 157 | self.eval() 158 | 159 | tasks = get_ARSC_test_tasks(prefix=data_path) 160 | metrics = list() 161 | logger.info("Embedding sentences...") 162 | sentences_to_embed = [ 163 | s 164 | for task in tasks 165 | for sentences_lists in task['xs'] + task['x_test'] + task['x_valid'] 166 | for s in sentences_lists 167 | ] 168 | 169 | # sentence_to_embedding_dict = {s: np.random.randn(768) for s in tqdm.tqdm(sentences_to_embed)} 170 | sentence_to_embedding_dict = {s: self.encoder.forward([s]).cpu().detach().numpy().squeeze() for s in tqdm.tqdm(sentences_to_embed)} 171 | for ix_task, task in enumerate(tasks): 172 | task_metrics = list() 173 | 174 | n_episode_classes = 2 175 | loss_fn = nn.CrossEntropyLoss() 176 | episode_matrix = None 177 | episode_classifier = None 178 | if self.is_pp: 179 | with torch.no_grad(): 180 | init_matrix = np.array([ 181 | [ 182 | sentence_to_embedding_dict[sentence] 183 | for sentence in task["xs"][c] 184 | ] 185 | for c in range(n_episode_classes) 186 | ]).mean(1) 187 | 188 | episode_matrix = torch.Tensor(init_matrix).to(device) 189 | episode_matrix.requires_grad = True 190 | optimizer = torch.optim.Adam([episode_matrix], lr=2e-5) 191 | else: 192 | episode_classifier = nn.Linear(in_features=self.hidden_dim, out_features=n_episode_classes).to(device) 193 | optimizer = torch.optim.Adam(list(episode_classifier.parameters()), lr=2e-5) 194 | 195 | # Train on support 196 | iter_bar = tqdm.tqdm(range(n_iter)) 197 | losses = list() 198 | accuracies = list() 199 | 200 | for iteration in iter_bar: 201 | optimizer.zero_grad() 202 | 203 | sentences = [sentence for sentence_list in task["xs"] for sentence in sentence_list] 204 | labels = torch.Tensor([ix for ix, sl in enumerate(task["xs"]) for _ in sl]).long().to(device) 205 | batch_embeddings = torch.Tensor([sentence_to_embedding_dict[s] for s in sentences]).to(device) 206 | # z = self.encoder(sentences) 207 | z = batch_embeddings 208 | 209 | if self.is_pp: 210 | if self.metric == "cosine": 211 | z = cosine_similarity(z, episode_matrix) * 5 212 | elif self.metric == "euclidean": 213 | z = -euclidean_dist(z, episode_matrix) 214 | else: 215 | raise NotImplementedError 216 | else: 217 | z = self.dropout(z) 218 | z = episode_classifier(z) 219 | 220 | loss = loss_fn(input=z, target=labels) 221 | acc = (z.argmax(1) == labels).float().mean() 222 | loss.backward() 223 | optimizer.step() 224 | iter_bar.set_description(f"{loss.item():.3f} | {acc.item():.3f}") 225 | losses.append(loss.item()) 226 | accuracies.append(acc.item()) 227 | 228 | if (eval_every and (iteration + 1) % eval_every == 0) or (not eval_every and iteration + 1 == n_iter): 229 | self.eval() 230 | if not self.is_pp: 231 | episode_classifier.eval() 232 | 233 | # -------------- 234 | # VALIDATION 235 | # -------------- 236 | valid_query_data_list = [ 237 | {"sentence": sentence, "label": label} 238 | for label, sentences in enumerate(task["x_valid"]) 239 | for sentence in sentences 240 | ] 241 | 242 | valid_query_labels = torch.Tensor([d['label'] for d in valid_query_data_list]).long().to(device) 243 | logits = list() 244 | with torch.no_grad(): 245 | for ix in range(0, len(valid_query_data_list), 16): 246 | batch = valid_query_data_list[ix:ix + 16] 247 | batch_sentences = [d['sentence'] for d in batch] 248 | batch_embeddings = torch.Tensor([sentence_to_embedding_dict[s] for s in batch_sentences]).to(device) 249 | # z = self.encoder(batch_sentences) 250 | z = batch_embeddings 251 | 252 | if self.is_pp: 253 | if self.metric == "cosine": 254 | z = cosine_similarity(z, episode_matrix) * 5 255 | elif self.metric == "euclidean": 256 | z = -euclidean_dist(z, episode_matrix) 257 | else: 258 | raise NotImplementedError 259 | else: 260 | z = episode_classifier(z) 261 | 262 | logits.append(z) 263 | logits = torch.cat(logits, dim=0) 264 | y_hat = logits.argmax(1) 265 | 266 | valid_loss = loss_fn(input=logits, target=valid_query_labels) 267 | valid_acc = (y_hat == valid_query_labels).float().mean() 268 | 269 | # -------------- 270 | # TEST 271 | # -------------- 272 | test_query_data_list = [ 273 | {"sentence": sentence, "label": label} 274 | for label, sentences in enumerate(task["x_test"]) 275 | for sentence in sentences 276 | ] 277 | 278 | test_query_labels = torch.Tensor([d['label'] for d in test_query_data_list]).long().to(device) 279 | logits = list() 280 | with torch.no_grad(): 281 | for ix in range(0, len(test_query_data_list), 16): 282 | batch = test_query_data_list[ix:ix + 16] 283 | batch_sentences = [d['sentence'] for d in batch] 284 | batch_embeddings = torch.Tensor([sentence_to_embedding_dict[s] for s in batch_sentences]).to(device) 285 | # z = self.encoder(batch_sentences) 286 | z = batch_embeddings 287 | 288 | if self.is_pp: 289 | if self.metric == "cosine": 290 | z = cosine_similarity(z, episode_matrix) * 5 291 | elif self.metric == "euclidean": 292 | z = -euclidean_dist(z, episode_matrix) 293 | else: 294 | raise NotImplementedError 295 | else: 296 | z = episode_classifier(z) 297 | 298 | logits.append(z) 299 | logits = torch.cat(logits, dim=0) 300 | y_hat = logits.argmax(1) 301 | 302 | test_loss = loss_fn(input=logits, target=test_query_labels) 303 | test_acc = (y_hat == test_query_labels).float().mean() 304 | 305 | # --RETURN METRICS 306 | task_metrics.append({ 307 | "test": { 308 | "loss": test_loss.item(), 309 | "acc": test_acc.item() 310 | }, 311 | "valid": { 312 | "loss": valid_loss.item(), 313 | "acc": valid_acc.item() 314 | }, 315 | "step": iteration + 1 316 | }) 317 | # if valid_summary_writer: 318 | # valid_summary_writer.add_scalar(tag=f'loss', global_step=ix_task, scalar_value=valid_loss.item()) 319 | # valid_summary_writer.add_scalar(tag=f'acc', global_step=ix_task, scalar_value=valid_acc.item()) 320 | # if test_summary_writer: 321 | # test_summary_writer.add_scalar(tag=f'loss', global_step=ix_task, scalar_value=test_loss.item()) 322 | # test_summary_writer.add_scalar(tag=f'acc', global_step=ix_task, scalar_value=test_acc.item()) 323 | metrics.append(task_metrics) 324 | return metrics 325 | 326 | def train_model( 327 | self, 328 | data_dict: Dict[str, List[str]], 329 | summary_writer: SummaryWriter = None, 330 | n_epoch: int = 400, 331 | batch_size: int = 16, 332 | log_every: int = 10): 333 | self.train() 334 | 335 | training_classes = sorted(set(data_dict.keys())) 336 | n_training_classes = len(training_classes) 337 | class_to_ix = {c: ix for ix, c in enumerate(training_classes)} 338 | training_data_list = [{"sentence": sentence, "label": label} for label, sentences in data_dict.items() for sentence in sentences] 339 | 340 | training_matrix = None 341 | training_classifier = None 342 | 343 | if self.is_pp: 344 | training_matrix = torch.randn(n_training_classes, self.hidden_dim, requires_grad=True, device=device) 345 | optimizer = torch.optim.Adam(list(self.parameters()) + [training_matrix], lr=2e-5) 346 | else: 347 | training_classifier = nn.Linear(in_features=self.hidden_dim, out_features=n_training_classes).to(device) 348 | optimizer = torch.optim.Adam(list(self.parameters()) + list(training_classifier.parameters()), lr=2e-5) 349 | 350 | n_samples = len(training_data_list) 351 | loss_fn = nn.CrossEntropyLoss() 352 | global_step = 0 353 | 354 | # Metrics 355 | training_losses = list() 356 | training_accuracies = list() 357 | 358 | for _ in tqdm.tqdm(range(n_epoch)): 359 | random.shuffle(training_data_list) 360 | for ix in tqdm.tqdm(range(0, n_samples, batch_size)): 361 | optimizer.zero_grad() 362 | torch.cuda.empty_cache() 363 | 364 | batch_items = training_data_list[ix:ix + batch_size] 365 | batch_sentences = [d['sentence'] for d in batch_items] 366 | batch_labels = torch.Tensor([class_to_ix[d['label']] for d in batch_items]).long().to(device) 367 | z = self.encoder(batch_sentences) 368 | if self.is_pp: 369 | if self.metric == "cosine": 370 | z = cosine_similarity(z, training_matrix) * 5 371 | elif self.metric == "euclidean": 372 | z = -euclidean_dist(z, training_matrix) 373 | else: 374 | raise NotImplementedError 375 | else: 376 | z = self.dropout(z) 377 | z = training_classifier(z) 378 | loss = loss_fn(input=z, target=batch_labels) 379 | acc = (z.argmax(1) == batch_labels).float().mean() 380 | loss.backward() 381 | optimizer.step() 382 | 383 | global_step += 1 384 | training_losses.append(loss.item()) 385 | training_accuracies.append(acc.item()) 386 | if (global_step % log_every) == 0: 387 | if summary_writer: 388 | summary_writer.add_scalar(tag="loss", global_step=global_step, scalar_value=np.mean(training_losses)) 389 | summary_writer.add_scalar(tag="acc", global_step=global_step, scalar_value=np.mean(training_accuracies)) 390 | # Empty metrics 391 | training_losses = list() 392 | training_accuracies = list() 393 | 394 | def test_one_episode( 395 | self, 396 | support_data_dict: Dict[str, List[str]], 397 | query_data_dict: Dict[str, List[str]], 398 | sentence_to_embedding_dict: Dict, 399 | batch_size: int = 4, 400 | n_iter: int = 1000, 401 | summary_writer: SummaryWriter = None, 402 | summary_tag_prefix: str = None, 403 | ): 404 | 405 | # Check data integrity 406 | assert set(support_data_dict.keys()) == set(query_data_dict.keys()) 407 | 408 | # Freeze encoder 409 | self.encoder.eval() 410 | 411 | episode_classes = sorted(set(support_data_dict.keys())) 412 | n_episode_classes = len(episode_classes) 413 | class_to_ix = {c: ix for ix, c in enumerate(episode_classes)} 414 | ix_to_class = {ix: c for ix, c in enumerate(episode_classes)} 415 | support_data_list = [{"sentence": sentence, "label": label} for label, sentences in support_data_dict.items() for sentence in sentences] 416 | support_data_list = (support_data_list * batch_size * n_iter)[:(batch_size * n_iter)] 417 | 418 | loss_fn = nn.CrossEntropyLoss() 419 | episode_matrix = None 420 | episode_classifier = None 421 | if self.is_pp: 422 | init_matrix = np.array([ 423 | [ 424 | sentence_to_embedding_dict[sentence].ravel() 425 | for sentence in support_data_dict[ix_to_class[c]] 426 | ] 427 | for c in range(n_episode_classes) 428 | ]).mean(1) 429 | 430 | episode_matrix = torch.Tensor(init_matrix).to(device) 431 | episode_matrix.requires_grad = True 432 | optimizer = torch.optim.Adam([episode_matrix], lr=1e-3) 433 | else: 434 | episode_classifier = nn.Linear(in_features=self.hidden_dim, out_features=n_episode_classes).to(device) 435 | optimizer = torch.optim.Adam(list(episode_classifier.parameters()), lr=1e-3) 436 | 437 | # Train on support 438 | iter_bar = tqdm.tqdm(range(n_iter)) 439 | for iteration in iter_bar: 440 | optimizer.zero_grad() 441 | 442 | batch = support_data_list[iteration * batch_size: iteration * batch_size + batch_size] 443 | batch_sentences = [d['sentence'] for d in batch] 444 | batch_embeddings = torch.Tensor([sentence_to_embedding_dict[s] for s in batch_sentences]).to(device) 445 | batch_labels = torch.Tensor([class_to_ix[d['label']] for d in batch]).long().to(device) 446 | # z = self.encoder(batch_sentences) 447 | z = batch_embeddings 448 | 449 | if self.is_pp: 450 | if self.metric == "cosine": 451 | z = cosine_similarity(z, episode_matrix) * 5 452 | elif self.metric == "euclidean": 453 | z = -euclidean_dist(z, episode_matrix) 454 | else: 455 | raise NotImplementedError 456 | else: 457 | z = self.dropout(z) 458 | z = episode_classifier(z) 459 | 460 | loss = loss_fn(input=z, target=batch_labels) 461 | acc = (z.argmax(1) == batch_labels).float().mean() 462 | loss.backward() 463 | optimizer.step() 464 | iter_bar.set_description(f"{loss.item():.3f} | {acc.item():.3f}") 465 | 466 | if summary_writer: 467 | summary_writer.add_scalar(tag=f'{summary_tag_prefix}_loss', global_step=iteration, scalar_value=loss.item()) 468 | summary_writer.add_scalar(tag=f'{summary_tag_prefix}_acc', global_step=iteration, scalar_value=acc.item()) 469 | 470 | # Predict on query 471 | self.eval() 472 | if not self.is_pp: 473 | episode_classifier.eval() 474 | 475 | query_data_list = [{"sentence": sentence, "label": label} for label, sentences in query_data_dict.items() for sentence in sentences] 476 | query_labels = torch.Tensor([class_to_ix[d['label']] for d in query_data_list]).long().to(device) 477 | logits = list() 478 | with torch.no_grad(): 479 | for ix in range(0, len(query_data_list), 16): 480 | batch = query_data_list[ix:ix + 16] 481 | batch_sentences = [d['sentence'] for d in batch] 482 | batch_embeddings = torch.Tensor([sentence_to_embedding_dict[s] for s in batch_sentences]).to(device) 483 | # z = self.encoder(batch_sentences) 484 | z = batch_embeddings 485 | 486 | if self.is_pp: 487 | if self.metric == "cosine": 488 | z = cosine_similarity(z, episode_matrix) * 5 489 | elif self.metric == "euclidean": 490 | z = -euclidean_dist(z, episode_matrix) 491 | else: 492 | raise NotImplementedError 493 | else: 494 | z = episode_classifier(z) 495 | 496 | logits.append(z) 497 | logits = torch.cat(logits, dim=0) 498 | y_hat = logits.argmax(1) 499 | 500 | y_pred = logits.argmax(1).cpu().detach().numpy() 501 | probas_pred = logits.cpu().detach().numpy() 502 | probas_pred = np.exp(probas_pred) / np.exp(probas_pred).sum(1)[:, None] 503 | 504 | y_true = query_labels.cpu().detach().numpy() 505 | where_ok = np.where(y_pred == y_true)[0] 506 | import uuid 507 | tag = str(uuid.uuid4()) 508 | summary_writer.add_text(tag=tag, text_string=json.dumps(ix_to_class, ensure_ascii=False), global_step=0) 509 | if len(where_ok): 510 | # Looking for OK but with less confidence (not too easy) 511 | ok_idx = sorted(where_ok, key=lambda x: probas_pred[x][y_pred[x]])[0] 512 | ok_sentence = query_data_list[ok_idx]['sentence'] 513 | ok_prediction = ix_to_class[y_pred[ok_idx]] 514 | ok_label = query_data_list[ok_idx]['label'] 515 | summary_writer.add_text( 516 | tag=tag, 517 | text_string=json.dumps({ 518 | "sentence": ok_sentence, 519 | "true_label": ok_label, 520 | "predicted_label": ok_prediction, 521 | "p": probas_pred[ok_idx].tolist(), 522 | }), 523 | global_step=1) 524 | 525 | where_ko = np.where(y_pred != y_true)[0] 526 | if len(where_ko): 527 | # Looking for KO but with most confidence 528 | ko_idx = sorted(where_ko, key=lambda x: probas_pred[x][y_pred[x]], reverse=True)[0] 529 | ko_sentence = query_data_list[ko_idx]['sentence'] 530 | ko_prediction = ix_to_class[y_pred[ko_idx]] 531 | ko_label = query_data_list[ko_idx]['label'] 532 | summary_writer.add_text( 533 | tag=tag, 534 | text_string=json.dumps({ 535 | "sentence": ko_sentence, 536 | "true_label": ko_label, 537 | "predicted_label": ko_prediction, 538 | "p": probas_pred[ko_idx].tolist() 539 | }), 540 | global_step=2) 541 | 542 | loss = loss_fn(input=logits, target=query_labels) 543 | acc = (y_hat == query_labels).float().mean() 544 | 545 | return { 546 | "loss": loss.item(), 547 | "acc": acc.item() 548 | } 549 | 550 | def test_model( 551 | self, 552 | data_dict: Dict[str, List[str]], 553 | n_support: int, 554 | n_classes: int, 555 | n_episodes=600, 556 | summary_writer: SummaryWriter = None, 557 | n_test_iter: int = 100, 558 | test_batch_size: int = 4 559 | ): 560 | test_metrics = list() 561 | 562 | # Freeze encoder 563 | self.encoder.eval() 564 | logger.info("Embedding sentences...") 565 | sentences_to_embed = [s for label, sentences in data_dict.items() for s in sentences] 566 | sentence_to_embedding_dict = {s: self.encoder.forward([s]).cpu().detach().numpy().squeeze() for s in tqdm.tqdm(sentences_to_embed)} 567 | 568 | for episode in tqdm.tqdm(range(n_episodes)): 569 | episode_classes = np.random.choice(list(data_dict.keys()), size=n_classes, replace=False) 570 | episode_query_data_dict = dict() 571 | episode_support_data_dict = dict() 572 | 573 | for episode_class in episode_classes: 574 | random.shuffle(data_dict[episode_class]) 575 | episode_support_data_dict[episode_class] = data_dict[episode_class][:n_support] 576 | episode_query_data_dict[episode_class] = data_dict[episode_class][n_support:] 577 | 578 | episode_metrics = self.test_one_episode( 579 | support_data_dict=episode_support_data_dict, 580 | query_data_dict=episode_query_data_dict, 581 | n_iter=n_test_iter, 582 | batch_size=test_batch_size, 583 | sentence_to_embedding_dict=sentence_to_embedding_dict, 584 | summary_writer=summary_writer 585 | ) 586 | logger.info(f"Episode metrics: {episode_metrics}") 587 | test_metrics.append(episode_metrics) 588 | for metric_name, metric_value in episode_metrics.items(): 589 | summary_writer.add_scalar(tag=metric_name, global_step=episode, scalar_value=metric_value) 590 | 591 | return test_metrics 592 | 593 | 594 | def run_baseline( 595 | train_path: str, 596 | model_name_or_path: str, 597 | n_support: int, 598 | n_classes: int, 599 | valid_path: str = None, 600 | test_path: str = None, 601 | output_path: str = f'runs/{now()}', 602 | n_test_episodes: int = 600, 603 | log_every: int = 10, 604 | n_train_epoch: int = 400, 605 | train_batch_size: int = 16, 606 | is_pp: bool = False, 607 | test_batch_size: int = 4, 608 | n_test_iter: int = 100, 609 | metric: str = "cosine", 610 | arsc_format: bool = False, 611 | data_path: str = None 612 | ): 613 | if output_path: 614 | if os.path.exists(output_path) and len(os.listdir(output_path)): 615 | raise FileExistsError(f"Output path {output_path} already exists. Exiting.") 616 | 617 | # -------------------- 618 | # Creating Log Writers 619 | # -------------------- 620 | os.makedirs(output_path) 621 | os.makedirs(os.path.join(output_path, "logs/train")) 622 | train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(output_path, "logs/train"), flush_secs=1, max_queue=1) 623 | valid_writer: SummaryWriter = None 624 | test_writer: SummaryWriter = None 625 | log_dict = dict(train=list()) 626 | 627 | if valid_path: 628 | os.makedirs(os.path.join(output_path, "logs/valid")) 629 | valid_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/valid"), flush_secs=1, max_queue=1) 630 | log_dict["valid"] = list() 631 | if test_path: 632 | os.makedirs(os.path.join(output_path, "logs/test")) 633 | test_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/test"), flush_secs=1, max_queue=1) 634 | log_dict["test"] = list() 635 | 636 | def raw_data_to_labels_dict(data, shuffle=True): 637 | labels_dict = collections.defaultdict(list) 638 | for item in data: 639 | labels_dict[item['label']].append(item["sentence"]) 640 | labels_dict = dict(labels_dict) 641 | if shuffle: 642 | for key, val in labels_dict.items(): 643 | random.shuffle(val) 644 | return labels_dict 645 | 646 | # Load model 647 | bert = BERTEncoder(model_name_or_path).to(device) 648 | baseline_net = BaselineNet(encoder=bert, is_pp=is_pp, metric=metric).to(device) 649 | 650 | # Load data 651 | if not arsc_format: 652 | train_data = get_jsonl_data(train_path) 653 | train_data_dict = raw_data_to_labels_dict(train_data, shuffle=True) 654 | logger.info(f"train labels: {train_data_dict.keys()}") 655 | 656 | if valid_path: 657 | valid_data = get_jsonl_data(valid_path) 658 | valid_data_dict = raw_data_to_labels_dict(valid_data, shuffle=True) 659 | logger.info(f"valid labels: {valid_data_dict.keys()}") 660 | else: 661 | valid_data_dict = None 662 | 663 | if test_path: 664 | test_data = get_jsonl_data(test_path) 665 | test_data_dict = raw_data_to_labels_dict(test_data, shuffle=True) 666 | logger.info(f"test labels: {test_data_dict.keys()}") 667 | else: 668 | test_data_dict = None 669 | 670 | baseline_net.train_model( 671 | data_dict=train_data_dict, 672 | summary_writer=train_writer, 673 | n_epoch=n_train_epoch, 674 | batch_size=train_batch_size, 675 | log_every=log_every 676 | ) 677 | 678 | # Validation 679 | if valid_path: 680 | validation_metrics = baseline_net.test_model( 681 | data_dict=valid_data_dict, 682 | n_support=n_support, 683 | n_classes=n_classes, 684 | n_episodes=n_test_episodes, 685 | summary_writer=valid_writer, 686 | n_test_iter=n_test_iter, 687 | test_batch_size=test_batch_size 688 | ) 689 | with open(os.path.join(output_path, 'validation_metrics.json'), "w") as file: 690 | json.dump(validation_metrics, file, ensure_ascii=False) 691 | # Test 692 | if test_path: 693 | test_metrics = baseline_net.test_model( 694 | data_dict=test_data_dict, 695 | n_support=n_support, 696 | n_classes=n_classes, 697 | n_episodes=n_test_episodes, 698 | summary_writer=test_writer 699 | ) 700 | 701 | with open(os.path.join(output_path, 'test_metrics.json'), "w") as file: 702 | json.dump(test_metrics, file, ensure_ascii=False) 703 | 704 | else: 705 | # baseline_net.train_model_ARSC( 706 | # train_summary_writer=train_writer, 707 | # n_episodes=10, 708 | # n_train_iter=20 709 | # ) 710 | # metrics = baseline_net.test_model_ARSC( 711 | # n_iter=n_test_iter, 712 | # valid_summary_writer=valid_writer, 713 | # test_summary_writer=test_writer 714 | # ) 715 | metrics = baseline_net.run_ARSC( 716 | train_summary_writer=train_writer, 717 | valid_summary_writer=valid_writer, 718 | test_summary_writer=test_writer, 719 | n_episodes=1000, 720 | train_eval_every=50, 721 | n_train_iter=50, 722 | n_test_iter=200, 723 | test_eval_every=25, 724 | data_path=data_path 725 | ) 726 | with open(os.path.join(output_path, 'baseline_metrics.json'), "w") as file: 727 | json.dump(metrics, file, ensure_ascii=False) 728 | 729 | 730 | def main(): 731 | parser = argparse.ArgumentParser() 732 | parser.add_argument("--train-path", type=str, required=True, help="Path to training data") 733 | parser.add_argument("--valid-path", type=str, default=None, help="Path to validation data") 734 | parser.add_argument("--test-path", type=str, default=None, help="Path to testing data") 735 | parser.add_argument("--data-path", type=str, default=None, help="Path to data (ARSC only)") 736 | 737 | parser.add_argument("--output-path", type=str, default=f'runs/{now()}') 738 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use") 739 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging") 740 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set") 741 | 742 | # Few-Shot related stuff 743 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class") 744 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode") 745 | parser.add_argument("--n-test-episodes", type=int, default=600, help="Number of episodes during evaluation (valid, test)") 746 | parser.add_argument("--n-train-epoch", type=int, default=400, help="Number of epoch during training") 747 | parser.add_argument("--train-batch-size", type=int, default=16, help="Batch size used during training") 748 | parser.add_argument("--n-test-iter", type=int, default=100, help="Number of training iterations during testing episodes") 749 | parser.add_argument("--test-batch-size", type=int, default=4, help="Batch size used during training") 750 | 751 | # Baseline++ 752 | parser.add_argument("--pp", default=False, action="store_true", help="Boolean to use the ++ baseline model") 753 | parser.add_argument("--metric", default="cosine", type=str, help="Which metric to use in baseline++", choices=("euclidean", "cosine")) 754 | # ARSC data 755 | parser.add_argument("--arsc-format", default=False, action="store_true", help="Using ARSC few-shot format") 756 | 757 | args = parser.parse_args() 758 | logger.info(f"Received args:{args}") 759 | 760 | # Set random seed 761 | set_seeds(args.seed) 762 | 763 | # Check if data path(s) exist 764 | for arg in [args.train_path, args.valid_path, args.test_path]: 765 | if arg and not os.path.exists(arg): 766 | raise FileNotFoundError(f"Data @ {arg} not found.") 767 | 768 | # Run 769 | run_baseline( 770 | train_path=args.train_path, 771 | valid_path=args.valid_path, 772 | test_path=args.test_path, 773 | data_path=args.data_path, 774 | output_path=args.output_path, 775 | 776 | model_name_or_path=args.model_name_or_path, 777 | 778 | n_support=args.n_support, 779 | n_classes=args.n_classes, 780 | n_test_episodes=args.n_test_episodes, 781 | log_every=args.log_every, 782 | n_train_epoch=args.n_train_epoch, 783 | train_batch_size=args.train_batch_size, 784 | is_pp=args.pp, 785 | 786 | test_batch_size=args.test_batch_size, 787 | n_test_iter=args.n_test_iter, 788 | metric=args.metric, 789 | arsc_format=args.arsc_format 790 | ) 791 | 792 | # Save config 793 | with open(os.path.join(args.output_path, "config.json"), "w") as file: 794 | json.dump(vars(args), file, ensure_ascii=False, indent=1) 795 | 796 | 797 | if __name__ == '__main__': 798 | main() 799 | --------------------------------------------------------------------------------