├── models
├── __init__.py
├── proto
│ ├── protonet.sh
│ ├── protaugment.py
│ ├── protaugment-tmp.py
│ └── protonet.py
├── matching
│ ├── matchingnet.sh
│ └── matchingnet.py
├── relation
│ ├── relationnet.sh
│ └── relationnet.py
├── bert_baseline
│ ├── baseline.sh
│ └── baseline.py
├── induction
│ ├── inductionnet.sh
│ └── inductionnet.py
└── encoders
│ └── bert_encoder.py
├── .gitignore
├── utils
├── python.py
├── math.py
├── scripts
│ ├── prepate-intent-dataset.py
│ └── runner.sh
├── data.py
└── few_shot.py
├── README.md
├── requirements.txt
└── language_modeling
└── run_language_modeling.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .encoders.bert_encoder import BERTEncoder
2 | from .bert_baseline.baseline import BaselineNet
3 | from .induction.inductionnet import InductionNet
4 | from .matching.matchingnet import MatchingNet
5 | from .proto.protonet import ProtoNet
6 | from .relation.relationnet import RelationNet
7 |
--------------------------------------------------------------------------------
/models/proto/protonet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd $HOME/Projects/FewShotText
3 | source .venv/bin/activate
4 | source .envrc
5 | echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}"
6 | command -v nvidia-smi >/dev/null && {
7 | echo "GPU Devices:"
8 | nvidia-smi
9 | } || {
10 | :
11 | }
12 |
13 | PYTHONPATH=. python models/proto/protonet.py $@
14 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .envrc
2 | .venv
3 | .idea
4 | **/__pycache__
5 | # Runs folder
6 | runs
7 | runs*
8 | runs.bkp
9 | tmp
10 |
11 | # Shell scripts are ignored by default
12 | **/*.sh
13 |
14 | # Jupyter stuff
15 | notebooks
16 | **/.ipynb_checkpoints
17 |
18 | # Spreadsheet stuff
19 | **/*.ods
20 | **/*.xlsx
21 |
22 | # Ignore data by default
23 | data
24 | **/*.zip
25 |
--------------------------------------------------------------------------------
/models/matching/matchingnet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd $HOME/Projects/FewShotText
3 | source .venv/bin/activate
4 | source .envrc
5 | echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}"
6 | command -v nvidia-smi >/dev/null && {
7 | echo "GPU Devices:"
8 | nvidia-smi
9 | } || {
10 | :
11 | }
12 |
13 | PYTHONPATH=. python models/matching/matchingnet.py $@
14 |
--------------------------------------------------------------------------------
/models/relation/relationnet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd $HOME/Projects/FewShotText
3 | source .venv/bin/activate
4 | source .envrc
5 | echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}"
6 | command -v nvidia-smi >/dev/null && {
7 | echo "GPU Devices:"
8 | nvidia-smi
9 | } || {
10 | :
11 | }
12 |
13 | PYTHONPATH=. python models/relation/relationnet.py $@
14 |
--------------------------------------------------------------------------------
/models/bert_baseline/baseline.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd $HOME/Projects/FewShotText
3 | source .venv/bin/activate
4 | source .envrc
5 | echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}"
6 | command -v nvidia-smi >/dev/null && {
7 | echo "GPU Devices:"
8 | nvidia-smi
9 | } || {
10 | :
11 | }
12 |
13 | PYTHONPATH=. python models/bert_baseline/baseline.py $@
14 |
--------------------------------------------------------------------------------
/models/induction/inductionnet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd $HOME/Projects/FewShotText
3 | source .venv/bin/activate
4 | source .envrc
5 | echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}"
6 | command -v nvidia-smi >/dev/null && {
7 | echo "GPU Devices:"
8 | nvidia-smi
9 | } || {
10 | :
11 | }
12 |
13 | PYTHONPATH=. python models/induction/inductionnet.py $@
14 |
--------------------------------------------------------------------------------
/utils/python.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import random
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def now():
8 | return datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S.%f")
9 |
10 |
11 | def set_seeds(seed: int) -> None:
12 | """
13 | set random seeds
14 | :param seed: int
15 | :return: None
16 | """
17 | random.seed(seed)
18 | np.random.seed(seed)
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 |
--------------------------------------------------------------------------------
/utils/math.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def euclidean_dist(x, y):
5 | # x: N x D
6 | # y: M x D
7 | n = x.size(0)
8 | m = y.size(0)
9 | d = x.size(1)
10 | assert d == y.size(1)
11 |
12 | x = x.unsqueeze(1).expand(n, m, d)
13 | y = y.unsqueeze(0).expand(n, m, d)
14 |
15 | return torch.pow(x - y, 2).sum(2)
16 |
17 |
18 | def cosine_similarity(x, y):
19 | x = (x / x.norm(dim=1).view(-1, 1))
20 | y = (y / y.norm(dim=1).view(-1, 1))
21 |
22 | return x @ y.T
23 |
--------------------------------------------------------------------------------
/utils/scripts/prepate-intent-dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | from utils.data import get_jsonl_data, write_jsonl_data, write_txt_data
3 |
4 |
5 | def process_dataset(name):
6 | full = get_jsonl_data(f"data/{name}/full.jsonl")
7 | random.shuffle(full)
8 |
9 | labels = sorted(set([d['label'] for d in full]))
10 | n_labels = len(labels)
11 | random.seed(42)
12 | random.shuffle(labels)
13 |
14 | train_labels, valid_labels, test_labels = (
15 | labels[:int(n_labels / 3)],
16 | labels[int(n_labels / 3):int(2 * n_labels / 3)],
17 | labels[int(2 * n_labels / 3):]
18 | )
19 |
20 | write_jsonl_data([
21 | d for d in full if d['label'] in train_labels
22 | ], f"data/{name}/train.jsonl", force=True)
23 | write_txt_data(train_labels, f"data/{name}/labels.train.txt")
24 |
25 | write_jsonl_data([
26 | d for d in full if d['label'] in valid_labels
27 | ], f"data/{name}/valid.jsonl", force=True)
28 | write_txt_data(valid_labels, f"data/{name}/labels.valid.txt")
29 |
30 | write_jsonl_data([
31 | d for d in full if d['label'] in test_labels
32 | ], f"data/{name}/test.jsonl", force=True)
33 | write_txt_data(test_labels, f"data/{name}/labels.test.txt")
34 |
35 |
36 | if __name__ == "__main__":
37 | for name in ('OOS', 'TREC28', 'Liu'):
38 | process_dataset(name)
39 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FewShotText
2 | This repository contains code for the paper [A Neural Few-Shot Text Classification Reality Check](https://arxiv.org/abs/2101.12073)
3 |
4 | ## Environment setup
5 | ```bash
6 | # Create environment
7 | python3 -m virtualenv .venv --python=python3.6
8 |
9 | # Install environment
10 | .venv/bin/pip install -r requirements.txt
11 |
12 | # Activate environment
13 | source .venv/bin/activate
14 | ```
15 |
16 | ## Fine-tuning BERT on the MLM task
17 | ```bash
18 | model_name=bert-base-cased
19 | block_size=256
20 | dataset=OOS
21 | output_dir=transformer_models/${dataset}/fine-tuned
22 |
23 | python language_modeling/run_language_modeling.py \
24 | --model_name_or_path ${model_name} \
25 | --output_dir ${output_dir} \
26 | --mlm \
27 | --do_train \
28 | --train_data_file data/${dataset}/full/full-train.txt \
29 | --do_eval \
30 | --eval_data_file data/${dataset}/full/full-test.txt \
31 | --overwrite_output_dir \
32 | --logging_steps=1000 \
33 | --line_by_line \
34 | --logging_dir ${output_dir} \
35 | --block_size ${block_size} \
36 | --save_steps=1000 \
37 | --num_train_epochs 20 \
38 | --save_total_limit 20 \
39 | --seed 42 \
40 | --evaluation_strategy epoch
41 | ```
42 |
43 | ## Training a few-shot model
44 | To run the paper's experiments, simply use the ```utils/scripts/runner.sh``` file.
45 |
46 | ## Reference
47 | If you use the data or codes in this repository, please cite our paper:
48 | ```bash
49 | @article{dopierre2021neural,
50 | title={A Neural Few-Shot Text Classification Reality Check},
51 | author={Dopierre, Thomas and Gravier, Christophe and Logerais, Wilfried},
52 | journal={arXiv preprint arXiv:2101.12073},
53 | year={2021}
54 | }
55 | ```
56 |
--------------------------------------------------------------------------------
/models/encoders/bert_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch.nn as nn
4 | import logging
5 | import warnings
6 | import torch
7 | from transformers import AutoModel, AutoTokenizer
8 |
9 | logging.basicConfig()
10 | logger = logging.getLogger(__name__)
11 | logger.setLevel(logging.DEBUG)
12 |
13 | warnings.simplefilter('ignore')
14 |
15 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
16 |
17 |
18 | class BERTEncoder(nn.Module):
19 | def __init__(self, config_name_or_path):
20 | super(BERTEncoder, self).__init__()
21 | logger.info(f"Loading Encoder @ {config_name_or_path}")
22 | self.tokenizer = AutoTokenizer.from_pretrained(config_name_or_path)
23 | self.bert = AutoModel.from_pretrained(config_name_or_path).to(device)
24 | logger.info(f"Encoder loaded.")
25 | self.warmed: bool = False
26 |
27 | def embed_sentences(self, sentences: List[str]):
28 | if self.warmed:
29 | padding = True
30 | else:
31 | padding = "max_length"
32 | self.warmed = True
33 | batch = self.tokenizer(
34 | sentences,
35 | return_tensors="pt",
36 | max_length=64,
37 | truncation=True,
38 | padding=padding
39 | )
40 | batch = {k: v.to(device) for k, v in batch.items()}
41 |
42 | fw = self.bert.forward(**batch)
43 | return fw.pooler_output
44 |
45 | def forward(self, sentences: List[str]):
46 | try:
47 | return self.embed_sentences(sentences)
48 | except Exception as e:
49 | logger.error(f"could not embed sentence {sentences} (err: {type(e)}, {e}, {str(e)}")
50 | raise e
51 |
52 |
53 | def test():
54 | encoder = BERTEncoder("bert-base-cased")
55 | sentences = ["test sentence #1", "test sentence #2🍇"]
56 | encoder.embed_sentences(sentences)
57 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.11.0
2 | altair==4.1.0
3 | antlr4-python3-runtime==4.8
4 | argon2-cffi==20.1.0
5 | astor==0.8.1
6 | astunparse==1.6.3
7 | async-generator==1.10
8 | attrs==20.3.0
9 | backcall==0.2.0
10 | base58==2.0.1
11 | bleach==3.3.0
12 | blessings==1.7
13 | blinker==1.4
14 | cachetools==4.1.0
15 | certifi==2020.4.5.2
16 | cffi==1.14.4
17 | chardet==3.0.4
18 | charset-normalizer==2.0.3
19 | click==7.1.2
20 | configparser==5.0.1
21 | Cython==0.29.21
22 | dataclasses==0.7
23 | decorator==4.4.2
24 | defusedxml==0.6.0
25 | docker-pycreds==0.4.0
26 | entrypoints==0.3
27 | fairseq==0.10.1
28 | filelock==3.0.12
29 | flatbuffers==1.12
30 | future==0.18.2
31 | gast==0.3.3
32 | gitdb==4.0.5
33 | GitPython==3.1.11
34 | google-auth==1.17.2
35 | google-auth-oauthlib==0.4.1
36 | google-pasta==0.2.0
37 | grpcio==1.32.0
38 | h5py==2.10.0
39 | huggingface-hub==0.0.12
40 | hydra-core==1.0.4
41 | idna==2.9
42 | importlib-metadata==1.6.1
43 | importlib-resources==3.3.0
44 | ipykernel==5.4.2
45 | ipython==7.16.1
46 | ipython-genutils==0.2.0
47 | ipywidgets==7.5.1
48 | jedi==0.17.2
49 | Jinja2==2.11.3
50 | joblib==0.15.1
51 | jsonpatch==1.25
52 | jsonpointer==2.0
53 | jsonschema==3.2.0
54 | jupyter-client==6.1.7
55 | jupyter-core==4.7.0
56 | jupyterlab-pygments==0.1.2
57 | Keras-Preprocessing==1.1.2
58 | Markdown==3.2.2
59 | MarkupSafe==1.1.1
60 | mistune==0.8.4
61 | nbclient==0.5.1
62 | nbconvert==6.0.7
63 | nbformat==5.0.8
64 | nest-asyncio==1.4.3
65 | nltk==3.5
66 | notebook==6.1.5
67 | numpy==1.19.4
68 | oauthlib==3.1.0
69 | omegaconf==2.0.5
70 | opt-einsum==3.3.0
71 | packaging==21.0
72 | pandas==1.1.5
73 | pandocfilters==1.4.3
74 | parso==0.7.1
75 | pexpect==4.8.0
76 | pickleshare==0.7.5
77 | Pillow==8.2.0
78 | portalocker==2.0.0
79 | prometheus-client==0.9.0
80 | promise==2.3
81 | prompt-toolkit==3.0.8
82 | protobuf==3.12.2
83 | psutil==5.7.3
84 | ptyprocess==0.6.0
85 | pyarrow==2.0.0
86 | pyasn1==0.4.8
87 | pyasn1-modules==0.2.8
88 | pybind11==2.5.0
89 | pycparser==2.20
90 | pydeck==0.5.0
91 | Pygments==2.7.4
92 | pyparsing==3.0.0a1
93 | pyrsistent==0.17.3
94 | python-dateutil==2.8.1
95 | pytz==2020.4
96 | PyYAML==5.4
97 | pyzmq==19.0.1
98 | regex==2020.6.8
99 | requests==2.26.0
100 | requests-oauthlib==1.3.0
101 | rouge-score==0.0.4
102 | rsa==4.7
103 | sacrebleu==1.4.14
104 | sacremoses==0.0.43
105 | scikit-learn==0.23.1
106 | scipy==1.5.0rc1
107 | Send2Trash==1.5.0
108 | sentencepiece==0.1.92
109 | sentry-sdk==0.19.5
110 | seqeval==1.2.2
111 | shortuuid==1.0.1
112 | simpletransformers==0.51.3
113 | six==1.15.0
114 | smmap==3.0.4
115 | streamlit==0.72.0
116 | subprocess32==3.5.4
117 | tensorboard==2.4.0
118 | tensorboard-plugin-wit==1.6.0.post3
119 | tensorboardX==2.0
120 | tensorflow==2.4.2
121 | tensorflow-estimator==2.4.0
122 | tensorflow-hub==0.10.0
123 | termcolor==1.1.0
124 | terminado==0.9.1
125 | testpath==0.4.4
126 | threadpoolctl==2.1.0
127 | tokenizers==0.10.3
128 | toml==0.10.2
129 | toolz==0.11.1
130 | torch==1.5.0
131 | torchfile==0.1.0
132 | torchnet==0.0.4
133 | torchvision==0.6.0
134 | tornado==6.0.4
135 | tqdm==4.54.1
136 | traitlets==4.3.3
137 | transformers==4.8.2
138 | typing-extensions==3.7.4.3
139 | tzlocal==2.1
140 | urllib3==1.26.5
141 | validators==0.18.1
142 | visdom==0.1.8.9
143 | wandb==0.10.12
144 | watchdog==1.0.1
145 | wcwidth==0.2.5
146 | webencodings==0.5.1
147 | websocket-client==0.57.0
148 | Werkzeug==1.0.1
149 | widgetsnbextension==3.5.1
150 | wrapt==1.12.1
151 | zipp==3.1.0
152 |
--------------------------------------------------------------------------------
/utils/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import collections
4 | import os
5 | import json
6 | from typing import List, Dict
7 |
8 |
9 | def get_jsonl_data(jsonl_path: str):
10 | assert jsonl_path.endswith(".jsonl")
11 | out = list()
12 | with open(jsonl_path, 'r', encoding="utf-8") as file:
13 | for line in file:
14 | j = json.loads(line.strip())
15 | out.append(j)
16 | return out
17 |
18 |
19 | def write_jsonl_data(jsonl_data: List[Dict], jsonl_path: str, force=False):
20 | if os.path.exists(jsonl_path) and not force:
21 | raise FileExistsError
22 | with open(jsonl_path, 'w') as file:
23 | for line in jsonl_data:
24 | file.write(json.dumps(line, ensure_ascii=False) + '\n')
25 |
26 |
27 | def get_txt_data(txt_path: str):
28 | assert txt_path.endswith(".txt")
29 | with open(txt_path, "r") as file:
30 | return [line.strip() for line in file.readlines()]
31 |
32 |
33 | def write_txt_data(data: List[str], path: str, force: bool = False):
34 | if os.path.exists(path) and not force:
35 | raise FileExistsError
36 | with open(path, "w") as file:
37 | for line in data:
38 | file.write(line + "\n")
39 |
40 |
41 | def get_tsv_data(tsv_path: str, label: str = None):
42 | out = list()
43 | with open(tsv_path, "r") as file:
44 | for line in file:
45 | line = line.strip().split('\t')
46 | if not label:
47 | label = tsv_path.split('/')[-1]
48 |
49 | out.append({
50 | "sentence": line[0],
51 | "label": label + str(line[1])
52 | })
53 | return out
54 |
55 |
56 | def raw_data_to_dict(data, shuffle=True):
57 | labels_dict = collections.defaultdict(list)
58 | for item in data:
59 | labels_dict[item['label']].append(item)
60 | labels_dict = dict(labels_dict)
61 | if shuffle:
62 | for key, val in labels_dict.items():
63 | random.shuffle(val)
64 | return labels_dict
65 |
66 |
67 | class UnlabeledDataLoader:
68 | def __init__(self, file_path: str):
69 | self.file_path = file_path
70 | self.raw_data = get_jsonl_data(self.file_path)
71 | self.data_dict = raw_data_to_dict(self.raw_data, shuffle=True)
72 |
73 | def create_episode(self, n_augment: int = 0):
74 | episode = dict()
75 | augmentations = list()
76 | if n_augment:
77 | already_done = list()
78 | for i in range(n_augment):
79 | # Draw a random label
80 | key = random.choice(list(self.data_dict.keys()))
81 | # Draw a random data index
82 | ix = random.choice(range(len(self.data_dict[key])))
83 | # If already used, re-sample
84 | while (key, ix) in already_done:
85 | key = random.choice(list(self.data_dict.keys()))
86 | ix = random.choice(range(len(self.data_dict[key])))
87 | already_done.append((key, ix))
88 | if "augmentations" not in self.data_dict[key][ix]:
89 | raise KeyError(f"Input data {self.data_dict[key][ix]} does not contain any augmentations / is not properly formatted.")
90 | augmentations.append((self.data_dict[key][ix]))
91 |
92 | episode["x_augment"] = augmentations
93 |
94 | return episode
95 |
96 |
97 | class FewShotDataLoader:
98 | def __init__(self, file_path, unlabeled_file_path: str = None):
99 | self.raw_data = get_jsonl_data(file_path)
100 | self.data_dict = raw_data_to_dict(self.raw_data, shuffle=True)
101 | self.unlabeled_file_path = unlabeled_file_path
102 | if self.unlabeled_file_path:
103 | self.unlabeled_data_loader = UnlabeledDataLoader(file_path=self.unlabeled_file_path)
104 |
105 | def create_episode(self, n_support: int = 0, n_classes: int = 0, n_query: int = 0, n_unlabeled: int = 0, n_augment: int = 0):
106 | episode = dict()
107 | if n_classes:
108 | n_classes = min(n_classes, len(self.data_dict.keys()))
109 | rand_keys = np.random.choice(list(self.data_dict.keys()), n_classes, replace=False)
110 |
111 | assert min([len(val) for val in self.data_dict.values()]) >= n_support + n_query + n_unlabeled
112 |
113 | for key, val in self.data_dict.items():
114 | random.shuffle(val)
115 |
116 | if n_support:
117 | episode["xs"] = [[self.data_dict[k][i] for i in range(n_support)] for k in rand_keys]
118 | if n_query:
119 | episode["xq"] = [[self.data_dict[k][n_support + i] for i in range(n_query)] for k in rand_keys]
120 |
121 | if n_unlabeled:
122 | episode['xu'] = [item for k in rand_keys for item in self.data_dict[k][n_support + n_query:n_support + n_query + n_unlabeled]]
123 |
124 | if n_augment:
125 | episode = dict(**episode, **self.unlabeled_data_loader.create_episode(n_augment=n_augment))
126 |
127 | return episode
128 |
--------------------------------------------------------------------------------
/utils/scripts/runner.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | for n_class in 5 2 3 4; do
4 | for seed in 1 2 3 4 5; do
5 | for shots in 5; do
6 | for dataset in OOS TREC28 Liu; do
7 |
8 | OUTPUT_ROOT="runs/${dataset}/${n_class}C_${shots}K/seed${seed}"
9 | data_params="
10 | --train-path data/${dataset}/train.jsonl
11 | --valid-path data/${dataset}/valid.jsonl
12 | --test-path data/${dataset}/test.jsonl"
13 |
14 | baseline_params="
15 | --n-classes ${n_class}
16 | --n-support ${shots}
17 | --n-test-episodes 600"
18 |
19 | few_shot_params="
20 | --n-classes ${n_class}
21 | --n-support ${shots}
22 | --n-query 5
23 | --n-test-episodes 600"
24 |
25 | model_params="
26 | --model-name-or-path transformer_models/${dataset}/fine-tuned"
27 |
28 |
29 | baseline_training_params="
30 | --n-train-epoch 10
31 | --seed ${seed}"
32 |
33 | few_shot_training_params="
34 | --max-iter 10000
35 | --evaluate-every 100
36 | --early-stop 25
37 | --seed ${seed}"
38 |
39 | # Baseline
40 | OUT_PATH="${OUTPUT_ROOT}/baseline"
41 | if [[ -d "${OUT_PATH}" ]]; then
42 | echo "${OUT_PATH} already exists. Skipping."
43 | else
44 | mkdir -p ${OUT_PATH}
45 | LOGS_PATH="${OUT_PATH}/training.log"
46 |
47 | ./models/bert_baseline/baseline.sh \
48 | $(echo ${data_params}) \
49 | $(echo ${baseline_params}) \
50 | $(echo ${model_params}) \
51 | $(echo ${baseline_training_params}) \
52 | --output-path "${OUT_PATH}/output" > ${LOGS_PATH}
53 | fi
54 |
55 | # Induction Network
56 | OUT_PATH="${OUTPUT_ROOT}/induction"
57 | if [[ -d "${OUT_PATH}" ]]; then
58 | echo "${OUT_PATH} already exists. Skipping."
59 | else
60 | mkdir -p ${OUT_PATH}
61 | LOGS_PATH="${OUT_PATH}/training.log"
62 |
63 | ./models/induction/inductionnet.sh \
64 | $(echo ${data_params}) \
65 | $(echo ${few_shot_params}) \
66 | $(echo ${model_params}) \
67 | $(echo ${few_shot_training_params}) \
68 | --ntl-n-slices 100 \
69 | --n-routing-iter 3 \
70 | --output-path "${OUT_PATH}/output" > ${LOGS_PATH}
71 | fi
72 |
73 | # Relation Network
74 | for relation_module_type in base ntl; do
75 |
76 | OUT_PATH="${OUTPUT_ROOT}/relation-${relation_module_type}"
77 | if [[ -d "${OUT_PATH}" ]]; then
78 | echo "${OUT_PATH} already exists. Skipping."
79 | else
80 | mkdir -p ${OUT_PATH}
81 | LOGS_PATH="${OUT_PATH}/training.log"
82 |
83 | ./models/relation/relationnet.sh \
84 | $(echo ${data_params}) \
85 | $(echo ${few_shot_params}) \
86 | $(echo ${model_params}) \
87 | $(echo ${few_shot_training_params}) \
88 | --relation-module-type "${relation_module_type}" \
89 | --output-path "${OUT_PATH}/output" > ${LOGS_PATH}
90 |
91 | fi
92 | done
93 |
94 |
95 | for metric in euclidean cosine; do
96 | # Baseline++
97 | OUT_PATH="${OUTPUT_ROOT}/baseline++_${metric}"
98 | if [[ -d "${OUT_PATH}" ]]; then
99 | echo "${OUT_PATH} already exists. Skipping."
100 | else
101 | mkdir -p ${OUT_PATH}
102 | LOGS_PATH="${OUT_PATH}/training.log"
103 |
104 | ./models/bert_baseline/baseline.sh \
105 | $(echo ${data_params}) \
106 | $(echo ${baseline_params}) \
107 | $(echo ${model_params}) \
108 | $(echo ${baseline_training_params}) \
109 | --pp --metric "${metric}" \
110 | --output-path "${OUT_PATH}/output" > ${LOGS_PATH}
111 | fi
112 |
113 | # Matching Network
114 | OUT_PATH="${OUTPUT_ROOT}/matching-${metric}"
115 | if [[ -d "${OUT_PATH}" ]]; then
116 | echo "${OUT_PATH} already exists. Skipping."
117 | else
118 | mkdir -p ${OUT_PATH}
119 | LOGS_PATH="${OUT_PATH}/training.log"
120 |
121 | ./models/matching/matchingnet.sh \
122 | $(echo ${data_params}) \
123 | $(echo ${model_params}) \
124 | $(echo ${few_shot_params}) \
125 | $(echo ${few_shot_training_params}) \
126 | --metric ${metric} \
127 | --output-path "${OUT_PATH}/output" > ${LOGS_PATH}
128 | fi
129 |
130 | # Prototypical Network
131 | OUT_PATH="${OUTPUT_ROOT}/proto-${metric}"
132 | if [[ -d "${OUT_PATH}" ]]; then
133 | echo "${OUT_PATH} already exists. Skipping."
134 | else
135 | mkdir -p ${OUT_PATH}
136 | LOGS_PATH="${OUT_PATH}/training.log"
137 |
138 | ./models/proto/protonet.sh \
139 | $(echo ${data_params}) \
140 | $(echo ${model_params}) \
141 | $(echo ${few_shot_params}) \
142 | $(echo ${few_shot_training_params}) \
143 | --metric ${metric} \
144 | --output-path "${OUT_PATH}/output" > ${LOGS_PATH}
145 | fi
146 |
147 | # Proto++
148 | OUT_PATH="${OUTPUT_ROOT}/proto++-${metric}"
149 | if [[ -d "${OUT_PATH}" ]]; then
150 | echo "${OUT_PATH} already exists. Skipping."
151 | else
152 | mkdir -p ${OUT_PATH}
153 | LOGS_PATH="${OUT_PATH}/training.log"
154 |
155 | if [[ "${dataset}" == "Liu" ]]; then
156 | n_unlabeled=10
157 | else
158 | n_unlabeled=20
159 | fi
160 | ./models/proto/protonet.sh \
161 | $(echo ${data_params}) \
162 | $(echo ${model_params}) \
163 | $(echo ${few_shot_params}) \
164 | $(echo ${few_shot_training_params}) \
165 | --metric ${metric} \
166 | --n-unlabeled ${n_unlabeled} \
167 | --output-path "${OUT_PATH}/output" > ${LOGS_PATH}
168 | fi
169 | done
170 | done
171 | done
172 | done
173 | done
174 |
--------------------------------------------------------------------------------
/utils/few_shot.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | import numpy as np
4 | import random
5 | from typing import List
6 | from utils.data import get_tsv_data
7 | import torch
8 |
9 |
10 | def random_sample_cls(sentences: List[str], labels: List[str], n_support: int, n_query: int, label: str):
11 | """
12 | Randomly samples Ns examples as support set and Nq as Query set
13 | """
14 | data = [sentences[i] for i, lab in enumerate(labels) if lab == label]
15 | perm = torch.randperm(len(data))
16 | idx = perm[:n_support]
17 | support = [data[i] for i in idx]
18 | idx = perm[n_support: n_support + n_query]
19 | query = [data[i] for i in idx]
20 |
21 | return support, query
22 |
23 |
24 | def create_episode(data_dict, n_support, n_classes, n_query, n_unlabeled=0, n_augment=0):
25 | n_classes = min(n_classes, len(data_dict.keys()))
26 | rand_keys = np.random.choice(list(data_dict.keys()), n_classes, replace=False)
27 |
28 | assert min([len(val) for val in data_dict.values()]) >= n_support + n_query + n_unlabeled
29 |
30 | for key, val in data_dict.items():
31 | random.shuffle(val)
32 |
33 | episode = {
34 | "xs": [
35 | [data_dict[k][i] for i in range(n_support)] for k in rand_keys
36 | ],
37 | "xq": [
38 | [data_dict[k][n_support + i] for i in range(n_query)] for k in rand_keys
39 | ]
40 | }
41 |
42 | if n_unlabeled:
43 | episode['xu'] = [
44 | item for k in rand_keys for item in data_dict[k][n_support + n_query:n_support + n_query + 10]
45 | ]
46 |
47 | if n_augment:
48 | augmentations = list()
49 | already_done = list()
50 | for i in range(n_augment):
51 | # Draw a random label
52 | key = random.choice(list(data_dict.keys()))
53 | # Draw a random data index
54 | ix = random.choice(range(len(data_dict[key])))
55 | # If already used, re-sample
56 | while (key, ix) in already_done:
57 | key = random.choice(list(data_dict.keys()))
58 | ix = random.choice(range(len(data_dict[key])))
59 | already_done.append((key, ix))
60 | if "augmentations" not in data_dict[key][ix]:
61 | raise KeyError(f"Input data {data_dict[key][ix]} does not contain any augmentations / is not properly formatted.")
62 | augmentations.append((
63 | data_dict[key][ix]["sentence"],
64 | [item["text"] for item in data_dict[key][ix]["augmentations"]]
65 | ))
66 | episode["x_augment"] = augmentations
67 |
68 | return episode
69 |
70 |
71 | def create_ARSC_train_episode(prefix: str = "data/ARSC-Yu/raw", n_support: int = 5, n_query: int = 5, n_unlabeled=0):
72 | labels = sorted(
73 | set([line.strip() for line in open(f"{prefix}/workspace.filtered.list", "r").readlines()])
74 | - set([line.strip() for line in open(f"{prefix}/workspace.target.list", "r").readlines()]))
75 |
76 | # Pick a random label
77 | label = random.choice(labels)
78 |
79 | # Pick a random binary task (2, 4, 5)
80 | binary_task = random.choice([2, 4, 5])
81 |
82 | # Fix: this label/binary task sucks
83 | while label == "office_products" and binary_task == 2:
84 | # Pick a random label
85 | label = random.choice(labels)
86 |
87 | # Pick a random binary task (2, 4, 5)
88 | binary_task = random.choice([2, 4, 5])
89 |
90 | data = (
91 | get_tsv_data(f"{prefix}/{label}.t{binary_task}.train", label=label) +
92 | get_tsv_data(f"{prefix}/{label}.t{binary_task}.dev", label=label) +
93 | get_tsv_data(f"{prefix}/{label}.t{binary_task}.test", label=label)
94 | )
95 |
96 | random.shuffle(data)
97 | task = collections.defaultdict(list)
98 | for d in data:
99 | task[d['label']].append(d['sentence'])
100 | task = dict(task)
101 |
102 | assert min([len(val) for val in task.values()]) >= n_support + n_query + n_unlabeled, \
103 | f"Label {label}_{binary_task}: min samples is {min([len(val) for val in task.values()])} while K+Q+U={n_support + n_query + n_unlabeled}"
104 |
105 | for key, val in task.items():
106 | random.shuffle(val)
107 |
108 | episode = {
109 | "xs": [
110 | [task[k][i] for i in range(n_support)] for k in task.keys()
111 | ],
112 | "xq": [
113 | [task[k][n_support + i] for i in range(n_query)] for k in task.keys()
114 | ]
115 | }
116 |
117 | if n_unlabeled:
118 | episode['xu'] = [
119 | item for k in task.keys() for item in task[k][n_support + n_query:n_support + n_query + n_unlabeled]
120 | ]
121 | return episode
122 |
123 |
124 | def create_ARSC_test_episode(prefix: str = "data/ARSC-Yu/raw", n_query: int = 5, n_unlabeled=0, set_type: str = "test"):
125 | assert set_type in ("test", "dev")
126 | labels = [line.strip() for line in open(f"{prefix}/workspace.target.list", "r").readlines()]
127 |
128 | # Pick a random label
129 | label = random.choice(labels)
130 |
131 | # Pick a random binary task (2, 4, 5)
132 | binary_task = random.choice([2, 4, 5])
133 |
134 | support_data = get_tsv_data(f"{prefix}/{label}.t{binary_task}.train", label=label)
135 | assert len(support_data) == 10 # 2 * 5 shots
136 | support_dict = collections.defaultdict(list)
137 | for d in support_data:
138 | support_dict[d['label']].append(d['sentence'])
139 |
140 | query_data = get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.{set_type}", label=label)
141 | query_dict = collections.defaultdict(list)
142 | for d in query_data:
143 | query_dict[d['label']].append(d['sentence'])
144 |
145 | assert min([len(val) for val in query_dict.values()]) >= n_query + n_unlabeled
146 |
147 | for key, val in query_dict.items():
148 | random.shuffle(val)
149 |
150 | episode = {
151 | "xs": [
152 | [sentence for sentence in support_dict[k]] for k in sorted(query_dict.keys())
153 | ],
154 | "xq": [
155 | [query_dict[k][i] for i in range(n_query)] for k in sorted(query_dict.keys())
156 | ]
157 | }
158 |
159 | if n_unlabeled:
160 | episode['xu'] = [
161 | item for k in sorted(query_dict.keys()) for item in query_dict[k][n_query:n_query + n_unlabeled]
162 | ]
163 | return episode
164 |
165 |
166 | def create_ARSC_train_baseline_episode():
167 | labels = sorted(
168 | set([line.strip() for line in open("data/ARSC-Yu/raw/workspace.filtered.list", "r").readlines()])
169 | - set([line.strip() for line in open("data/ARSC-Yu/raw/workspace.target.list", "r").readlines()]))
170 |
171 | # Pick a random label
172 | label = random.choice(labels)
173 |
174 | # Pick a random binary task (2, 4, 5)
175 | binary_task = random.choice([2, 4, 5])
176 |
177 | data = (
178 | get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.train", label=label) +
179 | get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.dev", label=label) +
180 | get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.test", label=label)
181 | )
182 |
183 | random.shuffle(data)
184 | task = collections.defaultdict(list)
185 | for d in data:
186 | task[d['label']].append(d['sentence'])
187 | task = dict(task)
188 |
189 | for key, val in task.items():
190 | random.shuffle(val)
191 |
192 | episode = {
193 | "xs": [
194 | list(task[k]) for k in task.keys()
195 | ]
196 | }
197 |
198 | return episode
199 |
200 |
201 | def get_ARSC_test_tasks():
202 | labels = sorted(set([line.strip() for line in open("data/ARSC-Yu/raw/workspace.target.list", "r").readlines()]))
203 |
204 | tasks = list()
205 | for label in labels:
206 | for binary_task in (2, 4, 5):
207 | train_data = get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.train", label=label)
208 | valid_data = get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.dev", label=label)
209 | test_data = get_tsv_data(f"data/ARSC-Yu/raw/{label}.t{binary_task}.test", label=label)
210 | tasks.append({
211 | "xs": [
212 | [d['sentence'] for d in train_data if d['label'] == f"{label}-1"],
213 | [d['sentence'] for d in train_data if d['label'] == f"{label}1"],
214 | ],
215 | "x_valid": [
216 | [d['sentence'] for d in valid_data if d['label'] == f"{label}-1"],
217 | [d['sentence'] for d in valid_data if d['label'] == f"{label}1"],
218 | ],
219 | "x_test": [
220 | [d['sentence'] for d in test_data if d['label'] == f"{label}-1"],
221 | [d['sentence'] for d in test_data if d['label'] == f"{label}1"],
222 | ],
223 | })
224 |
225 | assert all([len(task['xs'][0]) == len(task['xs'][1]) for task in tasks])
226 | return tasks
227 |
--------------------------------------------------------------------------------
/language_modeling/run_language_modeling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
18 | GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
19 | using a masked language modeling (MLM) loss.
20 | """
21 |
22 | import logging
23 | import math
24 | import os
25 | from dataclasses import dataclass, field
26 | from typing import Optional
27 |
28 | from transformers import (
29 | CONFIG_MAPPING,
30 | MODEL_WITH_LM_HEAD_MAPPING,
31 | AutoConfig,
32 | AutoModelForMaskedLM,
33 | AutoTokenizer,
34 | DataCollatorForLanguageModeling,
35 | HfArgumentParser,
36 | LineByLineTextDataset,
37 | PreTrainedTokenizer,
38 | TextDataset,
39 | Trainer,
40 | TrainingArguments,
41 | set_seed,
42 | )
43 |
44 | logger = logging.getLogger(__name__)
45 |
46 | MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
47 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
48 |
49 |
50 | @dataclass
51 | class ModelArguments:
52 | """
53 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
54 | """
55 |
56 | model_name_or_path: Optional[str] = field(
57 | default=None,
58 | metadata={
59 | "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch."
60 | },
61 | )
62 | model_type: Optional[str] = field(
63 | default=None,
64 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
65 | )
66 | config_name: Optional[str] = field(
67 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
68 | )
69 | tokenizer_name: Optional[str] = field(
70 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
71 | )
72 | cache_dir: Optional[str] = field(
73 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
74 | )
75 |
76 |
77 | @dataclass
78 | class DataTrainingArguments:
79 | """
80 | Arguments pertaining to what data we are going to input our model for training and eval.
81 | """
82 |
83 | train_data_file: Optional[str] = field(
84 | default=None, metadata={"help": "The input training data file (a text file)."}
85 | )
86 | eval_data_file: Optional[str] = field(
87 | default=None,
88 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
89 | )
90 | line_by_line: bool = field(
91 | default=False,
92 | metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
93 | )
94 |
95 | mlm: bool = field(
96 | default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."}
97 | )
98 | mlm_probability: float = field(
99 | default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
100 | )
101 |
102 | block_size: int = field(
103 | default=-1,
104 | metadata={
105 | "help": "Optional input sequence length after tokenization."
106 | "The training dataset will be truncated in block of this size for training."
107 | "Default to the model max input length for single sentence inputs (take into account special tokens)."
108 | },
109 | )
110 | overwrite_cache: bool = field(
111 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
112 | )
113 |
114 |
115 | def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False, local_rank=-1):
116 | file_path = args.eval_data_file if evaluate else args.train_data_file
117 | if args.line_by_line:
118 | return LineByLineTextDataset(
119 | tokenizer=tokenizer, file_path=file_path, block_size=args.block_size
120 | )
121 | else:
122 | return TextDataset(
123 | tokenizer=tokenizer, file_path=file_path, block_size=args.block_size
124 | )
125 |
126 |
127 | def main():
128 | # See all possible arguments in src/transformers/training_args.py
129 | # or by passing the --help flag to this script.
130 | # We now keep distinct sets of args, for a cleaner separation of concerns.
131 |
132 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
133 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
134 |
135 | if data_args.eval_data_file is None and training_args.do_eval:
136 | raise ValueError(
137 | "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
138 | "or remove the --do_eval argument."
139 | )
140 |
141 | if (
142 | os.path.exists(training_args.output_dir)
143 | and os.listdir(training_args.output_dir)
144 | and training_args.do_train
145 | and not training_args.overwrite_output_dir
146 | ):
147 | raise ValueError(
148 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
149 | )
150 |
151 | # Setup logging
152 | logging.basicConfig(
153 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
154 | datefmt="%m/%d/%Y %H:%M:%S",
155 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
156 | )
157 | logger.warning(
158 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
159 | training_args.local_rank,
160 | training_args.device,
161 | training_args.n_gpu,
162 | bool(training_args.local_rank != -1),
163 | training_args.fp16,
164 | )
165 | logger.info("Training/evaluation parameters %s", training_args)
166 |
167 | # Set seed
168 | set_seed(training_args.seed)
169 |
170 | # Load pretrained model and tokenizer
171 | #
172 | # Distributed training:
173 | # The .from_pretrained methods guarantee that only one local process can concurrently
174 | # download model & vocab.
175 |
176 | if model_args.config_name:
177 | config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
178 | elif model_args.model_name_or_path:
179 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
180 | else:
181 | config = CONFIG_MAPPING[model_args.model_type]()
182 | logger.warning("You are instantiating a new config instance from scratch.")
183 |
184 | if model_args.tokenizer_name:
185 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir)
186 | elif model_args.model_name_or_path:
187 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
188 | else:
189 | raise ValueError(
190 | "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
191 | "and load it from here, using --tokenizer_name"
192 | )
193 |
194 | if model_args.model_name_or_path:
195 | model = AutoModelForMaskedLM.from_pretrained(
196 | model_args.model_name_or_path,
197 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
198 | config=config,
199 | cache_dir=model_args.cache_dir,
200 | )
201 | else:
202 | logger.info("Training new model from scratch")
203 | model = AutoModelForMaskedLM.from_config(config)
204 |
205 | model.resize_token_embeddings(len(tokenizer))
206 |
207 | if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm:
208 | raise ValueError(
209 | "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
210 | "flag (masked language modeling)."
211 | )
212 |
213 | if data_args.block_size <= 0:
214 | data_args.block_size = tokenizer.max_len
215 | # Our input block size will be the max possible for the model
216 | else:
217 | data_args.block_size = min(data_args.block_size, tokenizer.model_max_length)
218 |
219 | # Get datasets
220 | train_dataset = (
221 | get_dataset(data_args, tokenizer=tokenizer)
222 | if training_args.do_train
223 | else None
224 | )
225 | eval_dataset = (
226 | get_dataset(data_args, tokenizer=tokenizer, evaluate=True)
227 | if training_args.do_eval
228 | else None
229 | )
230 | data_collator = DataCollatorForLanguageModeling(
231 | tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
232 | )
233 |
234 | # Initialize our Trainer
235 | trainer = Trainer(
236 | model=model,
237 | args=training_args,
238 | data_collator=data_collator,
239 | train_dataset=train_dataset,
240 | eval_dataset=eval_dataset,
241 | )
242 |
243 | # Training
244 | if training_args.do_train:
245 | model_path = (
246 | model_args.model_name_or_path
247 | if model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)
248 | else None
249 | )
250 | trainer.train(model_path=model_path)
251 | trainer.save_model()
252 | # For convenience, we also re-save the tokenizer to the same directory,
253 | # so that you can share your model easily on huggingface.co/models =)
254 | if trainer.is_world_process_zero():
255 | tokenizer.save_pretrained(training_args.output_dir)
256 |
257 | # Evaluation
258 | results = {}
259 | if training_args.do_eval and training_args.local_rank in [-1, 0]:
260 | logger.info("*** Evaluate ***")
261 |
262 | eval_output = trainer.evaluate()
263 | perplexity = math.exp(eval_output["eval_loss"])
264 | result = {"perplexity": perplexity}
265 |
266 | output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt")
267 | with open(output_eval_file, "w") as writer:
268 | logger.info("***** Eval results *****")
269 | for key in sorted(result.keys()):
270 | logger.info(" %s = %s", key, str(result[key]))
271 | writer.write("%s = %s\n" % (key, str(result[key])))
272 |
273 | results.update(result)
274 |
275 | return results
276 |
277 |
278 | if __name__ == "__main__":
279 | main()
280 |
--------------------------------------------------------------------------------
/models/matching/matchingnet.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from utils.data import get_jsonl_data
4 | from utils.python import now, set_seeds
5 | import random
6 | import collections
7 | import os
8 | from typing import List, Dict
9 | from tensorboardX import SummaryWriter
10 | import numpy as np
11 | from models.encoders.bert_encoder import BERTEncoder
12 | import torch
13 | import torch.nn as nn
14 | import warnings
15 | import logging
16 | from utils.few_shot import create_episode, create_ARSC_train_episode, create_ARSC_test_episode
17 | from utils.math import euclidean_dist, cosine_similarity
18 | from sklearn.metrics import f1_score
19 |
20 | logging.basicConfig()
21 | logger = logging.getLogger(__name__)
22 | logger.setLevel(logging.DEBUG)
23 |
24 | warnings.simplefilter('ignore')
25 |
26 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
27 |
28 |
29 | class MatchingNet(nn.Module):
30 | def __init__(self, encoder, metric: str = "cosine"):
31 | super(MatchingNet, self).__init__()
32 |
33 | self.encoder = encoder
34 | self.metric = metric
35 | assert self.metric in ("cosine", "euclidean")
36 |
37 | def loss(self, sample):
38 | """
39 | :param sample: {
40 | "xs": [
41 | [support_A_1, support_A_2, ...],
42 | [support_B_1, support_B_2, ...],
43 | [support_C_1, support_C_2, ...],
44 | ...
45 | ],
46 | "xq": [
47 | [query_A_1, query_A_2, ...],
48 | [query_B_1, query_B_2, ...],
49 | [query_C_1, query_C_2, ...],
50 | ...
51 | ]
52 | }
53 | :return:
54 | """
55 | xs = sample["xs"] # support
56 | xq = sample["xq"] # query
57 |
58 | n_class = len(xs)
59 | assert len(xq) == n_class
60 | n_support = len(xs[0])
61 | n_query = len(xq[0])
62 |
63 | x = [item for xs_ in xs for item in xs_] + [item for xq_ in xq for item in xq_]
64 | z = self.encoder.forward(x)
65 | z_support = z[:n_class * n_support]
66 | z_query = z[n_class * n_support:]
67 |
68 | if self.metric == "euclidean":
69 | similarities = -euclidean_dist(z_query, z_support)
70 | elif self.metric == "cosine":
71 | similarities = cosine_similarity(z_query, z_support) * 5
72 | else:
73 | raise NotImplementedError
74 |
75 | # Average over support samples
76 | distances_from_query_to_classes = torch.cat([similarities[:, c * n_support: (c + 1) * n_support].mean(1).view(1, -1) for c in range(n_class)]).T
77 | true_labels = torch.zeros_like(distances_from_query_to_classes)
78 |
79 | for ix_class, class_query_sentences in enumerate(xq):
80 | for ix_sentence, sentence in enumerate(class_query_sentences):
81 | true_labels[ix_class * n_query + ix_sentence, ix_class] = 1
82 |
83 | loss_fn = nn.CrossEntropyLoss()
84 | loss_val = loss_fn(distances_from_query_to_classes, true_labels.argmax(1))
85 | acc_val = (true_labels.argmax(1) == distances_from_query_to_classes.argmax(1)).float().mean()
86 | f1_val = f1_score(
87 | true_labels.argmax(1).cpu().numpy(),
88 | distances_from_query_to_classes.argmax(1).cpu().numpy(),
89 | average="weighted"
90 | )
91 | return loss_val, {
92 | "loss": loss_val.item(),
93 | "metrics": {
94 | "acc": acc_val.item(),
95 | "loss": loss_val.item(),
96 | "f1": f1_val,
97 | },
98 | "y_hat": distances_from_query_to_classes.argmax(1).cpu().detach().numpy()
99 | }
100 |
101 | def train_step(self, optimizer, data_dict: Dict[str, List[str]], n_support, n_classes, n_query):
102 |
103 | episode = create_episode(
104 | data_dict=data_dict,
105 | n_support=n_support,
106 | n_classes=n_classes,
107 | n_query=n_query
108 | )
109 |
110 | self.train()
111 | optimizer.zero_grad()
112 | torch.cuda.empty_cache()
113 | loss, loss_dict = self.loss(episode)
114 | loss.backward()
115 | optimizer.step()
116 |
117 | return loss, loss_dict
118 |
119 | def test_step(self, data_dict, n_support, n_classes, n_query, n_episodes=1000):
120 | metrics = collections.defaultdict(list)
121 | self.eval()
122 | for i in range(n_episodes):
123 | episode = create_episode(
124 | data_dict=data_dict,
125 | n_support=n_support,
126 | n_classes=n_classes,
127 | n_query=n_query
128 | )
129 |
130 | with torch.no_grad():
131 | loss, loss_dict = self.loss(episode)
132 |
133 | for k, v in loss_dict["metrics"].items():
134 | metrics[k].append(v)
135 |
136 | return {
137 | key: np.mean(value) for key, value in metrics.items()
138 | }
139 |
140 | def train_step_ARSC(self, data_path: str, optimizer):
141 | episode = create_ARSC_train_episode(prefix=data_path, n_support=5, n_query=5)
142 |
143 | self.train()
144 | optimizer.zero_grad()
145 | torch.cuda.empty_cache()
146 | loss, loss_dict = self.loss(episode)
147 | loss.backward()
148 | optimizer.step()
149 |
150 | return loss, loss_dict
151 |
152 | def test_step_ARSC(self, data_path: str, n_episodes=1000, set_type="test"):
153 | assert set_type in ("dev", "test")
154 | metrics = collections.defaultdict(list)
155 | self.eval()
156 | for i in range(n_episodes):
157 | episode = create_ARSC_test_episode(prefix=data_path, n_query=5, set_type=set_type)
158 |
159 | with torch.no_grad():
160 | loss, loss_dict = self.loss(episode)
161 |
162 | for k, v in loss_dict["metrics"].items():
163 | metrics[k].append(v)
164 |
165 | return {
166 | key: np.mean(value) for key, value in metrics.items()
167 | }
168 |
169 |
170 | def run_matching(
171 | train_path: str,
172 | model_name_or_path: str,
173 | n_support: int,
174 | n_query: int,
175 | n_classes: int,
176 | valid_path: str = None,
177 | test_path: str = None,
178 | output_path: str = f"runs/{now()}",
179 | max_iter: int = 10000,
180 | evaluate_every: int = 100,
181 | early_stop: int = None,
182 | n_test_episodes: int = 1000,
183 | log_every: int = 10,
184 | metric: str = "cosine",
185 | arsc_format: bool = False,
186 | data_path: str = None
187 | ):
188 | if output_path:
189 | if os.path.exists(output_path) and len(os.listdir(output_path)):
190 | raise FileExistsError(f"Output path {output_path} already exists. Exiting.")
191 |
192 | # --------------------
193 | # Creating Log Writers
194 | # --------------------
195 | os.makedirs(output_path)
196 | os.makedirs(os.path.join(output_path, "logs/train"))
197 | train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(output_path, "logs/train"), flush_secs=1, max_queue=1)
198 | valid_writer: SummaryWriter = None
199 | test_writer: SummaryWriter = None
200 | log_dict = dict(train=list())
201 |
202 | if valid_path:
203 | os.makedirs(os.path.join(output_path, "logs/valid"))
204 | valid_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/valid"), flush_secs=1, max_queue=1)
205 | log_dict["valid"] = list()
206 | if test_path:
207 | os.makedirs(os.path.join(output_path, "logs/test"))
208 | test_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/test"), flush_secs=1, max_queue=1)
209 | log_dict["test"] = list()
210 |
211 | def raw_data_to_labels_dict(data, shuffle=True):
212 | labels_dict = collections.defaultdict(list)
213 | for item in data:
214 | labels_dict[item["label"]].append(item["sentence"])
215 | labels_dict = dict(labels_dict)
216 | if shuffle:
217 | for key, val in labels_dict.items():
218 | random.shuffle(val)
219 | return labels_dict
220 |
221 | # Load model
222 | bert = BERTEncoder(model_name_or_path).to(device)
223 | matching_net = MatchingNet(encoder=bert, metric=metric)
224 | optimizer = torch.optim.Adam(matching_net.parameters(), lr=2e-5)
225 |
226 | # Load data
227 | if not arsc_format:
228 | train_data = get_jsonl_data(train_path)
229 | train_data_dict = raw_data_to_labels_dict(train_data, shuffle=True)
230 | logger.info(f"train labels: {train_data_dict.keys()}")
231 |
232 | if valid_path:
233 | valid_data = get_jsonl_data(valid_path)
234 | valid_data_dict = raw_data_to_labels_dict(valid_data, shuffle=True)
235 | logger.info(f"valid labels: {valid_data_dict.keys()}")
236 | else:
237 | valid_data_dict = None
238 |
239 | if test_path:
240 | test_data = get_jsonl_data(test_path)
241 | test_data_dict = raw_data_to_labels_dict(test_data, shuffle=True)
242 | logger.info(f"test labels: {test_data_dict.keys()}")
243 | else:
244 | test_data_dict = None
245 | else:
246 | train_data_dict = None
247 | valid_data_dict = None
248 | test_data_dict = None
249 |
250 | train_metrics = collections.defaultdict(list)
251 | n_eval_since_last_best = 0
252 | best_valid_acc = 0.0
253 |
254 | for step in range(max_iter):
255 | if not arsc_format:
256 | loss, loss_dict = matching_net.train_step(
257 | optimizer=optimizer,
258 | data_dict=train_data_dict,
259 | n_support=n_support,
260 | n_query=n_query,
261 | n_classes=n_classes
262 | )
263 | else:
264 | loss, loss_dict = matching_net.train_step_ARSC(
265 | data_path=data_path,
266 | optimizer=optimizer
267 | )
268 | for key, value in loss_dict["metrics"].items():
269 | train_metrics[key].append(value)
270 |
271 | # Logging
272 | if (step + 1) % log_every == 0:
273 | for key, value in train_metrics.items():
274 | train_writer.add_scalar(tag=key, scalar_value=np.mean(value), global_step=step)
275 | logger.info(f"train | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in train_metrics.items()]))
276 | log_dict["train"].append({
277 | "metrics": [
278 | {
279 | "tag": key,
280 | "value": np.mean(value)
281 | }
282 | for key, value in train_metrics.items()
283 | ],
284 | "global_step": step
285 | })
286 |
287 | train_metrics = collections.defaultdict(list)
288 |
289 | if valid_path or test_path:
290 | if (step + 1) % evaluate_every == 0:
291 | for path, writer, set_type, set_data in zip(
292 | [valid_path, test_path],
293 | [valid_writer, test_writer],
294 | ["valid", "test"],
295 | [valid_data_dict, test_data_dict]
296 | ):
297 | if path:
298 | if not arsc_format:
299 | set_results = matching_net.test_step(
300 | data_dict=set_data,
301 | n_support=n_support,
302 | n_query=n_query,
303 | n_classes=n_classes,
304 | n_episodes=n_test_episodes
305 | )
306 | else:
307 | set_results = matching_net.test_step_ARSC(
308 | data_path=data_path,
309 | n_episodes=n_test_episodes,
310 | set_type={"valid": "dev", "test": "test"}[set_type]
311 | )
312 |
313 | for key, val in set_results.items():
314 | writer.add_scalar(tag=key, scalar_value=val, global_step=step)
315 | log_dict[set_type].append({
316 | "metrics": [
317 | {
318 | "tag": key,
319 | "value": val
320 | }
321 | for key, val in set_results.items()
322 | ],
323 | "global_step": step
324 | })
325 |
326 | logger.info(f"{set_type} | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in set_results.items()]))
327 | if set_type == "valid":
328 | if set_results["acc"] > best_valid_acc:
329 | best_valid_acc = set_results["acc"]
330 | n_eval_since_last_best = 0
331 | logger.info(f"Better eval results!")
332 | else:
333 | n_eval_since_last_best += 1
334 | logger.info(f"Worse eval results ({n_eval_since_last_best}/{early_stop})")
335 |
336 | if early_stop and n_eval_since_last_best >= early_stop:
337 | logger.warning(f"Early-stopping.")
338 | break
339 | with open(os.path.join(output_path, "metrics.json"), "w") as file:
340 | json.dump(log_dict, file, ensure_ascii=False)
341 |
342 |
343 | def main():
344 | parser = argparse.ArgumentParser()
345 | parser.add_argument("--train-path", type=str, required=True, help="Path to training data")
346 | parser.add_argument("--valid-path", type=str, default=None, help="Path to validation data")
347 | parser.add_argument("--test-path", type=str, default=None, help="Path to testing data")
348 | parser.add_argument("--data-path", type=str, default=None, help="Path to data (ARSC only)")
349 |
350 | parser.add_argument("--output-path", type=str, default=f"runs/{now()}")
351 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use")
352 | parser.add_argument("--max-iter", type=int, default=10000, help="Max number of training episodes")
353 | parser.add_argument("--evaluate-every", type=int, default=100, help="Number of training episodes between each evaluation (on both valid, test)")
354 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging")
355 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set")
356 | parser.add_argument("--early-stop", type=int, default=0, help="Number of worse evaluation steps before stopping. 0=disabled")
357 |
358 | # Few-Shot related stuff
359 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class")
360 | parser.add_argument("--n-query", type=int, default=5, help="Number of query points for each class")
361 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode")
362 | parser.add_argument("--n-test-episodes", type=int, default=1000, help="Number of episodes during evaluation (valid, test)")
363 |
364 | # Which metric to use
365 | parser.add_argument("--metric", type=str, default="cosine", help="Metric to use", choices=("euclidean", "cosine"))
366 |
367 | # ARSC data
368 | parser.add_argument("--arsc-format", default=False, action="store_true", help="Using ARSC few-shot format")
369 | args = parser.parse_args()
370 |
371 | # Set random seed
372 | set_seeds(args.seed)
373 |
374 | # Check if data path(s) exist
375 | for arg in [args.train_path, args.valid_path, args.test_path]:
376 | if arg and not os.path.exists(arg):
377 | raise FileNotFoundError(f"Data @ {arg} not found.")
378 |
379 | # Run
380 | run_matching(
381 | train_path=args.train_path,
382 | valid_path=args.valid_path,
383 | test_path=args.test_path,
384 | output_path=args.output_path,
385 |
386 | model_name_or_path=args.model_name_or_path,
387 |
388 | n_support=args.n_support,
389 | n_query=args.n_query,
390 | n_classes=args.n_classes,
391 | n_test_episodes=args.n_test_episodes,
392 |
393 | max_iter=args.max_iter,
394 | evaluate_every=args.evaluate_every,
395 | metric=args.metric,
396 | early_stop=args.early_stop,
397 | arsc_format=args.arsc_format,
398 | data_path=args.data_path,
399 | log_every=args.log_every
400 | )
401 |
402 | # Save config
403 | with open(os.path.join(args.output_path, "config.json"), "w") as file:
404 | json.dump(vars(args), file, ensure_ascii=False)
405 |
406 |
407 | if __name__ == "__main__":
408 | main()
409 |
--------------------------------------------------------------------------------
/models/relation/relationnet.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from utils.data import get_jsonl_data
4 | from utils.python import now, set_seeds
5 | import random
6 | import collections
7 | import os
8 | from typing import List, Dict
9 | from tensorboardX import SummaryWriter
10 | import numpy as np
11 | from models.encoders.bert_encoder import BERTEncoder
12 | import torch
13 | import torch.nn as nn
14 | import warnings
15 | import logging
16 | from utils.few_shot import create_episode, create_ARSC_train_episode, create_ARSC_test_episode
17 |
18 | logging.basicConfig()
19 | logger = logging.getLogger(__name__)
20 | logger.setLevel(logging.DEBUG)
21 |
22 | warnings.simplefilter('ignore')
23 |
24 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
25 |
26 |
27 | class RelationNet(nn.Module):
28 | def __init__(self, encoder, hidden_dim: int = 768, relation_module_type: str = "base", ntl_n_slices: int = 100):
29 | super(RelationNet, self).__init__()
30 |
31 | self.encoder = encoder
32 | self.relation_module_type = relation_module_type
33 | self.ntl_n_slices = ntl_n_slices
34 | self.hidden_dim = hidden_dim
35 |
36 | # Declare relation module
37 | if self.relation_module_type == "base":
38 | self.relation_module = RelationModule(input_dim=hidden_dim).to(device)
39 | elif self.relation_module_type == "ntl":
40 | self.relation_module = NTLRelationModule(input_dim=hidden_dim, n_slice=self.ntl_n_slices).to(device)
41 | else:
42 | raise NotImplementedError(f"relation module type {self.relation_module_type} not implemented.")
43 |
44 | def loss(self, sample):
45 | """
46 | :param sample: {
47 | "xs": [
48 | [support_A_1, support_A_2, ...],
49 | [support_B_1, support_B_2, ...],
50 | [support_C_1, support_C_2, ...],
51 | ...
52 | ],
53 | "xq": [
54 | [query_A_1, query_A_2, ...],
55 | [query_B_1, query_B_2, ...],
56 | [query_C_1, query_C_2, ...],
57 | ...
58 | ]
59 | }
60 | :return:
61 | """
62 | xs = sample["xs"] # support
63 | xq = sample["xq"] # query
64 |
65 | n_class = len(xs)
66 | assert len(xq) == n_class
67 | n_support = len(xs[0])
68 | n_query = len(xq[0])
69 |
70 | x = [item for xs_ in xs for item in xs_] + [item for xq_ in xq for item in xq_]
71 | z = self.encoder.forward(x)
72 | z_dim = z.size(-1)
73 |
74 | z_query = z[n_class * n_support:]
75 | z_proto = z[:n_class * n_support].view(n_class, n_support, z_dim).mean(1)
76 |
77 | relation_module_scores = self.relation_module.forward(z_q=z_query, z_c=z_proto)
78 | true_labels = torch.zeros_like(relation_module_scores).to(device)
79 |
80 | for ix_class, class_query_sentences in enumerate(xq):
81 | for ix_sentence, sentence in enumerate(class_query_sentences):
82 | true_labels[ix_class * n_query + ix_sentence, ix_class] = 1
83 |
84 | # MSE LOSS
85 | # relation_module_scores = torch.sigmoid(relation_module_scores)
86 | # loss_fn = nn.MSELoss()
87 | # loss_val = loss_fn(relation_module_scores, true_labels)
88 | # acc_full = ((relation_module_scores > 0.5).float() == true_labels.float()).float().mean()
89 | # acc_exact = (((relation_module_scores > 0.5).float() - true_labels.float()).abs().max(dim=1)[0] == 0).float().mean()
90 | # acc_max = (relation_module_scores.argmax(1) == true_labels.argmax(1)).float().mean()
91 | #
92 | # return loss_val, {
93 | # "loss": loss_val.item(),
94 | # "metrics": {
95 | # "loss": loss_val.item(),
96 | # "acc_full": acc_full.item(),
97 | # "acc_exact": acc_exact.item(),
98 | # "acc_max": acc_max.item(),
99 | # "acc": acc_max.item()
100 | # },
101 | # "y_hat": relation_module_scores.argmax(1).cpu().detach().numpy()
102 | # }
103 |
104 | # CE LOSS
105 | loss_fn = nn.CrossEntropyLoss()
106 | loss_val = loss_fn(relation_module_scores, true_labels.argmax(1))
107 | acc_val = (true_labels.argmax(1) == relation_module_scores.argmax(1)).float().mean()
108 | return loss_val, {
109 | "loss": loss_val.item(),
110 | "metrics": {
111 | "loss": loss_val.item(),
112 | "acc": acc_val.item()
113 | },
114 | "y_hat": relation_module_scores.argmax(1).cpu().detach().numpy()
115 | }
116 |
117 | def train_step(self, optimizer, data_dict: Dict[str, List[str]], n_support, n_classes, n_query):
118 |
119 | episode = create_episode(
120 | data_dict=data_dict,
121 | n_support=n_support,
122 | n_classes=n_classes,
123 | n_query=n_query
124 | )
125 |
126 | self.train()
127 | optimizer.zero_grad()
128 | torch.cuda.empty_cache()
129 | loss, loss_dict = self.loss(episode)
130 | loss.backward()
131 | optimizer.step()
132 |
133 | return loss, loss_dict
134 |
135 | def test_step(self, data_dict, n_support, n_classes, n_query, n_episodes=1000):
136 | metrics = collections.defaultdict(list)
137 | self.eval()
138 | for i in range(n_episodes):
139 | episode = create_episode(
140 | data_dict=data_dict,
141 | n_support=n_support,
142 | n_classes=n_classes,
143 | n_query=n_query
144 | )
145 |
146 | with torch.no_grad():
147 | loss, loss_dict = self.loss(episode)
148 |
149 | for key, value in loss_dict["metrics"].items():
150 | metrics[key].append(value)
151 |
152 | return {
153 | key: np.mean(value) for key, value in metrics.items()
154 | }
155 |
156 | def train_step_ARSC(self, data_path: str, optimizer):
157 | episode = create_ARSC_train_episode(prefix=data_path, n_support=5, n_query=5)
158 |
159 | self.train()
160 | optimizer.zero_grad()
161 | torch.cuda.empty_cache()
162 | loss, loss_dict = self.loss(episode)
163 | loss.backward()
164 | optimizer.step()
165 |
166 | return loss, loss_dict
167 |
168 | def test_step_ARSC(self, data_path: str, n_episodes=1000, set_type="test"):
169 | assert set_type in ("dev", "test")
170 | metrics = collections.defaultdict(list)
171 | self.eval()
172 | for i in range(n_episodes):
173 | episode = create_ARSC_test_episode(prefix=data_path, n_query=5, set_type=set_type)
174 |
175 | with torch.no_grad():
176 | loss, loss_dict = self.loss(episode)
177 |
178 | for key, value in loss_dict["metrics"].items():
179 | metrics[key].append(value)
180 |
181 | return {
182 | key: np.mean(value) for key, value in metrics.items()
183 | }
184 |
185 |
186 | class RelationModule(nn.Module):
187 | def __init__(self, input_dim):
188 | super(RelationModule, self).__init__()
189 | self.fc1 = nn.Sequential(
190 | nn.Linear(in_features=input_dim * 2, out_features=input_dim),
191 | nn.ReLU(),
192 | nn.Dropout(p=0.25)
193 | )
194 | self.fc2 = nn.Sequential(
195 | nn.Linear(in_features=input_dim, out_features=1)
196 | )
197 |
198 | def forward(self, z_q, z_c):
199 | n_class = z_c.size(0)
200 | n_query = z_q.size(0)
201 | concatenated = torch.cat((
202 | z_q.repeat(1, n_class).view(-1, z_q.size(-1)),
203 | z_c.repeat(n_query, 1)
204 | ), dim=1)
205 |
206 | return self.fc2(self.fc1(concatenated)).view(n_query, n_class)
207 |
208 |
209 | class NTLRelationModule(nn.Module):
210 | def __init__(self, input_dim, n_slice=100):
211 | super(NTLRelationModule, self).__init__()
212 | self.n_slice = n_slice
213 | M = np.random.randn(n_slice, input_dim, input_dim)
214 | M = M / np.linalg.norm(M, axis=(1, 2))[:, None, None]
215 | self.register_parameter("M", nn.Parameter(torch.Tensor(M)))
216 | self.dropout = nn.Dropout(p=0.25)
217 | self.fc = nn.Linear(n_slice, 1)
218 |
219 | def forward(self, z_q, z_c):
220 | n_query = z_q.size(0)
221 | n_class = z_c.size(0)
222 |
223 | v = self.dropout(nn.ReLU()(torch.cat([(z_q @ m @ z_c.T).unsqueeze(-1) for m in self.M], dim=-1).view(-1, self.n_slice)))
224 | r_logit = self.fc(v).view(n_query, n_class)
225 | return r_logit
226 |
227 |
228 | def run_relation(
229 | train_path: str,
230 | model_name_or_path: str,
231 | n_support: int,
232 | n_query: int,
233 | n_classes: int,
234 | valid_path: str = None,
235 | test_path: str = None,
236 | output_path: str = f"runs/{now()}",
237 | max_iter: int = 10000,
238 | evaluate_every: int = 100,
239 | early_stop: int = None,
240 | n_test_episodes: int = 1000,
241 | log_every: int = 10,
242 | relation_module_type: str = "base",
243 | ntl_n_slices: int = 100,
244 | arsc_format: bool = False,
245 | data_path: str = None
246 | ):
247 | if output_path:
248 | if os.path.exists(output_path) and len(os.listdir(output_path)):
249 | raise FileExistsError(f"Output path {output_path} already exists. Exiting.")
250 |
251 | # --------------------
252 | # Creating Log Writers
253 | # --------------------
254 | os.makedirs(output_path)
255 | os.makedirs(os.path.join(output_path, "logs/train"))
256 | train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(output_path, "logs/train"), flush_secs=1, max_queue=1)
257 | valid_writer: SummaryWriter = None
258 | test_writer: SummaryWriter = None
259 | log_dict = dict(train=list())
260 |
261 | if valid_path:
262 | os.makedirs(os.path.join(output_path, "logs/valid"))
263 | valid_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/valid"), flush_secs=1, max_queue=1)
264 | log_dict["valid"] = list()
265 | if test_path:
266 | os.makedirs(os.path.join(output_path, "logs/test"))
267 | test_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/test"), flush_secs=1, max_queue=1)
268 | log_dict["test"] = list()
269 |
270 | def raw_data_to_labels_dict(data, shuffle=True):
271 | labels_dict = collections.defaultdict(list)
272 | for item in data:
273 | labels_dict[item["label"]].append(item["sentence"])
274 | labels_dict = dict(labels_dict)
275 | if shuffle:
276 | for key, val in labels_dict.items():
277 | random.shuffle(val)
278 | return labels_dict
279 |
280 | # Load model
281 | bert = BERTEncoder(model_name_or_path).to(device)
282 | matching_net = RelationNet(encoder=bert, relation_module_type=relation_module_type, ntl_n_slices=ntl_n_slices)
283 | optimizer = torch.optim.Adam(matching_net.parameters(), lr=2e-5)
284 |
285 | # Load data
286 | if not arsc_format:
287 | train_data = get_jsonl_data(train_path)
288 | train_data_dict = raw_data_to_labels_dict(train_data, shuffle=True)
289 | logger.info(f"train labels: {train_data_dict.keys()}")
290 |
291 | if valid_path:
292 | valid_data = get_jsonl_data(valid_path)
293 | valid_data_dict = raw_data_to_labels_dict(valid_data, shuffle=True)
294 | logger.info(f"valid labels: {valid_data_dict.keys()}")
295 | else:
296 | valid_data_dict = None
297 |
298 | if test_path:
299 | test_data = get_jsonl_data(test_path)
300 | test_data_dict = raw_data_to_labels_dict(test_data, shuffle=True)
301 | logger.info(f"test labels: {test_data_dict.keys()}")
302 | else:
303 | test_data_dict = None
304 | else:
305 | train_data_dict = None
306 | test_data_dict = None
307 | valid_data_dict = None
308 |
309 | train_metrics = collections.defaultdict(list)
310 | n_eval_since_last_best = 0
311 | best_valid_acc = 0.0
312 |
313 | for step in range(max_iter):
314 | if not arsc_format:
315 | loss, loss_dict = matching_net.train_step(
316 | optimizer=optimizer,
317 | data_dict=train_data_dict,
318 | n_support=n_support,
319 | n_query=n_query,
320 | n_classes=n_classes
321 | )
322 | else:
323 | loss, loss_dict = matching_net.train_step_ARSC(optimizer=optimizer, data_path=data_path)
324 |
325 | for key, value in loss_dict["metrics"].items():
326 | train_metrics[key].append(value)
327 |
328 | # Logging
329 | if (step + 1) % log_every == 0:
330 | for key, value in train_metrics.items():
331 | train_writer.add_scalar(tag=key, scalar_value=np.mean(value), global_step=step)
332 |
333 | logger.info(f"train | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in train_metrics.items()]))
334 | log_dict["train"].append({
335 | "metrics": [
336 | {
337 | "tag": key,
338 | "value": np.mean(value)
339 | }
340 | for key, value in train_metrics.items()
341 | ],
342 | "global_step": step
343 | })
344 |
345 | train_metrics = collections.defaultdict(list)
346 |
347 | if valid_path or test_path or data_path:
348 | if (step + 1) % evaluate_every == 0:
349 | for path, writer, set_type, set_data in zip(
350 | [valid_path, test_path],
351 | [valid_writer, test_writer],
352 | ["valid", "test"],
353 | [valid_data_dict, test_data_dict]
354 | ):
355 | if path:
356 | if not arsc_format:
357 | set_results = matching_net.test_step(
358 | data_dict=set_data,
359 | n_support=n_support,
360 | n_query=n_query,
361 | n_classes=n_classes,
362 | n_episodes=n_test_episodes
363 | )
364 | else:
365 | set_results = matching_net.test_step_ARSC(
366 | data_path=data_path,
367 | n_episodes=n_test_episodes,
368 | set_type={"valid": "dev", "test": "test"}[set_type]
369 | )
370 | for key, val in set_results.items():
371 | writer.add_scalar(tag=key, scalar_value=val, global_step=step)
372 | log_dict[set_type].append({
373 | "metrics": [
374 | {
375 | "tag": key,
376 | "value": val
377 | }
378 | for key, val in set_results.items()
379 | ],
380 | "global_step": step
381 | })
382 |
383 | logger.info(f"{set_type} | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in set_results.items()]))
384 |
385 | if set_type == "valid":
386 | if set_results["acc"] > best_valid_acc:
387 | best_valid_acc = set_results["acc"]
388 | n_eval_since_last_best = 0
389 | logger.info(f"Better eval results!")
390 | else:
391 | n_eval_since_last_best += 1
392 | logger.info(f"Worse eval results ({n_eval_since_last_best}/{early_stop})")
393 |
394 | if early_stop and n_eval_since_last_best >= early_stop:
395 | logger.warning(f"Early-stopping.")
396 | break
397 | with open(os.path.join(output_path, "metrics.json"), "w") as file:
398 | json.dump(log_dict, file, ensure_ascii=False)
399 |
400 |
401 | def main():
402 | parser = argparse.ArgumentParser()
403 | parser.add_argument("--train-path", type=str, default=None, help="Path to training data")
404 | parser.add_argument("--valid-path", type=str, default=None, help="Path to validation data")
405 | parser.add_argument("--test-path", type=str, default=None, help="Path to testing data")
406 | parser.add_argument("--data-path", type=str, default=None, help="Path to data (ARSC only)")
407 |
408 | parser.add_argument("--output-path", type=str, default=f"runs/{now()}")
409 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use")
410 | parser.add_argument("--max-iter", type=int, default=10000, help="Max number of training episodes")
411 | parser.add_argument("--evaluate-every", type=int, default=100, help="Number of training episodes between each evaluation (on both valid, test)")
412 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging")
413 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set")
414 | parser.add_argument("--early-stop", type=int, default=0, help="Number of worse evaluation steps before stopping. 0=disabled")
415 |
416 | # Few-Shot related stuff
417 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class")
418 | parser.add_argument("--n-query", type=int, default=5, help="Number of query points for each class")
419 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode")
420 | parser.add_argument("--n-test-episodes", type=int, default=1000, help="Number of episodes during evaluation (valid, test)")
421 |
422 | # Relation Network-specific
423 | parser.add_argument("--relation-module-type", type=str, required=True, help="Which relation module to use")
424 | parser.add_argument("--ntl-n-slices", type=int, default=100, help="Number of matrices to use in NTL")
425 | # ARSC data
426 | parser.add_argument("--arsc-format", default=False, action="store_true", help="Using ARSC few-shot format")
427 |
428 | args = parser.parse_args()
429 | logger.debug(f"Received args {args}")
430 |
431 | # Set random seed
432 | set_seeds(args.seed)
433 |
434 | # Check if data path(s) exist
435 | for arg in [args.train_path, args.valid_path, args.test_path]:
436 | if arg and not os.path.exists(arg):
437 | raise FileNotFoundError(f"Data @ {arg} not found.")
438 |
439 | # Run
440 | run_relation(
441 | train_path=args.train_path,
442 | valid_path=args.valid_path,
443 | test_path=args.test_path,
444 | output_path=args.output_path,
445 |
446 | model_name_or_path=args.model_name_or_path,
447 |
448 | n_support=args.n_support,
449 | n_query=args.n_query,
450 | n_classes=args.n_classes,
451 | n_test_episodes=args.n_test_episodes,
452 |
453 | max_iter=args.max_iter,
454 | evaluate_every=args.evaluate_every,
455 |
456 | relation_module_type=args.relation_module_type,
457 | ntl_n_slices=args.ntl_n_slices,
458 |
459 | early_stop=args.early_stop,
460 | arsc_format=args.arsc_format,
461 | data_path=args.data_path
462 | )
463 |
464 | # Save config
465 | with open(os.path.join(args.output_path, "config.json"), "w") as file:
466 | json.dump(vars(args), file, ensure_ascii=False)
467 |
468 |
469 | if __name__ == "__main__":
470 | main()
471 |
--------------------------------------------------------------------------------
/models/induction/inductionnet.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from utils.data import get_jsonl_data
4 | from utils.python import now, set_seeds
5 | import random
6 | import collections
7 | import os
8 | from typing import List, Dict
9 | from tensorboardX import SummaryWriter
10 | import numpy as np
11 | from models.encoders.bert_encoder import BERTEncoder
12 | import torch
13 | import torch.nn as nn
14 | import warnings
15 | import logging
16 | from utils.few_shot import create_episode, create_ARSC_train_episode, create_ARSC_test_episode
17 |
18 | logging.basicConfig()
19 | logger = logging.getLogger(__name__)
20 | logger.setLevel(logging.DEBUG)
21 |
22 | warnings.simplefilter('ignore')
23 | torch.autograd.set_detect_anomaly(True)
24 |
25 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
26 |
27 |
28 | class InductionNet(nn.Module):
29 | def __init__(self, encoder, hidden_dim: int = 768, ntl_n_slices: int = 100, n_routing_iter: int = 3):
30 | super(InductionNet, self).__init__()
31 |
32 | self.encoder = encoder
33 |
34 | self.ntl_n_slices: int = ntl_n_slices
35 | self.n_routing_iter: int = n_routing_iter
36 | self.hidden_dim = hidden_dim
37 | self.relation_module = NTLRelationModule(input_dim=self.hidden_dim, n_slice=self.ntl_n_slices).to(device)
38 | self.induction_module = InductionModule(input_dim=hidden_dim, n_routing_iter=self.n_routing_iter).to(device)
39 |
40 | def loss(self, sample):
41 | """
42 | :param sample: {
43 | "xs": [
44 | [support_A_1, support_A_2, ...],
45 | [support_B_1, support_B_2, ...],
46 | [support_C_1, support_C_2, ...],
47 | ...
48 | ],
49 | "xq": [
50 | [query_A_1, query_A_2, ...],
51 | [query_B_1, query_B_2, ...],
52 | [query_C_1, query_C_2, ...],
53 | ...
54 | ]
55 | }
56 | :return:
57 | """
58 | xs = sample["xs"] # support
59 | xq = sample["xq"] # query
60 |
61 | n_class = len(xs)
62 | assert len(xq) == n_class
63 | n_support = len(xs[0])
64 | n_query = len(xq[0])
65 |
66 | x = [item for xs_ in xs for item in xs_] + [item for xq_ in xq for item in xq_]
67 | z = self.encoder.forward(x)
68 | z_dim = z.size(-1)
69 |
70 | z_query = z[n_class * n_support:]
71 | z_support = z[:n_class * n_support].view(n_class, n_support, z_dim)
72 |
73 | class_representatives = self.induction_module.forward(z_s=z_support)
74 | relation_module_scores = self.relation_module.forward(z_q=z_query, z_c=class_representatives)
75 | true_labels = torch.zeros_like(relation_module_scores).to(device)
76 |
77 | for ix_class, class_query_sentences in enumerate(xq):
78 | for ix_sentence, sentence in enumerate(class_query_sentences):
79 | true_labels[ix_class * n_query + ix_sentence, ix_class] = 1
80 |
81 | # MSE LOSS
82 | # relation_module_scores = torch.sigmoid(relation_module_scores)
83 | # loss_fn = nn.MSELoss()
84 | # loss_val = loss_fn(relation_module_scores, true_labels)
85 | # acc_full = ((relation_module_scores > 0.5).float() == true_labels.float()).float().mean()
86 | # acc_exact = (((relation_module_scores > 0.5).float() - true_labels.float()).abs().max(dim=1)[0] == 0).float().mean()
87 | # acc_max = (relation_module_scores.argmax(1) == true_labels.argmax(1)).float().mean()
88 | #
89 | # return loss_val, {
90 | # "loss": loss_val.item(),
91 | # "metrics": {
92 | # "loss": loss_val.item(),
93 | # "acc_full": acc_full.item(),
94 | # "acc_exact": acc_exact.item(),
95 | # "acc_max": acc_max.item(),
96 | # "acc": acc_max.item()
97 | # },
98 | # "y_hat": relation_module_scores.argmax(1).cpu().detach().numpy()
99 | # }
100 |
101 | # CE LOSS
102 | loss_fn = nn.CrossEntropyLoss()
103 | loss_val = loss_fn(relation_module_scores, true_labels.argmax(1))
104 | acc_val = (true_labels.argmax(1) == relation_module_scores.argmax(1)).float().mean()
105 | return loss_val, {
106 | "loss": loss_val.item(),
107 | "metrics": {
108 | "loss": loss_val.item(),
109 | "acc": acc_val.item()
110 | },
111 | "y_hat": relation_module_scores.argmax(1).cpu().detach().numpy()
112 | }
113 |
114 | def train_step(self, optimizer, data_dict: Dict[str, List[str]], n_support, n_classes, n_query):
115 |
116 | episode = create_episode(
117 | data_dict=data_dict,
118 | n_support=n_support,
119 | n_classes=n_classes,
120 | n_query=n_query
121 | )
122 |
123 | self.train()
124 | optimizer.zero_grad()
125 | torch.cuda.empty_cache()
126 | loss, loss_dict = self.loss(episode)
127 | loss.backward()
128 | optimizer.step()
129 |
130 | return loss, loss_dict
131 |
132 | def test_step(self, data_dict, n_support, n_classes, n_query, n_episodes=1000):
133 | metrics = collections.defaultdict(list)
134 | self.eval()
135 | for i in range(n_episodes):
136 | episode = create_episode(
137 | data_dict=data_dict,
138 | n_support=n_support,
139 | n_classes=n_classes,
140 | n_query=n_query
141 | )
142 |
143 | with torch.no_grad():
144 | loss, loss_dict = self.loss(episode)
145 |
146 | for key, value in loss_dict["metrics"].items():
147 | metrics[key].append(value)
148 |
149 | return {
150 | key: np.mean(value) for key, value in metrics.items()
151 | }
152 |
153 | def train_step_ARSC(self, data_path: str, optimizer):
154 | episode = create_ARSC_train_episode(prefix=data_path, n_support=5, n_query=5)
155 |
156 | self.train()
157 | optimizer.zero_grad()
158 | torch.cuda.empty_cache()
159 | loss, loss_dict = self.loss(episode)
160 | loss.backward()
161 | optimizer.step()
162 |
163 | return loss, loss_dict
164 |
165 | def test_step_ARSC(self, data_path: str, n_episodes=1000, set_type="test"):
166 | assert set_type in ("dev", "test")
167 | metrics = collections.defaultdict(list)
168 | self.eval()
169 | for i in range(n_episodes):
170 | episode = create_ARSC_test_episode(prefix=data_path, n_query=5, set_type=set_type)
171 |
172 | with torch.no_grad():
173 | loss, loss_dict = self.loss(episode)
174 |
175 | for key, value in loss_dict["metrics"].items():
176 | metrics[key].append(value)
177 |
178 | return {
179 | key: np.mean(value) for key, value in metrics.items()
180 | }
181 |
182 |
183 | class InductionModule(nn.Module):
184 | def __init__(self, input_dim: int, n_routing_iter: int = 3):
185 | super(InductionModule, self).__init__()
186 | self.input_dim: int = input_dim
187 | self.n_routing_iter: int = n_routing_iter
188 |
189 | # Init Ws, bs
190 | self.Ws_bs = nn.Linear(input_dim, input_dim)
191 |
192 | @staticmethod
193 | def squash(x):
194 | return (x / x.norm(dim=1)[:, None]) * ((x.norm(dim=1) ** 2) / (1 + (x.norm(dim=1) ** 2)))[:, None]
195 |
196 | def forward(self, z_s):
197 | """
198 | :param z_s: embedding of support samples, shape=(C, K, hidden_dim)
199 | :return:
200 | """
201 | C, K, hidden_dim = z_s.size()
202 | class_representatives: List[torch.Tensor] = list()
203 |
204 | for i in range(C):
205 | z_squashed = self.squash(self.Ws_bs(z_s[i]))
206 | b_i = torch.autograd.Variable(torch.zeros(K)).to(device)
207 | for iteration in range(self.n_routing_iter):
208 | d_i = b_i.clone().softmax(dim=-1)
209 | c_i = torch.matmul(d_i, z_squashed)
210 | c_i = self.squash(c_i.view(1, -1))
211 | b_i += (z_squashed @ c_i.view(-1, 1)).view(-1)
212 |
213 | class_representatives.append(c_i)
214 | class_representatives = torch.cat(class_representatives).to(device)
215 | return class_representatives
216 |
217 |
218 | class NTLRelationModule(nn.Module):
219 | def __init__(self, input_dim, n_slice=100):
220 | super(NTLRelationModule, self).__init__()
221 | self.n_slice = n_slice
222 | M = np.random.randn(n_slice, input_dim, input_dim)
223 | M = M / np.linalg.norm(M, axis=(1, 2))[:, None, None]
224 | self.register_parameter("M", nn.Parameter(torch.Tensor(M)))
225 | self.dropout = nn.Dropout(p=0.25)
226 | self.fc = nn.Linear(n_slice, 1)
227 |
228 | def forward(self, z_q, z_c):
229 | n_query = z_q.size(0)
230 | n_class = z_c.size(0)
231 |
232 | v = self.dropout(nn.ReLU()(torch.cat([(z_q @ m @ z_c.T).unsqueeze(-1) for m in self.M], dim=-1).view(-1, self.n_slice)))
233 | r_logit = self.fc(v).view(n_query, n_class)
234 | return r_logit
235 |
236 |
237 | def run_induction(
238 | train_path: str,
239 | model_name_or_path: str,
240 | n_support: int,
241 | n_query: int,
242 | n_classes: int,
243 | valid_path: str = None,
244 | test_path: str = None,
245 | output_path: str = f"runs/{now()}",
246 | max_iter: int = 10000,
247 | evaluate_every: int = 100,
248 | early_stop: int = None,
249 | n_test_episodes: int = 1000,
250 | log_every: int = 10,
251 | ntl_n_slices: int = 100,
252 | n_routing_iter: int = 3,
253 | arsc_format: bool = False,
254 | data_path: str = None
255 | ):
256 | if output_path:
257 | if os.path.exists(output_path) and len(os.listdir(output_path)):
258 | raise FileExistsError(f"Output path {output_path} already exists. Exiting.")
259 |
260 | # --------------------
261 | # Creating Log Writers
262 | # --------------------
263 | os.makedirs(output_path)
264 | os.makedirs(os.path.join(output_path, "logs/train"))
265 | train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(output_path, "logs/train"), flush_secs=1, max_queue=1)
266 | valid_writer: SummaryWriter = None
267 | test_writer: SummaryWriter = None
268 | log_dict = dict(train=list())
269 |
270 | if valid_path:
271 | os.makedirs(os.path.join(output_path, "logs/valid"))
272 | valid_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/valid"), flush_secs=1, max_queue=1)
273 | log_dict["valid"] = list()
274 | if test_path:
275 | os.makedirs(os.path.join(output_path, "logs/test"))
276 | test_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/test"), flush_secs=1, max_queue=1)
277 | log_dict["test"] = list()
278 |
279 | def raw_data_to_labels_dict(data, shuffle=True):
280 | labels_dict = collections.defaultdict(list)
281 | for item in data:
282 | labels_dict[item["label"]].append(item["sentence"])
283 | labels_dict = dict(labels_dict)
284 | if shuffle:
285 | for key, val in labels_dict.items():
286 | random.shuffle(val)
287 | return labels_dict
288 |
289 | # Load model
290 | bert = BERTEncoder(model_name_or_path).to(device)
291 | induction_net = InductionNet(
292 | encoder=bert,
293 | ntl_n_slices=ntl_n_slices,
294 | n_routing_iter=n_routing_iter
295 | )
296 | optimizer = torch.optim.Adam(induction_net.parameters(), lr=2e-5)
297 |
298 | # Load data
299 | if not arsc_format:
300 | train_data = get_jsonl_data(train_path)
301 | train_data_dict = raw_data_to_labels_dict(train_data, shuffle=True)
302 | logger.info(f"train labels: {train_data_dict.keys()}")
303 |
304 | if valid_path:
305 | valid_data = get_jsonl_data(valid_path)
306 | valid_data_dict = raw_data_to_labels_dict(valid_data, shuffle=True)
307 | logger.info(f"valid labels: {valid_data_dict.keys()}")
308 | else:
309 | valid_data_dict = None
310 |
311 | if test_path:
312 | test_data = get_jsonl_data(test_path)
313 | test_data_dict = raw_data_to_labels_dict(test_data, shuffle=True)
314 | logger.info(f"test labels: {test_data_dict.keys()}")
315 | else:
316 | test_data_dict = None
317 | else:
318 | train_data_dict = None
319 | test_data_dict = None
320 | valid_data_dict = None
321 |
322 | train_metrics = collections.defaultdict(list)
323 | n_eval_since_last_best = 0
324 | best_valid_acc = 0.0
325 |
326 | for step in range(max_iter):
327 | if not arsc_format:
328 | loss, loss_dict = induction_net.train_step(
329 | optimizer=optimizer,
330 | data_dict=train_data_dict,
331 | n_support=n_support,
332 | n_query=n_query,
333 | n_classes=n_classes
334 | )
335 | else:
336 | loss, loss_dict = induction_net.train_step_ARSC(
337 | optimizer=optimizer,
338 | data_path=data_path
339 | )
340 |
341 | for key, value in loss_dict["metrics"].items():
342 | train_metrics[key].append(value)
343 |
344 | # Logging
345 | if (step + 1) % log_every == 0:
346 | for key, value in train_metrics.items():
347 | train_writer.add_scalar(tag=key, scalar_value=np.mean(value), global_step=step)
348 |
349 | logger.info(f"train | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in train_metrics.items()]))
350 | log_dict["train"].append({
351 | "metrics": [
352 | {
353 | "tag": key,
354 | "value": np.mean(value)
355 | }
356 | for key, value in train_metrics.items()
357 | ],
358 | "global_step": step
359 | })
360 |
361 | train_metrics = collections.defaultdict(list)
362 |
363 | if valid_path or test_path or data_path:
364 | if (step + 1) % evaluate_every == 0:
365 | for path, writer, set_type, set_data in zip(
366 | [valid_path, test_path],
367 | [valid_writer, test_writer],
368 | ["valid", "test"],
369 | [valid_data_dict, test_data_dict]
370 | ):
371 | if path:
372 | if not arsc_format:
373 | set_results = induction_net.test_step(
374 | data_dict=set_data,
375 | n_support=n_support,
376 | n_query=n_query,
377 | n_classes=n_classes,
378 | n_episodes=n_test_episodes
379 | )
380 | else:
381 | set_results = induction_net.test_step_ARSC(
382 | data_path=data_path,
383 | n_episodes=n_test_episodes,
384 | set_type={"valid": "dev", "test": "test"}[set_type]
385 | )
386 | for key, val in set_results.items():
387 | writer.add_scalar(tag=key, scalar_value=val, global_step=step)
388 | log_dict[set_type].append({
389 | "metrics": [
390 | {
391 | "tag": key,
392 | "value": val
393 | }
394 | for key, val in set_results.items()
395 | ],
396 | "global_step": step
397 | })
398 |
399 | logger.info(f"{set_type} | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in set_results.items()]))
400 |
401 | if set_type == "valid":
402 | if set_results["acc"] > best_valid_acc:
403 | best_valid_acc = set_results["acc"]
404 | n_eval_since_last_best = 0
405 | logger.info(f"Better eval results!")
406 | else:
407 | n_eval_since_last_best += 1
408 | logger.info(f"Worse eval results ({n_eval_since_last_best}/{early_stop})")
409 |
410 | if early_stop and n_eval_since_last_best >= early_stop:
411 | print(f"Early-stopping.")
412 | break
413 | with open(os.path.join(output_path, "metrics.json"), "w") as file:
414 | json.dump(log_dict, file, ensure_ascii=False)
415 |
416 |
417 | def main():
418 | parser = argparse.ArgumentParser()
419 | parser.add_argument("--train-path", type=str, default=None, help="Path to training data")
420 | parser.add_argument("--valid-path", type=str, default=None, help="Path to validation data")
421 | parser.add_argument("--test-path", type=str, default=None, help="Path to testing data")
422 | parser.add_argument("--data-path", type=str, default=None, help="Path to data (ARSC only)")
423 |
424 | parser.add_argument("--output-path", type=str, default=f"runs/{now()}")
425 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use")
426 | parser.add_argument("--max-iter", type=int, default=10000, help="Max number of training episodes")
427 | parser.add_argument("--evaluate-every", type=int, default=100, help="Number of training episodes between each evaluation (on both valid, test)")
428 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging")
429 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set")
430 | parser.add_argument("--early-stop", type=int, default=0, help="Number of worse evaluation steps before stopping. 0=disabled")
431 |
432 | # Few-Shot related stuff
433 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class")
434 | parser.add_argument("--n-query", type=int, default=5, help="Number of query points for each class")
435 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode")
436 | parser.add_argument("--n-test-episodes", type=int, default=1000, help="Number of episodes during evaluation (valid, test)")
437 |
438 | # Relation Network-specific
439 | parser.add_argument("--ntl-n-slices", type=int, default=100, help="Number of matrices to use in NTL")
440 | parser.add_argument("--n-routing-iter", type=int, default=3, help="Number of routing iterations in the induction module")
441 |
442 | # ARSC data
443 | parser.add_argument("--arsc-format", default=False, action="store_true", help="Using ARSC few-shot format")
444 | args = parser.parse_args()
445 |
446 | # Set random seed
447 | set_seeds(args.seed)
448 |
449 | # Check if data path(s) exist
450 | for arg in [args.train_path, args.valid_path, args.test_path]:
451 | if arg and not os.path.exists(arg):
452 | raise FileNotFoundError(f"Data @ {arg} not found.")
453 |
454 | # Run
455 | run_induction(
456 | train_path=args.train_path,
457 | valid_path=args.valid_path,
458 | test_path=args.test_path,
459 | output_path=args.output_path,
460 |
461 | model_name_or_path=args.model_name_or_path,
462 |
463 | n_support=args.n_support,
464 | n_query=args.n_query,
465 | n_classes=args.n_classes,
466 | n_test_episodes=args.n_test_episodes,
467 |
468 | max_iter=args.max_iter,
469 | evaluate_every=args.evaluate_every,
470 | n_routing_iter=args.n_routing_iter,
471 | ntl_n_slices=args.ntl_n_slices,
472 | early_stop=args.early_stop,
473 | arsc_format=args.arsc_format,
474 | data_path=args.data_path
475 | )
476 |
477 | # Save config
478 | with open(os.path.join(args.output_path, "config.json"), "w") as file:
479 | json.dump(vars(args), file, ensure_ascii=False)
480 |
481 |
482 | if __name__ == "__main__":
483 | main()
484 |
--------------------------------------------------------------------------------
/models/proto/protaugment.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
4 |
5 | import json
6 | import argparse
7 |
8 | from transformers import AutoTokenizer
9 |
10 | from models.encoders.bert_encoder import BERTEncoder
11 | from paraphrase.modeling import UnigramRandomDropParaphraseBatchPreparer, DBSParaphraseModel, BigramDropParaphraseBatchPreparer, BaseParaphraseBatchPreparer
12 | from paraphrase.utils.data import FewShotDataset, FewShotSSLParaphraseDataset, FewShotSSLFileDataset
13 | from utils.data import get_jsonl_data, FewShotDataLoader
14 | from utils.python import now, set_seeds
15 | import random
16 | import collections
17 | import os
18 | from typing import List, Dict, Callable, Union
19 | from tensorboardX import SummaryWriter
20 | import numpy as np
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn.functional as torch_functional
24 | from torch.autograd import Variable
25 | import warnings
26 | import logging
27 | from utils.few_shot import create_episode, create_ARSC_test_episode, create_ARSC_train_episode
28 | from utils.math import euclidean_dist, cosine_similarity
29 |
30 | logger = logging.getLogger(__name__)
31 | logger.setLevel(logging.DEBUG)
32 |
33 | warnings.simplefilter('ignore')
34 |
35 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
36 |
37 |
38 | class ProtAugmentNet(nn.Module):
39 | def __init__(self, encoder, metric="euclidean"):
40 | super(ProtAugmentNet, self).__init__()
41 |
42 | self.encoder: BERTEncoder = encoder
43 | self.metric = metric
44 | assert self.metric in ('euclidean', 'cosine')
45 |
46 | def loss(self, sample, supervised_loss_share: float = 0):
47 | """
48 | :param supervised_loss_share: share of supervised loss in total loss
49 | :param sample: {
50 | "xs": [
51 | [support_A_1, support_A_2, ...],
52 | [support_B_1, support_B_2, ...],
53 | [support_C_1, support_C_2, ...],
54 | ...
55 | ],
56 | "xq": [
57 | [query_A_1, query_A_2, ...],
58 | [query_B_1, query_B_2, ...],
59 | [query_C_1, query_C_2, ...],
60 | ...
61 | ]
62 | }
63 | :return:
64 | """
65 | xs = sample['xs'] # support
66 | xq = sample['xq'] # query
67 |
68 | n_class = len(xs)
69 | assert len(xq) == n_class
70 | n_support = len(xs[0])
71 | n_query = len(xq[0])
72 |
73 | target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long()
74 | target_inds = Variable(target_inds, requires_grad=False).to(device)
75 |
76 | has_augment = "x_augment" in sample
77 | if has_augment:
78 | augmentations = sample["x_augment"]
79 |
80 | n_augmentations_samples = len(sample["x_augment"])
81 | n_augmentations_per_sample = [len(item['tgt_texts']) for item in augmentations]
82 | assert len(set(n_augmentations_per_sample)) == 1
83 | n_augmentations_per_sample = n_augmentations_per_sample[0]
84 |
85 | supports = [item["sentence"] for xs_ in xs for item in xs_]
86 | queries = [item["sentence"] for xq_ in xq for item in xq_]
87 | augmentations_supports = [[item2 for item2 in item["tgt_texts"]] for item in sample["x_augment"]]
88 | augmentation_queries = [item["src_text"] for item in sample["x_augment"]]
89 |
90 | # Encode
91 | x = supports + queries + [item2 for item1 in augmentations_supports for item2 in item1] + augmentation_queries
92 | z = self.encoder.embed_sentences(x)
93 | z_dim = z.size(-1)
94 |
95 | # Dispatch
96 | z_support = z[:len(supports)].view(n_class, n_support, z_dim).mean(dim=[1])
97 | z_query = z[len(supports):len(supports) + len(queries)]
98 | z_aug_support = (z[len(supports) + len(queries):len(supports) + len(queries) + n_augmentations_per_sample * n_augmentations_samples]
99 | .view(n_augmentations_samples, n_augmentations_per_sample, z_dim).mean(dim=[1]))
100 | z_aug_query = z[-len(augmentation_queries):]
101 | else:
102 | # When not using augmentations
103 | supports = [item["sentence"] for xs_ in xs for item in xs_]
104 | queries = [item["sentence"] for xq_ in xq for item in xq_]
105 |
106 | # Encode
107 | x = supports + queries
108 | z = self.encoder.embed_sentences(x)
109 | z_dim = z.size(-1)
110 |
111 | # Dispatch
112 | z_support = z[:len(supports)].view(n_class, n_support, z_dim).mean(dim=[1])
113 | z_query = z[len(supports):len(supports) + len(queries)]
114 |
115 | if self.metric == "euclidean":
116 | supervised_dists = euclidean_dist(z_query, z_support)
117 | if has_augment:
118 | unsupervised_dists = euclidean_dist(z_aug_query, z_aug_support)
119 | elif self.metric == "cosine":
120 | supervised_dists = (-cosine_similarity(z_query, z_support) + 1) * 5
121 | if has_augment:
122 | unsupervised_dists = (-cosine_similarity(z_aug_query, z_aug_support) + 1) * 5
123 | else:
124 | raise NotImplementedError
125 |
126 | from torch.nn import CrossEntropyLoss
127 | supervised_loss = CrossEntropyLoss()(-supervised_dists, target_inds.reshape(-1))
128 | _, y_hat_supervised = (-supervised_dists).max(1)
129 | acc_val_supervised = torch.eq(y_hat_supervised, target_inds.reshape(-1)).float().mean()
130 |
131 | if has_augment:
132 | # Unsupervised loss
133 | unsupervised_target_inds = torch.range(0, n_augmentations_samples - 1).to(device).long()
134 | unsupervised_loss = CrossEntropyLoss()(-unsupervised_dists, unsupervised_target_inds)
135 | _, y_hat_unsupervised = (-unsupervised_dists).max(1)
136 | acc_val_unsupervised = torch.eq(y_hat_unsupervised, unsupervised_target_inds.reshape(-1)).float().mean()
137 |
138 | # Final loss
139 | assert 0 <= supervised_loss_share <= 1
140 | final_loss = (supervised_loss_share) * supervised_loss + (1 - supervised_loss_share) * unsupervised_loss
141 |
142 | return final_loss, {
143 | "metrics": {
144 | "supervised_acc": acc_val_supervised.item(),
145 | "unsupervised_acc": acc_val_unsupervised.item(),
146 | "supervised_loss": supervised_loss.item(),
147 | "unsupervised_loss": unsupervised_loss.item(),
148 | "supervised_loss_share": supervised_loss_share,
149 | "final_loss": final_loss.item(),
150 | },
151 | "supervised_dists": supervised_dists,
152 | "unsupervised_dists": unsupervised_dists,
153 | "target": target_inds
154 | }
155 |
156 | return supervised_loss, {
157 | "metrics": {
158 | "acc": acc_val_supervised.item(),
159 | "loss": supervised_loss.item(),
160 | },
161 | "dists": supervised_dists,
162 | "target": target_inds
163 | }
164 |
165 | def train_step(self, optimizer, episode, supervised_loss_share: float):
166 | self.train()
167 | optimizer.zero_grad()
168 | torch.cuda.empty_cache()
169 | loss, loss_dict = self.loss(episode, supervised_loss_share=supervised_loss_share)
170 | loss.backward()
171 | optimizer.step()
172 |
173 | return loss, loss_dict
174 |
175 | def test_step(self, dataset: FewShotDataset, n_episodes: int = 1000):
176 | metrics = collections.defaultdict(list)
177 |
178 | self.eval()
179 | for i in range(n_episodes):
180 | episode = dataset.get_episode()
181 |
182 | with torch.no_grad():
183 | loss, loss_dict = self.loss(episode, supervised_loss_share=1)
184 |
185 | for k, v in loss_dict["metrics"].items():
186 | metrics[k].append(v)
187 |
188 | return {
189 | key: np.mean(value) for key, value in metrics.items()
190 | }
191 |
192 |
193 | def run_proto(
194 | # Compulsory!
195 | data_path: str,
196 | train_labels_path: str,
197 | model_name_or_path: str,
198 |
199 | # Few-shot Stuff
200 | n_support: int,
201 | n_query: int,
202 | n_classes: int,
203 | metric: str = "euclidean",
204 |
205 | # Optional path to augmented data
206 | unlabeled_path: str = None,
207 |
208 | # Path training data ONLY (optional)
209 | train_path: str = None,
210 |
211 | # Validation & test
212 | valid_labels_path: str = None,
213 | test_labels_path: str = None,
214 | evaluate_every: int = 100,
215 | n_test_episodes: int = 1000,
216 |
217 | # Logging & Saving
218 | output_path: str = f'runs/{now()}',
219 | log_every: int = 10,
220 |
221 | # Training stuff
222 | max_iter: int = 10000,
223 | early_stop: int = None,
224 |
225 | # Augmentation & paraphrase
226 | n_augmentation: int = 5,
227 | paraphrase_model_name_or_path: str = None,
228 | paraphrase_tokenizer_name_or_path: str = None,
229 | paraphrase_num_beams: int = None,
230 | paraphrase_beam_group_size: int = None,
231 | paraphrase_diversity_penalty: float = None,
232 | paraphrase_filtering_strategy: str = None,
233 | paraphrase_drop_strategy: str = None,
234 | paraphrase_drop_chance_speed: str = None,
235 | paraphrase_drop_chance_auc: float = None,
236 | supervised_loss_share_fn: Callable[[int, int], float] = lambda x, y: 1 - (x / y),
237 |
238 | augmentation_data_path: str = None
239 | ):
240 | if output_path:
241 | if os.path.exists(output_path) and len(os.listdir(output_path)):
242 | raise FileExistsError(f"Output path {output_path} already exists. Exiting.")
243 |
244 | # --------------------
245 | # Creating Log Writers
246 | # --------------------
247 | os.makedirs(output_path)
248 | os.makedirs(os.path.join(output_path, "logs/train"))
249 | train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(output_path, "logs/train"), flush_secs=1, max_queue=1)
250 | valid_writer: SummaryWriter = None
251 | test_writer: SummaryWriter = None
252 | log_dict = dict(train=list())
253 |
254 | # ----------
255 | # Load model
256 | # ----------
257 | bert = BERTEncoder(model_name_or_path).to(device)
258 | protonet: ProtAugmentNet = ProtAugmentNet(encoder=bert, metric=metric)
259 | optimizer = torch.optim.Adam(protonet.parameters(), lr=2e-5)
260 |
261 | # ------------------
262 | # Load Train Dataset
263 | # ------------------
264 | if augmentation_data_path:
265 | # If an augmentation data path is provided, uses those pre-generated augmentations
266 | train_dataset = FewShotSSLFileDataset(
267 | data_path=train_path if train_path else data_path,
268 | labels_path=train_labels_path,
269 | n_classes=n_classes,
270 | n_support=n_support,
271 | n_query=n_query,
272 | n_unlabeled=n_augmentation,
273 | unlabeled_file_path=augmentation_data_path,
274 | )
275 | else:
276 | # ---------------------
277 | # Load paraphrase model
278 | # ---------------------
279 | paraphrase_model_device = torch.device("cpu") if "20newsgroup" in data_path else torch.device("cuda")
280 | logger.info(f"Paraphrase model device: {paraphrase_model_device}")
281 | paraphrase_tokenizer = AutoTokenizer.from_pretrained(paraphrase_tokenizer_name_or_path)
282 | if paraphrase_drop_strategy == "unigram":
283 | paraphrase_batch_preparer = UnigramRandomDropParaphraseBatchPreparer(
284 | tokenizer=paraphrase_tokenizer,
285 | auc=paraphrase_drop_chance_auc,
286 | drop_chance_speed=paraphrase_drop_chance_speed,
287 | device=paraphrase_model_device
288 | )
289 | elif paraphrase_drop_strategy == "bigram":
290 | paraphrase_batch_preparer = BigramDropParaphraseBatchPreparer(tokenizer=paraphrase_tokenizer, device=paraphrase_model_device)
291 | else:
292 | paraphrase_batch_preparer = BaseParaphraseBatchPreparer(tokenizer=paraphrase_tokenizer, device=paraphrase_model_device)
293 |
294 | paraphrase_model = DBSParaphraseModel(
295 | model_name_or_path=paraphrase_model_name_or_path,
296 | tok_name_or_path=paraphrase_tokenizer_name_or_path,
297 | num_beams=paraphrase_num_beams,
298 | beam_group_size=paraphrase_beam_group_size,
299 | diversity_penalty=paraphrase_diversity_penalty,
300 | filtering_strategy=paraphrase_filtering_strategy,
301 | paraphrase_batch_preparer=paraphrase_batch_preparer,
302 | device=paraphrase_model_device
303 | )
304 |
305 | train_dataset = FewShotSSLParaphraseDataset(
306 | data_path=train_path if train_path else data_path,
307 | labels_path=train_labels_path,
308 | n_classes=n_classes,
309 | n_support=n_support,
310 | n_query=n_query,
311 | n_unlabeled=n_augmentation,
312 | unlabeled_file_path=unlabeled_path,
313 | paraphrase_model=paraphrase_model
314 | )
315 | logger.info(f"Train dataset has {len(train_dataset)} items")
316 |
317 | # ---------
318 | # Load data
319 | # ---------
320 | logger.info(f"train labels: {train_dataset.data.keys()}")
321 | valid_dataset: FewShotDataset = None
322 | if valid_labels_path:
323 | os.makedirs(os.path.join(output_path, "logs/valid"))
324 | valid_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/valid"), flush_secs=1, max_queue=1)
325 | log_dict["valid"] = list()
326 | valid_dataset = FewShotDataset(data_path=data_path, labels_path=valid_labels_path, n_classes=n_classes, n_support=n_support, n_query=n_query)
327 | logger.info(f"valid labels: {valid_dataset.data.keys()}")
328 | assert len(set(valid_dataset.data.keys()) & set(train_dataset.data.keys())) == 0
329 |
330 | test_dataset: FewShotDataset = None
331 | if test_labels_path:
332 | os.makedirs(os.path.join(output_path, "logs/test"))
333 | test_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/test"), flush_secs=1, max_queue=1)
334 | log_dict["test"] = list()
335 | test_dataset = FewShotDataset(data_path=data_path, labels_path=test_labels_path, n_classes=n_classes, n_support=n_support, n_query=n_query)
336 | logger.info(f"test labels: {test_dataset.data.keys()}")
337 | assert len(set(test_dataset.data.keys()) & set(train_dataset.data.keys())) == 0
338 |
339 | train_metrics = collections.defaultdict(list)
340 | n_eval_since_last_best = 0
341 | best_valid_acc = 0.0
342 |
343 | for step in range(max_iter):
344 | episode = train_dataset.get_episode()
345 |
346 | supervised_loss_share = supervised_loss_share_fn(step, max_iter)
347 | loss, loss_dict = protonet.train_step(optimizer=optimizer, episode=episode, supervised_loss_share=supervised_loss_share)
348 |
349 | for key, value in loss_dict["metrics"].items():
350 | train_metrics[key].append(value)
351 |
352 | # Logging
353 | if (step + 1) % log_every == 0:
354 | for key, value in train_metrics.items():
355 | train_writer.add_scalar(tag=key, scalar_value=np.mean(value), global_step=step)
356 | logger.info(f"train | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in train_metrics.items()]))
357 | log_dict["train"].append({
358 | "metrics": [
359 | {
360 | "tag": key,
361 | "value": np.mean(value)
362 | }
363 | for key, value in train_metrics.items()
364 | ],
365 | "global_step": step
366 | })
367 |
368 | train_metrics = collections.defaultdict(list)
369 |
370 | if valid_labels_path or test_labels_path:
371 | if (step + 1) % evaluate_every == 0:
372 | for labels_path, writer, set_type, set_dataset in zip(
373 | [valid_labels_path, test_labels_path],
374 | [valid_writer, test_writer],
375 | ["valid", "test"],
376 | [valid_dataset, test_dataset]
377 | ):
378 | if set_dataset:
379 |
380 | set_results = protonet.test_step(
381 | dataset=set_dataset,
382 | n_episodes=n_test_episodes
383 | )
384 |
385 | for key, val in set_results.items():
386 | writer.add_scalar(tag=key, scalar_value=val, global_step=step)
387 | log_dict[set_type].append({
388 | "metrics": [
389 | {
390 | "tag": key,
391 | "value": val
392 | }
393 | for key, val in set_results.items()
394 | ],
395 | "global_step": step
396 | })
397 |
398 | logger.info(f"{set_type} | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in set_results.items()]))
399 | if set_type == "valid":
400 | if set_results["acc"] > best_valid_acc:
401 | best_valid_acc = set_results["acc"]
402 | n_eval_since_last_best = 0
403 | logger.info(f"Better eval results!")
404 | else:
405 | n_eval_since_last_best += 1
406 | logger.info(f"Worse eval results ({n_eval_since_last_best}/{early_stop})")
407 |
408 | if early_stop and n_eval_since_last_best >= early_stop:
409 | logger.warning(f"Early-stopping.")
410 | break
411 |
412 | with open(os.path.join(output_path, 'metrics.json'), "w") as file:
413 | json.dump(log_dict, file, ensure_ascii=False)
414 |
415 |
416 | def main():
417 | parser = argparse.ArgumentParser()
418 | parser.add_argument("--data-path", type=str, required=True, help="Path to data")
419 | parser.add_argument("--train-labels-path", type=str, required=True, help="Path to train labels")
420 | parser.add_argument("--train-path", type=str, help="Path to training data (if provided, picks training data from this path instead of --data-path")
421 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use")
422 |
423 | # Few-Shot related stuff
424 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class")
425 | parser.add_argument("--n-query", type=int, default=5, help="Number of query points for each class")
426 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode")
427 | parser.add_argument("--metric", type=str, default="euclidean", help="Metric to use", choices=("euclidean", "cosine"))
428 |
429 | # Path to augmented data
430 | parser.add_argument("--unlabeled-path", type=str, help="Path to data containing augmentations used for consistency")
431 |
432 | # Validation & test
433 | parser.add_argument("--valid-labels-path", type=str, required=True, help="Path to valid labels")
434 | parser.add_argument("--test-labels-path", type=str, required=True, help="Path to test labels")
435 | parser.add_argument("--evaluate-every", type=int, default=100, help="Number of training episodes between each evaluation (on both valid, test)")
436 | parser.add_argument("--n-test-episodes", type=int, default=1000, help="Number of episodes during evaluation (valid, test)")
437 |
438 | # Logging & Saving
439 | parser.add_argument("--output-path", type=str, default=f'runs/{now()}')
440 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging")
441 |
442 | # Training stuff
443 | parser.add_argument("--max-iter", type=int, default=10000, help="Max number of training episodes")
444 | parser.add_argument("--early-stop", type=int, default=0, help="Number of worse evaluation steps before stopping. 0=disabled")
445 |
446 | # Augmentation & Paraphrase
447 | parser.add_argument("--n-augmentation", type=int, help="Number of unlabeled data points per class (proto++)", default=0)
448 | parser.add_argument("--paraphrase-model-name-or-path", type=str)
449 | parser.add_argument("--paraphrase-tokenizer-name-or-path", type=str)
450 | parser.add_argument("--paraphrase-num-beams", type=int)
451 | parser.add_argument("--paraphrase-beam-group-size", type=int)
452 | parser.add_argument("--paraphrase-diversity-penalty", type=float)
453 | parser.add_argument("--paraphrase-filtering-strategy", type=str)
454 | parser.add_argument("--paraphrase-drop-strategy", type=str)
455 | parser.add_argument("--paraphrase-drop-chance-speed", type=str)
456 | parser.add_argument("--paraphrase-drop-chance-auc", type=float)
457 |
458 | # Augmentation file path (optional, but if provided it will be used)
459 | parser.add_argument("--augmentation-data-path", type=str)
460 |
461 | # Seed
462 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set")
463 |
464 | # Supervised loss share
465 | parser.add_argument("--supervised-loss-share-power", default=1.0, type=float, help="supervised_loss_share = 1 - (x/y) ** ")
466 |
467 | args = parser.parse_args()
468 | logger.debug(f"Received args: {json.dumps(args.__dict__, sort_keys=True, ensure_ascii=False, indent=1)}")
469 | # Set random seed
470 | set_seeds(args.seed)
471 |
472 | # Check if data path(s) exist
473 | for arg in [args.data_path, args.train_labels_path, args.valid_labels_path, args.test_labels_path]:
474 | if arg and not os.path.exists(arg):
475 | raise FileNotFoundError(f"Data @ {arg} not found.")
476 |
477 | # Create supervised_loss_share_fn
478 | def get_supervised_loss_share_fn(supervised_loss_share_power: Union[int, float]) -> Callable[[int, int], float]:
479 | def _supervised_loss_share_fn(current_step: int, max_steps: int) -> float:
480 | assert current_step <= max_steps
481 | return 1 - (current_step / max_steps) ** supervised_loss_share_power
482 |
483 | return _supervised_loss_share_fn
484 |
485 | supervised_loss_share_fn = get_supervised_loss_share_fn(args.supervised_loss_share_power)
486 |
487 | # Run
488 | run_proto(
489 | data_path=args.data_path,
490 | train_labels_path=args.train_labels_path,
491 | train_path=args.train_path,
492 | model_name_or_path=args.model_name_or_path,
493 | n_support=args.n_support,
494 | n_query=args.n_query,
495 | n_classes=args.n_classes,
496 | metric=args.metric,
497 | unlabeled_path=args.unlabeled_path,
498 |
499 | valid_labels_path=args.valid_labels_path,
500 | test_labels_path=args.test_labels_path,
501 | evaluate_every=args.evaluate_every,
502 | n_test_episodes=args.n_test_episodes,
503 |
504 | output_path=args.output_path,
505 | log_every=args.log_every,
506 | max_iter=args.max_iter,
507 | early_stop=args.early_stop,
508 |
509 | n_augmentation=args.n_augmentation,
510 | paraphrase_model_name_or_path=args.paraphrase_model_name_or_path,
511 | paraphrase_tokenizer_name_or_path=args.paraphrase_tokenizer_name_or_path,
512 | paraphrase_num_beams=args.paraphrase_num_beams,
513 | paraphrase_beam_group_size=args.paraphrase_beam_group_size,
514 | paraphrase_filtering_strategy=args.paraphrase_filtering_strategy,
515 | paraphrase_drop_strategy=args.paraphrase_drop_strategy,
516 | paraphrase_drop_chance_speed=args.paraphrase_drop_chance_speed,
517 | paraphrase_drop_chance_auc=args.paraphrase_drop_chance_auc,
518 | supervised_loss_share_fn=supervised_loss_share_fn,
519 |
520 | augmentation_data_path=args.augmentation_data_path
521 | )
522 |
523 | # Save config
524 | with open(os.path.join(args.output_path, "config.json"), "w") as file:
525 | json.dump(vars(args), file, ensure_ascii=False, indent=1)
526 |
527 |
528 | if __name__ == '__main__':
529 | main()
530 |
--------------------------------------------------------------------------------
/models/proto/protaugment-tmp.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
4 |
5 | import json
6 | import argparse
7 |
8 | from transformers import AutoTokenizer
9 |
10 | from models.encoders.bert_encoder import BERTEncoder
11 | from paraphrase.modeling import UnigramRandomDropParaphraseBatchPreparer, DBSParaphraseModel, BigramDropParaphraseBatchPreparer, BaseParaphraseBatchPreparer
12 | from paraphrase.utils.data import FewShotDataset, FewShotSSLParaphraseDataset, FewShotSSLFileDataset
13 | from utils.data import get_jsonl_data, FewShotDataLoader
14 | from utils.python import now, set_seeds
15 | import random
16 | import collections
17 | import os
18 | from typing import List, Dict, Callable, Union
19 | from tensorboardX import SummaryWriter
20 | import numpy as np
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn.functional as torch_functional
24 | from torch.autograd import Variable
25 | import warnings
26 | import logging
27 | from utils.few_shot import create_episode, create_ARSC_test_episode, create_ARSC_train_episode
28 | from utils.math import euclidean_dist, cosine_similarity
29 |
30 | logger = logging.getLogger(__name__)
31 | logger.setLevel(logging.DEBUG)
32 |
33 | warnings.simplefilter('ignore')
34 |
35 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
36 |
37 |
38 | class ProtAugmentNet(nn.Module):
39 | def __init__(self, encoder, metric="euclidean"):
40 | super(ProtAugmentNet, self).__init__()
41 |
42 | self.encoder: BERTEncoder = encoder
43 | self.metric = metric
44 | assert self.metric in ('euclidean', 'cosine')
45 |
46 | def loss(self, sample, supervised_loss_share: float = 0):
47 | """
48 | :param supervised_loss_share: share of supervised loss in total loss
49 | :param sample: {
50 | "xs": [
51 | [support_A_1, support_A_2, ...],
52 | [support_B_1, support_B_2, ...],
53 | [support_C_1, support_C_2, ...],
54 | ...
55 | ],
56 | "xq": [
57 | [query_A_1, query_A_2, ...],
58 | [query_B_1, query_B_2, ...],
59 | [query_C_1, query_C_2, ...],
60 | ...
61 | ]
62 | }
63 | :return:
64 | """
65 | xs = sample['xs'] # support
66 | xq = sample['xq'] # query
67 |
68 | n_class = len(xs)
69 | assert len(xq) == n_class
70 | n_support = len(xs[0])
71 | n_query = len(xq[0])
72 |
73 | target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long()
74 | target_inds = Variable(target_inds, requires_grad=False).to(device)
75 |
76 | has_augment = "x_augment" in sample
77 | if has_augment:
78 | augmentations = sample["x_augment"]
79 |
80 | n_augmentations_samples = len(sample["x_augment"])
81 | n_augmentations_per_sample = [len(item['tgt_texts']) for item in augmentations]
82 | assert len(set(n_augmentations_per_sample)) == 1
83 | n_augmentations_per_sample = n_augmentations_per_sample[0]
84 |
85 | supports = [item["sentence"] for xs_ in xs for item in xs_]
86 | queries = [item["sentence"] for xq_ in xq for item in xq_]
87 | augmentations_supports = [[item2 for item2 in item["tgt_texts"]] for item in sample["x_augment"]]
88 | augmentation_queries = [item["src_text"] for item in sample["x_augment"]]
89 |
90 | # Encode
91 | x = supports + queries + [item2 for item1 in augmentations_supports for item2 in item1] + augmentation_queries
92 | z = self.encoder.embed_sentences(x)
93 | z_dim = z.size(-1)
94 |
95 | # Dispatch
96 | z_support = z[:len(supports)].view(n_class, n_support, z_dim).mean(dim=[1])
97 | z_query = z[len(supports):len(supports) + len(queries)]
98 | z_aug_support = (z[len(supports) + len(queries):len(supports) + len(queries) + n_augmentations_per_sample * n_augmentations_samples]
99 | .view(n_augmentations_samples, n_augmentations_per_sample, z_dim).mean(dim=[1]))
100 | z_aug_query = z[-len(augmentation_queries):]
101 | else:
102 | # When not using augmentations
103 | supports = [item["sentence"] for xs_ in xs for item in xs_]
104 | queries = [item["sentence"] for xq_ in xq for item in xq_]
105 |
106 | # Encode
107 | x = supports + queries
108 | z = self.encoder.embed_sentences(x)
109 | z_dim = z.size(-1)
110 |
111 | # Dispatch
112 | z_support = z[:len(supports)].view(n_class, n_support, z_dim).mean(dim=[1])
113 | z_query = z[len(supports):len(supports) + len(queries)]
114 |
115 | if self.metric == "euclidean":
116 | supervised_dists = euclidean_dist(z_query, z_support)
117 | if has_augment:
118 | unsupervised_dists = euclidean_dist(z_aug_query, z_aug_support)
119 | elif self.metric == "cosine":
120 | supervised_dists = (-cosine_similarity(z_query, z_support) + 1) * 5
121 | if has_augment:
122 | unsupervised_dists = (-cosine_similarity(z_aug_query, z_aug_support) + 1) * 5
123 | else:
124 | raise NotImplementedError
125 |
126 | from torch.nn import CrossEntropyLoss
127 | supervised_loss = CrossEntropyLoss()(-supervised_dists, target_inds.reshape(-1))
128 | _, y_hat_supervised = (-supervised_dists).max(1)
129 | acc_val_supervised = torch.eq(y_hat_supervised, target_inds.reshape(-1)).float().mean()
130 |
131 | if has_augment:
132 | # Unsupervised loss
133 | unsupervised_target_inds = torch.range(0, n_augmentations_samples - 1).to(device).long()
134 | unsupervised_loss = CrossEntropyLoss()(-unsupervised_dists, unsupervised_target_inds)
135 | _, y_hat_unsupervised = (-unsupervised_dists).max(1)
136 | acc_val_unsupervised = torch.eq(y_hat_unsupervised, unsupervised_target_inds.reshape(-1)).float().mean()
137 |
138 | # Final loss
139 | assert 0 <= supervised_loss_share <= 1
140 | final_loss = (supervised_loss_share) * supervised_loss + (1 - supervised_loss_share) * unsupervised_loss
141 |
142 | return final_loss, {
143 | "metrics": {
144 | "supervised_acc": acc_val_supervised.item(),
145 | "unsupervised_acc": acc_val_unsupervised.item(),
146 | "supervised_loss": supervised_loss.item(),
147 | "unsupervised_loss": unsupervised_loss.item(),
148 | "supervised_loss_share": supervised_loss_share,
149 | "final_loss": final_loss.item(),
150 | },
151 | "supervised_dists": supervised_dists,
152 | "unsupervised_dists": unsupervised_dists,
153 | "target": target_inds
154 | }
155 |
156 | return supervised_loss, {
157 | "metrics": {
158 | "acc": acc_val_supervised.item(),
159 | "loss": supervised_loss.item(),
160 | },
161 | "dists": supervised_dists,
162 | "target": target_inds
163 | }
164 |
165 | def train_step(self, optimizer, episode, supervised_loss_share: float):
166 | self.train()
167 | optimizer.zero_grad()
168 | torch.cuda.empty_cache()
169 | loss, loss_dict = self.loss(episode, supervised_loss_share=supervised_loss_share)
170 | loss.backward()
171 | optimizer.step()
172 |
173 | return loss, loss_dict
174 |
175 | def test_step(self, dataset: FewShotDataset, n_episodes: int = 1000):
176 | metrics = collections.defaultdict(list)
177 |
178 | self.eval()
179 | for i in range(n_episodes):
180 | episode = dataset.get_episode()
181 |
182 | with torch.no_grad():
183 | loss, loss_dict = self.loss(episode, supervised_loss_share=1)
184 |
185 | for k, v in loss_dict["metrics"].items():
186 | metrics[k].append(v)
187 |
188 | return {
189 | key: np.mean(value) for key, value in metrics.items()
190 | }
191 |
192 |
193 | def run_proto(
194 | # Compulsory!
195 | data_path: str,
196 | train_labels_path: str,
197 | model_name_or_path: str,
198 |
199 | # Few-shot Stuff
200 | n_support: int,
201 | n_query: int,
202 | n_classes: int,
203 | metric: str = "euclidean",
204 |
205 | # Optional path to augmented data
206 | unlabeled_path: str = None,
207 |
208 | # Path training data ONLY (optional)
209 | train_path: str = None,
210 |
211 | # Validation & test
212 | valid_labels_path: str = None,
213 | test_labels_path: str = None,
214 | evaluate_every: int = 100,
215 | n_test_episodes: int = 1000,
216 |
217 | # Logging & Saving
218 | output_path: str = f'runs/{now()}',
219 | log_every: int = 10,
220 |
221 | # Training stuff
222 | max_iter: int = 10000,
223 | early_stop: int = None,
224 |
225 | # Augmentation & paraphrase
226 | n_augmentation: int = 5,
227 | paraphrase_model_name_or_path: str = None,
228 | paraphrase_tokenizer_name_or_path: str = None,
229 | paraphrase_num_beams: int = None,
230 | paraphrase_beam_group_size: int = None,
231 | paraphrase_diversity_penalty: float = None,
232 | paraphrase_filtering_strategy: str = None,
233 | paraphrase_drop_strategy: str = None,
234 | paraphrase_drop_chance_speed: str = None,
235 | paraphrase_drop_chance_auc: float = None,
236 | supervised_loss_share_fn: Callable[[int, int], float] = lambda x, y: 1 - (x / y),
237 |
238 | augmentation_data_path: str = None
239 | ):
240 | if output_path:
241 | if os.path.exists(output_path) and len(os.listdir(output_path)):
242 | raise FileExistsError(f"Output path {output_path} already exists. Exiting.")
243 |
244 | # --------------------
245 | # Creating Log Writers
246 | # --------------------
247 | os.makedirs(output_path)
248 | os.makedirs(os.path.join(output_path, "logs/train"))
249 | train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(output_path, "logs/train"), flush_secs=1, max_queue=1)
250 | valid_writer: SummaryWriter = None
251 | test_writer: SummaryWriter = None
252 | log_dict = dict(train=list())
253 |
254 | # ----------
255 | # Load model
256 | # ----------
257 | bert = BERTEncoder(model_name_or_path).to(device)
258 | protonet: ProtAugmentNet = ProtAugmentNet(encoder=bert, metric=metric)
259 | optimizer = torch.optim.Adam(protonet.parameters(), lr=2e-5)
260 |
261 | # ------------------
262 | # Load Train Dataset
263 | # ------------------
264 | if augmentation_data_path:
265 | # If an augmentation data path is provided, uses those pre-generated augmentations
266 | train_dataset = FewShotSSLFileDataset(
267 | data_path=train_path if train_path else data_path,
268 | labels_path=train_labels_path,
269 | n_classes=n_classes,
270 | n_support=n_support,
271 | n_query=n_query,
272 | n_unlabeled=n_augmentation,
273 | unlabeled_file_path=augmentation_data_path,
274 | )
275 | else:
276 | # ---------------------
277 | # Load paraphrase model
278 | # ---------------------
279 | paraphrase_model_device = torch.device("cpu") if "20newsgroup" in data_path else torch.device("cuda")
280 | logger.info(f"Paraphrase model device: {paraphrase_model_device}")
281 | paraphrase_tokenizer = AutoTokenizer.from_pretrained(paraphrase_tokenizer_name_or_path)
282 | if paraphrase_drop_strategy == "unigram":
283 | paraphrase_batch_preparer = UnigramRandomDropParaphraseBatchPreparer(
284 | tokenizer=paraphrase_tokenizer,
285 | auc=paraphrase_drop_chance_auc,
286 | drop_chance_speed=paraphrase_drop_chance_speed,
287 | device=paraphrase_model_device
288 | )
289 | elif paraphrase_drop_strategy == "bigram":
290 | paraphrase_batch_preparer = BigramDropParaphraseBatchPreparer(tokenizer=paraphrase_tokenizer, device=paraphrase_model_device)
291 | else:
292 | paraphrase_batch_preparer = BaseParaphraseBatchPreparer(tokenizer=paraphrase_tokenizer, device=paraphrase_model_device)
293 |
294 | paraphrase_model = DBSParaphraseModel(
295 | model_name_or_path=paraphrase_model_name_or_path,
296 | tok_name_or_path=paraphrase_tokenizer_name_or_path,
297 | num_beams=paraphrase_num_beams,
298 | beam_group_size=paraphrase_beam_group_size,
299 | diversity_penalty=paraphrase_diversity_penalty,
300 | filtering_strategy=paraphrase_filtering_strategy,
301 | paraphrase_batch_preparer=paraphrase_batch_preparer,
302 | device=paraphrase_model_device
303 | )
304 |
305 | train_dataset = FewShotSSLParaphraseDataset(
306 | data_path=train_path if train_path else data_path,
307 | labels_path=train_labels_path,
308 | n_classes=n_classes,
309 | n_support=n_support,
310 | n_query=n_query,
311 | n_unlabeled=n_augmentation,
312 | unlabeled_file_path=unlabeled_path,
313 | paraphrase_model=paraphrase_model
314 | )
315 | logger.info(f"Train dataset has {len(train_dataset)} items")
316 |
317 | # ---------
318 | # Load data
319 | # ---------
320 | logger.info(f"train labels: {train_dataset.data.keys()}")
321 | valid_dataset: FewShotDataset = None
322 | if valid_labels_path:
323 | os.makedirs(os.path.join(output_path, "logs/valid"))
324 | valid_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/valid"), flush_secs=1, max_queue=1)
325 | log_dict["valid"] = list()
326 | valid_dataset = FewShotDataset(data_path=data_path, labels_path=valid_labels_path, n_classes=n_classes, n_support=n_support, n_query=n_query)
327 | logger.info(f"valid labels: {valid_dataset.data.keys()}")
328 | assert len(set(valid_dataset.data.keys()) & set(train_dataset.data.keys())) == 0
329 |
330 | test_dataset: FewShotDataset = None
331 | if test_labels_path:
332 | os.makedirs(os.path.join(output_path, "logs/test"))
333 | test_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/test"), flush_secs=1, max_queue=1)
334 | log_dict["test"] = list()
335 | test_dataset = FewShotDataset(data_path=data_path, labels_path=test_labels_path, n_classes=n_classes, n_support=n_support, n_query=n_query)
336 | logger.info(f"test labels: {test_dataset.data.keys()}")
337 | assert len(set(test_dataset.data.keys()) & set(train_dataset.data.keys())) == 0
338 |
339 | train_metrics = collections.defaultdict(list)
340 | n_eval_since_last_best = 0
341 | best_valid_acc = 0.0
342 |
343 | for step in range(max_iter):
344 | episode = train_dataset.get_episode()
345 |
346 | supervised_loss_share = supervised_loss_share_fn(step, max_iter)
347 | loss, loss_dict = protonet.train_step(optimizer=optimizer, episode=episode, supervised_loss_share=supervised_loss_share)
348 |
349 | for key, value in loss_dict["metrics"].items():
350 | train_metrics[key].append(value)
351 |
352 | # Logging
353 | if (step + 1) % log_every == 0:
354 | for key, value in train_metrics.items():
355 | train_writer.add_scalar(tag=key, scalar_value=np.mean(value), global_step=step)
356 | logger.info(f"train | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in train_metrics.items()]))
357 | log_dict["train"].append({
358 | "metrics": [
359 | {
360 | "tag": key,
361 | "value": np.mean(value)
362 | }
363 | for key, value in train_metrics.items()
364 | ],
365 | "global_step": step
366 | })
367 |
368 | train_metrics = collections.defaultdict(list)
369 |
370 | if valid_labels_path or test_labels_path:
371 | if (step + 1) % evaluate_every == 0:
372 | for labels_path, writer, set_type, set_dataset in zip(
373 | [valid_labels_path, test_labels_path],
374 | [valid_writer, test_writer],
375 | ["valid", "test"],
376 | [valid_dataset, test_dataset]
377 | ):
378 | if set_dataset:
379 |
380 | set_results = protonet.test_step(
381 | dataset=set_dataset,
382 | n_episodes=n_test_episodes
383 | )
384 |
385 | for key, val in set_results.items():
386 | writer.add_scalar(tag=key, scalar_value=val, global_step=step)
387 | log_dict[set_type].append({
388 | "metrics": [
389 | {
390 | "tag": key,
391 | "value": val
392 | }
393 | for key, val in set_results.items()
394 | ],
395 | "global_step": step
396 | })
397 |
398 | logger.info(f"{set_type} | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in set_results.items()]))
399 | if set_type == "valid":
400 | if set_results["acc"] > best_valid_acc:
401 | best_valid_acc = set_results["acc"]
402 | n_eval_since_last_best = 0
403 | logger.info(f"Better eval results!")
404 | else:
405 | n_eval_since_last_best += 1
406 | logger.info(f"Worse eval results ({n_eval_since_last_best}/{early_stop})")
407 |
408 | if early_stop and n_eval_since_last_best >= early_stop:
409 | logger.warning(f"Early-stopping.")
410 | break
411 |
412 | with open(os.path.join(output_path, 'metrics.json'), "w") as file:
413 | json.dump(log_dict, file, ensure_ascii=False)
414 |
415 |
416 | def main():
417 | parser = argparse.ArgumentParser()
418 | parser.add_argument("--data-path", type=str, required=True, help="Path to data")
419 | parser.add_argument("--train-labels-path", type=str, required=True, help="Path to train labels")
420 | parser.add_argument("--train-path", type=str, help="Path to training data (if provided, picks training data from this path instead of --data-path")
421 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use")
422 |
423 | # Few-Shot related stuff
424 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class")
425 | parser.add_argument("--n-query", type=int, default=5, help="Number of query points for each class")
426 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode")
427 | parser.add_argument("--metric", type=str, default="euclidean", help="Metric to use", choices=("euclidean", "cosine"))
428 |
429 | # Path to augmented data
430 | parser.add_argument("--unlabeled-path", type=str, required=True, help="Path to data containing augmentations used for consistency")
431 |
432 | # Validation & test
433 | parser.add_argument("--valid-labels-path", type=str, required=True, help="Path to valid labels")
434 | parser.add_argument("--test-labels-path", type=str, required=True, help="Path to test labels")
435 | parser.add_argument("--evaluate-every", type=int, default=100, help="Number of training episodes between each evaluation (on both valid, test)")
436 | parser.add_argument("--n-test-episodes", type=int, default=1000, help="Number of episodes during evaluation (valid, test)")
437 |
438 | # Logging & Saving
439 | parser.add_argument("--output-path", type=str, default=f'runs/{now()}')
440 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging")
441 |
442 | # Training stuff
443 | parser.add_argument("--max-iter", type=int, default=10000, help="Max number of training episodes")
444 | parser.add_argument("--early-stop", type=int, default=0, help="Number of worse evaluation steps before stopping. 0=disabled")
445 |
446 | # Augmentation & Paraphrase
447 | parser.add_argument("--n-augmentation", type=int, help="Number of unlabeled data points per class (proto++)", default=0)
448 | parser.add_argument("--paraphrase-model-name-or-path", type=str)
449 | parser.add_argument("--paraphrase-tokenizer-name-or-path", type=str)
450 | parser.add_argument("--paraphrase-num-beams", type=int)
451 | parser.add_argument("--paraphrase-beam-group-size", type=int)
452 | parser.add_argument("--paraphrase-diversity-penalty", type=float)
453 | parser.add_argument("--paraphrase-filtering-strategy", type=str)
454 | parser.add_argument("--paraphrase-drop-strategy", type=str)
455 | parser.add_argument("--paraphrase-drop-chance-speed", type=str)
456 | parser.add_argument("--paraphrase-drop-chance-auc", type=float)
457 |
458 | # Augmentation file path (optional, but if provided it will be used)
459 | parser.add_argument("--augmentation-data-path", type=str)
460 |
461 | # Seed
462 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set")
463 |
464 | # Supervised loss share
465 | parser.add_argument("--supervised-loss-share-power", default=1.0, type=float, help="supervised_loss_share = 1 - (x/y) ** ")
466 |
467 | args = parser.parse_args()
468 | logger.debug(f"Received args: {json.dumps(args.__dict__, sort_keys=True, ensure_ascii=False, indent=1)}")
469 | # Set random seed
470 | set_seeds(args.seed)
471 |
472 | # Check if data path(s) exist
473 | for arg in [args.data_path, args.train_labels_path, args.valid_labels_path, args.test_labels_path]:
474 | if arg and not os.path.exists(arg):
475 | raise FileNotFoundError(f"Data @ {arg} not found.")
476 |
477 | # Create supervised_loss_share_fn
478 | def get_supervised_loss_share_fn(supervised_loss_share_power: Union[int, float]) -> Callable[[int, int], float]:
479 | def _supervised_loss_share_fn(current_step: int, max_steps: int) -> float:
480 | assert current_step <= max_steps
481 | return 1 - (current_step / max_steps) ** supervised_loss_share_power
482 |
483 | return _supervised_loss_share_fn
484 |
485 | supervised_loss_share_fn = get_supervised_loss_share_fn(args.supervised_loss_share_power)
486 |
487 | # Run
488 | run_proto(
489 | data_path=args.data_path,
490 | train_labels_path=args.train_labels_path,
491 | train_path=args.train_path,
492 | model_name_or_path=args.model_name_or_path,
493 | n_support=args.n_support,
494 | n_query=args.n_query,
495 | n_classes=args.n_classes,
496 | metric=args.metric,
497 | unlabeled_path=args.unlabeled_path,
498 |
499 | valid_labels_path=args.valid_labels_path,
500 | test_labels_path=args.test_labels_path,
501 | evaluate_every=args.evaluate_every,
502 | n_test_episodes=args.n_test_episodes,
503 |
504 | output_path=args.output_path,
505 | log_every=args.log_every,
506 | max_iter=args.max_iter,
507 | early_stop=args.early_stop,
508 |
509 | n_augmentation=args.n_augmentation,
510 | paraphrase_model_name_or_path=args.paraphrase_model_name_or_path,
511 | paraphrase_tokenizer_name_or_path=args.paraphrase_tokenizer_name_or_path,
512 | paraphrase_num_beams=args.paraphrase_num_beams,
513 | paraphrase_beam_group_size=args.paraphrase_beam_group_size,
514 | paraphrase_filtering_strategy=args.paraphrase_filtering_strategy,
515 | paraphrase_drop_strategy=args.paraphrase_drop_strategy,
516 | paraphrase_drop_chance_speed=args.paraphrase_drop_chance_speed,
517 | paraphrase_drop_chance_auc=args.paraphrase_drop_chance_auc,
518 | supervised_loss_share_fn=supervised_loss_share_fn,
519 |
520 | augmentation_data_path=args.augmentation_data_path
521 | )
522 |
523 | # Save config
524 | with open(os.path.join(args.output_path, "config.json"), "w") as file:
525 | json.dump(vars(args), file, ensure_ascii=False, indent=1)
526 |
527 |
528 | if __name__ == '__main__':
529 | main()
530 |
--------------------------------------------------------------------------------
/models/proto/protonet.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from models.encoders.bert_encoder import BERTEncoder
4 | from utils.data import get_jsonl_data, FewShotDataLoader
5 | from utils.python import now, set_seeds
6 | import random
7 | import collections
8 | import os
9 | from typing import List, Dict, Callable, Union
10 | from tensorboardX import SummaryWriter
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as torch_functional
15 | from torch.autograd import Variable
16 | import warnings
17 | import logging
18 | from utils.few_shot import create_episode, create_ARSC_test_episode, create_ARSC_train_episode
19 | from utils.math import euclidean_dist, cosine_similarity
20 |
21 | logging.basicConfig()
22 | logger = logging.getLogger(__name__)
23 | logger.setLevel(logging.DEBUG)
24 |
25 | warnings.simplefilter('ignore')
26 |
27 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
28 |
29 |
30 | class ProtoNet(nn.Module):
31 | def __init__(self, encoder: BERTEncoder, metric="euclidean"):
32 | super(ProtoNet, self).__init__()
33 |
34 | self.encoder: BERTEncoder = encoder
35 | self.metric = metric
36 | assert self.metric in ('euclidean', 'cosine')
37 |
38 | def loss(self, sample, supervised_loss_share: float = 0):
39 | """
40 | :param supervised_loss_share: share of supervised loss in total loss
41 | :param sample: {
42 | "xs": [
43 | [support_A_1, support_A_2, ...],
44 | [support_B_1, support_B_2, ...],
45 | [support_C_1, support_C_2, ...],
46 | ...
47 | ],
48 | "xq": [
49 | [query_A_1, query_A_2, ...],
50 | [query_B_1, query_B_2, ...],
51 | [query_C_1, query_C_2, ...],
52 | ...
53 | ]
54 | }
55 | :return:
56 | """
57 | xs = sample['xs'] # support
58 | xq = sample['xq'] # query
59 |
60 | n_class = len(xs)
61 | assert len(xq) == n_class
62 | n_support = len(xs[0])
63 | n_query = len(xq[0])
64 |
65 | target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long()
66 | target_inds = Variable(target_inds, requires_grad=False).to(device)
67 |
68 | has_augmentations = ("x_augment" in sample)
69 |
70 | if has_augmentations:
71 | # When using augmentations
72 | augmentations = sample["x_augment"]
73 |
74 | n_augmentations_samples = len(sample["x_augment"])
75 | n_augmentations_per_sample = [len(item['augmentations']) for item in augmentations]
76 | assert set(n_augmentations_per_sample) == {5}
77 | n_augmentations_per_sample = n_augmentations_per_sample[0]
78 |
79 | supports = [item["sentence"] for xs_ in xs for item in xs_]
80 | queries = [item["sentence"] for xq_ in xq for item in xq_]
81 | augmentations_supports = [[item2["text"] for item2 in item["augmentations"]] for item in sample["x_augment"]]
82 | augmentation_queries = [item["sentence"] for item in sample["x_augment"]]
83 |
84 | # Encode
85 | x = supports + queries + [item2 for item1 in augmentations_supports for item2 in item1] + augmentation_queries
86 | z = self.encoder.embed_sentences(x)
87 | z_dim = z.size(-1)
88 |
89 | # Dispatch
90 | z_support = z[:len(supports)].view(n_class, n_support, z_dim).mean(dim=[1])
91 | z_query = z[len(supports):len(supports) + len(queries)]
92 | z_aug_support = (z[len(supports) + len(queries):len(supports) + len(queries) + n_augmentations_per_sample * n_augmentations_samples]
93 | .view(n_augmentations_samples, n_augmentations_per_sample, z_dim).mean(dim=[1]))
94 | z_aug_query = z[-len(augmentation_queries):]
95 | else:
96 | # When not using augmentations
97 | supports = [item["sentence"] for xs_ in xs for item in xs_]
98 | queries = [item["sentence"] for xq_ in xq for item in xq_]
99 |
100 | # Encode
101 | x = supports + queries
102 | z = self.encoder.embed_sentences(x)
103 | z_dim = z.size(-1)
104 |
105 | # Dispatch
106 | z_support = z[:len(supports)].view(n_class, n_support, z_dim).mean(dim=[1])
107 | z_query = z[len(supports):len(supports) + len(queries)]
108 |
109 | if self.metric == "euclidean":
110 | supervised_dists = euclidean_dist(z_query, z_support)
111 | if has_augmentations:
112 | unsupervised_dists = euclidean_dist(z_aug_query, z_aug_support)
113 | elif self.metric == "cosine":
114 | supervised_dists = (-cosine_similarity(z_query, z_support) + 1) * 5
115 | if has_augmentations:
116 | unsupervised_dists = (-cosine_similarity(z_aug_query, z_aug_support) + 1) * 5
117 | else:
118 | raise NotImplementedError
119 |
120 | # Supervised loss
121 | # -- legacy
122 | # log_p_y = torch_functional.log_softmax(-supervised_dists, dim=1).view(n_class, n_query, -1)
123 | # dists.view(n_class, n_query, -1)
124 | # loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
125 | # -- NEW
126 | from torch.nn import CrossEntropyLoss
127 | supervised_loss = CrossEntropyLoss()(-supervised_dists, target_inds.reshape(-1))
128 | _, y_hat_supervised = (-supervised_dists).max(1)
129 | acc_val_supervised = torch.eq(y_hat_supervised, target_inds.reshape(-1)).float().mean()
130 |
131 | if has_augmentations:
132 | # Unsupervised loss
133 | unsupervised_target_inds = torch.range(0, n_augmentations_samples - 1).to(device).long()
134 | unsupervised_loss = CrossEntropyLoss()(-unsupervised_dists, unsupervised_target_inds)
135 | _, y_hat_unsupervised = (-unsupervised_dists).max(1)
136 | acc_val_unsupervised = torch.eq(y_hat_unsupervised, unsupervised_target_inds.reshape(-1)).float().mean()
137 |
138 | # Final loss
139 | assert 0 <= supervised_loss_share <= 1
140 | final_loss = (supervised_loss_share) * supervised_loss + (1 - supervised_loss_share) * unsupervised_loss
141 |
142 | return final_loss, {
143 | "metrics": {
144 | "supervised_acc": acc_val_supervised.item(),
145 | "unsupervised_acc": acc_val_unsupervised.item(),
146 | "supervised_loss": supervised_loss.item(),
147 | "unsupervised_loss": unsupervised_loss.item(),
148 | "supervised_loss_share": supervised_loss_share,
149 | "final_loss": final_loss.item(),
150 | },
151 | "supervised_dists": supervised_dists,
152 | "unsupervised_dists": unsupervised_dists,
153 | "target": target_inds
154 | }
155 |
156 | return supervised_loss, {
157 | "metrics": {
158 | "acc": acc_val_supervised.item(),
159 | "loss": supervised_loss.item(),
160 | },
161 | "dists": supervised_dists,
162 | "target": target_inds
163 | }
164 |
165 | def loss_softkmeans(self, sample):
166 | xs = sample['xs'] # support
167 | xq = sample['xq'] # query
168 | xu = sample['xu'] # unlabeled
169 |
170 | n_class = len(xs)
171 | assert len(xq) == n_class
172 | n_support = len(xs[0])
173 | n_query = len(xq[0])
174 |
175 | target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long()
176 | target_inds = Variable(target_inds, requires_grad=False).to(device)
177 |
178 | x = [item["sentence"] for xs_ in xs for item in xs_] + [item["sentence"] for xq_ in xq for item in xq_] + [item["sentence"] for item in xu]
179 | z = self.encoder.embed_sentences(x)
180 | z_dim = z.size(-1)
181 |
182 | zs = z[:n_class * n_support]
183 | z_proto = z[:n_class * n_support].view(n_class, n_support, z_dim).mean(1)
184 | zq = z[n_class * n_support: (n_class * n_support) + (n_class * n_query)]
185 | zu = z[(n_class * n_support) + (n_class * n_query):]
186 |
187 | distances_to_proto = euclidean_dist(
188 | torch.cat((zs, zu)),
189 | z_proto
190 | )
191 |
192 | distances_to_proto_normed = torch.nn.Softmax(dim=-1)(-distances_to_proto)
193 |
194 | refined_protos = list()
195 | for class_ix in range(n_class):
196 | z = torch.cat(
197 | (zs[class_ix * n_support: (class_ix + 1) * n_support], zu)
198 | )
199 | d = torch.cat(
200 | (torch.ones(n_support).to(device),
201 | distances_to_proto_normed[(n_class * n_support):, class_ix])
202 | )
203 | refined_proto = ((z.t() * d).sum(1) / d.sum())
204 | refined_protos.append(refined_proto.view(1, -1))
205 | refined_protos = torch.cat(refined_protos)
206 |
207 | if self.metric == "euclidean":
208 | dists = euclidean_dist(zq, refined_protos)
209 | elif self.metric == "cosine":
210 | dists = (-cosine_similarity(zq, refined_protos) + 1) * 5
211 | else:
212 | raise NotImplementedError
213 |
214 | log_p_y = torch_functional.log_softmax(-dists, dim=1).view(n_class, n_query, -1)
215 | dists.view(n_class, n_query, -1)
216 | loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
217 | _, y_hat = log_p_y.max(2)
218 | acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()
219 |
220 | return loss_val, {
221 | 'loss': loss_val.item(),
222 | "metrics": {
223 | "acc": acc_val.item(),
224 | "loss": loss_val.item(),
225 | },
226 | 'dists': dists,
227 | 'target': target_inds
228 | }
229 |
230 | def loss_consistency(self, sample):
231 | x_augment = sample["x_augment"]
232 | n_samples = len(x_augment)
233 |
234 | # x_augment = [(A, [A_1, A_2, ..., A_n]), (B, [B_1, B_2, ..., B_m])]
235 | lengths = [1 + len(augments) for sentence, augments in x_augment]
236 |
237 | x = list()
238 | for sentence, augs in x_augment:
239 | x.append(sentence)
240 | x += augs
241 |
242 | z = self.encoder.embed_sentences(x)
243 | assert len(z) == sum(lengths)
244 |
245 | i = 0
246 | original_embeddings = list()
247 | augmented_embeddings = list()
248 | for length in lengths:
249 | original_embeddings.append(z[i])
250 | augmented_embeddings.append(z[i + 1:i + length + 1])
251 | i += length
252 |
253 | augmented_embeddings = [a.mean(0) for a in augmented_embeddings]
254 | if self.metric == "euclidean":
255 | dists = euclidean_dist(original_embeddings, augmented_embeddings)
256 | elif self.metric == "cosine":
257 | dists = (-cosine_similarity(original_embeddings, augmented_embeddings) + 1) * 5
258 | else:
259 | raise NotImplementedError
260 |
261 | log_p_y = torch_functional.log_softmax(-dists, dim=1).view(n_samples, n_samples, -1)
262 |
263 | # loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
264 | # _, y_hat = log_p_y.max(2)
265 | # acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()
266 | #
267 | # return loss_val, {
268 | # 'loss': loss_val.item(),
269 | # 'acc': acc_val.item(),
270 | # 'dists': dists,
271 | # 'target': target_inds
272 | # }
273 |
274 | def train_step(self, optimizer, episode, supervised_loss_share: float, unlabeled: bool = False):
275 | self.train()
276 | optimizer.zero_grad()
277 | torch.cuda.empty_cache()
278 | if unlabeled:
279 | loss, loss_dict = self.loss_softkmeans(episode)
280 | else:
281 | loss, loss_dict = self.loss(episode, supervised_loss_share=supervised_loss_share)
282 | loss.backward()
283 | optimizer.step()
284 |
285 | return loss, loss_dict
286 |
287 | def test_step(self,
288 | data_loader: FewShotDataLoader,
289 | n_support: int,
290 | n_query: int,
291 | n_classes: int,
292 | n_unlabeled: int = 0,
293 | n_episodes: int = 1000):
294 | metrics = collections.defaultdict(list)
295 |
296 | self.eval()
297 | for i in range(n_episodes):
298 | episode = data_loader.create_episode(
299 | n_support=n_support,
300 | n_query=n_query,
301 | n_unlabeled=n_unlabeled,
302 | n_classes=n_classes,
303 | n_augment=0
304 | )
305 |
306 | with torch.no_grad():
307 | if n_unlabeled:
308 | loss, loss_dict = self.loss_softkmeans(episode)
309 | else:
310 | loss, loss_dict = self.loss(episode, supervised_loss_share=0)
311 |
312 | for k, v in loss_dict["metrics"].items():
313 | metrics[k].append(v)
314 |
315 | return {
316 | key: np.mean(value) for key, value in metrics.items()
317 | }
318 |
319 | def train_step_ARSC(self, data_path: str, optimizer, n_unlabeled: int):
320 | episode = create_ARSC_train_episode(prefix=data_path, n_support=5, n_query=5, n_unlabeled=n_unlabeled)
321 |
322 | self.train()
323 | optimizer.zero_grad()
324 | torch.cuda.empty_cache()
325 | if n_unlabeled:
326 | loss, loss_dict = self.loss_softkmeans(episode)
327 | else:
328 | loss, loss_dict = self.loss(episode)
329 | loss.backward()
330 | optimizer.step()
331 |
332 | return loss, loss_dict
333 |
334 | def test_step_ARSC(self, data_path: str, n_unlabeled=0, n_episodes=1000, set_type="test"):
335 | assert set_type in ("dev", "test")
336 | metrics = collections.defaultdict(list)
337 | self.eval()
338 | for i in range(n_episodes):
339 | episode = create_ARSC_test_episode(prefix=data_path, n_query=5, n_unlabeled=n_unlabeled, set_type=set_type)
340 |
341 | with torch.no_grad():
342 | if n_unlabeled:
343 | loss, loss_dict = self.loss_softkmeans(episode)
344 | else:
345 | loss, loss_dict = self.loss(episode)
346 |
347 | for k, v in loss_dict["metrics"].items():
348 | metrics[k].append(v)
349 |
350 | return {
351 | key: np.mean(value) for key, value in metrics.items()
352 | }
353 |
354 |
355 | def run_proto(
356 | train_path: str,
357 | model_name_or_path: str,
358 | n_support: int,
359 | n_query: int,
360 | n_classes: int,
361 | unlabeled_path: str = None,
362 | valid_path: str = None,
363 | test_path: str = None,
364 | n_unlabeled: int = 0,
365 | n_augment: int = 0,
366 | output_path: str = f'runs/{now()}',
367 | max_iter: int = 10000,
368 | evaluate_every: int = 100,
369 | early_stop: int = None,
370 | n_test_episodes: int = 1000,
371 | log_every: int = 10,
372 | metric: str = "euclidean",
373 | arsc_format: bool = False,
374 | data_path: str = None,
375 | supervised_loss_share_fn: Callable[[int, int], float] = lambda x, y: 1 - (x / y)
376 | ):
377 | if output_path:
378 | if os.path.exists(output_path) and len(os.listdir(output_path)):
379 | raise FileExistsError(f"Output path {output_path} already exists. Exiting.")
380 |
381 | # --------------------
382 | # Creating Log Writers
383 | # --------------------
384 | os.makedirs(output_path)
385 | os.makedirs(os.path.join(output_path, "logs/train"))
386 | train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(output_path, "logs/train"), flush_secs=1, max_queue=1)
387 | valid_writer: SummaryWriter = None
388 | test_writer: SummaryWriter = None
389 | log_dict = dict(train=list())
390 |
391 | if valid_path:
392 | os.makedirs(os.path.join(output_path, "logs/valid"))
393 | valid_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/valid"), flush_secs=1, max_queue=1)
394 | log_dict["valid"] = list()
395 | if test_path:
396 | os.makedirs(os.path.join(output_path, "logs/test"))
397 | test_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/test"), flush_secs=1, max_queue=1)
398 | log_dict["test"] = list()
399 |
400 | # Load model
401 | bert = BERTEncoder(model_name_or_path).to(device)
402 | protonet = ProtoNet(encoder=bert, metric=metric)
403 | optimizer = torch.optim.Adam(protonet.parameters(), lr=2e-5)
404 |
405 | # Load data
406 | if not arsc_format:
407 | train_data_loader = FewShotDataLoader(train_path, unlabeled_file_path=unlabeled_path)
408 | logger.info(f"train labels: {train_data_loader.data_dict.keys()}")
409 |
410 | if valid_path:
411 | valid_data_loader = FewShotDataLoader(valid_path)
412 | logger.info(f"valid labels: {valid_data_loader.data_dict.keys()}")
413 | else:
414 | valid_data_loader = None
415 |
416 | if test_path:
417 | test_data_loader = FewShotDataLoader(test_path)
418 | logger.info(f"test labels: {test_data_loader.data_dict.keys()}")
419 | else:
420 | test_data_loader = None
421 | else:
422 | train_data_loader = None
423 | valid_data_loader = None
424 | test_data_loader = None
425 |
426 | train_metrics = collections.defaultdict(list)
427 | n_eval_since_last_best = 0
428 | best_valid_acc = 0.0
429 |
430 | for step in range(max_iter):
431 | if not arsc_format:
432 | episode = train_data_loader.create_episode(
433 | n_support=n_support,
434 | n_query=n_query,
435 | n_classes=n_classes,
436 | n_unlabeled=n_unlabeled,
437 | n_augment=n_augment
438 | )
439 | else:
440 | episode = create_ARSC_train_episode(n_support=5, n_query=5)
441 |
442 | supervised_loss_share = supervised_loss_share_fn(step, max_iter)
443 | loss, loss_dict = protonet.train_step(optimizer=optimizer, episode=episode, unlabeled=(n_unlabeled > 0), supervised_loss_share=supervised_loss_share)
444 |
445 | for key, value in loss_dict["metrics"].items():
446 | train_metrics[key].append(value)
447 |
448 | # Logging
449 | if (step + 1) % log_every == 0:
450 | for key, value in train_metrics.items():
451 | train_writer.add_scalar(tag=key, scalar_value=np.mean(value), global_step=step)
452 | logger.info(f"train | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in train_metrics.items()]))
453 | log_dict["train"].append({
454 | "metrics": [
455 | {
456 | "tag": key,
457 | "value": np.mean(value)
458 | }
459 | for key, value in train_metrics.items()
460 | ],
461 | "global_step": step
462 | })
463 |
464 | train_metrics = collections.defaultdict(list)
465 |
466 | if valid_path or test_path:
467 | if (step + 1) % evaluate_every == 0:
468 | for path, writer, set_type, set_data_loader in zip(
469 | [valid_path, test_path],
470 | [valid_writer, test_writer],
471 | ["valid", "test"],
472 | [valid_data_loader, valid_data_loader]
473 | ):
474 | if path:
475 | if not arsc_format:
476 | set_results = protonet.test_step(
477 | data_loader=set_data_loader,
478 | n_unlabeled=n_unlabeled,
479 | n_support=n_support,
480 | n_query=n_query,
481 | n_classes=n_classes,
482 | n_episodes=n_test_episodes
483 | )
484 | else:
485 | set_results = protonet.test_step_ARSC(
486 | data_path=data_path,
487 | n_unlabeled=n_unlabeled,
488 | n_episodes=n_test_episodes,
489 | set_type={"valid": "dev", "test": "test"}[set_type]
490 | )
491 |
492 | for key, val in set_results.items():
493 | writer.add_scalar(tag=key, scalar_value=val, global_step=step)
494 | log_dict[set_type].append({
495 | "metrics": [
496 | {
497 | "tag": key,
498 | "value": val
499 | }
500 | for key, val in set_results.items()
501 | ],
502 | "global_step": step
503 | })
504 |
505 | logger.info(f"{set_type} | " + " | ".join([f"{key}:{np.mean(value):.4f}" for key, value in set_results.items()]))
506 | if set_type == "valid":
507 | if set_results["acc"] > best_valid_acc:
508 | best_valid_acc = set_results["acc"]
509 | n_eval_since_last_best = 0
510 | logger.info(f"Better eval results!")
511 | else:
512 | n_eval_since_last_best += 1
513 | logger.info(f"Worse eval results ({n_eval_since_last_best}/{early_stop})")
514 |
515 | if early_stop and n_eval_since_last_best >= early_stop:
516 | logger.warning(f"Early-stopping.")
517 | break
518 | with open(os.path.join(output_path, 'metrics.json'), "w") as file:
519 | json.dump(log_dict, file, ensure_ascii=False)
520 |
521 |
522 | def main():
523 | parser = argparse.ArgumentParser()
524 | parser.add_argument("--train-path", type=str, required=True, help="Path to training data")
525 | parser.add_argument("--valid-path", type=str, default=None, help="Path to validation data")
526 | parser.add_argument("--test-path", type=str, default=None, help="Path to testing data")
527 | parser.add_argument("--unlabeled-path", type=str, default=None, help="Path to data containing augmentations used for consistency")
528 | parser.add_argument("--data-path", type=str, default=None, help="Path to data (ARSC only)")
529 |
530 | parser.add_argument("--output-path", type=str, default=f'runs/{now()}')
531 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use")
532 | parser.add_argument("--n-unlabeled", type=int, help="Number of unlabeled data points per class (proto++)", default=0)
533 | parser.add_argument("--max-iter", type=int, default=10000, help="Max number of training episodes")
534 | parser.add_argument("--evaluate-every", type=int, default=100, help="Number of training episodes between each evaluation (on both valid, test)")
535 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging")
536 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set")
537 | parser.add_argument("--early-stop", type=int, default=0, help="Number of worse evaluation steps before stopping. 0=disabled")
538 |
539 | # Few-Shot related stuff
540 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class")
541 | parser.add_argument("--n-query", type=int, default=5, help="Number of query points for each class")
542 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode")
543 | parser.add_argument("--n-augment", type=int, default=0, help="Number of augmented samples to take")
544 | parser.add_argument("--n-test-episodes", type=int, default=1000, help="Number of episodes during evaluation (valid, test)")
545 |
546 | # Metric to use in proto distance calculation
547 | parser.add_argument("--metric", type=str, default="euclidean", help="Metric to use", choices=("euclidean", "cosine"))
548 |
549 | # Supervised loss share
550 | parser.add_argument("--supervised-loss-share-power", default=1.0, type=float, help="supervised_loss_share = 1 - (x/y) ** ")
551 |
552 | # ARSC data
553 | parser.add_argument("--arsc-format", default=False, action="store_true", help="Using ARSC few-shot format")
554 | args = parser.parse_args()
555 |
556 | # Set random seed
557 | set_seeds(args.seed)
558 |
559 | # Check if data path(s) exist
560 | for arg in [args.train_path, args.valid_path, args.test_path]:
561 | if arg and not os.path.exists(arg):
562 | raise FileNotFoundError(f"Data @ {arg} not found.")
563 |
564 | # Create supervised_loss_share_fn
565 | def get_supervised_loss_share_fn(supervised_loss_share_power: Union[int, float]) -> Callable[[int, int], float]:
566 | def _supervised_loss_share_fn(current_step: int, max_steps: int) -> float:
567 | assert current_step <= max_steps
568 | return 1 - (current_step / max_steps) ** supervised_loss_share_power
569 |
570 | return _supervised_loss_share_fn
571 |
572 | supervised_loss_share_fn = get_supervised_loss_share_fn(args.supervised_loss_share_power)
573 |
574 | # Run
575 | run_proto(
576 | train_path=args.train_path,
577 | valid_path=args.valid_path,
578 | test_path=args.test_path,
579 | output_path=args.output_path,
580 | unlabeled_path=args.unlabeled_path,
581 |
582 | model_name_or_path=args.model_name_or_path,
583 | n_unlabeled=args.n_unlabeled,
584 |
585 | n_support=args.n_support,
586 | n_query=args.n_query,
587 | n_classes=args.n_classes,
588 | n_test_episodes=args.n_test_episodes,
589 | n_augment=args.n_augment,
590 |
591 | max_iter=args.max_iter,
592 | evaluate_every=args.evaluate_every,
593 |
594 | metric=args.metric,
595 | early_stop=args.early_stop,
596 | arsc_format=args.arsc_format,
597 | data_path=args.data_path,
598 | log_every=args.log_every,
599 |
600 | supervised_loss_share_fn=supervised_loss_share_fn
601 |
602 | )
603 |
604 | # Save config
605 | with open(os.path.join(args.output_path, "config.json"), "w") as file:
606 | json.dump(vars(args), file, ensure_ascii=False)
607 |
608 |
609 | if __name__ == '__main__':
610 | main()
611 |
--------------------------------------------------------------------------------
/models/bert_baseline/baseline.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 |
4 | import tqdm
5 |
6 | from models.encoders.bert_encoder import BERTEncoder
7 | from utils.data import get_jsonl_data
8 | from utils.python import now, set_seeds
9 | from utils.few_shot import create_ARSC_train_episode, get_ARSC_test_tasks
10 | import random
11 | import collections
12 | import os
13 | from typing import List, Dict
14 | from tensorboardX import SummaryWriter
15 | import numpy as np
16 | import torch
17 | import torch.nn as nn
18 | import warnings
19 | import logging
20 | from utils.math import euclidean_dist, cosine_similarity
21 |
22 | logging.basicConfig()
23 | logger = logging.getLogger(__name__)
24 | logger.setLevel(logging.DEBUG)
25 |
26 | warnings.simplefilter('ignore')
27 |
28 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
29 |
30 |
31 | class BaselineNet(nn.Module):
32 | def __init__(
33 | self,
34 | encoder,
35 | is_pp: bool = False,
36 | hidden_dim: int = 768,
37 | metric: str = "cosine"
38 | ):
39 | super(BaselineNet, self).__init__()
40 | self.encoder = encoder
41 | self.dropout = nn.Dropout(p=0.25).to(device)
42 | self.is_pp = is_pp
43 | self.hidden_dim = hidden_dim
44 | self.metric = metric
45 | assert self.metric in ("euclidean", "cosine")
46 |
47 | def train_ARSC_one_episode(
48 | self,
49 | data_path: str,
50 | n_iter: int = 100,
51 | ):
52 | self.train()
53 | episode = create_ARSC_train_episode(prefix=data_path, n_support=5, n_query=0, n_unlabeled=0)
54 | n_episode_classes = len(episode["xs"])
55 | loss_fn = nn.CrossEntropyLoss()
56 | episode_matrix = None
57 | episode_classifier = None
58 | if self.is_pp:
59 | with torch.no_grad():
60 | init_matrix = np.array([
61 | [
62 | self.encoder.forward([sentence]).squeeze().cpu().detach().numpy()
63 | for sentence in episode["xs"][c]
64 | ]
65 | for c in range(n_episode_classes)
66 | ]).mean(1)
67 |
68 | episode_matrix = torch.Tensor(init_matrix).to(device)
69 | episode_matrix.requires_grad = True
70 | optimizer = torch.optim.Adam(list(self.parameters()) + [episode_matrix], lr=2e-5)
71 | else:
72 | episode_classifier = nn.Linear(in_features=self.hidden_dim, out_features=n_episode_classes).to(device)
73 | optimizer = torch.optim.Adam(list(self.parameters()) + list(episode_classifier.parameters()), lr=2e-5)
74 |
75 | # Train on support
76 | iter_bar = tqdm.tqdm(range(n_iter))
77 | losses = list()
78 | accuracies = list()
79 |
80 | for _ in iter_bar:
81 | optimizer.zero_grad()
82 |
83 | sentences = [sentence for sentence_list in episode["xs"] for sentence in sentence_list]
84 | labels = torch.Tensor([ix for ix, sl in enumerate(episode["xs"]) for _ in sl]).long().to(device)
85 | z = self.encoder(sentences)
86 |
87 | # z = batch_embeddings
88 |
89 | if self.is_pp:
90 | if self.metric == "cosine":
91 | z = cosine_similarity(z, episode_matrix) * 5
92 | elif self.metric == "euclidean":
93 | z = -euclidean_dist(z, episode_matrix)
94 | else:
95 | raise NotImplementedError
96 | else:
97 | z = self.dropout(z)
98 | z = episode_classifier(z)
99 |
100 | loss = loss_fn(input=z, target=labels)
101 | acc = (z.argmax(1) == labels).float().mean()
102 | loss.backward()
103 | optimizer.step()
104 | iter_bar.set_description(f"{loss.item():.3f} | {acc.item():.3f}")
105 | losses.append(loss.item())
106 | accuracies.append(acc.item())
107 | return {
108 | "loss": np.mean(losses),
109 | "acc": np.mean(accuracies)
110 | }
111 |
112 | def run_ARSC(
113 | self,
114 | data_path: str,
115 | train_summary_writer: SummaryWriter = None,
116 | valid_summary_writer: SummaryWriter = None,
117 | test_summary_writer: SummaryWriter = None,
118 | n_episodes: int = 1000,
119 | n_train_iter: int = 100,
120 | train_eval_every: int = 100,
121 | n_test_iter: int = 1000,
122 | test_eval_every: int = 100,
123 | ):
124 | metrics = list()
125 | for episode_ix in range(n_episodes):
126 | output = self.train_ARSC_one_episode(data_path=data_path, n_iter=n_train_iter)
127 | episode_metrics = {
128 | "train": output
129 | }
130 |
131 | if train_summary_writer:
132 | train_summary_writer.add_scalar(tag=f'loss', global_step=episode_ix, scalar_value=output["loss"])
133 | train_summary_writer.add_scalar(tag=f'acc', global_step=episode_ix, scalar_value=output["acc"])
134 |
135 | # Running evaluation
136 | if (train_eval_every and (episode_ix + 1) % train_eval_every == 0) or (not train_eval_every and episode_ix + 1 == n_episodes):
137 | test_metrics = self.test_model_ARSC(
138 | data_path=data_path,
139 | valid_summary_writer=valid_summary_writer,
140 | test_summary_writer=test_summary_writer,
141 | n_iter=n_test_iter,
142 | eval_every=test_eval_every
143 | )
144 | episode_metrics["test"] = test_metrics
145 |
146 | metrics.append(episode_metrics)
147 | return metrics
148 |
149 | def test_model_ARSC(
150 | self,
151 | data_path: str,
152 | n_iter: int = 1000,
153 | valid_summary_writer: SummaryWriter = None,
154 | test_summary_writer: SummaryWriter = None,
155 | eval_every: int = 100
156 | ):
157 | self.eval()
158 |
159 | tasks = get_ARSC_test_tasks(prefix=data_path)
160 | metrics = list()
161 | logger.info("Embedding sentences...")
162 | sentences_to_embed = [
163 | s
164 | for task in tasks
165 | for sentences_lists in task['xs'] + task['x_test'] + task['x_valid']
166 | for s in sentences_lists
167 | ]
168 |
169 | # sentence_to_embedding_dict = {s: np.random.randn(768) for s in tqdm.tqdm(sentences_to_embed)}
170 | sentence_to_embedding_dict = {s: self.encoder.forward([s]).cpu().detach().numpy().squeeze() for s in tqdm.tqdm(sentences_to_embed)}
171 | for ix_task, task in enumerate(tasks):
172 | task_metrics = list()
173 |
174 | n_episode_classes = 2
175 | loss_fn = nn.CrossEntropyLoss()
176 | episode_matrix = None
177 | episode_classifier = None
178 | if self.is_pp:
179 | with torch.no_grad():
180 | init_matrix = np.array([
181 | [
182 | sentence_to_embedding_dict[sentence]
183 | for sentence in task["xs"][c]
184 | ]
185 | for c in range(n_episode_classes)
186 | ]).mean(1)
187 |
188 | episode_matrix = torch.Tensor(init_matrix).to(device)
189 | episode_matrix.requires_grad = True
190 | optimizer = torch.optim.Adam([episode_matrix], lr=2e-5)
191 | else:
192 | episode_classifier = nn.Linear(in_features=self.hidden_dim, out_features=n_episode_classes).to(device)
193 | optimizer = torch.optim.Adam(list(episode_classifier.parameters()), lr=2e-5)
194 |
195 | # Train on support
196 | iter_bar = tqdm.tqdm(range(n_iter))
197 | losses = list()
198 | accuracies = list()
199 |
200 | for iteration in iter_bar:
201 | optimizer.zero_grad()
202 |
203 | sentences = [sentence for sentence_list in task["xs"] for sentence in sentence_list]
204 | labels = torch.Tensor([ix for ix, sl in enumerate(task["xs"]) for _ in sl]).long().to(device)
205 | batch_embeddings = torch.Tensor([sentence_to_embedding_dict[s] for s in sentences]).to(device)
206 | # z = self.encoder(sentences)
207 | z = batch_embeddings
208 |
209 | if self.is_pp:
210 | if self.metric == "cosine":
211 | z = cosine_similarity(z, episode_matrix) * 5
212 | elif self.metric == "euclidean":
213 | z = -euclidean_dist(z, episode_matrix)
214 | else:
215 | raise NotImplementedError
216 | else:
217 | z = self.dropout(z)
218 | z = episode_classifier(z)
219 |
220 | loss = loss_fn(input=z, target=labels)
221 | acc = (z.argmax(1) == labels).float().mean()
222 | loss.backward()
223 | optimizer.step()
224 | iter_bar.set_description(f"{loss.item():.3f} | {acc.item():.3f}")
225 | losses.append(loss.item())
226 | accuracies.append(acc.item())
227 |
228 | if (eval_every and (iteration + 1) % eval_every == 0) or (not eval_every and iteration + 1 == n_iter):
229 | self.eval()
230 | if not self.is_pp:
231 | episode_classifier.eval()
232 |
233 | # --------------
234 | # VALIDATION
235 | # --------------
236 | valid_query_data_list = [
237 | {"sentence": sentence, "label": label}
238 | for label, sentences in enumerate(task["x_valid"])
239 | for sentence in sentences
240 | ]
241 |
242 | valid_query_labels = torch.Tensor([d['label'] for d in valid_query_data_list]).long().to(device)
243 | logits = list()
244 | with torch.no_grad():
245 | for ix in range(0, len(valid_query_data_list), 16):
246 | batch = valid_query_data_list[ix:ix + 16]
247 | batch_sentences = [d['sentence'] for d in batch]
248 | batch_embeddings = torch.Tensor([sentence_to_embedding_dict[s] for s in batch_sentences]).to(device)
249 | # z = self.encoder(batch_sentences)
250 | z = batch_embeddings
251 |
252 | if self.is_pp:
253 | if self.metric == "cosine":
254 | z = cosine_similarity(z, episode_matrix) * 5
255 | elif self.metric == "euclidean":
256 | z = -euclidean_dist(z, episode_matrix)
257 | else:
258 | raise NotImplementedError
259 | else:
260 | z = episode_classifier(z)
261 |
262 | logits.append(z)
263 | logits = torch.cat(logits, dim=0)
264 | y_hat = logits.argmax(1)
265 |
266 | valid_loss = loss_fn(input=logits, target=valid_query_labels)
267 | valid_acc = (y_hat == valid_query_labels).float().mean()
268 |
269 | # --------------
270 | # TEST
271 | # --------------
272 | test_query_data_list = [
273 | {"sentence": sentence, "label": label}
274 | for label, sentences in enumerate(task["x_test"])
275 | for sentence in sentences
276 | ]
277 |
278 | test_query_labels = torch.Tensor([d['label'] for d in test_query_data_list]).long().to(device)
279 | logits = list()
280 | with torch.no_grad():
281 | for ix in range(0, len(test_query_data_list), 16):
282 | batch = test_query_data_list[ix:ix + 16]
283 | batch_sentences = [d['sentence'] for d in batch]
284 | batch_embeddings = torch.Tensor([sentence_to_embedding_dict[s] for s in batch_sentences]).to(device)
285 | # z = self.encoder(batch_sentences)
286 | z = batch_embeddings
287 |
288 | if self.is_pp:
289 | if self.metric == "cosine":
290 | z = cosine_similarity(z, episode_matrix) * 5
291 | elif self.metric == "euclidean":
292 | z = -euclidean_dist(z, episode_matrix)
293 | else:
294 | raise NotImplementedError
295 | else:
296 | z = episode_classifier(z)
297 |
298 | logits.append(z)
299 | logits = torch.cat(logits, dim=0)
300 | y_hat = logits.argmax(1)
301 |
302 | test_loss = loss_fn(input=logits, target=test_query_labels)
303 | test_acc = (y_hat == test_query_labels).float().mean()
304 |
305 | # --RETURN METRICS
306 | task_metrics.append({
307 | "test": {
308 | "loss": test_loss.item(),
309 | "acc": test_acc.item()
310 | },
311 | "valid": {
312 | "loss": valid_loss.item(),
313 | "acc": valid_acc.item()
314 | },
315 | "step": iteration + 1
316 | })
317 | # if valid_summary_writer:
318 | # valid_summary_writer.add_scalar(tag=f'loss', global_step=ix_task, scalar_value=valid_loss.item())
319 | # valid_summary_writer.add_scalar(tag=f'acc', global_step=ix_task, scalar_value=valid_acc.item())
320 | # if test_summary_writer:
321 | # test_summary_writer.add_scalar(tag=f'loss', global_step=ix_task, scalar_value=test_loss.item())
322 | # test_summary_writer.add_scalar(tag=f'acc', global_step=ix_task, scalar_value=test_acc.item())
323 | metrics.append(task_metrics)
324 | return metrics
325 |
326 | def train_model(
327 | self,
328 | data_dict: Dict[str, List[str]],
329 | summary_writer: SummaryWriter = None,
330 | n_epoch: int = 400,
331 | batch_size: int = 16,
332 | log_every: int = 10):
333 | self.train()
334 |
335 | training_classes = sorted(set(data_dict.keys()))
336 | n_training_classes = len(training_classes)
337 | class_to_ix = {c: ix for ix, c in enumerate(training_classes)}
338 | training_data_list = [{"sentence": sentence, "label": label} for label, sentences in data_dict.items() for sentence in sentences]
339 |
340 | training_matrix = None
341 | training_classifier = None
342 |
343 | if self.is_pp:
344 | training_matrix = torch.randn(n_training_classes, self.hidden_dim, requires_grad=True, device=device)
345 | optimizer = torch.optim.Adam(list(self.parameters()) + [training_matrix], lr=2e-5)
346 | else:
347 | training_classifier = nn.Linear(in_features=self.hidden_dim, out_features=n_training_classes).to(device)
348 | optimizer = torch.optim.Adam(list(self.parameters()) + list(training_classifier.parameters()), lr=2e-5)
349 |
350 | n_samples = len(training_data_list)
351 | loss_fn = nn.CrossEntropyLoss()
352 | global_step = 0
353 |
354 | # Metrics
355 | training_losses = list()
356 | training_accuracies = list()
357 |
358 | for _ in tqdm.tqdm(range(n_epoch)):
359 | random.shuffle(training_data_list)
360 | for ix in tqdm.tqdm(range(0, n_samples, batch_size)):
361 | optimizer.zero_grad()
362 | torch.cuda.empty_cache()
363 |
364 | batch_items = training_data_list[ix:ix + batch_size]
365 | batch_sentences = [d['sentence'] for d in batch_items]
366 | batch_labels = torch.Tensor([class_to_ix[d['label']] for d in batch_items]).long().to(device)
367 | z = self.encoder(batch_sentences)
368 | if self.is_pp:
369 | if self.metric == "cosine":
370 | z = cosine_similarity(z, training_matrix) * 5
371 | elif self.metric == "euclidean":
372 | z = -euclidean_dist(z, training_matrix)
373 | else:
374 | raise NotImplementedError
375 | else:
376 | z = self.dropout(z)
377 | z = training_classifier(z)
378 | loss = loss_fn(input=z, target=batch_labels)
379 | acc = (z.argmax(1) == batch_labels).float().mean()
380 | loss.backward()
381 | optimizer.step()
382 |
383 | global_step += 1
384 | training_losses.append(loss.item())
385 | training_accuracies.append(acc.item())
386 | if (global_step % log_every) == 0:
387 | if summary_writer:
388 | summary_writer.add_scalar(tag="loss", global_step=global_step, scalar_value=np.mean(training_losses))
389 | summary_writer.add_scalar(tag="acc", global_step=global_step, scalar_value=np.mean(training_accuracies))
390 | # Empty metrics
391 | training_losses = list()
392 | training_accuracies = list()
393 |
394 | def test_one_episode(
395 | self,
396 | support_data_dict: Dict[str, List[str]],
397 | query_data_dict: Dict[str, List[str]],
398 | sentence_to_embedding_dict: Dict,
399 | batch_size: int = 4,
400 | n_iter: int = 1000,
401 | summary_writer: SummaryWriter = None,
402 | summary_tag_prefix: str = None,
403 | ):
404 |
405 | # Check data integrity
406 | assert set(support_data_dict.keys()) == set(query_data_dict.keys())
407 |
408 | # Freeze encoder
409 | self.encoder.eval()
410 |
411 | episode_classes = sorted(set(support_data_dict.keys()))
412 | n_episode_classes = len(episode_classes)
413 | class_to_ix = {c: ix for ix, c in enumerate(episode_classes)}
414 | ix_to_class = {ix: c for ix, c in enumerate(episode_classes)}
415 | support_data_list = [{"sentence": sentence, "label": label} for label, sentences in support_data_dict.items() for sentence in sentences]
416 | support_data_list = (support_data_list * batch_size * n_iter)[:(batch_size * n_iter)]
417 |
418 | loss_fn = nn.CrossEntropyLoss()
419 | episode_matrix = None
420 | episode_classifier = None
421 | if self.is_pp:
422 | init_matrix = np.array([
423 | [
424 | sentence_to_embedding_dict[sentence].ravel()
425 | for sentence in support_data_dict[ix_to_class[c]]
426 | ]
427 | for c in range(n_episode_classes)
428 | ]).mean(1)
429 |
430 | episode_matrix = torch.Tensor(init_matrix).to(device)
431 | episode_matrix.requires_grad = True
432 | optimizer = torch.optim.Adam([episode_matrix], lr=1e-3)
433 | else:
434 | episode_classifier = nn.Linear(in_features=self.hidden_dim, out_features=n_episode_classes).to(device)
435 | optimizer = torch.optim.Adam(list(episode_classifier.parameters()), lr=1e-3)
436 |
437 | # Train on support
438 | iter_bar = tqdm.tqdm(range(n_iter))
439 | for iteration in iter_bar:
440 | optimizer.zero_grad()
441 |
442 | batch = support_data_list[iteration * batch_size: iteration * batch_size + batch_size]
443 | batch_sentences = [d['sentence'] for d in batch]
444 | batch_embeddings = torch.Tensor([sentence_to_embedding_dict[s] for s in batch_sentences]).to(device)
445 | batch_labels = torch.Tensor([class_to_ix[d['label']] for d in batch]).long().to(device)
446 | # z = self.encoder(batch_sentences)
447 | z = batch_embeddings
448 |
449 | if self.is_pp:
450 | if self.metric == "cosine":
451 | z = cosine_similarity(z, episode_matrix) * 5
452 | elif self.metric == "euclidean":
453 | z = -euclidean_dist(z, episode_matrix)
454 | else:
455 | raise NotImplementedError
456 | else:
457 | z = self.dropout(z)
458 | z = episode_classifier(z)
459 |
460 | loss = loss_fn(input=z, target=batch_labels)
461 | acc = (z.argmax(1) == batch_labels).float().mean()
462 | loss.backward()
463 | optimizer.step()
464 | iter_bar.set_description(f"{loss.item():.3f} | {acc.item():.3f}")
465 |
466 | if summary_writer:
467 | summary_writer.add_scalar(tag=f'{summary_tag_prefix}_loss', global_step=iteration, scalar_value=loss.item())
468 | summary_writer.add_scalar(tag=f'{summary_tag_prefix}_acc', global_step=iteration, scalar_value=acc.item())
469 |
470 | # Predict on query
471 | self.eval()
472 | if not self.is_pp:
473 | episode_classifier.eval()
474 |
475 | query_data_list = [{"sentence": sentence, "label": label} for label, sentences in query_data_dict.items() for sentence in sentences]
476 | query_labels = torch.Tensor([class_to_ix[d['label']] for d in query_data_list]).long().to(device)
477 | logits = list()
478 | with torch.no_grad():
479 | for ix in range(0, len(query_data_list), 16):
480 | batch = query_data_list[ix:ix + 16]
481 | batch_sentences = [d['sentence'] for d in batch]
482 | batch_embeddings = torch.Tensor([sentence_to_embedding_dict[s] for s in batch_sentences]).to(device)
483 | # z = self.encoder(batch_sentences)
484 | z = batch_embeddings
485 |
486 | if self.is_pp:
487 | if self.metric == "cosine":
488 | z = cosine_similarity(z, episode_matrix) * 5
489 | elif self.metric == "euclidean":
490 | z = -euclidean_dist(z, episode_matrix)
491 | else:
492 | raise NotImplementedError
493 | else:
494 | z = episode_classifier(z)
495 |
496 | logits.append(z)
497 | logits = torch.cat(logits, dim=0)
498 | y_hat = logits.argmax(1)
499 |
500 | y_pred = logits.argmax(1).cpu().detach().numpy()
501 | probas_pred = logits.cpu().detach().numpy()
502 | probas_pred = np.exp(probas_pred) / np.exp(probas_pred).sum(1)[:, None]
503 |
504 | y_true = query_labels.cpu().detach().numpy()
505 | where_ok = np.where(y_pred == y_true)[0]
506 | import uuid
507 | tag = str(uuid.uuid4())
508 | summary_writer.add_text(tag=tag, text_string=json.dumps(ix_to_class, ensure_ascii=False), global_step=0)
509 | if len(where_ok):
510 | # Looking for OK but with less confidence (not too easy)
511 | ok_idx = sorted(where_ok, key=lambda x: probas_pred[x][y_pred[x]])[0]
512 | ok_sentence = query_data_list[ok_idx]['sentence']
513 | ok_prediction = ix_to_class[y_pred[ok_idx]]
514 | ok_label = query_data_list[ok_idx]['label']
515 | summary_writer.add_text(
516 | tag=tag,
517 | text_string=json.dumps({
518 | "sentence": ok_sentence,
519 | "true_label": ok_label,
520 | "predicted_label": ok_prediction,
521 | "p": probas_pred[ok_idx].tolist(),
522 | }),
523 | global_step=1)
524 |
525 | where_ko = np.where(y_pred != y_true)[0]
526 | if len(where_ko):
527 | # Looking for KO but with most confidence
528 | ko_idx = sorted(where_ko, key=lambda x: probas_pred[x][y_pred[x]], reverse=True)[0]
529 | ko_sentence = query_data_list[ko_idx]['sentence']
530 | ko_prediction = ix_to_class[y_pred[ko_idx]]
531 | ko_label = query_data_list[ko_idx]['label']
532 | summary_writer.add_text(
533 | tag=tag,
534 | text_string=json.dumps({
535 | "sentence": ko_sentence,
536 | "true_label": ko_label,
537 | "predicted_label": ko_prediction,
538 | "p": probas_pred[ko_idx].tolist()
539 | }),
540 | global_step=2)
541 |
542 | loss = loss_fn(input=logits, target=query_labels)
543 | acc = (y_hat == query_labels).float().mean()
544 |
545 | return {
546 | "loss": loss.item(),
547 | "acc": acc.item()
548 | }
549 |
550 | def test_model(
551 | self,
552 | data_dict: Dict[str, List[str]],
553 | n_support: int,
554 | n_classes: int,
555 | n_episodes=600,
556 | summary_writer: SummaryWriter = None,
557 | n_test_iter: int = 100,
558 | test_batch_size: int = 4
559 | ):
560 | test_metrics = list()
561 |
562 | # Freeze encoder
563 | self.encoder.eval()
564 | logger.info("Embedding sentences...")
565 | sentences_to_embed = [s for label, sentences in data_dict.items() for s in sentences]
566 | sentence_to_embedding_dict = {s: self.encoder.forward([s]).cpu().detach().numpy().squeeze() for s in tqdm.tqdm(sentences_to_embed)}
567 |
568 | for episode in tqdm.tqdm(range(n_episodes)):
569 | episode_classes = np.random.choice(list(data_dict.keys()), size=n_classes, replace=False)
570 | episode_query_data_dict = dict()
571 | episode_support_data_dict = dict()
572 |
573 | for episode_class in episode_classes:
574 | random.shuffle(data_dict[episode_class])
575 | episode_support_data_dict[episode_class] = data_dict[episode_class][:n_support]
576 | episode_query_data_dict[episode_class] = data_dict[episode_class][n_support:]
577 |
578 | episode_metrics = self.test_one_episode(
579 | support_data_dict=episode_support_data_dict,
580 | query_data_dict=episode_query_data_dict,
581 | n_iter=n_test_iter,
582 | batch_size=test_batch_size,
583 | sentence_to_embedding_dict=sentence_to_embedding_dict,
584 | summary_writer=summary_writer
585 | )
586 | logger.info(f"Episode metrics: {episode_metrics}")
587 | test_metrics.append(episode_metrics)
588 | for metric_name, metric_value in episode_metrics.items():
589 | summary_writer.add_scalar(tag=metric_name, global_step=episode, scalar_value=metric_value)
590 |
591 | return test_metrics
592 |
593 |
594 | def run_baseline(
595 | train_path: str,
596 | model_name_or_path: str,
597 | n_support: int,
598 | n_classes: int,
599 | valid_path: str = None,
600 | test_path: str = None,
601 | output_path: str = f'runs/{now()}',
602 | n_test_episodes: int = 600,
603 | log_every: int = 10,
604 | n_train_epoch: int = 400,
605 | train_batch_size: int = 16,
606 | is_pp: bool = False,
607 | test_batch_size: int = 4,
608 | n_test_iter: int = 100,
609 | metric: str = "cosine",
610 | arsc_format: bool = False,
611 | data_path: str = None
612 | ):
613 | if output_path:
614 | if os.path.exists(output_path) and len(os.listdir(output_path)):
615 | raise FileExistsError(f"Output path {output_path} already exists. Exiting.")
616 |
617 | # --------------------
618 | # Creating Log Writers
619 | # --------------------
620 | os.makedirs(output_path)
621 | os.makedirs(os.path.join(output_path, "logs/train"))
622 | train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(output_path, "logs/train"), flush_secs=1, max_queue=1)
623 | valid_writer: SummaryWriter = None
624 | test_writer: SummaryWriter = None
625 | log_dict = dict(train=list())
626 |
627 | if valid_path:
628 | os.makedirs(os.path.join(output_path, "logs/valid"))
629 | valid_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/valid"), flush_secs=1, max_queue=1)
630 | log_dict["valid"] = list()
631 | if test_path:
632 | os.makedirs(os.path.join(output_path, "logs/test"))
633 | test_writer = SummaryWriter(logdir=os.path.join(output_path, "logs/test"), flush_secs=1, max_queue=1)
634 | log_dict["test"] = list()
635 |
636 | def raw_data_to_labels_dict(data, shuffle=True):
637 | labels_dict = collections.defaultdict(list)
638 | for item in data:
639 | labels_dict[item['label']].append(item["sentence"])
640 | labels_dict = dict(labels_dict)
641 | if shuffle:
642 | for key, val in labels_dict.items():
643 | random.shuffle(val)
644 | return labels_dict
645 |
646 | # Load model
647 | bert = BERTEncoder(model_name_or_path).to(device)
648 | baseline_net = BaselineNet(encoder=bert, is_pp=is_pp, metric=metric).to(device)
649 |
650 | # Load data
651 | if not arsc_format:
652 | train_data = get_jsonl_data(train_path)
653 | train_data_dict = raw_data_to_labels_dict(train_data, shuffle=True)
654 | logger.info(f"train labels: {train_data_dict.keys()}")
655 |
656 | if valid_path:
657 | valid_data = get_jsonl_data(valid_path)
658 | valid_data_dict = raw_data_to_labels_dict(valid_data, shuffle=True)
659 | logger.info(f"valid labels: {valid_data_dict.keys()}")
660 | else:
661 | valid_data_dict = None
662 |
663 | if test_path:
664 | test_data = get_jsonl_data(test_path)
665 | test_data_dict = raw_data_to_labels_dict(test_data, shuffle=True)
666 | logger.info(f"test labels: {test_data_dict.keys()}")
667 | else:
668 | test_data_dict = None
669 |
670 | baseline_net.train_model(
671 | data_dict=train_data_dict,
672 | summary_writer=train_writer,
673 | n_epoch=n_train_epoch,
674 | batch_size=train_batch_size,
675 | log_every=log_every
676 | )
677 |
678 | # Validation
679 | if valid_path:
680 | validation_metrics = baseline_net.test_model(
681 | data_dict=valid_data_dict,
682 | n_support=n_support,
683 | n_classes=n_classes,
684 | n_episodes=n_test_episodes,
685 | summary_writer=valid_writer,
686 | n_test_iter=n_test_iter,
687 | test_batch_size=test_batch_size
688 | )
689 | with open(os.path.join(output_path, 'validation_metrics.json'), "w") as file:
690 | json.dump(validation_metrics, file, ensure_ascii=False)
691 | # Test
692 | if test_path:
693 | test_metrics = baseline_net.test_model(
694 | data_dict=test_data_dict,
695 | n_support=n_support,
696 | n_classes=n_classes,
697 | n_episodes=n_test_episodes,
698 | summary_writer=test_writer
699 | )
700 |
701 | with open(os.path.join(output_path, 'test_metrics.json'), "w") as file:
702 | json.dump(test_metrics, file, ensure_ascii=False)
703 |
704 | else:
705 | # baseline_net.train_model_ARSC(
706 | # train_summary_writer=train_writer,
707 | # n_episodes=10,
708 | # n_train_iter=20
709 | # )
710 | # metrics = baseline_net.test_model_ARSC(
711 | # n_iter=n_test_iter,
712 | # valid_summary_writer=valid_writer,
713 | # test_summary_writer=test_writer
714 | # )
715 | metrics = baseline_net.run_ARSC(
716 | train_summary_writer=train_writer,
717 | valid_summary_writer=valid_writer,
718 | test_summary_writer=test_writer,
719 | n_episodes=1000,
720 | train_eval_every=50,
721 | n_train_iter=50,
722 | n_test_iter=200,
723 | test_eval_every=25,
724 | data_path=data_path
725 | )
726 | with open(os.path.join(output_path, 'baseline_metrics.json'), "w") as file:
727 | json.dump(metrics, file, ensure_ascii=False)
728 |
729 |
730 | def main():
731 | parser = argparse.ArgumentParser()
732 | parser.add_argument("--train-path", type=str, required=True, help="Path to training data")
733 | parser.add_argument("--valid-path", type=str, default=None, help="Path to validation data")
734 | parser.add_argument("--test-path", type=str, default=None, help="Path to testing data")
735 | parser.add_argument("--data-path", type=str, default=None, help="Path to data (ARSC only)")
736 |
737 | parser.add_argument("--output-path", type=str, default=f'runs/{now()}')
738 | parser.add_argument("--model-name-or-path", type=str, required=True, help="Transformer model to use")
739 | parser.add_argument("--log-every", type=int, default=10, help="Number of training episodes between each logging")
740 | parser.add_argument("--seed", type=int, default=42, help="Random seed to set")
741 |
742 | # Few-Shot related stuff
743 | parser.add_argument("--n-support", type=int, default=5, help="Number of support points for each class")
744 | parser.add_argument("--n-classes", type=int, default=5, help="Number of classes per episode")
745 | parser.add_argument("--n-test-episodes", type=int, default=600, help="Number of episodes during evaluation (valid, test)")
746 | parser.add_argument("--n-train-epoch", type=int, default=400, help="Number of epoch during training")
747 | parser.add_argument("--train-batch-size", type=int, default=16, help="Batch size used during training")
748 | parser.add_argument("--n-test-iter", type=int, default=100, help="Number of training iterations during testing episodes")
749 | parser.add_argument("--test-batch-size", type=int, default=4, help="Batch size used during training")
750 |
751 | # Baseline++
752 | parser.add_argument("--pp", default=False, action="store_true", help="Boolean to use the ++ baseline model")
753 | parser.add_argument("--metric", default="cosine", type=str, help="Which metric to use in baseline++", choices=("euclidean", "cosine"))
754 | # ARSC data
755 | parser.add_argument("--arsc-format", default=False, action="store_true", help="Using ARSC few-shot format")
756 |
757 | args = parser.parse_args()
758 | logger.info(f"Received args:{args}")
759 |
760 | # Set random seed
761 | set_seeds(args.seed)
762 |
763 | # Check if data path(s) exist
764 | for arg in [args.train_path, args.valid_path, args.test_path]:
765 | if arg and not os.path.exists(arg):
766 | raise FileNotFoundError(f"Data @ {arg} not found.")
767 |
768 | # Run
769 | run_baseline(
770 | train_path=args.train_path,
771 | valid_path=args.valid_path,
772 | test_path=args.test_path,
773 | data_path=args.data_path,
774 | output_path=args.output_path,
775 |
776 | model_name_or_path=args.model_name_or_path,
777 |
778 | n_support=args.n_support,
779 | n_classes=args.n_classes,
780 | n_test_episodes=args.n_test_episodes,
781 | log_every=args.log_every,
782 | n_train_epoch=args.n_train_epoch,
783 | train_batch_size=args.train_batch_size,
784 | is_pp=args.pp,
785 |
786 | test_batch_size=args.test_batch_size,
787 | n_test_iter=args.n_test_iter,
788 | metric=args.metric,
789 | arsc_format=args.arsc_format
790 | )
791 |
792 | # Save config
793 | with open(os.path.join(args.output_path, "config.json"), "w") as file:
794 | json.dump(vars(args), file, ensure_ascii=False, indent=1)
795 |
796 |
797 | if __name__ == '__main__':
798 | main()
799 |
--------------------------------------------------------------------------------