├── .github └── workflows │ └── test_and_deploy.yml ├── .gitignore ├── .style.yapf ├── LICENSE ├── README.md ├── codecov.yml ├── conftest.py ├── notebooks ├── rastermap_interactive.ipynb ├── rastermap_largescale.ipynb ├── rastermap_singleneurons.ipynb ├── rastermap_widefield.ipynb ├── rastermap_zebrafish.ipynb └── tutorial.ipynb ├── paper ├── fig1.ipynb ├── fig1.py ├── fig2.ipynb ├── fig2.py ├── fig3.ipynb ├── fig3.py ├── fig4.ipynb ├── fig4.py ├── fig5.ipynb ├── fig5.py ├── fig6.ipynb ├── fig6.py ├── fig7.ipynb ├── fig7.py ├── fig8.ipynb ├── fig8.py ├── fig_utils.py ├── loaders.py ├── metrics.py ├── other_upsampling.py ├── qrdqn.py ├── simulations.py ├── splitting.ipynb └── svca.ipynb ├── rastermap ├── __init__.py ├── __main__.py ├── cluster.py ├── gui │ ├── colormaps.py │ ├── gui.py │ ├── guiparts.py │ ├── io.py │ ├── menus.py │ ├── run.py │ └── views.py ├── io.py ├── rastermap.py ├── sort.py ├── svd.py ├── upsample.py └── utils.py ├── setup.py ├── tests ├── test_import.py └── test_rastermap.py └── tox.ini /.github/workflows/test_and_deploy.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - "v*x" 8 | tags: 9 | - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 10 | pull_request: 11 | branches: 12 | - main 13 | # Allows you to run this workflow manually from the Actions tab 14 | workflow_dispatch: 15 | 16 | 17 | jobs: 18 | test: 19 | name: ${{ matrix.platform }} py${{ matrix.python }} 20 | runs-on: ${{ matrix.platform }} 21 | strategy: 22 | matrix: 23 | platform: [ubuntu-latest, macos-latest, windows-latest] 24 | python: ["3.8", "3.9", "3.10"] 25 | 26 | steps: 27 | - uses: actions/checkout@v3 28 | 29 | - name: Set up Python ${{ matrix.python }} 30 | uses: actions/setup-python@v4 31 | with: 32 | python-version: ${{ matrix.python }} 33 | 34 | - name: Install ubuntu gl 35 | if: runner.os == 'Linux' 36 | run: | 37 | sudo apt-get update && sudo apt-get install libegl1 38 | - name: Install dependencies 39 | run: | 40 | pip install --upgrade pip 41 | pip install setuptools tox tox-gh-actions 42 | shell: bash 43 | 44 | - name: Test with tox 45 | run: tox 46 | env: 47 | PLATFORM: ${{ matrix.platform }} 48 | 49 | - name: Coverage 50 | uses: codecov/codecov-action@v3 51 | 52 | 53 | deploy: 54 | # this will run when you have tagged a commit, starting with "v*" 55 | # and requires that you have put your twine API key in your 56 | # github secrets (see readme for details) 57 | needs: [test] 58 | runs-on: ubuntu-latest 59 | if: contains(github.ref, 'tags') 60 | steps: 61 | - uses: actions/checkout@v3 62 | - name: Set up Python 63 | uses: actions/setup-python@v4 64 | with: 65 | python-version: "3.x" 66 | - name: Install dependencies 67 | run: | 68 | python -m pip install --upgrade pip 69 | pip install -U setuptools setuptools_scm wheel twine 70 | - name: Build and publish 71 | env: 72 | TWINE_USERNAME: __token__ 73 | TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} 74 | run: | 75 | git tag 76 | python setup.py sdist bdist_wheel 77 | twine upload dist/* 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Windows image file caches 2 | Thumbs.db 3 | ehthumbs.db 4 | 5 | # Folder config file 6 | Desktop.ini 7 | 8 | # Recycle Bin used on file shares 9 | $RECYCLE.BIN/ 10 | 11 | # Compiled source 12 | *.mexa64 13 | *.mexw64 14 | *.asv 15 | 16 | # Windows shortcuts 17 | *.lnk 18 | 19 | # python and jupyter 20 | *.npy 21 | jupyter/.ipynb_checkpoints/ 22 | jupyter/__pycache__/ 23 | .ipynb_checkpoints/ 24 | __pycache__/ 25 | dist/ 26 | suite2p.egg-info/ 27 | build/ 28 | suite2p/classifiers/classifier_user.npy 29 | suite2p/ops/ops_user.npy 30 | *.ipynb 31 | 32 | # ========================= 33 | # Operating System Files 34 | # ========================= 35 | 36 | # OSX 37 | # ========================= 38 | 39 | .DS_Store 40 | .AppleDouble 41 | .LSOverride 42 | 43 | # Thumbnails 44 | ._* 45 | 46 | # Files that might appear on external disk 47 | .Spotlight-V100 48 | .Trashes 49 | 50 | # Directories potentially created on remote AFP share 51 | .AppleDB 52 | .AppleDesktop 53 | Network Trash Folder 54 | Temporary Items 55 | .apdisk 56 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | split_before_named_assigns = false 4 | column_limit = 88 5 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | ignore: 2 | - "rastermap/gui/*" 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: auto 8 | # adjust accordingly based on how flaky your tests are 9 | # this allows a 5% drop from the previous base commit coverage 10 | threshold: 5% 11 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import os, warnings, time, tempfile, datetime, pathlib, shutil, subprocess 2 | from urllib.request import urlopen 3 | from urllib.parse import urlparse 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | import pytest 7 | 8 | def download_url_to_file(url, dst, progress=True): 9 | r"""Download object at the given URL to a local path. 10 | Thanks to torch, slightly modified 11 | Args: 12 | url (string): URL of the object to download 13 | dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` 14 | progress (bool, optional): whether or not to display a progress bar to stderr 15 | Default: True 16 | """ 17 | file_size = None 18 | import ssl 19 | ssl._create_default_https_context = ssl._create_unverified_context 20 | u = urlopen(url) 21 | meta = u.info() 22 | if hasattr(meta, 'getheaders'): 23 | content_length = meta.getheaders("Content-Length") 24 | else: 25 | content_length = meta.get_all("Content-Length") 26 | if content_length is not None and len(content_length) > 0: 27 | file_size = int(content_length[0]) 28 | # We deliberately save it in a temp file and move it after 29 | dst = os.path.expanduser(dst) 30 | dst_dir = os.path.dirname(dst) 31 | f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) 32 | try: 33 | with tqdm(total=file_size, disable=not progress, 34 | unit='B', unit_scale=True, unit_divisor=1024) as pbar: 35 | while True: 36 | buffer = u.read(8192) 37 | if len(buffer) == 0: 38 | break 39 | f.write(buffer) 40 | pbar.update(len(buffer)) 41 | f.close() 42 | shutil.move(f.name, dst) 43 | finally: 44 | f.close() 45 | if os.path.exists(f.name): 46 | os.remove(f.name) 47 | 48 | @pytest.fixture() 49 | def test_file(): 50 | ddir = Path.home().joinpath('.rastermap') 51 | ddir.mkdir(exist_ok=True) 52 | data_dir = ddir.joinpath('data') 53 | data_dir.mkdir(exist_ok=True) 54 | url = "https://osf.io/download/67f008f2f74150d8738b8257/" 55 | test_file = str(data_dir.joinpath("neuropop_test_data.npz")) 56 | if not os.path.exists(test_file): 57 | download_url_to_file(url, test_file) 58 | return test_file 59 | -------------------------------------------------------------------------------- /notebooks/rastermap_singleneurons.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/MouseLand/rastermap/blob/main/notebooks/rastermap_singleneurons.ipynb)\n", 9 | "\n", 10 | "# Rastermap sorting of 137 neurons\n", 11 | "\n", 12 | "We will use a processed version of the data from [Grosmark & Buzsaki, 2016](https://crcns.org/data-sets/hc/hc-11/about-hc-11). 137 neurons from rat hippocampal CA1 were recorded over several hours using eight bilateral silicon-probes. We selected the 33 minute period from the recording in which the rat traverses a linear track of length 1.6 meters. We binned the neural activity into time bins of length 200 ms. " 13 | ] 14 | }, 15 | { 16 | "attachments": {}, 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "First we will install the required packages, if not already installed. If on google colab, it will require you to click the \"RESTART RUNTIME\" button because we are updating numpy." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "!pip install numpy>=1.24 # (required for google colab)\n", 30 | "!pip install rastermap\n", 31 | "!pip install matplotlib" 32 | ] 33 | }, 34 | { 35 | "attachments": {}, 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "### Load data and import libraries\n", 40 | "\n", 41 | "If not already downloaded, the following cell will automatically download the processed data stored [here](https://osf.io/szmw6)." 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "import numpy as np\n", 51 | "import matplotlib.pyplot as plt\n", 52 | "# importing rastermap\n", 53 | "# (this will be slow the first time since it is compiling the numba functions)\n", 54 | "from rastermap import Rastermap, utils\n", 55 | "from scipy.stats import zscore\n", 56 | "\n", 57 | "# download processed hippocampus recording from Grosmark & Buzsaki 2016\n", 58 | "filename = utils.download_data(data_type=\"hippocampus\")\n", 59 | "\n", 60 | "dat = np.load(filename)\n", 61 | "\n", 62 | "# spks is neurons by time\n", 63 | "# (each timepoint is 200 ms)\n", 64 | "spks = dat[\"spks\"]\n", 65 | "n_neurons, n_time = spks.shape\n", 66 | "print(f\"{n_neurons} neurons by {n_time} timepoints\")\n", 67 | "# zscore activity (each neuron activity trace is then mean 0 and standard-deviation 1)\n", 68 | "spks = zscore(spks, axis=1)\n", 69 | "\n", 70 | "# location of the rat and speed\n", 71 | "loc2d = dat[\"loc2d\"] # 2D location\n", 72 | "loc_signed = dat[\"loc_signed\"] # left runs are positive and right runs are negative\n", 73 | "speed = (np.diff(loc2d, axis=0)**2).sum(axis=1)**0.5\n", 74 | "speed = np.concatenate((np.zeros((1,)), speed), axis=0)\n", 75 | "\n", 76 | "# which neurons in the recording are pyramidal cells\n", 77 | "pyr_cells = dat[\"pyr_cells\"].astype(\"int\")" 78 | ] 79 | }, 80 | { 81 | "attachments": {}, 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "### Run Rastermap\n", 86 | "\n", 87 | "Let's sort the single neurons with Rastermap, skipping clustering and upsampling:" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "model = Rastermap(n_clusters=None, # None turns off clustering and sorts single neurons \n", 97 | " n_PCs=64, # use fewer PCs than neurons\n", 98 | " locality=0.1, # some locality in sorting (this is a value from 0-1)\n", 99 | " time_lag_window=15, # use future timepoints to compute correlation\n", 100 | " grid_upsample=0, # 0 turns off upsampling since we're using single neurons\n", 101 | " ).fit(spks)\n", 102 | "y = model.embedding # neurons x 1\n", 103 | "isort = model.isort" 104 | ] 105 | }, 106 | { 107 | "attachments": {}, 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "compute tuning curves along linear corridor" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "n_pos = 15\n", 121 | "bins = np.arange(-1, 1+1./n_pos, 1./n_pos)\n", 122 | "ibin = np.digitize(loc_signed, bins) - 1\n", 123 | "n_bins = ibin.max()\n", 124 | "inan = np.isnan(loc_signed)\n", 125 | "ibin[inan] = -1\n", 126 | "tcurves = np.zeros((spks.shape[0], n_bins))\n", 127 | "for b in range(n_bins):\n", 128 | " tcurves[:, b] = spks[:, ibin==b].mean(axis=1)\n", 129 | "tcurves -= tcurves.mean(axis=1, keepdims=True)" 130 | ] 131 | }, 132 | { 133 | "attachments": {}, 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "### Visualization\n", 138 | "\n", 139 | "Use the Rastermap sorting to visualize the neural activity and tuning curves:" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "# timepoints to visualize\n", 149 | "xmin = 1000\n", 150 | "xmax = xmin + 2000\n", 151 | "\n", 152 | "# make figure with grid for easy plotting\n", 153 | "fig = plt.figure(figsize=(12,8), dpi=200)\n", 154 | "grid = plt.GridSpec(10, 24, figure=fig, wspace = 0.1, hspace = 0.4)\n", 155 | "\n", 156 | "# plot location\n", 157 | "ax = plt.subplot(grid[0, :-5])\n", 158 | "ax.plot(loc2d[xmin:xmax])#, color=kp_colors[0])\n", 159 | "ax.set_xlim([0, xmax-xmin])\n", 160 | "ax.axis(\"off\")\n", 161 | "ax.set_title(\"2D location\")\n", 162 | "\n", 163 | "# plot running speed\n", 164 | "ax = plt.subplot(grid[1, :-5])\n", 165 | "ax.plot(speed[xmin:xmax], color=0.5*np.ones(3))\n", 166 | "ax.set_xlim([0, xmax-xmin])\n", 167 | "ax.axis(\"off\")\n", 168 | "ax.set_title(\"running speed\")\n", 169 | "\n", 170 | "# plot sorted neural activity\n", 171 | "ax = plt.subplot(grid[2:, :-5])\n", 172 | "ax.imshow(spks[isort, xmin:xmax], cmap=\"gray_r\", vmin=0, vmax=1.2, aspect=\"auto\")\n", 173 | "ax.set_xlabel(\"time\")\n", 174 | "ax.set_ylabel(\"superneurons\")\n", 175 | "\n", 176 | "# excitatory cells in yellow, and inhibitory cells in dark blue\n", 177 | "# (could replace this with a colorbar or other property)\n", 178 | "ax = plt.subplot(grid[2:, -5])\n", 179 | "ax.imshow(pyr_cells[isort, np.newaxis],\n", 180 | " cmap=\"viridis\", aspect=\"auto\")\n", 181 | "ax.axis(\"off\")\n", 182 | "\n", 183 | "# plot single-neuron tuning curves\n", 184 | "ax = plt.subplot(grid[2:, -4:])\n", 185 | "x = np.arange(0, n_pos)\n", 186 | "dy = 2\n", 187 | "xpad = n_pos/10\n", 188 | "nn = spks.shape[0]\n", 189 | "for t in range(len(tcurves)):\n", 190 | " ax.plot(x, tcurves[isort[t], :n_pos]*dy - dy/2 + nn - t, \n", 191 | " color=\"k\", lw=0.5)\n", 192 | " ax.plot(x+n_pos+xpad, tcurves[isort[t], n_pos:]*dy - dy/2 + nn - t, \n", 193 | " color=\"k\", lw=0.5)\n", 194 | "for j in range(2):\n", 195 | " xstr = \"position\\n(left run)\" if j==0 else \"position\\n(right run)\"\n", 196 | " ax.text(n_pos/2 + j*(n_pos+xpad), -14, xstr, ha=\"center\")\n", 197 | " ax.text(j*(n_pos+xpad), -3, \"0\")\n", 198 | " ax.text(n_pos + j*(n_pos+xpad), -3, \"1.6\", ha=\"right\")\n", 199 | "ax.set_ylim([0, nn])\n", 200 | "ax.axis(\"off\")" 201 | ] 202 | }, 203 | { 204 | "attachments": {}, 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "### Settings\n", 209 | "\n", 210 | "You can see all the rastermap settings with `Rastermap?`" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "Rastermap?" 220 | ] 221 | }, 222 | { 223 | "attachments": {}, 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "### Outputs\n", 228 | "\n", 229 | "All the attributes assigned to the Rastermap `model` are listed with `Rastermap.fit?`" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "Rastermap.fit?" 239 | ] 240 | } 241 | ], 242 | "metadata": { 243 | "kernelspec": { 244 | "display_name": "Python 3.8.10 64-bit", 245 | "language": "python", 246 | "name": "python3" 247 | }, 248 | "language_info": { 249 | "codemirror_mode": { 250 | "name": "ipython", 251 | "version": 3 252 | }, 253 | "file_extension": ".py", 254 | "mimetype": "text/x-python", 255 | "name": "python", 256 | "nbconvert_exporter": "python", 257 | "pygments_lexer": "ipython3", 258 | "version": "3.9.16" 259 | }, 260 | "orig_nbformat": 4, 261 | "vscode": { 262 | "interpreter": { 263 | "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" 264 | } 265 | } 266 | }, 267 | "nbformat": 4, 268 | "nbformat_minor": 2 269 | } 270 | -------------------------------------------------------------------------------- /notebooks/rastermap_zebrafish.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/MouseLand/rastermap/blob/main/notebooks/rastermap_zebrafish.ipynb)\n", 9 | "\n", 10 | "# Rastermap sorting of zebrafish neural activity\n", 11 | "\n", 12 | "We will use a zebrafish wholebrain neural activity recording from [Chen*, Mu*, Hu*, Kuan* et al 2018](https://doi.org/10.1016/j.neuron.2018.09.042). The full dataset is available [here](https://doi.org/10.25378/janelia.7272617). The recordings were performed at a rate of 2.1 Hz. We took the neurons with the highest variance signals and deconvolved them to reduce long timescales in the data from the calcium sensor." 13 | ] 14 | }, 15 | { 16 | "attachments": {}, 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "First we will install the required packages, if not already installed. If on google colab, it will require you to click the \"RESTART RUNTIME\" button because we are updating numpy." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "!pip install numpy>=1.24 # (required for google colab)\n", 30 | "!pip install rastermap \n", 31 | "!pip install matplotlib" 32 | ] 33 | }, 34 | { 35 | "attachments": {}, 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "### Load data and import libraries\n", 40 | "\n", 41 | "If not already downloaded, the following cell will automatically download the processed data stored [here](https://osf.io/2w8pa)." 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "import numpy as np\n", 51 | "import matplotlib.pyplot as plt\n", 52 | "# importing rastermap\n", 53 | "# (this will be slow the first time since it is compiling the numba functions)\n", 54 | "from rastermap import Rastermap, utils\n", 55 | "from scipy.stats import zscore\n", 56 | "\n", 57 | "# download spontaneous activity\n", 58 | "filename = utils.download_data(data_type=\"fish\")\n", 59 | "\n", 60 | "dat = np.load(filename)\n", 61 | "\n", 62 | "# spks is neurons by time\n", 63 | "# (each timepoint is 476 ms)\n", 64 | "spks = dat[\"spks\"]\n", 65 | "n_neurons, n_time = spks.shape\n", 66 | "print(f\"{n_neurons} neurons by {n_time} timepoints\")\n", 67 | "\n", 68 | "# zscore activity (each neuron activity trace is then mean 0 and standard-deviation 1)\n", 69 | "spks = zscore(spks, axis=1)\n", 70 | "\n", 71 | "# XYZ position of each neuron in the recording\n", 72 | "xyz = dat[\"xyz\"]\n", 73 | "\n", 74 | "# load the stimulus times\n", 75 | "stims = dat[\"stims\"]\n", 76 | "# stim colors\n", 77 | "fcolor = np.zeros((stims.max()+1, 4))\n", 78 | "fcolor[0:3] = np.array([[0., 0.5, 1.0, 1.0], [1.0, 0.0, 1.0, 1.0], \n", 79 | " [1., 1., 1., 1.]])\n", 80 | "fcolor[8:12] = np.array([[1.0,0.0,0,1],\n", 81 | " [0.8,1.0,0,1], [0,0,1,1], [0,1,1,1]])\n", 82 | "\n", 83 | "# load the fictive swimming\n", 84 | "swimming = dat[\"swimming\"]\n" 85 | ] 86 | }, 87 | { 88 | "attachments": {}, 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "### Run Rastermap\n", 93 | "\n", 94 | "Let's sort the single neurons with Rastermap, with clustering and upsampling:" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "model = Rastermap(n_clusters=100, # number of clusters to compute\n", 104 | " n_PCs=200, # number of PCs to use\n", 105 | " locality=0.1, # locality in sorting is low here to get more global sorting (this is a value from 0-1)\n", 106 | " time_lag_window=5, # use future timepoints to compute correlation\n", 107 | " grid_upsample=10, # default value, 10 is good for large recordings\n", 108 | " ).fit(spks)\n", 109 | "y = model.embedding # neurons x 1\n", 110 | "isort = model.isort" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "Let's create superneurons from Rastermap -- we sort the data and then sum over neighboring neurons:" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "nbin = 50 # number of neurons to bin over \n", 127 | "sn = utils.bin1d(spks[isort], bin_size=nbin, axis=0) # bin over neuron axis" 128 | ] 129 | }, 130 | { 131 | "attachments": {}, 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "### Visualization\n", 136 | "\n", 137 | "Use the Rastermap sorting to visualize the neural activity (see Figure 4 from the paper for the stimulus legend):" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "# timepoints to visualize\n", 147 | "xmin = 5700\n", 148 | "xmax = 7860\n", 149 | "\n", 150 | "# make figure with grid for easy plotting\n", 151 | "fig = plt.figure(figsize=(12,6), dpi=200)\n", 152 | "grid = plt.GridSpec(9, 20, figure=fig, wspace = 0.05, hspace = 0.3)\n", 153 | "\n", 154 | "# plot swimming speed\n", 155 | "ax = plt.subplot(grid[0, :-1])\n", 156 | "ax.plot(swimming[xmin:xmax, 0], color=fcolor[11])\n", 157 | "ax.plot(swimming[xmin:xmax, 1], color=fcolor[10])\n", 158 | "ax.set_xlim([0, xmax-xmin])\n", 159 | "ax.axis(\"off\")\n", 160 | "ax.set_title(\"swimming speed\")\n", 161 | "\n", 162 | "# plot superneuron activity\n", 163 | "ax = plt.subplot(grid[1:, :-1])\n", 164 | "ax.imshow(sn[:, xmin:xmax], cmap=\"gray_r\", vmin=0, vmax=0.8, aspect=\"auto\")\n", 165 | "ax.set_xlabel(\"time\")\n", 166 | "ax.set_ylabel(\"superneurons\")\n", 167 | "\n", 168 | "# color time periods by stimuli\n", 169 | "from matplotlib import patches\n", 170 | "nn = sn.shape[0]\n", 171 | "stims_t = stims[xmin:xmax]\n", 172 | "starts = np.nonzero(np.diff(stims_t))\n", 173 | "starts = np.append(np.array([0]), starts)\n", 174 | "starts = np.append(starts, np.array([len(stims_t)-1]))\n", 175 | "for n in range(len(starts)-1):\n", 176 | " start = starts[n]+1\n", 177 | " stype = stims_t[start]\n", 178 | " if stype!=3:\n", 179 | " width = starts[n+1] - start + min(0, start)\n", 180 | " start = max(0, start)\n", 181 | " ax.add_patch(\n", 182 | " patches.Rectangle(xy=(start, 0), width=width,\n", 183 | " height=nn, facecolor=fcolor[stype], \n", 184 | " edgecolor=None, alpha=0.15*(stype!=2)))\n", 185 | "\n", 186 | "\n", 187 | "ax = plt.subplot(grid[1:, -1])\n", 188 | "ax.imshow(np.arange(0, len(sn))[:,np.newaxis], cmap=\"gist_ncar\", aspect=\"auto\")\n", 189 | "ax.axis(\"off\")" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "Color the neurons by their position in the rastermap:" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "plt.figure(figsize=(5, 5))\n", 206 | "plt.scatter(xyz[:,1], xyz[:,0], s=1, c=y, cmap=\"gist_ncar\", alpha=0.25)\n", 207 | "plt.xlabel('X position')\n", 208 | "plt.ylabel('Y position')\n", 209 | "plt.axis(\"square\")" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "We can also divide the rastermap into sections to more easily visualize spatial relations (as in Figure 4):" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "ny, nx = 3, 6\n", 226 | "nxy = nx * ny\n", 227 | "\n", 228 | "# divide into nxy sections\n", 229 | "nb = len(isort) // nxy\n", 230 | "colors = plt.get_cmap(\"gist_ncar\")(np.linspace(0, 0.9, nxy))\n", 231 | "\n", 232 | "# make figure with grid for easy plotting\n", 233 | "fig = plt.figure(figsize=(12,6), dpi=200)\n", 234 | "\n", 235 | "grid = plt.GridSpec(ny, nx, figure=fig, wspace = 0.25, hspace = 0.1)\n", 236 | "for j in range(nx):\n", 237 | " for k in range(ny):\n", 238 | " ax = plt.subplot(grid[k,j])\n", 239 | " # plot all neurons\n", 240 | " subsample = 25\n", 241 | " ax.scatter(xyz[:,1][::subsample], xyz[:,0][::subsample], s=2, alpha=1, \n", 242 | " color=0.9*np.ones(3), rasterized=True)\n", 243 | " ip = j + k*nx\n", 244 | " ix = isort[ip*nb : (ip+1)*nb]\n", 245 | " subsample = 1\n", 246 | " ax.scatter(xyz[ix,1][::subsample], xyz[ix,0][::subsample],\n", 247 | " s=0.5, alpha=0.3, color=colors[ip])\n", 248 | " ax.axis(\"off\")\n", 249 | " ax.axis(\"square\")\n", 250 | " ax.text(0.1,0,str(ip+1), transform=ax.transAxes, ha=\"right\")" 251 | ] 252 | }, 253 | { 254 | "attachments": {}, 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "### Settings\n", 259 | "\n", 260 | "You can see all the rastermap settings with `Rastermap?`" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "Rastermap?" 270 | ] 271 | }, 272 | { 273 | "attachments": {}, 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "### Outputs\n", 278 | "\n", 279 | "All the attributes assigned to the Rastermap `model` are listed with `Rastermap.fit?`" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "Rastermap.fit?" 289 | ] 290 | } 291 | ], 292 | "metadata": { 293 | "kernelspec": { 294 | "display_name": "Python 3.9.16 ('rastermap')", 295 | "language": "python", 296 | "name": "python3" 297 | }, 298 | "language_info": { 299 | "codemirror_mode": { 300 | "name": "ipython", 301 | "version": 3 302 | }, 303 | "file_extension": ".py", 304 | "mimetype": "text/x-python", 305 | "name": "python", 306 | "nbconvert_exporter": "python", 307 | "pygments_lexer": "ipython3", 308 | "version": "3.9.16" 309 | }, 310 | "orig_nbformat": 4, 311 | "vscode": { 312 | "interpreter": { 313 | "hash": "998540cc2fc2836a46e99cd3ca3c37c375205941b23fd1eb4b203c48f2be758f" 314 | } 315 | } 316 | }, 317 | "nbformat": 4, 318 | "nbformat_minor": 2 319 | } 320 | -------------------------------------------------------------------------------- /paper/fig1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import matplotlib.pyplot as plt\n", 10 | "\n", 11 | "import numpy as np\n", 12 | "from scipy.ndimage import gaussian_filter1d\n", 13 | "from rastermap.svd import SVD\n", 14 | "import sys, os\n", 15 | "from rastermap import Rastermap\n", 16 | "from scipy.stats import zscore\n", 17 | "from rastermap.utils import bin1d\n", 18 | "\n", 19 | "sys.path.insert(0, '/github/rastermap/paper/')\n", 20 | "import metrics, simulations, fig1\n", 21 | "\n", 22 | "root = \"/media/carsen/ssd2/rastermap_paper/\"\n", 23 | "os.makedirs(os.path.join(root, \"simulations/\"), exist_ok=True)\n" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "### make simulations" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "n_per_module = 1000\n", 40 | "for random_state in range(0, 10):\n", 41 | " out = simulations.make_full_simulation(n_per_module=n_per_module, random_state=random_state)\n", 42 | " spks, xi_all, stim_times_all, psth, psth_spont, iperm = out\n", 43 | " np.savez(os.path.join(root, \"simulations/\", f\"sim_{random_state}.npz\"), \n", 44 | " spks=spks, xi_all=xi_all, \n", 45 | " stim_times_all=np.array(stim_times_all, dtype=object), \n", 46 | " psth=psth, psth_spont=psth_spont, iperm=iperm)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "### run embedding algorithms and benchmark performance" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "simulations.embedding_performance(root, save=True)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "### make figure" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "# root path has folder \"simulations\" with saved results\n", 79 | "# will save figures to \"figures\" folder\n", 80 | "os.makedirs(os.path.join(root, \"figures/\"), exist_ok=True)\n", 81 | "fig1.fig1(root, save_figure=True) " 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "### supp t-SNE + UMAP" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# run t-SNE with different perplexities\n", 98 | "knn = np.array([10,50,100,200,500])\n", 99 | "mnn_all = np.zeros((10, 7, len(knn)))\n", 100 | "rho_all = np.zeros((10, 7))\n", 101 | "embs_all = np.zeros((10, 7, 6000, 1))\n", 102 | "scores_all = np.zeros((10, 2, 8, 5))\n", 103 | "for random_state in range(10):\n", 104 | " print(random_state)\n", 105 | " dat = np.load(os.path.join(root, \"simulations\", f\"sim_{random_state}.npz\"), allow_pickle=True)\n", 106 | " spks = dat[\"spks\"]\n", 107 | " # run rastermap to get PCs\n", 108 | " model = Rastermap(n_clusters=100, n_PCs=200, locality=0.8,\n", 109 | " time_lag_window=10, time_bin=10).fit(spks) \n", 110 | " perplexities = []\n", 111 | " j = 0\n", 112 | " for perplexity in [10,30,60,100,200]:\n", 113 | " M = metrics.run_TSNE(model.Usv, perplexities=[perplexity])\n", 114 | " embs_all[random_state, j] = M\n", 115 | " j += 1\n", 116 | " perplexities.append([perplexity, 0])\n", 117 | " if perplexity > 60:\n", 118 | " M = metrics.run_TSNE(model.Usv, perplexities=[30, perplexity])\n", 119 | " embs_all[random_state, j] = M\n", 120 | " j += 1\n", 121 | " perplexities.append([30, perplexity])\n", 122 | " contamination_scores, triplet_scores = metrics.benchmarks(dat[\"xi_all\"], embs_all[random_state])\n", 123 | " mnn, rho = metrics.embedding_quality_gt(dat[\"xi_all\"], embs_all[random_state], knn=knn.copy())\n", 124 | " mnn_all[random_state], rho_all[random_state] = mnn, rho\n", 125 | " scores_all[random_state] = np.stack((contamination_scores, triplet_scores), \n", 126 | " axis=0)\n", 127 | " \n", 128 | "np.savez(os.path.join(root, \"simulations\", \"sim_performance_tsne.npz\"), \n", 129 | " embs_all=embs_all, scores_all=scores_all, \n", 130 | " mnn_all=mnn_all, rho_all=rho_all, knn=knn,\n", 131 | " perplexities=perplexities)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "# run UMAP with different n_neighbors\n", 141 | "knn = np.array([10,50,100,200,500])\n", 142 | "n_neighbors = np.array([5, 15, 30, 60, 100, 200])\n", 143 | "mnn_all = np.zeros((10, 6, len(knn)))\n", 144 | "rho_all = np.zeros((10, 6))\n", 145 | "embs_all = np.zeros((10, 6, 6000, 1))\n", 146 | "scores_all = np.zeros((10, 2, 7, 5))\n", 147 | "for random_state in range(10):\n", 148 | " print(random_state)\n", 149 | " dat = np.load(os.path.join(root, \"simulations\", f\"sim_{random_state}.npz\"), allow_pickle=True)\n", 150 | " spks = dat[\"spks\"]\n", 151 | " # run rastermap to get PCs\n", 152 | " model = Rastermap(n_clusters=100, n_PCs=200, locality=0.8,\n", 153 | " time_lag_window=10, time_bin=10).fit(spks) \n", 154 | " j = 0\n", 155 | " for nneigh in n_neighbors:\n", 156 | " M = metrics.run_UMAP(model.Usv, n_neighbors=nneigh)\n", 157 | " embs_all[random_state, j] = M\n", 158 | " j += 1\n", 159 | " contamination_scores, triplet_scores = metrics.benchmarks(dat[\"xi_all\"], embs_all[random_state])\n", 160 | " mnn, rho = metrics.embedding_quality_gt(dat[\"xi_all\"], embs_all[random_state], knn=knn.copy())\n", 161 | " mnn_all[random_state], rho_all[random_state] = mnn, rho\n", 162 | " scores_all[random_state] = np.stack((contamination_scores, triplet_scores), \n", 163 | " axis=0)\n", 164 | " \n", 165 | "np.savez(os.path.join(root, \"simulations\", \"sim_performance_umap.npz\"), \n", 166 | " embs_all=embs_all, scores_all=scores_all, \n", 167 | " mnn_all=mnn_all, rho_all=rho_all, knn=knn,\n", 168 | " n_neighbors=n_neighbors)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "metadata": {}, 174 | "source": [ 175 | "### supp neighbor scores" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "from tqdm import trange\n", 185 | "# compute neighbor scores for original embeddings\n", 186 | "knn = np.array([10,50,100,200,500])\n", 187 | "d2 = np.load(os.path.join(root, \"simulations\", \"sim_performance.npz\"), allow_pickle=True) \n", 188 | "mnn_all = np.zeros((10, 7, len(knn)))\n", 189 | "rho_all = np.zeros((10, 7))\n", 190 | "for random_state in trange(10):\n", 191 | " dat = np.load(os.path.join(root, \"simulations\", f\"sim_{random_state}.npz\"), allow_pickle=True)\n", 192 | " embs = d2[\"embs_all\"][random_state].squeeze()\n", 193 | " mnn, rho = metrics.embedding_quality_gt(dat[\"xi_all\"], embs, knn=knn.copy())\n", 194 | " mnn_all[random_state], rho_all[random_state] = mnn, rho\n", 195 | "np.savez(os.path.join(root, \"simulations\", \"sim_performance_neigh.npz\"), \n", 196 | " mnn_all=mnn_all, rho_all=rho_all, knn=knn)\n" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "fig1.suppfig_scores(root, save_figure=True)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": {}, 211 | "source": [ 212 | "### supp power-law only simulation" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "# create sims and benchmark\n", 222 | "simulations.spont_simulations(root)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "### supp reproducible" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "simulations.repro_algs(root)\n", 239 | "fig1.suppfig_repro(root, save_fig=True)" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "### supp parameter changes" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "# run param sweeps\n", 256 | "simulations.params_rastermap(root)\n", 257 | "# make figure\n", 258 | "fig1.suppfig_params(root)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": {}, 264 | "source": [ 265 | "### supp no power-law noise added" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "\n", 275 | "n_per_module = 1000\n", 276 | "for random_state in range(1, 10):\n", 277 | " out = simulations.make_full_simulation(n_per_module=n_per_module, \n", 278 | " random_state=random_state, add_spont=False)\n", 279 | " spks, xi_all, stim_times_all, psth, psth_spont, iperm = out\n", 280 | " np.savez(os.path.join(root, \"simulations/\", f\"sim_no_add_spont_{random_state}.npz\"), \n", 281 | " spks=spks, xi_all=xi_all, \n", 282 | " stim_times_all=np.array(stim_times_all, dtype=object), \n", 283 | " psth=psth, psth_spont=psth_spont, iperm=iperm)\n", 284 | "\n", 285 | "# 6000 neurons in simulation with 5 modules\n", 286 | "from tqdm import trange\n", 287 | "embs_all = np.zeros((10, 7, 6000, 1))\n", 288 | "scores_all = np.zeros((10, 2, 8, 5))\n", 289 | "algos = [\"rastermap\", \"tSNE\", \"UMAP\", \"isomap\", \"laplacian\\neigenmaps\", \"hierarchical\\nclustering\", \"PCA\"]\n", 290 | "\n", 291 | "for random_state in trange(10):\n", 292 | " path = os.path.join(root, \"simulations\", f\"sim_no_add_spont_{random_state}.npz\")\n", 293 | " dat = np.load(path, allow_pickle=True)\n", 294 | " spks = dat[\"spks\"]\n", 295 | " embs, model = simulations.run_algos(spks, time_lag_window=10, locality=0.8)\n", 296 | "\n", 297 | " # benchmarks\n", 298 | " contamination_scores, triplet_scores = metrics.benchmarks(dat[\"xi_all\"], \n", 299 | " embs.copy())\n", 300 | " embs_all[random_state] = embs\n", 301 | " scores_all[random_state] = np.stack((contamination_scores, triplet_scores), \n", 302 | " axis=0)\n", 303 | " if random_state==0:\n", 304 | " xi_all = dat[\"xi_all\"]\n", 305 | "\n", 306 | "np.savez(os.path.join(root, \"simulations\", \"sim_no_add_spont_performance.npz\"), \n", 307 | " scores_all=scores_all, \n", 308 | " embs_all=embs_all,\n", 309 | " xi_all=xi_all)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "# make figure\n", 319 | "fig1.suppfig_spont(root)" 320 | ] 321 | } 322 | ], 323 | "metadata": { 324 | "kernelspec": { 325 | "display_name": "Python 3.9.16 ('rastermap')", 326 | "language": "python", 327 | "name": "python3" 328 | }, 329 | "language_info": { 330 | "codemirror_mode": { 331 | "name": "ipython", 332 | "version": 3 333 | }, 334 | "file_extension": ".py", 335 | "mimetype": "text/x-python", 336 | "name": "python", 337 | "nbconvert_exporter": "python", 338 | "pygments_lexer": "ipython3", 339 | "version": "3.9.16" 340 | }, 341 | "orig_nbformat": 4, 342 | "vscode": { 343 | "interpreter": { 344 | "hash": "998540cc2fc2836a46e99cd3ca3c37c375205941b23fd1eb4b203c48f2be758f" 345 | } 346 | } 347 | }, 348 | "nbformat": 4, 349 | "nbformat_minor": 2 350 | } 351 | -------------------------------------------------------------------------------- /paper/fig2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys, os\n", 10 | "import numpy as np\n", 11 | "from scipy.stats import zscore\n", 12 | "from rastermap import Rastermap, utils\n", 13 | "\n", 14 | "# path to paper code\n", 15 | "sys.path.insert(0, '/github/rastermap/paper')\n", 16 | "from loaders import tuning_curves_VR\n", 17 | "import fig2\n", 18 | "\n", 19 | "# path to directory with data etc\n", 20 | "### *** CHANGE THIS TO WHEREEVER YOU ARE DOWNLOADING THE DATA ***\n", 21 | "root = \"/media/carsen/ssd2/rastermap_paper/\"\n", 22 | "# (in this folder we have a \"data\" folder and a \"results\" folder)\n", 23 | "os.makedirs(os.path.join(root, \"data\"), exist_ok=True)\n", 24 | "os.makedirs(os.path.join(root, \"results\"), exist_ok=True)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "### load virtual reality task data\n", 32 | "\n", 33 | "(this data will be available upon publication of the paper)\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "dat = np.load(os.path.join(root, \"data/\", \"corridor_neur.npz\"))\n", 43 | "corridor = np.load(os.path.join(root, \"data/\", \"corridor_behavior.npz\"))\n", 44 | "\n", 45 | "xpos, ypos, spks = dat[\"xpos\"], dat[\"ypos\"], dat[\"spks\"]\n", 46 | "spks = zscore(spks, axis=1)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "### run rastermap and compute tuning curves" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "model = Rastermap(n_clusters=100, n_PCs=200, \n", 63 | " time_lag_window=10, locality=0.75).fit(spks)\n", 64 | "isort = model.isort \n", 65 | "cc_nodes = model.cc\n", 66 | "bin_size = 100\n", 67 | "sn = zscore(utils.bin1d(spks[isort], bin_size, axis=0), axis=1)\n", 68 | "corridor_tuning = tuning_curves_VR(sn, corridor[\"VRpos\"], corridor[\"corridor_starts\"])\n", 69 | "\n", 70 | "# sort in time\n", 71 | "model2 = Rastermap(n_clusters=100, n_splits=0, locality=0.,\n", 72 | " n_PCs=200).fit(sn.T)\n", 73 | "isort2 = model2.isort\n", 74 | "\n", 75 | "np.savez(os.path.join(root, \"results\", \"corridor_proc.npz\"),\n", 76 | " sn=sn, xpos=xpos, ypos=ypos, isort=isort, isort2=isort2,\n", 77 | " cc_nodes=cc_nodes, corridor_tuning=corridor_tuning)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "### make figure" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "# root path has folder \"results\" with saved results\n", 94 | "# will save figures to \"figures\" folder\n", 95 | "os.makedirs(os.path.join(root, \"figures/\"), exist_ok=True)\n", 96 | "fig2.fig2(root)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "### supplementary analysis" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "import metrics\n", 113 | "\n", 114 | "ys = [metrics.run_TSNE(model.Usv), \n", 115 | " metrics.run_UMAP(model.Usv)]\n", 116 | "\n", 117 | "snys = []\n", 118 | "ctunings = []\n", 119 | "for k in range(2):\n", 120 | " isorty = ys[k][:,0].argsort()\n", 121 | " sny = zscore(utils.bin1d(spks[isorty], 100, axis=0))\n", 122 | " ctuning = tuning_curves_VR(sny, corridor[\"VRpos\"], corridor[\"corridor_starts\"])\n", 123 | " snys.append(sny)\n", 124 | " ctunings.append(ctuning)\n", 125 | "\n", 126 | "np.savez(os.path.join(root, \"results\", \"corridor_supp.npz\"),\n", 127 | " snys=snys, ctunings=ctunings, \n", 128 | " corridor_starts=corridor[\"corridor_starts\"], \n", 129 | " corridor_widths=corridor[\"corridor_widths\"], \n", 130 | " reward_inds=corridor[\"reward_inds\"])" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "d = np.load(os.path.join(root, \"results\", \"corridor_supp.npz\"))\n", 140 | "fig = fig3._suppfig_vr_algs(**d)" 141 | ] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "Python 3.9.16 ('rastermap')", 147 | "language": "python", 148 | "name": "python3" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 3 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython3", 160 | "version": "3.9.16" 161 | }, 162 | "orig_nbformat": 4, 163 | "vscode": { 164 | "interpreter": { 165 | "hash": "998540cc2fc2836a46e99cd3ca3c37c375205941b23fd1eb4b203c48f2be758f" 166 | } 167 | } 168 | }, 169 | "nbformat": 4, 170 | "nbformat_minor": 2 171 | } 172 | -------------------------------------------------------------------------------- /paper/fig2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | import matplotlib.pyplot as plt 5 | from matplotlib import patches 6 | import os 7 | import numpy as np 8 | from fig_utils import * 9 | 10 | ccolor = [[0,1,0], [0,0,0.8]] 11 | 12 | def panel_neuron_pos(fig, grid, il, yratio, xpos0, ypos0, isort, brain_img): 13 | xpos, ypos = xpos0.copy(), -1*ypos0.copy() 14 | ylim = np.array([ypos.min(), ypos.max()]) 15 | xlim = np.array([xpos.min(), xpos.max()]) 16 | ylr = np.diff(ylim)[0] / np.diff(xlim)[0] 17 | 18 | ax = fig.add_subplot(grid[0,0]) 19 | poss = ax.get_position().bounds 20 | ax.set_position([poss[0]+0.01, poss[1]-.04, 1*poss[2], 21 | 1*poss[2] / ylr * yratio]) 22 | poss = ax.get_position().bounds 23 | 24 | memb = np.zeros_like(isort) 25 | memb[isort] = np.arange(0, len(isort)) 26 | subsample = 5 27 | ax.scatter(ypos[::subsample], xpos[::subsample], cmap=cmap_emb, 28 | s=0.5, alpha=0.5, c=memb[::subsample], rasterized=True) 29 | ax.axis("off") 30 | 31 | add_apml(ax, xpos, ypos) 32 | 33 | axin = fig.add_axes([poss[0]-0.02, poss[1] +poss[3]*.8, 34 | poss[2]*0.3, poss[3]*0.3]) 35 | axin.imshow(brain_img) 36 | axin.axis("off") 37 | transl = mtransforms.ScaledTranslation(-8 / 72, -0/ 72, fig.dpi_scale_trans) 38 | il = plot_label(ltr, il, axin, transl, fs_title) 39 | 40 | return il 41 | 42 | def panels_tuning(axs, il, padding, corridor_tuning, label_white=True): 43 | nov = 30 44 | n_corr, nn, npts = corridor_tuning.shape 45 | for icorr in range(n_corr): 46 | ctmax = corridor_tuning[icorr].max() 47 | ctmin = corridor_tuning[icorr].min() 48 | npl = 100 49 | ipl = np.linspace(1, nn-npl//4, npl).astype("int") 50 | for i in ipl: 51 | ct = corridor_tuning[icorr, i].copy() 52 | ct -= ctmin 53 | ct /= ctmax 54 | axs[icorr].plot(np.arange(0, npts), i - ct*nov + nov/2, #(n_sn-i-24)+ct*nov, 55 | color=ccolor[icorr], lw=0.5) 56 | axs[icorr].plot((npts*2/3) * np.ones(2), [0, nn*(1+padding)], 57 | color='k', lw=1, zorder=5) 58 | if label_white: 59 | axs[icorr].text(2/3 + 0.02, 0.02, 'white space start', 60 | transform=axs[icorr].transAxes, va='bottom', rotation=90) 61 | if icorr==0: 62 | axs[icorr].set_title("tuning curves") 63 | #text(0, 1, 'tuning curves', ha='left', 64 | # transform=axs[icorr].transAxes, fontsize="large") 65 | axs[icorr].text(1.1, -0.05, "position (cm)", ha="center", va="top", 66 | transform=axs[icorr].transAxes) 67 | transl = mtransforms.ScaledTranslation(-15 / 72, 5/ 72, axs[icorr].figure.dpi_scale_trans) 68 | il = plot_label(ltr, il, axs[icorr], transl, fs_title) 69 | 70 | axs[icorr].set_xlim([0, npts]) 71 | axs[icorr].set_ylim([0, nn*(1+padding)]) 72 | axs[icorr].invert_yaxis() 73 | axs[icorr].spines["left"].set_visible(False) 74 | axs[icorr].set_yticks([]) 75 | axs[icorr].set_xticks([0, 2/3*100]) 76 | axs[icorr].set_xticklabels(["0", "40"]) 77 | return il 78 | 79 | def panel_raster(fig, ax, il, padding, sn, xmin, xmax, 80 | corridor_starts, corridor_widths, reward_inds, 81 | cmap_neurons=True, 82 | title_str="neural activity in virtual reality"): 83 | poss = ax.get_position().bounds 84 | cax = fig.add_axes([poss[0]-0.035, poss[1]+poss[3]-0.12*poss[3], 85 | poss[3]*0.005, 0.1*poss[3]]) 86 | plot_raster(ax, sn, xmin=xmin, xmax=xmax, 87 | vmax=2, fs=3.38, n_neurons=5000, nper=100, label=True, 88 | padding=padding, cax=cax, cax_label="left", 89 | cax_orientation="vertical", label_pos="right") 90 | #plt.colorbar(im, cax, orientation="horizontal") 91 | #cax.set_xlabel("z-scored\n ") 92 | ax.set_title(title_str) 93 | transl = mtransforms.ScaledTranslation(-15 / 72, 5/ 72, fig.dpi_scale_trans) 94 | il = plot_label(ltr, il, ax, transl, fs_title) 95 | 96 | nn = sn.shape[0] 97 | if cmap_neurons: 98 | cax = fig.add_axes([poss[0]-poss[2]*0.02, poss[1], poss[2]*0.01, poss[3]]) 99 | cols = cmap_emb(np.linspace(0, 1, nn)) 100 | cax.imshow(cols[:,np.newaxis], aspect="auto") 101 | cax.set_ylim([0, (1+padding)*nn]) 102 | cax.invert_yaxis() 103 | cax.axis("off") 104 | 105 | # add corridor colors 106 | for n in range(len(corridor_starts)): 107 | if (corridor_starts[n,0]+corridor_widths[n] > xmin and 108 | corridor_starts[n,0] < xmax): 109 | icorr = int(corridor_starts[n,1]) 110 | start = corridor_starts[n,0] 111 | width = corridor_widths[n] 112 | width += min(0, start-xmin) 113 | start = max(0, start - xmin) 114 | width = min(width, xmax - xmin - start) 115 | ax.add_patch( 116 | patches.Rectangle(xy=(start, 0), width=width, 117 | height=nn, facecolor=ccolor[icorr], 118 | edgecolor=None, alpha=0.1)) 119 | # add reward events 120 | for n in range(len(reward_inds)): 121 | if reward_inds[n] > xmin and reward_inds[n] < xmax: 122 | start = int(reward_inds[n] - xmin) 123 | width = 0 124 | ax.add_patch(patches.Rectangle(xy=(start, 0), width=width, 125 | height=nn, facecolor=None, edgecolor='g', alpha=1)) 126 | 127 | return il 128 | 129 | def panel_events(ax, xmin, xmax, sound_inds, lick_inds, reward_inds): 130 | h1=ax.scatter(sound_inds-0.5,0*np.ones([len(sound_inds),]), 131 | color=[1.,0.6,0], marker='s', s=30) 132 | h2=ax.scatter(lick_inds-0.5,-1*np.ones([len(lick_inds),]), 133 | color=[1.0,0.3,0.3], marker='.', s=30) 134 | h0=ax.scatter(reward_inds-0.5,1*np.ones([len(reward_inds),]), 135 | color='g', marker='^', s=30) 136 | ax.axis('off') 137 | ax.set_xlim([xmin, xmax]) 138 | ax.set_ylim([-1.35, 1.35]) 139 | ax.legend([h0,h1,h2], ["reward", "tone", "licks"], 140 | handletextpad=0.01, labelspacing=0.15, loc=(-0.08,-0.31), 141 | labelcolor="linecolor", frameon=False) 142 | 143 | def panel_imgs(grid, il, corridor_imgs): 144 | Ly, Lx = corridor_imgs.shape[1:] 145 | Lyc = Lx*4 146 | xp = int(Lx*0.4) 147 | imgs = 255*np.ones((Lx*2+xp*2, Lyc), "uint8") 148 | for k in range(2): 149 | imgs[(Lx+xp)*k+xp : (Lx+xp)*k+xp + Lx] = corridor_imgs[k, :Lyc].T 150 | imgs = np.tile(imgs[:,:,np.newaxis], (1,1,3)) 151 | 152 | ax = plt.subplot(grid[1,0]) 153 | ax.imshow(imgs) 154 | for k in range(2): 155 | ax.text(0, (Lx+xp)*k + xp-10, "leaves" if k==0 else "circles", 156 | color=ccolor[k]) 157 | ax.axis("off") 158 | ax.set_title("VR corridors") 159 | transl = mtransforms.ScaledTranslation(-15 / 72, 5/ 72, grid.figure.dpi_scale_trans) 160 | il = plot_label(ltr, il, ax, transl, fs_title) 161 | 162 | return il 163 | 164 | def panel_cc(grid, il, yratio, cc_nodes): 165 | ax = plt.subplot(grid[-1, 0]) 166 | poss = ax.get_position().bounds 167 | ax.set_position([poss[0], poss[1]-.0, 0.95*poss[2], 168 | 0.95*poss[2] * yratio]) 169 | poss = ax.get_position().bounds 170 | vmax = 1 171 | im = ax.imshow(cc_nodes, vmin=-vmax, vmax=vmax, cmap="RdBu_r") 172 | ax.axis("off") 173 | ax.set_title("asymmetric similarity") 174 | transl = mtransforms.ScaledTranslation(-15 / 72, 5/ 72, grid.figure.dpi_scale_trans) 175 | il = plot_label(ltr, il, ax, transl, fs_title) 176 | 177 | cax = grid.figure.add_axes([poss[0]+poss[2]*1.02, poss[1]+poss[3]*0.75, 178 | poss[2]*0.03, poss[3]*0.25]) 179 | plt.colorbar(im, cax) 180 | return il 181 | 182 | def _fig2(brain_img, sn, xpos, ypos, isort, isort2, cc_nodes, 183 | corridor_starts, corridor_widths, 184 | corridor_tuning, corridor_imgs, VRpos, 185 | reward_inds, sound_inds, lick_inds, run): 186 | fig = plt.figure(figsize=(14,7)) 187 | yratio = 14 / 7 188 | grid = plt.GridSpec(3,5, figure=fig, left=0.02, right=0.98, top=0.98, bottom=0.02, 189 | wspace = 0.3, hspace = 0.15) 190 | 191 | il = 0 192 | il = panel_neuron_pos(fig, grid, il, yratio, xpos, ypos, isort, brain_img) 193 | 194 | il = panel_imgs(grid, il, corridor_imgs) 195 | 196 | il = panel_cc(grid, il, yratio, cc_nodes) 197 | 198 | ax = plt.subplot(grid[:,1:]) 199 | pos = ax.get_position().bounds 200 | ax.remove() 201 | 202 | xmin = 0 203 | xmax=xmin+520 204 | 205 | nn = sn.shape[0] 206 | xr = xmax - xmin 207 | y0 = pos[1] 208 | x0 = pos[0] 209 | padding=0.025 210 | dye = 0.06 211 | dyr = 0.09 212 | dx = 0.8 213 | xpad = 0.03*pos[2] 214 | xpadt = 0.01*pos[2] 215 | dxt = ((1-dx)*pos[2]-xpad-xpadt)/2 216 | ypad = 0.02*pos[3] 217 | ys = y0+(dye+dyr)*pos[3]+ypad+0.01*pos[3] 218 | poss = [x0, ys, pos[2]*dx, pos[3]-ys] 219 | 220 | ax = fig.add_axes(poss) 221 | il = panel_raster(fig, ax, il, padding, sn, xmin, xmax, 222 | corridor_starts, corridor_widths, reward_inds) 223 | 224 | ax = fig.add_axes([poss[0], y0+dyr*pos[3]+ypad, poss[2], dye*pos[3]]) 225 | panel_events(ax, xmin, xmax, sound_inds, lick_inds, reward_inds) 226 | 227 | ax = fig.add_axes([poss[0], y0, poss[2], dyr*pos[3]]) 228 | ax.fill_between(np.arange(0, xr), run[xmin:xmax], color=kp_colors[0]) 229 | ax.set_xlim([0, xr]) 230 | ax.set_ylim([0, np.percentile(run[xmin:xmax], 99)]) 231 | ax.axis("off") 232 | ax.text(0.11,0.9,"running speed", transform=ax.transAxes, color=kp_colors[0]) 233 | 234 | axs = [fig.add_axes([poss[0]+poss[2]+xpad, poss[1], dxt, poss[3]]), 235 | fig.add_axes([poss[0]+poss[2]+xpad+xpadt+dxt, poss[1], dxt, poss[3]])] 236 | 237 | il = panels_tuning(axs, il, padding, corridor_tuning) 238 | 239 | return fig 240 | 241 | 242 | def fig2(root, save_figure=True): 243 | d = np.load(os.path.join(root, "results", "corridor_proc.npz"), allow_pickle=True) 244 | d2 = np.load(os.path.join(root, "data", "corridor_behavior.npz"), allow_pickle=True) 245 | try: 246 | brain_img = plt.imread(os.path.join(root, "figures", "brain_window_visual.png")) 247 | except: 248 | brain_img = np.zeros((50,50)) 249 | 250 | fig = _fig2(brain_img, **d, **d2) 251 | if save_figure: 252 | fig.savefig(os.path.join(root, "figures", "fig2.pdf"), dpi=200) 253 | 254 | 255 | def _suppfig_vr_algs(snys, ctunings, 256 | corridor_starts, corridor_widths, reward_inds): 257 | 258 | fig = plt.figure(figsize=(12,12)) 259 | grid = plt.GridSpec(2,1, figure=fig, left=0.06, right=0.96, top=0.96, bottom=0.04, 260 | wspace = 0.3, hspace = 0.15) 261 | 262 | xmin = 1000 263 | xmax=xmin+500 264 | il = 0 265 | padding = 0.025 266 | alg = ["t-SNE", "UMAP"] 267 | for k in range(2): 268 | sny = snys[k] 269 | ctuning = ctunings[k] 270 | 271 | ax = plt.subplot(grid[k]) 272 | pos = ax.get_position().bounds 273 | ax.remove() 274 | 275 | xmin = 1000 276 | xmax=xmin+500 277 | 278 | nn = sny.shape[0] 279 | xr = xmax - xmin 280 | y0 = pos[1] 281 | x0 = pos[0] 282 | padding=0.025 283 | dx = 0.8 284 | xpad = 0.03*pos[2] 285 | xpadt = 0.01*pos[2] 286 | dxt = ((1-dx)*pos[2]-xpad-xpadt)/2 287 | poss = [x0, y0, pos[2]*dx, pos[3]] 288 | 289 | ax = fig.add_axes(poss) 290 | il = panel_raster(fig, ax, il, padding, sny, xmin, xmax, 291 | corridor_starts, corridor_widths, reward_inds, 292 | cmap_neurons=False, title_str=f"{alg[k]} sorting") 293 | axs = [fig.add_axes([poss[0]+poss[2]+xpad, poss[1], dxt, poss[3]]), 294 | fig.add_axes([poss[0]+poss[2]+xpad+xpadt+dxt, poss[1], dxt, poss[3]])] 295 | il = panels_tuning(axs, il, padding, ctuning, label_white=False) 296 | return fig 297 | 298 | def suppfig_vr_algs(root, save_figure=True): 299 | d = np.load(os.path.join(root, "results", "corridor_supp.npz"), allow_pickle=True) 300 | fig = _suppfig_vr_algs(**d); 301 | if save_figure: 302 | fig.savefig(os.path.join(root, "figures", "suppfig_vr_algs.pdf")) 303 | -------------------------------------------------------------------------------- /paper/fig4.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys, os\n", 10 | "import numpy as np\n", 11 | "from scipy.stats import zscore\n", 12 | "from rastermap import Rastermap, utils\n", 13 | "from neuropop import linear_prediction\n", 14 | "\n", 15 | "# path to paper code\n", 16 | "sys.path.insert(0, '/github/rastermap/paper')\n", 17 | "from loaders import (load_widefield_data, load_hippocampus_data, \n", 18 | " load_fish_data, tuning_curves_hipp)\n", 19 | "import fig4 \n", 20 | "\n", 21 | "# path to directory with data etc\n", 22 | "### *** CHANGE THIS TO WHEREEVER YOU ARE DOWNLOADING THE DATA ***\n", 23 | "root = \"/media/carsen/ssd2/rastermap_paper/\"\n", 24 | "# (in this folder we have a \"data\" folder and a \"results\" folder)\n", 25 | "os.makedirs(os.path.join(root, \"data\"), exist_ok=True)\n", 26 | "os.makedirs(os.path.join(root, \"results\"), exist_ok=True)\n" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "### hippocampus data\n", 34 | "\n", 35 | "We used the spiking data from Grosmark & Buszaki 2016, available [here](https://crcns.org/data-sets/hc/hc-11/about-hc-11). Specifically download the `Achilles_10252013_sessInfo.mat`." 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "### path to mat file (SPECIFIC FOR YOUR COMPUTER)\n", 45 | "filename = os.path.join(root, \"data\", \"Achilles_10252013_sessInfo.mat\")\n", 46 | "\n", 47 | "# load data\n", 48 | "bin_sec = 0.2\n", 49 | "dat = load_hippocampus_data(filename, bin_sec=bin_sec)\n", 50 | "spks = dat[\"spks\"]\n", 51 | "spks = zscore(spks, axis=1)\n", 52 | "n_neurons, n_time = spks.shape \n", 53 | "pyr_cells = dat[\"pyr_cells\"]\n", 54 | "loc_signed = dat[\"loc_signed\"]\n", 55 | "loc2d = dat[\"loc2d\"]\n", 56 | "speed = (np.diff(loc2d, axis=0)**2).sum(axis=1)**0.5\n", 57 | "speed = np.concatenate((np.zeros((1,)), speed), axis=0)\n", 58 | "\n", 59 | "# compute tuning curves along linear corridor\n", 60 | "n_pos = 15\n", 61 | "bins = np.arange(-1, 1+1./n_pos, 1./n_pos)\n", 62 | "tcurves = tuning_curves_hipp(spks, loc_signed, bins)\n", 63 | "\n", 64 | "model = Rastermap(n_clusters=None, \n", 65 | " n_PCs=64, \n", 66 | " locality=0.1, \n", 67 | " time_lag_window=15, \n", 68 | " symmetric=False,\n", 69 | " grid_upsample=0,\n", 70 | " ).fit(spks)\n", 71 | "isort = model.isort \n", 72 | "cc_nodes = model.cc\n", 73 | "\n", 74 | "np.savez(os.path.join(root, \"results\", \"hippocampus_proc.npz\"),\n", 75 | " spks=spks, pyr_cells=pyr_cells, speed=speed, loc2d=loc2d, tcurves=tcurves,\n", 76 | " isort=isort, cc_nodes=cc_nodes)\n", 77 | " " 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "### widefield data\n", 85 | "\n", 86 | "We used the widefield data from the Musall*, Kaufman* et al paper [here](https://labshare.cshl.edu/shares/library/repository/38599/Widefield/mSM43/21-Nov-2017/). Specifically download `interpVc.mat`, `Vc.mat` and `regData.mat`. Then the following code will sort the data and compute the neural prediction from the task variables and behaviors." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "### path to mat files (SPECIFIC FOR YOUR COMPUTER)\n", 96 | "expname = \"mSM43_21-Nov-2017\"\n", 97 | "data_path = os.path.join(root, \"data/\", expname)\n", 98 | "\n", 99 | "### load data\n", 100 | "out = load_widefield_data(data_path)\n", 101 | "(U0, sv, Vsv, ypos, xpos, regressors, \n", 102 | " behav_idx, stim_times, reward_times, stim_labels) = out\n", 103 | "\n", 104 | "\n", 105 | "### run rastermap\n", 106 | "model = Rastermap(n_clusters=100, locality=0.5, time_lag_window=10,\n", 107 | " n_PCs=U0.shape[1]).fit(Usv = U0 * sv, \n", 108 | " Vsv = Vsv)\n", 109 | "\n", 110 | "isort = model.isort \n", 111 | "cc_nodes = model.cc\n", 112 | "Vsv_sub = model.Vsv\n", 113 | "\n", 114 | "bin_size = 200\n", 115 | "U_sn = utils.bin1d(U0[isort], bin_size=bin_size, axis=0) # bin over voxel axis\n", 116 | "sn = U_sn @ Vsv_sub.T\n", 117 | "sn = zscore(sn, axis=1)\n", 118 | "\n", 119 | "### predict activity from behavior\n", 120 | "ve, _, sn_pred, itest = linear_prediction.prediction_wrapper(regressors, sn.T, lam=1e4, delay=0)\n", 121 | "sn_pred = sn_pred.T\n", 122 | "itest = itest.flatten()\n", 123 | "print(ve)\n", 124 | "ve, _, sn_pred_beh, itest = linear_prediction.prediction_wrapper(regressors[:,behav_idx], sn.T, lam=1e4, delay=0)\n", 125 | "itest = itest.flatten()\n", 126 | "sn_pred_beh = sn_pred_beh.T\n", 127 | "print(ve)\n", 128 | "\n", 129 | "np.savez(os.path.join(root, \"results\", \"widefield_proc.npz\"),\n", 130 | " stim_times_0=stim_times[0], \n", 131 | " stim_times_1=stim_times[1], \n", 132 | " stim_times_2=stim_times[2], \n", 133 | " stim_times_3=stim_times[3], \n", 134 | " stim_labels=stim_labels, reward_times=reward_times, \n", 135 | " sn=sn, sn_pred=sn_pred, sn_pred_beh=sn_pred_beh, \n", 136 | " bin_size=bin_size, itest=itest, ypos=ypos, xpos=xpos, isort=isort)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "### zebrafish data\n", 144 | "\n", 145 | "We used a neural recording from Chen*, Mu*, Hu*, Kuan* et al 2018, specifically `subject_8` found [here](https://figshare.com/articles/dataset/Whole-brain_light-sheet_imaging_data/7272617?file=13474868). Download the zip file and unzip it (we found this works best with the tool 7zip)." 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "### folder with \"subject_8\" folder\n", 155 | "data_path = os.path.join(root, \"data\")\n", 156 | "\n", 157 | "spks, F, xyz, stims, swimming, eyepos = load_fish_data(data_path, subject=8)\n", 158 | "\n", 159 | "model = Rastermap(n_clusters=100, locality=0.1, time_lag_window=5, n_PCs=200).fit(spks)\n", 160 | "isort = model.isort \n", 161 | "cc_nodes = model.cc\n", 162 | "sn = zscore(utils.bin1d(zscore(spks[isort], axis=1), bin_size=50, axis=0), axis=1)\n", 163 | "\n", 164 | "np.savez(os.path.join(root, \"results\", \"fish_proc.npz\"),\n", 165 | " swimming=swimming, eyepos=eyepos, stims=stims, \n", 166 | " sn=sn, xyz=xyz, isort=isort, cc_nodes=cc_nodes)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "### make figure" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "# root path has folder \"results\" with saved results\n", 183 | "# will save figures to \"figures\" folder\n", 184 | "# will ignore panels that aren't processed\n", 185 | "import imp\n", 186 | "imp.reload(fig4)\n", 187 | "os.makedirs(os.path.join(root, \"figures/\"), exist_ok=True)\n", 188 | "fig4.fig4(root)" 189 | ] 190 | } 191 | ], 192 | "metadata": { 193 | "kernelspec": { 194 | "display_name": "Python 3.9.16 ('rastermap')", 195 | "language": "python", 196 | "name": "python3" 197 | }, 198 | "language_info": { 199 | "codemirror_mode": { 200 | "name": "ipython", 201 | "version": 3 202 | }, 203 | "file_extension": ".py", 204 | "mimetype": "text/x-python", 205 | "name": "python", 206 | "nbconvert_exporter": "python", 207 | "pygments_lexer": "ipython3", 208 | "version": "3.9.16" 209 | }, 210 | "orig_nbformat": 4, 211 | "vscode": { 212 | "interpreter": { 213 | "hash": "998540cc2fc2836a46e99cd3ca3c37c375205941b23fd1eb4b203c48f2be758f" 214 | } 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 2 219 | } 220 | -------------------------------------------------------------------------------- /paper/fig5.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# DMFC recordings in macaque (Sohn, Narain et al 2019)\n", 8 | "\n", 9 | "original data-loading notebook from: https://neurallatents.github.io/datasets#dmfcrsg" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "## Download dataset and required packages if necessary\n", 19 | "!pip install git+https://github.com/neurallatents/nlb_tools.git\n", 20 | "!pip install dandi\n", 21 | "!dandi download https://gui.dandiarchive.org/#/dandiset/000130" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "load dataset" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "import sys\n", 38 | "import numpy as np\n", 39 | "import pandas as pd\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "from nlb_tools.nwb_interface import NWBDataset\n", 42 | "from scipy.stats import zscore\n", 43 | "\n", 44 | "sys.path.insert(0, \"/github/rastermap/paper\")\n", 45 | "import fig5\n", 46 | "\n", 47 | "# path to directory with data etc\n", 48 | "### *** CHANGE THIS TO WHEREEVER YOU ARE SAVING YOUR DATA OUTPUTS ***\n", 49 | "root = \"/media/carsen/ssd2/rastermap_paper/\"\n", 50 | "\n", 51 | "## Load dataset\n", 52 | "dataset = NWBDataset(\"000130/sub-Haydn/\", \"*train\", split_heldout=False)\n", 53 | "\n", 54 | "# bin at 20ms\n", 55 | "dataset.resample(20)\n", 56 | "\n", 57 | "# convert neural times from nanoseconds to seconds \n", 58 | "neural_time = (dataset.data.index.to_numpy() / 1e3).astype(\"float\") / 1e6\n", 59 | "\n", 60 | "# convert task times from nanoseconds to seconds for valid trials\n", 61 | "# (valid trials = set_time at least 3 second after start of exp and before end of exp)\n", 62 | "igood = ~dataset.trial_info.ready_time.isna()\n", 63 | "igood *= ~dataset.trial_info.set_time.isna()\n", 64 | "igood *= ~dataset.trial_info.go_time.isna()\n", 65 | "\n", 66 | "ready_time = (dataset.trial_info.ready_time.to_numpy() / 1e3).astype(\"float\") / 1e6\n", 67 | "set_time = (dataset.trial_info.set_time.to_numpy() / 1e3).astype(\"float\") / 1e6\n", 68 | "go_time = (dataset.trial_info.go_time.to_numpy() / 1e3).astype(\"float\") / 1e6\n", 69 | "\n", 70 | "nt_sec = 3\n", 71 | "igood *= (set_time - nt_sec) > 0\n", 72 | "igood *= (set_time + nt_sec - neural_time[-1]) < 0\n", 73 | "\n", 74 | "ready_time = ready_time[igood]\n", 75 | "set_time = set_time[igood]\n", 76 | "go_time = go_time[igood]\n", 77 | "is_short = dataset.trial_info.is_short.to_numpy()[igood]\n", 78 | "is_eye = dataset.trial_info.is_eye.to_numpy()[igood]\n", 79 | "iti = dataset.trial_info.iti.to_numpy()[igood]\n", 80 | "\n", 81 | "print(f\"number of trials: {len(set_time)}\")\n", 82 | "\n", 83 | "print(len(is_eye), len(is_short))\n", 84 | "\n", 85 | "# some spike timebins are NaN, replace with nearby values\n", 86 | "spks = dataset.data.to_numpy().T.copy()\n", 87 | "spks = spks.astype(\"float32\")\n", 88 | "ibad = np.isnan(spks[0])\n", 89 | "nbad = np.arange(0, spks.shape[-1])[~ibad]\n", 90 | "ibad = np.nonzero(ibad)[0]\n", 91 | "ireplace = np.array([nbad[np.abs(nbad - ibad[i]).argmin()] for i in range(len(ibad))])\n", 92 | "spks[:, ibad] = spks[:, ireplace]\n", 93 | "print(spks.shape)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "### run rastermap" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "from rastermap import Rastermap\n", 110 | "model = Rastermap(n_clusters=None, # None turns off clustering and sorts single neurons\n", 111 | " n_PCs=48, # use fewer PCs than neurons\n", 112 | " locality=0.5, # some locality in sorting (this is a value from 0-1)\n", 113 | " time_lag_window=20, # use future timepoints to compute correlation\n", 114 | " grid_upsample=0, # 0 turns off upsampling since we're using single neurons\n", 115 | " mean_time=True,\n", 116 | " bin_size=1,\n", 117 | " time_bin=1\n", 118 | " ).fit(spks, compute_X_embedding=True)\n", 119 | "y = model.embedding # neurons x 1\n", 120 | "isort = model.isort" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "reshape spks into trials" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "set_idx = np.array([np.abs(neural_time - set_time[i]).argmin() for i in range(len(set_time))])\n", 137 | "nt = int(nt_sec/.02)\n", 138 | "print(nt)\n", 139 | "set_idx = np.arange(-nt, nt+1) + set_idx[:,np.newaxis]\n", 140 | "spks_trials = spks[:, set_idx].copy()\n", 141 | "\n", 142 | "\n", 143 | "# split trials into short and long prior blocks, and by set_time bins\n", 144 | "ttypes = [is_short, ~is_short]\n", 145 | "ttypes = [(is_eye)*(is_short), (is_eye)*(~is_short), (~is_eye)*(is_short), (~is_eye)*(~is_short)]\n", 146 | "bins = list(np.linspace(0.46, 0.85, 6))\n", 147 | "bins.extend(list(np.arange(0.95, 1.3, 0.1)))\n", 148 | "sr = np.digitize(set_time - ready_time, bins) - 1\n", 149 | "ttypes = [sr==i for i in range(9)]\n", 150 | "ttypes.insert(4, (sr==4)*(is_short))\n", 151 | "ttypes[5] = (sr==4)*(~is_short)\n", 152 | "rts = np.array(bins) + 0.05\n", 153 | "rts = (nt - rts/0.02).astype(\"int\")\n", 154 | "rts = list(rts)\n", 155 | "rts.insert(4, rts[4])\n", 156 | "gts = np.array([(go_time[ttypes[k]] - set_time[ttypes[k]]).mean() for k in range(len(ttypes))])\n", 157 | "gts = (nt + gts/0.02).astype(\"int\")\n", 158 | "\n", 159 | "psths = []\n", 160 | "for k in range(len(ttypes)):\n", 161 | " psths.append(spks_trials[isort][:,ttypes[k]].mean(axis=1))\n", 162 | "psths = np.array(psths)\n", 163 | "psths = psths.transpose(1,0,2).reshape(spks.shape[0], -1)\n", 164 | "psths = zscore(psths, axis=1)\n", 165 | "psths = psths.reshape(spks.shape[0], len(ttypes), -1).transpose(1,0,2)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "### make figure" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "os.makedirs(os.path.join(root, \"figures/\"), exist_ok=True)\n", 182 | "fig5.fig5(root, psths, rts, gts, save_figure=True)" 183 | ] 184 | } 185 | ], 186 | "metadata": { 187 | "kernelspec": { 188 | "display_name": "rastermap", 189 | "language": "python", 190 | "name": "python3" 191 | }, 192 | "language_info": { 193 | "codemirror_mode": { 194 | "name": "ipython", 195 | "version": 3 196 | }, 197 | "file_extension": ".py", 198 | "mimetype": "text/x-python", 199 | "name": "python", 200 | "nbconvert_exporter": "python", 201 | "pygments_lexer": "ipython3", 202 | "version": "3.9.16" 203 | } 204 | }, 205 | "nbformat": 4, 206 | "nbformat_minor": 2 207 | } 208 | -------------------------------------------------------------------------------- /paper/fig5.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | import os 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from matplotlib import patches 8 | 9 | from fig_utils import * 10 | 11 | 12 | def fig5(root, psths, rts, gts, save_figure=True): 13 | il = 0 14 | 15 | nk, nn, nt2 = psths.shape 16 | nt = nt2//2 17 | 18 | ci = np.hstack((np.linspace(0, 0.3, 5), np.linspace(0.7, 1., 5))) 19 | colors = plt.get_cmap("PiYG")(ci) 20 | c0 = colors[np.newaxis,:].copy() 21 | c0[:, 5:] = 1.0 22 | c1 = colors[np.newaxis,1:].copy() 23 | c1[:, :4] = 1.0 24 | cis = [c0, c1] 25 | 26 | c_t = plt.get_cmap("YlOrBr")([0.4, 0.6, 0.9])[::-1] 27 | 28 | fig = plt.figure(figsize=(14,6)) 29 | 30 | grid = plt.GridSpec(2,7, figure=fig, left=0.03, right=0.98, top=0.85, bottom=0.05, 31 | wspace = 0.2, hspace = 0.3) 32 | 33 | t_sample = np.hstack((np.linspace(480, 800, 5), 34 | np.linspace(800, 1200, 5))).astype("int") 35 | 36 | tstr = [r"short prior block, $t_s$ (ms)", r"long prior block, $t_s$ (ms)"] 37 | 38 | ax0 = plt.subplot(grid[0,:2]) 39 | pos = ax0.get_position().bounds 40 | ax0.set_position([pos[0], pos[1], pos[2]*0.85, pos[3]]) 41 | transl = mtransforms.ScaledTranslation(-18 / 72, 40 / 72, fig.dpi_scale_trans) 42 | il = plot_label(ltr, il, ax0, transl, fs_title) 43 | ax0.axis("off") 44 | 45 | ax = ax0.inset_axes([0., 0.8, 1., 0.15]) 46 | ax.plot([0, 1.2], [0, 0], ls="-", color="k") 47 | ax.plot([0, 0], [-1, 1], color="k") 48 | ax.plot([1, 1], [-1, 1], color="k") 49 | ax.plot([1.2, 1.8], [0, 0], ls="dotted", color="k") 50 | ax.plot([1.8, 2], [0, 0], ls="-", color="k") 51 | ax.plot([2, 2], [-1, 1], color="k") 52 | ax.text(0, 1.1, "ready cue", va="bottom", ha="center", color=c_t[0]) 53 | ax.text(0.5, -1.5, r"$t_s$", va="bottom", ha="center") 54 | ax.text(1, 1.1, "set cue", va="bottom", ha="center", color=c_t[1]) 55 | ax.text(1.5, -1.5, r"$t_p$", va="bottom", ha="center") 56 | ax.text(2, 1.1, "go (action)", va="bottom", ha="center", color=c_t[2]) 57 | ax.text(1.65, -3.5, r"reward $\propto$ |$t_p$ - $t_s$| / $t_s$", va="bottom", ha="center") 58 | ax.text(1, 3.5, "time-interval reproduction task\n(Sohn, Narain et al, 2019)", fontsize="large", 59 | va="bottom", ha="center") 60 | ax.set_xlim([-0.05,2.05]) 61 | ax.set_ylim([-1,1]) 62 | ax.axis("off") 63 | 64 | for j in range(2): 65 | ax = ax0.inset_axes([0., 0.35-0.3*j, 1., 0.125]) 66 | ax.imshow(cis[j], aspect="auto") 67 | ax.axis("off") 68 | for i in range(5): 69 | ax.text(i+4*j, 0, t_sample[i+5*j], ha="center", va="center", fontsize="small", 70 | color="w", fontweight="bold") 71 | if (i==4 and j==0) or (i==0 and j==1): 72 | width = 1 73 | ax.add_patch(patches.Rectangle(xy=(i+4*j-0.55, -0.5), width=width, 74 | fill=False, height=1, facecolor=None, edgecolor="k", lw=3)) 75 | ax.text(j*4+2, -0.8, tstr[j], ha="center", fontsize="large") 76 | ax.set_xlim([-0.5, 8.5]) 77 | 78 | kis = np.hstack((np.arange(4,-1, -1), np.arange(5, 10))) 79 | for k in range(10): 80 | ki = kis[k] 81 | ax = plt.subplot(grid[ki//5, 2+ki%5]) 82 | pos = ax.get_position().bounds 83 | im = ax.imshow(psths[k], aspect="auto", vmin=0, vmax=5, cmap="gray_r") 84 | for l,tt in enumerate([rts[k], nt, gts[k]]): 85 | ax.plot([tt, tt], [0, nn], ls="--", lw=2, color=c_t[l]) 86 | ax.set_ylim([0, nn]) 87 | ax.set_xlim([nt-80, nt+80]) 88 | ax.set_yticks([]) 89 | ax.set_xticks([]) 90 | ax.spines["top"].set_visible(True) 91 | ax.spines["right"].set_visible(True) 92 | ax.set_title(f"{t_sample[k]} ms", color=colors[k], fontweight="bold", 93 | loc="center", fontsize="medium") 94 | #ax.imshow(psth_s - psth_l, aspect="auto", vmin=-1, vmax=1, cmap="RdBu_r") 95 | if ki%5 == 2: 96 | ax.text(0.5, 1.15, ["short prior block", "long prior block"][k//5], 97 | transform=ax.transAxes, fontsize="large", ha="center") 98 | if ki==0: 99 | ax.text(-0.01, 1.22, "trial-averaged responses", 100 | transform=ax.transAxes, fontsize="large") 101 | transl = mtransforms.ScaledTranslation(-18 / 72, 40 / 72, fig.dpi_scale_trans) 102 | il = plot_label(ltr, il, ax, transl, fs_title) 103 | if ki==1: 104 | cax = fig.add_axes([pos[0]+0.*pos[2], pos[1] - pos[3]*0.08, 105 | pos[2]*0.5, 0.03*pos[3]]) 106 | plt.colorbar(im, cax, orientation="horizontal") 107 | 108 | 109 | ax = plt.subplot(grid[1,0]) 110 | transl = mtransforms.ScaledTranslation(-50 / 72, 10 / 72, fig.dpi_scale_trans) 111 | il = plot_label(ltr, il, ax, transl, fs_title) 112 | pos = ax.get_position().bounds 113 | ax.set_position([pos[0]+0.06, pos[1], pos[2], pos[3]]) 114 | pos = ax.get_position().bounds 115 | im = ax.imshow(psths[4] - psths[5], aspect="auto", vmin=-5, vmax=5, cmap="RdBu_r") 116 | k = 4 117 | for l,tt in enumerate([rts[k], nt, gts[k]]): 118 | ax.plot([tt, tt], [0, nn], ls="--", lw=2, color=c_t[l]) 119 | ax.set_ylim([0, nn]) 120 | ax.set_xlim([nt-80, nt+80]) 121 | ax.set_yticks([]) 122 | ax.set_xticks([]) 123 | ax.spines["top"].set_visible(True) 124 | ax.spines["right"].set_visible(True) 125 | ax.set_title(r"short - long prior, $t_p$ = 800 ms", loc="center") 126 | cax = fig.add_axes([pos[0] + 1.1*pos[2], pos[1] + pos[3]*0.65, 127 | pos[2]*0.04, 0.3*pos[3]]) 128 | plt.colorbar(im, cax) 129 | 130 | if save_figure: 131 | fig.savefig(os.path.join(root, "figures", "fig5.pdf"), dpi=200) -------------------------------------------------------------------------------- /paper/fig6.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# download and load (Steinmetz et al 2019)\n", 8 | "\n", 9 | "original loading notebook from: https://compneuro.neuromatch.io/projects/neurons/README.html#steinmetz" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import sys\n", 19 | "import numpy as np\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "from scipy.stats import zscore, ranksums\n", 22 | "from sklearn.decomposition import PCA\n", 23 | "from rastermap import Rastermap, utils\n", 24 | "from tqdm import trange \n", 25 | "\n", 26 | "sys.path.insert(0, \"/github/rastermap/paper\")\n", 27 | "import fig6\n", 28 | "\n", 29 | "# path to directory with data etc\n", 30 | "### *** CHANGE THIS TO WHEREEVER YOU ARE SAVING YOUR DATA OUTPUTS ***\n", 31 | "root = \"/media/carsen/ssd2/rastermap_paper/\"\n", 32 | "\n", 33 | "# Data retrieval\n", 34 | "import os, requests\n", 35 | "\n", 36 | "fname = []\n", 37 | "for j in range(3):\n", 38 | " fname.append(os.path.join(root, \"data\", \"steinmetz_part%d.npz\"%j))\n", 39 | "url = [\"https://osf.io/agvxh/download\"]\n", 40 | "url.append(\"https://osf.io/uv3mw/download\")\n", 41 | "url.append(\"https://osf.io/ehmw2/download\")\n", 42 | "\n", 43 | "for j in range(len(url)):\n", 44 | " if not os.path.isfile(fname[j]):\n", 45 | " try:\n", 46 | " r = requests.get(url[j])\n", 47 | " except requests.ConnectionError:\n", 48 | " print(\"!!! Failed to download data !!!\")\n", 49 | " else:\n", 50 | " if r.status_code != requests.codes.ok:\n", 51 | " print(\"!!! Failed to download data !!!\")\n", 52 | " else:\n", 53 | " with open(fname[j], \"wb\") as fid:\n", 54 | " fid.write(r.content)\n", 55 | "\n", 56 | "# Data loading\n", 57 | "alldat = np.array([])\n", 58 | "for j in range(len(fname)):\n", 59 | " alldat = np.hstack((alldat,\n", 60 | " np.load(os.path.join(root, \"data\", \"steinmetz_part%d.npz\"%j),\n", 61 | " allow_pickle=True)['dat']))" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "### sort trials from each recording" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "psigs = []\n", 78 | "brain_areas = []\n", 79 | "reaction_times = []\n", 80 | "pupil_speeds = []\n", 81 | "licks = []\n", 82 | "isorts = []\n", 83 | "itrials = []\n", 84 | "rewards = []\n", 85 | "face_motions = []\n", 86 | "wheel_moves = []\n", 87 | "ccfs = []\n", 88 | "# 6\n", 89 | "for d in trange(len(alldat)):\n", 90 | " spks = alldat[d][\"spks\"].copy().astype(\"float32\")\n", 91 | " nn,_,nt = spks.shape\n", 92 | " brain_area = alldat[d][\"brain_area\"]\n", 93 | " response = alldat[d][\"response\"].flatten()\n", 94 | " rtypes = [-1, 1]\n", 95 | " ttypes = [\"stim\", \"reaction\", \"feedback\"]\n", 96 | "\n", 97 | " spks = spks.reshape(nn, -1)\n", 98 | " igood = ((spks.mean(axis=-1)) / .01) > 0.1\n", 99 | " igood *= (brain_area != \"root\")\n", 100 | " spks = zscore(spks[igood], axis=-1)\n", 101 | " brain_area = brain_area[igood]\n", 102 | " ccfs.append(alldat[d][\"ccf\"][igood])\n", 103 | "\n", 104 | " n_PCs = 10\n", 105 | " spcs = PCA(n_components=n_PCs).fit_transform(spks.T).T\n", 106 | " U = spks @ (spcs.T / (spcs**2).sum(axis=1)**0.5)\n", 107 | " spcs_trials = spcs.reshape(n_PCs, -1, nt)\n", 108 | " spks = spks.reshape(spks.shape[0], -1, nt)\n", 109 | "\n", 110 | " brain_areas.append(brain_area)\n", 111 | " for r in [0, 1]:\n", 112 | " rtype = rtypes[r]\n", 113 | " if rtype != 0:\n", 114 | " itrial = np.logical_and(response==rtype, ~np.isinf(alldat[d][\"reaction_time\"][:,0]))\n", 115 | " else:\n", 116 | " itrial = response==rtype\n", 117 | " ntrials = itrial.sum()\n", 118 | " \n", 119 | " sresp = spcs_trials[:,itrial].copy()\n", 120 | " sresp -= sresp.mean(axis=(-2,-1), keepdims=True)\n", 121 | " sresp_std = sresp.std(axis=(-2,-1), keepdims=True)\n", 122 | " sresp /= sresp_std\n", 123 | "\n", 124 | " # flatten across PCs to sort trials\n", 125 | " sresp_trials = sresp.copy().transpose(1,0,2).reshape(ntrials, -1)\n", 126 | " \n", 127 | " model = Rastermap(n_clusters=None, n_PCs=64, \n", 128 | " time_lag_window=0, time_bin=0,\n", 129 | " locality=0.1, mean_time=False, verbose=False\n", 130 | " ).fit(sresp_trials)\n", 131 | " \n", 132 | " isorts.append(model.isort)\n", 133 | " itrials.append(itrial)\n", 134 | " pspeed = (np.diff(alldat[d][\"pupil\"][1:], axis=-1)**2).sum(axis=0)**0.5\n", 135 | " pupil_speeds.append(pspeed[itrial][model.isort].mean(axis=-1))\n", 136 | " wms = alldat[d][\"wheel\"].squeeze()[itrial][model.isort].mean(axis=-1)\n", 137 | " wheel_moves.append(wms)\n", 138 | " rts = alldat[d][\"reaction_time\"][itrial,0][model.isort]/1000\n", 139 | " reaction_times.append(rts)\n", 140 | " lcks = alldat[d][\"licks\"][0,itrial][model.isort].sum(axis=-1)\n", 141 | " licks.append(lcks)\n", 142 | " rwds = alldat[d][\"feedback_type\"][itrial][model.isort]\n", 143 | " rewards.append(rwds)\n", 144 | " fe = alldat[d][\"face\"][0,itrial][model.isort].mean(axis=-1)\n", 145 | " face_motions.append(fe)\n", 146 | " iswap = model.isort[:10].mean() > model.isort[-10:].mean()\n", 147 | "\n", 148 | " #ss = spks[:,itrial].copy()\n", 149 | " ss = alldat[d][\"spks\"][igood][:,itrial].copy()\n", 150 | " ss = ss.mean(axis=-1)\n", 151 | " s0 = ss[:,:20]\n", 152 | " s1 = ss[:,-20:]\n", 153 | "\n", 154 | " stats, p = ranksums(s0.T, s1.T)\n", 155 | " psig = np.sign(stats) * (p<0.05)\n", 156 | " psigs.append(psig)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "### compute percentage of (p<0.05) of late vs early per brain area" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "brain_groups = [[\"VISa\", \"VISam\", \"VISl\", \"VISp\", \"VISpm\", \"VISrl\"], # visual cortex\n", 173 | " [\"CL\", \"LD\", \"LGd\", \"LH\", \"LP\", \"MD\", \"MG\", \"PO\", \"POL\", \"PT\", \"RT\", \"SPF\", \"TH\", \"VAL\", \"VPL\", \"VPM\"], # thalamus\n", 174 | " [\"CA\", \"CA1\", \"CA2\", \"CA3\", \"DG\", \"SUB\", \"POST\"], # hippocampal\n", 175 | " [\"ACA\", \"AUD\", \"COA\", \"DP\", \"ILA\", \"MOp\", \"MOs\", \"OLF\", \"ORB\", \"ORBm\", \"PIR\", \"PL\", \"SSp\", \"SSs\", \"RSP\",\"TT\"], # non-visual cortex\n", 176 | " [\"APN\", \"IC\", \"MB\", \"MRN\", \"NB\", \"PAG\", \"RN\", \"SCs\", \"SCm\", \"SCig\", \"SCsg\", \"SNr\", \"ZI\"], # midbrain\n", 177 | " [\"ACB\", \"CP\", \"GPe\", \"LS\", \"LSc\", \"LSr\", \"MS\", \"OT\", \"SI\"], # basal ganglia \n", 178 | " [\"BLA\", \"BMA\", \"EP\", \"EPd\", \"MEA\"] # cortical subplate\n", 179 | " ]\n", 180 | "\n", 181 | "area = [\"vis ctx\", \"thalamus\", \"hippocampal\", \"other ctx\", \"midbrain\", \"striatum\", \"amygdala\"]\n", 182 | "\n", 183 | "perc = np.zeros((len(psigs), 3, 7))\n", 184 | "import itertools\n", 185 | "ccf_all = list(itertools.chain(*ccfs))\n", 186 | "for k in range(len(psigs)):\n", 187 | " bass = np.zeros(len(brain_areas[k//2]), \"int\")\n", 188 | " for i,b in enumerate(brain_groups):\n", 189 | " bind = np.isin(brain_areas[k//2], b)\n", 190 | " bass[bind] = i\n", 191 | " perc[k, 0] = np.histogram(bass[psigs[k] > 0], np.arange(0, 8))[0]\n", 192 | " perc[k, 1] = np.histogram(bass[psigs[k] < 0], np.arange(0, 8))[0]\n", 193 | " perc[k, 2] = np.histogram(bass[np.abs(psigs[k]) > 0], np.arange(0, 8))[0]\n", 194 | " nb = np.histogram(bass, np.arange(0, 8))[0]\n", 195 | " perc[k] /= nb \n", 196 | "perc *= 100\n" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "# save results\n", 206 | "np.savez(os.path.join(root, \"results/steinmetz_proc.npz\"), \n", 207 | " ccf_all=ccf_all, itrials=np.array(itrials, dtype=object), isorts=np.array(isorts, dtype=object), \n", 208 | " reaction_times=np.array(reaction_times, dtype=object), rewards=np.array(rewards, dtype=object), \n", 209 | " licks=np.array(licks, dtype=object), wheel_moves=np.array(wheel_moves, dtype=object), \n", 210 | " face_motions=np.array(face_motions, dtype=object), pupil_speeds=np.array(pupil_speeds, dtype=object), \n", 211 | " perc=perc, area=np.array(area))" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "# will save figures to \"figures\" folder\n", 221 | "os.makedirs(os.path.join(root, \"figures/\"), exist_ok=True)\n", 222 | "fig6.fig6(root, alldat)" 223 | ] 224 | } 225 | ], 226 | "metadata": { 227 | "kernelspec": { 228 | "display_name": "rastermap", 229 | "language": "python", 230 | "name": "python3" 231 | }, 232 | "language_info": { 233 | "codemirror_mode": { 234 | "name": "ipython", 235 | "version": 3 236 | }, 237 | "file_extension": ".py", 238 | "mimetype": "text/x-python", 239 | "name": "python", 240 | "nbconvert_exporter": "python", 241 | "pygments_lexer": "ipython3", 242 | "version": "3.9.16" 243 | } 244 | }, 245 | "nbformat": 4, 246 | "nbformat_minor": 2 247 | } 248 | -------------------------------------------------------------------------------- /paper/fig7.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "02c64150", 7 | "metadata": { 8 | "id": "02c64150", 9 | "outputId": "79cc9655-6aa3-48c1-c6db-04acb272a452" 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "import os, sys\n", 14 | "\n", 15 | "# clone rl-baselines3-zoo repo and checkout branch update/hf\n", 16 | "# !pip install opencv-python # not headless version\n", 17 | "rl_zoo3_path = \"/github/rl-baselines3-zoo/\"\n", 18 | "model_folder = os.path.join(rl_zoo3_path, \"rl-trained-agents/\")\n", 19 | "sys.path.insert(0, rl_zoo3_path)\n", 20 | "\n", 21 | "import torch\n", 22 | "# install torch with cuda support\n", 23 | "device = torch.device('cuda')\n", 24 | "num_threads = 16\n", 25 | "torch.set_num_threads(num_threads)\n", 26 | "\n", 27 | "sys.path.insert(0, \"/github/rastermap/paper\")\n", 28 | "import qrdqn # has functions that wrap rl-baselines3-zoo and stable_baselines3\n", 29 | "import fig7\n", 30 | "\n", 31 | "# path to directory with data etc\n", 32 | "### *** CHANGE THIS TO WHEREEVER YOU ARE SAVING YOUR MODEL OUTPUTS ***\n", 33 | "root = \"/media/carsen/ssd2/rastermap_paper/\"\n", 34 | "# (in this folder we have a \"simulations\" folder)\n", 35 | "os.makedirs(os.path.join(root, \"simulations\"), exist_ok=True)\n", 36 | "\n", 37 | "env_ids = [\"PongNoFrameskip-v4\", \"SpaceInvadersNoFrameskip-v4\", \n", 38 | " \"EnduroNoFrameskip-v4\", \"SeaquestNoFrameskip-v4\"]\n", 39 | "n_seeds = 10\n" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "id": "3ae4ab6c", 45 | "metadata": {}, 46 | "source": [ 47 | "### run DQN\n", 48 | "\n", 49 | "This code will run the DQN and save the activations for `n_seeds` worth of runs and save to the `simulations` folder." 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "78e816ab", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "### IMPORTANT:\n", 60 | "### need to run one game at a time then restart notebook because hooks stick around\n", 61 | "qrdqn.run_qrdqn(model_folder, root, env_id=env_ids[0], n_seeds=n_seeds, device=device)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "id": "acf413db", 67 | "metadata": {}, 68 | "source": [ 69 | "### run rastermap and save\n", 70 | "\n", 71 | "this function will run rastermap on the activations from each game and save to the `simulations` folder." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "860b8722", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "### process qrdqn outputs\n", 82 | "for env_id in env_ids:\n", 83 | " print(env_id)\n", 84 | " qrdqn.sort_spks(root, env_id)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "id": "4678b1a0", 90 | "metadata": {}, 91 | "source": [ 92 | "### make figure" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "id": "6ab59517", 99 | "metadata": { 100 | "id": "6ab59517", 101 | "outputId": "7dc6125b-ec4a-4983-f6c5-e6b8402fe3a2" 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "# root path has folder \"simulations\" with saved results\n", 106 | "# will save figures to \"figures\" folder\n", 107 | "os.makedirs(os.path.join(root, \"figures/\"), exist_ok=True)\n", 108 | "fig7.fig7(root)" 109 | ] 110 | } 111 | ], 112 | "metadata": { 113 | "colab": { 114 | "name": "atari_pretrained.ipynb", 115 | "provenance": [] 116 | }, 117 | "kernelspec": { 118 | "display_name": "Python 3.8.17 ('RL')", 119 | "language": "python", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 3 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython3", 132 | "version": "3.9.16" 133 | }, 134 | "vscode": { 135 | "interpreter": { 136 | "hash": "cdf17a3778e5017f066e6f3db18157f3af3a0506de593e09339227d49ead1beb" 137 | } 138 | } 139 | }, 140 | "nbformat": 4, 141 | "nbformat_minor": 5 142 | } 143 | -------------------------------------------------------------------------------- /paper/fig7.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | import os 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | from fig_utils import * 9 | 10 | 11 | def fig7(root, save_figure=True): 12 | env_ids = ["PongNoFrameskip-v4", "SpaceInvadersNoFrameskip-v4", 13 | "EnduroNoFrameskip-v4", "SeaquestNoFrameskip-v4"] 14 | fig = plt.figure(figsize=(14,7)) 15 | grid = plt.GridSpec(2,6, figure=fig, left=0.04, right=0.98, top=0.96, bottom=0.07, 16 | wspace = 0.15, hspace = 0.25) 17 | transl = mtransforms.ScaledTranslation(-13 / 72, 20 / 72, fig.dpi_scale_trans) 18 | il = 0 19 | layer_cols = cmap_emb(np.array([0.55, 0.65, 0.75, 0.9, 0])) 20 | layer_names = ["conv1", "conv2", "conv3", "linear", "valuenet"] 21 | for igame, env_id in enumerate(env_ids): 22 | print(env_id) 23 | i0, j0 = igame//2, 3*(igame%2) 24 | 25 | d = np.load(os.path.join(root, "simulations/", f"qrdqn_{env_id}_results.npz")) 26 | X_embedding = d["X_embedding"] 27 | nn, nt = X_embedding.shape 28 | emb_layer = d["emb_layer"] 29 | ex_frames = d["ex_frames"] 30 | iframes = d["iframes"] 31 | 32 | ax = plt.subplot(grid[i0, j0+1:j0+3]) 33 | pos = ax.get_position().bounds 34 | ax.imshow(X_embedding, aspect='auto', 35 | vmax=2.5, vmin=-0., 36 | cmap='gray_r') 37 | 38 | for k in range(4): 39 | ik = iframes[k] 40 | ax.plot(ik*np.ones(2), [0, nn], color="b", ls="--") 41 | ax.spines["left"].set_visible(False) 42 | ax.set_yticks([]) 43 | ax.set_ylim([0, nn]) 44 | if env_id=="EnduroNoFrameskip-v4": 45 | ax.set_xlim([780, nt]) 46 | ax.invert_yaxis() 47 | ax.set_xlabel("timepoint in episode") 48 | if igame==0: 49 | ax.text(0.28, 1.02, "layers in DQN: ", color="k", 50 | transform=ax.transAxes, ha="right") 51 | for l, lcol in enumerate(layer_cols): 52 | ax.text(0.3+l*0.13, 1.02, layer_names[l], color=lcol, transform=ax.transAxes) 53 | if l<4: 54 | ax.text(0.3+(l+1)*0.13-0.02, 1.02, ",", color="k", transform=ax.transAxes) 55 | 56 | cax = fig.add_axes([pos[0]+pos[2]*1.015, pos[1], pos[2]*0.015, pos[3]]) 57 | cax.imshow(layer_cols[emb_layer][:,np.newaxis], aspect="auto") 58 | cax.axis("off") 59 | 60 | ax = plt.subplot(grid[i0,j0]) 61 | pos = ax.get_position().bounds 62 | ax.set_position([pos[0]+0.08*pos[2], pos[1]-0.08*pos[3], pos[2], pos[3]]) 63 | grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(2,2, subplot_spec=ax, 64 | wspace=0.05, hspace=0.15) 65 | ax.remove() 66 | for k in range(4): 67 | ax = plt.subplot(grid1[k//2, k%2]) 68 | ax.imshow(ex_frames[k]) 69 | ax.set_title(f"frame {iframes[k]}", fontsize="medium", color="b") 70 | ax.axis("off") 71 | if k==0: 72 | ax.text(0, 1.23, env_id[:-14], fontsize="large", transform=ax.transAxes) 73 | il = plot_label(ltr, il, ax, transl, fs_title) 74 | 75 | if save_figure: 76 | fig.savefig(os.path.join(root, "figures", "fig7.pdf"), dpi=200) -------------------------------------------------------------------------------- /paper/fig8.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import sys, os\n", 12 | "from scipy.stats import zscore \n", 13 | "\n", 14 | "from rastermap import Rastermap, utils\n", 15 | "\n", 16 | "# path to paper code\n", 17 | "sys.path.insert(0, \"/github/rastermap/paper\")\n", 18 | "import simulations, metrics, fig8\n", 19 | "from loaders import load_visual_data, load_alexnet_data\n", 20 | "\n", 21 | "root = \"/media/carsen/ssd2/rastermap_paper/\"\n", 22 | "\n", 23 | "os.makedirs(os.path.join(root, \"simulations/\"), exist_ok=True)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "### 2D simulations" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "\n", 40 | "filename = os.path.join(root, \"simulations\", \"sim2D.npz\")\n", 41 | "if not os.path.exists(filename):\n", 42 | " # create simulated data with intrinsic dimensionality of 2\n", 43 | " simulations.make_2D_simulation(filename)\n", 44 | "\n", 45 | "dat = np.load(filename)\n", 46 | "spks = dat[\"spks\"]\n", 47 | "xi = dat[\"xi\"]\n", 48 | "\n", 49 | "### run algorithms\n", 50 | "model = Rastermap(n_clusters=100, n_splits=0, n_PCs=400).fit(spks, normalize=False)\n", 51 | "isort0 = model.isort \n", 52 | "\n", 53 | "model = Rastermap(n_clusters=100, n_splits=3, n_PCs=400).fit(spks, normalize=False)\n", 54 | "isort_split = model.isort \n", 55 | "X_embedding = model.X_embedding\n", 56 | "\n", 57 | "perplexities = [[10, 100], [10], [30], [100], [300]]\n", 58 | "isorts_tsne = []\n", 59 | "for i, perplexity in enumerate(perplexities):\n", 60 | " print(perplexity)\n", 61 | " y_tsne = metrics.run_TSNE(model.Usv, perplexities=perplexity, verbose=False)\n", 62 | " if i==0:\n", 63 | " isort_tsne = y_tsne[:,0].argsort()\n", 64 | " isorts_tsne.append(y_tsne[:,0].argsort())\n", 65 | "\n", 66 | "isorts = [isort0, isort_split, *isorts_tsne]\n", 67 | "\n", 68 | "### benchmark\n", 69 | "knn_score, knn, rhos = simulations.benchmark_2D(xi, isorts)\n", 70 | " \n", 71 | "np.savez(os.path.join(root, \"simulations\", \"sim2D_results.npz\"),\n", 72 | " X_embedding=X_embedding, isorts=np.array(isorts), \n", 73 | " knn_score=knn_score, knn=knn, rhos=rhos, \n", 74 | " xi=xi)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "### visual cortex data\n", 82 | "\n", 83 | "(this data will be shared upon publication)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "filename = os.path.join(root, \"data/\", \"TX61_3x.npz\")\n", 93 | "stim_filename = os.path.join(root, \"data/\", \"text5k_3x.mat\")\n", 94 | "\n", 95 | "out = load_visual_data(filename, stim_filename)\n", 96 | "spks, istim, stim_times, xpos, ypos, run, ex_stim, img_pca, img_U, Ly, Lx = out\n", 97 | "\n", 98 | "# run rastermap \n", 99 | "# neuron bin in rastermap\n", 100 | "n_neurons = spks.shape[0]\n", 101 | "n_bins = 500\n", 102 | "bin_size = n_neurons // n_bins\n", 103 | "model = Rastermap(n_clusters=100, n_splits=3, nc_splits=25, locality=0., bin_size=bin_size,\n", 104 | " n_PCs=400, mean_time=True).fit(spks, compute_X_embedding=True)\n", 105 | "isort = model.isort\n", 106 | "\n", 107 | "X_embedding = model.X_embedding\n", 108 | "\n", 109 | "# compute stimulus responses sresp and average over the three repeats\n", 110 | "iss = np.zeros((3,5000), \"int\")\n", 111 | "for j in range(5000):\n", 112 | " iss[:,j] = (istim==j).nonzero()[0][:3]\n", 113 | "sresp = spks[:, stim_times]\n", 114 | "sresp = sresp[:, iss].transpose((1,0,2))\n", 115 | "snr_neurons = (zscore(sresp[0], axis=-1) * zscore(sresp[1], axis=-1)).mean(axis=1)\n", 116 | "\n", 117 | "# bin rastermap by neurons\n", 118 | "n_stim = sresp.shape[-1]\n", 119 | "n_bins = 500\n", 120 | "bin_size = n_neurons // n_bins\n", 121 | "x = sresp[:, isort[:(n_neurons // bin_size) * bin_size]]\n", 122 | "x = x.reshape(3, -1, bin_size, n_stim).mean(axis=2)\n", 123 | "n_bins = x.shape[1]\n", 124 | "snr = (zscore(x[0], axis=-1) * zscore(x[1], axis=-1)).mean(axis=-1)\n", 125 | "\n", 126 | "isort2 = []\n", 127 | "\n", 128 | "# mean over 3 repeats\n", 129 | "sresp = sresp.mean(axis=0)\n", 130 | "sresp = zscore(sresp, axis=1)\n", 131 | "x = x.mean(axis=0)\n", 132 | "x = zscore(x, axis=-1)\n", 133 | "\n", 134 | "# ridge regression from 200 image PCs to 1000 rastermap components\n", 135 | "itrain = np.arange(5000)%5>0\n", 136 | "itest = ~itrain\n", 137 | "\n", 138 | "# ridge regression on training data with regularizer of 1e4\n", 139 | "imgTimg = (img_pca[itrain].T @ img_pca[itrain])/itrain.sum()\n", 140 | "imgTx = (img_pca[itrain].T @ x[:, itrain].T)/itrain.sum()\n", 141 | "B = np.linalg.solve(imgTimg + 1e4 * np.eye(200), imgTx)\n", 142 | "\n", 143 | "# reconstruct the receptive fields from the PCs\n", 144 | "rfs = B.T @ img_U\n", 145 | "rfs = np.reshape(rfs, (n_bins, Ly, Lx))\n", 146 | "\n", 147 | "# evaluate model on test data\n", 148 | "rpred = img_pca[itest] @ B\n", 149 | "cpred = (zscore(rpred.T, 1) * zscore(x[:,itest], 1)).mean(1)\n", 150 | "\n", 151 | "print(f\"mean r on test data {cpred.mean()}\")\n", 152 | "\n", 153 | "np.savez(os.path.join(root, \"results\", \"v1stimresp_proc.npz\"),\n", 154 | " X_embedding=X_embedding, bin_size=bin_size, isort=isort, isort2=isort2, \n", 155 | " xpos=xpos, ypos=ypos, x=x,\n", 156 | " stim_times=stim_times, run=run, ex_stim=ex_stim, rfs=rfs)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "\n", 166 | "np.savez(os.path.join(root, \"results\", \"v1stimresp_proc.npz\"),\n", 167 | " X_embedding=X_embedding, bin_size=bin_size, isort=isort, isort2=np.zeros(len(isort)), \n", 168 | " xpos=xpos, ypos=ypos, x=x,\n", 169 | " stim_times=stim_times, run=run, ex_stim=ex_stim, rfs=rfs)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": {}, 175 | "source": [ 176 | "### alexnet activations to same images\n", 177 | "\n", 178 | "(this data will be shared upon publication)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "filename = os.path.join(root, \"data\", \"ann_fvs_Grayscale(224)_TX61_3X.npz\")\n", 188 | "sresp, ilayer, ipos, iconv, nmax = load_alexnet_data(filename)\n", 189 | "\n", 190 | "# run rastermap\n", 191 | "bin_size = 24\n", 192 | "model = Rastermap(n_clusters=100, n_splits=3, nc_splits=25, locality=0., bin_size=bin_size,\n", 193 | " n_PCs=400, mean_time=True).fit(sresp, compute_X_embedding=True)\n", 194 | "isort = model.isort\n", 195 | "\n", 196 | "isort2 = np.zeros(len(isort))\n", 197 | "\n", 198 | "np.savez(os.path.join(root, \"results\", \"alexnet_proc.npz\"),\n", 199 | " X_embedding=model.X_embedding, bin_size=bin_size, isort=isort, isort2=isort2,\n", 200 | " ilayer=ilayer, ipos=ipos, iconv=iconv, nmax=nmax)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "### make figure" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "fig8.fig_all(root, False)" 217 | ] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "Python 3 (ipykernel)", 223 | "language": "python", 224 | "name": "python3" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 3 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython3", 236 | "version": "3.9.16" 237 | }, 238 | "varInspector": { 239 | "cols": { 240 | "lenName": 16, 241 | "lenType": 16, 242 | "lenVar": 40 243 | }, 244 | "kernels_config": { 245 | "python": { 246 | "delete_cmd_postfix": "", 247 | "delete_cmd_prefix": "del ", 248 | "library": "var_list.py", 249 | "varRefreshCmd": "print(var_dic_list())" 250 | }, 251 | "r": { 252 | "delete_cmd_postfix": ") ", 253 | "delete_cmd_prefix": "rm(", 254 | "library": "var_list.r", 255 | "varRefreshCmd": "cat(var_dic_list()) " 256 | } 257 | }, 258 | "position": { 259 | "height": "546px", 260 | "left": "845px", 261 | "right": "20px", 262 | "top": "120px", 263 | "width": "344px" 264 | }, 265 | "types_to_exclude": [ 266 | "module", 267 | "function", 268 | "builtin_function_or_method", 269 | "instance", 270 | "_Feature" 271 | ], 272 | "window_display": false 273 | }, 274 | "vscode": { 275 | "interpreter": { 276 | "hash": "998540cc2fc2836a46e99cd3ca3c37c375205941b23fd1eb4b203c48f2be758f" 277 | } 278 | } 279 | }, 280 | "nbformat": 4, 281 | "nbformat_minor": 2 282 | } 283 | -------------------------------------------------------------------------------- /paper/fig_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | import string 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | import matplotlib.transforms as mtransforms 8 | import numpy as np 9 | from matplotlib import rcParams 10 | from matplotlib.colors import ListedColormap 11 | 12 | cmap_emb = ListedColormap(plt.get_cmap("gist_ncar")(np.linspace(0.05, 0.95), 100)) 13 | 14 | 15 | kp_colors = np.array([[0.55,0.55,0.55], 16 | [0.,0.,1], 17 | [0.8,0,0], 18 | [1.,0.4,0.2], 19 | [0,0.6,0.4], 20 | [0.2,1,0.5], 21 | ]) 22 | 23 | default_font = 12 24 | rcParams["font.family"] = "Arial" 25 | rcParams["savefig.dpi"] = 300 26 | rcParams["axes.spines.top"] = False 27 | rcParams["axes.spines.right"] = False 28 | rcParams["axes.titlelocation"] = "left" 29 | rcParams["axes.titleweight"] = "normal" 30 | rcParams["font.size"] = default_font 31 | 32 | ltr = string.ascii_lowercase 33 | fs_title = 16 34 | weight_title = "normal" 35 | 36 | 37 | def add_apml(ax, xpos, ypos, dx=300, dy=300, tp=30): 38 | x0, x1, y0, y1 = ( 39 | xpos.min() - dx / 2, 40 | xpos.min() + dx / 2, 41 | ypos.max(), 42 | ypos.max() + dy, 43 | ) 44 | ax.plot(np.ones(2) * (y0 + dy / 2), [x0, x1], color="k") 45 | ax.plot([y0, y1], np.ones(2) * (x0 + dx / 2), color="k") 46 | ax.text(y0 + dy / 2, x0 - tp, "P", ha="center", va="top", fontsize="small") 47 | ax.text(y0 + dy / 2, x0 + dx + tp, "A", ha="center", va="bottom", fontsize="small") 48 | ax.text(y0 - tp, x0 + dx / 2, "M", ha="right", va="center", fontsize="small") 49 | ax.text(y0 + dy + tp, x0 + dx / 2, "L", ha="left", va="center", fontsize="small") 50 | print(x0, y0) 51 | 52 | def plot_label(ltr, il, ax, trans, fs_title=20): 53 | ax.text( 54 | 0.0, 55 | 1.0, 56 | ltr[il], 57 | transform=ax.transAxes + trans, 58 | va="bottom", 59 | fontsize=fs_title, 60 | fontweight="bold", 61 | ) 62 | il += 1 63 | return il 64 | 65 | 66 | def plot_raster(ax, X, xmin, xmax, vmax=1.5, symmetric=False, cax=None, nper=30, 67 | label=False, n_neurons=500, n_sec=10, fs=20, padding=0.025, 68 | padding_x = 0.005, xlabel="sec.", 69 | label_pos="left", axis_off=False, 70 | cax_label="x", cax_orientation="horizontal"): 71 | xr = xmax - xmin 72 | nn = X.shape[0] 73 | if n_neurons is None: 74 | xmin0, xmax0 = 0, X.shape[1] 75 | else: 76 | xmin0, xmax0 = xmin, xmax 77 | xmin, xmax = 0, xmax - xmin 78 | im = ax.imshow(X[:, xmin0:xmax0], vmin=-vmax if symmetric else 0, vmax=vmax, 79 | cmap="RdBu_r" if symmetric else "gray_r", aspect="auto") 80 | ax.axis("off") 81 | if label_pos=="left": 82 | if n_neurons is not None: 83 | ax.plot(-padding_x*xr * np.ones(2), nn - np.array([0, n_neurons/nper]), color="k") 84 | if n_sec is not None: 85 | ax.plot(xmin + np.array([0, fs*n_sec]), nn*(1+padding/2) + np.zeros(2), color="k") 86 | else: 87 | if n_neurons is not None: 88 | ax.plot((1+padding_x)*xr * np.ones(2), nn - np.array([0, n_neurons/nper]), color="k") 89 | if n_sec is not None: 90 | ax.plot(xmin + np.array([xr-fs*n_sec, xr]), nn*(1+padding/2) + np.zeros(2), color="k") 91 | ax.set_ylim([0, nn*(1+padding)]) 92 | ax.invert_yaxis() 93 | if cax is not None: 94 | plt.colorbar(im, cax, orientation=cax_orientation) 95 | if cax_label=="x": 96 | cax.set_xlabel("z-scored\n ") 97 | else: 98 | cax.text(-0.2,0,"z-scored", transform=cax.transAxes, 99 | ha="right", 100 | rotation=90 if cax_orientation=="vertical" else 0) 101 | if n_neurons is None: 102 | ax.set_xlim([xmin, xmax]) 103 | else: 104 | if label_pos=="left": 105 | ax.set_xlim([-2*padding_x*xr, xr]) 106 | else: 107 | ax.set_xlim([0*xr, xr*(1+padding_x*2)]) 108 | if label: 109 | if label_pos=="left": 110 | if n_neurons is not None: 111 | ht=ax.text(-2*padding_x*xr, X.shape[0], f"{n_neurons} neurons", ha="right") 112 | ht.set_rotation(90) 113 | ax.text(xmin, nn*(1+padding), f"{n_sec} {xlabel}", va="top") 114 | else: 115 | if n_neurons is not None: 116 | ht=ax.text((1+2*padding_x)*xr, X.shape[0], f"{n_neurons} neurons", ha="left") 117 | ht.set_rotation(90) 118 | ax.text(xr, nn*(1+padding), f"{n_sec} {xlabel}", 119 | va="top", ha="right") 120 | -------------------------------------------------------------------------------- /paper/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | from sklearn.manifold import SpectralEmbedding, Isomap, LocallyLinearEmbedding 5 | import time 6 | from scipy.stats import zscore, spearmanr 7 | from multiprocessing import Pool 8 | from scipy.spatial.distance import pdist 9 | import numpy as np 10 | import scipy 11 | from openTSNE import TSNE, affinity, TSNEEmbedding 12 | from umap import UMAP 13 | 14 | def emb_to_idx(emb): 15 | if emb.ndim==2: 16 | embs = emb[:,0] 17 | else: 18 | embs = emb 19 | isort = embs.argsort() 20 | idx = np.zeros_like(isort) 21 | idx[isort] = np.arange(0, len(isort)) 22 | return idx 23 | 24 | def triplet_order(gt, emb): 25 | """ benchmarking triplet score for embedding with ground truth""" 26 | if (gt<1).sum() == len(gt): 27 | idx_gt = emb_to_idx(gt) 28 | idx_emb = emb_to_idx(emb) 29 | nn = len(idx_gt) 30 | correct_triplets = 0 31 | nrand = nn * 10 32 | for i in range(nrand): 33 | i_gt = np.random.choice(nn, size=3, replace=False) 34 | i_gt = i_gt[i_gt.argsort()] 35 | triplet_gt = np.array([np.nonzero(idx_gt==k)[0][0] for k in i_gt]) 36 | triplet_emb = idx_emb[triplet_gt] 37 | if ((triplet_emb[0] < triplet_emb[1] and triplet_emb[1] < triplet_emb[2]) or 38 | (triplet_emb[0] > triplet_emb[1] and triplet_emb[1] > triplet_emb[2])): 39 | correct_triplets += 1 40 | return correct_triplets / nrand 41 | else: 42 | n_modules = int(np.ceil(gt.max())) 43 | triplet_scores = np.zeros(n_modules) 44 | for k in range(n_modules): 45 | inds = np.floor(gt) == k 46 | triplet_scores[k] = triplet_order(gt[inds]-k, emb[inds]) 47 | return triplet_scores 48 | 49 | def embedding_contamination(gt, emb): 50 | """ benchmarking contamination score for embedding with ground truth""" 51 | n_modules = int(np.ceil(gt.max())) 52 | contamination_scores = np.zeros(n_modules) 53 | isort = emb.flatten().argsort() 54 | for k in range(n_modules): 55 | imod = np.floor(gt) == k 56 | in_mod = np.nonzero(imod)[0] 57 | nn = len(in_mod) 58 | nrand = nn * 10 59 | nk = 0 60 | for i in range(nrand): 61 | i_gt = in_mod[np.random.choice(nn, size=2, replace=False)] 62 | i0 = np.nonzero(isort == i_gt[0])[0][0] 63 | i1 = np.nonzero(isort == i_gt[1])[0][0] 64 | if np.abs(i1 - i0) > 5: 65 | i0, i1 = min(i0, i1), max(i0, i1) 66 | contamination_scores[k] += 1 - imod[isort[i0+1 : i1-1]].mean() 67 | nk+=1 68 | contamination_scores[k] /= nk 69 | return contamination_scores 70 | 71 | def benchmarks(xi_all, embs): 72 | n_modules = int(np.ceil(xi_all.max())) 73 | emb_rand = np.random.rand(len(embs[0]), 1).astype("float32") 74 | contamination_scores = np.zeros((len(embs)+1, n_modules)) 75 | for k, emb in enumerate(embs): 76 | contamination_scores[k] = embedding_contamination(xi_all, emb) 77 | contamination_scores[k+1] = embedding_contamination(xi_all, emb_rand) 78 | 79 | triplet_scores = np.zeros((len(embs)+1, n_modules)) 80 | for k, emb in enumerate(embs): 81 | triplet_scores[k] = triplet_order(xi_all, emb) 82 | triplet_scores[k+1] = triplet_order(xi_all, emb_rand) 83 | 84 | return contamination_scores, triplet_scores 85 | 86 | def embedding_quality_gt(gt, embs, knn=[10,50,100,200,500]): 87 | """ benchmarking local and global scores for embedding with ground truth """ 88 | idx_gt = emb_to_idx(gt)[:,np.newaxis] 89 | mnn = np.zeros((len(embs), len(knn))) 90 | rho = np.zeros((len(embs),)) 91 | for k,emb in enumerate(embs): 92 | idx_emb = emb_to_idx(emb)[:,np.newaxis] 93 | mnn[k], rho[k] = embedding_quality(idx_gt, idx_emb, knn=knn) 94 | return mnn, rho 95 | 96 | def distance_matrix(Z, n_X=None, wrapping=False, correlation=False): 97 | if wrapping: 98 | #n_X = int(np.floor(Z.max() + 1)) 99 | dists = (Z - Z[:, np.newaxis, :]) % n_X 100 | Zdist = (np.minimum(dists, n_X - dists)**2).sum(axis=-1) 101 | else: 102 | if correlation: 103 | Zdist = Z @ Z.T 104 | Z2 = 1e-10 + np.diag(Zdist)**.5 105 | Zdist = 1 - Zdist / np.outer(Z2, Z2) 106 | else: 107 | #Zdist = ((Z - Z[:,np.newaxis,:])**2).sum(axis=-1) 108 | Z2 = np.sum(Z**2, 1) 109 | Zdist = Z2 + Z2[:, np.newaxis] - 2 * Z @ Z.T 110 | Zdist = np.maximum(0, Zdist) 111 | 112 | #import pdb; pdb.set_trace(); 113 | return Zdist 114 | 115 | def embedding_quality(X, Z, knn=[10,50,100,200,500], subsetsize=2000, 116 | wrapping=False, correlation=False, n_X=None): 117 | """ changed correlation to False and n_X from 0 to None""" 118 | from sklearn.neighbors import NearestNeighbors 119 | np.random.seed(101) 120 | if subsetsize < X.shape[0]: 121 | subset = np.random.choice(X.shape[0], size=subsetsize, replace=False) 122 | else: 123 | subsetsize = X.shape[0] 124 | subset = slice(0, X.shape[0]) 125 | Xdist = distance_matrix(X[subset], correlation=correlation) 126 | Zdist = distance_matrix(Z[subset], n_X=n_X, wrapping=wrapping) 127 | xd = Xdist[np.tril_indices(Xdist.shape[0], -1)] 128 | zd = Zdist[np.tril_indices(Xdist.shape[0], -1)] 129 | mnn = [] 130 | if not isinstance(knn, (np.ndarray, list)): 131 | knn = [knn] 132 | elif isinstance(knn, np.ndarray): 133 | knn = list(knn) 134 | 135 | for kni in knn: 136 | nbrs1 = NearestNeighbors(n_neighbors=kni, metric="precomputed").fit(Xdist) 137 | ind1 = nbrs1.kneighbors(return_distance=False) 138 | 139 | nbrs2 = NearestNeighbors(n_neighbors=kni, metric="precomputed").fit(Zdist) 140 | ind2 = nbrs2.kneighbors(return_distance=False) 141 | 142 | intersections = 0.0 143 | for i in range(subsetsize): 144 | intersections += len(set(ind1[i]) & set(ind2[i])) 145 | mnn.append(intersections / subsetsize / kni) 146 | 147 | rho = spearmanr(xd, zd).correlation 148 | 149 | return (mnn, rho) 150 | 151 | def run_TSNE(U, perplexities=[30], metric="cosine", verbose=False): 152 | if len(perplexities) > 1: 153 | affinities_annealing = affinity.PerplexityBasedNN( 154 | U, 155 | perplexity=perplexities[1], 156 | metric=metric, 157 | n_jobs=16, 158 | random_state=1, 159 | verbose=verbose 160 | ) 161 | embedding = TSNEEmbedding( 162 | U[:,:1]*0.0001, 163 | affinities_annealing, 164 | negative_gradient_method="fft", 165 | random_state=1, 166 | n_jobs=16, 167 | verbose=verbose 168 | ) 169 | embedding1 = embedding.optimize(n_iter=250, exaggeration=12, momentum=0.5) 170 | embedding2 = embedding1.optimize(n_iter=750, exaggeration=1, momentum=0.8) 171 | 172 | affinities_annealing.set_perplexity(perplexities[0]) 173 | embeddingOPENTSNE = embedding2.optimize(n_iter=500, momentum=0.8) 174 | else: 175 | tsne = TSNE( 176 | perplexity=perplexities[0], 177 | metric=metric, 178 | n_jobs=16, 179 | random_state=1, 180 | verbose=verbose, 181 | n_components = 1, 182 | initialization = .0001 * U[:,:1], 183 | ) 184 | embeddingOPENTSNE = tsne.fit(U) 185 | 186 | return embeddingOPENTSNE 187 | 188 | def run_UMAP(U, n_neighbors=15, min_dist=0.1, metric="cosine"): 189 | embeddingUMAP = UMAP(n_components=1, n_neighbors=n_neighbors, random_state=1, 190 | min_dist=min_dist, init=U[:,:1], metric=metric).fit_transform(U) 191 | return embeddingUMAP 192 | 193 | def run_LE(U): 194 | LE = SpectralEmbedding(n_components=1, n_jobs=16, random_state=1).fit(U) 195 | return LE.embedding_ 196 | 197 | def run_LLE(U, n_neighbors=5): 198 | LLE = LocallyLinearEmbedding(n_components=1, n_jobs=16, n_neighbors=n_neighbors, random_state=1).fit(U) 199 | return LLE.embedding_ 200 | 201 | def run_isomap(U, n_neighbors=5, metric="cosine"): 202 | IM = Isomap(n_components=1, n_jobs=16, n_neighbors=n_neighbors, 203 | metric=metric).fit(U) 204 | return IM.embedding_ 205 | 206 | -------------------------------------------------------------------------------- /paper/other_upsampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | import numpy as np 5 | from scipy.stats import zscore 6 | 7 | #if self.quadratic_upsample: 8 | # Y = quadratic_upsampling1D(cc, g) 9 | #elif self.gradient_upsample: 10 | # Y = upsample_grad(cc, self.n_components, self.n_X) 11 | 12 | def quadratic_upsampling1D(cc, grid, npts=10): 13 | """ upsample grid using quadratic approximation 14 | sample correlation with grid is cc - ngrid x n_samples """ 15 | npts = max(3, npts) 16 | if npts%2!=1: 17 | npts += 1 18 | 19 | dims, n_X = grid.shape 20 | n_samples = cc.shape[1] 21 | 22 | cbest = cc.argmax(axis=0) 23 | 24 | # find peaks and shift to have at least 5 pts 25 | ibest = cbest 26 | imin = np.maximum(0, ibest-npts//2) 27 | ishift = n_X - (ibest+npts//2+1) < 0 28 | imin[ishift] -= (ibest[ishift]+npts//2+1) - n_X 29 | icent = imin + npts//2 30 | 31 | # create grid of points 32 | igrid = np.arange(0,npts) 33 | # convert to cc inds 34 | cinds = igrid + imin[:,np.newaxis] 35 | 36 | # make float and mean centered for regression 37 | igrid = igrid.astype(np.float32) - npts//2 38 | 39 | C = cc[cinds, np.tile(np.arange(0, n_samples)[:,np.newaxis], (1, igrid.size))] 40 | IJ = np.stack((np.ones_like(igrid), igrid**2, igrid), axis=1) 41 | 42 | A = np.linalg.solve(IJ.T @ IJ, IJ.T @ C.T) 43 | 44 | xmax = np.clip(-A[2] / (2*A[1]), -npts//2, npts//2) 45 | 46 | xdelta = np.diff(grid[0,:]).mean() 47 | xmax = xmax*xdelta + grid[0,icent] 48 | Y = xmax[:,np.newaxis] 49 | 50 | return Y 51 | 52 | def quadratic_upsampling2D(X, cc, x_m, y_m): 53 | n_X = x_m.shape[0] 54 | n_samples = X.shape[0] 55 | 56 | cbest = cc.argmax(axis=0) 57 | 58 | ibest, jbest = np.unravel_index(cbest, (n_X, n_X)) 59 | imin = np.maximum(0, ibest-1) 60 | imin[n_X - (ibest+2) < 0] -= 1 61 | jmin = np.maximum(0, jbest-1) 62 | jmin[n_X - (jbest+2) < 0] -= 1 63 | icent, jcent = imin+1, jmin+1 64 | 65 | igrid, jgrid = np.meshgrid(np.arange(0,3), np.arange(0,3), indexing='ij') 66 | igrid, jgrid = igrid.flatten(), jgrid.flatten() 67 | iinds = igrid + imin[:,np.newaxis] 68 | jinds = jgrid + jmin[:,np.newaxis] 69 | igrid = igrid.astype(np.float32) - 1. 70 | jgrid = jgrid.astype(np.float32) - 1. 71 | 72 | cinds = np.ravel_multi_index((iinds, jinds), (n_X, n_X)) 73 | C = cc[cinds, np.tile(np.arange(0, n_samples)[:,np.newaxis], (1, 9))] 74 | IJ = np.stack((np.ones_like(igrid), igrid**2 + jgrid**2, igrid, jgrid), axis=1) 75 | A = np.linalg.solve(IJ.T @ IJ, IJ.T @ C.T) 76 | xmax = np.clip(-A[2] / (2*A[1]), -1, 1) 77 | ymax = np.clip(-A[3] / (2*A[1]), -1, 1) 78 | 79 | # put in original space 80 | xdelta = np.diff(x_m[:,0]).mean() 81 | ydelta = np.diff(y_m[0]).mean() 82 | xmax = xmax*xdelta + x_m[icent,0] 83 | ymax = ymax*ydelta + y_m[0,jcent] 84 | 85 | Y = np.stack((xmax, ymax), axis=1) 86 | return Y 87 | 88 | def grid_upsampling2(X, X_nodes, Y_nodes, n_X=41, n_neighbors=50): 89 | n_X = 41 90 | x_m = np.linspace(Y_nodes[:,0].min(), Y_nodes[:,0].max(), n_X) 91 | y_m = np.linspace(Y_nodes[:,1].min(), Y_nodes[:,1].max(), n_X) 92 | 93 | x_m, y_m = np.meshgrid(x_m, y_m, indexing='ij') 94 | xy = np.vstack((x_m.flatten(), y_m.flatten())) 95 | 96 | ds = (xy[0][:,np.newaxis] - Y_nodes[:,0])**2 + (xy[1][:,np.newaxis] - Y_nodes[:,1])**2 97 | isort = np.argsort(ds, 1)[:,:n_neighbors] 98 | nraster = xy.shape[1] 99 | Xrec = np.zeros((nraster, X_nodes.shape[1])) 100 | for j in range(nraster): 101 | ineigh = isort[j] 102 | dists = ds[j, ineigh] 103 | w = np.exp(-dists / dists[7]) 104 | M, N = X_nodes[ineigh], Y_nodes[ineigh] 105 | N = np.concatenate((N, np.ones((n_neighbors,1))), axis=1) 106 | R = np.linalg.solve((N.T * w) @ N, (N.T * w) @ M) 107 | Xrec[j] = xy[:,j] @ R[:2] + R[-1] 108 | 109 | Xrec = Xrec / (Xrec**2).sum(1)[:,np.newaxis]**.5 110 | cc = Xrec @ zscore(X, 1).T 111 | cc = np.maximum(0, cc) 112 | imax = np.argmax(cc, 0) 113 | Y = xy[:, imax].T 114 | 115 | return Y, cc, x_m, y_m 116 | 117 | def kriging_upsampling(X, X_nodes, Y_nodes, grid_upsample=10, sig=0.5): 118 | # assume the input is 5 by 5 by 5 by 5.... vectorized 119 | if (Y_nodes==-1).sum()>0: 120 | Xn = X_nodes.copy()[X_nodes!=-1] 121 | Yn = Y_nodes.copy()[Y_nodes!=-1] 122 | 123 | nclust = Y_nodes.max()+1 124 | xs = np.arange(0, nclust) 125 | xu = np.arange(0, nclust, 1./grid_upsample) 126 | Kxx = np.exp(-(xs[:,np.newaxis] - xs)**2 / sig) 127 | Kxu = np.exp(-(xs[:,np.newaxis] - xu)**2 / sig) 128 | Km = np.linalg.solve(Kxx + np.eye(Kxx.shape[0]), Kxu) 129 | 130 | Xrec = X_nodes[Y_nodes[:,0].argsort()].T @ Km 131 | Xrec = Xrec.T 132 | Xrec = Xrec / (1e-10 + (Xrec**2).sum(axis=1)[:,np.newaxis]**.5) 133 | 134 | cc = Xrec @ zscore(X, axis=1).T 135 | cc = np.maximum(0, cc) 136 | imax = np.argmax(cc, 0) 137 | Y = xu[imax].T 138 | Y = Y[:,np.newaxis] 139 | 140 | return Y, cc, xu, Xrec 141 | 142 | 143 | 144 | def upsample_grad(CC, dims, nX): 145 | CC /= np.amax(CC, axis=1)[:, np.newaxis] 146 | xid = np.argmax(CC, axis=1) 147 | if dims==2: 148 | ys, xs = np.meshgrid(np.arange(nX), np.arange(nX)) 149 | y0 = np.vstack((xs.flatten(), ys.flatten())) 150 | else: 151 | ys = np.arange(nX) 152 | y0 = ys[np.newaxis,:] 153 | 154 | eta = .1 155 | y = optimize_neurons(CC, y0[:,xid], y0, eta) 156 | return y 157 | 158 | def gradient_descent_neurons(inputs): 159 | CC, yinit, ynodes, eta = inputs 160 | flag = 1 161 | niter = 201 # 201 162 | alpha = 1. 163 | y = yinit 164 | x = 1. 165 | sig = 1. 166 | eta = np.linspace(eta, eta/10, niter) 167 | for j in range(niter): 168 | yy0 = y[:, np.newaxis] - ynodes 169 | if flag: 170 | K = np.exp(-np.sum(yy0**2, axis=0)/(2*sig**2)) 171 | else: 172 | yyi = 1 + np.sum(yy0**2, axis=0) 173 | K = 1/yyi**alpha 174 | x = np.sum(K*CC)/np.sum(K**2) 175 | err = (x*K - CC) 176 | if flag: 177 | Kprime = - x * yy0 * K 178 | else: 179 | Kprime = - yy0 * alpha * 1/yyi**(alpha+1) 180 | dy = np.sum(Kprime *err, axis=-1) 181 | y = y - eta[j] * dy 182 | return y 183 | 184 | def optimize_neurons(CC, y, ynodes, eta): 185 | inputs = [] 186 | for j in range(CC.shape[0]): 187 | inputs.append((CC[j,:], y[:, j], ynodes, eta)) 188 | 189 | num_cores = multiprocessing.cpu_count() 190 | with Pool(num_cores) as p: 191 | y = p.map(gradient_descent_neurons, inputs) 192 | #y = gradient_descent_neurons((CC, y, ynodes, eta)) 193 | 194 | y = np.array(y).T 195 | return y 196 | 197 | 198 | def LLE_upsampling(X, X_nodes, Y_nodes, n_neighbors=10, LLE = 1): 199 | """ X is original space points, X_nodes nodes in original space, Y_nodes nodes in embedding space """ 200 | e_dists = ((Y_nodes[:,:,np.newaxis] - Y_nodes.T)**2).sum(axis=1) 201 | cc = -np.sum(X_nodes**2, 1)[:,np.newaxis] - np.sum(X**2, 1) + 2 * X_nodes @ X.T 202 | y = np.zeros((X.shape[0],2)) 203 | for i in range(X.shape[0]): 204 | x = X[i] 205 | ineigh0 = cc[:,i].argmax() #cc[:,i].argsort()[::-1][:n_neighbors] 206 | ineigh = e_dists[ineigh0].argsort()[:n_neighbors] 207 | if LLE: 208 | z = X_nodes[ineigh] - x 209 | G = z @ z.T 210 | alpha = 1e-8 211 | w = np.linalg.solve(G + alpha*np.eye(n_neighbors), np.ones(n_neighbors, np.float32)) 212 | else: 213 | w = np.linalg.solve(X_nodes[ineigh] @ X_nodes[ineigh].T, X_nodes[ineigh] @ x) 214 | w /= w.sum() 215 | y[i] = w @ Y_nodes[ineigh] 216 | return y 217 | 218 | def PCA_upsampling(X, X_nodes, Y_nodes): 219 | from sklearn.decomposition import PCA 220 | n_samples, n_features = X.shape 221 | n_nodes = X_nodes.shape[0] 222 | n_components = Y_nodes.shape[1] 223 | 224 | cc = -np.sum(X_nodes**2, 1)[:,np.newaxis] - np.sum(X**2, 1) + 2 * X_nodes @ X.T 225 | inode = cc.argmax(axis=0) 226 | Y = np.zeros((n_samples, 2)) 227 | for n in range(n_nodes): 228 | pts = X[inode==n] 229 | delta = PCA(n_components=2).fit_transform(pts) 230 | Y[inode==n] = Y_nodes[n] + 1e-2 * delta 231 | 232 | return Y 233 | 234 | def knn_upsampling(X, X_nodes, Y_nodes, n_neighbors=10): 235 | n_samples = X.shape[0] 236 | Ndist = np.sum(X_nodes**2, 1)[:,np.newaxis] + np.sum(X**2, 1) - 2 * X_nodes @ X.T 237 | inds_k = Ndist.argsort(axis=0)[:n_neighbors] 238 | Ndist_k = np.sort(Ndist, axis=0)[:n_neighbors] 239 | sigma = Ndist_k[0] 240 | w = np.exp(-1 * Ndist_k / sigma) 241 | w /= w.sum(axis=0) 242 | Y_knn = (w[...,np.newaxis] * Y_nodes[inds_k]).sum(axis=0) 243 | return Y_knn 244 | 245 | def subspace_upsampling2(X, X_nodes, Y_nodes, n_neighbors=10): 246 | e_dists = ((Y_nodes[:,:,np.newaxis] - Y_nodes.T)**2).sum(axis=1) 247 | cc = -np.sum(X_nodes**2, 1)[:,np.newaxis] - np.sum(X**2, 1) + 2 * X_nodes @ X.T 248 | #cc = zscore(X_nodes,axis=1) @ zscore(X,axis=1).T 249 | 250 | n_samples, n_features = X.shape 251 | n_nodes = X_nodes.shape[0] 252 | n_components = Y_nodes.shape[1] 253 | 254 | Y = np.zeros((n_samples, n_components)) 255 | for n in range(n_nodes): 256 | ineigh = e_dists[n].argsort()[:n_neighbors] 257 | # min || M - (a * N @ R + b) || 258 | M, N = X_nodes[ineigh], Y_nodes[ineigh] 259 | ones = np.ones((n_neighbors,1)) 260 | a = 1 261 | b = 0 262 | for k in range(1): 263 | cov = M.T @ N 264 | model = PCA(n_components=n_components).fit(cov) 265 | vv = model.components_.T 266 | uv = (cov @ vv.T) / model.singular_values_ 267 | R = vv @ uv.T 268 | 269 | a = ((M.T @ (N @ R)).sum() - (R.T @ N.T * b).sum()) / (R.T @ N.T @ N @ R).sum() 270 | 271 | R1 = np.linalg.solve(N.T @ N, N.T @ M) 272 | print( ((M - N @ R)**2).sum(), ((M - (a *N @ R + b))**2).sum(), ((M - N @ R1)**2).sum()) 273 | return Y 274 | 275 | if 0: 276 | Y = np.zeros((n_nodes, n_samples, n_components)) 277 | rerr = np.zeros((n_nodes, n_samples)) 278 | for n in range(n_nodes): 279 | ineigh = e_dists[n].argsort()[:n_neighbors] 280 | M, N = X_nodes[ineigh], Y_nodes[ineigh] 281 | N = np.concatenate((N, np.ones((n_neighbors,1))), axis=1) 282 | R = np.linalg.solve(N.T @ N, N.T @ M) 283 | Xz = X.copy() - R[2] 284 | Y_est[n] = np.linalg.solve(R[:2] @ R[:2].T, R[:2] @ Xz.T).T 285 | rerr[n] = ((Xz - Y_est[n] @ R[:2])**2).sum(axis=1) 286 | 287 | # Y = Y_est[rerr.argmin(axis=0), np.arange(0, n_samples)] 288 | Y = Y_est[cc.argmax(axis=0), np.arange(0, n_samples)] 289 | #return Y -------------------------------------------------------------------------------- /paper/qrdqn.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from scipy.stats import zscore, mode 5 | import torch 6 | from torch import nn 7 | from rastermap import Rastermap , utils 8 | 9 | from rl_zoo3.utils import ALGOS 10 | from stable_baselines3.common.preprocessing import preprocess_obs 11 | from stable_baselines3.common.env_util import make_atari_env 12 | from stable_baselines3.common.vec_env import VecFrameStack 13 | 14 | 15 | # setup model with learning_rate OFF 16 | custom_objects = { 17 | "learning_rate": 0.0, 18 | "lr_schedule": lambda _: 0.0, 19 | "clip_range": lambda _: 0.0, 20 | "optimize_memory_usage": False 21 | } 22 | 23 | vis_cnn = {} 24 | def hook_fn_cnn(m, i, o): 25 | vis_cnn[m] = o.detach().clone().cpu() 26 | 27 | vis_mlp = {} 28 | def hook_fn_mlp(m, i, o): 29 | vis_mlp[m] = o.detach().clone().cpu() 30 | 31 | def get_all_layers(model, hook_fn): 32 | for name, layer in model._modules.items(): 33 | if isinstance(layer, nn.Sequential): 34 | get_all_layers(layer, hook_fn) 35 | else: 36 | layer.register_forward_hook(hook_fn) 37 | 38 | def get_all_activations_qrdqn(cnn_net, mlp_net, obs_torch): 39 | ilayer = np.zeros((0,), "int") 40 | cnn = np.zeros((0,), np.float32) 41 | mlp = np.zeros((0,), np.float32) 42 | out = cnn_net(obs_torch) 43 | out2 = mlp_net(out) 44 | i = 0 45 | for layer_name in vis_cnn.keys(): 46 | if "ReLU()" in str(layer_name): 47 | act = vis_cnn[layer_name].flatten().detach().numpy() 48 | cnn = np.concatenate((cnn, act), axis=0) 49 | ilayer = np.concatenate((ilayer, i*np.ones(len(act), "int")), axis=0) 50 | i += 1 51 | mlp = out2.cpu().flatten().detach().numpy() 52 | ilayer = np.concatenate((ilayer, i*np.ones(len(mlp), "int")), axis=0) 53 | return cnn, mlp, ilayer 54 | 55 | def run_qrdqn(model_folder, root, env_id, n_seeds=10, device=torch.device("cuda")): 56 | algo = "qrdqn" 57 | 58 | log_path = os.path.join(model_folder, algo, env_id+"_1") 59 | model_path = os.path.join(log_path, f"{env_id}.zip") 60 | 61 | print(f"using Atari env {env_id}") 62 | env = make_atari_env(env_id, n_envs=1, seed=0) 63 | # Frame-stacking with 4 frames 64 | env = VecFrameStack(env, n_stack=4) 65 | model = ALGOS[algo].load(model_path, env=env, custom_objects=custom_objects) 66 | cnn_net = model.policy.quantile_net.features_extractor 67 | print(cnn_net) 68 | mlp_net = model.policy.quantile_net.quantile_net 69 | print(mlp_net) 70 | get_all_layers(cnn_net, hook_fn_cnn) 71 | 72 | spks, eps_len, actions = [], [], [] 73 | for seed in range(n_seeds): 74 | print(f"seed {seed}") 75 | env = make_atari_env(env_id, n_envs=1, seed=seed) 76 | # Frame-stacking with 4 frames 77 | env = VecFrameStack(env, n_stack=4) 78 | obs = env.reset() 79 | 80 | deterministic = False 81 | state = None 82 | episode_reward = 0.0 83 | ep_len = 0 84 | n_iterations = 4000 85 | 86 | # don't show game 87 | render = False 88 | 89 | obs_all, actions_all, states_all = [], [], [] 90 | iteration = 0 91 | for _ in range(n_iterations): 92 | 93 | action, state = model.predict(obs, state=state, deterministic=deterministic) 94 | obs_torch = torch.from_numpy(obs.transpose(0,3,1,2).astype(np.float32).copy()).to(device) 95 | preprocessed_obs = preprocess_obs(obs_torch, 96 | model.policy.observation_space, 97 | normalize_images=model.policy.normalize_images) 98 | obs_all.append(env.get_images()[0]) 99 | #actions = torch.from_numpy(action[np.newaxis,:]).to(device) 100 | #out = cnn_net(preprocessed_obs) 101 | cnn, mlp, ilayer = get_all_activations_qrdqn(cnn_net, mlp_net, preprocessed_obs) 102 | 103 | if iteration==0: 104 | activations = np.zeros((len(cnn)+len(mlp), n_iterations), "float32") 105 | print(activations.shape, ilayer.max()+1) 106 | activations[:len(cnn), iteration] = cnn 107 | activations[len(cnn):, iteration] = mlp 108 | 109 | obs, reward, done, infos = env.step(action) 110 | 111 | if render: 112 | env.render("human") 113 | 114 | episode_reward += reward[0] 115 | ep_len += 1 116 | iteration += 1 117 | actions_all.append(action) 118 | states_all.append(state) 119 | 120 | if infos is not None: 121 | episode_infos = infos[0].get("episode") 122 | if episode_infos is not None: 123 | print(f"Atari Episode Score: {episode_infos['r']:.2f}") 124 | print("Atari Episode Length", episode_infos["l"]) 125 | break; 126 | activations = activations[:, 4:iteration] 127 | env.close() 128 | spks.append(activations) 129 | actions.append(np.array(actions_all)) 130 | eps_len.append(activations.shape[1]) 131 | 132 | obs = np.stack(tuple(obs_all), axis=0).squeeze()[4:] 133 | spks = np.concatenate(tuple(spks), axis=1) 134 | actions = np.concatenate(tuple(actions), axis=0) 135 | eps_len = np.array(eps_len) 136 | 137 | np.savez(os.path.join(root, "simulations/", f"qrdqn_{env_id}.npz"), 138 | spks=spks, ilayer=ilayer, obs=obs, actions=actions, 139 | eps_len=eps_len) 140 | 141 | def sort_spks(root, env_id): 142 | dat = np.load(os.path.join(root, "simulations/", f"qrdqn_{env_id}.npz")) 143 | spks = dat["spks"] 144 | obs = dat["obs"] 145 | ilayer = dat["ilayer"] 146 | eps_len = dat["eps_len"] 147 | 148 | x_std = spks.std(axis=1) 149 | igood = x_std > 1e-3 150 | print(igood.mean()) 151 | 152 | S = zscore(spks[igood], axis=1) 153 | rm_model = Rastermap(time_lag_window=10, locality=0.75).fit(S[:,:]) 154 | isort = rm_model.isort 155 | 156 | # show last episode 157 | bin_size = 50 158 | X_embedding = zscore(utils.bin1d(S[:,-eps_len[-1]:][isort], bin_size, axis=0), axis=1) 159 | #if env_id=="EnduroNoFrameskip-v4": 160 | # X_embedding = X_embedding[:,780:] 161 | # obs = obs[780:] 162 | nn, nt = X_embedding.shape 163 | if env_id=="EnduroNoFrameskip-v4": 164 | nt = nt-900 165 | iframes = np.linspace(780 + nt*0.1, 780 + nt*0.9, 4).astype("int") 166 | else: 167 | iframes = np.linspace(nt*0.1, nt*0.9, 4).astype("int") 168 | print(iframes) 169 | emb_layer = mode(ilayer[igood][isort][:nn*bin_size].reshape(-1, bin_size), axis=1, keepdims=False).mode 170 | ex_frames = obs[iframes] 171 | print(ex_frames.shape) 172 | np.savez(os.path.join(root, "simulations/", f"qrdqn_{env_id}_results.npz"), 173 | X_embedding=X_embedding, emb_layer=emb_layer, 174 | ex_frames=ex_frames, iframes=iframes) -------------------------------------------------------------------------------- /paper/splitting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import sys, os\n", 12 | "from scipy.stats import zscore \n", 13 | "\n", 14 | "from rastermap import Rastermap, utils\n", 15 | "\n", 16 | "# path to paper code\n", 17 | "sys.path.insert(0, \"/github/rastermap/paper\")\n", 18 | "import simulations, metrics, fig_splitting\n", 19 | "from loaders import load_visual_data, load_alexnet_data\n", 20 | "\n", 21 | "root = \"/media/carsen/ssd2/rastermap_paper/\"\n", 22 | "\n", 23 | "os.makedirs(os.path.join(root, \"simulations/\"), exist_ok=True)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "### 2D simulations" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "\n", 40 | "filename = os.path.join(root, \"simulations\", \"sim2D.npz\")\n", 41 | "if not os.path.exists(filename):\n", 42 | " # create simulated data with intrinsic dimensionality of 2\n", 43 | " simulations.make_2D_simulation(filename)\n", 44 | "\n", 45 | "dat = np.load(filename)\n", 46 | "spks = dat[\"spks\"]\n", 47 | "xi = dat[\"xi\"]\n", 48 | "\n", 49 | "### run algorithms\n", 50 | "model = Rastermap(n_clusters=100, n_splits=0, n_PCs=400).fit(spks, normalize=False)\n", 51 | "isort0 = model.isort \n", 52 | "\n", 53 | "model = Rastermap(n_clusters=100, n_splits=3, n_PCs=400).fit(spks, normalize=False)\n", 54 | "isort_split = model.isort \n", 55 | "X_embedding = model.X_embedding\n", 56 | "\n", 57 | "perplexities = [[10, 100], [10], [30], [100], [300]]\n", 58 | "isorts_tsne = []\n", 59 | "for i, perplexity in enumerate(perplexities):\n", 60 | " print(perplexity)\n", 61 | " y_tsne = metrics.run_TSNE(model.Usv, perplexities=perplexity, verbose=False)\n", 62 | " if i==0:\n", 63 | " isort_tsne = y_tsne[:,0].argsort()\n", 64 | " isorts_tsne.append(y_tsne[:,0].argsort())\n", 65 | "\n", 66 | "isorts = [isort0, isort_split, *isorts_tsne]\n", 67 | "\n", 68 | "### benchmark\n", 69 | "knn_score, knn, rhos = simulations.benchmark_2D(xi, isorts)\n", 70 | " \n", 71 | "np.savez(os.path.join(root, \"simulations\", \"sim2D_results.npz\"),\n", 72 | " X_embedding=X_embedding, isorts=np.array(isorts), \n", 73 | " knn_score=knn_score, knn=knn, rhos=rhos, \n", 74 | " xi=xi)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "### visual cortex data\n", 82 | "\n", 83 | "(this data will be shared upon publication)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "filename = os.path.join(root, \"data/\", \"TX61_3x.npz\")\n", 93 | "stim_filename = os.path.join(root, \"data/\", \"text5k_3x.mat\")\n", 94 | "\n", 95 | "out = load_visual_data(filename, stim_filename)\n", 96 | "spks, istim, stim_times, xpos, ypos, run, ex_stim, img_pca, img_U, Ly, Lx = out\n", 97 | "\n", 98 | "# run rastermap \n", 99 | "model = Rastermap(n_clusters=100, n_splits=3, nc_splits=25, locality=0.,\n", 100 | " n_PCs=400, mean_time=False).fit(spks)\n", 101 | "isort = model.isort\n", 102 | "# bin by rastermap\n", 103 | "n_neurons = len(isort)\n", 104 | "n_bins = 500\n", 105 | "bin_size = n_neurons // n_bins\n", 106 | "X_embedding = zscore(utils.bin1d(spks[isort], bin_size, axis=0), axis=1)\n", 107 | "\n", 108 | "# compute stimulus responses sresp and average over the three repeats\n", 109 | "iss = np.zeros((3,5000), \"int\")\n", 110 | "for j in range(5000):\n", 111 | " iss[:,j] = (istim==j).nonzero()[0][:3]\n", 112 | "sresp = spks[:, stim_times]\n", 113 | "sresp = sresp[:, iss].transpose((1,0,2))\n", 114 | "snr_neurons = (zscore(sresp[0], axis=-1) * zscore(sresp[1], axis=-1)).mean(axis=1)\n", 115 | "\n", 116 | "# bin rastermap by neurons\n", 117 | "n_stim = sresp.shape[-1]\n", 118 | "n_bins = 500\n", 119 | "bin_size = n_neurons // n_bins\n", 120 | "x = sresp[:, isort[:(n_neurons // bin_size) * bin_size]]\n", 121 | "x = x.reshape(3, -1, bin_size, n_stim).mean(axis=2)\n", 122 | "n_bins = x.shape[1]\n", 123 | "snr = (zscore(x[0], axis=-1) * zscore(x[1], axis=-1)).mean(axis=-1)\n", 124 | "\n", 125 | "# sort stimuli\n", 126 | "model2 = Rastermap(n_clusters=100, n_splits=0, locality=0.,\n", 127 | " n_PCs=400).fit(x.T)\n", 128 | "isort2 = model2.isort\n", 129 | "\n", 130 | "# mean over 3 repeats\n", 131 | "sresp = sresp.mean(axis=0)\n", 132 | "sresp = zscore(sresp, axis=1)\n", 133 | "x = x.mean(axis=0)\n", 134 | "x = zscore(x, axis=-1)\n", 135 | "\n", 136 | "# ridge regression from 200 image PCs to 1000 rastermap components\n", 137 | "itrain = np.arange(5000)%5>0\n", 138 | "itest = ~itrain\n", 139 | "\n", 140 | "# ridge regression on training data with regularizer of 1e4\n", 141 | "imgTimg = (img_pca[itrain].T @ img_pca[itrain])/itrain.sum()\n", 142 | "imgTx = (img_pca[itrain].T @ x[:, itrain].T)/itrain.sum()\n", 143 | "B = np.linalg.solve(imgTimg + 1e4 * np.eye(200), imgTx)\n", 144 | "\n", 145 | "# reconstruct the receptive fields from the PCs\n", 146 | "rfs = B.T @ img_U\n", 147 | "rfs = np.reshape(rfs, (n_bins, Ly, Lx))\n", 148 | "\n", 149 | "# evaluate model on test data\n", 150 | "rpred = img_pca[itest] @ B\n", 151 | "cpred = (zscore(rpred.T, 1) * zscore(x[:,itest], 1)).mean(1)\n", 152 | "\n", 153 | "print(f\"mean r on test data {cpred.mean()}\")\n", 154 | "\n", 155 | "np.savez(os.path.join(root, \"results\", \"v1stimresp_proc.npz\"),\n", 156 | " X_embedding=X_embedding, bin_size=bin_size, isort=isort, isort2=isort2, \n", 157 | " xpos=xpos, ypos=ypos, x=x,\n", 158 | " stim_times=stim_times, run=run, ex_stim=ex_stim, rfs=rfs)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "### alexnet activations to same images\n", 166 | "\n", 167 | "(this data will be shared upon publication)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "filename = os.path.join(root, \"data\", \"ann_fvs_Grayscale(224)_TX61_3X.npz\")\n", 177 | "sresp, ilayer, ipos, iconv, nmax = load_alexnet_data(filename)\n", 178 | "\n", 179 | "# run rastermap\n", 180 | "model = Rastermap(n_clusters=100, n_splits=3, nc_splits=25, locality=0.,\n", 181 | " n_PCs=400, mean_time=False).fit(sresp)\n", 182 | "isort = model.isort\n", 183 | "\n", 184 | "# bin by rastermap\n", 185 | "bin_size = 24\n", 186 | "X_embedding = zscore(utils.bin1d(sresp[isort], bin_size, axis=0), axis=1)\n", 187 | "\n", 188 | "# sort stimuli\n", 189 | "model2 = Rastermap(n_clusters=100, n_splits=0, locality=0.,\n", 190 | " n_PCs=400).fit(X_embedding.T)\n", 191 | "isort2 = model2.isort\n", 192 | "\n", 193 | "np.savez(os.path.join(root, \"results\", \"alexnet_proc.npz\"),\n", 194 | " X_embedding=X_embedding, bin_size=bin_size, isort=isort, isort2=isort2,\n", 195 | " ilayer=ilayer, ipos=ipos, iconv=iconv, nmax=nmax)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "### make figures" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "# root path has folders \"simulations\" and \"results\" with saved results\n", 212 | "# will save figures to \"figures\" folder\n", 213 | "fig = plt.figure(figsize=(14,8/3))\n", 214 | "grid = plt.GridSpec(1,5, figure=fig, left=0.01, right=0.99, top=0.85, bottom=0.15, \n", 215 | " wspace = 0.15, hspace = 0.25)\n", 216 | "il = 0\n", 217 | "d = np.load(os.path.join(root, \"simulations\", \"sim2D_results.npz\"))\n", 218 | "il = fig_splitting.panels_sim2d(fig, grid, il, **d)\n", 219 | "\n", 220 | "fig.savefig(os.path.join(root, \"figures\", \"suppfig_sim2D.pdf\"), dpi=200)\n" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "from scipy.io import loadmat\n", 230 | "\n", 231 | "fig = plt.figure(figsize=(14, 10))\n", 232 | "yratio = 14 / 10\n", 233 | "grid = plt.GridSpec(4,5, figure=fig, left=0.01, right=0.99, top=0.94, bottom=0.07, \n", 234 | " wspace = 0.15, hspace = 0.25)\n", 235 | "il = 0\n", 236 | "areas = loadmat(os.path.join(root, \"figures\", \"ctxOutlines.mat\"), \n", 237 | " squeeze_me=True)[\"coords\"]\n", 238 | "d = np.load(os.path.join(root, \"results\", \"v1stimresp_proc.npz\"))\n", 239 | "il = fig_splitting.panels_v1stimresp(fig, grid, il, yratio, areas, **d, g0=0)\n", 240 | "\n", 241 | "d = np.load(os.path.join(root, \"results\", \"alexnet_proc.npz\"))\n", 242 | "il = fig_splitting.panels_alexnet(fig, grid, il, **d, g0=2) \n", 243 | "\n", 244 | "fig.savefig(os.path.join(root, \"figures\", \"suppfig_visual.pdf\"), dpi=200)\n" 245 | ] 246 | } 247 | ], 248 | "metadata": { 249 | "kernelspec": { 250 | "display_name": "Python 3 (ipykernel)", 251 | "language": "python", 252 | "name": "python3" 253 | }, 254 | "language_info": { 255 | "codemirror_mode": { 256 | "name": "ipython", 257 | "version": 3 258 | }, 259 | "file_extension": ".py", 260 | "mimetype": "text/x-python", 261 | "name": "python", 262 | "nbconvert_exporter": "python", 263 | "pygments_lexer": "ipython3", 264 | "version": "3.9.16" 265 | }, 266 | "varInspector": { 267 | "cols": { 268 | "lenName": 16, 269 | "lenType": 16, 270 | "lenVar": 40 271 | }, 272 | "kernels_config": { 273 | "python": { 274 | "delete_cmd_postfix": "", 275 | "delete_cmd_prefix": "del ", 276 | "library": "var_list.py", 277 | "varRefreshCmd": "print(var_dic_list())" 278 | }, 279 | "r": { 280 | "delete_cmd_postfix": ") ", 281 | "delete_cmd_prefix": "rm(", 282 | "library": "var_list.r", 283 | "varRefreshCmd": "cat(var_dic_list()) " 284 | } 285 | }, 286 | "position": { 287 | "height": "546px", 288 | "left": "845px", 289 | "right": "20px", 290 | "top": "120px", 291 | "width": "344px" 292 | }, 293 | "types_to_exclude": [ 294 | "module", 295 | "function", 296 | "builtin_function_or_method", 297 | "instance", 298 | "_Feature" 299 | ], 300 | "window_display": false 301 | }, 302 | "vscode": { 303 | "interpreter": { 304 | "hash": "998540cc2fc2836a46e99cd3ca3c37c375205941b23fd1eb4b203c48f2be758f" 305 | } 306 | } 307 | }, 308 | "nbformat": 4, 309 | "nbformat_minor": 2 310 | } 311 | -------------------------------------------------------------------------------- /paper/svca.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys, os\n", 10 | "import numpy as np\n", 11 | "from scipy.stats import zscore\n", 12 | "from neuropop import dimensionality\n", 13 | "\n", 14 | "# path to paper code\n", 15 | "sys.path.insert(0, '/github/rastermap/paper')\n", 16 | "from loaders import load_fish_data, load_visual_data, load_alexnet_data\n", 17 | "\n", 18 | "# path to directory with data etc\n", 19 | "### *** CHANGE THIS TO WHEREEVER YOU ARE DOWNLOADING THE DATA ***\n", 20 | "root = \"/media/carsen/ssd2/rastermap_paper/\"\n", 21 | "# (in this folder we have a \"data\" folder and a \"results\" folder)\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "### compute SVCA for large datasets\n" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "scovs = []\n", 38 | "\n", 39 | "dat = np.load(os.path.join(root, \"data/\", \"spont_data.npz\"))\n", 40 | "spks = dat[\"spks\"]\n", 41 | "spks = zscore(spks, axis=1)\n", 42 | "scov, varcov = dimensionality.SVCA(spks)\n", 43 | "scovs.append(scov)\n", 44 | "\n", 45 | "dat = np.load(os.path.join(root, \"data/\", \"corridor_neur.npz\"))\n", 46 | "xpos, ypos, spks = dat[\"xpos\"], dat[\"ypos\"], dat[\"spks\"]\n", 47 | "spks = zscore(spks, axis=1)\n", 48 | "scov, varcov = dimensionality.SVCA(spks)\n", 49 | "scovs.append(scov)\n", 50 | "\n", 51 | "### folder with \"subject_8\" folder\n", 52 | "data_path = os.path.join(root, \"data\")\n", 53 | "spks, F, xyz, stims, swimming, eyepos = load_fish_data(data_path, subject=8)\n", 54 | "spks = zscore(spks, axis=1)\n", 55 | "scov, varcov = dimensionality.SVCA(spks)\n", 56 | "scovs.append(scov)\n", 57 | "\n", 58 | "filename = os.path.join(root, \"data/\", \"TX61_3x.npz\")\n", 59 | "stim_filename = os.path.join(root, \"data/\", \"text5k_3x.mat\")\n", 60 | "out = load_visual_data(filename, stim_filename)\n", 61 | "spks = out[0]\n", 62 | "spks = zscore(spks, axis=1)\n", 63 | "scov, varcov = dimensionality.SVCA(spks)\n", 64 | "scovs.append(scov)\n", 65 | "\n", 66 | "env_ids = [\"PongNoFrameskip-v4\", \"SpaceInvadersNoFrameskip-v4\", \n", 67 | " \"EnduroNoFrameskip-v4\", \"SeaquestNoFrameskip-v4\"]\n", 68 | "\n", 69 | "for env_id in env_ids:\n", 70 | " dat = np.load(os.path.join(root, \"simulations/\", f\"qrdqn_{env_id}.npz\"))\n", 71 | " spks = dat[\"spks\"]\n", 72 | "\n", 73 | " x_std = spks.std(axis=1)\n", 74 | " igood = x_std > 1e-3\n", 75 | " print(igood.mean())\n", 76 | "\n", 77 | " S = zscore(spks[igood], axis=1)\n", 78 | " scov, varcov = dimensionality.SVCA(S)\n", 79 | " scovs.append(scov)\n", 80 | "\n", 81 | "filename = os.path.join(root, \"data\", \"ann_fvs_Grayscale(224)_TX61_3X.npz\")\n", 82 | "sresp, ilayer, ipos, iconv, nmax = load_alexnet_data(filename)\n", 83 | "for l in range(5):\n", 84 | " S = zscore(sresp[ilayer==l].copy(), axis=1)\n", 85 | " scov, varcov = dimensionality.SVCA(S)\n", 86 | " scovs.append(scov)\n", 87 | "\n", 88 | "scovs_all = np.nan * np.zeros((13, 1024))\n", 89 | "for k, scov in enumerate(scovs):\n", 90 | " scovs_all[k, :len(scov)] = scov\n", 91 | "\n", 92 | "np.save(os.path.join(root, \"results\", \"scovs.npy\"), scovs_all)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "### make figure" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "from fig_utils import *\n", 109 | "\n", 110 | "scovs_all = np.load(os.path.join(root, \"results\", \"scovs.npy\"))\n", 111 | "\n", 112 | "fig = plt.figure(figsize=(14, 7.5))\n", 113 | "grid = plt.GridSpec(3,5, figure=fig, left=0.06, right=0.98, top=0.96, bottom=0.08, \n", 114 | " wspace = 0.8, hspace = 0.8)\n", 115 | "transl = mtransforms.ScaledTranslation(-13 / 72, 20 / 72, fig.dpi_scale_trans)\n", 116 | "il = 0\n", 117 | "\n", 118 | "titles = [\"spontaneous activity, mouse\", \"virtual reality, mouse\", \"zebrafish wholebrain activity\", \"flashed images, mouse\",\n", 119 | " \"QRDQN - Pong\", \"QRDQN - SpaceInvaders\", \"QRDQN - Enduro\", \"QRDQN - Seaquest\",\n", 120 | " \"Alexnet - layer 1\", \"Alexnet - layer 2\", \"Alexnet - layer 3\", \"Alexnet - layer 4\", \"Alexnet - layer 5\"]\n", 121 | "\n", 122 | "for k, scov in enumerate(scovs_all):\n", 123 | " if k<8:\n", 124 | " if k==2 or k==3:\n", 125 | " ki = 3 if k==2 else 2\n", 126 | " else:\n", 127 | " ki = k\n", 128 | " ax = plt.subplot(grid[ki//4, ki%4])\n", 129 | " else:\n", 130 | " ax = plt.subplot(grid[2, k-8])\n", 131 | " ss = scov.copy()\n", 132 | " ss /= ss[0]\n", 133 | " #ss *= len(ss)\n", 134 | " alpha, ypred = dimensionality.get_powerlaw(ss, np.arange(11, 500))\n", 135 | " #print(alpha)\n", 136 | " ax.loglog(np.arange(1, len(ss)+1), ss, color=[0,0.5,1])\n", 137 | " ax.plot(np.arange(len(ypred))+1, ypred, color=\"k\", lw=1)\n", 138 | " ax.text(0.5, 0.7, rf\"$\\alpha$={alpha:.2f}\", transform=ax.transAxes)\n", 139 | " ax.set_title(titles[k], fontsize=\"medium\")#, fontsize=\"small\")\n", 140 | " ax.set_ylim([1/1024, 1.5])\n", 141 | " ax.set_xlim([1, 1024])\n", 142 | " ax.set_xticks([1, 10, 100, 1000])\n", 143 | " ax.set_xlabel(\"PC #\")\n", 144 | " ax.set_ylabel(\"shared covariance\")\n", 145 | " ax.xaxis.get_minor_locator().set_params(numticks=99, subs=np.arange(0.1, 1, 0.1))\n", 146 | "\n", 147 | "fig.savefig(os.path.join(root, \"figures\", \"suppfig_powerlaws.pdf\"))" 148 | ] 149 | } 150 | ], 151 | "metadata": { 152 | "kernelspec": { 153 | "display_name": "rastermap", 154 | "language": "python", 155 | "name": "python3" 156 | }, 157 | "language_info": { 158 | "codemirror_mode": { 159 | "name": "ipython", 160 | "version": 3 161 | }, 162 | "file_extension": ".py", 163 | "mimetype": "text/x-python", 164 | "name": "python", 165 | "nbconvert_exporter": "python", 166 | "pygments_lexer": "ipython3", 167 | "version": "3.9.16" 168 | } 169 | }, 170 | "nbformat": 4, 171 | "nbformat_minor": 2 172 | } 173 | -------------------------------------------------------------------------------- /rastermap/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | from rastermap.rastermap import Rastermap, default_settings, settings_info -------------------------------------------------------------------------------- /rastermap/__main__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | from scipy.stats import zscore 5 | import numpy as np 6 | import argparse 7 | import os 8 | from rastermap import Rastermap 9 | from rastermap.io import load_activity, load_spike_times 10 | 11 | try: 12 | from rastermap.gui import gui 13 | GUI_ENABLED = True 14 | except ImportError as err: 15 | GUI_ERROR = err 16 | GUI_ENABLED = False 17 | GUI_IMPORT = True 18 | except Exception as err: 19 | GUI_ENABLED = False 20 | GUI_ERROR = err 21 | GUI_IMPORT = False 22 | raise 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser(description="spikes") 26 | parser.add_argument("--S", default=[], type=str, help="spiking matrix") 27 | parser.add_argument("--spike_times", default=[], type=str, help="spike_times.npy") 28 | parser.add_argument("--spike_clusters", default=[], type=str, help="spike_clusters.npy") 29 | parser.add_argument("--st_bin", default=100, type=float, help="bin size in milliseconds for spike times") 30 | parser.add_argument("--proc", default=[], type=str, 31 | help="processed data file 'embedding.npy'") 32 | parser.add_argument("--ops", default=[], type=str, help="options file 'ops.npy'") 33 | parser.add_argument("--iscell", default=[], type=str, 34 | help="which cells to select for processing") 35 | args = parser.parse_args() 36 | 37 | if len(args.ops) > 0 and (len(args.S) > 0 or 38 | (len(args.spike_times) > 0 and len(args.spike_clusters) > 0)): 39 | if len(args.S) > 0: 40 | X, Usv, Vsv, xy = load_activity(args.S) 41 | else: 42 | Usv, Vsv, xy = None, None, None 43 | X = load_spike_times(args.spike_times, args.spike_clusters, args.st_bin) 44 | ops = np.load(args.ops, allow_pickle=True).item() 45 | if len(args.iscell) > 0: 46 | iscell = np.load(args.iscell) 47 | if iscell.ndim > 1: 48 | iscell = iscell[:, 0].astype("bool") 49 | else: 50 | iscell = iscell.astype("bool") 51 | if iscell.size == X.shape[0]: 52 | X = X[iscell, :] 53 | print("iscell found and used to select neurons") 54 | 55 | if Usv is not None and Usv.ndim==3: 56 | Usv = Usv.reshape(-1, Usv.shape[-1]) 57 | 58 | model = Rastermap(**ops) 59 | train_time = np.ones(X.shape[1] if X is not None 60 | else Vsv.shape[0], "bool") 61 | if X is not None: 62 | if ("end_time" in ops and ops["end_time"] == -1) or "end_time" not in ops: 63 | ops["end_time"] = X.shape[1] 64 | ops["start_time"] = 0 65 | else: 66 | train_time = np.zeros(X.shape[1], "bool") 67 | train_time[np.arange(ops["start_time"], ops["end_time"]).astype(int)] = 1 68 | X = X[:, train_time] 69 | 70 | model.fit(data=X, Usv=Usv, Vsv=Vsv) 71 | 72 | proc = { 73 | "filename": args.S if len(args.S) > 0 else args.spike_times, 74 | "filename_cluid": args.spike_clusters if args.spike_clusters else None, 75 | "st_bin": args.st_bin if args.spike_clusters else None, 76 | "save_path": os.path.split(args.S)[0] if args.S else os.path.split(args.spike_times)[0], 77 | "isort": model.isort, 78 | "embedding": model.embedding, 79 | "user_clusters": None, 80 | "ops": ops, 81 | } 82 | basename, fname = os.path.split(args.S) if args.S else os.path.split(args.spike_times) 83 | fname = os.path.splitext(fname)[0] 84 | try: 85 | np.save(os.path.join(basename, f"{fname}_embedding.npy"), proc) 86 | except Exception as e: 87 | print("ERROR: no permission to write to data folder") 88 | #os.path.dirname(args.ops) 89 | np.save("embedding.npy", proc) 90 | else: 91 | if not GUI_ENABLED: 92 | print("GUI ERROR: %s" % GUI_ERROR) 93 | if GUI_IMPORT: 94 | print( 95 | "GUI FAILED: GUI dependencies may not be installed, to install, run" 96 | ) 97 | print(" pip install rastermap[gui]") 98 | else: 99 | # use proc path if it exists, else use S path 100 | filename = args.proc if len(args.proc) > 0 else None 101 | filename = args.S if len(args.S) > 0 and filename is None else filename 102 | gui.run(filename=filename, proc=len(args.proc) > 0) 103 | -------------------------------------------------------------------------------- /rastermap/cluster.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | import numpy as np 5 | import warnings 6 | from sklearn.cluster import MiniBatchKMeans as KMeans 7 | from scipy.stats import zscore 8 | 9 | def _scaled_kmeans_init(X, n_clusters=100, n_local_trials=None): 10 | """Init n_clusters seeds according to k-means++ using correlations 11 | 12 | adapted from scikit-learn's code for correlation distance 13 | 14 | Parameters 15 | ----------- 16 | X : array shape (n_samples, n_features) 17 | The data to pick seeds for. 18 | n_clusters : integer 19 | The number of seeds to choose 20 | n_local_trials : integer, optional 21 | The number of seeding trials for each center (except the first), 22 | of which the one reducing inertia the most is greedily chosen. 23 | Set to None to make the number of trials depend logarithmically 24 | on the number of seeds (2+log(k)); this is the default. 25 | 26 | Returns 27 | ----------- 28 | X_nodes : array shape (n_clusters, n_features) 29 | cluster centers for initializing kmeans 30 | 31 | Notes 32 | ----- 33 | Selects initial cluster centers for k-mean clustering in a smart way 34 | to speed up convergence. see: Arthur, D. and Vassilvitskii, S. 35 | "k-means++: the advantages of careful seeding". ACM-SIAM symposium 36 | on Discrete algorithms. 2007 37 | Version ported from http://www.stanford.edu/~darthur/kMeansppTest.zip, 38 | which is the implementation used in the aforementioned paper. 39 | """ 40 | n_samples, n_features = X.shape 41 | X_nodes = np.empty((n_clusters, n_features), dtype=X.dtype) 42 | if n_local_trials is None: 43 | n_local_trials = min(2 + 10 * int(np.log(n_clusters)), n_samples - 1) 44 | else: 45 | n_local_trials = min(n_samples-1, n_local_trials) 46 | 47 | # Pick first center randomly 48 | center_id = np.random.randint(n_samples) 49 | X_nodes[0] = X[center_id] 50 | 51 | # Initialize list of closest distances and calculate current potential 52 | X_norm = (X**2).sum(axis=1)**0.5 53 | closest_dist_sq = 1 - (X @ X_nodes[0]) / (X_norm * (X_nodes[0]**2).sum()**0.5) 54 | current_pot = closest_dist_sq.sum() 55 | 56 | # Pick the remaining n_clusters-1 points 57 | for c in range(1, n_clusters): 58 | # Choose center candidates by sampling with probability proportional 59 | # to the squared distance to the closest existing center 60 | rand_vals = np.random.random_sample(n_local_trials) * current_pot 61 | candidate_ids = np.searchsorted(np.cumsum(closest_dist_sq.astype(np.float64)), 62 | rand_vals) 63 | 64 | # Compute distances to center candidates 65 | X_candidates = X[candidate_ids] 66 | X_candidates_norm = (X_candidates**2).sum(axis=1)**0.5 67 | distance_to_candidates = 1 - (X @ X_candidates.T) / np.outer( 68 | X_norm, X_candidates_norm) 69 | 70 | # Decide which candidate is the best 71 | best_candidate = None 72 | best_pot = None 73 | best_dist_sq = None 74 | for trial in range(n_local_trials): 75 | # Compute potential when including center candidate 76 | new_dist_sq = np.minimum(closest_dist_sq, distance_to_candidates[:, trial]) 77 | new_pot = new_dist_sq.sum() 78 | 79 | # Store result if it is the best local trial so far 80 | if (best_candidate is None) or (new_pot < best_pot): 81 | best_candidate = candidate_ids[trial] 82 | best_pot = new_pot 83 | best_dist_sq = new_dist_sq 84 | 85 | # Permanently add best center candidate found in local tries 86 | X_nodes[c] = X[best_candidate] 87 | current_pot = best_pot 88 | closest_dist_sq = best_dist_sq 89 | 90 | return X_nodes 91 | 92 | 93 | def scaled_kmeans(X, n_clusters=100, n_iter=50, n_local_trials=100, 94 | init="kmeans++", random_state=0): 95 | """ kmeans using correlation distance 96 | 97 | Parameters 98 | ----------- 99 | X : array shape (n_samples, n_features) 100 | The data to cluster 101 | n_clusters : integer, optional (default=100) 102 | The number of clusters 103 | n_iter : integer, optional (default=50) 104 | number of iterations 105 | random_state : integer, optional (default=0) 106 | seed of numpy random number generator 107 | 108 | Returns 109 | ----------- 110 | X_nodes : array shape (n_clusters, n_features) 111 | cluster centers found by kmeans 112 | imax : array (n_samples) 113 | best cluster for each data point 114 | 115 | """ 116 | n_samples, n_features = X.shape 117 | # initialize with kmeans++ 118 | np.random.seed(random_state) 119 | if init == "kmeans++": 120 | X_nodes = _scaled_kmeans_init(X, n_clusters=n_clusters, 121 | n_local_trials=n_local_trials) 122 | else: 123 | X_nodes = np.random.randn(n_clusters, n_features) * (X**2).sum(axis=0)**0.5 124 | X_nodes = X_nodes / (1e-4 + (X_nodes**2).sum(axis=1)[:, np.newaxis])**.5 125 | 126 | # iterate and reassign neurons 127 | for j in range(n_iter): 128 | cc = X @ X_nodes.T 129 | imax = np.argmax(cc, axis=1) 130 | cc = cc * (cc > np.max(cc, 1)[:, np.newaxis] - 1e-6) 131 | X_nodes = cc.T @ X 132 | X_nodes = X_nodes / (1e-10 + (X_nodes**2).sum(axis=1)[:, np.newaxis])**.5 133 | X_nodes_norm = (X_nodes**2).sum(axis=1)**0.5 134 | X_nodes = X_nodes[X_nodes_norm > 0] 135 | X_nodes = X_nodes[X_nodes[:, 0].argsort()] 136 | 137 | if X_nodes.shape[0] < n_clusters // 2 and init == "kmeans++": 138 | warnings.warn( 139 | "found fewer than half the n_clusters that the user specified, rerunning with random initialization" 140 | ) 141 | X_nodes, imax = scaled_kmeans(X, n_clusters=n_clusters, n_iter=n_iter, 142 | init="random", random_state=random_state) 143 | else: 144 | cc = X @ X_nodes.T 145 | imax = cc.argmax(axis=1) 146 | 147 | if X_nodes.shape[0] < n_clusters and init != "kmeans++": 148 | warnings.warn( 149 | "found fewer clusters than user specified, try reducing n_clusters and/or reduce n_splits and/or increase n_PCs" 150 | ) 151 | 152 | return X_nodes, imax 153 | 154 | def kmeans(X, n_clusters=100, random_state=0): 155 | np.random.seed(random_state) 156 | #X_nodes = (np.random.randn(n_clusters, n_features) / 157 | # (1 + np.arange(n_features))**0.5) 158 | #X_nodes = X_nodes / (1e-4 + (X_nodes**2).sum(axis=1)[:,np.newaxis])**.5 159 | model = KMeans(n_init=1, init="k-means++", n_clusters=n_clusters, 160 | random_state=random_state).fit(X) 161 | X_nodes = model.cluster_centers_ 162 | X_nodes = X_nodes / (1e-10 + ((X_nodes**2).sum(axis=1))[:, np.newaxis])**.5 163 | imax = model.labels_ 164 | X_nodes = X_nodes[X_nodes[:, 0].argsort()] 165 | cc = X @ X_nodes.T 166 | imax = cc.argmax(axis=1) 167 | return X_nodes, imax 168 | 169 | def compute_cc_tdelay(V, U_nodes, time_lag_window=5, symmetric=False): 170 | """ compute correlation matrix of clusters at time offsets and take max """ 171 | X_nodes = U_nodes @ V.T 172 | X_nodes = zscore(X_nodes, axis=1) 173 | n_nodes, nt = X_nodes.shape 174 | 175 | tshifts = np.arange(-time_lag_window * symmetric, time_lag_window + 1) 176 | cc_tdelay = np.zeros((n_nodes, n_nodes, len(tshifts)), np.float32) 177 | for i, tshift in enumerate(tshifts): 178 | if tshift < 0: 179 | cc_tdelay[:, :, i] = ((X_nodes[:, :nt + tshift] @ X_nodes[:, -tshift:].T) / 180 | (nt - tshift)) 181 | else: 182 | cc_tdelay[:, :, i] = ((X_nodes[:, tshift:] @ X_nodes[:, :nt - tshift].T) / 183 | (nt - tshift)) 184 | 185 | return cc_tdelay.max(axis=-1) -------------------------------------------------------------------------------- /rastermap/gui/guiparts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | from qtpy import QtGui, QtCore, QtWidgets 5 | from qtpy.QtWidgets import QMainWindow, QApplication, QWidget, QScrollBar, QSlider, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox, QGroupBox, QStyle, QStyleOptionSlider 6 | import pyqtgraph as pg 7 | from pyqtgraph import functions as fn 8 | from pyqtgraph import Point 9 | import numpy as np 10 | 11 | 12 | class TimeROI(pg.LinearRegionItem): 13 | 14 | def __init__(self, parent=None, color=[128, 128, 255, 50], bounds=[0, 100], 15 | roi_id=0): 16 | self.color = color 17 | self.parent = parent 18 | self.pen = pg.mkPen(pg.mkColor(*self.color), width=2, style=QtCore.Qt.SolidLine) 19 | self.brush = pg.mkBrush(pg.mkColor(*self.color[:3], 50)) 20 | self.hover_brush = pg.mkBrush(pg.mkColor(*self.color)) 21 | self.roi_id = roi_id 22 | self.bounds = bounds 23 | super().__init__(orientation="vertical", bounds=bounds, pen=self.pen, 24 | brush=self.brush, hoverBrush=self.hover_brush) 25 | 26 | def time_set(self): 27 | region = self.getRegion() 28 | region = [int(region[0]), int(region[1])] 29 | region[0] = max(self.bounds[0], region[0]) 30 | region[1] = min(self.bounds[1] - 1, region[1]) 31 | x0, x1 = region[0], region[1] + 1 32 | 33 | self.parent.xrange = slice(x0, x1) 34 | # Update zoom in plot 35 | self.parent.imgROI.setImage(self.parent.sp_smoothed[:, self.parent.xrange]) 36 | self.parent.imgROI.setLevels( 37 | [self.parent.sat[0] / 100., self.parent.sat[1] / 100.]) 38 | self.parent.p2.setXRange(0, x1 - x0, padding=0) 39 | self.parent.p2.show() 40 | 41 | # update other plots 42 | self.parent.p3.setLimits(xMin=x0, xMax=x1) 43 | self.parent.p3.setXRange(x0, x1) 44 | self.parent.p3.show() 45 | self.parent.p4.setLimits(xMin=x0, xMax=x1) 46 | self.parent.p4.setXRange(x0, x1) 47 | self.parent.p4.show() 48 | 49 | 50 | class ClusterROI(pg.LinearRegionItem): 51 | 52 | def __init__(self, parent=None, color=[128, 128, 255, 50], bounds=[0, 100], 53 | roi_id=0): 54 | self.color = color 55 | self.parent = parent 56 | self.pen = pg.mkPen(pg.mkColor(*self.color), width=2, style=QtCore.Qt.SolidLine) 57 | self.brush = pg.mkBrush(pg.mkColor(*self.color[:3], 50)) 58 | self.hover_brush = pg.mkBrush(pg.mkColor(*self.color)) 59 | self.roi_id = roi_id 60 | self.bounds = bounds 61 | super().__init__(orientation="horizontal", bounds=bounds, pen=self.pen, 62 | brush=self.brush, hoverBrush=self.hover_brush) 63 | self.sigRegionChanged.connect(self.cluster_set) 64 | 65 | def cluster_set(self): 66 | region = self.getRegion() 67 | region = [int(region[0]), int(region[1])] 68 | region[0] = max(self.bounds[0], region[0]) 69 | region[1] = min(self.bounds[1], region[1]) 70 | if len(self.parent.cluster_slices) > self.roi_id: 71 | self.parent.cluster_slices[self.roi_id] = slice(region[0], region[1]) 72 | self.parent.selected = self.parent.cluster_slices[self.roi_id] 73 | self.parent.plot_traces(roi_id=self.roi_id) 74 | if self.parent.neuron_pos is not None or self.parent.behav_data is not None: 75 | self.parent.update_scatter(roi_id=self.roi_id) 76 | if hasattr(self.parent, "PlaneWindow"): 77 | self.parent.PlaneWindow.update_plots(roi_id=self.roi_id) 78 | self.parent.p3.show() 79 | 80 | def mouseClickEvent(self, ev): 81 | if self.moving and ev.button() == QtCore.Qt.RightButton: 82 | ev.accept() 83 | for i, l in enumerate(self.lines): 84 | l.setPos(self.startPositions[i]) 85 | self.moving = False 86 | self.cluster_set() 87 | self.sigRegionChanged.emit(self) 88 | self.sigRegionChangeFinished.emit(self) 89 | elif ev.button() == QtCore.Qt.LeftButton and ev.modifiers( 90 | ) == QtCore.Qt.ControlModifier: 91 | if len(self.parent.cluster_rois) > 0: 92 | print("removing cluster roi") 93 | self.remove() 94 | 95 | def remove(self): 96 | # delete color and add to end of list 97 | del self.parent.colors[self.roi_id] 98 | self.parent.colors.append(self.color) 99 | 100 | # delete slice and ROI 101 | del self.parent.cluster_slices[self.roi_id] 102 | self.parent.p2.removeItem(self.parent.cluster_rois[self.roi_id]) 103 | del self.parent.cluster_rois[self.roi_id] 104 | 105 | # remove scatter plots 106 | self.parent.p5.removeItem(self.parent.scatter_plots[0][self.roi_id + 1]) 107 | del self.parent.scatter_plots[0][self.roi_id + 1] 108 | self.parent.scatter_plots[0].append(pg.ScatterPlotItem()) 109 | self.parent.p5.addItem(self.parent.scatter_plots[0][-1]) 110 | self.parent.p5.addItem(self.parent.cluster_plots[-1]) 111 | 112 | # remove avg activity 113 | self.parent.p3.removeItem(self.parent.cluster_plots[self.roi_id]) 114 | del self.parent.cluster_plots[self.roi_id] 115 | self.parent.cluster_plots.append(pg.PlotDataItem()) 116 | self.parent.p3.addItem(self.parent.cluster_plots[-1]) 117 | 118 | # reindex roi_id 119 | for i in range(len(self.parent.cluster_rois)): 120 | self.parent.cluster_rois[i].roi_id = i 121 | 122 | # update avg plot 123 | self.parent.plot_traces() 124 | 125 | 126 | # custom vertical label 127 | class VerticalLabel(QWidget): 128 | 129 | def __init__(self, text=None): 130 | super(self.__class__, self).__init__() 131 | self.text = text 132 | 133 | def paintEvent(self, event): 134 | painter = QtGui.QPainter(self) 135 | painter.setPen(QtCore.Qt.white) 136 | painter.translate(0, 0) 137 | painter.rotate(90) 138 | if self.text: 139 | painter.drawText(0, 0, self.text) 140 | painter.end() 141 | -------------------------------------------------------------------------------- /rastermap/gui/menus.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | from qtpy import QtGui, QtCore, QtWidgets 5 | from qtpy.QtWidgets import QAction 6 | import pyqtgraph as pg 7 | import numpy as np 8 | from . import io, run, gui, views 9 | 10 | 11 | # ------ MENU BAR ----------------- 12 | def mainmenu(parent): 13 | # make mainmenu! 14 | main_menu = parent.menuBar() 15 | 16 | file_menu = main_menu.addMenu("&File") 17 | 18 | loadMat = QAction("&Load data matrix (neurons by time)", parent) 19 | loadMat.setShortcut("Ctrl+L") 20 | loadMat.triggered.connect(lambda: io.load_mat(parent, name=None)) 21 | parent.addAction(loadMat) 22 | file_menu.addAction(loadMat) 23 | 24 | loadSt = QAction("Load spike_times and spike_&Clusters", parent) 25 | loadSt.setShortcut("Ctrl+C") 26 | loadSt.triggered.connect(lambda: io.load_st_clu(parent, name=None)) 27 | parent.addAction(loadSt) 28 | file_menu.addAction(loadSt) 29 | 30 | parent.loadXY = QAction("Load &XY(z) positions of neurons", parent) 31 | parent.loadXY.setShortcut("Ctrl+X") 32 | parent.loadXY.triggered.connect(lambda: io.load_neuron_pos(parent)) 33 | parent.addAction(parent.loadXY) 34 | file_menu.addAction(parent.loadXY) 35 | 36 | # load Z-stack 37 | parent.loadProc = QAction("Load &Z-stack (mean images)", parent) 38 | parent.loadProc.setShortcut("Ctrl+Z") 39 | parent.loadProc.triggered.connect(lambda: io.load_zstack(parent, name=None)) 40 | parent.addAction(parent.loadProc) 41 | file_menu.addAction(parent.loadProc) 42 | 43 | parent.loadNd = QAction("Load &N-d variable (times or cont.)", parent) 44 | parent.loadNd.setShortcut("Ctrl+N") 45 | parent.loadNd.triggered.connect(lambda: io.get_behav_data(parent)) 46 | parent.loadNd.setEnabled(False) 47 | parent.addAction(parent.loadNd) 48 | file_menu.addAction(parent.loadNd) 49 | 50 | # load processed data 51 | parent.loadProc = QAction("&Load processed data", parent) 52 | parent.loadProc.setShortcut("Ctrl+P") 53 | parent.loadProc.triggered.connect(lambda: io.load_proc(parent, name=None)) 54 | parent.addAction(parent.loadProc) 55 | file_menu.addAction(parent.loadProc) 56 | 57 | # export figure 58 | exportFig = QAction("Export as image (svg)", parent) 59 | exportFig.triggered.connect(lambda: export_fig(parent)) 60 | exportFig.setEnabled(True) 61 | parent.addAction(exportFig) 62 | 63 | run_menu = main_menu.addMenu("&Run") 64 | # Save processed data 65 | parent.runRmap = QAction("&Run rastermap", parent) 66 | parent.runRmap.setShortcut("Ctrl+R") 67 | parent.runRmap.triggered.connect(lambda: run.RunWindow(parent)) 68 | parent.runRmap.setEnabled(False) 69 | parent.addAction(parent.runRmap) 70 | run_menu.addAction(parent.runRmap) 71 | 72 | #view_menu = main_menu.addMenu("&Views") 73 | #parent.view3D = QAction("&View multi-plane data", parent) 74 | #parent.view3D.setShortcut("Ctrl+V") 75 | #parent.view3D.triggered.connect(lambda: plane_window(parent)) 76 | #parent.view3D.setEnabled(False) 77 | #parent.addAction(parent.view3D) 78 | #view_menu.addAction(parent.view3D) 79 | 80 | save_menu = main_menu.addMenu("&Save") 81 | # Save processed data 82 | parent.saveProc = QAction("&Save processed data", parent) 83 | parent.saveProc.setShortcut("Ctrl+S") 84 | parent.saveProc.triggered.connect(lambda: io.save_proc(parent)) 85 | parent.addAction(parent.saveProc) 86 | save_menu.addAction(parent.saveProc) 87 | 88 | #help_menu = main_menu.addMenu("&Help") 89 | ## Save processed data 90 | #helpmenu = QAction("&Save processed data", parent) 91 | #parent.saveProc.setShortcut("Ctrl+S") 92 | #parent.saveProc.triggered.connect(lambda: io.save_proc(parent)) 93 | #parent.addAction(parent.saveProc) 94 | #save_menu.addAction(parent.saveProc) 95 | 96 | 97 | def export_fig(parent): 98 | parent.win.scene().contextMenuItem = parent.p0 99 | parent.win.scene().showExportDialog() 100 | -------------------------------------------------------------------------------- /rastermap/gui/run.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | import numpy as np 5 | import os, sys 6 | from qtpy import QtGui, QtCore 7 | from qtpy.QtWidgets import QMainWindow, QApplication, QSizePolicy, QDialog, QWidget, QScrollBar, QSlider, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox, QGroupBox, QButtonGroup, QRadioButton, QStatusBar, QTextEdit 8 | from . import io 9 | 10 | 11 | ### custom QDialog which allows user to fill in ops and run rastermap 12 | class RunWindow(QDialog): 13 | 14 | def __init__(self, parent=None): 15 | super(RunWindow, self).__init__(parent) 16 | self.setGeometry(50, 50, 600, 600) 17 | self.setWindowTitle("Choose rastermap run options") 18 | self.win = QWidget(self) 19 | self.layout = QGridLayout() 20 | self.layout.setHorizontalSpacing(25) 21 | self.win.setLayout(self.layout) 22 | 23 | print( 24 | ">>> importing rastermap functions (will be slow if you haven't run rastermap before) <<<" 25 | ) 26 | from rastermap import default_settings, settings_info, Rastermap 27 | # default ops 28 | self.ops = default_settings() 29 | info = settings_info() 30 | keys = [ 31 | "n_clusters", "n_PCs", "time_lag_window", "locality", "grid_upsample", 32 | "time_bin", "n_splits" 33 | ] 34 | tooltips = [info[key] for key in keys] 35 | bigfont = QtGui.QFont("Arial", 10, QtGui.QFont.Bold) 36 | l = 0 37 | self.keylist = [] 38 | self.editlist = [] 39 | k = 0 40 | for key in keys: 41 | qedit = LineEdit(k, key, self) 42 | qlabel = QLabel(key) 43 | qlabel.setToolTip(tooltips[k]) 44 | qedit.set_text(self.ops) 45 | qedit.setFixedWidth(90) 46 | self.layout.addWidget(qlabel, k, 0, 1, 1) 47 | self.layout.addWidget(qedit, k, 1, 1, 1) 48 | self.keylist.append(key) 49 | self.editlist.append(qedit) 50 | k += 1 51 | 52 | #for j in range(10): 53 | # self.layout.addWidget(QLabel("."),19,4+j,1,1) 54 | 55 | self.layout.setColumnStretch(4, 10) 56 | self.runButton = QPushButton("RUN") 57 | self.runButton.clicked.connect(lambda: self.run_RMAP(parent)) 58 | self.layout.addWidget(self.runButton, 19, 0, 1, 1) 59 | #self.runButton.setEnabled(False) 60 | self.textEdit = QTextEdit() 61 | self.textEdit.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) 62 | self.layout.addWidget(self.textEdit, 20, 0, 30, 14) 63 | self.process = QtCore.QProcess(self) 64 | self.process.readyReadStandardOutput.connect(self.stdout_write) 65 | self.process.readyReadStandardError.connect(self.stderr_write) 66 | # disable the button when running the rastermap process 67 | self.process.started.connect(self.started) 68 | self.process.finished.connect(lambda: self.finished(parent)) 69 | self.process.errorOccurred.connect(self.errored) 70 | # stop process 71 | self.stopButton = QPushButton("STOP") 72 | self.stopButton.setEnabled(False) 73 | self.layout.addWidget(self.stopButton, 19, 1, 1, 1) 74 | self.stopButton.clicked.connect(self.stop) 75 | 76 | self.show() 77 | 78 | def run_RMAP(self, parent): 79 | self.finish = True 80 | self.error = False 81 | self.save_text() 82 | ops_path = os.path.join(os.getcwd(), "rmap_ops.npy") 83 | np.save(ops_path, self.ops) 84 | print("Running rastermap with command:") 85 | if parent.from_spike_times: 86 | cmd = f"-u -W ignore -m rastermap --ops {ops_path} --spike_times {parent.fname} --spike_clusters {parent.fname_cluid} --st_bin {parent.st_bin}" 87 | else: 88 | cmd = f"-u -W ignore -m rastermap --ops {ops_path} --S {parent.fname}" 89 | if parent.file_iscell is not None: 90 | cmd += f" --iscell {parent.file_iscell}" 91 | print("python " + cmd) 92 | self.process.start(sys.executable, cmd.split(" ")) 93 | 94 | def stop(self): 95 | self.finish = False 96 | self.process.kill() 97 | 98 | def errored(self, error): 99 | print("ERROR") 100 | process = self.process 101 | print("error: ", error, "-", " ".join([process.program()] + process.arguments())) 102 | 103 | def started(self): 104 | self.runButton.setEnabled(False) 105 | self.stopButton.setEnabled(True) 106 | 107 | def finished(self, parent): 108 | self.runButton.setEnabled(True) 109 | self.stopButton.setEnabled(False) 110 | if self.finish and not self.error: 111 | cursor = self.textEdit.textCursor() 112 | cursor.movePosition(cursor.End) 113 | cursor.insertText("Opening in GUI (can close this window)\n") 114 | basename, fname = os.path.split(parent.fname) 115 | fname = os.path.splitext(fname)[0] 116 | if os.path.isfile(os.path.join(basename, f"{fname}_embedding.npy")): 117 | parent.fname = os.path.join(basename, f"{fname}_embedding.npy") 118 | else: 119 | parent.fname = f"{fname}_embedding.npy" 120 | io.load_proc(parent, name=parent.fname) 121 | elif not self.error: 122 | cursor = self.textEdit.textCursor() 123 | cursor.movePosition(cursor.End) 124 | cursor.insertText("Interrupted by user (not finished)\n") 125 | else: 126 | cursor = self.textEdit.textCursor() 127 | cursor.movePosition(cursor.End) 128 | cursor.insertText("Interrupted by error (not finished)\n") 129 | 130 | def save_text(self): 131 | for k in range(len(self.editlist)): 132 | key = self.keylist[k] 133 | self.ops[key] = self.editlist[k].get_text(self.ops[key]) 134 | 135 | def stdout_write(self): 136 | cursor = self.textEdit.textCursor() 137 | cursor.movePosition(cursor.End) 138 | cursor.insertText(str(self.process.readAllStandardOutput(), "utf-8")) 139 | self.textEdit.ensureCursorVisible() 140 | 141 | def stderr_write(self): 142 | cursor = self.textEdit.textCursor() 143 | cursor.movePosition(cursor.End) 144 | cursor.insertText(">>>ERROR<<<\n") 145 | cursor.insertText(str(self.process.readAllStandardError(), "utf-8")) 146 | self.textEdit.ensureCursorVisible() 147 | self.error = True 148 | 149 | 150 | class LineEdit(QLineEdit): 151 | 152 | def __init__(self, k, key, parent=None): 153 | super(LineEdit, self).__init__(parent) 154 | self.key = key 155 | #self.textEdited.connect(lambda: self.edit_changed(parent.ops, k)) 156 | 157 | def get_text(self, okey): 158 | key = self.key 159 | if key == "diameter" or key == "block_size": 160 | diams = self.text().replace(" ", "").split(",") 161 | if len(diams) > 1: 162 | okey = [int(diams[0]), int(diams[1])] 163 | else: 164 | okey = int(diams[0]) 165 | else: 166 | if type(okey) is float: 167 | okey = float(self.text()) 168 | elif type(okey) is str: 169 | okey = self.text() 170 | elif type(okey) is int or bool: 171 | okey = int(self.text()) 172 | 173 | return okey 174 | 175 | def set_text(self, ops): 176 | key = self.key 177 | if key == "diameter" or key == "block_size": 178 | if (type(ops[key]) is not int) and (len(ops[key]) > 1): 179 | dstr = str(int(ops[key][0])) + ", " + str(int(ops[key][1])) 180 | else: 181 | dstr = str(int(ops[key])) 182 | else: 183 | if type(ops[key]) is bool: 184 | dstr = str(int(ops[key])) 185 | elif type(ops[key]) is str: 186 | dstr = ops[key] 187 | else: 188 | dstr = str(ops[key]) 189 | self.setText(dstr) 190 | -------------------------------------------------------------------------------- /rastermap/gui/views.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | from qtpy import QtGui, QtCore, QtWidgets 5 | from qtpy.QtWidgets import QMainWindow, QApplication, QDialog, QWidget, QScrollBar, QSlider, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox, QGroupBox, QStyle, QStyleOptionSlider 6 | import pyqtgraph as pg 7 | from pyqtgraph import functions as fn 8 | from pyqtgraph import Point 9 | import numpy as np 10 | 11 | from . import colormaps 12 | 13 | nclust_max = 100 14 | 15 | 16 | class PlaneWindow(QMainWindow): 17 | 18 | def __init__(self, parent=None): 19 | super(PlaneWindow, self).__init__(parent) 20 | self.parent = parent 21 | self.setGeometry(50, 50, 800, 800) 22 | self.setWindowTitle("view neurons multi-plane") 23 | self.win = QWidget(self) 24 | self.layout = QGridLayout() 25 | self.layout.setHorizontalSpacing(25) 26 | self.win.setLayout(self.layout) 27 | 28 | self.cwidget = QWidget(self) 29 | self.setCentralWidget(self.cwidget) 30 | self.l0 = QGridLayout() 31 | self.cwidget.setLayout(self.l0) 32 | self.win = pg.GraphicsLayoutWidget() 33 | self.l0.addWidget(self.win, 0, 0, 1, 1) 34 | layout = self.win.ci.layout 35 | 36 | self.menuBar().clear() 37 | 38 | neuron_pos = self.parent.neuron_pos.copy() 39 | if neuron_pos.shape[-1] == 3: 40 | y, x, z = neuron_pos.T 41 | else: 42 | y, x = neuron_pos.T 43 | z = np.zeros_like(y) 44 | 45 | zgt, z = np.unique(z, return_inverse=True) 46 | zgt = zgt.astype("int") 47 | n_planes = z.max() + 1 48 | if n_planes > 30: 49 | bins = np.linspace(0, n_planes + 1, 30) 50 | bin_centers = bins[:-1] + np.diff(bins)[0] / 2 51 | zgt = zgt[bin_centers.astype(int)] 52 | z = np.digitize(z, bins) 53 | z = np.unique(z, return_inverse=True)[1].astype(int) 54 | n_planes = z.max() + 1 55 | 56 | Ly, Lx = y.max(), x.max() 57 | 58 | nX = np.ceil(np.sqrt(float(Ly) * float(Lx) * n_planes) / float(Lx)) 59 | nX = int(nX) 60 | nY = n_planes // nX 61 | print(n_planes, nX, nY) 62 | self.x, self.y, self.z = x, y, z 63 | self.nX, self.nY = nX, nY 64 | self.plots = [] 65 | self.scatter_plots = [] 66 | self.imgs = [] 67 | 68 | self.parent.PlaneWindow = self 69 | 70 | self.all_neurons_checkBox = self.parent.all_neurons_checkBox 71 | self.embedding = self.parent.embedding 72 | self.sorting = self.parent.sorting 73 | self.cluster_rois = self.parent.cluster_rois 74 | self.cluster_slices = self.parent.cluster_slices 75 | self.colors = self.parent.colors 76 | self.smooth_bin = self.parent.smooth_bin 77 | self.zstack = self.parent.zstack 78 | 79 | for ii in range(self.nY): 80 | for jj in range(self.nX): 81 | iplane = ii * self.nX + jj 82 | self.plots.append( 83 | self.win.addPlot(title=f"z = {iplane}", row=ii, col=jj, rowspan=1, 84 | colspan=1)) 85 | self.scatter_plots.append([]) 86 | self.imgs.append(pg.ImageItem()) 87 | self.plots[-1].addItem(self.imgs[-1]) 88 | if self.zstack is not None: 89 | self.imgs[-1].setImage(self.zstack[:, :, zgt[iplane]]) 90 | for i in range(nclust_max + 1): 91 | self.scatter_plots[-1].append(pg.ScatterPlotItem()) 92 | self.plots[-1].addItem(self.scatter_plots[-1][-1]) 93 | 94 | self.update_plots() 95 | self.win.show() 96 | self.show() 97 | 98 | def update_plots(self, roi_id=None): 99 | for ii in range(self.nY): 100 | for jj in range(self.nX): 101 | iplane = ii * self.nX + jj 102 | ip = self.z == iplane 103 | self.plot_scatter(self.x[ip], self.y[ip], iplane=iplane, neurons=ip, 104 | roi_id=roi_id) 105 | self.plots[iplane].show() 106 | 107 | def neurons_selected(self, selected=None, neurons=None): 108 | selected = selected if selected is not None else self.selected 109 | neurons_select = np.zeros(len(self.sorting), "bool") 110 | neurons_select[self.sorting[selected.start * self.smooth_bin:selected.stop * 111 | self.smooth_bin]] = True 112 | if neurons is not None: 113 | neurons_select = neurons_select[neurons] 114 | return neurons_select 115 | 116 | def plot_scatter(self, x, y, roi_id=None, iplane=0, neurons=None): 117 | if self.all_neurons_checkBox.isChecked() and roi_id is None: 118 | colors = colormaps.gist_ncar[np.linspace( 119 | 0, 254, len(x)).astype("int")][self.sorting] 120 | brushes = [pg.mkBrush(color=c) for c in colors] 121 | self.scatter_plots[iplane][0].setData(x, y, symbol="o", brush=brushes, 122 | hoverable=True) 123 | for i in range(1, nclust_max + 1): 124 | self.scatter_plots[iplane][i].setData([], []) 125 | else: 126 | if roi_id is None: 127 | self.scatter_plots[iplane][0].setData( 128 | x, y, symbol="o", brush=pg.mkBrush(color=(180, 180, 180)), 129 | hoverable=True) 130 | for roi_id in range(nclust_max): 131 | if roi_id < len(self.cluster_rois): 132 | selected = self.neurons_selected(self.cluster_slices[roi_id], 133 | neurons=neurons) 134 | self.scatter_plots[iplane][roi_id + 1].setData( 135 | x[selected], y[selected], symbol="o", 136 | brush=pg.mkBrush(color=self.colors[roi_id][:3]), 137 | hoverable=True) 138 | else: 139 | self.scatter_plots[iplane][roi_id + 1].setData([], []) 140 | else: 141 | selected = self.neurons_selected(self.cluster_slices[roi_id], 142 | neurons=neurons) 143 | self.scatter_plots[iplane][roi_id + 1].setData( 144 | x[selected], y[selected], symbol="o", 145 | brush=pg.mkBrush(color=self.colors[roi_id][:3]), hoverable=True) 146 | -------------------------------------------------------------------------------- /rastermap/io.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | import os, warnings 5 | import numpy as np 6 | import scipy.io as sio 7 | from scipy.stats import zscore 8 | from scipy.sparse import csr_array 9 | 10 | def _load_dict(dat, keys): 11 | X, Usv, Vsv, xpos, ypos, xy = None, None, None, None, None, None 12 | other_keys = [] 13 | for key in keys: 14 | if key=="Usv": 15 | Usv = dat["Usv"] 16 | elif key=="Vsv": 17 | Vsv = dat["Vsv"] 18 | elif key=="U": 19 | U = dat["U"] 20 | elif key=="U0": 21 | U = dat["U0"] 22 | elif key=="V": 23 | V = dat["V"] 24 | elif key=="V0": 25 | V = dat["V0"] 26 | elif key=="Sv": 27 | Sv = dat["Sv"] 28 | elif key=="sv": 29 | Sv = dat["sv"] 30 | elif key=="X": 31 | X = dat["X"] 32 | elif key=="spks": 33 | X = dat["spks"] 34 | elif key=="xpos": 35 | xpos = dat["xpos"] 36 | elif key=="ypos": 37 | ypos = dat["ypos"] 38 | elif key=="xy": 39 | xy = dat["xy"] 40 | elif key=="xyz": 41 | xy = dat["xyz"] 42 | else: 43 | other_keys.append(key) 44 | 45 | if X is None: 46 | if Usv is None and U is not None and Sv is None: 47 | if Vsv is not None: 48 | Sv = (Vsv**2).sum(axis=0)**0.5 49 | else: 50 | raise ValueError("no Sv scaling for PCs available") 51 | elif Vsv is None and V is not None and Sv is None: 52 | if Usv is not None: 53 | Sv = (Usv**2).sum(axis=0)**0.5 54 | else: 55 | raise ValueError("no Sv scaling for PCs available") 56 | if Usv is None and U is not None and Sv is not None: 57 | Usv = U * Sv 58 | if Vsv is None and V is not None and Sv is not None: 59 | Vsv = V * Sv 60 | 61 | if Usv.shape[-1] != Vsv.shape[-1]: 62 | raise ValueError("Usv and Vsv must have the same number of components") 63 | if Usv.ndim > 3: 64 | raise ValueError("Usv cannot have more than 3 dimensions") 65 | if Vsv.ndim != 2: 66 | raise ValueError("Vsv must have 2 dimensions") 67 | 68 | if xpos is not None and xy is None: 69 | xy = np.stack((ypos, xpos), axis=1) 70 | 71 | if xy is not None: 72 | if xy.ndim != 2: 73 | print("cannot use xy from file: x and y positions of neurons must be 2-dimensional") 74 | xy = None 75 | elif xy.shape[0]==2 or xy.shape[0]==3: 76 | xy = xy.T 77 | if xy is not None: 78 | if X is not None and X.shape[0]!=xy.shape[0]: 79 | xy = None 80 | elif Usv is not None and Usv.shape[0]!=xy.shape[0]: 81 | xy = None 82 | if xy is None: 83 | print("cannot use xy from file: x and y positions of neurons are not same size as activity") 84 | 85 | return X, Usv, Vsv, xy 86 | 87 | def load_activity(filename): 88 | ext = os.path.splitext(filename)[-1] 89 | print("Loading " + filename) 90 | Usv, Vsv, xy = None, None, None 91 | if ext == ".mat": 92 | try: 93 | X = sio.loadmat(filename) 94 | if isinstance(X, dict): 95 | for i, key in enumerate(X.keys()): 96 | if key not in ["__header__", "__version__", "__globals__"]: 97 | X = X[key] 98 | except NotImplementedError: 99 | try: 100 | import mat73 101 | except ImportError: 102 | print("please 'pip install mat73'") 103 | X = mat73.loadmat(filename) 104 | if isinstance(X, dict): 105 | keys = [] 106 | for i, key in enumerate(X.keys()): 107 | if key not in ["__header__", "__version__", "__globals__"]: 108 | keys.append(key) 109 | X, Usv, Vsv, xy = _load_dict(X, keys) 110 | elif ext == ".npy": 111 | X = np.load(filename, allow_pickle=True) 112 | try: 113 | if isinstance(X.item(), dict): 114 | dat = X.item() 115 | keys = dat.keys() 116 | X, Usv, Vsv, xy = _load_dict(dat, keys) 117 | except: 118 | print("data matrix:") 119 | try: 120 | print(X.shape) 121 | except: 122 | raise ValueError(".npy file does not contain an array or a dictionary") 123 | elif ext == ".npz": 124 | dat = np.load(filename, allow_pickle=True) 125 | keys = dat.files 126 | X, Usv, Vsv, xy = _load_dict(dat, keys) 127 | elif ext == ".nwb": 128 | X, xy = _load_nwb(filename) 129 | else: 130 | raise Exception("Invalid file type") 131 | 132 | if X is None and (Usv is None or Vsv is None): 133 | return 134 | if X is not None: 135 | if X.ndim == 1: 136 | raise ValueError( 137 | "ERROR: 1D array provided, but rastermap requires 2D array" 138 | ) 139 | elif X.ndim > 3: 140 | raise ValueError( 141 | "ERROR: nD array provided (n>3), but rastermap requires 2D array" 142 | ) 143 | elif X.ndim == 3: 144 | warnings.warn( 145 | "WARNING: 3D array provided (n>3), rastermap requires 2D array, will flatten to 2D" 146 | ) 147 | if X.shape[0] < 10: 148 | raise ValueError( 149 | "ERROR: matrix with fewer than 10 neurons provided" 150 | ) 151 | 152 | if len(X.shape) == 3: 153 | print( 154 | f"activity matrix has third dimension of size {X.shape[-1]}, flattening matrix to size ({X.shape[0]}, {X.shape[1] * X.shape[-1]}" 155 | ) 156 | X = X.reshape(X.shape[0], -1) 157 | 158 | return X, Usv, Vsv, xy 159 | 160 | def load_spike_times(fname, fname_cluid, st_bin=100): 161 | print("Loading " + fname) 162 | st = np.load(fname).squeeze() 163 | clu = np.load(fname_cluid).squeeze() 164 | if len(st) != len(clu): 165 | raise ValueError("spike times and clusters must have same length") 166 | spks = csr_array((np.ones(len(st), "uint8"), 167 | (clu, np.floor(st / st_bin * 1000).astype("int")))) 168 | spks = spks.todense().astype("float32") 169 | return spks 170 | 171 | def _cell_center(voxel_mask): 172 | x = np.median(np.array([v[0] for v in voxel_mask])) 173 | y = np.median(np.array([v[1] for v in voxel_mask])) 174 | return np.array([x, y]) 175 | 176 | def _load_nwb(filename): 177 | try: 178 | from pynwb import NWBHDF5IO, NWBFile, TimeSeries 179 | from pynwb.ophys import ( 180 | DfOverF, 181 | Fluorescence , 182 | RoiResponseSeries 183 | ) 184 | except: 185 | raise ImportError("pynwb not installed, please pip install pynwb") 186 | """ load ophys data from nwb""" 187 | with NWBHDF5IO(filename, "r") as io: 188 | read_nwbfile = io.read() 189 | 190 | # load neural activity 191 | X = [x for x in read_nwbfile.objects.values() if isinstance(x, Fluorescence)] 192 | names = [x.name for x in read_nwbfile.objects.values() if isinstance(x, Fluorescence)] 193 | if len(X) == 0: 194 | X = [x for x in read_nwbfile.objects.values() if isinstance(x, DfOverF)] 195 | names = [x.name for x in read_nwbfile.objects.values() if isinstance(x, DfOverF)] 196 | if len(X) == 0: 197 | X = [x for x in read_nwbfile.objects.values() if isinstance(x, RoiResponseSeries)] 198 | names = [x.name for x in read_nwbfile.objects.values() if isinstance(x, RoiResponseSeries)] 199 | 200 | 201 | if len(X) > 0: 202 | if len(X) == 3 and "Deconvolved" in names: 203 | X = X[names.index("Deconvolved")] 204 | elif len(X) > 1: 205 | # todo: allow user to select series 206 | print(f"more than one series to choose from, taking first series {names[0]}") 207 | X = X[0] 208 | elif len(X) == 1: 209 | X = X[0] 210 | 211 | planes = list(X.roi_response_series.keys()) 212 | 213 | spks = np.concatenate(([X[plane].data[:] for plane in planes]), 214 | axis=1).T 215 | spks = spks.astype("float32") 216 | ids = np.concatenate(([X[plane].rois.data[:] for plane in planes]), 217 | axis=0) 218 | 219 | if hasattr(X[planes[0]].rois[0], "image_mask"): 220 | roikey = "image_mask" 221 | elif hasattr(X[planes[0]].rois[0], "voxel_mask"): 222 | roikey = "voxel_mask" 223 | else: 224 | roikey = None 225 | 226 | if roikey is not None: 227 | xy = np.concatenate([np.array([_cell_center(roi[roikey].values[0]) 228 | for roi in X[plane].rois]) 229 | for plane in planes]) 230 | else: 231 | voxel_masks = np.concatenate([np.array([roi 232 | for roi in X[plane].rois]) 233 | for plane in planes]) 234 | xy = np.stack([_cell_center(vm[0][0]) for vm in voxel_masks], 235 | axis=0) 236 | else: 237 | raise ValueError("not an ophys NWB file with a Fluorescence or DfOverF roi_response_series") 238 | 239 | return spks, xy 240 | 241 | 242 | def _load_iscell(filename): 243 | basename = os.path.split(filename)[0] 244 | try: 245 | file_iscell = os.path.join(basename, "iscell.npy") 246 | iscell = np.load(file_iscell) 247 | probcell = iscell[:, 1] 248 | iscell = iscell[:, 0].astype("bool") 249 | except (ValueError, OSError, RuntimeError, TypeError, NameError): 250 | iscell = None 251 | file_iscell = None 252 | return iscell, file_iscell 253 | 254 | def _load_stat(filename): 255 | basename = os.path.split(filename)[0] 256 | try: 257 | file_stat = os.path.join(basename, "stat.npy") 258 | stat = np.load(file_stat, allow_pickle=True) 259 | xy = np.array([s["med"] for s in stat]) 260 | except (ValueError, OSError, RuntimeError, TypeError, NameError): 261 | xy = None 262 | file_stat = None 263 | return xy, file_stat 264 | -------------------------------------------------------------------------------- /rastermap/svd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | from sklearn.decomposition import TruncatedSVD 5 | import numpy as np 6 | from tqdm import trange 7 | from .utils import bin1d 8 | 9 | def subsampled_mean(X, n_mean=1000): 10 | n_frames = X.shape[0] 11 | n_mean = min(n_mean, n_frames) 12 | # load in chunks of up to 100 frames (for speed) 13 | nt = 100 14 | n_batches = int(np.floor(n_mean / nt)) 15 | chunk_len = n_frames // n_batches 16 | avgframe = np.zeros(X.shape[1:], dtype='float32') 17 | stdframe = np.zeros(X.shape[1:], dtype='float32') 18 | for n in trange(n_batches): 19 | Xb = X[n*chunk_len : n*chunk_len + nt].astype('float32') 20 | avgXb = Xb.mean(axis=0) 21 | avgframe += avgXb 22 | stdframe += ((Xb - avgXb[np.newaxis,...])**2).mean(axis=0) 23 | avgframe /= n_batches 24 | stdframe /= n_batches 25 | 26 | return avgframe, stdframe 27 | 28 | def SVD(X, n_components=250, return_USV=False, transpose=False): 29 | nmin = np.min(X.shape) 30 | nmin = min(nmin, n_components) 31 | 32 | Xt = X.T if transpose else X 33 | U = TruncatedSVD(n_components=nmin, 34 | random_state=0).fit_transform(Xt) 35 | 36 | if transpose: 37 | sv = (U**2).sum(axis=0)**0.5 38 | U /= sv 39 | V = (X @ U) / sv 40 | if return_USV: 41 | return V, sv, U 42 | else: 43 | return V 44 | else: 45 | if return_USV: 46 | sv = (U**2).sum(axis=0)**0.5 47 | U /= sv 48 | V = (X.T @ U) / sv 49 | return U, sv, V 50 | else: 51 | return U 52 | 53 | def subsampled_SVD(X, n_components=500, n_mean=1000, 54 | n_svd=15000, batch_size=1000, exclude=2): 55 | """ X is frames by voxels / pixels """ 56 | avgframe, stdframe = subsampled_mean(X) 57 | if exclude > 0: 58 | smin, smax = np.percentile(stdframe,1), np.percentile(stdframe,99) 59 | cutoff = np.linspace(smin, smax, 100//exclude + 1)[1] 60 | exclude = (stdframe < cutoff).flatten() 61 | print(f'{exclude.sum()} voxels excluded') 62 | else: 63 | exclude = np.zeros(avgframe.size, 'bool') 64 | 65 | n_voxels = np.prod(X.shape[1:]) 66 | n_frames = X.shape[0] 67 | batch_size = min(batch_size, n_frames) 68 | n_batches = int(min(np.floor(n_svd / batch_size), np.floor(n_frames / batch_size))) 69 | chunk_len = n_frames // n_batches 70 | nc = int(250) # <- how many PCs to keep in each chunk 71 | nc = min(nc, batch_size - 1) 72 | if n_batches == 1: 73 | nc = min(n_components, batch_size - 1) 74 | n_components = min(nc*n_batches, n_components) 75 | 76 | U = np.zeros(((~exclude).sum(), nc*n_batches), 'float32') 77 | avgframe_f = avgframe.flatten()[~exclude][np.newaxis,...] 78 | for n in trange(n_batches): 79 | Xb = X[n*chunk_len : n*chunk_len + batch_size] 80 | Xb = Xb.reshape(Xb.shape[0], -1) 81 | if exclude.sum()>0: 82 | Xb = Xb[:,~exclude] 83 | Xb = Xb.copy().astype('float32') 84 | Xb -= avgframe_f 85 | Xb = Xb.reshape(Xb.shape[0], -1) 86 | Ub = SVD(Xb.T, n_components=nc) 87 | U[:, n*nc : (n+1)*nc] = Ub 88 | 89 | if U.shape[-1] > n_components: 90 | U = SVD(U, n_components=n_components) 91 | 92 | Sv = (U**2).sum(axis=0)**0.5 93 | U /= Sv 94 | 95 | n_components = U.shape[-1] 96 | V = np.zeros((n_frames, n_components)) 97 | n_batches = int(np.ceil(n_frames / batch_size)) 98 | for n in trange(n_batches): 99 | Xb = X[n*batch_size : (n+1)*batch_size] 100 | Xb = Xb.reshape(Xb.shape[0], -1) 101 | if exclude.sum()>0: 102 | Xb = Xb[:,~exclude] 103 | Xb = Xb.copy().astype('float32') 104 | Xb -= avgframe_f 105 | Vb = Xb.reshape(Xb.shape[0],-1) @ U 106 | V[n*batch_size : (n+1)*batch_size] = Vb 107 | 108 | if exclude.sum()>0: 109 | Uf = np.nan * np.zeros((n_voxels, n_components), 'float32') 110 | Uf[~exclude] = U 111 | U = Uf 112 | 113 | U = U.reshape(*X.shape[1:], U.shape[-1]) 114 | 115 | return U, Sv, V -------------------------------------------------------------------------------- /rastermap/upsample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | import numpy as np 5 | from scipy.stats import zscore 6 | 7 | 8 | def grid_upsampling(X, X_nodes, Y_nodes, n_X, n_neighbors=50, e_neighbor=1): 9 | e_neighbor = min(n_neighbors - 1, e_neighbor) 10 | xy = [] 11 | n_clusters = Y_nodes.max() + 1 12 | grid_upsample = np.round(n_X / n_clusters) 13 | for i in range(Y_nodes.shape[1]): 14 | xy.append(np.arange(0, n_clusters, 1. / grid_upsample)) 15 | #xy.append(np.linspace(Y_nodes[:,i].min(), Y_nodes[:,i].max(), n_X)) 16 | if Y_nodes.shape[1] == 2: 17 | x_m, y_m = np.meshgrid(xy[0], xy[1], indexing="ij") 18 | xy = np.vstack((x_m.flatten(), y_m.flatten())) 19 | else: 20 | xy = xy[0][np.newaxis, :] 21 | 22 | ds = np.zeros((xy.shape[1], Y_nodes.shape[0])) 23 | n_components = len(xy) 24 | for i in range(len(xy)): 25 | ds += (xy[i][:, np.newaxis] - Y_nodes[:, i])**2 26 | isort = np.argsort(ds, 1)[:, :n_neighbors] 27 | nraster = xy.shape[1] 28 | Xrec = np.zeros((nraster, X_nodes.shape[1])) 29 | for j in range(nraster): 30 | ineigh = isort[j] 31 | dists = ds[j, ineigh] 32 | w = np.exp(-dists / dists[e_neighbor]) 33 | M, N = X_nodes[ineigh], Y_nodes[ineigh] 34 | N = np.concatenate((N, np.ones((n_neighbors, 1))), axis=1) 35 | R = np.linalg.solve((N.T * w) @ N, (N.T * w) @ M) 36 | Xrec[j] = xy[:, j] @ R[:-1] + R[-1] 37 | 38 | Xrec = Xrec / (Xrec**2).sum(1)[:, np.newaxis]**.5 39 | cc = Xrec @ zscore(X, 1).T 40 | cc = np.maximum(0, cc) 41 | imax = np.argmax(cc, 0) 42 | Y = xy[:, imax].T 43 | 44 | return Y, cc, xy, Xrec 45 | -------------------------------------------------------------------------------- /rastermap/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | import numpy as np 5 | import os, tempfile, shutil 6 | from urllib.request import urlopen 7 | from tqdm import tqdm 8 | from pathlib import Path 9 | 10 | def bin1d(X, bin_size, axis=0): 11 | """ mean bin over axis of data with bin bin_size """ 12 | if bin_size > 0: 13 | size = list(X.shape) 14 | Xb = X.swapaxes(0, axis) 15 | size_new = Xb.shape 16 | Xb = Xb[:size[axis]//bin_size*bin_size].reshape((size[axis]//bin_size, bin_size, *size_new[1:])).mean(axis=1) 17 | Xb = Xb.swapaxes(axis, 0) 18 | return Xb 19 | else: 20 | return X 21 | 22 | def split_traintest(n_t, frac=0.25, n_segs=20, pad=3, split_time=False): 23 | """this returns deterministic split of train and test in time chunks 24 | 25 | Parameters 26 | ---------- 27 | n_t : int 28 | number of timepoints to split 29 | frac : float (optional, default 0.25) 30 | fraction of points to put in test set 31 | pad : int (optional, default 3) 32 | number of timepoints to exclude from test set before and after training segment 33 | split_time : bool (optional, default False) 34 | split train and test into beginning and end of experiment 35 | Returns 36 | -------- 37 | itrain: 2D int array 38 | times in train set, arranged in chunks 39 | 40 | itest: 2D int array 41 | times in test set, arranged in chunks 42 | """ 43 | #usu want 20 segs, but might not have enough frames for that 44 | n_segs = int(min(n_segs, n_t/4)) 45 | n_len = int(np.floor(n_t/n_segs)) 46 | inds_train = np.linspace(0, n_t - n_len - 5, n_segs).astype(int) 47 | if not split_time: 48 | l_train = int(np.floor(n_len * (1-frac))) 49 | inds_test = inds_train + l_train + pad 50 | l_test = np.diff(np.stack((inds_train, inds_train + l_train)).T.flatten()).min() - pad 51 | else: 52 | inds_test = inds_train[:int(np.floor(n_segs*frac))] 53 | inds_train = inds_train[int(np.floor(n_segs*frac)):] 54 | l_train = n_len - 10 55 | l_test = l_train 56 | itrain = (inds_train[:,np.newaxis] + np.arange(0, l_train, 1, int)) 57 | itest = (inds_test[:,np.newaxis] + np.arange(0, l_test, 1, int)) 58 | return itrain, itest 59 | 60 | def download_url_to_file(url, dst, progress=True): 61 | r"""Download object at the given URL to a local path. 62 | Thanks to torch, slightly modified 63 | Args: 64 | url (string): URL of the object to download 65 | dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` 66 | progress (bool, optional): whether or not to display a progress bar to stderr 67 | Default: True 68 | """ 69 | file_size = None 70 | import ssl 71 | ssl._create_default_https_context = ssl._create_unverified_context 72 | u = urlopen(url) 73 | meta = u.info() 74 | if hasattr(meta, 'getheaders'): 75 | content_length = meta.getheaders("Content-Length") 76 | else: 77 | content_length = meta.get_all("Content-Length") 78 | if content_length is not None and len(content_length) > 0: 79 | file_size = int(content_length[0]) 80 | # We deliberately save it in a temp file and move it after 81 | dst = os.path.expanduser(dst) 82 | dst_dir = os.path.dirname(dst) 83 | f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) 84 | try: 85 | with tqdm(total=file_size, disable=not progress, 86 | unit='B', unit_scale=True, unit_divisor=1024) as pbar: 87 | while True: 88 | buffer = u.read(8192) 89 | if len(buffer) == 0: 90 | break 91 | f.write(buffer) 92 | pbar.update(len(buffer)) 93 | f.close() 94 | shutil.move(f.name, dst) 95 | finally: 96 | f.close() 97 | if os.path.exists(f.name): 98 | os.remove(f.name) 99 | 100 | def download_data(data_type="hippocampus"): 101 | if data_type=="widefield": 102 | url = "https://osf.io/5d8q7/download" 103 | elif data_type=="spont2": 104 | url = "https://osf.io/8xg7n/download" 105 | elif data_type=="hippocampus": 106 | url = "https://osf.io/szmw6/download" 107 | elif data_type=="fish": 108 | url = "https://osf.io/2w8pa/download" 109 | ddir = Path.home().joinpath('.rastermap') 110 | ddir.mkdir(exist_ok=True) 111 | data_dir = ddir.joinpath('data') 112 | data_dir.mkdir(exist_ok=True) 113 | data_file = str(data_dir.joinpath(f"{data_type}_data.npz")) 114 | if not os.path.exists(data_file): 115 | download_url_to_file(url, data_file) 116 | return data_file 117 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. 3 | """ 4 | import setuptools 5 | 6 | install_deps = [ 7 | "numpy>=1.24.0", 8 | "scipy", 9 | "scikit-learn", 10 | "numba>=0.57.0", 11 | "natsort", 12 | "tqdm" 13 | ] 14 | 15 | gui_deps = [ 16 | "pyqtgraph>=0.11.0rc0", 17 | "pyqt6", 18 | "pyqt6.sip", 19 | "qtpy", 20 | "superqt", 21 | ] 22 | 23 | try: 24 | import PyQt5 25 | gui_deps.remove("pyqt6") 26 | gui_deps.remove("pyqt6.sip") 27 | except: 28 | pass 29 | 30 | try: 31 | import PySide2 32 | gui_deps.remove("pyqt6") 33 | gui_deps.remove("pyqt6.sip") 34 | except: 35 | pass 36 | 37 | try: 38 | import PySide6 39 | gui_deps.remove("pyqt6") 40 | gui_deps.remove("pyqt6.sip") 41 | except: 42 | pass 43 | 44 | 45 | with open("README.md", "r") as fh: 46 | long_description = fh.read() 47 | 48 | setuptools.setup( 49 | name="rastermap", 50 | use_scm_version=True, 51 | author="Marius Pachitariu and Carsen Stringer", 52 | author_email="carsen.stringer@gmail.com", 53 | description="Unsupervised clustering algorithm for 2D data (neurons by time)", 54 | long_description=long_description, 55 | long_description_content_type="text/markdown", 56 | url="https://github.com/MouseLand/rastermap", 57 | packages=setuptools.find_packages(), 58 | install_requires = install_deps, 59 | extras_require = { 60 | "gui": gui_deps 61 | }, 62 | tests_require = ["pytest"], 63 | include_package_data=True, 64 | classifiers=( 65 | "Programming Language :: Python :: 3", 66 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", 67 | "Operating System :: OS Independent", 68 | ), 69 | ) 70 | -------------------------------------------------------------------------------- /tests/test_import.py: -------------------------------------------------------------------------------- 1 | 2 | def test_gui(): 3 | from rastermap.gui import gui 4 | 5 | def test_class(): 6 | from rastermap import Rastermap -------------------------------------------------------------------------------- /tests/test_rastermap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rastermap import Rastermap 3 | 4 | 5 | def test_rastermap(test_file): 6 | dat = np.load(test_file) 7 | spks = dat["spks"] 8 | 9 | model = Rastermap().fit(data=spks) 10 | 11 | assert hasattr(model, "embedding") 12 | assert hasattr(model, "isort") 13 | assert hasattr(model, "Usv") 14 | assert hasattr(model, "Vsv") 15 | 16 | def test_rastermap_splits(test_file): 17 | dat = np.load(test_file) 18 | spks = dat["spks"] 19 | 20 | model = Rastermap(n_splits=2, n_clusters=20, nc_splits=10).fit(data=spks) 21 | 22 | assert hasattr(model, "embedding") 23 | assert hasattr(model, "isort") 24 | assert hasattr(model, "Usv") 25 | assert hasattr(model, "Vsv") -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # For more information about tox, see https://tox.readthedocs.io/en/latest/ 2 | [tox] 3 | envlist = py{38,39,310}-{linux,macos,windows} 4 | 5 | [gh-actions] 6 | python = 7 | 3.8: py38 8 | 3.9: py39 9 | 3.10: py310 10 | 11 | [gh-actions:env] 12 | PLATFORM = 13 | ubuntu-latest: linux 14 | macos-latest: macos 15 | windows-latest: windows 16 | 17 | [testenv] 18 | platform = 19 | macos: darwin 20 | linux: linux 21 | windows: win32 22 | passenv = 23 | CI 24 | GITHUB_ACTIONS 25 | DISPLAY,XAUTHORITY 26 | NUMPY_EXPERIMENTAL_ARRAY_FUNCTION 27 | PYVISTA_OFF_SCREEN 28 | conda_deps = 29 | pytest 30 | conda_channels = 31 | pytorch 32 | deps = 33 | .[gui] 34 | pytest 35 | pytest-cov 36 | pytest-xvfb 37 | tqdm 38 | commands = 39 | pytest -v --color=yes --cov=rastermap --cov-report=xml 40 | --------------------------------------------------------------------------------