├── .dockerignore ├── .gitignore ├── CITATION.cff ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── configs ├── .gitignore ├── distillation_001.py └── true_batch_001.py ├── data ├── .gitignore ├── experiments.torrent └── readme_images │ ├── architecture.png │ ├── preview.png │ └── softplus.png ├── requirements.txt ├── scripts ├── download_data.py ├── ensemble.py ├── predict.py └── train.py ├── setup.cfg └── src ├── __init__.py ├── argus_models.py ├── constants.py ├── data.py ├── datasets.py ├── ema.py ├── indexes.py ├── inputs.py ├── losses.py ├── metrics.py ├── mixers.py ├── models ├── __init__.py └── dwiseneuro.py ├── phash.py ├── predictors.py ├── responses.py ├── submission.py └── utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | data/ 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: Solution for Sensorium 2023 Competition 3 | message: 'If you use this work, please cite it using these metadata.' 4 | type: software 5 | authors: 6 | - given-names: Ruslan 7 | family-names: Baikulov 8 | email: ruslan1123@gmail.com 9 | orcid: 'https://orcid.org/0009-0003-4400-0619' 10 | repository-code: 'https://github.com/lRomul/sensorium' 11 | abstract: >- 12 | This repository contains the winning solution for the Sensorium 2023 Competition. 13 | The competition aimed to predict the activity of neurons in the primary visual 14 | cortex of mice in response to videos. The proposed solution includes a novel 15 | model architecture called DwiseNeuro and a training pipeline with a solid 16 | cross-validation strategy and knowledge distillation. 17 | license: MIT 18 | version: v23.11.22 19 | date-released: '2023-11-22' 20 | doi: 10.5281/zenodo.10155151 21 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM osaiai/dokai:23.10-pytorch 2 | 3 | COPY requirements.txt requirements.txt 4 | RUN pip3 install --no-cache-dir -r requirements.txt 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ruslan Baikulov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | NAME?=sensorium 2 | COMMAND?=bash 3 | OPTIONS?= 4 | 5 | GPUS?=all 6 | ifeq ($(GPUS),none) 7 | GPUS_OPTION= 8 | else 9 | GPUS_OPTION=--gpus=$(GPUS) 10 | endif 11 | 12 | .PHONY: all 13 | all: stop build run 14 | 15 | .PHONY: build 16 | build: 17 | docker build -t $(NAME) . 18 | 19 | .PHONY: stop 20 | stop: 21 | -docker stop $(NAME) 22 | -docker rm $(NAME) 23 | 24 | .PHONY: run 25 | run: 26 | docker run --rm -dit \ 27 | --net=host \ 28 | --ipc=host \ 29 | $(OPTIONS) \ 30 | $(GPUS_OPTION) \ 31 | -v $(shell pwd):/workdir \ 32 | --name=$(NAME) \ 33 | $(NAME) \ 34 | $(COMMAND) 35 | docker attach $(NAME) 36 | 37 | .PHONY: attach 38 | attach: 39 | docker attach $(NAME) 40 | 41 | .PHONY: logs 42 | logs: 43 | docker logs -f $(NAME) 44 | 45 | .PHONY: exec 46 | exec: 47 | docker exec -it $(OPTIONS) $(NAME) $(COMMAND) 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Solution for Sensorium 2023 Competition 4 | 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE) 6 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.10155151.svg)](https://doi.org/10.5281/zenodo.10155151) 7 | 8 |
9 | 10 | This repository contains the winning solution for the Sensorium 2023, part of the NeurIPS 2023 competition track. 11 | The competition aimed to find the best model that can predict the activity of neurons in the primary visual cortex of mice in response to videos. 12 | The competition introduced a temporal component using dynamic stimuli (videos) instead of static stimuli (images) used in Sensorium 2022, making the task more challenging. 13 | 14 | The primary metric of the competition was a single trial correlation. 15 | You can read about the metric, the data, and the task in the competition paper [^1]. 16 | It is important to note that additional data for five mice was introduced during the competition, which doubled the dataset's size ([old](https://gin.g-node.org/pollytur/Sensorium2023Data) and [new](https://gin.g-node.org/pollytur/sensorium_2023_dataset) data). 17 | 18 | ## Solution 19 | 20 | Key points: 21 | * [DwiseNeuro](src/models/dwiseneuro.py) - novel model architecture for prediction neural activity in the mouse primary visual cortex. 22 | * Solid cross-validation strategy with splitting folds by perceptual video hash. 23 | * Training on all mice with an option to fill unlabeled samples via distillation. 24 | 25 | ## Model Architecture 26 | 27 | During the competition, I dedicated most of my time to designing the model architecture since it significantly impacted the solution's outcome compared to other components. 28 | I iteratively tested various computer vision and deep learning techniques, integrating them into the architecture as the correlation metric improved. 29 | 30 | The diagram below illustrates the final architecture, which I named DwiseNeuro: 31 | 32 | ![architecture](data/readme_images/architecture.png) 33 | 34 | DwiseNeuro consists of three main parts: core, cortex, and readouts. 35 | The core consumes sequences of video frames and mouse behavior activity in separate channels, processing temporal and spatial features. 36 | Produced features pass through global pooling over spatial dimensions to aggregate them. 37 | The cortex processes the pooled features independently for each timestep, significantly increasing the channels. 38 | Finally, each readout predicts the activation of neurons for the corresponding mouse. 39 | 40 | In the following sections, we will delve deeper into each part of the architecture. 41 | 42 | ### Core 43 | 44 | The first layer of the module is the stem. It's a point-wise 3D convolution for increasing the number of channels, followed by batch normalization. 45 | The rest of the core consists of inverted residual blocks [^2][^3] with a `narrow -> wide -> narrow` channel structure. 46 | 47 | #### Techniques 48 | 49 | Several methods were added to the inverted residual block rewritten in 3D layers: 50 | * **Absolute Position Encoding** [^4] - summing the encoding to the input of each block allows convolutions to accumulate position information. It's quite important because of the subsequent spatial pooling after the core. 51 | * **Factorized (2+1)D convolution** [^5] - 3D depth-wise convolution was replaced with a spatial 2D depth-wise convolution followed by a temporal 1D depth-wise convolution. There are spatial convolutions with stride two in some blocks to compress output size. 52 | * **Shortcut Connections** - completely parameter-free residual shortcuts with three operations: 53 | * Identity mapping if input and output dimensions are equal. It's the same as the connection proposed in ResNet [^6]. 54 | * Nearest interpolation in case of different spatial sizes. 55 | * Cycling repeating of channels if they don't match. 56 | * **Squeeze-and-Excitation** [^7] - dynamic channel-wise feature recalibration. 57 | * **DropPath (Stochastic Depth)** [^8][^9] - regularization that randomly drops the block's main path for each sample in batch. 58 | 59 | Batch normalization is applied after each layer, including the shortcut. 60 | SiLU activation is used after expansion and depth-wise convolutions. 61 | 62 | #### Hyperparameters 63 | 64 | I found that the number of core blocks and their parameters dramatically affect the outcome. 65 | It's possible to tune channels, strides, expansion ratio, and spatial/temporal kernel sizes. 66 | Obviously, it is almost impossible to start experiments with optimal values. 67 | The problem is mentioned in the EfficientNet [^3] paper, which concluded that it is essential to carefully balance model width, depth, and resolution. 68 | 69 | After conducting a lot of experiments, I chose the following parameters: 70 | * Four blocks with 64 output channels, three with 128, and two with 256. 71 | * Three blocks have stride two. They are the first in each subgroup from the point above. 72 | * Expansion ratio of the inverted residual block is six. 73 | * Kernel of spatial depth-wise convolution is (1, 3, 3). 74 | * Kernel of temporal depth-wise convolution is (5, 1, 1). 75 | 76 | ### Cortex 77 | 78 | Compared with related works [^10][^11], the model architecture includes a new part called the cortex. 79 | It is also common for all mice as the core. 80 | The cortex receives features with channels and temporal dimensions only. 81 | Spatial information was accumulated through position encoding applied earlier in the core and compressed by average pooling after the core. 82 | The primary purpose of the cortex is to smoothly increase the number of channels, which the readouts will further use. 83 | 84 | The building element of the module is a grouped 1D convolution followed by the channel shuffle operation [^12]. 85 | Batch normalization, SiLU activation, and shortcut connections with stochastic depth were applied similarly to the core. 86 | 87 | #### Hyperparameters 88 | 89 | Hyperparameters of the cortex were also important: 90 | * Convolution with two groups and kernel size one (bigger kernel size over temporal dimension has not led to better results). 91 | * Three layers with 1024, 2048, and 4096 channels. 92 | 93 | As you can see, the number of channels is quite large. 94 | Groups help optimize computation and memory efficiency. 95 | Channel shuffle operation allows the sharing of information between groups of different layers. 96 | 97 | ### Readouts 98 | 99 | The readout is a single 1D convolution with two groups and kernel size one, followed by Softplus activation. 100 | Each of the ten mice has its readout with the number of output channels equal to the number of neurons (7863, 7908, 8202, 7939, 8122, 7440, 7928, 8285, 7671, 7495, respectively). 101 | 102 | #### Softplus 103 | 104 | Keeping the response positive by using Softplus [^10] was essential in my pipeline. 105 | It works much better than `ELU + 1` [^11], especially with tuning the Softplus beta parameter. 106 | In my case, the optimal beta value was about 0.07, which resulted in a 0.018 increase in the correlation metric. 107 | 108 | You can see a comparison of `ELU + 1` and Softplus in the plot below: 109 | 110 | ![softplus](data/readme_images/softplus.png) 111 | 112 | #### Learnable Softplus 113 | 114 | I also conducted an experiment where the beta parameter was trainable. 115 | Interestingly, the trained value converged approximately to the optimal, which I found by grid search. 116 | I omitted the learnable Softplus from the solution because it resulted in a slightly worse score. 117 | But this may be an excellent way to quickly and automatically find a good beta. 118 | 119 | Here's a numerical stable implementation of learnable Softplus in PyTorch: 120 | 121 | ```Python 122 | import torch 123 | from torch import nn 124 | 125 | 126 | class LearnableSoftplus(nn.Module): 127 | def __init__(self, beta: float): 128 | super().__init__() 129 | self.beta = nn.Parameter(torch.tensor(float(beta))) 130 | 131 | def forward(self, x): 132 | xb = x * self.beta 133 | return (torch.clamp(xb, 0) + torch.minimum(xb, -xb).exp().log1p()) / self.beta 134 | ``` 135 | 136 | ## Training 137 | 138 | ### Validation strategy 139 | 140 | At the end of the competition, I employed 7-fold cross-validation to check hypotheses and tune hyperparameters more precisely. 141 | I used all available labeled data to make folds. 142 | Random splitting gave an overly optimistic metric estimation because some videos were duplicated (e.g., in the original validation split or between old and new datasets). 143 | To solve this issue, I created group folds with non-overlapping videos. 144 | Similar videos were found using [perceptual hashes](src/phash.py) of several frames fetched deterministically. 145 | 146 | ### Basic training ([config](configs/true_batch_001.py)) 147 | 148 | The training was performed in two stages. The first stage is basic training with the following pipeline parameters: 149 | * Learning rate warmup for the first three epochs from 0 to 2.4e-03, cosine annealing last 18 epochs to 2.4e-05 150 | * Batch size 32, one training epoch comprises 72000 samples 151 | * Optimizer AdamW [^13] with weight decay 0.05 152 | * Poisson loss 153 | * Model EMA with decay 0.999 154 | * CutMix [^14] with alpha 1.0 and usage probability 0.5 155 | * The sampling of different mice in the batch is random by uniform distribution 156 | 157 | Each dataset sample consists of a grayscale video, behavior activity measurements (pupil center, pupil dilation, and running speed), and the neuron responses of a single mouse. 158 | All data is presented at 30 FPS. During training, the model consumes 16 frames, skipping every second (equivalent to 16 neighboring frames at 15 FPS). 159 | The video frames were zero-padded to 64x64 pixels. The behavior activities were added as separate channels. 160 | The entire tensor channel was filled with the value for each behavior measurement. 161 | No normalization is applied to the target and input tensors during training. 162 | 163 | The ensemble of models from all folds gets 0.2905 single-trial correlation on the main track and 0.2207 on the bonus track in the final phase of the competition. 164 | This result would be enough to take first place in the main and bonus (out-of-distribution) competition tracks. 165 | 166 | ### Knowledge Distillation ([config](configs/distillation_001.py)) 167 | 168 | For an individual sample in the batch, the loss was calculated for the responses of only one mouse. 169 | Because the input tensor is associated with a single mouse trial, and there are no neural activity data for other mice. 170 | However, the model can predict responses for all mice from the input tensor. In the second stage of training, I used a method similar to knowledge distillation [^15]. 171 | I created a pipeline where models from the first stage predict unlabeled responses during training. 172 | As a result, the second-stage models trained all their readouts via each batch sample. 173 | The loss value on distilled predictions was weighed to be 0.36% of the overall loss. 174 | 175 | The hyperparameters were identical, except for the expansion ratio in inverted residual blocks: seven in the first stage and six in the second. 176 | 177 | In the second stage, the ensemble of models achieves nearly the same single-trial correlation as the ensemble from the first stage. 178 | However, what is fascinating is that each fold model performs better by an average score of 0.007 than the corresponding model from the first stage. 179 | Thus, the distilled model works like an ensemble of undistilled models. 180 | According to the work [^16], the individual model is forced to learn the ensemble's performance during knowledge distillation, and an ensemble of distilled models offers no more performance boost. 181 | I can observe the same behavior in my solution. 182 | 183 | Distillation can be a great practice if you need one good model. 184 | But in ensembles, this leads to minor changes in performance. 185 | 186 | ## Prediction 187 | 188 | The ensembles were produced by taking the arithmetic mean of predictions from multiple steps: 189 | * Overlapping a sliding window over each possible sequence of frames. 190 | * Models from cross-validations of one training stage. 191 | * Training stages (optional). 192 | 193 | I used the same model weights for both competition tracks. 194 | The competition only evaluated the responses of five mice from the new dataset, so I only predicted those. 195 | 196 | ## Competition progress 197 | 198 | You can see the progress of solution development during the competition in [spreadsheets](https://docs.google.com/spreadsheets/d/1xJTB6lZvtjNSQbYQiB_hgKIrL9hUB0uL3xhUAK1Pmxw/edit#gid=0) (the document consists of multiple sheets). 199 | Unfortunately, the document contains less than half of the conducted experiments because sometimes I was too lazy to fill it :) 200 | However, if you need a complete chronology of valuable changes, you can see it in git history. 201 | 202 | To summarize, an early model with depth-wise 3D convolution blocks achieved a score of around 0.19 on the main track during the live phase of the competition. 203 | Subsequently, implementing techniques from the core section, tuning hyperparameters, and training on all available data boosted the score to 0.25. 204 | Applying non-standard normalization, expected by evaluation servers, on postprocessing improved the score to 0.27. 205 | The cortex and CutMix increased the score to 0.276. 206 | Then, the beta value of Softplus was tuned, resulting in a score of 0.294. 207 | Lastly, adjusting drop rate and batch size parameters helped to achieve a score of 0.3 on the main track during the live phase. 208 | 209 | The ensemble of the basic and distillation training stages achieved a single-trial correlation of 0.2913 on the main track and 0.2215 on the bonus track in the final phase (0.3005 and 0.2173 in the live phase, respectively). 210 | This result is just a bit better than the basic training result alone, but I should provide it because it was the best submission in the competition. 211 | In addition, it was interesting to research the relation between ensembling and distillation, which I wrote above. 212 | 213 | Thanks to the Sensorium organizers and participants for the excellent competition. Thanks to my family and friends who supported me during the competition! 214 | 215 | 216 | 217 | [^1]: Turishcheva, Polina, et al. (2023). The Dynamic Sensorium competition for predicting large-scale mouse visual cortex activity from videos. https://arxiv.org/abs/2305.19654 218 | [^2]: Sandler, Mark, et al. (2018). Mobilenetv2: Inverted residuals and linear bottlenecks. https://arxiv.org/abs/1801.04381 219 | [^3]: Tan, Mingxing, and Quoc Le. (2019). Efficientnet: Rethinking model scaling for convolutional neural networks. https://arxiv.org/abs/1905.11946 220 | [^4]: Vaswani, Ashish, et al. (2017). Attention is all you need. https://arxiv.org/abs/1706.03762 221 | [^5]: Tran, Du, et al. (2018). A closer look at spatiotemporal convolutions for action recognition. https://arxiv.org/abs/1711.11248 222 | [^6]: He, Kaiming, et al. (2016). Deep residual learning for image recognition. https://arxiv.org/abs/1512.03385 223 | [^7]: Hu, Jie, Li Shen, and Gang Sun. (2018). Squeeze-and-excitation networks. https://arxiv.org/abs/1709.01507 224 | [^8]: Larsson, Gustav, Michael Maire, and Gregory Shakhnarovich. (2016). Fractalnet: Ultra-deep neural networks without residuals. https://arxiv.org/abs/1605.07648 225 | [^9]: Huang, Gao, et al. (2016). Deep networks with stochastic depth. https://arxiv.org/abs/1603.09382 226 | [^10]: Höfling, Larissa, et al. (2022). A chromatic feature detector in the retina signals visual context changes. https://www.biorxiv.org/content/10.1101/2022.11.30.518492 227 | [^11]: Lurz, Konstantin-Klemens, et al. (2020). Generalization in data-driven models of primary visual cortex. https://www.biorxiv.org/content/10.1101/2020.10.05.326256 228 | [^12]: Zhang, Xiangyu, et al. (2018). Shufflenet: An extremely efficient convolutional neural network for mobile devices. https://arxiv.org/abs/1707.01083 229 | [^13]: Loshchilov, Ilya, and Frank Hutter. (2017). Decoupled weight decay regularization. https://arxiv.org/abs/1711.05101 230 | [^14]: Yun, Sangdoo, et al. (2019). Cutmix: Regularization strategy to train strong classifiers with localizable features. https://arxiv.org/abs/1905.04899 231 | [^15]: Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. (2015). Distilling the knowledge in a neural network. https://arxiv.org/abs/1503.02531 232 | [^16]: Allen-Zhu, Zeyuan, and Yuanzhi Li. (2020). Towards understanding ensemble, knowledge distillation and self-distillation in deep learning. https://arxiv.org/abs/2012.09816 233 | 234 | ## Quick setup and start 235 | 236 | ### Requirements 237 | 238 | * Linux (tested on Ubuntu 20.04 and 22.04) 239 | * NVIDIA GPU (models trained on RTX A6000) 240 | * NVIDIA Drivers >= 535, CUDA >= 12.2 241 | * [Docker](https://docs.docker.com/engine/install/) 242 | * [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) 243 | 244 | Pipeline tuned for training on a single RTX A6000 with 48 GB. 245 | In the case of GPU with less memory, you can use gradient accumulation by increasing the `iter_size` parameter in training configs. 246 | It will worsen the result (by a 0.002 score for `"iter_size": 2`), but it has less than the effect of reducing the batch size. 247 | 248 | ### Run 249 | 250 | Clone the repo and enter the folder. 251 | 252 | ```bash 253 | git clone git@github.com:lRomul/sensorium.git 254 | cd sensorium 255 | ``` 256 | 257 | Build a Docker image and run a container. 258 | 259 |
Here is a small guide on how to use the provided Makefile 260 | 261 | ```bash 262 | make # stop, build, run 263 | 264 | # do the same 265 | make stop 266 | make build 267 | make run 268 | 269 | make # by default all GPUs passed 270 | make GPUS=all # do the same 271 | make GPUS=none # without GPUs 272 | 273 | make run GPUS=2 # pass the first two GPUs 274 | make run GPUS='\"device=1,2\"' # pass GPUs numbered 1 and 2 275 | 276 | make logs 277 | make exec # run a new command in a running container 278 | make exec COMMAND="bash" # do the same 279 | make stop 280 | ``` 281 | 282 |
283 | 284 | ```bash 285 | make 286 | ``` 287 | 288 | From now on, you should run all commands inside the docker container. 289 | 290 | If you already have the Sensorium 2023 dataset (148 GB), copy it to the folder `./data/sensorium_all_2023/`. 291 | Otherwise, use the script for downloading: 292 | 293 | ```bash 294 | python scripts/download_data.py 295 | ``` 296 | 297 | You can now reproduce the final results of the solution using the following commands: 298 | ```bash 299 | # Train 300 | # The training time is 3.5 days (12 hours per fold) for each experiment on a single A6000 301 | # You can speed up the process by using the --folds argument to train folds in parallel 302 | # Or just download trained weights in the section below 303 | python scripts/train.py -e true_batch_001 304 | python scripts/train.py -e distillation_001 305 | 306 | # Predict 307 | # Any GPU with more than 6 GB memory will be enough 308 | python scripts/predict.py -e true_batch_001 -s live_test_main 309 | python scripts/predict.py -e true_batch_001 -s live_test_bonus 310 | python scripts/predict.py -e true_batch_001 -s final_test_main 311 | python scripts/predict.py -e true_batch_001 -s final_test_bonus 312 | python scripts/predict.py -e distillation_001 -s live_test_main 313 | python scripts/predict.py -e distillation_001 -s live_test_bonus 314 | python scripts/predict.py -e distillation_001 -s final_test_main 315 | python scripts/predict.py -e distillation_001 -s final_test_bonus 316 | 317 | # Ensemble predictions of two experiments 318 | python scripts/ensemble.py -e distillation_001,true_batch_001 -s live_test_main 319 | python scripts/ensemble.py -e distillation_001,true_batch_001 -s live_test_bonus 320 | python scripts/ensemble.py -e distillation_001,true_batch_001 -s final_test_main 321 | python scripts/ensemble.py -e distillation_001,true_batch_001 -s final_test_bonus 322 | 323 | # Final predictions will be there 324 | cd data/predictions/distillation_001,true_batch_001 325 | ``` 326 | 327 | ### Trained model weights 328 | 329 | You can skip the training step by downloading model weights (9.5 GB) using [torrent file](data/experiments.torrent). 330 | 331 | Place the files in the data directory so that the folder structure is as follows: 332 | 333 | ``` 334 | data 335 | ├── experiments 336 | │ ├── distillation_001 337 | │ └── true_batch_001 338 | └── sensorium_all_2023 339 | ├── dynamic29156-11-10-Video-8744edeac3b4d1ce16b680916b5267ce 340 | ├── dynamic29228-2-10-Video-8744edeac3b4d1ce16b680916b5267ce 341 | ├── dynamic29234-6-9-Video-8744edeac3b4d1ce16b680916b5267ce 342 | ├── dynamic29513-3-5-Video-8744edeac3b4d1ce16b680916b5267ce 343 | ├── dynamic29514-2-9-Video-8744edeac3b4d1ce16b680916b5267ce 344 | ├── dynamic29515-10-12-Video-9b4f6a1a067fe51e15306b9628efea20 345 | ├── dynamic29623-4-9-Video-9b4f6a1a067fe51e15306b9628efea20 346 | ├── dynamic29647-19-8-Video-9b4f6a1a067fe51e15306b9628efea20 347 | ├── dynamic29712-5-9-Video-9b4f6a1a067fe51e15306b9628efea20 348 | └── dynamic29755-2-8-Video-9b4f6a1a067fe51e15306b9628efea20 349 | ``` 350 | -------------------------------------------------------------------------------- /configs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !*.py 3 | -------------------------------------------------------------------------------- /configs/distillation_001.py: -------------------------------------------------------------------------------- 1 | from src.utils import get_lr 2 | from src import constants 3 | 4 | 5 | image_size = (64, 64) 6 | batch_size = 32 7 | base_lr = 3e-4 8 | frame_stack_size = 16 9 | config = dict( 10 | image_size=image_size, 11 | batch_size=batch_size, 12 | base_lr=base_lr, 13 | min_base_lr=base_lr * 0.01, 14 | ema_decay=0.999, 15 | train_epoch_size=72000, 16 | num_epochs=[3, 18], 17 | stages=["warmup", "train"], 18 | num_dataloader_workers=8, 19 | init_weights=True, 20 | argus_params={ 21 | "nn_module": ("dwiseneuro", { 22 | "readout_outputs": constants.num_neurons, 23 | "in_channels": 5, 24 | "core_features": (64, 64, 64, 64, 25 | 128, 128, 128, 26 | 256, 256), 27 | "spatial_strides": (2, 1, 1, 1, 28 | 2, 1, 1, 29 | 2, 1), 30 | "spatial_kernel": 3, 31 | "temporal_kernel": 5, 32 | "expansion_ratio": 6, 33 | "se_reduce_ratio": 32, 34 | "cortex_features": (512 * 2, 1024 * 2, 2048 * 2), 35 | "groups": 2, 36 | "softplus_beta": 0.07, 37 | "drop_rate": 0.4, 38 | "drop_path_rate": 0.1, 39 | }), 40 | "loss": ("mice_poisson", { 41 | "log_input": False, 42 | "full": False, 43 | "eps": 1e-8, 44 | }), 45 | "optimizer": ("AdamW", { 46 | "lr": get_lr(base_lr, batch_size), 47 | "weight_decay": 0.05, 48 | }), 49 | "device": "cuda:0", 50 | "frame_stack": { 51 | "size": frame_stack_size, 52 | "step": 2, 53 | "position": "last", 54 | }, 55 | "inputs_processor": ("stack_inputs", { 56 | "size": image_size, 57 | "pad_fill_value": 0., 58 | }), 59 | "responses_processor": ("identity", {}), 60 | "amp": True, 61 | "iter_size": 1, 62 | }, 63 | cutmix={ 64 | "alpha": 1.0, 65 | "prob": 0.5, 66 | }, 67 | distill={ 68 | "experiment": "true_batch_001", 69 | "ratio": 0.36, 70 | }, 71 | ) 72 | -------------------------------------------------------------------------------- /configs/true_batch_001.py: -------------------------------------------------------------------------------- 1 | from src.utils import get_lr 2 | from src import constants 3 | 4 | 5 | image_size = (64, 64) 6 | batch_size = 32 7 | base_lr = 3e-4 8 | frame_stack_size = 16 9 | config = dict( 10 | image_size=image_size, 11 | batch_size=batch_size, 12 | base_lr=base_lr, 13 | min_base_lr=base_lr * 0.01, 14 | ema_decay=0.999, 15 | train_epoch_size=72000, 16 | num_epochs=[3, 18], 17 | stages=["warmup", "train"], 18 | num_dataloader_workers=8, 19 | init_weights=True, 20 | argus_params={ 21 | "nn_module": ("dwiseneuro", { 22 | "readout_outputs": constants.num_neurons, 23 | "in_channels": 5, 24 | "core_features": (64, 64, 64, 64, 25 | 128, 128, 128, 26 | 256, 256), 27 | "spatial_strides": (2, 1, 1, 1, 28 | 2, 1, 1, 29 | 2, 1), 30 | "spatial_kernel": 3, 31 | "temporal_kernel": 5, 32 | "expansion_ratio": 7, 33 | "se_reduce_ratio": 32, 34 | "cortex_features": (512 * 2, 1024 * 2, 2048 * 2), 35 | "groups": 2, 36 | "softplus_beta": 0.07, 37 | "drop_rate": 0.4, 38 | "drop_path_rate": 0.1, 39 | }), 40 | "loss": ("mice_poisson", { 41 | "log_input": False, 42 | "full": False, 43 | "eps": 1e-8, 44 | }), 45 | "optimizer": ("AdamW", { 46 | "lr": get_lr(base_lr, batch_size), 47 | "weight_decay": 0.05, 48 | }), 49 | "device": "cuda:0", 50 | "frame_stack": { 51 | "size": frame_stack_size, 52 | "step": 2, 53 | "position": "last", 54 | }, 55 | "inputs_processor": ("stack_inputs", { 56 | "size": image_size, 57 | "pad_fill_value": 0., 58 | }), 59 | "responses_processor": ("identity", {}), 60 | "amp": True, 61 | "iter_size": 1, 62 | }, 63 | cutmix={ 64 | "alpha": 1.0, 65 | "prob": 0.5, 66 | }, 67 | ) 68 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !experiments.torrent 4 | !readme_images/ 5 | -------------------------------------------------------------------------------- /data/experiments.torrent: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lRomul/sensorium/6849050e74e2843fbd70b33af110d71638aa37bb/data/experiments.torrent -------------------------------------------------------------------------------- /data/readme_images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lRomul/sensorium/6849050e74e2843fbd70b33af110d71638aa37bb/data/readme_images/architecture.png -------------------------------------------------------------------------------- /data/readme_images/preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lRomul/sensorium/6849050e74e2843fbd70b33af110d71638aa37bb/data/readme_images/preview.png -------------------------------------------------------------------------------- /data/readme_images/softplus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lRomul/sensorium/6849050e74e2843fbd70b33af110d71638aa37bb/data/readme_images/softplus.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.66.1 2 | requests==2.31.0 3 | numpy==1.26.1 4 | pandas==2.1.1 5 | pyarrow==14.0.1 6 | torch>=2.0.0 7 | pytorch-argus==1.0.0 8 | deeplake==3.8.6 9 | ImageHash==4.3.1 10 | -------------------------------------------------------------------------------- /scripts/download_data.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import zipfile 3 | import argparse 4 | from pathlib import Path 5 | 6 | import deeplake 7 | import numpy as np 8 | import requests 9 | from tqdm import tqdm 10 | 11 | from src import constants 12 | 13 | 14 | def parse_arguments(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("-p", "--path", default=constants.sensorium_dir, type=Path) 17 | return parser.parse_args() 18 | 19 | 20 | if __name__ == "__main__": 21 | args = parse_arguments() 22 | 23 | sensorium_dir = args.path 24 | sensorium_dir.mkdir(parents=True, exist_ok=True) 25 | 26 | for mouse in constants.mice: 27 | file_name = f"{mouse}.zip" 28 | dataset = constants.mouse2dataset[mouse] 29 | url = constants.dataset2url_format[dataset].format(file_name=file_name) 30 | zip_path = sensorium_dir / file_name 31 | mouse_dir = sensorium_dir / mouse 32 | 33 | if mouse_dir.exists(): 34 | print(f"Folder '{str(mouse_dir)}' already exists, skip download") 35 | continue 36 | 37 | print(f"Download '{url}' to '{zip_path}'") 38 | zip_path.unlink(missing_ok=True) 39 | with requests.get(url, stream=True) as r: 40 | total_length = int(r.headers.get("Content-Length")) 41 | with tqdm.wrapattr(r.raw, "read", total=total_length, desc="") as raw: 42 | with open(zip_path, 'wb') as output: 43 | shutil.copyfileobj(raw, output) 44 | 45 | print("Unzip", zip_path) 46 | with zipfile.ZipFile(zip_path, 'r') as zip_file: 47 | zip_file.extractall(sensorium_dir) 48 | 49 | print("Delete", zip_path) 50 | zip_path.unlink() 51 | shutil.rmtree(sensorium_dir / "__MACOSX", ignore_errors=True) 52 | 53 | if mouse in constants.new_mice: 54 | continue 55 | for split in constants.unlabeled_splits: 56 | dataset = deeplake.load(f"hub://sinzlab/Sensorium_2023_{mouse}_{split}") 57 | trials_ids = dataset.id.numpy().astype(int).ravel().tolist() 58 | for index, trial_id in enumerate(trials_ids): 59 | responses_path = mouse_dir / "data" / "responses" / f"{trial_id}.npy" 60 | responses = dataset.responses[index].numpy() 61 | np.save(str(responses_path), responses) 62 | -------------------------------------------------------------------------------- /scripts/ensemble.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from tqdm import tqdm 4 | import numpy as np 5 | 6 | from src.submission import evaluate_folds_predictions, make_submission 7 | from src.data import get_mouse_data 8 | from src import constants 9 | 10 | 11 | def parse_arguments(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("-e", "--experiments", required=True, type=str) 14 | parser.add_argument("-s", "--split", required=True, 15 | choices=["folds"] + constants.unlabeled_splits, type=str) 16 | parser.add_argument("-d", "--dataset", default="new", choices=["new", "old"], type=str) 17 | return parser.parse_args() 18 | 19 | 20 | def ensemble_experiments(experiments: list[str], split: str, dataset: str): 21 | assert len(experiments) > 1 22 | print(f"Ensemble experiments: {experiments=}, {split=}, {dataset=}") 23 | split_dir_name = "out-of-fold" if split == "folds" else split 24 | splits = constants.folds_splits if split == "folds" else [split] 25 | ensemble_dir = constants.predictions_dir / ",".join(experiments) / split_dir_name 26 | for mouse in constants.dataset2mice[dataset]: 27 | ensemble_mouse_dir = ensemble_dir / mouse 28 | print(f"Ensemble mouse: {mouse=}, {str(ensemble_mouse_dir)=}") 29 | ensemble_mouse_dir.mkdir(parents=True, exist_ok=True) 30 | mouse_data = get_mouse_data(mouse=mouse, splits=splits) 31 | 32 | for trial_data in tqdm(mouse_data["trials"]): 33 | pred_filename = f"{trial_data['trial_id']}.npy" 34 | responses_lst = [] 35 | for experiment in experiments: 36 | responses = np.load( 37 | str(constants.predictions_dir / experiment / split_dir_name / mouse / pred_filename) 38 | ) 39 | responses_lst.append(responses) 40 | blend_responses = np.mean(responses_lst, axis=0) 41 | np.save(str(ensemble_mouse_dir / pred_filename), blend_responses) 42 | 43 | 44 | if __name__ == "__main__": 45 | args = parse_arguments() 46 | experiments_lst = sorted(args.experiments.split(',')) 47 | experiment_name = ",".join(experiments_lst) 48 | ensemble_experiments(experiments_lst, args.split, args.dataset) 49 | if args.split == "folds": 50 | evaluate_folds_predictions(experiment_name, args.dataset) 51 | elif args.dataset == "new": 52 | make_submission(experiment_name, args.split) 53 | -------------------------------------------------------------------------------- /scripts/predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | from src.submission import evaluate_folds_predictions, make_submission 8 | from src.utils import get_best_model_path 9 | from src.predictors import Predictor 10 | from src.data import get_mouse_data 11 | from src import constants 12 | 13 | 14 | def parse_arguments(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("-e", "--experiment", required=True, type=str) 17 | parser.add_argument("-s", "--split", required=True, 18 | choices=["folds"] + constants.unlabeled_splits, type=str) 19 | parser.add_argument("-d", "--dataset", default="new", choices=["new", "old"], type=str) 20 | parser.add_argument("--device", default="cuda:0", type=str) 21 | return parser.parse_args() 22 | 23 | 24 | def predict_trial(trial_data: dict, predictor: Predictor, mouse_index: int): 25 | length = trial_data["length"] 26 | video = np.load(trial_data["video_path"])[..., :length] 27 | behavior = np.load(trial_data["behavior_path"])[..., :length] 28 | pupil_center = np.load(trial_data["pupil_center_path"])[..., :length] 29 | responses = predictor.predict_trial( 30 | video=video, 31 | behavior=behavior, 32 | pupil_center=pupil_center, 33 | mouse_index=mouse_index, 34 | ) 35 | return responses 36 | 37 | 38 | def predict_mouse_split(mouse: str, split: str, 39 | predictors: list[Predictor], save_dir: Path): 40 | mouse_index = constants.mouse2index[mouse] 41 | print(f"Predict mouse split: {mouse=} {split=} {len(predictors)=} {str(save_dir)=}") 42 | mouse_data = get_mouse_data(mouse=mouse, splits=[split]) 43 | 44 | for trial_data in tqdm(mouse_data["trials"]): 45 | responses_lst = [] 46 | for predictor in predictors: 47 | responses = predict_trial(trial_data, predictor, mouse_index) 48 | responses_lst.append(responses) 49 | blend_responses = np.mean(responses_lst, axis=0) 50 | np.save(str(save_dir / f"{trial_data['trial_id']}.npy"), blend_responses) 51 | 52 | 53 | def predict_folds(experiment: str, dataset: str, device: str): 54 | print(f"Predict folds: {experiment=}, {dataset=}, {device=}") 55 | for mouse in constants.dataset2mice[dataset]: 56 | mouse_prediction_dir = constants.predictions_dir / experiment / "out-of-fold" / mouse 57 | mouse_prediction_dir.mkdir(parents=True, exist_ok=True) 58 | for fold_split in constants.folds_splits: 59 | model_path = get_best_model_path(constants.experiments_dir / experiment / fold_split) 60 | print("Model path:", str(model_path)) 61 | predictor = Predictor(model_path=model_path, device=device, blend_weights="ones") 62 | predict_mouse_split(mouse, fold_split, [predictor], mouse_prediction_dir) 63 | 64 | 65 | def predict_unlabeled_split(experiment: str, split: str, dataset: str, device: str): 66 | print(f"Predict unlabeled split: {experiment=}, {split=}, {dataset=}, {device=}") 67 | predictors = [] 68 | for fold_split in constants.folds_splits: 69 | model_path = get_best_model_path(constants.experiments_dir / experiment / fold_split) 70 | print("Model path:", str(model_path)) 71 | predictor = Predictor(model_path=model_path, device=device, blend_weights="ones") 72 | predictors.append(predictor) 73 | for mouse in constants.dataset2mice[dataset]: 74 | mouse_prediction_dir = constants.predictions_dir / experiment / split / mouse 75 | mouse_prediction_dir.mkdir(parents=True, exist_ok=True) 76 | predict_mouse_split(mouse, split, predictors, mouse_prediction_dir) 77 | 78 | 79 | if __name__ == "__main__": 80 | args = parse_arguments() 81 | 82 | if args.split == "folds": 83 | predict_folds(args.experiment, args.dataset, args.device) 84 | evaluate_folds_predictions(args.experiment, args.dataset) 85 | elif args.dataset == "new": 86 | predict_unlabeled_split(args.experiment, args.split, args.dataset, args.device) 87 | make_submission(args.experiment, args.split) 88 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import copy 3 | import json 4 | import argparse 5 | from pathlib import Path 6 | from pprint import pprint 7 | from importlib.machinery import SourceFileLoader 8 | 9 | import torch 10 | from torch.utils.data import DataLoader 11 | 12 | from argus import load_model 13 | from argus.callbacks import ( 14 | LoggingToFile, 15 | LoggingToCSV, 16 | CosineAnnealingLR, 17 | Checkpoint, 18 | LambdaLR, 19 | ) 20 | 21 | from src.datasets import TrainMouseVideoDataset, ValMouseVideoDataset, ConcatMiceVideoDataset 22 | from src.utils import get_lr, init_weights, get_best_model_path 23 | from src.responses import get_responses_processor 24 | from src.ema import ModelEma, EmaCheckpoint 25 | from src.inputs import get_inputs_processor 26 | from src.metrics import CorrelationMetric 27 | from src.indexes import IndexesGenerator 28 | from src.argus_models import MouseModel 29 | from src.data import get_mouse_data 30 | from src.mixers import CutMix 31 | from src import constants 32 | 33 | 34 | def parse_arguments(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("-e", "--experiment", required=True, type=str) 37 | parser.add_argument("-f", "--folds", default="all", type=str) 38 | return parser.parse_args() 39 | 40 | 41 | def train_mouse(config: dict, save_dir: Path, train_splits: list[str], val_splits: list[str]): 42 | config = copy.deepcopy(config) 43 | argus_params = config["argus_params"] 44 | 45 | model = MouseModel(argus_params) 46 | 47 | if config["init_weights"]: 48 | print("Weight initialization") 49 | init_weights(model.nn_module) 50 | 51 | if config["ema_decay"]: 52 | print("EMA decay:", config["ema_decay"]) 53 | model.model_ema = ModelEma(model.nn_module, decay=config["ema_decay"]) 54 | checkpoint_class = EmaCheckpoint 55 | else: 56 | checkpoint_class = Checkpoint 57 | 58 | if "distill" in config: 59 | distill_params = config["distill"] 60 | distill_experiment_dir = constants.experiments_dir / distill_params["experiment"] / val_splits[0] 61 | distill_model_path = get_best_model_path(distill_experiment_dir) 62 | distill_model = load_model(distill_model_path, device=argus_params["device"]) 63 | distill_model.eval() 64 | model.distill_model = distill_model.nn_module 65 | model.distill_ratio = distill_params["ratio"] 66 | print(f"Distillation model {str(distill_model_path)}, ratio {model.distill_ratio}") 67 | 68 | indexes_generator = IndexesGenerator(**argus_params["frame_stack"]) 69 | inputs_processor = get_inputs_processor(*argus_params["inputs_processor"]) 70 | responses_processor = get_responses_processor(*argus_params["responses_processor"]) 71 | 72 | cutmix = CutMix(**config["cutmix"]) 73 | train_datasets = [] 74 | mouse_epoch_size = config["train_epoch_size"] // constants.num_mice 75 | for mouse in constants.mice: 76 | train_datasets += [ 77 | TrainMouseVideoDataset( 78 | mouse_data=get_mouse_data(mouse=mouse, splits=train_splits), 79 | indexes_generator=indexes_generator, 80 | inputs_processor=inputs_processor, 81 | responses_processor=responses_processor, 82 | epoch_size=mouse_epoch_size, 83 | mixer=cutmix, 84 | ) 85 | ] 86 | train_dataset = ConcatMiceVideoDataset(train_datasets) 87 | print("Train dataset len:", len(train_dataset)) 88 | val_datasets = [] 89 | for mouse in constants.mice: 90 | val_datasets += [ 91 | ValMouseVideoDataset( 92 | mouse_data=get_mouse_data(mouse=mouse, splits=val_splits), 93 | indexes_generator=indexes_generator, 94 | inputs_processor=inputs_processor, 95 | responses_processor=responses_processor, 96 | ) 97 | ] 98 | val_dataset = ConcatMiceVideoDataset(val_datasets) 99 | print("Val dataset len:", len(val_dataset)) 100 | 101 | train_loader = DataLoader( 102 | train_dataset, 103 | batch_size=config["batch_size"], 104 | num_workers=config["num_dataloader_workers"], 105 | shuffle=True, 106 | ) 107 | val_loader = DataLoader( 108 | val_dataset, 109 | batch_size=config["batch_size"] // argus_params["iter_size"], 110 | num_workers=config["num_dataloader_workers"], 111 | shuffle=False, 112 | ) 113 | 114 | for num_epochs, stage in zip(config["num_epochs"], config["stages"]): 115 | callbacks = [ 116 | LoggingToFile(save_dir / "log.txt", append=True), 117 | LoggingToCSV(save_dir / "log.csv", append=True), 118 | ] 119 | 120 | num_iterations = (len(train_dataset) // config["batch_size"]) * num_epochs 121 | if stage == "warmup": 122 | callbacks += [ 123 | LambdaLR(lambda x: x / num_iterations, 124 | step_on_iteration=True), 125 | ] 126 | elif stage == "train": 127 | checkpoint_format = "model-{epoch:03d}-{val_corr:.6f}.pth" 128 | callbacks += [ 129 | checkpoint_class(save_dir, file_format=checkpoint_format, max_saves=1), 130 | CosineAnnealingLR( 131 | T_max=num_iterations, 132 | eta_min=get_lr(config["min_base_lr"], config["batch_size"]), 133 | step_on_iteration=True, 134 | ), 135 | ] 136 | 137 | metrics = [ 138 | CorrelationMetric(), 139 | ] 140 | 141 | model.fit(train_loader, 142 | val_loader=val_loader, 143 | num_epochs=num_epochs, 144 | callbacks=callbacks, 145 | metrics=metrics) 146 | 147 | 148 | if __name__ == "__main__": 149 | args = parse_arguments() 150 | print("Experiment:", args.experiment) 151 | 152 | config_path = constants.configs_dir / f"{args.experiment}.py" 153 | if not config_path.exists(): 154 | raise RuntimeError(f"Config '{config_path}' is not exists") 155 | 156 | train_config = SourceFileLoader(args.experiment, str(config_path)).load_module().config 157 | print("Experiment config:") 158 | pprint(train_config, sort_dicts=False) 159 | 160 | experiment_dir = constants.experiments_dir / args.experiment 161 | print("Experiment dir:", experiment_dir) 162 | if not experiment_dir.exists(): 163 | experiment_dir.mkdir(parents=True, exist_ok=True) 164 | else: 165 | print(f"Folder '{experiment_dir}' already exists.") 166 | 167 | with open(experiment_dir / "train.py", "w") as outfile: 168 | outfile.write(open(__file__).read()) 169 | 170 | with open(experiment_dir / "config.json", "w") as outfile: 171 | json.dump(train_config, outfile, indent=4) 172 | 173 | if args.folds == "all": 174 | folds_splits = constants.folds_splits 175 | else: 176 | folds_splits = [f"fold_{fold}" for fold in args.folds.split(",")] 177 | 178 | for fold_split in folds_splits: 179 | fold_experiment_dir = experiment_dir / fold_split 180 | 181 | val_folds_splits = [fold_split] 182 | train_folds_splits = sorted(set(constants.folds_splits) - set(val_folds_splits)) 183 | 184 | print(f"Val fold: {val_folds_splits}, train folds: {train_folds_splits}") 185 | print(f"Fold experiment dir: {fold_experiment_dir}") 186 | train_mouse(train_config, fold_experiment_dir, train_folds_splits, val_folds_splits) 187 | 188 | torch.cuda.empty_cache() 189 | time.sleep(12) 190 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 119 3 | exclude = __pycache__,data,notebooks 4 | per-file-ignores = __init__.py:F401 5 | 6 | [mypy] 7 | files = src 8 | ignore_missing_imports = True 9 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from src.argus_models import MouseModel 2 | -------------------------------------------------------------------------------- /src/argus_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import argus 4 | from argus.engine import State 5 | from argus.loss import pytorch_losses 6 | from argus.utils import deep_to, deep_detach, deep_chunk 7 | 8 | from src.ema import ModelEma 9 | from src.losses import MicePoissonLoss 10 | from src.models.dwiseneuro import DwiseNeuro 11 | 12 | 13 | class MouseModel(argus.Model): 14 | nn_module = { 15 | "dwiseneuro": DwiseNeuro, 16 | } 17 | loss = { 18 | **pytorch_losses, 19 | "mice_poisson": MicePoissonLoss, 20 | } 21 | 22 | def __init__(self, params: dict): 23 | super().__init__(params) 24 | self.iter_size = int(params.get('iter_size', 1)) 25 | self.amp = bool(params.get('amp', False)) 26 | self.grad_scaler = torch.cuda.amp.GradScaler(enabled=self.amp) 27 | self.model_ema: ModelEma | None = None 28 | self.distill_model: torch.nn.Module | None = None 29 | self.distill_ratio: float = 0. 30 | 31 | @torch.no_grad() 32 | def add_distill_predictions(self, input, target): 33 | if self.distill_model is not None and self.distill_ratio: 34 | distill_prediction = self.distill_model(input) 35 | target_tensors, mice_weights = target 36 | distill_mask = mice_weights == 0. 37 | distill_weight = (self.distill_ratio / (1. - self.distill_ratio) 38 | * mice_weights.sum() / distill_mask.sum()) 39 | for batch_idx, mouse_idx in torch.argwhere(distill_mask): 40 | target_tensors[mouse_idx][batch_idx] = distill_prediction[mouse_idx][batch_idx] 41 | mice_weights[batch_idx, mouse_idx] = distill_weight 42 | 43 | def train_step(self, batch, state: State) -> dict: 44 | self.train() 45 | self.optimizer.zero_grad() 46 | 47 | loss_value = 0 48 | for i, chunk_batch in enumerate(deep_chunk(batch, self.iter_size)): 49 | input, target = deep_to(chunk_batch, self.device, non_blocking=True) 50 | with torch.cuda.amp.autocast(enabled=self.amp): 51 | self.add_distill_predictions(input, target) 52 | prediction = self.nn_module(input) 53 | loss = self.loss(prediction, target) 54 | loss = loss / self.iter_size 55 | self.grad_scaler.scale(loss).backward() 56 | loss_value += loss.item() 57 | 58 | self.grad_scaler.step(self.optimizer) 59 | self.grad_scaler.update() 60 | 61 | if self.model_ema is not None: 62 | self.model_ema.update(self.nn_module) 63 | 64 | prediction = deep_detach(prediction) 65 | target = deep_detach(target) 66 | prediction = self.prediction_transform(prediction) 67 | return { 68 | 'prediction': prediction, 69 | 'target': target, 70 | 'loss': loss_value 71 | } 72 | 73 | def val_step(self, batch, state: State) -> dict: 74 | self.eval() 75 | with torch.no_grad(): 76 | input, target = deep_to(batch, device=self.device, non_blocking=True) 77 | if self.model_ema is None: 78 | prediction = self.nn_module(input) 79 | else: 80 | prediction = self.model_ema.ema(input) 81 | loss = self.loss(prediction, target) 82 | prediction = self.prediction_transform(prediction) 83 | return { 84 | 'prediction': prediction, 85 | 'target': target, 86 | 'loss': loss.item() 87 | } 88 | 89 | def predict(self, input, mouse_index: int | None = None): 90 | self._check_predict_ready() 91 | with torch.no_grad(): 92 | self.eval() 93 | input = deep_to(input, self.device) 94 | if self.model_ema is None: 95 | prediction = self.nn_module(input, mouse_index) 96 | else: 97 | prediction = self.model_ema.ema(input, mouse_index) 98 | prediction = self.prediction_transform(prediction) 99 | return prediction 100 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | work_dir = Path("/workdir") 4 | data_dir = work_dir / "data" 5 | sensorium_dir = data_dir / "sensorium_all_2023" 6 | 7 | configs_dir = work_dir / "configs" 8 | experiments_dir = data_dir / "experiments" 9 | predictions_dir = data_dir / "predictions" 10 | 11 | new_mice = [ 12 | "dynamic29515-10-12-Video-9b4f6a1a067fe51e15306b9628efea20", 13 | "dynamic29623-4-9-Video-9b4f6a1a067fe51e15306b9628efea20", 14 | "dynamic29647-19-8-Video-9b4f6a1a067fe51e15306b9628efea20", 15 | "dynamic29712-5-9-Video-9b4f6a1a067fe51e15306b9628efea20", 16 | "dynamic29755-2-8-Video-9b4f6a1a067fe51e15306b9628efea20", 17 | ] 18 | new_num_neurons = [7863, 7908, 8202, 7939, 8122] 19 | old_mice = [ 20 | "dynamic29156-11-10-Video-8744edeac3b4d1ce16b680916b5267ce", 21 | "dynamic29228-2-10-Video-8744edeac3b4d1ce16b680916b5267ce", 22 | "dynamic29234-6-9-Video-8744edeac3b4d1ce16b680916b5267ce", 23 | "dynamic29513-3-5-Video-8744edeac3b4d1ce16b680916b5267ce", 24 | "dynamic29514-2-9-Video-8744edeac3b4d1ce16b680916b5267ce", 25 | ] 26 | old_num_neurons = [7440, 7928, 8285, 7671, 7495] 27 | dataset2mice = { 28 | "new": new_mice, 29 | "old": old_mice, 30 | } 31 | mouse2dataset = {m: d for d, mc in dataset2mice.items() for m in mc} 32 | dataset2url_format = { 33 | "new": "https://gin.g-node.org/pollytur/sensorium_2023_dataset/raw/master/{file_name}", 34 | "old": "https://gin.g-node.org/pollytur/Sensorium2023Data/raw/master/{file_name}", 35 | } 36 | 37 | mice = new_mice + old_mice 38 | num_neurons = new_num_neurons + old_num_neurons 39 | 40 | num_mice = len(mice) 41 | index2mouse: dict[int, str] = {index: mouse for index, mouse in enumerate(mice)} 42 | mouse2index: dict[str, int] = {mouse: index for index, mouse in enumerate(mice)} 43 | mouse2num_neurons: dict[str, int] = {mouse: num for mouse, num in zip(mice, num_neurons)} 44 | mice_indexes = list(range(num_mice)) 45 | 46 | unlabeled_splits = ["live_test_main", "live_test_bonus", "final_test_main", "final_test_bonus"] 47 | 48 | num_folds = 7 49 | folds = list(range(num_folds)) 50 | folds_splits = [f"fold_{fold}" for fold in folds] 51 | 52 | submission_limit_length = 300 53 | submission_skip_first = 50 54 | submission_skip_last = 1 55 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from src.phash import calculate_video_phash 4 | from src.utils import get_length_without_nan 5 | from src import constants 6 | 7 | 8 | def create_videos_phashes(mouse: str) -> np.ndarray: 9 | mouse_dir = constants.sensorium_dir / mouse 10 | tiers = np.load(str(mouse_dir / "meta" / "trials" / "tiers.npy")) 11 | phashes = np.zeros(tiers.shape[0], dtype=np.uint64) 12 | for trial_id, tier in enumerate(tiers): 13 | if tier == "none": 14 | continue 15 | video = np.load(str(mouse_dir / "data" / "videos" / f"{trial_id}.npy")) 16 | phashes[trial_id] = calculate_video_phash(video) 17 | return phashes 18 | 19 | 20 | def get_folds_tiers(mouse: str, num_folds: int): 21 | tiers = np.load(str(constants.sensorium_dir / mouse / "meta" / "trials" / "tiers.npy")) 22 | phashes = create_videos_phashes(mouse) 23 | if mouse in constants.new_mice: 24 | trial_ids = np.argwhere((tiers == "train") | (tiers == "oracle")).ravel() 25 | else: 26 | trial_ids = np.argwhere(tiers != "none").ravel() 27 | for trial_id in trial_ids: 28 | fold = int(phashes[trial_id]) % num_folds # group k-fold by video hash 29 | tiers[trial_id] = f"fold_{fold}" 30 | return tiers 31 | 32 | 33 | def get_mouse_data(mouse: str, splits: list[str]) -> dict: 34 | assert mouse in constants.mice 35 | tiers = get_folds_tiers(mouse, constants.num_folds) 36 | mouse_dir = constants.sensorium_dir / mouse 37 | neuron_ids = np.load(str(mouse_dir / "meta" / "neurons" / "unit_ids.npy")) 38 | cell_motor_coords = np.load(str(mouse_dir / "meta" / "neurons" / "cell_motor_coordinates.npy")) 39 | 40 | mouse_data = { 41 | "mouse": mouse, 42 | "splits": splits, 43 | "neuron_ids": neuron_ids, 44 | "num_neurons": neuron_ids.shape[0], 45 | "cell_motor_coordinates": cell_motor_coords, 46 | "trials": [], 47 | } 48 | 49 | for split in splits: 50 | if split in constants.folds_splits: 51 | labeled_split = True 52 | elif split in constants.unlabeled_splits: 53 | labeled_split = False 54 | else: 55 | raise ValueError(f"Unknown data split '{split}'") 56 | trial_ids = np.argwhere(tiers == split).ravel().tolist() 57 | 58 | for trial_id in trial_ids: 59 | behavior_path = str(mouse_dir / "data" / "behavior" / f"{trial_id}.npy") 60 | trial_data = { 61 | "trial_id": trial_id, 62 | "length": get_length_without_nan(np.load(behavior_path)[0]), 63 | "video_path": str(mouse_dir / "data" / "videos" / f"{trial_id}.npy"), 64 | "behavior_path": behavior_path, 65 | "pupil_center_path": str(mouse_dir / "data" / "pupil_center" / f"{trial_id}.npy"), 66 | } 67 | if labeled_split: 68 | response_path = str(mouse_dir / "data" / "responses" / f"{trial_id}.npy") 69 | trial_data["response_path"] = response_path 70 | trial_data["length"] = get_length_without_nan(np.load(response_path)[0]) 71 | mouse_data["trials"].append(trial_data) 72 | 73 | return mouse_data 74 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import random 3 | 4 | import numpy as np 5 | 6 | import torch 7 | from torch import nn 8 | from torch import Tensor 9 | from torch.utils.data import Dataset 10 | 11 | from src.mixers import Mixer 12 | from src.utils import set_random_seed 13 | from src.inputs import InputsProcessor 14 | from src.indexes import IndexesGenerator 15 | from src.responses import ResponsesProcessor 16 | from src import constants 17 | 18 | 19 | class MouseVideoDataset(Dataset, metaclass=abc.ABCMeta): 20 | def __init__(self, 21 | mouse_data: dict, 22 | indexes_generator: IndexesGenerator, 23 | inputs_processor: InputsProcessor, 24 | responses_processor: ResponsesProcessor): 25 | self.mouse_data = mouse_data 26 | self.mouse = mouse_data["mouse"] 27 | self.mouse_index = constants.mouse2index[self.mouse] 28 | self.indexes_generator = indexes_generator 29 | self.inputs_processor = inputs_processor 30 | self.responses_processor = responses_processor 31 | 32 | self.trials = self.mouse_data["trials"] 33 | self.num_trials = len(self.trials) 34 | self.trials_lengths = [t["length"] for t in self.trials] 35 | self.num_neurons = self.mouse_data["num_neurons"] 36 | 37 | def get_frames(self, trial_index: int, indexes: list[int]) -> np.ndarray: 38 | frames = np.load(self.trials[trial_index]["video_path"])[..., indexes] 39 | return frames 40 | 41 | def get_responses(self, trial_index: int, indexes: list[int]) -> np.ndarray: 42 | responses = np.load(self.trials[trial_index]["response_path"])[..., indexes] 43 | return responses 44 | 45 | def get_behavior(self, trial_index: int, indexes: list[int]) -> np.ndarray: 46 | behavior = np.load(self.trials[trial_index]["behavior_path"])[..., indexes] 47 | return behavior 48 | 49 | def get_pupil_center(self, trial_index: int, indexes: list[int]) -> np.ndarray: 50 | pupil_center = np.load(self.trials[trial_index]["pupil_center_path"])[..., indexes] 51 | return pupil_center 52 | 53 | def get_inputs_responses( 54 | self, 55 | trial_index: int, 56 | indexes: list[int], 57 | ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 58 | frames = self.get_frames(trial_index, indexes) 59 | responses = self.get_responses(trial_index, indexes) 60 | behavior = self.get_behavior(trial_index, indexes) 61 | pupil_center = self.get_pupil_center(trial_index, indexes) 62 | return frames, behavior, pupil_center, responses 63 | 64 | def process_inputs_responses(self, 65 | frames: np.ndarray, 66 | behavior: np.ndarray, 67 | pupil_center: np.ndarray, 68 | responses: np.ndarray) -> tuple[Tensor, Tensor]: 69 | input_tensor = self.inputs_processor(frames, behavior, pupil_center) 70 | target_tensor = self.responses_processor(responses) 71 | return input_tensor, target_tensor 72 | 73 | @abc.abstractmethod 74 | def __len__(self) -> int: 75 | pass 76 | 77 | @abc.abstractmethod 78 | def get_indexes(self, index: int) -> tuple[int, list[int]]: 79 | pass 80 | 81 | def get_sample_tensors(self, index: int) -> tuple[Tensor, Tensor]: 82 | trial_index, indexes = self.get_indexes(index) 83 | frames, behavior, pupil_center, responses = self.get_inputs_responses(trial_index, indexes) 84 | return self.process_inputs_responses(frames, behavior, pupil_center, responses) 85 | 86 | def __getitem__(self, index: int) -> tuple[Tensor, Tensor]: 87 | return self.get_sample_tensors(index) 88 | 89 | 90 | class TrainMouseVideoDataset(MouseVideoDataset): 91 | def __init__(self, 92 | mouse_data: dict, 93 | indexes_generator: IndexesGenerator, 94 | inputs_processor: InputsProcessor, 95 | responses_processor: ResponsesProcessor, 96 | epoch_size: int, 97 | augmentations: nn.Module | None = None, 98 | mixer: Mixer | None = None): 99 | super().__init__(mouse_data, indexes_generator, inputs_processor, responses_processor) 100 | self.epoch_size = epoch_size 101 | self.augmentations = augmentations 102 | self.mixer = mixer 103 | 104 | def __len__(self) -> int: 105 | return self.epoch_size 106 | 107 | def get_indexes(self, index: int) -> tuple[int, list[int]]: 108 | set_random_seed(index) 109 | trial_index = random.randrange(0, self.num_trials) 110 | num_frames = self.trials[trial_index]["length"] 111 | frame_index = random.randrange( 112 | self.indexes_generator.behind, 113 | num_frames - self.indexes_generator.ahead 114 | ) 115 | indexes = self.indexes_generator.make_indexes(frame_index) 116 | return trial_index, indexes 117 | 118 | def get_sample_tensors(self, index: int) -> tuple[Tensor, Tensor]: 119 | frames, responses = super().get_sample_tensors(index) 120 | if self.augmentations is not None: 121 | frames = self.augmentations(frames[None])[0] 122 | return frames, responses 123 | 124 | def __getitem__(self, index: int) -> tuple[Tensor, Tensor]: 125 | sample = self.get_sample_tensors(index) 126 | if self.mixer is not None and self.mixer.use(): 127 | random_sample = self.get_sample_tensors(index + 1) 128 | sample = self.mixer(sample, random_sample) 129 | return sample 130 | 131 | 132 | class ValMouseVideoDataset(MouseVideoDataset): 133 | def __init__(self, 134 | mouse_data: dict, 135 | indexes_generator: IndexesGenerator, 136 | inputs_processor: InputsProcessor, 137 | responses_processor: ResponsesProcessor): 138 | super().__init__(mouse_data, indexes_generator, inputs_processor, responses_processor) 139 | self.window_size = self.indexes_generator.width 140 | self.samples_per_trials = [length // self.window_size for length in self.trials_lengths] 141 | self.num_samples = sum(self.samples_per_trials) 142 | 143 | def __len__(self) -> int: 144 | return self.num_samples 145 | 146 | def get_indexes(self, index: int) -> tuple[int, list[int]]: 147 | assert 0 <= index < self.__len__() 148 | trial_sample_index = index 149 | trial_index = 0 150 | for trial_index, num_trial_samples in enumerate(self.samples_per_trials): 151 | if trial_sample_index >= num_trial_samples: 152 | trial_sample_index -= num_trial_samples 153 | else: 154 | break 155 | 156 | frame_index = self.indexes_generator.behind + trial_sample_index * self.window_size 157 | indexes = self.indexes_generator.make_indexes(frame_index) 158 | return trial_index, indexes 159 | 160 | 161 | class ConcatMiceVideoDataset(Dataset): 162 | def __init__(self, mice_datasets: list[MouseVideoDataset]): 163 | self.mice_indexes = [d.mouse_index for d in mice_datasets] 164 | assert self.mice_indexes == constants.mice_indexes 165 | self.mice_datasets = mice_datasets 166 | self.samples_per_dataset = [len(d) for d in mice_datasets] 167 | self.num_samples = sum(self.samples_per_dataset) 168 | 169 | def __len__(self): 170 | return self.num_samples 171 | 172 | def construct_mice_sample( 173 | self, mouse_index: int, mouse_sample: tuple[Tensor, Tensor] 174 | ) -> tuple[Tensor, tuple[list[Tensor], Tensor]]: 175 | input_tensor, target_tensor = mouse_sample 176 | target_tensors = [] 177 | for index in self.mice_indexes: 178 | if index == mouse_index: 179 | target_tensors.append(target_tensor) 180 | else: 181 | temporal_shape = [target_tensor.shape[-1]] if len(target_tensor.shape) == 2 else [] 182 | target_tensors.append( 183 | torch.zeros(constants.num_neurons[index], *temporal_shape, dtype=torch.float32) 184 | ) 185 | mice_weights = torch.zeros(constants.num_mice, dtype=torch.float32) 186 | mice_weights[mouse_index] = 1.0 187 | return input_tensor, (target_tensors, mice_weights) 188 | 189 | def __getitem__(self, index: int) -> tuple[Tensor, tuple[list[Tensor], Tensor]]: 190 | assert 0 <= index < self.__len__() 191 | sample_index = index 192 | mouse_index = 0 193 | for mouse_index, num_trial_samples in enumerate(self.samples_per_dataset): 194 | if sample_index >= num_trial_samples: 195 | sample_index -= num_trial_samples 196 | else: 197 | break 198 | mouse_sample = self.mice_datasets[mouse_index][sample_index] 199 | mice_sample = self.construct_mice_sample(mouse_index, mouse_sample) 200 | return mice_sample 201 | -------------------------------------------------------------------------------- /src/ema.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.parallel import DataParallel, DistributedDataParallel 6 | 7 | from argus.utils import deep_to 8 | from argus.engine import State 9 | from argus.callbacks import Checkpoint 10 | 11 | 12 | class ModelEma(nn.Module): 13 | """ Model Exponential Moving Average V2 14 | 15 | Keep a moving average of everything in the model state_dict (parameters and buffers). 16 | V2 of this module is simpler, it does not match params/buffers based on name but simply 17 | iterates in order. It works with torchscript (JIT of full model). 18 | 19 | This is intended to allow functionality like 20 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 21 | 22 | A smoothed version of the weights is necessary for some training schemes to perform well. 23 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use 24 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA 25 | smoothing of weights to match results. Pay attention to the decay constant you are using 26 | relative to your update count per epoch. 27 | 28 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but 29 | disable validation of the EMA weights. Validation will have to be done manually in a separate 30 | process, or after the training stops converging. 31 | 32 | This class is sensitive where it is initialized in the sequence of model init, 33 | GPU assignment and distributed training wrappers. 34 | 35 | Source: https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/model_ema.py 36 | """ 37 | def __init__(self, model, decay=0.9999, device=None): 38 | super().__init__() 39 | # make a copy of the model for accumulating moving average of weights 40 | self.ema = deepcopy(model) 41 | self.ema.eval() 42 | self.decay = decay 43 | self.device = device # perform ema on different device from model if set 44 | if self.device is not None: 45 | self.ema.to(device=device) 46 | 47 | def _update(self, model, update_fn): 48 | with torch.no_grad(): 49 | for ema_v, model_v in zip(self.ema.state_dict().values(), model.state_dict().values()): 50 | if self.device is not None: 51 | model_v = model_v.to(device=self.device) 52 | ema_v.copy_(update_fn(ema_v, model_v)) 53 | 54 | def update(self, model): 55 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) 56 | 57 | def set(self, model): 58 | self._update(model, update_fn=lambda e, m: m) 59 | 60 | 61 | class EmaCheckpoint(Checkpoint): 62 | def save_model(self, state: State, file_path): 63 | nn_module = state.model.model_ema.ema 64 | if isinstance(nn_module, (DataParallel, DistributedDataParallel)): 65 | nn_module = nn_module.module 66 | 67 | torch_state = { 68 | 'model_name': state.model.__class__.__name__, 69 | 'params': state.model.params, 70 | 'nn_state_dict': deep_to(nn_module.state_dict(), 'cpu'), 71 | } 72 | torch.save(torch_state, file_path) 73 | state.logger.info(f"Model saved to '{file_path}'") 74 | -------------------------------------------------------------------------------- /src/indexes.py: -------------------------------------------------------------------------------- 1 | class IndexesGenerator: 2 | def __init__(self, size: int, step: int, position: str = "last"): 3 | self.size = size 4 | self.step = step 5 | 6 | if position == "first": 7 | self.behind = 0 8 | self.ahead = self.size - 1 9 | elif position == "middle": 10 | self.behind = self.size // 2 11 | self.ahead = self.size - self.behind - 1 12 | elif position == "last": 13 | self.behind = self.size - 1 14 | self.ahead = 0 15 | else: 16 | raise ValueError( 17 | f"Index position value should be one of {'first', 'middle', 'last'}" 18 | ) 19 | self.behind *= self.step 20 | self.ahead *= self.step 21 | self.width = self.behind + self.ahead + 1 22 | 23 | def make_indexes(self, index: int) -> list[int]: 24 | return list( 25 | range( 26 | index - self.behind, 27 | index + self.ahead + 1, 28 | self.step, 29 | ) 30 | ) 31 | 32 | def clip_index(self, index: int, length: int, save_zone: int = 0) -> int: 33 | behind_frames = self.behind + save_zone 34 | ahead_frames = self.ahead + save_zone 35 | if index < behind_frames: 36 | index = behind_frames 37 | elif index >= length - ahead_frames: 38 | index = length - ahead_frames - 1 39 | return index 40 | -------------------------------------------------------------------------------- /src/inputs.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Type 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | 9 | class InputsProcessor(metaclass=abc.ABCMeta): 10 | @abc.abstractmethod 11 | def __call__(self, frames: np.ndarray, behavior: np.ndarray, pupil_center: np.ndarray) -> torch.Tensor: 12 | pass 13 | 14 | 15 | class StackInputsProcessor(InputsProcessor): 16 | def __init__(self, 17 | size: tuple[int, int], 18 | pad_fill_value: int = 0): 19 | self.size = size 20 | self.pad_fill_value = pad_fill_value 21 | 22 | def __call__(self, frames: np.ndarray, behavior: np.ndarray, pupil_center: np.ndarray) -> torch.Tensor: 23 | length = frames.shape[-1] 24 | input_array = np.full((5, length, self.size[1], self.size[0]), self.pad_fill_value, dtype=np.float32) 25 | 26 | frames = np.transpose(frames.astype(np.float32), (2, 0, 1)) 27 | height, width = frames.shape[-2:] 28 | height_start = (self.size[1] - height) // 2 29 | width_start = (self.size[0] - width) // 2 30 | input_array[0, :, height_start: height_start + height, width_start: width_start + width] = frames 31 | 32 | input_array[1:3] = behavior[:, :, None, None] 33 | input_array[3:] = pupil_center[:, :, None, None] 34 | 35 | tensor_frames = torch.from_numpy(input_array) 36 | return tensor_frames 37 | 38 | 39 | _INPUTS_PROCESSOR_REGISTRY: dict[str, Type[InputsProcessor]] = dict( 40 | stack_inputs=StackInputsProcessor, 41 | ) 42 | 43 | 44 | def get_inputs_processor(name: str, processor_params: dict) -> InputsProcessor: 45 | assert name in _INPUTS_PROCESSOR_REGISTRY 46 | return _INPUTS_PROCESSOR_REGISTRY[name](**processor_params) 47 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class MicePoissonLoss(nn.Module): 6 | def __init__(self, log_input: bool = False, full: bool = False, eps: float = 1e-8): 7 | super().__init__() 8 | self.poisson = nn.PoissonNLLLoss(log_input=log_input, full=full, eps=eps, reduction="none") 9 | 10 | def forward(self, inputs, targets): 11 | target_tensors, mice_weights = targets 12 | mice_weights = mice_weights / mice_weights.sum() 13 | loss_value = 0 14 | for mouse_index, (input_tensor, target_tensor) in enumerate(zip(inputs, target_tensors)): 15 | mouse_weights = mice_weights[..., mouse_index] 16 | mask = mouse_weights != 0.0 17 | if torch.any(mask): 18 | loss = self.poisson(input_tensor[mask], target_tensor[mask]) 19 | loss *= mouse_weights[mask].view(-1, *[1] * (len(loss.shape) - 1)) 20 | loss_value += loss.sum() 21 | return loss_value 22 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Union, Tuple 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | from argus.metrics import Metric 9 | 10 | 11 | def corr( 12 | y1: np.ndarray, y2: np.ndarray, axis: Union[None, int, Tuple[int]] = -1, eps: float = 1e-8, **kwargs 13 | ) -> np.ndarray: 14 | """ 15 | Compute the correlation between two NumPy arrays along the specified dimension(s). 16 | 17 | Args: 18 | y1: first NumPy array 19 | y2: second NumPy array 20 | axis: dimension(s) along which the correlation is computed. Any valid NumPy 21 | axis spec works here 22 | eps: offset to the standard deviation to avoid exploding the correlation due 23 | to small division (default 1e-8) 24 | **kwargs: passed to final numpy.mean operation over standardized y1 * y2 25 | 26 | Returns: correlation array 27 | """ 28 | 29 | y1 = (y1 - y1.mean(axis=axis, keepdims=True)) / (y1.std(axis=axis, keepdims=True, ddof=0) + eps) 30 | y2 = (y2 - y2.mean(axis=axis, keepdims=True)) / (y2.std(axis=axis, keepdims=True, ddof=0) + eps) 31 | return (y1 * y2).mean(axis=axis, **kwargs) 32 | 33 | 34 | class CorrelationMetric(Metric): 35 | name: str = "corr" 36 | better: str = "max" 37 | 38 | def __init__(self): 39 | super().__init__() 40 | self.predictions = defaultdict(list) 41 | self.targets = defaultdict(list) 42 | self.weights = defaultdict(list) 43 | 44 | def reset(self): 45 | self.predictions = defaultdict(list) 46 | self.targets = defaultdict(list) 47 | self.weights = defaultdict(list) 48 | 49 | def update(self, step_output: dict): 50 | pred_tensors = step_output["prediction"] 51 | target_tensors, mice_weights = step_output["target"] 52 | 53 | for mouse_index, (pred, target) in enumerate(zip(pred_tensors, target_tensors)): 54 | mouse_weight = mice_weights[..., mouse_index] 55 | mask = mouse_weight != 0.0 56 | if torch.any(mask): 57 | pred, target = pred[mask], target[mask] 58 | 59 | if len(target.shape) == 3: 60 | pred = torch.transpose(pred, 1, 2) 61 | pred = pred.reshape(-1, pred.shape[-1]) 62 | target = torch.transpose(target, 1, 2) 63 | target = target.reshape(-1, target.shape[-1]) 64 | 65 | self.predictions[mouse_index].append(pred.cpu().numpy()) 66 | self.targets[mouse_index].append(target.cpu().numpy()) 67 | 68 | def compute(self): 69 | mice_corr = dict() 70 | for mouse_index in self.predictions: 71 | targets = np.concatenate(self.targets[mouse_index], axis=0) 72 | predictions = np.concatenate(self.predictions[mouse_index], axis=0) 73 | mice_corr[mouse_index] = corr(predictions, targets, axis=0).mean() 74 | return mice_corr 75 | 76 | def epoch_complete(self, state): 77 | with torch.no_grad(): 78 | mice_corr = self.compute() 79 | name_prefix = f"{state.phase}_" if state.phase else '' 80 | for mouse_index, mouse_corr in mice_corr.items(): 81 | state.metrics[name_prefix + self.name + f"_mouse_{mouse_index}"] = mouse_corr 82 | state.metrics[name_prefix + self.name] = np.mean(list(mice_corr.values())) 83 | -------------------------------------------------------------------------------- /src/mixers.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import torch 4 | 5 | import numpy as np 6 | 7 | SampleType = tuple[torch.Tensor, torch.Tensor] 8 | 9 | 10 | class Mixer(metaclass=abc.ABCMeta): 11 | def __init__(self, prob: float): 12 | self.prob = prob 13 | 14 | def use(self): 15 | return np.random.random() < self.prob 16 | 17 | @abc.abstractmethod 18 | def __call__(self, sample1: SampleType, sample2: SampleType) -> SampleType: 19 | pass 20 | 21 | 22 | class Mixup(Mixer): 23 | def __init__(self, alpha: float = 0.4, prob: float = 1.0): 24 | super().__init__(prob) 25 | self.alpha = alpha 26 | 27 | def __call__(self, sample1: SampleType, sample2: SampleType) -> SampleType: 28 | inputs1, target1 = sample1 29 | inputs2, target2 = sample2 30 | lam = np.random.beta(self.alpha, self.alpha) 31 | inputs = (1 - lam) * inputs1 + lam * inputs2 32 | target = (1 - lam) * target1 + lam * target2 33 | return inputs, target 34 | 35 | 36 | def rand_bbox(height: int, width: int, lam: float): 37 | cut_rat = np.sqrt(lam) 38 | cut_w = (width * cut_rat).astype(int) 39 | cut_h = (height * cut_rat).astype(int) 40 | 41 | cx = np.random.randint(width) 42 | cy = np.random.randint(height) 43 | 44 | bbx1 = np.clip(cx - cut_w // 2, 0, width) 45 | bby1 = np.clip(cy - cut_h // 2, 0, height) 46 | bbx2 = np.clip(cx + cut_w // 2, 0, width) 47 | bby2 = np.clip(cy + cut_h // 2, 0, height) 48 | 49 | return bbx1, bby1, bbx2, bby2 50 | 51 | 52 | class CutMix(Mixer): 53 | def __init__(self, alpha: float = 1.0, prob: float = 1.0): 54 | super().__init__(prob) 55 | self.alpha = alpha 56 | 57 | def __call__(self, sample1: SampleType, sample2: SampleType) -> SampleType: 58 | inputs1, target1 = sample1 59 | inputs2, target2 = sample2 60 | inputs = inputs1.clone().detach() 61 | lam = np.random.beta(self.alpha, self.alpha) 62 | h, w = inputs1.shape[-2:] 63 | bbx1, bby1, bbx2, bby2 = rand_bbox(h, w, lam) 64 | inputs[..., bbx1: bbx2, bby1: bby2] = inputs2[..., bbx1: bbx2, bby1: bby2] 65 | lam = (bbx2 - bbx1) * (bby2 - bby1) / (h * w) 66 | target = (1 - lam) * target1 + lam * target2 67 | return inputs, target 68 | 69 | 70 | class RandomChoiceMixer(Mixer): 71 | def __init__(self, mixers: list[Mixer], choice_probs: list[float], prob: float = 1.0): 72 | super().__init__(prob) 73 | self.mixers = mixers 74 | self.choice_probs = choice_probs 75 | 76 | def __call__(self, sample1: SampleType, sample2: SampleType) -> SampleType: 77 | mixer_index = np.random.choice(range(len(self.mixers)), p=self.choice_probs) 78 | mixer = self.mixers[mixer_index] 79 | return mixer(sample1, sample2) 80 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lRomul/sensorium/6849050e74e2843fbd70b33af110d71638aa37bb/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/dwiseneuro.py: -------------------------------------------------------------------------------- 1 | import math 2 | import functools 3 | from typing import Callable 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class BatchNormAct(nn.Module): 10 | def __init__(self, 11 | num_features: int, 12 | bn_layer: Callable = nn.BatchNorm3d, 13 | act_layer: Callable = nn.ReLU, 14 | apply_act: bool = True): 15 | super().__init__() 16 | self.bn = bn_layer(num_features) 17 | self.act = act_layer() if apply_act else nn.Identity() 18 | 19 | def forward(self, x): 20 | x = self.bn(x) 21 | x = self.act(x) 22 | return x 23 | 24 | 25 | class SqueezeExcite3d(nn.Module): 26 | def __init__(self, 27 | in_features: int, 28 | reduce_ratio: int = 16, 29 | act_layer: Callable = nn.ReLU, 30 | gate_layer: Callable = nn.Sigmoid): 31 | super().__init__() 32 | rd_channels = in_features // reduce_ratio 33 | self.conv_reduce = nn.Conv3d(in_features, rd_channels, (1, 1, 1), bias=True) 34 | self.act1 = act_layer() 35 | self.conv_expand = nn.Conv3d(rd_channels, in_features, (1, 1, 1), bias=True) 36 | self.gate = gate_layer() 37 | 38 | def forward(self, x): 39 | x_se = x.mean((2, 3, 4), keepdim=True) 40 | x_se = self.conv_reduce(x_se) 41 | x_se = self.act1(x_se) 42 | x_se = self.conv_expand(x_se) 43 | return x * self.gate(x_se) 44 | 45 | 46 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): 47 | if drop_prob == 0. or not training: 48 | return x 49 | keep_prob = 1 - drop_prob 50 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 51 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 52 | if keep_prob > 0. and scale_by_keep: 53 | random_tensor.div_(keep_prob) 54 | return x * random_tensor 55 | 56 | 57 | class DropPath(nn.Module): 58 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 59 | super(DropPath, self).__init__() 60 | self.drop_prob = drop_prob 61 | self.scale_by_keep = scale_by_keep 62 | 63 | def forward(self, x): 64 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 65 | 66 | def extra_repr(self): 67 | return f"drop_prob={round(self.drop_prob,3):0.3f}" 68 | 69 | 70 | class InvertedResidual3d(nn.Module): 71 | def __init__(self, 72 | in_features: int, 73 | out_features: int, 74 | spatial_kernel: int = 3, 75 | temporal_kernel: int = 3, 76 | spatial_stride: int = 1, 77 | expansion_ratio: int = 3, 78 | se_reduce_ratio: int = 16, 79 | act_layer: Callable = nn.ReLU, 80 | bn_layer: Callable = nn.BatchNorm3d, 81 | drop_path_rate: float = 0., 82 | bias: bool = False): 83 | super().__init__() 84 | self.spatial_stride = spatial_stride 85 | self.out_features = out_features 86 | mid_features = in_features * expansion_ratio 87 | stride = (1, spatial_stride, spatial_stride) 88 | 89 | # Point-wise expansion 90 | self.conv_pw = nn.Sequential( 91 | nn.Conv3d(in_features, mid_features, (1, 1, 1), bias=bias), 92 | BatchNormAct(mid_features, bn_layer=bn_layer, act_layer=act_layer), 93 | ) 94 | 95 | # Spatial depth-wise convolution 96 | spatial_padding = spatial_kernel // 2 97 | self.spat_covn_dw = nn.Sequential( 98 | nn.Conv3d(mid_features, mid_features, (1, spatial_kernel, spatial_kernel), 99 | stride=stride, padding=(0, spatial_padding, spatial_padding), 100 | groups=mid_features, bias=bias), 101 | BatchNormAct(mid_features, bn_layer=bn_layer, act_layer=act_layer), 102 | ) 103 | 104 | # Temporal depth-wise convolution 105 | temporal_padding = temporal_kernel // 2 106 | self.temp_covn_dw = nn.Sequential( 107 | nn.Conv3d(mid_features, mid_features, (temporal_kernel, 1, 1), 108 | stride=(1, 1, 1), padding=(temporal_padding, 0, 0), 109 | groups=mid_features, bias=bias), 110 | BatchNormAct(mid_features, bn_layer=bn_layer, act_layer=act_layer), 111 | ) 112 | 113 | # Squeeze-and-excitation 114 | self.se = SqueezeExcite3d(mid_features, act_layer=act_layer, reduce_ratio=se_reduce_ratio) 115 | 116 | # Point-wise linear projection 117 | self.conv_pwl = nn.Sequential( 118 | nn.Conv3d(mid_features, out_features, (1, 1, 1), bias=bias), 119 | BatchNormAct(out_features, bn_layer=bn_layer, apply_act=False), 120 | ) 121 | 122 | self.drop_path = DropPath(drop_prob=drop_path_rate) 123 | self.bn_sc = BatchNormAct(out_features, bn_layer=bn_layer, apply_act=False) 124 | 125 | def interpolate_shortcut(self, shortcut): 126 | _, c, t, h, w = shortcut.shape 127 | if self.spatial_stride > 1: 128 | size = (t, math.ceil(h / self.spatial_stride), math.ceil(w / self.spatial_stride)) 129 | shortcut = nn.functional.interpolate(shortcut, size=size, mode="nearest") 130 | if c != self.out_features: 131 | tile_dims = (1, math.ceil(self.out_features / c), 1, 1, 1) 132 | shortcut = torch.tile(shortcut, tile_dims)[:, :self.out_features] 133 | shortcut = self.bn_sc(shortcut) 134 | return shortcut 135 | 136 | def forward(self, x): 137 | shortcut = x 138 | x = self.conv_pw(x) 139 | x = self.spat_covn_dw(x) 140 | x = self.temp_covn_dw(x) 141 | x = self.se(x) 142 | x = self.conv_pwl(x) 143 | x = self.drop_path(x) + self.interpolate_shortcut(shortcut) 144 | return x 145 | 146 | 147 | class PositionalEncoding3d(nn.Module): 148 | def __init__(self, channels: int): 149 | super().__init__() 150 | self.orig_channels = channels 151 | channels = math.ceil(channels / 6) * 2 152 | if channels % 2: 153 | channels += 1 154 | self.channels = channels 155 | inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 156 | self.register_buffer("inv_freq", inv_freq) 157 | self.register_buffer("cached_encoding", None, persistent=False) 158 | 159 | def get_emb(self, sin_inp): 160 | emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=0) 161 | return torch.flatten(emb, 0, 1) 162 | 163 | def create_cached_encoding(self, tensor): 164 | _, orig_ch, x, y, z = tensor.shape 165 | assert orig_ch == self.orig_channels 166 | self.cached_encoding = None 167 | pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) 168 | pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) 169 | pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type()) 170 | sin_inp_x = torch.einsum("i,j->ij", self.inv_freq, pos_x) 171 | sin_inp_y = torch.einsum("i,j->ij", self.inv_freq, pos_y) 172 | sin_inp_z = torch.einsum("i,j->ij", self.inv_freq, pos_z) 173 | emb_x = self.get_emb(sin_inp_x).unsqueeze(-1).unsqueeze(-1) 174 | emb_y = self.get_emb(sin_inp_y).unsqueeze(1).unsqueeze(-1) 175 | emb_z = self.get_emb(sin_inp_z).unsqueeze(1).unsqueeze(1) 176 | emb = torch.zeros((self.channels * 3, x, y, z), dtype=tensor.dtype, device=tensor.device) 177 | emb[:self.channels] = emb_x 178 | emb[self.channels: 2 * self.channels] = emb_y 179 | emb[2 * self.channels:] = emb_z 180 | emb = emb[None, :self.orig_channels].contiguous() 181 | self.cached_encoding = emb 182 | return emb 183 | 184 | def forward(self, x): 185 | if len(x.shape) != 5: 186 | raise RuntimeError("The input tensor has to be 5D") 187 | 188 | cached_encoding = self.cached_encoding 189 | if cached_encoding is None or cached_encoding.shape[1:] != x.shape[1:]: 190 | cached_encoding = self.create_cached_encoding(x) 191 | 192 | return x + cached_encoding.expand_as(x) 193 | 194 | 195 | class ShuffleLayer(nn.Module): 196 | def __init__(self, 197 | in_features: int, 198 | out_features: int, 199 | groups: int = 1, 200 | act_layer: Callable = nn.ReLU, 201 | bn_layer: Callable = nn.BatchNorm1d, 202 | drop_path_rate: float = 0.): 203 | super().__init__() 204 | self.in_features = in_features 205 | self.out_features = out_features 206 | self.groups = groups 207 | self.conv = nn.Conv1d(in_features, out_features, (1,), groups=groups, bias=False) 208 | self.bn = BatchNormAct(out_features, bn_layer=bn_layer, act_layer=act_layer) 209 | self.drop_path = DropPath(drop_prob=drop_path_rate) 210 | self.bn_sc = BatchNormAct(out_features, bn_layer=bn_layer, apply_act=False) 211 | 212 | def shuffle_channels(self, x): 213 | if self.groups > 1: 214 | # Shuffle channels between groups 215 | b, c, t = x.shape 216 | x = x.view(b, self.groups, -1, t) 217 | x = torch.transpose(x, 1, 2) 218 | x = x.reshape(b, -1, t) 219 | return x 220 | 221 | def tile_shortcut(self, shortcut): 222 | if self.in_features != self.out_features: 223 | tile_dims = (1, math.ceil(self.out_features / self.in_features), 1) 224 | shortcut = torch.tile(shortcut, tile_dims)[:, :self.out_features] 225 | shortcut = self.bn_sc(shortcut) 226 | return shortcut 227 | 228 | def forward(self, x): 229 | shortcut = x 230 | x = self.conv(x) 231 | x = self.bn(x) 232 | x = self.shuffle_channels(x) 233 | x = self.drop_path(x) + self.tile_shortcut(shortcut) 234 | return x 235 | 236 | 237 | class Cortex(nn.Module): 238 | def __init__(self, 239 | in_features: int, 240 | features: tuple[int, ...], 241 | groups: int = 1, 242 | act_layer: Callable = nn.ReLU, 243 | bn_layer: Callable = nn.BatchNorm1d, 244 | drop_path_rate: float = 0.): 245 | super().__init__() 246 | self.layers = nn.Sequential() 247 | prev_num_features = in_features 248 | for num_features in features: 249 | self.layers.append( 250 | ShuffleLayer( 251 | in_features=prev_num_features, 252 | out_features=num_features, 253 | groups=groups, 254 | act_layer=act_layer, 255 | bn_layer=bn_layer, 256 | drop_path_rate=drop_path_rate, 257 | ) 258 | ) 259 | prev_num_features = num_features 260 | 261 | def forward(self, x): 262 | x = self.layers(x) 263 | return x 264 | 265 | 266 | class Readout(nn.Module): 267 | def __init__(self, 268 | in_features: int, 269 | out_features: int, 270 | groups: int = 1, 271 | softplus_beta: float = 1.0, 272 | drop_rate: float = 0.): 273 | super().__init__() 274 | self.out_features = out_features 275 | self.layer = nn.Sequential( 276 | nn.Dropout1d(p=drop_rate), 277 | nn.Conv1d(in_features, 278 | math.ceil(out_features / groups) * groups, (1,), 279 | groups=groups, bias=True), 280 | ) 281 | self.gate = nn.Softplus(beta=softplus_beta) # type: ignore 282 | 283 | def forward(self, x): 284 | x = self.layer(x) 285 | x = x[:, :self.out_features] 286 | x = self.gate(x) 287 | return x 288 | 289 | 290 | class DepthwiseCore(nn.Module): 291 | def __init__(self, 292 | in_channels: int = 1, 293 | features: tuple[int, ...] = (64, 128, 256, 512), 294 | spatial_strides: tuple[int, ...] = (2, 2, 2, 2), 295 | spatial_kernel: int = 3, 296 | temporal_kernel: int = 3, 297 | expansion_ratio: int = 3, 298 | se_reduce_ratio: int = 16, 299 | act_layer: Callable = nn.ReLU, 300 | bn_layer: Callable = nn.BatchNorm3d, 301 | drop_path_rate: float = 0.): 302 | super().__init__() 303 | num_blocks = len(features) 304 | assert num_blocks and num_blocks == len(spatial_strides) 305 | next_num_features = features[0] 306 | self.stem = nn.Sequential( 307 | nn.Conv3d(in_channels, next_num_features, (1, 1, 1), bias=False), 308 | BatchNormAct(next_num_features, bn_layer=bn_layer, apply_act=False), 309 | ) 310 | 311 | blocks = [] 312 | for block_index in range(num_blocks): 313 | num_features = features[block_index] 314 | spatial_stride = spatial_strides[block_index] 315 | if block_index < num_blocks - 1: 316 | next_num_features = features[block_index + 1] 317 | block_drop_path_rate = drop_path_rate * block_index / num_blocks 318 | 319 | blocks += [ 320 | PositionalEncoding3d(num_features), 321 | InvertedResidual3d( 322 | num_features, 323 | next_num_features, 324 | spatial_kernel=spatial_kernel, 325 | temporal_kernel=temporal_kernel, 326 | spatial_stride=spatial_stride, 327 | expansion_ratio=expansion_ratio, 328 | se_reduce_ratio=se_reduce_ratio, 329 | act_layer=act_layer, 330 | bn_layer=bn_layer, 331 | drop_path_rate=block_drop_path_rate, 332 | bias=False, 333 | ) 334 | ] 335 | self.blocks = nn.Sequential(*blocks) 336 | 337 | def forward(self, x): 338 | x = self.stem(x) 339 | x = self.blocks(x) 340 | return x 341 | 342 | 343 | class DwiseNeuro(nn.Module): 344 | def __init__(self, 345 | readout_outputs: tuple[int, ...], 346 | in_channels: int = 5, 347 | core_features: tuple[int, ...] = (64, 64, 64, 64, 128, 128, 128, 256, 256), 348 | spatial_strides: tuple[int, ...] = (2, 1, 1, 1, 2, 1, 1, 2, 1), 349 | spatial_kernel: int = 3, 350 | temporal_kernel: int = 5, 351 | expansion_ratio: int = 6, 352 | se_reduce_ratio: int = 32, 353 | cortex_features: tuple[int, ...] = (1024, 2048, 4096), 354 | groups: int = 2, 355 | softplus_beta: float = 0.07, 356 | drop_rate: float = 0.4, 357 | drop_path_rate: float = 0.1): 358 | super().__init__() 359 | act_layer = functools.partial(nn.SiLU, inplace=True) 360 | 361 | self.core = DepthwiseCore( 362 | in_channels=in_channels, 363 | features=core_features, 364 | spatial_strides=spatial_strides, 365 | spatial_kernel=spatial_kernel, 366 | temporal_kernel=temporal_kernel, 367 | expansion_ratio=expansion_ratio, 368 | se_reduce_ratio=se_reduce_ratio, 369 | act_layer=act_layer, 370 | bn_layer=nn.BatchNorm3d, 371 | drop_path_rate=drop_path_rate, 372 | ) 373 | 374 | self.pool = nn.AdaptiveAvgPool3d((None, 1, 1)) 375 | 376 | self.cortex = Cortex( 377 | in_features=core_features[-1], 378 | features=cortex_features, 379 | groups=groups, 380 | act_layer=act_layer, 381 | bn_layer=nn.BatchNorm1d, 382 | drop_path_rate=drop_path_rate, 383 | ) 384 | 385 | self.readouts = nn.ModuleList() 386 | for readout_output in readout_outputs: 387 | self.readouts.append( 388 | Readout( 389 | in_features=cortex_features[-1], 390 | out_features=readout_output, 391 | groups=groups, 392 | softplus_beta=softplus_beta, 393 | drop_rate=drop_rate, 394 | ) 395 | ) 396 | 397 | def forward(self, x: torch.Tensor, index: int | None = None) -> list[torch.Tensor] | torch.Tensor: 398 | # Input shape: (batch, channel, time, height, width), e.g. (32, 5, 16, 64, 64) 399 | x = self.core(x) # (32, 256, 16, 8, 8) 400 | x = self.pool(x).squeeze(-1).squeeze(-1) # (32, 256, 16) 401 | x = self.cortex(x) # (32, 4096, 16) 402 | if index is None: 403 | return [readout(x) for readout in self.readouts] 404 | else: 405 | return self.readouts[index](x) # (32, neurons, 16) 406 | -------------------------------------------------------------------------------- /src/phash.py: -------------------------------------------------------------------------------- 1 | import imagehash 2 | import numpy as np 3 | from PIL import Image 4 | 5 | from src.utils import get_length_without_nan 6 | 7 | 8 | def binary_array_to_int(arr: np.ndarray) -> int: 9 | bit_string = ''.join(str(b) for b in 1 * arr.flatten()) 10 | return int(bit_string, 2) 11 | 12 | 13 | def calculate_frame_phash(frame: np.ndarray) -> int: 14 | frame = Image.fromarray(frame.astype(np.uint8), 'L') 15 | phash = imagehash.phash(frame).hash 16 | return binary_array_to_int(phash.ravel()) 17 | 18 | 19 | def calculate_video_phash(video: np.ndarray, num_hash_frames: int = 5) -> int: 20 | length = get_length_without_nan(video[0, 0]) 21 | assert length >= num_hash_frames 22 | step = length // num_hash_frames 23 | video_hash: int = 0 24 | for frame_index in range(step // 2, length, step)[:num_hash_frames]: 25 | video_hash ^= calculate_frame_phash(video[..., frame_index]) 26 | return video_hash 27 | -------------------------------------------------------------------------------- /src/predictors.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import argus 7 | 8 | from src.indexes import IndexesGenerator 9 | from src.inputs import get_inputs_processor 10 | from src.argus_models import MouseModel 11 | from src import constants 12 | 13 | 14 | def get_blend_weights(name: str, size: int): 15 | if name == "ones": 16 | return np.ones(size, dtype=np.float32) 17 | elif name == "linear": 18 | return np.linspace(0, 1, num=size) 19 | else: 20 | raise ValueError(f"Blend weights '{name}' is not supported") 21 | 22 | 23 | class Predictor: 24 | def __init__(self, model_path: Path | str, device: str = "cuda:0", blend_weights="ones"): 25 | self.model: MouseModel = argus.load_model(model_path, device=device, optimizer=None, loss=None) 26 | self.model.eval() 27 | self.inputs_processor = get_inputs_processor(*self.model.params["inputs_processor"]) 28 | self.frame_stack_size = self.model.params["frame_stack"]["size"] 29 | self.frame_stack_step = self.model.params["frame_stack"]["step"] 30 | assert self.model.params["frame_stack"]["position"] == "last" 31 | assert self.model.params["responses_processor"][0] == "identity" 32 | self.indexes_generator = IndexesGenerator(self.frame_stack_size, 33 | self.frame_stack_step) 34 | self.blend_weights = get_blend_weights(blend_weights, self.frame_stack_size) 35 | 36 | @torch.no_grad() 37 | def predict_trial(self, 38 | video: np.ndarray, 39 | behavior: np.ndarray, 40 | pupil_center: np.ndarray, 41 | mouse_index: int) -> np.ndarray: 42 | inputs = self.inputs_processor(video, behavior, pupil_center).to(self.model.device) 43 | length = video.shape[-1] 44 | responses = np.zeros((constants.num_neurons[mouse_index], length), dtype=np.float32) 45 | blend_weights = np.zeros(length, np.float32) 46 | for index in range( 47 | self.indexes_generator.behind, 48 | length - self.indexes_generator.ahead 49 | ): 50 | indexes = self.indexes_generator.make_indexes(index) 51 | prediction = self.model.predict(inputs[:, indexes].unsqueeze(0), mouse_index)[0] 52 | responses[..., indexes] += prediction.cpu().numpy() 53 | blend_weights[indexes] += self.blend_weights 54 | responses /= np.clip(blend_weights, 1.0, None) 55 | return responses 56 | -------------------------------------------------------------------------------- /src/responses.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Type 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | from src import constants 9 | 10 | 11 | class ResponseNormalizer: 12 | def __init__(self, mouse: str): 13 | std = np.load( 14 | str(constants.sensorium_dir / mouse / "meta" / "statistics" / "responses" / "all" / "std.npy") 15 | ) 16 | threshold = 0.01 * np.nanmean(std) 17 | idx = std > threshold 18 | self._response_precision = np.ones_like(std) / threshold 19 | self._response_precision[idx] = 1 / std[idx] 20 | 21 | def __call__(self, responses): 22 | return responses * self._response_precision[..., :responses.shape[-1]] 23 | 24 | 25 | def responses_to_tensor(responses: np.ndarray) -> torch.Tensor: 26 | responses = responses.astype(np.float32) 27 | responses_tensor = torch.from_numpy(responses) 28 | responses_tensor = torch.relu(responses_tensor) 29 | return responses_tensor 30 | 31 | 32 | class ResponsesProcessor(metaclass=abc.ABCMeta): 33 | @abc.abstractmethod 34 | def __call__(self, responses: np.ndarray) -> torch.Tensor: 35 | pass 36 | 37 | 38 | class IdentityResponsesProcessor(ResponsesProcessor): 39 | def __call__(self, responses: np.ndarray) -> torch.Tensor: 40 | return responses_to_tensor(responses) 41 | 42 | 43 | class IndexingResponsesProcessor(ResponsesProcessor): 44 | def __init__(self, index: int | list[int]): 45 | self.index = index 46 | 47 | def __call__(self, responses: np.ndarray) -> torch.Tensor: 48 | responses = responses[..., self.index] 49 | responses_tensor = responses_to_tensor(responses) 50 | return responses_tensor 51 | 52 | 53 | class SelectLastResponsesProcessor(IndexingResponsesProcessor): 54 | def __init__(self): 55 | super().__init__(index=-1) 56 | 57 | 58 | _RESPONSES_PROCESSOR_REGISTRY: dict[str, Type[ResponsesProcessor]] = dict( 59 | identity=IdentityResponsesProcessor, 60 | indexing=IndexingResponsesProcessor, 61 | last=SelectLastResponsesProcessor, 62 | ) 63 | 64 | 65 | def get_responses_processor(name: str, processor_params: dict) -> ResponsesProcessor: 66 | assert name in _RESPONSES_PROCESSOR_REGISTRY 67 | return _RESPONSES_PROCESSOR_REGISTRY[name](**processor_params) 68 | -------------------------------------------------------------------------------- /src/submission.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from src.responses import ResponseNormalizer 7 | from src.data import get_mouse_data 8 | from src.metrics import corr 9 | from src import constants 10 | 11 | 12 | def cut_responses_for_submission(prediction: np.ndarray): 13 | prediction = prediction[..., :constants.submission_limit_length] 14 | prediction = prediction[..., constants.submission_skip_first:] 15 | if constants.submission_skip_last: 16 | prediction = prediction[..., :-constants.submission_skip_last] 17 | return prediction 18 | 19 | 20 | def evaluate_folds_predictions(experiment: str, dataset: str): 21 | prediction_dir = constants.predictions_dir / experiment / "out-of-fold" 22 | correlations = dict() 23 | for mouse in constants.dataset2mice[dataset]: 24 | mouse_data = get_mouse_data(mouse=mouse, splits=constants.folds_splits) 25 | mouse_prediction_dir = prediction_dir / mouse 26 | predictions = [] 27 | targets = [] 28 | for trial_data in mouse_data["trials"]: 29 | trial_id = trial_data['trial_id'] 30 | prediction = np.load(str(mouse_prediction_dir / f"{trial_id}.npy")) 31 | target = np.load(trial_data["response_path"])[..., :trial_data["length"]] 32 | prediction = cut_responses_for_submission(prediction) 33 | target = cut_responses_for_submission(target) 34 | predictions.append(prediction) 35 | targets.append(target) 36 | correlation = float(corr( 37 | np.concatenate(predictions, axis=1), 38 | np.concatenate(targets, axis=1), 39 | axis=1 40 | ).mean()) 41 | print(f"Mouse {mouse} correlation: {correlation}") 42 | correlations[mouse] = correlation 43 | mean_correlation = float(np.mean(list(correlations.values()))) 44 | print("Mean correlation:", mean_correlation) 45 | 46 | evaluate_result = {"correlations": correlations, "mean_correlation": mean_correlation} 47 | with open(prediction_dir / f"evaluate_{dataset}.json", "w") as outfile: 48 | json.dump(evaluate_result, outfile, indent=4) 49 | 50 | 51 | def make_submission(experiment: str, split: str): 52 | prediction_dir = constants.predictions_dir / experiment / split 53 | data = [] 54 | for mouse in constants.new_mice: 55 | normalizer = ResponseNormalizer(mouse) 56 | mouse_data = get_mouse_data(mouse=mouse, splits=[split]) 57 | neuron_ids = mouse_data["neuron_ids"].tolist() 58 | mouse_prediction_dir = prediction_dir / mouse 59 | for trial_data in mouse_data["trials"]: 60 | trial_id = trial_data['trial_id'] 61 | prediction = np.load(str(mouse_prediction_dir / f"{trial_id}.npy")) 62 | prediction = normalizer(prediction) 63 | prediction = cut_responses_for_submission(prediction) 64 | data.append((mouse, trial_id, prediction.tolist(), neuron_ids)) 65 | submission_df = pd.DataFrame.from_records( 66 | data, 67 | columns=['mouse', 'trial_indices', 'prediction', 'neuron_ids'] 68 | ) 69 | del data 70 | split = split.replace('_test_', '_').replace('bonus', 'test_bonus_ood') 71 | submission_path = prediction_dir / f"predictions_{split}.parquet.brotli" 72 | submission_df.to_parquet(submission_path, compression='brotli', engine='pyarrow', index=False) 73 | print(f"Submission saved to '{submission_path}'") 74 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import math 3 | import time 4 | import random 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | 9 | from torch import nn 10 | 11 | 12 | def set_random_seed(index: int): 13 | seed = int(time.time() * 1000.0) + index 14 | random.seed(seed) 15 | np.random.seed(seed % (2 ** 32 - 1)) 16 | 17 | 18 | def get_lr(base_lr: float, batch_size: int, base_batch_size: int = 4) -> float: 19 | return base_lr * (batch_size / base_batch_size) 20 | 21 | 22 | def get_best_model_path(dir_path, return_score=False, more_better=True): 23 | dir_path = Path(dir_path) 24 | model_scores = [] 25 | for model_path in dir_path.glob('*.pth'): 26 | score = re.search(r'-(\d+(?:\.\d+)?).pth', str(model_path)) 27 | if score is not None: 28 | score = float(score.group(0)[1:-4]) 29 | model_scores.append((model_path, score)) 30 | 31 | if not model_scores: 32 | if return_score: 33 | return None, -np.inf if more_better else np.inf 34 | else: 35 | return None 36 | 37 | model_score = sorted(model_scores, key=lambda x: x[1], reverse=more_better) 38 | best_model_path = model_score[0][0] 39 | if return_score: 40 | best_score = model_score[0][1] 41 | return best_model_path, best_score 42 | else: 43 | return best_model_path 44 | 45 | 46 | def init_weights(module: nn.Module): 47 | for m in module.modules(): 48 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 49 | fan_out = math.prod(m.kernel_size) * m.out_channels 50 | fan_out //= m.groups 51 | nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out)) 52 | if m.bias is not None: 53 | nn.init.zeros_(m.bias) 54 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 55 | nn.init.ones_(m.weight) 56 | nn.init.zeros_(m.bias) 57 | elif isinstance(m, nn.Linear): 58 | fan_out = m.weight.size(0) 59 | fan_in = 0 60 | init_range = 1.0 / math.sqrt(fan_in + fan_out) 61 | nn.init.uniform_(m.weight, -init_range, init_range) 62 | if m.bias is not None: 63 | nn.init.zeros_(m.bias) 64 | 65 | 66 | def get_length_without_nan(array: np.ndarray): 67 | nan_indexes = np.argwhere(np.isnan(array)).ravel() 68 | if nan_indexes.shape[0]: 69 | return nan_indexes[0] 70 | else: 71 | return array.shape[0] 72 | --------------------------------------------------------------------------------