├── README.md ├── hyper.ipynb ├── images ├── README.md └── small.png ├── main.py ├── poster.pdf └── poster_hyperalignment.pdf /README.md: -------------------------------------------------------------------------------- 1 | # HYPER (brain decoding framework) 🧠 + 🤖 + 📖 = ✨ 2 | 3 | This repo accompanies the original paper "Hyperrealistic neural decoding for reconstructing faces from fMRI activations via the GAN latent space" ([Dado et al., 2022](https://www.nature.com/articles/s41598-021-03938-w)). This study introduces a novel experimental paradigm that uses synthesized yet highly naturalistic stimuli with a priori known feature representations together with an implementation thereof for HYperrealistic reconstruction of PERception (HYPER) of faces from brain recordings. The goal was to reveal what information was present in the recorded brain responses by reconstructing the original faces presented to the participants. 4 | 5 | [Click here](https://medium.com/neural-coding-lab/neural-decoding-w-synthesized-reality-5eeb476f399) for the blog post / tutorial. 6 | 7 | ## The experiment 8 | 9 | Two participants were presented with face images while we recorded their brain responses in the MRI scanner. After collecting the (faces, responses) dataset, we trained a decoding model to reconstruct what the participants were seeing from their (held-out test set) fMRI recordings alone. 10 | 11 |
12 |
13 |
14 | 15 | ![](https://github.com/Neural-Coding/HYPER/blob/master/images/small.png) 16 | 17 | The faces in the presented photographs do not really exist but are artificially generated by a progressiveGAN ([PGGAN](https://github.com/tkarras/progressive_growing_of_gans)) from randomly sampled latent vectors. The results suggest that the PGGAN latent space and the neural face manifold must have an approximate linear relationship that can be exploited during brain decoding. That is, the latent vectors used for face generation effectively capture the same defining stimulus features as the fMRI measurements. As such, we can predict the latents that underlie the perceived face images and feed them to the PGGAN for (re)generation, leading to the most accurate reconstructions of perception to date. 18 | 19 | 🤖🤖🤖 20 | 21 | 22 | 23 | ## Required components 24 | 25 | This repo contains a Jupyter Notebook that presents the approach. All required data to reproduce the results, preprocessing steps and test set images (stimuli and reconstructions) are made available on [Google Drive](https://drive.google.com/drive/u/1/folders/1NEblHtlRFvUyD5CA2sqSVfcGlfJBqw_T). 26 | 27 | 28 | ## Conclusion 🚀 29 | 30 | We used GANs for neural decoding of perception. It should be noted that the results of this study are valid reconstructions of visual perception regardless of the synthetic nature of the stimuli themselves. Considering the speed of progress in the field of generative modeling, this framework will likely result in even more impressive reconstructions of perception. Our approach constitutes a leap forward in our ability to reconstruct percepts from patterns of human brain activity. 31 | -------------------------------------------------------------------------------- /hyper.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "hyper", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "GPU" 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "code", 19 | "metadata": { 20 | "id": "5UtrCvbS4RG7" 21 | }, 22 | "source": [ 23 | "from __future__ import annotations\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "from google.colab import drive\n", 26 | "from PIL import Image\n", 27 | "import numpy as np\n", 28 | "import pickle\n", 29 | "import os" 30 | ], 31 | "execution_count": 1, 32 | "outputs": [] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "metadata": { 37 | "colab": { 38 | "base_uri": "https://localhost:8080/" 39 | }, 40 | "id": "gP7lLgUQ5nii", 41 | "outputId": "54f418b1-f5eb-4cf0-cbd6-ba9dec43d08d" 42 | }, 43 | "source": [ 44 | "drive.mount('/content/drive')" 45 | ], 46 | "execution_count": 2, 47 | "outputs": [ 48 | { 49 | "output_type": "stream", 50 | "text": [ 51 | "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" 52 | ], 53 | "name": "stdout" 54 | } 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "metadata": { 60 | "colab": { 61 | "base_uri": "https://localhost:8080/" 62 | }, 63 | "id": "P0tYmkTi4Tj-", 64 | "outputId": "e6a17d30-dfb8-4929-9c4d-5abc839a9304" 65 | }, 66 | "source": [ 67 | "!pip install mxnet-cu101\n", 68 | "\n", 69 | "from typing import Tuple, Union\n", 70 | "from mxnet import nd, symbol\n", 71 | "from mxnet.gluon.nn import HybridBlock\n", 72 | "from mxnet.gluon.parameter import Parameter\n", 73 | "from mxnet.initializer import Zero\n", 74 | "from mxnet.gluon.nn import Conv2D, HybridSequential, LeakyReLU, Dense\n", 75 | "from mxnet import nd, gluon, autograd\n", 76 | "import mxnet as mx\n", 77 | "from mxnet.io import NDArrayIter" 78 | ], 79 | "execution_count": 3, 80 | "outputs": [ 81 | { 82 | "output_type": "stream", 83 | "text": [ 84 | "Requirement already satisfied: mxnet-cu101 in /usr/local/lib/python3.7/dist-packages (1.7.0.post1)\n", 85 | "Requirement already satisfied: numpy<2.0.0,>1.16.0 in /usr/local/lib/python3.7/dist-packages (from mxnet-cu101) (1.19.5)\n", 86 | "Requirement already satisfied: requests<3,>=2.20.0 in /usr/local/lib/python3.7/dist-packages (from mxnet-cu101) (2.23.0)\n", 87 | "Requirement already satisfied: graphviz<0.9.0,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from mxnet-cu101) (0.8.4)\n", 88 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet-cu101) (2.10)\n", 89 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet-cu101) (3.0.4)\n", 90 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet-cu101) (1.24.3)\n", 91 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet-cu101) (2020.12.5)\n" 92 | ], 93 | "name": "stdout" 94 | } 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "metadata": { 100 | "id": "MP1X8zSQ4Tmh" 101 | }, 102 | "source": [ 103 | "def load_dataset(t, x, batch_size):\n", 104 | " return NDArrayIter({ \"x\": nd.stack(*x, axis=0) }, { \"t\": nd.stack(*t, axis=0) }, batch_size, True)" 105 | ], 106 | "execution_count": 4, 107 | "outputs": [] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "metadata": { 112 | "id": "SBAbcsjq4To6" 113 | }, 114 | "source": [ 115 | "class Linear(HybridSequential):\n", 116 | " def __init__(self, n_in, n_out):\n", 117 | " super(Linear, self).__init__()\n", 118 | " with self.name_scope():\n", 119 | " self.add(Dense(n_out, in_units=n_in))\n", 120 | "\n", 121 | "\n", 122 | "class Pixelnorm(HybridBlock):\n", 123 | " def __init__(self, epsilon: float = 1e-8) -> None:\n", 124 | " super(Pixelnorm, self).__init__()\n", 125 | " self.epsilon = epsilon\n", 126 | "\n", 127 | " def hybrid_forward(self, F, x) -> nd:\n", 128 | " return x * F.rsqrt(F.mean(F.square(x), 1, True) + self.epsilon)\n", 129 | "\n", 130 | "\n", 131 | "class Bias(HybridBlock):\n", 132 | " def __init__(self, shape: Tuple) -> None:\n", 133 | " super(Bias, self).__init__()\n", 134 | " self.shape = shape\n", 135 | " with self.name_scope():\n", 136 | " self.b = self.params.get(\"b\", init=Zero(), shape=shape)\n", 137 | "\n", 138 | " def hybrid_forward(self, F, x, b) -> nd:\n", 139 | " return F.broadcast_add(x, b[None, :, None, None])\n", 140 | "\n", 141 | "\n", 142 | "class Block(HybridSequential):\n", 143 | " def __init__(self, channels: int, in_channels: int) -> None:\n", 144 | " super(Block, self).__init__()\n", 145 | " self.channels = channels\n", 146 | " self.in_channels = in_channels\n", 147 | " with self.name_scope():\n", 148 | " self.add(Conv2D(channels, 3, padding=1, in_channels=in_channels))\n", 149 | " self.add(LeakyReLU(0.2))\n", 150 | " self.add(Pixelnorm())\n", 151 | " self.add(Conv2D(channels, 3, padding=1, in_channels=channels))\n", 152 | " self.add(LeakyReLU(0.2))\n", 153 | " self.add(Pixelnorm())\n", 154 | "\n", 155 | " def hybrid_forward(self, F, x) -> nd:\n", 156 | " x = F.repeat(x, 2, 2)\n", 157 | " x = F.repeat(x, 2, 3)\n", 158 | " for i in range(len(self)):\n", 159 | " x = self[i](x)\n", 160 | " return x" 161 | ], 162 | "execution_count": 5, 163 | "outputs": [] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "metadata": { 168 | "id": "nbt6GO9Y4Trm" 169 | }, 170 | "source": [ 171 | "class Generator(HybridSequential):\n", 172 | " def __init__(self) -> None:\n", 173 | " super(Generator, self).__init__()\n", 174 | " with self.name_scope():\n", 175 | " self.add(Pixelnorm())\n", 176 | " self.add(Dense(8192, use_bias=False, in_units=512))\n", 177 | " self.add(Bias((512,)))\n", 178 | " self.add(LeakyReLU(0.2))\n", 179 | " self.add(Pixelnorm())\n", 180 | " self.add(Conv2D(512, 3, padding=1, in_channels=512))\n", 181 | " self.add(LeakyReLU(0.2))\n", 182 | " self.add(Pixelnorm())\n", 183 | " \n", 184 | " self.add(Block(512, 512)) # 8\n", 185 | " self.add(Block(512, 512))\n", 186 | " self.add(Block(512, 512))\n", 187 | " self.add(Block(256, 512))\n", 188 | " self.add(Block(128, 256))\n", 189 | " self.add(Block(64, 128))\n", 190 | " self.add(Block(32, 64))\n", 191 | " self.add(Block(16, 32)) # 15\n", 192 | " self.add(Conv2D(3, 1, in_channels=16))\n", 193 | "\n", 194 | "\n", 195 | " def hybrid_forward(self, F: Union(nd, symbol), x: nd, layer: int) -> nd:\n", 196 | " x = F.Reshape(self[1](self[0](x)), (-1, 512, 4, 4))\n", 197 | " for i in range(2, len(self)):\n", 198 | " x = self[i](x)\n", 199 | " if i == layer + 7:\n", 200 | " return x\n", 201 | " return x" 202 | ], 203 | "execution_count": 6, 204 | "outputs": [] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "metadata": { 209 | "id": "URXm1Nk54fz3" 210 | }, 211 | "source": [ 212 | "max_epoch = 1500\n", 213 | "batch_size = 30\n", 214 | "n_vox = 4096\n", 215 | "n_lat = 512" 216 | ], 217 | "execution_count": 7, 218 | "outputs": [] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "metadata": { 223 | "id": "wD8Fux7h4f2U" 224 | }, 225 | "source": [ 226 | "# Note that we are using gradient descent to fit the weights of the dense layer whereas ordinary least squares would yield a similar\n", 227 | "# solution. However, the current setup allows you to experiment and try different things to make more sophisticated models (e.g., predict\n", 228 | "#intermediate layer activations of PGGAN and include this in your loss function).\n", 229 | "generator = Generator()\n", 230 | "generator.load_parameters(\"/content/drive/MyDrive/HYPER/data/generator.params\")\n", 231 | "mean_squared_error = gluon.loss.L2Loss()\n", 232 | "for subject in [1, 2]:\n", 233 | " \n", 234 | " # Data\n", 235 | " with open(\"/content/drive/MyDrive/HYPER/data/data_%i.dat\" % subject, 'rb') as f:\n", 236 | " X_tr, T_tr, X_te, T_te = pickle.load(f)\n", 237 | " train = load_dataset(nd.array(T_tr), nd.array(X_tr), batch_size) \n", 238 | " test = load_dataset(nd.array(T_te), nd.array(X_te), batch_size=36) \n", 239 | "\n", 240 | " # Training\n", 241 | " vox_to_lat = Linear(n_vox, n_lat)\n", 242 | " vox_to_lat.initialize()\n", 243 | " trainer = gluon.Trainer(vox_to_lat.collect_params(), \"Adam\", {\"learning_rate\": 0.00001, \"wd\": 0.01})\n", 244 | " epoch = 0\n", 245 | " results_tr = []\n", 246 | " results_te = []\n", 247 | " while epoch < max_epoch:\n", 248 | " train.reset()\n", 249 | " test.reset()\n", 250 | " loss_tr = 0\n", 251 | " loss_te = 0\n", 252 | " count = 0\n", 253 | " for batch_tr in train:\n", 254 | " with autograd.record():\n", 255 | " lat_Y = vox_to_lat(batch_tr.data[0])\n", 256 | " loss = mean_squared_error(lat_Y, batch_tr.label[0])\n", 257 | " loss.backward()\n", 258 | " trainer.step(batch_size)\n", 259 | " loss_tr += loss.mean().asnumpy()\n", 260 | " count += 1\n", 261 | " for batch_te in test:\n", 262 | " lat_Y = vox_to_lat(batch_te.data[0])\n", 263 | " loss = mean_squared_error(lat_Y, batch_te.label[0])\n", 264 | " loss_te += loss.mean().asnumpy()\n", 265 | " loss_tr_normalized = loss_tr / count\n", 266 | " results_tr.append(loss_tr_normalized)\n", 267 | " results_te.append(loss_te)\n", 268 | " epoch += 1\n", 269 | " print(\"Epoch %i: %.4f / %.4f\" % (epoch, loss_tr_normalized, loss_te))\n", 270 | "\n", 271 | " plt.figure()\n", 272 | " plt.plot(np.linspace(0, epoch, epoch), results_tr)\n", 273 | " plt.plot(np.linspace(0, epoch, epoch), results_te)\n", 274 | " plt.savefig(\"loss_s%i.png\" % subject)\n", 275 | "\n", 276 | " # Testing and reconstructing\n", 277 | " lat_Y = vox_to_lat(nd.array(X_te))\n", 278 | " dir = '/content/faces_%i' % subject\n", 279 | " if not os.path.exists(dir):\n", 280 | " os.mkdir(dir)\n", 281 | " for i, latent in enumerate(lat_Y):\n", 282 | " face = generator(latent[None], 9).asnumpy()\n", 283 | " face = np.clip(np.rint(127.5 * face + 127.5), 0.0, 255.0)\n", 284 | " face = face.astype(\"uint8\")\n", 285 | " face = face.transpose(0, 2, 3, 1)\n", 286 | " Image.fromarray(face[0], 'RGB').save(dir + \"/%d.png\" % i)" 287 | ], 288 | "execution_count": 8, 289 | "outputs": [] 290 | } 291 | ] 292 | } 293 | -------------------------------------------------------------------------------- /images/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralcodinglab/HYPER/6d0562d90ae9f5b9e332c6a666e5f128faa70319/images/small.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from skimage.metrics import structural_similarity as ssim 2 | from sklearn.metrics import mean_squared_error 3 | from scipy.stats import pearsonr 4 | from chainer import serializers 5 | from chainer import Chain 6 | import chainer.links as L 7 | import tensorflow as tf 8 | from PIL import Image 9 | import numpy as np 10 | import cupy 11 | import pickle 12 | import os 13 | 14 | from keras_vggface.vggface import VGGFace 15 | from keras.preprocessing import image 16 | from keras_vggface import utils 17 | 18 | config = tf.compat.v1.ConfigProto() 19 | config.gpu_options.allow_growth = True 20 | config.log_device_placement = True 21 | sess = tf.Session(config=config) 22 | 23 | 24 | class ModelLinear(Chain): 25 | def __init__(self, n_in, n_out): 26 | super(ModelLinear, self).__init__() 27 | with self.init_scope(): 28 | self.fc1 = L.Linear(n_in, n_out) 29 | 30 | def __call__(self, x): 31 | return self.fc1(x) 32 | 33 | 34 | def load_network(dev): 35 | p = "karras2018iclr-celebahq-1024x1024.pkl" 36 | sess = tf.InteractiveSession() 37 | with tf.device('/gpu:%d' % dev): 38 | _, _, Gs = pickle.load(open(p, "rb")) 39 | 40 | weights = sess.run(tf.trainable_variables()) 41 | with open('pggan_weights', 'wb') as f: 42 | pickle.dump(weights, f) 43 | return Gs 44 | 45 | 46 | def face_from_latent(model, latents, my_path, save_image=True): 47 | dummy_label = np.zeros([1] + model.input_shapes[1][1:]) 48 | for i in range(latents.shape[0]): 49 | latent = np.expand_dims(latents[i], 0) 50 | face = model.run(latent, dummy_label) 51 | face = np.clip(np.rint((face + 1.0) / 2.0 * 255.0), 0.0, 255.0).astype(np.uint8) 52 | face = face.transpose((0, 2, 3, 1)) 53 | 54 | if save_image: 55 | if not os.path.exists(my_path): 56 | os.mkdir(my_path) 57 | save_path = os.path.join(my_path, '%d.png' % i) 58 | Image.fromarray(face[0], 'RGB').save(save_path) 59 | else: 60 | Image.fromarray(face[0], 'RGB').show() 61 | 62 | 63 | def get_features(img): 64 | img = img.resize((224, 224)) 65 | x = image.img_to_array(img) 66 | x = np.expand_dims(x, axis=0) 67 | x = utils.preprocess_input(x, version=2) 68 | feats = vgg_features.predict(x) 69 | return feats[0][0][0] 70 | 71 | 72 | if __name__ == '__main__': 73 | 74 | np.random.seed(1412) 75 | dev = 1 76 | with tf.device('/gpu:%d' % dev): 77 | model = load_network(dev=dev) 78 | 79 | # data 80 | with open('data_sub1_4096.dat', 'rb') as fp: 81 | _, _, X_test, T_test = pickle.load(fp) 82 | 83 | # predict latents 84 | with cupy.cuda.Device(dev): 85 | model_linear = ModelLinear(n_in=4096, n_out=512).to_gpu(dev) 86 | serializers.load_npz('l0_s1_5000_final.model', model_linear) 87 | X_test = cupy.array(X_test, dtype=cupy.float32) 88 | T_test = cupy.array(T_test, dtype=cupy.float32) 89 | Y_test = model_linear(X_test).array 90 | 91 | # generate stimuli and reconstructions 92 | face_from_latent(model, cupy.asnumpy(T_test), 'stimuli', save_image=True) 93 | face_from_latent(model, cupy.asnumpy(Y_test), 'reconstructions', save_image=True) 94 | 95 | # stimuli vs. reconstructions 96 | trials = len(T_test) 97 | metrics = {"lsim": np.zeros((trials, )), 98 | "fsim": np.zeros((trials, )), 99 | "ssim": np.zeros((trials, )), 100 | "gender": np.zeros((trials, )), 101 | "age": np.zeros((trials, )), 102 | "eyeglasses": np.zeros((trials, )), 103 | "pose": np.zeros((trials, )), 104 | "smile": np.zeros((trials, ))} 105 | 106 | test_feats = np.zeros((36, 2048)) 107 | pred_feats = np.zeros((36, 2048)) 108 | test_scores = np.zeros((5, 36)) 109 | pred_scores = np.zeros((5, 36)) 110 | vgg_features = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3)) 111 | 112 | gender = np.load('bounds/pggan_celebahq_gender_boundary.npy') 113 | age = np.load('bounds/pggan_celebahq_age_boundary.npy') 114 | eyeglasses = np.load('bounds/pggan_celebahq_eyeglasses_boundary.npy') 115 | pose = np.load('bounds/pggan_celebahq_pose_boundary.npy') 116 | smile = np.load('bounds/pggan_celebahq_smile_boundary.npy') 117 | boundaries = [gender, age, eyeglasses, pose, smile] 118 | 119 | Y_test = cupy.asnumpy(Y_test) 120 | T_test = cupy.asnumpy(T_test) 121 | for trial in range(trials): 122 | stim_ssim = np.array(Image.open("stimuli/%i.png" % trial)) 123 | recon_ssim = np.array(Image.open("reconstructions/%i.png" % trial)) 124 | stim_fsim = image.load_img("stimuli/%i.png" % trial, target_size=(224, 224)) 125 | recon_fsim = image.load_img("reconstructions/%i.png" % trial, target_size=(224, 224)) 126 | 127 | test_feats[trial] = get_features(stim_fsim) 128 | pred_feats[trial] = get_features(recon_fsim) 129 | metrics['lsim'][trial] = 1. / (1 + mean_squared_error(Y_test[trial], T_test[trial])) 130 | metrics['fsim'][trial] = 1. / (1 + mean_squared_error(test_feats[trial], pred_feats[trial])) 131 | metrics['ssim'][trial] = ssim(stim_ssim, recon_ssim, multichannel=True) 132 | 133 | for i, boundary in enumerate(boundaries): 134 | test_scores[i, trial] = T_test[trial].reshape(1, -1).dot(boundary.T)[0][0] 135 | pred_scores[i, trial] = Y_test[trial].reshape(1, -1).dot(boundary.T)[0][0] 136 | 137 | # print metrics 138 | print("latent similarity: %.4f" % metrics['lsim'].mean()) 139 | print("Feature similarity: %.4f" % metrics['fsim'].mean()) 140 | print("Structural similarity: %.4f" % metrics['ssim'].mean()) 141 | 142 | names = ["Gender", "Age", "Eyeglasses", "Pose", "Smile"] 143 | for i in range(5): 144 | corr, pval = pearsonr(test_scores[i], pred_scores[i]) 145 | print("%s Corr. coef.: %.4f" % (names[i], corr)) 146 | 147 | -------------------------------------------------------------------------------- /poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralcodinglab/HYPER/6d0562d90ae9f5b9e332c6a666e5f128faa70319/poster.pdf -------------------------------------------------------------------------------- /poster_hyperalignment.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralcodinglab/HYPER/6d0562d90ae9f5b9e332c6a666e5f128faa70319/poster_hyperalignment.pdf --------------------------------------------------------------------------------