├── mabe ├── __init__.py ├── config.py ├── types.py ├── util.py ├── features.py ├── data.py ├── training.py ├── loss.py └── model.py ├── .isort.cfg ├── .flake8 ├── .pylintrc ├── pyproject.toml ├── setup.py ├── .pre-commit-config.yaml ├── LICENSE ├── hydrogen ├── merge_results.py ├── getposepca.py ├── create_splits.py ├── features.py ├── merge_results_alltasks.py ├── eda.py └── task3_onlylinear.py ├── scripts └── train.py └── README.md /mabe/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | line_length=100 3 | profile="black" 4 | forced_separate=mabe 5 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503 3 | max-line-length = 100 4 | select = B,C,E,F,W,T4,B9 5 | -------------------------------------------------------------------------------- /mabe/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | API_KEY = None 5 | ROOT_PATH = pathlib.Path(os.getenv("MABE_ROOT_PATH", "/srv/data/benwild/mabe")) 6 | -------------------------------------------------------------------------------- /mabe/types.py: -------------------------------------------------------------------------------- 1 | # dummy variables for flake8 2 | # see: https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md 3 | 4 | batch = None 5 | channels = None 6 | time = None 7 | behavior = None 8 | annotator = None 9 | classes = None 10 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [TYPECHECK] 2 | # List of members which are set dynamically and missed by pylint inference 3 | # system, and so shouldn't trigger E1101 when accessed. Python regular 4 | # expressions are accepted. 5 | generated-members=numpy.*,torch.* 6 | 7 | [MESSAGES CONTROL] 8 | disable=unused-argument 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | target-version = ['py36', 'py37', 'py38'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | /( 7 | \.eggs 8 | | \.git 9 | | \.hg 10 | | \.mypy_cache 11 | | \.tox 12 | | \.venv 13 | | _build 14 | | buck-out 15 | | build 16 | | dist 17 | )/ 18 | ''' 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | try: 2 | from setuptools import setup 3 | except ImportError: 4 | from distutils.core import setup 5 | 6 | 7 | setup( 8 | name="mabe", 9 | version="0.1", 10 | description="", 11 | author="Benjamin Wild", 12 | author_email="b.w@fu-berlin.de", 13 | packages=["mabe"], 14 | install_requires=[ 15 | "numpy", 16 | "scipy", 17 | "pandas", 18 | "sklearn", 19 | "torch", 20 | "numba", 21 | "joblib", 22 | "h5py", 23 | "matplotlib", 24 | "torchtyping", 25 | "fastprogress", 26 | "madgrad", 27 | "bayesian-optimization", 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.4.0 4 | hooks: 5 | - id: check-added-large-files 6 | - id: check-ast 7 | - id: check-case-conflict 8 | - id: check-executables-have-shebangs 9 | - id: check-yaml 10 | - id: end-of-file-fixer 11 | - id: trailing-whitespace 12 | - id: requirements-txt-fixer 13 | - repo: https://github.com/asottile/pyupgrade 14 | rev: v2.9.0 15 | hooks: 16 | - id: pyupgrade 17 | args: [--py38-plus] 18 | - repo: https://github.com/pycqa/isort 19 | rev: 5.7.0 20 | hooks: 21 | - id: isort 22 | args: ["--profile", "black"] 23 | - repo: https://github.com/psf/black 24 | rev: 20.8b1 25 | hooks: 26 | - id: black 27 | language_version: python3 28 | - repo: https://github.com/kynan/nbstripout 29 | rev: 0.3.9 30 | hooks: 31 | - id: nbstripout 32 | - repo: https://github.com/pre-commit/mirrors-mypy 33 | rev: 'v0.812' 34 | hooks: 35 | - id: mypy 36 | args: ["--ignore-missing-imports"] 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Benjamin Wild 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /hydrogen/merge_results.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import scipy 4 | import scipy.special 5 | import torch 6 | from fastprogress.fastprogress import force_console_behavior 7 | 8 | import mabe 9 | import mabe.config 10 | import mabe.loss 11 | import mabe.model 12 | import mabe.training 13 | 14 | # %% 15 | master_bar, progress_bar = force_console_behavior() 16 | 17 | # %% 18 | results_path = ( 19 | mabe.config.ROOT_PATH / "training_results_2021-04-10 19:36:30.018611_baseline_madgrad_0.834.pt" 20 | ) 21 | results = torch.load(results_path) 22 | 23 | # %% 24 | keys = results[0].test_logits.keys() 25 | 26 | merged_submission = {} 27 | for key in keys: 28 | preds = np.stack(list(map(lambda s: s.test_logits[key], results))) 29 | preds = np.argmax(scipy.special.softmax(preds, -1).mean(axis=0), axis=-1) 30 | 31 | merged_submission[key] = preds 32 | 33 | 34 | plt.plot(preds) 35 | 36 | 37 | # %% 38 | sample_submission = np.load( 39 | mabe.config.ROOT_PATH / "sample-submission.npy", allow_pickle=True 40 | ).item() 41 | 42 | 43 | def validate_submission(submission, sample_submission): 44 | if not isinstance(submission, dict): 45 | print("Submission should be dict") 46 | return False 47 | 48 | if not submission.keys() == sample_submission.keys(): 49 | print("Submission keys don't match") 50 | return False 51 | 52 | for key in submission: 53 | sv = submission[key] 54 | ssv = sample_submission[key] 55 | if not len(sv) == len(ssv): 56 | print(f"Submission lengths of {key} doesn't match") 57 | return False 58 | 59 | for key, sv in submission.items(): 60 | if not all(isinstance(x, (np.int32, np.int64, int)) for x in list(sv)): 61 | print(f"Submission of {key} is not all integers") 62 | return False 63 | 64 | print("All tests passed") 65 | return True 66 | 67 | 68 | if validate_submission(merged_submission, sample_submission): 69 | np.save(mabe.config.ROOT_PATH / "submission5.npy", merged_submission) 70 | -------------------------------------------------------------------------------- /hydrogen/getposepca.py: -------------------------------------------------------------------------------- 1 | import fastprogress 2 | import joblib 3 | import numpy as np 4 | import scipy.spatial 5 | import scipy.spatial.distance 6 | import scipy.special 7 | import sklearn 8 | import sklearn.decomposition 9 | 10 | import mabe 11 | import mabe.config 12 | import mabe.features 13 | 14 | # %% codecell 15 | master_bar, progress_bar = fastprogress.fastprogress.force_console_behavior() 16 | 17 | train_path = mabe.config.ROOT_PATH / "train.npy" 18 | train_task2_path = mabe.config.ROOT_PATH / "train_task2.npy" 19 | train_task3_path = mabe.config.ROOT_PATH / "train_task3.npy" 20 | test_path = mabe.config.ROOT_PATH / "test-release.npy" 21 | pca_path = mabe.config.ROOT_PATH / "pose-pca.joblib" 22 | 23 | # %% 24 | test = np.load(mabe.config.ROOT_PATH / "test-release.npy", allow_pickle=True).item() 25 | 26 | # %% codecell 27 | X, X_extra, Y, groups, annotators = mabe.features.load_dataset(train_path, raw_trajectories=True) 28 | 29 | # %% 30 | Xt2, Xt2_extra, Yt2, groupst2, annotatorst2 = mabe.features.load_dataset( 31 | train_task2_path, raw_trajectories=True 32 | ) 33 | 34 | X += Xt2 35 | X_extra += Xt2_extra 36 | Y += Yt2 37 | groups += groupst2 38 | annotators += annotatorst2 39 | clf_tasks = [0] * len(X) 40 | 41 | for behavior, data in mabe.features.load_task3_datasets(train_task3_path, raw_trajectories=True): 42 | clf_task = int(behavior[-1]) + 1 43 | print(behavior, clf_task) 44 | 45 | X += data[0] 46 | X_extra += data[1] 47 | Y += data[2] 48 | groups += data[3] 49 | annotators += data[4] 50 | clf_tasks += [clf_task] * len(data[0]) 51 | 52 | # %% codecell 53 | X_test, X_extra_test, Y_test, groups_test, _ = mabe.features.load_dataset( 54 | test_path, raw_trajectories=True 55 | ) 56 | 57 | # %% 58 | X += X_test 59 | Y += Y_test 60 | groups += groups_test 61 | 62 | # %% 63 | X_pdists = [np.stack([scipy.spatial.distance.pdist(s) for s in x]) for x in X] 64 | X_pdists = np.concatenate(X_pdists) 65 | X_pdists = np.log1p(X_pdists) 66 | 67 | # %% 68 | pca = sklearn.decomposition.PCA(0.95) 69 | pca.fit(X_pdists) 70 | 71 | np.set_printoptions(suppress=True) 72 | np.cumsum(pca.explained_variance_ratio_) 73 | 74 | joblib.dump(pca, mabe.config.ROOT_PATH / "pose_pca.joblib") 75 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import datetime 4 | from typing import Union 5 | 6 | import click 7 | import click_pathlib 8 | import numpy as np 9 | import torch 10 | 11 | import mabe 12 | import mabe.config 13 | import mabe.data 14 | import mabe.loss 15 | import mabe.model 16 | import mabe.ringbuffer 17 | import mabe.training 18 | import mabe.util 19 | 20 | 21 | def parse_optional_parameter(val: Union[str, int, float]) -> Union[str, int, float]: 22 | try: 23 | val = int(val) 24 | except ValueError: 25 | try: 26 | val = float(val) 27 | except ValueError: 28 | pass 29 | return val 30 | 31 | 32 | @click.command( 33 | context_settings=dict( 34 | ignore_unknown_options=True, 35 | allow_extra_args=True, 36 | ) 37 | ) 38 | @click.option("--model_name", type=str, required=True) 39 | @click.option("--device", type=str, required=True) 40 | @click.option("--device", type=str, required=True) 41 | @click.option( 42 | "--feature_path", 43 | type=click_pathlib.Path(exists=True), 44 | default=mabe.config.ROOT_PATH / "features.hdf5", 45 | ) 46 | @click.option("--from_split", type=int, default=0) 47 | @click.option("--to_split", type=int, default=10) 48 | @click.pass_context 49 | def train(click_context, model_name, device, feature_path, from_split, to_split): 50 | config_kwargs = {} 51 | for arg in click_context.args: 52 | key, val = arg.split("--")[1].split("=") 53 | config_kwargs[key] = parse_optional_parameter(val) 54 | 55 | data = mabe.data.DataWrapper(feature_path) 56 | results = [] 57 | for split_idx in range(from_split, to_split): 58 | config = mabe.training.TrainingConfig( 59 | split_idx=split_idx, feature_path=feature_path, **config_kwargs 60 | ) 61 | split = mabe.data.CVSplit(split_idx, data) 62 | trainer = mabe.training.Trainer(config, data, split, device) 63 | result = trainer.train_model() 64 | print(f"\nBest validation F1 (split {split_idx}): {result.best_val_f1[0]:.3f}") 65 | 66 | results.append(result) 67 | 68 | f1s = [r.best_val_f1[0] for r in results] 69 | print(f"Validation F1: {np.mean(f1s):.3f} ± {np.std(f1s):.3f}") 70 | 71 | torch.save( 72 | results, 73 | mabe.config.ROOT_PATH 74 | / f"training_results_{datetime.datetime.now()}_{model_name}_{np.mean(f1s):.3f}.pt", 75 | _use_new_zipfile_serialization=False, 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | train() 81 | -------------------------------------------------------------------------------- /hydrogen/create_splits.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import h5py 4 | import numpy as np 5 | import sklearn 6 | import sklearn.decomposition 7 | import sklearn.linear_model 8 | import sklearn.preprocessing 9 | from fastprogress.fastprogress import force_console_behavior 10 | 11 | import mabe 12 | import mabe.config 13 | import mabe.loss 14 | import mabe.model 15 | import mabe.ringbuffer 16 | 17 | # %% 18 | num_splits = 32 19 | 20 | # %% 21 | master_bar, progress_bar = force_console_behavior() 22 | 23 | # %% 24 | feature_path = mabe.config.ROOT_PATH / "features_task123_final_pca.hdf5" 25 | 26 | # %% 27 | with h5py.File(feature_path, "r") as hdf: 28 | 29 | def load_all(groupname): 30 | return list(map(lambda v: v[:].astype(np.float32), hdf[groupname].values())) 31 | 32 | X_labeled = load_all("train/x") 33 | Y_labeled = load_all("train/y") 34 | 35 | annotators_labeled = np.array(list(map(lambda v: v[()], hdf["train/annotators"].values()))) 36 | num_annotators = len(np.unique(annotators_labeled)) 37 | 38 | clf_tasks_labeled = np.array(list(map(lambda v: v[()], hdf["train/clf_tasks"].values()))) 39 | num_clf_tasks = len(np.unique(clf_tasks_labeled)) 40 | 41 | X_unlabeled = load_all("test/x") 42 | Y_unlabeled = load_all("test/y") 43 | groups_unlabeled = list(map(lambda v: v[()], hdf["test/groups"].values())) 44 | 45 | # %% 46 | X = X_labeled + X_unlabeled 47 | Y = Y_labeled + Y_unlabeled 48 | 49 | scaler = sklearn.preprocessing.StandardScaler().fit(np.concatenate(X)) 50 | X = list(map(lambda x: scaler.transform(x), X)) 51 | X_labeled = list(map(lambda x: scaler.transform(x), X_labeled)) 52 | X_unlabeled = list(map(lambda x: scaler.transform(x), X_unlabeled)) 53 | 54 | # %% 55 | sample_lengths = np.array(list(map(len, X))) 56 | p_draw = sample_lengths / np.sum(sample_lengths) 57 | 58 | len(X), len(X_labeled), min(sample_lengths), max(sample_lengths) 59 | 60 | # %% 61 | for i in range(0, num_splits): 62 | indices_labeled = np.arange(len(X_labeled)) 63 | indices_unlabeled = len(X_labeled) + np.arange(len(X_unlabeled)) 64 | indices = np.arange(len(X)) 65 | 66 | # sample until the train split has at least one sample from each annotator 67 | valid = False 68 | while not valid: 69 | train_indices_labeled = np.random.choice( 70 | indices_labeled, int(0.85 * len(X_labeled)), replace=False 71 | ) 72 | val_indices_labeled = np.array( 73 | [i for i in indices_labeled if i not in train_indices_labeled] 74 | ) 75 | 76 | valid = len(np.unique(annotators_labeled[train_indices_labeled])) == num_annotators 77 | valid &= len(np.unique(clf_tasks_labeled[train_indices_labeled])) == num_clf_tasks 78 | valid &= len(np.unique(clf_tasks_labeled[val_indices_labeled])) >= ( 79 | num_clf_tasks - 1 80 | ) # one task with only one trajectory 81 | 82 | train_indices_unlabeled = np.random.choice( 83 | indices_unlabeled, int(0.85 * len(X_unlabeled)), replace=False 84 | ) 85 | train_indices = np.concatenate((train_indices_labeled, train_indices_unlabeled)) 86 | val_indices_unlabeled = np.array( 87 | [i for i in indices_unlabeled if i not in train_indices_unlabeled] 88 | ) 89 | val_indices = np.concatenate((val_indices_labeled, val_indices_unlabeled)) 90 | 91 | split = dict( 92 | indices_labeled=indices_labeled, 93 | indices_unlabeled=indices_unlabeled, 94 | indices=indices, 95 | train_indices_labeled=train_indices_labeled, 96 | train_indices_unlabeled=train_indices_unlabeled, 97 | train_indices=train_indices, 98 | val_indices_labeled=val_indices_labeled, 99 | val_indices_unlabeled=val_indices_unlabeled, 100 | val_indices=val_indices, 101 | ) 102 | 103 | pickle.dump(split, open(mabe.config.ROOT_PATH / f"split_{i}.pkl", "wb")) 104 | -------------------------------------------------------------------------------- /mabe/util.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import io 3 | 4 | import numpy as np 5 | import torch 6 | from fastprogress.fastprogress import force_console_behavior 7 | 8 | master_bar, progress_bar = force_console_behavior() 9 | 10 | 11 | def predict_test_data( 12 | cpc, logreg, data, device, config, params, fixed_params=False, task12=True, task3=True 13 | ): 14 | def load_task_params(task): 15 | if config.use_best_task0: 16 | task = 0 17 | 18 | if fixed_params: 19 | cpc_state_dict, logreg_state_dict = params 20 | else: 21 | cpc_state_dict, logreg_state_dict = params[task] 22 | cpc.load_state_dict(cpc_state_dict) 23 | logreg.load_state_dict(logreg_state_dict) 24 | cpc_ = cpc.eval() 25 | logreg_ = logreg.eval() 26 | return cpc_, logreg_ 27 | 28 | crop_pre, crop_post = cpc.get_crops(device) 29 | 30 | def add_padding(seq): 31 | return np.concatenate( 32 | ( 33 | np.zeros_like(seq)[:crop_pre], 34 | seq, 35 | np.zeros_like(seq)[:crop_post], 36 | ) 37 | ) 38 | 39 | cpc, logreg = load_task_params(0) 40 | test_predictions = collections.defaultdict(dict) 41 | test_logits = collections.defaultdict(dict) 42 | if task12: 43 | with torch.no_grad(): 44 | bar = progress_bar(range(len(data.X_unlabeled))) 45 | for idx in bar: 46 | x = add_padding(data.X_unlabeled[idx].astype(np.float32)) 47 | x_extra = None 48 | if config.use_extra_features: 49 | x_extra = data.X_unlabeled_extra[idx].astype(np.float32) 50 | x_extra = torch.from_numpy(x_extra).to(device, non_blocking=True) 51 | 52 | g = data.groups_unlabeled[idx] 53 | 54 | x = torch.transpose(torch.from_numpy(x[None, :, :]), 2, 1).to( 55 | device, non_blocking=True 56 | ) 57 | x_emb = cpc.embedder(x) 58 | 59 | c = cpc.apply_contexter(x_emb, device) 60 | 61 | logreg_features = c[0].T 62 | 63 | for annotator in range(data.num_annotators): 64 | a = np.array([annotator]).repeat(len(logreg_features)) 65 | task = 0 66 | l = logreg(logreg_features, x_extra, a, task) 67 | 68 | l = torch.cat((l[:1], l), dim=0) # crop from feature preprocessing 69 | p = torch.argmax(l, dim=-1) 70 | 71 | assert len(p) == len(data.X_unlabeled[idx]) + 1 72 | 73 | with io.BytesIO() as buffer: 74 | np.savez_compressed(buffer, p.cpu().data.numpy()) 75 | test_predictions[g.decode("utf-8")][annotator] = buffer.getvalue() 76 | with io.BytesIO() as buffer: 77 | np.savez_compressed(buffer, l.cpu().data.numpy().astype(np.float32)) 78 | test_logits[g.decode("utf-8")][annotator] = buffer.getvalue() 79 | 80 | task3_test_logits = collections.defaultdict(dict) 81 | if task3: 82 | annotator = 0 # only first annotator for task 3 83 | for task in range(1, max(data.clf_tasks) + 1): 84 | cpc, logreg = load_task_params(task) 85 | 86 | with torch.no_grad(): 87 | # bar = progress_bar(range(len(data.X_unlabeled))) 88 | bar = range(len(data.X_unlabeled)) 89 | for idx in bar: 90 | x = add_padding(data.X_unlabeled[idx].astype(np.float32)) 91 | x_extra = None 92 | if config.use_extra_features: 93 | x_extra = data.X_unlabeled_extra[idx].astype(np.float32) 94 | x_extra = torch.from_numpy(x_extra).to(device, non_blocking=True) 95 | 96 | g = data.groups_unlabeled[idx] 97 | 98 | x = torch.transpose(torch.from_numpy(x[None, :, :]), 2, 1).to( 99 | device, non_blocking=True 100 | ) 101 | x_emb = cpc.embedder(x) 102 | 103 | c = cpc.apply_contexter(x_emb, device) 104 | 105 | logreg_features = c[0].T 106 | a = np.array([annotator]).repeat(len(logreg_features)) 107 | l = logreg(logreg_features, x_extra, a, task) 108 | 109 | l = torch.cat((l[:1], l), dim=0) # crop from feature preprocessing 110 | p = torch.argmax(l, dim=-1) 111 | 112 | assert len(p) == len(data.X_unlabeled[idx]) + 1 113 | 114 | with io.BytesIO() as buffer: 115 | np.savez_compressed(buffer, l.cpu().data.numpy().astype(np.float32)) 116 | task3_test_logits[g.decode("utf-8")][task] = buffer.getvalue() 117 | 118 | return test_predictions, test_logits, task3_test_logits 119 | -------------------------------------------------------------------------------- /hydrogen/features.py: -------------------------------------------------------------------------------- 1 | # %% codecell 2 | import io 3 | 4 | import h5py 5 | import numpy as np 6 | import scipy.spatial 7 | import scipy.spatial.distance 8 | import scipy.special 9 | import torch 10 | from fastprogress.fastprogress import force_console_behavior 11 | 12 | import mabe 13 | import mabe.config 14 | import mabe.features 15 | 16 | # %% codecell 17 | master_bar, progress_bar = force_console_behavior() 18 | 19 | train_path = mabe.config.ROOT_PATH / "train.npy" 20 | train_task2_path = mabe.config.ROOT_PATH / "train_task2.npy" 21 | train_task3_path = mabe.config.ROOT_PATH / "train_task3.npy" 22 | test_path = mabe.config.ROOT_PATH / "test-release.npy" 23 | feature_path = mabe.config.ROOT_PATH / "features_task123_final_pca.hdf5" 24 | 25 | # %% 26 | test = np.load(mabe.config.ROOT_PATH / "test-release.npy", allow_pickle=True).item() 27 | 28 | # %% codecell 29 | X, X_extra, Y, groups, annotators = mabe.features.load_dataset(train_path) 30 | clf_tasks = [0] * len(X) 31 | 32 | Xt2, Xt2_extra, Yt2, groupst2, annotatorst2 = mabe.features.load_dataset(train_task2_path) 33 | 34 | X += Xt2 35 | X_extra += Xt2_extra 36 | Y += Yt2 37 | groups += groupst2 38 | annotators += annotatorst2 39 | clf_tasks = [0] * len(X) 40 | 41 | # %% 42 | for behavior, data in mabe.features.load_task3_datasets(train_task3_path): 43 | clf_task = int(behavior[-1]) + 1 44 | 45 | X += data[0] 46 | X_extra += data[1] 47 | Y += data[2] 48 | groups += data[3] 49 | annotators += data[4] 50 | clf_tasks += [clf_task] * len(data[0]) 51 | 52 | # %% codecell 53 | X_test, X_extra_test, Y_test, groups_test, _ = mabe.features.load_dataset(test_path) 54 | 55 | # %% 56 | results_files = ( 57 | "training_results_2021-05-03 07:12:21.568593_teacher_ensemble2_0to5_0.844.pt", 58 | "training_results_2021-05-03 07:30:45.111496_teacher_ensemble2_10to15_0.829.pt", 59 | "training_results_2021-05-03 19:30:01.286213_teacher_ensemble2_10to15_0.827.pt", 60 | "training_results_2021-05-03 19:50:54.641450_teacher_ensemble2_15to20_0.814.pt", 61 | "training_results_2021-05-03 07:20:44.078873_teacher_ensemble2_20to25_0.832.pt", 62 | "training_results_2021-05-03 19:50:16.149305_teacher_ensemble2_25to30_0.807.pt", 63 | ) 64 | 65 | # %% 66 | def get_annotator_logits(results, annotator_id="keep"): 67 | if annotator_id == "keep": 68 | return results.test_logits 69 | else: 70 | assert isinstance(annotator_id, int) 71 | test_logits = {} 72 | for key, value in results.test_logits.items(): 73 | test_logits[key] = {annotator_id: value[annotator_id]} 74 | return test_logits 75 | 76 | 77 | results = [] 78 | for filename in progress_bar(results_files): 79 | results += list(map(get_annotator_logits, torch.load(mabe.config.ROOT_PATH / filename))) 80 | 81 | len(results) 82 | 83 | # %% 84 | all_annotators = list({s["annotator_id"] for s in test["sequences"].values()}) 85 | # all_annotators = [0] 86 | 87 | # %% 88 | Y_test_dark_annotators = [] 89 | for key, sequence in progress_bar(test["sequences"].items()): 90 | all_probs = [] 91 | for annotator_id in all_annotators: 92 | preds = np.stack( 93 | list( 94 | map( 95 | lambda s: np.load(io.BytesIO(s[key][annotator_id]))["arr_0"], 96 | results, 97 | ) 98 | ) 99 | ) 100 | all_probs.append(scipy.special.softmax(preds, -1).mean(axis=0)) 101 | assert np.all(np.isfinite(all_probs[-1])) 102 | 103 | all_probs = np.stack(all_probs).transpose(1, 0, 2) 104 | Y_test_dark_annotators.append(all_probs) 105 | 106 | # %% 107 | sample_submission = np.load( 108 | mabe.config.ROOT_PATH / "sample-submission-task3.npy", allow_pickle=True 109 | ).item() 110 | 111 | # %% 112 | results = [] 113 | for filename in progress_bar(results_files): 114 | results += list( 115 | map(lambda r: r.task3_test_logits, torch.load(mabe.config.ROOT_PATH / filename)) 116 | ) 117 | 118 | len(results) 119 | 120 | # %% 121 | Y_test_dark_behaviors = [] 122 | for key, sequence in progress_bar(test["sequences"].items()): 123 | all_probs = [] 124 | for behavior_key in sorted(sample_submission.keys()): 125 | behavior_idx = int(behavior_key[-1]) 126 | 127 | preds = np.stack( 128 | list( 129 | map( 130 | lambda s: np.load(io.BytesIO(s[key][behavior_idx + 1]))["arr_0"], 131 | results, 132 | ) 133 | ) 134 | ) 135 | all_probs.append(scipy.special.softmax(preds, -1).mean(axis=0)) 136 | assert np.all(np.isfinite(all_probs[-1])) 137 | 138 | all_probs = np.stack(all_probs).transpose(1, 0, 2) 139 | Y_test_dark_behaviors.append(all_probs) 140 | 141 | # %% codecell 142 | with h5py.File(feature_path, "w") as hdf: 143 | 144 | def store(groupname, values): 145 | grp = hdf.create_group(groupname) 146 | for idx, v in enumerate(values): 147 | # grp.create_dataset(f"{idx:04d}", data=v, compression='lzf') 148 | grp[f"{idx:04d}"] = v 149 | 150 | store("train/x", X) 151 | store("train/x_extra", X_extra) 152 | store("train/y", Y) 153 | store("train/groups", groups) 154 | store("train/annotators", annotators) 155 | store("train/clf_tasks", clf_tasks) 156 | 157 | store("test/x", X_test) 158 | store("test/x_extra", X_extra_test) 159 | store("test/y", Y_test) 160 | store("test/y_dark_annotators", Y_test_dark_annotators) 161 | store("test/y_dark_behaviors", Y_test_dark_behaviors) 162 | store("test/groups", groups_test) 163 | -------------------------------------------------------------------------------- /mabe/features.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | import numba 3 | import numpy as np 4 | import scipy 5 | import scipy.spatial 6 | from fastprogress.fastprogress import force_console_behavior 7 | 8 | import mabe 9 | import mabe.config 10 | 11 | master_bar, progress_bar = force_console_behavior() 12 | 13 | 14 | pose_pca = joblib.load(mabe.config.ROOT_PATH / "pose_pca.joblib") 15 | 16 | 17 | @numba.njit 18 | def get_mouse_orientation_angle(mouse): 19 | tail_coords = mouse[:, :, -1] 20 | neck_coords = mouse[:, :, 3] 21 | 22 | orientation_vector = neck_coords - tail_coords 23 | body_angle = np.arctan2(orientation_vector[:, 1], orientation_vector[:, 0]) 24 | 25 | return tail_coords, neck_coords, orientation_vector, body_angle 26 | 27 | 28 | @numba.njit 29 | def rotate_broadcast(vector, angle): 30 | vector = vector.astype(np.float32) 31 | for t_idx in range(vector.shape[0]): 32 | c, s = np.cos(-angle[t_idx]), np.sin(-angle[t_idx]) 33 | R = np.array([[c, -s], [s, c]], dtype=np.float32) 34 | vector[t_idx] = np.dot(R, vector[t_idx]) 35 | return vector 36 | 37 | 38 | def normalize_mouse(mouse): 39 | tail_coords, neck_coords, orientation_vector, body_angle = get_mouse_orientation_angle(mouse) 40 | 41 | mid_coords = (tail_coords + neck_coords) / 2.0 42 | mouse -= mid_coords[:, :, None] 43 | 44 | mouse = rotate_broadcast(mouse, body_angle) 45 | 46 | return mouse, mid_coords 47 | 48 | 49 | def get_distance_angle_to(xy0, xy1): 50 | vector = xy1 - xy0 51 | distance = np.linalg.norm(vector, axis=1)[:, None] 52 | vector = vector / distance 53 | return vector, distance 54 | 55 | 56 | def get_distance_angle_between_mice(m0, m1): 57 | neck0 = m0[:, :, 3] 58 | neck1 = m1[:, :, 3] 59 | butt1 = m1[:, :, -1] 60 | 61 | tail_coords, neck_coords, orientation_vector, body_angle = get_mouse_orientation_angle(m0) 62 | 63 | v0, d0 = get_distance_angle_to(neck0, neck1) 64 | v0 = rotate_broadcast(v0, body_angle) 65 | v1, d1 = get_distance_angle_to(neck0, butt1) 66 | v1 = rotate_broadcast(v1, body_angle) 67 | 68 | return np.concatenate((v0, d0, v1, d1), axis=1) 69 | 70 | 71 | def get_movement_velocity_orienation(mouse): 72 | tail_coords, neck_coords, orientation_vector, body_angle = get_mouse_orientation_angle(mouse) 73 | 74 | mean_mouse = np.mean(mouse, axis=2) 75 | mean_motion = np.diff(mean_mouse, axis=0) 76 | 77 | velocity = np.einsum("bc,bc->b", orientation_vector[1:], mean_motion)[:, None] 78 | 79 | angle_diff = np.diff(body_angle, axis=0) 80 | orientation_change = np.stack((np.cos(angle_diff), np.sin(angle_diff)), axis=1) 81 | 82 | return velocity, orientation_change 83 | 84 | 85 | def wall_dist(mid): 86 | xmax = 1020 87 | ymax = 570 88 | 89 | wall_dist = np.linalg.norm( 90 | np.stack( 91 | ( 92 | np.stack((mid[:, 0], xmax - mid[:, 0])).min(axis=0), 93 | np.stack((mid[:, 1], ymax - mid[:, 1])).min(axis=0), 94 | ) 95 | ), 96 | axis=0, 97 | ) 98 | 99 | return np.tanh(wall_dist / 465) 100 | 101 | 102 | def get_pdists(trajectory): 103 | t = trajectory.transpose(0, 1, 3, 2) 104 | t = t.reshape(t.shape[0], -1, t.shape[-1])[1:] 105 | 106 | X_pdists = np.stack([scipy.spatial.distance.pdist(s) for s in t]) 107 | X_pdists = np.log1p(X_pdists) 108 | 109 | return pose_pca.transform(X_pdists) 110 | 111 | 112 | def transform_to_feature_vector( 113 | trajectory, with_abs_pos=False, with_wall_dist=True, with_pdists=True, raw_trajectories=False 114 | ): 115 | if raw_trajectories: 116 | t = trajectory.transpose(0, 1, 3, 2) 117 | t = t.reshape(t.shape[0], -1, t.shape[-1])[1:] 118 | 119 | return t, t 120 | 121 | m0 = trajectory[:, 0, :, :].copy() 122 | m1 = trajectory[:, 1, :, :].copy() 123 | velocity, orientation = get_movement_velocity_orienation(m0) 124 | relative_position_info = get_distance_angle_between_mice(m0, m1) 125 | m0, mid0 = normalize_mouse(m0) 126 | m1, mid1 = normalize_mouse(m1) 127 | 128 | if with_abs_pos: 129 | m0 = np.concatenate((m0, mid0[:, :, None]), axis=-1) 130 | m1 = np.concatenate((m1, mid1[:, :, None]), axis=-1) 131 | 132 | m0 = m0.reshape(-1, m0.shape[1] * m0.shape[2]) 133 | m1 = m1.reshape(-1, m1.shape[1] * m1.shape[2]) 134 | 135 | features = np.concatenate( 136 | (m0[1:], m1[1:], velocity, orientation, relative_position_info[1:]), axis=1 137 | ) 138 | 139 | if with_pdists: 140 | pdists = get_pdists(trajectory) 141 | features = np.concatenate((features, pdists), axis=1) 142 | 143 | indices = np.arange(0, velocity.shape[0]) 144 | is_beginning = np.tanh(indices / 1500).reshape(-1, 1) 145 | is_ending = np.tanh((velocity.shape[0] - indices) / 2300).reshape(-1, 1) 146 | 147 | extra_features = np.concatenate( 148 | (is_beginning, is_ending, wall_dist(mid0)[1:, None], wall_dist(mid1)[1:, None]), axis=1 149 | ) 150 | 151 | return features, extra_features 152 | 153 | 154 | def get_features_and_labels(sample_sequence, raw_trajectories=False): 155 | features, extra_features = transform_to_feature_vector( 156 | sample_sequence["keypoints"], raw_trajectories=raw_trajectories 157 | ) 158 | if "annotations" in sample_sequence: 159 | labels = sample_sequence["annotations"] 160 | labels = labels[1:] 161 | annotator = sample_sequence["annotator_id"] 162 | else: 163 | labels = np.array([-1] * features.shape[0]) 164 | annotator = None 165 | 166 | return features, extra_features, labels, annotator 167 | 168 | 169 | def load_dataset(path=None, raw_data=None, raw_trajectories=False): 170 | assert path is not None or raw_data is not None 171 | if raw_data is None: 172 | raw_data = np.load(path, allow_pickle=True).item() 173 | 174 | if "vocabulary" in raw_data: 175 | label_vocabulary = raw_data["vocabulary"] 176 | print(label_vocabulary) 177 | raw_data = raw_data["sequences"] 178 | 179 | X = [] 180 | X_extra = [] 181 | Y = [] 182 | groups = [] 183 | annotators = [] 184 | 185 | for key, data in raw_data.items(): 186 | x, x_extra, y, annotator = get_features_and_labels(data, raw_trajectories=raw_trajectories) 187 | X.append(x) 188 | X_extra.append(x_extra) 189 | Y.append(y) 190 | groups.append(key) 191 | annotators.append(annotator) 192 | 193 | return X, X_extra, Y, groups, annotators 194 | 195 | 196 | def load_task3_datasets(path, raw_trajectories=False): 197 | raw_data = np.load(path, allow_pickle=True).item() 198 | 199 | for behavior_key in raw_data.keys(): 200 | raw_data_behavior = raw_data[behavior_key] 201 | 202 | yield behavior_key, load_dataset( 203 | raw_data=raw_data_behavior, raw_trajectories=raw_trajectories 204 | ) 205 | -------------------------------------------------------------------------------- /hydrogen/merge_results_alltasks.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import io 3 | 4 | import numpy as np 5 | import scipy 6 | import scipy.special 7 | import torch 8 | from fastprogress.fastprogress import force_console_behavior 9 | 10 | import mabe 11 | import mabe.config 12 | import mabe.loss 13 | import mabe.model 14 | import mabe.training 15 | import mabe.util 16 | 17 | # %% 18 | master_bar, progress_bar = force_console_behavior() 19 | 20 | # %% 21 | results_files = ( 22 | "training_results_2021-05-06 03:44:12.966718_final_ensemble5_0to2_0.886.pt", 23 | "training_results_2021-05-06 09:21:07.001650_final_ensemble5_2to4_0.816.pt", 24 | "training_results_2021-05-06 15:31:40.573751_final_ensemble5_4to6_0.844.pt", 25 | "training_results_2021-05-06 21:23:28.094687_final_ensemble5_6to8_0.747.pt", 26 | "training_results_2021-05-06 03:42:25.017044_final_ensemble5_8to10_0.892.pt", 27 | "training_results_2021-05-06 09:07:51.314619_final_ensemble5_10to12_0.829.pt", 28 | "training_results_2021-05-06 09:41:19.682021_final_ensemble5_10to12_0.845.pt", 29 | "training_results_2021-05-06 15:12:38.257900_final_ensemble5_12to14_0.846.pt", 30 | "training_results_2021-05-06 20:45:42.821152_final_ensemble5_14to16_0.843.pt", 31 | "training_results_2021-05-06 03:27:10.472722_final_ensemble5_16to18_0.780.pt", 32 | "training_results_2021-05-06 16:55:01.388127_final_ensemble5_20to22_0.776.pt", 33 | "training_results_2021-05-07 01:04:17.931439_final_ensemble5_22to24_0.861.pt", 34 | "training_results_2021-05-06 03:36:34.616866_final_ensemble5_24to26_0.874.pt", 35 | "training_results_2021-05-06 10:22:00.194306_final_ensemble5_26to28_0.838.pt", 36 | "training_results_2021-05-06 17:00:43.363940_final_ensemble5_28to30_0.764.pt", 37 | "training_results_2021-05-06 23:41:20.680606_final_ensemble5_30to32_0.855.pt", 38 | ) 39 | 40 | # %% 41 | test = np.load(mabe.config.ROOT_PATH / "test-release.npy", allow_pickle=True).item() 42 | 43 | # %% 44 | def get_task3_scores(result): 45 | mean_task3_score = np.mean( 46 | list(map(lambda v: np.max(np.array(v)), list(result.clf_val_f1s.values())[1:4])) 47 | + list(map(lambda v: np.max(np.array(v)), list(result.clf_val_f1s.values())[5:])) 48 | ) 49 | return [ 50 | (max(v) if v[0] is not None else mean_task3_score) for k, v in result.clf_val_f1s.items() 51 | ] 52 | 53 | 54 | # %% 55 | results = [] 56 | scores_list = [] 57 | for filename in progress_bar(results_files): 58 | result_batch = torch.load(mabe.config.ROOT_PATH / filename) 59 | results += list(map(lambda r: r.task3_test_logits, result_batch)) 60 | scores_list += [np.stack(list(map(get_task3_scores, result_batch)))] 61 | 62 | scores = np.concatenate(scores_list) 63 | scores = scores.T 64 | 65 | np.percentile(scores, [10, 50, 90], axis=0).mean(axis=1) 66 | 67 | # %% 68 | sample_submission = np.load( 69 | mabe.config.ROOT_PATH / "sample-submission-task3.npy", allow_pickle=True 70 | ).item() 71 | 72 | # %% 73 | submission: dict = collections.defaultdict(dict) 74 | for behavior_key in progress_bar(sample_submission.keys()): 75 | behavior_idx = int(behavior_key[-1]) 76 | 77 | for key, sequence in test["sequences"].items(): 78 | preds = np.stack( 79 | list( 80 | map( 81 | lambda s: np.load(io.BytesIO(s[key][behavior_idx + 1]))["arr_0"], 82 | results, 83 | ) 84 | ) 85 | ) 86 | 87 | """ 88 | preds = np.argmax( 89 | np.average(scipy.special.softmax(preds, -1), weights=scores[behavior_idx + 1], axis=0), 90 | axis=-1, 91 | ) 92 | """ 93 | preds = np.argmax(np.mean(scipy.special.softmax(preds, -1), axis=0), axis=-1) 94 | submission[behavior_key][key] = preds 95 | 96 | # %% 97 | def validate_submission_task3(submission, sample_submission): 98 | if not isinstance(submission, dict): 99 | print("Submission should be dict") 100 | return False 101 | 102 | if not submission.keys() == sample_submission.keys(): 103 | print("Submission keys don't match") 104 | return False 105 | for behavior in submission: 106 | sb = submission[behavior] 107 | ssb = sample_submission[behavior] 108 | if not isinstance(sb, dict): 109 | print("Submission should be dict") 110 | return False 111 | 112 | if not sb.keys() == ssb.keys(): 113 | print("Submission keys don't match") 114 | return False 115 | 116 | for key in sb: 117 | sv = sb[key] 118 | ssv = ssb[key] 119 | if not len(sv) == len(ssv): 120 | print(f"Submission lengths of {key} doesn't match") 121 | return False 122 | 123 | for key, sv in sb.items(): 124 | if not all(isinstance(x, (np.int32, np.int64, int)) for x in list(sv)): 125 | print(f"Submission of {key} is not all integers") 126 | return False 127 | 128 | print("All tests passed") 129 | return True 130 | 131 | 132 | # %% 133 | if validate_submission_task3(submission, sample_submission): 134 | np.save(mabe.config.ROOT_PATH / "task3_submission.npy", submission) 135 | 136 | # %% 137 | sample_submission = np.load( 138 | mabe.config.ROOT_PATH / "sample-submission.npy", allow_pickle=True 139 | ).item() 140 | 141 | # %% 142 | merged_submission = {} 143 | for key in progress_bar(sample_submission.keys()): 144 | annotator_id = 0 145 | preds = np.stack( 146 | list(map(lambda s: np.load(io.BytesIO(s[key][annotator_id]))["arr_0"], results)) 147 | ) 148 | preds = np.argmax(np.average(scipy.special.softmax(preds, -1), weights=scores, axis=0), axis=-1) 149 | 150 | merged_submission[key] = preds 151 | 152 | # %% 153 | def validate_submission_task1(submission, sample_submission): 154 | if not isinstance(submission, dict): 155 | print("Submission should be dict") 156 | return False 157 | 158 | if not submission.keys() == sample_submission.keys(): 159 | print("Submission keys don't match") 160 | return False 161 | 162 | for key in submission: 163 | sv = submission[key] 164 | ssv = sample_submission[key] 165 | if not len(sv) == len(ssv): 166 | print(f"Submission lengths of {key} doesn't match") 167 | return False 168 | 169 | for key, sv in submission.items(): 170 | if not all(isinstance(x, (np.int32, np.int64, int)) for x in list(sv)): 171 | print(f"Submission of {key} is not all integers") 172 | return False 173 | 174 | print("All tests passed") 175 | return True 176 | 177 | 178 | # %% 179 | if validate_submission_task1(merged_submission, sample_submission): 180 | np.save(mabe.config.ROOT_PATH / "task1_submission.npy", merged_submission) 181 | 182 | 183 | # %% 184 | merged_submission = {} 185 | for key, sequence in test["sequences"].items(): 186 | annotator_id = sequence["annotator_id"] 187 | preds = np.stack( 188 | list(map(lambda s: np.load(io.BytesIO(s.test_logits[key][annotator_id]))["arr_0"], results)) 189 | ) 190 | preds = np.argmax(scipy.special.softmax(preds, -1).mean(axis=0), axis=-1) 191 | 192 | merged_submission[key] = preds 193 | 194 | # %% 195 | sample_submission = np.load( 196 | mabe.config.ROOT_PATH / "sample-submission-task2.npy", allow_pickle=True 197 | ).item() 198 | 199 | 200 | def validate_submission_task2(submission, sample_submission): 201 | if not isinstance(submission, dict): 202 | print("Submission should be dict") 203 | return False 204 | 205 | if not submission.keys() == sample_submission.keys(): 206 | print("Submission keys don't match") 207 | return False 208 | 209 | for key in submission: 210 | sv = submission[key] 211 | ssv = sample_submission[key] 212 | if not len(sv) == len(ssv): 213 | print(f"Submission lengths of {key} doesn't match") 214 | return False 215 | 216 | for key, sv in submission.items(): 217 | if not all(isinstance(x, (np.int32, np.int64, int)) for x in list(sv)): 218 | print(f"Submission of {key} is not all integers") 219 | return False 220 | 221 | print("All tests passed") 222 | return True 223 | 224 | 225 | # %% 226 | if validate_submission_task2(merged_submission, sample_submission): 227 | np.save(mabe.config.ROOT_PATH / "task2_submission.npy", merged_submission) 228 | -------------------------------------------------------------------------------- /hydrogen/eda.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | import sklearn 5 | import sklearn.linear_model 6 | import sklearn.metrics 7 | import sklearn.model_selection 8 | from bayes_opt import BayesianOptimization, UtilityFunction 9 | from matplotlib import gridspec 10 | 11 | import mabe 12 | import mabe.config 13 | 14 | # %% 15 | train = np.load(mabe.config.ROOT_PATH / "train.npy", allow_pickle=True).item() 16 | test = np.load(mabe.config.ROOT_PATH / "test-release.npy", allow_pickle=True).item() 17 | sample_submission = np.load( 18 | mabe.config.ROOT_PATH / "sample-submission.npy", allow_pickle=True 19 | ).item() 20 | 21 | # %% 22 | sample_key = list(train["sequences"].keys())[4] 23 | sample = train["sequences"][sample_key] 24 | 25 | # Dimensions: (# frames) x (mouse ID) x (x, y coordinate) x (body part). 26 | # Units: pixels; coordinates are relative to the entire image. 27 | # Original image dimensions are 1024 x 570. 28 | 29 | train["vocabulary"] 30 | 31 | # %% 32 | sample.keys() 33 | 34 | # %% 35 | sample["keypoints"].shape 36 | 37 | # %% 38 | plt.scatter(*sample["keypoints"][-1000][0], c="blue") 39 | plt.scatter(*sample["keypoints"][-1000][1], c="red") 40 | 41 | # %% 42 | groups = train["sequences"].keys() 43 | annotation_df = pd.DataFrame( 44 | np.concatenate( 45 | list( 46 | map( 47 | lambda s: np.stack( 48 | ( 49 | np.arange(len(s["annotations"])), 50 | np.arange(len(s["annotations"]))[::-1], 51 | s["annotations"], 52 | ) 53 | ), 54 | train["sequences"].values(), 55 | ) 56 | ), 57 | axis=1, 58 | ).T, 59 | columns=["timestep", "timestep_reverse", "annotation"], 60 | ) 61 | 62 | # %% 63 | annotation_df.groupby("timestep").apply(lambda ts: (ts.annotation == 0).mean()).plot() 64 | 65 | # %% 66 | annotation_df.groupby("timestep_reverse").apply(lambda ts: (ts.annotation == 0).mean()).plot() 67 | 68 | # %% 69 | annotation_df.groupby("timestep").apply(lambda ts: (ts.annotation == 1).mean()).plot() 70 | 71 | # %% 72 | annotation_df.groupby("timestep_reverse").apply(lambda ts: (ts.annotation == 1).mean()).plot() 73 | 74 | # %% 75 | annotation_df.groupby("timestep").apply(lambda ts: (ts.annotation == 2).mean()).plot() 76 | 77 | # %% 78 | annotation_df.groupby("timestep").apply(lambda ts: (ts.annotation == 3).mean()).plot() 79 | 80 | # %% 81 | keypoints = list(map(lambda s: s["keypoints"], train["sequences"].values())) 82 | groups = np.concatenate([np.ones(len(keypoints[i])) * i for i in range(len(keypoints))]).astype( 83 | np.int 84 | ) 85 | 86 | # %% 87 | linear = sklearn.linear_model.LogisticRegression( 88 | multi_class="multinomial", max_iter=1000, class_weight="balanced" 89 | ) 90 | 91 | # %% 92 | def evaluate_f1(c): 93 | l = linear.fit( 94 | np.tanh(annotation_df[["timestep"]] / c), 95 | annotation_df["annotation"], 96 | ) 97 | return sklearn.metrics.f1_score( 98 | annotation_df["annotation"], 99 | l.predict( 100 | np.tanh(annotation_df[["timestep"]] / c), 101 | ), 102 | average="macro", 103 | ) 104 | 105 | 106 | # %% 107 | pbounds = {"c": (1, 5000)} 108 | 109 | optimizer = BayesianOptimization( 110 | f=evaluate_f1, 111 | pbounds=pbounds, 112 | random_state=42, 113 | ) 114 | 115 | optimizer.maximize( 116 | init_points=5, 117 | n_iter=50, 118 | ) 119 | 120 | print(optimizer.max) 121 | 122 | # %% 123 | def posterior(optimizer, x_obs, y_obs, grid): 124 | optimizer._gp.fit(x_obs, y_obs) 125 | 126 | mu, sigma = optimizer._gp.predict(grid, return_std=True) 127 | return mu, sigma 128 | 129 | 130 | def plot_gp(optimizer, x): 131 | fig = plt.figure(figsize=(16, 10)) 132 | steps = len(optimizer.space) 133 | fig.suptitle( 134 | f"Gaussian Process and Utility Function After {steps} Steps", fontdict={"size": 30} 135 | ) 136 | 137 | gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1]) 138 | axis = plt.subplot(gs[0]) 139 | acq = plt.subplot(gs[1]) 140 | 141 | x_obs = np.array([[res["params"]["c"]] for res in optimizer.res]) 142 | y_obs = np.array([res["target"] for res in optimizer.res]) 143 | 144 | mu, sigma = posterior(optimizer, x_obs, y_obs, x) 145 | axis.plot(x_obs.flatten(), y_obs, "D", markersize=8, label="Observations", color="r") 146 | axis.plot(x, mu, "--", color="k", label="Prediction") 147 | 148 | axis.fill( 149 | np.concatenate([x, x[::-1]]), 150 | np.concatenate([mu - 1.9600 * sigma, (mu + 1.9600 * sigma)[::-1]]), 151 | alpha=0.6, 152 | fc="c", 153 | ec="None", 154 | label="95% confidence interval", 155 | ) 156 | 157 | axis.set_ylim((None, None)) 158 | axis.set_ylabel("f(x)", fontdict={"size": 20}) 159 | axis.set_xlabel("x", fontdict={"size": 20}) 160 | 161 | utility_function = UtilityFunction(kind="ucb", kappa=5, xi=0) 162 | utility = utility_function.utility(x, optimizer._gp, 0) 163 | acq.plot(x, utility, label="Utility Function", color="purple") 164 | acq.plot( 165 | x[np.argmax(utility)], 166 | np.max(utility), 167 | "*", 168 | markersize=15, 169 | label="Next Best Guess", 170 | markerfacecolor="gold", 171 | markeredgecolor="k", 172 | markeredgewidth=1, 173 | ) 174 | acq.set_ylim((0, np.max(utility) + 0.5)) 175 | acq.set_ylabel("Utility", fontdict={"size": 20}) 176 | acq.set_xlabel("x", fontdict={"size": 20}) 177 | 178 | axis.legend(loc=2, bbox_to_anchor=(1.01, 1), borderaxespad=0.0) 179 | acq.legend(loc=2, bbox_to_anchor=(1.01, 1), borderaxespad=0.0) 180 | 181 | 182 | # %% 183 | plot_gp(optimizer, np.linspace(*pbounds["c"], num=1000)[:, None]) 184 | 185 | # %% 186 | def evaluate_f1_reverse(c): 187 | l = linear.fit( 188 | np.tanh(annotation_df[["timestep_reverse"]] / c), 189 | annotation_df["annotation"], 190 | ) 191 | return sklearn.metrics.f1_score( 192 | annotation_df["annotation"], 193 | l.predict( 194 | np.tanh(annotation_df[["timestep_reverse"]] / c), 195 | ), 196 | average="macro", 197 | ) 198 | 199 | 200 | # %% 201 | pbounds = {"c": (1, 10000)} 202 | 203 | optimizer = BayesianOptimization( 204 | f=evaluate_f1_reverse, 205 | pbounds=pbounds, 206 | random_state=42, 207 | ) 208 | 209 | optimizer.maximize( 210 | init_points=5, 211 | n_iter=50, 212 | ) 213 | 214 | print(optimizer.max) 215 | 216 | # %% 217 | plot_gp(optimizer, np.linspace(*pbounds["c"], num=1000)[:, None]) 218 | 219 | # %% 220 | def evaluate_f1_combined(c0, c1): 221 | features = np.tanh( 222 | np.concatenate( 223 | (annotation_df[["timestep"]] / c0, annotation_df[["timestep_reverse"]] / c1), axis=-1 224 | ) 225 | ) 226 | l = linear.fit( 227 | features, 228 | annotation_df["annotation"], 229 | ) 230 | return sklearn.metrics.f1_score( 231 | annotation_df["annotation"], 232 | l.predict(features), 233 | average="macro", 234 | ) 235 | 236 | 237 | # %% 238 | pbounds = {"c0": (1, 10000), "c1": (1, 10000)} 239 | 240 | optimizer = BayesianOptimization( 241 | f=evaluate_f1_combined, 242 | pbounds=pbounds, 243 | random_state=42, 244 | ) 245 | 246 | optimizer.maximize( 247 | init_points=5, 248 | n_iter=500, 249 | ) 250 | 251 | print(optimizer.max) 252 | 253 | # %% 254 | annotation_df["x"] = np.concatenate(list(map(lambda k: k[:, 0, :, 0], keypoints)))[:, 0] 255 | annotation_df["y"] = np.concatenate(list(map(lambda k: k[:, 0, :, 0], keypoints)))[:, 1] 256 | 257 | annotation_df["qx"] = pd.cut(annotation_df.x, 20, labels=False) 258 | annotation_df["qy"] = pd.cut(annotation_df.y, 20, labels=False) 259 | 260 | # %% 261 | plt.imshow( 262 | annotation_df.pivot_table( 263 | index="qx", columns="qy", values="annotation", aggfunc=lambda v: (v == 0).mean() 264 | ) 265 | ) 266 | plt.title("Attack") 267 | plt.show() 268 | 269 | # %% 270 | plt.imshow( 271 | annotation_df.pivot_table( 272 | index="qx", columns="qy", values="annotation", aggfunc=lambda v: (v == 1).mean() 273 | ) 274 | ) 275 | plt.title("Investigation") 276 | plt.show() 277 | 278 | # %% 279 | plt.imshow( 280 | annotation_df.pivot_table( 281 | index="qx", columns="qy", values="annotation", aggfunc=lambda v: (v == 2).mean() 282 | ) 283 | ) 284 | plt.title("Mount") 285 | plt.show() 286 | 287 | # %% 288 | plt.imshow( 289 | annotation_df.pivot_table( 290 | index="qx", columns="qy", values="annotation", aggfunc=lambda v: (v == 3).mean() 291 | ) 292 | ) 293 | plt.title("Other") 294 | plt.colorbar() 295 | plt.show() 296 | 297 | # %% 298 | c0 = 1485.151664444058 299 | c1 = 2308.0343145614083 300 | 301 | 302 | features = np.tanh( 303 | np.concatenate( 304 | (annotation_df[["timestep"]] / c0, annotation_df[["timestep_reverse"]] / c1), axis=-1 305 | ) 306 | ) 307 | wall_dist = np.linalg.norm( 308 | np.stack( 309 | ( 310 | np.stack((annotation_df.x, annotation_df.x.max() - annotation_df.x)).min(axis=0), 311 | np.stack((annotation_df.y, annotation_df.y.max() - annotation_df.y)).min(axis=0), 312 | ) 313 | ), 314 | axis=0, 315 | ) 316 | 317 | # %% 318 | def evaluate_f1_wdist(c): 319 | features = np.tanh(wall_dist / c)[:, None] 320 | l = linear.fit( 321 | features, 322 | annotation_df["annotation"], 323 | ) 324 | return sklearn.metrics.f1_score( 325 | annotation_df["annotation"], 326 | l.predict(features), 327 | average="macro", 328 | ) 329 | 330 | 331 | # %% 332 | pbounds = {"c": (1, 600)} 333 | 334 | optimizer = BayesianOptimization( 335 | f=evaluate_f1_wdist, 336 | pbounds=pbounds, 337 | random_state=42, 338 | ) 339 | 340 | optimizer.maximize( 341 | init_points=5, 342 | n_iter=50, 343 | ) 344 | 345 | print(optimizer.max) 346 | 347 | # %% 348 | plot_gp(optimizer, np.linspace(*pbounds["c"], num=1000)[:, None]) 349 | -------------------------------------------------------------------------------- /mabe/data.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import pickle 3 | 4 | import h5py 5 | import numba 6 | import numpy as np 7 | import sklearn 8 | import sklearn.preprocessing 9 | 10 | import mabe.config 11 | 12 | 13 | @dataclasses.dataclass 14 | class TrainingBatch: 15 | X: numba.typed.List # [np.array] 16 | X_extra: numba.typed.List # [np.array] 17 | Y: numba.typed.List # [np.array] 18 | indices: np.array 19 | annotators: numba.typed.List # [np.array] 20 | clf_tasks: np.array 21 | Y_dark_behaviors: numba.typed.List # [np.array] 22 | Y_dark_annotators: numba.typed.List # [np.array] 23 | 24 | 25 | class DataWrapper: 26 | def __init__(self, feature_path): 27 | self.vocabulary = ["attack", "investigation", "mount", "other"] 28 | 29 | with h5py.File(feature_path, "r") as hdf: 30 | 31 | def load_all(groupname): 32 | return list(map(lambda v: v[:].astype(np.float32), hdf[groupname].values())) 33 | 34 | self.X_labeled = load_all("train/x") 35 | self.Y_labeled = load_all("train/y") 36 | 37 | self.annotators_labeled = list(map(lambda v: v[()], hdf["train/annotators"].values())) 38 | self.clf_tasks_labeled = np.array( 39 | list(map(lambda v: int(v[()]), hdf["train/clf_tasks"].values())) 40 | ) 41 | 42 | self.X_unlabeled = load_all("test/x") 43 | self.Y_unlabeled = load_all("test/y") 44 | 45 | self.Y_unlabeled_dark_behaviors = load_all("test/y_dark_behaviors") 46 | num_dark_behaviors = self.Y_unlabeled_dark_behaviors[0].shape[1] 47 | num_dark_behaviors_classes = self.Y_unlabeled_dark_behaviors[0].shape[2] 48 | self.Y_labeled_dark_behaviors = [ 49 | np.full( 50 | (len(y), num_dark_behaviors, num_dark_behaviors_classes), -1, dtype=np.float32 51 | ) 52 | for y in self.Y_labeled 53 | ] 54 | 55 | self.Y_unlabeled_dark_annotators = load_all("test/y_dark_annotators") 56 | num_dark_annotators = self.Y_unlabeled_dark_annotators[0].shape[1] 57 | num_dark_annotators_classes = self.Y_unlabeled_dark_annotators[0].shape[2] 58 | self.Y_labeled_dark_annotators = [ 59 | np.full( 60 | (len(y), num_dark_annotators, num_dark_annotators_classes), -1, dtype=np.float32 61 | ) 62 | for y in self.Y_labeled 63 | ] 64 | 65 | self.annotators_unlabeled = [-1] * len(self.X_unlabeled) 66 | self.clf_tasks_unlabeled = np.array([-1] * len(self.X_unlabeled)) 67 | 68 | try: 69 | self.X_labeled_extra = load_all("train/x_extra") 70 | self.X_unlabeled_extra = load_all("test/x_extra") 71 | except KeyError: 72 | self.X_labeled_extra = None 73 | self.X_unlabeled_extra = None 74 | 75 | self.groups_unlabeled = list(map(lambda v: v[()], hdf["test/groups"].values())) 76 | 77 | self.X = self.X_labeled + self.X_unlabeled 78 | self.Y = self.Y_labeled + self.Y_unlabeled 79 | self.Y_dark_behaviors = self.Y_labeled_dark_behaviors + self.Y_unlabeled_dark_behaviors 80 | self.Y_dark_annotators = self.Y_labeled_dark_annotators + self.Y_unlabeled_dark_annotators 81 | self.annotators = self.annotators_labeled + self.annotators_unlabeled 82 | self.num_annotators = len(np.unique(self.annotators_labeled)) 83 | self.clf_tasks = np.concatenate((self.clf_tasks_labeled, self.clf_tasks_unlabeled)) 84 | self.num_clf_tasks = len(np.unique(self.clf_tasks)) 85 | 86 | if self.X_labeled_extra is not None: 87 | self.X_extra = self.X_labeled_extra + self.X_unlabeled_extra 88 | self.num_extra_features = self.X_extra[0].shape[-1] 89 | else: 90 | self.X_extra = None 91 | self.num_extra_features = 0 92 | 93 | scaler = sklearn.preprocessing.StandardScaler().fit(np.concatenate(self.X)) 94 | self.X = list(map(lambda x: scaler.transform(x), self.X)) 95 | self.X_labeled = list(map(lambda x: scaler.transform(x), self.X_labeled)) 96 | self.X_unlabeled = list(map(lambda x: scaler.transform(x), self.X_unlabeled)) 97 | 98 | self.sample_lengths = np.array(list(map(len, self.X))) 99 | 100 | 101 | class CVSplit: 102 | def __init__(self, split_idx, data): 103 | split = pickle.load(open(mabe.config.ROOT_PATH / f"split_{split_idx}.pkl", "rb")) 104 | 105 | self.indices_labeled = split["indices_labeled"] 106 | self.indices_unlabeled = split["indices_unlabeled"] 107 | self.indices = split["indices"] 108 | self.train_indices_labeled = split["train_indices_labeled"] 109 | self.train_indices_unlabeled = split["train_indices_unlabeled"] 110 | self.train_indices = split["train_indices"] 111 | self.val_indices_labeled = split["val_indices_labeled"] 112 | self.val_indices_unlabeled = split["val_indices_unlabeled"] 113 | self.val_indices = split["val_indices"] 114 | 115 | self.data = data 116 | self.calculate_draw_probs() 117 | 118 | def calculate_draw_probs(self): 119 | data = self.data 120 | sample_lengths = data.sample_lengths 121 | 122 | self.p_draw = sample_lengths / np.sum(sample_lengths) 123 | 124 | self.p_draw_labeled = sample_lengths[self.indices_labeled] / np.sum( 125 | sample_lengths[self.indices_labeled] 126 | ) 127 | self.p_draw_unlabeled = sample_lengths[self.indices_unlabeled] / np.sum( 128 | sample_lengths[self.indices_unlabeled] 129 | ) 130 | self.p_draw_train_labeled = sample_lengths[self.train_indices_labeled] / np.sum( 131 | sample_lengths[self.train_indices_labeled] 132 | ) 133 | self.p_draw_train_unlabeled = sample_lengths[self.train_indices_unlabeled] / np.sum( 134 | sample_lengths[self.train_indices_unlabeled] 135 | ) 136 | self.p_draw_train = sample_lengths[self.train_indices] / np.sum( 137 | sample_lengths[self.train_indices] 138 | ) 139 | self.p_draw_val_labeled = sample_lengths[self.val_indices_labeled] / np.sum( 140 | sample_lengths[self.val_indices_labeled] 141 | ) 142 | self.p_draw_val_unlabeled = sample_lengths[self.val_indices_unlabeled] / np.sum( 143 | sample_lengths[self.val_indices_unlabeled] 144 | ) 145 | self.p_draw_val = sample_lengths[self.val_indices] / np.sum( 146 | sample_lengths[self.val_indices] 147 | ) 148 | 149 | def get_train_batch( 150 | self, batch_size, random_noise=0.0, extra_features=False, dark_knowledge=False 151 | ): 152 | def random_task_train_index(task): 153 | task_train_indices = self.train_indices_labeled[ 154 | self.data.clf_tasks[self.train_indices_labeled] == task 155 | ] 156 | task_p_draw = self.data.sample_lengths[task_train_indices].astype(np.float) 157 | task_p_draw /= np.sum(task_p_draw) 158 | return np.array([np.random.choice(task_train_indices, p=task_p_draw)]) 159 | 160 | indices_batch_unlabeled = np.random.choice( 161 | self.indices_unlabeled, 162 | size=int(0.75 * batch_size - self.data.clf_tasks.max() - 1), 163 | p=self.p_draw_unlabeled, 164 | ) 165 | # at least one sample per task 166 | indices_batch = np.concatenate( 167 | ( 168 | np.random.choice( 169 | self.train_indices_labeled, 170 | size=int(0.25 * batch_size), 171 | p=self.p_draw_train_labeled, 172 | ), 173 | *[ 174 | random_task_train_index(task) 175 | for task in range(0, self.data.clf_tasks.max() + 1) 176 | ], 177 | indices_batch_unlabeled, 178 | ) 179 | ) 180 | 181 | assert np.all([i not in self.val_indices_labeled for i in indices_batch]) 182 | 183 | X_batch = numba.typed.List() 184 | augment = lambda x: x + np.random.randn(*x.shape) * random_noise 185 | if random_noise > 0: 186 | [X_batch.append(augment(self.data.X[i])) for i in indices_batch] 187 | else: 188 | [X_batch.append(self.data.X[i]) for i in indices_batch] 189 | 190 | Y_batch = numba.typed.List() 191 | [Y_batch.append(self.data.Y[i].astype(int)) for i in indices_batch] 192 | 193 | if dark_knowledge: 194 | Y_batch_dark_behaviors = numba.typed.List() 195 | [ 196 | Y_batch_dark_behaviors.append(self.data.Y_dark_behaviors[i].astype(np.float32)) 197 | for i in indices_batch 198 | ] 199 | Y_batch_dark_annotators = numba.typed.List() 200 | [ 201 | Y_batch_dark_annotators.append(self.data.Y_dark_annotators[i].astype(np.float32)) 202 | for i in indices_batch 203 | ] 204 | else: 205 | Y_batch_dark_behaviors = None 206 | Y_batch_dark_annotators = None 207 | 208 | annotators_batch = numba.typed.List() 209 | [ 210 | annotators_batch.append(np.array([self.data.annotators[i]]).repeat(len(y))) 211 | for i, y in zip(indices_batch, Y_batch) 212 | ] 213 | 214 | clf_tasks_batch = self.data.clf_tasks[indices_batch] 215 | 216 | for i, task in enumerate(clf_tasks_batch): 217 | if task > 0: 218 | assert np.all(annotators_batch[i] == 0) 219 | 220 | X_extra_batch = None 221 | if extra_features: 222 | X_extra_batch = numba.typed.List() 223 | [X_extra_batch.append(self.data.X_extra[i].astype(int)) for i in indices_batch] 224 | 225 | assert np.all([i not in self.val_indices_labeled for i in indices_batch]) 226 | assert np.all( 227 | [ 228 | Y_batch[i].max() < 0 229 | for i in range( 230 | len(indices_batch) - len(indices_batch_unlabeled), len(indices_batch) 231 | ) 232 | ] 233 | ) 234 | 235 | return TrainingBatch( 236 | X_batch, 237 | X_extra_batch, 238 | Y_batch, 239 | indices_batch, 240 | annotators_batch, 241 | clf_tasks_batch, 242 | Y_batch_dark_behaviors, 243 | Y_batch_dark_annotators, 244 | ) 245 | -------------------------------------------------------------------------------- /hydrogen/task3_onlylinear.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | import sklearn 5 | import sklearn.decomposition 6 | import sklearn.linear_model 7 | import sklearn.pipeline 8 | import sklearn.preprocessing 9 | import torch 10 | from fastprogress.fastprogress import force_console_behavior 11 | 12 | import mabe 13 | import mabe.config 14 | import mabe.features 15 | import mabe.model 16 | 17 | master_bar, progress_bar = force_console_behavior() 18 | 19 | # %% 20 | device = "cuda:3" 21 | 22 | # %% 23 | result_file = "training_results_2021-04-16 05:26:50.104528_baseline2_task12_smallcontext_0.845.pt" 24 | # TODO: use all runs 25 | result = torch.load(mabe.config.ROOT_PATH / result_file)[0] 26 | 27 | # %% 28 | config = result.config 29 | cpc_params = result.best_params[0] 30 | num_features = 37 31 | num_extra_features = 2 32 | 33 | cpc = mabe.model.ConvCPC( 34 | num_features, 35 | config.num_embeddings, 36 | config.num_context, 37 | config.num_ahead, 38 | config.num_ahead_subsampling, 39 | config.subsample_length, 40 | num_embedder_blocks=config.num_embedder_blocks, 41 | input_dropout=config.input_dropout, 42 | head_dropout=config.head_dropout, 43 | dropout=config.dropout, 44 | split_idx=config.split_idx, 45 | num_extra_features=num_extra_features, 46 | ).to(device) 47 | 48 | cpc.load_state_dict(cpc_params) 49 | cpc = cpc.eval() 50 | 51 | # %% 52 | task3_path = mabe.config.ROOT_PATH / "train_task3.npy" 53 | test_path = mabe.config.ROOT_PATH / "test-release.npy" 54 | 55 | # %% 56 | X_test, X_test_extra, _, groups_test, _ = mabe.features.load_dataset(test_path) 57 | 58 | features_test = [] 59 | with torch.no_grad(): 60 | for idx in range(len(X_test)): 61 | # from feature preprocessing 62 | crop_pre = 1 63 | crop_post = 0 64 | 65 | group = groups_test[idx] 66 | x = X_test[idx].astype(np.float32) 67 | if config.use_extra_features: 68 | x_extra = X_test_extra[idx].astype(np.float32) 69 | x_extra = torch.from_numpy(x_extra).to(device, non_blocking=True) 70 | 71 | x = torch.transpose(torch.from_numpy(x[None, :, :]), 2, 1).to(device, non_blocking=True) 72 | x_emb = cpc.embedder(x) 73 | 74 | crop = (x.shape[-1] - x_emb.shape[-1]) // 2 75 | crop_pre += crop 76 | crop_post += crop 77 | 78 | c = cpc.apply_contexter(x_emb, device) 79 | 80 | crop = x_emb.shape[-1] - c.shape[-1] 81 | crop_pre += crop 82 | 83 | logreg_features = c[0].T 84 | if config.use_extra_features: 85 | x_extra = x_extra[crop_pre : -(crop_post - 1)] 86 | 87 | if config.use_extra_features: 88 | x_cominbed = torch.cat((logreg_features, x_extra), dim=-1) 89 | else: 90 | x_combined = logreg_features 91 | 92 | x_combined = x_cominbed.cpu().data.numpy() 93 | features_test.append(x_combined) 94 | features_test = np.concatenate(features_test) 95 | 96 | # %% 97 | cv_scores = [] 98 | for behavior, (X, X_extra, Y, groups, annotators) in mabe.features.load_task3_datasets(task3_path): 99 | X_flat = [] 100 | Y_flat = [] 101 | groups_flat = [] 102 | with torch.no_grad(): 103 | for idx in range(len(X)): 104 | # from feature preprocessing 105 | crop_pre = 1 106 | crop_post = 0 107 | 108 | x = X[idx].astype(np.float32) 109 | x_extra = None 110 | if config.use_extra_features: 111 | x_extra = X_extra[idx].astype(np.float32) 112 | x_extra = torch.from_numpy(x_extra).to(device, non_blocking=True) 113 | 114 | g = np.array([idx]) 115 | 116 | x = torch.transpose(torch.from_numpy(x[None, :, :]), 2, 1).to(device, non_blocking=True) 117 | x_emb = cpc.embedder(x) 118 | 119 | crop = (x.shape[-1] - x_emb.shape[-1]) // 2 120 | crop_pre += crop 121 | crop_post += crop 122 | 123 | c = cpc.apply_contexter(x_emb, device) 124 | 125 | crop = x_emb.shape[-1] - c.shape[-1] 126 | crop_pre += crop 127 | 128 | logreg_features = c[0].T 129 | x_extra = x_extra[crop_pre : -(crop_post - 1)] 130 | y = Y[idx][crop_pre : -(crop_post - 1)] 131 | 132 | x_cominbed = torch.cat((logreg_features, x_extra), dim=-1) 133 | 134 | X_flat.append(x_cominbed.cpu().data.numpy()) 135 | Y_flat.append(y) 136 | groups_flat.append(g.repeat(len(y))) 137 | 138 | X_flat = np.concatenate(X_flat) 139 | Y_flat = np.concatenate(Y_flat) 140 | groups_flat = np.concatenate(groups_flat) 141 | 142 | print(behavior) 143 | print(len(np.unique(groups_flat))) 144 | if len(np.unique(groups_flat)) > 1: 145 | cv = sklearn.model_selection.GroupShuffleSplit(8) 146 | else: 147 | cv = sklearn.model_selection.StratifiedShuffleSplit(8) 148 | 149 | X_flat_all = np.concatenate((X_flat, features_test)) 150 | scaler = sklearn.preprocessing.StandardScaler().fit(X_flat_all) 151 | X_flat = scaler.transform(X_flat) 152 | 153 | linear = sklearn.pipeline.make_pipeline( 154 | sklearn.linear_model.LogisticRegression( 155 | multi_class="multinomial", class_weight="balanced", max_iter=1000, C=1e-1 156 | ) 157 | ) 158 | scores = sklearn.model_selection.cross_validate( 159 | linear, 160 | X_flat, 161 | Y_flat, 162 | n_jobs=8, 163 | cv=cv, 164 | groups=groups_flat, 165 | scoring=dict( 166 | f1=sklearn.metrics.make_scorer(sklearn.metrics.f1_score), # , average="macro"), 167 | precision=sklearn.metrics.make_scorer( 168 | sklearn.metrics.precision_score 169 | ), # , average="macro"), 170 | ), 171 | ) 172 | if len(np.unique(groups_flat)) > 1: 173 | cv_scores.append(scores["test_f1"]) 174 | print(np.median(scores["test_f1"])) 175 | print() 176 | 177 | print(np.mean(cv_scores)) 178 | 179 | 180 | # %% 181 | submission: dict[str, dict] = {} 182 | for behavior, (X, X_extra, Y, groups, annotators) in mabe.features.load_task3_datasets(task3_path): 183 | submission[behavior] = dict() 184 | X_flat = [] 185 | Y_flat = [] 186 | groups_flat = [] 187 | with torch.no_grad(): 188 | for idx in range(len(X)): 189 | # from feature preprocessing 190 | crop_pre = 1 191 | crop_post = 0 192 | 193 | x = X[idx].astype(np.float32) 194 | x_extra = None 195 | if config.use_extra_features: 196 | x_extra = X_extra[idx].astype(np.float32) 197 | x_extra = torch.from_numpy(x_extra).to(device, non_blocking=True) 198 | 199 | g = np.array([idx]) 200 | 201 | x = torch.transpose(torch.from_numpy(x[None, :, :]), 2, 1).to(device, non_blocking=True) 202 | x_emb = cpc.embedder(x) 203 | 204 | crop = (x.shape[-1] - x_emb.shape[-1]) // 2 205 | crop_pre += crop 206 | crop_post += crop 207 | 208 | c = cpc.apply_contexter(x_emb, device) 209 | 210 | crop = x_emb.shape[-1] - c.shape[-1] 211 | crop_pre += crop 212 | 213 | logreg_features = c[0].T 214 | x_extra = x_extra[crop_pre : -(crop_post - 1)] 215 | y = Y[idx][crop_pre : -(crop_post - 1)] 216 | 217 | x_cominbed = torch.cat((logreg_features, x_extra), dim=-1) 218 | 219 | X_flat.append(x_cominbed.cpu().data.numpy()) 220 | Y_flat.append(y) 221 | groups_flat.append(g.repeat(len(y))) 222 | 223 | X_flat = np.concatenate(X_flat) 224 | Y_flat = np.concatenate(Y_flat) 225 | groups_flat = np.concatenate(groups_flat) 226 | 227 | X_flat_all = np.concatenate((X_flat, features_test)) 228 | scaler = sklearn.preprocessing.StandardScaler().fit(X_flat_all) 229 | X_flat = scaler.transform(X_flat) 230 | 231 | linear = sklearn.pipeline.make_pipeline( 232 | sklearn.linear_model.LogisticRegression( 233 | multi_class="multinomial", class_weight="balanced", max_iter=1000, C=1e-1 234 | ) 235 | ) 236 | linear.fit(X_flat, Y_flat) 237 | 238 | with torch.no_grad(): 239 | for idx in range(len(X_test)): 240 | # from feature preprocessing 241 | crop_pre = 1 242 | crop_post = 0 243 | 244 | group = groups_test[idx] 245 | x = X_test[idx].astype(np.float32) 246 | x_extra = None 247 | if config.use_extra_features: 248 | x_extra = X_test_extra[idx].astype(np.float32) 249 | x_extra = torch.from_numpy(x_extra).to(device, non_blocking=True) 250 | 251 | x = torch.transpose(torch.from_numpy(x[None, :, :]), 2, 1).to(device, non_blocking=True) 252 | x_emb = cpc.embedder(x) 253 | 254 | crop = (x.shape[-1] - x_emb.shape[-1]) // 2 255 | crop_pre += crop 256 | crop_post += crop 257 | 258 | c = cpc.apply_contexter(x_emb, device) 259 | 260 | crop = x_emb.shape[-1] - c.shape[-1] 261 | crop_pre += crop 262 | 263 | logreg_features = c[0].T 264 | x_extra = x_extra[crop_pre : -(crop_post - 1)] 265 | 266 | x_cominbed = torch.cat((logreg_features, x_extra), dim=-1) 267 | x_combined = x_cominbed.cpu().data.numpy() 268 | 269 | y_pred = linear.predict(scaler.transform(x_combined)) 270 | # TODO: off-by-one? 271 | y_pred = np.concatenate( 272 | (y_pred[:1].repeat(crop_pre), y_pred, y_pred[-1:].repeat(crop_post)) 273 | ) 274 | 275 | submission[behavior][group] = y_pred 276 | 277 | # %% 278 | sample_submission = np.load( 279 | mabe.config.ROOT_PATH / "sample-submission-task3.npy", allow_pickle=True 280 | ).item() 281 | 282 | 283 | def validate_submission(submission, sample_submission): 284 | if not isinstance(submission, dict): 285 | print("Submission should be dict") 286 | return False 287 | 288 | if not submission.keys() == sample_submission.keys(): 289 | print("Submission keys don't match") 290 | return False 291 | for behavior in submission: 292 | sb = submission[behavior] 293 | ssb = sample_submission[behavior] 294 | if not isinstance(sb, dict): 295 | print("Submission should be dict") 296 | return False 297 | 298 | if not sb.keys() == ssb.keys(): 299 | print("Submission keys don't match") 300 | return False 301 | 302 | for key in sb: 303 | sv = sb[key] 304 | ssv = ssb[key] 305 | if not len(sv) == len(ssv): 306 | print(f"Submission lengths of {key} doesn't match") 307 | return False 308 | 309 | for key, sv in sb.items(): 310 | if not all(isinstance(x, (np.int32, np.int64, int)) for x in list(sv)): 311 | print(f"Submission of {key} is not all integers") 312 | return False 313 | 314 | print("All tests passed") 315 | return True 316 | 317 | 318 | # %% 319 | if validate_submission(submission, sample_submission): 320 | np.save(mabe.config.ROOT_PATH / "task3_submission2.npy", submission) 321 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Animal Behavior challenge solution outline 2 | Author : Benjamin Wild () 3 | 4 | ## Overview 5 | 6 | The core idea of my solution was to get as much use as possible out of the limited amount of data 7 | available using a combination of several methods: 8 | 9 | * Feature preprocessing instead of augmentation: Use egocentric representation of the data with 10 | relative orientations and velocities. I added a PCA embedding of the graph consisting of 11 | the pairwise distances of all points of both individuals to the features, and I'm also using 12 | absolute spatial and temporal information in the final classification layer of the model. 13 | 14 | * Semi-supervised learning using the InfoNCE / CPC objective. My initial attempt used a 15 | unsupervised CPC model pretrained on the entire dataset (train and test data combined). A linear 16 | classifier on these embeddings is about as good as the baseline model in Task 1. To further 17 | improve the accuracy, I then trained the model in a semi-supervised fashion with 75% of the 18 | samples in each batch sampled from the test sequences and the remaining 25% from the train 19 | sequences. The InfoNCE loss was computed for all samples and the task specific classification 20 | losses only for the samples from the train sequences 21 | ([Oord et al., 2019](https://arxiv.org/abs/1807.03748), 22 | [Chen et al, 2020](https://arxiv.org/abs/2002.05709v3)). 23 | 24 | * Joint training of one model for all tasks: Instead of treating the three tasks separately, I opted 25 | to jointly train one model because I assumed that the task objectives are strongly correlated 26 | (i.e., the Task 1 classification task would regularize the model for the Task 3 classification 27 | tasks and vice versa). To that end, I use one model that extracts embeddings from the sequences 28 | and stack a multi-head classification layer for the three tasks on top of it. 29 | 30 | * CV / Ensembling / Label smoothing: I relied extensively on my local cross validation. I 31 | group samples by their sequence ID and ensure that in each CV split, all samples from one sequence 32 | are either in the train or validation set (I assumed this was also how the public / private 33 | leaderboard worked, but based on the small differences between the public and final scores this 34 | is apparently not the case). This local CV pipeline had a number of benefits: a) I could evaluate 35 | if changes to the model significantly improved the performance without relying on and potentially 36 | overfitting to the public leaderboard. b) For each model trained on a CV split, I stored the 37 | parameters with the lowest validation loss. I then used these models in an ensemble for the 38 | submissions. I used label smoothing to improve the calibration of the model, thereby increasing 39 | the accuracy of the ensemble ([Guo et al., 2017](https://arxiv.org/abs/1706.04599), 40 | [Müller et al., 2020](https://arxiv.org/abs/1906.02629)). 41 | 42 | * Dark knowledge: I tried to improve the performance by using the predictions from the ensemble as 43 | soft labels for the test data, i.e. to bootstrap training labels for the model using an ensemble 44 | of the previous version of the model. In a previous project of mine, this worked quite well and 45 | I hoped that it would be very beneficial for the Task 3 behaviors. In the end, this approach only 46 | marginally improved the performance for Task 1, and not at all for Task 2 and 3. One explanation 47 | for this could be that the test data for which I created those soft labels was also the data the 48 | model was scored on for the leaderboards. Maybe this approach would have worked better with a 49 | truly separate dataset for unsupervised training? 50 | ([Hinton et al., 2015](https://arxiv.org/abs/1503.02531)) 51 | 52 | * Model architecture: No big surprises here. I use residual blocks with convolutional layers in a 53 | more or less standard ResNet/CPC architecture. The first part of the model (the embedder, 54 | in CPC terms) is non-causal. The second part (the CPC contexter) is a causal model. The CPC 55 | contexts are used in the InfoNCE objective and also in the classification head. I use LayerNorm to 56 | avoid potential problems of data leakage with BatchNorm and the CPC objective. 57 | 58 | ## Details 59 | 60 | The implementation consists of two parts: a) A core library located in the folder `mabe/` containing 61 | most of the code. b) A number of hydrogen notebooks located in the folder `hydrogen/` for EDA, 62 | feature preprocessing, submission creation, and in general for "throwaway" / experimental code, most 63 | of which never made it into this repository. 64 | 65 | ### Feature preprocessing 66 | 67 | `hydrogen/features.py`, `hydrogen/getposepca.py`, `mabe/features.py` 68 | 69 | Data augmentation is often used to encode domain knowledge in a machine learning model. Here, I 70 | attempted to encode this prior knowledge in the model architecture and using preprocessing of 71 | the features. 72 | 73 | I define the orientation of a mouse as the vector from its tail to neck coordinates and compute 74 | distances, velocities, and angles between coordinates relative to this orientation, i.e. in a 75 | egocentric coordinate system. Angles are represented as unit vectors. 76 | 77 | I also included a learned representation of the graph of all point-to-point euclidean distances of 78 | all tracking points of both mice, i.e. for two mice with 7 tracking points each the graph is 79 | stored in a 14 x 14 distance matrix. I then compute PCA features of the condensed form of these 80 | distance matrices for all points in the train and test sets and keep the first $n$ principal 81 | components that explain at least 95% of the total variance. 82 | 83 | I also compute the euclidean distance to the next wall and temporal features, see `Joint training` 84 | for more details. 85 | 86 | ### Semi-supervised learning 87 | 88 | `mabe/model.py` 89 | 90 | The model utilizes the CPC / InfoNCE loss, a domain-agnostic unsupervised learning objective. The 91 | core idea is as follows: A small non-causal convolutional neural network (`Embedder`) processes the 92 | preprocessed input sequences and returns embedded sequences $e_t$ of the same length (same number of 93 | temporal steps). A much bigger causal convolutional network (`Contexter`) then returns contexts 94 | $c_t$ for each temporal step. The model also jointly learns a number of linear projections $W_n$ 95 | that map contexts $c_t$ into the future $c_{t+n}$. The InfoNCE objective is minimized if the cosine 96 | similarity from these mapped contexts $c_{t+n}$ is 1 for $e_{t+n}$ from the same sequence and 0 for 97 | randomly sampled $e$s from other sequences. See [Oord et al., 98 | 2019](https://arxiv.org/abs/1807.03748) for more details. Note that the loss used in my 99 | implementation is a slight variation of the original CPC loss (`NT-Xent`) as described in [Chen et 100 | al., 2020](https://arxiv.org/abs/2002.05709v3). 101 | 102 | In my implementation, I chose to use a fixed ratio (1:3) of labeled and unlabeled samples in each 103 | batch and computed the InfoNCE objective for all samples and the classification loss only for the 104 | labeled samples. In each batch, I sample `batch_size` sequences proportional to their length (i.e., 105 | a longer sequence is sampled more often) and then uniformly sample a temporal subsequence of each 106 | sequence. 107 | 108 | ### Joint training 109 | 110 | `mabe/training.py` 111 | 112 | The core idea here was to share as many of the model parameters as possible for all tasks and the 113 | CPC objective. Therefore, only the last part of the model (`MultiAnnotatorLogisticRegressionHead`) 114 | is task-specific. 115 | 116 | To stabilize the training, batches are sampled s.t. they always contain at least one sample for each 117 | task, i.e. at least one sample is for the Task 0 behaviors, one for behavior-0 of Task 3, and so on. 118 | 119 | The `MultiAnnotatorLogisticRegressionHead` consists of a LayerNorm layer, a residual block, and 120 | final linear layers for each classification from Task 1 and 3 (i.e., a multinomial classification 121 | layer for Task 1, and 7 binary classification layers for Task 3). The residual block has one 122 | additional input: A learned embedding of the annotator for Task 3. This embedding is initialized as 123 | a diagonal matrix, i.e. the embedding for annotator 0 will initially be $[1, 0, 0, 0, 0, 0]$. The 124 | model can then learn similar embeddings for annotators with a similar style and the residual block 125 | can modify the inputs to the classification head to match the style of each annotator. The annotator 126 | embeddings are kept small to avoid overfitting, but I did not experiment with larger or different 127 | kinds of embeddings. 128 | 129 | The losses were scaled to prevent overfitting to the Task 3 classification tasks with a much lower 130 | amount of training data. 131 | 132 | Finally, a number of features (`extra_features`) are only concatenated in this final regression head 133 | to avoid overfitting. I used representations of space (tanh of scaled euclidean distance to the 134 | wall) and time (tanh of scaled timestep and (num_timesteps - timestep)) as extra features, where the 135 | scaling factors where determined via a bayesian hyperparameter optimization (see `hydrogen/eda.py`). 136 | 137 | ### CV / Ensembling / Label smoothing 138 | 139 | `hydrogen/create_splits.py`, `hydrogen/merge_results_alltasks.py`, `mabe/loss.py` 140 | 141 | To be able to reliably measure if modifications to the model improved the performance without 142 | overfitting to the public leaderboard, I created a fixed set of 32 cross-validation splits, whereby 143 | one sequence would always be completely in either the training or the validation split. Because of 144 | runtime and compute constraints, I usually only trained models on the first 10 splits and only 145 | trained model on the full set of CV splits prior to a submission. 146 | 147 | During each run, I stored the highest validation F1 scores for Task 1 and individually for each 148 | behavior from Task 3. I also kept a copy of the parameters of the entire model with the highest 149 | validation F1 scores for each behavior. I didn't explicitly measure the F1 scores for Task 2, 150 | because I assumed they would be strongly correlated with the Task 1 scores. After the training, I 151 | stored the validation scores and also predictions (logits) of the model with the highest scores in a 152 | `TrainingResult` dataclass object and stored a compressed copy of these results to the filesystem. 153 | 154 | Before submission, I loaded the predictions from each model and computed the average predicted 155 | per-class probabilities. The ensemble prediction was then simply the argmax of these averaged model 156 | predictions. Using a weighted mean with the validation F1 scores as weights did not significantly 157 | improve the results. 158 | 159 | Using a grouped cross validation approach worked well for Task 1 and 2, but was somewhat problematic 160 | for Task 3, where only a small amount of sequences were available for each behavior. 161 | 162 | ### Dark knowledge 163 | 164 | `hydrogen/features.py`, `mabe/model.py` 165 | 166 | For Task 3, using grouped CV is problematic because only few (in one case only 1) sequences exists 167 | per behavior. I tried to circumvent this by first training models on the CV splits, and them use the 168 | ensemble to create additional training data, thereby effectively bootstrapping a bigger training set 169 | from the dark knowledge ([Hinton et al., 2015](https://arxiv.org/abs/1503.02531)) of the ensemble. 170 | Because the ensemble consists of models trained on different CV splits, it has effectively been 171 | trained on the entire training set (all sequences). 172 | 173 | While I think that the idea is valid and I've successfully used this approach in a previous 174 | project, it didn't work nearly as well as I was hoping for in this challenge. I wasn't able to 175 | definitely figure out why, but these are the two potential problems that I see: a) When training 176 | using the dark knowledge loss, you're effectively leaking information from the validation split (via 177 | the knowledge contained in the ensemble), thereby making strategies like early stopping on the 178 | validation loss problematic. One way around this would be to use an additional test split, but for 179 | most behaviors in Task 3 there are not enough sequences available to do this properly. b) I used the 180 | dark knowledge term only for the test sequences, but these are also the sequences which get used for 181 | scoring. Maybe the model is able to overfit to the predictions of the ensemble on the test 182 | sequences, thereby rendering the loss term useless when trying to improve the predictions for the 183 | test sequences. 184 | 185 | It is possible that this approach would've worked better with a separate unlabeled dataset, or maybe 186 | even with different loss scaling factors or by applying the dark knowledge term to the train 187 | sequences from the other tasks, but I wasn't able to properly investigate this before the deadlines. 188 | 189 | ### Additional details 190 | 191 | * All final models use the `MADGRAD` optimizer with cosine learning rate annealing 192 | [Defazio et al., 2021](https://arxiv.org/abs/2101.11075). 193 | 194 | * Hyperparameters for the final model: All models use the default hyperparameters as defined in the 195 | `TrainingConfig` dataclass in `mabe/model.py` except for the Task 3 ensemble, for which 196 | `dark_knowledge_loss_scaler` was set to 0. 197 | 198 | * Almost no "traditional regularisation": A small weight decay of $1e-5$ is used during 199 | optimization. I tried to apply dropout at various positions in the model, but it never increased 200 | the model performance. I also briefly experimented with augmentation, but most reasonable 201 | augmentations are not necessary anymore after feature preprocessing. Domain knowledge might be 202 | helpful in designing better augmentations. 203 | 204 | * I used `sklearn.utils.class_weight.compute_class_weight` for all classification tasks based on the 205 | entire dataset (train + validation) data. 206 | 207 | ## How to reproduce results 208 | 209 | The code assumes that all data (e.g., `train.npy`), is stored in the location defined in `ROOT_PATH` 210 | if `mabe/config.py`. Alternatively, the environment flag `MABE_ROOT_PATH` can be set to override 211 | this config variable. 212 | 213 | 0. Optional: Use the `hydrogen/getposepca.py` notebook to get the PCA model for the point-to-point 214 | distance graph features based on all data points from all three tasks. 215 | 216 | 1. Run feature preprocessing notebook: `hydrogen/features.py`. This will create a hdf5 file with 217 | the preprocessed sequences for all three tasks and the test data. Note: To reproduce the final 218 | results, you need to have a previously trained ensemble of models for the dark knowledge loss terms 219 | and a pretrained PCA for the relative positions of the tracking points. The improvements from 220 | these approaches are marginal and could also be ignored. 221 | 222 | 2. Create cross validation splits: I initially created a set of 32 fixed CV splits to be able to 223 | reliably test the effects of modifications to the model. These splits can be created using the 224 | `hydrogen/create_splits` notebook. 225 | 226 | 3. Train models using training scripts `scripts/train.py`: This command line script trains a batch 227 | on model on the given CV splits and returns the cross validated F1 scores. All training 228 | hyperparameters defined in the `TrainingConfig` dataclass in `mabe/training.py` can be set using 229 | command line flags. Example: 230 | 231 | > `./train.py --model_name "ensemble_0to5" --device "cuda:0" --feature_path features.hdf5 232 | --weight_decay=1e-5 --from_split 0 --to_split 5` 233 | 234 | 4. Optional: Use this ensemble to bootstrap training data for step 1 using the 235 | `hydrogen/features.py` notebook. 236 | 237 | 5. Create final submission: Use the notebook `hydrogen/merge_results_alltasks.py` to load the 238 | ensemble predictions, average them, and create the submissions for the three tasks. 239 | 240 | ## Final remarks 241 | 242 | I tried many things which turned out to not help much. The codebase is therefore somewhat 243 | convoluted and could be improved significantly if only the core functionality were desired. If I 244 | were to use such a model in production, I would only use good feature preprocessing with as much 245 | domain knowledge as possible, and utilize semi-supervised training using the InfoNCE objective. If 246 | accuracy was absolutely critical, I would also use an ensembling approach. 247 | 248 | Here are a couple of ideas that might further improve the results: 249 | 250 | * Class-Balanced Loss ([Cui et al, 2019](https://arxiv.org/abs/1901.05555)). 251 | 252 | * Proper hyperparameter optimization for CPC (Batch size, embedding size, ...). There are also some 253 | recent papers that describe improvements to CPC, in particular w.r.t. the selection of negative 254 | samples for the InfoNCE objective 255 | ([e.g., Robinson et al., 2021](https://arxiv.org/abs/2010.04592)). 256 | 257 | * Train model on task 1 and 2, and only fine-tune final classification layer for task 3: This 258 | approach performed well in my local CV, but for some reason not at all on the leaderboard. I don't 259 | know why. 260 | 261 | * Extract embeddings from raw video data instead of pose tracking. I think there's enough data available 262 | that such an approach might be feasible here, in particular in the semi-supervised setting. 263 | -------------------------------------------------------------------------------- /mabe/training.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import dataclasses 4 | import pathlib 5 | 6 | import madgrad 7 | import numpy as np 8 | import sklearn 9 | import sklearn.metrics 10 | import sklearn.utils 11 | import torch 12 | from fastprogress.fastprogress import progress_bar 13 | 14 | import mabe.data 15 | import mabe.loss 16 | import mabe.model 17 | 18 | 19 | @dataclasses.dataclass 20 | class TrainingConfig: 21 | split_idx: int 22 | feature_path: pathlib.Path = mabe.config.ROOT_PATH / "features.hdf5" 23 | batch_size: int = 32 24 | num_epochs: int = 40 25 | subsample_length: int = 256 26 | num_embeddings: int = 128 27 | num_context: int = 512 28 | num_ahead: int = 32 * 8 29 | num_ahead_subsampling: int = 32 30 | num_embedder_blocks: int = 3 31 | input_dropout: float = 0.0 32 | head_dropout: float = 0.0 33 | dropout: float = 0.0 34 | clf_loss_scaling: float = 1.0 35 | label_smoothing: float = 0.2 36 | optimizer: str = "SGD" 37 | learning_rate: float = 0.01 38 | weight_decay: float = 1e-4 39 | scheduler: str = "cosine_annealing" 40 | augmentation_random_noise: float = 0.0 41 | use_extra_features: bool = False 42 | extra_task_loss_scaler: float = 0.1 43 | dark_annotator_loss_scaler: float = 0.2 44 | dark_knowledge_loss_scaler: float = 0.5 45 | fade_out_dark_knowledge: bool = False 46 | test_run: bool = False 47 | label_smoothing_task3: bool = True 48 | use_best_task0: bool = False 49 | 50 | 51 | @dataclasses.dataclass 52 | class TrainingResult: 53 | config: TrainingConfig 54 | losses: list 55 | clf_losses: dict 56 | clf_val_f1s: dict 57 | best_val_f1: dict 58 | best_params: dict 59 | final_params: tuple 60 | test_predictions: dict 61 | test_logits: dict 62 | task3_test_logits: dict 63 | params_by_epoch: list 64 | 65 | 66 | class Trainer: 67 | config: TrainingConfig 68 | cpc: mabe.model.ConvCPC 69 | logreg: mabe.model.MultiAnnotatorLogisticRegressionHead 70 | scheduler: torch.optim.lr_scheduler._LRScheduler 71 | optimizer: torch.optim.Optimizer 72 | split: mabe.data.CVSplit 73 | data: mabe.data.DataWrapper 74 | device: str 75 | clf_loss: list 76 | dark_clf_loss: torch.nn.modules.loss._Loss 77 | 78 | losses: list 79 | clf_losses: dict 80 | dark_losses: dict 81 | clf_val_f1s: dict 82 | best_params: dict 83 | best_val_f1: dict 84 | params_by_epoch: list 85 | 86 | def __init__(self, config, data, split, device): 87 | self.config = config 88 | self.data = data 89 | self.split = split 90 | self.device = device 91 | 92 | if config.test_run: 93 | self.batches_per_epoch = 2 94 | else: 95 | self.batches_per_epoch = int( 96 | sum(data.sample_lengths) / config.subsample_length / config.batch_size 97 | ) 98 | self.num_extra_clf_tasks = ( 99 | len(np.unique(data.clf_tasks)) - 2 100 | ) # task12 clf and -1 for test data 101 | self.num_features = data.X[0].shape[-1] 102 | 103 | self.cpc = mabe.model.ConvCPC( 104 | self.num_features, 105 | config.num_embeddings, 106 | config.num_context, 107 | config.num_ahead, 108 | config.num_ahead_subsampling, 109 | config.subsample_length, 110 | num_embedder_blocks=config.num_embedder_blocks, 111 | input_dropout=config.input_dropout, 112 | head_dropout=config.head_dropout, 113 | dropout=config.dropout, 114 | split_idx=config.split_idx, 115 | num_extra_features=data.num_extra_features, 116 | ).to(device) 117 | 118 | self.logreg = mabe.model.MultiAnnotatorLogisticRegressionHead( 119 | config.num_context, 120 | data.num_annotators, 121 | data.num_extra_features, 122 | self.num_extra_clf_tasks, 123 | ).to(device) 124 | 125 | if config.optimizer == "SGD": 126 | self.optimizer = torch.optim.SGD( 127 | list(self.cpc.parameters()) + list(self.logreg.parameters()), 128 | weight_decay=config.weight_decay, 129 | lr=config.learning_rate, 130 | momentum=0.9, 131 | nesterov=True, 132 | ) 133 | elif config.optimizer == "MADGRAD": 134 | self.optimizer = madgrad.MADGRAD( 135 | list(self.cpc.parameters()) + list(self.logreg.parameters()), 136 | weight_decay=config.weight_decay, 137 | lr=config.learning_rate, 138 | ) 139 | 140 | self.clf_loss = [] 141 | for task in range(self.num_extra_clf_tasks + 1): 142 | task_indices = np.argwhere(data.clf_tasks_labeled == task).flatten() 143 | task_train_Y = np.concatenate([data.Y_labeled[i] for i in task_indices]).astype(np.int) 144 | # TODO: class_weights only on train samples? 145 | class_weights = sklearn.utils.class_weight.compute_class_weight( 146 | "balanced", classes=np.unique(task_train_Y), y=task_train_Y 147 | ) 148 | 149 | _, class_counts = np.unique(task_train_Y, return_counts=True) 150 | p_class = class_counts / np.sum(class_counts) 151 | 152 | if config.label_smoothing_task3: 153 | self.clf_loss.append( 154 | mabe.loss.CrossEntropyLoss( 155 | weight=torch.from_numpy(class_weights).to(device).float(), 156 | ignore_index=-1, 157 | smooth_eps=config.label_smoothing, 158 | smooth_dist=torch.from_numpy(p_class).to(device).float(), 159 | ).to(device) 160 | ) 161 | else: 162 | self.clf_loss.append( 163 | mabe.loss.CrossEntropyLoss( 164 | weight=torch.from_numpy(class_weights).to(device).float(), 165 | ignore_index=-1, 166 | ).to(device) 167 | ) 168 | 169 | # TODO: weight? 170 | self.dark_clf_loss = mabe.loss.CrossEntropyLoss().to(device) 171 | 172 | if config.scheduler == "cosine_annealing": 173 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 174 | self.optimizer, config.num_epochs * self.batches_per_epoch 175 | ) 176 | elif config.scheduler == "none": 177 | pass 178 | else: 179 | assert False 180 | 181 | self.losses = [] 182 | self.clf_losses = collections.defaultdict(list) 183 | self.dark_losses = [] 184 | self.clf_val_f1s = collections.defaultdict(list) 185 | self.best_params = {} 186 | self.best_val_f1 = {} 187 | self.best_val_f1_combined = 0.0 188 | self.params_by_epoch = [] 189 | 190 | for task in range(self.num_extra_clf_tasks + 1): 191 | self.best_params[task] = ( 192 | self.get_cpu_params(self.cpc), 193 | self.get_cpu_params(self.logreg), 194 | ) 195 | 196 | def validation_f1(self, task: int): 197 | with torch.no_grad(): 198 | cpc = self.cpc.eval() 199 | logreg = self.logreg.eval() 200 | 201 | predictions = [] 202 | labels = [] 203 | annotators = [] 204 | 205 | crop_pre, crop_post = cpc.get_crops(self.device) 206 | 207 | def add_padding(seq): 208 | return np.concatenate( 209 | ( 210 | np.zeros_like(seq)[:crop_pre], 211 | seq, 212 | np.zeros_like(seq)[:crop_post], 213 | ) 214 | ) 215 | 216 | with torch.no_grad(): 217 | for idx in self.split.val_indices_labeled: 218 | if self.data.clf_tasks[idx] != task: 219 | continue 220 | 221 | y = self.data.Y_labeled[idx] 222 | a = np.array([self.data.annotators_labeled[idx]]).repeat(len(y)) 223 | if task > 0: 224 | assert np.all(a == 0) # only first annotator for task 3 225 | x = add_padding(self.data.X_labeled[idx].astype(np.float32)) 226 | if self.config.use_extra_features: 227 | x_extra = self.data.X_labeled_extra[idx].astype(np.float32) 228 | x_extra = torch.from_numpy(x_extra).to(self.device, non_blocking=True) 229 | 230 | x = torch.transpose(torch.from_numpy(x[None, :, :]), 2, 1).to( 231 | self.device, non_blocking=True 232 | ) 233 | x_emb = cpc.embedder(x) 234 | 235 | c = cpc.apply_contexter(x_emb, self.device) 236 | 237 | logreg_features = c[0].T 238 | l = logreg(logreg_features, x_extra, a, task) 239 | p = torch.argmax(l, dim=-1) 240 | 241 | predictions.append(p.cpu().numpy()) 242 | labels.append(y) 243 | annotators.append(a) 244 | 245 | if len(predictions): 246 | annotators = np.concatenate(annotators).astype(np.int) 247 | predictions = np.concatenate(predictions).astype(np.int) 248 | labels_array = np.concatenate(labels).astype(np.int) 249 | 250 | if task == 0: 251 | # validation loss only for first annotator 252 | predictions = predictions[annotators == 0] 253 | labels_array = labels_array[annotators == 0] 254 | 255 | return mabe.loss.macro_f1_score(labels_array, predictions, 4) 256 | else: 257 | assert labels_array.max() == 1 258 | assert labels_array.min() == 0 259 | # calculate F1 score for behavior, not macro F1 260 | # return mabe.loss.f1_score_for_label(labels_array, predictions, 2, 1) 261 | return sklearn.metrics.f1_score(labels_array, predictions) 262 | else: 263 | return None 264 | 265 | def clf_task_loss( 266 | self, 267 | batch, 268 | X_extra_batch, 269 | Y_batch, 270 | contexts, 271 | annotators_batch, 272 | task, 273 | ): 274 | has_train_labels = batch.clf_tasks == task 275 | 276 | Y_batch_flat = Y_batch[has_train_labels].flatten().long() 277 | valids = Y_batch_flat >= 0 278 | assert torch.any(valids) 279 | 280 | logreg_features = contexts[has_train_labels] 281 | annotators_batch = annotators_batch[has_train_labels] 282 | if self.config.use_extra_features: 283 | X_extra_batch = X_extra_batch[has_train_labels] 284 | 285 | num_classes = 4 if task == 0 else 2 286 | clf_batch_loss = self.clf_loss[task]( 287 | self.logreg(logreg_features, X_extra_batch, annotators_batch, task).reshape( 288 | -1, num_classes 289 | ), 290 | Y_batch_flat, 291 | ) 292 | 293 | assert np.all( 294 | (annotators_batch.flatten() >= 0) 295 | | ((annotators_batch.flatten() == -1) & (Y_batch_flat.cpu().data.numpy() == -1)) 296 | ) 297 | 298 | return clf_batch_loss 299 | 300 | def dark_clf_losses( 301 | self, 302 | contexts, 303 | X_extra_batch, 304 | Y_batch_dark_behaviors, 305 | Y_batch_dark_annotators, 306 | annotators_batch, 307 | ): 308 | has_dark_labels = (Y_batch_dark_behaviors >= 0.0).sum(dim=(1, 2, 3)) > 0 309 | logreg_features = contexts[has_dark_labels] 310 | annotators_batch = np.zeros_like(annotators_batch[has_dark_labels.cpu().data.numpy()]) 311 | if self.config.use_extra_features: 312 | X_extra_batch = X_extra_batch[has_dark_labels] 313 | 314 | dark_behavior_losses_batch = [] 315 | num_classes = 2 316 | for task in range(1, self.num_extra_clf_tasks + 1): 317 | behavior = task - 1 318 | 319 | behavior_logits = Y_batch_dark_behaviors[has_dark_labels][:, :, behavior].reshape( 320 | -1, num_classes 321 | ) 322 | dark_clf_batch_loss = self.dark_clf_loss( 323 | self.logreg(logreg_features, X_extra_batch, annotators_batch, task).reshape( 324 | -1, num_classes 325 | ), 326 | behavior_logits, 327 | ) 328 | dark_behavior_losses_batch.append(dark_clf_batch_loss) 329 | dark_behavior_loss_batch = ( 330 | sum(dark_behavior_losses_batch) * self.config.extra_task_loss_scaler 331 | ) 332 | 333 | dark_annotator_losses_batch = [] 334 | num_classes = 4 335 | task = 0 336 | sum_annotators = 0 337 | for annotator in range(self.data.num_annotators): 338 | annotator_logits = Y_batch_dark_annotators[has_dark_labels][:, :, annotator].reshape( 339 | -1, num_classes 340 | ) 341 | dark_clf_batch_loss = self.dark_clf_loss( 342 | self.logreg(logreg_features, X_extra_batch, annotators_batch, task).reshape( 343 | -1, num_classes 344 | ), 345 | annotator_logits, 346 | ) 347 | 348 | scaler = 1 if annotator == 0 else self.config.dark_annotator_loss_scaler 349 | dark_clf_batch_loss *= scaler 350 | sum_annotators += scaler 351 | 352 | dark_annotator_losses_batch.append(dark_clf_batch_loss) 353 | dark_annotator_loss_batch = sum(dark_annotator_losses_batch) / sum_annotators 354 | 355 | return dark_behavior_loss_batch + dark_annotator_loss_batch 356 | 357 | def train_batch(self, epoch): 358 | self.optimizer.zero_grad() 359 | cpc = self.cpc.train() 360 | 361 | batch = self.split.get_train_batch( 362 | self.config.batch_size, 363 | random_noise=self.config.augmentation_random_noise, 364 | extra_features=self.config.use_extra_features, 365 | dark_knowledge=self.config.dark_knowledge_loss_scaler > 0.0, 366 | ) 367 | 368 | ( 369 | contexts, 370 | X_extra_batch, 371 | Y_batch, 372 | Y_batch_dark_behaviors, 373 | Y_batch_dark_annotators, 374 | annotators_batch, 375 | batch_loss, 376 | ) = cpc(batch, device=self.device, with_loss=True) 377 | self.losses.append(batch_loss.cpu().item()) 378 | 379 | batch_clf_task_losses = [] 380 | for task in range(self.num_extra_clf_tasks + 1): 381 | task_loss = self.clf_task_loss( 382 | batch, 383 | X_extra_batch, 384 | Y_batch, 385 | contexts, 386 | annotators_batch, 387 | task, 388 | ) 389 | 390 | self.clf_losses[task].append(task_loss.item()) 391 | 392 | loss_scaler = 1 393 | if task > 0: 394 | loss_scaler = self.config.extra_task_loss_scaler 395 | 396 | task_loss *= loss_scaler 397 | batch_clf_task_losses.append(task_loss) 398 | batch_clf_task_loss = sum(batch_clf_task_losses) 399 | 400 | if self.config.dark_knowledge_loss_scaler > 0.0: 401 | dark_knowledge_loss = self.dark_clf_losses( 402 | contexts, 403 | X_extra_batch, 404 | Y_batch_dark_behaviors, 405 | Y_batch_dark_annotators, 406 | annotators_batch, 407 | ) 408 | self.dark_losses.append(dark_knowledge_loss.item()) 409 | else: 410 | dark_knowledge_loss = 0.0 411 | 412 | epoch_scaler = 1.0 413 | if self.config.fade_out_dark_knowledge: 414 | epoch_scaler = (self.config.num_epochs / (epoch + 1)) / self.config.num_epochs 415 | 416 | batch_loss = ( 417 | batch_loss 418 | + self.config.clf_loss_scaling * batch_clf_task_loss 419 | + self.config.dark_knowledge_loss_scaler * dark_knowledge_loss * epoch_scaler 420 | ) 421 | 422 | batch_loss.backward() 423 | self.optimizer.step() 424 | if self.scheduler is not None: 425 | self.scheduler.step() 426 | 427 | @staticmethod 428 | def get_lr(optimizer): 429 | for param_group in optimizer.param_groups: 430 | return param_group["lr"] 431 | 432 | def running(self, losses): 433 | return np.mean([i for i in losses[-self.batches_per_epoch :] if i is not None]) 434 | 435 | def log_batch(self, bar): 436 | task3_running_clf_loss = [] 437 | for task in range(1, self.num_extra_clf_tasks + 1): 438 | task3_running_clf_loss.append(self.running(self.clf_losses[task])) 439 | task3_running_clf_loss = np.mean(task3_running_clf_loss) 440 | 441 | task3_running_val_f1 = [] 442 | for task in range(1, self.num_extra_clf_tasks + 1): 443 | task3_running_val_f1.append(self.best_val_f1[task]) 444 | task3_running_val_f1 = np.mean([i for i in task3_running_val_f1 if i is not None]) 445 | 446 | bar.comment = ( 447 | f"Train: {self.running(self.losses):.3f} | " 448 | + f"CLF Train: {self.running(self.clf_losses[0]):.3f} | " 449 | + f"CLF Train [T3]: {task3_running_clf_loss:.3f} | " 450 | + f"Dark: {self.running(self.dark_losses):.3f} | " 451 | + f"Best CLF Val F1: {self.best_val_f1[0]:.3f} | " 452 | + f"Best CLF Val F1 [T3]: {task3_running_val_f1:.3f} | " 453 | + f"LR: {self.get_lr(self.optimizer):.4f}" 454 | ) 455 | 456 | @staticmethod 457 | def get_cpu_params(model): 458 | return copy.deepcopy({k: v.cpu().detach() for k, v in model.state_dict().items()}) 459 | 460 | def finalize_epoch(self): 461 | for task in range(self.num_extra_clf_tasks + 1): 462 | val_f1 = self.validation_f1(task) 463 | 464 | if val_f1 is not None: 465 | if val_f1 > self.best_val_f1[task]: 466 | self.best_params[task] = ( 467 | self.get_cpu_params(self.cpc), 468 | self.get_cpu_params(self.logreg), 469 | ) 470 | self.best_val_f1[task] = max(val_f1, self.best_val_f1[task]) 471 | self.clf_val_f1s[task].append(val_f1) 472 | 473 | # mean of validation f1s for all task3 subtasks 474 | """ 475 | val_f1 = np.mean( 476 | [ 477 | self.clf_val_f1s[task][-1] 478 | for task in range(1, self.num_extra_clf_tasks + 1) 479 | if self.clf_val_f1s[task][-1] is not None 480 | ] 481 | ) 482 | if val_f1 > self.best_val_f1_combined: 483 | # no validation data for task 3.3, use mean of other task 3 subtasks 484 | task = 4 485 | self.best_params[task] = ( 486 | self.get_cpu_params(self.cpc), 487 | self.get_cpu_params(self.logreg), 488 | ) 489 | self.best_val_f1_combined = val_f1 490 | """ 491 | 492 | self.params_by_epoch.append( 493 | ( 494 | self.get_cpu_params(self.cpc), 495 | self.get_cpu_params(self.logreg), 496 | ) 497 | ) 498 | 499 | def get_result(self) -> TrainingResult: 500 | final_params = ( 501 | self.get_cpu_params(self.cpc), 502 | self.get_cpu_params(self.logreg), 503 | ) 504 | 505 | test_predictions, test_logits, task3_test_logits = mabe.util.predict_test_data( 506 | self.cpc, self.logreg, self.data, self.device, self.config, self.best_params 507 | ) 508 | 509 | result = TrainingResult( 510 | config=self.config, 511 | losses=self.losses, 512 | clf_losses=self.clf_losses, 513 | clf_val_f1s=self.clf_val_f1s, 514 | best_val_f1=self.best_val_f1, 515 | best_params=self.best_params, 516 | params_by_epoch=self.params_by_epoch, 517 | final_params=final_params, 518 | test_predictions=test_predictions, 519 | test_logits=test_logits, 520 | task3_test_logits=task3_test_logits, 521 | ) 522 | 523 | return result 524 | 525 | def train_model(self) -> TrainingResult: 526 | for task in range(self.num_extra_clf_tasks + 1): 527 | val_f1 = self.validation_f1(task) 528 | self.best_val_f1[task] = val_f1 529 | self.clf_val_f1s[task].append(val_f1) 530 | 531 | # use combined bar for epochs and batches 532 | bar = progress_bar(range(self.config.num_epochs * self.batches_per_epoch)) 533 | bar_iter = iter(bar) 534 | for i_epoch in range(self.config.num_epochs): 535 | for _ in range(self.batches_per_epoch): 536 | next(bar_iter) 537 | self.train_batch(i_epoch) 538 | self.log_batch(bar) 539 | 540 | self.finalize_epoch() 541 | 542 | return self.get_result() 543 | -------------------------------------------------------------------------------- /mabe/loss.py: -------------------------------------------------------------------------------- 1 | import numba 2 | import numpy 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import FloatTensor 8 | 9 | 10 | def onehot(indexes, N=None, ignore_index=None): 11 | """ 12 | Creates a one-representation of indexes with N possible entries 13 | if N is not specified, it will suit the maximum index appearing. 14 | indexes is a long-tensor of indexes 15 | ignore_index will be zero in onehot representation 16 | """ 17 | if N is None: 18 | N = indexes.max() + 1 19 | sz = list(indexes.size()) 20 | output = indexes.new().byte().resize_(*sz, N).zero_() 21 | output.scatter_(-1, indexes.unsqueeze(-1), 1) 22 | if ignore_index is not None and ignore_index >= 0: 23 | output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0) 24 | return output 25 | 26 | 27 | def _is_long(x): 28 | if hasattr(x, "data"): 29 | x = x.data 30 | return isinstance(x, torch.LongTensor) or isinstance(x, torch.cuda.LongTensor) 31 | 32 | 33 | def cross_entropy( 34 | inputs, 35 | target, 36 | weight=None, 37 | ignore_index=-100, 38 | reduction="mean", 39 | smooth_eps=None, 40 | smooth_dist=None, 41 | from_logits=True, 42 | ): 43 | """cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567""" 44 | smooth_eps = smooth_eps or 0 45 | 46 | # ordinary log-liklihood - use cross_entropy from nn 47 | if _is_long(target) and smooth_eps == 0: 48 | if from_logits: 49 | return F.cross_entropy( 50 | inputs, target, weight, ignore_index=ignore_index, reduction=reduction 51 | ) 52 | else: 53 | return F.nll_loss( 54 | inputs, target, weight, ignore_index=ignore_index, reduction=reduction 55 | ) 56 | 57 | if from_logits: 58 | # log-softmax of inputs 59 | lsm = F.log_softmax(inputs, dim=-1) 60 | else: 61 | lsm = inputs 62 | 63 | masked_indices = None 64 | num_classes = inputs.size(-1) 65 | 66 | if _is_long(target): 67 | masked_indices = target.eq(ignore_index) 68 | lsm = lsm[~masked_indices] 69 | target = target[~masked_indices] 70 | 71 | if smooth_eps > 0 and smooth_dist is not None: 72 | if _is_long(target): 73 | target = onehot(target, num_classes).type_as(inputs) 74 | if smooth_dist.dim() < target.dim(): 75 | smooth_dist = smooth_dist.unsqueeze(0) 76 | target.lerp_(smooth_dist, smooth_eps) 77 | 78 | if weight is not None: 79 | lsm = lsm * weight.unsqueeze(0) 80 | 81 | if _is_long(target): 82 | eps_sum = smooth_eps / num_classes 83 | eps_nll = 1.0 - eps_sum - smooth_eps 84 | likelihood = lsm.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1) 85 | loss = -(eps_nll * likelihood + eps_sum * lsm.sum(-1)) 86 | else: 87 | loss = -(target * lsm).sum(-1) 88 | 89 | if reduction == "sum": 90 | loss = loss.sum() 91 | elif reduction == "mean": 92 | loss = loss.mean() 93 | 94 | return loss 95 | 96 | 97 | class CrossEntropyLoss(nn.CrossEntropyLoss): 98 | """CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing""" 99 | 100 | def __init__( 101 | self, 102 | weight=None, 103 | ignore_index=-100, 104 | reduction="mean", 105 | smooth_eps=None, 106 | smooth_dist=None, 107 | from_logits=True, 108 | ): 109 | super().__init__(weight=weight, ignore_index=ignore_index, reduction=reduction) 110 | self.smooth_eps = smooth_eps 111 | self.smooth_dist = smooth_dist 112 | self.from_logits = from_logits 113 | 114 | def forward(self, input, target, smooth_dist=None): 115 | if smooth_dist is None: 116 | smooth_dist = self.smooth_dist 117 | return cross_entropy( 118 | input, 119 | target, 120 | weight=self.weight, 121 | ignore_index=self.ignore_index, 122 | reduction=self.reduction, 123 | smooth_eps=self.smooth_eps, 124 | smooth_dist=smooth_dist, 125 | from_logits=self.from_logits, 126 | ) 127 | 128 | 129 | def binary_cross_entropy( 130 | inputs, target, weight=None, reduction="mean", smooth_eps=None, from_logits=False 131 | ): 132 | """cross entropy loss, with support for label smoothing https://arxiv.org/abs/1512.00567""" 133 | smooth_eps = smooth_eps or 0 134 | if smooth_eps > 0: 135 | target = target.float() 136 | target.add_(smooth_eps).div_(2.0) 137 | if from_logits: 138 | return F.binary_cross_entropy_with_logits( 139 | inputs, target, weight=weight, reduction=reduction 140 | ) 141 | else: 142 | return F.binary_cross_entropy(inputs, target, weight=weight, reduction=reduction) 143 | 144 | 145 | def binary_cross_entropy_with_logits( 146 | inputs, target, weight=None, reduction="mean", smooth_eps=None, from_logits=True 147 | ): 148 | return binary_cross_entropy(inputs, target, weight, reduction, smooth_eps, from_logits) 149 | 150 | 151 | class BCELoss(nn.BCELoss): 152 | def __init__( 153 | self, 154 | weight=None, 155 | size_average=None, 156 | reduce=None, 157 | reduction="mean", 158 | smooth_eps=None, 159 | from_logits=False, 160 | ): 161 | super().__init__(weight, size_average, reduce, reduction) 162 | self.smooth_eps = smooth_eps 163 | self.from_logits = from_logits 164 | 165 | def forward(self, input, target): 166 | return binary_cross_entropy( 167 | input, 168 | target, 169 | weight=self.weight, 170 | reduction=self.reduction, 171 | smooth_eps=self.smooth_eps, 172 | from_logits=self.from_logits, 173 | ) 174 | 175 | 176 | class BCEWithLogitsLoss(BCELoss): 177 | def __init__( 178 | self, 179 | weight=None, 180 | size_average=None, 181 | reduce=None, 182 | reduction="mean", 183 | smooth_eps=None, 184 | from_logits=True, 185 | ): 186 | super().__init__( 187 | weight, size_average, reduce, reduction, smooth_eps=smooth_eps, from_logits=from_logits 188 | ) 189 | 190 | 191 | def range_to_anchors_and_delta(precision_range, num_anchors): 192 | """Calculates anchor points from precision range. 193 | Args: 194 | precision_range: an interval (a, b), where 0.0 <= a <= b <= 1.0 195 | num_anchors: int, number of equally spaced anchor points. 196 | Returns: 197 | precision_values: A `Tensor` of [num_anchors] equally spaced values 198 | in the interval precision_range. 199 | delta: The spacing between the values in precision_values. 200 | Raises: 201 | ValueError: If precision_range is invalid. 202 | """ 203 | # Validate precision_range. 204 | if len(precision_range) != 2: 205 | raise ValueError("length of precision_range (%d) must be 2" % len(precision_range)) 206 | if not 0 <= precision_range[0] <= precision_range[1] <= 1: 207 | raise ValueError( 208 | "precision values must follow 0 <= %f <= %f <= 1" 209 | % (precision_range[0], precision_range[1]) 210 | ) 211 | 212 | # Sets precision_values uniformly between min_precision and max_precision. 213 | precision_values = numpy.linspace( 214 | start=precision_range[0], stop=precision_range[1], num=num_anchors + 1 215 | )[1:] 216 | 217 | delta = (precision_range[1] - precision_range[0]) / num_anchors 218 | return FloatTensor(precision_values), delta 219 | 220 | 221 | def build_class_priors( 222 | labels, 223 | class_priors=None, 224 | weights=None, 225 | positive_pseudocount=1.0, 226 | negative_pseudocount=1.0, 227 | ): 228 | """build class priors, if necessary. 229 | For each class, the class priors are estimated as 230 | (P + sum_i w_i y_i) / (P + N + sum_i w_i), 231 | where y_i is the ith label, w_i is the ith weight, P is a pseudo-count of 232 | positive labels, and N is a pseudo-count of negative labels. 233 | Args: 234 | labels: A `Tensor` with shape [batch_size, num_classes]. 235 | Entries should be in [0, 1]. 236 | class_priors: None, or a floating point `Tensor` of shape [C] 237 | containing the prior probability of each class (i.e. the fraction of the 238 | training data consisting of positive examples). If None, the class 239 | priors are computed from `targets` with a moving average. 240 | weights: `Tensor` of shape broadcastable to labels, [N, 1] or [N, C], 241 | where `N = batch_size`, C = num_classes` 242 | positive_pseudocount: Number of positive labels used to initialize the class 243 | priors. 244 | negative_pseudocount: Number of negative labels used to initialize the class 245 | priors. 246 | Returns: 247 | class_priors: A Tensor of shape [num_classes] consisting of the 248 | weighted class priors, after updating with moving average ops if created. 249 | """ 250 | if class_priors is not None: 251 | return class_priors 252 | 253 | N, C = labels.size() 254 | 255 | weighted_label_counts = (weights * labels).sum(0) 256 | 257 | weight_sum = weights.sum(0) 258 | 259 | class_priors = torch.div( 260 | weighted_label_counts + positive_pseudocount, 261 | weight_sum + positive_pseudocount + negative_pseudocount, 262 | ) 263 | 264 | return class_priors 265 | 266 | 267 | def weighted_hinge_loss(labels, logits, positive_weights=1.0, negative_weights=1.0): 268 | """ 269 | Args: 270 | labels: one-hot representation `Tensor` of shape broadcastable to logits 271 | logits: A `Tensor` of shape [N, C] or [N, C, K] 272 | positive_weights: Scalar or Tensor 273 | negative_weights: same shape as positive_weights 274 | Returns: 275 | 3D Tensor of shape [N, C, K], where K is length of positive weights 276 | or 2D Tensor of shape [N, C] 277 | """ 278 | positive_weights_is_tensor = torch.is_tensor(positive_weights) 279 | negative_weights_is_tensor = torch.is_tensor(negative_weights) 280 | 281 | # Validate positive_weights and negative_weights 282 | if positive_weights_is_tensor ^ negative_weights_is_tensor: 283 | raise ValueError( 284 | "positive_weights and negative_weights must be same shape Tensor " 285 | "or both be scalars. But positive_weight_is_tensor: %r, while " 286 | "negative_weight_is_tensor: %r" 287 | % (positive_weights_is_tensor, negative_weights_is_tensor) 288 | ) 289 | 290 | if positive_weights_is_tensor and (positive_weights.size() != negative_weights.size()): 291 | raise ValueError( 292 | "shape of positive_weights and negative_weights " 293 | "must be the same! " 294 | "shape of positive_weights is {0}, " 295 | "but shape of negative_weights is {1}" 296 | % {0: positive_weights.size(), 1: negative_weights.size()} 297 | ) 298 | 299 | # positive_term: Tensor [N, C] or [N, C, K] 300 | positive_term = (1 - logits).clamp(min=0) * labels 301 | negative_term = (1 + logits).clamp(min=0) * (1 - labels) 302 | 303 | if positive_weights_is_tensor and positive_term.dim() == 2: 304 | return ( 305 | positive_term.unsqueeze(-1) * positive_weights 306 | + negative_term.unsqueeze(-1) * negative_weights 307 | ) 308 | else: 309 | return positive_term * positive_weights + negative_term * negative_weights 310 | 311 | 312 | def true_positives_lower_bound(labels, logits, weights): 313 | """ 314 | true_positives_lower_bound defined in paper: 315 | "Scalable Learning of Non-Decomposable Objectives" 316 | Args: 317 | labels: A `Tensor` of shape broadcastable to logits. 318 | logits: A `Tensor` of shape [N, C] or [N, C, K]. 319 | If the third dimension is present, 320 | the lower bound is computed on each slice [:, :, k] independently. 321 | weights: Per-example loss coefficients, with shape [N, 1] or [N, C] 322 | Returns: 323 | A `Tensor` of shape [C] or [C, K]. 324 | """ 325 | # A `Tensor` of shape [N, C] or [N, C, K] 326 | loss_on_positives = weighted_hinge_loss(labels, logits, negative_weights=0.0) 327 | 328 | weighted_loss_on_positives = ( 329 | weights.unsqueeze(-1) * (labels - loss_on_positives) 330 | if loss_on_positives.dim() > weights.dim() 331 | else weights * (labels - loss_on_positives) 332 | ) 333 | return weighted_loss_on_positives.sum(0) 334 | 335 | 336 | def false_postives_upper_bound(labels, logits, weights): 337 | """ 338 | false_positives_upper_bound defined in paper: 339 | "Scalable Learning of Non-Decomposable Objectives" 340 | Args: 341 | labels: A `Tensor` of shape broadcastable to logits. 342 | logits: A `Tensor` of shape [N, C] or [N, C, K]. 343 | If the third dimension is present, 344 | the lower bound is computed on each slice [:, :, k] independently. 345 | weights: Per-example loss coefficients, with shape broadcast-compatible with 346 | that of `labels`. i.e. [N, 1] or [N, C] 347 | Returns: 348 | A `Tensor` of shape [C] or [C, K]. 349 | """ 350 | loss_on_negatives = weighted_hinge_loss(labels, logits, positive_weights=0) 351 | 352 | weighted_loss_on_negatives = ( 353 | weights.unsqueeze(-1) * loss_on_negatives 354 | if loss_on_negatives.dim() > weights.dim() 355 | else weights * loss_on_negatives 356 | ) 357 | return weighted_loss_on_negatives.sum(0) 358 | 359 | 360 | class LagrangeMultiplier(torch.autograd.Function): 361 | @staticmethod 362 | def forward(ctx, input): 363 | ctx.save_for_backward(input) 364 | return input.clamp(min=0) 365 | 366 | @staticmethod 367 | def backward(ctx, grad_output): 368 | return grad_output.neg() 369 | 370 | 371 | def lagrange_multiplier(x): 372 | return LagrangeMultiplier.apply(x) 373 | 374 | 375 | class AUCPRHingeLoss(nn.Module): 376 | """area under the precision-recall curve loss, 377 | Reference: "Scalable Learning of Non-Decomposable Objectives", Section 5 \ 378 | TensorFlow Implementation: \ 379 | https://github.com/tensorflow/models/tree/master/research/global_objectives\ 380 | """ 381 | 382 | class Config: 383 | r""" 384 | Attributes: 385 | precision_range_lower (float): the lower range of precision values over 386 | which to compute AUC. Must be nonnegative, `\leq precision_range_upper`, 387 | and `leq 1.0`. 388 | precision_range_upper (float): the upper range of precision values over 389 | which to compute AUC. Must be nonnegative, `\geq precision_range_lower`, 390 | and `leq 1.0`. 391 | num_classes (int): number of classes(aka labels) 392 | num_anchors (int): The number of grid points used to approximate the 393 | Riemann sum. 394 | """ 395 | 396 | precision_range_lower: float = 0.0 397 | precision_range_upper: float = 1.0 398 | num_classes: int = 1 399 | num_anchors: int = 20 400 | 401 | def __init__(self, config, weights=None, *args, **kwargs): 402 | """Args: 403 | config: Config containing `precision_range_lower`, `precision_range_upper`, 404 | `num_classes`, `num_anchors` 405 | """ 406 | nn.Module.__init__(self) 407 | self.config = config 408 | 409 | self.num_classes = self.config.num_classes 410 | self.num_anchors = self.config.num_anchors 411 | self.precision_range = ( 412 | self.config.precision_range_lower, 413 | self.config.precision_range_upper, 414 | ) 415 | 416 | # Create precision anchor values and distance between anchors. 417 | # coresponding to [alpha_t] and [delta_t] in the paper. 418 | # precision_values: 1D `Tensor` of shape [K], where `K = num_anchors` 419 | # delta: Scalar (since we use equal distance between anchors) 420 | self.precision_values, self.delta = range_to_anchors_and_delta( 421 | self.precision_range, self.num_anchors 422 | ) 423 | 424 | # notation is [b_k] in paper, Parameter of shape [C, K] 425 | # where `C = number of classes` `K = num_anchors` 426 | self.biases = nn.Parameter( 427 | FloatTensor(self.config.num_classes, self.config.num_anchors).zero_() 428 | ) 429 | self.lambdas = nn.Parameter( 430 | FloatTensor(self.config.num_classes, self.config.num_anchors).data.fill_(1.0) 431 | ) 432 | 433 | def forward(self, logits, targets, reduce=True, size_average=True, weights=None): 434 | """ 435 | Args: 436 | logits: Variable :math:`(N, C)` where `C = number of classes` 437 | targets: Variable :math:`(N)` where each value is 438 | `0 <= targets[i] <= C-1` 439 | weights: Coefficients for the loss. Must be a `Tensor` of shape 440 | [N] or [N, C], where `N = batch_size`, `C = number of classes`. 441 | size_average (bool, optional): By default, the losses are averaged 442 | over observations for each minibatch. However, if the field 443 | sizeAverage is set to False, the losses are instead summed 444 | for each minibatch. Default: ``True`` 445 | reduce (bool, optional): By default, the losses are averaged or summed over 446 | observations for each minibatch depending on size_average. When reduce 447 | is False, returns a loss per input/target element instead and ignores 448 | size_average. Default: True 449 | """ 450 | C = 1 if logits.dim() == 1 else logits.size(1) 451 | 452 | if self.num_classes != C: 453 | raise ValueError("num classes is %d while logits width is %d" % (self.num_classes, C)) 454 | 455 | labels, weights = AUCPRHingeLoss._prepare_labels_weights(logits, targets, weights=weights) 456 | 457 | # Lagrange multipliers 458 | # Lagrange multipliers are required to be nonnegative. 459 | # Their gradient is reversed so that they are maximized 460 | # (rather than minimized) by the optimizer. 461 | # 1D `Tensor` of shape [K], where `K = num_anchors` 462 | lambdas = lagrange_multiplier(self.lambdas) 463 | # print("lambdas: {}".format(lambdas)) 464 | 465 | precision_values = self.precision_values.to(logits.device) 466 | 467 | # A `Tensor` of Shape [N, C, K] 468 | hinge_loss = weighted_hinge_loss( 469 | labels.unsqueeze(-1), 470 | logits.unsqueeze(-1) - self.biases, 471 | positive_weights=1.0 + lambdas * (1.0 - precision_values), 472 | negative_weights=lambdas * precision_values, 473 | ) 474 | 475 | # 1D tensor of shape [C] 476 | class_priors = build_class_priors(labels, weights=weights) 477 | 478 | # lambda_term: Tensor[C, K] 479 | # according to paper, lambda_term = lambda * (1 - precision) * |Y^+| 480 | # where |Y^+| is number of postive examples = N * class_priors 481 | lambda_term = class_priors.unsqueeze(-1) * (lambdas * (1.0 - precision_values)) 482 | 483 | per_anchor_loss = weights.unsqueeze(-1) * hinge_loss - lambda_term 484 | 485 | # Riemann sum over anchors, and normalized by precision range 486 | # loss: Tensor[N, C] 487 | loss = per_anchor_loss.sum(2) * self.delta 488 | loss /= self.precision_range[1] - self.precision_range[0] 489 | 490 | if not reduce: 491 | return loss 492 | elif size_average: 493 | return loss.mean() 494 | else: 495 | return loss.sum() 496 | 497 | @staticmethod 498 | def _prepare_labels_weights(logits, targets, weights=None): 499 | """ 500 | Args: 501 | logits: Variable :math:`(N, C)` where `C = number of classes` 502 | targets: Variable :math:`(N)` where each value is 503 | `0 <= targets[i] <= C-1` 504 | weights: Coefficients for the loss. Must be a `Tensor` of shape 505 | [N] or [N, C], where `N = batch_size`, `C = number of classes`. 506 | Returns: 507 | labels: Tensor of shape [N, C], one-hot representation 508 | weights: Tensor of shape broadcastable to labels 509 | """ 510 | N, C = logits.size() 511 | # Converts targets to one-hot representation. Dim: [N, C] 512 | labels = ( 513 | FloatTensor(N, C).to(logits.device).zero_().scatter(1, targets.unsqueeze(1).data, 1) 514 | ) 515 | 516 | if weights is None: 517 | weights = FloatTensor(N).to(logits.device).fill_(1.0) 518 | 519 | if weights.dim() == 1: 520 | weights.unsqueeze_(-1) 521 | 522 | weights = weights.detach() 523 | 524 | return labels, weights 525 | 526 | 527 | def f1_loss(Y_batch_flat, Y_pred_flat, valids): 528 | valids = (Y_batch_flat >= 0).cpu().numpy() 529 | y_true = Y_batch_flat[valids] 530 | y_pred = torch.nn.functional.softmax(Y_pred_flat[valids], dim=-1) 531 | y_true = onehot(y_true, N=4) 532 | eps = 1e-8 533 | 534 | tp = (y_true * y_pred).sum(dim=0) 535 | # tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0) 536 | fp = ((1 - y_true) * y_pred).sum(dim=0) 537 | fn = (y_true * (1 - y_pred)).sum(dim=0) 538 | 539 | p = tp / (tp + fp + eps) 540 | r = tp / (tp + fn + eps) 541 | 542 | f1 = 2 * p * r / (p + r + eps) 543 | 544 | return 1 - f1.mean() 545 | 546 | 547 | @numba.njit 548 | def f1_score_for_label(y_true, y_pred, n_labels, i): 549 | yt = y_true == i 550 | yp = y_pred == i 551 | 552 | tp = np.sum(yt & yp) 553 | 554 | tpfp = np.sum(yp) 555 | tpfn = np.sum(yt) 556 | if tpfp == 0: 557 | precision = 0.0 558 | else: 559 | precision = tp / tpfp 560 | if tpfn == 0: 561 | print("[ERROR] label not found in y_true...") 562 | recall = 0.0 563 | else: 564 | recall = tp / tpfn 565 | 566 | if precision == 0.0 or recall == 0.0: 567 | f1 = 0.0 568 | else: 569 | f1 = 2 * precision * recall / (precision + recall) 570 | 571 | return f1 572 | 573 | 574 | @numba.njit 575 | def macro_f1_score(y_true, y_pred, n_labels): 576 | total_f1 = 0.0 577 | for i in range(n_labels): 578 | f1 = f1_score_for_label(y_true, y_pred, n_labels, i) 579 | total_f1 += f1 580 | return total_f1 / n_labels 581 | 582 | 583 | class ECELoss(nn.Module): 584 | """ 585 | Calculates the Expected Calibration Error of a model. 586 | (This isn't necessary for temperature scaling, just a cool metric). 587 | The input to this loss is the logits of a model, NOT the softmax scores. 588 | This divides the confidence outputs into equally-sized interval bins. 589 | In each bin, we compute the confidence gap: 590 | bin_gap = | avg_confidence_in_bin - accuracy_in_bin | 591 | We then return a weighted average of the gaps, based on the number 592 | of samples in each bin 593 | See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht. 594 | "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI. 595 | 2015. 596 | """ 597 | 598 | def __init__(self, n_bins=15): 599 | """ 600 | n_bins (int): number of confidence interval bins 601 | """ 602 | super().__init__() 603 | bin_boundaries = torch.linspace(0, 1, n_bins + 1) 604 | self.bin_lowers = bin_boundaries[:-1] 605 | self.bin_uppers = bin_boundaries[1:] 606 | 607 | def forward(self, logits, labels): 608 | softmaxes = F.softmax(logits, dim=1) 609 | confidences, predictions = torch.max(softmaxes, 1) 610 | accuracies = predictions.eq(labels) 611 | 612 | ece = torch.zeros(1, device=logits.device) 613 | for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers): 614 | # Calculated |confidence - accuracy| in each bin 615 | in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item()) 616 | prop_in_bin = in_bin.float().mean() 617 | if prop_in_bin.item() > 0: 618 | accuracy_in_bin = accuracies[in_bin].float().mean() 619 | avg_confidence_in_bin = confidences[in_bin].mean() 620 | ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin 621 | 622 | return ece 623 | -------------------------------------------------------------------------------- /mabe/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import numba 5 | import numpy as np 6 | import torch 7 | import torchtyping 8 | from fastprogress.fastprogress import force_console_behavior 9 | from torchtyping import TensorType # type: ignore 10 | 11 | import mabe 12 | import mabe.data 13 | from mabe.types import annotator, batch, behavior, channels, classes, time 14 | 15 | torchtyping.patch_typeguard() 16 | master_bar, progress_bar = force_console_behavior() 17 | 18 | 19 | class CausalConv1D(torch.nn.Module): 20 | def __init__( 21 | self, 22 | in_channels, 23 | out_channels, 24 | kernel_size=3, 25 | dilation=1, 26 | downsample=False, 27 | **conv1d_kwargs, 28 | ): 29 | super().__init__() 30 | self.pad = (kernel_size - 1) * dilation 31 | self.downsample = downsample 32 | self.conv = torch.nn.Conv1d( 33 | in_channels, 34 | out_channels, 35 | kernel_size, 36 | padding=self.pad, 37 | dilation=dilation, 38 | **conv1d_kwargs, 39 | ) 40 | 41 | def forward( 42 | self, x: TensorType["batch", "channels", "time", float] 43 | ) -> TensorType["batch", "channels", "time", float]: 44 | x = self.conv(x) 45 | 46 | # remove trailing padding 47 | x = x[:, :, self.pad : -self.pad] 48 | 49 | if self.downsample: 50 | x = x[:, :, :: self.conv.dilation[0]] 51 | 52 | return x 53 | 54 | 55 | class ResidualBlock(torch.nn.Module): 56 | def __init__(self, channels, padding, kernel_size, layer_norm=False, causal=False, **kwargs): 57 | super().__init__() 58 | 59 | # TODO: add dilation 60 | self.offset = (kernel_size - 1) // 2 61 | self.causal = causal 62 | 63 | self.conv1 = torch.nn.Conv1d(channels, channels, padding=padding, kernel_size=kernel_size) 64 | self.conv2 = torch.nn.Conv1d(channels, channels, padding=padding, kernel_size=kernel_size) 65 | 66 | self.layer_norm = layer_norm 67 | self.layer_norm1 = torch.nn.LayerNorm(channels) if layer_norm else None 68 | self.layer_norm2 = torch.nn.LayerNorm(channels) if layer_norm else None 69 | 70 | def forward( 71 | self, x: TensorType["batch", "channels", "time", float] 72 | ) -> TensorType["batch", "channels", "time", float]: 73 | x_ = x 74 | 75 | if self.layer_norm: 76 | x = x.transpose(2, 1) 77 | x = self.layer_norm1(x) 78 | x = x.transpose(2, 1) 79 | x = torch.nn.functional.leaky_relu(x) 80 | x = self.conv1(x) 81 | 82 | if self.layer_norm: 83 | x = x.transpose(2, 1) 84 | x = self.layer_norm2(x) 85 | x = x.transpose(2, 1) 86 | x = torch.nn.functional.leaky_relu(x) 87 | x = self.conv2(x) 88 | 89 | if self.offset > 0: 90 | if self.causal: 91 | x_ = x_[:, :, 4 * self.offset :] 92 | else: 93 | x_ = x_[:, :, 2 * self.offset : -2 * self.offset] 94 | return x_ + x 95 | 96 | 97 | class Embedder(torch.nn.Module): 98 | def __init__( 99 | self, 100 | in_channels, 101 | out_channels, 102 | layer_norm=True, 103 | input_dropout=0, 104 | head_dropout=0, 105 | num_embedder_blocks=3, 106 | **kwargs, 107 | ): 108 | super().__init__() 109 | 110 | self.input_dropout = torch.nn.Dropout2d(p=input_dropout) 111 | self.head_dropout = torch.nn.Dropout2d(p=head_dropout) 112 | 113 | self.dropout = torch.nn.Dropout(input_dropout) 114 | self.head = torch.nn.Conv1d(in_channels, out_channels, kernel_size=1, padding=0) 115 | 116 | self.convolutions = torch.nn.Sequential( 117 | ResidualBlock(out_channels, kernel_size=1, padding=0, layer_norm=layer_norm), 118 | ResidualBlock(out_channels, kernel_size=1, padding=0, layer_norm=layer_norm), 119 | *[ 120 | ResidualBlock(out_channels, kernel_size=3, padding=0, layer_norm=layer_norm) 121 | for _ in range(num_embedder_blocks) 122 | ], 123 | ) 124 | 125 | def forward( 126 | self, x: TensorType["batch", "channels", "time", float] 127 | ) -> TensorType["batch", "channels", "time", float]: 128 | x = self.input_dropout(x) 129 | x = self.head(x) 130 | x = self.head_dropout(x) 131 | 132 | x = self.convolutions(x) 133 | 134 | return x 135 | 136 | 137 | class CausalResidualBlock(torch.nn.Module): 138 | def __init__(self, channels, layer_norm=False, **kwargs): 139 | super().__init__() 140 | 141 | self.conv1 = CausalConv1D(channels, channels, dilation=2) 142 | self.conv2 = CausalConv1D(channels, channels, dilation=1) 143 | 144 | self.layer_norm = layer_norm 145 | self.layer_norm1 = torch.nn.LayerNorm(channels) if layer_norm else None 146 | self.layer_norm2 = torch.nn.LayerNorm(channels) if layer_norm else None 147 | 148 | def forward( 149 | self, x: TensorType["batch", "channels", "time", float] 150 | ) -> TensorType["batch", "channels", "time", float]: 151 | x_ = x 152 | 153 | if self.layer_norm: 154 | x = x.transpose(2, 1) 155 | x = self.layer_norm1(x) 156 | x = x.transpose(2, 1) 157 | x = torch.nn.functional.leaky_relu(x) 158 | x = self.conv1(x) 159 | 160 | if self.layer_norm: 161 | x = x.transpose(2, 1) 162 | x = self.layer_norm2(x) 163 | x = x.transpose(2, 1) 164 | x = torch.nn.functional.leaky_relu(x) 165 | x = self.conv2(x) 166 | 167 | return x_ + x 168 | 169 | 170 | class Contexter(torch.nn.Module): 171 | def __init__(self, in_channels, out_channels, layer_norm=True, dropout=0, **kwargs): 172 | super().__init__() 173 | 174 | self.head = CausalConv1D(in_channels, out_channels) 175 | self.convolutions = torch.nn.Sequential( 176 | *[ 177 | ResidualBlock( 178 | out_channels, kernel_size=3, padding=0, layer_norm=layer_norm, causal=True 179 | ) 180 | for i in range(8) 181 | ] 182 | ) 183 | self.dropout = torch.nn.Dropout(dropout) 184 | 185 | def forward( 186 | self, x: TensorType["batch", "channels", "time", float] 187 | ) -> TensorType["batch", "channels", "time", float]: 188 | x = self.head(x) 189 | x = self.convolutions(x) 190 | x = self.dropout(x) 191 | 192 | return x 193 | 194 | 195 | def nt_xent_loss(predictions_to, has_value_to=None, cpc_tau=0.07, cpc_alpha=0.5): 196 | # Normalized Temperature-scaled Cross Entropy Loss 197 | # https://arxiv.org/abs/2002.05709v3 198 | 199 | if has_value_to is not None: 200 | predictions_to = predictions_to[has_value_to] 201 | labels = torch.where(has_value_to)[0] 202 | labels_diag = torch.zeros( 203 | predictions_to.shape[0], 204 | predictions_to.shape[1], 205 | device=predictions_to.device, 206 | dtype=torch.bool, 207 | ) 208 | labels_diag[torch.arange(predictions_to.shape[0]), labels] = True 209 | else: 210 | labels_diag = torch.diag( 211 | torch.ones(len(predictions_to), device=predictions_to.device) 212 | ).bool() 213 | 214 | neg = predictions_to[~labels_diag].reshape(len(predictions_to), -1) 215 | pos = predictions_to[labels_diag] 216 | 217 | neg_and_pos = torch.cat((neg, pos.unsqueeze(1)), dim=1) 218 | 219 | loss_pos = -pos / cpc_tau 220 | loss_neg = torch.logsumexp(neg_and_pos / cpc_tau, dim=1) 221 | 222 | loss = 2 * (cpc_alpha * loss_pos + (1.0 - cpc_alpha) * loss_neg) 223 | 224 | return loss 225 | 226 | 227 | # @numba.njit 228 | def subsample(X, X_, valid_lengths, combined_length, subsample_from_rand): 229 | for i, x in enumerate(X): 230 | from_idx = math.floor(subsample_from_rand[i] * max(0, len(x) - combined_length)) 231 | to_idx = min(len(x), from_idx + combined_length) 232 | valid_length = to_idx - from_idx 233 | sample = x[from_idx:to_idx] 234 | X_[i, :valid_length] = sample 235 | valid_lengths[i] = valid_length 236 | 237 | 238 | class ConvCPC(torch.nn.Module): 239 | def __init__( 240 | self, 241 | num_features, 242 | num_embeddings, 243 | num_context, 244 | num_ahead, 245 | num_ahead_subsampling, 246 | subsample_length, 247 | split_idx, 248 | num_embedder_blocks=3, 249 | input_dropout=0, 250 | head_dropout=0, 251 | dropout=0, 252 | num_extra_features=0, 253 | **kwargs, 254 | ): 255 | super().__init__() 256 | 257 | self.num_embeddings = num_embeddings 258 | self.num_context = num_context 259 | self.num_features = num_features 260 | self.num_extra_features = num_extra_features 261 | self.split_idx = split_idx 262 | 263 | self.embedder = Embedder( 264 | num_features, 265 | num_embeddings, 266 | num_embedder_blocks=num_embedder_blocks, 267 | input_dropout=input_dropout, 268 | head_dropout=head_dropout, 269 | dropout=dropout, 270 | ) 271 | self.contexter = Contexter(num_embeddings, num_context, dropout=dropout) 272 | 273 | self.projections = torch.nn.ModuleList( 274 | [ 275 | torch.nn.Linear(num_context, num_embeddings, bias=False) 276 | for _ in range(num_ahead // num_ahead_subsampling) 277 | ] 278 | ) 279 | 280 | self.num_ahead = num_ahead 281 | self.num_ahead_subsampling = num_ahead_subsampling 282 | self.subsample_length = subsample_length 283 | 284 | for m in self.modules(): 285 | if isinstance(m, torch.nn.Conv1d): 286 | torch.nn.init.kaiming_normal_(m.weight, mode="fan_out") 287 | torch.nn.init.zeros_(m.bias) 288 | elif isinstance(m, torch.nn.LayerNorm): 289 | torch.nn.init.constant_(m.weight, 1) 290 | torch.nn.init.constant_(m.bias, 0) 291 | elif isinstance(m, torch.nn.Linear): 292 | torch.nn.init.kaiming_normal_(m.weight, mode="fan_out") 293 | if m.bias is not None: 294 | torch.nn.init.zeros_(m.bias) 295 | 296 | self.crop_pre = None 297 | self.crop_post = None 298 | 299 | def get_crops(self, device): 300 | if self.crop_pre is None: 301 | with torch.no_grad(): 302 | initial_length = 128 303 | x_dummy = torch.zeros(1, self.num_features, initial_length, device=device) 304 | x_postemb = self.embedder(x_dummy) 305 | 306 | crop = (initial_length - x_postemb.shape[-1]) // 2 307 | self.crop_pre = crop 308 | self.crop_post = crop 309 | 310 | x_postcontext = self.apply_contexter(x_postemb, device) 311 | 312 | crop = x_postemb.shape[-1] - x_postcontext.shape[-1] 313 | self.crop_pre += crop 314 | 315 | return self.crop_pre, self.crop_post 316 | 317 | def random_subsample( 318 | self, 319 | batch: mabe.data.TrainingBatch, 320 | subsample_from_rand: TensorType["batch"] = None, 321 | ) -> Tuple[ 322 | TensorType["batch", "time", "channels", float], 323 | TensorType["batch", "time", "channels", float], 324 | TensorType["batch", "time", int], 325 | TensorType["batch", "time", "behavior", "classes", float], 326 | TensorType["batch", "time", "annotator", "classes", float], 327 | TensorType["batch", "time", int], 328 | TensorType["batch", int], 329 | ]: 330 | batch_size = len(batch.X) 331 | if subsample_from_rand is None: 332 | subsample_from_rand = np.random.rand(batch_size) 333 | combined_length = self.subsample_length + self.num_ahead 334 | 335 | X_ = np.zeros((batch_size, combined_length, self.num_features), dtype=np.float32) 336 | valid_lengths = np.zeros(batch_size, dtype=np.int) 337 | subsample(batch.X, X_, valid_lengths, combined_length, subsample_from_rand) 338 | 339 | X_extra_ = None 340 | if batch.X_extra is not None: 341 | X_extra_ = np.zeros( 342 | (batch_size, combined_length, self.num_extra_features), dtype=np.float32 343 | ) 344 | subsample( 345 | batch.X_extra, 346 | X_extra_, 347 | valid_lengths, 348 | combined_length, 349 | subsample_from_rand, 350 | ) 351 | 352 | Y_ = None 353 | if batch.Y is not None: 354 | Y_ = np.full((batch_size, combined_length), -1, dtype=np.int) 355 | subsample(batch.Y, Y_, valid_lengths, combined_length, subsample_from_rand) 356 | 357 | Y_dark_behaviors_ = None 358 | if batch.Y_dark_behaviors is not None: 359 | num_dark_behaviors = batch.Y_dark_behaviors[0].shape[1] 360 | num_dark_behaviors_classes = batch.Y_dark_behaviors[0].shape[2] 361 | Y_dark_behaviors_ = np.full( 362 | (batch_size, combined_length, num_dark_behaviors, num_dark_behaviors_classes), 363 | -1, 364 | dtype=np.float32, 365 | ) 366 | subsample( 367 | batch.Y_dark_behaviors, 368 | Y_dark_behaviors_, 369 | valid_lengths, 370 | combined_length, 371 | subsample_from_rand, 372 | ) 373 | 374 | Y_dark_annotators_ = None 375 | if batch.Y_dark_annotators is not None: 376 | num_dark_annotators = batch.Y_dark_annotators[0].shape[1] 377 | num_dark_annotators_classes = batch.Y_dark_annotators[0].shape[2] 378 | Y_dark_annotators_ = np.full( 379 | (batch_size, combined_length, num_dark_annotators, num_dark_annotators_classes), 380 | -1, 381 | dtype=np.float32, 382 | ) 383 | subsample( 384 | batch.Y_dark_annotators, 385 | Y_dark_annotators_, 386 | valid_lengths, 387 | combined_length, 388 | subsample_from_rand, 389 | ) 390 | 391 | annotators_ = None 392 | annotators_ = np.full((batch_size, combined_length), -1, dtype=np.int) 393 | subsample( 394 | batch.annotators, 395 | annotators_, 396 | valid_lengths, 397 | combined_length, 398 | subsample_from_rand, 399 | ) 400 | 401 | return X_, X_extra_, Y_, Y_dark_behaviors_, Y_dark_annotators_, annotators_, valid_lengths 402 | 403 | def subsample_and_pad( 404 | self, 405 | batch: mabe.data.TrainingBatch, 406 | device: str = "cpu", 407 | subsample_from_rand: TensorType["batch", int] = None, 408 | ) -> Tuple[ 409 | TensorType["batch", "time", "channels", float], 410 | TensorType["batch", "time", "channels", float], 411 | TensorType["batch", "time", int], 412 | TensorType["batch", "time", "behavior", "classes", float], 413 | TensorType["batch", "time", "annotator", "classes", float], 414 | TensorType["batch", "time", int], 415 | TensorType["batch", int], 416 | ]: 417 | ( 418 | X, 419 | X_extra, 420 | Y, 421 | Y_dark_behaviors, 422 | Y_dark_annotators, 423 | annotators, 424 | valid_lengths, 425 | ) = self.random_subsample(batch, subsample_from_rand=subsample_from_rand) 426 | X = torch.transpose(torch.from_numpy(X), 2, 1).to(device, non_blocking=True) 427 | 428 | if X_extra is not None: 429 | X_extra = torch.from_numpy(X_extra).to(device, non_blocking=True) 430 | 431 | if Y is not None: 432 | Y = torch.from_numpy(Y).to(device, non_blocking=True) 433 | 434 | if Y_dark_behaviors is not None: 435 | Y_dark_behaviors = torch.from_numpy(Y_dark_behaviors).to(device, non_blocking=True) 436 | 437 | if Y_dark_annotators is not None: 438 | Y_dark_annotators = torch.from_numpy(Y_dark_annotators).to(device, non_blocking=True) 439 | 440 | return X, X_extra, Y, Y_dark_behaviors, Y_dark_annotators, annotators, valid_lengths 441 | 442 | def cpc_loss( 443 | self, 444 | X_emb: TensorType["batch", "time", "channels", float], 445 | contexts: TensorType["batch", "time", "channels", float], 446 | from_timesteps: TensorType["batch", int], 447 | valid_lengths: TensorType["batch", int], 448 | ) -> TensorType[float]: 449 | batch_size = len(contexts) 450 | 451 | from_timesteps_reduced = from_timesteps 452 | contexts_from = torch.stack( 453 | [contexts[i, from_timesteps_reduced[i]] for i in range(batch_size)] 454 | ) 455 | 456 | embeddings_projections = torch.stack( 457 | list(map(lambda p: p(contexts_from), self.projections)) 458 | ) 459 | embeddings_projections = torch.nn.functional.normalize(embeddings_projections, p=2, dim=-1) 460 | 461 | ahead = torch.arange(self.num_ahead_subsampling, self.num_ahead, self.num_ahead_subsampling) 462 | to_timesteps = torch.from_numpy(from_timesteps_reduced)[:, None] + ahead 463 | 464 | X_emb_t = X_emb.transpose(2, 1) 465 | embeddings_to = X_emb_t[ 466 | torch.arange(batch_size)[:, None].repeat(1, len(ahead)).flatten(), 467 | to_timesteps.flatten(), 468 | ].reshape(batch_size, len(ahead), -1) 469 | embeddings_to = torch.nn.functional.normalize(embeddings_to, p=2, dim=-1) 470 | 471 | embeddings_projections = torch.stack( 472 | list(map(lambda p: p(contexts_from), self.projections[1:])) 473 | ) 474 | embeddings_projections = torch.nn.functional.normalize(embeddings_projections, p=2, dim=-1) 475 | 476 | # assume both are l2-normalized -> cosine similarity 477 | predictions_to = torch.einsum("tac,btc->tab", embeddings_projections, embeddings_to) 478 | 479 | batch_loss = [] 480 | for idx, ahead in enumerate( 481 | range(self.num_ahead_subsampling, self.num_ahead, self.num_ahead_subsampling) 482 | ): 483 | predictions_to_ahead = predictions_to[idx] 484 | 485 | has_value_to = torch.from_numpy((from_timesteps_reduced + ahead) < valid_lengths).to( 486 | X_emb.device, non_blocking=True 487 | ) 488 | loss = nt_xent_loss(predictions_to_ahead, has_value_to) 489 | loss_mean = torch.mean(loss) 490 | 491 | batch_loss.append(loss_mean) 492 | 493 | aggregated_batch_loss = sum(batch_loss) / len(batch_loss) 494 | 495 | return aggregated_batch_loss 496 | 497 | def apply_contexter(self, X_emb, device="cpu"): 498 | contexts = self.contexter(X_emb) 499 | 500 | return contexts 501 | 502 | def get_contexts(self, X, device="cpu"): 503 | contexts = [] 504 | with torch.no_grad(): 505 | bar = progress_bar(range(len(X))) 506 | for idx in bar: 507 | x = X[idx].astype(np.float32) 508 | 509 | x = torch.transpose(torch.from_numpy(x[None, :, :]), 2, 1).to( 510 | device, non_blocking=True 511 | ) 512 | x_emb = self.embedder(x) 513 | c = self.apply_contexter(x_emb, device) 514 | 515 | contexts.append(c.cpu().numpy()) 516 | 517 | contexts = np.concatenate(contexts) 518 | return contexts 519 | 520 | def add_padding_to_batch(self, batch: mabe.data.TrainingBatch, device: str): 521 | crop_pre, crop_post = self.get_crops(device) 522 | 523 | def concat_padding_1d(sequences): 524 | l = numba.typed.List() 525 | for x in sequences: 526 | l.append( 527 | np.concatenate( 528 | ( 529 | np.zeros((crop_pre,), dtype=x.dtype), 530 | x, 531 | np.zeros((crop_post,), dtype=x.dtype), 532 | ), 533 | axis=-1, 534 | ) 535 | ) 536 | return l 537 | 538 | def concat_padding_nd(sequences): 539 | l = numba.typed.List() 540 | for x in sequences: 541 | l.append( 542 | np.concatenate( 543 | ( 544 | np.zeros((crop_pre, *x.shape[1:]), dtype=x.dtype), 545 | x, 546 | np.zeros((crop_post, *x.shape[1:]), dtype=x.dtype), 547 | ), 548 | axis=0, 549 | ) 550 | ) 551 | return l 552 | 553 | batch.X = concat_padding_nd(batch.X) 554 | batch.X_extra = concat_padding_nd(batch.X_extra) 555 | batch.Y = concat_padding_1d(batch.Y) 556 | if batch.Y_dark_behaviors is not None: 557 | batch.Y_dark_behaviors = concat_padding_nd(batch.Y_dark_behaviors) 558 | if batch.Y_dark_annotators is not None: 559 | batch.Y_dark_annotators = concat_padding_nd(batch.Y_dark_annotators) 560 | batch.annotators = concat_padding_1d(batch.annotators) 561 | 562 | return batch 563 | 564 | def forward( 565 | self, 566 | batch: mabe.data.TrainingBatch, 567 | device: str = "cpu", 568 | with_loss: bool = False, 569 | min_from_timesteps: int = 0, 570 | subsample_from_rand: TensorType["batch", int] = None, 571 | ): 572 | # TODO: refactor subsampling 573 | batch_size = len(batch.X) 574 | batch = self.add_padding_to_batch(batch, device) 575 | 576 | ( 577 | X_batch_samples, 578 | X_extra_batch_samples, 579 | Y_batch_samples, 580 | Y_batch_dark_behaviors_samples, 581 | Y_batch_dark_annotators_samples, 582 | annotators_samples, 583 | valid_lengths, 584 | ) = self.subsample_and_pad(batch, device, subsample_from_rand=subsample_from_rand) 585 | 586 | X_emb = self.embedder(X_batch_samples) 587 | crop = (Y_batch_samples.shape[-1] - X_emb.shape[-1]) // 2 588 | valid_lengths -= crop 589 | if X_extra_batch_samples is not None: 590 | X_extra_batch_samples = X_extra_batch_samples[:, crop:-crop] 591 | if Y_batch_samples is not None: 592 | Y_batch_samples = Y_batch_samples[:, crop:-crop] 593 | if Y_batch_dark_behaviors_samples is not None: 594 | Y_batch_dark_behaviors_samples = Y_batch_dark_behaviors_samples[:, crop:-crop] 595 | if Y_batch_dark_annotators_samples is not None: 596 | Y_batch_dark_annotators_samples = Y_batch_dark_annotators_samples[:, crop:-crop] 597 | annotators_samples = annotators_samples[:, crop:-crop] 598 | 599 | max_from_timesteps = [x.shape[-1] - (self.num_ahead + crop) for x in X_emb] 600 | from_timesteps = np.random.randint( 601 | low=min_from_timesteps, 602 | high=max_from_timesteps, 603 | size=batch_size, 604 | ) 605 | 606 | contexts = self.apply_contexter(X_emb, device).transpose(2, 1) 607 | crop = Y_batch_samples.shape[-1] - contexts.shape[-2] 608 | valid_lengths -= crop 609 | if X_extra_batch_samples is not None: 610 | X_extra_batch_samples = X_extra_batch_samples[:, crop:] 611 | if Y_batch_samples is not None: 612 | Y_batch_samples = Y_batch_samples[:, crop:] 613 | if Y_batch_dark_behaviors_samples is not None: 614 | Y_batch_dark_behaviors_samples = Y_batch_dark_behaviors_samples[:, crop:] 615 | if Y_batch_dark_annotators_samples is not None: 616 | Y_batch_dark_annotators_samples = Y_batch_dark_annotators_samples[:, crop:] 617 | annotators_samples = annotators_samples[:, crop:] 618 | 619 | if with_loss: 620 | batch_loss = self.cpc_loss(X_emb, contexts, from_timesteps, valid_lengths) 621 | return ( 622 | contexts, 623 | X_extra_batch_samples, 624 | Y_batch_samples, 625 | Y_batch_dark_behaviors_samples, 626 | Y_batch_dark_annotators_samples, 627 | annotators_samples, 628 | batch_loss, 629 | ) 630 | else: 631 | return ( 632 | contexts, 633 | X_extra_batch_samples, 634 | Y_batch_samples, 635 | Y_batch_dark_behaviors_samples, 636 | Y_batch_dark_annotators_samples, 637 | annotators_samples, 638 | ) 639 | 640 | 641 | class MultiAnnotatorLogisticRegressionHead(torch.nn.Module): 642 | def __init__( 643 | self, 644 | num_features, 645 | num_annotators, 646 | num_extra_features=0, 647 | num_extra_clf_tasks=0, 648 | num_classes=4, 649 | num_extra_classes=2, 650 | ): 651 | super().__init__() 652 | 653 | annotator_embedding_size = num_annotators 654 | 655 | self.num_features_combined = num_features + num_extra_features 656 | 657 | self.ln = torch.nn.LayerNorm(num_features) 658 | 659 | self.logregs = [torch.nn.Linear(self.num_features_combined, num_classes)] 660 | for i in range(num_extra_clf_tasks): 661 | self.logregs.append(torch.nn.Linear(self.num_features_combined, num_extra_classes)) 662 | self.logregs = torch.nn.ModuleList(self.logregs) 663 | 664 | self.residual = torch.nn.Sequential( 665 | torch.nn.Linear( 666 | self.num_features_combined + annotator_embedding_size, 667 | self.num_features_combined, 668 | ), 669 | torch.nn.LeakyReLU(), 670 | torch.nn.Linear( 671 | self.num_features_combined, 672 | self.num_features_combined, 673 | ), 674 | torch.nn.LeakyReLU(), 675 | ) 676 | 677 | self.embedding = torch.nn.Parameter((torch.diag(torch.ones(num_annotators))).detach()) 678 | self.register_parameter(name="embedding", param=self.embedding) 679 | 680 | for m in self.modules(): 681 | if isinstance(m, torch.nn.LayerNorm): 682 | torch.nn.init.constant_(m.weight, 1) 683 | torch.nn.init.constant_(m.bias, 0) 684 | elif isinstance(m, torch.nn.Linear): 685 | torch.nn.init.normal_(m.weight, std=0.01) 686 | if m.bias is not None: 687 | torch.nn.init.zeros_(m.bias) 688 | 689 | def forward(self, x, x_extra, annotators, clf_task): 690 | x = self.ln(x) 691 | a = self.embedding[annotators.flatten()].reshape(*annotators.shape, -1) 692 | 693 | x_ = torch.cat((x, x_extra, a), dim=-1) 694 | x_ = self.residual(x_) 695 | 696 | x = torch.cat((x, x_extra), dim=-1) + x_ 697 | 698 | logits = self.logregs[clf_task](x) 699 | 700 | return logits 701 | --------------------------------------------------------------------------------