├── NN.pdf ├── README.md ├── addons ├── README.md ├── TI_FrozenWin.ipynb ├── net_FrozenInputWeights.dat ├── toymodel.png └── toymodel.py ├── main.ipynb ├── net_active.dat ├── net_passive.dat └── simple.ipynb /NN.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasMiconi/TransitiveInference/fbdebf06f6d335bcd4473a4fa24926208a98910f/NN.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-training plastic networks for transitive inference 2 | 3 | This is the code for the paper [Neural mechanisms of relational learning and fast knowledge reassembly in plastic neural networks](https://thomasmiconi.github.io/NN.pdf), by Thomas Miconi and Kenneth Kay, Nature Neuroscience 2024 (previous preprint [here](https://www.biorxiv.org/content/10.1101/2023.07.27.550739)). 4 | 5 | We also include parameter files for two pre-trained networks, representing each of the two strategies (active, list-linking and passive, not list-linking) described in the paper. 6 | 7 | The code consists of two notebooks. These notebooks are immediately usable on Google Colab, as-is. 8 | 9 | If you just want to understand how the system works, it is **highly** recommended to look at `simple.ipynb` first. 10 | 11 | The code actually used for the paper is in `main.ipynb`. This code includes a lot of additional code for running the various experiments from the paper. By contrast, the code in `simple.ipynb` (which only contains one large code cell) is a simplified version that only includes the basic code for meta-training a plastic network for transitive inference. The network structure and experimental settings are essentially idenctical between the two, with only the additional code for the various side experiments removed. 12 | 13 | Note that the networks produced by `simple.ipynb` can be used in the EVAL (figure-producing) mode of `main.ipynb`. 14 | 15 | Consult the respecitve notebooks for more details. 16 | 17 | ### To generate the figures from the paper 18 | 19 | 1- Copy `net_active.dat` to `net.dat` and upload it to where the notebook can access it. 20 | 21 | 2- In line 207 of `main.ipynb`, set EVAL to `True` 22 | 23 | 3- Run `main.ipynb` (making sure that `net.dat` is in the path of your notebook) 24 | 25 | This produces figures for the active strategy (capable of list-linking). Other figures may need more modifications, consult the relevant cells in `main.ipynb`. 26 | 27 | To produce similar figures for the passive strategy (not capable of list-linking), use `net_passive.dat` (and rename it to `net.dat`) instead. 28 | 29 | ### To train your own networks from scratch 30 | 31 | 1- In line 207 of `main.ipynb`, set EVAL to `False` 32 | 33 | 2- Run `main.ipynb` 34 | 35 | This will run for 30000 iterations (which might take a few hours) and produce a fully meta-trained plastic network, stored in `net.dat`. You can then use `main.ipynb` (with EVAL set to `True` in line 207) to produce figures for this trained network. 36 | 37 | -------------------------------------------------------------------------------- /addons/README.md: -------------------------------------------------------------------------------- 1 | ### Additional code 2 | 3 | This additional programs mentioned in the Appendix of the paper: 4 | 5 | 1- `toymodel.py` contains the toy model version of the neural algorithm. 6 | 2- `TI_FrozenWin.py` contains a version of the main program with frozen, non-trainable input weights. It behaves much like `main.ipynb` in the main folder (which it is based on). 7 | 8 | -------------------------------------------------------------------------------- /addons/TI_FrozenWin.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "## Frozen Input Weights Version" 7 | ], 8 | "metadata": { 9 | "id": "ph0xjIU6csTK" 10 | } 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "mORufQW4nEYK" 16 | }, 17 | "source": [ 18 | "## HOW TO USE THIS NOTEBOOK\n", 19 | "\n", 20 | "\n", 21 | "1. Run this notebook once. It will optimize a network and store it in net.dat. It will run for 30000 iterations and generate the `net.dat` file that contains the trained network's parameters.\n", 22 | "\n", 23 | "2. In line 234, set EVAL=True\n", 24 | "\n", 25 | "3. Re-run the notebook. It will evaluate the trained network in `net.dat` by running 10 episodes (without changing the parameters of course), each over a batch of 500 individuals\n", 26 | "\n", 27 | "4. See the figures\n", 28 | "\n" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "source": [ 34 | "(Note: This code added some decorrelation between cues in comparison to a previous version)" 35 | ], 36 | "metadata": { 37 | "id": "UHx325-3Zxrn" 38 | } 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": { 44 | "id": "0Nboz_4ynCaZ" 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "\n", 49 | "# What GPU are we using?\n", 50 | "!nvidia-smi" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "id": "x277nTktumok" 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "# Based on the code for the Stimulus-response task as described in Miconi et al. ICLR 2019.\n", 62 | "\n", 63 | "import argparse\n", 64 | "import pdb\n", 65 | "import torch\n", 66 | "import torch.nn as nn\n", 67 | "import numpy as np\n", 68 | "from numpy import random\n", 69 | "import torch.nn.functional as F\n", 70 | "from torch import optim\n", 71 | "from torch.optim import lr_scheduler\n", 72 | "import random\n", 73 | "import sys\n", 74 | "import pickle\n", 75 | "import time\n", 76 | "import os\n", 77 | "import platform\n", 78 | "\n", 79 | "import numpy as np\n", 80 | "import glob\n", 81 | "\n", 82 | "\n", 83 | "\n", 84 | "myseed = -1\n", 85 | "\n", 86 | "\n", 87 | "# If running this code on a cluster, uncomment the following, and pass a RNG seed as the --seed parameter on the command line\n", 88 | "# parser = argparse.ArgumentParser()\n", 89 | "# parser.add_argument('--seed', type=int, default=-1)\n", 90 | "# args = parser.parse_args()\n", 91 | "# myseed = args.seed\n", 92 | "\n", 93 | "\n", 94 | "\n", 95 | "# This needs to be before parameter initialization\n", 96 | "NBMASSEDTRIALS = 0\n", 97 | "MASSEDPAIR = [3,4]\n", 98 | "\n", 99 | "\n", 100 | "\n", 101 | "\n", 102 | "np.set_printoptions(precision=5)\n", 103 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 104 | "# device = 'cpu'\n", 105 | "\n", 106 | "\n", 107 | "params={}\n", 108 | "params['rngseed']=myseed # RNG seed, or -1 for no seed\n", 109 | "params['rew']=1.0 # reward amount\n", 110 | "params['wp']=.0 # penalty for taking action 1 (not used here)\n", 111 | "params['bent']=.1 # entropy incentive (actually sum-of-squares)\n", 112 | "params['blossv']=.1 # value prediction loss coefficient\n", 113 | "params['gr']=.9 # Gamma for temporal reward discounting\n", 114 | "\n", 115 | "params['hs']=200 # Size of the RNN's hidden layer\n", 116 | "params['bs']=32 # Batch size\n", 117 | "params['gc']=2.0 # Gradient clipping\n", 118 | "params['eps']=1e-6 # A parameter for Adam\n", 119 | "params['nbiter']= 30000 # 60000\n", 120 | "params['save_every']=200\n", 121 | "params['pe']= 101 #\"print every\"\n", 122 | "\n", 123 | "\n", 124 | "params['nbcuesrange'] = range(4,9)\n", 125 | "# params['nbcues']= 5 # 7 # number of inputs - number of different stimuli used for each episode\n", 126 | "\n", 127 | "params['cs']= 15 # 10 # Cue size - number of binary elements in each cue vector (not including the 'go' bit and additional inputs, see below)\n", 128 | "\n", 129 | "params['triallen'] = 4 # 4 # 5 # 5 + 1 # 4 + 1 # Each trial has: stimulus presentation, 'go' cue, then 3 empty trials.\n", 130 | "NUMRESPONSESTEP = 1\n", 131 | "params['nbtraintrials'] = 20 # 22 # 20 # 5 # The first nbtraintrials are the \"train\" trials. This is included in nbtrials.\n", 132 | "params['nbtesttrials'] = 10 # 2 # 12 # 10 # The last nbtesttrials are the \"test\" trials. This is included in nbtrials.\n", 133 | "params['nbtrials'] = params['nbtraintrials'] + NBMASSEDTRIALS + params['nbtesttrials'] # 20 # Number of trials per episode\n", 134 | "params['eplen'] = params['nbtrials'] * params['triallen'] # eplen = episode length\n", 135 | "params['testlmult'] = 3.0 # multiplier for the loss during the test trials\n", 136 | "params['l2'] = 0 # 1e-5 # L2 penalty\n", 137 | "params['lr'] = 1e-4\n", 138 | "params['lpw'] = 1e-4 # 3 # plastic weight loss\n", 139 | "params['lda'] = 0 # 1e-4 # 3e-5 # 1e-4 # 1e-5 # DA output penalty\n", 140 | "params['lhl1'] = 0 # 3e-5\n", 141 | "params['nbepsbwresets'] = 3 # 1\n", 142 | "\n", 143 | "PROBAOLDDATA = .25\n", 144 | "POSALPHA = False\n", 145 | "POSALPHAINITONLY = False\n", 146 | "VECTALPHA = False\n", 147 | "\n", 148 | "# RNN with plastic connections and neuromodulation (\"DA\").\n", 149 | "# Plasticity only in the recurrent connections, for now.\n", 150 | "\n", 151 | "class RetroModulRNN(nn.Module):\n", 152 | " def __init__(self, params):\n", 153 | " super(RetroModulRNN, self).__init__()\n", 154 | " # NOTE: 'outputsize' excludes the value and neuromodulator outputs!\n", 155 | " for paramname in ['outputsize', 'inputsize', 'hs', 'bs']:\n", 156 | " if paramname not in params.keys():\n", 157 | " raise KeyError(\"Must provide missing key in argument 'params': \"+paramname)\n", 158 | " NBDA = 2 # 2 DA neurons, we take the difference - see below\n", 159 | " self.params = params\n", 160 | " self.activ = torch.tanh\n", 161 | " self.i2h = torch.nn.Linear(self.params['inputsize'], params['hs']).to(device)\n", 162 | "\n", 163 | "\n", 164 | " # FREEZING INPUT WEIGHTS\n", 165 | "\n", 166 | " # Doesnt work:\n", 167 | " # self.i2h.weight.data.fill_(0.0)\n", 168 | " # self.i2h.weight[:,:GG['hs']//4].data.fill_(1.0) # only a portion of the RNN get inputs\n", 169 | "\n", 170 | " # # This works well. net1.dat. It does do list-linking. Seems to be using a broadly similar strategy too - at least the DA traces are similar.\n", 171 | " # self.i2h.weight.data.fill_(0.0)\n", 172 | " # for nn in range(GG['inputsize']):\n", 173 | " # self.i2h.weight.data[nn,nn] = 1.0\n", 174 | " # # pdb.set_trace()\n", 175 | " # self.i2h.weight.data[nn+GG['inputsize'],nn] = 1.0\n", 176 | " # self.i2h.bias.data.fill_(0.0)\n", 177 | " # self.i2h.requires_grad = False\n", 178 | "\n", 179 | "\n", 180 | " # # This may work but slowly?...\n", 181 | " # self.i2h.weight[GG['hs']//3:,:].data.fill_(0.0) # only a portion of the RNN get inputs\n", 182 | " # self.i2h.bias.data.fill_(0.0)\n", 183 | " # self.i2h.requires_grad = False\n", 184 | "\n", 185 | " # This also works well and does list-linking. net2.dat.\n", 186 | " self.i2h.weight[params['hs']//3:,:].data.fill_(0.0) # only a portion of the RNN get inputs\n", 187 | " self.i2h.weight[params['hs']//3:,:].data *= 10.0\n", 188 | " self.i2h.bias.data.fill_(0.0)\n", 189 | " self.i2h.requires_grad = False\n", 190 | "\n", 191 | "\n", 192 | "\n", 193 | "\n", 194 | " self.w = torch.nn.Parameter(( (1.0 / np.sqrt(params['hs'])) * ( 2.0 * torch.rand(params['hs'], params['hs']) - 1.0) ).to(device), requires_grad=True)\n", 195 | " #self.alpha = torch.nn.Parameter((.1 * torch.ones(params['hs'], params['hs'])).to(device), requires_grad=True)\n", 196 | " # self.alpha = torch.nn.Parameter((.01 * torch.ones(params['hs'], params['hs'])).to(device), requires_grad=True)\n", 197 | " if VECTALPHA:\n", 198 | " self.alpha = .01 * (2.0 * torch.rand(params['hs'], 1) -1.0).to(device) # A column vector, so each neuron has a single plasticity coefficient applied to all its input connections\n", 199 | " else:\n", 200 | " self.alpha = .01 * (2.0 * torch.rand(params['hs'], params['hs']) -1.0).to(device)\n", 201 | " if POSALPHA or POSALPHAINITONLY:\n", 202 | " self.alpha = torch.abs(self.alpha)\n", 203 | " self.alpha = torch.nn.Parameter(self.alpha, requires_grad=True)\n", 204 | " # self.etaet = torch.nn.Parameter((.5 * torch.ones(1)).to(device), requires_grad=True) # Everyone has the same etaet\n", 205 | " self.etaet = torch.nn.Parameter((.7 * torch.ones(1)).to(device), requires_grad=True) # Everyone has the same etaet\n", 206 | " # self.DAmult = torch.nn.Parameter((1.0 * torch.ones(1)).to(device), requires_grad=True) # Everyone has the same DAmult\n", 207 | " self.DAmult = torch.nn.Parameter((1.0 * torch.ones(1)).to(device), requires_grad=True) # Everyone has the same DAmult\n", 208 | " # self.DAmult = .2\n", 209 | " self.h2DA = torch.nn.Linear(params['hs'], NBDA).to(device) # DA output\n", 210 | " self.h2o = torch.nn.Linear(params['hs'], self.params['outputsize']).to(device) # Actual output\n", 211 | " self.h2v = torch.nn.Linear(params['hs'], 1).to(device) # V prediction\n", 212 | "\n", 213 | " def forward(self, inputs, hidden, et, pw):\n", 214 | " BATCHSIZE = inputs.shape[0] # self.params['bs']\n", 215 | " HS = self.params['hs']\n", 216 | " assert pw.shape[0] == hidden.shape[0] == et.shape[0] == BATCHSIZE\n", 217 | "\n", 218 | " # Multiplying inputs (i.e. current hidden values) by the total recurrent weights, w + alpha * plastic_weights\n", 219 | " hactiv = self.activ(self.i2h(inputs).view(BATCHSIZE, HS, 1) + torch.matmul((self.w + torch.mul(self.alpha, pw)),\n", 220 | " hidden.view(BATCHSIZE, HS, 1))).view(BATCHSIZE, HS)\n", 221 | " activout = self.h2o(hactiv) # Output layer. Pure linear, raw scores - will be softmaxed later\n", 222 | " valueout = self.h2v(hactiv) # Value prediction\n", 223 | "\n", 224 | " # Now computing the Hebbian updates...\n", 225 | "\n", 226 | " # With batching, DAout is a matrix of size BS x 1\n", 227 | " DAout2 = torch.tanh(self.h2DA(hactiv))\n", 228 | " DAout = self.DAmult * (DAout2[:,0] - DAout2[:,1])[:,None] # DA output is the difference between two tanh neurons - allows negative, positive and easy stable 0 output (by jamming both neurons to max or min)\n", 229 | "\n", 230 | "\n", 231 | " # Eligibility trace gets stamped into the plastic weights - gated by DAout\n", 232 | " deltapw = DAout.view(BATCHSIZE,1,1) * et\n", 233 | " pw = pw + deltapw\n", 234 | "\n", 235 | " torch.clip_(pw, min=-50.0, max=50.0)\n", 236 | "\n", 237 | "\n", 238 | "\n", 239 | " # Updating the eligibility trace - Hebbbian update with a simple decay\n", 240 | " # NOTE: the decay is for the eligibility trace, NOT the plastic weights (which never decay during a lifetime, i.e. an episode)\n", 241 | " deltaet = torch.bmm(hactiv.view(BATCHSIZE, HS, 1), hidden.view(BATCHSIZE, 1, HS)) # batched outer product; at this point 'hactiv' is the output and 'hidden' is the input (i.e. ativities from previous time step)\n", 242 | " # deltaet = torch.bmm(hactiv.view(BATCHSIZE, HS, 1), hidden.view(BATCHSIZE, 1, HS)) - et * hactiv[:, :, None] ** 2 # Oja's rule (...? anyway, doesn't ensure stability with tanh and arbitrary damult / etaet)\n", 243 | " # deltaet = torch.bmm(hactiv.view(BATCHSIZE, HS, 1), hidden.view(BATCHSIZE, 1, HS)) - hactiv.view(BATCHSIZE, HS, 1) * et # Instar rule (?)\n", 244 | "\n", 245 | " deltaet = torch.tanh(deltaet)\n", 246 | "\n", 247 | " et = (1 - self.etaet) * et + self.etaet * deltaet\n", 248 | " # et = deltaet\n", 249 | "\n", 250 | " hidden = hactiv\n", 251 | " return activout, valueout, DAout, hidden, et, pw\n", 252 | "\n", 253 | "\n", 254 | "\n", 255 | "\n", 256 | " def initialZeroET(self, mybs):\n", 257 | " # return torch.zeros(self.params['bs'], self.params['hs'], self.params['hs'], requires_grad=False).to(device)\n", 258 | " return torch.zeros(mybs, self.params['hs'], self.params['hs'], requires_grad=False).to(device)\n", 259 | "\n", 260 | " def initialZeroPlasticWeights(self, mybs):\n", 261 | " return torch.zeros(mybs, self.params['hs'], self.params['hs'] , requires_grad=False).to(device)\n", 262 | " def initialZeroState(self, mybs):\n", 263 | " return torch.zeros(mybs, self.params['hs'], requires_grad=False ).to(device)\n", 264 | "\n", 265 | "\n", 266 | "\n", 267 | "print(\"Starting...\")\n", 268 | "\n", 269 | "print(\"Passed params: \", params)\n", 270 | "print(platform.uname())\n", 271 | "suffix = \"_\"+\"\".join( [str(kk)+str(vv)+\"_\" if kk != 'pe' and kk != 'nbsteps' and kk != 'rngseed' and kk != 'save_every' and kk != 'test_every' else '' for kk, vv in sorted(zip(params.keys(), params.values()))] ) + \"_rng\" + str(params['rngseed']) # Turning the parameters into a nice suffix for filenames\n", 272 | "print(suffix)\n", 273 | "\n", 274 | "\n", 275 | "# Total input size = cue size + one 'go' bit + 4 additional inputs\n", 276 | "ADDINPUT = 4 # 1 inputs for the previous reward, 1 inputs for numstep, 1 unused, 1 \"Bias\" inputs\n", 277 | "NBSTIMBITS = 2 * params['cs'] + 1 # The additional bit is for the response cue (i.e. the \"Go\" cue)\n", 278 | "params['outputsize'] = 2 # \"response\" and \"no response\"\n", 279 | "params['inputsize'] = NBSTIMBITS + ADDINPUT + params['outputsize'] # The total number of input bits is the size of cues, plus the \"response cue\" binary input, plus the number of additional inputs, plus the number of actions\n", 280 | "\n", 281 | "\n", 282 | "# Initialize random seeds, unless rngseed is -1 (first two redundant?)\n", 283 | "if params['rngseed'] > -1 :\n", 284 | " print(\"Setting random seed\", params['rngseed'])\n", 285 | " np.random.seed(params['rngseed']); random.seed(params['rngseed']); torch.manual_seed(params['rngseed'])\n", 286 | "else:\n", 287 | " print(\"No random seed.\")\n", 288 | "\n", 289 | "\n", 290 | "\n", 291 | "\n", 292 | "\n", 293 | "# Are we running in evaluation mode?\n", 294 | "EVAL = False\n", 295 | "\n", 296 | "\n", 297 | "# Various possible experiments:\n", 298 | "\n", 299 | "RESETHIDDENEVERYTRIAL = RESETETEVERYTRIAL = True # False # True\n", 300 | "\n", 301 | "\n", 302 | "\n", 303 | "RESETPWEVERYTRIAL = False\n", 304 | "\n", 305 | "\n", 306 | "ONLYTWOLASTADJ = False\n", 307 | "\n", 308 | "LINKEDLISTSEVAL = False\n", 309 | "LINKINGISSHAM = False\n", 310 | "\n", 311 | "FIXEDCUES = False\n", 312 | "\n", 313 | "HALFNOBARREDPAIRUNTILT18 = False # Ensures that half the batch never sees the \"barred\" pair before trial 18. This should only be used for one thing: ensuring enough selects and selectadd's when looking at single-step weight changes, so that some figures look better.\n", 314 | "BARREDPAIR = [3,4]\n", 315 | "# BARREDPAIR = [2,3]\n", 316 | "# TO MAKE THE PLOTS WITH ADDITIONAL BARRED PAIR:\n", 317 | "# 1- Set the proper ADDBARREDPAIRR below (the pair just before or just after the main BARRED PAIR)\n", 318 | "# 2- Set SHOWALLSELECTS = False below\n", 319 | "\n", 320 | "\n", 321 | "\n", 322 | "if EVAL:\n", 323 | " params['nbiter'] = 1 # 5 # 10\n", 324 | " params['bs'] = 2000\n", 325 | " params['nbcues'] = 8\n", 326 | " if not LINKEDLISTSEVAL:\n", 327 | " params['nbepsbwresets'] = 1\n", 328 | " torch.set_grad_enabled(False)\n", 329 | "if LINKEDLISTSEVAL:\n", 330 | " assert EVAL\n", 331 | " assert NBMASSEDTRIALS==0\n", 332 | " assert params['nbepsbwresets'] == 3\n", 333 | " params['nbiter'] = 3\n", 334 | " params['nbcues'] = 8 # 10\n", 335 | " params['bs'] = 4000\n", 336 | " SHOWFIRSTHALFFIRST = 1 # np.random.randint(2)\n", 337 | " # The following applies for the first 2 episodes, then will be modified later for the 3rd episode\n", 338 | " params['nbtraintrials'] = 10\n", 339 | " params['nbtesttrials'] = 0\n", 340 | " params['nbtrials'] = params['nbtraintrials'] + params['nbtesttrials']\n", 341 | " params['eplen'] = params['nbtrials'] * params['triallen'] # eplen = episode length\n", 342 | "\n", 343 | "if FIXEDCUES:\n", 344 | " params['bs'] = 2000\n", 345 | "\n", 346 | "BS = params['bs'] # Batch size\n", 347 | "\n", 348 | "\n", 349 | "assert not ( (NBMASSEDTRIALS > 0 ) and (not EVAL) ) # We should only use massed trials in eval, not training\n", 350 | "if ONLYTWOLASTADJ:\n", 351 | " assert params['nbcues'] == 7\n", 352 | "if HALFNOBARREDPAIRUNTILT18:\n", 353 | " assert EVAL and (NBMASSEDTRIALS == 0) and not LINKEDLISTSEVAL and not ONLYTWOLASTADJ and not FIXEDCUES\n", 354 | "if LINKINGISSHAM:\n", 355 | " assert LINKEDLISTSEVAL\n", 356 | "\n", 357 | "\n", 358 | "\n", 359 | "\n", 360 | "\n", 361 | "print(\"Initializing network\")\n", 362 | "net = RetroModulRNN(params)\n", 363 | "if EVAL:\n", 364 | " net.load_state_dict(torch.load('net.dat'))\n", 365 | " net.eval()\n", 366 | " # net.alpha *= -1; net.DAmult *= -1 # Should leave the system invariant\n", 367 | "\n", 368 | "\n", 369 | "\n", 370 | "\n", 371 | "\n", 372 | "\n", 373 | "\n", 374 | "print (\"Shape of all optimized parameters:\", [x.size() for x in net.parameters()])\n", 375 | "allsizes = [torch.numel(x.data.cpu()) for x in net.parameters()]\n", 376 | "print (\"Size (numel) of all optimized elements:\", allsizes)\n", 377 | "print (\"Total size (numel) of all optimized elements:\", sum(allsizes))\n", 378 | "\n", 379 | "print(\"Initializing optimizer\")\n", 380 | "optimizer = torch.optim.Adam(net.parameters(), lr=1.0*params['lr'], eps=params['eps'], weight_decay=params['l2'])\n", 381 | "\n", 382 | "# A lot of logging...\n", 383 | "all_losses = []\n", 384 | "all_grad_norms = []\n", 385 | "all_losses_objective = []\n", 386 | "all_mean_rewards_ep = []\n", 387 | "all_mean_testrewards_ep = []\n", 388 | "all_losses_v = []\n", 389 | "\n", 390 | "oldcuedata = []\n", 391 | "\n", 392 | "lossbetweensaves = 0\n", 393 | "nowtime = time.time()\n", 394 | "\n", 395 | "nbtrials = [0]*BS\n", 396 | "totalnbtrials = 0\n", 397 | "nbtrialswithcc = 0\n", 398 | "\n", 399 | "print(\"Starting episodes!\")\n", 400 | "\n", 401 | "\n", 402 | "for numepisode in range(params['nbiter']):\n", 403 | "\n", 404 | "\n", 405 | " PRINTTRACE = False\n", 406 | " if (numepisode) % (params['pe']) == 0 or EVAL:\n", 407 | " PRINTTRACE = True\n", 408 | "\n", 409 | "\n", 410 | "\n", 411 | " if LINKEDLISTSEVAL and numepisode == 2:\n", 412 | " params['nbtraintrials'] = 1 if LINKINGISSHAM else 4 # 12 # 7\n", 413 | " params['nbtesttrials'] = 1\n", 414 | " params['nbtrials'] = params['nbtraintrials'] + params['nbtesttrials']\n", 415 | " params['eplen'] = params['nbtrials'] * params['triallen'] # eplen = episode length\n", 416 | "\n", 417 | " optimizer.zero_grad()\n", 418 | " loss = 0\n", 419 | " lossv = 0\n", 420 | " lossDA = 0\n", 421 | " lossHL1 = 0\n", 422 | " # The freshly generated uedata will be appended to oldcuedata later, after the episode is run\n", 423 | " if numepisode % params['nbepsbwresets'] == 0:\n", 424 | " if not EVAL:\n", 425 | " params['nbcues']= random.choice(params['nbcuesrange'])\n", 426 | " oldcuedata = []\n", 427 | " hidden = net.initialZeroState(BS)\n", 428 | " et = net.initialZeroET(BS) # The Hebbian eligibility trace\n", 429 | " pw = net.initialZeroPlasticWeights(BS)\n", 430 | " else:\n", 431 | " hidden = hidden.detach()\n", 432 | " et = et.detach()\n", 433 | " pw = pw.detach()\n", 434 | "\n", 435 | " numstep_ep = 0\n", 436 | " iscorrect_thisep = np.zeros((BS, params['nbtrials']))\n", 437 | " istest_thisep = np.zeros((BS, params['nbtrials']))\n", 438 | " isadjacent_thisep = np.zeros((BS, params['nbtrials']))\n", 439 | " isolddata_thisep = np.zeros((BS, params['nbtrials']))\n", 440 | " resps_thisep = np.zeros((BS, params['nbtrials']))\n", 441 | " cuepairs_thisep = []\n", 442 | " numactionschosen_alltrialsandsteps_thisep = np.zeros((BS, params['nbtrials'], params['triallen'])).astype(int)\n", 443 | " if EVAL:\n", 444 | " allpwsavs_thisep = []\n", 445 | " ds_thisep =[]; rs_thisep = []\n", 446 | " allrates_thisep = np.zeros((BS, params['hs'], params['eplen']))\n", 447 | "\n", 448 | "\n", 449 | " # Generate the bitstring for each cue number for this episode. Make sure they're all different (important when using very small cues for debugging, e.g. cs=2, ni=2)\n", 450 | "\n", 451 | "\n", 452 | "\n", 453 | " # print(\"Generating cues...\")\n", 454 | " if FIXEDCUES:\n", 455 | " # Debugging only: Never change cue data\n", 456 | " if numepisode == 0:\n", 457 | " cuedata=[]\n", 458 | " for nb in range(BS):\n", 459 | " cuedata.append([])\n", 460 | " for ncue in range(params['nbcues']):\n", 461 | " if nb == 0:\n", 462 | " assert len(cuedata[nb]) == ncue\n", 463 | " foundsame = 1\n", 464 | " cpt = 0\n", 465 | " while foundsame > 0 :\n", 466 | " cpt += 1\n", 467 | " if cpt > 10000:\n", 468 | " # This should only occur with very weird parameters, e.g. cs=2, ni>4\n", 469 | " raise ValueError(\"Could not generate a full list of different cues\")\n", 470 | " foundsame = 0\n", 471 | " candidate = np.random.randint(2, size=params['cs']) * 2 - 1\n", 472 | " for backtrace in range(ncue):\n", 473 | " # if np.array_equal(cuedata[nb][backtrace], candidate):\n", 474 | " if np.mean(cuedata[nb][backtrace] == candidate) > .66 :\n", 475 | " # if np.abs(np.mean(cuedata[nb][backtrace] * candidate)) > .1 :\n", 476 | " foundsame = 1\n", 477 | " cuedata[nb].append(candidate)\n", 478 | " else:\n", 479 | " cuedata[nb].append(cuedata[0][ncue])\n", 480 | "\n", 481 | " else: # Not fixed cues\n", 482 | " if not LINKEDLISTSEVAL or numepisode == 0:\n", 483 | " # if numepisode == 0: # THIS DOESN't WORK TO FIX CUES! Different nb's still have different cues\n", 484 | " cuedata=[]\n", 485 | " for nb in range(BS):\n", 486 | " cuedata.append([])\n", 487 | " for ncue in range(params['nbcues']):\n", 488 | " assert len(cuedata[nb]) == ncue\n", 489 | " foundsame = 1\n", 490 | " cpt = 0\n", 491 | " while foundsame > 0 :\n", 492 | " cpt += 1\n", 493 | " if cpt > 10000:\n", 494 | " # This should only occur with very weird parameters, e.g. cs=2, ni>4\n", 495 | " raise ValueError(\"Could not generate a full list of different cues\")\n", 496 | " foundsame = 0\n", 497 | " candidate = np.random.randint(2, size=params['cs']) * 2 - 1\n", 498 | " for backtrace in range(ncue):\n", 499 | " # if np.array_equal(cuedata[nb][backtrace], candidate):\n", 500 | " # if np.abs(np.mean(cuedata[nb][backtrace] * candidate)) > .2 :\n", 501 | " # if np.sum(cuedata[nb][backtrace] != candidate) < 4: # 2:\n", 502 | " if np.mean(cuedata[nb][backtrace] == candidate) > .66 :\n", 503 | " foundsame = 1\n", 504 | "\n", 505 | " cuedata[nb].append(candidate)\n", 506 | " # print(\"Cues generated!\")\n", 507 | "\n", 508 | "\n", 509 | "\n", 510 | "\n", 511 | "\n", 512 | " # # The freshly generated uedata will be appended to oldcuedata later, after the episode is run\n", 513 | " # if numepisode % params['nbepsbwresets'] == 0:\n", 514 | " # oldcuedata = []\n", 515 | "\n", 516 | " reward = np.zeros(BS)\n", 517 | " sumreward = np.zeros(BS)\n", 518 | " sumrewardtest = np.zeros(BS)\n", 519 | " rewards = []\n", 520 | " vs = []\n", 521 | " logprobs = []\n", 522 | " cues=[]\n", 523 | " for nb in range(BS):\n", 524 | " cues.append([])\n", 525 | " dist = 0\n", 526 | " numactionschosen = np.zeros(BS, dtype='int32')\n", 527 | "\n", 528 | " #reward = 0.0\n", 529 | " #rewards = []\n", 530 | " #vs = []\n", 531 | " #logprobs = []\n", 532 | " #sumreward = 0.0\n", 533 | " nbtrials = np.zeros(BS)\n", 534 | " nbtesttrials = nbtesttrials_correct = nbtesttrials_adjcues = nbtesttrials_adjcues_correct = nbtesttrials_nonadjcues = nbtesttrials_nonadjcues_correct = 0\n", 535 | " nbrewardabletrials = np.zeros(BS)\n", 536 | " thistrialhascorrectorder = np.zeros(BS)\n", 537 | " thistrialhasadjacentcues = np.zeros(BS)\n", 538 | " thistrialhascorrectanswer = np.zeros(BS)\n", 539 | "\n", 540 | "\n", 541 | " # 2 steps of blank input between episodes. Not sure if it helps.\n", 542 | " inputs = np.zeros((BS, params['inputsize']), dtype='float32')\n", 543 | " inputsC = torch.from_numpy(inputs).detach().to(device)\n", 544 | " for nn in range(2):\n", 545 | " y, v, DAout, hidden, et, pw = net(inputsC, hidden, et, pw) # y should output raw scores, not probas\n", 546 | "\n", 547 | "\n", 548 | "\n", 549 | " #print(\"EPISODE \", numepisode)\n", 550 | "\n", 551 | " for numtrial in range(params['nbtrials']):\n", 552 | "\n", 553 | "\n", 554 | " if RESETHIDDENEVERYTRIAL:\n", 555 | " hidden = net.initialZeroState(BS)\n", 556 | " if RESETETEVERYTRIAL:\n", 557 | " # et = et * 0 # net.initialZeroET()\n", 558 | " et = net.initialZeroET(BS)\n", 559 | " if RESETPWEVERYTRIAL:\n", 560 | " pw = net.initialZeroPlasticWeights(BS)\n", 561 | "\n", 562 | " # First, we prepare the specific sequence of inputs for this trial\n", 563 | " # The inputs can be a pair of cue numbers, or -1 (empty stimulus), or a single number equal to params['nbcues'], which indicates the 'response' cue.\n", 564 | " # These will be translated into actual network inputs (using the actual bitstrings) later.\n", 565 | " # Remember that the actual data for each cue (i.e. its actual bitstring) is randomly generated for each episode, above\n", 566 | "\n", 567 | " cuepairs_thistrial = []\n", 568 | " for nb in range(BS):\n", 569 | " thistrialhascorrectorder[nb] = 0\n", 570 | " cuerange = range(params['nbcues'])\n", 571 | " if LINKEDLISTSEVAL:\n", 572 | " if SHOWFIRSTHALFFIRST:\n", 573 | " if numepisode == 0:\n", 574 | " cuerange = range(params['nbcues']//2)\n", 575 | " elif numepisode == 1:\n", 576 | " cuerange = range(params['nbcues']//2, params['nbcues'])\n", 577 | " else:\n", 578 | " cuerange = range(params['nbcues'])\n", 579 | " else:\n", 580 | " if numepisode == 0:\n", 581 | " cuerange = range(params['nbcues']//2, params['nbcues'])\n", 582 | " elif numepisode == 1:\n", 583 | " cuerange = range(params['nbcues']//2)\n", 584 | " else:\n", 585 | " cuerange = range(params['nbcues'])\n", 586 | " # # In any trial, we show exactly two cues (randomly chosen), simultaneously:\n", 587 | " cuepair = list(np.random.choice(cuerange, 2, replace=False))\n", 588 | "\n", 589 | " # If the trial is NOT a test trial, these two cues should be adjacent\n", 590 | " if nbtrials[nb] < params['nbtraintrials'] or (ONLYTWOLASTADJ and nbtrials[nb] >= params['nbtrials'] - 2):\n", 591 | " if ONLYTWOLASTADJ and nbtrials[nb] >= params['nbtrials'] - 2:\n", 592 | " while abs(cuepair[0] - cuepair[1]) > 1 or 0 in cuepair or 6 in cuepair:\n", 593 | " cuepair = list(np.random.choice(cuerange, 2, replace=False))\n", 594 | " else:\n", 595 | " while abs(cuepair[0] - cuepair[1]) > 1 :\n", 596 | " cuepair = list(np.random.choice(cuerange, 2, replace=False))\n", 597 | " else:\n", 598 | " assert nbtrials[nb] >= params['nbtraintrials']\n", 599 | " if ONLYTWOLASTADJ:\n", 600 | " assert nbtrials[nb] < params['nbtrials'] - 2\n", 601 | " while not(\n", 602 | " (2 in cuepair and 0 in cuepair )\n", 603 | " or (2 in cuepair and 4 in cuepair )\n", 604 | " or (4 in cuepair and 6 in cuepair )\n", 605 | " or (3 in cuepair and 0 in cuepair )\n", 606 | " or (3 in cuepair and 6 in cuepair )\n", 607 | " ):\n", 608 | " cuepair = list(np.random.choice(cuerange, 2, replace=False))\n", 609 | "\n", 610 | " if NBMASSEDTRIALS > 0 and nbtrials[nb] >= params['nbtraintrials'] and numtrial < params['nbtrials'] - params['nbtesttrials']:\n", 611 | " cuepair = MASSEDPAIR\n", 612 | "\n", 613 | " if LINKEDLISTSEVAL and numepisode == 2:\n", 614 | " if numtrial < params['nbtraintrials']:\n", 615 | " if LINKINGISSHAM:\n", 616 | " cuepair = [params['nbcues']//2-3,params['nbcues']//2-2] # Sanity check for debugging: this should lead to chance perf in the test trial of 3rd episode\n", 617 | " else:\n", 618 | " cuepair = [params['nbcues']//2-1,params['nbcues']//2] if np.random.randint(2) else [params['nbcues']//2,params['nbcues']//2-1]\n", 619 | " # else nothing, we're in the 'test' phase (which is now only 1 trial) and we sample from all pairs above\n", 620 | "\n", 621 | " if nb > params['bs']//2 and HALFNOBARREDPAIRUNTILT18:\n", 622 | " if numtrial == 18:\n", 623 | " cuepair = BARREDPAIR if np.random.randint(2) else [BARREDPAIR[1],BARREDPAIR[0]]\n", 624 | " elif numtrial < 18:\n", 625 | " while True:\n", 626 | " cuepair = list(np.random.choice(cuerange, 2, replace=False))\n", 627 | " if (abs(cuepair[0] - cuepair[1]) == 1) :\n", 628 | " if (BARREDPAIR[0] not in cuepair) or (BARREDPAIR[1] not in cuepair):\n", 629 | " break\n", 630 | "\n", 631 | "\n", 632 | "\n", 633 | " thistrialhascorrectorder[nb] = 1 if cuepair[0] < cuepair[1] else 0\n", 634 | " thistrialhasadjacentcues[nb] = 1 if (abs(cuepair[0]-cuepair[1]) == 1) else 0\n", 635 | " isadjacent_thisep[nb,numtrial] = thistrialhasadjacentcues[nb]\n", 636 | " istest_thisep[nb, numtrial] = 1 if numtrial >= params['nbtraintrials'] + NBMASSEDTRIALS else 0\n", 637 | "\n", 638 | " # mycues = [cuepair,cuepair]\n", 639 | " mycues = [cuepair,]\n", 640 | " cuepairs_thistrial.append(cuepair)\n", 641 | "\n", 642 | " # Filling up other inputs for this trial\n", 643 | " # # We first insert some empty time steps at random either before or after the stimulus\n", 644 | " # for nc in range(params['triallen'] - len(mycues) - 3):\n", 645 | " # # mycues.insert(np.random.randint(len(mycues)), -1)\n", 646 | " # mycues.insert(0, -1)\n", 647 | " # No, we don't do that any more.\n", 648 | "\n", 649 | " mycues.append(params['nbcues']) # The 'go' cue, instructing response from the network\n", 650 | " mycues.append(-1) # One empty step.During the first empty step, reward (computed on the previous step) is seen by the network.\n", 651 | " mycues.append(-1)\n", 652 | " # mycues.append(-1)\n", 653 | " assert len(mycues) == params['triallen']\n", 654 | " assert mycues[NUMRESPONSESTEP] == params['nbcues'] # The 'response' step is signalled by the 'go' cue, whose number is params['nbcues'].\n", 655 | " cues[nb] = mycues\n", 656 | "\n", 657 | " cuepairs_thisep.append(cuepairs_thistrial)\n", 658 | "\n", 659 | " # In test period, if there ars some old cues in the store, some trials will use old cues\n", 660 | " if len(oldcuedata) > 0 and numtrial >= params['nbtraintrials'] + NBMASSEDTRIALS:\n", 661 | " for nb in range(BS):\n", 662 | " if np.random.rand() < PROBAOLDDATA:\n", 663 | " isolddata_thisep[nb,numtrial] = 1\n", 664 | "\n", 665 | "\n", 666 | "\n", 667 | " # Now we are ready to actually run the trial:\n", 668 | "\n", 669 | " for numstep in range(params['triallen']):\n", 670 | "\n", 671 | " inputs = np.zeros((BS, params['inputsize']), dtype='float32')\n", 672 | "\n", 673 | " for nb in range(BS):\n", 674 | " # Turning the cue number for this time step into actual (signed) bitstring inputs, using the cue data generated at the beginning of the episode - or, ocasionally, oldcuedata\n", 675 | " inputs[nb, :NBSTIMBITS] = 0\n", 676 | " if cues[nb][numstep] != -1 and cues[nb][numstep] != params['nbcues']:\n", 677 | " assert len(cues[nb][numstep]) == 2\n", 678 | " if isolddata_thisep[nb, numtrial]:\n", 679 | " oldpos = np.random.randint(len(oldcuedata))\n", 680 | " inputs[nb, :NBSTIMBITS-1] = np.concatenate( ( oldcuedata[oldpos][nb][cues[nb][numstep][0]][:], oldcuedata[oldpos][nb][cues[nb][numstep][1]][:] ) )\n", 681 | " else:\n", 682 | " inputs[nb, :NBSTIMBITS-1] = np.concatenate( ( cuedata[nb][cues[nb][numstep][0]][:], cuedata[nb][cues[nb][numstep][1]][:] ) )\n", 683 | " if cues[nb][numstep] == params['nbcues']:\n", 684 | " inputs[nb, NBSTIMBITS-1] = 1 # \"Go\" cue\n", 685 | "\n", 686 | " inputs[nb, NBSTIMBITS + 0] = 1.0 # Bias neuron, probably not necessary\n", 687 | " inputs[nb,NBSTIMBITS + 1] = numstep_ep / params['eplen'] # Time passed in this episode. Should it be the trial? Doesn't matter much anyway.\n", 688 | " inputs[nb, NBSTIMBITS + 2] = 1.0 * reward[nb] # Reward from previous time step\n", 689 | "\n", 690 | "\n", 691 | " # Original:\n", 692 | " # if numstep > 0:\n", 693 | " # inputs[nb, NBSTIMBITS + ADDINPUT + numactionschosen[nb]] = 1 # Previously chosen action\n", 694 | " # DEBUGGING !!\n", 695 | " # if numstep == 2:\n", 696 | " # if not (numtrial == 0 and numstep == 0):\n", 697 | " # if numstep == 0 and numtrial > 0:\n", 698 | " assert NUMRESPONSESTEP + 1 < params['triallen'] # If that is not the case, we must provide the action signal in the next trial (this works)\n", 699 | " if numstep == NUMRESPONSESTEP + 1:\n", 700 | " inputs[nb, NBSTIMBITS + ADDINPUT + numactionschosen[nb]] = 1 # Previously chosen action\n", 701 | "\n", 702 | "\n", 703 | " # inputsC = torch.from_numpy(inputs, requires_grad=False).to(device)\n", 704 | " inputsC = torch.from_numpy(inputs).detach().to(device)\n", 705 | "\n", 706 | "\n", 707 | "\n", 708 | " pwold = pw.clone()\n", 709 | "\n", 710 | " ## Running the network\n", 711 | " y, v, DAout, hidden, et, pw = net(inputsC, hidden, et, pw) # y should output raw scores, not probas\n", 712 | "\n", 713 | " # This should hold true if we reset h and et (not pw) between every episode:\n", 714 | " # if numstep < 2:\n", 715 | " # assert torch.sum(torch.abs(pwold-pw)) < 1e-8\n", 716 | "\n", 717 | "\n", 718 | " if EVAL:\n", 719 | " allrates_thisep[:, :, numstep_ep] = hidden.cpu().numpy()[:,:]\n", 720 | " ds_thisep.append(DAout.cpu().numpy())\n", 721 | " rs_thisep.append(reward[:, None])\n", 722 | " # LIMITSAVPW = 200\n", 723 | " if numtrial in [0,1, 18,19]:\n", 724 | " allpwsavs_thisep.append(pw.cpu().numpy().astype('float16'))\n", 725 | " else:\n", 726 | " allpwsavs_thisep.append(None)\n", 727 | "\n", 728 | "\n", 729 | "\n", 730 | "\n", 731 | " # Choosing the action from the outputs\n", 732 | " y = F.softmax(y, dim=1)\n", 733 | " # Must convert y to probas to use this !\n", 734 | " distrib = torch.distributions.Categorical(y)\n", 735 | " actionschosen = distrib.sample()\n", 736 | " logprobs.append(distrib.log_prob(actionschosen)) # To be used later for the A2C algorithm\n", 737 | " # if numstep == NUMRESPONSESTEP: # 2: # 4: #3: # 2:\n", 738 | " # logprobs.append(distrib.log_prob(actionschosen)) # To be used later for the A2C algorithm\n", 739 | " # else:\n", 740 | " # logprobs.append(0)\n", 741 | " numactionschosen = actionschosen.data.cpu().numpy() # Store as scalars (for the whole batch)\n", 742 | "\n", 743 | " if PRINTTRACE:\n", 744 | " print(\"Tr\", numtrial, \"Step \", numstep, \", Cue 1 (0):\", inputs[0,:params['cs']], \"Cue 2 (0):\", inputs[0,params['cs']:2*params['cs']],\n", 745 | " \"Other inputs:\", inputs[0, 2*params['cs']:], \"\\n - Outputs(0): \", y.data.cpu().numpy()[0,:], \" - action chosen(0): \", numactionschosen[0],\n", 746 | " \"TrialLen:\", params['triallen'], \"numstep\", numstep, \"TTHCC(0): \", thistrialhascorrectorder[0], \"TTHOC(0):\", isolddata_thisep[0, numtrial], \"Reward (based on prev step): \", reward[0], \", DAout:\", float(DAout[0]), \", cues(0):\", cues[0] ) #, \", cc(0):\", correctcue[0])\n", 747 | "\n", 748 | "\n", 749 | " # Computing the rewards. This is done for each time step.\n", 750 | " reward = np.zeros(BS, dtype='float32')\n", 751 | " for nb in range(BS):\n", 752 | " if numactionschosen[nb] == 1:\n", 753 | " # Small penalty for any non-rest action taken\n", 754 | " # In practice, this would usually be 0\n", 755 | " reward[nb] -= params['wp']\n", 756 | "\n", 757 | " numactionschosen_alltrialsandsteps_thisep[nb, numtrial, numstep] = numactionschosen[nb]\n", 758 | "\n", 759 | " if numstep == NUMRESPONSESTEP: # 2: # 4: #3: # 2:\n", 760 | " # This is the 'response' step of the trial (and we showed the response signal\n", 761 | " assert cues[nb][numstep] == params['nbcues']\n", 762 | " resps_thisep[nb, numtrial] = numactionschosen[nb] *2 - 1 # Store the response in this timestep as the response for the whole trial, for logging/analysis purposes\n", 763 | " # We must deliver reward (which will be perceived by the agent at the next step), positive or negative, depending on response\n", 764 | " thistrialhascorrectanswer[nb] = 1\n", 765 | " if thistrialhascorrectorder[nb] and numactionschosen[nb] == 1:\n", 766 | " reward[nb] += params['rew']\n", 767 | " elif (not thistrialhascorrectorder[nb]) and numactionschosen[nb] == 0:\n", 768 | " reward[nb] += params['rew']\n", 769 | " else:\n", 770 | " reward[nb] -= params['rew']\n", 771 | " thistrialhascorrectanswer[nb] = 0\n", 772 | " iscorrect_thisep[nb, numtrial] = thistrialhascorrectanswer[nb]\n", 773 | "\n", 774 | " if ( cuepairs_thistrial[nb][0] < cuepairs_thistrial[nb][1] ) and numactionschosen[nb] == 1:\n", 775 | " assert thistrialhascorrectanswer[nb]\n", 776 | " if ( cuepairs_thistrial[nb][0] > cuepairs_thistrial[nb][1] ) and numactionschosen[nb] == 1:\n", 777 | " assert not thistrialhascorrectanswer[nb]\n", 778 | " if ( cuepairs_thistrial[nb][0] < cuepairs_thistrial[nb][1] ) and numactionschosen[nb] == 0:\n", 779 | " assert not thistrialhascorrectanswer[nb]\n", 780 | " if ( cuepairs_thistrial[nb][0] > cuepairs_thistrial[nb][1] ) and numactionschosen[nb] == 0:\n", 781 | " assert thistrialhascorrectanswer[nb]\n", 782 | "\n", 783 | "\n", 784 | "\n", 785 | "\n", 786 | " if numstep == params['triallen'] - 1:\n", 787 | " # This was the last step of the trial\n", 788 | " nbtrials[nb] += 1\n", 789 | " totalnbtrials += 1\n", 790 | " if thistrialhascorrectorder[nb]:\n", 791 | " nbtrialswithcc += 1\n", 792 | "\n", 793 | "\n", 794 | "\n", 795 | " rewards.append(reward)\n", 796 | " vs.append(v)\n", 797 | " sumreward += reward\n", 798 | " if numtrial >= params['nbtrials'] - params['nbtesttrials']:\n", 799 | " sumrewardtest += reward\n", 800 | " # lossDA += torch.sum(torch.abs(DAout))\n", 801 | " lossDA += torch.sum(torch.abs(DAout / (1e-8 + net.DAmult))) # This is a hack to \"remove\" DAmult from the L1 penalty. Assumes DAmult never goes to < 0.\n", 802 | " lossHL1 += torch.mean(torch.abs(hidden))\n", 803 | "\n", 804 | "\n", 805 | " loss += (params['bent'] * y.pow(2).sum() / BS ) # In real A2c, this is an entropy incentive. Our original version of PyTorch did not have an entropy() function for Distribution, so we use sum-of-squares instead.\n", 806 | "\n", 807 | " numstep_ep += 1\n", 808 | "\n", 809 | "\n", 810 | " # All steps done for this trial\n", 811 | " if numtrial >= params['nbtrials'] - params['nbtesttrials']:\n", 812 | " sumrewardtest += reward\n", 813 | " nbtesttrials += BS\n", 814 | " nbtesttrials_correct += np.sum(thistrialhascorrectanswer)\n", 815 | " nbtesttrials_adjcues += np.sum(thistrialhasadjacentcues)\n", 816 | " nbtesttrials_adjcues_correct += np.sum(thistrialhasadjacentcues * thistrialhascorrectanswer)\n", 817 | " nbtesttrials_nonadjcues += np.sum(1 - thistrialhasadjacentcues)\n", 818 | " nbtesttrials_nonadjcues_correct += np.sum((1-thistrialhasadjacentcues) * thistrialhascorrectanswer)\n", 819 | "\n", 820 | "\n", 821 | " # All trials done for this episode\n", 822 | "\n", 823 | " oldcuedata.append(cuedata)\n", 824 | " if EVAL:\n", 825 | " ds_thisep = np.hstack(ds_thisep)\n", 826 | " rs_thisep = np.hstack(rs_thisep)\n", 827 | "\n", 828 | " # Computing the various losses for A2C (outer-loop training)\n", 829 | "\n", 830 | " R = torch.zeros(BS, requires_grad=False).to(device)\n", 831 | " gammaR = params['gr']\n", 832 | " for numstepb in reversed(range(params['eplen'])) :\n", 833 | " R = gammaR * R + torch.from_numpy(rewards[numstepb]).detach().to(device)\n", 834 | " # ctrR = R - vs[numstepb][0]\n", 835 | " # ctrR = R - vs[numstepb]\n", 836 | " ctrR = R - vs[numstepb][:,0] # I think this is right...\n", 837 | " lossv += ctrR.pow(2).sum() / BS\n", 838 | " LOSSMULT = params['testlmult'] if numstepb > params['eplen'] - params['triallen'] * params['nbtesttrials'] else 1.0\n", 839 | "\n", 840 | " # NOTE: We accumulate the logprobs from all time steps, even when the output is ignored (it is only used to sample response at time step 1, i.e. RESPONSETIME)\n", 841 | " # Unsurprisingly, performance is better if we anly record the logprobs for response time (and set them to 0 otherwise), but we keep this version because it was used in the paper.\n", 842 | " loss -= LOSSMULT * (logprobs[numstepb] * ctrR.detach()).sum() / BS # Action poliy loss\n", 843 | "\n", 844 | "\n", 845 | "\n", 846 | " lossobj = float(loss)\n", 847 | " loss += params['blossv'] * lossv # lossmult is not applied to value-prediction loss; is it right?...\n", 848 | " loss += params['lda'] * lossDA # lossDA is loss on absolute value of DA output (see above)\n", 849 | " loss += params['lhl1'] * lossHL1\n", 850 | " loss /= params['eplen']\n", 851 | " losspw = torch.mean(pw ** 2) * params['lpw'] # loss on squared final plastic weights is not divided by episode length\n", 852 | " loss += losspw\n", 853 | "\n", 854 | " if PRINTTRACE:\n", 855 | " print(\"lossobj (with coeff):\", lossobj / params['eplen'], \", lossv (with coeff): \", params['blossv'] * float(lossv) / params['eplen'],\n", 856 | " \"lossDA (with coeff): \", params['lda'] * float(lossDA) / params['eplen'],\", losspw:\", float(losspw))\n", 857 | " print (\"Total reward for this episode(0):\", sumreward[0], \"Prop. of trials w/ rewarded cue:\", (nbtrialswithcc / totalnbtrials), \" Total Nb of trials:\", totalnbtrials)\n", 858 | " print(\"Nb Test Trials:\", nbtesttrials, \", Nb Test Trials AdjCues:\", nbtesttrials_adjcues, \", Nb Test Trials NonAdjCues:\", nbtesttrials_nonadjcues)\n", 859 | " if nbtesttrials > 0:\n", 860 | " # Should always be the case except for LinkedListsEval\n", 861 | " print(\"Test Perf (both methods):\", np.array([nbtesttrials_correct / nbtesttrials, np.sum(iscorrect_thisep * istest_thisep) / np.sum(istest_thisep)]),\n", 862 | " \"Test Perf AdjCues:\", np.array([(nbtesttrials_adjcues_correct / nbtesttrials_adjcues)]) if nbtesttrials_adjcues > 0 else 'N/A',\n", 863 | " \"Test Perf NonAdjCues:\", np.array([nbtesttrials_nonadjcues_correct / nbtesttrials_nonadjcues]) if nbtesttrials_nonadjcues > 0 else 'N/A',\n", 864 | " \"Test perf old cues:\", np.array([np.sum(iscorrect_thisep * istest_thisep * isolddata_thisep) / np.sum(istest_thisep * isolddata_thisep)]) if np.sum(istest_thisep * isolddata_thisep) > 0 else \"N/A\" ,\n", 865 | " )\n", 866 | "\n", 867 | "\n", 868 | " if not EVAL:\n", 869 | " loss.backward()\n", 870 | " gn = torch.nn.utils.clip_grad_norm_(net.parameters(), params['gc'])\n", 871 | " all_grad_norms.append(gn)\n", 872 | " if numepisode > 100: # Burn-in period\n", 873 | " optimizer.step()\n", 874 | " if POSALPHA:\n", 875 | " torch.clip_(net.alpha.data, min=0)\n", 876 | "\n", 877 | "\n", 878 | " lossnum = float(loss)\n", 879 | " lossbetweensaves += lossnum\n", 880 | " all_losses_objective.append(lossnum)\n", 881 | " all_mean_rewards_ep.append(sumreward.mean())\n", 882 | " all_mean_testrewards_ep.append(sumrewardtest.mean())\n", 883 | "\n", 884 | "\n", 885 | " if PRINTTRACE:\n", 886 | "\n", 887 | " print(\"Episode\", numepisode, \"====\")\n", 888 | " print(\"Mean loss: \", lossbetweensaves / params['pe'])\n", 889 | " lossbetweensaves = 0\n", 890 | " print(\"Mean reward per episode (over whole batch and last\", params['pe'], \"episodes: \", np.sum(all_mean_rewards_ep[-params['pe']:])/ params['pe'])\n", 891 | " print(\"Mean test-time reward per episode (over whole batch and last\", params['pe'], \"episodes: \", np.sum(all_mean_testrewards_ep[-params['pe']:])/ params['pe'])\n", 892 | " previoustime = nowtime\n", 893 | " nowtime = time.time()\n", 894 | " print(\"Time spent on last\", params['pe'], \"iters: \", nowtime - previoustime)\n", 895 | "\n", 896 | " # print(\" etaet: \", net.etaet.data.cpu().numpy(), \" DAmult: \", net.DAmult.data.cpu().numpy(), \" mean-abs pw: \", np.mean(np.abs(pw.data.cpu().numpy())))\n", 897 | " print(\" etaet: \", net.etaet.data.cpu().numpy(), \" DAmult: \", float(net.DAmult), \" mean-abs pw: \", np.mean(np.abs(pw.data.cpu().numpy())))\n", 898 | " print(\"min/max/med-abs w, alpha, pw\")\n", 899 | " print(float(torch.min(net.w)), float(torch.max(net.w)), float(torch.median(torch.abs(net.w))))\n", 900 | " print(float(torch.min(net.alpha)), float(torch.max(net.alpha)), float(torch.median(torch.abs(net.alpha))))\n", 901 | " print(float(torch.min(pw)), float(torch.max(pw)), float(torch.median(torch.abs(pw))))\n", 902 | " # pwc = pw.cpu().numpy()\n", 903 | " # print(np.min(pwc), np.max(pwc), np.median(np.abs(pwc)))\n", 904 | "\n", 905 | " # if (numepisode) % params['save_every'] == 0:\n", 906 | " if EVAL:\n", 907 | " np.savez('outcomes_'+str(numepisode)+'.npz', c=iscorrect_thisep.astype(int), a=isadjacent_thisep.astype(int),\n", 908 | " cp=np.moveaxis(np.array(cuepairs_thisep),1,0), r=resps_thisep.astype(int), ac = numactionschosen_alltrialsandsteps_thisep.astype(int))\n", 909 | " np.save('allrates_thisep_'+str(numepisode)+'.npy', allrates_thisep)\n", 910 | "\n", 911 | " if (numepisode) % params['save_every'] == 0 and numepisode > 0:\n", 912 | " losslast100 = np.mean(all_losses_objective[-100:])\n", 913 | " print(\"Average loss over the last 100 episodes:\", losslast100)\n", 914 | " print(\"Saving local files...\")\n", 915 | "\n", 916 | " if (not EVAL) and numepisode > 0:\n", 917 | " # print(\"Saving model parameters...\")\n", 918 | " # torch.save(net.state_dict(), 'net_'+suffix+'.dat')\n", 919 | " torch.save(net.state_dict(), 'netAE'+str(params['rngseed'])+'.dat')\n", 920 | " torch.save(net.state_dict(), 'net.dat')\n", 921 | "\n", 922 | " # with open('rewards_'+suffix+'.txt', 'w') as thefile:\n", 923 | " # for item in all_mean_rewards_ep[::10]:\n", 924 | " # thefile.write(\"%s\\n\" % item)\n", 925 | " # with open('testrew_'+suffix+'.txt', 'w') as thefile:\n", 926 | " # for item in all_mean_testrewards_ep[::10]:\n", 927 | " # thefile.write(\"%s\\n\" % item)\n", 928 | " with open('tAE'+str(params['rngseed'])+'.txt', 'w') as thefile:\n", 929 | " for item in all_mean_testrewards_ep[::10]:\n", 930 | " thefile.write(\"%s\\n\" % item)\n", 931 | "\n", 932 | "\n" 933 | ] 934 | }, 935 | { 936 | "cell_type": "code", 937 | "source": [ 938 | "if EVAL == False:\n", 939 | " raise ValueError(\"No need to go further if not in EVAL mode.\")" 940 | ], 941 | "metadata": { 942 | "id": "4u1PoRtDY-OH" 943 | }, 944 | "execution_count": null, 945 | "outputs": [] 946 | }, 947 | { 948 | "cell_type": "code", 949 | "execution_count": null, 950 | "metadata": { 951 | "id": "exEby-1KF0nX" 952 | }, 953 | "outputs": [], 954 | "source": [ 955 | "pp = print" 956 | ] 957 | }, 958 | { 959 | "cell_type": "code", 960 | "source": [ 961 | " # R = torch.zeros(BS, requires_grad=False).to(device)\n", 962 | " # gammaR = params['gr']\n", 963 | " # for numstepb in reversed(range(params['eplen'])) :\n", 964 | " # R = gammaR * R + torch.from_numpy(rewards[numstepb]).detach().to(device)\n", 965 | " # ctrR = R - vs[numstepb][0]\n", 966 | " # lossv += ctrR.pow(2).sum() / BS\n", 967 | " # LOSSMULT = params['testlmult'] if numstepb > params['eplen'] - params['triallen'] * params['nbtesttrials'] else 1.0\n", 968 | " # loss -= LOSSMULT * (logprobs[numstepb] * ctrR.detach()).sum() / BS # Action poliy loss\n", 969 | "\n", 970 | "\n", 971 | "pp(vs[numstepb][0].shape)\n", 972 | "pp(vs[numstepb][:,0].shape)\n", 973 | "pp(vs[numstepb].shape)\n", 974 | "pp(R.shape)\n", 975 | "\n", 976 | "pp( (R-vs[numstepb]).shape )\n", 977 | "pp( (R-vs[numstepb][:,0]).shape )\n", 978 | "pp( (R-vs[numstepb][0]).shape )\n" 979 | ], 980 | "metadata": { 981 | "id": "wTPwS8TiRXUK" 982 | }, 983 | "execution_count": null, 984 | "outputs": [] 985 | }, 986 | { 987 | "cell_type": "code", 988 | "source": [ 989 | "# net.alpha *= -1.0\n", 990 | "# net.DAmult *= -1.0\n", 991 | "# torch.save(net.state_dict(), 'net30K13_flip.dat')\n" 992 | ], 993 | "metadata": { 994 | "id": "P3MBdl3tteNr" 995 | }, 996 | "execution_count": null, 997 | "outputs": [] 998 | }, 999 | { 1000 | "cell_type": "code", 1001 | "execution_count": null, 1002 | "metadata": { 1003 | "id": "P2AO7RMv8vo1" 1004 | }, 1005 | "outputs": [], 1006 | "source": [ 1007 | "print([float(x) for x in [torch.min(net.alpha), torch.max(net.alpha)]])\n", 1008 | "print([float(x) for x in [torch.mean(net.alpha), torch.median(net.alpha)]])\n", 1009 | "print(float(torch.mean((net.alpha>0).float())))\n" 1010 | ] 1011 | }, 1012 | { 1013 | "cell_type": "code", 1014 | "source": [ 1015 | "\n", 1016 | "import matplotlib.pyplot as plt\n", 1017 | "\n", 1018 | "torch.set_grad_enabled(False)\n", 1019 | "\n", 1020 | "\n", 1021 | "# Note: this is reward computed in previous stime step, shown at the time it is percceived byb the agent....\n", 1022 | "# That is: rreward is computed at response time (2nd time step of each trial), then perceived by the agent at the next step (3rd time step)\n", 1023 | "# The DA output at time t perceives rreward shown at time t too.\n", 1024 | "\n", 1025 | "\n", 1026 | "# ================\n", 1027 | "# ================\n", 1028 | "DAmultiplier = 1.0\n", 1029 | "\n", 1030 | "if False:\n", 1031 | " # We may adjust direction of DA trace depending on whether alpha is mostly-negative or mostly-positive? Better not, too confusing. Just realize that sometimes mutually-compensating sign changes will occur, which don't affect the overall algorithm\n", 1032 | "\n", 1033 | " DAmultiplier = 1.0 if float(torch.median(net.alpha)) > 0 else -1.0\n", 1034 | "\n", 1035 | " net.alpha *= DAmultiplier\n", 1036 | "\n", 1037 | " net.DAmult *= DAmultiplier\n", 1038 | " allpwsavs_thisep = [x * DAmultiplier if x is not None else None for x in allpwsavs_thisep ]\n", 1039 | " ds_thisep = ds_thisep * DAmultiplier\n", 1040 | "\n", 1041 | "# ================\n", 1042 | "# ================\n", 1043 | "\n", 1044 | "\n", 1045 | "print(\"DAmultiplier:\", DAmultiplier)\n", 1046 | "\n", 1047 | "LL = params['eplen']\n", 1048 | "TL = params['triallen']\n", 1049 | "\n", 1050 | "xt = np.arange(0, ds_thisep.shape[1]+1, TL*10)\n", 1051 | "\n", 1052 | "plt.figure(figsize=(8,6))\n", 1053 | "\n", 1054 | "ds= ds_thisep.copy()\n", 1055 | "rs = rs_thisep.copy()\n", 1056 | "\n", 1057 | "\n", 1058 | "for numplot in range(4):\n", 1059 | " ds_blank01 = ds[numplot,:].copy()\n", 1060 | " if RESETHIDDENEVERYTRIAL:\n", 1061 | " # If we reset h and et every trial, the first two steps of DA for each trial are irrelevant because there can't be any weight modification\n", 1062 | " # So we zero them out (might as well not show them but the graph would look bizarre)\n", 1063 | " # ds_blank01_b[0::params['triallen']] = [np.NAN for x in ds_blank01[0::params['triallen']]]\n", 1064 | " # ds_blank01_b[1::params['triallen']] = [np.NAN for x in ds_blank01[1::params['triallen']]]\n", 1065 | " ds_blank01[0::params['triallen']] = [0 for x in ds_blank01[0::params['triallen']]]\n", 1066 | " ds_blank01[1::params['triallen']] = [0 for x in ds_blank01[1::params['triallen']]]\n", 1067 | " plt.subplot(2,2,numplot+1)\n", 1068 | " plt.plot(ds_blank01, 'r', label='m(t)')\n", 1069 | " plt.plot(rs[numplot, :], 'b', label='reward')\n", 1070 | " # plt.xticks((xt-1)*4, xt)\n", 1071 | " plt.xticks(xt, xt//TL)\n", 1072 | " if numplot >1:\n", 1073 | " plt.xlabel('Trial')\n", 1074 | " if numplot ==0:\n", 1075 | " plt.legend()\n", 1076 | " if NBMASSEDTRIALS > 0:\n", 1077 | " plt.axvspan(params['nbtraintrials']* TL, params['nbtraintrials']*TL + NBMASSEDTRIALS * TL, color='orange', alpha=0.2, lw=0)\n", 1078 | "\n", 1079 | "plt.tight_layout()\n", 1080 | "plt.savefig('traces.png', dpi=300)\n", 1081 | "\n", 1082 | "\n", 1083 | "# raise ValueError" 1084 | ], 1085 | "metadata": { 1086 | "id": "YrDnmkE3O1s_" 1087 | }, 1088 | "execution_count": null, 1089 | "outputs": [] 1090 | }, 1091 | { 1092 | "cell_type": "code", 1093 | "source": [ 1094 | "# raise ValueError" 1095 | ], 1096 | "metadata": { 1097 | "id": "YjPI0eSZqxpY" 1098 | }, 1099 | "execution_count": null, 1100 | "outputs": [] 1101 | }, 1102 | { 1103 | "cell_type": "code", 1104 | "source": [ 1105 | "# Zoom\n", 1106 | "if not LINKEDLISTSEVAL:\n", 1107 | " numplot = 0\n", 1108 | " NUMTRIAL_Z = 3\n", 1109 | " plt.figure(figsize=(4.5,3))\n", 1110 | " ds_blank01 = ds[numplot,:].copy()\n", 1111 | " ds_blank01[0::params['triallen']] = [0 for x in ds_blank01[0::params['triallen']]]\n", 1112 | " ds_blank01[1::params['triallen']] = [0 for x in ds_blank01[1::params['triallen']]]\n", 1113 | " rz = rs[numplot, NUMTRIAL_Z * params['triallen']:NUMTRIAL_Z * params['triallen'] + params['triallen']*4+1]\n", 1114 | " dz = ds_blank01[NUMTRIAL_Z * params['triallen']:NUMTRIAL_Z * params['triallen'] + params['triallen']*4+1]\n", 1115 | " maxlen = len(rz)\n", 1116 | " plt.xticks(np.arange(0, maxlen, params['triallen']), NUMTRIAL_Z + np.arange(0, maxlen // params['triallen']+1))\n", 1117 | " for ybar in np.arange(0, maxlen, params['triallen']):\n", 1118 | " plt.axvline(x=ybar, color='k', ls=':')\n", 1119 | " plt.xlabel('Trials')\n", 1120 | " plt.plot(dz, 'r')\n", 1121 | " # plt.plot(range(2,maxlen,params['triallen']), rz[2::params['triallen']], 'o', markersize=10)\n", 1122 | " plt.plot(range(2,maxlen,params['triallen']), dz[2::params['triallen']], 'ro', markersize=8, label='Step 3')\n", 1123 | " # plt.plot(range(3,maxlen,params['triallen']), rz[3::params['triallen']], 'x', markersize=10)\n", 1124 | " plt.plot(range(3,maxlen,params['triallen']), dz[3::params['triallen']], 'rx', markersize=8, label='Step 4')\n", 1125 | " plt.ylim((-4,4))\n", 1126 | " plt.yticks((-3, 0, 3))\n", 1127 | " plt.ylabel('$m(t)$', fontsize=14, color='r')\n", 1128 | " plt.gca().yaxis.set_label_coords(-.075,.5)\n", 1129 | " plt.tick_params(axis='y', colors='r')\n", 1130 | " plt.legend(loc='upper left')\n", 1131 | "\n", 1132 | " ax2 = plt.twinx()\n", 1133 | " ax2.plot(rz, 'b')\n", 1134 | " ax2.set_ylim((-1.5, 1.5))\n", 1135 | " ax2.set_yticks((-1, 0, 1))\n", 1136 | " ax2.set_ylabel('$R(t)$', color='b', fontsize=14, rotation=-90)\n", 1137 | " # ax2.yaxis.label.set_color('b')\n", 1138 | " ax2.tick_params(axis='y', colors='b')\n", 1139 | " ax2.yaxis.set_label_coords(1.15,.5)\n", 1140 | "\n", 1141 | " plt.tight_layout()\n", 1142 | " plt.savefig('zoom.png', dpi=300)\n", 1143 | "\n" 1144 | ], 1145 | "metadata": { 1146 | "id": "5gd0-Jarbngh" 1147 | }, 1148 | "execution_count": null, 1149 | "outputs": [] 1150 | }, 1151 | { 1152 | "cell_type": "code", 1153 | "source": [ 1154 | "import scipy\n", 1155 | "from scipy import stats\n", 1156 | "\n", 1157 | "# NUMTRIAL_VP = 1\n", 1158 | "NUMTRIAL_VP = 4\n", 1159 | "# NUMTRIAL_VP = 18\n", 1160 | "print(ds.shape, rs.shape)\n", 1161 | "# dt2 = ds[:,2::params['triallen']]; rt2 = rs[:,2::params['triallen']]\n", 1162 | "# dt3 = ds[:,3::params['triallen']]; rt3 = rs[:,3::params['triallen']]\n", 1163 | "dt2 = ds[:,NUMTRIAL_VP * params['triallen'] + 2]\n", 1164 | "rt2 = rs[:,NUMTRIAL_VP * params['triallen'] + 2]\n", 1165 | "dt3 = ds[:,NUMTRIAL_VP * params['triallen'] + 3]\n", 1166 | "# rt3 = rs[:,18*params['triallen'] + 3]. # We only use rt2, since reward is only given at t=2 !\n", 1167 | "print(dt2.shape, rt2.shape)\n", 1168 | "print(np.unique(rt2))\n", 1169 | "dt2_rpos = dt2[rt2 > 0]\n", 1170 | "dt2_rneg = dt2[rt2 < 0]\n", 1171 | "dt3_rpos = dt3[rt2 > 0] # note: using rt2\n", 1172 | "dt3_rneg = dt3[rt2 < 0] # note: using rt2\n", 1173 | "print(dt2_rneg.shape, dt2_rpos.shape) # the dt3_x have th same shapes\n", 1174 | "\n", 1175 | "plt.figure(figsize=(4,2))\n", 1176 | "\n", 1177 | "plt.subplot(1,2,1)\n", 1178 | "# x_dt2_rpos = .1 * (np.random.randn(dt2_rpos.size) - .5)\n", 1179 | "plt.violinplot((dt2_rneg, dt2_rpos), showextrema=False, showmedians=True, widths = [dt2_rneg.size / dt2.size, dt2_rpos.size / dt2.size])\n", 1180 | "plt.xlim((.5, 2.5))\n", 1181 | "plt.xticks((1, 2), ('Rew-', 'Rew+'))\n", 1182 | "res = scipy.stats.mannwhitneyu(dt2_rneg, dt2_rpos)\n", 1183 | "print(res.pvalue)\n", 1184 | "plt.ylim((-3.5, 3.5))\n", 1185 | "plt.yticks((-3,0,3))\n", 1186 | "plt.title('$m(t=3)$')\n", 1187 | "\n", 1188 | "plt.subplot(1,2,2)\n", 1189 | "plt.violinplot((dt3_rneg, dt3_rpos), showextrema=False, showmedians=True, widths = [dt3_rneg.size / dt3.size, dt3_rpos.size / dt3.size])\n", 1190 | "plt.ylim((-3.5, 3.5))\n", 1191 | "plt.yticks((-3,0,3))\n", 1192 | "plt.xlim((.5, 2.5))\n", 1193 | "plt.xticks((1, 2), ('Rew-', 'Rew+'))\n", 1194 | "res = scipy.stats.mannwhitneyu(dt3_rneg, dt3_rpos)\n", 1195 | "# res = scipy.stats.mannwhitneyu(dt3_rneg, np.random.permutation(dt3_rneg)) # For debugging\n", 1196 | "print(res.pvalue)\n", 1197 | "plt.title('$m(t=4)$')\n", 1198 | "\n", 1199 | "plt.tight_layout()\n", 1200 | "plt.savefig('violin.png', dpi=300)\n" 1201 | ], 1202 | "metadata": { 1203 | "id": "3a8QwKL4ECdQ" 1204 | }, 1205 | "execution_count": null, 1206 | "outputs": [] 1207 | }, 1208 | { 1209 | "cell_type": "code", 1210 | "source": [ 1211 | "# print(ds[0,2::4])\n", 1212 | "# print(ds[0,3::4])\n", 1213 | "# print(\"Correlation between DA at time steps 2 and 3 across trials (all batch):\", np.corrcoef(ds[:,2::4].flatten(), ds[:,3::4].flatten())[0,1])\n", 1214 | "# print(\"Correlation between DA at time step 2 and trial reward across trials (all batch):\", np.corrcoef(ds[:,2::4].flatten(), rs[:,2::4].flatten())[0,1])\n", 1215 | "# print(\"Correlation between DA at time step 3 and trial reward across trials (all batch):\", np.corrcoef(ds[:,3::4].flatten(), rs[:,2::4].flatten())[0,1])\n", 1216 | "# print(\"Covariance between DA at time step 2 and trial reward across trials (all batch):\", np.cov(ds[:,2::4].flatten(), rs[:,2::4].flatten())[0,1])\n", 1217 | "# print(\"Covariance between DA at time step 3 and trial reward across trials (all batch):\", np.cov(ds[:,3::4].flatten(), rs[:,2::4].flatten())[0,1])\n", 1218 | "\n" 1219 | ], 1220 | "metadata": { 1221 | "id": "Shupwa8BmOvh" 1222 | }, 1223 | "execution_count": null, 1224 | "outputs": [] 1225 | }, 1226 | { 1227 | "cell_type": "code", 1228 | "execution_count": null, 1229 | "metadata": { 1230 | "id": "kREuFmpHOOsH" 1231 | }, 1232 | "outputs": [], 1233 | "source": [ 1234 | "# Performances IN THE TEST PHASE are collected for all pairs\n", 1235 | "\n", 1236 | "orderofcuepairs=[]\n", 1237 | "for nc in range(1,params['nbcues']):\n", 1238 | " for nc2 in range(params['nbcues']):\n", 1239 | " if nc+nc2 >= params['nbcues']:\n", 1240 | " break\n", 1241 | " orderofcuepairs.append([nc2, nc2+nc])\n", 1242 | "print(orderofcuepairs)\n", 1243 | "\n", 1244 | "\n", 1245 | "e = 2 if LINKEDLISTSEVAL else 0\n", 1246 | "o = np.load('outcomes_'+str(e)+'.npz', allow_pickle=True); c = o['c']; a = o['a']; na = 1-a; cp = o['cp']; r = o['r']\n", 1247 | "\n", 1248 | "\n", 1249 | "if ONLYTWOLASTADJ:\n", 1250 | " testperf_adj = np.sum(c[:,-2:] * a[:,-2:]) / np.sum(a[:, -2:] )\n", 1251 | " testperf_nonadj = np.sum(c[:,-2:] * na[:,-2:]) / np.sum(na[:, -2:] )\n", 1252 | " print(testperf_adj, testperf_nonadj)\n", 1253 | "\n", 1254 | "allperfs = []\n", 1255 | "NBSPLITS = 10\n", 1256 | "SPLITSIZE = BS // NBSPLITS\n", 1257 | "for nsp in range(NBSPLITS):\n", 1258 | " # c = corrects[e]\n", 1259 | " # a = adjs[e]\n", 1260 | " # cp = cuepairs[e]\n", 1261 | " # r = resps[e]\n", 1262 | " perfs = np.zeros(len(orderofcuepairs))\n", 1263 | " nbs = np.zeros(len(orderofcuepairs))\n", 1264 | " for pos, p in enumerate(orderofcuepairs):\n", 1265 | " nbthispair = nbcorrectthispair = 0\n", 1266 | " for nb in range(nsp*SPLITSIZE, nsp*SPLITSIZE + SPLITSIZE):\n", 1267 | " for nt in range(cp.shape[1]):\n", 1268 | " if nt < params['nbtrials'] - params['nbtesttrials']:\n", 1269 | " # if nt > 2: # Should be terrible\n", 1270 | " continue\n", 1271 | " if ( cp[nb, nt, 0] == p[0] and cp[nb, nt, 1] == p[1]) or ( cp[nb, nt, 0] == p[1] and cp[nb, nt, 1] == p[0]):\n", 1272 | " nbthispair += 1\n", 1273 | " if (cp[nb, nt, 0] == p[0] and r[nb,nt] == 1) or (cp[nb, nt, 0] == p[1] and r[nb,nt] == -1): # response was correct\n", 1274 | " assert c[nb,nt] == 1 # In this house we believe in consistency\n", 1275 | " nbcorrectthispair += 1\n", 1276 | " assert nbthispair > 0\n", 1277 | " perfs[pos] = nbcorrectthispair / nbthispair if nbthispair > 0 else -10\n", 1278 | " nbs[pos] = nbthispair\n", 1279 | " allperfs.append(perfs)\n", 1280 | "\n", 1281 | "# else:\n", 1282 | "# for e in range(NBEPISODES):\n", 1283 | "# c = corrects[e]\n", 1284 | "# a = adjs[e]\n", 1285 | "# cp = cuepairs[e]\n", 1286 | "# r = resps[e]\n", 1287 | "# perfs = np.zeros(len(orderofcuepairs))\n", 1288 | "# nbs = np.zeros(len(orderofcuepairs))\n", 1289 | "# for pos, p in enumerate(orderofcuepairs):\n", 1290 | "# nbthispair = nbcorrectthispair = 0\n", 1291 | "# for nb in range(cp.shape[0]):\n", 1292 | "# for nt in range(cp.shape[1]):\n", 1293 | "# if nt < params['nbtrials'] - params['nbtesttrials']:\n", 1294 | "# # if nt > 2: # Should be terrible\n", 1295 | "# continue\n", 1296 | "# if ( cp[nb, nt, 0] == p[0] and cp[nb, nt, 1] == p[1]) or ( cp[nb, nt, 0] == p[1] and cp[nb, nt, 1] == p[0]):\n", 1297 | "# nbthispair += 1\n", 1298 | "# if (cp[nb, nt, 0] == p[0] and r[nb,nt] == 1) or (cp[nb, nt, 0] == p[1] and r[nb,nt] == -1): # response was correct\n", 1299 | "# assert c[nb,nt] == 1 # In this house we believe in consistency\n", 1300 | "# nbcorrectthispair += 1\n", 1301 | "# perfs[pos] = nbcorrectthispair / nbthispair if nbthispair > 0 else -0.1\n", 1302 | "# nbs[pos] = nbthispair\n", 1303 | "# allperfs.append(perfs)\n", 1304 | "\n", 1305 | "\n", 1306 | "# allperfs contains the performance for each pair, for each split\n", 1307 | "\n", 1308 | "allperfs = np.array(allperfs)\n", 1309 | "print(np.median(allperfs, axis=0))\n", 1310 | "print(nbs)\n", 1311 | "print(allperfs.shape)\n" 1312 | ] 1313 | }, 1314 | { 1315 | "cell_type": "code", 1316 | "execution_count": null, 1317 | "metadata": { 1318 | "id": "dWXCaktjXhIa" 1319 | }, 1320 | "outputs": [], 1321 | "source": [ 1322 | "print(\"NOTE: This figure may look quite different every time you re-run the notebook, even if you use the same network parameters.\")\n", 1323 | "\n", 1324 | "plt.figure(figsize=(7,4))\n", 1325 | "\n", 1326 | "alphabet = [chr(i) for i in range(ord('A'),ord('Z')+1)] # How did we live without StackOverflow?\n", 1327 | "\n", 1328 | "# assert params['nbcues'] == 8 # Not sure if necessary\n", 1329 | "strt = 0\n", 1330 | "offset = 0 # offset is only here to put gaps in the x axis of the graph, it's not used to compute the actual values\n", 1331 | "rnge = params['nbcues']-1\n", 1332 | "xtks0 = []\n", 1333 | "xtks1 = []\n", 1334 | "for nump in range(params['nbcues']-1):\n", 1335 | " print(strt, strt+rnge)\n", 1336 | " # xtks = xtks + [alphabet[x[0]]+alphabet[x[1]] for x in orderofcuepairs[strt:strt+rnge]]\n", 1337 | " # xtks = xtks + [alphabet[x[0]]+alphabet[x[1]]\n", 1338 | " xtks0 = xtks0 + list(range(strt+offset, strt+rnge+offset))\n", 1339 | " for numx, x in enumerate(orderofcuepairs[strt:strt+rnge]):\n", 1340 | " xtks1= xtks1 + [('\\n' if numx % 2 == 1 else '') + alphabet[x[0]] +alphabet[x[1]]]\n", 1341 | "\n", 1342 | " plt.plot(range(strt+offset,strt+rnge+offset), np.median(allperfs[:,strt:strt+rnge], axis=0))\n", 1343 | " plt.fill_between(range(strt+offset,strt+rnge+offset), np.quantile(allperfs[:,strt:strt+rnge], .25, axis=0), np.quantile(allperfs[:,strt:strt+rnge], .75, axis=0), alpha=.3)\n", 1344 | "\n", 1345 | " # plt.plot(range(strt+offset,strt+rnge+offset), np.mean(allperfs[:,strt:strt+rnge], axis=0))\n", 1346 | " # plt.fill_between(range(strt+offset,strt+rnge+offset), np.mean(allperfs[:,strt:strt+rnge], axis=0) - np.std(allperfs[:,strt:strt+rnge], axis=0),\n", 1347 | " # np.mean(allperfs[:,strt:strt+rnge], axis=0) + np.std(allperfs[:,strt:strt+rnge], axis=0), alpha=.3)\n", 1348 | "\n", 1349 | " if rnge == 1:\n", 1350 | "\n", 1351 | " plt.plot([strt+offset], [np.median(allperfs[:,strt])], '.')\n", 1352 | " plt.errorbar([strt+offset], [np.median(allperfs[:,strt]),], [[np.median(allperfs[:,strt]) - np.quantile(allperfs[:,strt], .25)], [np.quantile(allperfs[:,strt], .75) - np.median(allperfs[:,strt])]]) # The most troublesome line in the whole codebase. We love matplotlib.\n", 1353 | "\n", 1354 | " # plt.plot([strt+offset], [np.mean(allperfs[:,strt])], '.')\n", 1355 | " # plt.errorbar([strt+offset], [np.mean(allperfs[:,strt]),], [np.std(allperfs[:,strt]),] ) # The most troublesome line in the whole codebase. We love matplotlib.\n", 1356 | "\n", 1357 | " strt += rnge\n", 1358 | " rnge -= 1\n", 1359 | " offset += 2\n", 1360 | "# plt.xticks(range(strt+offset-2), xtks[:strt+offset-2])\n", 1361 | "plt.xticks(xtks0, xtks1)\n", 1362 | "plt.ylabel('% correct (last '+str(params['nbtesttrials'])+' trials)')\n", 1363 | "if LINKEDLISTSEVAL:\n", 1364 | " plt.axhline(y=0.5, color='k', linestyle='--')\n", 1365 | " plt.ylabel('% correct (last test trial)')\n", 1366 | "else:\n", 1367 | " plt.ylabel('% correct (last '+str(params['nbtesttrials'])+' trials)')\n", 1368 | "\n", 1369 | "if not FIXEDCUES:\n", 1370 | " if LINKEDLISTSEVAL:\n", 1371 | " plt.savefig('SDE_LINKEDLISTS'+('_SHAM' if LINKINGISSHAM else '')+'.png', dpi=300)\n", 1372 | " else:\n", 1373 | " plt.savefig('SDE.png', dpi=300)\n" 1374 | ] 1375 | }, 1376 | { 1377 | "cell_type": "code", 1378 | "source": [ 1379 | "if LINKEDLISTSEVAL:\n", 1380 | " raise(ValueError(\"No point in going further for linked-lists experiments\"))" 1381 | ], 1382 | "metadata": { 1383 | "id": "ZBeHj7Nl9VhM" 1384 | }, 1385 | "execution_count": null, 1386 | "outputs": [] 1387 | }, 1388 | { 1389 | "cell_type": "code", 1390 | "execution_count": null, 1391 | "metadata": { 1392 | "id": "Wbm4Da8Ob2bk" 1393 | }, 1394 | "outputs": [], 1395 | "source": [ 1396 | "\n", 1397 | "\n", 1398 | "a = np.load('allrates_thisep_0.npy')\n", 1399 | "plt.plot(a[0, :10, -20:].T) # Notice the T. 10 first neurons, 20 last timesteps (for whole eplisode, so several trials)" 1400 | ] 1401 | }, 1402 | { 1403 | "cell_type": "code", 1404 | "execution_count": null, 1405 | "metadata": { 1406 | "id": "4lQq81XgsZO6" 1407 | }, 1408 | "outputs": [], 1409 | "source": [ 1410 | "# Try to predict whether the first cue was cue number X, based on firing rates at time T for the last trial\n", 1411 | "# Corr on test set: .5 for cues 0 and 6, .2 for 1 and 5, all the rest terrible....... This is only if T = 2 or 3, none other.\n", 1412 | "# Trial 19: stmi 0 is .25, 6 is .23, 1 is .1\n", 1413 | "# Also same values for logistic regression.\n", 1414 | "\n", 1415 | "# Also tried to predict simply whether the cued were correctly orderd (i.e. the correct response)\n", 1416 | "# First, panic attack: great correlation from trial 0 !!\n", 1417 | "# But then I realized it was for position in trial 3 - after reward is received!\n", 1418 | "# By contrast, at pos 2, correlation (i.e. ability to detect whether the two cues were correctly\n", 1419 | "# ordered from neural rates) is ~0 at trrial 0, but >.8 at trial 29. Good!\n", 1420 | "# at trial 19, only .5 corr... (though note that this is on an adjacent pairs) - at trial 20, .74... logistic: .6+-.8/.\n", 1421 | "\n", 1422 | "# Pos 2, trial 29: perfect-ish prediction of the difference between the ranks of stims1 and stim2 - even though prediction of \"whether ue N is the 1st/2nd\" is very bad (for non-end cues, and not very good for end-cues)...\n", 1423 | "# This requires recurrent step. At trial time 1, prediction is at chance, correlation 0.03... Strange? No, there was no randomness, theempty cue is always before the stimulus.\n", 1424 | "# trial 19 (ie only on adjacent pairs): predicts only at ~.5 corr. trial 20: .8, trial 22: .86\n", 1425 | "\n", 1426 | "# What about simplypredicting stim1 and stim2's rank themselves?...Decent, butnot perfect prediction - corr is .75 io .94.\n", 1427 | "# Trial 19: chance!! (-.002), Trial 20: .64,\n", 1428 | "#Predictiing the value of the firstcue jumps from ~0 at trial 19, to .6+ at trial 20 !... .5 trial 29.\n", 1429 | "\n", 1430 | "# NOTE: until trial 20, ordered and dists are essentially the same quantity with different scaling (because only adjacent pairs)\n", 1431 | "\n", 1432 | "# At trial pos 1 (stim presentation), nothing can be decoded. Yet 'hidden' definitely registers the inputs - if I pass in TTHCO as an input, decoding or ordered becomes 1.0 (also, deoding of 'first' becomes very good, espeically when non-adjacent?....)\n", 1433 | "# Unsurprising. There is no way it can register inputs as anything other than their actual bitstring content, since input weights are non-plastic.\n", 1434 | "# However, ability to now decode first-0 (better than first_3) and first (even better after 20!) is surprising. Is it plasticity-related?\n", 1435 | "# Same (giving TTHCO as input) but resetting pw and et at beginnign of all trials. Similar... Note that this ability is flat before 20 and after.\n", 1436 | "# So at trial pos 1, *IF* the decoder has access to TTHCO and whatever it sees at time 1, it can decode 'first' to good value (.2 bbefore 20, equall to first_0, .6 after 20, better than first_0, first_3 always chance)...\n", 1437 | "# .. but not without TTHCO - then it can decode nothing at all.\n", 1438 | "# Confirmed that it's really the information itself thst matters, not indirecct through recurrence or plasticity (if numtrial > 19 and numtrial % 2 != 0 and numstep == 1:, rreset everything, get the expected serrated curves)\n", 1439 | "# Note that you also get the dists to .8, afyer trial 20...\n", 1440 | "# But lol, you get the same values with zeroing out the cue input\n", 1441 | "# So you can decode 'first' to .6 (after 20) justfrom knowing TTHCO...\n", 1442 | "\n", 1443 | "\n", 1444 | "import sklearn.linear_model\n", 1445 | "import sklearn.neural_network\n", 1446 | "\n", 1447 | "# biga = []; bigcp = []; bigcorr= []; bigresps = []; bigac = []\n", 1448 | "# for numep in range(NBEPISODES):\n", 1449 | "# biga.append(np.load('allrates_thisep_'+str(numep)+'.npy'))\n", 1450 | "# o = np.load('outcomes_'+str(numep)+'.npz', allow_pickle=True);\n", 1451 | "# bigac.append(o['ac'])\n", 1452 | "# bigcp.append(o['cp'])\n", 1453 | "# bigcorr.append(o['c'])\n", 1454 | "# bigresps.append(o['r'])\n", 1455 | "\n", 1456 | "\n", 1457 | "numep = 2 if LINKEDLISTSEVAL else 0\n", 1458 | "o = np.load('outcomes_'+str(e)+'.npz', allow_pickle=True); corr = o['c']; ac = o['ac']; cp = o['cp']; resps = o['r']\n", 1459 | "\n", 1460 | "allrates = np.load('allrates_thisep_'+str(numep)+'.npy')\n", 1461 | "\n", 1462 | "# cp = np.vstack(bigcp)\n", 1463 | "# asav = a.copy()\n", 1464 | "# corr = np.vstack(bigcorr)\n", 1465 | "# resps = np.vstack(bigresps)\n", 1466 | "# ac = np.vstack(bigac)\n", 1467 | "print(cp.shape, a.shape, corr.shape, resps.shape, ac.shape)\n", 1468 | "#(5000, 30, 2) (5000, 200, 150) (5000, 30) (5000, 30) (5000, 30, 5)\n", 1469 | "print(np.unique(resps)) # -1, 1\n", 1470 | "print(np.unique(ac)) # 0, 1\n", 1471 | "\n", 1472 | "a = allrates.copy()\n", 1473 | "\n", 1474 | "assert cp.shape[0] == BS and cp.shape[1] == params['nbtrials']\n", 1475 | "assert a.shape[0] == BS and a.shape[1] == params['hs'] and a.shape[2] == params['eplen']\n", 1476 | "assert corr.shape[0] == BS and corr.shape[1] == params['nbtrials']\n", 1477 | "# assert cp.shape[0] == NBEPISODES * BS and cp.shape[1] == params['nbtrials']\n", 1478 | "# assert a.shape[0] == NBEPISODES * BS and a.shape[1] == params['hs'] and a.shape[2] == params['eplen']\n", 1479 | "# assert corr.shape[0] == NBEPISODES * BS and corr.shape[1] == params['nbtrials']\n", 1480 | "\n", 1481 | "# \"first_iscuenum\": for each stimulus (dim 1), whether this stimulus was the \"first\" stimulus for joined-batch element dim 2, numtrial dim 3\n", 1482 | "first_iscuenum = np.zeros((params['nbcues'], BS, params['nbtrials']))\n", 1483 | "second_iscuenum = np.zeros((params['nbcues'], BS, params['nbtrials']))\n", 1484 | "ordered = np.zeros(( BS, params['nbtrials']))\n", 1485 | "dists = np.zeros(( BS, params['nbtrials']))\n", 1486 | "first = np.zeros(( BS, params['nbtrials']))\n", 1487 | "second = np.zeros(( BS, params['nbtrials']))\n", 1488 | "for nb in range( BS):\n", 1489 | " for nt in range(params['nbtrials']):\n", 1490 | " first_iscuenum[cp[nb, nt, 0], nb, nt] = 1\n", 1491 | " second_iscuenum[cp[nb, nt, 1], nb, nt] = 1\n", 1492 | " ordered[nb, nt] = 1 if cp[nb, nt, 0] < cp[nb, nt, 1] else 0\n", 1493 | " dists[nb, nt] = cp[nb, nt, 0] - cp[nb, nt, 1]\n", 1494 | " first[nb, nt] = cp[nb, nt, 0]\n", 1495 | " second[nb, nt] = cp[nb, nt, 1]\n", 1496 | "# Note that corr already has the adequate shape, nb x nt. So does resps.\n", 1497 | "\n", 1498 | "print(first_iscuenum[1, 0, :]) # For stimulus 1, bath element 0, all trials\n", 1499 | "print(cp[0, :, 0]) # For batch element 0, all trials, first of pair. They better correspondto the previous line!\n", 1500 | "\n", 1501 | "# NUMSTIM = 2 # for the _iscuenum arrays\n", 1502 | "NUMTRIAL = 19 # 0 # 29 # 19\n", 1503 | "POSINTRIAL = 1 # 5 # 3 # 2 # 1\n", 1504 | "\n", 1505 | "x= a[:, :, POSINTRIAL::params['triallen']]\n", 1506 | "print(x.shape) # NBEPISODES*BS, 200, 30\n", 1507 | "# y = first_iscuenum[NUMSTIM,:,:][:, None, :]\n", 1508 | "# y = second_iscuenum[NUMSTIM,:,:][:, None, :]\n", 1509 | "# y = ordered[:, None, :]\n", 1510 | "# y = dists[:, None, :]\n", 1511 | "y = first[:, None, :]\n", 1512 | "# y = second[:, None, :]\n", 1513 | "\n", 1514 | "yall = np.hstack((\n", 1515 | " first_iscuenum[0,:,:][:, None, :],\n", 1516 | " first_iscuenum[3,:,:][:, None, :],\n", 1517 | " first[:, None, :],\n", 1518 | " ordered[:, None, :],\n", 1519 | " dists[:, None, :],\n", 1520 | " corr[:, None, :],\n", 1521 | " resps[:, None, :],\n", 1522 | "))\n", 1523 | "print(y.shape)\n", 1524 | "print(yall.shape)\n", 1525 | "\n", 1526 | "\n", 1527 | "# All runs are equally different, whether across episodes or within episodes (across the batch).\n", 1528 | "# So we can just split the data into two parts.\n", 1529 | "\n", 1530 | "model = sklearn.linear_model.LinearRegression()\n", 1531 | "# model = sklearn.linear_model.LogisticRegression(class_weight='balanced',solver='newton-cholesky',max_iter=10000)\n", 1532 | "# model = sklearn.linear_model.Ridge()\n", 1533 | "curves = np.zeros((params['nbtrials'], yall.shape[1]))\n", 1534 | "for NUMTRIAL in range(params['nbtrials']):\n", 1535 | " model.fit(x[:-200,:,NUMTRIAL], yall[:-200, :, NUMTRIAL])\n", 1536 | " out = model.predict(x[-200:, :,NUMTRIAL])\n", 1537 | " for nc in range(yall.shape[1]):\n", 1538 | " # curves[NUMTRIAL, nc] = np.corrcoef(np.argsort(np.argsort(out.T[nc,:])), np.argsort(np.argsort(yall[-200:, nc, NUMTRIAL])).T)[0,1]\n", 1539 | " curves[NUMTRIAL, nc] = np.corrcoef(out.T[nc,:], yall[-200:, nc, NUMTRIAL].T)[0,1]\n", 1540 | " # curves[NUMTRIAL, nc] = np.mean(out.T[nc,:].astype(int) == yall[-200:, nc, NUMTRIAL].T)\n", 1541 | "\n" 1542 | ] 1543 | }, 1544 | { 1545 | "cell_type": "code", 1546 | "execution_count": null, 1547 | "metadata": { 1548 | "id": "l6zf21pOBNAN" 1549 | }, 1550 | "outputs": [], 1551 | "source": [ 1552 | "plt.figure(figsize=(7,4))\n", 1553 | "plt.plot(curves, label=['Cue1_is_A', 'Cue1_is_D', 'Rank_of_Cue1', 'Ordered?', 'Distance', 'Resp_was_corr', 'Resp+/-'])\n", 1554 | "plt.xlabel('Trial')\n", 1555 | "plt.ylabel('% correct')\n", 1556 | "plt.legend(loc=(1.04,0))\n", 1557 | "plt.tight_layout()\n", 1558 | "plt.savefig('decoding.png',dpi=300)\n" 1559 | ] 1560 | }, 1561 | { 1562 | "cell_type": "code", 1563 | "execution_count": null, 1564 | "metadata": { 1565 | "id": "x4sGEIuJl0wx" 1566 | }, 1567 | "outputs": [], 1568 | "source": [ 1569 | "\n", 1570 | "# Decoding \"ordered\" is easy late in the episode, of course, so the system works,,,\n", 1571 | "# model = sklearn.neural_network.MLPClassifier(activation='tanh', hidden_layer_sizes=[100,], max_iter=2000)\n", 1572 | "# model = sklearn.neural_network.MLPRegressor (activation='tanh', hidden_layer_sizes=[100,], max_iter=2000)\n", 1573 | "model = sklearn.linear_model.LinearRegression()\n", 1574 | "curvenn = np.zeros(params['nbtrials'])\n", 1575 | "# curvennother = np.zeros(params['nbtrials'])\n", 1576 | "ynn = first_iscuenum[3, :, :]\n", 1577 | "# ynnother = first_iscuenum[4, :, :]\n", 1578 | "# ynn = ordered\n", 1579 | "NUMTRIAL = 19\n", 1580 | "POSINTRIALTRAIN = 0\n", 1581 | "POSINTRIALTEST = 1\n", 1582 | "mxtrain= allrates[:, :, POSINTRIALTRAIN::params['triallen']]\n", 1583 | "mxtest= allrates[:, :, POSINTRIALTEST::params['triallen']]\n", 1584 | "for NUMTRIAL in range(params['nbtrials']):\n", 1585 | " if NUMTRIAL % 6 == 0:\n", 1586 | " print(NUMTRIAL)\n", 1587 | " model.fit(mxtrain[:-200,:,NUMTRIAL], ynn[:-200, NUMTRIAL])\n", 1588 | " out = model.predict(mxtest[-200:, :,NUMTRIAL])\n", 1589 | " # print(out[:10]) # Might be all 0s!\n", 1590 | " curvenn[NUMTRIAL] = np.corrcoef(out.T[:], ynn[-200:, NUMTRIAL].T)[0,1]\n", 1591 | " # curvennother[NUMTRIAL] = np.corrcoef(out.T[:], ynnother[-200:, NUMTRIAL].T)[0,1]\n", 1592 | " else:\n", 1593 | " curvenn[NUMTRIAL] = curvenn[NUMTRIAL-1]\n", 1594 | " # curvennother[NUMTRIAL] = curvennother[NUMTRIAL-1]\n", 1595 | "print(curvenn)\n", 1596 | "# print(curvennother)" 1597 | ] 1598 | }, 1599 | { 1600 | "cell_type": "code", 1601 | "execution_count": null, 1602 | "metadata": { 1603 | "id": "rq--BP3IhE12" 1604 | }, 1605 | "outputs": [], 1606 | "source": [ 1607 | "print(a.shape, first_iscuenum.shape)\n", 1608 | "POSINTRIAL = 1\n", 1609 | "NUMTRIAL = 19\n", 1610 | "rates = allrates[:, :, POSINTRIAL::params['triallen']]\n", 1611 | "rates = rates[:, :, NUMTRIAL]\n", 1612 | "rates = rates - np.mean(rates, axis=0)[None, :]\n", 1613 | "rates2 = np.mean(rates[first_iscuenum[2, :, NUMTRIAL] == 1, :], axis=0)\n", 1614 | "rates3 = np.mean(rates[first_iscuenum[3, :, NUMTRIAL] == 1, :], axis=0)\n", 1615 | "rates4 = np.mean(rates[first_iscuenum[4, :, NUMTRIAL] == 1, :], axis=0)\n", 1616 | "rates5 = np.mean(rates[first_iscuenum[5, :, NUMTRIAL] == 1, :], axis=0)\n", 1617 | "rates6 = np.mean(rates[first_iscuenum[6, :, NUMTRIAL] == 1, :], axis=0)\n", 1618 | "rates7 = np.mean(rates[first_iscuenum[7, :, NUMTRIAL] == 1, :], axis=0)\n", 1619 | "print(np.corrcoef(rates2, rates3))\n", 1620 | "print(np.corrcoef(rates2, rates4))\n", 1621 | "print(np.corrcoef(rates2, rates5))\n", 1622 | "print(np.corrcoef(rates2, rates6))\n", 1623 | "print(np.corrcoef(rates2, rates7))\n", 1624 | "\n" 1625 | ] 1626 | }, 1627 | { 1628 | "cell_type": "code", 1629 | "execution_count": null, 1630 | "metadata": { 1631 | "id": "Ad-haXTQ6YGv" 1632 | }, 1633 | "outputs": [], 1634 | "source": [ 1635 | "print(x.shape) # 5000 (BS * NbEpisodes), 200 (NbNeurons), 32 (NbTrials)\n", 1636 | "print(yall.shape) # 5000 (BBS * NbEpisodes), 5 (number of different quantities), 32 (NbTrials)\n", 1637 | "print(yall[:, 1, :].shape)\n", 1638 | "print(first_iscuenum.shape) # 8 (numbber of different cues), 5000 (BS * NbEpisodes), 32 (NbTrials)\n", 1639 | "print(ordered.shape) # 5000 (BS * NbEpisodes), 32 (NbTrials)\n", 1640 | "print(corr.shape) # 5000 (BS * NbEpisodes), 32 (NbTrials)\n", 1641 | "\n", 1642 | "\n", 1643 | "\n", 1644 | "NUMTRIALPCA = 19 # 21 # 34\n", 1645 | "POSINTRIALPCA = 1\n", 1646 | "\n", 1647 | "# mx = x[:,:, NUMTRIAL]\n", 1648 | "mxp= allrates[:, :, POSINTRIALPCA::params['triallen']]\n", 1649 | "mx = mxp[:,:, NUMTRIALPCA]\n", 1650 | "\n", 1651 | "# mx = mx - np.mean(mx, axis=0)[None, :]\n", 1652 | "# mx = mx / (1e-8 + np.std(mx, axis=0)[None, :])\n", 1653 | "\n", 1654 | "\n", 1655 | "MAX = 300\n", 1656 | "\n", 1657 | "from sklearn.decomposition import PCA\n", 1658 | "\n", 1659 | "pca = PCA(n_components=50)\n", 1660 | "mx_pca =pca.fit_transform(mx)\n", 1661 | "print(mx_pca.shape) # 5000 (BS * NbEpisodes), 2 (Numberr of PCA components)\n", 1662 | "print(\"Variance explained for each successive PC:\", pca.explained_variance_ratio_)\n", 1663 | "\n", 1664 | "plt.figure(figsize=(20,10))\n", 1665 | "plt.subplot(3,4,1)\n", 1666 | "mx_pca_firstis0 = mx_pca[first_iscuenum[0, :, NUMTRIALPCA] == 1, :]\n", 1667 | "plt.plot(mx_pca_firstis0[:, 0], mx_pca_firstis0[:, 1], '.g', alpha=.3, label='First is 0')\n", 1668 | "mx_pca_firstis2 = mx_pca[first_iscuenum[2, :, NUMTRIALPCA] == 1, :]\n", 1669 | "plt.plot(mx_pca_firstis2[:, 0], mx_pca_firstis2[:, 1], '.r', alpha=.3, label='First is 2')\n", 1670 | "mx_pca_firstis3 = mx_pca[first_iscuenum[3, :, NUMTRIALPCA] == 1, :]\n", 1671 | "plt.plot(mx_pca_firstis3[:, 0], mx_pca_firstis3[:, 1], '.b', alpha=.3, label='First is 3')\n", 1672 | "mx_pca_firstis4 = mx_pca[first_iscuenum[4, :, NUMTRIALPCA] == 1, :]\n", 1673 | "plt.plot(mx_pca_firstis4[:, 0], mx_pca_firstis4[:, 1], '.y', alpha=.3, label='First is 4')\n", 1674 | "mx_pca_firstis5 = mx_pca[first_iscuenum[5, :, NUMTRIALPCA] == 1, :]\n", 1675 | "plt.plot(mx_pca_firstis5[:, 0], mx_pca_firstis5[:, 1], '.c', alpha=.3, label='First is 5')\n", 1676 | "plt.xlabel('PC 0'); plt.ylabel('PC 1')\n", 1677 | "plt.legend()\n", 1678 | "\n", 1679 | "\n", 1680 | "# mx_pca_firstis1 = mx_pca[first_iscuenum[1, :, NUMTRIAL] == 1, :]\n", 1681 | "# plt.plot(mx_pca_firstis1[:, 0], mx_pca_firstis1[:, 1], '.')\n", 1682 | "# mx_pca_firstis2 = mx_pca[first_iscuenum[2, :, NUMTRIAL] == 1, :]\n", 1683 | "# plt.plot(mx_pca_firstis2[:, 0], mx_pca_firstis2[:, 1], '.')\n", 1684 | "# mx_pca_firstis4 = mx_pca[first_iscuenum[4, :, NUMTRIAL] == 1, :]\n", 1685 | "# plt.plot(mx_pca_firstis4[:, 0], mx_pca_firstis4[:, 1], '.')\n", 1686 | "# mx_pca_firstis5 = mx_pca[first_iscuenum[5, :, NUMTRIAL] == 1, :]\n", 1687 | "# plt.plot(mx_pca_firstis5[:, 0], mx_pca_firstis5[:, 1], '.')\n", 1688 | "# mx_pca_firstis6 = mx_pca[first_iscuenum[6, :, NUMTRIAL] == 1, :]\n", 1689 | "# plt.plot(mx_pca_firstis6[:, 0], mx_pca_firstis6[:, 1], '.')\n", 1690 | "\n", 1691 | "\n", 1692 | "\n", 1693 | "\n", 1694 | "plt.subplot(3,4,2)\n", 1695 | "respPpoints = mx_pca[resps[:, NUMTRIALPCA] == 1, :]\n", 1696 | "plt.plot(respPpoints[:, 0], respPpoints[:, 1], '.g', alpha=.2, label='Resp+')\n", 1697 | "\n", 1698 | "respNpoints = mx_pca[resps[:, NUMTRIALPCA] == -1, :]\n", 1699 | "plt.plot(respNpoints[:, 0], respNpoints[:, 1], '.r', alpha=.2, label='Resp-')\n", 1700 | "plt.xlabel('PC 0'); plt.ylabel('PC 1')\n", 1701 | "plt.legend()\n", 1702 | "\n", 1703 | "\n", 1704 | "\n", 1705 | "\n", 1706 | "plt.subplot(3,4,3)\n", 1707 | "orderedpoints = mx_pca[ordered[:, NUMTRIALPCA] == 1, :]\n", 1708 | "plt.plot(orderedpoints[:, 0], orderedpoints[:, 1], '.g', alpha=.2, label='Ordered')\n", 1709 | "\n", 1710 | "notorderedpoints = mx_pca[ordered[:, NUMTRIALPCA] == 0, :]\n", 1711 | "plt.plot(notorderedpoints[:, 0], notorderedpoints[:, 1], '.r', alpha=.2, label='Not Ordered')\n", 1712 | "plt.xlabel('PC 0'); plt.ylabel('PC 1')\n", 1713 | "plt.legend()\n", 1714 | "\n", 1715 | "\n", 1716 | "plt.subplot(3,4,4)\n", 1717 | "currcorrpoints = mx_pca[corr[:, NUMTRIALPCA] == 1, :]\n", 1718 | "plt.plot(currcorrpoints[:, 0], currcorrpoints[:, 1], '.g', alpha=.2, label='Curr Resp Correct')\n", 1719 | "\n", 1720 | "currnotcorrpoints = mx_pca[corr[:, NUMTRIALPCA] == 0, :]\n", 1721 | "plt.plot(currnotcorrpoints[:, 0], currnotcorrpoints[:, 1], '.r', alpha=.2, label='Curr Resp Wrong')\n", 1722 | "plt.xlabel('PC 0'); plt.ylabel('PC 1')\n", 1723 | "plt.legend()\n", 1724 | "\n", 1725 | "\n", 1726 | "plt.subplot(3,4,5)\n", 1727 | "\n", 1728 | "\n", 1729 | "\n", 1730 | "prevcorrpoints = mx_pca[resps[:, NUMTRIALPCA-1] == 1, :]\n", 1731 | "plt.plot(prevcorrpoints[:, 0], prevcorrpoints[:, 1], '.g', alpha=.2, label='PrevTrial Resp+')\n", 1732 | "prevnotcorrpoints = mx_pca[resps[:, NUMTRIALPCA-1] == -1, :]\n", 1733 | "plt.plot(prevnotcorrpoints[:, 0], prevnotcorrpoints[:, 1], '.r', alpha=.2, label='PrevTrial Resp-')\n", 1734 | "plt.xlabel('PCA 0'); plt.ylabel('PCA 1')\n", 1735 | "plt.legend()\n", 1736 | "\n", 1737 | "\n", 1738 | "\n", 1739 | "\n", 1740 | "plt.subplot(3,4,6)\n", 1741 | "prespPpoints = mx_pca[resps[:, NUMTRIALPCA-1] == 1, :]\n", 1742 | "plt.plot(prespPpoints[:, 2], prespPpoints[:, 3], '.g', alpha=.2, label='Prev Trial Resp+')\n", 1743 | "prespNpoints = mx_pca[resps[:, NUMTRIALPCA-1] == -1, :]\n", 1744 | "plt.plot(prespNpoints[:, 2], prespNpoints[:, 3], '.r', alpha=.2, label='Prev Trial Resp-')\n", 1745 | "plt.xlabel('PCA 2'); plt.ylabel('PCA 3')\n", 1746 | "plt.legend()\n", 1747 | "\n", 1748 | "plt.subplot(3,4,7)\n", 1749 | "preva0points = mx_pca[ac[:, NUMTRIALPCA, POSINTRIALPCA-1] == 0, :]\n", 1750 | "plt.plot(preva0points[:, 0], preva0points[:, 1], '.g', alpha=.2, label='PrevStep Action 0')\n", 1751 | "preva1points = mx_pca[ac[:, NUMTRIALPCA, POSINTRIALPCA-1] == 1, :]\n", 1752 | "plt.plot(preva1points[:, 0], preva1points[:, 1], '.r', alpha=.2, label='PrevStep Action 1')\n", 1753 | "plt.xlabel('PCA 0'); plt.ylabel('PCA 1')\n", 1754 | "plt.legend()\n", 1755 | "\n", 1756 | "\n", 1757 | "plt.subplot(3,4,8)\n", 1758 | "mx_pca_secondis0 = mx_pca[second_iscuenum[0, :, NUMTRIALPCA] == 1, :]\n", 1759 | "plt.plot(mx_pca_secondis0[:, 0], mx_pca_secondis0[:, 1], '.g', alpha=.3, label='Second is 0')\n", 1760 | "mx_pca_secondis3 = mx_pca[second_iscuenum[3, :, NUMTRIALPCA] == 1, :]\n", 1761 | "plt.plot(mx_pca_secondis3[:, 0], mx_pca_secondis3[:, 1], '.b', alpha=.3, label='Second is 3')\n", 1762 | "mx_pca_secondis2 = mx_pca[second_iscuenum[2, :, NUMTRIALPCA] == 1, :]\n", 1763 | "plt.plot(mx_pca_secondis2[:, 0], mx_pca_secondis2[:, 1], '.r', alpha=.3, label='Second is 2')\n", 1764 | "mx_pca_secondis4 = mx_pca[second_iscuenum[4, :, NUMTRIALPCA] == 1, :]\n", 1765 | "plt.plot(mx_pca_secondis4[:, 0], mx_pca_secondis4[:, 1], '.y', alpha=.3, label='Second is 4')\n", 1766 | "mx_pca_secondis5 = mx_pca[second_iscuenum[5, :, NUMTRIALPCA] == 1, :]\n", 1767 | "plt.plot(mx_pca_secondis5[:, 0], mx_pca_secondis5[:, 1], '.c', alpha=.3, label='Second is 5')\n", 1768 | "plt.xlabel('PCA 0'); plt.ylabel('PCA 1')\n", 1769 | "plt.legend()\n", 1770 | "\n", 1771 | "\n", 1772 | "# plt.subplot(3,4,8)\n", 1773 | "# mx_pca_secondis0 = mx_pca[second_iscuenum[0, :, NUMTRIAL-1] == 1, :]\n", 1774 | "# plt.plot(mx_pca_secondis0[:, 4], mx_pca_secondis0[:, 3], '.g', alpha=.3, label='Prev Second is 0')\n", 1775 | "# mx_pca_secondis3 = mx_pca[second_iscuenum[3, :, NUMTRIAL-1] == 1, :]\n", 1776 | "# plt.plot(mx_pca_secondis3[:, 4], mx_pca_secondis3[:, 3], '.b', alpha=.3, label='Prev Second is 3')\n", 1777 | "# mx_pca_secondis2 = mx_pca[second_iscuenum[2, :, NUMTRIAL-1] == 1, :]\n", 1778 | "# plt.plot(mx_pca_secondis2[:, 4], mx_pca_secondis2[:, 3], '.r', alpha=.3, label='Prev Second is 2')\n", 1779 | "# mx_pca_secondis4 = mx_pca[second_iscuenum[4, :, NUMTRIAL-1] == 1, :]\n", 1780 | "# plt.plot(mx_pca_secondis4[:, 4], mx_pca_secondis4[:, 3], '.y', alpha=.3, label='Prev Second is 4')\n", 1781 | "# mx_pca_secondis5 = mx_pca[second_iscuenum[5, :, NUMTRIAL-1] == 1, :]\n", 1782 | "# plt.plot(mx_pca_secondis5[:, 4], mx_pca_secondis5[:, 3], '.c', alpha=.3, label='Prev Second is 5')\n", 1783 | "# plt.xlabel('PCA 4'); plt.ylabel('PCA 3')\n", 1784 | "# plt.legend()\n", 1785 | "\n", 1786 | "\n", 1787 | "\n", 1788 | "plt.subplot(3,4,9)\n", 1789 | "respPpoints = mx_pca[resps[:, NUMTRIALPCA] == 1, :]\n", 1790 | "plt.plot(respPpoints[:, 2], respPpoints[:, 3], '.g', alpha=.2, label='Resp+')\n", 1791 | "\n", 1792 | "respNpoints = mx_pca[resps[:, NUMTRIALPCA] == -1, :]\n", 1793 | "plt.plot(respNpoints[:, 2], respNpoints[:, 3], '.r', alpha=.2, label='Resp-')\n", 1794 | "plt.xlabel('PCA 2'); plt.ylabel('PCA 3')\n", 1795 | "plt.legend()\n", 1796 | "\n", 1797 | "\n", 1798 | "\n", 1799 | "plt.subplot(3,4,10)\n", 1800 | "prevcorrpoints = mx_pca[corr[:, NUMTRIALPCA] == 1, :]\n", 1801 | "plt.plot(prevcorrpoints[:, 2], prevcorrpoints[:, 3], '.g', alpha=.2, label='ThisTrial Resp Correct')\n", 1802 | "prevnotcorrpoints = mx_pca[corr[:, NUMTRIALPCA] == 0, :]\n", 1803 | "plt.plot(prevnotcorrpoints[:, 2], prevnotcorrpoints[:, 3], '.r', alpha=.2, label='ThisTrial Resp Wrong')\n", 1804 | "plt.xlabel('PCA 2'); plt.ylabel('PCA 3')\n", 1805 | "plt.legend()\n", 1806 | "\n", 1807 | "# plt.subplot(3,4,10)\n", 1808 | "# prevcorrpoints = mx_pca[corr[:, NUMTRIAL-1] == 1, :]\n", 1809 | "# plt.plot(prevcorrpoints[:, 2], prevcorrpoints[:, 3], '.g', alpha=.2, label='PrevTrial Resp Correct')\n", 1810 | "# prevnotcorrpoints = mx_pca[corr[:, NUMTRIAL-1] == 0, :]\n", 1811 | "# plt.plot(prevnotcorrpoints[:, 2], prevnotcorrpoints[:, 3], '.r', alpha=.2, label='PrevTrial Resp Wrong')\n", 1812 | "# plt.xlabel('PCA 2'); plt.ylabel('PCA 3')\n", 1813 | "# plt.legend()\n", 1814 | "\n", 1815 | "\n", 1816 | "plt.subplot(3,4,11)\n", 1817 | "mx_pca_secondis0 = mx_pca[second_iscuenum[0, :, NUMTRIALPCA] == 1, :]\n", 1818 | "plt.plot(mx_pca_secondis0[:, 2], mx_pca_secondis0[:, 3], '.g', alpha=.3, label='Second is 0')\n", 1819 | "mx_pca_secondis3 = mx_pca[second_iscuenum[3, :, NUMTRIALPCA] == 1, :]\n", 1820 | "plt.plot(mx_pca_secondis3[:, 2], mx_pca_secondis3[:, 3], '.b', alpha=.3, label='Second is 3')\n", 1821 | "mx_pca_secondis2 = mx_pca[second_iscuenum[2, :, NUMTRIALPCA] == 1, :]\n", 1822 | "plt.plot(mx_pca_secondis2[:, 2], mx_pca_secondis2[:, 3], '.r', alpha=.3, label='Second is 2')\n", 1823 | "mx_pca_secondis4 = mx_pca[second_iscuenum[4, :, NUMTRIALPCA] == 1, :]\n", 1824 | "plt.plot(mx_pca_secondis4[:, 2], mx_pca_secondis4[:, 3], '.y', alpha=.3, label='Second is 4')\n", 1825 | "mx_pca_secondis5 = mx_pca[second_iscuenum[5, :, NUMTRIALPCA] == 1, :]\n", 1826 | "plt.plot(mx_pca_secondis5[:, 2], mx_pca_secondis5[:, 3], '.c', alpha=.3, label='Second is 5')\n", 1827 | "plt.xlabel('PCA 2'); plt.ylabel('PCA 3')\n", 1828 | "plt.legend()\n", 1829 | "\n", 1830 | "plt.subplot(3,4,12)\n", 1831 | "prevcorrpoints = mx_pca[corr[:, NUMTRIALPCA-1] == 1, :]\n", 1832 | "plt.plot(prevcorrpoints[:, 0], prevcorrpoints[:, 1], '.g', alpha=.2, label='PrevTrial Resp Correct')\n", 1833 | "prevnotcorrpoints = mx_pca[corr[:, NUMTRIALPCA-1] == 0, :]\n", 1834 | "plt.plot(prevnotcorrpoints[:, 0], prevnotcorrpoints[:, 1], '.r', alpha=.2, label='PrevTrial Resp Wrong')\n", 1835 | "plt.xlabel('PCA 0'); plt.ylabel('PCA 1')\n", 1836 | "plt.legend()\n" 1837 | ] 1838 | }, 1839 | { 1840 | "cell_type": "code", 1841 | "execution_count": null, 1842 | "metadata": { 1843 | "id": "_gHSWamRrHd1" 1844 | }, 1845 | "outputs": [], 1846 | "source": [ 1847 | "# If the PCA was done on numstep 1 (response step), the first PC / eigenv, which represents 'decision', should be strongly aligned with the input weights of the network's output head\n", 1848 | "wo = net.h2o.weight.cpu().numpy()\n", 1849 | "print(\"Correlation between the two rows of Wout:\")\n", 1850 | "print(np.corrcoef(wo[0,:], wo[1,:]))\n", 1851 | "\n", 1852 | "\n", 1853 | "wo = wo[1,:] - wo[0,:]\n", 1854 | "print(\"Correlation between Wout and PC1:\")\n", 1855 | "print(np.corrcoef(wo, pca.components_[0,:]))\n", 1856 | "\n", 1857 | "print(\"Correlation between Win for each cue\")\n", 1858 | "wi = net.i2h.weight.cpu().numpy()\n", 1859 | "print(wi.shape)\n", 1860 | "wi1 = wi[:,:params['cs']].flatten()\n", 1861 | "wi2 = wi[:,params['cs']:2*params['cs']].flatten()\n", 1862 | "print(np.corrcoef(wi1, wi2))\n", 1863 | "\n" 1864 | ] 1865 | }, 1866 | { 1867 | "cell_type": "code", 1868 | "execution_count": null, 1869 | "metadata": { 1870 | "id": "pB0JC4z08Xyc" 1871 | }, 1872 | "outputs": [], 1873 | "source": [ 1874 | "\n", 1875 | "wo = net.h2o.weight.cpu().numpy()\n", 1876 | "wo =wo[1,:] - wo[0,:]\n", 1877 | "print(wo.shape)\n", 1878 | "\n", 1879 | "woPCA = pca.transform(wo[None, :])[0,:]\n", 1880 | "print(woPCA.shape)\n", 1881 | "print(woPCA[:4])\n", 1882 | "\n" 1883 | ] 1884 | }, 1885 | { 1886 | "cell_type": "code", 1887 | "execution_count": null, 1888 | "metadata": { 1889 | "id": "ob1n-qbm71a-" 1890 | }, 1891 | "outputs": [], 1892 | "source": [ 1893 | "\n", 1894 | "\n", 1895 | "plt.figure(figsize=(6,6))\n", 1896 | "\n", 1897 | "\n", 1898 | "\n", 1899 | "plt.subplot(2,2,1)\n", 1900 | "\n", 1901 | "\n", 1902 | "respPpoints = mx_pca[resps[:, NUMTRIALPCA] == 1, :]\n", 1903 | "plt.plot(respPpoints[:, 0], respPpoints[:, 1], '+c', alpha=.3, label='Choose Stim1')\n", 1904 | "\n", 1905 | "respNpoints = mx_pca[resps[:, NUMTRIALPCA] == -1, :]\n", 1906 | "plt.plot(respNpoints[:, 0], respNpoints[:, 1], '.r', alpha=.2, label='Choose Stim2')\n", 1907 | "plt.xlabel('PC 1'); plt.ylabel('PC 2')\n", 1908 | "\n", 1909 | "\n", 1910 | "# Arrrow represents the output weights vector\n", 1911 | "# Direction of arrow (left/right) is random....but it should point towards the blue crosses (positive response) !\n", 1912 | "\n", 1913 | "arrlength = 1.3\n", 1914 | "plt.arrow(0, 0, arrlength * woPCA[0], arrlength * woPCA[1], color='k', zorder=10, width=.1, head_width=.5)#, label='$\\mathbf{W}_{out}$')\n", 1915 | "# plt.text(-3.5, -.6, '$\\mathbf{W}_{out}$', fontsize=15)\n", 1916 | "plt.text(0, -.75, '$\\mathbf{W_{out}}$', fontsize=15)\n", 1917 | "\n", 1918 | "plt.legend()\n", 1919 | "\n", 1920 | "\n", 1921 | "\n", 1922 | "\n", 1923 | "plt.subplot(2,2,2)\n", 1924 | "orderedpoints = mx_pca[ordered[:, NUMTRIALPCA] == 1, :]\n", 1925 | "plt.plot(orderedpoints[:, 0], orderedpoints[:, 1], '+c', alpha=.3, label='Stim1>Stim2')\n", 1926 | "\n", 1927 | "notorderedpoints = mx_pca[ordered[:, NUMTRIALPCA] == 0, :]\n", 1928 | "plt.plot(notorderedpoints[:, 0], notorderedpoints[:, 1], '.r', alpha=.2, label='Stim2>Stim1')\n", 1929 | "plt.xlabel('PC 1'); plt.ylabel('PC 2')\n", 1930 | "plt.legend()\n", 1931 | "\n", 1932 | "\n", 1933 | "plt.subplot(2,2,3)\n", 1934 | "\n", 1935 | "\n", 1936 | "currcorrpoints = mx_pca[corr[:, NUMTRIALPCA] == 1, :]\n", 1937 | "plt.plot(currcorrpoints[:, 0], currcorrpoints[:, 1], '+c', alpha=.3, label='Correct')\n", 1938 | "\n", 1939 | "currnotcorrpoints = mx_pca[corr[:, NUMTRIALPCA] == 0, :]\n", 1940 | "plt.plot(currnotcorrpoints[:, 0], currnotcorrpoints[:, 1], '.r', alpha=.2, label='Error')\n", 1941 | "plt.xlabel('PC 1'); plt.ylabel('PC 2')\n", 1942 | "plt.legend()\n", 1943 | "\n", 1944 | "\n", 1945 | "\n", 1946 | "\n", 1947 | "\n", 1948 | "plt.subplot(2,2,4)\n", 1949 | "\n", 1950 | "mx_pca_firstis0 = mx_pca[first_iscuenum[0, :, NUMTRIALPCA] == 1, :]\n", 1951 | "plt.plot(mx_pca_firstis0[:, 0], mx_pca_firstis0[:, 1], '.g', alpha=.3, label='Cue1:A')\n", 1952 | "mx_pca_firstis2 = mx_pca[first_iscuenum[2, :, NUMTRIALPCA] == 1, :]\n", 1953 | "plt.plot(mx_pca_firstis2[:, 0], mx_pca_firstis2[:, 1], '.r', alpha=.3, label='Cue1:B')\n", 1954 | "mx_pca_firstis3 = mx_pca[first_iscuenum[3, :, NUMTRIALPCA] == 1, :]\n", 1955 | "plt.plot(mx_pca_firstis3[:, 0], mx_pca_firstis3[:, 1], '.b', alpha=.3, label='Cue1:C')\n", 1956 | "mx_pca_firstis4 = mx_pca[first_iscuenum[4, :, NUMTRIALPCA] == 1, :]\n", 1957 | "plt.plot(mx_pca_firstis4[:, 0], mx_pca_firstis4[:, 1], '.y', alpha=.3, label='Cue1:D')\n", 1958 | "# mx_pca_firstis5 = mx_pca[first_iscuenum[5, :, NUMTRIALPCA] == 1, :]\n", 1959 | "# plt.plot(mx_pca_firstis5[:, 0], mx_pca_firstis5[:, 1], '.c', alpha=.3, label='Cue1=E')\n", 1960 | "plt.xlabel('PC 1'); plt.ylabel('PC 2')\n", 1961 | "plt.legend(ncol=2)\n", 1962 | "\n", 1963 | "# plt.gcf().suptitle(\"PCA of $\\mathbf{r}(t=2)$, trial 20, 2000 runs\")\n", 1964 | "\n", 1965 | "plt.tight_layout()\n", 1966 | "\n", 1967 | "plt.savefig('PCA.png', dpi=300)\n", 1968 | "\n" 1969 | ] 1970 | }, 1971 | { 1972 | "cell_type": "code", 1973 | "execution_count": null, 1974 | "metadata": { 1975 | "id": "_bvf83An1CDl" 1976 | }, 1977 | "outputs": [], 1978 | "source": [ 1979 | "# Extracting the patterns produced in hidden neurons by each individual cue when presented as 1st cue, on step 0 (i.e. without any recurrence)\n", 1980 | "\n", 1981 | "# NOTE: This is only useful across batch if we use fixed cues !\n", 1982 | "BS2 = 10\n", 1983 | "\n", 1984 | "inputs = np.zeros((BS2, params['inputsize']), dtype='float32')\n", 1985 | "\n", 1986 | "for nb in range(params['nbcues']):\n", 1987 | " # Turning the cue number for this time step into actual (signed) bitstring inputs, using the cue data generated at the beginning of the episode - or, ocasionally, oldcuedata\n", 1988 | " inputs[nb, :NBSTIMBITS] = 0\n", 1989 | " inputs[nb, :params['cs']] = cuedata[0][nb][:]\n", 1990 | "\n", 1991 | " inputs[nb, NBSTIMBITS + 0] = 1.0 # Bias neuron, probably not necessary\n", 1992 | " inputs[nb,NBSTIMBITS + 1] = 0 #numstep_ep / params['eplen'] # Time passed in this episode. Should it be the trial? Doesn't matter much anyway.\n", 1993 | " inputs[nb, NBSTIMBITS + 2] = 0 # 1.0 * reward[nb] # Reward from previous time step\n", 1994 | "\n", 1995 | " # if numstep == 0 and numtrial > 0:\n", 1996 | " # inputs[nb, NBSTIMBITS + ADDINPUT + numactionschosen[nb]] = 1 # Previously chosen action\n", 1997 | "\n", 1998 | "\n", 1999 | "# inputsC = torch.from_numpy(inputs, requires_grad=False).to(device)\n", 2000 | "inputsN = torch.from_numpy(inputs).detach().to(device)\n", 2001 | "\n", 2002 | "hidden = net.initialZeroState(BS2)\n", 2003 | "et = net.initialZeroET(BS2) # The Hebbian eligibility trace\n", 2004 | "pw = net.initialZeroPlasticWeights(BS2)\n", 2005 | "\n", 2006 | "y, v, DAout, cue1patterns, et, pw = net(inputsN, hidden, et, pw) # y should output raw scores, not probas\n", 2007 | "\n", 2008 | "print(cue1patterns.shape) # 500, 200\n", 2009 | "cue1patterns = cue1patterns[:params['nbcues'], :].cpu().numpy()\n" 2010 | ] 2011 | }, 2012 | { 2013 | "cell_type": "code", 2014 | "execution_count": null, 2015 | "metadata": { 2016 | "id": "41HlASWBZ0YX" 2017 | }, 2018 | "outputs": [], 2019 | "source": [ 2020 | "# *Single* cues, after one step of recurrence, are projected along the \"decision\" axis (which matches the output weights) at their appropriate rank.\n", 2021 | "\n", 2022 | "# You can't decode the rank of either 1st oe 2nd cue from neuron firing in normal network operation, because then the network sees both the first and the second cue as one and\n", 2023 | "# (becausethe input weights for both cues are mirror images of each other) reflect the *difference* between the two cues\n", 2024 | "# which is exactly what is needed for decision)\n", 2025 | "\n", 2026 | "# NOTE: This uses the specific representation of the nb'th cue in the nb'th network so it doesn't require fixed cues, but it's very noisy ofc\n", 2027 | "# because the pw of only 1 batch element are used to assess corrrelatioin of step-1 repres of each cue with wo\n", 2028 | "\n", 2029 | "BS2 = 10\n", 2030 | "inputs = np.zeros((BS2, params['inputsize']), dtype='float32')\n", 2031 | "for nb in range(params['nbcues']):\n", 2032 | " inputs[nb, :NBSTIMBITS] = 0\n", 2033 | " # inputs[nb, :params['cs']] = cuedata[0][nb][:]\n", 2034 | " inputs[nb, :params['cs']] = cuedata[nb][nb][:]\n", 2035 | " inputs[nb, NBSTIMBITS + 0] = 1.0 # Bias neuron, probably not necessary\n", 2036 | " inputs[nb,NBSTIMBITS + 1] = 0 #numstep_ep / params['eplen'] # Time passed in this episode. Should it be the trial? Doesn't matter much anyway.\n", 2037 | " inputs[nb, NBSTIMBITS + 2] = 0 # 1.0 * reward[nb] # Reward from previous time step\n", 2038 | "inputsN0 = torch.from_numpy(inputs).detach().to(device)\n", 2039 | "inputsN1 = inputsN0.clone()\n", 2040 | "inputsN1[:, :NBSTIMBITS] = 0 # Normally, stimuli are only presented at the first time step\n", 2041 | "\n", 2042 | "hidden = net.initialZeroState(BS2)\n", 2043 | "et = net.initialZeroET(BS2) # The Hebbian eligibility trace\n", 2044 | "\n", 2045 | "# pwsav = pwsav29.clone()\n", 2046 | "pwsav = allpwsavs_thisep[19*params['triallen']][:BS2, :].copy()\n", 2047 | "pwtest = torch.from_numpy(pwsav).to(device)\n", 2048 | "\n", 2049 | "# TESTNB = 0\n", 2050 | "# pwtest[:,:,:] = pwsav[TESTNB,:,:][None,:,:]\n", 2051 | "# print(\"These should all be identical:\")\n", 2052 | "# print(pwtest[0,7,:10])\n", 2053 | "# print(pwtest[1,7,:10])\n", 2054 | "# print(pwsav[5,7,:10])\n", 2055 | "\n", 2056 | "y, v, DAout, hidden, et, pw = net(inputsN0, hidden, et, pwtest)\n", 2057 | "y, v, DAout, hiddenout, et, pw = net(inputsN1, hidden, et, pw)\n", 2058 | "\n", 2059 | "hiddenout = hiddenout.cpu().numpy()\n", 2060 | "\n", 2061 | "wo = net.h2o.weight.cpu().numpy()\n", 2062 | "wo = wo[1,:] - wo[0,:] # output weight vector\n", 2063 | "print(\"Correlation of step-1 representation of each cue with output-weights vector\\n(after training period) (FOR THE PW OF A SINGLE BATCH ELEMENT!) (very noisy due to not averaging overr batch):\")\n", 2064 | "for nb in range(params['nbcues']):\n", 2065 | " print(\"Cue\", nb, \":\", np.corrcoef(hiddenout[nb,:], wo)[0,1])#, np.sum(hiddenout[nb,:]**2) np.sum(hiddenout[nb,:] * wo) / np.sqrt(np.sum(wo ** 2)), np.sum(hiddenout[nb,:] * wo) / (np.sqrt(np.sum(wo ** 2)) * np.sqrt(sum(hiddenout[nb,:]**2))))\n", 2066 | "# print(cp[TESTNB,19:22,:].T)\n", 2067 | "\n" 2068 | ] 2069 | }, 2070 | { 2071 | "cell_type": "code", 2072 | "execution_count": null, 2073 | "metadata": { 2074 | "id": "pkLBqIyQ7uyZ" 2075 | }, 2076 | "outputs": [], 2077 | "source": [ 2078 | "# Extracting the patterns produced in hidden neurons by each individual cue when presented as 1st cue, on step 0 (i.e. without any recurrence)\n", 2079 | "\n", 2080 | "# For each batch element separately ! No need for fixed cues\n", 2081 | "\n", 2082 | "\n", 2083 | "tic =time.time()\n", 2084 | "MYBS = params['bs']\n", 2085 | "\n", 2086 | "inputs = np.zeros((params['bs'], params['inputsize']), dtype='float32')\n", 2087 | "cuedata_arr= np.array(cuedata)\n", 2088 | "# raise ValueError\n", 2089 | "\n", 2090 | "print(\"Extracting step-0 (feedforward) representations of all cues, for each batchelement (they differ because each batchelement has its own randomly generated cues)\")\n", 2091 | "\n", 2092 | "cue1patternsallbatch = []\n", 2093 | "for nb in range(params['nbcues']):\n", 2094 | " print(\"Cue\", nb)\n", 2095 | " # Turning the cue number for this time step into actual (signed) bitstring inputs, using the cue data generated at the beginning of the episode - or, ocasionally, oldcuedata\n", 2096 | " inputs[:, :NBSTIMBITS] = 0\n", 2097 | " inputs[:, :params['cs']] = cuedata_arr[:, nb, :]\n", 2098 | "\n", 2099 | " inputs[:, NBSTIMBITS + 0] = 1.0 # Bias neuron, probably not necessary\n", 2100 | " inputs[:,NBSTIMBITS + 1] = 0 #numstep_ep / params['eplen'] # Time passed in this episode. Should it be the trial? Doesn't matter much anyway.\n", 2101 | " inputs[:, NBSTIMBITS + 2] = 0 # 1.0 * reward[nb] # Reward from previous time step\n", 2102 | "\n", 2103 | " inputsN = torch.from_numpy(inputs).detach().to(device)\n", 2104 | "\n", 2105 | " hidden = net.initialZeroState(MYBS)\n", 2106 | " et = net.initialZeroET(MYBS) # The Hebbian eligibility trace\n", 2107 | " pw = net.initialZeroPlasticWeights(MYBS)\n", 2108 | "\n", 2109 | " y, v, DAout, hidden, et, pw = net(inputsN, hidden, et, pw) # y should output raw scores, not probas\n", 2110 | " cue1patternsallbatch.append(hidden.cpu().numpy())\n", 2111 | "\n", 2112 | "print(\"Time taken:\", time.time()-tic)\n" 2113 | ] 2114 | }, 2115 | { 2116 | "cell_type": "code", 2117 | "execution_count": null, 2118 | "metadata": { 2119 | "id": "d8uDR2YZ_4oN" 2120 | }, 2121 | "outputs": [], 2122 | "source": [ 2123 | "# *Single* cues, after one step of recurrence, are projected along the \"decision\" axis (which matches the output weights) at their appropriate rank.\n", 2124 | "\n", 2125 | "# You can't decode the rank of either 1st oe 2nd cue from neuron firing in normal network operation, because then the network sees both the first and the second cue as one and\n", 2126 | "# (becausethe input weights for both cues are mirror images of each other) reflect the *difference* between the two cues\n", 2127 | "# which is exactly what is needed for decision)\n", 2128 | "\n", 2129 | "# NOTE: This uses the specific representation of the nb'th cue in the nb'th network so it doesn't require fixed cues\n", 2130 | "MYBS = params['bs']\n", 2131 | "\n", 2132 | "inputs = np.zeros((params['bs'], params['inputsize']), dtype='float32')\n", 2133 | "cuedata_arr= np.array(cuedata)\n", 2134 | "# raise ValueError\n", 2135 | "\n", 2136 | "pwsav = allpwsavs_thisep[19*params['triallen']].copy()\n", 2137 | "pwtest = torch.from_numpy(pwsav).to(device)\n", 2138 | "\n", 2139 | "wo = net.h2o.weight.cpu().numpy()\n", 2140 | "wo =wo[1,:] - wo[0,:]\n", 2141 | "\n", 2142 | "allcorrs = []\n", 2143 | "print(\"Mean (across batch) correlation of step-1 representation of each cue with output-weights vector\\n(trial 20):\")\n", 2144 | "\n", 2145 | "for nb in range(params['nbcues']):\n", 2146 | " # Turning the cue number for this time step into actual (signed) bitstring inputs, using the cue data generated at the beginning of the episode - or, ocasionally, oldcuedata\n", 2147 | " inputs[:, :NBSTIMBITS] = 0\n", 2148 | " inputs[:, :params['cs']] = cuedata_arr[:, nb, :]\n", 2149 | "\n", 2150 | " inputs[:, NBSTIMBITS + 0] = 1.0 # Bias neuron, probably not necessary\n", 2151 | " inputs[:,NBSTIMBITS + 1] = 0 #numstep_ep / params['eplen'] # Time passed in this episode. Should it be the trial? Doesn't matter much anyway.\n", 2152 | " inputs[:, NBSTIMBITS + 2] = 0 # 1.0 * reward[nb] # Reward from previous time step\n", 2153 | "\n", 2154 | " inputsN0 = torch.from_numpy(inputs).detach().to(device)\n", 2155 | " inputsN1 = inputsN0.clone()\n", 2156 | " inputsN1[:, :NBSTIMBITS] = 0 # Normally, stimuli are only presented at the first time step\n", 2157 | "\n", 2158 | " hidden = net.initialZeroState(MYBS)\n", 2159 | " et = net.initialZeroET(MYBS) # The Hebbian eligibility trace\n", 2160 | " pw = net.initialZeroPlasticWeights(MYBS)\n", 2161 | "\n", 2162 | " y, v, DAout, hidden, et, pw = net(inputsN0, hidden, et, pwtest)\n", 2163 | " y, v, DAout, hiddenout, et, pw = net(inputsN1, hidden, et, pw)\n", 2164 | "\n", 2165 | " hiddenout = hiddenout.cpu().numpy()\n", 2166 | " z = np.corrcoef(hiddenout, wo) # Very wasteful, computes correlations between all the hiddenouts in the batch! But so be it.\n", 2167 | " z = z[:-1,-1] # Last item is always 1.0\n", 2168 | " print(\"Cue\", nb, \":\", np.mean(z))\n", 2169 | "\n", 2170 | " allcorrs.append(z)\n", 2171 | "\n", 2172 | "\n", 2173 | "allcorrs = np.array(allcorrs)" 2174 | ] 2175 | }, 2176 | { 2177 | "cell_type": "code", 2178 | "source": [ 2179 | "print(\"Correlation between the step-2 (learned) representation of each individual cue, and the output weight vector w_out\")\n", 2180 | "print(allcorrs.shape) # 8, 2000\n", 2181 | "\n", 2182 | "plt.figure(figsize=(3,3))\n", 2183 | "plt.xticks(np.arange(8), alphabet[:8])\n", 2184 | "plt.plot(np.array(np.mean(allcorrs, axis=1)), 'orange')\n", 2185 | "plt.errorbar(np.arange(8), np.mean(allcorrs, axis=1),yerr=np.std(allcorrs, axis=1),color='b', marker='o', linestyle='none')\n", 2186 | "plt.xlabel(\"Single Cue (X)\")\n", 2187 | "plt.ylabel(\"Corr($\\psi_{t2}(X), \\mathbf{w}_{out}$)\")\n", 2188 | "plt.tight_layout()\n", 2189 | "\n", 2190 | "plt.savefig(\"cuecorrs.png\", dpi=300)" 2191 | ], 2192 | "metadata": { 2193 | "id": "lQp8Vjv8pvqK" 2194 | }, 2195 | "execution_count": null, 2196 | "outputs": [] 2197 | }, 2198 | { 2199 | "cell_type": "code", 2200 | "source": [ 2201 | "print(\"Plotting the change in alignment between step-2 cue representation and wout, step by step, for each cue\")\n", 2202 | "\n", 2203 | "MYBS = params['bs']\n", 2204 | "\n", 2205 | "\n", 2206 | "# At which trial are we going to assess the changes in correlation?\n", 2207 | "NUMTRIALCHNGCORR = 18 # 5 or 6 or 18 or 19, because only data from trials 5,6, 18 and 19 are stored in the main code cell. Paper uses 18. 19 should have fewer data points if we select those that didn't see the central pair before that, obviously (and if we don't use HALFNOBARREDPAIRUNTILT18) - but it works.\n", 2208 | "# NUMTRIALCHNGCORR = 19 # 5 or 6 or 18 or 19, because only trials 5,6, 18 and 19 are stored in the main cell\n", 2209 | "\n", 2210 | "\n", 2211 | "\n", 2212 | "np.set_printoptions(linewidth=np.inf)\n", 2213 | "\n", 2214 | "# We select the runs where the cue pair at trial NUMTRIALCHNGCORR was 4&3 or 3&4, AND it was not shown in any of the previous trials (i.e. the \"spontaneous list linking\" episodes)\n", 2215 | "# (response for such trials will invariably be wrong, I think)\n", 2216 | "\n", 2217 | "selectpair = BARREDPAIR\n", 2218 | "\n", 2219 | "# ADDBARREDPAIR = [BARREDPAIR[0]-1, BARREDPAIR[0]]\n", 2220 | "ADDBARREDPAIR = [BARREDPAIR[1], BARREDPAIR[1]+1]\n", 2221 | "\n", 2222 | "\n", 2223 | "selects = []\n", 2224 | "for nb in range(MYBS):\n", 2225 | " include = True\n", 2226 | " if selectpair[0] not in cp[nb, NUMTRIALCHNGCORR,:] or selectpair[1] not in cp[nb, NUMTRIALCHNGCORR,:]:\n", 2227 | " include = False\n", 2228 | " else:\n", 2229 | " for nt in range(NUMTRIALCHNGCORR): # going to NUMTRIALCHNGCORR - 1 incl.\n", 2230 | " if cp[nb, nt, 0] in selectpair and cp[nb, nt, 1] in selectpair:\n", 2231 | " include = False\n", 2232 | " break\n", 2233 | " if include:\n", 2234 | " selects.append(nb)\n", 2235 | "print(len(selects))\n", 2236 | "print(selects)\n", 2237 | "\n", 2238 | "\n", 2239 | "# Additionally, from all these, we separately select those where the pair juust beforre OR just after the main barrred pair was also not shown before\n", 2240 | "selectsadd = []\n", 2241 | "for nb in selects:\n", 2242 | " include = True\n", 2243 | " for nt in range(NUMTRIALCHNGCORR): # going to NUMTRIALCHNGCORR - 1 incl.\n", 2244 | " if cp[nb, nt, 0] in [ADDBARREDPAIR[0],ADDBARREDPAIR[1]] and cp[nb, nt, 1] in [ADDBARREDPAIR[0],ADDBARREDPAIR[1]] : # Any order would cause pairing.\n", 2245 | " include = False\n", 2246 | " break\n", 2247 | " if include:\n", 2248 | " selectsadd.append(nb)\n", 2249 | "print(len(selectsadd))\n", 2250 | "print(selectsadd)\n", 2251 | "\n", 2252 | "\n", 2253 | "NUMBECHNGCORR = selects[0] # NUMBE = \"Number (index) of the batch element\"\n", 2254 | "\n", 2255 | "\n", 2256 | "\n", 2257 | "print(\"Cue pairs for batch element\", NUMBECHNGCORR, \":\")\n", 2258 | "print(np.hstack((cp[NUMBECHNGCORR,:,:], np.arange(cp.shape[1])[:, None])).T)\n", 2259 | "\n", 2260 | "MYBS = params['bs']\n", 2261 | "\n", 2262 | "ds= ds_thisep.copy()\n", 2263 | "rs = rs_thisep.copy()\n", 2264 | "\n", 2265 | "inputs = np.zeros((params['bs'], params['inputsize']), dtype='float32')\n", 2266 | "cuedata_arr= np.array(cuedata)\n", 2267 | "\n", 2268 | "torch.set_grad_enabled(False)\n", 2269 | "\n", 2270 | "# wo = net.h2o.weight.cpu().numpy()[0,:] # output weights\n", 2271 | "wo = net.h2o.weight.cpu().numpy()\n", 2272 | "wo = wo[1,:] - wo[0,:] # output weight vector\n", 2273 | "\n", 2274 | "oldvals = [''] * 8\n", 2275 | "\n", 2276 | "myselects = selects\n", 2277 | "# myselects = selectsadd\n", 2278 | "\n", 2279 | "corrseachstep = []\n", 2280 | "\n", 2281 | "\n", 2282 | "for numstep in range(4):\n", 2283 | " actualnumstep = NUMTRIALCHNGCORR * params['triallen'] + numstep\n", 2284 | " print(\"Actual step:\", actualnumstep)\n", 2285 | " pwsavfull = allpwsavs_thisep[actualnumstep]\n", 2286 | " pwtest = torch.from_numpy(pwsavfull).to(device)\n", 2287 | " corrsthiscue = []\n", 2288 | " for nbcue in range(8):\n", 2289 | "\n", 2290 | " # We run the network for two time steps, using the stored plastic weights in pwtest, so we can extract the step-2 representation of each single cue, as encoded\n", 2291 | " # by the successive states of the plastic weights over the 4 time steps of the trial\n", 2292 | "\n", 2293 | " hidden = net.initialZeroState(MYBS)\n", 2294 | " et = net.initialZeroET(MYBS) # The Hebbian eligibility trace\n", 2295 | "\n", 2296 | " inputs[:, :NBSTIMBITS] = 0\n", 2297 | " inputs[:, :params['cs']] = cuedata_arr[:, nbcue, :]\n", 2298 | "\n", 2299 | " inputs[:, NBSTIMBITS + 0] = 1.0 # Bias neuron, probably not necessary\n", 2300 | " inputs[:,NBSTIMBITS + 1] = 0 #numstep_ep / params['eplen'] # Time passed in this episode. Should it be the trial? Doesn't matter much anyway.\n", 2301 | " inputs[:, NBSTIMBITS + 2] = 0 # 1.0 * reward[nb] # Reward from previous time step\n", 2302 | "\n", 2303 | " inputsN0 = torch.from_numpy(inputs).detach().to(device)\n", 2304 | " inputsN1 = inputsN0.clone()\n", 2305 | " inputsN1[:, :NBSTIMBITS] = 0 # Normally, stimuli are only presented at the first time step\n", 2306 | "\n", 2307 | " y, v, DAout, hidden, et, pw = net(inputsN0, hidden, et, pwtest)\n", 2308 | " y, v, DAout, hiddenout, et, pw = net(inputsN1, hidden, et, pwtest)\n", 2309 | "\n", 2310 | " hiddenout = hiddenout.cpu().numpy()\n", 2311 | "\n", 2312 | " cor = np.corrcoef(hiddenout, wo)[:-1,-1] # Again, very wasteful\n", 2313 | " corrsthiscue.append(cor)\n", 2314 | "\n", 2315 | " if numstep > 0:\n", 2316 | " print(\"Corr:\", np.mean(cor[selects]), \"| Change from previous step:\", np.mean(cor[myselects] - corrseachstep[numstep-1][nbcue][myselects]))\n", 2317 | " else:\n", 2318 | " print(\"Corr:\", np.mean(cor[selects]), \"| - \")\n", 2319 | "\n", 2320 | " corrseachstep.append(corrsthiscue)\n", 2321 | "\n", 2322 | "corrseachstep = np.array(corrseachstep)\n" 2323 | ], 2324 | "metadata": { 2325 | "id": "UNujHwUdkPoD" 2326 | }, 2327 | "execution_count": null, 2328 | "outputs": [] 2329 | }, 2330 | { 2331 | "cell_type": "code", 2332 | "execution_count": null, 2333 | "metadata": { 2334 | "id": "G5nCT85NBcHC" 2335 | }, 2336 | "outputs": [], 2337 | "source": [ 2338 | "if False:\n", 2339 | " # Code uses 0 counting, but figure legends use 1-counting: \"Step 3\" is step2, \"Step 4 \" is step3.\n", 2340 | " plt.figure(figsize=(4,3))\n", 2341 | " pp(len(step3corrdiffs), step3corrdiffs[0].shape)\n", 2342 | " pp([np.median(x) for x in step3corrdiffs])\n", 2343 | "\n", 2344 | "\n", 2345 | " myselects = selects\n", 2346 | " # myselects = selectsadd\n", 2347 | "\n", 2348 | "\n", 2349 | " plt.subplot(1,2,1)\n", 2350 | " plt.title(\"Step 3\") # 1-counting in the paper\n", 2351 | " plt.axhline(0, color='k')\n", 2352 | " plt.bar(list(range(8)), [np.median(x[myselects]) for x in step2corrdiffs], yerr=np.vstack(([np.quantile(x[myselects],.75)-np.median(x[myselects]) for x in step2corrdiffs],\n", 2353 | " [np.median(x[myselects]) - np.quantile(x[myselects],.25) for x in step2corrdiffs])),\n", 2354 | " color='r',edgecolor='k', lw=1, width=1.0)\n", 2355 | " plt.ylim([-.7,.7])\n", 2356 | " plt.xticks(range(8), alphabet[:8])\n", 2357 | "\n", 2358 | " plt.subplot(1,2,2)\n", 2359 | " plt.title(\"Step 4\")\n", 2360 | " plt.axhline(0, color='k')\n", 2361 | " plt.bar(list(range(8)), [np.median(x[myselects]) for x in step3corrdiffs], yerr=np.vstack(([np.quantile(x[myselects],.75)-np.median(x[myselects]) for x in step3corrdiffs],\n", 2362 | " [np.median(x[myselects]) - np.quantile(x[myselects],.25) for x in step3corrdiffs])),\n", 2363 | " color='r',edgecolor='k', lw=1, width=1.0)\n", 2364 | " plt.ylim([-.7,.7])\n", 2365 | " plt.yticks([])\n", 2366 | " # plt.tick_params(axis='y', labelsize=8) #which='both', labelleft=False, labelright=True)\n", 2367 | " plt.xticks(range(8), alphabet[:8])\n", 2368 | "\n", 2369 | " plt.tight_layout()\n", 2370 | "\n", 2371 | "\n", 2372 | "\n", 2373 | "\n", 2374 | " # Plot the acctual corr, in addition to the changes ni corr....\n", 2375 | "\n", 2376 | " # plt.plot(list(range(8)), [np.median(x) for x in step3corrdiffs])" 2377 | ] 2378 | }, 2379 | { 2380 | "cell_type": "code", 2381 | "source": [ 2382 | "# Code uses 0 counting, but figure legends use 1-counting: \"Step 3\" is step2, \"Step 4 \" is step3.\n", 2383 | "plt.figure(figsize=(6,4))\n", 2384 | "\n", 2385 | "\n", 2386 | "nump=0\n", 2387 | "\n", 2388 | "\n", 2389 | "\n", 2390 | "\n", 2391 | "SHOWALLSELECTS = True\n", 2392 | "\n", 2393 | "if SHOWALLSELECTS:\n", 2394 | " myselects = selects\n", 2395 | "else:\n", 2396 | " myselects = selectsadd\n", 2397 | "\n", 2398 | "\n", 2399 | "\n", 2400 | "\n", 2401 | "# print(\">\",cp.shape, corr.shape)\n", 2402 | "# myselects = np.isin(cp[:,18,0], [3,4]) & np.isin(cp[:,18,1], [3,4]) & ~corr[:,18]\n", 2403 | "# # myselects = np.isin(cp[:,18,0], [3,4]) & np.isin(cp[:,18,1], [3,4]) & corr[:,18]\n", 2404 | "# myselects = [x for x in range(myselects.shape[0]) if myselects[x]] # Rest of the code expects a list of indices !\n", 2405 | "\n", 2406 | "\n", 2407 | "\n", 2408 | "print(len(myselects))\n", 2409 | "ss = np.zeros(corrseachstep.shape[-1])\n", 2410 | "ss[myselects] = 1\n", 2411 | "ss = ss>0 # BBoolean\n", 2412 | "pp(corrseachstep.shape)\n", 2413 | "# myselects = selectsadd\n", 2414 | "pp(corrseachstep[nump,:,:].shape)\n", 2415 | "pp(corrseachstep[nump,:,myselects].shape) # Somehow this permutes dimensions....\n", 2416 | "pp(corrseachstep[nump,:,ss].shape) # So does this\n", 2417 | "pp(corrseachstep[nump,:,myselects].T.shape) # So need to permute back\n", 2418 | "pp(np.mean(corrseachstep[nump,:,myselects].T, axis=1).shape)\n", 2419 | "\n", 2420 | "\n", 2421 | "for nump in range(4):\n", 2422 | " plt.subplot(2,4,1+nump)\n", 2423 | "\n", 2424 | " plt.xticks(np.arange(8), alphabet[:8])\n", 2425 | " plt.plot(np.array(np.mean(corrseachstep[nump,:,myselects].T, axis=1)), 'orange')\n", 2426 | " plt.errorbar(np.arange(8), np.mean(corrseachstep[nump,:,myselects].T, axis=1),yerr=np.std(corrseachstep[nump,:,myselects].T, axis=1),color='b', marker='o', linestyle='none')\n", 2427 | "\n", 2428 | " # plt.xticks(np.arange(8), alphabet[:8])\n", 2429 | " # plt.plot(np.mean(corrseachstep[nump,:,myselects].T, axis=1)) # Again, the .T is just there to cancel weird dimension-permuting by numpy when using \"\"\"smart\"\"\" indexing\n", 2430 | " # plt.plot(np.mean(corrseachstep[nump,:,myselects].T, axis=1), 'o')\n", 2431 | " # plt.xlabel(\"Single Cue (X)\")\n", 2432 | " plt.title('Step '+str(nump+1))\n", 2433 | " if nump == 0:\n", 2434 | " plt.ylabel(\"Corr($\\psi_{t2}(X), \\mathbf{w}_{out}$)\")\n", 2435 | " else:\n", 2436 | " plt.yticks([])\n", 2437 | "\n", 2438 | "\n", 2439 | "\n", 2440 | "for nump in range(4):\n", 2441 | " plt.subplot(2,4,5+nump)\n", 2442 | " if nump>1:\n", 2443 | " corrdiffs = corrseachstep[nump,:,myselects].T - corrseachstep[nump - 1,:,myselects].T\n", 2444 | " # plt.title(\"Step 3\") # 1-counting in the paper\n", 2445 | " plt.axhline(0, color='k')\n", 2446 | " plt.bar(list(range(8)), np.mean(corrdiffs, axis=1), yerr=np.std(corrdiffs, axis=1) ,\n", 2447 | " color='r',edgecolor='k', lw=1, width=1.0)\n", 2448 | " # plt.bar(list(range(8)), np.median(corrdiffs, axis=1), yerr=np.vstack( ( np.quantile(corrdiffs, .75, axis=1) -np.median(corrdiffs, axis=1) ,\n", 2449 | " # np.median(corrdiffs, axis=1) - np.quantile(corrdiffs, .25, axis=1)) ) ,\n", 2450 | " # color='r',edgecolor='k', lw=1, width=1.0)\n", 2451 | " plt.ylim([-.7,.7])\n", 2452 | " plt.xticks(range(8), alphabet[:8])\n", 2453 | " if nump ==3:\n", 2454 | " plt.yticks([])\n", 2455 | " else:\n", 2456 | " plt.axis('off')\n", 2457 | " if nump==0:\n", 2458 | " plt.text(-.7, .1, \"Change from\\nprevious step\", rotation=90)\n", 2459 | " plt.text(0.15,.5,' No\\nchange')\n", 2460 | "\n", 2461 | "# plt.tight_layout()\n", 2462 | "\n", 2463 | "plt.savefig('corrchanges'+str(BARREDPAIR[0])+str(BARREDPAIR[1])+('_add'+str(ADDBARREDPAIR[0])+str(ADDBARREDPAIR[1]) if not SHOWALLSELECTS else '')+'.png', dpi=300)\n", 2464 | "\n", 2465 | "\n", 2466 | "# Plot the acctual corr, in addition to the changes ni corr....\n", 2467 | "\n", 2468 | "# # plt.plot(list(range(8)), [np.median(x) for x in step3corrdiffs])" 2469 | ], 2470 | "metadata": { 2471 | "id": "luw2u7BKi5YR" 2472 | }, 2473 | "execution_count": null, 2474 | "outputs": [] 2475 | }, 2476 | { 2477 | "cell_type": "code", 2478 | "source": [ 2479 | "# The plastic weights that change most between steps 2-3 (coupling) are quite distinct from those that change most between steps 3-4 (representation changes) ?\n", 2480 | "# Actually correlation is ~0....\n", 2481 | "# But the matrices look really different! The former is mostly horizontal lines, the latter have much more vertical structure...\n", 2482 | "# However, the two groups (if they are groups) do not seem to have strongly different signs of alpha...\n", 2483 | "\n", 2484 | "\n", 2485 | " # actualnumstep = NUMTRIALCHNGCORR * params['triallen'] + numstep\n", 2486 | " # print(\"Actual step:\", actualnumstep)\n", 2487 | " # pwsavfull = allpwsavs_thisep[actualnumstep]\n", 2488 | "\n", 2489 | "actualnumstep0 = NUMTRIALCHNGCORR * params['triallen']\n", 2490 | "pw1m0 = allpwsavs_thisep[actualnumstep0+1] - allpwsavs_thisep[actualnumstep0]\n", 2491 | "pw2m1 = allpwsavs_thisep[actualnumstep0+2] - allpwsavs_thisep[actualnumstep0+1]\n", 2492 | "pw3m2 = allpwsavs_thisep[actualnumstep0+3] - allpwsavs_thisep[actualnumstep0+2]\n", 2493 | "\n", 2494 | "pw1m0 = pw1m0 ** 2\n", 2495 | "pw2m1 = pw2m1 ** 2\n", 2496 | "pw3m2 = pw3m2 ** 2\n", 2497 | "pw2m1_s = np.sum(pw2m1, axis=0)\n", 2498 | "pw3m2_s = np.sum(pw3m2, axis=0)\n", 2499 | "print(np.max(pw1m0)) # This one should be 0\n", 2500 | "print(np.max(pw2m1))\n", 2501 | "print(pw3m2.shape) # 2000 200 200\n", 2502 | "\n", 2503 | "print(np.corrcoef(pw2m1_s.flatten()>1000, pw3m2_s.flatten()>1000))\n", 2504 | "print(\"ABSOLUTE CHANGES IN PLASTIC WEIGHTS BETWEEEN TIME STEPS, summed over whole batch\")\n", 2505 | "plt.figure()\n", 2506 | "plt.subplot(1,2,1)\n", 2507 | "plt.title(\"step 2-3\")\n", 2508 | "plt.imshow(pw2m1_s)\n", 2509 | "plt.colorbar()\n", 2510 | "plt.subplot(1,2,2)\n", 2511 | "plt.title(\"step 3-4\")\n", 2512 | "plt.imshow(pw3m2_s)\n", 2513 | "plt.colorbar()\n", 2514 | "plt.figure()\n", 2515 | "plt.subplot(1,2,1)\n", 2516 | "plt.hist(pw2m1_s.flatten(),bins=100)\n", 2517 | "plt.subplot(1,2,2)\n", 2518 | "plt.hist(pw3m2_s.flatten(),bins=100)\n", 2519 | "aa = net.alpha.detach().cpu().numpy()\n", 2520 | "a1 = aa[pw2m1_s>1000]\n", 2521 | "a2 = aa[pw3m2_s>1000]\n", 2522 | "print(np.min(a1), np.max(a1), np.mean(a1))\n", 2523 | "print(np.min(a2), np.max(a2), np.mean(a2))\n" 2524 | ], 2525 | "metadata": { 2526 | "id": "ATGpQMV88m1h" 2527 | }, 2528 | "execution_count": null, 2529 | "outputs": [] 2530 | }, 2531 | { 2532 | "cell_type": "code", 2533 | "execution_count": null, 2534 | "metadata": { 2535 | "id": "bjGtDv6yDB47" 2536 | }, 2537 | "outputs": [], 2538 | "source": [ 2539 | "# This is an older analysis, requiring fixed cues for the entire batch\n", 2540 | "\n", 2541 | "\n", 2542 | "if False:\n", 2543 | "\n", 2544 | " # The pw change at step 3, *after* multiplication by alpha, correlates (negatively!) with the outer product of cue1pattern for the showncues (and adjacents) with the output axis.\n", 2545 | "\n", 2546 | " # If not multiplying by alpha, this effect disappears (or is strrongly diminished and reveresed in sign??), confirming the importance of taking alpha into consideraton (and that the system itself does just tha)\n", 2547 | "\n", 2548 | " # This is all good and as predicted, but... the correlation is small! -.15/+.07 (tbf the change in correlation of 1-step output with wo, as per the previous cell, was also small: -.28/+.15)\n", 2549 | "\n", 2550 | " # The correlation is actually a bit lower when using outer product of cue1patterns with pca vector 0 ) the \"decision axis\") at step 1...\n", 2551 | "\n", 2552 | " # Requires Fixed Cues\n", 2553 | "\n", 2554 | " pwdiff = allpwsavs_thisep[NUMTRIALCHNGCORR * params['triallen'] + 3][NUMBECHNGCORR, :, :] - allpwsavs_thisep[NUMTRIALCHNGCORR * params['triallen'] + 2][NUMBECHNGCORR, :, :]\n", 2555 | "\n", 2556 | " pwdiffalpha = pwdiff * net.alpha.cpu().numpy()\n", 2557 | " pwdiffnotalpha = pwdiff\n", 2558 | "\n", 2559 | " mats =[]\n", 2560 | " wo = net.h2o.weight.cpu().numpy()[0,:] # output weights\n", 2561 | " for nc in range(8):\n", 2562 | " mats.append(np.matmul(wo[:, None], cue1patterns[nc,:][None, :]))\n", 2563 | " print(\"With alpha:\", np.corrcoef(mats[nc].flatten(), pwdiffalpha.flatten())[0,1], \", without alpha:\", np.corrcoef(mats[nc].flatten(), pwdiffnotalpha.flatten())[0,1])\n", 2564 | "\n", 2565 | " for nc in range(8):\n", 2566 | " print(\"Projection of pwdiff over outer prod of cue1 and wo (flattened, with alpha):\", np.sum(mats[nc].flatten() * pwdiffalpha.flatten()) / np.linalg.norm(mats[nc].flatten()))\n", 2567 | "\n", 2568 | " #The correlation is stronger when you add together the various outer products of the shown and adjacent cues with wo (with correct sign and multiplier), though still low (.2/.3)\n", 2569 | " print(np.corrcoef((-.5 * mats[2] - mats[3] + mats[4] + .5 * mats[5]).flatten(), pwdiffalpha.flatten())[0,1])" 2570 | ] 2571 | }, 2572 | { 2573 | "cell_type": "code", 2574 | "execution_count": null, 2575 | "metadata": { 2576 | "id": "CiSYlQIPH4_5" 2577 | }, 2578 | "outputs": [], 2579 | "source": [ 2580 | "# Same using PC 0 (note that sign is irrelevant there)\n", 2581 | "if False:\n", 2582 | " for nc in range(8):\n", 2583 | " mat = np.matmul(pca.components_[0,:][:, None], cue1patterns[nc,:][None, :])\n", 2584 | " print(\"With alpha:\", np.corrcoef(mat.flatten(), pwdiffalpha.flatten())[0,1], \", without alpha:\", np.corrcoef(mat.flatten(), pwdiffnotalpha.flatten())[0,1])" 2585 | ] 2586 | }, 2587 | { 2588 | "cell_type": "code", 2589 | "execution_count": null, 2590 | "metadata": { 2591 | "id": "yuYTs1cdFJ_d" 2592 | }, 2593 | "outputs": [], 2594 | "source": [ 2595 | "# Alternatively: we directly find the input vector (for each cue) and the output vector (common to all cues)\n", 2596 | "# such that their outer product * alpha, when multiplied by the same cue's FF representation, best matches the output weights\n", 2597 | "\n", 2598 | "#This is only for one single batch element !!! This was an early analysis. The one actually used for the paper performs the search for all batch elements. See below.\n", 2599 | "\n", 2600 | "#Need to add net.w, with a small multiplier to match realistic relative sizes!\n", 2601 | "\n", 2602 | "# Works ! BUT! The optimization probblem seems to have two solutions.\n", 2603 | "\n", 2604 | "# Outcome 1 (most frequent): the v1s are strongly represented at step 1, according to whetherr or not the corrresponding cues exhibit large change\n", 2605 | "# above,and have no corr with step-0 (FF) cue representations, and v2 is strongly represented at step 2 and no corr with output weights.\n", 2606 | "\n", 2607 | "# Outcome 2 (not changing anything else, just rre-runnning this vvery cell and the next and leavinig everything else\n", 2608 | "# unchanged): the v1s have ~0 representationiat step 1, but are strongly correlated with the step-0 (FF) cue representtions (r>.7), and\n", 2609 | "# the v2 is moderately represented at step 1 and has corr .9 with output weights...\n", 2610 | "\n", 2611 | "\n", 2612 | "# This one uses cue1patternsallbatch, which does not require FixedCues.\n", 2613 | "# BUT!...\n", 2614 | "# It only uses the cue1 patterns of one element in batch!\n", 2615 | "\n", 2616 | "if False:\n", 2617 | " MYNUM = selects[1] # NUMBECHNGCORR\n", 2618 | "\n", 2619 | " if True:\n", 2620 | " wo = net.h2o.weight.cpu().numpy()\n", 2621 | " wo = wo[1,:] - wo[0,:] # output weights\n", 2622 | "\n", 2623 | " HS = params['hs']\n", 2624 | " torch.set_grad_enabled(True)\n", 2625 | " alph = torch.zeros(HS, HS)\n", 2626 | " alph[:,:] = net.alpha[:,:]\n", 2627 | " ww = torch.zeros(HS, HS)\n", 2628 | " ww[:,:] = net.w[:,:]\n", 2629 | "\n", 2630 | " # cue1patterns has shape 8,200\n", 2631 | " # cue1patternsallbatch[0] has shape 2000,200\n", 2632 | " # c1ps = torch.from_numpy(cue1patterns)\n", 2633 | " c1ps = np.vstack([x[MYNUM,:] for x in cue1patternsallbatch])\n", 2634 | " c1ps = torch.from_numpy(c1ps)\n", 2635 | " wwo = torch.from_numpy(wo)\n", 2636 | "\n", 2637 | " v1s = []\n", 2638 | " for nc in range(8):\n", 2639 | " v1s.append( torch.rand(HS, requires_grad=True) )\n", 2640 | " v1s[nc].data = .01 * (v1s[nc].data - .5)\n", 2641 | "\n", 2642 | " v2 = torch.rand(HS, requires_grad=True)\n", 2643 | " v2.data = .01 * (v2.data - .5)\n", 2644 | "\n", 2645 | " optimexp = torch.optim.Adam((v1s + [v2]), lr=3e-4, weight_decay=1e-3)\n", 2646 | " # optimexp = torch.optim.Adam((v1s + [v2]), lr=3e-4, weight_decay=1e-2)\n", 2647 | " # optimexp = torch.optim.Adam((v1s + [v2]), lr=1e-3, weight_decay=1e-2)\n", 2648 | " for numstep in range(3000):\n", 2649 | " optimexp.zero_grad()\n", 2650 | " loss = 0\n", 2651 | " for nc in range(8):\n", 2652 | " # tgtmat = torch.flatten(torch.from_numpy(mats[nc])).detach()\n", 2653 | " outer = torch.matmul(v2[:, None], v1s[nc][None,:])\n", 2654 | " # outera = .03*ww.detach()+outer * alph.detach()\n", 2655 | " outera = .3*ww.detach()+outer * alph.detach()\n", 2656 | " prod = torch.matmul(outera, c1ps[nc,:].detach())\n", 2657 | " cc = torch.corrcoef(torch.vstack( (prod, wwo) ))\n", 2658 | " # print(cc)\n", 2659 | " loss += -cc[0,1] / 8\n", 2660 | " # if numstep % 10 == 0:\n", 2661 | " # print(float(loss))\n", 2662 | " loss.backward()\n", 2663 | " if numstep % 100 == 0:\n", 2664 | " print(\"loss:\", float(loss))\n", 2665 | " optimexp.step()\n", 2666 | "\n", 2667 | " torch.set_grad_enabled(False)\n", 2668 | " print(\"THIS IS ONLY FOR BATCH ELEMENT\", MYNUM)\n" 2669 | ] 2670 | }, 2671 | { 2672 | "cell_type": "code", 2673 | "execution_count": null, 2674 | "metadata": { 2675 | "id": "ClMNtbnVJuSz" 2676 | }, 2677 | "outputs": [], 2678 | "source": [ 2679 | "\n", 2680 | "# Probing: are the vectors found above actually generated by the system in its activations? (Again, this one is for a single batch element!) (Be sure to use the same batch element as above)\n", 2681 | "\n", 2682 | "\n", 2683 | "# THERE ARE TWO POSSIBLE PATTERNS, for the exact same data, based on the outcome of the grad desc process in the cell above:\n", 2684 | "# 1- v1s match the FF (step-0) representatiions of the cues and V2 matches the H at step 1 (2nd step), and also the output wights\n", 2685 | "# 2- v1s of cues shown in that trial (and their adjacent cues in each direction) match the H at step 1, and v2 matches H at step 2 (with arbitrary sign).\n", 2686 | "#\n", 2687 | "# Note: outcome 2 seems to result in better (more negative) loss in the prvious cell's optimization. (~-.84)\n", 2688 | "#\n", 2689 | "#BUT Now with better optimization outcome 1 never happens\n", 2690 | "\n", 2691 | "# Remarkable match between each cue's correlation with h(t) (as seen here), and whether the corresponding cue had a change in representation alignment above, for any given batch element !\n", 2692 | "# E.g. if one cue failed to change its alignment, its corresponding v1 will not be represented in h (will have no correlation with h(time step 1), as computed here)...\n", 2693 | "\n", 2694 | "# MYNUM IS DEFINED IN PREVIOUS CELL, DON'T CHANGE\n", 2695 | "\n", 2696 | "if False:\n", 2697 | "\n", 2698 | "\n", 2699 | " print(\"Cue pairs for batch element\", MYNUM, \":\")\n", 2700 | " print(np.hstack((cp[MYNUM,:,:], np.arange(cp.shape[1])[:, None])).T)\n", 2701 | " print(\"Responses of batch element\", MYNUM, \":\")\n", 2702 | " print(resps[MYNUM, :])\n", 2703 | "\n", 2704 | "\n", 2705 | " print(\"Correlations between the found v1's (estimates of 'optimal' step-1 representations to produce learning)\")\n", 2706 | " print(allrates.shape, v1s[0].shape, np.vstack(v1s).shape, np.corrcoef(np.vstack(v1s)).shape)\n", 2707 | " np.set_printoptions(suppress=True)\n", 2708 | " print(np.corrcoef(np.vstack(v1s)))\n", 2709 | "\n", 2710 | " print(\"Correlation between the actual cues themselves:\")\n", 2711 | " cc = np.array(cuedata[MYNUM])\n", 2712 | " print(np.corrcoef(cc))\n", 2713 | "\n", 2714 | " alph = net.alpha.detach().cpu().numpy()\n", 2715 | "\n", 2716 | " print(\"Correlation between found v1s and the FF (step-0) representation of the corresponding cue:\")\n", 2717 | " for nc in range(8):\n", 2718 | " # mats.append(np.matmul(wo[:, None], cue1patterns[nc,:][None, :]))\n", 2719 | " # print( np.corrcoef(mats[nc].flatten(), (np.matmul( v2.detach()[:, None], v1s[nc].detach()[None,:] ) * alph).flatten())[0,1] )\n", 2720 | " # print( np.corrcoef(v1s[nc].detach(), cue1patterns[nc,:] )[0,1] )\n", 2721 | " print( np.corrcoef(v1s[nc].detach(), cue1patternsallbatch[nc][NUMBECHNGCORR,:] )[0,1] )\n", 2722 | "\n", 2723 | " print(\"==\")\n", 2724 | " ZS = (NUMTRIALCHNGCORR) *params['triallen'] + 1\n", 2725 | "\n", 2726 | " print(\"Correlation between found v1s and the H vector at timestep\",ZS,\" for BE\", MYNUM,\":\")\n", 2727 | " for nc in range(8):\n", 2728 | " print( np.corrcoef(v1s[nc].detach(), allrates[MYNUM,:, ZS] )[0,1] )\n", 2729 | "\n", 2730 | " print(\"Correlation between found v1s and the H vector at timestep 0 for BE\", MYNUM,\"(should be ~0 because it's trial step 0):\")\n", 2731 | " for nc in range(8):\n", 2732 | " print( np.corrcoef(v1s[nc].detach(), allrates[MYNUM,:, 0] )[0,1] )\n", 2733 | " print(\"Correlation between found v1s and the H vector at timestep 1 for BE\", MYNUM,\"(should be high+ for 1st cue of trial 0 and high- for 2nd cue of trial 0 -or vice versa):\")\n", 2734 | " for nc in range(8):\n", 2735 | " print( np.corrcoef(v1s[nc].detach(), allrates[MYNUM,:, 1] )[0,1] )\n", 2736 | " print(\"Correlation between found v1s and the H vector at timestep 5 for BE\", MYNUM,\"(should be high+ for 1st cue of trial 1 and high- for 2nd cue of trial 1 -or vice versa):\")\n", 2737 | " for nc in range(8):\n", 2738 | " print( np.corrcoef(v1s[nc].detach(), allrates[MYNUM,:, 5] )[0,1] )\n", 2739 | " print(\"==\")\n", 2740 | " print(\"Correlation between found v2 and the H vector at timesteps\",ZS-1,\"to\", ZS+2, \"(incl) (BE\",MYNUM,\"):\")\n", 2741 | " print( np.corrcoef(v2.detach(), allrates[MYNUM,:, ZS-1] )[0,1] )\n", 2742 | " print( np.corrcoef(v2.detach(), allrates[MYNUM,:, ZS] )[0,1] )\n", 2743 | " print( np.corrcoef(v2.detach(), allrates[MYNUM,:, ZS+1] )[0,1] )\n", 2744 | " print( np.corrcoef(v2.detach(), allrates[MYNUM,:, ZS+2] )[0,1] )\n", 2745 | " print(\"==\")\n", 2746 | " print(\"Correlation btweeen found v2 and output weights:\")\n", 2747 | " print( np.corrcoef(v2.detach(), wo )[0,1] )\n", 2748 | "\n", 2749 | "\n" 2750 | ] 2751 | }, 2752 | { 2753 | "cell_type": "code", 2754 | "source": [ 2755 | "# This is the one actually used in the paper.\n", 2756 | "\n", 2757 | "# We optimize to find a set of v1s for each element in batch, and single commmon v2 for the whole batch.\n", 2758 | "\n", 2759 | "# More precisely: we directly find the input vector v1s (for each cue in each batch element) and the output vector v2 (common to all cues)\n", 2760 | "# such that their outer product * alpha + w (the +w is important!), when multiplied by the same cue's FF representation, best matches the output weights vector wo.\n", 2761 | "\n", 2762 | "# Need to add net.w, possibly with a small multiplier to help match the small sizes of the v1s and v2\n", 2763 | "\n", 2764 | "# Works ! BUT! In some circumstances, the optimization problem seems to have two solutions.\n", 2765 | "\n", 2766 | "# Outcome 1 (most frequent): the v1s are strongly represented at step 1 (i.e. 2nd step), according to whetherr or not the corrresponding cues exhibit large change\n", 2767 | "# above,and have no corr with step-0 (FF) cue representations, and v2 is strongly represented at step 2 and no corr with output weights.\n", 2768 | "\n", 2769 | "# Outcome 2 (not changing anything else, just re-runnning this very cell and the next and leaving everything else\n", 2770 | "# unchanged): the v1s have ~0 representationiat step 1, but are strongly correlated with the step-0 (FF) cue representtions (r>.7), and\n", 2771 | "# the v2 is moderately represented at step 1 and has corr .9 with output weights... Basically ignores alpha and w\n", 2772 | "\n", 2773 | "# FINDING: With better training (higher multiplier on w, mmore time, higher lr, etc.), the proportion of outome-2 becomes very low. But you need to watch out!\n", 2774 | "\n", 2775 | "# Also, if the final loss is ~-8000 or lower, you're generally in outcome 1.\n", 2776 | "\n", 2777 | "\n", 2778 | "if True:\n", 2779 | " MYBS = params['bs']\n", 2780 | " HS = params['hs']\n", 2781 | "\n", 2782 | " torch.set_grad_enabled(True)\n", 2783 | " alph = torch.zeros(HS, HS).to(device)\n", 2784 | " alph[:,:] = net.alpha[:,:]\n", 2785 | " ww = torch.zeros(HS, HS).to(device)\n", 2786 | " ww[:,:] = net.w[:,:]\n", 2787 | "\n", 2788 | " # cue1patterns has shape 8,200\n", 2789 | " # cue1patternsallbatch[0] has shape 2000,200\n", 2790 | " # c1ps = torch.from_numpy(cue1patterns)\n", 2791 | " # c1ps = np.vstack([x[NUMBECHNGCORR,:] for x in cue1patternsallbatch])\n", 2792 | " c1ps = [torch.from_numpy(x).to(device) for x in cue1patternsallbatch]\n", 2793 | "\n", 2794 | " wo = net.h2o.weight.cpu().detach().numpy()\n", 2795 | " wo =wo[1,:] - wo[0,:] # output weights\n", 2796 | " wwo = torch.from_numpy(wo).to(device) # wwo is just a vector\n", 2797 | "\n", 2798 | " v1s = []\n", 2799 | " for nc in range(8):\n", 2800 | " v1s.append( torch.rand(MYBS, 1, HS, requires_grad=True, device=device) )\n", 2801 | " v1s[nc].data = .01 * (v1s[nc].data - .5)\n", 2802 | "\n", 2803 | " # v2 = torch.rand(MYBS, HS, 1, requires_grad=True, device=device)\n", 2804 | " v2 = torch.rand(HS, 1, requires_grad=True, device=device) # No bbatch dimension ! 1 common v2 for the whole batch!So no worries about signs.\n", 2805 | " v2.data = .01 * (v2.data - .5)\n", 2806 | "\n", 2807 | "\n", 2808 | "\n", 2809 | " # optimexp = torch.optim.Adam((v1s + [v2]), lr=3e-4, weight_decay=1e-2)\n", 2810 | " optimexp = torch.optim.Adam((v1s + [v2]), lr=1e-3, weight_decay=1e-2)\n", 2811 | " # optimexp = torch.optim.Adam((v1s + [v2]), lr=1e-3, weight_decay=1e-3)\n", 2812 | " # optimexp = torch.optim.Adam((v1s + [v2]), lr=1e-3, weight_decay=1e-3)\n", 2813 | " for numstep in range(1000):\n", 2814 | " optimexp.zero_grad()\n", 2815 | " loss = 0\n", 2816 | " for nc in range(8):\n", 2817 | " # tgtmat = torch.flatten(torch.from_numpy(mats[nc])).detach()\n", 2818 | " # outer = torch.matmul(v2[:, None], v1s[nc][None,:])\n", 2819 | " outer = torch.matmul(v2[None, :, :], v1s[nc]) # Outer has shape 2000, 200, 200 - as expected (batched outer product)\n", 2820 | " outera = 1.0 * ww.detach()[None,:,:] +outer * alph.detach()[None,:,:] # outera: 2000, 200, 200\n", 2821 | " prod = torch.matmul(outera, c1ps[nc].detach()[:, :, None]) # prrod has shape 2000, 200, 1\n", 2822 | "\n", 2823 | " cc = torch.nn.functional.cosine_similarity(prod, wwo[None,:,None].detach()) # cc should have shape 2000,1\n", 2824 | " # cc = torch.corrcoef(torch.hstack( (prod, wwo) )) # hstack, not vstack\n", 2825 | " # loss += -cc[0,1] / 8\n", 2826 | " loss += -torch.sum(cc)\n", 2827 | " # if numstep % 10 == 0:\n", 2828 | " # print(float(loss))\n", 2829 | " loss.backward()\n", 2830 | " if numstep % 30 == 0:\n", 2831 | " print(\"loss:\", float(loss))\n", 2832 | " optimexp.step()\n", 2833 | "\n", 2834 | " torch.set_grad_enabled(False)\n" 2835 | ], 2836 | "metadata": { 2837 | "id": "nd5xGDqBbPLO" 2838 | }, 2839 | "execution_count": null, 2840 | "outputs": [] 2841 | }, 2842 | { 2843 | "cell_type": "code", 2844 | "source": [ 2845 | "v2N = v2.cpu().detach().numpy()\n", 2846 | "pp(v2N.shape) # 200 1\n", 2847 | "pp(resps.shape) # 2000 30\n", 2848 | "pp(corr.shape) # 2000, 30\n", 2849 | "\n", 2850 | "# Same sign in curve above. Corr withresp here is strong-negative [[ 1. -0.97788]. Again. Opposite sign: positive correlation.\n", 2851 | "\n", 2852 | "ZS = (NUMTRIALCHNGCORR) *params['triallen'] + 1 # Step 1 of trial 18\n", 2853 | "\n", 2854 | "\n", 2855 | "numstep = ZS+1\n", 2856 | "tmp1 = torch.from_numpy(v2N[None, :,0]) # 1, 200.\n", 2857 | "tmp2 = torch.from_numpy(allrates[:,:, numstep])\n", 2858 | "sims = torch.nn.functional.cosine_similarity(tmp1, tmp2).numpy() # 2000. We use cosine_similarity because it can be done in a batch\n", 2859 | "print(np.vstack((sims[:10], resps[:10, NUMTRIALCHNGCORR])))\n", 2860 | "print(np.corrcoef(sims, resps[:, NUMTRIALCHNGCORR]))\n", 2861 | "print(np.corrcoef(sims, corr[:, NUMTRIALCHNGCORR]))\n", 2862 | "\n", 2863 | "pp(\"--\")\n", 2864 | "\n", 2865 | "# At early trial, v2N is still well represented in r(t+2), and correlation with actual response is also high (if anything higher)\n", 2866 | "\n", 2867 | "EARLYTRIAL = 5\n", 2868 | "\n", 2869 | "numstep = EARLYTRIAL * params['triallen'] + 2\n", 2870 | "tmp1 = torch.from_numpy(v2N[None, :,0]) # 1, 200.\n", 2871 | "tmp2 = torch.from_numpy(allrates[:,:, numstep])\n", 2872 | "sims = torch.nn.functional.cosine_similarity(tmp1, tmp2).numpy() # 2000. We use cosine_similarity because it can be done in a batch\n", 2873 | "print(np.vstack((sims[:10], resps[:10, EARLYTRIAL])))\n", 2874 | "print(np.corrcoef(sims, resps[:, EARLYTRIAL]))\n", 2875 | "print(np.corrcoef(sims, corr[:, EARLYTRIAL]))\n", 2876 | "\n", 2877 | "poscorr = np.corrcoef(sims, resps[:, EARLYTRIAL])[0,1] > 0\n", 2878 | "pp('--')\n", 2879 | "\n", 2880 | "if poscorr:\n", 2881 | " print(\"Corr b/w found v2 and response is positive, nothing to do.\")\n", 2882 | "else:\n", 2883 | " print(\"Corr b/w found v2 and response is negative, flipping v1 and v2\")\n", 2884 | " v2 = -v2\n", 2885 | " v1s = [-x for x in v1s]" 2886 | ], 2887 | "metadata": { 2888 | "id": "nIMV8i0U0HuI" 2889 | }, 2890 | "execution_count": null, 2891 | "outputs": [] 2892 | }, 2893 | { 2894 | "cell_type": "code", 2895 | "execution_count": null, 2896 | "metadata": { 2897 | "id": "pdJHiKezjh7Z" 2898 | }, 2899 | "outputs": [], 2900 | "source": [ 2901 | "print(outer.shape)\n", 2902 | "print(outera.shape)\n", 2903 | "print(prod.shape)\n", 2904 | "print(c1ps[0].shape)\n", 2905 | "print(cc.shape)\n", 2906 | "print(v1s[0].shape, v2.shape)\n", 2907 | "print(cue1patternsallbatch[0].shape)\n" 2908 | ] 2909 | }, 2910 | { 2911 | "cell_type": "code", 2912 | "execution_count": null, 2913 | "metadata": { 2914 | "id": "mnEuhyU0EPCF" 2915 | }, 2916 | "outputs": [], 2917 | "source": [ 2918 | "# First, we test our found v1s and v2 for a single batch element, looking at which v1s are represented (if any) at various\n", 2919 | "# points in the episode (we use the v1s for this batch element, of course)\n", 2920 | "\n", 2921 | "alph = net.alpha.detach().cpu().numpy()\n", 2922 | "v1sN = [x.cpu().detach().numpy() for x in v1s]\n", 2923 | "v2N = v2.cpu().detach().numpy()\n", 2924 | "\n", 2925 | "MYNUM = selects[1]\n", 2926 | "# MYNUM = NUMBECHNGCORR\n", 2927 | "\n", 2928 | "wo = net.h2o.weight.cpu().detach().numpy()\n", 2929 | "wo =wo[1,:] - wo[0,:] # output weights\n", 2930 | "\n", 2931 | "print(\"Correlation between found v1s and the FF (step-0) representation of the corresponding cue:\")\n", 2932 | "for nc in range(8):\n", 2933 | " print( np.corrcoef(v1sN[nc][MYNUM,0,:], cue1patternsallbatch[nc][MYNUM,:] )[0,1] )\n", 2934 | " # torch.nn.functional.cosine_similarity( torch.from_numpy(v1sN[nc][MYNUM,0,:][None,:]), torch.from_numpy(cue1patternsallbatch[nc][MYNUM,:][None,:])))\n", 2935 | "\n", 2936 | "\n", 2937 | "print(\"==\")\n", 2938 | "ZS = (NUMTRIALCHNGCORR) *params['triallen'] + 1\n", 2939 | "\n", 2940 | "print(\"Correlation between found v1s and the H vector at timestep\",ZS,\" (BE\",MYNUM,\"):\")\n", 2941 | "for nc in range(8):\n", 2942 | " print( np.corrcoef(v1sN[nc][MYNUM,0,:], allrates[MYNUM,:, ZS] )[0,1] )\n", 2943 | "print(\"==\")\n", 2944 | "print(\"Correlation between found v2 and the H vector at timesteps\",ZS-1,\"to\", ZS+2, \"(incl):\")\n", 2945 | "print( np.corrcoef(v2N[:,0], allrates[MYNUM,:, ZS-1] )[0,1] )\n", 2946 | "print( np.corrcoef(v2N[:,0], allrates[MYNUM,:, ZS] )[0,1] )\n", 2947 | "print( np.corrcoef(v2N[:,0], allrates[MYNUM,:, ZS+1] )[0,1] )\n", 2948 | "print( np.corrcoef(v2N[:,0], allrates[MYNUM,:, ZS+2] )[0,1] )\n", 2949 | "print(\"==\")\n", 2950 | "print(\"Correlation btweeen found v2 and output weights:\")\n", 2951 | "print( np.corrcoef(v2N[:,0], wo )[0,1] )\n", 2952 | "\n" 2953 | ] 2954 | }, 2955 | { 2956 | "cell_type": "code", 2957 | "source": [ 2958 | "# Average representation of the v1s for each cue at time step 1 of trial 18, only for the selected batch elements (i.e. those who hadeither 3-4 or\n", 2959 | "# 4-3 on thistrial, and it was the first time they saw either)\n", 2960 | "\n", 2961 | "# Importantly, you need to adapt the sign of the v1s depending on whether it's 3-4 or 4-3. The first cue represents iits v1s with positive sign, the seond cue represents them with negative signs\n", 2962 | "\n", 2963 | "#So we show the curves separately for 3-4 and 4-3\n", 2964 | "\n", 2965 | "#print( np.corrcoef(v2N[MYNUM,:,0], allrates[MYNUM,:, ZS+1] )[0,1] )\n", 2966 | "# v2N: 2000, 200, 1\n", 2967 | "# allrates[:,:,ZS+!]: 2000, 200\n", 2968 | "v1sN = [x.cpu().detach().numpy() for x in v1s]\n", 2969 | "v2N = v2.cpu().detach().numpy()\n", 2970 | "\n", 2971 | "print(v1sN[0].shape, v2N.shape, wo.shape)\n", 2972 | "print(cp.shape,\".\")\n", 2973 | "\n", 2974 | "cue1patternsallbatchN = np.array(cue1patternsallbatch)\n", 2975 | "pp(\">\", cue1patternsallbatchN.shape)\n", 2976 | "\n", 2977 | "ZS = (NUMTRIALCHNGCORR) *params['triallen'] + 1 # Step 1 of trial 18, again\n", 2978 | "\n", 2979 | "allsims=None\n", 2980 | "\n", 2981 | "\n", 2982 | "# We extract the runs where the cues of trial NUMTRIALCHNGCORR are D and E, shown in the right order\n", 2983 | "# These do NOT need to be from 'selects' !\n", 2984 | "# ss1 = np.zeros(cp.shape[0]); ss1[selects] = 1; ss1[cp[:,18,0] < cp[:, 18, 1]] = 0; ss1 = ss1>0;\n", 2985 | "# ss2 = np.zeros(cp.shape[0]); ss2[selects] = 1; ss2[cp[:,18,0] > cp[:, 18, 1]] = 0; ss2 = ss2>0;\n", 2986 | "ss1 = (cp[:,NUMTRIALCHNGCORR,0] == 3) & (cp[:,NUMTRIALCHNGCORR,1] == 4)\n", 2987 | "ss2 = (cp[:,NUMTRIALCHNGCORR,0] == 4) & (cp[:,NUMTRIALCHNGCORR,1] == 3)\n", 2988 | "\n", 2989 | "\n", 2990 | "pp(\"selected in ss1:\", np.sum(ss1), \"| selected in ss2\", np.sum(ss2))\n", 2991 | "\n", 2992 | "\n", 2993 | "allsims_adap=[]\n", 2994 | "allsims_orig=[]\n", 2995 | "for numstep in range(ZS-1, ZS+3):\n", 2996 | " print(\"Step\", numstep, \":\")\n", 2997 | " allsims_adap_thisstep = []\n", 2998 | " allsims_orig_thisstep = []\n", 2999 | " for nc in range(8):\n", 3000 | " tmp1 = torch.from_numpy(v1sN[nc][:,0,:]) # 2000, 200. The 0 is a dummy dimension\n", 3001 | " tmp1b = torch.from_numpy(cue1patternsallbatchN[nc,:,:]) # 2000, 200.\n", 3002 | " tmp2 = torch.from_numpy(allrates[:,:, numstep])\n", 3003 | " allsims_adap_thisstep.append( torch.nn.functional.cosine_similarity(tmp1, tmp2).numpy() ) # 2000. We use cosine_similarity because it can be done in a batch\n", 3004 | " allsims_orig_thisstep.append( torch.nn.functional.cosine_similarity(tmp1b, tmp2).numpy() ) # 2000. We use cosine_similarity because it can be done in a batch\n", 3005 | " print(\"Adapted cue\", nc, \":\", np.mean(allsims_adap_thisstep[nc][ss1]))\n", 3006 | "\n", 3007 | " allsims_adap.append(allsims_adap_thisstep)\n", 3008 | " allsims_orig.append(allsims_orig_thisstep)\n", 3009 | "\n", 3010 | "allsims_adap = np.array(allsims_adap) # 4 steps x 8 cues x 2000 batch elements\n", 3011 | "allsims_orig = np.array(allsims_orig) # 4 steps x 8 cues x 2000 batch elements\n", 3012 | "pp( allsims_adap.shape )\n", 3013 | "\n", 3014 | "\n", 3015 | "for numfig in range(2):\n", 3016 | "\n", 3017 | " quantitytoplot = [allsims_adap, allsims_orig][numfig]\n", 3018 | "\n", 3019 | " plt.figure(figsize=(6,2))\n", 3020 | "\n", 3021 | " for nump in range(4):\n", 3022 | " plt.subplot(1,4,1+nump)\n", 3023 | " # plt.axhline(0, color='gray')\n", 3024 | "\n", 3025 | " plt.xticks(np.arange(8), alphabet[:8])\n", 3026 | " plt.ylim([-.35, .35] if numfig == 0 else [-.8, .8])\n", 3027 | " plt.errorbar(np.arange(8), np.mean(quantitytoplot[nump,:,ss1].T, axis=1),yerr=np.std(quantitytoplot[nump,:,ss1].T, axis=1),\n", 3028 | " color='orange', ecolor='b', markerfacecolor='b', markeredgecolor='b', marker='o', label='DE')\n", 3029 | " plt.errorbar(np.arange(8), np.mean(quantitytoplot[nump,:,ss2].T, axis=1),yerr=np.std(quantitytoplot[nump,:,ss2].T, axis=1),\n", 3030 | " color='m', linestyle='--', ecolor='c', markerfacecolor='c', markeredgecolor='c', marker='o', label='ED')\n", 3031 | "\n", 3032 | "\n", 3033 | " plt.title('Step '+str(nump+1))\n", 3034 | " if nump == 3:\n", 3035 | " plt.legend(loc=(1.1, .5))\n", 3036 | " if nump == 0:\n", 3037 | " plt.ylabel(r\"Corr($\\tilde{\\psi}_{t1}(X), \\mathbf{r}(t)$)\" if numfig == 0 else r\"Corr($\\psi_{t1}(X), \\mathbf{r}(t)$)\")\n", 3038 | " else:\n", 3039 | " plt.yticks([])\n", 3040 | "\n", 3041 | " plt.tight_layout()\n", 3042 | " plt.savefig((\"adapted\" if numfig == 0 else \"original\")+\"FFcuereps_t18_DEonly.png\", dpi=300)\n", 3043 | "\n" 3044 | ], 3045 | "metadata": { 3046 | "id": "aLbzRMPQcGcz" 3047 | }, 3048 | "execution_count": null, 3049 | "outputs": [] 3050 | }, 3051 | { 3052 | "cell_type": "code", 3053 | "source": [ 3054 | "# V2, the optimization-found adapted 'output' vector for producing adequate learning, is represented at time step 3, with\n", 3055 | "# sign almost (but not quite) identical to the network's response for this trial\n", 3056 | "\n", 3057 | "z = np.corrcoef(allrates[:,:, ZS+1], v2N[:,0] )[-1,:-1]\n", 3058 | "print(z.shape, v2N.shape) # 2000 / 200,1\n", 3059 | "print(z[:10])\n", 3060 | "print(resps[:10, 18])\n", 3061 | "print(\"Correlation between response and representation of V2 at step 2 (3rd step) of trial 18, across batch:\")\n", 3062 | "print(np.corrcoef(z, resps[:,18])[0,1])\n", 3063 | "\n", 3064 | "# ss1 = (cp[:,18,0] == 3) & (cp[:,18,1] == 4)\n", 3065 | "# ss1 = cp[:, 18, 0] > cp[:, 18, 1]\n", 3066 | "# ss2 = cp[:, 18, 0] < cp[:, 18, 1]\n", 3067 | "ss1 = resps[:, NUMTRIALCHNGCORR] >0\n", 3068 | "ss2 = resps[:, NUMTRIALCHNGCORR] < 1\n", 3069 | "\n", 3070 | "\n", 3071 | "ZS = (NUMTRIALCHNGCORR) *params['triallen'] + 1 # Step 1 of trial 18, again\n", 3072 | "\n", 3073 | "meancorrs1 = []\n", 3074 | "meancorrs2 = []\n", 3075 | "stdcorrs1 = []\n", 3076 | "stdcorrs2 = []\n", 3077 | "for numstep in range(4):\n", 3078 | " meancorrs1.append( np.mean(np.corrcoef(allrates[ss1,:, ZS-1+numstep], v2N[:,0] )[-1,:-1] ) )\n", 3079 | " stdcorrs1.append( np.std(np.corrcoef(allrates[ss1,:, ZS-1+numstep], v2N[:,0] )[-1,:-1] ) )\n", 3080 | " meancorrs2.append( np.mean(np.corrcoef(allrates[ss2,:, ZS-1+numstep], v2N[:,0] )[-1,:-1] ) )\n", 3081 | " stdcorrs2.append( np.std(np.corrcoef(allrates[ss2,:, ZS-1+numstep], v2N[:,0] )[-1,:-1] ) )\n", 3082 | "print(meancorrs1)\n", 3083 | "print(meancorrs2)\n", 3084 | "\n", 3085 | "plt.figure(figsize=(3,3))\n", 3086 | "plt.errorbar(np.arange(4), meancorrs1, yerr=stdcorrs1,\n", 3087 | " color='orange', ecolor='b', markerfacecolor='b', markeredgecolor='b', marker='o', label='Resp+')\n", 3088 | "plt.errorbar(np.arange(4), meancorrs2, yerr=stdcorrs2,\n", 3089 | " color='m', linestyle='--', ecolor='c', markerfacecolor='c', markeredgecolor='c', marker='o', label='Resp-')\n", 3090 | "plt.xticks(range(4), ['Step '+str(x) for x in range(4)])\n", 3091 | "plt.title(r\"Corr($\\mathbf{\\tilde{w}}_{out}, \\mathbf{r}(t)$), trial 20, 2000 runs\")\n", 3092 | "plt.legend()\n", 3093 | "plt.savefig('adaptedwo.png', dpi=300)" 3094 | ], 3095 | "metadata": { 3096 | "id": "Ud7w111sC35R" 3097 | }, 3098 | "execution_count": null, 3099 | "outputs": [] 3100 | }, 3101 | { 3102 | "cell_type": "code", 3103 | "source": [ 3104 | "# Average representation of the v1s for each cue at time step 1 of trial 0, only for the selected batch elements (i.e. those who hadeither 3-4 or\n", 3105 | "# 4-3 on thistrial, and it was the first time they saw either)\n", 3106 | "\n", 3107 | "# This is for the first trial, before any learniong for this episode.\n", 3108 | "\n", 3109 | "# Importantly, you need to switch the sign of the v1s depending on whetherr it's 3-4 or 4-3. The first cue represents iits v1s with positive sign, the seond cue represents them with negative signs\n", 3110 | "\n", 3111 | "#Might be more appropriate toshow separately for 3-4 and 4-3?\n", 3112 | "\n", 3113 | "#print( np.corrcoef(v2N[MYNUM,:,0], allrates[MYNUM,:, ZS+1] )[0,1] )\n", 3114 | "# v2N: 2000, 200, 1\n", 3115 | "# allrates[:,:,ZS+!]: 2000, 200\n", 3116 | "v1sN = [x.cpu().detach().numpy() for x in v1s]\n", 3117 | "v2N = v2.cpu().detach().numpy()\n", 3118 | "\n", 3119 | "print(v1sN[0].shape, v2N.shape, wo.shape)\n", 3120 | "print(cp.shape,\".\")\n", 3121 | "\n", 3122 | "cue1patternsallbatchN = np.array(cue1patternsallbatch)\n", 3123 | "pp(\">\", cue1patternsallbatchN.shape)\n", 3124 | "\n", 3125 | "\n", 3126 | "\n", 3127 | "# Trial 1, cues 3/4 or 4/3\n", 3128 | "ss1 = (cp[:,0,0] == 3) & (cp[:,0,1] == 4)\n", 3129 | "ss2 = (cp[:,0,0] == 4) & (cp[:,0,1] == 3)\n", 3130 | "\n", 3131 | "\n", 3132 | "pp(\"selected in ss1:\", np.sum(ss1), \"| selected in ss2\", np.sum(ss2))\n", 3133 | "\n", 3134 | "\n", 3135 | "allsimsT1=[]\n", 3136 | "\n", 3137 | "for numstep in range(0, 4):\n", 3138 | " print(\"Step\", numstep, \":\")\n", 3139 | " allsimsT1_thisstep = []\n", 3140 | " for nc in range(8):\n", 3141 | " tmp1 = torch.from_numpy(v1sN[nc][:,0,:]) # 2000, 200. The 0 is a dummy dimension\n", 3142 | " tmp2 = torch.from_numpy(allrates[:,:, numstep])\n", 3143 | " allsimsT1_thisstep.append( torch.nn.functional.cosine_similarity(tmp1, tmp2).numpy() ) # 2000. We use cosine_similarity because it can be done in a batch\n", 3144 | " allsimsT1.append(allsimsT1_thisstep)\n", 3145 | "\n", 3146 | "allsimsT1 = np.array(allsimsT1) # 4 steps x 8 cues x 2000 batch elements\n", 3147 | "pp( allsimsT1.shape )\n", 3148 | "\n", 3149 | "\n", 3150 | "\n", 3151 | "\n", 3152 | "plt.figure(figsize=(6,3))\n", 3153 | "\n", 3154 | "for nump in range(4):\n", 3155 | " plt.subplot(1,4,1+nump)\n", 3156 | " # plt.axhline(0, color='gray')\n", 3157 | "\n", 3158 | " plt.xticks(np.arange(8), alphabet[:8])\n", 3159 | " plt.ylim( [-.8, .8])\n", 3160 | " plt.errorbar(np.arange(8), np.mean(allsimsT1[nump,:,ss1].T, axis=1),yerr=np.std(allsimsT1[nump,:,ss1].T, axis=1),\n", 3161 | " color='orange', ecolor='b', markerfacecolor='b', markeredgecolor='b', marker='o', label='DE')\n", 3162 | " plt.errorbar(np.arange(8), np.mean(allsimsT1[nump,:,ss2].T, axis=1),yerr=np.std(allsimsT1[nump,:,ss2].T, axis=1),\n", 3163 | " color='m', linestyle='--', ecolor='c', markerfacecolor='c', markeredgecolor='c', marker='o', label='ED')\n", 3164 | "\n", 3165 | "\n", 3166 | " plt.title('Step '+str(nump+1))\n", 3167 | " if nump == 3:\n", 3168 | " plt.legend(loc=(1.1, .5))\n", 3169 | " if nump == 0:\n", 3170 | " plt.ylabel(r\"Corr($\\tilde{\\psi}_{t1}(X), \\mathbf{r}(t)$)\" )\n", 3171 | " else:\n", 3172 | " plt.yticks([])\n", 3173 | "plt.suptitle(\"Trial 1\")\n", 3174 | "plt.tight_layout()\n", 3175 | "plt.savefig(\"adaptedFFcuereps_t1_DEonly.png\", dpi=300)\n", 3176 | "\n" 3177 | ], 3178 | "metadata": { 3179 | "id": "gvA2ZVviYIvw" 3180 | }, 3181 | "execution_count": null, 3182 | "outputs": [] 3183 | }, 3184 | { 3185 | "cell_type": "code", 3186 | "source": [ 3187 | "\n", 3188 | "numstep = ZS\n", 3189 | "print(\"Step\", numstep, \":\")\n", 3190 | "meansims = []\n", 3191 | "zepairs = [[0,1], [1,0], [1,2], [2,1], [2,3], [3,2], [3,4], [4,3], [4,5], [5,4], [5,6], [6,5], [6,7], [7,6] ]\n", 3192 | "for pair in zepairs:\n", 3193 | " meansims_thispair = []\n", 3194 | " ss = (cp[:,NUMTRIALCHNGCORR,0] == pair[0] ) & (cp[:,NUMTRIALCHNGCORR,1] == pair[1])\n", 3195 | " for nc in range(8):\n", 3196 | " tmp1 = torch.from_numpy(v1sN[nc][:,0,:]) # 2000, 200. The 0 is a dummy dimension\n", 3197 | " tmp2 = torch.from_numpy(allrates[:,:, numstep])\n", 3198 | " sims = torch.nn.functional.cosine_similarity(tmp1, tmp2).numpy() # 2000. We use cosine_similarity because it can be done in a batch\n", 3199 | " # allsims_thispair.append(sims)\n", 3200 | " meansims_thispair.append(np.mean(sims[ss]))\n", 3201 | " meansims.append(meansims_thispair)\n", 3202 | "\n", 3203 | "meansims = np.array(meansims)\n", 3204 | "pp(meansims.shape)\n", 3205 | "\n", 3206 | "plt.figure(figsize=(6,3))\n", 3207 | "plt.matshow(meansims.T, fignum=1)\n", 3208 | "plt.colorbar(fraction=0.026, pad=0.04)\n", 3209 | "plt.xticks(range(len(zepairs)), [alphabet[x[0]]+alphabet[x[1]] for x in zepairs])\n", 3210 | "plt.yticks(range(8), [r\"$\\tilde{\\psi}_{t1}($\"+str(alphabet[x])+r\"$)$\" for x in range(8)])\n", 3211 | "plt.tick_params(axis='x', labelbottom=True, labeltop=False, top=False)\n", 3212 | "plt.xlabel('Pair shown in trial 20')\n", 3213 | "plt.ylabel('Adapted FF cue repres.')\n", 3214 | "plt.title(r\"Corr($\\tilde{\\psi}_{t1}(X), \\mathbf{r}(t=2)$) (Trial 20, 2000 runs)\")\n", 3215 | "\n", 3216 | "plt.savefig(\"adaptedFFcuereps_t18_allpairs.png\", dpi=300)\n", 3217 | "\n" 3218 | ], 3219 | "metadata": { 3220 | "id": "EsbQ5XDnuU18" 3221 | }, 3222 | "execution_count": null, 3223 | "outputs": [] 3224 | }, 3225 | { 3226 | "cell_type": "code", 3227 | "source": [ 3228 | "\n", 3229 | "# NOTE: This picture looks better with HALFNOBARREDPAIRUNTILT18 = True (more data points)\n", 3230 | "\n", 3231 | "\n", 3232 | "# Correlation between how mmuch cue barredpair[0]-1 was represented in the firing rates of time step ZS (step 1 of\n", 3233 | "#trial 18), and how much change there wars in the step-1 representation of thhis same cue at time step 3\n", 3234 | "\n", 3235 | "# This graph is onlyuseful if selects is large\n", 3236 | "\n", 3237 | "\n", 3238 | "\n", 3239 | "pp(allsims_adap.shape) # 4 steps, 8 cues, 2000 batchsize\n", 3240 | "pp(corrseachstep.shape) # 4, 8, 2000: at each step, corrs of the step-1 (2nd step, after one pass through learned plastic weights) representation of each cue with w_out (trial 18)\n", 3241 | "pp(cp.shape) # 2000 , 30, 2\n", 3242 | "pp(BARREDPAIR, ADDBARREDPAIR)\n", 3243 | "pp(ds.shape) # 2000, 120\n", 3244 | "dssign = np.sign(ds[:, params['triallen'] * 18 + 3])\n", 3245 | "print(np.mean(dssign<0))\n", 3246 | "pp(resps.shape) # 2000, 30\n", 3247 | "pp(corr.shape)\n", 3248 | "pp(np.unique(resps))\n", 3249 | "\n", 3250 | "ss = np.zeros(allsims_adap.shape[-1])\n", 3251 | "# ss[selects] = 1; ss =ss>0 # Booleanize\n", 3252 | "# ss = (cp[:, 18, 0] == 3) & (cp[:, 18, 1] == 4)\n", 3253 | "ss = np.isin(cp[:, 18, 0], BARREDPAIR) & np.isin(cp[:, 18, 1], BARREDPAIR)\n", 3254 | "\n", 3255 | "\n", 3256 | "# # We look at the one in the additional barred pair that is not part of the barred pair\n", 3257 | "# if max(ADDBARREDPAIR) > max(BARREDPAIR):\n", 3258 | "# zecue = max(ADDBARREDPAIR)\n", 3259 | "# else:\n", 3260 | "# zecue = min(ADDBARREDPAIR)\n", 3261 | "\n", 3262 | "# Reminder:\n", 3263 | "# tmp1 = torch.from_numpy(v1sN[nc][:,0,:]) # 2000, 200. The 0 is a dummy dimension\n", 3264 | "# tmp2 = torch.from_numpy(allrates[:,:, numstep])\n", 3265 | "# allsims_adap_thisstep.append( torch.nn.functional.cosine_similarity(tmp1, tmp2).numpy() ) # 2000. We use cosine_similarity because it can be done in a batch\n", 3266 | "\n", 3267 | "\n", 3268 | "zecue = 2 # C\n", 3269 | "plt.figure(figsize=(6,3))\n", 3270 | "plt.subplot(1,2,1)\n", 3271 | "plt.plot((allsims_adap[1, zecue,ss]) , corrseachstep[3, zecue,ss] - corrseachstep[2, zecue,ss] ,'.')\n", 3272 | "plt.xlabel(r\"Corr($\\mathbf{r}(t=2), \\mathbf{\\tilde{\\psi}}_{t1}(C)$)\")\n", 3273 | "plt.ylabel(r\"$\\Delta\\mathbf{\\psi}_{t2}(C)$ (Trial 20, step 4)\")\n", 3274 | "\n", 3275 | "\n", 3276 | "ss1 = ss & (corr[:,18] >0)# (np.abs(ds[:, params['triallen'] * 18 + 3])< .3)\n", 3277 | "ss2 = ss & (corr[:,18] <1)# (np.abs(ds[:, params['triallen'] * 18 + 3])< .3)\n", 3278 | "\n", 3279 | "pp(\"ss1:\", np.sum(ss1))\n", 3280 | "pp(\"ss2:\", np.sum(ss2))\n", 3281 | "\n", 3282 | "plt.subplot(1,2,2)\n", 3283 | "# plt.plot((allsims_adap[1, zecue,ss1]) , (corrseachstep[3, zecue,ss1] - corrseachstep[2, zecue,ss1]) * resps[ss1, 18] ,'c+', label='Correct')\n", 3284 | "# plt.plot((allsims_adap[1, zecue,ss2]) , (corrseachstep[3, zecue,ss2] - corrseachstep[2, zecue,ss2]) * resps[ss2, 18] ,'r.',label='Incorrect')\n", 3285 | "plt.plot((allsims_adap[1, zecue,ss1]) * resps[ss1, 18], (corrseachstep[3, zecue,ss1] - corrseachstep[2, zecue,ss1]) ,'c+', label='Correct')\n", 3286 | "plt.plot((allsims_adap[1, zecue,ss2]) * resps[ss2, 18] , (corrseachstep[3, zecue,ss2] - corrseachstep[2, zecue,ss2]),'r.',label='Incorrect')\n", 3287 | "plt.xlabel(r\"Corr($\\mathbf{r}(t=2), \\mathbf{\\tilde{\\psi}}_{t1}(C)$) * Resp\")\n", 3288 | "\n", 3289 | "plt.legend()\n", 3290 | "\n", 3291 | "\n", 3292 | "plt.savefig(\"psitildetodeltapsi.png\", dpi=300)\n" 3293 | ], 3294 | "metadata": { 3295 | "id": "lPvNqPU7sD_y" 3296 | }, 3297 | "execution_count": null, 3298 | "outputs": [] 3299 | }, 3300 | { 3301 | "cell_type": "code", 3302 | "source": [ 3303 | "# raise ValueError" 3304 | ], 3305 | "metadata": { 3306 | "id": "3CKRHVMqotlJ" 3307 | }, 3308 | "execution_count": null, 3309 | "outputs": [] 3310 | }, 3311 | { 3312 | "cell_type": "code", 3313 | "execution_count": null, 3314 | "metadata": { 3315 | "id": "xBCzMEcwhQ0p" 3316 | }, 3317 | "outputs": [], 3318 | "source": [ 3319 | "print(len(ds))\n", 3320 | "print(np.shape(np.hstack(ds)))\n", 3321 | "print(len(rs))\n", 3322 | "print(np.shape(np.hstack(rs)))\n", 3323 | "print(rs[0].shape)\n", 3324 | "print(ds[0].shape)" 3325 | ] 3326 | }, 3327 | { 3328 | "cell_type": "code", 3329 | "execution_count": null, 3330 | "metadata": { 3331 | "id": "sRtmdv94t0K0" 3332 | }, 3333 | "outputs": [], 3334 | "source": [ 3335 | "print(cp[NUMBECHNGCORR,5,:])\n", 3336 | "print(cp[NUMBECHNGCORR,6,:])\n", 3337 | "print(cp[NUMBECHNGCORR,:7,:].T)" 3338 | ] 3339 | }, 3340 | { 3341 | "cell_type": "code", 3342 | "execution_count": null, 3343 | "metadata": { 3344 | "id": "M9WE4vRLc3TR" 3345 | }, 3346 | "outputs": [], 3347 | "source": [ 3348 | "print(cp.shape, a.shape) # (5000, 35, 2) (5000, 200, 70) (70 = nbtrials * nbstepspertrial)\n", 3349 | "# am = allrates - np.mean(allrates, axis=0)[None,:,:]\n", 3350 | "am = allrates - 1e10\n", 3351 | "NUMOTHERTRIAL = 29\n", 3352 | "mask1 = np.argwhere((cp[:, 0, 0] == 2) & (cp[:, 0, 1]==3) )\n", 3353 | "mask2 = np.argwhere ((cp[:, 0, 0] == 3) & (cp[:, 0, 1]==2) )\n", 3354 | "mask3 = np.argwhere ((cp[:, 0, 0] == 2) & (cp[:, 0, 1]==1) )\n", 3355 | "mask4 = np.argwhere ((cp[:, 0, 0] == 6) & (cp[:, 0, 1]==7) )\n", 3356 | "mask5 = np.argwhere ((cp[:, 0, 0] == 0) & (cp[:, 0, 1]==1) )\n", 3357 | "mask6 = np.argwhere ((cp[:, 0, 0] == 1) & (cp[:, 0, 1]==2) )\n", 3358 | "mask7 = np.argwhere ((cp[:, NUMOTHERTRIAL, 0] == 2) & (cp[:, NUMOTHERTRIAL, 1]==3) )\n", 3359 | "print(mask1[0], \"-\", mask2[0])\n", 3360 | "print(am[mask1[0], :, 0].shape)\n", 3361 | "print(np.corrcoef(am[mask1[0], :, 0], am[mask1[3], :, 0]))\n", 3362 | "print(np.corrcoef(am[mask1[0], :, 0], am[mask2[3], :, 0]))\n", 3363 | "print(np.corrcoef(am[mask1[0], :, 0], am[mask3[3], :, 0]))\n", 3364 | "print(np.corrcoef(am[mask1[0], :, 0], am[mask4[3], :, 0]))\n", 3365 | "print(np.corrcoef(am[mask1[0], :, 0], am[mask5[3], :, 0]))\n", 3366 | "print(np.corrcoef(am[mask1[0], :, 0], am[mask6[3], :, 0]))\n", 3367 | "print(np.corrcoef(am[mask1[0], :, 0], am[mask7[1], :, NUMOTHERTRIAL * 2]))\n", 3368 | "plt.figure(figsize=(10,5))\n", 3369 | "plt.plot(am[mask1[0], :, 0].T)\n", 3370 | "plt.plot(am[mask7[1], :, NUMOTHERTRIAL*2].T)\n" 3371 | ] 3372 | }, 3373 | { 3374 | "cell_type": "code", 3375 | "execution_count": null, 3376 | "metadata": { 3377 | "id": "esING6_8SFyy" 3378 | }, 3379 | "outputs": [], 3380 | "source": [ 3381 | "# print(nb, nt)\n", 3382 | "# print(p[:])\n", 3383 | "# print(cp[nb, nt,:])\n", 3384 | "# print(r[nb,nt])\n", 3385 | "# print(c[nb,nt])" 3386 | ] 3387 | }, 3388 | { 3389 | "cell_type": "code", 3390 | "execution_count": null, 3391 | "metadata": { 3392 | "id": "qHFcLGKh2cXY" 3393 | }, 3394 | "outputs": [], 3395 | "source": [ 3396 | "# len(all_grad_norms)\n", 3397 | "# print(all_grad_norms[-2])\n", 3398 | "# gns = np.array([float(x) for x in all_grad_norms])\n", 3399 | "# print(np.mean(gns>1.9), np.mean(gns>1.49), np.mean(gns>.99))" 3400 | ] 3401 | } 3402 | ], 3403 | "metadata": { 3404 | "accelerator": "GPU", 3405 | "colab": { 3406 | "provenance": [] 3407 | }, 3408 | "kernelspec": { 3409 | "display_name": "Python 3", 3410 | "name": "python3" 3411 | }, 3412 | "language_info": { 3413 | "name": "python" 3414 | } 3415 | }, 3416 | "nbformat": 4, 3417 | "nbformat_minor": 0 3418 | } -------------------------------------------------------------------------------- /addons/net_FrozenInputWeights.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasMiconi/TransitiveInference/fbdebf06f6d335bcd4473a4fa24926208a98910f/addons/net_FrozenInputWeights.dat -------------------------------------------------------------------------------- /addons/toymodel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasMiconi/TransitiveInference/fbdebf06f6d335bcd4473a4fa24926208a98910f/addons/toymodel.png -------------------------------------------------------------------------------- /addons/toymodel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import scipy 4 | 5 | 6 | ALPHA = 1.0 7 | NBEPISODES = 1000 8 | ONLYUPDATEONERROR = False # True # False 9 | TANH = False # True 10 | 11 | plt.figure(figsize=(4,3)) 12 | 13 | corrzperplot = [] 14 | 15 | for numplot, REINSTATE in enumerate([False, True]): 16 | valsperepisode = [] 17 | for numepisode in range(NBEPISODES): 18 | 19 | vals = 2.0 * np.random.rand(8) - 1.0 20 | for numtrial in range(20): 21 | pos1 = np.random.randint(7) 22 | pos2 = pos1 + 1 23 | response = vals[pos2] - vals[pos1] 24 | correct = True if response < 0 else False # vals1 should be > vals2 25 | incr = ALPHA * (1 if response > 0 else -1) * (-1 if correct else 1) 26 | if correct and ONLYUPDATEONERROR: 27 | incr = 0 28 | 29 | vals[pos1] += incr 30 | vals[pos2] -= incr 31 | if REINSTATE: 32 | if pos1>0: 33 | vals[pos1-1] += .5 * incr 34 | if pos2<7: 35 | vals[pos2+1] -= .5 * incr 36 | valsperepisode.append(vals) 37 | valsperepisode = np.array(valsperepisode) 38 | if TANH: 39 | valsperepisode = np.tanh(valsperepisode) 40 | mycolor = ['b','r'][numplot] 41 | 42 | corrz = np.corrcoef(valsperepisode, np.arange(8))[NBEPISODES][:NBEPISODES] # Correlation of each final set of vals (across all episodes) with arange(8); note that we must exclude tha last val which is just corr of arange(8) with itself 43 | corrzperplot.append(corrz) 44 | print("Mean/std of corrz b/w vals and arange(8) (reinstate="+str(REINSTATE)+"):", np.mean(corrz), np.std(corrz)) 45 | meanz = np.mean(valsperepisode, axis=0) 46 | stdz = np.std(valsperepisode, axis=0) 47 | 48 | plt.fill_between(np.arange(8), meanz-stdz, meanz+stdz, color=mycolor, alpha=.3) 49 | plt.plot(meanz, color=mycolor, label=('REINSTATE' if REINSTATE else 'NO-REINSTATE')) 50 | 51 | 52 | plt.xticks(np.arange(8), ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']) 53 | plt.legend() 54 | 55 | 56 | print("Mann-Whitney U-test (2-sided) on the groups of correlations from both plots (without/with reinstate):", 57 | scipy.stats.mannwhitneyu(corrzperplot[0], corrzperplot[1])) 58 | # scipy.stats.mannwhitneyu(corrzperplot[0], corrzperplot[0])) 59 | 60 | plt.show() 61 | plt.savefig('toymodel.png', dpi=300) 62 | 63 | 64 | -------------------------------------------------------------------------------- /net_active.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasMiconi/TransitiveInference/fbdebf06f6d335bcd4473a4fa24926208a98910f/net_active.dat -------------------------------------------------------------------------------- /net_passive.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThomasMiconi/TransitiveInference/fbdebf06f6d335bcd4473a4fa24926208a98910f/net_passive.dat -------------------------------------------------------------------------------- /simple.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "mORufQW4nEYK" 7 | }, 8 | "source": [ 9 | "## HOW TO USE THIS NOTEBOOK\n", 10 | "\n", 11 | "This is the simpler vesion of the code, without any experimental or evaluation code. It just meta-trains a network (over 30000 iterations) and stores the optimized network in `net.dat`. You can then use this file to run the EVAL mode of the main code (i.e. run `main.ipynb` with EVAL=True) and produce figures.\n", 12 | "\n", 13 | "If you want to understand how the system works, it is highly recommended to look at this code rather than the main code.\n", 14 | "\n", 15 | "This system uses the exact same training process as the main code, except for the fact that plastic weights are reset at every episode and no data from previous episodes is used (no attempt at continual meta-learning, unlike the main code where the network keeps memory of up to 3 sequences). However, the resulting network work just as well on all experiments from the main code, including list-linking." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": { 22 | "id": "0Nboz_4ynCaZ" 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "\n", 27 | "# What GPU are we using?\n", 28 | "#!nvidia-smi" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": { 35 | "id": "x277nTktumok" 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "# Based on the code for the Stimulus-response task as described in Miconi et al. ICLR 2019.\n", 40 | "\n", 41 | "import argparse\n", 42 | "import pdb\n", 43 | "import torch\n", 44 | "import torch.nn as nn\n", 45 | "import numpy as np\n", 46 | "from numpy import random\n", 47 | "import torch.nn.functional as F\n", 48 | "from torch import optim\n", 49 | "import random\n", 50 | "import sys\n", 51 | "import pickle\n", 52 | "import time\n", 53 | "import os\n", 54 | "import platform\n", 55 | "\n", 56 | "import numpy as np\n", 57 | "import glob\n", 58 | "\n", 59 | "\n", 60 | "\n", 61 | "myseed = -1\n", 62 | "\n", 63 | "\n", 64 | "# If running this code on a cluster, uncomment the following, and pass a RNG seed as the --seed parameter on the command line\n", 65 | "# parser = argparse.ArgumentParser()\n", 66 | "# parser.add_argument('--seed', type=int, default=-1)\n", 67 | "# args = parser.parse_args()\n", 68 | "# myseed = args.seed\n", 69 | "\n", 70 | "\n", 71 | "\n", 72 | "\n", 73 | "\n", 74 | "np.set_printoptions(precision=5)\n", 75 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 76 | "# device = 'cpu'\n", 77 | "\n", 78 | "\n", 79 | "# Global parameters\n", 80 | "GG={}\n", 81 | "GG['rngseed']=myseed # RNG seed, or -1 for no seed\n", 82 | "GG['rew']=1.0 # reward amount\n", 83 | "GG['wp']=.0 # penalty for taking action 1 (not used here)\n", 84 | "GG['bent']=.1 # entropy incentive (actually sum-of-squares)\n", 85 | "GG['blossv']=.1 # value prediction loss coefficient\n", 86 | "GG['gr']=.9 # Gamma for temporal reward discounting\n", 87 | "\n", 88 | "GG['hs']=200 # Size of the RNN's hidden layer\n", 89 | "GG['bs']=32 # Batch size\n", 90 | "GG['gc']=2.0 # Gradient clipping\n", 91 | "GG['eps']=1e-6 # A parameter for Adam\n", 92 | "GG['nbiter']= 30000 # 60000\n", 93 | "GG['save_every']=200\n", 94 | "GG['pe']= 101 #\"print every\"\n", 95 | "\n", 96 | "\n", 97 | "GG['nbcuesrange'] = range(4,9) # The total number of cues varies from one episode to the next\n", 98 | "\n", 99 | "GG['cs']= 15 # 10 # Cue size - number of binary elements in each cue vector (not including the 'go' bit and additional inputs, see below)\n", 100 | "\n", 101 | "GG['triallen'] = 4 # Number of time steps in each trial\n", 102 | "NUMRESPONSESTEP = 1\n", 103 | "GG['nbtraintrials'] = 20 # The first nbtraintrials are the \"train\" trials. This is included in nbtrials.\n", 104 | "GG['nbtesttrials'] = 10 # The last nbtesttrials are the \"test\" trials. This is included in nbtrials.\n", 105 | "GG['nbtrials'] = GG['nbtraintrials'] + GG['nbtesttrials'] # Number of trials per episode\n", 106 | "GG['eplen'] = GG['nbtrials'] * GG['triallen'] # eplen = episode length\n", 107 | "GG['testlmult'] = 3.0 # multiplier for the loss during the test trials\n", 108 | "GG['l2'] = 0 # 1e-5 # L2 penalty\n", 109 | "GG['lr'] = 1e-4\n", 110 | "GG['lpw'] = 1e-4 # 3 # plastic weight loss\n", 111 | "\n", 112 | "\n", 113 | "\n", 114 | "\n", 115 | "\n", 116 | "# RNN with plastic connections and neuromodulation (\"DA\").\n", 117 | "# Plasticity only in the recurrent connections.\n", 118 | "\n", 119 | "class RetroModulRNN(nn.Module):\n", 120 | " def __init__(self, GG):\n", 121 | " super(RetroModulRNN, self).__init__()\n", 122 | " # NOTE: 'outputsize' excludes the value and neuromodulator outputs!\n", 123 | " for paramname in ['outputsize', 'inputsize', 'hs', 'bs']:\n", 124 | " if paramname not in GG.keys():\n", 125 | " raise KeyError(\"Must provide missing key in argument 'GG': \"+paramname)\n", 126 | " NBDA = 2 # 2 DA neurons, we take the difference - see below\n", 127 | " self.GG = GG\n", 128 | " self.activ = torch.tanh\n", 129 | " self.i2h = torch.nn.Linear(self.GG['inputsize'], GG['hs']).to(device)\n", 130 | " self.w = torch.nn.Parameter(( (1.0 / np.sqrt(GG['hs'])) * ( 2.0 * torch.rand(GG['hs'], GG['hs']) - 1.0) ).to(device), requires_grad=True)\n", 131 | " self.alpha = .01 * (2.0 * torch.rand(GG['hs'], GG['hs']) -1.0).to(device)\n", 132 | " self.alpha = torch.nn.Parameter(self.alpha, requires_grad=True)\n", 133 | " self.etaet = torch.nn.Parameter((.7 * torch.ones(1)).to(device), requires_grad=True) # Everyone has the same etaet\n", 134 | " self.DAmult = torch.nn.Parameter((1.0 * torch.ones(1)).to(device), requires_grad=True) # Everyone has the same DAmult\n", 135 | " self.h2DA = torch.nn.Linear(GG['hs'], NBDA).to(device) # DA output\n", 136 | " self.h2o = torch.nn.Linear(GG['hs'], self.GG['outputsize']).to(device) # Actual output\n", 137 | " self.h2v = torch.nn.Linear(GG['hs'], 1).to(device) # V prediction\n", 138 | "\n", 139 | " def forward(self, inputs, hidden, et, pw):\n", 140 | " BATCHSIZE = inputs.shape[0] # self.GG['bs']\n", 141 | " HS = self.GG['hs']\n", 142 | " assert pw.shape[0] == hidden.shape[0] == et.shape[0] == BATCHSIZE\n", 143 | "\n", 144 | " # Multiplying inputs (i.e. current hidden values) by the total recurrent weights, w + alpha * plastic_weights\n", 145 | " hactiv = self.activ(self.i2h(inputs).view(BATCHSIZE, HS, 1) + torch.matmul((self.w + torch.mul(self.alpha, pw)),\n", 146 | " hidden.view(BATCHSIZE, HS, 1))).view(BATCHSIZE, HS)\n", 147 | " activout = self.h2o(hactiv) # Output layer. Pure linear, raw scores - will be softmaxed later\n", 148 | " valueout = self.h2v(hactiv) # Value prediction\n", 149 | "\n", 150 | " # Now computing the Hebbian updates...\n", 151 | "\n", 152 | " # With batching, DAout is a matrix of size BS x 1\n", 153 | " DAout2 = torch.tanh(self.h2DA(hactiv))\n", 154 | " DAout = self.DAmult * (DAout2[:,0] - DAout2[:,1])[:,None] # DA output is the difference between two tanh neurons - allows negative, positive and easy stable 0 output (by jamming both neurons to max or min)\n", 155 | "\n", 156 | " # Eligibility trace gets stamped into the plastic weights - gated by DAout\n", 157 | " deltapw = DAout.view(BATCHSIZE,1,1) * et\n", 158 | " pw = pw + deltapw\n", 159 | "\n", 160 | " torch.clip_(pw, min=-50.0, max=50.0)\n", 161 | "\n", 162 | " # Updating the eligibility trace - Hebbbian update with a simple decay\n", 163 | " # NOTE: the decay is for the eligibility trace, NOT the plastic weights (which never decay during a lifetime, i.e. an episode)\n", 164 | " deltaet = torch.bmm(hactiv.view(BATCHSIZE, HS, 1), hidden.view(BATCHSIZE, 1, HS)) # batched outer product; at this point 'hactiv' is the output and 'hidden' is the input (i.e. ativities from previous time step)\n", 165 | " deltaet = torch.tanh(deltaet)\n", 166 | " et = (1 - self.etaet) * et + self.etaet * deltaet\n", 167 | "\n", 168 | " hidden = hactiv\n", 169 | " return activout, valueout, DAout, hidden, et, pw\n", 170 | "\n", 171 | "\n", 172 | "\n", 173 | " def initialZeroET(self, mybs):\n", 174 | " # return torch.zeros(self.GG['bs'], self.GG['hs'], self.GG['hs'], requires_grad=False).to(device)\n", 175 | " return torch.zeros(mybs, self.GG['hs'], self.GG['hs'], requires_grad=False).to(device)\n", 176 | "\n", 177 | " def initialZeroPlasticWeights(self, mybs):\n", 178 | " return torch.zeros(mybs, self.GG['hs'], self.GG['hs'] , requires_grad=False).to(device)\n", 179 | " def initialZeroState(self, mybs):\n", 180 | " return torch.zeros(mybs, self.GG['hs'], requires_grad=False ).to(device)\n", 181 | "\n", 182 | "\n", 183 | "\n", 184 | "print(\"Starting...\")\n", 185 | "\n", 186 | "print(\"Passed GG: \", GG)\n", 187 | "print(platform.uname())\n", 188 | "suffix = \"_\"+\"\".join( [str(kk)+str(vv)+\"_\" if kk != 'pe' and kk != 'nbsteps' and kk != 'rngseed' and kk != 'save_every' and kk != 'test_every' else '' for kk, vv in sorted(zip(GG.keys(), GG.values()))] ) + \"_rng\" + str(GG['rngseed']) # Turning the parameters into a nice suffix for filenames\n", 189 | "print(suffix)\n", 190 | "\n", 191 | "\n", 192 | "# Total input size = cue size + one 'go' bit + 4 additional inputs\n", 193 | "ADDINPUT = 4 # Additional inputs: 1 inputs for the previous reward, 1 inputs for numstep, 1 unused, 1 \"Bias\" inputs\n", 194 | "NBSTIMBITS = 2 * GG['cs'] + 1 # The additional bit is for the response cue (i.e. the \"Go\" cue)\n", 195 | "GG['outputsize'] = 2 # \"response\" and \"no response\"\n", 196 | "GG['inputsize'] = NBSTIMBITS + ADDINPUT + GG['outputsize'] # The total number of input bits is the size of cues, plus the \"response cue\" binary input, plus the number of additional inputs, plus the number of actions\n", 197 | "\n", 198 | "\n", 199 | "# Initialize random seeds, unless rngseed is -1 (first two redundant?)\n", 200 | "if GG['rngseed'] > -1 :\n", 201 | " print(\"Setting random seed\", GG['rngseed'])\n", 202 | " np.random.seed(GG['rngseed']); random.seed(GG['rngseed']); torch.manual_seed(GG['rngseed'])\n", 203 | "else:\n", 204 | " print(\"No random seed.\")\n", 205 | "\n", 206 | "\n", 207 | "BS = GG['bs'] # Batch size\n", 208 | "\n", 209 | "\n", 210 | "print(\"Initializing network\")\n", 211 | "net = RetroModulRNN(GG)\n", 212 | "\n", 213 | "\n", 214 | "print (\"Shape of all optimized parameters:\", [x.size() for x in net.parameters()])\n", 215 | "allsizes = [torch.numel(x.data.cpu()) for x in net.parameters()]\n", 216 | "print (\"Size (numel) of all optimized elements:\", allsizes)\n", 217 | "print (\"Total size (numel) of all optimized elements:\", sum(allsizes))\n", 218 | "\n", 219 | "print(\"Initializing optimizer\")\n", 220 | "optimizer = torch.optim.Adam(net.parameters(), lr=1.0*GG['lr'], eps=GG['eps'], weight_decay=GG['l2'])\n", 221 | "\n", 222 | "\n", 223 | "lossbetweensaves = 0\n", 224 | "nowtime = time.time()\n", 225 | "\n", 226 | "nbtrials = [0]*BS\n", 227 | "totalnbtrials = 0\n", 228 | "nbtrialswithcc = 0\n", 229 | "all_mean_testrewards_ep = []\n", 230 | "\n", 231 | "\n", 232 | "\n", 233 | "\n", 234 | "print(\"Starting episodes!\")\n", 235 | "\n", 236 | "for numepisode in range(GG['nbiter']):\n", 237 | "\n", 238 | " PRINTTRACE = False\n", 239 | " if (numepisode) % (GG['pe']) == 0 :\n", 240 | " PRINTTRACE = True\n", 241 | "\n", 242 | " optimizer.zero_grad()\n", 243 | " loss = 0\n", 244 | " lossv = 0\n", 245 | " GG['nbcues']= random.choice(GG['nbcuesrange'])\n", 246 | " hidden = net.initialZeroState(BS)\n", 247 | " et = net.initialZeroET(BS) # The Hebbian eligibility trace\n", 248 | "\n", 249 | " # In this simplified version we just reset the plastic weights at every episode (the main version only resets it every 3rd episode and remembers previous lists)\n", 250 | " pw = net.initialZeroPlasticWeights(BS)\n", 251 | "\n", 252 | " numstep_ep = 0\n", 253 | " iscorrect_thisep = np.zeros((BS, GG['nbtrials']))\n", 254 | " istest_thisep = np.zeros((BS, GG['nbtrials']))\n", 255 | " isadjacent_thisep = np.zeros((BS, GG['nbtrials']))\n", 256 | " # isolddata_thisep = np.zeros((BS, GG['nbtrials']))\n", 257 | " resps_thisep = np.zeros((BS, GG['nbtrials']))\n", 258 | " cuepairs_thisep = []\n", 259 | " numactionschosen_alltrialsandsteps_thisep = np.zeros((BS, GG['nbtrials'], GG['triallen'])).astype(int)\n", 260 | "\n", 261 | "\n", 262 | " # Generate the bitstring for each cue number for this episode. Make sure they're all different (important when using very small cues for debugging, e.g. cs=2, ni=2)\n", 263 | " cuedata=[]\n", 264 | " for nb in range(BS):\n", 265 | " cuedata.append([])\n", 266 | " for ncue in range(GG['nbcues']):\n", 267 | " assert len(cuedata[nb]) == ncue\n", 268 | " foundsame = 1\n", 269 | " cpt = 0\n", 270 | " while foundsame > 0 :\n", 271 | " cpt += 1\n", 272 | " if cpt > 10000:\n", 273 | " # This should only occur with very weird parameters, e.g. cs=2, ni>4\n", 274 | " raise ValueError(\"Could not generate a full list of different cues\")\n", 275 | " foundsame = 0\n", 276 | " candidate = np.random.randint(2, size=GG['cs']) * 2 - 1\n", 277 | " for backtrace in range(ncue):\n", 278 | " # if np.array_equal(cuedata[nb][backtrace], candidate):\n", 279 | " if np.mean(cuedata[nb][backtrace] == candidate) > .66 :\n", 280 | " foundsame = 1\n", 281 | "\n", 282 | " cuedata[nb].append(candidate)\n", 283 | "\n", 284 | "\n", 285 | " reward = np.zeros(BS)\n", 286 | " sumreward = np.zeros(BS)\n", 287 | " sumrewardtest = np.zeros(BS)\n", 288 | " rewards = []\n", 289 | " vs = []\n", 290 | " logprobs = []\n", 291 | " cues=[]\n", 292 | " for nb in range(BS):\n", 293 | " cues.append([])\n", 294 | " dist = 0\n", 295 | " numactionschosen = np.zeros(BS, dtype='int32')\n", 296 | "\n", 297 | "\n", 298 | " nbtrials = np.zeros(BS)\n", 299 | " nbtesttrials = nbtesttrials_correct = nbtesttrials_adjcues = nbtesttrials_adjcues_correct = nbtesttrials_nonadjcues = nbtesttrials_nonadjcues_correct = 0\n", 300 | " nbrewardabletrials = np.zeros(BS)\n", 301 | " thistrialhascorrectorder = np.zeros(BS)\n", 302 | " thistrialhasadjacentcues = np.zeros(BS)\n", 303 | " thistrialhascorrectanswer = np.zeros(BS)\n", 304 | "\n", 305 | "\n", 306 | " # 2 steps of blank input between episodes. Not sure if it helps.\n", 307 | " inputs = np.zeros((BS, GG['inputsize']), dtype='float32')\n", 308 | " inputsC = torch.from_numpy(inputs).detach().to(device)\n", 309 | " for nn in range(2):\n", 310 | " y, v, DAout, hidden, et, pw = net(inputsC, hidden, et, pw) # y should output raw scores, not probas\n", 311 | "\n", 312 | "\n", 313 | "\n", 314 | "\n", 315 | " for numtrial in range(GG['nbtrials']):\n", 316 | "\n", 317 | "\n", 318 | " # To simplify dynamics as much as possible, we reset hidden activations and eligibility traces (but not plastic weights) between trials.\n", 319 | " hidden = net.initialZeroState(BS)\n", 320 | " et = net.initialZeroET(BS)\n", 321 | "\n", 322 | " # First, we prepare the specific sequence of inputs for this trial\n", 323 | " # The inputs can be a pair of cue numbers, or -1 (empty stimulus), or a single number equal to GG['nbcues'], which indicates the 'response' cue.\n", 324 | " # These will be translated into actual network inputs (using the actual bitstrings) later.\n", 325 | " # Remember that the actual data for each cue (i.e. its actual bitstring) is randomly generated for each episode, above\n", 326 | "\n", 327 | " cuepairs_thistrial = []\n", 328 | " for nb in range(BS):\n", 329 | " thistrialhascorrectorder[nb] = 0\n", 330 | " cuerange = range(GG['nbcues'])\n", 331 | " # # In any trial, we show exactly two cues (randomly chosen), simultaneously:\n", 332 | " cuepair = list(np.random.choice(cuerange, 2, replace=False))\n", 333 | "\n", 334 | " # If the trial is NOT a test trial, these two cues should be adjacent\n", 335 | " if nbtrials[nb] < GG['nbtraintrials'] :\n", 336 | " while abs(cuepair[0] - cuepair[1]) > 1 :\n", 337 | " cuepair = list(np.random.choice(cuerange, 2, replace=False))\n", 338 | " else:\n", 339 | " assert nbtrials[nb] >= GG['nbtraintrials']\n", 340 | "\n", 341 | " assert nbtrials[nb] == numtrial\n", 342 | "\n", 343 | " thistrialhascorrectorder[nb] = 1 if cuepair[0] < cuepair[1] else 0\n", 344 | " thistrialhasadjacentcues[nb] = 1 if (abs(cuepair[0]-cuepair[1]) == 1) else 0\n", 345 | " isadjacent_thisep[nb,numtrial] = thistrialhasadjacentcues[nb]\n", 346 | " istest_thisep[nb, numtrial] = 1 if numtrial >= GG['nbtraintrials'] else 0\n", 347 | "\n", 348 | " # mycues = [cuepair,cuepair]\n", 349 | " mycues = [cuepair,]\n", 350 | " cuepairs_thistrial.append(cuepair)\n", 351 | "\n", 352 | " mycues.append(GG['nbcues']) # The 'go' cue, instructing response from the network\n", 353 | " mycues.append(-1) # One empty step.During the first empty step, reward (computed on the previous step) is seen by the network.\n", 354 | " mycues.append(-1)\n", 355 | " # mycues.append(-1)\n", 356 | " assert len(mycues) == GG['triallen']\n", 357 | " assert mycues[NUMRESPONSESTEP] == GG['nbcues'] # The 'response' step is signalled by the 'go' cue, whose number is GG['nbcues'].\n", 358 | " cues[nb] = mycues\n", 359 | "\n", 360 | " cuepairs_thisep.append(cuepairs_thistrial)\n", 361 | "\n", 362 | "\n", 363 | " # Now we are ready to actually run the trial:\n", 364 | "\n", 365 | " for numstep in range(GG['triallen']):\n", 366 | "\n", 367 | " # Preparing inputs\n", 368 | " inputs = np.zeros((BS, GG['inputsize']), dtype='float32')\n", 369 | " for nb in range(BS):\n", 370 | " # Turning the cue number for this time step into actual (signed) bitstring inputs, using the cue data generated at the beginning of the episode\n", 371 | " inputs[nb, :NBSTIMBITS] = 0\n", 372 | " if cues[nb][numstep] != -1 and cues[nb][numstep] != GG['nbcues']:\n", 373 | " assert len(cues[nb][numstep]) == 2\n", 374 | " inputs[nb, :NBSTIMBITS-1] = np.concatenate( ( cuedata[nb][cues[nb][numstep][0]][:], cuedata[nb][cues[nb][numstep][1]][:] ) )\n", 375 | " if cues[nb][numstep] == GG['nbcues']:\n", 376 | " inputs[nb, NBSTIMBITS-1] = 1 # \"Go\" cue\n", 377 | "\n", 378 | " inputs[nb, NBSTIMBITS + 0] = 1.0 # Bias neuron, probably not necessary\n", 379 | " inputs[nb,NBSTIMBITS + 1] = numstep_ep / GG['eplen'] # Time passed in this episode. Should it be the trial? Doesn't matter much anyway.\n", 380 | " inputs[nb, NBSTIMBITS + 2] = 1.0 * reward[nb] # Reward from previous time step\n", 381 | "\n", 382 | " assert NUMRESPONSESTEP + 1 < GG['triallen'] # If that is not the case, we must provide the action signal in the next trial (this works)\n", 383 | " if numstep == NUMRESPONSESTEP + 1:\n", 384 | " inputs[nb, NBSTIMBITS + ADDINPUT + numactionschosen[nb]] = 1 # Previously chosen action, folowing standard meta-RL practice\n", 385 | "\n", 386 | "\n", 387 | " inputsC = torch.from_numpy(inputs).detach().to(device)\n", 388 | "\n", 389 | "\n", 390 | "\n", 391 | "\n", 392 | " ## Running the network\n", 393 | " y, v, DAout, hidden, et, pw = net(inputsC, hidden, et, pw) # y should output raw scores, not probas\n", 394 | "\n", 395 | "\n", 396 | "\n", 397 | " # Choosing the action from the outputs\n", 398 | " y = F.softmax(y, dim=1)\n", 399 | " # Must convert y to probas to use this !\n", 400 | " distrib = torch.distributions.Categorical(y)\n", 401 | " actionschosen = distrib.sample()\n", 402 | " logprobs.append(distrib.log_prob(actionschosen)) # To be used later for the A2C algorithm\n", 403 | " # Alternatively: only record logprobs just after the response step (the only step where it matters). Better performance, but not used for the paper.\n", 404 | " # if numstep == NUMRESPONSESTEP:\n", 405 | " # logprobs.append(distrib.log_prob(actionschosen)) # To be used later for the A2C algorithm\n", 406 | " # else:\n", 407 | " # logprobs.append(0)\n", 408 | " numactionschosen = actionschosen.data.cpu().numpy() # Store as scalars (for the whole batch)\n", 409 | "\n", 410 | " if PRINTTRACE:\n", 411 | " print(\"Tr\", numtrial, \"Step \", numstep, \", Cue 1 (0):\", inputs[0,:GG['cs']], \"Cue 2 (0):\", inputs[0,GG['cs']:2*GG['cs']],\n", 412 | " \"Other inputs:\", inputs[0, 2*GG['cs']:], \"\\n - Outputs(0): \", y.data.cpu().numpy()[0,:], \" - action chosen(0): \", numactionschosen[0],\n", 413 | " \"TrialLen:\", GG['triallen'], \"numstep\", numstep, \"TTHCC(0): \", thistrialhascorrectorder[0], \"Reward (based on prev step): \", reward[0], \", DAout:\", float(DAout[0]), \", cues(0):\", cues[0] ) #, \", cc(0):\", correctcue[0])\n", 414 | "\n", 415 | "\n", 416 | " # Computing the rewards. This is done for each time step.\n", 417 | " reward = np.zeros(BS, dtype='float32')\n", 418 | " for nb in range(BS):\n", 419 | "\n", 420 | " numactionschosen_alltrialsandsteps_thisep[nb, numtrial, numstep] = numactionschosen[nb]\n", 421 | "\n", 422 | " if numstep == NUMRESPONSESTEP: # 2: # 4: #3: # 2:\n", 423 | " # This is the 'response' step of the trial (and we showed the response signal\n", 424 | " assert cues[nb][numstep] == GG['nbcues']\n", 425 | " resps_thisep[nb, numtrial] = numactionschosen[nb] *2 - 1 # Store the response in this timestep as the response for the whole trial, for logging/analysis purposes\n", 426 | " # We must deliver reward (which will be perceived by the agent at the next step), positive or negative, depending on response\n", 427 | " thistrialhascorrectanswer[nb] = 1\n", 428 | " if thistrialhascorrectorder[nb] and numactionschosen[nb] == 1:\n", 429 | " reward[nb] += GG['rew']\n", 430 | " elif (not thistrialhascorrectorder[nb]) and numactionschosen[nb] == 0:\n", 431 | " reward[nb] += GG['rew']\n", 432 | " else:\n", 433 | " reward[nb] -= GG['rew']\n", 434 | " thistrialhascorrectanswer[nb] = 0\n", 435 | " iscorrect_thisep[nb, numtrial] = thistrialhascorrectanswer[nb]\n", 436 | "\n", 437 | " if numstep == GG['triallen'] - 1:\n", 438 | " # This was the last step of the trial\n", 439 | " nbtrials[nb] += 1\n", 440 | " totalnbtrials += 1\n", 441 | " if thistrialhascorrectorder[nb]:\n", 442 | " nbtrialswithcc += 1\n", 443 | "\n", 444 | "\n", 445 | "\n", 446 | " rewards.append(reward)\n", 447 | " vs.append(v)\n", 448 | " sumreward += reward\n", 449 | " if numtrial >= GG['nbtrials'] - GG['nbtesttrials']:\n", 450 | " sumrewardtest += reward\n", 451 | "\n", 452 | "\n", 453 | " loss += (GG['bent'] * y.pow(2).sum() / BS ) # In real A2c, this is an entropy incentive. Our original version of PyTorch did not have an entropy() function for Distribution, so we use sum-of-squares instead.\n", 454 | "\n", 455 | " numstep_ep += 1\n", 456 | "\n", 457 | "\n", 458 | " # All steps done for this trial\n", 459 | " if numtrial >= GG['nbtrials'] - GG['nbtesttrials']:\n", 460 | " sumrewardtest += reward\n", 461 | " nbtesttrials += BS\n", 462 | " nbtesttrials_correct += np.sum(thistrialhascorrectanswer)\n", 463 | " nbtesttrials_adjcues += np.sum(thistrialhasadjacentcues)\n", 464 | " nbtesttrials_adjcues_correct += np.sum(thistrialhasadjacentcues * thistrialhascorrectanswer)\n", 465 | " nbtesttrials_nonadjcues += np.sum(1 - thistrialhasadjacentcues)\n", 466 | " nbtesttrials_nonadjcues_correct += np.sum((1-thistrialhasadjacentcues) * thistrialhascorrectanswer)\n", 467 | "\n", 468 | "\n", 469 | " # All trials done for this episode\n", 470 | "\n", 471 | "\n", 472 | " # Computing the various losses for A2C (outer-loop training)\n", 473 | "\n", 474 | " R = torch.zeros(BS, requires_grad=False).to(device)\n", 475 | " gammaR = GG['gr']\n", 476 | " for numstepb in reversed(range(GG['eplen'])) :\n", 477 | " R = gammaR * R + torch.from_numpy(rewards[numstepb]).detach().to(device)\n", 478 | " ctrR = R - vs[numstepb][:,0] # I think this is right...\n", 479 | " lossv += ctrR.pow(2).sum() / BS\n", 480 | " LOSSMULT = GG['testlmult'] if numstepb > GG['eplen'] - GG['triallen'] * GG['nbtesttrials'] else 1.0\n", 481 | " loss -= LOSSMULT * (logprobs[numstepb] * ctrR.detach()).sum() / BS # Action policy loss\n", 482 | "\n", 483 | "\n", 484 | "\n", 485 | " lossobj = float(loss)\n", 486 | " loss += GG['blossv'] * lossv # lossmult is not applied to value-prediction loss; is it right?...\n", 487 | " loss /= GG['eplen']\n", 488 | " losspw = torch.mean(pw ** 2) * GG['lpw'] # loss on squared final plastic weights is not divided by episode length\n", 489 | " loss += losspw\n", 490 | "\n", 491 | "\n", 492 | "\n", 493 | " loss.backward()\n", 494 | " gn = torch.nn.utils.clip_grad_norm_(net.parameters(), GG['gc'])\n", 495 | " if numepisode > 100: # Burn-in period\n", 496 | " optimizer.step()\n", 497 | "\n", 498 | "\n", 499 | " lossnum = float(loss)\n", 500 | " lossbetweensaves += lossnum\n", 501 | " all_mean_testrewards_ep.append(sumrewardtest.mean())\n", 502 | "\n", 503 | "\n", 504 | " if PRINTTRACE:\n", 505 | "\n", 506 | " print(\"Episode\", numepisode, \"====\")\n", 507 | " previoustime = nowtime\n", 508 | " nowtime = time.time()\n", 509 | " print(\"Time spent on last\", GG['pe'], \"iters: \", nowtime - previoustime)\n", 510 | "\n", 511 | " print(\" etaet: \", net.etaet.data.cpu().numpy(), \" DAmult: \", float(net.DAmult), \" mean-abs pw: \", np.mean(np.abs(pw.data.cpu().numpy())))\n", 512 | " print(\"min/max/med-abs w, alpha, pw\")\n", 513 | " print(float(torch.min(net.w)), float(torch.max(net.w)), float(torch.median(torch.abs(net.w))))\n", 514 | " print(float(torch.min(net.alpha)), float(torch.max(net.alpha)), float(torch.median(torch.abs(net.alpha))))\n", 515 | " print(float(torch.min(pw)), float(torch.max(pw)), float(torch.median(torch.abs(pw))))\n", 516 | "\n", 517 | " # print(\"lossobj (with coeff):\", lossobj / GG['eplen'], \", lossv (with coeff): \", GG['blossv'] * float(lossv) / GG['eplen'],\n", 518 | " # \", losspw:\", float(losspw))\n", 519 | " # print (\"Total reward for this episode(0):\", sumreward[0], \"Prop. of trials w/ rewarded cue:\", (nbtrialswithcc / totalnbtrials), \" Total Nb of trials:\", totalnbtrials)\n", 520 | " print(\"Nb Test Trials:\", nbtesttrials, \", Nb Test Trials AdjCues:\", nbtesttrials_adjcues, \", Nb Test Trials NonAdjCues:\", nbtesttrials_nonadjcues)\n", 521 | " if nbtesttrials > 0:\n", 522 | " # Should always be the case except for LinkedListsEval\n", 523 | " print(\">>>> Test Performance (both methods):\", np.array([nbtesttrials_correct / nbtesttrials, np.sum(iscorrect_thisep * istest_thisep) / np.sum(istest_thisep)]),\n", 524 | " \"Test Perf AdjCues:\", np.array([(nbtesttrials_adjcues_correct / nbtesttrials_adjcues)]) if nbtesttrials_adjcues > 0 else 'N/A',\n", 525 | " \"Test Perf NonAdjCues:\", np.array([nbtesttrials_nonadjcues_correct / nbtesttrials_nonadjcues]) if nbtesttrials_nonadjcues > 0 else 'N/A'\n", 526 | " )\n", 527 | "\n", 528 | "\n", 529 | " if (numepisode) % GG['save_every'] == 0 and numepisode > 0:\n", 530 | " print(\"Saving local files...\")\n", 531 | "\n", 532 | " if numepisode > 0:\n", 533 | " # print(\"Saving model parameters...\")\n", 534 | " # torch.save(net.state_dict(), 'net_'+suffix+'.dat')\n", 535 | " torch.save(net.state_dict(), 'netAE'+str(GG['rngseed'])+'.dat')\n", 536 | " torch.save(net.state_dict(), 'net.dat')\n", 537 | "\n", 538 | " # with open('rewards_'+suffix+'.txt', 'w') as thefile:\n", 539 | " # for item in all_mean_rewards_ep[::10]:\n", 540 | " # thefile.write(\"%s\\n\" % item)\n", 541 | " # with open('testrew_'+suffix+'.txt', 'w') as thefile:\n", 542 | " # for item in all_mean_testrewards_ep[::10]:\n", 543 | " # thefile.write(\"%s\\n\" % item)\n", 544 | " with open('tAE'+str(GG['rngseed'])+'.txt', 'w') as thefile:\n", 545 | " for item in all_mean_testrewards_ep[::10]:\n", 546 | " thefile.write(\"%s\\n\" % item)\n", 547 | "\n", 548 | "\n" 549 | ] 550 | } 551 | ], 552 | "metadata": { 553 | "accelerator": "GPU", 554 | "colab": { 555 | "provenance": [] 556 | }, 557 | "kernelspec": { 558 | "display_name": "Python 3", 559 | "name": "python3" 560 | }, 561 | "language_info": { 562 | "name": "python" 563 | } 564 | }, 565 | "nbformat": 4, 566 | "nbformat_minor": 0 567 | } --------------------------------------------------------------------------------