├── .gitignore ├── .style.yapf ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── mnist_video.gif ├── plot_all_the_things.sh ├── shell.nix └── src ├── .gitignore ├── cifar100_calibration_plot.py ├── cifar100_full_acc5.py ├── cifar100_resnet20_ensembling.py ├── cifar100_resnet20_interp_logits.py ├── cifar100_resnet20_split_data_plot.py ├── cifar100_resnet20_train.py ├── cifar100_resnet20_weight_matching.py ├── cifar10_mlp_activation_matching.py ├── cifar10_mlp_barrier_vs_epoch_matching.py ├── cifar10_mlp_barrier_vs_epoch_plot.py ├── cifar10_mlp_interp_plot.py ├── cifar10_mlp_ste2.py ├── cifar10_mlp_train.py ├── cifar10_mlp_weight_matching.py ├── cifar10_resnet20_interp_plot.py ├── cifar10_resnet20_train.py ├── cifar10_resnet20_weight_matching.py ├── cifar10_resnet20_width_ablation_plot.py ├── cifar10_vgg_activation_matching.py ├── cifar10_vgg_cosine_similarity_matching.py ├── cifar10_vgg_run.py ├── cifar10_vgg_ste.py ├── cifar10_vgg_ste2.py ├── cifar10_vgg_weight_matching.py ├── cifar10_vgg_width_ablation_plot.py ├── datasets.py ├── imagenet_resnet50_interp_plot.py ├── imagenet_resnet50_weight_matching.py ├── matplotlib_style.py ├── mnist_barrier_vs_epoch_matching.py ├── mnist_barrier_vs_epoch_plot.py ├── mnist_convnet_plot.py ├── mnist_convnet_run.py ├── mnist_mlp_activation_matching.py ├── mnist_mlp_cosine_similarity_matching.py ├── mnist_mlp_interp_plot.py ├── mnist_mlp_loss_contour.py ├── mnist_mlp_ste.py ├── mnist_mlp_ste2.py ├── mnist_mlp_steepest_descent.py ├── mnist_mlp_train.py ├── mnist_mlp_weight_matching.py ├── mnist_mlp_wm_many.py ├── mnist_vgg16_run.py ├── mnist_vgg_weight_matching.py ├── mnist_video.py ├── online_stats.py ├── parallel_cifar10_run.py ├── parallel_mnist_videos.py ├── plot_utils.py ├── resnet.py ├── resnet20.py ├── sgd_is_special.py ├── should_be_deterministic.py ├── utils.py └── weight_matching.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.mp4 2 | *.pdf 3 | *.png 4 | *.eps 5 | 6 | # Nix build results 7 | result* 8 | 9 | # wandb related 10 | artifacts/ 11 | wandb/ 12 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | INDENT_WIDTH = 2 3 | BLANK_LINES_AROUND_TOP_LEVEL_DEFINITION = 1 4 | COLUMN_LIMIT = 100 5 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "nixEnvSelector.nixFile": "${workspaceRoot}/shell.nix", 3 | "python.formatting.provider": "yapf" 4 | } 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Samuel Ainsworth 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Git Re-Basin: Merging Models modulo Permutation Symmetries 2 | 3 | ![Video demonstrating the effect of our permutation matching algorithm on the loss landscape throughout training.](mnist_video.gif) 4 | 5 | Code for the paper [Git Re-Basin: Merging Models modulo Permutation Symmetries](https://arxiv.org/abs/2209.04836). 6 | 7 | Abstract: 8 | 9 | > The success of deep learning is thanks to our ability to solve certain massive non-convex optimization problems with relative ease. Despite non-convex optimization being NP-hard, simple algorithms -- often variants of stochastic gradient descent -- exhibit surprising effectiveness in fitting large neural networks in practice. We argue that neural network loss landscapes contain (nearly) a single basin, after accounting for all possible permutation symmetries of hidden units. We introduce three algorithms to permute the units of one model to bring them into alignment with units of a reference model. This transformation produces a functionally equivalent set of weights that lie in an approximately convex basin near the reference model. Experimentally, we demonstrate the single basin phenomenon across a variety of model architectures and datasets, including the first (to our knowledge) demonstration of zero-barrier linear mode connectivity between independently trained ResNet models on CIFAR-10 and CIFAR-100. Additionally, we identify intriguing phenomena relating model width and training time to mode connectivity across a variety of models and datasets. Finally, we discuss shortcomings of a single basin theory, including a counterexample to the linear mode connectivity hypothesis. 10 | -------------------------------------------------------------------------------- /mnist_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samuela/git-re-basin/ef40098257ab97243930eba737d6dcb8edd5863e/mnist_video.gif -------------------------------------------------------------------------------- /plot_all_the_things.sh: -------------------------------------------------------------------------------- 1 | python cifar10_mlp_plot.py 2 | python mnist_mlp_interp_plot.py 3 | 4 | python mnist_barrier_vs_epoch_plot.py 5 | python cifar10_mlp_barrier_vs_epoch_plot.py 6 | 7 | python cifar10_vgg_width_ablation_plot.py 8 | python cifar10_resnet20_width_ablation_plot.py 9 | 10 | python sgd_is_special.py 11 | 12 | python cifar100_resnet20_split_data_plot.py 13 | -------------------------------------------------------------------------------- /shell.nix: -------------------------------------------------------------------------------- 1 | # Run with nixGL, eg `nixGLNvidia-510.47.03 python cifar10_convnet_run.py --test` 2 | 3 | # To prevent JAX from allocating all GPU memory: XLA_PYTHON_CLIENT_PREALLOCATE=false 4 | # To push build to cachix: nix-store -qR --include-outputs $(nix-instantiate shell.nix) | cachix push ploop 5 | 6 | let 7 | # pkgs = import (/home/skainswo/dev/nixpkgs) { }; 8 | 9 | # Last updated: 2022-05-16. Check for new commits at status.nixos.org. 10 | pkgs = import (fetchTarball "https://github.com/NixOS/nixpkgs/archive/556ce9a40abde33738e6c9eac65f965a8be3b623.tar.gz") { 11 | config.allowUnfree = true; 12 | # These actually cause problems for some reason. bug report? 13 | # config.cudaSupport = true; 14 | # config.cudnnSupport = true; 15 | }; 16 | in 17 | pkgs.mkShell { 18 | buildInputs = with pkgs; [ 19 | ffmpeg 20 | python3 21 | python3Packages.augmax 22 | python3Packages.einops 23 | python3Packages.flax 24 | python3Packages.ipython 25 | python3Packages.jax 26 | # See https://discourse.nixos.org/t/petition-to-build-and-cache-unfree-packages-on-cache-nixos-org/17440/14 27 | # as to why we don't use the source builds of jaxlib/tensorflow. 28 | (python3Packages.jaxlib-bin.override { 29 | cudaSupport = true; 30 | }) 31 | python3Packages.matplotlib 32 | # python3Packages.pandas 33 | python3Packages.plotly 34 | # python3Packages.scikit-learn 35 | python3Packages.seaborn 36 | (python3Packages.tensorflow-bin.override { 37 | cudaSupport = false; 38 | }) 39 | # Thankfully tensorflow-datasets does not have tensorflow as a propagatedBuildInput. If that were the case for any 40 | # of these dependencies, we'd be in trouble since Python does not like multiple versions of the same package in 41 | # PYTHONPATH. 42 | python3Packages.tensorflow-datasets 43 | python3Packages.tqdm 44 | python3Packages.wandb 45 | 46 | # Necessary for LaTeX in matplotlib. 47 | texlive.combined.scheme-full 48 | 49 | yapf 50 | ]; 51 | 52 | # Don't clog EFS with wandb results. Wandb will create and use /tmp/wandb. 53 | WANDB_DIR = "/tmp"; 54 | WANDB_CACHE_DIR = "/tmp"; 55 | } 56 | -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /src/cifar100_calibration_plot.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from jax import nn 6 | from tqdm import tqdm 7 | 8 | import matplotlib_style as _ 9 | 10 | # See https://github.com/google/jax/issues/696#issuecomment-642457347 11 | # import os 12 | # os.environ["JAX_PLATFORM_NAME"] = "cpu" 13 | 14 | NUM_CLASSES = 100 15 | 16 | data = pickle.load(open("../cifar100_interp_logits.pkl", "rb")) 17 | 18 | num_bins = 15 19 | bins = np.linspace(0, 1, num_bins + 1) 20 | bin_locations = 0.5 * (bins[:-1] + bins[1:]) 21 | 22 | def one(bin_ix, probs, labels): 23 | lo, hi = bins[bin_ix], bins[bin_ix + 1] 24 | mask = (lo <= probs) & (probs <= hi) 25 | y_onehot = nn.one_hot(labels, NUM_CLASSES) 26 | return np.mean(y_onehot[mask]) 27 | 28 | ### Plotting 29 | plt.figure(figsize=(12, 6)) 30 | 31 | # Train 32 | plt.subplot(1, 2, 1) 33 | plotting_ds_name = "train" 34 | plotting_ds = data[f"{plotting_ds_name}_dataset"] 35 | 36 | a_probs = nn.softmax(data[f"a_{plotting_ds_name}_logits"]) 37 | b_probs = nn.softmax(data[f"b_{plotting_ds_name}_logits"]) 38 | clever_probs = nn.softmax(data[f"clever_{plotting_ds_name}_logits"]) 39 | naive_probs = nn.softmax(data[f"naive_{plotting_ds_name}_logits"]) 40 | ensemble_probs = nn.softmax(0.5 * (data[f"a_{plotting_ds_name}_logits"] + data[f"b_{plotting_ds_name}_logits"])) 41 | 42 | model_a_ys = [one(ix, a_probs, plotting_ds["labels"]) for ix in tqdm(range(num_bins))] 43 | model_b_ys = [one(ix, b_probs, plotting_ds["labels"]) for ix in tqdm(range(num_bins))] 44 | wm_ys = [one(ix, clever_probs, plotting_ds["labels"]) for ix in tqdm(range(num_bins))] 45 | naive_ys = [one(ix, naive_probs, plotting_ds["labels"]) for ix in tqdm(range(num_bins))] 46 | ensemble_ys = [one(ix, ensemble_probs, plotting_ds["labels"]) for ix in tqdm(range(num_bins))] 47 | 48 | plt.plot([0, 1], [0, 1], color="tab:grey", linestyle="dotted", label="Perfect calibration") 49 | plt.plot(bin_locations, model_a_ys, alpha=0.5, label="Model A") 50 | plt.plot(bin_locations, model_b_ys, alpha=0.5, label="Model B") 51 | plt.plot(bin_locations, naive_ys, color="tab:grey", marker=".", label="Naïve merging") 52 | plt.plot(bin_locations, ensemble_ys, color="tab:purple", marker="2", label="Model ensemble") 53 | plt.plot(bin_locations, wm_ys, color="tab:green", marker="^", linewidth=2, label="Weight matching") 54 | plt.xlabel("Predicted probability") 55 | plt.ylabel("True probability") 56 | plt.axis("equal") 57 | plt.legend() 58 | plt.title("Train") 59 | plt.xlim(0, 1) 60 | plt.ylim(0, 1) 61 | plt.xticks(np.linspace(0, 1, 5)) 62 | plt.yticks(np.linspace(0, 1, 5)) 63 | 64 | # Test 65 | plt.subplot(1, 2, 2) 66 | plotting_ds_name = "test" 67 | plotting_ds = data[f"{plotting_ds_name}_dataset"] 68 | 69 | a_probs = nn.softmax(data[f"a_{plotting_ds_name}_logits"]) 70 | b_probs = nn.softmax(data[f"b_{plotting_ds_name}_logits"]) 71 | clever_probs = nn.softmax(data[f"clever_{plotting_ds_name}_logits"]) 72 | naive_probs = nn.softmax(data[f"naive_{plotting_ds_name}_logits"]) 73 | ensemble_probs = nn.softmax(0.5 * (data[f"a_{plotting_ds_name}_logits"] + data[f"b_{plotting_ds_name}_logits"])) 74 | 75 | model_a_ys = [one(ix, a_probs, plotting_ds["labels"]) for ix in tqdm(range(num_bins))] 76 | model_b_ys = [one(ix, b_probs, plotting_ds["labels"]) for ix in tqdm(range(num_bins))] 77 | wm_ys = [one(ix, clever_probs, plotting_ds["labels"]) for ix in tqdm(range(num_bins))] 78 | naive_ys = [one(ix, naive_probs, plotting_ds["labels"]) for ix in tqdm(range(num_bins))] 79 | ensemble_ys = [one(ix, ensemble_probs, plotting_ds["labels"]) for ix in tqdm(range(num_bins))] 80 | 81 | plt.plot([0, 1], [0, 1], color="tab:grey", linestyle="dotted", label="Perfect calibration") 82 | plt.plot(bin_locations, model_a_ys, alpha=0.5, linestyle="dashed", label="Model A") 83 | plt.plot(bin_locations, model_b_ys, alpha=0.5, linestyle="dashed", label="Model B") 84 | plt.plot(bin_locations, 85 | naive_ys, 86 | color="tab:grey", 87 | marker=".", 88 | linestyle="dashed", 89 | label="Naïve merging") 90 | plt.plot(bin_locations, 91 | ensemble_ys, 92 | color="tab:purple", 93 | marker="2", 94 | linestyle="dashed", 95 | label="Model ensemble") 96 | plt.plot(bin_locations, 97 | wm_ys, 98 | color="tab:green", 99 | marker="^", 100 | linewidth=2, 101 | linestyle="dashed", 102 | label="Weight matching") 103 | plt.xlabel("Predicted probability") 104 | plt.ylabel("True probability") 105 | plt.axis("equal") 106 | plt.title("Test") 107 | plt.xlim(0, 1) 108 | plt.ylim(0, 1) 109 | plt.xticks(np.linspace(0, 1, 5)) 110 | plt.yticks(np.linspace(0, 1, 5)) 111 | 112 | plt.suptitle("CIFAR-100 Split Datasets, Calibration") 113 | plt.tight_layout() 114 | plt.savefig("figs/cifar100_calibration_plot.png", dpi=300) 115 | plt.savefig("figs/cifar100_calibration_plot.pdf") 116 | -------------------------------------------------------------------------------- /src/cifar100_full_acc5.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import jax.numpy as jnp 5 | import wandb 6 | from flax.serialization import from_bytes 7 | from jax import random 8 | 9 | from cifar100_resnet20_train import NUM_CLASSES, make_stuff 10 | from datasets import load_cifar100 11 | from resnet20 import BLOCKS_PER_GROUP, ResNet 12 | from utils import ec2_get_instance_type 13 | 14 | NUM_CLASSES = 100 15 | 16 | # https://wandb.ai/skainswo/git-re-basin/runs/f40w12z7/overview?workspace=user-skainswo 17 | # use model=v11, width multiplier=32, load-epoch=249 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--model", type=str, required=True) 21 | parser.add_argument("--width-multiplier", type=int, required=True) 22 | parser.add_argument("--load-epoch", type=int, required=True) 23 | args = parser.parse_args() 24 | 25 | with wandb.init( 26 | project="git-re-basin", 27 | entity="skainswo", 28 | tags=["cifar100", "resnet20", "top5"], 29 | job_type="analysis", 30 | ) as wandb_run: 31 | config = wandb.config 32 | config.ec2_instance_type = ec2_get_instance_type() 33 | config.model = args.model 34 | config.width_multiplier = args.width_multiplier 35 | config.load_epoch = args.load_epoch 36 | 37 | model = ResNet(blocks_per_group=BLOCKS_PER_GROUP["resnet20"], 38 | num_classes=NUM_CLASSES, 39 | width_multiplier=config.width_multiplier) 40 | stuff = make_stuff(model) 41 | 42 | def load_model(filepath): 43 | with open(filepath, "rb") as fh: 44 | return from_bytes( 45 | model.init(random.PRNGKey(0), jnp.zeros((1, 32, 32, 3)))["params"], fh.read()) 46 | 47 | filename = f"checkpoint{config.load_epoch}" 48 | model = load_model( 49 | Path( 50 | wandb_run.use_artifact(f"cifar100-resnet-weights:{config.model}").get_path( 51 | filename).download())) 52 | 53 | train_ds, test_ds = load_cifar100() 54 | 55 | test_loss, test_acc1, test_acc5 = stuff["dataset_loss_and_accuracies"](model, test_ds, 1000) 56 | 57 | print({ 58 | "test_loss": test_loss, 59 | "test_acc1": test_acc1, 60 | "test_acc5": test_acc5, 61 | }) 62 | wandb_run.log({ 63 | "test_loss": test_loss, 64 | "test_acc1": test_acc1, 65 | "test_acc5": test_acc5, 66 | }) 67 | -------------------------------------------------------------------------------- /src/cifar100_resnet20_ensembling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import jax.nn 5 | import jax.numpy as jnp 6 | import matplotlib.pyplot as plt 7 | import optax 8 | import wandb 9 | from flax.serialization import from_bytes 10 | from jax import jit, random, vmap 11 | from tqdm import tqdm 12 | 13 | from cifar100_resnet20_train import NUM_CLASSES, make_stuff 14 | from datasets import load_cifar100 15 | from resnet20 import BLOCKS_PER_GROUP, ResNet 16 | from utils import ec2_get_instance_type, lerp 17 | 18 | NUM_CLASSES = 100 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 23 | parser.add_argument("--model-a", type=str, required=True) 24 | parser.add_argument("--model-b", type=str, required=True) 25 | parser.add_argument("--width-multiplier", type=int, required=True) 26 | parser.add_argument("--load-epoch", type=int, required=True) 27 | args = parser.parse_args() 28 | 29 | with wandb.init( 30 | project="git-re-basin", 31 | entity="skainswo", 32 | tags=["cifar100", "resnet20", "ensembling"], 33 | job_type="analysis", 34 | ) as wandb_run: 35 | config = wandb.config 36 | config.ec2_instance_type = ec2_get_instance_type() 37 | config.model_a = args.model_a 38 | config.model_b = args.model_b 39 | config.width_multiplier = args.width_multiplier 40 | config.seed = args.seed 41 | config.load_epoch = args.load_epoch 42 | 43 | model = ResNet(blocks_per_group=BLOCKS_PER_GROUP["resnet20"], 44 | num_classes=NUM_CLASSES, 45 | width_multiplier=config.width_multiplier) 46 | stuff = make_stuff(model) 47 | 48 | def load_model(filepath): 49 | with open(filepath, "rb") as fh: 50 | return from_bytes( 51 | model.init(random.PRNGKey(0), jnp.zeros((1, 32, 32, 3)))["params"], fh.read()) 52 | 53 | filename = f"checkpoint{config.load_epoch}" 54 | model_a = load_model( 55 | Path( 56 | wandb_run.use_artifact(f"cifar100-resnet-weights:{config.model_a}").get_path( 57 | filename).download())) 58 | model_b = load_model( 59 | Path( 60 | wandb_run.use_artifact(f"cifar100-resnet-weights:{config.model_b}").get_path( 61 | filename).download())) 62 | 63 | train_ds, test_ds = load_cifar100() 64 | 65 | @jit 66 | def batch_logits(params, images_u8): 67 | images_f32 = vmap(stuff["normalize_transform"])(None, images_u8) 68 | return model.apply({"params": params}, images_f32) 69 | 70 | batch_size = 500 71 | 72 | def dataset_logits(params, dataset): 73 | num_examples = dataset["images_u8"].shape[0] 74 | assert num_examples % batch_size == 0 75 | num_batches = num_examples // batch_size 76 | batch_ix = jnp.arange(num_examples).reshape((num_batches, batch_size)) 77 | # Can't use vmap or run in a single batch since that overloads GPU memory. 78 | return jnp.concatenate([ 79 | batch_logits(params, dataset["images_u8"][batch_ix[i, :], ...]) 80 | for i in tqdm(range(num_batches)) 81 | ], 82 | axis=0) 83 | 84 | train_logits_a = dataset_logits(model_a, train_ds) 85 | train_logits_b = dataset_logits(model_b, train_ds) 86 | test_logits_a = dataset_logits(model_a, test_ds) 87 | test_logits_b = dataset_logits(model_b, test_ds) 88 | 89 | lambdas = jnp.linspace(0, 1, num=25) 90 | train_loss_interp = jnp.array([ 91 | jnp.mean( 92 | optax.softmax_cross_entropy(logits=lerp(lam, train_logits_a, train_logits_b), 93 | labels=jax.nn.one_hot(train_ds["labels"], NUM_CLASSES))) 94 | for lam in lambdas 95 | ]) 96 | test_loss_interp = jnp.array([ 97 | jnp.mean( 98 | optax.softmax_cross_entropy(logits=lerp(lam, test_logits_a, test_logits_b), 99 | labels=jax.nn.one_hot(test_ds["labels"], NUM_CLASSES))) 100 | for lam in lambdas 101 | ]) 102 | 103 | train_acc_interp = jnp.array([ 104 | jnp.sum( 105 | jnp.argmax(lerp(lam, train_logits_a, train_logits_b), axis=-1) == train_ds["labels"]) 106 | for lam in lambdas 107 | ]) / train_ds["labels"].shape[0] 108 | test_acc_interp = jnp.array([ 109 | jnp.sum(jnp.argmax(lerp(lam, test_logits_a, test_logits_b), axis=-1) == test_ds["labels"]) 110 | for lam in lambdas 111 | ]) / test_ds["labels"].shape[0] 112 | 113 | wandb_run.log({ 114 | "train_loss_interp": train_loss_interp, 115 | "test_loss_interp": test_loss_interp, 116 | "train_acc_interp": train_acc_interp, 117 | "test_acc_interp": test_acc_interp 118 | }) 119 | 120 | fig = plt.figure() 121 | plt.plot(lambdas, train_loss_interp, label="train") 122 | plt.plot(lambdas, test_loss_interp, label="test") 123 | plt.legend() 124 | plt.xlabel("Lambda") 125 | plt.ylabel("Loss") 126 | plt.title("Ensembling Loss") 127 | wandb_run.log({"loss_interp": wandb.Image(fig)}) 128 | 129 | fig = plt.figure() 130 | plt.plot(lambdas, train_acc_interp, label="train") 131 | plt.plot(lambdas, test_acc_interp, label="test") 132 | plt.legend() 133 | plt.xlabel("Lambda") 134 | plt.ylabel("Top-1 Accuracy") 135 | plt.title("Ensembling Accuracy") 136 | wandb_run.log({"acc_interp": wandb.Image(fig)}) 137 | -------------------------------------------------------------------------------- /src/cifar100_resnet20_interp_logits.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from pathlib import Path 4 | 5 | import jax.numpy as jnp 6 | import wandb 7 | from flax.serialization import from_bytes 8 | from jax import random 9 | 10 | from cifar100_resnet20_train import make_stuff 11 | from datasets import load_cifar100 12 | from resnet20 import BLOCKS_PER_GROUP, ResNet 13 | from utils import ec2_get_instance_type, flatten_params, lerp, unflatten_params 14 | from weight_matching import apply_permutation, resnet20_permutation_spec 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--model-a", type=str, required=True) 19 | parser.add_argument("--model-b", type=str, required=True) 20 | parser.add_argument("--permutation", type=str, required=True) 21 | parser.add_argument("--width-multiplier", type=int, required=True) 22 | parser.add_argument("--load-epoch", type=int, required=True) 23 | args = parser.parse_args() 24 | 25 | with wandb.init( 26 | project="git-re-basin", 27 | entity="skainswo", 28 | tags=["cifar100", "resnet20", "logits"], 29 | job_type="analysis", 30 | ) as wandb_run: 31 | config = wandb.config 32 | config.ec2_instance_type = ec2_get_instance_type() 33 | config.model_a = args.model_a 34 | config.model_b = args.model_b 35 | config.permutation = args.permutation 36 | config.width_multiplier = args.width_multiplier 37 | config.load_epoch = args.load_epoch 38 | 39 | model = ResNet(blocks_per_group=BLOCKS_PER_GROUP["resnet20"], 40 | num_classes=100, 41 | width_multiplier=config.width_multiplier) 42 | stuff = make_stuff(model) 43 | 44 | def load_model(filepath): 45 | with open(filepath, "rb") as fh: 46 | return from_bytes( 47 | model.init(random.PRNGKey(0), jnp.zeros((1, 32, 32, 3)))["params"], fh.read()) 48 | 49 | filename = f"checkpoint{config.load_epoch}" 50 | model_a = load_model( 51 | Path( 52 | wandb_run.use_artifact(f"cifar100-resnet-weights:{config.model_a}").get_path( 53 | filename).download())) 54 | model_b = load_model( 55 | Path( 56 | wandb_run.use_artifact(f"cifar100-resnet-weights:{config.model_b}").get_path( 57 | filename).download())) 58 | 59 | train_ds, test_ds = load_cifar100() 60 | 61 | permutation_spec = resnet20_permutation_spec() 62 | final_permutation = pickle.load( 63 | Path( 64 | wandb_run.use_artifact(f"model_b_permutation:{config.permutation}").get_path( 65 | "permutation.pkl").download()).open("rb")) 66 | 67 | print("model A") 68 | a_train_logits = stuff["dataset_logits"](model_a, train_ds, 1000) 69 | a_test_logits = stuff["dataset_logits"](model_a, test_ds, 1000) 70 | 71 | print("model B") 72 | b_train_logits = stuff["dataset_logits"](model_b, train_ds, 1000) 73 | b_test_logits = stuff["dataset_logits"](model_b, test_ds, 1000) 74 | 75 | print("naive interpolation") 76 | naive_interp_p = lerp(0.5, model_a, model_b) 77 | naive_train_logits = stuff["dataset_logits"](naive_interp_p, train_ds, 1000) 78 | naive_test_logits = stuff["dataset_logits"](naive_interp_p, test_ds, 1000) 79 | 80 | model_b_clever = unflatten_params( 81 | apply_permutation(permutation_spec, final_permutation, flatten_params(model_b))) 82 | 83 | print("clever interpolation") 84 | clever_interp_p = lerp(0.5, model_a, model_b_clever) 85 | clever_train_logits = stuff["dataset_logits"](clever_interp_p, train_ds, 1000) 86 | clever_test_logits = stuff["dataset_logits"](clever_interp_p, test_ds, 1000) 87 | 88 | with Path("cifar100_interp_logits.pkl").open("wb") as fh: 89 | pickle.dump( 90 | { 91 | "train_dataset": train_ds, 92 | "test_dataset": test_ds, 93 | "a_train_logits": a_train_logits, 94 | "a_test_logits": a_test_logits, 95 | "b_train_logits": b_train_logits, 96 | "b_test_logits": b_test_logits, 97 | "naive_train_logits": naive_train_logits, 98 | "naive_test_logits": naive_test_logits, 99 | "clever_train_logits": clever_train_logits, 100 | "clever_test_logits": clever_test_logits, 101 | }, fh) 102 | -------------------------------------------------------------------------------- /src/cifar100_resnet20_split_data_plot.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import wandb 6 | 7 | import matplotlib_style as _ 8 | from utils import lerp 9 | 10 | api = wandb.Api() 11 | # width multiplier = 32 12 | # weight_matching_run = api.run("skainswo/git-re-basin/r9tyrenf") # lerp 13 | weight_matching_run = api.run("skainswo/git-re-basin/3so345lw") # lerp, with top-5 14 | # weight_matching_run = api.run("skainswo/git-re-basin/bcu96ank") # slerp 15 | ensembling_run = api.run("skainswo/git-re-basin/2nwx9yyu") 16 | combined_training_run = api.run("skainswo/git-re-basin/f40w12z7") 17 | 18 | ensembling_data = pickle.load(open("../cifar100_interp_logits.pkl", "rb")) 19 | 20 | ### Loss plot 21 | fig = plt.figure() 22 | ax = fig.add_subplot(111) 23 | lambdas = np.linspace(0, 1, 25) 24 | 25 | # Naive 26 | # ax.plot(lambdas, 27 | # weight_matching_run.summary["train_loss_interp_naive"], 28 | # color="grey", 29 | # linewidth=2, 30 | # label=f"Naïve weight interp.") 31 | ax.plot(lambdas, 32 | weight_matching_run.summary["test_loss_interp_naive"], 33 | color="grey", 34 | linewidth=2, 35 | linestyle="dashed", 36 | label="Naïve weight interp.") 37 | 38 | # Ensembling 39 | ax.plot(lambdas, 40 | ensembling_run.summary["test_loss_interp"], 41 | color="tab:purple", 42 | marker="2", 43 | linestyle="dashed", 44 | linewidth=2, 45 | label="Ensembling") 46 | 47 | # Weight matching 48 | # ax.plot(lambdas, 49 | # weight_matching_run.summary["train_loss_interp_clever"], 50 | # color="tab:green", 51 | # marker="v", 52 | # linewidth=2, 53 | # label="Weight matching weight interp.") 54 | ax.plot(lambdas, 55 | weight_matching_run.summary["test_loss_interp_clever"], 56 | color="tab:green", 57 | marker="^", 58 | linestyle="dashed", 59 | linewidth=2, 60 | label="Weight matching") 61 | 62 | ax.axhline(y=combined_training_run.summary["test_loss"], 63 | linewidth=2, 64 | linestyle="dashed", 65 | label="Combined data training") 66 | 67 | ax.set_ylim(1, 5.1) 68 | ax.set_xlabel("$\lambda$") 69 | ax.set_xticks([0, 1]) 70 | ax.set_xticklabels(["Model $A$\nDataset $A$", "Model $B$\nDataset $B$"]) 71 | ax.set_ylabel("Test loss") 72 | ax.set_title("Split data training") 73 | ax.legend(loc="upper right", framealpha=0.5) 74 | fig.tight_layout() 75 | 76 | plt.savefig("figs/cifar100_resnet20_split_data_test_loss.png", dpi=300) 77 | plt.savefig("figs/cifar100_resnet20_split_data_test_loss.pdf") 78 | 79 | ### Top-1 Accuracy plot 80 | fig = plt.figure() 81 | ax = fig.add_subplot(111) 82 | lambdas = np.linspace(0, 1, 25) 83 | 84 | # Naive 85 | # ax.plot(lambdas, 86 | # weight_matching_run.summary["train_loss_interp_naive"], 87 | # color="grey", 88 | # linewidth=2, 89 | # label=f"Naïve weight interp.") 90 | ax.plot(lambdas, 91 | 100 * np.array(weight_matching_run.summary["test_acc1_interp_naive"]), 92 | color="grey", 93 | linewidth=2, 94 | linestyle="dashed", 95 | label="Naïve weight interp.") 96 | 97 | # Ensembling 98 | ax.plot(lambdas, 99 | 100 * np.array(ensembling_run.summary["test_acc_interp"]), 100 | color="tab:purple", 101 | marker="2", 102 | linestyle="dashed", 103 | linewidth=2, 104 | label="Ensembling") 105 | 106 | # Weight matching 107 | # ax.plot(lambdas, 108 | # weight_matching_run.summary["train_loss_interp_clever"], 109 | # color="tab:green", 110 | # marker="v", 111 | # linewidth=2, 112 | # label="Weight matching weight interp.") 113 | ax.plot(lambdas, 114 | 100 * np.array(weight_matching_run.summary["test_acc1_interp_clever"]), 115 | color="tab:green", 116 | marker="^", 117 | linestyle="dashed", 118 | linewidth=2, 119 | label="Weight matching") 120 | 121 | ax.axhline(y=100 * combined_training_run.summary["test_accuracy"], 122 | linewidth=2, 123 | linestyle="dashed", 124 | label="Combined data training") 125 | 126 | # ax.set_ylim(1, 5.1) 127 | ax.set_xlabel("$\lambda$") 128 | ax.set_xticks([0, 1]) 129 | ax.set_xticklabels(["Model $A$\nDataset $A$", "Model $B$\nDataset $B$"]) 130 | ax.set_ylabel("Top-1 accuracy") 131 | ax.set_title("CIFAR-100, Split data training") 132 | # ax.legend(loc="upper right", framealpha=0.5) 133 | fig.tight_layout() 134 | 135 | plt.savefig("figs/cifar100_resnet20_split_data_test_acc1.png", dpi=300) 136 | plt.savefig("figs/cifar100_resnet20_split_data_test_acc1.pdf") 137 | 138 | ### Top-5 Accuracy plot 139 | fig = plt.figure() 140 | ax = fig.add_subplot(111) 141 | lambdas = np.linspace(0, 1, 25) 142 | 143 | # Naive 144 | # ax.plot(lambdas, 145 | # weight_matching_run.summary["train_loss_interp_naive"], 146 | # color="grey", 147 | # linewidth=2, 148 | # label=f"Naïve weight interp.") 149 | ax.plot(lambdas, 150 | 100 * np.array(weight_matching_run.summary["test_acc5_interp_naive"]), 151 | color="grey", 152 | linewidth=2, 153 | linestyle="dashed", 154 | label="Naïve weight interp.") 155 | 156 | # Weight matching 157 | # ax.plot(lambdas, 158 | # weight_matching_run.summary["train_loss_interp_clever"], 159 | # color="tab:green", 160 | # marker="v", 161 | # linewidth=2, 162 | # label="Weight matching weight interp.") 163 | ax.plot( 164 | lambdas, 165 | 100 * np.array(weight_matching_run.summary["test_acc5_interp_clever"]), 166 | color="tab:green", 167 | marker="^", 168 | linestyle="dashed", 169 | linewidth=2, 170 | label="Weight matching", 171 | ) 172 | 173 | # Ensembling 174 | def lam_top5_acc(lam): 175 | logits = lerp(lam, ensembling_data["a_test_logits"], ensembling_data["b_test_logits"]) 176 | labels = ensembling_data["test_dataset"]["labels"] 177 | top5_num_correct = np.sum(np.isin(labels[:, np.newaxis], np.argsort(logits, axis=-1)[:, -5:])) 178 | return top5_num_correct / len(labels) 179 | 180 | ax.plot( 181 | lambdas, 182 | [100 * lam_top5_acc(lam) for lam in lambdas], 183 | color="tab:purple", 184 | marker="2", 185 | linestyle="dashed", 186 | linewidth=2, 187 | label="Ensembling", 188 | ) 189 | 190 | # See https://wandb.ai/skainswo/git-re-basin/runs/10kebhlr?workspace=user-skainswo for the calculation of this value 191 | ax.axhline(y=100.0, linewidth=2, linestyle="dashed", label="Combined data training") 192 | 193 | # ax.set_ylim(1, 5.1) 194 | ax.set_xlabel("$\lambda$") 195 | ax.set_xticks([0, 1]) 196 | ax.set_xticklabels(["Model $A$\nDataset $A$", "Model $B$\nDataset $B$"]) 197 | ax.set_ylabel("Top-5 accuracy") 198 | ax.set_title("CIFAR-100, Split data training") 199 | # ax.legend(loc="upper right", framealpha=0.5) 200 | fig.tight_layout() 201 | 202 | plt.savefig("figs/cifar100_resnet20_split_data_test_acc5.png", dpi=300) 203 | plt.savefig("figs/cifar100_resnet20_split_data_test_acc5.pdf") 204 | -------------------------------------------------------------------------------- /src/cifar100_resnet20_train.py: -------------------------------------------------------------------------------- 1 | """Train a ResNet20 model on a biased subset of CIFAR-100. Then we'll 2 | interpolate between them downstream and hopefully do better. 3 | """ 4 | import argparse 5 | 6 | import augmax 7 | import flax 8 | import jax.nn 9 | import jax.numpy as jnp 10 | import numpy as np 11 | import optax 12 | import tensorflow as tf 13 | import wandb 14 | from flax.training.train_state import TrainState 15 | from jax import jit, random, value_and_grad, vmap 16 | from tqdm import tqdm 17 | 18 | from datasets import load_cifar100, load_cifar100_split 19 | from resnet20 import BLOCKS_PER_GROUP, ResNet 20 | from utils import ec2_get_instance_type, rngmix, timeblock 21 | 22 | # See https://github.com/tensorflow/tensorflow/issues/53831. 23 | 24 | # See https://github.com/google/jax/issues/9454. 25 | tf.config.set_visible_devices([], "GPU") 26 | 27 | NUM_CLASSES = 100 28 | 29 | def make_stuff(model): 30 | train_transform = augmax.Chain( 31 | # augmax does not seem to support random crops with padding. See https://github.com/khdlr/augmax/issues/6. 32 | augmax.RandomSizedCrop(32, 32, zoom_range=(0.8, 1.2)), 33 | augmax.HorizontalFlip(), 34 | augmax.Rotate(), 35 | ) 36 | # Applied to all input images, test and train. 37 | normalize_transform = augmax.Chain(augmax.ByteToFloat(), augmax.Normalize()) 38 | 39 | @jit 40 | def batch_eval(params, images_u8, labels): 41 | images_f32 = vmap(normalize_transform)(None, images_u8) 42 | y_onehot = jax.nn.one_hot(labels, NUM_CLASSES) 43 | logits = model.apply({"params": params}, images_f32) 44 | l = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y_onehot)) 45 | top1_num_correct = jnp.sum(jnp.argmax(logits, axis=-1) == labels) 46 | # See https://github.com/google/jax/issues/2079. argpartition is not currently (2022-09-28) supported. 47 | # top5_num_correct = jnp.sum( 48 | # jnp.isin(labels[:, jnp.newaxis], 49 | # jnp.argpartition(logits, -5, axis=-1)[:, -5:])) 50 | top5_num_correct = jnp.sum( 51 | jnp.isin(labels[:, jnp.newaxis], 52 | jnp.argsort(logits, axis=-1)[:, -5:])) 53 | return l, { 54 | "logits": logits, 55 | "top1_num_correct": top1_num_correct, 56 | "top5_num_correct": top5_num_correct 57 | } 58 | 59 | @jit 60 | def step(rng, train_state, images, labels): 61 | images_transformed = vmap(train_transform)(random.split(rng, images.shape[0]), images) 62 | (l, info), g = value_and_grad(batch_eval, has_aux=True)(train_state.params, images_transformed, 63 | labels) 64 | # logits can be quite heaviweight, so we don't want to pass those along. 65 | return train_state.apply_gradients(grads=g), {**info, "batch_loss": l, "logits": None} 66 | 67 | def dataset_loss_and_accuracies(params, dataset, batch_size: int): 68 | num_examples = dataset["images_u8"].shape[0] 69 | assert num_examples % batch_size == 0 70 | num_batches = num_examples // batch_size 71 | batch_ix = jnp.arange(num_examples).reshape((num_batches, batch_size)) 72 | # Can't use vmap or run in a single batch since that overloads GPU memory. 73 | losses, infos = zip(*[ 74 | batch_eval( 75 | params, 76 | dataset["images_u8"][batch_ix[i, :], :, :, :], 77 | dataset["labels"][batch_ix[i, :]], 78 | ) for i in range(num_batches) 79 | ]) 80 | return ( 81 | jnp.sum(batch_size * jnp.array(losses)) / num_examples, 82 | sum(x["top1_num_correct"] for x in infos) / num_examples, 83 | sum(x["top5_num_correct"] for x in infos) / num_examples, 84 | ) 85 | 86 | def dataset_logits(params, dataset, batch_size: int): 87 | num_examples = dataset["images_u8"].shape[0] 88 | assert num_examples % batch_size == 0 89 | num_batches = num_examples // batch_size 90 | batch_ix = jnp.arange(num_examples).reshape((num_batches, batch_size)) 91 | # Can't use vmap or run in a single batch since that overloads GPU memory. 92 | _, infos = zip(*[ 93 | batch_eval( 94 | params, 95 | dataset["images_u8"][batch_ix[i, :], :, :, :], 96 | dataset["labels"][batch_ix[i, :]], 97 | ) for i in range(num_batches) 98 | ]) 99 | return jnp.concatenate([x["logits"] for x in infos]) 100 | 101 | return { 102 | "train_transform": train_transform, 103 | "normalize_transform": normalize_transform, 104 | "batch_eval": batch_eval, 105 | "step": step, 106 | "dataset_loss_and_accuracies": dataset_loss_and_accuracies, 107 | "dataset_logits": dataset_logits, 108 | } 109 | 110 | def init_train_state(rng, model, learning_rate, num_epochs, batch_size, num_train_examples, 111 | weight_decay: float): 112 | # See https://github.com/kuangliu/pytorch-cifar. 113 | warmup_epochs = 5 114 | steps_per_epoch = num_train_examples // batch_size 115 | lr_schedule = optax.warmup_cosine_decay_schedule( 116 | init_value=1e-6, 117 | peak_value=learning_rate, 118 | warmup_steps=warmup_epochs * steps_per_epoch, 119 | # Confusingly, `decay_steps` is actually the total number of steps, 120 | # including the warmup. 121 | decay_steps=num_epochs * steps_per_epoch, 122 | ) 123 | tx = optax.chain(optax.add_decayed_weights(weight_decay), optax.sgd(lr_schedule, momentum=0.9)) 124 | # tx = optax.adamw(learning_rate=lr_schedule, weight_decay=5e-4) 125 | vars = model.init(rng, jnp.zeros((1, 32, 32, 3))) 126 | return TrainState.create(apply_fn=model.apply, params=vars["params"], tx=tx) 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument("--test", action="store_true", help="Run in smoke-test mode") 131 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 132 | parser.add_argument("--data-split", choices=["split1", "split2", "both"], required=True) 133 | parser.add_argument("--width-multiplier", type=int, default=1) 134 | parser.add_argument("--weight-decay", type=float, default=1e-4) 135 | args = parser.parse_args() 136 | 137 | with wandb.init( 138 | project="git-re-basin", 139 | entity="skainswo", 140 | tags=["cifar10", "resnet", "training"], 141 | mode="disabled" if args.test else "online", 142 | job_type="train", 143 | ) as wandb_run: 144 | artifact = wandb.Artifact("cifar100-resnet-weights", type="model-weights") 145 | 146 | config = wandb.config 147 | config.ec2_instance_type = ec2_get_instance_type() 148 | config.test = args.test 149 | config.seed = args.seed 150 | config.data_split = args.data_split 151 | config.learning_rate = 0.1 152 | config.num_epochs = 10 if args.test else 250 153 | config.batch_size = 100 154 | config.width_multiplier = args.width_multiplier 155 | config.weight_decay = args.weight_decay 156 | 157 | rng = random.PRNGKey(config.seed) 158 | 159 | model = ResNet(blocks_per_group=BLOCKS_PER_GROUP["resnet20"], 160 | num_classes=NUM_CLASSES, 161 | width_multiplier=config.width_multiplier) 162 | 163 | with timeblock("load datasets"): 164 | if config.data_split == "both": 165 | train_ds, test_ds = load_cifar100() 166 | else: 167 | split1, split2, test_ds = load_cifar100_split() 168 | train_ds = split1 if config.data_split == "split1" else split2 169 | 170 | print("train_ds labels hash", hash(np.array(train_ds["labels"]).tobytes())) 171 | print("test_ds labels hash", hash(np.array(test_ds["labels"]).tobytes())) 172 | 173 | num_train_examples = train_ds["images_u8"].shape[0] 174 | num_test_examples = test_ds["images_u8"].shape[0] 175 | assert num_train_examples % config.batch_size == 0 176 | print("num_train_examples", num_train_examples) 177 | print("num_test_examples", num_test_examples) 178 | 179 | stuff = make_stuff(model) 180 | train_state = init_train_state(rngmix(rng, "init"), 181 | model=model, 182 | learning_rate=config.learning_rate, 183 | num_epochs=config.num_epochs, 184 | batch_size=config.batch_size, 185 | num_train_examples=train_ds["images_u8"].shape[0], 186 | weight_decay=config.weight_decay) 187 | 188 | for epoch in tqdm(range(config.num_epochs)): 189 | infos = [] 190 | with timeblock(f"Epoch"): 191 | batch_ix = random.permutation(rngmix(rng, f"epoch-{epoch}"), num_train_examples).reshape( 192 | (-1, config.batch_size)) 193 | batch_rngs = random.split(rngmix(rng, f"batch_rngs-{epoch}"), batch_ix.shape[0]) 194 | for i in range(batch_ix.shape[0]): 195 | p = batch_ix[i, :] 196 | images_u8 = train_ds["images_u8"][p, :, :, :] 197 | labels = train_ds["labels"][p] 198 | train_state, info = stuff["step"](batch_rngs[i], train_state, images_u8, labels) 199 | infos.append(info) 200 | 201 | train_loss = sum(config.batch_size * x["batch_loss"] for x in infos) / num_train_examples 202 | train_acc1 = sum(x["top1_num_correct"] for x in infos) / num_train_examples 203 | train_acc5 = sum(x["top5_num_correct"] for x in infos) / num_train_examples 204 | 205 | # Evaluate test loss/accuracy 206 | with timeblock("Test set eval"): 207 | test_loss, test_acc1, test_acc5 = stuff["dataset_loss_and_accuracies"](train_state.params, 208 | test_ds, 1000) 209 | 210 | # See https://github.com/wandb/client/issues/3690. 211 | wandb_run.log({ 212 | "epoch": epoch, 213 | "train_loss": train_loss, 214 | "test_loss": test_loss, 215 | "train_acc1": train_acc1, 216 | "test_acc1": test_acc1, 217 | "train_acc5": train_acc5, 218 | "test_acc5": test_acc5, 219 | }) 220 | 221 | # No point saving the model at all if we're running in test mode. 222 | if (not config.test) and (epoch % 10 == 0 or epoch == config.num_epochs - 1): 223 | with timeblock("model serialization"): 224 | # See https://github.com/wandb/client/issues/3823 225 | filename = f"/tmp/checkpoint{epoch}" 226 | with open(filename, mode="wb") as f: 227 | f.write(flax.serialization.to_bytes(train_state.params)) 228 | artifact.add_file(filename) 229 | 230 | # This will be a no-op when config.test is enabled anyhow, since wandb will 231 | # be initialized with mode="disabled". 232 | wandb_run.log_artifact(artifact) 233 | -------------------------------------------------------------------------------- /src/cifar10_mlp_barrier_vs_epoch_matching.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | import jax.numpy as jnp 5 | import wandb 6 | from flax.serialization import from_bytes 7 | from jax import random 8 | from tqdm import tqdm 9 | 10 | from cifar10_mlp_train import MLPModel, make_stuff 11 | from datasets import load_cifar10 12 | from utils import flatten_params, lerp, unflatten_params 13 | from weight_matching import (apply_permutation, mlp_permutation_spec, weight_matching) 14 | 15 | with wandb.init( 16 | project="git-re-basin", 17 | entity="skainswo", 18 | tags=["cifar10", "mlp", "weight-matching", "barrier-vs-epoch"], 19 | job_type="analysis", 20 | ) as wandb_run: 21 | # api = wandb.Api() 22 | # seed0_run = api.run("skainswo/git-re-basin/1b1gztfx") 23 | # seed1_run = api.run("skainswo/git-re-basin/1hrmw7wr") 24 | 25 | config = wandb.config 26 | config.total_epochs = 100 27 | config.seed = 123 28 | 29 | model = MLPModel() 30 | stuff = make_stuff(model) 31 | 32 | def load_model(filepath): 33 | with open(filepath, "rb") as fh: 34 | return from_bytes( 35 | model.init(random.PRNGKey(0), jnp.zeros((1, 32, 32, 3)))["params"], fh.read()) 36 | 37 | seed0_artifact = Path(wandb_run.use_artifact("cifar10-mlp-weights:v13").download()) 38 | seed1_artifact = Path(wandb_run.use_artifact("cifar10-mlp-weights:v14").download()) 39 | 40 | permutation_spec = mlp_permutation_spec(3) 41 | 42 | def match_one_epoch(epoch: int): 43 | model_a = load_model(seed0_artifact / f"checkpoint{epoch}") 44 | model_b = load_model(seed1_artifact / f"checkpoint{epoch}") 45 | return weight_matching( 46 | random.PRNGKey(config.seed), 47 | permutation_spec, 48 | flatten_params(model_a), 49 | flatten_params(model_b), 50 | ) 51 | 52 | permutation_vs_epoch = [match_one_epoch(i) for i in tqdm(range(config.total_epochs))] 53 | 54 | artifact = wandb.Artifact("cifar10_permutation_vs_epoch", 55 | type="permutation_vs_epoch", 56 | metadata={ 57 | "dataset": "cifar10", 58 | "model": "mlp", 59 | "analysis": "weight-matching" 60 | }) 61 | with artifact.new_file("permutation_vs_epoch.pkl", mode="wb") as f: 62 | pickle.dump(permutation_vs_epoch, f) 63 | wandb_run.log_artifact(artifact) 64 | 65 | # Eval 66 | train_ds, test_ds = load_cifar10() 67 | 68 | def eval_one(epoch, permutation): 69 | model_a = load_model(seed0_artifact / f"checkpoint{epoch}") 70 | model_b = load_model(seed1_artifact / f"checkpoint{epoch}") 71 | 72 | lambdas = jnp.linspace(0, 1, num=25) 73 | 74 | model_b_perm = unflatten_params( 75 | apply_permutation(permutation_spec, permutation, flatten_params(model_b))) 76 | 77 | train_loss_interp = [] 78 | test_loss_interp = [] 79 | train_acc_interp = [] 80 | test_acc_interp = [] 81 | for lam in lambdas: 82 | clever_p = lerp(lam, model_a, model_b_perm) 83 | train_loss, train_acc = stuff["dataset_loss_and_accuracy"](clever_p, train_ds, 10_000) 84 | test_loss, test_acc = stuff["dataset_loss_and_accuracy"](clever_p, test_ds, 10_000) 85 | train_loss_interp.append(train_loss) 86 | test_loss_interp.append(test_loss) 87 | train_acc_interp.append(train_acc) 88 | test_acc_interp.append(test_acc) 89 | 90 | return { 91 | "train_loss_interp": train_loss_interp, 92 | "test_loss_interp": test_loss_interp, 93 | "train_acc_interp": train_acc_interp, 94 | "test_acc_interp": test_acc_interp, 95 | } 96 | 97 | interp_eval_vs_epoch = [eval_one(i, p) for i, p in tqdm(enumerate(permutation_vs_epoch))] 98 | 99 | artifact = wandb.Artifact("cifar10_permutation_eval_vs_epoch", 100 | type="permutation_eval_vs_epoch", 101 | metadata={ 102 | "dataset": "cifar10", 103 | "model": "mlp", 104 | "analysis": "weight-matching", 105 | "interpolation": "lerp" 106 | }) 107 | with artifact.new_file("permutation_eval_vs_epoch.pkl", mode="wb") as f: 108 | pickle.dump(interp_eval_vs_epoch, f) 109 | wandb_run.log_artifact(artifact) 110 | -------------------------------------------------------------------------------- /src/cifar10_mlp_barrier_vs_epoch_plot.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import wandb 7 | 8 | import matplotlib_style as _ 9 | from plot_utils import loss_barrier_is_nonnegative 10 | 11 | max_epoch = 25 12 | 13 | api = wandb.Api() 14 | # https://wandb.ai/skainswo/git-re-basin/runs/1t9yk4tm 15 | # run = api.run("skainswo/git-re-basin/1t9yk4tm") 16 | artifact = Path( 17 | api.artifact("skainswo/git-re-basin/cifar10_permutation_eval_vs_epoch:v0").download()) 18 | 19 | with open(artifact / "permutation_eval_vs_epoch.pkl", "rb") as f: 20 | interp_eval_vs_epoch = pickle.load(f) 21 | 22 | train_loss_interp = np.array([x["train_loss_interp"] for x in interp_eval_vs_epoch]) 23 | train_barrier_vs_epoch = np.max(train_loss_interp, 24 | axis=1) - 0.5 * (train_loss_interp[:, 0] + train_loss_interp[:, -1]) 25 | 26 | test_loss_interp = np.array([x["test_loss_interp"] for x in interp_eval_vs_epoch]) 27 | test_barrier_vs_epoch = np.max(test_loss_interp, 28 | axis=1) - 0.5 * (test_loss_interp[:, 0] + test_loss_interp[:, -1]) 29 | 30 | fig = plt.figure() 31 | # fig = plt.figure(figsize=(8, 4)) 32 | ax = fig.add_subplot(111) 33 | 34 | ax.arrow(5, 0.27, -4, 0.04, alpha=0.25) 35 | ins1 = ax.inset_axes((0.2, 0.7, 0.25, 0.25)) 36 | ins1.plot(train_loss_interp[0, :]) 37 | ins1.plot(test_loss_interp[0, :], linestyle="dashed") 38 | ins1.set_xticks([]) 39 | ins1.set_yticks([]) 40 | 41 | ax.arrow(22, 0.15, 0, -0.135, alpha=0.25) 42 | ins2 = ax.inset_axes((0.72, 0.35, 0.25, 0.25)) 43 | ins2.plot(train_loss_interp[21, :]) 44 | ins2.plot(test_loss_interp[21, :], linestyle="dashed") 45 | ins2.set_xticks([]) 46 | ins2.set_yticks([]) 47 | 48 | ax.plot( 49 | 1 + np.arange(max_epoch), 50 | train_barrier_vs_epoch[:max_epoch], 51 | marker="o", 52 | linewidth=2, 53 | label="Train", 54 | ) 55 | ax.plot( 56 | 1 + np.arange(max_epoch), 57 | test_barrier_vs_epoch[:max_epoch], 58 | marker="^", 59 | linestyle="dashed", 60 | linewidth=2, 61 | label="Test", 62 | ) 63 | 64 | loss_barrier_is_nonnegative(ax) 65 | 66 | ax.set_xlabel("Epoch") 67 | ax.set_ylabel("Loss barrier") 68 | ax.set_title(f"CIFAR-10") 69 | ax.legend(loc="upper right", framealpha=0.5) 70 | fig.tight_layout() 71 | 72 | plt.savefig("figs/cifar10_mlp_barrier_vs_epoch.png", dpi=300) 73 | plt.savefig("figs/cifar10_mlp_barrier_vs_epoch.eps") 74 | plt.savefig("figs/cifar10_mlp_barrier_vs_epoch.pdf") 75 | -------------------------------------------------------------------------------- /src/cifar10_mlp_interp_plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import wandb 4 | 5 | import matplotlib_style as _ 6 | 7 | if __name__ == "__main__": 8 | api = wandb.Api() 9 | activation_matching_run = api.run("skainswo/git-re-basin/3kcqs9ns") 10 | weight_matching_run = api.run("skainswo/git-re-basin/2al62vvv") 11 | ste_matching_run = api.run("skainswo/git-re-basin/371j84rs") 12 | 13 | ### Loss plot 14 | fig = plt.figure() 15 | ax = fig.add_subplot(111) 16 | lambdas = np.linspace(0, 1, 25) 17 | 18 | # Naive 19 | ax.plot(lambdas, 20 | np.array(activation_matching_run.summary["train_loss_interp_naive"]), 21 | color="grey", 22 | linewidth=2, 23 | label=f"Naïve") 24 | ax.plot(lambdas, 25 | np.array(activation_matching_run.summary["test_loss_interp_naive"]), 26 | color="grey", 27 | linewidth=2, 28 | linestyle="dashed") 29 | 30 | # Activation matching 31 | ax.plot(lambdas, 32 | np.array(activation_matching_run.summary["train_loss_interp_clever"]), 33 | color="tab:blue", 34 | marker="*", 35 | linewidth=2, 36 | label=f"Activation matching") 37 | ax.plot(lambdas, 38 | np.array(activation_matching_run.summary["test_loss_interp_clever"]), 39 | color="tab:blue", 40 | marker="*", 41 | linewidth=2, 42 | linestyle="dashed") 43 | 44 | # Weight matching 45 | ax.plot(lambdas, 46 | np.array(weight_matching_run.summary["train_loss_interp_clever"]), 47 | color="tab:green", 48 | marker="^", 49 | linewidth=2, 50 | label=f"Weight matching") 51 | ax.plot(lambdas, 52 | np.array(weight_matching_run.summary["test_loss_interp_clever"]), 53 | color="tab:green", 54 | marker="^", 55 | linestyle="dashed", 56 | linewidth=2) 57 | 58 | # STE matching 59 | ax.plot(lambdas, 60 | np.array(ste_matching_run.summary["train_loss_interp_clever"]), 61 | color="tab:red", 62 | marker="p", 63 | linewidth=2, 64 | label=f"STE matching") 65 | ax.plot(lambdas, 66 | np.array(ste_matching_run.summary["test_loss_interp_clever"]), 67 | color="tab:red", 68 | marker="p", 69 | linestyle="dashed", 70 | linewidth=2) 71 | 72 | ax.set_xlabel("$\lambda$") 73 | ax.set_xticks([0, 1]) 74 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 75 | ax.set_ylabel("Loss") 76 | ax.set_title(f"CIFAR-10, MLP") 77 | # ax.legend(loc="upper right", framealpha=0.5) 78 | fig.tight_layout() 79 | 80 | plt.savefig("figs/cifar10_mlp_loss_interp.png", dpi=300) 81 | plt.savefig("figs/cifar10_mlp_loss_interp.pdf") 82 | 83 | ### Accuracy plot 84 | fig = plt.figure() 85 | ax = fig.add_subplot(111) 86 | lambdas = np.linspace(0, 1, 25) 87 | 88 | # Naive 89 | ax.plot(lambdas, 90 | 100 * np.array(activation_matching_run.summary["train_acc_interp_naive"]), 91 | color="grey", 92 | linewidth=2, 93 | label=f"Naïve") 94 | ax.plot(lambdas, 95 | 100 * np.array(activation_matching_run.summary["test_acc_interp_naive"]), 96 | color="grey", 97 | linewidth=2, 98 | linestyle="dashed") 99 | 100 | # Activation matching 101 | ax.plot(lambdas, 102 | 100 * np.array(activation_matching_run.summary["train_acc_interp_clever"]), 103 | color="tab:blue", 104 | marker="*", 105 | linewidth=2, 106 | label=f"Activation matching") 107 | ax.plot(lambdas, 108 | 100 * np.array(activation_matching_run.summary["test_acc_interp_clever"]), 109 | color="tab:blue", 110 | marker="*", 111 | linewidth=2, 112 | linestyle="dashed") 113 | 114 | # Weight matching 115 | ax.plot(lambdas, 116 | 100 * np.array(weight_matching_run.summary["train_acc_interp_clever"]), 117 | color="tab:green", 118 | marker="^", 119 | linewidth=2, 120 | label=f"Weight matching") 121 | ax.plot(lambdas, 122 | 100 * np.array(weight_matching_run.summary["test_acc_interp_clever"]), 123 | color="tab:green", 124 | marker="^", 125 | linestyle="dashed", 126 | linewidth=2) 127 | 128 | # STE matching 129 | ax.plot(lambdas, 130 | 100 * np.array(ste_matching_run.summary["train_acc_interp_clever"]), 131 | color="tab:red", 132 | marker="p", 133 | linewidth=2, 134 | label=f"STE matching") 135 | ax.plot(lambdas, 136 | 100 * np.array(ste_matching_run.summary["test_acc_interp_clever"]), 137 | color="tab:red", 138 | marker="p", 139 | linestyle="dashed", 140 | linewidth=2) 141 | 142 | ax.set_xlabel("$\lambda$") 143 | ax.set_xticks([0, 1]) 144 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 145 | ax.set_ylabel("Accuracy") 146 | ax.set_title("CIFAR-10, MLP") 147 | # ax.legend(loc="lower right", framealpha=0.5) 148 | fig.tight_layout() 149 | 150 | plt.savefig("figs/cifar10_mlp_accuracy_interp.png", dpi=300) 151 | plt.savefig("figs/cifar10_mlp_accuracy_interp.pdf") 152 | -------------------------------------------------------------------------------- /src/cifar10_mlp_train.py: -------------------------------------------------------------------------------- 1 | """Train a MLP on CIFAR-10 on one random seed.""" 2 | import argparse 3 | 4 | import flax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import optax 8 | import tensorflow as tf 9 | import wandb 10 | from flax import linen as nn 11 | from flax.training.train_state import TrainState 12 | from jax import random, tree_map 13 | from tqdm import tqdm 14 | 15 | from cifar10_vgg_run import make_stuff 16 | from datasets import load_cifar10 17 | from utils import ec2_get_instance_type, flatten_params, rngmix, timeblock 18 | 19 | # See https://github.com/tensorflow/tensorflow/issues/53831. 20 | 21 | # See https://github.com/google/jax/issues/9454. 22 | tf.config.set_visible_devices([], "GPU") 23 | 24 | activation = nn.relu 25 | 26 | class MLPModel(nn.Module): 27 | 28 | @nn.compact 29 | def __call__(self, x): 30 | x = jnp.reshape(x, (-1, 32 * 32 * 3)) 31 | x = nn.Dense(512)(x) 32 | x = activation(x) 33 | x = nn.Dense(512)(x) 34 | x = activation(x) 35 | x = nn.Dense(512)(x) 36 | x = activation(x) 37 | x = nn.Dense(10)(x) 38 | x = nn.log_softmax(x) 39 | return x 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--test", action="store_true", help="Run in smoke-test mode") 44 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 45 | parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], required=True) 46 | parser.add_argument("--learning-rate", type=float, required=True) 47 | args = parser.parse_args() 48 | 49 | with wandb.init( 50 | project="git-re-basin", 51 | entity="skainswo", 52 | tags=["cifar10", "mlp", "training"], 53 | mode="disabled" if args.test else "online", 54 | job_type="train", 55 | ) as wandb_run: 56 | artifact = wandb.Artifact("cifar10-mlp-weights", type="model-weights") 57 | 58 | config = wandb.config 59 | config.ec2_instance_type = ec2_get_instance_type() 60 | config.test = args.test 61 | config.seed = args.seed 62 | config.optimizer = args.optimizer 63 | config.learning_rate = args.learning_rate 64 | config.num_epochs = 100 65 | config.batch_size = 100 66 | 67 | rng = random.PRNGKey(config.seed) 68 | 69 | model = MLPModel() 70 | stuff = make_stuff(model) 71 | 72 | with timeblock("load datasets"): 73 | train_ds, test_ds = load_cifar10() 74 | print("train_ds labels hash", hash(np.array(train_ds["labels"]).tobytes())) 75 | print("test_ds labels hash", hash(np.array(test_ds["labels"]).tobytes())) 76 | 77 | num_train_examples = train_ds["images_u8"].shape[0] 78 | num_test_examples = test_ds["images_u8"].shape[0] 79 | assert num_train_examples % config.batch_size == 0 80 | print("num_train_examples", num_train_examples) 81 | print("num_test_examples", num_test_examples) 82 | 83 | if config.optimizer == "sgd": 84 | lr_schedule = optax.warmup_cosine_decay_schedule( 85 | init_value=1e-6, 86 | peak_value=config.learning_rate, 87 | warmup_steps=num_train_examples // config.batch_size, 88 | # Confusingly, `decay_steps` is actually the total number of steps, 89 | # including the warmup. 90 | decay_steps=config.num_epochs * (num_train_examples // config.batch_size), 91 | ) 92 | # tx = optax.sgd(lr_schedule, momentum=0.9) 93 | tx = optax.chain(optax.add_decayed_weights(5e-4), optax.sgd(lr_schedule, momentum=0.9)) 94 | elif config.optimizer == "adam": 95 | tx = optax.adam(config.learning_rate) 96 | else: 97 | tx = optax.adamw(config.learning_rate, weight_decay=5e-4) 98 | 99 | train_state = TrainState.create( 100 | apply_fn=model.apply, 101 | params=model.init(rngmix(rng, "init"), jnp.zeros((1, 32, 32, 3)))["params"], 102 | tx=tx, 103 | ) 104 | 105 | for epoch in tqdm(range(config.num_epochs)): 106 | infos = [] 107 | with timeblock(f"Epoch"): 108 | batch_ix = random.permutation(rngmix(rng, f"epoch-{epoch}"), num_train_examples).reshape( 109 | (-1, config.batch_size)) 110 | batch_rngs = random.split(rngmix(rng, f"batch_rngs-{epoch}"), batch_ix.shape[0]) 111 | for i in range(batch_ix.shape[0]): 112 | p = batch_ix[i, :] 113 | images_u8 = train_ds["images_u8"][p, :, :, :] 114 | labels = train_ds["labels"][p] 115 | train_state, info = stuff["step"](batch_rngs[i], train_state, images_u8, labels) 116 | infos.append(info) 117 | 118 | train_loss = sum(config.batch_size * x["batch_loss"] for x in infos) / num_train_examples 119 | train_accuracy = sum(x["num_correct"] for x in infos) / num_train_examples 120 | 121 | # Evaluate test loss/accuracy 122 | with timeblock("Test set eval"): 123 | test_loss, test_accuracy = stuff["dataset_loss_and_accuracy"](train_state.params, test_ds, 124 | 1000) 125 | 126 | params_l2 = tree_map(lambda x: jnp.sqrt(jnp.sum(x**2)), 127 | flatten_params({"params_l2": train_state.params})) 128 | 129 | # See https://github.com/wandb/client/issues/3690. 130 | wandb_run.log({ 131 | "epoch": epoch, 132 | "train_loss": train_loss, 133 | "test_loss": test_loss, 134 | "train_accuracy": train_accuracy, 135 | "test_accuracy": test_accuracy, 136 | **params_l2 137 | }) 138 | 139 | # No point saving the model at all if we're running in test mode. 140 | # With layer width 512, the MLP is 3.7MB per checkpoint. 141 | if not config.test: 142 | with timeblock("model serialization"): 143 | with artifact.new_file(f"checkpoint{epoch}", mode="wb") as f: 144 | f.write(flax.serialization.to_bytes(train_state.params)) 145 | 146 | # This will be a no-op when config.test is enabled anyhow, since wandb will 147 | # be initialized with mode="disabled". 148 | wandb_run.log_artifact(artifact) 149 | -------------------------------------------------------------------------------- /src/cifar10_mlp_weight_matching.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from pathlib import Path 4 | 5 | import jax.numpy as jnp 6 | import matplotlib.pyplot as plt 7 | import wandb 8 | from flax.serialization import from_bytes 9 | from jax import random 10 | from tqdm import tqdm 11 | 12 | from cifar10_mlp_train import MLPModel, make_stuff 13 | from datasets import load_cifar10 14 | from utils import ec2_get_instance_type, flatten_params, lerp, unflatten_params 15 | from weight_matching import (apply_permutation, mlp_permutation_spec, weight_matching) 16 | 17 | def plot_interp_loss(epoch, lambdas, train_loss_interp_naive, test_loss_interp_naive, 18 | train_loss_interp_clever, test_loss_interp_clever): 19 | fig = plt.figure() 20 | ax = fig.add_subplot(111) 21 | ax.plot(lambdas, 22 | train_loss_interp_naive, 23 | linestyle="dashed", 24 | color="tab:blue", 25 | alpha=0.5, 26 | linewidth=2, 27 | label="Train, naïve interp.") 28 | ax.plot(lambdas, 29 | test_loss_interp_naive, 30 | linestyle="dashed", 31 | color="tab:orange", 32 | alpha=0.5, 33 | linewidth=2, 34 | label="Test, naïve interp.") 35 | ax.plot(lambdas, 36 | train_loss_interp_clever, 37 | linestyle="solid", 38 | color="tab:blue", 39 | linewidth=2, 40 | label="Train, permuted interp.") 41 | ax.plot(lambdas, 42 | test_loss_interp_clever, 43 | linestyle="solid", 44 | color="tab:orange", 45 | linewidth=2, 46 | label="Test, permuted interp.") 47 | ax.set_xlabel("$\lambda$") 48 | ax.set_xticks([0, 1]) 49 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 50 | ax.set_ylabel("Loss") 51 | # TODO label x=0 tick as \theta_1, and x=1 tick as \theta_2 52 | ax.set_title(f"Loss landscape between the two models (epoch {epoch})") 53 | ax.legend(loc="upper right", framealpha=0.5) 54 | fig.tight_layout() 55 | return fig 56 | 57 | def plot_interp_acc(epoch, lambdas, train_acc_interp_naive, test_acc_interp_naive, 58 | train_acc_interp_clever, test_acc_interp_clever): 59 | fig = plt.figure() 60 | ax = fig.add_subplot(111) 61 | ax.plot(lambdas, 62 | train_acc_interp_naive, 63 | linestyle="dashed", 64 | color="tab:blue", 65 | alpha=0.5, 66 | linewidth=2, 67 | label="Train, naïve interp.") 68 | ax.plot(lambdas, 69 | test_acc_interp_naive, 70 | linestyle="dashed", 71 | color="tab:orange", 72 | alpha=0.5, 73 | linewidth=2, 74 | label="Test, naïve interp.") 75 | ax.plot(lambdas, 76 | train_acc_interp_clever, 77 | linestyle="solid", 78 | color="tab:blue", 79 | linewidth=2, 80 | label="Train, permuted interp.") 81 | ax.plot(lambdas, 82 | test_acc_interp_clever, 83 | linestyle="solid", 84 | color="tab:orange", 85 | linewidth=2, 86 | label="Test, permuted interp.") 87 | ax.set_xlabel("$\lambda$") 88 | ax.set_xticks([0, 1]) 89 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 90 | ax.set_ylabel("Accuracy") 91 | # TODO label x=0 tick as \theta_1, and x=1 tick as \theta_2 92 | ax.set_title(f"Accuracy between the two models (epoch {epoch})") 93 | ax.legend(loc="lower right", framealpha=0.5) 94 | fig.tight_layout() 95 | return fig 96 | 97 | def main(): 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 100 | parser.add_argument("--model-a", type=str, required=True) 101 | parser.add_argument("--model-b", type=str, required=True) 102 | parser.add_argument("--load-epoch", type=int, required=True) 103 | args = parser.parse_args() 104 | 105 | with wandb.init( 106 | project="git-re-basin", 107 | entity="skainswo", 108 | tags=["cifar10", "mlp", "weight-matching"], 109 | job_type="analysis", 110 | ) as wandb_run: 111 | config = wandb.config 112 | config.ec2_instance_type = ec2_get_instance_type() 113 | config.model_a = args.model_a 114 | config.model_b = args.model_b 115 | config.seed = args.seed 116 | config.load_epoch = args.load_epoch 117 | 118 | model = MLPModel() 119 | stuff = make_stuff(model) 120 | 121 | def load_model(filepath): 122 | with open(filepath, "rb") as fh: 123 | return from_bytes( 124 | model.init(random.PRNGKey(0), jnp.zeros((1, 32, 32, 3)))["params"], fh.read()) 125 | 126 | filename = f"checkpoint{config.load_epoch}" 127 | model_a = load_model( 128 | Path( 129 | wandb_run.use_artifact(f"cifar10-mlp-weights:{config.model_a}").get_path( 130 | filename).download())) 131 | model_b = load_model( 132 | Path( 133 | wandb_run.use_artifact(f"cifar10-mlp-weights:{config.model_b}").get_path( 134 | filename).download())) 135 | 136 | train_ds, test_ds = load_cifar10() 137 | 138 | permutation_spec = mlp_permutation_spec(3) 139 | final_permutation = weight_matching(random.PRNGKey(config.seed), permutation_spec, 140 | flatten_params(model_a), flatten_params(model_b)) 141 | 142 | # Save final_permutation as an Artifact 143 | artifact = wandb.Artifact("model_b_permutation", 144 | type="permutation", 145 | metadata={ 146 | "dataset": "cifar10", 147 | "model": "mlp", 148 | "analysis": "weight-matching" 149 | }) 150 | with artifact.new_file("permutation.pkl", mode="wb") as f: 151 | pickle.dump(final_permutation, f) 152 | wandb_run.log_artifact(artifact) 153 | 154 | lambdas = jnp.linspace(0, 1, num=25) 155 | train_loss_interp_naive = [] 156 | test_loss_interp_naive = [] 157 | train_acc_interp_naive = [] 158 | test_acc_interp_naive = [] 159 | for lam in tqdm(lambdas): 160 | naive_p = lerp(lam, model_a, model_b) 161 | train_loss, train_acc = stuff["dataset_loss_and_accuracy"](naive_p, train_ds, 10_000) 162 | test_loss, test_acc = stuff["dataset_loss_and_accuracy"](naive_p, test_ds, 10_000) 163 | train_loss_interp_naive.append(train_loss) 164 | test_loss_interp_naive.append(test_loss) 165 | train_acc_interp_naive.append(train_acc) 166 | test_acc_interp_naive.append(test_acc) 167 | 168 | model_b_clever = unflatten_params( 169 | apply_permutation(permutation_spec, final_permutation, flatten_params(model_b))) 170 | 171 | train_loss_interp_clever = [] 172 | test_loss_interp_clever = [] 173 | train_acc_interp_clever = [] 174 | test_acc_interp_clever = [] 175 | for lam in tqdm(lambdas): 176 | clever_p = lerp(lam, model_a, model_b_clever) 177 | train_loss, train_acc = stuff["dataset_loss_and_accuracy"](clever_p, train_ds, 10_000) 178 | test_loss, test_acc = stuff["dataset_loss_and_accuracy"](clever_p, test_ds, 10_000) 179 | train_loss_interp_clever.append(train_loss) 180 | test_loss_interp_clever.append(test_loss) 181 | train_acc_interp_clever.append(train_acc) 182 | test_acc_interp_clever.append(test_acc) 183 | 184 | assert len(lambdas) == len(train_loss_interp_naive) 185 | assert len(lambdas) == len(test_loss_interp_naive) 186 | assert len(lambdas) == len(train_acc_interp_naive) 187 | assert len(lambdas) == len(test_acc_interp_naive) 188 | assert len(lambdas) == len(train_loss_interp_clever) 189 | assert len(lambdas) == len(test_loss_interp_clever) 190 | assert len(lambdas) == len(train_acc_interp_clever) 191 | assert len(lambdas) == len(test_acc_interp_clever) 192 | 193 | print("Plotting...") 194 | fig = plot_interp_loss(config.load_epoch, lambdas, train_loss_interp_naive, 195 | test_loss_interp_naive, train_loss_interp_clever, 196 | test_loss_interp_clever) 197 | plt.savefig(f"cifar10_mlp_weight_matching_interp_loss_epoch{config.load_epoch}.png", dpi=300) 198 | wandb.log({"interp_loss_fig": wandb.Image(fig)}, commit=False) 199 | plt.close(fig) 200 | 201 | fig = plot_interp_acc(config.load_epoch, lambdas, train_acc_interp_naive, test_acc_interp_naive, 202 | train_acc_interp_clever, test_acc_interp_clever) 203 | plt.savefig(f"cifar10_mlp_weight_matching_interp_accuracy_epoch{config.load_epoch}.png", 204 | dpi=300) 205 | wandb.log({"interp_acc_fig": wandb.Image(fig)}, commit=False) 206 | plt.close(fig) 207 | 208 | wandb.log({ 209 | "train_loss_interp_naive": train_loss_interp_naive, 210 | "test_loss_interp_naive": test_loss_interp_naive, 211 | "train_acc_interp_naive": train_acc_interp_naive, 212 | "test_acc_interp_naive": test_acc_interp_naive, 213 | "train_loss_interp_clever": train_loss_interp_clever, 214 | "test_loss_interp_clever": test_loss_interp_clever, 215 | "train_acc_interp_clever": train_acc_interp_clever, 216 | "test_acc_interp_clever": test_acc_interp_clever, 217 | }) 218 | 219 | print({ 220 | "train_loss_interp_naive": train_loss_interp_naive, 221 | "test_loss_interp_naive": test_loss_interp_naive, 222 | "train_acc_interp_naive": train_acc_interp_naive, 223 | "test_acc_interp_naive": test_acc_interp_naive, 224 | "train_loss_interp_clever": train_loss_interp_clever, 225 | "test_loss_interp_clever": test_loss_interp_clever, 226 | "train_acc_interp_clever": train_acc_interp_clever, 227 | "test_acc_interp_clever": test_acc_interp_clever, 228 | }) 229 | 230 | if __name__ == "__main__": 231 | main() 232 | -------------------------------------------------------------------------------- /src/cifar10_resnet20_interp_plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import wandb 4 | 5 | import matplotlib_style as _ 6 | 7 | api = wandb.Api() 8 | wm32_run = api.run("skainswo/git-re-basin/223t7txl") 9 | 10 | # ins2.plot(all_runs[-1].summary["train_loss_interp_clever"]) 11 | # ins2.plot(all_runs[-1].summary["test_loss_interp_clever"], linestyle="dashed") 12 | # ins2.set_xticks([]) 13 | # ins2.set_yticks([]) 14 | # ymin, ymax = ins2.get_ylim() 15 | # ins2.set_ylim((ymin - 0.2 * (ymax - ymin), ymax + 0.2 * (ymax - ymin))) 16 | 17 | ### Loss plot 18 | fig = plt.figure() 19 | ax = fig.add_subplot(111) 20 | lambdas = np.linspace(0, 1, 25) 21 | 22 | # Naive 23 | ax.plot( 24 | lambdas, 25 | wm32_run.summary["train_loss_interp_naive"], 26 | color="grey", 27 | linewidth=2, 28 | # label="Naïve", 29 | ) 30 | ax.plot( 31 | lambdas, 32 | wm32_run.summary["test_loss_interp_naive"], 33 | color="grey", 34 | linewidth=2, 35 | linestyle="dashed", 36 | ) 37 | 38 | # Activation matching 39 | # ax.plot(lambdas, 40 | # np.array(activation_matching_run.summary["train_loss_interp_clever"]), 41 | # color="tab:blue", 42 | # marker="*", 43 | # linewidth=2, 44 | # label=f"Activation matching") 45 | # ax.plot(lambdas, 46 | # np.array(activation_matching_run.summary["test_loss_interp_clever"]), 47 | # color="tab:blue", 48 | # marker="*", 49 | # linewidth=2, 50 | # linestyle="dashed") 51 | 52 | # Weight matching 53 | ax.plot(lambdas, 54 | wm32_run.summary["train_loss_interp_clever"], 55 | color="tab:green", 56 | marker="^", 57 | linewidth=2, 58 | label="Weight matching") 59 | ax.plot( 60 | lambdas, 61 | wm32_run.summary["test_loss_interp_clever"], 62 | color="tab:green", 63 | marker="^", 64 | linestyle="dashed", 65 | linewidth=2, 66 | ) 67 | 68 | # STE matching 69 | # ax.plot(lambdas, 70 | # np.array(ste_matching_run.summary["train_loss_interp_clever"]), 71 | # color="tab:red", 72 | # marker="p", 73 | # linewidth=2, 74 | # label=f"STE matching") 75 | # ax.plot(lambdas, 76 | # np.array(ste_matching_run.summary["test_loss_interp_clever"]), 77 | # color="tab:red", 78 | # marker="p", 79 | # linestyle="dashed", 80 | # linewidth=2) 81 | 82 | # ax.plot([], [], color="grey", linewidth=2, label="Train") 83 | # ax.plot([], [], color="grey", linewidth=2, linestyle="dashed", label="Test") 84 | 85 | ax.set_xlabel("$\lambda$") 86 | ax.set_xticks([0, 1]) 87 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 88 | ax.set_ylabel("Loss") 89 | ax.set_title(f"CIFAR-10, ResNet20 (32× width)") 90 | # ax.legend(loc="upper right", framealpha=0.5) 91 | fig.tight_layout() 92 | 93 | plt.savefig("figs/cifar10_resnet20_loss_interp.png", dpi=300) 94 | plt.savefig("figs/cifar10_resnet20_loss_interp.pdf") 95 | 96 | ### Accuracy plot 97 | fig = plt.figure() 98 | ax = fig.add_subplot(111) 99 | 100 | # Naive 101 | ax.plot( 102 | lambdas, 103 | 100 * np.array(wm32_run.summary["train_acc_interp_naive"]), 104 | color="grey", 105 | linewidth=2, 106 | label="Train", 107 | ) 108 | ax.plot( 109 | lambdas, 110 | 100 * np.array(wm32_run.summary["test_acc_interp_naive"]), 111 | color="grey", 112 | linewidth=2, 113 | linestyle="dashed", 114 | label="Test", 115 | ) 116 | 117 | # Activation matching 118 | # ax.plot(lambdas, 119 | # np.array(activation_matching_run.summary["train_acc_interp_clever"]), 120 | # color="tab:blue", 121 | # marker="*", 122 | # linewidth=2) 123 | # ax.plot(lambdas, 124 | # np.array(activation_matching_run.summary["test_acc_interp_clever"]), 125 | # color="tab:blue", 126 | # marker="*", 127 | # linewidth=2, 128 | # linestyle="dashed") 129 | 130 | # Weight matching 131 | ax.plot( 132 | lambdas, 133 | 100 * np.array(wm32_run.summary["train_acc_interp_clever"]), 134 | color="tab:green", 135 | marker="^", 136 | linewidth=2, 137 | ) 138 | ax.plot( 139 | lambdas, 140 | 100 * np.array(wm32_run.summary["test_acc_interp_clever"]), 141 | color="tab:green", 142 | marker="^", 143 | linestyle="dashed", 144 | linewidth=2, 145 | ) 146 | 147 | # STE matching 148 | # ax.plot(lambdas, 149 | # np.array(ste_matching_run.summary["train_acc_interp_clever"]), 150 | # color="tab:red", 151 | # marker="p", 152 | # linewidth=2) 153 | # ax.plot(lambdas, 154 | # np.array(ste_matching_run.summary["test_acc_interp_clever"]), 155 | # color="tab:red", 156 | # marker="p", 157 | # linestyle="dashed", 158 | # linewidth=2) 159 | 160 | ax.set_xlabel("$\lambda$") 161 | ax.set_xticks([0, 1]) 162 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 163 | ax.set_ylabel("Accuracy") 164 | ax.set_title("CIFAR-10, ResNet20 (32× width)") 165 | # ax.legend(loc="lower right", framealpha=0.5) 166 | fig.tight_layout() 167 | 168 | plt.savefig("figs/cifar10_resnet20_accuracy_interp.png", dpi=300) 169 | plt.savefig("figs/cifar10_resnet20_accuracy_interp.pdf") 170 | -------------------------------------------------------------------------------- /src/cifar10_resnet20_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | See 4 | * https://github.com/hushon/JAX-ResNet-CIFAR10/blob/main/resnet_cifar.py 5 | * https://github.com/akamaster/pytorch_resnet_cifar10/blob/master/resnet.py 6 | """ 7 | import argparse 8 | 9 | import augmax 10 | import flax 11 | import jax.nn 12 | import jax.numpy as jnp 13 | import numpy as np 14 | import optax 15 | import tensorflow as tf 16 | import wandb 17 | from flax.training.train_state import TrainState 18 | from jax import jit, random, value_and_grad, vmap 19 | from tqdm import tqdm 20 | 21 | from cifar100_resnet20_train import NUM_CLASSES 22 | from datasets import load_cifar10, load_cifar10_split 23 | from resnet20 import BLOCKS_PER_GROUP, ResNet 24 | from utils import ec2_get_instance_type, rngmix, timeblock 25 | 26 | # See https://github.com/tensorflow/tensorflow/issues/53831. 27 | 28 | # See https://github.com/google/jax/issues/9454. 29 | tf.config.set_visible_devices([], "GPU") 30 | 31 | NUM_CLASSES = 10 32 | 33 | def make_stuff(model): 34 | train_transform = augmax.Chain( 35 | # augmax does not seem to support random crops with padding. See https://github.com/khdlr/augmax/issues/6. 36 | augmax.RandomSizedCrop(32, 32, zoom_range=(0.8, 1.2)), 37 | augmax.HorizontalFlip(), 38 | augmax.Rotate(), 39 | ) 40 | # Applied to all input images, test and train. 41 | normalize_transform = augmax.Chain(augmax.ByteToFloat(), augmax.Normalize()) 42 | 43 | @jit 44 | def batch_eval(params, images_u8, labels): 45 | images_f32 = vmap(normalize_transform)(None, images_u8) 46 | y_onehot = jax.nn.one_hot(labels, NUM_CLASSES) 47 | logits = model.apply({"params": params}, images_f32) 48 | l = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y_onehot)) 49 | num_correct = jnp.sum(jnp.argmax(logits, axis=-1) == labels) 50 | return l, {"num_correct": num_correct} 51 | 52 | @jit 53 | def step(rng, train_state, images, labels): 54 | images_transformed = vmap(train_transform)(random.split(rng, images.shape[0]), images) 55 | (l, info), g = value_and_grad(batch_eval, has_aux=True)(train_state.params, images_transformed, 56 | labels) 57 | return train_state.apply_gradients(grads=g), {"batch_loss": l, **info} 58 | 59 | def dataset_loss_and_accuracy(params, dataset, batch_size: int): 60 | num_examples = dataset["images_u8"].shape[0] 61 | assert num_examples % batch_size == 0 62 | num_batches = num_examples // batch_size 63 | batch_ix = jnp.arange(num_examples).reshape((num_batches, batch_size)) 64 | # Can't use vmap or run in a single batch since that overloads GPU memory. 65 | losses, infos = zip(*[ 66 | batch_eval( 67 | params, 68 | dataset["images_u8"][batch_ix[i, :], :, :, :], 69 | dataset["labels"][batch_ix[i, :]], 70 | ) for i in range(num_batches) 71 | ]) 72 | return ( 73 | jnp.sum(batch_size * jnp.array(losses)) / num_examples, 74 | sum(x["num_correct"] for x in infos) / num_examples, 75 | ) 76 | 77 | return { 78 | "train_transform": train_transform, 79 | "normalize_transform": normalize_transform, 80 | "batch_eval": batch_eval, 81 | "step": step, 82 | "dataset_loss_and_accuracy": dataset_loss_and_accuracy, 83 | } 84 | 85 | def init_train_state(rng, model, learning_rate, num_epochs, batch_size, num_train_examples, 86 | weight_decay: float): 87 | # See https://github.com/kuangliu/pytorch-cifar. 88 | warmup_epochs = 5 89 | steps_per_epoch = num_train_examples // batch_size 90 | lr_schedule = optax.warmup_cosine_decay_schedule( 91 | init_value=1e-6, 92 | peak_value=learning_rate, 93 | warmup_steps=warmup_epochs * steps_per_epoch, 94 | # Confusingly, `decay_steps` is actually the total number of steps, 95 | # including the warmup. 96 | decay_steps=num_epochs * steps_per_epoch, 97 | ) 98 | tx = optax.chain(optax.add_decayed_weights(weight_decay), optax.sgd(lr_schedule, momentum=0.9)) 99 | # tx = optax.adamw(learning_rate=lr_schedule, weight_decay=5e-4) 100 | vars = model.init(rng, jnp.zeros((1, 32, 32, 3))) 101 | return TrainState.create(apply_fn=model.apply, params=vars["params"], tx=tx) 102 | 103 | if __name__ == "__main__": 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument("--test", action="store_true", help="Run in smoke-test mode") 106 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 107 | parser.add_argument("--data-split", choices=["split1", "split2", "both"], required=True) 108 | parser.add_argument("--width-multiplier", type=int, default=1) 109 | parser.add_argument("--weight-decay", type=float, default=1e-4) 110 | args = parser.parse_args() 111 | 112 | with wandb.init( 113 | project="git-re-basin", 114 | entity="skainswo", 115 | tags=["cifar10", "resnet", "training"], 116 | mode="disabled" if args.test else "online", 117 | job_type="train", 118 | ) as wandb_run: 119 | artifact = wandb.Artifact("cifar10-resnet-weights", type="model-weights") 120 | 121 | config = wandb.config 122 | config.ec2_instance_type = ec2_get_instance_type() 123 | config.test = args.test 124 | config.seed = args.seed 125 | config.data_split = args.data_split 126 | config.learning_rate = 0.1 127 | config.num_epochs = 10 if args.test else 250 128 | config.batch_size = 100 129 | config.width_multiplier = args.width_multiplier 130 | config.weight_decay = args.weight_decay 131 | 132 | rng = random.PRNGKey(config.seed) 133 | 134 | model = ResNet(blocks_per_group=BLOCKS_PER_GROUP["resnet20"], 135 | num_classes=NUM_CLASSES, 136 | width_multiplier=config.width_multiplier) 137 | 138 | with timeblock("load datasets"): 139 | if config.data_split == "both": 140 | train_ds, test_ds = load_cifar10() 141 | else: 142 | split1, split2, test_ds = load_cifar10_split() 143 | train_ds = split1 if config.data_split == "split1" else split2 144 | 145 | print("train_ds labels hash", hash(np.array(train_ds["labels"]).tobytes())) 146 | print("test_ds labels hash", hash(np.array(test_ds["labels"]).tobytes())) 147 | 148 | num_train_examples = train_ds["images_u8"].shape[0] 149 | num_test_examples = test_ds["images_u8"].shape[0] 150 | assert num_train_examples % config.batch_size == 0 151 | print("num_train_examples", num_train_examples) 152 | print("num_test_examples", num_test_examples) 153 | 154 | stuff = make_stuff(model) 155 | train_state = init_train_state(rngmix(rng, "init"), 156 | model=model, 157 | learning_rate=config.learning_rate, 158 | num_epochs=config.num_epochs, 159 | batch_size=config.batch_size, 160 | num_train_examples=train_ds["images_u8"].shape[0], 161 | weight_decay=config.weight_decay) 162 | 163 | for epoch in tqdm(range(config.num_epochs)): 164 | infos = [] 165 | with timeblock(f"Epoch"): 166 | batch_ix = random.permutation(rngmix(rng, f"epoch-{epoch}"), num_train_examples).reshape( 167 | (-1, config.batch_size)) 168 | batch_rngs = random.split(rngmix(rng, f"batch_rngs-{epoch}"), batch_ix.shape[0]) 169 | for i in range(batch_ix.shape[0]): 170 | p = batch_ix[i, :] 171 | images_u8 = train_ds["images_u8"][p, :, :, :] 172 | labels = train_ds["labels"][p] 173 | train_state, info = stuff["step"](batch_rngs[i], train_state, images_u8, labels) 174 | infos.append(info) 175 | 176 | train_loss = sum(config.batch_size * x["batch_loss"] for x in infos) / num_train_examples 177 | train_accuracy = sum(x["num_correct"] for x in infos) / num_train_examples 178 | 179 | # Evaluate test loss/accuracy 180 | with timeblock("Test set eval"): 181 | test_loss, test_accuracy = stuff["dataset_loss_and_accuracy"](train_state.params, test_ds, 182 | 1000) 183 | 184 | # See https://github.com/wandb/client/issues/3690. 185 | wandb_run.log({ 186 | "epoch": epoch, 187 | "train_loss": train_loss, 188 | "test_loss": test_loss, 189 | "train_accuracy": train_accuracy, 190 | "test_accuracy": test_accuracy, 191 | }) 192 | 193 | # No point saving the model at all if we're running in test mode. 194 | if (not config.test) and (epoch % 10 == 0 or epoch == config.num_epochs - 1): 195 | with timeblock("model serialization"): 196 | # See https://github.com/wandb/client/issues/3823 197 | filename = f"/tmp/checkpoint{epoch}" 198 | with open(filename, mode="wb") as f: 199 | f.write(flax.serialization.to_bytes(train_state.params)) 200 | artifact.add_file(filename) 201 | 202 | # This will be a no-op when config.test is enabled anyhow, since wandb will 203 | # be initialized with mode="disabled". 204 | wandb_run.log_artifact(artifact) 205 | -------------------------------------------------------------------------------- /src/cifar10_resnet20_weight_matching.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from pathlib import Path 4 | 5 | import jax.numpy as jnp 6 | import matplotlib.pyplot as plt 7 | import wandb 8 | from flax.serialization import from_bytes 9 | from jax import random 10 | from tqdm import tqdm 11 | 12 | from cifar10_resnet20_train import BLOCKS_PER_GROUP, ResNet, make_stuff 13 | from datasets import load_cifar10 14 | from utils import ec2_get_instance_type, flatten_params, lerp, unflatten_params 15 | from weight_matching import (apply_permutation, resnet20_permutation_spec, weight_matching) 16 | 17 | def plot_interp_loss(epoch, lambdas, train_loss_interp_naive, test_loss_interp_naive, 18 | train_loss_interp_clever, test_loss_interp_clever): 19 | fig = plt.figure() 20 | ax = fig.add_subplot(111) 21 | ax.plot(lambdas, 22 | train_loss_interp_naive, 23 | linestyle="dashed", 24 | color="tab:blue", 25 | alpha=0.5, 26 | linewidth=2, 27 | label="Train, naïve interp.") 28 | ax.plot(lambdas, 29 | test_loss_interp_naive, 30 | linestyle="dashed", 31 | color="tab:orange", 32 | alpha=0.5, 33 | linewidth=2, 34 | label="Test, naïve interp.") 35 | ax.plot(lambdas, 36 | train_loss_interp_clever, 37 | linestyle="solid", 38 | color="tab:blue", 39 | linewidth=2, 40 | label="Train, permuted interp.") 41 | ax.plot(lambdas, 42 | test_loss_interp_clever, 43 | linestyle="solid", 44 | color="tab:orange", 45 | linewidth=2, 46 | label="Test, permuted interp.") 47 | ax.set_xlabel("$\lambda$") 48 | ax.set_xticks([0, 1]) 49 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 50 | ax.set_ylabel("Loss") 51 | # TODO label x=0 tick as \theta_1, and x=1 tick as \theta_2 52 | ax.set_title(f"Loss landscape between the two models (epoch {epoch})") 53 | ax.legend(loc="upper right", framealpha=0.5) 54 | fig.tight_layout() 55 | return fig 56 | 57 | def plot_interp_acc(epoch, lambdas, train_acc_interp_naive, test_acc_interp_naive, 58 | train_acc_interp_clever, test_acc_interp_clever): 59 | fig = plt.figure() 60 | ax = fig.add_subplot(111) 61 | ax.plot(lambdas, 62 | train_acc_interp_naive, 63 | linestyle="dashed", 64 | color="tab:blue", 65 | alpha=0.5, 66 | linewidth=2, 67 | label="Train, naïve interp.") 68 | ax.plot(lambdas, 69 | test_acc_interp_naive, 70 | linestyle="dashed", 71 | color="tab:orange", 72 | alpha=0.5, 73 | linewidth=2, 74 | label="Test, naïve interp.") 75 | ax.plot(lambdas, 76 | train_acc_interp_clever, 77 | linestyle="solid", 78 | color="tab:blue", 79 | linewidth=2, 80 | label="Train, permuted interp.") 81 | ax.plot(lambdas, 82 | test_acc_interp_clever, 83 | linestyle="solid", 84 | color="tab:orange", 85 | linewidth=2, 86 | label="Test, permuted interp.") 87 | ax.set_xlabel("$\lambda$") 88 | ax.set_xticks([0, 1]) 89 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 90 | ax.set_ylabel("Accuracy") 91 | # TODO label x=0 tick as \theta_1, and x=1 tick as \theta_2 92 | ax.set_title(f"Accuracy between the two models (epoch {epoch})") 93 | ax.legend(loc="lower right", framealpha=0.5) 94 | fig.tight_layout() 95 | return fig 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 100 | parser.add_argument("--model-a", type=str, required=True) 101 | parser.add_argument("--model-b", type=str, required=True) 102 | parser.add_argument("--width-multiplier", type=int, required=True) 103 | parser.add_argument("--load-epoch", type=int, required=True) 104 | args = parser.parse_args() 105 | 106 | with wandb.init( 107 | project="git-re-basin", 108 | entity="skainswo", 109 | tags=["cifar10", "resnet20", "weight-matching"], 110 | job_type="analysis", 111 | ) as wandb_run: 112 | config = wandb.config 113 | config.ec2_instance_type = ec2_get_instance_type() 114 | config.model_a = args.model_a 115 | config.model_b = args.model_b 116 | config.width_multiplier = args.width_multiplier 117 | config.seed = args.seed 118 | config.load_epoch = args.load_epoch 119 | 120 | model = ResNet(blocks_per_group=BLOCKS_PER_GROUP["resnet20"], 121 | num_classes=10, 122 | width_multiplier=config.width_multiplier) 123 | stuff = make_stuff(model) 124 | 125 | def load_model(filepath): 126 | with open(filepath, "rb") as fh: 127 | return from_bytes( 128 | model.init(random.PRNGKey(0), jnp.zeros((1, 32, 32, 3)))["params"], fh.read()) 129 | 130 | filename = f"checkpoint{config.load_epoch}" 131 | model_a = load_model( 132 | Path( 133 | wandb_run.use_artifact(f"cifar10-resnet-weights:{config.model_a}").get_path( 134 | filename).download())) 135 | model_b = load_model( 136 | Path( 137 | wandb_run.use_artifact(f"cifar10-resnet-weights:{config.model_b}").get_path( 138 | filename).download())) 139 | 140 | train_ds, test_ds = load_cifar10() 141 | 142 | permutation_spec = resnet20_permutation_spec() 143 | final_permutation = weight_matching(random.PRNGKey(config.seed), permutation_spec, 144 | flatten_params(model_a), flatten_params(model_b)) 145 | 146 | # Save final_permutation as an Artifact 147 | artifact = wandb.Artifact("model_b_permutation", 148 | type="permutation", 149 | metadata={ 150 | "dataset": "cifar10", 151 | "model": "resnet20", 152 | "analysis": "weight-matching" 153 | }) 154 | with artifact.new_file("permutation.pkl", mode="wb") as f: 155 | pickle.dump(final_permutation, f) 156 | wandb_run.log_artifact(artifact) 157 | 158 | lambdas = jnp.linspace(0, 1, num=25) 159 | train_loss_interp_naive = [] 160 | test_loss_interp_naive = [] 161 | train_acc_interp_naive = [] 162 | test_acc_interp_naive = [] 163 | for lam in tqdm(lambdas): 164 | naive_p = lerp(lam, model_a, model_b) 165 | train_loss, train_acc = stuff["dataset_loss_and_accuracy"](naive_p, train_ds, 1000) 166 | test_loss, test_acc = stuff["dataset_loss_and_accuracy"](naive_p, test_ds, 1000) 167 | train_loss_interp_naive.append(train_loss) 168 | test_loss_interp_naive.append(test_loss) 169 | train_acc_interp_naive.append(train_acc) 170 | test_acc_interp_naive.append(test_acc) 171 | 172 | model_b_clever = unflatten_params( 173 | apply_permutation(permutation_spec, final_permutation, flatten_params(model_b))) 174 | 175 | train_loss_interp_clever = [] 176 | test_loss_interp_clever = [] 177 | train_acc_interp_clever = [] 178 | test_acc_interp_clever = [] 179 | for lam in tqdm(lambdas): 180 | clever_p = lerp(lam, model_a, model_b_clever) 181 | train_loss, train_acc = stuff["dataset_loss_and_accuracy"](clever_p, train_ds, 1000) 182 | test_loss, test_acc = stuff["dataset_loss_and_accuracy"](clever_p, test_ds, 1000) 183 | train_loss_interp_clever.append(train_loss) 184 | test_loss_interp_clever.append(test_loss) 185 | train_acc_interp_clever.append(train_acc) 186 | test_acc_interp_clever.append(test_acc) 187 | 188 | assert len(lambdas) == len(train_loss_interp_naive) 189 | assert len(lambdas) == len(test_loss_interp_naive) 190 | assert len(lambdas) == len(train_acc_interp_naive) 191 | assert len(lambdas) == len(test_acc_interp_naive) 192 | assert len(lambdas) == len(train_loss_interp_clever) 193 | assert len(lambdas) == len(test_loss_interp_clever) 194 | assert len(lambdas) == len(train_acc_interp_clever) 195 | assert len(lambdas) == len(test_acc_interp_clever) 196 | 197 | print("Plotting...") 198 | fig = plot_interp_loss(config.load_epoch, lambdas, train_loss_interp_naive, 199 | test_loss_interp_naive, train_loss_interp_clever, 200 | test_loss_interp_clever) 201 | plt.savefig(f"cifar10_resnet20_weight_matching_interp_loss_epoch{config.load_epoch}.png", 202 | dpi=300) 203 | wandb.log({"interp_loss_fig": wandb.Image(fig)}, commit=False) 204 | plt.close(fig) 205 | 206 | fig = plot_interp_acc(config.load_epoch, lambdas, train_acc_interp_naive, test_acc_interp_naive, 207 | train_acc_interp_clever, test_acc_interp_clever) 208 | plt.savefig(f"cifar10_resnet20_weight_matching_interp_accuracy_epoch{config.load_epoch}.png", 209 | dpi=300) 210 | wandb.log({"interp_acc_fig": wandb.Image(fig)}, commit=False) 211 | plt.close(fig) 212 | 213 | wandb.log({ 214 | "train_loss_interp_naive": train_loss_interp_naive, 215 | "test_loss_interp_naive": test_loss_interp_naive, 216 | "train_acc_interp_naive": train_acc_interp_naive, 217 | "test_acc_interp_naive": test_acc_interp_naive, 218 | "train_loss_interp_clever": train_loss_interp_clever, 219 | "test_loss_interp_clever": test_loss_interp_clever, 220 | "train_acc_interp_clever": train_acc_interp_clever, 221 | "test_acc_interp_clever": test_acc_interp_clever, 222 | }) 223 | 224 | print({ 225 | "train_loss_interp_naive": train_loss_interp_naive, 226 | "test_loss_interp_naive": test_loss_interp_naive, 227 | "train_acc_interp_naive": train_acc_interp_naive, 228 | "test_acc_interp_naive": test_acc_interp_naive, 229 | "train_loss_interp_clever": train_loss_interp_clever, 230 | "test_loss_interp_clever": test_loss_interp_clever, 231 | "train_acc_interp_clever": train_acc_interp_clever, 232 | "test_acc_interp_clever": test_acc_interp_clever, 233 | }) 234 | 235 | # if __name__ == "__main__": 236 | # main() 237 | -------------------------------------------------------------------------------- /src/cifar10_resnet20_width_ablation_plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import wandb 4 | 5 | import matplotlib_style as _ 6 | from plot_utils import loss_barrier_is_nonnegative 7 | 8 | if __name__ == "__main__": 9 | api = wandb.Api() 10 | wm1_run = api.run("skainswo/git-re-basin/61rq8cq7") 11 | wm2_run = api.run("skainswo/git-re-basin/2xer5nly") 12 | wm4_run = api.run("skainswo/git-re-basin/3gtdjmnf") 13 | wm8_run = api.run("skainswo/git-re-basin/3gq93t7g") 14 | wm16_run = api.run("skainswo/git-re-basin/2zeq7qy2") 15 | wm32_run = api.run("skainswo/git-re-basin/223t7txl") 16 | 17 | all_runs = [wm1_run, wm2_run, wm4_run, wm8_run, wm16_run, wm32_run] 18 | 19 | # fig = plt.figure() 20 | # ax = fig.add_subplot(111) 21 | # lambdas = np.linspace(0, 1, 25) 22 | # wm_glyphs = ["⅛", "¼", "½", "1", "2", "4"] 23 | # cmap = plt.get_cmap("YlOrRd") 24 | # for i, wm_glyph, run in zip(range(len(all_runs)), wm_glyphs, all_runs): 25 | # ys = np.array(run.summary["train_loss_interp_clever"]) 26 | # ys = ys - 0.5 * (ys[0] + ys[-1]) 27 | # ax.plot(lambdas, 28 | # ys, 29 | # color=cmap(0.25 + 0.75 * i / len(all_runs)), 30 | # linewidth=2, 31 | # label=f"{wm_glyph}× width") 32 | # ax.set_xlabel("$\lambda$") 33 | # ax.set_xticks([0, 1]) 34 | # ax.set_xticklabels(["Model $A$", "Model $B$"]) 35 | # ax.set_ylabel("Training loss barrier") 36 | # ax.set_title(f"CIFAR-10 ResNet Width Ablation") 37 | # ax.legend(loc="upper right", framealpha=0.5) 38 | # fig.tight_layout() 39 | 40 | # plt.savefig("figs/cifar10_resnet20_width_ablation_plot_train_loss.png", dpi=300) 41 | # plt.savefig("figs/cifar10_resnet20_width_ablation_plot_train_loss.pdf") 42 | # plt.savefig("figs/cifar10_resnet20_width_ablation_plot_train_loss.eps") 43 | 44 | # ############ 45 | # fig = plt.figure() 46 | # ax = fig.add_subplot(111) 47 | # lambdas = np.linspace(0, 1, 25) 48 | # wm_glyphs = ["⅛", "¼", "½", "1", "2", "4"] 49 | # cmap = plt.get_cmap("YlOrRd") 50 | # for i, wm_glyph, run in zip(range(len(all_runs)), wm_glyphs, all_runs): 51 | # ys = np.array(run.summary["test_loss_interp_clever"]) 52 | # ys = ys - 0.5 * (ys[0] + ys[-1]) 53 | # ax.plot(lambdas, 54 | # ys, 55 | # color=cmap(0.25 + 0.75 * i / len(all_runs)), 56 | # linewidth=2, 57 | # label=f"{wm_glyph}× width") 58 | # ax.set_xlabel("$\lambda$") 59 | # ax.set_xticks([0, 1]) 60 | # ax.set_xticklabels(["Model $A$", "Model $B$"]) 61 | # ax.set_ylabel("Test loss barrier") 62 | # ax.set_title(f"CIFAR-10 ResNet Width Ablation") 63 | # ax.legend(loc="upper right", framealpha=0.5) 64 | # fig.tight_layout() 65 | 66 | # plt.savefig("figs/cifar10_resnet20_width_ablation_plot_test_loss.png", dpi=300) 67 | # plt.savefig("figs/cifar10_resnet20_width_ablation_plot_test_loss.pdf") 68 | # plt.savefig("figs/cifar10_resnet20_width_ablation_plot_test_loss.eps") 69 | 70 | ### 71 | fig = plt.figure() 72 | # fig = plt.figure(figsize=(8, 6)) 73 | ax = fig.add_subplot(111) 74 | lambdas = np.linspace(0, 1, 25) 75 | wm_glyphs = ["1", "2", "4", "8", "16", "32"] 76 | 77 | train_barriers = [ 78 | max(run.summary["train_loss_interp_clever"]) - 0.5 * 79 | (run.summary["train_loss_interp_clever"][0] + run.summary["train_loss_interp_clever"][-1]) 80 | for run in all_runs 81 | ] 82 | test_barriers = [ 83 | max(run.summary["test_loss_interp_clever"]) - 0.5 * 84 | (run.summary["test_loss_interp_clever"][0] + run.summary["test_loss_interp_clever"][-1]) 85 | for run in all_runs 86 | ] 87 | 88 | ax.plot( 89 | train_barriers, 90 | marker="o", 91 | linewidth=2, 92 | label="Train", 93 | ) 94 | ax.plot( 95 | test_barriers, 96 | marker="^", 97 | linestyle="dashed", 98 | linewidth=2, 99 | label="Test", 100 | ) 101 | 102 | ax.arrow(5, 0, -0.75, 0.5, alpha=0.25) 103 | ins2 = ax.inset_axes((0.7, 0.3, 0.25, 0.25)) 104 | ins2.plot(all_runs[-1].summary["train_loss_interp_clever"]) 105 | ins2.plot(all_runs[-1].summary["test_loss_interp_clever"], linestyle="dashed") 106 | ins2.set_xticks([]) 107 | ins2.set_yticks([]) 108 | ymin, ymax = ins2.get_ylim() 109 | ins2.set_ylim((ymin - 0.2 * (ymax - ymin), ymax + 0.2 * (ymax - ymin))) 110 | 111 | loss_barrier_is_nonnegative(ax) 112 | 113 | ax.set_xlabel("Width multiplier") 114 | ax.set_xticks(range(len(all_runs))) 115 | ax.set_xticklabels([f"{x}×" for x in wm_glyphs]) 116 | ax.set_ylabel("Loss barrier") 117 | ax.set_title("ResNet20") 118 | ax.legend(loc="upper right", framealpha=0.5) 119 | fig.tight_layout() 120 | 121 | plt.savefig("figs/cifar10_resnet20_width_ablation_line_plot.png", dpi=300) 122 | plt.savefig("figs/cifar10_resnet20_width_ablation_line_plot.pdf") 123 | -------------------------------------------------------------------------------- /src/cifar10_vgg_weight_matching.py: -------------------------------------------------------------------------------- 1 | """ 2 | --width-multiplier=8 --model-a=v4 --model-b=v5 3 | --width-multiplier=16 --model-a=v6 --model-b=v7 4 | --width-multiplier=32 --model-a=v8 --model-b=v9 5 | --width-multiplier=64 --model-a=v12 --model-b=v13 6 | --width-multiplier=128 --model-a=v10 --model-b=v11 7 | --width-multiplier=256 --model-a=v16 --model-b=v18 8 | # --width-multiplier=512 --model-a= --model-b= 9 | """ 10 | import argparse 11 | import pickle 12 | from pathlib import Path 13 | 14 | import jax.numpy as jnp 15 | import matplotlib.pyplot as plt 16 | import wandb 17 | from flax.serialization import from_bytes 18 | from jax import random 19 | from tqdm import tqdm 20 | 21 | from cifar10_vgg_run import (VGG16Wide, init_train_state, make_stuff, make_vgg_width_ablation) 22 | from datasets import load_cifar10 23 | from utils import (ec2_get_instance_type, flatten_params, lerp, rngmix, unflatten_params) 24 | from weight_matching import (apply_permutation, vgg16_permutation_spec, weight_matching) 25 | 26 | def plot_interp_loss(epoch, lambdas, train_loss_interp_naive, test_loss_interp_naive, 27 | train_loss_interp_clever, test_loss_interp_clever): 28 | fig = plt.figure() 29 | ax = fig.add_subplot(111) 30 | ax.plot(lambdas, 31 | train_loss_interp_naive, 32 | linestyle="dashed", 33 | color="tab:blue", 34 | alpha=0.5, 35 | linewidth=2, 36 | label="Train, naïve interp.") 37 | ax.plot(lambdas, 38 | test_loss_interp_naive, 39 | linestyle="dashed", 40 | color="tab:orange", 41 | alpha=0.5, 42 | linewidth=2, 43 | label="Test, naïve interp.") 44 | ax.plot(lambdas, 45 | train_loss_interp_clever, 46 | linestyle="solid", 47 | color="tab:blue", 48 | linewidth=2, 49 | label="Train, permuted interp.") 50 | ax.plot(lambdas, 51 | test_loss_interp_clever, 52 | linestyle="solid", 53 | color="tab:orange", 54 | linewidth=2, 55 | label="Test, permuted interp.") 56 | ax.set_xlabel("$\lambda$") 57 | ax.set_xticks([0, 1]) 58 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 59 | ax.set_ylabel("Loss") 60 | # TODO label x=0 tick as \theta_1, and x=1 tick as \theta_2 61 | ax.set_title(f"Loss landscape between the two models (epoch {epoch})") 62 | ax.legend(loc="upper right", framealpha=0.5) 63 | fig.tight_layout() 64 | return fig 65 | 66 | def plot_interp_acc(epoch, lambdas, train_acc_interp_naive, test_acc_interp_naive, 67 | train_acc_interp_clever, test_acc_interp_clever): 68 | fig = plt.figure() 69 | ax = fig.add_subplot(111) 70 | ax.plot(lambdas, 71 | train_acc_interp_naive, 72 | linestyle="dashed", 73 | color="tab:blue", 74 | alpha=0.5, 75 | linewidth=2, 76 | label="Train, naïve interp.") 77 | ax.plot(lambdas, 78 | test_acc_interp_naive, 79 | linestyle="dashed", 80 | color="tab:orange", 81 | alpha=0.5, 82 | linewidth=2, 83 | label="Test, naïve interp.") 84 | ax.plot(lambdas, 85 | train_acc_interp_clever, 86 | linestyle="solid", 87 | color="tab:blue", 88 | linewidth=2, 89 | label="Train, permuted interp.") 90 | ax.plot(lambdas, 91 | test_acc_interp_clever, 92 | linestyle="solid", 93 | color="tab:orange", 94 | linewidth=2, 95 | label="Test, permuted interp.") 96 | ax.set_xlabel("$\lambda$") 97 | ax.set_xticks([0, 1]) 98 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 99 | ax.set_ylabel("Accuracy") 100 | # TODO label x=0 tick as \theta_1, and x=1 tick as \theta_2 101 | ax.set_title(f"Accuracy between the two models (epoch {epoch})") 102 | ax.legend(loc="lower right", framealpha=0.5) 103 | fig.tight_layout() 104 | return fig 105 | 106 | def main(): 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument("--model-a", type=str, required=True) 109 | parser.add_argument("--model-b", type=str, required=True) 110 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 111 | parser.add_argument("--width-multiplier", type=int, required=True) 112 | args = parser.parse_args() 113 | 114 | with wandb.init( 115 | project="git-re-basin", 116 | entity="skainswo", 117 | tags=["cifar10", "vgg16", "weight-matching"], 118 | job_type="analysis", 119 | ) as wandb_run: 120 | config = wandb.config 121 | config.ec2_instance_type = ec2_get_instance_type() 122 | config.model_a = args.model_a 123 | config.model_b = args.model_b 124 | config.seed = args.seed 125 | config.width_multiplier = args.width_multiplier 126 | config.load_epoch = 99 127 | 128 | # model = VGG16Wide() 129 | model = make_vgg_width_ablation(config.width_multiplier) 130 | 131 | def load_model(filepath): 132 | with open(filepath, "rb") as fh: 133 | return from_bytes( 134 | init_train_state(random.PRNGKey(0), 135 | model, 136 | learning_rate=-1, 137 | num_epochs=100, 138 | batch_size=100, 139 | num_train_examples=50_000), fh.read()) 140 | 141 | filename = f"checkpoint{config.load_epoch}" 142 | model_a = load_model( 143 | Path( 144 | wandb_run.use_artifact(f"cifar10-vgg-weights:{config.model_a}").get_path( 145 | filename).download())) 146 | model_b = load_model( 147 | Path( 148 | wandb_run.use_artifact(f"cifar10-vgg-weights:{config.model_b}").get_path( 149 | filename).download())) 150 | 151 | stuff = make_stuff(model) 152 | train_ds, test_ds = load_cifar10() 153 | 154 | permutation_spec = vgg16_permutation_spec() 155 | final_permutation = weight_matching(random.PRNGKey(config.seed), permutation_spec, 156 | flatten_params(model_a.params), 157 | flatten_params(model_b.params)) 158 | 159 | # Save final_permutation as an Artifact 160 | artifact = wandb.Artifact("model_b_permutation", 161 | type="permutation", 162 | metadata={ 163 | "dataset": "cifar10", 164 | "model": "vgg16" 165 | }) 166 | with artifact.new_file("permutation.pkl", mode="wb") as f: 167 | pickle.dump(final_permutation, f) 168 | wandb_run.log_artifact(artifact) 169 | 170 | lambdas = jnp.linspace(0, 1, num=25) 171 | train_loss_interp_naive = [] 172 | test_loss_interp_naive = [] 173 | train_acc_interp_naive = [] 174 | test_acc_interp_naive = [] 175 | for lam in tqdm(lambdas): 176 | naive_p = lerp(lam, model_a.params, model_b.params) 177 | train_loss, train_acc = stuff["dataset_loss_and_accuracy"](naive_p, train_ds, 1000) 178 | test_loss, test_acc = stuff["dataset_loss_and_accuracy"](naive_p, test_ds, 1000) 179 | train_loss_interp_naive.append(train_loss) 180 | test_loss_interp_naive.append(test_loss) 181 | train_acc_interp_naive.append(train_acc) 182 | test_acc_interp_naive.append(test_acc) 183 | 184 | model_b_clever = unflatten_params( 185 | apply_permutation(permutation_spec, final_permutation, flatten_params(model_b.params))) 186 | 187 | train_loss_interp_clever = [] 188 | test_loss_interp_clever = [] 189 | train_acc_interp_clever = [] 190 | test_acc_interp_clever = [] 191 | for lam in tqdm(lambdas): 192 | clever_p = lerp(lam, model_a.params, model_b_clever) 193 | train_loss, train_acc = stuff["dataset_loss_and_accuracy"](clever_p, train_ds, 1000) 194 | test_loss, test_acc = stuff["dataset_loss_and_accuracy"](clever_p, test_ds, 1000) 195 | train_loss_interp_clever.append(train_loss) 196 | test_loss_interp_clever.append(test_loss) 197 | train_acc_interp_clever.append(train_acc) 198 | test_acc_interp_clever.append(test_acc) 199 | 200 | assert len(lambdas) == len(train_loss_interp_naive) 201 | assert len(lambdas) == len(test_loss_interp_naive) 202 | assert len(lambdas) == len(train_acc_interp_naive) 203 | assert len(lambdas) == len(test_acc_interp_naive) 204 | assert len(lambdas) == len(train_loss_interp_clever) 205 | assert len(lambdas) == len(test_loss_interp_clever) 206 | assert len(lambdas) == len(train_acc_interp_clever) 207 | assert len(lambdas) == len(test_acc_interp_clever) 208 | 209 | print("Plotting...") 210 | fig = plot_interp_loss(config.load_epoch, lambdas, train_loss_interp_naive, 211 | test_loss_interp_naive, train_loss_interp_clever, 212 | test_loss_interp_clever) 213 | plt.savefig(f"cifar10_vgg16_weight_matching_interp_loss_epoch{config.load_epoch}.png", dpi=300) 214 | wandb.log({"interp_loss_fig": wandb.Image(fig)}, commit=False) 215 | plt.close(fig) 216 | 217 | fig = plot_interp_acc(config.load_epoch, lambdas, train_acc_interp_naive, test_acc_interp_naive, 218 | train_acc_interp_clever, test_acc_interp_clever) 219 | plt.savefig(f"cifar10_vgg16_weight_matching_interp_accuracy_epoch{config.load_epoch}.png", 220 | dpi=300) 221 | wandb.log({"interp_acc_fig": wandb.Image(fig)}, commit=False) 222 | plt.close(fig) 223 | 224 | wandb.log({ 225 | "train_loss_interp_naive": train_loss_interp_naive, 226 | "test_loss_interp_naive": test_loss_interp_naive, 227 | "train_acc_interp_naive": train_acc_interp_naive, 228 | "test_acc_interp_naive": test_acc_interp_naive, 229 | "train_loss_interp_clever": train_loss_interp_clever, 230 | "test_loss_interp_clever": test_loss_interp_clever, 231 | "train_acc_interp_clever": train_acc_interp_clever, 232 | "test_acc_interp_clever": test_acc_interp_clever, 233 | }) 234 | 235 | print({ 236 | "train_loss_interp_naive": train_loss_interp_naive, 237 | "test_loss_interp_naive": test_loss_interp_naive, 238 | "train_acc_interp_naive": train_acc_interp_naive, 239 | "test_acc_interp_naive": test_acc_interp_naive, 240 | "train_loss_interp_clever": train_loss_interp_clever, 241 | "test_loss_interp_clever": test_loss_interp_clever, 242 | "train_acc_interp_clever": train_acc_interp_clever, 243 | "test_acc_interp_clever": test_acc_interp_clever, 244 | }) 245 | 246 | if __name__ == "__main__": 247 | main() 248 | -------------------------------------------------------------------------------- /src/cifar10_vgg_width_ablation_plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import seaborn as sns 5 | import wandb 6 | 7 | import matplotlib_style as _ 8 | from plot_utils import loss_barrier_is_nonnegative 9 | 10 | if __name__ == "__main__": 11 | api = wandb.Api() 12 | wm8_run = api.run("skainswo/git-re-basin/i85jxkgf") 13 | wm16_run = api.run("skainswo/git-re-basin/gk003bct") 14 | wm32_run = api.run("skainswo/git-re-basin/1np06qag") 15 | wm64_run = api.run("skainswo/git-re-basin/31kskp4e") 16 | wm128_run = api.run("skainswo/git-re-basin/17huxv8g") 17 | wm256_run = api.run("skainswo/git-re-basin/37g4iks3") 18 | all_runs = [wm8_run, wm16_run, wm32_run, wm64_run, wm128_run, wm256_run] 19 | 20 | fig = plt.figure() 21 | ax = fig.add_subplot(111) 22 | lambdas = np.linspace(0, 1, 25) 23 | wm_glyphs = ["⅛", "¼", "½", "1", "2", "4"] 24 | cmap = plt.get_cmap("YlOrRd") 25 | for i, wm_glyph, run in zip(range(len(all_runs)), wm_glyphs, all_runs): 26 | ys = np.array(run.summary["train_loss_interp_clever"]) 27 | ys = ys - 0.5 * (ys[0] + ys[-1]) 28 | ax.plot(lambdas, 29 | ys, 30 | color=cmap(0.25 + 0.75 * i / len(all_runs)), 31 | linewidth=2, 32 | label=f"{wm_glyph}× width") 33 | ax.set_xlabel("$\lambda$") 34 | ax.set_xticks([0, 1]) 35 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 36 | ax.set_ylabel("Training loss barrier") 37 | ax.set_title(f"CIFAR-10 VGG-16 Width Ablation") 38 | ax.legend(loc="upper right", framealpha=0.5) 39 | fig.tight_layout() 40 | 41 | plt.savefig("figs/cifar10_vgg_width_ablation_plot_train_loss.png", dpi=300) 42 | plt.savefig("figs/cifar10_vgg_width_ablation_plot_train_loss.pdf") 43 | plt.savefig("figs/cifar10_vgg_width_ablation_plot_train_loss.eps") 44 | 45 | ############ 46 | fig = plt.figure() 47 | ax = fig.add_subplot(111) 48 | lambdas = np.linspace(0, 1, 25) 49 | wm_glyphs = ["⅛", "¼", "½", "1", "2", "4"] 50 | cmap = plt.get_cmap("YlOrRd") 51 | for i, wm_glyph, run in zip(range(len(all_runs)), wm_glyphs, all_runs): 52 | ys = np.array(run.summary["test_loss_interp_clever"]) 53 | ys = ys - 0.5 * (ys[0] + ys[-1]) 54 | ax.plot(lambdas, 55 | ys, 56 | color=cmap(0.25 + 0.75 * i / len(all_runs)), 57 | linewidth=2, 58 | label=f"{wm_glyph}× width") 59 | ax.set_xlabel("$\lambda$") 60 | ax.set_xticks([0, 1]) 61 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 62 | ax.set_ylabel("Test loss barrier") 63 | ax.set_title(f"CIFAR-10 VGG-16 Width Ablation") 64 | ax.legend(loc="upper right", framealpha=0.5) 65 | fig.tight_layout() 66 | 67 | plt.savefig("figs/cifar10_vgg_width_ablation_plot_test_loss.png", dpi=300) 68 | plt.savefig("figs/cifar10_vgg_width_ablation_plot_test_loss.pdf") 69 | plt.savefig("figs/cifar10_vgg_width_ablation_plot_test_loss.eps") 70 | 71 | ### 72 | fig = plt.figure() 73 | # fig = plt.figure(figsize=(8, 6)) 74 | ax = fig.add_subplot(111) 75 | lambdas = np.linspace(0, 1, 25) 76 | wm_glyphs = ["⅛", "¼", "½", "1", "2", "4"] 77 | 78 | train_barriers = [ 79 | max(run.summary["train_loss_interp_clever"]) - 0.5 * 80 | (run.summary["train_loss_interp_clever"][0] + run.summary["train_loss_interp_clever"][-1]) 81 | for run in all_runs 82 | ] 83 | test_barriers = [ 84 | max(run.summary["test_loss_interp_clever"]) - 0.5 * 85 | (run.summary["test_loss_interp_clever"][0] + run.summary["test_loss_interp_clever"][-1]) 86 | for run in all_runs 87 | ] 88 | 89 | ax.plot( 90 | train_barriers, 91 | marker="o", 92 | linewidth=2, 93 | label=f"Train", 94 | ) 95 | ax.plot( 96 | test_barriers, 97 | marker="^", 98 | linestyle="dashed", 99 | linewidth=2, 100 | label=f"Test", 101 | ) 102 | 103 | loss_barrier_is_nonnegative(ax) 104 | 105 | ax.set_xlabel("Width multiplier") 106 | ax.set_xticks(range(len(all_runs))) 107 | ax.set_xticklabels([f"{x}×" for x in wm_glyphs]) 108 | ax.set_ylabel("Loss barrier") 109 | ax.set_title(f"VGG-16") 110 | # ax.legend(loc="upper right", framealpha=0.5) 111 | fig.tight_layout() 112 | 113 | plt.savefig("figs/cifar10_vgg_width_ablation_line_plot.png", dpi=300) 114 | plt.savefig("figs/cifar10_vgg_width_ablation_line_plot.pdf") 115 | plt.savefig("figs/cifar10_vgg_width_ablation_line_plot.eps") 116 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow_datasets as tfds 3 | from torchvision import datasets, transforms 4 | import torch, os 5 | 6 | def load_cifar10(): 7 | """Return the training and test datasets, as jnp.array's.""" 8 | train_ds_images_u8, train_ds_labels = tfds.as_numpy( 9 | tfds.load("cifar10", split="train", batch_size=-1, as_supervised=True)) 10 | test_ds_images_u8, test_ds_labels = tfds.as_numpy( 11 | tfds.load("cifar10", split="test", batch_size=-1, as_supervised=True)) 12 | train_ds = {"images_u8": train_ds_images_u8, "labels": train_ds_labels} 13 | test_ds = {"images_u8": test_ds_images_u8, "labels": test_ds_labels} 14 | return train_ds, test_ds 15 | 16 | def load_cifar100(): 17 | train_ds_images_u8, train_ds_labels = tfds.as_numpy( 18 | tfds.load("cifar100", split="train", batch_size=-1, as_supervised=True)) 19 | test_ds_images_u8, test_ds_labels = tfds.as_numpy( 20 | tfds.load("cifar100", split="test", batch_size=-1, as_supervised=True)) 21 | train_ds = {"images_u8": train_ds_images_u8, "labels": train_ds_labels} 22 | test_ds = {"images_u8": test_ds_images_u8, "labels": test_ds_labels} 23 | return train_ds, test_ds 24 | 25 | def _split_cifar(train_ds, label_split: int): 26 | """Split a CIFAR-ish dataset into two biased subsets.""" 27 | assert train_ds["images_u8"].shape[0] == 50_000 28 | assert train_ds["labels"].shape[0] == 50_000 29 | 30 | # We randomly permute the training data, just in case there's some kind of 31 | # non-iid ordering coming out of tfds. 32 | perm = np.random.default_rng(123).permutation(50_000) 33 | train_images_u8 = train_ds["images_u8"][perm, :, :, :] 34 | train_labels = train_ds["labels"][perm] 35 | 36 | # This just so happens to be a clean 25000/25000 split. 37 | lt_images_u8 = train_images_u8[train_labels < label_split] 38 | lt_labels = train_labels[train_labels < label_split] 39 | gte_images_u8 = train_images_u8[train_labels >= label_split] 40 | gte_labels = train_labels[train_labels >= label_split] 41 | s1 = { 42 | "images_u8": np.concatenate((lt_images_u8[:5000], gte_images_u8[5000:]), axis=0), 43 | "labels": np.concatenate((lt_labels[:5000], gte_labels[5000:]), axis=0) 44 | } 45 | s2 = { 46 | "images_u8": np.concatenate((gte_images_u8[:5000], lt_images_u8[5000:]), axis=0), 47 | "labels": np.concatenate((gte_labels[:5000], lt_labels[5000:]), axis=0) 48 | } 49 | return s1, s2 50 | 51 | def load_cifar10_split(): 52 | train_ds, test_ds = load_cifar10() 53 | s1, s2 = _split_cifar(train_ds, label_split=5) 54 | return s1, s2, test_ds 55 | 56 | def load_cifar100_split(): 57 | train_ds, test_ds = load_cifar100() 58 | s1, s2 = _split_cifar(train_ds, label_split=50) 59 | return s1, s2, test_ds 60 | 61 | class ImageNet: 62 | def __init__(self): 63 | super(ImageNet, self).__init__() 64 | 65 | data_root = "/tmp" 66 | 67 | # Data loading code 68 | kwargs = {"num_workers": 4} 69 | 70 | # Data loading code 71 | traindir = os.path.join(data_root, "train") 72 | valdir = os.path.join(data_root, "val") 73 | 74 | normalize = transforms.Normalize( 75 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 76 | ) 77 | 78 | train_dataset = datasets.ImageFolder( 79 | traindir, 80 | transforms.Compose( 81 | [ 82 | transforms.RandomResizedCrop(224), 83 | transforms.RandomHorizontalFlip(), 84 | transforms.ToTensor(), 85 | normalize, 86 | ] 87 | ), 88 | ) 89 | 90 | self.train_loader = torch.utils.data.DataLoader( 91 | train_dataset, batch_size=1000, shuffle=True, **kwargs 92 | ) 93 | 94 | self.val_loader = torch.utils.data.DataLoader( 95 | datasets.ImageFolder( 96 | valdir, 97 | transforms.Compose( 98 | [ 99 | transforms.Resize(256), 100 | transforms.CenterCrop(224), 101 | transforms.ToTensor(), 102 | normalize, 103 | ] 104 | ), 105 | ), 106 | batch_size=1000, 107 | shuffle=False, 108 | **kwargs 109 | ) 110 | -------------------------------------------------------------------------------- /src/imagenet_resnet50_weight_matching.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from pathlib import Path 4 | import jax 5 | import torch 6 | torch.set_default_tensor_type(torch.FloatTensor) 7 | 8 | import jax.numpy as jnp 9 | import matplotlib.pyplot as plt 10 | import wandb 11 | from flax.serialization import from_bytes 12 | from jax import random 13 | from tqdm import tqdm 14 | 15 | from resnet import ResNet50 16 | from cifar10_resnet20_train import BLOCKS_PER_GROUP, ResNet, make_stuff 17 | from datasets import ImageNet 18 | from utils import ec2_get_instance_type, flatten_params, lerp, unflatten_params 19 | from weight_matching import ( 20 | apply_permutation, 21 | resnet50_permutation_spec, 22 | weight_matching, 23 | ) 24 | import pickle 25 | 26 | import numpy as np 27 | import torchmetrics 28 | import tqdm 29 | 30 | 31 | def plot_interp_loss( 32 | epoch, 33 | lambdas, 34 | train_loss_interp_naive, 35 | test_loss_interp_naive, 36 | train_loss_interp_clever, 37 | test_loss_interp_clever, 38 | ): 39 | fig = plt.figure() 40 | ax = fig.add_subplot(111) 41 | ax.plot( 42 | lambdas, 43 | train_loss_interp_naive, 44 | linestyle="dashed", 45 | color="tab:blue", 46 | alpha=0.5, 47 | linewidth=2, 48 | label="Train, naïve interp.", 49 | ) 50 | ax.plot( 51 | lambdas, 52 | test_loss_interp_naive, 53 | linestyle="dashed", 54 | color="tab:orange", 55 | alpha=0.5, 56 | linewidth=2, 57 | label="Test, naïve interp.", 58 | ) 59 | ax.plot( 60 | lambdas, 61 | train_loss_interp_clever, 62 | linestyle="solid", 63 | color="tab:blue", 64 | linewidth=2, 65 | label="Train, permuted interp.", 66 | ) 67 | ax.plot( 68 | lambdas, 69 | test_loss_interp_clever, 70 | linestyle="solid", 71 | color="tab:orange", 72 | linewidth=2, 73 | label="Test, permuted interp.", 74 | ) 75 | ax.set_xlabel("$\lambda$") 76 | ax.set_xticks([0, 1]) 77 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 78 | ax.set_ylabel("Loss") 79 | # TODO label x=0 tick as \theta_1, and x=1 tick as \theta_2 80 | ax.set_title(f"Loss landscape between the two models (epoch {epoch})") 81 | ax.legend(loc="upper right", framealpha=0.5) 82 | fig.tight_layout() 83 | return fig 84 | 85 | 86 | def plot_interp_acc( 87 | epoch, 88 | lambdas, 89 | train_acc_interp_naive, 90 | test_acc_interp_naive, 91 | train_acc_interp_clever, 92 | test_acc_interp_clever, 93 | ): 94 | fig = plt.figure() 95 | ax = fig.add_subplot(111) 96 | ax.plot( 97 | lambdas, 98 | train_acc_interp_naive, 99 | linestyle="dashed", 100 | color="tab:blue", 101 | alpha=0.5, 102 | linewidth=2, 103 | label="Train, naïve interp.", 104 | ) 105 | ax.plot( 106 | lambdas, 107 | test_acc_interp_naive, 108 | linestyle="dashed", 109 | color="tab:orange", 110 | alpha=0.5, 111 | linewidth=2, 112 | label="Test, naïve interp.", 113 | ) 114 | ax.plot( 115 | lambdas, 116 | train_acc_interp_clever, 117 | linestyle="solid", 118 | color="tab:blue", 119 | linewidth=2, 120 | label="Train, permuted interp.", 121 | ) 122 | ax.plot( 123 | lambdas, 124 | test_acc_interp_clever, 125 | linestyle="solid", 126 | color="tab:orange", 127 | linewidth=2, 128 | label="Test, permuted interp.", 129 | ) 130 | ax.set_xlabel("$\lambda$") 131 | ax.set_xticks([0, 1]) 132 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 133 | ax.set_ylabel("Accuracy") 134 | # TODO label x=0 tick as \theta_1, and x=1 tick as \theta_2 135 | ax.set_title(f"Accuracy between the two models (epoch {epoch})") 136 | ax.legend(loc="lower right", framealpha=0.5) 137 | fig.tight_layout() 138 | return fig 139 | 140 | 141 | if __name__ == "__main__": 142 | # parser = argparse.ArgumentParser() 143 | # parser.add_argument("--seed", type=int, default=0, help="Random seed") 144 | # args = parser.parse_args() 145 | 146 | model = ResNet50(n_classes=1000) 147 | 148 | def load_model(filepath): 149 | with open(filepath, "rb") as fh: 150 | return from_bytes( 151 | model.init(random.PRNGKey(0), jnp.zeros((1, 32, 32, 3)))["params"], 152 | fh.read(), 153 | ) 154 | 155 | model_a = pickle.load(open("../a_variables.pkl", "rb")) 156 | model_b = pickle.load(open("../b_variables.pkl", "rb")) 157 | 158 | # train_ds, test_ds = load_cifar10() 159 | 160 | permutation_spec = resnet50_permutation_spec() 161 | 162 | @jax.jit 163 | def model_apply1(variables, images): 164 | return model.apply(variables, images, mutable=["batch_stats"]) 165 | 166 | @jax.jit 167 | def model_apply2(variables, images): 168 | return model.apply(variables, images) 169 | 170 | accs = [] 171 | for key in range(10): 172 | final_permutation = weight_matching( 173 | random.PRNGKey(key), 174 | permutation_spec, 175 | flatten_params(model_a["params"]), 176 | flatten_params(model_b["params"]), 177 | ) 178 | 179 | model_b_clever = model_b["params"] 180 | model_b_clever = unflatten_params( 181 | apply_permutation(permutation_spec, final_permutation, flatten_params(model_b["params"]))) 182 | 183 | dataloader = ImageNet() 184 | loss = torch.nn.CrossEntropyLoss() 185 | 186 | # for lam in np.linspace(0, 1, 25): 187 | lam = 0.5 188 | model_ab = { 189 | "params": lerp(lam, model_a["params"], model_b_clever), 190 | # "params": model_a["params"], 191 | "batch_stats": model_a["batch_stats"] 192 | } 193 | # from flax.core import unfreeze 194 | # model_ab = unfreeze(pickle.load(open("final-perm-interp.pkl", "rb"))) 195 | 196 | for (images, labels), _ in zip(dataloader.train_loader, tqdm.trange(100)): 197 | images = jnp.moveaxis(jnp.asarray(images.numpy()), 1, 3) 198 | y, new_batch_stats = model_apply1(model_ab, images) 199 | model_ab["batch_stats"] = new_batch_stats['batch_stats'] 200 | # logits = torch.tensor(np.asarray(y)) 201 | # print((labels.unsqueeze(1) == logits.argmax(dim=1, keepdim=True)).sum(), loss(logits, labels)) 202 | 203 | acc1 = torchmetrics.Accuracy() 204 | acc5 = torchmetrics.Accuracy(top_k=5) 205 | sum_loss = 0 206 | for images, labels in dataloader.val_loader: 207 | images = jnp.moveaxis(jnp.asarray(images.numpy()), 1, 3) 208 | y = model_apply2(model_ab, images) 209 | logits = torch.tensor(np.asarray(y)) 210 | sum_loss += loss(logits, labels) 211 | acc1(logits, labels) 212 | acc5(logits, labels) 213 | 214 | # print({ 215 | # # "lambda": lam, 216 | # # "loss": (sum_loss / len(dataloader.val_loader)).item(), 217 | # "acc1": acc1.compute().item(), 218 | # "acc5": acc5.compute().item() 219 | # }) 220 | 221 | accs.append(acc1.compute().item()) 222 | 223 | 224 | # if __name__ == "__main__": 225 | # main() 226 | -------------------------------------------------------------------------------- /src/matplotlib_style.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import seaborn as sns 3 | 4 | matplotlib.rcParams["font.family"] = "serif" 5 | sns.set_context("talk") 6 | -------------------------------------------------------------------------------- /src/mnist_barrier_vs_epoch_matching.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | import jax.numpy as jnp 5 | import wandb 6 | from flax.serialization import from_bytes 7 | from jax import random 8 | from tqdm import tqdm 9 | 10 | from mnist_mlp_train import MLPModel, load_datasets, make_stuff 11 | from utils import ec2_get_instance_type, flatten_params, lerp, unflatten_params 12 | from weight_matching import (apply_permutation, mlp_permutation_spec, weight_matching) 13 | 14 | with wandb.init( 15 | project="git-re-basin", 16 | entity="skainswo", 17 | tags=["mnist", "mlp", "weight-matching", "barrier-vs-epoch"], 18 | job_type="analysis", 19 | ) as wandb_run: 20 | # api = wandb.Api() 21 | # seed0_run = api.run("skainswo/git-re-basin/1b1gztfx") 22 | # seed1_run = api.run("skainswo/git-re-basin/1hrmw7wr") 23 | 24 | config = wandb.config 25 | config.ec2_instance_type = ec2_get_instance_type() 26 | config.total_epochs = 100 27 | config.seed = 123 28 | 29 | model = MLPModel() 30 | stuff = make_stuff(model) 31 | 32 | def load_model(filepath): 33 | with open(filepath, "rb") as fh: 34 | return from_bytes( 35 | model.init(random.PRNGKey(0), jnp.zeros((1, 28, 28, 1)))["params"], fh.read()) 36 | 37 | seed0_artifact = Path(wandb_run.use_artifact("mnist-mlp-weights:v15").download()) 38 | seed1_artifact = Path(wandb_run.use_artifact("mnist-mlp-weights:v16").download()) 39 | 40 | permutation_spec = mlp_permutation_spec(3) 41 | 42 | def match_one_epoch(epoch: int): 43 | model_a = load_model(seed0_artifact / f"checkpoint{epoch}") 44 | model_b = load_model(seed1_artifact / f"checkpoint{epoch}") 45 | return weight_matching( 46 | random.PRNGKey(config.seed), 47 | permutation_spec, 48 | flatten_params(model_a), 49 | flatten_params(model_b), 50 | ) 51 | 52 | permutation_vs_epoch = [match_one_epoch(i) for i in tqdm(range(config.total_epochs))] 53 | 54 | artifact = wandb.Artifact("mnist_permutation_vs_epoch", 55 | type="permutation_vs_epoch", 56 | metadata={ 57 | "dataset": "mnist", 58 | "model": "mlp", 59 | "analysis": "weight-matching" 60 | }) 61 | with artifact.new_file("permutation_vs_epoch.pkl", mode="wb") as f: 62 | pickle.dump(permutation_vs_epoch, f) 63 | wandb_run.log_artifact(artifact) 64 | 65 | # Eval 66 | train_ds, test_ds = load_datasets() 67 | 68 | def eval_one(epoch, permutation): 69 | model_a = load_model(seed0_artifact / f"checkpoint{epoch}") 70 | model_b = load_model(seed1_artifact / f"checkpoint{epoch}") 71 | 72 | lambdas = jnp.linspace(0, 1, num=25) 73 | 74 | model_b_perm = unflatten_params( 75 | apply_permutation(permutation_spec, permutation, flatten_params(model_b))) 76 | 77 | naive_train_loss_interp = [] 78 | naive_test_loss_interp = [] 79 | naive_train_acc_interp = [] 80 | naive_test_acc_interp = [] 81 | for lam in lambdas: 82 | naive_p = lerp(lam, model_a, model_b) 83 | train_loss, train_acc = stuff["dataset_loss_and_accuracy"](naive_p, train_ds, 10_000) 84 | test_loss, test_acc = stuff["dataset_loss_and_accuracy"](naive_p, test_ds, 10_000) 85 | naive_train_loss_interp.append(train_loss) 86 | naive_test_loss_interp.append(test_loss) 87 | naive_train_acc_interp.append(train_acc) 88 | naive_test_acc_interp.append(test_acc) 89 | 90 | clever_train_loss_interp = [] 91 | clever_test_loss_interp = [] 92 | clever_train_acc_interp = [] 93 | clever_test_acc_interp = [] 94 | for lam in lambdas: 95 | clever_p = lerp(lam, model_a, model_b_perm) 96 | train_loss, train_acc = stuff["dataset_loss_and_accuracy"](clever_p, train_ds, 10_000) 97 | test_loss, test_acc = stuff["dataset_loss_and_accuracy"](clever_p, test_ds, 10_000) 98 | clever_train_loss_interp.append(train_loss) 99 | clever_test_loss_interp.append(test_loss) 100 | clever_train_acc_interp.append(train_acc) 101 | clever_test_acc_interp.append(test_acc) 102 | 103 | return { 104 | "naive_train_loss_interp": naive_train_loss_interp, 105 | "naive_test_loss_interp": naive_test_loss_interp, 106 | "naive_train_acc_interp": naive_train_acc_interp, 107 | "naive_test_acc_interp": naive_test_acc_interp, 108 | "clever_train_loss_interp": clever_train_loss_interp, 109 | "clever_test_loss_interp": clever_test_loss_interp, 110 | "clever_train_acc_interp": clever_train_acc_interp, 111 | "clever_test_acc_interp": clever_test_acc_interp, 112 | } 113 | 114 | interp_eval_vs_epoch = [eval_one(i, p) for i, p in tqdm(enumerate(permutation_vs_epoch))] 115 | 116 | artifact = wandb.Artifact("mnist_permutation_eval_vs_epoch", 117 | type="permutation_eval_vs_epoch", 118 | metadata={ 119 | "dataset": "mnist", 120 | "model": "mlp", 121 | "analysis": "weight-matching", 122 | "interpolation": "lerp" 123 | }) 124 | with artifact.new_file("permutation_eval_vs_epoch.pkl", mode="wb") as f: 125 | pickle.dump(interp_eval_vs_epoch, f) 126 | wandb_run.log_artifact(artifact) 127 | -------------------------------------------------------------------------------- /src/mnist_barrier_vs_epoch_plot.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import wandb 7 | 8 | import matplotlib_style as _ 9 | from plot_utils import loss_barrier_is_nonnegative 10 | 11 | max_epoch = 25 12 | 13 | api = wandb.Api() 14 | # run = api.run("skainswo/git-re-basin/begnvj15") 15 | artifact = Path( 16 | api.artifact("skainswo/git-re-basin/mnist_permutation_eval_vs_epoch:v0").download()) 17 | 18 | with open(artifact / "permutation_eval_vs_epoch.pkl", "rb") as f: 19 | interp_eval_vs_epoch = pickle.load(f) 20 | 21 | train_loss_interp = np.array([x["train_loss_interp"] for x in interp_eval_vs_epoch]) 22 | train_barrier_vs_epoch = np.max(train_loss_interp, 23 | axis=1) - 0.5 * (train_loss_interp[:, 0] + train_loss_interp[:, -1]) 24 | 25 | test_loss_interp = np.array([x["test_loss_interp"] for x in interp_eval_vs_epoch]) 26 | test_barrier_vs_epoch = np.max(test_loss_interp, 27 | axis=1) - 0.5 * (test_loss_interp[:, 0] + test_loss_interp[:, -1]) 28 | 29 | fig = plt.figure() 30 | # fig = plt.figure(figsize=(8, 4)) 31 | ax = fig.add_subplot(111) 32 | 33 | ax.arrow(5, 0.51, -4, 0.085, alpha=0.25) 34 | ins1 = ax.inset_axes((0.2, 0.7, 0.25, 0.25)) 35 | ins1.plot(train_loss_interp[0, :]) 36 | ins1.plot(test_loss_interp[0, :], linestyle="dashed") 37 | ins1.set_xticks([]) 38 | ins1.set_yticks([]) 39 | 40 | ax.arrow(21, 0.2, 4, -0.2, alpha=0.25) 41 | ins2 = ax.inset_axes((0.7, 0.3, 0.25, 0.25)) 42 | ins2.plot(train_loss_interp[25, :]) 43 | ins2.plot(test_loss_interp[25, :], linestyle="dashed") 44 | ins2.set_xticks([]) 45 | ins2.set_yticks([]) 46 | ymin, ymax = ins2.get_ylim() 47 | ins2.set_ylim((ymin - 0.2 * (ymax - ymin), ymax + 0.2 * (ymax - ymin))) 48 | 49 | ax.plot( 50 | 1 + np.arange(max_epoch), 51 | train_barrier_vs_epoch[:max_epoch], 52 | marker="o", 53 | linewidth=2, 54 | label="Train", 55 | ) 56 | ax.plot( 57 | 1 + np.arange(max_epoch), 58 | test_barrier_vs_epoch[:max_epoch], 59 | marker="^", 60 | linestyle="dashed", 61 | linewidth=2, 62 | label="Test", 63 | ) 64 | 65 | loss_barrier_is_nonnegative(ax) 66 | 67 | ax.set_xlabel("Epoch") 68 | ax.set_ylabel("Loss barrier") 69 | ax.set_title(f"MNIST") 70 | # ax.legend(loc="upper right", framealpha=0.5) 71 | fig.tight_layout() 72 | 73 | plt.savefig("figs/mnist_mlp_barrier_vs_epoch.png", dpi=300) 74 | plt.savefig("figs/mnist_mlp_barrier_vs_epoch.eps") 75 | plt.savefig("figs/mnist_mlp_barrier_vs_epoch.pdf") 76 | -------------------------------------------------------------------------------- /src/mnist_convnet_run.py: -------------------------------------------------------------------------------- 1 | """Train a convnet on MNIST on one random seed. Serialize the model for 2 | interpolation downstream.""" 3 | import argparse 4 | 5 | import jax.numpy as jnp 6 | import optax 7 | import tensorflow as tf 8 | import tensorflow_datasets as tfds 9 | from flax import linen as nn 10 | from flax.training.checkpoints import restore_checkpoint, save_checkpoint 11 | from flax.training.train_state import TrainState 12 | from jax import jit, random, value_and_grad 13 | from tqdm import tqdm 14 | 15 | import wandb 16 | from utils import RngPooper, ec2_get_instance_type, timeblock 17 | 18 | # See https://github.com/tensorflow/tensorflow/issues/53831. 19 | 20 | # See https://github.com/google/jax/issues/9454. 21 | tf.config.set_visible_devices([], "GPU") 22 | 23 | activation = nn.relu 24 | 25 | class TestModel(nn.Module): 26 | 27 | @nn.compact 28 | def __call__(self, x): 29 | x = nn.Conv(features=8, kernel_size=(3, 3))(x) 30 | x = activation(x) 31 | x = nn.Conv(features=16, kernel_size=(3, 3))(x) 32 | x = activation(x) 33 | x = nn.Conv(features=32, kernel_size=(3, 3))(x) 34 | x = activation(x) 35 | 36 | x = jnp.mean(x, axis=-1) 37 | x = jnp.reshape(x, (x.shape[0], -1)) 38 | x = nn.Dense(32)(x) 39 | x = activation(x) 40 | x = nn.Dense(10)(x) 41 | x = nn.log_softmax(x) 42 | return x 43 | 44 | class ConvNetModel(nn.Module): 45 | 46 | @nn.compact 47 | def __call__(self, x): 48 | x = nn.Conv(features=128, kernel_size=(3, 3))(x) 49 | x = activation(x) 50 | x = nn.Conv(features=128, kernel_size=(3, 3))(x) 51 | x = activation(x) 52 | x = nn.Conv(features=128, kernel_size=(3, 3))(x) 53 | x = activation(x) 54 | # Take the mean along the channel dimension. Otherwise the following dense 55 | # layer is massive. 56 | x = jnp.mean(x, axis=-1) 57 | x = jnp.reshape(x, (x.shape[0], -1)) 58 | x = nn.Dense(1024)(x) 59 | x = activation(x) 60 | x = nn.Dense(10)(x) 61 | x = nn.log_softmax(x) 62 | return x 63 | 64 | def make_stuff(model): 65 | ret = lambda: None 66 | 67 | @jit 68 | def batch_loss(params, images, y_onehot): 69 | logits = model.apply({"params": params}, images) 70 | return jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y_onehot)) 71 | 72 | @jit 73 | def batch_num_correct(params, images, y_onehot): 74 | logits = model.apply({"params": params}, images) 75 | return jnp.sum(jnp.argmax(logits, axis=-1) == jnp.argmax(y_onehot, axis=-1)) 76 | 77 | @jit 78 | def step(train_state, images, y_onehot): 79 | l, g = value_and_grad(batch_loss)(train_state.params, images, y_onehot) 80 | return train_state.apply_gradients(grads=g), l 81 | 82 | def dataset_loss(params, ds): 83 | # Note that we multiply by the batch size here, in order to get the sum of the 84 | # losses, then average over the whole dataset. 85 | return jnp.mean(jnp.array([x.shape[0] * batch_loss(params, x, y) for x, y in ds])) 86 | 87 | def dataset_total_correct(params, ds): 88 | return jnp.sum(jnp.array([batch_num_correct(params, x, y) for x, y in ds])) 89 | 90 | ret.batch_loss = batch_loss 91 | ret.batch_num_correct = batch_num_correct 92 | ret.step = step 93 | ret.dataset_loss = dataset_loss 94 | ret.dataset_total_correct = dataset_total_correct 95 | return ret 96 | 97 | def get_datasets(test_mode): 98 | """Return the training and test datasets, unbatched. 99 | 100 | test_mode: Whether or not we're running in "smoke test" mode. 101 | """ 102 | train_ds = tfds.load("mnist", split="train", as_supervised=True) 103 | test_ds = tfds.load("mnist", split="test", as_supervised=True) 104 | # Note: The take/cache warning: 105 | # 2022-01-25 07:32:58.144059: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 106 | # is not because we're actually doing this in the wrong order, but rather that 107 | # the dataset is loaded in and called .cache() on before we receive it. 108 | if test_mode: 109 | train_ds = train_ds.take(13) 110 | test_ds = test_ds.take(17) 111 | 112 | # Normalize 0-255 pixel values to 0.0-1.0 113 | normalize = lambda image, label: (tf.cast(image, tf.float32) / 255.0, tf.one_hot(label, depth=10)) 114 | train_ds = train_ds.map(normalize).cache() 115 | test_ds = test_ds.map(normalize).cache() 116 | return train_ds, test_ds 117 | 118 | def init_train_state(rng, learning_rate, model): 119 | tx = optax.adam(learning_rate) 120 | vars = model.init(rng, jnp.zeros((1, 28, 28, 1))) 121 | return TrainState.create(apply_fn=model.apply, params=vars["params"], tx=tx) 122 | 123 | if __name__ == "__main__": 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument("--test", action="store_true", help="Run in smoke-test mode") 126 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 127 | parser.add_argument("--resume", type=str, help="wandb run to resume from (eg. 1kqqa9js)") 128 | parser.add_argument("--resume-epoch", 129 | type=int, 130 | help="The epoch to resume from. Required if --resume is set.") 131 | args = parser.parse_args() 132 | 133 | wandb.init(project="git-re-basin", 134 | entity="skainswo", 135 | tags=["mnist", "convnet"], 136 | resume="must" if args.resume is not None else None, 137 | id=args.resume, 138 | mode="disabled" if args.test else "online") 139 | 140 | # Note: hopefully it's ok that we repeat this even when resuming a run? 141 | config = wandb.config 142 | config.ec2_instance_type = ec2_get_instance_type() 143 | config.test = args.test 144 | config.seed = args.seed 145 | config.learning_rate = 0.001 146 | config.num_epochs = 10 if config.test else 50 147 | config.batch_size = 7 if config.test else 512 148 | 149 | rp = RngPooper(random.PRNGKey(config.seed)) 150 | 151 | model = TestModel() if config.test else ConvNetModel() 152 | stuff = make_stuff(model) 153 | 154 | train_ds, test_ds = get_datasets(test_mode=config.test) 155 | num_train_examples = train_ds.cardinality().numpy() 156 | num_test_examples = test_ds.cardinality().numpy() 157 | 158 | train_state = init_train_state(rp.poop(), config.learning_rate, model) 159 | start_epoch = 0 160 | 161 | if args.resume is not None: 162 | # Bring the the desired resume epoch into the wandb run directory so that it 163 | # can then be picked up by `restore_checkpoint` below. 164 | wandb.restore(f"checkpoint_{args.resume_epoch}") 165 | last_epoch, train_state = restore_checkpoint(wandb.run.dir, (0, train_state)) 166 | # We need to increment last_epoch, because we store `(i, train_state)` 167 | # where `train_state` is the state _after_ i'th epoch. So we're actually 168 | # starting from the next epoch. 169 | start_epoch = last_epoch + 1 170 | 171 | for epoch in tqdm(range(start_epoch, config.num_epochs), 172 | initial=start_epoch, 173 | total=config.num_epochs): 174 | with timeblock(f"Epoch"): 175 | # Set the seed as a hash of the epoch and the overall random seed, so that 176 | # we ensure different seeds see different data orders, since tfds's random 177 | # seed is independent of our `RngPooper`. 178 | for images, labels in tfds.as_numpy( 179 | train_ds.shuffle(num_train_examples, 180 | seed=hash(f"{config.seed}-{epoch}")).batch(config.batch_size)): 181 | train_state, batch_loss = stuff.step(train_state, images, labels) 182 | 183 | train_ds_batched = tfds.as_numpy(train_ds.batch(config.batch_size)) 184 | test_ds_batched = tfds.as_numpy(test_ds.batch(config.batch_size)) 185 | 186 | # Evaluate train/test loss/accuracy 187 | with timeblock("Model eval"): 188 | train_loss = stuff.dataset_loss(train_state.params, train_ds_batched) 189 | test_loss = stuff.dataset_loss(train_state.params, test_ds_batched) 190 | train_accuracy = stuff.dataset_total_correct(train_state.params, 191 | train_ds_batched) / num_train_examples 192 | test_accuracy = stuff.dataset_total_correct(train_state.params, 193 | test_ds_batched) / num_test_examples 194 | 195 | if not config.test: 196 | # See https://docs.wandb.ai/guides/track/advanced/save-restore 197 | save_checkpoint(wandb.run.dir, (epoch, train_state), epoch, keep_every_n_steps=10) 198 | 199 | wandb.log({ 200 | "epoch": epoch, 201 | "train_loss": train_loss, 202 | "test_loss": test_loss, 203 | "train_accuracy": train_accuracy, 204 | "test_accuracy": test_accuracy, 205 | }) 206 | -------------------------------------------------------------------------------- /src/mnist_mlp_interp_plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import wandb 4 | 5 | import matplotlib_style as _ 6 | 7 | if __name__ == "__main__": 8 | api = wandb.Api() 9 | activation_matching_run = api.run("skainswo/git-re-basin/15tzxwm2") 10 | weight_matching_run = api.run("skainswo/git-re-basin/1i7s3fop") 11 | ste_matching_run = api.run("skainswo/git-re-basin/3tycauxs") 12 | 13 | ### Loss plot 14 | fig = plt.figure() 15 | ax = fig.add_subplot(111) 16 | lambdas = np.linspace(0, 1, 25) 17 | 18 | # Naive 19 | ax.plot(lambdas, 20 | np.array(activation_matching_run.summary["train_loss_interp_naive"]), 21 | color="grey", 22 | linewidth=2, 23 | label=f"Naïve") 24 | ax.plot(lambdas, 25 | np.array(activation_matching_run.summary["test_loss_interp_naive"]), 26 | color="grey", 27 | linewidth=2, 28 | linestyle="dashed") 29 | 30 | # Activation matching 31 | ax.plot(lambdas, 32 | np.array(activation_matching_run.summary["train_loss_interp_clever"]), 33 | color="tab:blue", 34 | marker="*", 35 | linewidth=2, 36 | label=f"Activation matching") 37 | ax.plot(lambdas, 38 | np.array(activation_matching_run.summary["test_loss_interp_clever"]), 39 | color="tab:blue", 40 | marker="*", 41 | linewidth=2, 42 | linestyle="dashed") 43 | 44 | # Weight matching 45 | ax.plot(lambdas, 46 | np.array(weight_matching_run.summary["train_loss_interp_clever"]), 47 | color="tab:green", 48 | marker="^", 49 | linewidth=2, 50 | label=f"Weight matching") 51 | ax.plot(lambdas, 52 | np.array(weight_matching_run.summary["test_loss_interp_clever"]), 53 | color="tab:green", 54 | marker="^", 55 | linestyle="dashed", 56 | linewidth=2) 57 | 58 | # STE matching 59 | ax.plot(lambdas, 60 | np.array(ste_matching_run.summary["train_loss_interp_clever"]), 61 | color="tab:red", 62 | marker="p", 63 | linewidth=2, 64 | label=f"STE matching") 65 | ax.plot(lambdas, 66 | np.array(ste_matching_run.summary["test_loss_interp_clever"]), 67 | color="tab:red", 68 | marker="p", 69 | linestyle="dashed", 70 | linewidth=2) 71 | 72 | ax.set_xlabel("$\lambda$") 73 | ax.set_xticks([0, 1]) 74 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 75 | ax.set_ylabel("Loss") 76 | ax.set_title(f"MNIST, MLP") 77 | ax.legend(loc="upper right", framealpha=0.5) 78 | fig.tight_layout() 79 | 80 | plt.savefig("figs/mnist_mlp_loss_interp.png", dpi=300) 81 | plt.savefig("figs/mnist_mlp_loss_interp.pdf") 82 | 83 | ### Accuracy plot 84 | fig = plt.figure() 85 | ax = fig.add_subplot(111) 86 | lambdas = np.linspace(0, 1, 25) 87 | 88 | # Naive 89 | ax.plot(lambdas, 90 | 100 * np.array(activation_matching_run.summary["train_acc_interp_naive"]), 91 | color="grey", 92 | linewidth=2, 93 | label="Train") 94 | ax.plot(lambdas, 95 | 100 * np.array(activation_matching_run.summary["test_acc_interp_naive"]), 96 | color="grey", 97 | linewidth=2, 98 | linestyle="dashed", 99 | label="Test") 100 | 101 | # Activation matching 102 | ax.plot(lambdas, 103 | 100 * np.array(activation_matching_run.summary["train_acc_interp_clever"]), 104 | color="tab:blue", 105 | marker="*", 106 | linewidth=2) 107 | ax.plot(lambdas, 108 | 100 * np.array(activation_matching_run.summary["test_acc_interp_clever"]), 109 | color="tab:blue", 110 | marker="*", 111 | linewidth=2, 112 | linestyle="dashed") 113 | 114 | # Weight matching 115 | ax.plot(lambdas, 116 | 100 * np.array(weight_matching_run.summary["train_acc_interp_clever"]), 117 | color="tab:green", 118 | marker="^", 119 | linewidth=2) 120 | ax.plot(lambdas, 121 | 100 * np.array(weight_matching_run.summary["test_acc_interp_clever"]), 122 | color="tab:green", 123 | marker="^", 124 | linestyle="dashed", 125 | linewidth=2) 126 | 127 | # STE matching 128 | ax.plot(lambdas, 129 | 100 * np.array(ste_matching_run.summary["train_acc_interp_clever"]), 130 | color="tab:red", 131 | marker="p", 132 | linewidth=2) 133 | ax.plot(lambdas, 134 | 100 * np.array(ste_matching_run.summary["test_acc_interp_clever"]), 135 | color="tab:red", 136 | marker="p", 137 | linestyle="dashed", 138 | linewidth=2) 139 | 140 | ax.set_xlabel("$\lambda$") 141 | ax.set_xticks([0, 1]) 142 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 143 | ax.set_ylabel("Accuracy") 144 | ax.set_title("MNIST, MLP") 145 | ax.legend(loc="lower right", framealpha=0.5) 146 | fig.tight_layout() 147 | 148 | plt.savefig("figs/mnist_mlp_accuracy_interp.png", dpi=300) 149 | plt.savefig("figs/mnist_mlp_accuracy_interp.pdf") 150 | -------------------------------------------------------------------------------- /src/mnist_mlp_loss_contour.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import jax.numpy as jnp 4 | import matplotlib 5 | import matplotlib.colors as colors 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import wandb 9 | from flax.serialization import from_bytes 10 | from jax import random 11 | from jax.flatten_util import ravel_pytree 12 | from matplotlib import tri 13 | from matplotlib.offsetbox import AnnotationBbox, OffsetImage 14 | from scipy.stats import qmc 15 | from tqdm import tqdm 16 | 17 | import matplotlib_style as _ 18 | from mnist_mlp_train import MLPModel, load_datasets, make_stuff 19 | from utils import (ec2_get_instance_type, flatten_params, lerp, timeblock, 20 | unflatten_params) 21 | from weight_matching import (apply_permutation, mlp_permutation_spec, 22 | weight_matching) 23 | 24 | matplotlib.rcParams["text.usetex"] = True 25 | 26 | with wandb.init( 27 | project="git-re-basin", 28 | entity="skainswo", 29 | tags=["mnist", "mlp", "weight-matching", "loss-contour"], 30 | job_type="analysis", 31 | ) as wandb_run: 32 | # api = wandb.Api() 33 | # seed0_run = api.run("skainswo/git-re-basin/1b1gztfx") 34 | # seed1_run = api.run("skainswo/git-re-basin/1hrmw7wr") 35 | 36 | config = wandb.config 37 | config.ec2_instance_type = ec2_get_instance_type() 38 | config.epoch = 99 39 | config.seed = 123 40 | config.num_eval_points = 2048 41 | 42 | model = MLPModel() 43 | stuff = make_stuff(model) 44 | 45 | def load_model(filepath): 46 | with open(filepath, "rb") as fh: 47 | return from_bytes( 48 | model.init(random.PRNGKey(0), jnp.zeros((1, 28, 28, 1)))["params"], fh.read()) 49 | 50 | model_a = load_model( 51 | Path( 52 | wandb_run.use_artifact("mnist-mlp-weights:v15").get_path( 53 | f"checkpoint{config.epoch}").download())) 54 | model_b = load_model( 55 | Path( 56 | wandb_run.use_artifact("mnist-mlp-weights:v16").get_path( 57 | f"checkpoint{config.epoch}").download())) 58 | 59 | permutation_spec = mlp_permutation_spec(3) 60 | 61 | with timeblock("weight_matching"): 62 | permutation = weight_matching( 63 | random.PRNGKey(config.seed), 64 | permutation_spec, 65 | flatten_params(model_a), 66 | flatten_params(model_b), 67 | ) 68 | 69 | # Eval 70 | train_ds, test_ds = load_datasets() 71 | 72 | model_b_rebasin = unflatten_params( 73 | apply_permutation(permutation_spec, permutation, flatten_params(model_b))) 74 | 75 | # We use model_a as the origin 76 | 77 | model_a_flat, unflatten = ravel_pytree(model_a) 78 | model_b_flat, _ = ravel_pytree(model_b) 79 | model_b_rebasin_flat, _ = ravel_pytree(model_b_rebasin) 80 | 81 | # project the vector a onto the vector b 82 | proj = lambda a, b: jnp.dot(a, b) / jnp.dot(b, b) * b 83 | norm = lambda a: jnp.sqrt(jnp.dot(a, a)) 84 | normalize = lambda a: a / norm(a) 85 | 86 | basis1 = model_b_flat - model_a_flat 87 | scale = norm(basis1) 88 | basis1_normed = normalize(basis1) 89 | 90 | a_to_pi_b = model_b_rebasin_flat - model_a_flat 91 | basis2 = a_to_pi_b - proj(a_to_pi_b, basis1) 92 | basis2_normed = normalize(basis2) 93 | 94 | project2d = lambda theta: jnp.array( 95 | [jnp.dot(theta - model_a_flat, basis1_normed), 96 | jnp.dot(theta - model_a_flat, basis2_normed)]) / scale 97 | 98 | eval_points = qmc.scale( 99 | qmc.Sobol(d=2, scramble=True, seed=config.seed).random(config.num_eval_points), [-0.5, -0.5], 100 | [1.5, 1.5]) 101 | 102 | def eval_one(xy): 103 | params = unflatten(model_a_flat + scale * (basis1_normed * xy[0] + basis2_normed * xy[1])) 104 | return stuff["dataset_loss_and_accuracy"](params, test_ds, 10_000) 105 | 106 | eval_results = jnp.array(list(map(eval_one, tqdm(eval_points)))) 107 | 108 | # Create grid values first. 109 | xi = np.linspace(-0.5, 1.5) 110 | yi = np.linspace(-0.5, 1.5) 111 | 112 | # Linearly interpolate the data (x, y) on a grid defined by (xi, yi). 113 | triang = tri.Triangulation(eval_points[:, 0], eval_points[:, 1]) 114 | # We need to cap the maximum loss value so that the contouring is not completely saturated by wildly large losses 115 | interpolator = tri.LinearTriInterpolator(triang, jnp.minimum(0.55, eval_results[:, 0])) 116 | # interpolator = tri.LinearTriInterpolator(triang, jnp.log(jnp.minimum(1.5, eval_results[:, 0]))) 117 | zi = interpolator(*np.meshgrid(xi, yi)) 118 | 119 | plt.figure() 120 | num_levels = 13 121 | plt.contour(xi, yi, zi, levels=num_levels, linewidths=0.25, colors="grey", alpha=0.5) 122 | # cmap_name = "RdGy" 123 | # cmap_name = "RdYlBu" 124 | # cmap_name = "Spectral" 125 | cmap_name = "coolwarm_r" 126 | 127 | # cmap_name = "YlOrBr_r" 128 | # cmap_name = "RdBu" 129 | 130 | # See https://stackoverflow.com/a/18926541/3880977 131 | def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100): 132 | return colors.LinearSegmentedColormap.from_list( 133 | 'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval), 134 | cmap(np.linspace(minval, maxval, n))) 135 | 136 | cmap = truncate_colormap(plt.get_cmap(cmap_name), 0.0, 0.9) 137 | plt.contourf(xi, yi, zi, levels=num_levels, cmap=cmap, extend="both") 138 | 139 | x, y = project2d(model_a_flat) 140 | plt.scatter([x], [y], marker="x", color="white", zorder=10) 141 | 142 | x, y = project2d(model_b_flat) 143 | plt.scatter([x], [y], marker="x", color="white", zorder=10) 144 | 145 | x, y = project2d(model_b_rebasin_flat) 146 | plt.scatter([x], [y], marker="x", color="white", zorder=10) 147 | 148 | label_bboxes = dict(facecolor="tab:grey", boxstyle="round", edgecolor="none", alpha=0.5) 149 | plt.text(-0.075, 150 | -0.1, 151 | r"${\bf \Theta_A}$", 152 | color="white", 153 | fontsize=24, 154 | bbox=label_bboxes, 155 | horizontalalignment="right", 156 | verticalalignment="top") 157 | plt.text(1.075, 158 | -0.1, 159 | r"${\bf \Theta_B}$", 160 | color="white", 161 | fontsize=24, 162 | bbox=label_bboxes, 163 | horizontalalignment="left", 164 | verticalalignment="top") 165 | x, y = project2d(model_b_rebasin_flat) 166 | plt.text(x - 0.075, 167 | y + 0.1, 168 | r"${\bf \pi(\Theta_B)}$", 169 | color="white", 170 | fontsize=24, 171 | bbox=label_bboxes, 172 | horizontalalignment="right", 173 | verticalalignment="bottom") 174 | 175 | # https://github.com/matplotlib/matplotlib/issues/17284#issuecomment-772820638 176 | # Draw line only 177 | connectionstyle = "arc3,rad=-0.3" 178 | plt.annotate("", 179 | xy=(1, 0), 180 | xytext=(x, y), 181 | arrowprops=dict(arrowstyle="-", 182 | edgecolor="white", 183 | facecolor="none", 184 | linewidth=5, 185 | linestyle=(0, (5, 3)), 186 | shrinkA=20, 187 | shrinkB=15, 188 | connectionstyle=connectionstyle)) 189 | # Draw arrow head only 190 | plt.annotate("", 191 | xy=(1, 0), 192 | xytext=(x, y), 193 | arrowprops=dict(arrowstyle="<|-", 194 | edgecolor="none", 195 | facecolor="white", 196 | mutation_scale=40, 197 | linewidth=0, 198 | shrinkA=12.5, 199 | shrinkB=15, 200 | connectionstyle=connectionstyle)) 201 | 202 | plt.annotate("", 203 | xy=(0, 0), 204 | xytext=(x, y), 205 | arrowprops=dict(arrowstyle="-", 206 | edgecolor="white", 207 | alpha=0.5, 208 | facecolor="none", 209 | linewidth=2, 210 | linestyle="-", 211 | shrinkA=10, 212 | shrinkB=10)) 213 | plt.annotate("", 214 | xy=(0, 0), 215 | xytext=(1, 0), 216 | arrowprops=dict(arrowstyle="-", 217 | edgecolor="white", 218 | alpha=0.5, 219 | facecolor="none", 220 | linewidth=2, 221 | linestyle="-", 222 | shrinkA=10, 223 | shrinkB=10)) 224 | 225 | plt.gca().add_artist( 226 | AnnotationBbox(OffsetImage(plt.imread( 227 | "https://emojipedia-us.s3.dualstack.us-west-1.amazonaws.com/thumbs/240/apple/325/check-mark-button_2705.png" 228 | ), 229 | zoom=0.1), (x / 2, y / 2), 230 | frameon=False)) 231 | plt.gca().add_artist( 232 | AnnotationBbox(OffsetImage(plt.imread( 233 | "https://emojipedia-us.s3.dualstack.us-west-1.amazonaws.com/thumbs/240/apple/325/cross-mark_274c.png" 234 | ), 235 | zoom=0.1), (0.5, 0), 236 | frameon=False)) 237 | 238 | # "Git Re-Basin" box 239 | # box_x = 0.5 * (arrow_start[0] + arrow_stop[0]) 240 | # box_y = 0.5 * (arrow_start[1] + arrow_stop[1]) 241 | # box_x = 0.5 * (arrow_start[0] + arrow_stop[0]) + 0.325 242 | # box_y = 0.5 * (arrow_start[1] + arrow_stop[1]) + 0.2 243 | 244 | box_x = 0.5 245 | box_y = 1.3 246 | # git_rebasin_text = r"\textsc{Git Re-Basin}" 247 | git_rebasin_text = r"\textbf{Git Re-Basin}" 248 | # git_rebasin_text = r"\texttt{\textdollar{} git re-basin}" 249 | 250 | # Draw box only 251 | plt.text(box_x, 252 | box_y, 253 | git_rebasin_text, 254 | color=(0.0, 0.0, 0.0, 0.0), 255 | fontsize=24, 256 | horizontalalignment="center", 257 | verticalalignment="center", 258 | bbox=dict(boxstyle="round", fc=(1, 1, 1, 1), ec="black", pad=0.4)) 259 | # Draw text only 260 | plt.text(box_x, 261 | box_y - 0.0115, 262 | git_rebasin_text, 263 | color=(0.0, 0.0, 0.0, 1.0), 264 | fontsize=24, 265 | horizontalalignment="center", 266 | verticalalignment="center") 267 | 268 | # plt.colorbar() 269 | plt.xlim(-0.4, 1.4) 270 | plt.ylim(-0.45, 1.3) 271 | # plt.xlim(-0.9, 1.9) 272 | # plt.ylim(-0.9, 1.9) 273 | plt.xticks([]) 274 | plt.yticks([]) 275 | plt.tight_layout() 276 | plt.savefig("figs/mnist_mlp_loss_contour.png", dpi=300) 277 | plt.savefig("figs/mnist_mlp_loss_contour.pdf") 278 | -------------------------------------------------------------------------------- /src/mnist_mlp_train.py: -------------------------------------------------------------------------------- 1 | """Train an MLP on MNIST on one random seed. Serialize the model for 2 | interpolation downstream.""" 3 | import argparse 4 | 5 | import augmax 6 | import flax 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | import optax 11 | import tensorflow as tf 12 | import tensorflow_datasets as tfds 13 | import wandb 14 | from flax import linen as nn 15 | from flax.training.train_state import TrainState 16 | from jax import jit, random, tree_map, value_and_grad, vmap 17 | from tqdm import tqdm 18 | 19 | from utils import ec2_get_instance_type, flatten_params, rngmix, timeblock 20 | 21 | # See https://github.com/tensorflow/tensorflow/issues/53831. 22 | 23 | # See https://github.com/google/jax/issues/9454. 24 | tf.config.set_visible_devices([], "GPU") 25 | 26 | activation = nn.relu 27 | 28 | class MLPModel(nn.Module): 29 | 30 | @nn.compact 31 | def __call__(self, x): 32 | x = jnp.reshape(x, (-1, 28 * 28)) 33 | x = nn.Dense(512)(x) 34 | x = activation(x) 35 | x = nn.Dense(512)(x) 36 | x = activation(x) 37 | x = nn.Dense(512)(x) 38 | x = activation(x) 39 | x = nn.Dense(10)(x) 40 | x = nn.log_softmax(x) 41 | return x 42 | 43 | def make_stuff(model): 44 | normalize_transform = augmax.ByteToFloat() 45 | 46 | @jit 47 | def batch_eval(params, images_u8, labels): 48 | images_f32 = vmap(normalize_transform)(None, images_u8) 49 | logits = model.apply({"params": params}, images_f32) 50 | y_onehot = jax.nn.one_hot(labels, 10) 51 | loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y_onehot)) 52 | num_correct = jnp.sum(jnp.argmax(logits, axis=-1) == jnp.argmax(y_onehot, axis=-1)) 53 | return loss, {"num_correct": num_correct} 54 | 55 | @jit 56 | def step(train_state, images_f32, labels): 57 | (l, info), g = value_and_grad(batch_eval, has_aux=True)(train_state.params, images_f32, labels) 58 | return train_state.apply_gradients(grads=g), {"batch_loss": l, **info} 59 | 60 | def dataset_loss_and_accuracy(params, dataset, batch_size: int): 61 | num_examples = dataset["images_u8"].shape[0] 62 | assert num_examples % batch_size == 0 63 | num_batches = num_examples // batch_size 64 | batch_ix = jnp.arange(num_examples).reshape((num_batches, batch_size)) 65 | # Can't use vmap or run in a single batch since that overloads GPU memory. 66 | losses, infos = zip(*[ 67 | batch_eval( 68 | params, 69 | dataset["images_u8"][batch_ix[i, :], :, :, :], 70 | dataset["labels"][batch_ix[i, :]], 71 | ) for i in range(num_batches) 72 | ]) 73 | return ( 74 | jnp.sum(batch_size * jnp.array(losses)) / num_examples, 75 | sum(x["num_correct"] for x in infos) / num_examples, 76 | ) 77 | 78 | return { 79 | "normalize_transform": normalize_transform, 80 | "batch_eval": batch_eval, 81 | "step": step, 82 | "dataset_loss_and_accuracy": dataset_loss_and_accuracy 83 | } 84 | 85 | def load_datasets(): 86 | """Return the training and test datasets, unbatched.""" 87 | # See https://www.tensorflow.org/datasets/overview#as_batched_tftensor_batch_size-1. 88 | train_ds_images_u8, train_ds_labels = tfds.as_numpy( 89 | tfds.load("mnist", split="train", batch_size=-1, as_supervised=True)) 90 | test_ds_images_u8, test_ds_labels = tfds.as_numpy( 91 | tfds.load("mnist", split="test", batch_size=-1, as_supervised=True)) 92 | train_ds = {"images_u8": train_ds_images_u8, "labels": train_ds_labels} 93 | test_ds = {"images_u8": test_ds_images_u8, "labels": test_ds_labels} 94 | return train_ds, test_ds 95 | 96 | def main(): 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("--test", action="store_true", help="Run in smoke-test mode") 99 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 100 | parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], required=True) 101 | parser.add_argument("--learning-rate", type=float, required=True) 102 | args = parser.parse_args() 103 | 104 | with wandb.init( 105 | project="git-re-basin", 106 | entity="skainswo", 107 | tags=["mnist", "mlp", "training"], 108 | mode="disabled" if args.test else "online", 109 | job_type="train", 110 | ) as wandb_run: 111 | artifact = wandb.Artifact("mnist-mlp-weights", type="model-weights") 112 | 113 | config = wandb.config 114 | config.ec2_instance_type = ec2_get_instance_type() 115 | config.test = args.test 116 | config.seed = args.seed 117 | config.optimizer = args.optimizer 118 | config.learning_rate = args.learning_rate 119 | config.num_epochs = 100 120 | config.batch_size = 500 121 | 122 | rng = random.PRNGKey(config.seed) 123 | 124 | model = MLPModel() 125 | stuff = make_stuff(model) 126 | 127 | with timeblock("load_datasets"): 128 | train_ds, test_ds = load_datasets() 129 | print("train_ds labels hash", hash(np.array(train_ds["labels"]).tobytes())) 130 | print("test_ds labels hash", hash(np.array(test_ds["labels"]).tobytes())) 131 | 132 | num_train_examples = train_ds["images_u8"].shape[0] 133 | num_test_examples = test_ds["images_u8"].shape[0] 134 | assert num_train_examples % config.batch_size == 0 135 | print("num_train_examples", num_train_examples) 136 | print("num_test_examples", num_test_examples) 137 | 138 | if config.optimizer == "sgd": 139 | # See runs: 140 | # * https://wandb.ai/skainswo/git-re-basin/runs/3blb4uhm 141 | # * https://wandb.ai/skainswo/git-re-basin/runs/174j7umt 142 | # * https://wandb.ai/skainswo/git-re-basin/runs/td02y8gg 143 | lr_schedule = optax.warmup_cosine_decay_schedule( 144 | init_value=1e-6, 145 | peak_value=config.learning_rate, 146 | warmup_steps=10, 147 | # Confusingly, `decay_steps` is actually the total number of steps, 148 | # including the warmup. 149 | decay_steps=config.num_epochs * (num_train_examples // config.batch_size), 150 | ) 151 | tx = optax.sgd(lr_schedule, momentum=0.9) 152 | elif config.optimizer == "adam": 153 | # See runs: 154 | # - https://wandb.ai/skainswo/git-re-basin/runs/1b1gztfx (trim-fire-575) 155 | # - https://wandb.ai/skainswo/git-re-basin/runs/1hrmw7wr (wild-dream-576) 156 | tx = optax.adam(config.learning_rate) 157 | else: 158 | # See runs: 159 | # - https://wandb.ai/skainswo/git-re-basin/runs/k4luj7er (faithful-spaceship-579) 160 | # - https://wandb.ai/skainswo/git-re-basin/runs/3ru7xy8c (sage-forest-580) 161 | tx = optax.adamw(config.learning_rate, weight_decay=1e-4) 162 | 163 | train_state = TrainState.create( 164 | apply_fn=model.apply, 165 | params=model.init(rngmix(rng, "init"), jnp.zeros((1, 28, 28, 1)))["params"], 166 | tx=tx, 167 | ) 168 | 169 | for epoch in tqdm(range(config.num_epochs)): 170 | infos = [] 171 | with timeblock(f"Epoch"): 172 | batch_ix = random.permutation(rngmix(rng, f"epoch-{epoch}"), num_train_examples).reshape( 173 | (-1, config.batch_size)) 174 | for i in range(batch_ix.shape[0]): 175 | p = batch_ix[i, :] 176 | images_u8 = train_ds["images_u8"][p, :, :, :] 177 | labels = train_ds["labels"][p] 178 | train_state, info = stuff["step"](train_state, images_u8, labels) 179 | infos.append(info) 180 | 181 | train_loss = sum(config.batch_size * x["batch_loss"] for x in infos) / num_train_examples 182 | train_accuracy = sum(x["num_correct"] for x in infos) / num_train_examples 183 | 184 | # Evaluate train/test loss/accuracy 185 | with timeblock("Test set eval"): 186 | test_loss, test_accuracy = stuff["dataset_loss_and_accuracy"](train_state.params, test_ds, 187 | 10_000) 188 | 189 | params_l2 = tree_map(lambda x: jnp.sqrt(jnp.sum(x**2)), 190 | flatten_params({"params_l2": train_state.params})) 191 | 192 | # See https://github.com/wandb/client/issues/3690. 193 | wandb_run.log({ 194 | "epoch": epoch, 195 | "train_loss": train_loss, 196 | "test_loss": test_loss, 197 | "train_accuracy": train_accuracy, 198 | "test_accuracy": test_accuracy, 199 | **params_l2 200 | }) 201 | 202 | # With layer width 512, the MLP is 3.7MB per checkpoint. 203 | with timeblock("model serialization"): 204 | with artifact.new_file(f"checkpoint{epoch}", mode="wb") as f: 205 | f.write(flax.serialization.to_bytes(train_state.params)) 206 | 207 | # This will be a no-op when config.test is enabled anyhow, since wandb will 208 | # be initialized with mode="disabled". 209 | wandb_run.log_artifact(artifact) 210 | 211 | if __name__ == "__main__": 212 | main() 213 | -------------------------------------------------------------------------------- /src/mnist_mlp_weight_matching.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from pathlib import Path 4 | 5 | import jax.numpy as jnp 6 | import matplotlib.pyplot as plt 7 | import wandb 8 | from flax.serialization import from_bytes 9 | from jax import random 10 | from tqdm import tqdm 11 | 12 | from mnist_mlp_train import MLPModel, load_datasets, make_stuff 13 | from utils import ec2_get_instance_type, flatten_params, lerp, unflatten_params 14 | from weight_matching import (apply_permutation, mlp_permutation_spec, weight_matching) 15 | 16 | def plot_interp_loss(epoch, lambdas, train_loss_interp_naive, test_loss_interp_naive, 17 | train_loss_interp_clever, test_loss_interp_clever): 18 | fig = plt.figure() 19 | ax = fig.add_subplot(111) 20 | ax.plot(lambdas, 21 | train_loss_interp_naive, 22 | linestyle="dashed", 23 | color="tab:blue", 24 | alpha=0.5, 25 | linewidth=2, 26 | label="Train, naïve interp.") 27 | ax.plot(lambdas, 28 | test_loss_interp_naive, 29 | linestyle="dashed", 30 | color="tab:orange", 31 | alpha=0.5, 32 | linewidth=2, 33 | label="Test, naïve interp.") 34 | ax.plot(lambdas, 35 | train_loss_interp_clever, 36 | linestyle="solid", 37 | color="tab:blue", 38 | linewidth=2, 39 | label="Train, permuted interp.") 40 | ax.plot(lambdas, 41 | test_loss_interp_clever, 42 | linestyle="solid", 43 | color="tab:orange", 44 | linewidth=2, 45 | label="Test, permuted interp.") 46 | ax.set_xlabel("$\lambda$") 47 | ax.set_xticks([0, 1]) 48 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 49 | ax.set_ylabel("Loss") 50 | # TODO label x=0 tick as \theta_1, and x=1 tick as \theta_2 51 | ax.set_title(f"Loss landscape between the two models (epoch {epoch})") 52 | ax.legend(loc="upper right", framealpha=0.5) 53 | fig.tight_layout() 54 | return fig 55 | 56 | def plot_interp_acc(epoch, lambdas, train_acc_interp_naive, test_acc_interp_naive, 57 | train_acc_interp_clever, test_acc_interp_clever): 58 | fig = plt.figure() 59 | ax = fig.add_subplot(111) 60 | ax.plot(lambdas, 61 | train_acc_interp_naive, 62 | linestyle="dashed", 63 | color="tab:blue", 64 | alpha=0.5, 65 | linewidth=2, 66 | label="Train, naïve interp.") 67 | ax.plot(lambdas, 68 | test_acc_interp_naive, 69 | linestyle="dashed", 70 | color="tab:orange", 71 | alpha=0.5, 72 | linewidth=2, 73 | label="Test, naïve interp.") 74 | ax.plot(lambdas, 75 | train_acc_interp_clever, 76 | linestyle="solid", 77 | color="tab:blue", 78 | linewidth=2, 79 | label="Train, permuted interp.") 80 | ax.plot(lambdas, 81 | test_acc_interp_clever, 82 | linestyle="solid", 83 | color="tab:orange", 84 | linewidth=2, 85 | label="Test, permuted interp.") 86 | ax.set_xlabel("$\lambda$") 87 | ax.set_xticks([0, 1]) 88 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 89 | ax.set_ylabel("Accuracy") 90 | # TODO label x=0 tick as \theta_1, and x=1 tick as \theta_2 91 | ax.set_title(f"Accuracy between the two models (epoch {epoch})") 92 | ax.legend(loc="lower right", framealpha=0.5) 93 | fig.tight_layout() 94 | return fig 95 | 96 | def main(): 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("--model-a", type=str, required=True) 99 | parser.add_argument("--model-b", type=str, required=True) 100 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 101 | args = parser.parse_args() 102 | 103 | with wandb.init( 104 | project="git-re-basin", 105 | entity="skainswo", 106 | tags=["mnist", "mlp", "weight-matching"], 107 | job_type="analysis", 108 | ) as wandb_run: 109 | config = wandb.config 110 | config.ec2_instance_type = ec2_get_instance_type() 111 | config.model_a = args.model_a 112 | config.model_b = args.model_b 113 | config.seed = args.seed 114 | config.load_epoch = 99 115 | 116 | model = MLPModel() 117 | stuff = make_stuff(model) 118 | 119 | def load_model(filepath): 120 | with open(filepath, "rb") as fh: 121 | return from_bytes( 122 | model.init(random.PRNGKey(0), jnp.zeros((1, 28, 28, 1)))["params"], fh.read()) 123 | 124 | filename = f"checkpoint{config.load_epoch}" 125 | model_a = load_model( 126 | Path( 127 | wandb_run.use_artifact(f"mnist-mlp-weights:{config.model_a}").get_path( 128 | filename).download())) 129 | model_b = load_model( 130 | Path( 131 | wandb_run.use_artifact(f"mnist-mlp-weights:{config.model_b}").get_path( 132 | filename).download())) 133 | 134 | train_ds, test_ds = load_datasets() 135 | 136 | permutation_spec = mlp_permutation_spec(3) 137 | final_permutation = weight_matching(random.PRNGKey(config.seed), permutation_spec, 138 | flatten_params(model_a), flatten_params(model_b)) 139 | 140 | # Save final_permutation as an Artifact 141 | artifact = wandb.Artifact("mnist_mlp_weight_matching", 142 | type="permutation", 143 | metadata={ 144 | "dataset": "mnist", 145 | "model": "mlp", 146 | "analysis": "weight-matching" 147 | }) 148 | with artifact.new_file("permutation.pkl", mode="wb") as f: 149 | pickle.dump(final_permutation, f) 150 | wandb_run.log_artifact(artifact) 151 | 152 | lambdas = jnp.linspace(0, 1, num=25) 153 | train_loss_interp_naive = [] 154 | test_loss_interp_naive = [] 155 | train_acc_interp_naive = [] 156 | test_acc_interp_naive = [] 157 | for lam in tqdm(lambdas): 158 | naive_p = lerp(lam, model_a, model_b) 159 | train_loss, train_acc = stuff["dataset_loss_and_accuracy"](naive_p, train_ds, 10_000) 160 | test_loss, test_acc = stuff["dataset_loss_and_accuracy"](naive_p, test_ds, 10_000) 161 | train_loss_interp_naive.append(train_loss) 162 | test_loss_interp_naive.append(test_loss) 163 | train_acc_interp_naive.append(train_acc) 164 | test_acc_interp_naive.append(test_acc) 165 | 166 | model_b_clever = unflatten_params( 167 | apply_permutation(permutation_spec, final_permutation, flatten_params(model_b))) 168 | 169 | train_loss_interp_clever = [] 170 | test_loss_interp_clever = [] 171 | train_acc_interp_clever = [] 172 | test_acc_interp_clever = [] 173 | for lam in tqdm(lambdas): 174 | clever_p = lerp(lam, model_a, model_b_clever) 175 | train_loss, train_acc = stuff["dataset_loss_and_accuracy"](clever_p, train_ds, 10_000) 176 | test_loss, test_acc = stuff["dataset_loss_and_accuracy"](clever_p, test_ds, 10_000) 177 | train_loss_interp_clever.append(train_loss) 178 | test_loss_interp_clever.append(test_loss) 179 | train_acc_interp_clever.append(train_acc) 180 | test_acc_interp_clever.append(test_acc) 181 | 182 | assert len(lambdas) == len(train_loss_interp_naive) 183 | assert len(lambdas) == len(test_loss_interp_naive) 184 | assert len(lambdas) == len(train_acc_interp_naive) 185 | assert len(lambdas) == len(test_acc_interp_naive) 186 | assert len(lambdas) == len(train_loss_interp_clever) 187 | assert len(lambdas) == len(test_loss_interp_clever) 188 | assert len(lambdas) == len(train_acc_interp_clever) 189 | assert len(lambdas) == len(test_acc_interp_clever) 190 | 191 | print("Plotting...") 192 | fig = plot_interp_loss(config.load_epoch, lambdas, train_loss_interp_naive, 193 | test_loss_interp_naive, train_loss_interp_clever, 194 | test_loss_interp_clever) 195 | plt.savefig(f"mnist_mlp_weight_matching_interp_loss_epoch{config.load_epoch}.png", dpi=300) 196 | wandb.log({"interp_loss_fig": wandb.Image(fig)}) 197 | plt.close(fig) 198 | 199 | fig = plot_interp_acc(config.load_epoch, lambdas, train_acc_interp_naive, test_acc_interp_naive, 200 | train_acc_interp_clever, test_acc_interp_clever) 201 | plt.savefig(f"mnist_mlp_weight_matching_interp_accuracy_epoch{config.load_epoch}.png", dpi=300) 202 | wandb.log({"interp_acc_fig": wandb.Image(fig)}) 203 | plt.close(fig) 204 | 205 | wandb.log({ 206 | "train_loss_interp_naive": train_loss_interp_naive, 207 | "test_loss_interp_naive": test_loss_interp_naive, 208 | "train_acc_interp_naive": train_acc_interp_naive, 209 | "test_acc_interp_naive": test_acc_interp_naive, 210 | "train_loss_interp_clever": train_loss_interp_clever, 211 | "test_loss_interp_clever": test_loss_interp_clever, 212 | "train_acc_interp_clever": train_acc_interp_clever, 213 | "test_acc_interp_clever": test_acc_interp_clever, 214 | }) 215 | 216 | print({ 217 | "train_loss_interp_naive": train_loss_interp_naive, 218 | "test_loss_interp_naive": test_loss_interp_naive, 219 | "train_acc_interp_naive": train_acc_interp_naive, 220 | "test_acc_interp_naive": test_acc_interp_naive, 221 | "train_loss_interp_clever": train_loss_interp_clever, 222 | "test_loss_interp_clever": test_loss_interp_clever, 223 | "train_acc_interp_clever": train_acc_interp_clever, 224 | "test_acc_interp_clever": test_acc_interp_clever, 225 | }) 226 | 227 | if __name__ == "__main__": 228 | main() 229 | -------------------------------------------------------------------------------- /src/mnist_vgg16_run.py: -------------------------------------------------------------------------------- 1 | """Train VGG16 on MNIST.""" 2 | import argparse 3 | 4 | import augmax 5 | import flax 6 | import jax.nn 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import optax 10 | import tensorflow as tf 11 | from flax import linen as nn 12 | from flax.training.train_state import TrainState 13 | from jax import jit, random, value_and_grad, vmap 14 | from tqdm import tqdm 15 | 16 | import wandb 17 | from mnist_mlp_run import load_datasets 18 | from utils import ec2_get_instance_type, rngmix, timeblock 19 | 20 | # See https://github.com/tensorflow/tensorflow/issues/53831. 21 | 22 | # See https://github.com/google/jax/issues/9454. 23 | tf.config.set_visible_devices([], "GPU") 24 | 25 | def make_vgg(backbone_layers, classifier_width: int, norm): 26 | 27 | class VGG(nn.Module): 28 | 29 | @nn.compact 30 | def __call__(self, x): 31 | for l in backbone_layers: 32 | if isinstance(l, int): 33 | x = nn.Conv(features=l, kernel_size=(3, 3))(x) 34 | x = norm()(x) 35 | x = nn.relu(x) 36 | elif l == "m": 37 | x = nn.max_pool(x, (2, 2), strides=(2, 2)) 38 | else: 39 | raise 40 | 41 | # Classifier 42 | # Note: everyone seems to do a different thing here. 43 | # * https://github.com/davisyoshida/vgg16-haiku/blob/4ef0bd001bf9daa4cfb2fa83ea3956ec01add3a8/vgg/vgg.py#L56 44 | # does average pooling with a kernel size of (7, 7) 45 | # * https://github.com/kuangliu/pytorch-cifar/blob/49b7aa97b0c12fe0d4054e670403a16b6b834ddd/models/vgg.py#L37 46 | # does average pooling with a kernel size of (1, 1) which doesn't seem 47 | # to accomplish anything. See https://github.com/kuangliu/pytorch-cifar/issues/110. 48 | # But this paper also doesn't really do the dense layers the same as in 49 | # the paper either... 50 | # * The paper itself doesn't mention any kind of pooling... 51 | # 52 | # I'll stick to replicating the paper as closely as possible for now. 53 | (_b, w, h, _c) = x.shape 54 | assert w == h == 1 55 | x = jnp.reshape(x, (x.shape[0], -1)) 56 | x = nn.Dense(classifier_width)(x) 57 | x = nn.relu(x) 58 | x = nn.Dense(classifier_width)(x) 59 | x = nn.relu(x) 60 | x = nn.Dense(10)(x) 61 | x = nn.log_softmax(x) 62 | return x 63 | 64 | return VGG 65 | 66 | TestVGG = make_vgg( 67 | [64, 64, "m", 64, 64, "m", 64, 64, 64, "m", 64, 64, 64, "m", 64, 64, 64, "m"], 68 | classifier_width=8, 69 | # norm=lambda: lambda x: x, 70 | norm=nn.LayerNorm) 71 | 72 | VGG16 = make_vgg( 73 | [64, 64, "m", 128, 128, "m", 256, 256, 256, "m", 512, 512, 512, "m", 512, 512, 512, "m"], 74 | classifier_width=4096, 75 | norm=nn.LayerNorm) 76 | 77 | VGG16ThinClassifier = make_vgg( 78 | [64, 64, "m", 128, 128, "m", 256, 256, 256, "m", 512, 512, 512, "m", 512, 512, 512, "m"], 79 | classifier_width=512, 80 | norm=nn.LayerNorm) 81 | 82 | def make_vgg_width_ablation(width_multiplier: int): 83 | m = width_multiplier 84 | return make_vgg([ 85 | m * 1, m * 1, "m", m * 2, m * 2, "m", m * 4, m * 4, m * 4, "m", m * 8, m * 8, m * 8, "m", 86 | m * 8, m * 8, m * 8, "m" 87 | ], 88 | classifier_width=m * 8, 89 | norm=nn.LayerNorm)() 90 | 91 | # 378.2MB 92 | VGG16Wide = make_vgg( 93 | [512, 512, "m", 512, 512, "m", 512, 512, 512, "m", 512, 512, 512, "m", 512, 512, 512, "m"], 94 | classifier_width=4096, 95 | norm=nn.LayerNorm) 96 | 97 | VGG19 = make_vgg([ 98 | 64, 64, "m", 128, 128, "m", 256, 256, 256, 256, "m", 512, 512, 512, 512, "m", 512, 512, 512, 99 | 512, "m" 100 | ], 101 | classifier_width=4096, 102 | norm=nn.LayerNorm) 103 | 104 | def make_stuff(model): 105 | # Applied to all input images, test and train. 106 | normalize_transform = augmax.Chain(augmax.Resize(32, 32), augmax.ByteToFloat()) 107 | 108 | @jit 109 | def batch_eval(params, images_u8, labels): 110 | images_f32 = vmap(normalize_transform)(None, images_u8) 111 | y_onehot = jax.nn.one_hot(labels, 10) 112 | logits = model.apply({"params": params}, images_f32) 113 | l = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y_onehot)) 114 | num_correct = jnp.sum(jnp.argmax(logits, axis=-1) == labels) 115 | return l, {"num_correct": num_correct} 116 | 117 | @jit 118 | def step(rng, train_state, images_u8, labels): 119 | # images_transformed = vmap(train_transform)(random.split(rng, images_u8.shape[0]), images_u8) 120 | (l, info), g = value_and_grad(batch_eval, has_aux=True)(train_state.params, images_u8, labels) 121 | return train_state.apply_gradients(grads=g), {"batch_loss": l, **info} 122 | 123 | def dataset_loss_and_accuracy(params, dataset, batch_size: int): 124 | num_examples = dataset["images_u8"].shape[0] 125 | assert num_examples % batch_size == 0 126 | num_batches = num_examples // batch_size 127 | batch_ix = jnp.arange(num_examples).reshape((num_batches, batch_size)) 128 | # Can't use vmap or run in a single batch since that overloads GPU memory. 129 | losses, infos = zip(*[ 130 | batch_eval( 131 | params, 132 | dataset["images_u8"][batch_ix[i, :], :, :, :], 133 | dataset["labels"][batch_ix[i, :]], 134 | ) for i in range(num_batches) 135 | ]) 136 | return ( 137 | jnp.sum(batch_size * jnp.array(losses)) / num_examples, 138 | sum(x["num_correct"] for x in infos) / num_examples, 139 | ) 140 | 141 | return { 142 | "normalize_transform": normalize_transform, 143 | "batch_eval": batch_eval, 144 | "step": step, 145 | "dataset_loss_and_accuracy": dataset_loss_and_accuracy, 146 | } 147 | 148 | def init_train_state(rng, model, learning_rate, num_epochs, batch_size, num_train_examples): 149 | # See https://github.com/kuangliu/pytorch-cifar. 150 | warmup_epochs = 1 151 | steps_per_epoch = num_train_examples // batch_size 152 | lr_schedule = optax.warmup_cosine_decay_schedule( 153 | init_value=1e-6, 154 | peak_value=learning_rate, 155 | warmup_steps=warmup_epochs * steps_per_epoch, 156 | # Confusingly, `decay_steps` is actually the total number of steps, 157 | # including the warmup. 158 | decay_steps=num_epochs * steps_per_epoch, 159 | ) 160 | tx = optax.chain(optax.add_decayed_weights(5e-4), optax.sgd(lr_schedule, momentum=0.9)) 161 | # tx = optax.adamw(learning_rate=lr_schedule, weight_decay=5e-4) 162 | vars = model.init(rng, jnp.zeros((1, 32, 32, 1))) 163 | return TrainState.create(apply_fn=model.apply, params=vars["params"], tx=tx) 164 | 165 | if __name__ == "__main__": 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument("--test", action="store_true", help="Run in smoke-test mode") 168 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 169 | parser.add_argument("--width-multiplier", type=int, default=64) 170 | args = parser.parse_args() 171 | 172 | with wandb.init( 173 | project="git-re-basin", 174 | entity="skainswo", 175 | tags=["mnist", "vgg16"], 176 | mode="disabled" if args.test else "online", 177 | job_type="train", 178 | ) as wandb_run: 179 | artifact = wandb.Artifact("mnist-vgg16-weights", type="model-weights") 180 | 181 | config = wandb.config 182 | config.ec2_instance_type = ec2_get_instance_type() 183 | config.test = args.test 184 | config.seed = args.seed 185 | config.learning_rate = 0.001 186 | config.num_epochs = 25 187 | config.width_multiplier = args.width_multiplier 188 | config.batch_size = 100 189 | 190 | rng = random.PRNGKey(config.seed) 191 | 192 | # model = TestVGG() if config.test else VGG16ThinClassifier() 193 | model = make_vgg_width_ablation(config.width_multiplier) 194 | with timeblock("load datasets"): 195 | train_ds, test_ds = load_datasets() 196 | print("train_ds labels hash", hash(np.array(train_ds["labels"]).tobytes())) 197 | print("test_ds labels hash", hash(np.array(test_ds["labels"]).tobytes())) 198 | 199 | num_train_examples = train_ds["images_u8"].shape[0] 200 | num_test_examples = test_ds["images_u8"].shape[0] 201 | assert num_train_examples % config.batch_size == 0 202 | print("num_train_examples", num_train_examples) 203 | print("num_test_examples", num_test_examples) 204 | 205 | stuff = make_stuff(model) 206 | train_state = init_train_state(rngmix(rng, "init"), 207 | model=model, 208 | learning_rate=config.learning_rate, 209 | num_epochs=config.num_epochs, 210 | batch_size=config.batch_size, 211 | num_train_examples=train_ds["images_u8"].shape[0]) 212 | 213 | for epoch in tqdm(range(config.num_epochs)): 214 | infos = [] 215 | with timeblock(f"Epoch"): 216 | batch_ix = random.permutation(rngmix(rng, f"epoch-{epoch}"), num_train_examples).reshape( 217 | (-1, config.batch_size)) 218 | batch_rngs = random.split(rngmix(rng, f"batch_rngs-{epoch}"), batch_ix.shape[0]) 219 | for i in range(batch_ix.shape[0]): 220 | p = batch_ix[i, :] 221 | images_u8 = train_ds["images_u8"][p, :, :, :] 222 | labels = train_ds["labels"][p] 223 | train_state, info = stuff["step"](batch_rngs[i], train_state, images_u8, labels) 224 | infos.append(info) 225 | 226 | train_loss = sum(config.batch_size * x["batch_loss"] for x in infos) / num_train_examples 227 | train_accuracy = sum(x["num_correct"] for x in infos) / num_train_examples 228 | 229 | # Evaluate test loss/accuracy 230 | with timeblock("Test set eval"): 231 | test_loss, test_accuracy = stuff["dataset_loss_and_accuracy"](train_state.params, test_ds, 232 | 1000) 233 | 234 | # See https://github.com/wandb/client/issues/3690. 235 | wandb_run.log({ 236 | "epoch": epoch, 237 | "train_loss": train_loss, 238 | "test_loss": test_loss, 239 | "train_accuracy": train_accuracy, 240 | "test_accuracy": test_accuracy, 241 | }) 242 | 243 | # No point saving the model at all if we're running in test mode. 244 | if (not config.test) and (epoch % 10 == 0 or epoch == config.num_epochs - 1): 245 | with timeblock("model serialization"): 246 | with artifact.new_file(f"checkpoint{epoch}", mode="wb") as f: 247 | f.write(flax.serialization.to_bytes(train_state)) 248 | 249 | # This will be a no-op when config.test is enabled anyhow, since wandb will 250 | # be initialized with mode="disabled". 251 | wandb_run.log_artifact(artifact) 252 | -------------------------------------------------------------------------------- /src/mnist_video.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import subprocess 3 | from pathlib import Path 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import wandb 8 | from matplotlib.ticker import FormatStrFormatter 9 | from tqdm import tqdm 10 | 11 | import matplotlib_style as _ 12 | 13 | api = wandb.Api() 14 | # run = api.run("skainswo/git-re-basin/begnvj15") 15 | artifact = Path( 16 | api.artifact("skainswo/git-re-basin/mnist_permutation_eval_vs_epoch:v1").download()) 17 | 18 | with open(artifact / "permutation_eval_vs_epoch.pkl", "rb") as f: 19 | interp_eval_vs_epoch = pickle.load(f) 20 | 21 | lambdas = np.linspace(0, 1, 25) 22 | 23 | grey_color = "black" 24 | highlight_color = "tab:orange" 25 | 26 | # for epoch in tqdm(range(10)): 27 | for epoch in tqdm(range(len(interp_eval_vs_epoch))): 28 | fig = plt.figure(figsize=(12, 6)) 29 | ax1 = fig.add_subplot(1, 2, 1) 30 | 31 | # Naive 32 | ax1.plot( 33 | lambdas, 34 | np.array(interp_eval_vs_epoch[epoch]["naive_train_loss_interp"]), 35 | color=grey_color, 36 | linewidth=2, 37 | # label="Naïve", 38 | label="Before", 39 | ) 40 | ax1.plot(lambdas, 41 | np.array(interp_eval_vs_epoch[epoch]["naive_test_loss_interp"]), 42 | color=grey_color, 43 | linewidth=2, 44 | linestyle="dashed") 45 | 46 | ax1.plot( 47 | lambdas, 48 | np.array(interp_eval_vs_epoch[epoch]["clever_train_loss_interp"]), 49 | color=highlight_color, 50 | # marker="^", 51 | linewidth=5, 52 | label="After (ours)") 53 | ax1.plot( 54 | lambdas, 55 | np.array(interp_eval_vs_epoch[epoch]["clever_test_loss_interp"]), 56 | color=highlight_color, 57 | # marker="^", 58 | linestyle="dashed", 59 | linewidth=5) 60 | 61 | # ax1.set_ylim(-0.05, 1.6) 62 | 63 | # ax1.set_xlabel("$\lambda$") 64 | ax1.set_xticks([0, 1]) 65 | ax1.set_xticklabels(["Model $A$", "Model $B$"]) 66 | ax1.set_ylabel("Loss", labelpad=7.5, fontsize=20) 67 | # ax1.set_title("Loss") 68 | ax1.legend(loc="upper right", framealpha=0.5) 69 | 70 | # Accuracy 71 | ax2 = fig.add_subplot(1, 2, 2) 72 | 73 | ax2.plot(lambdas, 74 | np.array(interp_eval_vs_epoch[epoch]["naive_train_acc_interp"]), 75 | color=grey_color, 76 | linewidth=2, 77 | label="Train") 78 | ax2.plot(lambdas, 79 | np.array(interp_eval_vs_epoch[epoch]["naive_test_acc_interp"]), 80 | color=grey_color, 81 | linewidth=2, 82 | linestyle="dashed", 83 | label="Test") 84 | 85 | ax2.plot( 86 | lambdas, 87 | np.array(interp_eval_vs_epoch[epoch]["clever_train_acc_interp"]), 88 | color=highlight_color, 89 | # marker="^", 90 | linewidth=5, 91 | # label="Ours", 92 | ) 93 | ax2.plot( 94 | lambdas, 95 | np.array(interp_eval_vs_epoch[epoch]["clever_test_acc_interp"]), 96 | color=highlight_color, 97 | # marker="^", 98 | linestyle="dashed", 99 | linewidth=5) 100 | 101 | # Prevent this from changing from frame to frame, messing up the spacing 102 | # See https://stackoverflow.com/questions/29188757/matplotlib-specify-format-of-floats-for-tick-labels 103 | ax2.yaxis.set_major_formatter(FormatStrFormatter("%.2f")) 104 | 105 | ax2.yaxis.tick_right() 106 | ax2.yaxis.set_label_position("right") 107 | 108 | # ax2.set_ylim(0.8, 1.01) 109 | 110 | # ax2.set_xlabel("$\lambda$") 111 | ax2.set_xticks([0, 1]) 112 | ax2.set_xticklabels(["Model $A$", "Model $B$"]) 113 | 114 | # 0.7, 0.725, ..., 0.975, 1.0 115 | allowed_ticks = 1.0 - np.arange(13)[::-1] * 0.025 116 | 117 | # For some reason the first/last ticks reported here are actually invisible... 118 | actual_ticks = ax2.get_yticks()[1:-1] 119 | ax2.set_yticks( 120 | [x for x in allowed_ticks if min(actual_ticks) - 1e-3 <= x <= max(actual_ticks) + 1e-3]) 121 | 122 | ax2.set_ylabel("Accuracy", rotation=270, labelpad=30, fontsize=20) 123 | # ax2.set_title("Accuracy") 124 | ax2.legend(loc="lower right", framealpha=0.5) 125 | 126 | fig.suptitle(f"Merging NNs before/after permuting weights (epoch {epoch+1})") 127 | # fig.tight_layout() 128 | 129 | plt.savefig(f"tmp/mnist_video_{epoch:05d}.png", dpi=300) 130 | plt.close(fig) 131 | 132 | subprocess.run([ 133 | "ffmpeg", "-r", "10", "-i", "tmp/mnist_video_%05d.png", "-vcodec", "libx264", "-crf", "15", 134 | "-pix_fmt", "yuv420p", "-y", "mnist_video.mp4" 135 | ], 136 | check=True) 137 | subprocess.run([ 138 | "ffmpeg", "-f", "image2", "-r", "10", "-i", "tmp/mnist_video_%05d.png", "-loop", "0", "-y", 139 | "mnist_video.gif" 140 | ], 141 | check=True) 142 | -------------------------------------------------------------------------------- /src/online_stats.py: -------------------------------------------------------------------------------- 1 | """Online-ish Pearson correlation of all n x n variable pairs simultaneously.""" 2 | from typing import NamedTuple 3 | 4 | import jax.numpy as jnp 5 | 6 | 7 | class OnlineMean(NamedTuple): 8 | sum: jnp.ndarray 9 | count: int 10 | 11 | @staticmethod 12 | def init(num_features: int): 13 | return OnlineMean(sum=jnp.zeros(num_features), count=0) 14 | 15 | def update(self, batch: jnp.ndarray): 16 | return OnlineMean(self.sum + jnp.sum(batch, axis=0), self.count + batch.shape[0]) 17 | 18 | def mean(self): 19 | return self.sum / self.count 20 | 21 | class OnlineCovariance(NamedTuple): 22 | a_mean: jnp.ndarray # (d, ) 23 | b_mean: jnp.ndarray # (d, ) 24 | cov: jnp.ndarray # (d, d) 25 | var_a: jnp.ndarray # (d, ) 26 | var_b: jnp.ndarray # (d, ) 27 | count: int 28 | 29 | @staticmethod 30 | def init(a_mean: jnp.ndarray, b_mean: jnp.ndarray): 31 | assert a_mean.shape == b_mean.shape 32 | assert len(a_mean.shape) == 1 33 | d = a_mean.shape[0] 34 | return OnlineCovariance(a_mean, 35 | b_mean, 36 | cov=jnp.zeros((d, d)), 37 | var_a=jnp.zeros((d, )), 38 | var_b=jnp.zeros((d, )), 39 | count=0) 40 | 41 | def update(self, a_batch, b_batch): 42 | assert a_batch.shape == b_batch.shape 43 | batch_size, _ = a_batch.shape 44 | a_res = a_batch - self.a_mean 45 | b_res = b_batch - self.b_mean 46 | return OnlineCovariance(a_mean=self.a_mean, 47 | b_mean=self.b_mean, 48 | cov=self.cov + a_res.T @ b_res, 49 | var_a=self.var_a + jnp.sum(a_res**2, axis=0), 50 | var_b=self.var_b + jnp.sum(b_res**2, axis=0), 51 | count=self.count + batch_size) 52 | 53 | def covariance(self): 54 | return self.cov / (self.count - 1) 55 | 56 | def a_variance(self): 57 | return self.var_a / (self.count - 1) 58 | 59 | def b_variance(self): 60 | return self.var_b / (self.count - 1) 61 | 62 | def a_stddev(self): 63 | return jnp.sqrt(self.a_variance()) 64 | 65 | def b_stddev(self): 66 | return jnp.sqrt(self.b_variance()) 67 | 68 | def E_ab(self): 69 | return self.covariance() + jnp.outer(self.a_mean, self.b_mean) 70 | 71 | def pearson_correlation(self): 72 | # Note that the 1/(n-1) normalization terms cancel out nicely here. 73 | # TODO: clip? 74 | eps = 0 75 | # Dead units will have zero variance, which produces NaNs. Convert those to 76 | # zeros with nan_to_num. 77 | return jnp.nan_to_num(self.cov / (jnp.sqrt(self.var_a[:, jnp.newaxis]) + eps) / 78 | (jnp.sqrt(self.var_b) + eps)) 79 | 80 | class OnlineInnerProduct(NamedTuple): 81 | val: jnp.ndarray # (d, d) 82 | 83 | @staticmethod 84 | def init(d: int): 85 | return OnlineInnerProduct(val=jnp.zeros((d, d))) 86 | 87 | def update(self, a_batch, b_batch): 88 | assert a_batch.shape == b_batch.shape 89 | return OnlineInnerProduct(val=self.val + a_batch.T @ b_batch) 90 | 91 | # def online_pearson_init_state(n): 92 | # return { 93 | # "Exy": jnp.zeros((n, n)), 94 | # "Ex": jnp.zeros((n, )), 95 | # "Ey": jnp.zeros((n, )), 96 | # "Ex2": jnp.zeros((n, )), 97 | # "Ey2": jnp.zeros((n, )), 98 | # "samples": 0, 99 | # } 100 | 101 | # def online_pearson_update(state, x_batch, y_batch): 102 | # """Online-ish Pearson update. 103 | 104 | # x_batch and y_batch are assumed to be of shape (batch_size, n).""" 105 | # assert x_batch.shape == y_batch.shape 106 | # batch_size = x_batch.shape[0] 107 | # return { 108 | # "Exy": state["Exy"] + x_batch.T @ y_batch, 109 | # "Ex": state["Ex"] + jnp.sum(x_batch, axis=0), 110 | # "Ey": state["Ey"] + jnp.sum(y_batch, axis=0), 111 | # "Ex2": state["Ex2"] + jnp.sum(x_batch**2, axis=0), 112 | # "Ey2": state["Ey2"] + jnp.sum(y_batch**2, axis=0), 113 | # "samples": state["samples"] + batch_size, 114 | # } 115 | 116 | # def online_pearson_finalize(state): 117 | # samples = state["samples"] 118 | # Exy = state["Exy"] / samples 119 | # Ex = state["Ex"] / samples 120 | # Ey = state["Ey"] / samples 121 | # Ex2 = state["Ex2"] / samples 122 | # Ey2 = state["Ey2"] / samples 123 | 124 | # print("finalizing pearson") 125 | # print((Exy - Ex[jnp.newaxis, :] * Ey).min(), (Exy - Ex[jnp.newaxis, :] * Ey).max()) 126 | # print(jnp.sqrt(Ex2 - Ex**2).min(), jnp.sqrt(Ex2 - Ex**2).max()) 127 | # print(jnp.sqrt(Ey2 - Ey**2).min(), jnp.sqrt(Ey2 - Ey**2).max()) 128 | # # Note that this will not be symmetric in general 129 | # # return (Exy - Ex[jnp.newaxis, :] * Ey) / jnp.sqrt(Ex2 - Ex**2)[jnp.newaxis, :] / jnp.sqrt(Ey2 - 130 | # # Ey**2) 131 | # return (Exy - Ex[jnp.newaxis, :] * Ey) 132 | 133 | # # def online_ 134 | 135 | # def test_one_batch(): 136 | # rp = RngPooper(random.PRNGKey(123)) 137 | # n = 3 138 | # x_batch = random.normal(rp.poop(), (1024, n)) 139 | # y_batch = random.normal(rp.poop(), (1024, n)) 140 | # state = online_pearson_init_state(n) 141 | # state = online_pearson_update(state, x_batch, y_batch) 142 | # pred = online_pearson_finalize(state) 143 | # print(pred) 144 | # gt = jnp.corrcoef(x_batch, y_batch, rowvar=False)[:n, n:] 145 | # print(gt) 146 | # np.testing.assert_allclose(pred, gt, rtol=1e0, atol=1e-2) 147 | 148 | # def test_multiple_batches(): 149 | # rp = RngPooper(random.PRNGKey(123)) 150 | # n = 3 151 | # x_batch = random.normal(rp.poop(), (1024, n)) 152 | # y_batch = random.normal(rp.poop(), (1024, n)) 153 | # state = online_pearson_init_state(n) 154 | # state = online_pearson_update(state, x_batch[:500, :], y_batch[:500, :]) 155 | # state = online_pearson_update(state, x_batch[500:750, :], y_batch[500:750, :]) 156 | # state = online_pearson_update(state, x_batch[750:, :], y_batch[750:, :]) 157 | # pred = online_pearson_finalize(state) 158 | # print(pred) 159 | # gt = jnp.corrcoef(x_batch, y_batch, rowvar=False)[:n, n:] 160 | # print(gt) 161 | # np.testing.assert_allclose(pred, gt, rtol=1e0, atol=1e-2) 162 | 163 | # if __name__ == "__main__": 164 | # test_one_batch() 165 | # test_multiple_batches() 166 | -------------------------------------------------------------------------------- /src/parallel_cifar10_run.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import jax.numpy as jnp 4 | import optax 5 | import tensorflow as tf 6 | import tensorflow_datasets as tfds 7 | from flax import linen as nn 8 | from jax import jit, random, tree_map, value_and_grad 9 | from tqdm import tqdm 10 | 11 | import wandb 12 | from parallel_mnist_plots import plot_interp_acc, plot_interp_loss 13 | from permutations import permutify 14 | from utils import RngPooper, ec2_get_instance_type, timeblock 15 | 16 | # See https://github.com/google/jax/issues/9454. 17 | tf.config.set_visible_devices([], "GPU") 18 | 19 | config = wandb.config 20 | config.ec2_instance_type = ec2_get_instance_type() 21 | config.smoke_test = "--test" in sys.argv 22 | config.learning_rate = 0.001 23 | config.num_epochs = 10 if config.smoke_test else 50 24 | config.batch_size = 7 if config.smoke_test else 256 25 | 26 | wandb.init(entity="skainswo", 27 | project="git-re-basin", 28 | tags=["cifar10"], 29 | mode="disabled" if config.smoke_test else "online") 30 | 31 | rp = RngPooper(random.PRNGKey(0)) 32 | 33 | activation = nn.relu 34 | 35 | if config.smoke_test: 36 | 37 | class Model(nn.Module): 38 | 39 | @nn.compact 40 | def __call__(self, x): 41 | x = nn.Conv(features=8, kernel_size=(3, 3), strides=(1, 1))(x) 42 | x = activation(x) 43 | x = jnp.reshape(x, (x.shape[0], -1)) 44 | x = nn.Dense(10)(x) 45 | x = nn.log_softmax(x) 46 | return x 47 | 48 | else: 49 | 50 | class Model(nn.Module): 51 | 52 | @nn.compact 53 | def __call__(self, x): 54 | x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1))(x) 55 | x = activation(x) 56 | x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1))(x) 57 | x = activation(x) 58 | x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1))(x) 59 | x = activation(x) 60 | x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1))(x) 61 | x = activation(x) 62 | x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1))(x) 63 | x = activation(x) 64 | x = jnp.reshape(x, (x.shape[0], -1)) 65 | x = nn.Dense(10)(x) 66 | x = nn.log_softmax(x) 67 | return x 68 | 69 | model = Model() 70 | 71 | @jit 72 | def batch_loss(params, x, y): 73 | logits = model.apply(params, x) 74 | return -jnp.mean(jnp.sum(y * logits, axis=-1)) 75 | 76 | @jit 77 | def batch_num_correct(params, x, y): 78 | logits = model.apply(params, x) 79 | return jnp.sum(jnp.argmax(logits, axis=-1) == jnp.argmax(y, axis=-1)) 80 | 81 | @jit 82 | def step(opt_state, params, x, y): 83 | l, g = value_and_grad(batch_loss)(params, x, y) 84 | updates, opt_state = tx.update(g, opt_state) 85 | params = optax.apply_updates(params, updates) 86 | return params, opt_state, l 87 | 88 | # See https://github.com/tensorflow/tensorflow/issues/53831. 89 | train_ds = tfds.load("cifar10", split="train", as_supervised=True) 90 | test_ds = tfds.load("cifar10", split="test", as_supervised=True) 91 | 92 | # Note: The take/cache warning: 93 | # 2022-01-25 07:32:58.144059: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 94 | # is not because we're actually doing this in the wrong order, but rather that 95 | # the dataset is loaded in and called .cache() on before we receive it. 96 | if config.smoke_test: 97 | train_ds = train_ds.take(13) 98 | test_ds = test_ds.take(17) 99 | 100 | # Normalize 0-255 pixel values to 0.0-1.0 101 | normalize = lambda image, label: (tf.cast(image, tf.float32) / 255.0, tf.one_hot(label, depth=10)) 102 | train_ds = train_ds.map(normalize).cache() 103 | test_ds = test_ds.map(normalize).cache() 104 | 105 | num_train_examples = train_ds.cardinality().numpy() 106 | num_test_examples = test_ds.cardinality().numpy() 107 | 108 | def dataset_loss(params, ds): 109 | # Note that we multiply by the batch size here, in order to get the sum of the 110 | # losses, then average over the whole dataset. 111 | return jnp.mean(jnp.array([x.shape[0] * batch_loss(params, x, y) for x, y in ds])) 112 | 113 | def dataset_total_correct(params, ds): 114 | return jnp.sum(jnp.array([batch_num_correct(params, x, y) for x, y in ds])) 115 | 116 | tx = optax.adam(config.learning_rate) 117 | 118 | params1 = model.init(rp.poop(), jnp.zeros((1, 32, 32, 3))) 119 | params2 = model.init(rp.poop(), jnp.zeros((1, 32, 32, 3))) 120 | opt_state1 = tx.init(params1) 121 | opt_state2 = tx.init(params2) 122 | for epoch in tqdm(range(config.num_epochs)): 123 | with timeblock(f"Epoch"): 124 | for images, labels in tfds.as_numpy( 125 | train_ds.shuffle(num_train_examples, seed=hash(f"{epoch}-1")).batch(config.batch_size)): 126 | params1, opt_state1, loss1 = step(opt_state1, params1, images, labels) 127 | for images, labels in tfds.as_numpy( 128 | train_ds.shuffle(num_train_examples, seed=hash(f"{epoch}-2")).batch(config.batch_size)): 129 | params2, opt_state2, loss2 = step(opt_state2, params2, images, labels) 130 | 131 | train_ds_batched = tfds.as_numpy(train_ds.batch(config.batch_size)) 132 | test_ds_batched = tfds.as_numpy(test_ds.batch(config.batch_size)) 133 | 134 | # This is inclusive on both ends. 135 | lambdas = jnp.linspace(0, 1, num=10) 136 | 137 | # TODO implement permutify for convnets! 138 | # params2_permuted = permutify(params1, params2) 139 | params2_permuted = params2 140 | 141 | def interp_naive(lam): 142 | return tree_map(lambda a, b: b * lam + a * (1 - lam), params1, params2) 143 | 144 | def interp_clever(lam): 145 | return tree_map(lambda a, b: b * lam + a * (1 - lam), params1, params2_permuted) 146 | 147 | with timeblock("Interpolation plot"): 148 | train_loss_interp_naive = jnp.array( 149 | [dataset_loss(interp_naive(l), train_ds_batched) for l in lambdas]) 150 | test_loss_interp_naive = jnp.array( 151 | [dataset_loss(interp_naive(l), test_ds_batched) for l in lambdas]) 152 | train_acc_interp_naive = jnp.array( 153 | [dataset_total_correct(interp_naive(l), train_ds_batched) 154 | for l in lambdas]) / num_train_examples 155 | test_acc_interp_naive = jnp.array( 156 | [dataset_total_correct(interp_naive(l), test_ds_batched) 157 | for l in lambdas]) / num_test_examples 158 | 159 | train_loss_interp_clever = jnp.array( 160 | [dataset_loss(interp_clever(l), train_ds_batched) for l in lambdas]) 161 | test_loss_interp_clever = jnp.array( 162 | [dataset_loss(interp_clever(l), test_ds_batched) for l in lambdas]) 163 | train_acc_interp_clever = jnp.array( 164 | [dataset_total_correct(interp_clever(l), train_ds_batched) 165 | for l in lambdas]) / num_train_examples 166 | test_acc_interp_clever = jnp.array( 167 | [dataset_total_correct(interp_clever(l), test_ds_batched) 168 | for l in lambdas]) / num_test_examples 169 | 170 | # These are redundant with the full arrays above, but we want pretty plots in 171 | # wandb. 172 | train_loss1 = train_loss_interp_naive[0] 173 | train_loss2 = train_loss_interp_naive[-1] 174 | test_loss1 = test_loss_interp_naive[0] 175 | test_loss2 = test_loss_interp_naive[-1] 176 | train_acc1 = train_acc_interp_naive[0] 177 | train_acc2 = train_acc_interp_naive[-1] 178 | test_acc1 = test_acc_interp_naive[0] 179 | test_acc2 = test_acc_interp_naive[-1] 180 | 181 | interp_loss_plot = plot_interp_loss(epoch, lambdas, train_loss_interp_naive, 182 | test_loss_interp_naive, train_loss_interp_clever, 183 | test_loss_interp_clever) 184 | interp_acc_plot = plot_interp_acc(epoch, lambdas, train_acc_interp_naive, test_acc_interp_naive, 185 | train_acc_interp_clever, test_acc_interp_clever) 186 | wandb.log({ 187 | "epoch": epoch, 188 | "train_loss1": train_loss1, 189 | "train_loss2": train_loss2, 190 | "train_acc1": train_acc1, 191 | "train_acc2": train_acc2, 192 | "test_loss1": test_loss1, 193 | "test_loss2": test_loss2, 194 | "test_acc1": test_acc1, 195 | "test_acc2": test_acc2, 196 | # This doesn't really change, but it's more convenient to store it here 197 | # when we go to make videos/plots later. 198 | "lambdas": lambdas, 199 | "train_loss_interp_naive": train_loss_interp_naive, 200 | "test_loss_interp_naive": test_loss_interp_naive, 201 | "train_acc_interp_naive": train_acc_interp_naive, 202 | "test_acc_interp_naive": test_acc_interp_naive, 203 | "train_loss_interp_clever": train_loss_interp_clever, 204 | "test_loss_interp_clever": test_loss_interp_clever, 205 | "train_acc_interp_clever": train_acc_interp_clever, 206 | "test_acc_interp_clever": test_acc_interp_clever, 207 | "interp_loss_plot": wandb.Image(interp_loss_plot), 208 | "interp_acc_plot": wandb.Image(interp_acc_plot), 209 | }) 210 | -------------------------------------------------------------------------------- /src/parallel_mnist_videos.py: -------------------------------------------------------------------------------- 1 | # Example usage: 2 | # python parallel_mnist_videos.py skainswo/git-re-basin/2vzg9n1u 3 | 4 | import subprocess 5 | import sys 6 | import tempfile 7 | from pathlib import Path 8 | 9 | import jax.numpy as jnp 10 | import matplotlib.pyplot as plt 11 | 12 | import wandb 13 | from parallel_mnist_plots import plot_interp_acc, plot_interp_loss 14 | 15 | api = wandb.Api() 16 | run = api.run(sys.argv[1]) 17 | history = run.history() 18 | 19 | # TODO: this should no longer be necessary... 20 | lambdas = jnp.linspace(0, 1, num=10) 21 | 22 | with tempfile.TemporaryDirectory() as tempdir: 23 | for step in history: 24 | fig = plot_interp_loss(step["epoch"], lambdas, step["train_loss_interp_naive"], 25 | step["test_loss_interp_naive"], step["train_loss_interp_clever"], 26 | step["test_loss_interp_clever"]) 27 | plt.savefig(Path(tempdir) / f"{step['epoch']:05d}.png") 28 | plt.close(fig) 29 | 30 | subprocess.run([ 31 | "ffmpeg", "-r", "10", "-i", 32 | Path(tempdir) / "%05d.png", "-vcodec", "libx264", "-crf", "15", "-pix_fmt", "yuv420p", "-y", 33 | f"parallel_mnist_interp_loss.mp4" 34 | ], 35 | check=True) 36 | 37 | with tempfile.TemporaryDirectory() as tempdir: 38 | for step in history: 39 | fig = plot_interp_acc(step["epoch"], lambdas, step["train_acc_interp_naive"], 40 | step["test_acc_interp_naive"], step["train_acc_interp_clever"], 41 | step["test_acc_interp_clever"]) 42 | plt.savefig(Path(tempdir) / f"{step['epoch']:05d}.png") 43 | plt.close(fig) 44 | 45 | subprocess.run([ 46 | "ffmpeg", "-r", "10", "-i", 47 | Path(tempdir) / "%05d.png", "-vcodec", "libx264", "-crf", "15", "-pix_fmt", "yuv420p", "-y", 48 | f"parallel_mnist_interp_acc.mp4" 49 | ], 50 | check=True) 51 | -------------------------------------------------------------------------------- /src/plot_utils.py: -------------------------------------------------------------------------------- 1 | def loss_barrier_is_nonnegative(ax): 2 | ax.axhline(y=0, color="tab:grey", linestyle=":", alpha=0, zorder=-1) 3 | ylim = ax.get_ylim() 4 | # See https://stackoverflow.com/a/5197426/3880977 5 | ax.axhspan(-0.1, 0, color="tab:grey", alpha=0.25, zorder=-2) 6 | ax.axhspan(-0.1, 0, facecolor="none", edgecolor="tab:grey", alpha=0.25, hatch="//", zorder=-1) 7 | ax.set_ylim(ylim) 8 | -------------------------------------------------------------------------------- /src/resnet20.py: -------------------------------------------------------------------------------- 1 | from einops import reduce 2 | from flax import linen as nn 3 | 4 | def reverse_compose(x, fs): 5 | for f in fs: 6 | x = f(x) 7 | return x 8 | 9 | class Block(nn.Module): 10 | num_channels: int = None 11 | strides: int = None 12 | 13 | def setup(self): 14 | self.conv1 = nn.Conv(features=self.num_channels, 15 | kernel_size=(3, 3), 16 | strides=self.strides, 17 | use_bias=False) 18 | self.norm1 = nn.LayerNorm() 19 | self.conv2 = nn.Conv(features=self.num_channels, kernel_size=(3, 3), strides=1, use_bias=False) 20 | self.norm2 = nn.LayerNorm() 21 | 22 | # When strides != 1, then it's 2, which means that we halve the width and height of the input, while doubling the 23 | # number of channels. Therefore we need to correspondingly halve the width and height of the residuals/shortcut. 24 | if self.strides != 1: 25 | assert self.strides == 2 26 | 27 | # Supposedly this is the original description, but it is not easily comaptible with our weight matching stuff 28 | # since it plays games with the channel structure by padding things around. 29 | # self.shortcut = lambda x: jnp.pad(x[:, ::2, ::2, :], ( 30 | # (0, 0), (0, 0), (0, 0), (self.num_channels // 4, self.num_channels // 4)), 31 | # "constant", 32 | # constant_values=0) 33 | 34 | # This is not the original, but is fairly common based on other implementations. 35 | self.shortcut = nn.Sequential([ 36 | nn.Conv(features=self.num_channels, 37 | kernel_size=(3, 3), 38 | strides=self.strides, 39 | use_bias=False), 40 | nn.LayerNorm() 41 | ]) 42 | else: 43 | self.shortcut = lambda x: x 44 | 45 | def __call__(self, x): 46 | y = x 47 | y = self.conv1(y) 48 | y = self.norm1(y) 49 | y = nn.relu(y) 50 | y = self.conv2(y) 51 | y = self.norm2(y) 52 | return nn.relu(y + self.shortcut(x)) 53 | 54 | class BlockGroup(nn.Module): 55 | num_channels: int = None 56 | num_blocks: int = None 57 | strides: int = None 58 | 59 | def setup(self): 60 | assert self.num_blocks > 0 61 | self.blocks = ( 62 | [Block(num_channels=self.num_channels, strides=self.strides)] + 63 | [Block(num_channels=self.num_channels, strides=1) for _ in range(self.num_blocks - 1)]) 64 | 65 | def __call__(self, x): 66 | return reverse_compose(x, self.blocks) 67 | 68 | class ResNet(nn.Module): 69 | blocks_per_group: int = None 70 | num_classes: int = None 71 | width_multiplier: int = 1 72 | 73 | def setup(self): 74 | wm = self.width_multiplier 75 | 76 | self.conv1 = nn.Conv(features=16 * wm, kernel_size=(3, 3), use_bias=False) 77 | self.norm1 = nn.LayerNorm() 78 | 79 | channels_per_group = (16 * wm, 32 * wm, 64 * wm) 80 | strides_per_group = (1, 2, 2) 81 | self.blockgroups = [ 82 | BlockGroup(num_channels=c, num_blocks=b, strides=s) 83 | for c, b, s in zip(channels_per_group, self.blocks_per_group, strides_per_group) 84 | ] 85 | 86 | self.dense = nn.Dense(self.num_classes) 87 | 88 | def __call__(self, x): 89 | x = self.conv1(x) 90 | x = self.norm1(x) 91 | x = nn.relu(x) 92 | x = reverse_compose(x, self.blockgroups) 93 | x = reduce(x, "n h w c -> n c", "mean") 94 | x = self.dense(x) 95 | x = nn.log_softmax(x) 96 | return x 97 | 98 | BLOCKS_PER_GROUP = { 99 | "resnet20": (3, 3, 3), 100 | "resnet32": (5, 5, 5), 101 | "resnet44": (7, 7, 7), 102 | "resnet56": (9, 9, 9), 103 | "resnet110": (18, 18, 18), 104 | } 105 | -------------------------------------------------------------------------------- /src/sgd_is_special.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import matplotlib.pyplot as plt 3 | from flax import linen as nn 4 | from jax import random, tree_map 5 | from matplotlib.colors import ListedColormap 6 | 7 | import matplotlib_style as _ 8 | from utils import unflatten_params 9 | 10 | rng = random.PRNGKey(0) 11 | 12 | class Model(nn.Module): 13 | 14 | @nn.compact 15 | def __call__(self, x): 16 | x = nn.Dense(2)(x) 17 | x = nn.PReLU()(x) 18 | x = nn.Dense(2)(x) 19 | x = nn.PReLU()(x) 20 | x = nn.Dense(1)(x) 21 | return x 22 | 23 | model = Model() 24 | 25 | # dense kernel shape: (in, out) 26 | dtype = jnp.float32 27 | paramsA = { 28 | "PReLU_0/negative_slope": jnp.array(0.01), 29 | "PReLU_1/negative_slope": jnp.array(0.01), 30 | "Dense_0/kernel": jnp.array([[-1, 0], [0, -1]], dtype=dtype), 31 | "Dense_0/bias": jnp.array([1, 0], dtype=dtype), 32 | "Dense_1/kernel": jnp.array([[-1, 0], [0, 1]], dtype=dtype), 33 | "Dense_1/bias": jnp.array([1, 0], dtype=dtype), 34 | "Dense_2/kernel": jnp.array([[-1], [-1]], dtype=dtype), 35 | "Dense_2/bias": jnp.array([0], dtype=dtype), 36 | } 37 | paramsB1 = { 38 | "PReLU_0/negative_slope": jnp.array(0.01), 39 | "PReLU_1/negative_slope": jnp.array(0.01), 40 | "Dense_0/kernel": jnp.array([[1, 0], [0, 1]], dtype=dtype), 41 | "Dense_0/bias": jnp.array([0, 1], dtype=dtype), 42 | "Dense_1/kernel": jnp.array([[1, 0], [0, -1]], dtype=dtype), 43 | "Dense_1/bias": jnp.array([0, 1], dtype=dtype), 44 | "Dense_2/kernel": jnp.array([[-1], [-1]], dtype=dtype), 45 | "Dense_2/bias": jnp.array([0], dtype=dtype), 46 | } 47 | 48 | def swap_layer(layer: int, params): 49 | ix = jnp.array([1, 0]) 50 | return { 51 | **params, 52 | f"Dense_{layer}/kernel": params[f"Dense_{layer}/kernel"][:, ix], 53 | f"Dense_{layer}/bias": params[f"Dense_{layer}/bias"][ix], 54 | f"Dense_{layer+1}/kernel": params[f"Dense_{layer+1}/kernel"][ix, :], 55 | } 56 | 57 | swap_first_layer = lambda params: swap_layer(0, params) 58 | swap_second_layer = lambda params: swap_layer(1, params) 59 | 60 | paramsB2 = swap_first_layer(paramsB1) 61 | paramsB3 = swap_second_layer(paramsB1) 62 | paramsB4 = swap_first_layer(swap_second_layer(paramsB1)) 63 | 64 | # Assert that [swapfirst, swapsecond] is the same as [swapsecond, swapfirst]. 65 | assert jnp.all( 66 | jnp.array( 67 | list( 68 | tree_map(jnp.allclose, swap_first_layer(swap_second_layer(paramsB1)), 69 | swap_second_layer(swap_first_layer(paramsB1))).values()))) 70 | 71 | num_examples = 1024 72 | testX = random.uniform(rng, (num_examples, 2), dtype=dtype, minval=-1, maxval=1) 73 | testY = (testX[:, 0] <= 0) & (testX[:, 1] >= 0) 74 | 75 | def accuracy(params): 76 | return jnp.sum((model.apply({"params": unflatten_params(params)}, testX) >= 0 77 | ).flatten() == testY) / num_examples 78 | 79 | # assert accuracy(paramsA) == 1.0 80 | # assert accuracy(paramsB1) == 1.0 81 | # assert accuracy(paramsB2) == 1.0 82 | # assert accuracy(paramsB3) == 1.0 83 | # assert accuracy(paramsB4) == 1.0 84 | 85 | def interp_params(lam, pA, pB): 86 | return tree_map(lambda a, b: lam * a + (1 - lam) * b, pA, pB) 87 | 88 | def plot_interp_loss(): 89 | lambdas = jnp.linspace(0, 1, num=10) 90 | interp1 = jnp.array([accuracy(interp_params(lam, paramsA, paramsB1)) for lam in lambdas]) 91 | interp2 = jnp.array([accuracy(interp_params(lam, paramsA, paramsB2)) for lam in lambdas]) 92 | interp3 = jnp.array([accuracy(interp_params(lam, paramsA, paramsB3)) for lam in lambdas]) 93 | interp4 = jnp.array([accuracy(interp_params(lam, paramsA, paramsB4)) for lam in lambdas]) 94 | 95 | fig = plt.figure() 96 | ax = fig.add_subplot(1, 1, 1) 97 | # We make losses start at 0, since that intuitively makes more sense. 98 | ax.plot(lambdas, -interp1 + 1, linewidth=2, marker="o", label="Identity") 99 | ax.plot(lambdas, -interp2 + 1, linewidth=2, marker=".", label="Swap layer 1") 100 | ax.plot(lambdas, -interp3 + 1, linewidth=2, marker="x", label="Swap layer 2") 101 | ax.plot(lambdas, -interp4 + 1, linewidth=2, marker="*", label="Swap both") 102 | ax.plot([-1, 2], [0, 0], linestyle=":", color="tab:grey", alpha=0.5, label="Perfect performance") 103 | ax.set_xlabel("$\lambda$") 104 | ax.set_xticks([0, 1]) 105 | ax.set_xticklabels(["Model $A$", "Model $B$"]) 106 | ax.set_xlim(-0.05, 1.05) 107 | ax.set_ylabel("Loss") 108 | ax.set_title("All possible permutations") 109 | ax.legend(framealpha=0.5) 110 | fig.tight_layout() 111 | return fig 112 | 113 | fig = plot_interp_loss() 114 | plt.savefig(f"figs/sgd_is_special_loss_interp.png", dpi=300) 115 | plt.savefig(f"figs/sgd_is_special_loss_interp.eps") 116 | plt.savefig(f"figs/sgd_is_special_loss_interp.pdf") 117 | plt.close(fig) 118 | 119 | def plot_interp_loss_zoom(max_lambda): 120 | lambdas = jnp.linspace(0, max_lambda, num=10) 121 | interp1 = jnp.array([accuracy(interp_params(lam, paramsA, paramsB1)) for lam in lambdas]) 122 | interp2 = jnp.array([accuracy(interp_params(lam, paramsA, paramsB2)) for lam in lambdas]) 123 | interp3 = jnp.array([accuracy(interp_params(lam, paramsA, paramsB3)) for lam in lambdas]) 124 | interp4 = jnp.array([accuracy(interp_params(lam, paramsA, paramsB4)) for lam in lambdas]) 125 | 126 | fig = plt.figure() 127 | ax = fig.add_subplot(1, 1, 1) 128 | # We make losses start at 0, since that intuitively makes more sense. 129 | ax.plot(lambdas, -interp1 + 1, linewidth=2, marker="o", label="Identity") 130 | ax.plot(lambdas, -interp2 + 1, linewidth=2, marker=".", label="Swap layer 1") 131 | ax.plot(lambdas, -interp3 + 1, linewidth=2, marker="x", label="Swap layer 2") 132 | ax.plot(lambdas, -interp4 + 1, linewidth=2, marker="*", label="Swap both") 133 | ax.plot([-1, 2], [0, 0], 134 | linestyle="dashed", 135 | color="tab:grey", 136 | alpha=0.5, 137 | label="Perfect performance") 138 | 139 | # ax.set_xscale("log") 140 | # ax.set_yscale("log") 141 | 142 | ax.set_xlabel("$\lambda$") 143 | # ax.set_xlim(-0.05, max_lambda * 1.05) 144 | ax.set_ylabel("Loss") 145 | ax.set_title("All possible permutations between two globally optimal models (zoom)") 146 | ax.legend(framealpha=0.5) 147 | fig.tight_layout() 148 | return fig 149 | 150 | fig = plot_interp_loss_zoom(max_lambda=1e-6) 151 | plt.savefig(f"figs/sgd_is_special_loss_interp_zoom.png", dpi=300) 152 | # plt.savefig(f"figs/sgd_is_special_loss_interp.pdf") 153 | plt.close(fig) 154 | 155 | def plot_data(): 156 | fig = plt.figure() 157 | ax = fig.add_subplot(1, 1, 1) 158 | ax.scatter(testX[testY, 0], 159 | testX[testY, 1], 160 | edgecolor="tab:green", 161 | facecolor="none", 162 | marker="o", 163 | label="$y=1$") 164 | ax.scatter(testX[~testY, 0], testX[~testY, 1], color="tab:red", marker="x", label="$y=0$") 165 | ax.axhline(0, color="tab:grey", alpha=0.25) 166 | ax.axvline(0, color="tab:grey", alpha=0.25) 167 | ax.set_xlabel("$x_1$") 168 | ax.set_ylabel("$x_2$") 169 | ax.set_title("Data") 170 | # ax.legend(framealpha=0.5) 171 | fig.tight_layout() 172 | return fig 173 | 174 | fig = plot_data() 175 | plt.savefig(f"figs/sgd_is_special_data.png", dpi=300) 176 | plt.savefig(f"figs/sgd_is_special_data.eps") 177 | plt.savefig(f"figs/sgd_is_special_data.pdf") 178 | plt.close() 179 | 180 | extrema = model.apply({"params": unflatten_params(paramsA)}, 181 | jnp.array([[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=dtype)) 182 | min_score = jnp.min(extrema) 183 | max_score = jnp.max(extrema) 184 | 185 | def plot_detailed_view(): 186 | lambdas = jnp.linspace(0, 1, num=9) 187 | 188 | s = 2 189 | fig, ax = plt.subplots(4, len(lambdas), figsize=(len(lambdas) * s, 4 * s)) 190 | 191 | ticks = jnp.linspace(-1, 1, num=100) 192 | xx1, xx2 = jnp.meshgrid(ticks, ticks) 193 | meshX = jnp.stack([xx1.flatten(), xx2.flatten()], axis=1) 194 | 195 | def asdf(row, paramsB): 196 | for i in range(len(lambdas)): 197 | params = interp_params(lambdas[i], paramsA, paramsB) 198 | meshY = model.apply({"params": unflatten_params(params)}, meshX).flatten() 199 | decision_boundary = (meshY >= 0).astype(float) 200 | ax[row, i].contourf(xx1, 201 | xx2, 202 | meshY.reshape(xx1.shape), 203 | levels=jnp.linspace(min_score, max_score, num=25), 204 | cmap="copper") 205 | # ax[row, i].contourf(xx1, 206 | # xx2, 207 | # decision_boundary.reshape(xx1.shape), 208 | # cmap=ListedColormap(np.array([[0, 0, 0, 0.0], [0, 1, 0, 0.5]]))) 209 | ax[row, i].set_xticks([]) 210 | ax[row, i].set_yticks([]) 211 | 212 | asdf(0, paramsB1) 213 | asdf(1, paramsB2) 214 | asdf(2, paramsB3) 215 | asdf(3, paramsB4) 216 | 217 | ax[0, 0].set_title("Model A", fontweight="bold") 218 | ax[0, -1].set_title("Model B", fontweight="bold") 219 | ax[0, 4].set_title("⟵ $\\lambda$ ⟶") 220 | 221 | ax[0, 0].set_ylabel("Identity") 222 | ax[1, 0].set_ylabel("Swap layer 1") 223 | ax[2, 0].set_ylabel("Swap layer 2") 224 | ax[3, 0].set_ylabel("Swap both") 225 | 226 | fig.tight_layout() 227 | return fig 228 | 229 | fig = plot_detailed_view() 230 | plt.savefig(f"figs/sgd_is_special_detailed_view.png", dpi=300) 231 | plt.savefig(f"figs/sgd_is_special_detailed_view.eps") 232 | plt.savefig(f"figs/sgd_is_special_detailed_view.pdf") 233 | plt.close(fig) 234 | -------------------------------------------------------------------------------- /src/should_be_deterministic.py: -------------------------------------------------------------------------------- 1 | """See https://github.com/google/jax/discussions/10674.""" 2 | import jax 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import optax 6 | from flax import linen as nn 7 | from flax.training.train_state import TrainState 8 | from jax import jit, random, value_and_grad 9 | 10 | activation = nn.relu 11 | 12 | class MLPModel(nn.Module): 13 | 14 | @nn.compact 15 | def __call__(self, x): 16 | x = jnp.reshape(x, (-1, 28 * 28)) 17 | x = nn.Dense(512)(x) 18 | x = activation(x) 19 | x = nn.Dense(512)(x) 20 | x = activation(x) 21 | x = nn.Dense(512)(x) 22 | x = activation(x) 23 | x = nn.Dense(10)(x) 24 | x = nn.log_softmax(x) 25 | return x 26 | 27 | def make_stuff(model): 28 | 29 | @jit 30 | def batch_eval(params, images_u8, labels): 31 | images_f32 = jnp.array(images_u8, dtype=jnp.float32) / 256.0 32 | logits = model.apply({"params": params}, images_f32) 33 | y_onehot = jax.nn.one_hot(labels, 10) 34 | loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y_onehot)) 35 | num_correct = jnp.sum(jnp.argmax(logits, axis=-1) == jnp.argmax(y_onehot, axis=-1)) 36 | return loss, num_correct 37 | 38 | @jit 39 | def step(train_state, images_f32, labels): 40 | (l, num_correct), g = value_and_grad(batch_eval, has_aux=True)(train_state.params, images_f32, 41 | labels) 42 | return train_state.apply_gradients(grads=g), l 43 | 44 | def dataset_loss_and_accuracy(params, dataset, batch_size: int): 45 | num_examples = dataset["images_u8"].shape[0] 46 | assert num_examples % batch_size == 0 47 | num_batches = num_examples // batch_size 48 | batch_ix = jnp.arange(num_examples).reshape((num_batches, batch_size)) 49 | # Can't use vmap or run in a single batch since that overloads GPU memory. 50 | losses, num_corrects = zip(*[ 51 | batch_eval( 52 | params, 53 | dataset["images_u8"][batch_ix[i, :], :, :, :], 54 | dataset["labels"][batch_ix[i, :]], 55 | ) for i in range(num_batches) 56 | ]) 57 | losses = jnp.array(losses) 58 | num_corrects = jnp.array(num_corrects) 59 | return jnp.sum(batch_size * losses) / num_examples, jnp.sum(num_corrects) / num_examples 60 | 61 | return { 62 | "batch_eval": batch_eval, 63 | "step": step, 64 | "dataset_loss_and_accuracy": dataset_loss_and_accuracy 65 | } 66 | 67 | def get_datasets(): 68 | num_train = 1000 69 | num_test = 1000 70 | return { 71 | "images_u8": 72 | random.choice(random.PRNGKey(0), jnp.arange(256, dtype=jnp.uint8), (num_train, 28, 28, 1)), 73 | "labels": 74 | random.choice(random.PRNGKey(2), jnp.arange(10, dtype=jnp.uint8), (num_train, )) 75 | }, { 76 | "images_u8": 77 | random.choice(random.PRNGKey(3), jnp.arange(256, dtype=jnp.uint8), (num_test, 28, 28, 1)), 78 | "labels": 79 | random.choice(random.PRNGKey(4), jnp.arange(10, dtype=jnp.uint8), (num_test, )) 80 | } 81 | 82 | def init_train_state(rng, learning_rate, model): 83 | tx = optax.adam(learning_rate) 84 | vars = model.init(rng, jnp.zeros((1, 28, 28, 1))) 85 | return TrainState.create(apply_fn=model.apply, params=vars["params"], tx=tx) 86 | 87 | def main(): 88 | learning_rate = 0.001 89 | num_epochs = 10 90 | batch_size = 100 91 | 92 | train_ds, test_ds = get_datasets() 93 | print("train_ds images_u8 hash", hash(np.array(train_ds["images_u8"]).tobytes())) 94 | print("train_ds labels hash", hash(np.array(train_ds["labels"]).tobytes())) 95 | print("test_ds images_u8 hash", hash(np.array(test_ds["images_u8"]).tobytes())) 96 | print("test_ds labels hash", hash(np.array(test_ds["labels"]).tobytes())) 97 | 98 | num_train_examples = train_ds["images_u8"].shape[0] 99 | assert num_train_examples % batch_size == 0 100 | 101 | model = MLPModel() 102 | stuff = make_stuff(model) 103 | train_state = init_train_state(random.PRNGKey(123), learning_rate, model) 104 | 105 | for epoch in range(num_epochs): 106 | batch_ix = random.permutation(random.PRNGKey(epoch), num_train_examples).reshape( 107 | (-1, batch_size)) 108 | for i in range(batch_ix.shape[0]): 109 | p = batch_ix[i, :] 110 | images_u8 = train_ds["images_u8"][p, :, :, :] 111 | labels = train_ds["labels"][p] 112 | train_state, batch_loss = stuff["step"](train_state, images_u8, labels) 113 | 114 | # Evaluate train/test loss/accuracy 115 | train_loss, train_accuracy = stuff["dataset_loss_and_accuracy"](train_state.params, train_ds, 116 | 1000) 117 | test_loss, test_accuracy = stuff["dataset_loss_and_accuracy"](train_state.params, test_ds, 1000) 118 | 119 | print({ 120 | "epoch": epoch, 121 | "batch_loss": float(batch_loss), 122 | "train_loss": float(train_loss), 123 | "test_loss": float(test_loss), 124 | "train_accuracy": float(train_accuracy), 125 | "test_accuracy": float(test_accuracy), 126 | }) 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import re 3 | import time 4 | from contextlib import contextmanager 5 | 6 | import jax.numpy as jnp 7 | from flax import traverse_util 8 | from flax.core import freeze, unfreeze 9 | from jax import random, tree_map 10 | from jax.tree_util import tree_reduce 11 | 12 | rngmix = lambda rng, x: random.fold_in(rng, hash(x)) 13 | 14 | @contextmanager 15 | def timeblock(name): 16 | start = time.time() 17 | try: 18 | yield 19 | finally: 20 | end = time.time() 21 | print(f"{name} took {end - start:.5f} seconds") 22 | 23 | class RngPooper: 24 | """A stateful wrapper around stateless random.PRNGKey's.""" 25 | 26 | def __init__(self, init_rng): 27 | self.rng = init_rng 28 | 29 | def poop(self): 30 | self.rng, rng_key = random.split(self.rng) 31 | return rng_key 32 | 33 | def l1prox(x, alpha): 34 | return jnp.sign(x) * jnp.maximum(0, jnp.abs(x) - alpha) 35 | 36 | def ec2_get_instance_type(): 37 | # See also https://stackoverflow.com/questions/51486405/aws-ec2-command-line-display-instance-type/51486782 38 | return open("/sys/devices/virtual/dmi/id/product_name").read().strip() 39 | 40 | # Utilities for dealing with flax model parameters 41 | def partition(pred, iterable): 42 | trues = [] 43 | falses = [] 44 | for item in iterable: 45 | if pred(item): 46 | trues.append(item) 47 | else: 48 | falses.append(item) 49 | return trues, falses 50 | 51 | def partition_dict(pred, d): 52 | trues = {} 53 | falses = {} 54 | for k, v in d.items(): 55 | if pred(k): 56 | trues[k] = v 57 | else: 58 | falses[k] = v 59 | return trues, falses 60 | 61 | def flatten_params(params): 62 | return {"/".join(k): v for k, v in traverse_util.flatten_dict(unfreeze(params)).items()} 63 | 64 | def unflatten_params(flat_params): 65 | return freeze( 66 | traverse_util.unflatten_dict({tuple(k.split("/")): v 67 | for k, v in flat_params.items()})) 68 | 69 | def merge_params(a, b): 70 | return unflatten_params({**a, **b}) 71 | 72 | def kmatch(pattern, key): 73 | regex = "^" 74 | i = 0 75 | while i < len(pattern): 76 | if pattern[i] == "*": 77 | if i + 1 < len(pattern) and pattern[i + 1] == "*": 78 | regex += "(.*)" 79 | i += 2 80 | else: 81 | regex += "([^\/]*)" 82 | i += 1 83 | else: 84 | regex += pattern[i] 85 | i += 1 86 | regex += "$" 87 | return re.fullmatch(regex, key) 88 | 89 | assert kmatch("*", "a") is not None 90 | assert kmatch("*", "a").group(0) == "a" 91 | assert kmatch("*", "a").group(1) == "a" 92 | assert kmatch("abc", "def") is None 93 | assert kmatch("abc/*/ghi", "abc/def/ghi").group(1) == "def" 94 | assert kmatch("abc/**/jkl", "abc/def/ghi/jkl").group(1) == "def/ghi" 95 | assert kmatch("abc/*/jkl", "abc/def/ghi/jkl") is None 96 | assert kmatch("**/*", "abc/def/ghi/jkl").group(1) == "abc/def/ghi" 97 | assert kmatch("**/*", "abc/def/ghi/jkl").group(2) == "jkl" 98 | 99 | def lerp(lam, t1, t2): 100 | return tree_map(lambda a, b: (1 - lam) * a + lam * b, t1, t2) 101 | 102 | def tree_norm(t): 103 | return jnp.sqrt(tree_reduce(operator.add, tree_map(lambda x: jnp.sum(x**2), t))) 104 | 105 | def tree_l2(t1, t2): 106 | return tree_norm(tree_map(lambda x, y: x - y, t1, t2)) 107 | 108 | def slerp(lam, t1, t2): 109 | # See https://en.wikipedia.org/wiki/Slerp 110 | om = jnp.arccos( 111 | tree_reduce(operator.add, tree_map(lambda x, y: jnp.sum(x * y), t1, t2)) / 112 | (tree_norm(t1) * tree_norm(t2))) 113 | sinom = jnp.sin(om) 114 | return tree_map( 115 | lambda x, y: jnp.sin((1 - lam) * om) / sinom * x + jnp.sin(lam * om) / sinom * y, 116 | t1, 117 | t2, 118 | ) 119 | --------------------------------------------------------------------------------