├── shift_happens ├── __init__.py ├── label_shift │ ├── README.md │ └── __init__.py ├── gendist │ ├── gendist │ │ ├── __init__.py │ │ ├── models.py │ │ ├── processing.py │ │ └── training.py │ ├── half-moons.mp4 │ ├── setup.py │ ├── experiments │ │ ├── multiprocess_test.py │ │ ├── mnist_zero_shot_test.py │ │ ├── mnist_rotation_meta.py │ │ └── mnist_rotation_data.py │ └── notebooks │ │ ├── dojax.py │ │ └── 013-metadata-inference.ipynb └── imagenet_flax │ ├── requirements.txt │ ├── logs │ └── lenet │ │ ├── checkpoint_1500 │ │ ├── checkpoint_3000 │ │ └── checkpoint_4500 │ ├── README.md │ ├── agents │ ├── mlp.py │ ├── lenet.py │ ├── models.py │ ├── nearest_centroid_classifier.py │ ├── nearest_centroid_classifier_old.py │ ├── models_test.py │ └── resnet.py │ ├── configs │ ├── default.py │ ├── resnet20_config.py │ ├── lenet_config.py │ └── tpu_dynamic.py │ ├── utils.py │ ├── gdumb.py │ ├── ncc_demo.py │ ├── imagenet_fake_data_benchmark.py │ ├── main.py │ ├── train_test.py │ ├── environment.py │ ├── gdumb_old.py │ └── train.py ├── README.md ├── LICENSE └── .gitignore /shift_happens/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /shift_happens/label_shift/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /shift_happens/label_shift/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # shift-happens 2 | Research code for ML with distribution shift 3 | -------------------------------------------------------------------------------- /shift_happens/gendist/gendist/__init__.py: -------------------------------------------------------------------------------- 1 | from . import training, processing, models -------------------------------------------------------------------------------- /shift_happens/gendist/half-moons.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/shift-happens/main/shift_happens/gendist/half-moons.mp4 -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/requirements.txt: -------------------------------------------------------------------------------- 1 | clu 2 | ml-collections 3 | optax 4 | tensorflow 5 | tensorflow-datasets 6 | imax -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/logs/lenet/checkpoint_1500: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/shift-happens/main/shift_happens/imagenet_flax/logs/lenet/checkpoint_1500 -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/logs/lenet/checkpoint_3000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/shift-happens/main/shift_happens/imagenet_flax/logs/lenet/checkpoint_3000 -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/logs/lenet/checkpoint_4500: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/shift-happens/main/shift_happens/imagenet_flax/logs/lenet/checkpoint_4500 -------------------------------------------------------------------------------- /shift_happens/gendist/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="gendist", 5 | packages=find_packages(), 6 | install_requires=[ 7 | "jaxlib", 8 | "jax" 9 | ] 10 | ) -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/README.md: -------------------------------------------------------------------------------- 1 | # Continual Learning 2 | 3 | There are three different models 4 | - ResNet 5 | - LeNet 6 | - Nearest Centroid Classifier 7 | 8 | ### Running the Training of Deep Models 9 | You can train a single model by 10 | 11 | ```shell 12 | python main.py --workdir=./logs/lenet --config=configs/lenet_config.py 13 | ``` 14 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/agents/mlp.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from typing import Any, Sequence 4 | from flax import linen as nn 5 | 6 | ModuleDef = Any 7 | 8 | 9 | class MLP(nn.Module): 10 | layer_dims: Sequence[int] 11 | num_classes: int 12 | dtype: Any = jnp.float32 13 | 14 | @nn.compact 15 | def __call__(self, x, train: bool = True): 16 | x = x.reshape((x.shape[0], -1)) 17 | for layer_dim in self.layer_dims: 18 | x = nn.Dense(features=layer_dim, dtype=self.dtype)(x) 19 | x = nn.relu(x) 20 | x = nn.Dense(self.num_classes, dtype=self.dtype)(x) 21 | return x 22 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/configs/default.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | """Get the default hyperparameter configuration.""" 6 | config = ml_collections.ConfigDict() 7 | 8 | # As defined in the `models` module. 9 | config.model = 'ResNet18' 10 | # `name` argument of tensorflow_datasets.builder() 11 | config.dataset = 'cifar10' 12 | config.image_size = -1 13 | 14 | config.learning_rate = 1e-7 15 | config.batch_size = -1 16 | 17 | config.num_epochs = 1 18 | config.log_every_steps = 1 19 | 20 | config.momentum_decay = 0.9 21 | config.weight_decay = 100. 22 | 23 | config.cache = False 24 | config.half_precision = False 25 | 26 | # If num_train_steps==-1 then the number of training steps is calculated from 27 | # num_epochs using the entire dataset. Similarly for steps_per_eval. 28 | config.num_train_steps = 1 29 | config.steps_per_eval = 1 30 | 31 | return config 32 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/configs/resnet20_config.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | """Get the default hyperparameter configuration.""" 6 | config = ml_collections.ConfigDict() 7 | 8 | # As defined in the `models` module. 9 | config.model = 'ResNet18' 10 | # `name` argument of tensorflow_datasets.builder() 11 | config.dataset = 'cifar10' 12 | config.image_size = -1 13 | 14 | config.learning_rate = 1e-7 15 | config.batch_size = 80 16 | 17 | config.num_epochs = 300 18 | config.log_every_steps = 100 19 | 20 | config.momentum_decay = 0.9 21 | config.weight_decay = 100. 22 | 23 | config.cache = False 24 | config.half_precision = False 25 | 26 | # If num_train_steps==-1 then the number of training steps is calculated from 27 | # num_epochs using the entire dataset. Similarly for steps_per_eval. 28 | config.num_train_steps = -1 29 | config.steps_per_eval = -1 30 | return config 31 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/configs/lenet_config.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(seed=2): 5 | """Get the default hyperparameter configuration.""" 6 | config = ml_collections.ConfigDict() 7 | 8 | # As defined in the `models` module. 9 | config.model = 'LeNet' 10 | # `name` argument of tensorflow_datasets.builder() 11 | config.dataset = 'mnist' 12 | config.seed = seed 13 | config.image_size = -1 14 | 15 | config.learning_rate = 0.001 16 | config.batch_size = 80 17 | config.train_freq = 5 18 | 19 | config.num_epochs = 200 20 | config.log_every_steps = 100 21 | 22 | config.momentum_decay = 0.9 23 | config.weight_decay = 0.001 24 | 25 | config.cache = False 26 | config.half_precision = False 27 | 28 | # If num_train_steps==-1 then the number of training steps is calculated from 29 | # num_epochs using the entire dataset. Similarly for steps_per_eval. 30 | config.num_train_steps = -1 31 | config.steps_per_eval = -1 32 | return config 33 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/agents/lenet.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from typing import Any, Callable 4 | from flax import linen as nn 5 | 6 | ModuleDef = Any 7 | 8 | 9 | class LeNet5(nn.Module): 10 | num_classes: int 11 | act: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 12 | dtype: Any = jnp.float32 13 | 14 | @nn.compact 15 | def __call__(self, x, train: bool = True): 16 | """Network inspired by LeNet-5.""" 17 | x = self.act(nn.Conv(features=6, kernel_size=(5, 5), padding="SAME", dtype=self.dtype)(x)) 18 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID") 19 | x = self.act(nn.Conv(features=16, kernel_size=(5, 5), padding="VALID", dtype=self.dtype)(x)) 20 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID") 21 | x = x.reshape((x.shape[0], -1)) 22 | x = self.act(nn.Dense(features=120, dtype=self.dtype)(x)) 23 | x = self.act(nn.Dense(features=84, dtype=self.dtype)(x)) 24 | x = nn.Dense(features=self.num_classes, dtype=self.dtype)(x) 25 | return x 26 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/agents/models.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from functools import partial 3 | 4 | from agents.mlp import MLP 5 | from agents.lenet import LeNet5 6 | from agents.resnet import ResNet, BottleneckResNetBlock, ResNetBlock 7 | 8 | ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], 9 | block_cls=ResNetBlock) 10 | ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], 11 | block_cls=ResNetBlock) 12 | ResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3], 13 | block_cls=BottleneckResNetBlock) 14 | ResNet101 = partial(ResNet, stage_sizes=[3, 4, 23, 3], 15 | block_cls=BottleneckResNetBlock) 16 | ResNet152 = partial(ResNet, stage_sizes=[3, 8, 36, 3], 17 | block_cls=BottleneckResNetBlock) 18 | ResNet200 = partial(ResNet, stage_sizes=[3, 24, 36, 3], 19 | block_cls=BottleneckResNetBlock) 20 | LeNet = partial(LeNet5, act=jax.nn.relu) 21 | 22 | mlp = partial(MLP, layer_dims=[256, 256]) 23 | 24 | # Used for testing only. 25 | _ResNet1 = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock) 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Probabilistic machine learning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/utils.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from jax.nn import log_softmax 4 | 5 | from collections import defaultdict 6 | 7 | 8 | def get_accuracy(y, logits): 9 | preds = jnp.argmax(logits, axis=-1) 10 | return jnp.mean(y == preds) 11 | 12 | 13 | def eval_step(env, state, predict_fn, prior=None, rotations=None): 14 | metrics = defaultdict(lambda: jnp.array([])) 15 | if prior is None: 16 | mean, *_ = state 17 | num_classes = len(mean) 18 | prior = log_softmax(jnp.ones((num_classes,))) 19 | 20 | if rotations is None: 21 | rotations = jnp.argwhere(env.rot_mat.sum(axis=0) > 0).flatten().tolist() 22 | 23 | for batch in env.test_data: 24 | input, label = jnp.array(batch['image']), jnp.array(batch['label']) 25 | if input.shape[-1] == 1: 26 | input = jnp.repeat(input, axis=-1, repeats=3) 27 | 28 | for degrees in rotations: 29 | if degrees: 30 | x = env.apply(input, degrees) 31 | else: 32 | x = input 33 | logits = predict_fn(state, x, prior) 34 | metrics[degrees] = jnp.append(metrics[degrees], get_accuracy(label, logits)) 35 | 36 | return {k: jnp.mean(v) for k, v in metrics.items()} 37 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/gdumb.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import ops, random 3 | 4 | from collections import defaultdict 5 | 6 | 7 | def init_sampler(): 8 | counts = defaultdict(int) 9 | dataset = (jnp.empty([]), jnp.empty([])) 10 | return counts, dataset 11 | 12 | 13 | def sample(state, carry, memory_size=200): 14 | counts, dataset = state 15 | key, x, y = carry 16 | n = len(counts.keys()) 17 | inputs, labels = dataset 18 | nperclass = memory_size / (n if n > 0 else 1) 19 | 20 | if n == 0: 21 | counts[y.item()] = counts[y.item()] + 1 22 | return counts, (x[None, ...], y) 23 | 24 | if y.item() not in counts or counts[y.item()] < nperclass: 25 | if len(inputs) >= memory_size: 26 | c_max = max(counts, key=counts.get) 27 | sample_key, key = random.split(key) 28 | indices = jnp.argwhere(labels == c_max) 29 | row = random.choice(sample_key, indices, shape=(1,)) 30 | inputs = ops.index_update(inputs, jnp.index_exp[row], x) 31 | labels = ops.index_update(labels, jnp.index_exp[row], y) 32 | counts[c_max] -= 1 33 | else: 34 | inputs = jnp.vstack([inputs, x[None, ...]]) 35 | labels = jnp.append(labels, y) 36 | 37 | counts[y.item()] += 1 38 | 39 | dataset = (inputs, labels) 40 | 41 | return counts, dataset 42 | -------------------------------------------------------------------------------- /shift_happens/gendist/experiments/multiprocess_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of the multiprocessing module in gendist for 3 | transforming multiple images. 4 | 5 | In this example, we transform each image in a 6 | dataset by a different angle. The first image 7 | is rotated by 0 degrees and the last image 8 | is rotated by 360 degrees. 9 | """ 10 | 11 | import gendist 12 | import torchvision 13 | import numpy as np 14 | from augly import image 15 | 16 | def processor(X, angle): 17 | X_shift = image.aug_np_wrapper(X, image.rotate, degrees=angle) 18 | size_im = X_shift.shape[0] 19 | size_pad = (28 - size_im) // 2 20 | size_pad_mod = (28 - size_im) % 2 21 | X_shift = np.pad(X_shift, (size_pad, size_pad + size_pad_mod)) 22 | 23 | return X_shift 24 | 25 | 26 | if __name__ == "__main__": 27 | from time import time 28 | 29 | init_time = time() 30 | mnist_train = torchvision.datasets.MNIST(root="./data", train=True, download=True) 31 | images = np.array(mnist_train.data) / 255.0 32 | 33 | n_configs = len(images) 34 | degrees = np.linspace(0, 360, n_configs) 35 | configs = [{"angle": float(angle)} for angle in degrees] 36 | process = gendist.processing.Factory(processor) 37 | images_proc = process(images, configs, n_processes=90) 38 | end_time = time() 39 | 40 | print(f"Time elapsed: {end_time - init_time:.2f}s") 41 | print(images_proc.shape) 42 | -------------------------------------------------------------------------------- /shift_happens/gendist/gendist/models.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | from typing import Callable 3 | import jax.numpy as jnp 4 | 5 | class LeNet5(nn.Module): 6 | num_classes : int 7 | activation : Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 8 | @nn.compact 9 | def __call__(self, x): 10 | """Aleyna's network inspired by LeNet-5.""" 11 | x = x if len(x.shape) > 1 else x[None, :] 12 | x = x.reshape((x.shape[0], 28, 28, 1)) 13 | x = self.activation(nn.Conv(features=6, kernel_size=(5,5), padding="SAME")(x)) 14 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID") 15 | x = self.activation(nn.Conv(features=16, kernel_size=(5,5), padding="VALID")(x)) 16 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID") 17 | x = x.reshape((x.shape[0], -1)) 18 | x = self.activation(nn.Dense(features=120)(x)) 19 | x = self.activation(nn.Dense(features=84)(x)) 20 | x = nn.Dense(features=self.num_classes)(x) 21 | x = nn.log_softmax(x) 22 | return x 23 | 24 | 25 | class MLPDataV1(nn.Module): 26 | num_outputs: int 27 | @nn.compact 28 | def __call__(self, x): 29 | x = nn.relu(nn.Dense(800)(x)) 30 | x = nn.relu(nn.Dense(500)(x)) 31 | x = nn.Dense(self.num_outputs)(x) 32 | x = nn.log_softmax(x) 33 | return x 34 | 35 | 36 | class MLPWeightsV1(nn.Module): 37 | num_outputs: int 38 | @nn.compact 39 | def __call__(self, x): 40 | x = nn.relu(nn.Dense(200)(x)) 41 | x = nn.Dense(self.num_outputs)(x) 42 | return x 43 | 44 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/agents/nearest_centroid_classifier.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import vmap, ops, lax 3 | from jax.scipy.stats import multivariate_normal 4 | 5 | from typing import Any, Tuple 6 | 7 | Array = Any 8 | 9 | 10 | def init(inputs: Array, labels: Array): 11 | _, *input_shape = inputs.shape 12 | num_classes = jnp.unique(labels).size 13 | 14 | input_size = jnp.product(jnp.array(input_shape)) 15 | 16 | mean = jnp.zeros((num_classes, input_size)) 17 | scale = jnp.repeat(jnp.eye(input_size)[None, ...], repeats=num_classes, axis=0) 18 | counts = jnp.zeros((num_classes,)) 19 | state = (mean, scale, counts) 20 | 21 | def init_step(state, carry): 22 | input, cls = carry 23 | state = update(state, input.reshape((1, -1)), cls) 24 | return state, None 25 | 26 | state, _ = lax.scan(init_step, state, (inputs, labels)) 27 | 28 | return state 29 | 30 | 31 | def predict(state: Tuple[Array, Array, Array], inputs: Array, prior: Array): 32 | mean, scale, _ = state 33 | 34 | def cond_prob(input): 35 | return multivariate_normal.logpdf(input.reshape((1, -1)), 36 | mean=mean, 37 | cov=scale) 38 | 39 | logits = vmap(cond_prob)(inputs) 40 | return logits + prior 41 | 42 | 43 | def update(state: Tuple[Array, Array, Array], inputs: Array, cls: int): 44 | mean, scale, counts = state 45 | n = counts[cls] + len(inputs) 46 | 47 | prev_sum = mean[cls] * counts[cls] 48 | cur_sum = inputs.reshape((len(inputs), -1)).sum(axis=0) 49 | running_avg = (prev_sum + cur_sum) / n 50 | 51 | mean = ops.index_update(mean, jnp.index_exp[cls], running_avg) 52 | state = (mean, scale, counts) 53 | return state 54 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/agents/nearest_centroid_classifier_old.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import vmap, ops, lax 3 | from jax.scipy.stats import multivariate_normal 4 | 5 | from typing import Any, Tuple 6 | 7 | Array = Any 8 | 9 | 10 | def init(inputs: Array, labels: Array): 11 | _, *input_shape = inputs.shape 12 | num_classes = jnp.unique(labels).size 13 | 14 | input_size = jnp.product(jnp.array(input_shape)) 15 | 16 | mean = jnp.zeros((num_classes, input_size)) 17 | scale = jnp.repeat(jnp.eye(input_size)[None, ...], repeats=num_classes, axis=0) 18 | counts = jnp.zeros((num_classes,)) 19 | state = (mean, scale, counts) 20 | 21 | def init_step(state, carry): 22 | input, cls = carry 23 | state = update(state, input.reshape((1, -1)), cls) 24 | return state, None 25 | 26 | state, _ = lax.scan(init_step, state, (inputs, labels)) 27 | 28 | return state 29 | 30 | 31 | def predict(state: Tuple[Array, Array, Array], inputs: Array, prior: Array): 32 | mean, scale, _ = state 33 | 34 | def cond_prob(input): 35 | return multivariate_normal.logpdf(input.reshape((1, -1)), 36 | mean=mean, 37 | cov=scale) 38 | 39 | logits = vmap(cond_prob)(inputs) 40 | return logits + prior 41 | 42 | 43 | def update(state: Tuple[Array, Array, Array], inputs: Array, cls: int): 44 | mean, scale, counts = state 45 | n = counts[cls] + len(inputs) 46 | 47 | prev_sum = mean[cls] * counts[cls] 48 | cur_sum = inputs.reshape((len(inputs), -1)).sum(axis=0) 49 | running_avg = (prev_sum + cur_sum) / n 50 | 51 | mean = ops.index_update(mean, jnp.index_exp[cls], running_avg) 52 | state = (mean, scale, counts) 53 | return state 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | scripts/.DS_Store 3 | .DS_Store 4 | MNIST 5 | outputs 6 | 7 | *.mp4 8 | *.png 9 | *.gif 10 | *.pdf 11 | *.jpg 12 | *.jpeg 13 | 14 | __pycache__ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # Distribution / packaging 19 | .Python build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | *.manifest 35 | *.spec 36 | 37 | # Log files 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | *.log 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | .pytest_cache/ 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # PyBuilder 59 | target/ 60 | 61 | # Jupyter Notebook 62 | .ipynb_checkpoints 63 | 64 | # IPython 65 | profile_default/ 66 | ipython_config.py 67 | 68 | # pyenv 69 | .python-version 70 | 71 | # pyflow 72 | __pypackages__/ 73 | 74 | # Environment 75 | .env 76 | .venv 77 | env/ 78 | venv/ 79 | ENV/ 80 | 81 | # If you are using PyCharm # 82 | .idea/* 83 | 84 | # Sublime Text 85 | *.tmlanguage.cache 86 | *.tmPreferences.cache 87 | *.stTheme.cache 88 | *.sublime-workspace 89 | *.sublime-project 90 | 91 | # sftp configuration file 92 | sftp-config.json 93 | 94 | # Package control specific files Package 95 | Control.last-run 96 | Control.ca-list 97 | Control.ca-bundle 98 | Control.system-ca-bundle 99 | GitHub.sublime-settings 100 | 101 | # Visual Studio Code # 102 | .vscode/* 103 | !.vscode/settings.json 104 | !.vscode/tasks.json 105 | !.vscode/launch.json 106 | !.vscode/extensions.json 107 | .history 108 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/ncc_demo.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import jax 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | import train 8 | import agents.nearest_centroid_classifier as agent 9 | from environment import Environment 10 | 11 | if __name__ == '__main__': 12 | key = jax.random.PRNGKey(0) 13 | local_device_count = jax.local_device_count() 14 | 15 | ds_name, num_classes = 'mnist', 10 16 | batch_size, num_pulls = 256, 10 17 | T = 200 18 | 19 | trans_mat, rot_mat = train.init_trans_mat_and_rot_mat(key, num_classes=num_classes) 20 | env = Environment('mnist', class_trans_mat=trans_mat, rot_mat=rot_mat, batch_size=batch_size) 21 | inputs, labels = env.warmup(num_pulls=num_pulls) 22 | 23 | state = agent.init(inputs, labels) 24 | prior = jax.nn.log_softmax(jnp.ones((num_classes))) 25 | 26 | accuracies, results = [defaultdict(list)] * num_classes, [] 27 | for t in range(T): 28 | batch = env.get_data() 29 | logits = agent.predict(state, batch['image'], prior) 30 | loss = train.cross_entropy_loss(logits, batch['label'], num_classes) 31 | accuracy = jnp.mean(jnp.argmax(logits, -1) == batch['label']) 32 | accuracies[env.current_class][env.rot.item()].append(accuracy) 33 | results.append(accuracy) 34 | state = agent.update(state, batch['image'], env.current_class) 35 | 36 | plt.style.use('seaborn-darkgrid') 37 | plt.figure(figsize=(12, 8)) 38 | plt.plot(np.arange(len(results)), results) 39 | plt.xlabel("Iteration") 40 | plt.ylabel("Accuracy") 41 | plt.savefig('./ncc_training.png') 42 | plt.show() 43 | 44 | '''fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(20, 8)) 45 | rotations = sorted(accuracies[0].keys()) 46 | for i, ax in enumerate(axes.flatten()): 47 | for rot, acc in accuracies[i].items(): 48 | ax.plot(np.arange(len(acc)), np.array(acc), 'o-', label=str(rot)) 49 | ax.legend() 50 | ax.set_title(f'Class {i}') 51 | ax.set_xlabel("Iteration") 52 | ax.set_ylabel("Accuracy") 53 | 54 | plt.tight_layout() 55 | plt.savefig('./ncc_cls_acc.png') 56 | plt.show()''' 57 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/imagenet_fake_data_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Benchmark for the ImageNet example using fake data for quick perf results.""" 16 | 17 | import pathlib 18 | import time 19 | 20 | from absl.testing import absltest 21 | from flax.testing import Benchmark 22 | import jax 23 | 24 | import tensorflow_datasets as tfds 25 | 26 | # Local imports. 27 | from configs import fake_data_benchmark as config_lib 28 | import train 29 | 30 | # Parse absl flags test_srcdir and test_tmpdir. 31 | jax.config.parse_flags_with_absl() 32 | 33 | 34 | class ImagenetBenchmarkFakeData(Benchmark): 35 | """Runs ImageNet using fake data for quickly measuring performance.""" 36 | 37 | def test_fake_data(self): 38 | workdir = self.get_tmp_model_dir() 39 | config = config_lib.get_config() 40 | # Go two directories up to the root of the flax directory. 41 | flax_root_dir = pathlib.Path(__file__).parents[2] 42 | data_dir = str(flax_root_dir) + '/.tfds/metadata' 43 | 44 | # Warm-up first so that we are not measuring just compilation. 45 | with tfds.testing.mock_data(num_examples=1024, data_dir=data_dir): 46 | train.train_and_evaluate(config, workdir) 47 | 48 | start_time = time.time() 49 | with tfds.testing.mock_data(num_examples=1024, data_dir=data_dir): 50 | train.train_and_evaluate(config, workdir) 51 | benchmark_time = time.time() - start_time 52 | 53 | self.report_wall_time(benchmark_time) 54 | self.report_extras({ 55 | 'description': 'ImageNet ResNet50 with fake data', 56 | 'model_name': 'resnet50', 57 | 'parameters': f'hp=true,bs={config.batch_size}', 58 | }) 59 | 60 | 61 | if __name__ == '__main__': 62 | absltest.main() 63 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Main file for running the ImageNet example. 16 | 17 | This file is intentionally kept short. The majority for logic is in libraries 18 | that can be easily tested and imported in Colab. 19 | """ 20 | import train 21 | import jax 22 | import tensorflow as tf 23 | 24 | from absl import app 25 | from absl import flags 26 | from absl import logging 27 | 28 | from ml_collections import config_flags 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | flags.DEFINE_string('workdir', None, 'Directory to store model data.') 33 | config_flags.DEFINE_config_file( 34 | 'config', 35 | None, 36 | 'File path to the training hyperparameter configuration.', 37 | lock_config=True) 38 | 39 | 40 | def main(argv): 41 | if len(argv) > 1: 42 | raise app.UsageError('Too many command-line arguments.') 43 | 44 | # exit() 45 | # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make 46 | # it unavailable to JAX. 47 | tf.config.experimental.set_visible_devices([], 'GPU') 48 | 49 | logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) 50 | logging.info('JAX local devices: %r', jax.local_devices()) 51 | 52 | # Add a note so that we can tell which task is which JAX host. 53 | # (Depending on the platform task 0 is not guaranteed to be host 0) 54 | """platform.work_unit().set_task_status(f'process_index: {jax.process_index()}', 55 | f'process_count: {jax.process_count()}') 56 | platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, 57 | FLAGS.workdir, 'workdir')""" 58 | 59 | train.train_and_evaluate(FLAGS.config, FLAGS.workdir) 60 | 61 | 62 | if __name__ == '__main__': 63 | flags.mark_flags_as_required(['config', 'workdir']) 64 | 65 | app.run(main) 66 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/train_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for flax.examples.imagenet.train.""" 16 | import tempfile 17 | 18 | from absl.testing import absltest 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | from jax import random 23 | 24 | import tensorflow as tf 25 | 26 | # Local imports. 27 | import experiments.continual_learning.agents.models as models 28 | import train 29 | from configs import default as default_lib 30 | 31 | jax.config.update('jax_disable_most_optimizations', True) 32 | 33 | 34 | class TrainTest(absltest.TestCase): 35 | 36 | def setUp(self): 37 | super().setUp() 38 | # Make sure tf does not allocate gpu memory. 39 | tf.config.experimental.set_visible_devices([], 'GPU') 40 | 41 | def test_create_model(self): 42 | """Tests creating model.""" 43 | num_classes = 10 44 | model = train.create_model(model_cls=models._ResNet1, num_classes=num_classes, 45 | half_precision=False) # pylint: disable=protected-access 46 | params, batch_stats = train.initialized(random.PRNGKey(0), 32, model) 47 | variables = {'params': params, 'batch_stats': batch_stats} 48 | x = random.normal(random.PRNGKey(1), (8, 32, 32, 3)) 49 | y = model.apply(variables, x, train=False) 50 | self.assertEqual(y.shape, (8, num_classes)) 51 | 52 | def test_train_and_evaluate(self): 53 | """Tests training and evaluation loop using mocked data.""" 54 | # Create a temporary directory where tensorboard metrics are written. 55 | workdir = tempfile.mkdtemp() 56 | 57 | # Define training configuration 58 | config = default_lib.get_config() 59 | config.model = '_ResNet1' 60 | config.batch_size = 8 61 | config.num_epochs = 1 62 | config.num_train_steps = 1 63 | config.steps_per_eval = 1 64 | train.train_and_evaluate(workdir=workdir, config=config) 65 | 66 | def test_log_likelihood(self): 67 | n, num_classes = 10, 10 68 | logits = jax.nn.log_softmax(jnp.ones((n, num_classes))) 69 | labels = jnp.arange(num_classes) 70 | loss = train.log_likelihood(logits, labels, num_classes) 71 | self.assertLess(loss, 0) 72 | self.assertEqual(loss, n * logits[0][0]) 73 | 74 | 75 | if __name__ == '__main__': 76 | absltest.main() 77 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/agents/models_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for flax.examples.imagenet.models.""" 16 | 17 | from absl.testing import absltest 18 | 19 | import jax 20 | from jax import numpy as jnp 21 | 22 | import experiments.continual_learning.agents.models as models 23 | 24 | jax.config.update('jax_disable_most_optimizations', True) 25 | 26 | 27 | class ResNetV1Test(absltest.TestCase): 28 | """Test cases for ResNet v1 model definition.""" 29 | 30 | def test_resnet_v1_model(self): 31 | """Tests ResNet V1 model definition and output (variables).""" 32 | rng = jax.random.PRNGKey(0) 33 | model_def = models.ResNet50(num_classes=10, dtype=jnp.float32) 34 | variables = model_def.init( 35 | rng, jnp.ones((8, 224, 224, 3), jnp.float32)) 36 | 37 | self.assertLen(variables, 2) 38 | # Resnet50 model will create parameters for the following layers: 39 | # conv + batch_norm = 2 40 | # BottleneckResNetBlock in stages: [3, 4, 6, 3] = 16 41 | # Followed by a Dense layer = 1 42 | self.assertLen(variables['params'], 19) 43 | 44 | def test_lenet5_model(self): 45 | """Tests LeNeT5 model definition and output (variables).""" 46 | rng = jax.random.PRNGKey(0) 47 | num_classes = 10 48 | model_def = models.LeNet(num_classes=num_classes, dtype=jnp.float32) 49 | variables = model_def.init( 50 | rng, jnp.ones((8, 32, 32, 3), jnp.float32)) 51 | 52 | self.assertLen(variables, 1) 53 | # LeNet5 model will create parameters for the following layers: 54 | # 2 Conv + 3 Dense = 2 55 | self.assertLen(variables['params'], 5) 56 | # The output of the last layer of LeNet5 will be equal to the number 57 | # of classes. In this case, it is 10 58 | self.assertLen(variables['params']['Dense_2']['bias'], num_classes) 59 | self.assertEqual(variables['params']['Dense_2']['kernel'].shape[-1], num_classes) 60 | 61 | def test_mlp_model(self): 62 | """Tests LeNeT5 model definition and output (variables).""" 63 | rng = jax.random.PRNGKey(0) 64 | num_classes = 10 65 | layer_dims = [256, 256] 66 | model_def = models.MLP(layer_dims=layer_dims, num_classes=num_classes, dtype=jnp.float32) 67 | variables = model_def.init( 68 | rng, jnp.ones((8, 32, 32, 3), jnp.float32)) 69 | 70 | self.assertLen(variables, 1) 71 | # MLP model will create parameters for the following layers: 72 | # 2 Dense Layer + 1 Output Layer = 3 73 | self.assertLen(variables['params'], len(layer_dims) + 1) 74 | 75 | layer_dims += [num_classes] 76 | for layer_idx, layer_dim in enumerate(layer_dims): 77 | self.assertLen(variables['params'][f'Dense_{layer_idx}']['bias'], layer_dim) 78 | self.assertEqual(variables['params'][f'Dense_{layer_idx}']['kernel'].shape[-1], layer_dim) 79 | 80 | 81 | if __name__ == '__main__': 82 | absltest.main() 83 | -------------------------------------------------------------------------------- /shift_happens/gendist/gendist/processing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This library contains functions to process image data used by GenDist 3 | """ 4 | import jax 5 | import numpy as np 6 | import jax.numpy as jnp 7 | from multiprocessing import Pool 8 | from augly import image 9 | 10 | # DataAugmentationFactory 11 | class Factory: 12 | """ 13 | This is a base library to process / transform the elements of a numpy 14 | array according to a given function. To be used with gendist.TrainingConfig 15 | """ 16 | def __init__(self, processor): 17 | self.processor = processor 18 | 19 | def __call__(self, img, configs, n_processes=90): 20 | return self.process_multiple_multiprocessing(img, configs, n_processes) 21 | 22 | def process_single(self, X, *args, **kwargs): 23 | """ 24 | Process a single element. 25 | 26 | Paramters 27 | --------- 28 | X: np.array 29 | A single numpy array 30 | kwargs: dict/params 31 | Processor's configuration parameters 32 | """ 33 | return self.processor(X, *args, **kwargs) 34 | 35 | def process_multiple(self, X_batch, configurations): 36 | """ 37 | Process all elements of a numpy array according to a list 38 | of configurations. 39 | Each image is processed according to a configuration. 40 | """ 41 | X_out = [] 42 | n_elements = len(X_batch) 43 | 44 | for X, configuration in zip(X_batch, configurations): 45 | X_processed = self.process_single(X, **configuration) 46 | X_out.append(X_processed) 47 | 48 | X_out = np.stack(X_out, axis=0) 49 | return X_out 50 | 51 | def process_multiple_multiprocessing(self, X_dataset, configurations, n_processes): 52 | """ 53 | Process elements in a numpy array in parallel. 54 | 55 | Parameters 56 | ---------- 57 | X_dataset: array(N, ...) 58 | N elements of arbitrary shape 59 | configurations: list 60 | List of configurations to apply to each element. Each 61 | element is a dict to pass to the processor. 62 | n_processes: int 63 | Number of cores to use 64 | """ 65 | num_elements = len(X_dataset) 66 | if type(configurations) == dict: 67 | configurations = [configurations] * num_elements 68 | 69 | dataset_proc = np.array_split(X_dataset, n_processes) 70 | config_split = np.array_split(configurations, n_processes) 71 | elements = zip(dataset_proc, config_split) 72 | 73 | with Pool(processes=n_processes) as pool: 74 | dataset_proc = pool.starmap(self.process_multiple, elements) 75 | dataset_proc = np.concatenate(dataset_proc, axis=0) 76 | pool.join() 77 | 78 | return dataset_proc.reshape(num_elements, -1) 79 | 80 | 81 | def flat_and_concat_params(params_hist): 82 | """ 83 | Flat and concat a list of parameters trained using 84 | a Flax model 85 | 86 | Parameters 87 | ---------- 88 | params_hist: list of flax FrozenDicts 89 | List of flax FrozenDicts containing trained model 90 | weights. 91 | 92 | Returns 93 | ------- 94 | jnp.array: flattened and concatenated weights 95 | function: function to unflatten (reconstruct) weights 96 | """ 97 | _, recontruct_pytree_fn = jax.flatten_util.ravel_pytree(params_hist[0]) 98 | flat_params = [jax.flatten_util.ravel_pytree(params)[0] for params in params_hist] 99 | flat_params = jnp.r_[flat_params] 100 | return flat_params, recontruct_pytree_fn -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/agents/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Flax implementation of ResNet V1, LeNet5 and MLP. 3 | https://github.com/google/flax/blob/main/examples/imagenet/models.py 4 | """ 5 | 6 | import jax.numpy as jnp 7 | 8 | from functools import partial 9 | from typing import Any, Callable, Sequence, Tuple 10 | from flax import linen as nn 11 | 12 | ModuleDef = Any 13 | 14 | 15 | class ResNetBlock(nn.Module): 16 | """ResNet block.""" 17 | filters: int 18 | conv: ModuleDef 19 | norm: ModuleDef 20 | act: Callable 21 | strides: Tuple[int, int] = (1, 1) 22 | 23 | @nn.compact 24 | def __call__(self, x, ): 25 | residual = x 26 | y = self.conv(self.filters, (3, 3), self.strides)(x) 27 | y = self.norm()(y) 28 | y = self.act(y) 29 | y = self.conv(self.filters, (3, 3))(y) 30 | y = self.norm(scale_init=nn.initializers.zeros)(y) 31 | 32 | if residual.shape != y.shape: 33 | residual = self.conv(self.filters, (1, 1), 34 | self.strides, name='conv_proj')(residual) 35 | residual = self.norm(name='norm_proj')(residual) 36 | 37 | return self.act(residual + y) 38 | 39 | 40 | class BottleneckResNetBlock(nn.Module): 41 | """Bottleneck ResNet block.""" 42 | filters: int 43 | conv: ModuleDef 44 | norm: ModuleDef 45 | act: Callable 46 | strides: Tuple[int, int] = (1, 1) 47 | 48 | @nn.compact 49 | def __call__(self, x): 50 | residual = x 51 | y = self.conv(self.filters, (1, 1))(x) 52 | y = self.norm()(y) 53 | y = self.act(y) 54 | y = self.conv(self.filters, (3, 3), self.strides)(y) 55 | y = self.norm()(y) 56 | y = self.act(y) 57 | y = self.conv(self.filters * 4, (1, 1))(y) 58 | y = self.norm(scale_init=nn.initializers.zeros)(y) 59 | 60 | if residual.shape != y.shape: 61 | residual = self.conv(self.filters * 4, (1, 1), 62 | self.strides, name='conv_proj')(residual) 63 | residual = self.norm(name='norm_proj')(residual) 64 | 65 | return self.act(residual + y) 66 | 67 | 68 | class ResNet(nn.Module): 69 | """ResNetV1.""" 70 | stage_sizes: Sequence[int] 71 | block_cls: ModuleDef 72 | num_classes: int 73 | num_filters: int = 64 74 | dtype: Any = jnp.float32 75 | act: Callable = nn.relu 76 | 77 | @nn.compact 78 | def __call__(self, x, train: bool = True): 79 | conv = partial(nn.Conv, use_bias=False, dtype=self.dtype) 80 | norm = partial(nn.BatchNorm, 81 | use_running_average=not train, 82 | momentum=0.9, 83 | epsilon=1e-5, 84 | dtype=self.dtype) 85 | 86 | x = conv(self.num_filters, (7, 7), (2, 2), 87 | padding=[(3, 3), (3, 3)], 88 | name='conv_init')(x) 89 | x = norm(name='bn_init')(x) 90 | x = nn.relu(x) 91 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') 92 | for i, block_size in enumerate(self.stage_sizes): 93 | for j in range(block_size): 94 | strides = (2, 2) if i > 0 and j == 0 else (1, 1) 95 | x = self.block_cls(self.num_filters * 2 ** i, 96 | strides=strides, 97 | conv=conv, 98 | norm=norm, 99 | act=self.act)(x) 100 | x = jnp.mean(x, axis=(1, 2)) 101 | x = nn.Dense(self.num_classes, dtype=self.dtype)(x) 102 | x = jnp.asarray(x, self.dtype) 103 | return x 104 | -------------------------------------------------------------------------------- /shift_happens/gendist/experiments/mnist_zero_shot_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import gendist 4 | import torchvision 5 | import numpy as np 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | from datetime import datetime 9 | from tqdm import tqdm 10 | from loguru import logger 11 | from augly import image 12 | from jax.flatten_util import ravel_pytree 13 | 14 | 15 | def processor(X, angle): 16 | X_shift = image.aug_np_wrapper(X, image.rotate, degrees=angle) 17 | size_im = X_shift.shape[0] 18 | size_pad = (28 - size_im) // 2 19 | size_pad_mod = (28 - size_im) % 2 20 | X_shift = np.pad(X_shift, (size_pad, size_pad + size_pad_mod)) 21 | 22 | return X_shift 23 | 24 | 25 | def predict_shifted_dataset(ix_seed, X_batch, processor, config, wmodel, wparams, dmodel, proj, fn_reconstruct): 26 | """ 27 | Parameters 28 | ---------- 29 | ix_seed: array 30 | X_batch: array 31 | ... 32 | wmodel: model for the latent space 33 | wparams: trained weights for the latent space 34 | dmodel: model for the observed space 35 | dparams: trained model for the observed weights 36 | """ 37 | x_seed = X_batch[ix] 38 | x_shift = processor.process_single(x_seed, **config).ravel() 39 | predicted_weights = wmodel.apply(wparams, x_shift) 40 | predicted_weights = proj.inverse_transform(predicted_weights) 41 | predicted_weights = fn_reconstruct(predicted_weights) 42 | 43 | X_batch_shift = processor(X_batch, config) 44 | y_batch_hat = dmodel.apply(predicted_weights, X_batch_shift) 45 | 46 | return y_batch_hat 47 | 48 | 49 | path_experiment = "./outputs/2203221129/" 50 | path_data_model = os.path.join(path_experiment, "output", "data-model-result.pkl") 51 | path_meta_model = os.path.join(path_experiment, "output", "meta-model.pkl") 52 | path_results = os.path.join(path_experiment, "output", "accuracy.pkl") 53 | 54 | 55 | with open(path_data_model, "rb") as f: 56 | data_model_results = pickle.load(f) 57 | 58 | with open(path_meta_model, "rb") as f: 59 | meta_model_results = pickle.load(f) 60 | 61 | now_str = datetime.now().strftime("%Y%m%d%H%M") 62 | file_log = f"trench_test_{now_str}.log" 63 | path_logger = os.path.join(path_experiment, "logs", file_log) 64 | logger.remove() 65 | logger.add(path_logger, rotation="5mb") 66 | 67 | mnist_test = torchvision.datasets.MNIST(root=".", train=False, download=True) 68 | X_test = np.array(mnist_test.data) / 255 69 | y_test = np.array(mnist_test.targets) 70 | 71 | proc_class = gendist.processing.Factory(processor) 72 | pca = meta_model_results["projection_model"] 73 | 74 | meta_model = gendist.models.MLPWeightsV1(pca.n_components) 75 | data_model = gendist.models.MLPDataV1(10) 76 | 77 | _, fn_reconstruct_params = ravel_pytree(data_model_results["params"][0]) 78 | 79 | accuracy_configs_learned = [] 80 | ixs = np.arange(5) 81 | 82 | for config in tqdm(data_model_results["configs"]): 83 | acc_dict = {} 84 | for ix in ixs: 85 | y_test_hat = predict_shifted_dataset(ix, X_test, proc_class, config, 86 | meta_model, meta_model_results["params"], 87 | data_model, pca, fn_reconstruct_params) 88 | y_test_hat = y_test_hat.argmax(axis=1) 89 | accuracy_learned = (y_test_hat == y_test).mean().item() 90 | acc_dict[ix] = accuracy_learned 91 | 92 | accuracy_configs_learned.append(acc_dict) 93 | 94 | angle = config["angle"] 95 | logger_row = "|".join([format(v, "0.2%") for v in acc_dict.values()]) 96 | logger_row = f"{angle=:0.4f} | " + logger_row 97 | 98 | logger.info(logger_row) 99 | 100 | pd.DataFrame(acc_dict).to_pickle(path_results) 101 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/configs/tpu_dynamic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2021 The Flax Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | 29 | """Hyperparameter configuration to run the example on TPUs.""" 30 | 31 | import ml_collections 32 | 33 | 34 | def get_config(model='ResNet50', 35 | dataset='imagenet2012:5.*.*', 36 | optimizer='Adam', 37 | optimizer_params=None, 38 | 39 | num_devices=8, 40 | cache=True, 41 | half_precision=True, 42 | num_classes=1000, 43 | image_size=224, 44 | crop_padding=32): 45 | """Get the hyperparameter configuration to train on TPUs.""" 46 | config = ml_collections.ConfigDict() 47 | 48 | # As defined in the `models` module. 49 | config.model = model 50 | # `name` argument of tensorflow_datasets.builder() 51 | config.dataset = dataset 52 | 53 | # Consider setting the batch size to max(tpu_chips * 256, 8 * 1024) if you 54 | # train on a larger pod slice. 55 | config.num_devices = num_devices 56 | 57 | config.cache = cache 58 | config.half_precision = half_precision 59 | # list of optimizer configs dicts 60 | config.optimizers = optimizer 61 | config.optimizers_params = [get_optimizer_config(optimizer_params, opt_num_devices=num_devices)] 62 | config.num_classes = num_classes 63 | config.image_size = image_size 64 | config.crop_padding = crop_padding 65 | return config 66 | 67 | 68 | def get_optimizer_config(config=None, learning_rate=0.1, 69 | warmup_epochs=5, 70 | momentum=0.9, 71 | num_epochs=100, 72 | log_every_steps=100, 73 | num_train_steps=-1, 74 | steps_per_eval=-1, 75 | batch_size=-1, 76 | opt_num_devices=8): 77 | if config is None: 78 | config = ml_collections.ConfigDict() 79 | config.learning_rate = learning_rate 80 | config.warmup_epochs = warmup_epochs 81 | config.momentum = momentum 82 | config.opt_num_devices = opt_num_devices 83 | config.num_epochs = num_epochs 84 | config.log_every_steps = log_every_steps 85 | 86 | # If num_train_steps==-1 then the number of training steps is calculated from 87 | # num_epochs using the entire dataset. Similarly for steps_per_eval. 88 | config.num_train_steps = num_train_steps 89 | config.steps_per_eval = steps_per_eval 90 | 91 | if batch_size == -1: 92 | config.batch_size = max(config.opt_num_devices * 256, 8 * 1024) 93 | else: 94 | config.batch_size = batch_size 95 | 96 | return config 97 | elif config is not None and type(config) is dict: 98 | return ml_collections.ConfigDict(initial_dictionary=config) 99 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/environment.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import vmap, tree_map 4 | from jax import random 5 | from imax import transforms 6 | 7 | import tensorflow_datasets as tfds 8 | import numpy as np 9 | 10 | from functools import partial 11 | from typing import Any 12 | 13 | Array = Any 14 | 15 | 16 | class Environment: 17 | def prepare_data(self, dataset_name: str): 18 | ds_builder = tfds.builder(dataset_name) 19 | ds_builder.download_and_prepare() 20 | ds_train = ds_builder.as_dataset(split="train") 21 | self.test_data = ds_builder.as_dataset(split="test").repeat().batch(self.batch_size).as_numpy_iterator() 22 | 23 | self.num_classes = ds_builder.info.features['label'].num_classes 24 | self.current_class = np.random.choice(np.arange(self.num_classes), size=(1,)) 25 | self.sets = [ds_train.filter(lambda x: x['label'] == i) for i in range(self.num_classes)] 26 | self.counts = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949] 27 | # [st.reduce(np.int64(0), lambda x, _ : x + 1).numpy() for st in self.sets] 28 | 29 | def __init__(self, dataset_name='mnist', class_trans_mat=None, rot_mat=None, seed=0, batch_size=64): 30 | self.class_trans_mat = class_trans_mat 31 | self.rot = 0 32 | self.rot_mat = rot_mat 33 | self.seed = seed 34 | self.batch_size = batch_size 35 | self.key = random.PRNGKey(seed) 36 | self.prepare_data(dataset_name) 37 | 38 | def get_data(self): 39 | class_key, rot_key, self.key = random.split(self.key, 3) 40 | self.current_class = random.choice(class_key, jnp.arange(self.num_classes), 41 | p=self.class_trans_mat[self.current_class].squeeze()) 42 | batch = self.get_train_batch(self.current_class) 43 | self.rot = random.choice(rot_key, jnp.arange(len(self.rot_mat)), p=self.rot_mat[self.rot].squeeze()) 44 | batch['image'] = self.apply(batch['image'], self.rot) 45 | return {'image': batch['image'], 'label': batch['label']} 46 | 47 | def apply(self, x: Array, degrees: float): 48 | rad = jnp.radians(degrees) 49 | # Creates transformation matrix 50 | transform = transforms.rotate(rad=rad) 51 | apply_transform = partial(transforms.apply_transform, transform=transform, mask_value=jnp.array([0, 0, 0])) 52 | return vmap(apply_transform)(x) 53 | 54 | def get_test_data(self, device_count=1, nchannels=3): 55 | rotations = jnp.argwhere(self.rot_mat.sum(axis=0) > 0).flatten().tolist() 56 | 57 | batch = next(self.test_data) 58 | input, label = jnp.array(batch['image']), jnp.array(batch['label']) 59 | if input.shape[-1] != nchannels: 60 | input = jnp.repeat(input, axis=-1, repeats=nchannels) 61 | 62 | label = label[None, ...] 63 | for degrees in rotations: 64 | if degrees: 65 | input = jnp.vstack([input, self.apply(input, degrees)]) 66 | label = jnp.vstack([label, label]) 67 | label = label.squeeze() 68 | if device_count > 1: 69 | input = input.reshape((device_count, -1, *input.shape[-3:])) 70 | label = label.reshape((device_count, -1)) 71 | 72 | return {'image': input, 'label': label} 73 | 74 | def get_train_batch(self, c: int, seed: int = 0): 75 | dataset = self.sets[c] 76 | nexamples = self.counts[c] 77 | batch = tree_map(jnp.array, 78 | next(dataset.shuffle(nexamples, seed=seed).batch(self.batch_size).as_numpy_iterator())) 79 | images, labels = batch['image'], batch['label'] 80 | *_, nchannels = images.shape 81 | if nchannels == 1: 82 | images = jnp.repeat(images, axis=-1, repeats=3) 83 | return {'image': images, 'label': labels} 84 | 85 | def warmup(self, num_pulls: int): 86 | warmup_classes = jnp.arange(self.num_classes) 87 | warmup_classes = jnp.repeat(warmup_classes, num_pulls).reshape(self.num_classes, -1) 88 | classes = warmup_classes.reshape(-1, order="F").astype(jnp.int32) 89 | num_warmup_classes, *_ = classes.shape 90 | seeds = jnp.arange(len(classes)) 91 | inputs, labels = [], [] 92 | 93 | for c, seed in zip(classes, seeds): 94 | batch = self.get_train_batch(c, seed) 95 | inputs.append(batch['image']) 96 | labels.append(batch['label']) 97 | return jnp.vstack(inputs), jnp.concatenate(labels) 98 | -------------------------------------------------------------------------------- /shift_happens/gendist/notebooks/dojax.py: -------------------------------------------------------------------------------- 1 | # Library for domain shift in Jax 2 | import jax 3 | import numpy as np 4 | import jax.numpy as jnp 5 | from multiprocessing import Pool 6 | from augly import image 7 | 8 | 9 | def rotation_matrix(angle): 10 | """ 11 | Create a rotation matrix that rotates the 12 | space 'angle'-radians. 13 | """ 14 | R = np.array([ 15 | [np.cos(angle), -np.sin(angle)], 16 | [np.sin(angle), np.cos(angle)] 17 | ]) 18 | return R 19 | 20 | 21 | def flat_and_concat_params(params_hist): 22 | """ 23 | Flat and concat a list of parameters trained using 24 | a Flax model 25 | 26 | 27 | Parameters 28 | ---------- 29 | params_hist: list of flax FrozenDicts 30 | List of flax FrozenDicts containing trained model 31 | weights. 32 | 33 | Returns 34 | ------- 35 | jnp.array: flattened and concatenated weights 36 | """ 37 | _, recontruct_pytree_fn = jax.flatten_util.ravel_pytree(params_hist[0]) 38 | flat_params = [jax.flatten_util.ravel_pytree(params)[0] for params in params_hist] 39 | flat_params = jnp.r_[flat_params] 40 | return flat_params, recontruct_pytree_fn 41 | 42 | 43 | def make_mse_func(model, x_batched, y_batched): 44 | def mse(params): 45 | # Define the squared loss for a single pair (x,y) 46 | def squared_error(x, y): 47 | pred = model.apply(params, x) 48 | residual = pred - y 49 | return residual @ residual / 2.0 50 | # We vectorize the previous to compute the average of the loss on all samples. 51 | return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0) 52 | return jax.jit(mse) 53 | 54 | 55 | ### Elements for shifting-mnist experiments ### 56 | 57 | class BlurRad: 58 | def __init__(self, rad): 59 | self.rad = rad 60 | 61 | def __call__(self, img): 62 | return self.blur_multiple(img) 63 | 64 | def blur(self, X): 65 | """ 66 | Blur an image using the augly library 67 | 68 | Paramters 69 | --------- 70 | X: np.array 71 | A single NxM-dimensional array 72 | radius: float 73 | The amout of blurriness 74 | """ 75 | return image.aug_np_wrapper(X, image.blur, radius=self.rad) 76 | 77 | def blur_multiple(self, X_batch): 78 | images_out = [] 79 | for X in X_batch: 80 | img_blur = self.blur(X) 81 | images_out.append(img_blur) 82 | images_out = np.stack(images_out, axis=0) 83 | return images_out 84 | 85 | def blur_multiple(radii, img_dataset): 86 | """ 87 | Blur every element of `img_dataset` given an element 88 | of `radii`. 89 | """ 90 | imgs_out = [] 91 | for radius, img in zip(radii, img_dataset): 92 | img_proc = BlurRad(radius).blur(img) 93 | imgs_out.append(img_proc) 94 | imgs_out = np.stack(imgs_out, axis=0) 95 | 96 | return imgs_out 97 | 98 | 99 | # To-do: Modify proc_dataset and proc_dataset_multiple to use 100 | # a function that modifies the image. 101 | 102 | def proc_dataset(radius, img_dataset, n_processes=90): 103 | """ 104 | Blur all images of a dataset stored in a numpy array. 105 | 106 | Parameters 107 | ---------- 108 | radius: float 109 | Intensity of bluriness 110 | img_dataset: array(N, L, K) 111 | N images of size LxK 112 | n_processes: int 113 | Number of processes to blur over 114 | """ 115 | with Pool(processes=n_processes) as pool: 116 | dataset_proc = np.array_split(img_dataset, n_processes) 117 | dataset_proc = pool.map(BlurRad(radius), dataset_proc) 118 | dataset_proc = np.concatenate(dataset_proc, axis=0) 119 | pool.terminate() 120 | pool.join() 121 | 122 | return dataset_proc 123 | 124 | 125 | def proc_dataset_multiple(radii, img_dataset, n_processes=90): 126 | """ 127 | Blur all images of a dataset stored in a numpy array with variable 128 | radius. 129 | 130 | Parameters 131 | ---------- 132 | radius: array(N,) or float 133 | Intensity of bluriness. One per image. If 134 | float, the same value is used for all images. 135 | img_dataset: array(N, L, K) 136 | N images of size LxK 137 | n_processes: int 138 | Number of processes to blur over 139 | """ 140 | 141 | if type(radii) in [float, np.float_]: 142 | radii = radii * np.ones(len(img_dataset)) 143 | 144 | with Pool(processes=n_processes) as pool: 145 | dataset_proc = np.array_split(img_dataset, n_processes) 146 | radii_split = np.array_split(radii, n_processes) 147 | 148 | elements = zip(radii_split, dataset_proc) 149 | dataset_proc = pool.starmap(blur_multiple, elements) 150 | dataset_proc = np.concatenate(dataset_proc, axis=0) 151 | 152 | return dataset_proc 153 | -------------------------------------------------------------------------------- /shift_happens/gendist/experiments/mnist_rotation_meta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jax 3 | import optax 4 | import gendist 5 | import pickle 6 | import torchvision 7 | import numpy as np 8 | from augly import image 9 | from sklearn.decomposition import PCA 10 | 11 | 12 | def rotate(X, angle): 13 | X_shift = image.aug_np_wrapper(X, image.rotate, degrees=angle) 14 | size_im = X_shift.shape[0] 15 | size_pad = (28 - size_im) // 2 16 | size_pad_mod = (28 - size_im) % 2 17 | X_shift = np.pad(X_shift, (size_pad, size_pad + size_pad_mod)) 18 | 19 | return X_shift 20 | 21 | 22 | def load_train_combo(filename): 23 | """ 24 | Load the parameters and the configurations from a file 25 | of trained models. We return the flatten weights, 26 | the reconstruct function and the configurations. 27 | """ 28 | with open(filename, "rb") as f: 29 | params = pickle.load(f) 30 | list_params = params["params"] 31 | list_configs = params["configs"] 32 | 33 | target_params, fn_reconstruct = gendist.processing.flat_and_concat_params(list_params) 34 | 35 | output = { 36 | "params": target_params, 37 | "configs": list_configs, 38 | "fn_reconstruct": fn_reconstruct 39 | } 40 | 41 | return output 42 | 43 | 44 | def configure_covariates(key, processor, X, configs, n_subset): 45 | """ 46 | Given a dataset with shape (n_train, ...), a subset size n_subset, 47 | and a list of configurations to transform the dataset, we transform the 48 | dataset in an array of shape (n_subset, n_features, ...). 49 | """ 50 | n_configs = len(configs) 51 | n_train, *elem_dims = X.shape 52 | 53 | imap = np.ones((n_configs, 1, *elem_dims)) 54 | configs_transform = np.repeat(configs, n_subset) 55 | subset_ix = jax.random.choice(key, n_train, (n_subset,), replace=False).to_py() 56 | X = X[subset_ix, ...] * imap 57 | X = processor(X.reshape(-1, *elem_dims), configs_transform) 58 | X = X.reshape((n_subset, n_configs, -1), order="F") 59 | 60 | return X 61 | 62 | 63 | def predict_shifted_dataset(ix_seed, X_batch, processor, config, meta_model, 64 | meta_params, dmodel, proj, fn_reconstruct): 65 | """ 66 | Predict weights and estimate the values 67 | 68 | Parameters 69 | ---------- 70 | ix_seed: array 71 | X_batch: array 72 | ... 73 | meta_model: model for the latent space 74 | meta_params: trained weights for the latent space 75 | dmodel: model for the observed space 76 | dparams: trained model for the observed weights 77 | """ 78 | x_seed = X_batch[ix_seed] 79 | x_shift = processor.process_single(x_seed, **config).ravel() 80 | predicted_weights = meta_model.apply(meta_params, x_shift) 81 | predicted_weights = proj.inverse_transform(predicted_weights) 82 | predicted_weights = fn_reconstruct(predicted_weights) 83 | 84 | X_batch_shift = processor(X_batch, config) 85 | y_batch_hat = dmodel.apply(predicted_weights, X_batch_shift) 86 | 87 | return y_batch_hat 88 | 89 | 90 | processing_class = gendist.processing.Factory(rotate) 91 | 92 | 93 | if __name__ == "__main__": 94 | import sys 95 | 96 | _, filename_data_model = sys.argv 97 | experiment_path, _ = os.path.split(filename_data_model) 98 | 99 | output = load_train_combo(filename_data_model) 100 | target_params = output["params"] 101 | list_configs = output["configs"] 102 | fn_reconstruct_params = output["fn_reconstruct"] 103 | 104 | processing_class = gendist.processing.Factory(rotate) 105 | key = jax.random.PRNGKey(314) 106 | key, key_subset = jax.random.split(key) 107 | 108 | mnist_train = torchvision.datasets.MNIST(root=".", train=True, download=True) 109 | X_train = np.array(mnist_train.data) / 255 110 | X_train = configure_covariates(key_subset, processing_class, X_train, list_configs, n_train_subset) 111 | 112 | n_components = 60 113 | n_classes = 10 114 | n_train_subset = 6_000 115 | n_train, *elem_dims = X_train.shape 116 | n_configs = len(list_params) 117 | 118 | pca = PCA(n_components=n_components) 119 | projected_params = pca.fit_transform(target_params)[None, ...] 120 | 121 | alpha = 0.01 122 | n_epochs = 150 123 | batch_size = 2000 124 | tx = optax.adam(learning_rate=alpha) 125 | lossfn = gendist.training.make_multi_output_loss_func 126 | weights_model = gendist.models.MLPWeightsV1(n_components) 127 | trainer = gendist.training.TrainingMeta(weights_model, lossfn, tx) 128 | 129 | meta_output = trainer.fit(key, X_train, projected_params, n_epochs, batch_size) 130 | meta_output["projection_model"] = pca 131 | 132 | filename_meta_model = "meta-model.pkl" 133 | filename_meta_model = os.path.join(experiment_path, filename_meta_model) 134 | with open(filename_meta_model, "wb") as f: 135 | pickle.dump(meta_output, f) 136 | -------------------------------------------------------------------------------- /shift_happens/gendist/experiments/mnist_rotation_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jax 3 | import optax 4 | import pickle 5 | import gendist 6 | import torchvision 7 | import numpy as np 8 | from augly import image 9 | from datetime import datetime 10 | from tqdm import tqdm 11 | from loguru import logger 12 | 13 | def eval_acc(y, yhat): 14 | return (y.argmax(axis=1) == yhat.argmax(axis=1)).mean().item() 15 | 16 | 17 | def processor(X, angle): 18 | X_shift = image.aug_np_wrapper(X, image.rotate, degrees=angle) 19 | size_im = X_shift.shape[0] 20 | size_pad = (28 - size_im) // 2 21 | size_pad_mod = (28 - size_im) % 2 22 | X_shift = np.pad(X_shift, (size_pad, size_pad + size_pad_mod)) 23 | 24 | return X_shift 25 | 26 | 27 | def create_experiment_path(base_path, experiment_name): 28 | base_path = os.path.join(base_path, experiment_name) 29 | path_output = os.path.join(base_path, "output") 30 | path_logs = os.path.join(base_path, "logs") 31 | 32 | if not os.path.exists(path_output): 33 | os.makedirs(path_output) 34 | if not os.path.exists(path_logs): 35 | os.makedirs(path_logs) 36 | 37 | return base_path 38 | 39 | 40 | def training_loop(key, X, y, configs, trainer, n_epochs, batch_size, evalfn, logger, leave=True): 41 | """ 42 | Train a collection of models with different configurations 43 | 44 | Parameters 45 | ---------- 46 | X : ndarray 47 | Input data 48 | y : ndarray 49 | Target data 50 | configs : list 51 | List of configurations to train. 52 | Each configuration is a dictionary 53 | trainer: gendist.training 54 | Trainer object 55 | n_epochs: int 56 | Number of epochs to train 57 | batch_size: int 58 | Batch size 59 | evalfn : function 60 | Function to evaluate the model 61 | 62 | Returns 63 | ------- 64 | results : dict 65 | Dictionary with the results of the experiments 66 | """ 67 | configs_params = [] 68 | configs_losses = [] 69 | configs_metric = [] 70 | 71 | keys = jax.random.split(key, len(configs)) 72 | for key, config in tqdm(zip(keys, configs), leave=leave): # Remove zip for key 73 | train_output = trainer.fit(key, X, y, config, n_epochs, batch_size, evalfn) 74 | configs_params.append(train_output["params"]) 75 | configs_losses.append(train_output["losses"]) 76 | configs_metric.append(train_output["metric"]) 77 | 78 | name, value = config.copy().popitem() 79 | logger.info(f"{name}={value:0.3f} | {train_output['metric']:.4f}") 80 | 81 | output = { 82 | "params": configs_params, 83 | "losses": configs_losses, 84 | "metric": configs_metric 85 | } 86 | 87 | return output 88 | 89 | 90 | def main(key, base_path, trainer, X, y, configs, n_epochs, batch_size, evalfn, 91 | experiment_path=None, logname=None, filename=None, leave=True): 92 | filename = "data-model-result.pkl" if filename is None else filename 93 | logname = "log-data.log" if logname is None else logname 94 | 95 | if experiment_path is None: 96 | experiment_path = datetime.now().strftime("%y%m%d%H%M") 97 | experiment_path = create_experiment_path(base_path, experiment_path) 98 | 99 | logs_path = os.path.join(experiment_path, "logs") 100 | logs_path = os.path.join(logs_path, logname) 101 | logger.add(logs_path, rotation="5mb") 102 | 103 | experiment_results = training_loop(key, X, y, configs, trainer, n_epochs, batch_size, evalfn, logger, leave=leave) 104 | experiment_results["configs"] = configs 105 | 106 | filename = os.path.join(experiment_path, "output", filename) 107 | with open(filename, "wb") as f: 108 | pickle.dump(experiment_results, f) 109 | 110 | 111 | if __name__ == "__main__": 112 | n_configs, n_classes = 150, 10 113 | batch_size = 2000 114 | n_epochs = 50 115 | alpha = 0.001 116 | tx = optax.adam(learning_rate=alpha) 117 | # model = gendist.models.MLPDataV1(num_outputs=10) 118 | model = gendist.models.LeNet5(n_classes) 119 | processing_class = gendist.processing.Factory(processor) 120 | loss = gendist.training.make_cross_entropy_loss_func 121 | trainer = gendist.training.TrainingBase(model, processing_class, loss, tx) 122 | 123 | mnist_train = torchvision.datasets.MNIST(root="./data", train=True, download=True) 124 | X_train = np.array(mnist_train.data) / 255.0 125 | y_train = np.array(mnist_train.targets) 126 | y_train_ohe = jax.nn.one_hot(y_train, n_classes) 127 | 128 | degrees = np.linspace(0, 360, n_configs) 129 | configs = [{"angle": float(angle)} for angle in degrees] 130 | 131 | n_tests = 10 132 | base_path = "./outputs" 133 | experiment_path = "cnn-rotation-v2" 134 | key = jax.random.PRNGKey(314) 135 | keys = jax.random.split(key, n_tests) 136 | for it, key in tqdm(list(enumerate(keys))): 137 | logger.remove() # avoid output to terminal 138 | main(key, base_path, trainer, X_train, y_train_ohe, configs, n_epochs, batch_size, eval_acc, 139 | experiment_path=experiment_path, logname=f"log-cnn-{it:02}.log", filename=f"cnn-{it:02}.pkl", 140 | leave=False) 141 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/gdumb_old.py: -------------------------------------------------------------------------------- 1 | from jax._src.random import split 2 | import jax.numpy as jnp 3 | from jax import lax, ops, random 4 | 5 | from collections import defaultdict 6 | 7 | 8 | def init_sampler(x , y, num_classes): 9 | dataset = x, y 10 | counts = jnp.bincount(y, minlength=num_classes) 11 | return counts, dataset 12 | 13 | 14 | def sample(key, X, Y, counts, datasets, memory_size=200): 15 | inputs, labels = datasets 16 | n = jnp.sum(jnp.where(counts>0, 1, 0)) 17 | nperclass = memory_size // n 18 | 19 | def true_fun(key, inputs, labels, counts, x, y): 20 | c_max = jnp.argwhere(counts == jnp.max(counts)) 21 | sample_key, key = random.split(key) 22 | indices = jnp.argwhere(labels == c_max) 23 | row = random.choice(sample_key, indices, shape=(1, )) 24 | inputs = ops.index_update(inputs, jnp.index_exp[row], x) 25 | labels = ops.index_update(labels, jnp.index_exp[row], y) 26 | counts = jnp.where(jnp.arange(len(counts)) == c_max, counts-1, counts) 27 | counts = jnp.where(jnp.arange(len(counts)) == y, counts+1, counts) 28 | return inputs, labels, counts 29 | 30 | def false_fun(key, inputs, labels, counts, x, y): 31 | return inputs, labels, counts 32 | 33 | 34 | def scan_fun(state, carry): 35 | inputs, labels, counts = state 36 | key, x, y = carry 37 | 38 | inputs, labels, counts = lax.cond(counts[y] == 0 or counts[y] < nperclass, true_fun, false_fun, 39 | operands=(key, inputs, labels, x, y)) 40 | return (inputs, labels, counts), None 41 | 42 | if len(datasets) < memory_size: 43 | for x, y in zip(X, Y): 44 | if counts[y] == 0 or counts[y] < nperclass: 45 | inputs = jnp.vstack([inputs, x[None, ...]]) 46 | labels = jnp.append(labels, y) 47 | counts = jnp.where(jnp.arange(len(counts)) ==y, counts+1, counts) 48 | 49 | else: 50 | scan_key, key = split(key) 51 | keys = split(scan_key, len(X)) 52 | (inputs, labels, counts), _ = lax.scan(scan_fun, (inputs, labels, counts), (keys, X, Y)) 53 | 54 | return counts, (inputs, labels) 55 | 56 | 57 | 58 | 59 | import jax.numpy as jnp 60 | from jax import vmap 61 | from jax import random 62 | import jax.dlpack 63 | 64 | from imax import transforms 65 | 66 | import tensorflow as tf 67 | import tensorflow_datasets as tfds 68 | 69 | from functools import partial 70 | from typing import Any 71 | 72 | Array = Any 73 | 74 | def tf_to_jax(arr): 75 | return jax.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(arr)) 76 | 77 | class Environment: 78 | def prepare_data(self, key : Any, dataset_name : str): 79 | 80 | ds_builder = tfds.builder(dataset_name) 81 | ds_builder.download_and_prepare() 82 | ds_train = ds_builder.as_dataset(split="train", as_supervised=True) 83 | self.test_data = ds_builder.as_dataset(split="test", as_supervised=True) 84 | 85 | self.num_classes = ds_builder.info.features['label'].num_classes 86 | self.current_class = random.choice(key, jnp.arange(self.num_classes), shape=(1,)) 87 | self.sets = [ds_train.filter(lambda x, y: y == i) for i in range(self.num_classes)] 88 | self.counts = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949] 89 | # [st.reduce(np.int64(0), lambda x, _ : x + 1).numpy() for st in self.sets] 90 | 91 | 92 | def __init__(self, dataset_name='mnist', class_trans_mat=None, rot_mat=None, seed=0, batch_size=64): 93 | self.class_trans_mat = class_trans_mat 94 | self.rot = 0 95 | self.rot_mat = rot_mat 96 | self.seed= seed 97 | self.batch_size = batch_size 98 | 99 | tf.random.set_seed(seed) 100 | key = random.PRNGKey(seed) 101 | self.prepare_data(key, dataset_name) 102 | 103 | 104 | def get_data(self, key): 105 | class_key, rot_key, key = random.split(key, 3) 106 | self.current_class = random.choice(class_key, jnp.arange(self.num_classes), p=self.class_trans_mat[self.current_class].squeeze()) 107 | batch = self.get_batch(self.current_class) 108 | self.rot = random.choice(rot_key, jnp.arange(len(self.rot_mat)), p=self.rot_mat[self.rot].squeeze()) 109 | batch['image'] = self.apply(batch['image'], self.rot) 110 | return batch 111 | 112 | 113 | def apply(self, x : Array, degrees : float): 114 | rad = jnp.radians(degrees) 115 | # Creates transformation matrix 116 | transform = transforms.rotate(rad=rad) 117 | apply_transform = partial(transforms.apply_transform, transform=transform, mask_value=jnp.array([0, 0, 0])) 118 | return vmap(apply_transform)(x) 119 | 120 | def get_batch(self, c: int): 121 | dataset = self.sets[c] 122 | nexamples = self.counts[c] 123 | for images, labels in dataset.shuffle(nexamples).batch(self.batch_size): 124 | images = tf_to_jax(images) 125 | labels = tf_to_jax(labels) 126 | break 127 | 128 | *_, nchannels = images.shape 129 | if nchannels == 1: 130 | images = jnp.repeat(images, axis=-1, repeats=3) 131 | return {'image': images, 'label': labels} 132 | 133 | def warmup(self, num_pulls : int): 134 | warmup_classes = jnp.arange(self.num_classes) 135 | warmup_classes = jnp.repeat(warmup_classes, num_pulls).reshape(self.num_classes, -1) 136 | classes = warmup_classes.reshape(-1, order="F").astype(jnp.int32) 137 | num_warmup_classes, *_ = classes.shape 138 | inputs, labels = [], [] 139 | 140 | for c in classes: 141 | batch = self.get_batch(c) 142 | inputs.append(batch['image'][None, ...]) 143 | labels.append(batch['label']) 144 | 145 | images = jnp.vstack(inputs) 146 | return images.reshape((-1, *images.shape[-3:])), jnp.concatenate(labels) -------------------------------------------------------------------------------- /shift_happens/gendist/gendist/training.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import optax 3 | import jax.numpy as jnp 4 | import numpy as np 5 | from functools import partial 6 | from tqdm.auto import tqdm 7 | 8 | 9 | def make_cross_entropy_loss_func(model, X, y): 10 | """ 11 | Make a loss function for a multi-output classifier, i.e., 12 | model: R^M -> (y1, y2, ..., yK) = y_hat; where y_hat is a 13 | probability distribution over K classes 14 | 15 | Make a loss function for a multi-output model, i.e., 16 | model: R^M -> R^K 17 | 18 | Parameters 19 | ---------- 20 | model: Flax model 21 | Flax model that takes X and returns y_hat 22 | X: array(N, ...) 23 | N samples of arbitrary shape 24 | y: array(N, K) 25 | N samples of K-dimensional outputs 26 | """ 27 | def loss_fn(params): 28 | y_hat = model.apply(params, X) 29 | loss = optax.softmax_cross_entropy(y_hat, y).mean() 30 | return loss 31 | return loss_fn 32 | 33 | 34 | def make_multi_output_loss_func(model, X, y): 35 | """ 36 | Make a loss function for a multi-output model, i.e., 37 | model: R^M -> R^K 38 | 39 | Parameters 40 | ---------- 41 | model: Flax model 42 | Flax model that takes X and returns y_hat 43 | X: array(N, ...) 44 | N samples of arbitrary shape 45 | y: array(N, K) 46 | N samples of K-dimensional outputs 47 | """ 48 | def loss_fn(params): 49 | y_hat = model.apply(params, X) 50 | loss = jnp.linalg.norm(y - y_hat, axis=-1) ** 2 51 | return loss.mean() 52 | return loss_fn 53 | 54 | class TrainingBase: 55 | """ 56 | Class to train a neural network model that transforms the input data 57 | given a processor function. 58 | """ 59 | def __init__(self, model, processor, loss_generator, tx): 60 | self.model = model 61 | self.processor = processor 62 | self.loss_generator = loss_generator 63 | self.tx = tx 64 | 65 | def fit(self, key, X_train, y_train, config, num_epochs, batch_size, evalfn=None): 66 | """ 67 | Train a flax.linen model by transforming the data according to 68 | process_config. 69 | 70 | Parameters 71 | ---------- 72 | key: jax.random.PRNGKey 73 | Random number generator key. 74 | model: flax.nn.Module 75 | Model to train. 76 | X_train: jnp.array(N, ...) 77 | Training data. 78 | y_train: jnp.array(N) 79 | Training target values 80 | config: dict 81 | Dictionary containing the training configuration to be passed to 82 | the processor. 83 | num_epochs: int 84 | Number of epochs to train the model. 85 | """ 86 | X_train_proc = self.processor(X_train, config) 87 | _, *input_shape = X_train_proc.shape 88 | 89 | batch = jnp.ones((batch_size, *input_shape)) 90 | params = self.model.init(key, batch) 91 | optimiser_state = self.tx.init(params) 92 | 93 | losses = [] 94 | for e in tqdm(range(num_epochs), leave=False): 95 | _, key = jax.random.split(key) 96 | params, optimiser_state, avg_loss = self.train_epoch(key, params, optimiser_state, 97 | X_train_proc, y_train, batch_size, e) 98 | losses.append(avg_loss) 99 | 100 | if evalfn is not None: 101 | yhat = self.model.apply(params, X_train_proc) 102 | metric = evalfn(y_train, yhat) 103 | else: 104 | metric = None 105 | 106 | training_output = { 107 | "losses": jnp.array(losses), 108 | "metric": metric, 109 | "params": params, 110 | } 111 | 112 | return training_output 113 | 114 | @partial(jax.jit, static_argnums=(0,)) 115 | def train_step(self, params, opt_state, X_batch, y_batch): 116 | loss_fn = self.loss_generator(self.model, X_batch, y_batch) 117 | loss_grad_fn = jax.value_and_grad(loss_fn) 118 | loss_val, grads = loss_grad_fn(params) 119 | updates, opt_state = self.tx.update(grads, opt_state) 120 | params = optax.apply_updates(params, updates) 121 | 122 | return loss_val, params, opt_state 123 | 124 | def get_batch_train_ixs(self, key, num_samples, batch_size): 125 | """ 126 | Obtain the training indices to be used in an epoch of 127 | mini-batch optimisation. 128 | """ 129 | steps_per_epoch = num_samples // batch_size 130 | batch_ixs = jax.random.permutation(key, num_samples) 131 | batch_ixs = batch_ixs[:steps_per_epoch * batch_size] 132 | batch_ixs = batch_ixs.reshape(steps_per_epoch, batch_size) 133 | 134 | return batch_ixs 135 | 136 | def train_epoch(self, key, params, opt_step, X, y, batch_size, epoch): 137 | num_samples, *_ = X.shape 138 | batch_ixs = self.get_batch_train_ixs(key, num_samples, batch_size) 139 | 140 | epoch_loss = 0.0 141 | for batch_ix in batch_ixs: 142 | X_batch = X[batch_ix, ...] 143 | y_batch = y[batch_ix, ...] 144 | loss, params, opt_step = self.train_step(params, opt_step, X_batch, y_batch) 145 | epoch_loss += loss 146 | 147 | epoch_loss = epoch_loss / len(batch_ixs) 148 | return params, opt_step, epoch_loss 149 | 150 | 151 | class TrainingSnapshot(TrainingBase): 152 | """ 153 | Extension of Training base class that saves the model parameters 154 | every snapshot_interval epochs. For this class, it is better to consider 155 | an optimiser that fluctuates the learning rate. 156 | """ 157 | def __init__(self, model, processor, loss_generator, tx, snapshot_interval): 158 | super().__init__(model, processor, loss_generator, tx) 159 | self.snapshot_interval = snapshot_interval 160 | 161 | def fit(self, key, X_train, y_train, config, num_epochs, batch_size, evalfn=None): 162 | """ 163 | Train a flax.linen model by transforming the data according to 164 | process_config. 165 | 166 | Parameters 167 | ---------- 168 | key: jax.random.PRNGKey 169 | Random number generator key. 170 | model: flax.nn.Module 171 | Model to train. 172 | X_train: jnp.array(N, ...) 173 | Training data. 174 | y_train: jnp.array(N) 175 | Training target values 176 | config: dict 177 | Dictionary containing the training configuration to be passed to 178 | the processor. 179 | num_epochs: int 180 | Number of epochs to train the model. 181 | """ 182 | X_train_proc = self.processor(X_train, config) 183 | _, *input_shape = X_train_proc.shape 184 | 185 | batch = jnp.ones((batch_size, *input_shape)) 186 | params = self.model.init(key, batch) 187 | optimiser_state = self.tx.init(params) 188 | 189 | losses = [] 190 | params_hist = [] 191 | metrics_hist = [] 192 | for e in tqdm(range(num_epochs), leave=False): 193 | _, key = jax.random.split(key) 194 | 195 | # Store the parameters and evaluate the model on 196 | # the train set 197 | if (e+1) % self.snapshot_interval == 0: 198 | params_hist.append(params) 199 | if evalfn is not None: 200 | yhat = self.model.apply(params, X_train_proc) 201 | metric = evalfn(y_train, yhat) 202 | metrics_hist.append(metric) 203 | 204 | 205 | params, optimiser_state, avg_loss = self.train_epoch(key, params, optimiser_state, 206 | X_train_proc, y_train, batch_size, e) 207 | losses.append(avg_loss) 208 | 209 | training_output = { 210 | "params": params_hist, 211 | "losses": jnp.array(losses), 212 | "metrics": metrics_hist, 213 | } 214 | 215 | return training_output 216 | 217 | 218 | class TrainingMeta(TrainingBase): 219 | """ 220 | Training class of model parameters. We consider an input of the form NxMx..., and a target 221 | variable of the form KxMxW, wher 222 | * N: number of observations 223 | * M: number of transformations per observation 224 | * ...: Dimension specification of a single instance 225 | * K: number of samples per configuration. 226 | * W: number of parameters per configuration 227 | """ 228 | def __init__(self, model, loss_generator, tx): 229 | super().__init__(model, lambda x, _: x, loss_generator, tx) 230 | 231 | def fit(self, key, X_train, y_train, num_epochs, batch_size, leave_pb=True): 232 | """ 233 | Train a flax.linen model by transforming the data according to 234 | process_config. 235 | 236 | Parameters 237 | ---------- 238 | key: jax.random.PRNGKey 239 | Random number generator key. 240 | model: flax.nn.Module 241 | Model to train. 242 | X_train: jnp.array(N, ...) 243 | Training data. 244 | y_train: jnp.array(N) 245 | Training target values 246 | num_epochs: int 247 | Number of epochs to train the model. 248 | batch_size: int 249 | Number of samples per batch. 250 | leave_pb: bool 251 | If True, the progress bar is left open. 252 | """ 253 | _, *input_shape = X_train.shape 254 | 255 | key, key_params = jax.random.split(key) 256 | batch = jnp.ones((batch_size, *input_shape)) 257 | params = self.model.init(key_params, batch) 258 | optimiser_state = self.tx.init(params) 259 | 260 | losses = [] 261 | for e in tqdm(range(num_epochs), leave=leave_pb): 262 | _, key = jax.random.split(key) 263 | params, optimiser_state, avg_loss = self.train_epoch(key, params, optimiser_state, 264 | X_train, y_train, batch_size, e) 265 | losses.append(avg_loss) 266 | 267 | training_output = { 268 | "params": params, 269 | "losses": jnp.array(losses), 270 | } 271 | 272 | return training_output 273 | 274 | 275 | def train_epoch(self, key, params, opt_step, X, y, batch_size, epoch): 276 | """ 277 | Train an model considering an input of the form NxMx..., and a target 278 | variable of the form KxMxW. 279 | """ 280 | num_samples, num_configs_X, *_ = X.shape 281 | num_cycles, num_configs_y, _ = y.shape 282 | if num_configs_X != num_configs_y: 283 | raise ValueError("The number of configurations in X and y must be the same.") 284 | num_configs = num_configs_X 285 | num_elements = num_samples * num_configs * num_cycles 286 | 287 | batch_ixs = self.get_batch_train_ixs(key, num_elements, batch_size) 288 | 289 | epoch_loss = 0.0 290 | for batch_ix in batch_ixs: 291 | X_batch = X[batch_ix % num_samples, batch_ix // (num_samples * num_cycles), ...] 292 | y_batch = y[(batch_ix // num_samples) % num_cycles, batch_ix // (num_samples * num_cycles), ...] 293 | loss, params, opt_step = self.train_step(params, opt_step, X_batch, y_batch) 294 | epoch_loss += loss 295 | 296 | epoch_loss = epoch_loss / len(batch_ixs) 297 | return params, opt_step, epoch_loss 298 | -------------------------------------------------------------------------------- /shift_happens/imagenet_flax/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | This script trains a ResNet-50 on the ImageNet dataset. 17 | The data is loaded using tensorflow_datasets. 18 | """ 19 | 20 | import functools 21 | import time 22 | from typing import Any 23 | 24 | from absl import logging 25 | from clu import metric_writers 26 | from clu import periodic_actions 27 | 28 | import jax 29 | import jax.numpy as jnp 30 | from jax import lax 31 | from jax import random 32 | 33 | import flax 34 | from flax import jax_utils 35 | from flax import optim 36 | from flax.training import checkpoints 37 | from flax.training import common_utils 38 | from flax.training import train_state 39 | 40 | import optax 41 | 42 | import tensorflow as tf 43 | import tensorflow_datasets as tfds 44 | 45 | import ml_collections 46 | 47 | import numpy as np 48 | 49 | # Local imports 50 | import gdumb 51 | from environment import Environment 52 | import agents.models as models 53 | 54 | 55 | def create_model(model_cls, num_classes, half_precision, **kwargs): 56 | platform = jax.local_devices()[0].platform 57 | if half_precision: 58 | if platform == 'tpu': 59 | model_dtype = jnp.bfloat16 60 | else: 61 | model_dtype = jnp.float16 62 | else: 63 | model_dtype = jnp.float32 64 | return model_cls(num_classes=num_classes, dtype=model_dtype, **kwargs) 65 | 66 | 67 | def init_trans_mat_and_rot_mat(key, num_classes: int, max_degree: int = 359, rotation_indices: list = [0, 180]): 68 | trans_mat = jax.nn.softmax(jax.random.normal(key, shape=(num_classes, num_classes))) 69 | rot_mat = np.zeros((max_degree + 1, max_degree + 1)) 70 | for row in rotation_indices: 71 | for col in rotation_indices: 72 | rot_mat[row, col] = 1 / len(rotation_indices) 73 | rot_mat = jnp.array(rot_mat) 74 | return trans_mat, rot_mat 75 | 76 | 77 | def initialized(key, image_size, model): 78 | input_shape = (1, image_size, image_size, 3) 79 | 80 | @jax.jit 81 | def init(*args): 82 | return model.init(*args) 83 | 84 | variables = init({'params': key}, jnp.ones(input_shape, model.dtype)) 85 | if "batch_stats" in variables: 86 | return variables['params'], variables['batch_stats'] 87 | return variables['params'], {"mean": jnp.array([])} 88 | 89 | 90 | def cross_entropy_loss(logits, labels, num_classes): 91 | one_hot_labels = common_utils.onehot(labels, num_classes=num_classes) 92 | xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels) 93 | return jnp.mean(xentropy) 94 | 95 | 96 | def compute_metrics(logits, labels, num_classes): 97 | print(logits.shape, labels.shape) 98 | loss = cross_entropy_loss(logits, labels, num_classes) 99 | accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) 100 | metrics = { 101 | 'loss': loss, 102 | 'accuracy': accuracy, 103 | } 104 | metrics = lax.pmean(metrics, axis_name='batch') 105 | return metrics 106 | 107 | 108 | def create_learning_rate_fn( 109 | num_steps: int, 110 | init_learning_rate: float): 111 | """Create learning rate schedule.""" 112 | 113 | def schedule(step): 114 | t = step / num_steps 115 | return 0.5 * init_learning_rate * (1 + jnp.cos(t * jnp.pi)) 116 | 117 | return schedule 118 | 119 | 120 | def train_step(state, batch, learning_rate_fn, num_classes, weight_decay=0.0001): 121 | """Perform a single training step.""" 122 | 123 | def loss_fn(params): 124 | """loss function used for training.""" 125 | logits, new_model_state = state.apply_fn( 126 | {'params': params, 'batch_stats': state.batch_stats}, 127 | batch['image'], 128 | mutable=['batch_stats']) 129 | loss = cross_entropy_loss(logits, batch['label'], num_classes) 130 | weight_penalty_params = jax.tree_leaves(params) 131 | weight_l2 = sum([jnp.sum(x ** 2) 132 | for x in weight_penalty_params 133 | if x.ndim > 1]) 134 | weight_penalty = weight_decay * 0.5 * weight_l2 135 | loss = loss + weight_penalty 136 | return loss, (new_model_state, logits) 137 | 138 | step = state.step 139 | dynamic_scale = state.dynamic_scale 140 | lr = learning_rate_fn(step) 141 | 142 | if dynamic_scale: 143 | grad_fn = dynamic_scale.value_and_grad( 144 | loss_fn, has_aux=True, axis_name='batch') 145 | dynamic_scale, is_fin, aux, grads = grad_fn(state.params) 146 | # dynamic loss takes care of averaging gradients across replicas 147 | else: 148 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 149 | aux, grads = grad_fn(state.params) 150 | # Re-use same axis_name as in the call to `pmap(...train_step...)` below. 151 | grads = lax.pmean(grads, axis_name='batch') 152 | 153 | new_model_state, logits = aux[1] 154 | metrics = compute_metrics(logits, batch['label'], num_classes) 155 | metrics['learning_rate'] = lr 156 | 157 | new_state = state.apply_gradients( 158 | grads=grads, batch_stats=new_model_state['batch_stats']) 159 | if dynamic_scale: 160 | # if is_fin == False the gradients contain Inf/NaNs and optimizer state and 161 | # params should be restored (= skip this step). 162 | new_state = new_state.replace( 163 | opt_state=jax.tree_multimap( 164 | functools.partial(jnp.where, is_fin), 165 | new_state.opt_state, 166 | state.opt_state), 167 | params=jax.tree_multimap( 168 | functools.partial(jnp.where, is_fin), 169 | new_state.params, 170 | state.params)) 171 | metrics['scale'] = dynamic_scale.scale 172 | 173 | return new_state, metrics 174 | 175 | 176 | def eval_step(state, batch, num_classes): 177 | params = state.params 178 | variables = {'params': params, 'batch_stats': state.batch_stats} 179 | 180 | logits = state.apply_fn(variables, batch['image'], train=False, mutable=False) 181 | return compute_metrics(logits, batch['label'], num_classes) 182 | 183 | 184 | class TrainState(train_state.TrainState): 185 | batch_stats: Any 186 | dynamic_scale: flax.optim.DynamicScale 187 | 188 | 189 | def restore_checkpoint(state, workdir): 190 | return checkpoints.restore_checkpoint(workdir, state) 191 | 192 | 193 | def save_checkpoint(state, workdir): 194 | if jax.process_index() == 0: 195 | # get train state from the first replica 196 | state = jax.device_get(jax.tree_map(lambda x: x[0], state)) 197 | step = int(state.step) 198 | checkpoints.save_checkpoint(workdir, state, step, keep=3) 199 | 200 | 201 | # pmean only works inside pmap because it needs an axis name. 202 | # This function will average the inputs across all devices. 203 | cross_replica_mean = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') 204 | 205 | 206 | def sync_batch_stats(state): 207 | """Sync the batch statistics across replicas.""" 208 | # Each device has its own version of the running average batch statistics and 209 | # we sync them before evaluation. 210 | return state.replace(batch_stats=cross_replica_mean(state.batch_stats)) 211 | 212 | 213 | def create_train_state(rng, config: ml_collections.ConfigDict, 214 | model, image_size, learning_rate_fn): 215 | """Create initial training state.""" 216 | dynamic_scale = None 217 | platform = jax.local_devices()[0].platform 218 | 219 | if config.half_precision and platform == 'gpu': 220 | dynamic_scale = optim.DynamicScale() 221 | else: 222 | dynamic_scale = None 223 | 224 | params, batch_stats = initialized(rng, image_size, model) 225 | 226 | tx = optax.sgd( 227 | learning_rate=learning_rate_fn, 228 | momentum=config.momentum_decay, 229 | nesterov=True, 230 | ) 231 | 232 | state = TrainState.create( 233 | apply_fn=model.apply, 234 | params=params, 235 | tx=tx, 236 | batch_stats=batch_stats, 237 | dynamic_scale=dynamic_scale) 238 | return state 239 | 240 | 241 | def get_input_dtype(half_precision, platform): 242 | if half_precision: 243 | if platform == 'tpu': 244 | input_dtype = tf.bfloat16 245 | else: 246 | input_dtype = tf.float16 247 | else: 248 | input_dtype = tf.float32 249 | return input_dtype 250 | 251 | 252 | def train_and_evaluate(config: ml_collections.ConfigDict, 253 | workdir: str) -> TrainState: 254 | """Execute model training and evaluation loop. 255 | 256 | Args: 257 | config: Hyperparameter configuration for training and evaluation. 258 | workdir: Directory where the tensorboard summaries are written to. 259 | 260 | Returns: 261 | Final TrainState. 262 | """ 263 | 264 | writer = metric_writers.create_default_writer( 265 | logdir=workdir, just_logging=jax.process_index() != 0) 266 | 267 | rng = random.PRNGKey(config.seed) 268 | 269 | if config.batch_size % jax.device_count() > 0: 270 | raise ValueError('Batch size must be divisible by the number of devices') 271 | local_batch_size = config.batch_size // jax.process_count() 272 | 273 | platform = jax.local_devices()[0].platform 274 | input_dtype = get_input_dtype(config.half_precision, platform) 275 | 276 | dataset_builder = tfds.builder(config.dataset) 277 | 278 | num_classes = dataset_builder.info.features['label'].num_classes 279 | train, test = "train", "validation" 280 | 281 | if config.dataset == "cifar10" or config.dataset == "mnist": 282 | test = "test" 283 | 284 | train_freq = config.train_freq 285 | 286 | init_key, rng = random.split(rng) 287 | trans_mat, rot_mat = init_trans_mat_and_rot_mat(init_key, num_classes) 288 | environment = Environment(config.dataset, trans_mat, rot_mat, batch_size=local_batch_size) 289 | 290 | steps_per_epoch = ( 291 | dataset_builder.info.splits[train].num_examples // config.batch_size 292 | ) 293 | 294 | if config.num_train_steps == -1: 295 | num_steps = int(steps_per_epoch * config.num_epochs) 296 | else: 297 | num_steps = config.num_train_steps 298 | 299 | if config.steps_per_eval == -1: 300 | num_validation_examples = dataset_builder.info.splits[ 301 | test].num_examples 302 | steps_per_eval = num_validation_examples // config.batch_size 303 | else: 304 | steps_per_eval = config.steps_per_eval 305 | 306 | steps_per_checkpoint = steps_per_epoch * 10 307 | 308 | model_cls = getattr(models, config.model) 309 | model = create_model( 310 | model_cls=model_cls, num_classes=num_classes, half_precision=config.half_precision) 311 | 312 | learning_rate_fn = create_learning_rate_fn( 313 | num_steps, config.learning_rate) 314 | 315 | if config.image_size == -1: 316 | image_size = dataset_builder.info.features['image'].shape[0] 317 | else: 318 | image_size = config.image_size 319 | 320 | init_key, rng = random.split(rng) 321 | state = create_train_state(init_key, config, model, image_size, learning_rate_fn) 322 | state = restore_checkpoint(state, workdir) 323 | # step_offset > 0 if restarting from checkpoint 324 | step_offset = int(state.step) 325 | state = jax_utils.replicate(state) 326 | 327 | p_train_step = jax.pmap( 328 | functools.partial(train_step, learning_rate_fn=learning_rate_fn, 329 | num_classes=num_classes, 330 | weight_decay=config.weight_decay), axis_name='batch') 331 | 332 | p_eval_step = jax.pmap(functools.partial(eval_step, num_classes=num_classes), axis_name='batch') 333 | 334 | train_metrics = [] 335 | hooks = [] 336 | if jax.process_index() == 0: 337 | hooks += [periodic_actions.Profile(num_profile_steps=5, logdir=workdir)] 338 | 339 | train_metrics_last_t = time.time() 340 | logging.info('Initial compilation, this might take some minutes...') 341 | local_device_count = jax.local_device_count() 342 | 343 | for step in range(step_offset, num_steps): 344 | batch = environment.get_data() 345 | if step == step_offset: 346 | counts, datasets = gdumb.init_sampler() 347 | 348 | if train_freq != 1: 349 | scan_key, rng = random.split(rng) 350 | keys = random.split(scan_key, local_batch_size) 351 | images, labels = batch['image'], batch['label'].flatten() 352 | images = images.reshape((local_batch_size, *images.shape[-3:])) 353 | for key, image, label in zip(keys, images, labels): 354 | counts, datasets = gdumb.sample((counts, datasets), (key, image, label)) 355 | 356 | n = (len(datasets[0]) // local_device_count) * local_device_count 357 | batch = {'image': datasets[0][:n].reshape((local_device_count, -1, *images.shape[-3:])), 358 | 'label': datasets[1][:n].reshape((local_device_count, -1))} 359 | 360 | if (step + 1) % train_freq == 0: 361 | state, metrics = p_train_step(state, batch) 362 | for h in hooks: 363 | h(step) 364 | if step == step_offset: 365 | logging.info('Initial compilation completed.') 366 | 367 | if config.get('log_every_steps'): 368 | train_metrics.append(metrics) 369 | if (step + 1) % config.log_every_steps == 0: 370 | train_metrics = common_utils.get_metrics(train_metrics) 371 | summary = { 372 | f'{train}_{k}': v 373 | for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items() 374 | } 375 | summary['steps_per_second'] = config.log_every_steps / ( 376 | time.time() - train_metrics_last_t) 377 | writer.write_scalars(step + 1, summary) 378 | train_metrics = [] 379 | train_metrics_last_t = time.time() 380 | 381 | if (step + 1) % steps_per_epoch == 0: 382 | epoch = step // steps_per_epoch 383 | eval_metrics = [] 384 | 385 | # sync batch statistics across replicas 386 | state = sync_batch_stats(state) 387 | for _ in range(steps_per_eval): 388 | eval_batch = environment.get_test_data(device_count=local_device_count) 389 | metrics = p_eval_step(state, eval_batch) 390 | eval_metrics.append(metrics) 391 | 392 | eval_metrics = common_utils.get_metrics(eval_metrics) 393 | summary = jax.tree_map(lambda x: x.mean(), eval_metrics) 394 | logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', 395 | epoch, summary['loss'], summary['accuracy'] * 100) 396 | writer.write_scalars( 397 | step + 1, {f'eval_{key}': val for key, val in summary.items()}) 398 | writer.flush() 399 | 400 | if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: 401 | state = sync_batch_stats(state) 402 | save_checkpoint(state, workdir) 403 | 404 | # Wait until computations are done before exiting 405 | random.normal(random.PRNGKey(0), ()).block_until_ready() 406 | 407 | return state 408 | -------------------------------------------------------------------------------- /shift_happens/gendist/notebooks/013-metadata-inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f0a5a74e-cfc6-491c-8267-60c5b1ceefa5", 6 | "metadata": {}, 7 | "source": [ 8 | "# Metadata inference" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "e8326182-b4fd-4ded-ad61-7095eb4a8f69", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "/home/gerardoduran/documents/shift-happens/gendist/experiments\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "%cd ../gendist/experiments\n", 27 | "%load_ext autoreload\n", 28 | "%autoreload 2" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "id": "23c89333-052f-4bcc-9921-c28ff816419a", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "import jax\n", 39 | "import optax\n", 40 | "import gendist\n", 41 | "import torchvision\n", 42 | "import numpy as np\n", 43 | "import jax.numpy as jnp\n", 44 | "import matplotlib.pyplot as plt\n", 45 | "import mnist_rotation_meta as metaexp" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 204, 51 | "id": "48f9cbe6-cbe3-425a-933d-343c325b706c", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "%config InlineBackend.figure_format = \"retina\"" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 82, 61 | "id": "e0c25fc1-f25f-4c39-997d-f32a7ab9d284", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "configs = np.linspace(0, 360, 150)\n", 66 | "list_configs = [{\"angle\": float(deg)} for deg in configs]" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 83, 72 | "id": "ca82ca0c-8efa-47b6-a655-725a06e65d09", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "key = jax.random.PRNGKey(314)\n", 77 | "key_subset, key_train = jax.random.split(key)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 159, 83 | "id": "664bfc7a-7fdc-4a9d-9689-196df47b1626", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "n_train_subset = 600\n", 88 | "mnist_train = torchvision.datasets.MNIST(root=\".\", train=True, download=True)\n", 89 | "X_train = np.array(mnist_train.data) / 255\n", 90 | "X_train = metaexp.configure_covariates(key_subset, metaexp.processing_class, X_train, list_configs, n_train_subset)\n", 91 | "X_train = X_train * 2 - 1" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 242, 97 | "id": "970348cd-f072-46ca-a3bb-6be88c781778", 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "alpha = 0.00005\n", 102 | "# alpha = 0.001\n", 103 | "n_epochs = 100\n", 104 | "batch_size = 2000\n", 105 | "\n", 106 | "tx = optax.adam(learning_rate=alpha)\n", 107 | "lossfn = gendist.training.make_von_mises_loss_func\n", 108 | "meta_model = gendist.models.LeNet5Regression(1)\n", 109 | "# meta_model = gendist.models.MLPDataV1(1)\n", 110 | "\n", 111 | "configs_radians = configs / 180 * jnp.pi\n", 112 | "trainer = gendist.training.TrainingMeta(meta_model, lossfn, tx)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 243, 118 | "id": "483e53f0-a310-4f38-b69d-833afa320dd7", 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "application/vnd.jupyter.widget-view+json": { 124 | "model_id": "62ab7d65b64246d0aa015b59248086dc", 125 | "version_major": 2, 126 | "version_minor": 0 127 | }, 128 | "text/plain": [ 129 | " 0%| | 0/100 [00:00]" 150 | ] 151 | }, 152 | "execution_count": 244, 153 | "metadata": {}, 154 | "output_type": "execute_result" 155 | }, 156 | { 157 | "data": { 158 | "image/png": "\n", 159 | "text/plain": [ 160 | "
" 161 | ] 162 | }, 163 | "metadata": { 164 | "image/png": { 165 | "height": 248, 166 | "width": 397 167 | }, 168 | "needs_background": "light" 169 | }, 170 | "output_type": "display_data" 171 | } 172 | ], 173 | "source": [ 174 | "plt.plot(meta_output[\"losses\"])" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 241, 180 | "id": "5ac01d9a-65c5-43d7-9d27-16248e722f64", 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "data": { 185 | "text/plain": [ 186 | "[]" 187 | ] 188 | }, 189 | "execution_count": 241, 190 | "metadata": {}, 191 | "output_type": "execute_result" 192 | }, 193 | { 194 | "data": { 195 | "image/png": "\n", 196 | "text/plain": [ 197 | "
" 198 | ] 199 | }, 200 | "metadata": { 201 | "image/png": { 202 | "height": 248, 203 | "width": 378 204 | }, 205 | "needs_background": "light" 206 | }, 207 | "output_type": "display_data" 208 | } 209 | ], 210 | "source": [ 211 | "x = jnp.linspace(0, 2 * jnp.pi)\n", 212 | "plt.plot(jnp.cos(x) + 1)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "id": "f3029221-02ff-4b70-8fed-1005878c5595", 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "jnp.sqrt(meta_output[\"losses\"])[-10:]" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 172, 228 | "id": "472ddb0c-0d3b-4aac-a523-46759d076024", 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "data": { 233 | "text/plain": [ 234 | "DeviceArray([104.64888 , 104.65269 , 104.653076, 104.651764, 104.662605,\n", 235 | " 104.65336 , 104.648094, 104.650826, 104.64987 , 104.64641 ], dtype=float32)" 236 | ] 237 | }, 238 | "execution_count": 172, 239 | "metadata": {}, 240 | "output_type": "execute_result" 241 | } 242 | ], 243 | "source": [ 244 | "jnp.sqrt(meta_output[\"losses\"])[-10:]" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 246, 250 | "id": "fd2b9ff8-8036-420a-8761-0cf817aba4b1", 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "configs_pred = meta_model.apply(meta_output[\"params\"], X_train.reshape(-1, 28**2))\n", 255 | "configs_pred = configs_pred.reshape(600, 150)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 247, 261 | "id": "595aa4ea-0901-48ee-8114-fd1392c0084d", 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "data": { 266 | "text/plain": [ 267 | "DeviceArray(3.0539749, dtype=float32)" 268 | ] 269 | }, 270 | "execution_count": 247, 271 | "metadata": {}, 272 | "output_type": "execute_result" 273 | } 274 | ], 275 | "source": [ 276 | "configs_pred.min()" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 248, 282 | "id": "c64b0544-a318-48ec-a547-426fbb88bf1f", 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "data": { 287 | "text/plain": [ 288 | "DeviceArray(3.264753, dtype=float32)" 289 | ] 290 | }, 291 | "execution_count": 248, 292 | "metadata": {}, 293 | "output_type": "execute_result" 294 | } 295 | ], 296 | "source": [ 297 | "configs_pred.max()" 298 | ] 299 | } 300 | ], 301 | "metadata": { 302 | "kernelspec": { 303 | "display_name": "Python 3 (ipykernel)", 304 | "language": "python", 305 | "name": "python3" 306 | }, 307 | "language_info": { 308 | "codemirror_mode": { 309 | "name": "ipython", 310 | "version": 3 311 | }, 312 | "file_extension": ".py", 313 | "mimetype": "text/x-python", 314 | "name": "python", 315 | "nbconvert_exporter": "python", 316 | "pygments_lexer": "ipython3", 317 | "version": "3.9.5" 318 | } 319 | }, 320 | "nbformat": 4, 321 | "nbformat_minor": 5 322 | } 323 | --------------------------------------------------------------------------------