├── .dockerignore
├── .gitignore
├── CITATION.cff
├── Dockerfile
├── LICENSE
├── Makefile
├── README.md
├── configs
├── .gitignore
├── distillation_001.py
└── true_batch_001.py
├── data
├── .gitignore
├── experiments.torrent
└── readme_images
│ ├── architecture.png
│ ├── preview.png
│ └── softplus.png
├── requirements.txt
├── scripts
├── download_data.py
├── ensemble.py
├── predict.py
└── train.py
├── setup.cfg
└── src
├── __init__.py
├── argus_models.py
├── constants.py
├── data.py
├── datasets.py
├── ema.py
├── indexes.py
├── inputs.py
├── losses.py
├── metrics.py
├── mixers.py
├── models
├── __init__.py
└── dwiseneuro.py
├── phash.py
├── predictors.py
├── responses.py
├── submission.py
└── utils.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | data/
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | title: Solution for Sensorium 2023 Competition
3 | message: 'If you use this work, please cite it using these metadata.'
4 | type: software
5 | authors:
6 | - given-names: Ruslan
7 | family-names: Baikulov
8 | email: ruslan1123@gmail.com
9 | orcid: 'https://orcid.org/0009-0003-4400-0619'
10 | repository-code: 'https://github.com/lRomul/sensorium'
11 | abstract: >-
12 | This repository contains the winning solution for the Sensorium 2023 Competition.
13 | The competition aimed to predict the activity of neurons in the primary visual
14 | cortex of mice in response to videos. The proposed solution includes a novel
15 | model architecture called DwiseNeuro and a training pipeline with a solid
16 | cross-validation strategy and knowledge distillation.
17 | license: MIT
18 | version: v23.11.22
19 | date-released: '2023-11-22'
20 | doi: 10.5281/zenodo.10155151
21 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM osaiai/dokai:23.10-pytorch
2 |
3 | COPY requirements.txt requirements.txt
4 | RUN pip3 install --no-cache-dir -r requirements.txt
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Ruslan Baikulov
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | NAME?=sensorium
2 | COMMAND?=bash
3 | OPTIONS?=
4 |
5 | GPUS?=all
6 | ifeq ($(GPUS),none)
7 | GPUS_OPTION=
8 | else
9 | GPUS_OPTION=--gpus=$(GPUS)
10 | endif
11 |
12 | .PHONY: all
13 | all: stop build run
14 |
15 | .PHONY: build
16 | build:
17 | docker build -t $(NAME) .
18 |
19 | .PHONY: stop
20 | stop:
21 | -docker stop $(NAME)
22 | -docker rm $(NAME)
23 |
24 | .PHONY: run
25 | run:
26 | docker run --rm -dit \
27 | --net=host \
28 | --ipc=host \
29 | $(OPTIONS) \
30 | $(GPUS_OPTION) \
31 | -v $(shell pwd):/workdir \
32 | --name=$(NAME) \
33 | $(NAME) \
34 | $(COMMAND)
35 | docker attach $(NAME)
36 |
37 | .PHONY: attach
38 | attach:
39 | docker attach $(NAME)
40 |
41 | .PHONY: logs
42 | logs:
43 | docker logs -f $(NAME)
44 |
45 | .PHONY: exec
46 | exec:
47 | docker exec -it $(OPTIONS) $(NAME) $(COMMAND)
48 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Solution for Sensorium 2023 Competition
4 |
5 | [](LICENSE)
6 | [](https://doi.org/10.5281/zenodo.10155151)
7 |
8 |
9 |
10 | This repository contains the winning solution for the Sensorium 2023, part of the NeurIPS 2023 competition track.
11 | The competition aimed to find the best model that can predict the activity of neurons in the primary visual cortex of mice in response to videos.
12 | The competition introduced a temporal component using dynamic stimuli (videos) instead of static stimuli (images) used in Sensorium 2022, making the task more challenging.
13 |
14 | The primary metric of the competition was a single trial correlation.
15 | You can read about the metric, the data, and the task in the competition paper [^1].
16 | It is important to note that additional data for five mice was introduced during the competition, which doubled the dataset's size ([old](https://gin.g-node.org/pollytur/Sensorium2023Data) and [new](https://gin.g-node.org/pollytur/sensorium_2023_dataset) data).
17 |
18 | ## Solution
19 |
20 | Key points:
21 | * [DwiseNeuro](src/models/dwiseneuro.py) - novel model architecture for prediction neural activity in the mouse primary visual cortex.
22 | * Solid cross-validation strategy with splitting folds by perceptual video hash.
23 | * Training on all mice with an option to fill unlabeled samples via distillation.
24 |
25 | ## Model Architecture
26 |
27 | During the competition, I dedicated most of my time to designing the model architecture since it significantly impacted the solution's outcome compared to other components.
28 | I iteratively tested various computer vision and deep learning techniques, integrating them into the architecture as the correlation metric improved.
29 |
30 | The diagram below illustrates the final architecture, which I named DwiseNeuro:
31 |
32 | 
33 |
34 | DwiseNeuro consists of three main parts: core, cortex, and readouts.
35 | The core consumes sequences of video frames and mouse behavior activity in separate channels, processing temporal and spatial features.
36 | Produced features pass through global pooling over spatial dimensions to aggregate them.
37 | The cortex processes the pooled features independently for each timestep, significantly increasing the channels.
38 | Finally, each readout predicts the activation of neurons for the corresponding mouse.
39 |
40 | In the following sections, we will delve deeper into each part of the architecture.
41 |
42 | ### Core
43 |
44 | The first layer of the module is the stem. It's a point-wise 3D convolution for increasing the number of channels, followed by batch normalization.
45 | The rest of the core consists of inverted residual blocks [^2][^3] with a `narrow -> wide -> narrow` channel structure.
46 |
47 | #### Techniques
48 |
49 | Several methods were added to the inverted residual block rewritten in 3D layers:
50 | * **Absolute Position Encoding** [^4] - summing the encoding to the input of each block allows convolutions to accumulate position information. It's quite important because of the subsequent spatial pooling after the core.
51 | * **Factorized (2+1)D convolution** [^5] - 3D depth-wise convolution was replaced with a spatial 2D depth-wise convolution followed by a temporal 1D depth-wise convolution. There are spatial convolutions with stride two in some blocks to compress output size.
52 | * **Shortcut Connections** - completely parameter-free residual shortcuts with three operations:
53 | * Identity mapping if input and output dimensions are equal. It's the same as the connection proposed in ResNet [^6].
54 | * Nearest interpolation in case of different spatial sizes.
55 | * Cycling repeating of channels if they don't match.
56 | * **Squeeze-and-Excitation** [^7] - dynamic channel-wise feature recalibration.
57 | * **DropPath (Stochastic Depth)** [^8][^9] - regularization that randomly drops the block's main path for each sample in batch.
58 |
59 | Batch normalization is applied after each layer, including the shortcut.
60 | SiLU activation is used after expansion and depth-wise convolutions.
61 |
62 | #### Hyperparameters
63 |
64 | I found that the number of core blocks and their parameters dramatically affect the outcome.
65 | It's possible to tune channels, strides, expansion ratio, and spatial/temporal kernel sizes.
66 | Obviously, it is almost impossible to start experiments with optimal values.
67 | The problem is mentioned in the EfficientNet [^3] paper, which concluded that it is essential to carefully balance model width, depth, and resolution.
68 |
69 | After conducting a lot of experiments, I chose the following parameters:
70 | * Four blocks with 64 output channels, three with 128, and two with 256.
71 | * Three blocks have stride two. They are the first in each subgroup from the point above.
72 | * Expansion ratio of the inverted residual block is six.
73 | * Kernel of spatial depth-wise convolution is (1, 3, 3).
74 | * Kernel of temporal depth-wise convolution is (5, 1, 1).
75 |
76 | ### Cortex
77 |
78 | Compared with related works [^10][^11], the model architecture includes a new part called the cortex.
79 | It is also common for all mice as the core.
80 | The cortex receives features with channels and temporal dimensions only.
81 | Spatial information was accumulated through position encoding applied earlier in the core and compressed by average pooling after the core.
82 | The primary purpose of the cortex is to smoothly increase the number of channels, which the readouts will further use.
83 |
84 | The building element of the module is a grouped 1D convolution followed by the channel shuffle operation [^12].
85 | Batch normalization, SiLU activation, and shortcut connections with stochastic depth were applied similarly to the core.
86 |
87 | #### Hyperparameters
88 |
89 | Hyperparameters of the cortex were also important:
90 | * Convolution with two groups and kernel size one (bigger kernel size over temporal dimension has not led to better results).
91 | * Three layers with 1024, 2048, and 4096 channels.
92 |
93 | As you can see, the number of channels is quite large.
94 | Groups help optimize computation and memory efficiency.
95 | Channel shuffle operation allows the sharing of information between groups of different layers.
96 |
97 | ### Readouts
98 |
99 | The readout is a single 1D convolution with two groups and kernel size one, followed by Softplus activation.
100 | Each of the ten mice has its readout with the number of output channels equal to the number of neurons (7863, 7908, 8202, 7939, 8122, 7440, 7928, 8285, 7671, 7495, respectively).
101 |
102 | #### Softplus
103 |
104 | Keeping the response positive by using Softplus [^10] was essential in my pipeline.
105 | It works much better than `ELU + 1` [^11], especially with tuning the Softplus beta parameter.
106 | In my case, the optimal beta value was about 0.07, which resulted in a 0.018 increase in the correlation metric.
107 |
108 | You can see a comparison of `ELU + 1` and Softplus in the plot below:
109 |
110 | 
111 |
112 | #### Learnable Softplus
113 |
114 | I also conducted an experiment where the beta parameter was trainable.
115 | Interestingly, the trained value converged approximately to the optimal, which I found by grid search.
116 | I omitted the learnable Softplus from the solution because it resulted in a slightly worse score.
117 | But this may be an excellent way to quickly and automatically find a good beta.
118 |
119 | Here's a numerical stable implementation of learnable Softplus in PyTorch:
120 |
121 | ```Python
122 | import torch
123 | from torch import nn
124 |
125 |
126 | class LearnableSoftplus(nn.Module):
127 | def __init__(self, beta: float):
128 | super().__init__()
129 | self.beta = nn.Parameter(torch.tensor(float(beta)))
130 |
131 | def forward(self, x):
132 | xb = x * self.beta
133 | return (torch.clamp(xb, 0) + torch.minimum(xb, -xb).exp().log1p()) / self.beta
134 | ```
135 |
136 | ## Training
137 |
138 | ### Validation strategy
139 |
140 | At the end of the competition, I employed 7-fold cross-validation to check hypotheses and tune hyperparameters more precisely.
141 | I used all available labeled data to make folds.
142 | Random splitting gave an overly optimistic metric estimation because some videos were duplicated (e.g., in the original validation split or between old and new datasets).
143 | To solve this issue, I created group folds with non-overlapping videos.
144 | Similar videos were found using [perceptual hashes](src/phash.py) of several frames fetched deterministically.
145 |
146 | ### Basic training ([config](configs/true_batch_001.py))
147 |
148 | The training was performed in two stages. The first stage is basic training with the following pipeline parameters:
149 | * Learning rate warmup for the first three epochs from 0 to 2.4e-03, cosine annealing last 18 epochs to 2.4e-05
150 | * Batch size 32, one training epoch comprises 72000 samples
151 | * Optimizer AdamW [^13] with weight decay 0.05
152 | * Poisson loss
153 | * Model EMA with decay 0.999
154 | * CutMix [^14] with alpha 1.0 and usage probability 0.5
155 | * The sampling of different mice in the batch is random by uniform distribution
156 |
157 | Each dataset sample consists of a grayscale video, behavior activity measurements (pupil center, pupil dilation, and running speed), and the neuron responses of a single mouse.
158 | All data is presented at 30 FPS. During training, the model consumes 16 frames, skipping every second (equivalent to 16 neighboring frames at 15 FPS).
159 | The video frames were zero-padded to 64x64 pixels. The behavior activities were added as separate channels.
160 | The entire tensor channel was filled with the value for each behavior measurement.
161 | No normalization is applied to the target and input tensors during training.
162 |
163 | The ensemble of models from all folds gets 0.2905 single-trial correlation on the main track and 0.2207 on the bonus track in the final phase of the competition.
164 | This result would be enough to take first place in the main and bonus (out-of-distribution) competition tracks.
165 |
166 | ### Knowledge Distillation ([config](configs/distillation_001.py))
167 |
168 | For an individual sample in the batch, the loss was calculated for the responses of only one mouse.
169 | Because the input tensor is associated with a single mouse trial, and there are no neural activity data for other mice.
170 | However, the model can predict responses for all mice from the input tensor. In the second stage of training, I used a method similar to knowledge distillation [^15].
171 | I created a pipeline where models from the first stage predict unlabeled responses during training.
172 | As a result, the second-stage models trained all their readouts via each batch sample.
173 | The loss value on distilled predictions was weighed to be 0.36% of the overall loss.
174 |
175 | The hyperparameters were identical, except for the expansion ratio in inverted residual blocks: seven in the first stage and six in the second.
176 |
177 | In the second stage, the ensemble of models achieves nearly the same single-trial correlation as the ensemble from the first stage.
178 | However, what is fascinating is that each fold model performs better by an average score of 0.007 than the corresponding model from the first stage.
179 | Thus, the distilled model works like an ensemble of undistilled models.
180 | According to the work [^16], the individual model is forced to learn the ensemble's performance during knowledge distillation, and an ensemble of distilled models offers no more performance boost.
181 | I can observe the same behavior in my solution.
182 |
183 | Distillation can be a great practice if you need one good model.
184 | But in ensembles, this leads to minor changes in performance.
185 |
186 | ## Prediction
187 |
188 | The ensembles were produced by taking the arithmetic mean of predictions from multiple steps:
189 | * Overlapping a sliding window over each possible sequence of frames.
190 | * Models from cross-validations of one training stage.
191 | * Training stages (optional).
192 |
193 | I used the same model weights for both competition tracks.
194 | The competition only evaluated the responses of five mice from the new dataset, so I only predicted those.
195 |
196 | ## Competition progress
197 |
198 | You can see the progress of solution development during the competition in [spreadsheets](https://docs.google.com/spreadsheets/d/1xJTB6lZvtjNSQbYQiB_hgKIrL9hUB0uL3xhUAK1Pmxw/edit#gid=0) (the document consists of multiple sheets).
199 | Unfortunately, the document contains less than half of the conducted experiments because sometimes I was too lazy to fill it :)
200 | However, if you need a complete chronology of valuable changes, you can see it in git history.
201 |
202 | To summarize, an early model with depth-wise 3D convolution blocks achieved a score of around 0.19 on the main track during the live phase of the competition.
203 | Subsequently, implementing techniques from the core section, tuning hyperparameters, and training on all available data boosted the score to 0.25.
204 | Applying non-standard normalization, expected by evaluation servers, on postprocessing improved the score to 0.27.
205 | The cortex and CutMix increased the score to 0.276.
206 | Then, the beta value of Softplus was tuned, resulting in a score of 0.294.
207 | Lastly, adjusting drop rate and batch size parameters helped to achieve a score of 0.3 on the main track during the live phase.
208 |
209 | The ensemble of the basic and distillation training stages achieved a single-trial correlation of 0.2913 on the main track and 0.2215 on the bonus track in the final phase (0.3005 and 0.2173 in the live phase, respectively).
210 | This result is just a bit better than the basic training result alone, but I should provide it because it was the best submission in the competition.
211 | In addition, it was interesting to research the relation between ensembling and distillation, which I wrote above.
212 |
213 | Thanks to the Sensorium organizers and participants for the excellent competition. Thanks to my family and friends who supported me during the competition!
214 |
215 |
216 |
217 | [^1]: Turishcheva, Polina, et al. (2023). The Dynamic Sensorium competition for predicting large-scale mouse visual cortex activity from videos. https://arxiv.org/abs/2305.19654
218 | [^2]: Sandler, Mark, et al. (2018). Mobilenetv2: Inverted residuals and linear bottlenecks. https://arxiv.org/abs/1801.04381
219 | [^3]: Tan, Mingxing, and Quoc Le. (2019). Efficientnet: Rethinking model scaling for convolutional neural networks. https://arxiv.org/abs/1905.11946
220 | [^4]: Vaswani, Ashish, et al. (2017). Attention is all you need. https://arxiv.org/abs/1706.03762
221 | [^5]: Tran, Du, et al. (2018). A closer look at spatiotemporal convolutions for action recognition. https://arxiv.org/abs/1711.11248
222 | [^6]: He, Kaiming, et al. (2016). Deep residual learning for image recognition. https://arxiv.org/abs/1512.03385
223 | [^7]: Hu, Jie, Li Shen, and Gang Sun. (2018). Squeeze-and-excitation networks. https://arxiv.org/abs/1709.01507
224 | [^8]: Larsson, Gustav, Michael Maire, and Gregory Shakhnarovich. (2016). Fractalnet: Ultra-deep neural networks without residuals. https://arxiv.org/abs/1605.07648
225 | [^9]: Huang, Gao, et al. (2016). Deep networks with stochastic depth. https://arxiv.org/abs/1603.09382
226 | [^10]: Höfling, Larissa, et al. (2022). A chromatic feature detector in the retina signals visual context changes. https://www.biorxiv.org/content/10.1101/2022.11.30.518492
227 | [^11]: Lurz, Konstantin-Klemens, et al. (2020). Generalization in data-driven models of primary visual cortex. https://www.biorxiv.org/content/10.1101/2020.10.05.326256
228 | [^12]: Zhang, Xiangyu, et al. (2018). Shufflenet: An extremely efficient convolutional neural network for mobile devices. https://arxiv.org/abs/1707.01083
229 | [^13]: Loshchilov, Ilya, and Frank Hutter. (2017). Decoupled weight decay regularization. https://arxiv.org/abs/1711.05101
230 | [^14]: Yun, Sangdoo, et al. (2019). Cutmix: Regularization strategy to train strong classifiers with localizable features. https://arxiv.org/abs/1905.04899
231 | [^15]: Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. (2015). Distilling the knowledge in a neural network. https://arxiv.org/abs/1503.02531
232 | [^16]: Allen-Zhu, Zeyuan, and Yuanzhi Li. (2020). Towards understanding ensemble, knowledge distillation and self-distillation in deep learning. https://arxiv.org/abs/2012.09816
233 |
234 | ## Quick setup and start
235 |
236 | ### Requirements
237 |
238 | * Linux (tested on Ubuntu 20.04 and 22.04)
239 | * NVIDIA GPU (models trained on RTX A6000)
240 | * NVIDIA Drivers >= 535, CUDA >= 12.2
241 | * [Docker](https://docs.docker.com/engine/install/)
242 | * [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)
243 |
244 | Pipeline tuned for training on a single RTX A6000 with 48 GB.
245 | In the case of GPU with less memory, you can use gradient accumulation by increasing the `iter_size` parameter in training configs.
246 | It will worsen the result (by a 0.002 score for `"iter_size": 2`), but it has less than the effect of reducing the batch size.
247 |
248 | ### Run
249 |
250 | Clone the repo and enter the folder.
251 |
252 | ```bash
253 | git clone git@github.com:lRomul/sensorium.git
254 | cd sensorium
255 | ```
256 |
257 | Build a Docker image and run a container.
258 |
259 | Here is a small guide on how to use the provided Makefile
260 |
261 | ```bash
262 | make # stop, build, run
263 |
264 | # do the same
265 | make stop
266 | make build
267 | make run
268 |
269 | make # by default all GPUs passed
270 | make GPUS=all # do the same
271 | make GPUS=none # without GPUs
272 |
273 | make run GPUS=2 # pass the first two GPUs
274 | make run GPUS='\"device=1,2\"' # pass GPUs numbered 1 and 2
275 |
276 | make logs
277 | make exec # run a new command in a running container
278 | make exec COMMAND="bash" # do the same
279 | make stop
280 | ```
281 |
282 |
283 |
284 | ```bash
285 | make
286 | ```
287 |
288 | From now on, you should run all commands inside the docker container.
289 |
290 | If you already have the Sensorium 2023 dataset (148 GB), copy it to the folder `./data/sensorium_all_2023/`.
291 | Otherwise, use the script for downloading:
292 |
293 | ```bash
294 | python scripts/download_data.py
295 | ```
296 |
297 | You can now reproduce the final results of the solution using the following commands:
298 | ```bash
299 | # Train
300 | # The training time is 3.5 days (12 hours per fold) for each experiment on a single A6000
301 | # You can speed up the process by using the --folds argument to train folds in parallel
302 | # Or just download trained weights in the section below
303 | python scripts/train.py -e true_batch_001
304 | python scripts/train.py -e distillation_001
305 |
306 | # Predict
307 | # Any GPU with more than 6 GB memory will be enough
308 | python scripts/predict.py -e true_batch_001 -s live_test_main
309 | python scripts/predict.py -e true_batch_001 -s live_test_bonus
310 | python scripts/predict.py -e true_batch_001 -s final_test_main
311 | python scripts/predict.py -e true_batch_001 -s final_test_bonus
312 | python scripts/predict.py -e distillation_001 -s live_test_main
313 | python scripts/predict.py -e distillation_001 -s live_test_bonus
314 | python scripts/predict.py -e distillation_001 -s final_test_main
315 | python scripts/predict.py -e distillation_001 -s final_test_bonus
316 |
317 | # Ensemble predictions of two experiments
318 | python scripts/ensemble.py -e distillation_001,true_batch_001 -s live_test_main
319 | python scripts/ensemble.py -e distillation_001,true_batch_001 -s live_test_bonus
320 | python scripts/ensemble.py -e distillation_001,true_batch_001 -s final_test_main
321 | python scripts/ensemble.py -e distillation_001,true_batch_001 -s final_test_bonus
322 |
323 | # Final predictions will be there
324 | cd data/predictions/distillation_001,true_batch_001
325 | ```
326 |
327 | ### Trained model weights
328 |
329 | You can skip the training step by downloading model weights (9.5 GB) using [torrent file](data/experiments.torrent).
330 |
331 | Place the files in the data directory so that the folder structure is as follows:
332 |
333 | ```
334 | data
335 | ├── experiments
336 | │ ├── distillation_001
337 | │ └── true_batch_001
338 | └── sensorium_all_2023
339 | ├── dynamic29156-11-10-Video-8744edeac3b4d1ce16b680916b5267ce
340 | ├── dynamic29228-2-10-Video-8744edeac3b4d1ce16b680916b5267ce
341 | ├── dynamic29234-6-9-Video-8744edeac3b4d1ce16b680916b5267ce
342 | ├── dynamic29513-3-5-Video-8744edeac3b4d1ce16b680916b5267ce
343 | ├── dynamic29514-2-9-Video-8744edeac3b4d1ce16b680916b5267ce
344 | ├── dynamic29515-10-12-Video-9b4f6a1a067fe51e15306b9628efea20
345 | ├── dynamic29623-4-9-Video-9b4f6a1a067fe51e15306b9628efea20
346 | ├── dynamic29647-19-8-Video-9b4f6a1a067fe51e15306b9628efea20
347 | ├── dynamic29712-5-9-Video-9b4f6a1a067fe51e15306b9628efea20
348 | └── dynamic29755-2-8-Video-9b4f6a1a067fe51e15306b9628efea20
349 | ```
350 |
--------------------------------------------------------------------------------
/configs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !*.py
3 |
--------------------------------------------------------------------------------
/configs/distillation_001.py:
--------------------------------------------------------------------------------
1 | from src.utils import get_lr
2 | from src import constants
3 |
4 |
5 | image_size = (64, 64)
6 | batch_size = 32
7 | base_lr = 3e-4
8 | frame_stack_size = 16
9 | config = dict(
10 | image_size=image_size,
11 | batch_size=batch_size,
12 | base_lr=base_lr,
13 | min_base_lr=base_lr * 0.01,
14 | ema_decay=0.999,
15 | train_epoch_size=72000,
16 | num_epochs=[3, 18],
17 | stages=["warmup", "train"],
18 | num_dataloader_workers=8,
19 | init_weights=True,
20 | argus_params={
21 | "nn_module": ("dwiseneuro", {
22 | "readout_outputs": constants.num_neurons,
23 | "in_channels": 5,
24 | "core_features": (64, 64, 64, 64,
25 | 128, 128, 128,
26 | 256, 256),
27 | "spatial_strides": (2, 1, 1, 1,
28 | 2, 1, 1,
29 | 2, 1),
30 | "spatial_kernel": 3,
31 | "temporal_kernel": 5,
32 | "expansion_ratio": 6,
33 | "se_reduce_ratio": 32,
34 | "cortex_features": (512 * 2, 1024 * 2, 2048 * 2),
35 | "groups": 2,
36 | "softplus_beta": 0.07,
37 | "drop_rate": 0.4,
38 | "drop_path_rate": 0.1,
39 | }),
40 | "loss": ("mice_poisson", {
41 | "log_input": False,
42 | "full": False,
43 | "eps": 1e-8,
44 | }),
45 | "optimizer": ("AdamW", {
46 | "lr": get_lr(base_lr, batch_size),
47 | "weight_decay": 0.05,
48 | }),
49 | "device": "cuda:0",
50 | "frame_stack": {
51 | "size": frame_stack_size,
52 | "step": 2,
53 | "position": "last",
54 | },
55 | "inputs_processor": ("stack_inputs", {
56 | "size": image_size,
57 | "pad_fill_value": 0.,
58 | }),
59 | "responses_processor": ("identity", {}),
60 | "amp": True,
61 | "iter_size": 1,
62 | },
63 | cutmix={
64 | "alpha": 1.0,
65 | "prob": 0.5,
66 | },
67 | distill={
68 | "experiment": "true_batch_001",
69 | "ratio": 0.36,
70 | },
71 | )
72 |
--------------------------------------------------------------------------------
/configs/true_batch_001.py:
--------------------------------------------------------------------------------
1 | from src.utils import get_lr
2 | from src import constants
3 |
4 |
5 | image_size = (64, 64)
6 | batch_size = 32
7 | base_lr = 3e-4
8 | frame_stack_size = 16
9 | config = dict(
10 | image_size=image_size,
11 | batch_size=batch_size,
12 | base_lr=base_lr,
13 | min_base_lr=base_lr * 0.01,
14 | ema_decay=0.999,
15 | train_epoch_size=72000,
16 | num_epochs=[3, 18],
17 | stages=["warmup", "train"],
18 | num_dataloader_workers=8,
19 | init_weights=True,
20 | argus_params={
21 | "nn_module": ("dwiseneuro", {
22 | "readout_outputs": constants.num_neurons,
23 | "in_channels": 5,
24 | "core_features": (64, 64, 64, 64,
25 | 128, 128, 128,
26 | 256, 256),
27 | "spatial_strides": (2, 1, 1, 1,
28 | 2, 1, 1,
29 | 2, 1),
30 | "spatial_kernel": 3,
31 | "temporal_kernel": 5,
32 | "expansion_ratio": 7,
33 | "se_reduce_ratio": 32,
34 | "cortex_features": (512 * 2, 1024 * 2, 2048 * 2),
35 | "groups": 2,
36 | "softplus_beta": 0.07,
37 | "drop_rate": 0.4,
38 | "drop_path_rate": 0.1,
39 | }),
40 | "loss": ("mice_poisson", {
41 | "log_input": False,
42 | "full": False,
43 | "eps": 1e-8,
44 | }),
45 | "optimizer": ("AdamW", {
46 | "lr": get_lr(base_lr, batch_size),
47 | "weight_decay": 0.05,
48 | }),
49 | "device": "cuda:0",
50 | "frame_stack": {
51 | "size": frame_stack_size,
52 | "step": 2,
53 | "position": "last",
54 | },
55 | "inputs_processor": ("stack_inputs", {
56 | "size": image_size,
57 | "pad_fill_value": 0.,
58 | }),
59 | "responses_processor": ("identity", {}),
60 | "amp": True,
61 | "iter_size": 1,
62 | },
63 | cutmix={
64 | "alpha": 1.0,
65 | "prob": 0.5,
66 | },
67 | )
68 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 | !experiments.torrent
4 | !readme_images/
5 |
--------------------------------------------------------------------------------
/data/experiments.torrent:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lRomul/sensorium/6849050e74e2843fbd70b33af110d71638aa37bb/data/experiments.torrent
--------------------------------------------------------------------------------
/data/readme_images/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lRomul/sensorium/6849050e74e2843fbd70b33af110d71638aa37bb/data/readme_images/architecture.png
--------------------------------------------------------------------------------
/data/readme_images/preview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lRomul/sensorium/6849050e74e2843fbd70b33af110d71638aa37bb/data/readme_images/preview.png
--------------------------------------------------------------------------------
/data/readme_images/softplus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lRomul/sensorium/6849050e74e2843fbd70b33af110d71638aa37bb/data/readme_images/softplus.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm==4.66.1
2 | requests==2.31.0
3 | numpy==1.26.1
4 | pandas==2.1.1
5 | pyarrow==14.0.1
6 | torch>=2.0.0
7 | pytorch-argus==1.0.0
8 | deeplake==3.8.6
9 | ImageHash==4.3.1
10 |
--------------------------------------------------------------------------------
/scripts/download_data.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import zipfile
3 | import argparse
4 | from pathlib import Path
5 |
6 | import deeplake
7 | import numpy as np
8 | import requests
9 | from tqdm import tqdm
10 |
11 | from src import constants
12 |
13 |
14 | def parse_arguments():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("-p", "--path", default=constants.sensorium_dir, type=Path)
17 | return parser.parse_args()
18 |
19 |
20 | if __name__ == "__main__":
21 | args = parse_arguments()
22 |
23 | sensorium_dir = args.path
24 | sensorium_dir.mkdir(parents=True, exist_ok=True)
25 |
26 | for mouse in constants.mice:
27 | file_name = f"{mouse}.zip"
28 | dataset = constants.mouse2dataset[mouse]
29 | url = constants.dataset2url_format[dataset].format(file_name=file_name)
30 | zip_path = sensorium_dir / file_name
31 | mouse_dir = sensorium_dir / mouse
32 |
33 | if mouse_dir.exists():
34 | print(f"Folder '{str(mouse_dir)}' already exists, skip download")
35 | continue
36 |
37 | print(f"Download '{url}' to '{zip_path}'")
38 | zip_path.unlink(missing_ok=True)
39 | with requests.get(url, stream=True) as r:
40 | total_length = int(r.headers.get("Content-Length"))
41 | with tqdm.wrapattr(r.raw, "read", total=total_length, desc="") as raw:
42 | with open(zip_path, 'wb') as output:
43 | shutil.copyfileobj(raw, output)
44 |
45 | print("Unzip", zip_path)
46 | with zipfile.ZipFile(zip_path, 'r') as zip_file:
47 | zip_file.extractall(sensorium_dir)
48 |
49 | print("Delete", zip_path)
50 | zip_path.unlink()
51 | shutil.rmtree(sensorium_dir / "__MACOSX", ignore_errors=True)
52 |
53 | if mouse in constants.new_mice:
54 | continue
55 | for split in constants.unlabeled_splits:
56 | dataset = deeplake.load(f"hub://sinzlab/Sensorium_2023_{mouse}_{split}")
57 | trials_ids = dataset.id.numpy().astype(int).ravel().tolist()
58 | for index, trial_id in enumerate(trials_ids):
59 | responses_path = mouse_dir / "data" / "responses" / f"{trial_id}.npy"
60 | responses = dataset.responses[index].numpy()
61 | np.save(str(responses_path), responses)
62 |
--------------------------------------------------------------------------------
/scripts/ensemble.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from tqdm import tqdm
4 | import numpy as np
5 |
6 | from src.submission import evaluate_folds_predictions, make_submission
7 | from src.data import get_mouse_data
8 | from src import constants
9 |
10 |
11 | def parse_arguments():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("-e", "--experiments", required=True, type=str)
14 | parser.add_argument("-s", "--split", required=True,
15 | choices=["folds"] + constants.unlabeled_splits, type=str)
16 | parser.add_argument("-d", "--dataset", default="new", choices=["new", "old"], type=str)
17 | return parser.parse_args()
18 |
19 |
20 | def ensemble_experiments(experiments: list[str], split: str, dataset: str):
21 | assert len(experiments) > 1
22 | print(f"Ensemble experiments: {experiments=}, {split=}, {dataset=}")
23 | split_dir_name = "out-of-fold" if split == "folds" else split
24 | splits = constants.folds_splits if split == "folds" else [split]
25 | ensemble_dir = constants.predictions_dir / ",".join(experiments) / split_dir_name
26 | for mouse in constants.dataset2mice[dataset]:
27 | ensemble_mouse_dir = ensemble_dir / mouse
28 | print(f"Ensemble mouse: {mouse=}, {str(ensemble_mouse_dir)=}")
29 | ensemble_mouse_dir.mkdir(parents=True, exist_ok=True)
30 | mouse_data = get_mouse_data(mouse=mouse, splits=splits)
31 |
32 | for trial_data in tqdm(mouse_data["trials"]):
33 | pred_filename = f"{trial_data['trial_id']}.npy"
34 | responses_lst = []
35 | for experiment in experiments:
36 | responses = np.load(
37 | str(constants.predictions_dir / experiment / split_dir_name / mouse / pred_filename)
38 | )
39 | responses_lst.append(responses)
40 | blend_responses = np.mean(responses_lst, axis=0)
41 | np.save(str(ensemble_mouse_dir / pred_filename), blend_responses)
42 |
43 |
44 | if __name__ == "__main__":
45 | args = parse_arguments()
46 | experiments_lst = sorted(args.experiments.split(','))
47 | experiment_name = ",".join(experiments_lst)
48 | ensemble_experiments(experiments_lst, args.split, args.dataset)
49 | if args.split == "folds":
50 | evaluate_folds_predictions(experiment_name, args.dataset)
51 | elif args.dataset == "new":
52 | make_submission(experiment_name, args.split)
53 |
--------------------------------------------------------------------------------
/scripts/predict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | from tqdm import tqdm
5 | import numpy as np
6 |
7 | from src.submission import evaluate_folds_predictions, make_submission
8 | from src.utils import get_best_model_path
9 | from src.predictors import Predictor
10 | from src.data import get_mouse_data
11 | from src import constants
12 |
13 |
14 | def parse_arguments():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("-e", "--experiment", required=True, type=str)
17 | parser.add_argument("-s", "--split", required=True,
18 | choices=["folds"] + constants.unlabeled_splits, type=str)
19 | parser.add_argument("-d", "--dataset", default="new", choices=["new", "old"], type=str)
20 | parser.add_argument("--device", default="cuda:0", type=str)
21 | return parser.parse_args()
22 |
23 |
24 | def predict_trial(trial_data: dict, predictor: Predictor, mouse_index: int):
25 | length = trial_data["length"]
26 | video = np.load(trial_data["video_path"])[..., :length]
27 | behavior = np.load(trial_data["behavior_path"])[..., :length]
28 | pupil_center = np.load(trial_data["pupil_center_path"])[..., :length]
29 | responses = predictor.predict_trial(
30 | video=video,
31 | behavior=behavior,
32 | pupil_center=pupil_center,
33 | mouse_index=mouse_index,
34 | )
35 | return responses
36 |
37 |
38 | def predict_mouse_split(mouse: str, split: str,
39 | predictors: list[Predictor], save_dir: Path):
40 | mouse_index = constants.mouse2index[mouse]
41 | print(f"Predict mouse split: {mouse=} {split=} {len(predictors)=} {str(save_dir)=}")
42 | mouse_data = get_mouse_data(mouse=mouse, splits=[split])
43 |
44 | for trial_data in tqdm(mouse_data["trials"]):
45 | responses_lst = []
46 | for predictor in predictors:
47 | responses = predict_trial(trial_data, predictor, mouse_index)
48 | responses_lst.append(responses)
49 | blend_responses = np.mean(responses_lst, axis=0)
50 | np.save(str(save_dir / f"{trial_data['trial_id']}.npy"), blend_responses)
51 |
52 |
53 | def predict_folds(experiment: str, dataset: str, device: str):
54 | print(f"Predict folds: {experiment=}, {dataset=}, {device=}")
55 | for mouse in constants.dataset2mice[dataset]:
56 | mouse_prediction_dir = constants.predictions_dir / experiment / "out-of-fold" / mouse
57 | mouse_prediction_dir.mkdir(parents=True, exist_ok=True)
58 | for fold_split in constants.folds_splits:
59 | model_path = get_best_model_path(constants.experiments_dir / experiment / fold_split)
60 | print("Model path:", str(model_path))
61 | predictor = Predictor(model_path=model_path, device=device, blend_weights="ones")
62 | predict_mouse_split(mouse, fold_split, [predictor], mouse_prediction_dir)
63 |
64 |
65 | def predict_unlabeled_split(experiment: str, split: str, dataset: str, device: str):
66 | print(f"Predict unlabeled split: {experiment=}, {split=}, {dataset=}, {device=}")
67 | predictors = []
68 | for fold_split in constants.folds_splits:
69 | model_path = get_best_model_path(constants.experiments_dir / experiment / fold_split)
70 | print("Model path:", str(model_path))
71 | predictor = Predictor(model_path=model_path, device=device, blend_weights="ones")
72 | predictors.append(predictor)
73 | for mouse in constants.dataset2mice[dataset]:
74 | mouse_prediction_dir = constants.predictions_dir / experiment / split / mouse
75 | mouse_prediction_dir.mkdir(parents=True, exist_ok=True)
76 | predict_mouse_split(mouse, split, predictors, mouse_prediction_dir)
77 |
78 |
79 | if __name__ == "__main__":
80 | args = parse_arguments()
81 |
82 | if args.split == "folds":
83 | predict_folds(args.experiment, args.dataset, args.device)
84 | evaluate_folds_predictions(args.experiment, args.dataset)
85 | elif args.dataset == "new":
86 | predict_unlabeled_split(args.experiment, args.split, args.dataset, args.device)
87 | make_submission(args.experiment, args.split)
88 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | import copy
3 | import json
4 | import argparse
5 | from pathlib import Path
6 | from pprint import pprint
7 | from importlib.machinery import SourceFileLoader
8 |
9 | import torch
10 | from torch.utils.data import DataLoader
11 |
12 | from argus import load_model
13 | from argus.callbacks import (
14 | LoggingToFile,
15 | LoggingToCSV,
16 | CosineAnnealingLR,
17 | Checkpoint,
18 | LambdaLR,
19 | )
20 |
21 | from src.datasets import TrainMouseVideoDataset, ValMouseVideoDataset, ConcatMiceVideoDataset
22 | from src.utils import get_lr, init_weights, get_best_model_path
23 | from src.responses import get_responses_processor
24 | from src.ema import ModelEma, EmaCheckpoint
25 | from src.inputs import get_inputs_processor
26 | from src.metrics import CorrelationMetric
27 | from src.indexes import IndexesGenerator
28 | from src.argus_models import MouseModel
29 | from src.data import get_mouse_data
30 | from src.mixers import CutMix
31 | from src import constants
32 |
33 |
34 | def parse_arguments():
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument("-e", "--experiment", required=True, type=str)
37 | parser.add_argument("-f", "--folds", default="all", type=str)
38 | return parser.parse_args()
39 |
40 |
41 | def train_mouse(config: dict, save_dir: Path, train_splits: list[str], val_splits: list[str]):
42 | config = copy.deepcopy(config)
43 | argus_params = config["argus_params"]
44 |
45 | model = MouseModel(argus_params)
46 |
47 | if config["init_weights"]:
48 | print("Weight initialization")
49 | init_weights(model.nn_module)
50 |
51 | if config["ema_decay"]:
52 | print("EMA decay:", config["ema_decay"])
53 | model.model_ema = ModelEma(model.nn_module, decay=config["ema_decay"])
54 | checkpoint_class = EmaCheckpoint
55 | else:
56 | checkpoint_class = Checkpoint
57 |
58 | if "distill" in config:
59 | distill_params = config["distill"]
60 | distill_experiment_dir = constants.experiments_dir / distill_params["experiment"] / val_splits[0]
61 | distill_model_path = get_best_model_path(distill_experiment_dir)
62 | distill_model = load_model(distill_model_path, device=argus_params["device"])
63 | distill_model.eval()
64 | model.distill_model = distill_model.nn_module
65 | model.distill_ratio = distill_params["ratio"]
66 | print(f"Distillation model {str(distill_model_path)}, ratio {model.distill_ratio}")
67 |
68 | indexes_generator = IndexesGenerator(**argus_params["frame_stack"])
69 | inputs_processor = get_inputs_processor(*argus_params["inputs_processor"])
70 | responses_processor = get_responses_processor(*argus_params["responses_processor"])
71 |
72 | cutmix = CutMix(**config["cutmix"])
73 | train_datasets = []
74 | mouse_epoch_size = config["train_epoch_size"] // constants.num_mice
75 | for mouse in constants.mice:
76 | train_datasets += [
77 | TrainMouseVideoDataset(
78 | mouse_data=get_mouse_data(mouse=mouse, splits=train_splits),
79 | indexes_generator=indexes_generator,
80 | inputs_processor=inputs_processor,
81 | responses_processor=responses_processor,
82 | epoch_size=mouse_epoch_size,
83 | mixer=cutmix,
84 | )
85 | ]
86 | train_dataset = ConcatMiceVideoDataset(train_datasets)
87 | print("Train dataset len:", len(train_dataset))
88 | val_datasets = []
89 | for mouse in constants.mice:
90 | val_datasets += [
91 | ValMouseVideoDataset(
92 | mouse_data=get_mouse_data(mouse=mouse, splits=val_splits),
93 | indexes_generator=indexes_generator,
94 | inputs_processor=inputs_processor,
95 | responses_processor=responses_processor,
96 | )
97 | ]
98 | val_dataset = ConcatMiceVideoDataset(val_datasets)
99 | print("Val dataset len:", len(val_dataset))
100 |
101 | train_loader = DataLoader(
102 | train_dataset,
103 | batch_size=config["batch_size"],
104 | num_workers=config["num_dataloader_workers"],
105 | shuffle=True,
106 | )
107 | val_loader = DataLoader(
108 | val_dataset,
109 | batch_size=config["batch_size"] // argus_params["iter_size"],
110 | num_workers=config["num_dataloader_workers"],
111 | shuffle=False,
112 | )
113 |
114 | for num_epochs, stage in zip(config["num_epochs"], config["stages"]):
115 | callbacks = [
116 | LoggingToFile(save_dir / "log.txt", append=True),
117 | LoggingToCSV(save_dir / "log.csv", append=True),
118 | ]
119 |
120 | num_iterations = (len(train_dataset) // config["batch_size"]) * num_epochs
121 | if stage == "warmup":
122 | callbacks += [
123 | LambdaLR(lambda x: x / num_iterations,
124 | step_on_iteration=True),
125 | ]
126 | elif stage == "train":
127 | checkpoint_format = "model-{epoch:03d}-{val_corr:.6f}.pth"
128 | callbacks += [
129 | checkpoint_class(save_dir, file_format=checkpoint_format, max_saves=1),
130 | CosineAnnealingLR(
131 | T_max=num_iterations,
132 | eta_min=get_lr(config["min_base_lr"], config["batch_size"]),
133 | step_on_iteration=True,
134 | ),
135 | ]
136 |
137 | metrics = [
138 | CorrelationMetric(),
139 | ]
140 |
141 | model.fit(train_loader,
142 | val_loader=val_loader,
143 | num_epochs=num_epochs,
144 | callbacks=callbacks,
145 | metrics=metrics)
146 |
147 |
148 | if __name__ == "__main__":
149 | args = parse_arguments()
150 | print("Experiment:", args.experiment)
151 |
152 | config_path = constants.configs_dir / f"{args.experiment}.py"
153 | if not config_path.exists():
154 | raise RuntimeError(f"Config '{config_path}' is not exists")
155 |
156 | train_config = SourceFileLoader(args.experiment, str(config_path)).load_module().config
157 | print("Experiment config:")
158 | pprint(train_config, sort_dicts=False)
159 |
160 | experiment_dir = constants.experiments_dir / args.experiment
161 | print("Experiment dir:", experiment_dir)
162 | if not experiment_dir.exists():
163 | experiment_dir.mkdir(parents=True, exist_ok=True)
164 | else:
165 | print(f"Folder '{experiment_dir}' already exists.")
166 |
167 | with open(experiment_dir / "train.py", "w") as outfile:
168 | outfile.write(open(__file__).read())
169 |
170 | with open(experiment_dir / "config.json", "w") as outfile:
171 | json.dump(train_config, outfile, indent=4)
172 |
173 | if args.folds == "all":
174 | folds_splits = constants.folds_splits
175 | else:
176 | folds_splits = [f"fold_{fold}" for fold in args.folds.split(",")]
177 |
178 | for fold_split in folds_splits:
179 | fold_experiment_dir = experiment_dir / fold_split
180 |
181 | val_folds_splits = [fold_split]
182 | train_folds_splits = sorted(set(constants.folds_splits) - set(val_folds_splits))
183 |
184 | print(f"Val fold: {val_folds_splits}, train folds: {train_folds_splits}")
185 | print(f"Fold experiment dir: {fold_experiment_dir}")
186 | train_mouse(train_config, fold_experiment_dir, train_folds_splits, val_folds_splits)
187 |
188 | torch.cuda.empty_cache()
189 | time.sleep(12)
190 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 119
3 | exclude = __pycache__,data,notebooks
4 | per-file-ignores = __init__.py:F401
5 |
6 | [mypy]
7 | files = src
8 | ignore_missing_imports = True
9 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | from src.argus_models import MouseModel
2 |
--------------------------------------------------------------------------------
/src/argus_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import argus
4 | from argus.engine import State
5 | from argus.loss import pytorch_losses
6 | from argus.utils import deep_to, deep_detach, deep_chunk
7 |
8 | from src.ema import ModelEma
9 | from src.losses import MicePoissonLoss
10 | from src.models.dwiseneuro import DwiseNeuro
11 |
12 |
13 | class MouseModel(argus.Model):
14 | nn_module = {
15 | "dwiseneuro": DwiseNeuro,
16 | }
17 | loss = {
18 | **pytorch_losses,
19 | "mice_poisson": MicePoissonLoss,
20 | }
21 |
22 | def __init__(self, params: dict):
23 | super().__init__(params)
24 | self.iter_size = int(params.get('iter_size', 1))
25 | self.amp = bool(params.get('amp', False))
26 | self.grad_scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
27 | self.model_ema: ModelEma | None = None
28 | self.distill_model: torch.nn.Module | None = None
29 | self.distill_ratio: float = 0.
30 |
31 | @torch.no_grad()
32 | def add_distill_predictions(self, input, target):
33 | if self.distill_model is not None and self.distill_ratio:
34 | distill_prediction = self.distill_model(input)
35 | target_tensors, mice_weights = target
36 | distill_mask = mice_weights == 0.
37 | distill_weight = (self.distill_ratio / (1. - self.distill_ratio)
38 | * mice_weights.sum() / distill_mask.sum())
39 | for batch_idx, mouse_idx in torch.argwhere(distill_mask):
40 | target_tensors[mouse_idx][batch_idx] = distill_prediction[mouse_idx][batch_idx]
41 | mice_weights[batch_idx, mouse_idx] = distill_weight
42 |
43 | def train_step(self, batch, state: State) -> dict:
44 | self.train()
45 | self.optimizer.zero_grad()
46 |
47 | loss_value = 0
48 | for i, chunk_batch in enumerate(deep_chunk(batch, self.iter_size)):
49 | input, target = deep_to(chunk_batch, self.device, non_blocking=True)
50 | with torch.cuda.amp.autocast(enabled=self.amp):
51 | self.add_distill_predictions(input, target)
52 | prediction = self.nn_module(input)
53 | loss = self.loss(prediction, target)
54 | loss = loss / self.iter_size
55 | self.grad_scaler.scale(loss).backward()
56 | loss_value += loss.item()
57 |
58 | self.grad_scaler.step(self.optimizer)
59 | self.grad_scaler.update()
60 |
61 | if self.model_ema is not None:
62 | self.model_ema.update(self.nn_module)
63 |
64 | prediction = deep_detach(prediction)
65 | target = deep_detach(target)
66 | prediction = self.prediction_transform(prediction)
67 | return {
68 | 'prediction': prediction,
69 | 'target': target,
70 | 'loss': loss_value
71 | }
72 |
73 | def val_step(self, batch, state: State) -> dict:
74 | self.eval()
75 | with torch.no_grad():
76 | input, target = deep_to(batch, device=self.device, non_blocking=True)
77 | if self.model_ema is None:
78 | prediction = self.nn_module(input)
79 | else:
80 | prediction = self.model_ema.ema(input)
81 | loss = self.loss(prediction, target)
82 | prediction = self.prediction_transform(prediction)
83 | return {
84 | 'prediction': prediction,
85 | 'target': target,
86 | 'loss': loss.item()
87 | }
88 |
89 | def predict(self, input, mouse_index: int | None = None):
90 | self._check_predict_ready()
91 | with torch.no_grad():
92 | self.eval()
93 | input = deep_to(input, self.device)
94 | if self.model_ema is None:
95 | prediction = self.nn_module(input, mouse_index)
96 | else:
97 | prediction = self.model_ema.ema(input, mouse_index)
98 | prediction = self.prediction_transform(prediction)
99 | return prediction
100 |
--------------------------------------------------------------------------------
/src/constants.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | work_dir = Path("/workdir")
4 | data_dir = work_dir / "data"
5 | sensorium_dir = data_dir / "sensorium_all_2023"
6 |
7 | configs_dir = work_dir / "configs"
8 | experiments_dir = data_dir / "experiments"
9 | predictions_dir = data_dir / "predictions"
10 |
11 | new_mice = [
12 | "dynamic29515-10-12-Video-9b4f6a1a067fe51e15306b9628efea20",
13 | "dynamic29623-4-9-Video-9b4f6a1a067fe51e15306b9628efea20",
14 | "dynamic29647-19-8-Video-9b4f6a1a067fe51e15306b9628efea20",
15 | "dynamic29712-5-9-Video-9b4f6a1a067fe51e15306b9628efea20",
16 | "dynamic29755-2-8-Video-9b4f6a1a067fe51e15306b9628efea20",
17 | ]
18 | new_num_neurons = [7863, 7908, 8202, 7939, 8122]
19 | old_mice = [
20 | "dynamic29156-11-10-Video-8744edeac3b4d1ce16b680916b5267ce",
21 | "dynamic29228-2-10-Video-8744edeac3b4d1ce16b680916b5267ce",
22 | "dynamic29234-6-9-Video-8744edeac3b4d1ce16b680916b5267ce",
23 | "dynamic29513-3-5-Video-8744edeac3b4d1ce16b680916b5267ce",
24 | "dynamic29514-2-9-Video-8744edeac3b4d1ce16b680916b5267ce",
25 | ]
26 | old_num_neurons = [7440, 7928, 8285, 7671, 7495]
27 | dataset2mice = {
28 | "new": new_mice,
29 | "old": old_mice,
30 | }
31 | mouse2dataset = {m: d for d, mc in dataset2mice.items() for m in mc}
32 | dataset2url_format = {
33 | "new": "https://gin.g-node.org/pollytur/sensorium_2023_dataset/raw/master/{file_name}",
34 | "old": "https://gin.g-node.org/pollytur/Sensorium2023Data/raw/master/{file_name}",
35 | }
36 |
37 | mice = new_mice + old_mice
38 | num_neurons = new_num_neurons + old_num_neurons
39 |
40 | num_mice = len(mice)
41 | index2mouse: dict[int, str] = {index: mouse for index, mouse in enumerate(mice)}
42 | mouse2index: dict[str, int] = {mouse: index for index, mouse in enumerate(mice)}
43 | mouse2num_neurons: dict[str, int] = {mouse: num for mouse, num in zip(mice, num_neurons)}
44 | mice_indexes = list(range(num_mice))
45 |
46 | unlabeled_splits = ["live_test_main", "live_test_bonus", "final_test_main", "final_test_bonus"]
47 |
48 | num_folds = 7
49 | folds = list(range(num_folds))
50 | folds_splits = [f"fold_{fold}" for fold in folds]
51 |
52 | submission_limit_length = 300
53 | submission_skip_first = 50
54 | submission_skip_last = 1
55 |
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from src.phash import calculate_video_phash
4 | from src.utils import get_length_without_nan
5 | from src import constants
6 |
7 |
8 | def create_videos_phashes(mouse: str) -> np.ndarray:
9 | mouse_dir = constants.sensorium_dir / mouse
10 | tiers = np.load(str(mouse_dir / "meta" / "trials" / "tiers.npy"))
11 | phashes = np.zeros(tiers.shape[0], dtype=np.uint64)
12 | for trial_id, tier in enumerate(tiers):
13 | if tier == "none":
14 | continue
15 | video = np.load(str(mouse_dir / "data" / "videos" / f"{trial_id}.npy"))
16 | phashes[trial_id] = calculate_video_phash(video)
17 | return phashes
18 |
19 |
20 | def get_folds_tiers(mouse: str, num_folds: int):
21 | tiers = np.load(str(constants.sensorium_dir / mouse / "meta" / "trials" / "tiers.npy"))
22 | phashes = create_videos_phashes(mouse)
23 | if mouse in constants.new_mice:
24 | trial_ids = np.argwhere((tiers == "train") | (tiers == "oracle")).ravel()
25 | else:
26 | trial_ids = np.argwhere(tiers != "none").ravel()
27 | for trial_id in trial_ids:
28 | fold = int(phashes[trial_id]) % num_folds # group k-fold by video hash
29 | tiers[trial_id] = f"fold_{fold}"
30 | return tiers
31 |
32 |
33 | def get_mouse_data(mouse: str, splits: list[str]) -> dict:
34 | assert mouse in constants.mice
35 | tiers = get_folds_tiers(mouse, constants.num_folds)
36 | mouse_dir = constants.sensorium_dir / mouse
37 | neuron_ids = np.load(str(mouse_dir / "meta" / "neurons" / "unit_ids.npy"))
38 | cell_motor_coords = np.load(str(mouse_dir / "meta" / "neurons" / "cell_motor_coordinates.npy"))
39 |
40 | mouse_data = {
41 | "mouse": mouse,
42 | "splits": splits,
43 | "neuron_ids": neuron_ids,
44 | "num_neurons": neuron_ids.shape[0],
45 | "cell_motor_coordinates": cell_motor_coords,
46 | "trials": [],
47 | }
48 |
49 | for split in splits:
50 | if split in constants.folds_splits:
51 | labeled_split = True
52 | elif split in constants.unlabeled_splits:
53 | labeled_split = False
54 | else:
55 | raise ValueError(f"Unknown data split '{split}'")
56 | trial_ids = np.argwhere(tiers == split).ravel().tolist()
57 |
58 | for trial_id in trial_ids:
59 | behavior_path = str(mouse_dir / "data" / "behavior" / f"{trial_id}.npy")
60 | trial_data = {
61 | "trial_id": trial_id,
62 | "length": get_length_without_nan(np.load(behavior_path)[0]),
63 | "video_path": str(mouse_dir / "data" / "videos" / f"{trial_id}.npy"),
64 | "behavior_path": behavior_path,
65 | "pupil_center_path": str(mouse_dir / "data" / "pupil_center" / f"{trial_id}.npy"),
66 | }
67 | if labeled_split:
68 | response_path = str(mouse_dir / "data" / "responses" / f"{trial_id}.npy")
69 | trial_data["response_path"] = response_path
70 | trial_data["length"] = get_length_without_nan(np.load(response_path)[0])
71 | mouse_data["trials"].append(trial_data)
72 |
73 | return mouse_data
74 |
--------------------------------------------------------------------------------
/src/datasets.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import random
3 |
4 | import numpy as np
5 |
6 | import torch
7 | from torch import nn
8 | from torch import Tensor
9 | from torch.utils.data import Dataset
10 |
11 | from src.mixers import Mixer
12 | from src.utils import set_random_seed
13 | from src.inputs import InputsProcessor
14 | from src.indexes import IndexesGenerator
15 | from src.responses import ResponsesProcessor
16 | from src import constants
17 |
18 |
19 | class MouseVideoDataset(Dataset, metaclass=abc.ABCMeta):
20 | def __init__(self,
21 | mouse_data: dict,
22 | indexes_generator: IndexesGenerator,
23 | inputs_processor: InputsProcessor,
24 | responses_processor: ResponsesProcessor):
25 | self.mouse_data = mouse_data
26 | self.mouse = mouse_data["mouse"]
27 | self.mouse_index = constants.mouse2index[self.mouse]
28 | self.indexes_generator = indexes_generator
29 | self.inputs_processor = inputs_processor
30 | self.responses_processor = responses_processor
31 |
32 | self.trials = self.mouse_data["trials"]
33 | self.num_trials = len(self.trials)
34 | self.trials_lengths = [t["length"] for t in self.trials]
35 | self.num_neurons = self.mouse_data["num_neurons"]
36 |
37 | def get_frames(self, trial_index: int, indexes: list[int]) -> np.ndarray:
38 | frames = np.load(self.trials[trial_index]["video_path"])[..., indexes]
39 | return frames
40 |
41 | def get_responses(self, trial_index: int, indexes: list[int]) -> np.ndarray:
42 | responses = np.load(self.trials[trial_index]["response_path"])[..., indexes]
43 | return responses
44 |
45 | def get_behavior(self, trial_index: int, indexes: list[int]) -> np.ndarray:
46 | behavior = np.load(self.trials[trial_index]["behavior_path"])[..., indexes]
47 | return behavior
48 |
49 | def get_pupil_center(self, trial_index: int, indexes: list[int]) -> np.ndarray:
50 | pupil_center = np.load(self.trials[trial_index]["pupil_center_path"])[..., indexes]
51 | return pupil_center
52 |
53 | def get_inputs_responses(
54 | self,
55 | trial_index: int,
56 | indexes: list[int],
57 | ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
58 | frames = self.get_frames(trial_index, indexes)
59 | responses = self.get_responses(trial_index, indexes)
60 | behavior = self.get_behavior(trial_index, indexes)
61 | pupil_center = self.get_pupil_center(trial_index, indexes)
62 | return frames, behavior, pupil_center, responses
63 |
64 | def process_inputs_responses(self,
65 | frames: np.ndarray,
66 | behavior: np.ndarray,
67 | pupil_center: np.ndarray,
68 | responses: np.ndarray) -> tuple[Tensor, Tensor]:
69 | input_tensor = self.inputs_processor(frames, behavior, pupil_center)
70 | target_tensor = self.responses_processor(responses)
71 | return input_tensor, target_tensor
72 |
73 | @abc.abstractmethod
74 | def __len__(self) -> int:
75 | pass
76 |
77 | @abc.abstractmethod
78 | def get_indexes(self, index: int) -> tuple[int, list[int]]:
79 | pass
80 |
81 | def get_sample_tensors(self, index: int) -> tuple[Tensor, Tensor]:
82 | trial_index, indexes = self.get_indexes(index)
83 | frames, behavior, pupil_center, responses = self.get_inputs_responses(trial_index, indexes)
84 | return self.process_inputs_responses(frames, behavior, pupil_center, responses)
85 |
86 | def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:
87 | return self.get_sample_tensors(index)
88 |
89 |
90 | class TrainMouseVideoDataset(MouseVideoDataset):
91 | def __init__(self,
92 | mouse_data: dict,
93 | indexes_generator: IndexesGenerator,
94 | inputs_processor: InputsProcessor,
95 | responses_processor: ResponsesProcessor,
96 | epoch_size: int,
97 | augmentations: nn.Module | None = None,
98 | mixer: Mixer | None = None):
99 | super().__init__(mouse_data, indexes_generator, inputs_processor, responses_processor)
100 | self.epoch_size = epoch_size
101 | self.augmentations = augmentations
102 | self.mixer = mixer
103 |
104 | def __len__(self) -> int:
105 | return self.epoch_size
106 |
107 | def get_indexes(self, index: int) -> tuple[int, list[int]]:
108 | set_random_seed(index)
109 | trial_index = random.randrange(0, self.num_trials)
110 | num_frames = self.trials[trial_index]["length"]
111 | frame_index = random.randrange(
112 | self.indexes_generator.behind,
113 | num_frames - self.indexes_generator.ahead
114 | )
115 | indexes = self.indexes_generator.make_indexes(frame_index)
116 | return trial_index, indexes
117 |
118 | def get_sample_tensors(self, index: int) -> tuple[Tensor, Tensor]:
119 | frames, responses = super().get_sample_tensors(index)
120 | if self.augmentations is not None:
121 | frames = self.augmentations(frames[None])[0]
122 | return frames, responses
123 |
124 | def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:
125 | sample = self.get_sample_tensors(index)
126 | if self.mixer is not None and self.mixer.use():
127 | random_sample = self.get_sample_tensors(index + 1)
128 | sample = self.mixer(sample, random_sample)
129 | return sample
130 |
131 |
132 | class ValMouseVideoDataset(MouseVideoDataset):
133 | def __init__(self,
134 | mouse_data: dict,
135 | indexes_generator: IndexesGenerator,
136 | inputs_processor: InputsProcessor,
137 | responses_processor: ResponsesProcessor):
138 | super().__init__(mouse_data, indexes_generator, inputs_processor, responses_processor)
139 | self.window_size = self.indexes_generator.width
140 | self.samples_per_trials = [length // self.window_size for length in self.trials_lengths]
141 | self.num_samples = sum(self.samples_per_trials)
142 |
143 | def __len__(self) -> int:
144 | return self.num_samples
145 |
146 | def get_indexes(self, index: int) -> tuple[int, list[int]]:
147 | assert 0 <= index < self.__len__()
148 | trial_sample_index = index
149 | trial_index = 0
150 | for trial_index, num_trial_samples in enumerate(self.samples_per_trials):
151 | if trial_sample_index >= num_trial_samples:
152 | trial_sample_index -= num_trial_samples
153 | else:
154 | break
155 |
156 | frame_index = self.indexes_generator.behind + trial_sample_index * self.window_size
157 | indexes = self.indexes_generator.make_indexes(frame_index)
158 | return trial_index, indexes
159 |
160 |
161 | class ConcatMiceVideoDataset(Dataset):
162 | def __init__(self, mice_datasets: list[MouseVideoDataset]):
163 | self.mice_indexes = [d.mouse_index for d in mice_datasets]
164 | assert self.mice_indexes == constants.mice_indexes
165 | self.mice_datasets = mice_datasets
166 | self.samples_per_dataset = [len(d) for d in mice_datasets]
167 | self.num_samples = sum(self.samples_per_dataset)
168 |
169 | def __len__(self):
170 | return self.num_samples
171 |
172 | def construct_mice_sample(
173 | self, mouse_index: int, mouse_sample: tuple[Tensor, Tensor]
174 | ) -> tuple[Tensor, tuple[list[Tensor], Tensor]]:
175 | input_tensor, target_tensor = mouse_sample
176 | target_tensors = []
177 | for index in self.mice_indexes:
178 | if index == mouse_index:
179 | target_tensors.append(target_tensor)
180 | else:
181 | temporal_shape = [target_tensor.shape[-1]] if len(target_tensor.shape) == 2 else []
182 | target_tensors.append(
183 | torch.zeros(constants.num_neurons[index], *temporal_shape, dtype=torch.float32)
184 | )
185 | mice_weights = torch.zeros(constants.num_mice, dtype=torch.float32)
186 | mice_weights[mouse_index] = 1.0
187 | return input_tensor, (target_tensors, mice_weights)
188 |
189 | def __getitem__(self, index: int) -> tuple[Tensor, tuple[list[Tensor], Tensor]]:
190 | assert 0 <= index < self.__len__()
191 | sample_index = index
192 | mouse_index = 0
193 | for mouse_index, num_trial_samples in enumerate(self.samples_per_dataset):
194 | if sample_index >= num_trial_samples:
195 | sample_index -= num_trial_samples
196 | else:
197 | break
198 | mouse_sample = self.mice_datasets[mouse_index][sample_index]
199 | mice_sample = self.construct_mice_sample(mouse_index, mouse_sample)
200 | return mice_sample
201 |
--------------------------------------------------------------------------------
/src/ema.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn.parallel import DataParallel, DistributedDataParallel
6 |
7 | from argus.utils import deep_to
8 | from argus.engine import State
9 | from argus.callbacks import Checkpoint
10 |
11 |
12 | class ModelEma(nn.Module):
13 | """ Model Exponential Moving Average V2
14 |
15 | Keep a moving average of everything in the model state_dict (parameters and buffers).
16 | V2 of this module is simpler, it does not match params/buffers based on name but simply
17 | iterates in order. It works with torchscript (JIT of full model).
18 |
19 | This is intended to allow functionality like
20 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
21 |
22 | A smoothed version of the weights is necessary for some training schemes to perform well.
23 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
24 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
25 | smoothing of weights to match results. Pay attention to the decay constant you are using
26 | relative to your update count per epoch.
27 |
28 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
29 | disable validation of the EMA weights. Validation will have to be done manually in a separate
30 | process, or after the training stops converging.
31 |
32 | This class is sensitive where it is initialized in the sequence of model init,
33 | GPU assignment and distributed training wrappers.
34 |
35 | Source: https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/model_ema.py
36 | """
37 | def __init__(self, model, decay=0.9999, device=None):
38 | super().__init__()
39 | # make a copy of the model for accumulating moving average of weights
40 | self.ema = deepcopy(model)
41 | self.ema.eval()
42 | self.decay = decay
43 | self.device = device # perform ema on different device from model if set
44 | if self.device is not None:
45 | self.ema.to(device=device)
46 |
47 | def _update(self, model, update_fn):
48 | with torch.no_grad():
49 | for ema_v, model_v in zip(self.ema.state_dict().values(), model.state_dict().values()):
50 | if self.device is not None:
51 | model_v = model_v.to(device=self.device)
52 | ema_v.copy_(update_fn(ema_v, model_v))
53 |
54 | def update(self, model):
55 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
56 |
57 | def set(self, model):
58 | self._update(model, update_fn=lambda e, m: m)
59 |
60 |
61 | class EmaCheckpoint(Checkpoint):
62 | def save_model(self, state: State, file_path):
63 | nn_module = state.model.model_ema.ema
64 | if isinstance(nn_module, (DataParallel, DistributedDataParallel)):
65 | nn_module = nn_module.module
66 |
67 | torch_state = {
68 | 'model_name': state.model.__class__.__name__,
69 | 'params': state.model.params,
70 | 'nn_state_dict': deep_to(nn_module.state_dict(), 'cpu'),
71 | }
72 | torch.save(torch_state, file_path)
73 | state.logger.info(f"Model saved to '{file_path}'")
74 |
--------------------------------------------------------------------------------
/src/indexes.py:
--------------------------------------------------------------------------------
1 | class IndexesGenerator:
2 | def __init__(self, size: int, step: int, position: str = "last"):
3 | self.size = size
4 | self.step = step
5 |
6 | if position == "first":
7 | self.behind = 0
8 | self.ahead = self.size - 1
9 | elif position == "middle":
10 | self.behind = self.size // 2
11 | self.ahead = self.size - self.behind - 1
12 | elif position == "last":
13 | self.behind = self.size - 1
14 | self.ahead = 0
15 | else:
16 | raise ValueError(
17 | f"Index position value should be one of {'first', 'middle', 'last'}"
18 | )
19 | self.behind *= self.step
20 | self.ahead *= self.step
21 | self.width = self.behind + self.ahead + 1
22 |
23 | def make_indexes(self, index: int) -> list[int]:
24 | return list(
25 | range(
26 | index - self.behind,
27 | index + self.ahead + 1,
28 | self.step,
29 | )
30 | )
31 |
32 | def clip_index(self, index: int, length: int, save_zone: int = 0) -> int:
33 | behind_frames = self.behind + save_zone
34 | ahead_frames = self.ahead + save_zone
35 | if index < behind_frames:
36 | index = behind_frames
37 | elif index >= length - ahead_frames:
38 | index = length - ahead_frames - 1
39 | return index
40 |
--------------------------------------------------------------------------------
/src/inputs.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Type
3 |
4 | import numpy as np
5 |
6 | import torch
7 |
8 |
9 | class InputsProcessor(metaclass=abc.ABCMeta):
10 | @abc.abstractmethod
11 | def __call__(self, frames: np.ndarray, behavior: np.ndarray, pupil_center: np.ndarray) -> torch.Tensor:
12 | pass
13 |
14 |
15 | class StackInputsProcessor(InputsProcessor):
16 | def __init__(self,
17 | size: tuple[int, int],
18 | pad_fill_value: int = 0):
19 | self.size = size
20 | self.pad_fill_value = pad_fill_value
21 |
22 | def __call__(self, frames: np.ndarray, behavior: np.ndarray, pupil_center: np.ndarray) -> torch.Tensor:
23 | length = frames.shape[-1]
24 | input_array = np.full((5, length, self.size[1], self.size[0]), self.pad_fill_value, dtype=np.float32)
25 |
26 | frames = np.transpose(frames.astype(np.float32), (2, 0, 1))
27 | height, width = frames.shape[-2:]
28 | height_start = (self.size[1] - height) // 2
29 | width_start = (self.size[0] - width) // 2
30 | input_array[0, :, height_start: height_start + height, width_start: width_start + width] = frames
31 |
32 | input_array[1:3] = behavior[:, :, None, None]
33 | input_array[3:] = pupil_center[:, :, None, None]
34 |
35 | tensor_frames = torch.from_numpy(input_array)
36 | return tensor_frames
37 |
38 |
39 | _INPUTS_PROCESSOR_REGISTRY: dict[str, Type[InputsProcessor]] = dict(
40 | stack_inputs=StackInputsProcessor,
41 | )
42 |
43 |
44 | def get_inputs_processor(name: str, processor_params: dict) -> InputsProcessor:
45 | assert name in _INPUTS_PROCESSOR_REGISTRY
46 | return _INPUTS_PROCESSOR_REGISTRY[name](**processor_params)
47 |
--------------------------------------------------------------------------------
/src/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class MicePoissonLoss(nn.Module):
6 | def __init__(self, log_input: bool = False, full: bool = False, eps: float = 1e-8):
7 | super().__init__()
8 | self.poisson = nn.PoissonNLLLoss(log_input=log_input, full=full, eps=eps, reduction="none")
9 |
10 | def forward(self, inputs, targets):
11 | target_tensors, mice_weights = targets
12 | mice_weights = mice_weights / mice_weights.sum()
13 | loss_value = 0
14 | for mouse_index, (input_tensor, target_tensor) in enumerate(zip(inputs, target_tensors)):
15 | mouse_weights = mice_weights[..., mouse_index]
16 | mask = mouse_weights != 0.0
17 | if torch.any(mask):
18 | loss = self.poisson(input_tensor[mask], target_tensor[mask])
19 | loss *= mouse_weights[mask].view(-1, *[1] * (len(loss.shape) - 1))
20 | loss_value += loss.sum()
21 | return loss_value
22 |
--------------------------------------------------------------------------------
/src/metrics.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import Union, Tuple
3 |
4 | import numpy as np
5 |
6 | import torch
7 |
8 | from argus.metrics import Metric
9 |
10 |
11 | def corr(
12 | y1: np.ndarray, y2: np.ndarray, axis: Union[None, int, Tuple[int]] = -1, eps: float = 1e-8, **kwargs
13 | ) -> np.ndarray:
14 | """
15 | Compute the correlation between two NumPy arrays along the specified dimension(s).
16 |
17 | Args:
18 | y1: first NumPy array
19 | y2: second NumPy array
20 | axis: dimension(s) along which the correlation is computed. Any valid NumPy
21 | axis spec works here
22 | eps: offset to the standard deviation to avoid exploding the correlation due
23 | to small division (default 1e-8)
24 | **kwargs: passed to final numpy.mean operation over standardized y1 * y2
25 |
26 | Returns: correlation array
27 | """
28 |
29 | y1 = (y1 - y1.mean(axis=axis, keepdims=True)) / (y1.std(axis=axis, keepdims=True, ddof=0) + eps)
30 | y2 = (y2 - y2.mean(axis=axis, keepdims=True)) / (y2.std(axis=axis, keepdims=True, ddof=0) + eps)
31 | return (y1 * y2).mean(axis=axis, **kwargs)
32 |
33 |
34 | class CorrelationMetric(Metric):
35 | name: str = "corr"
36 | better: str = "max"
37 |
38 | def __init__(self):
39 | super().__init__()
40 | self.predictions = defaultdict(list)
41 | self.targets = defaultdict(list)
42 | self.weights = defaultdict(list)
43 |
44 | def reset(self):
45 | self.predictions = defaultdict(list)
46 | self.targets = defaultdict(list)
47 | self.weights = defaultdict(list)
48 |
49 | def update(self, step_output: dict):
50 | pred_tensors = step_output["prediction"]
51 | target_tensors, mice_weights = step_output["target"]
52 |
53 | for mouse_index, (pred, target) in enumerate(zip(pred_tensors, target_tensors)):
54 | mouse_weight = mice_weights[..., mouse_index]
55 | mask = mouse_weight != 0.0
56 | if torch.any(mask):
57 | pred, target = pred[mask], target[mask]
58 |
59 | if len(target.shape) == 3:
60 | pred = torch.transpose(pred, 1, 2)
61 | pred = pred.reshape(-1, pred.shape[-1])
62 | target = torch.transpose(target, 1, 2)
63 | target = target.reshape(-1, target.shape[-1])
64 |
65 | self.predictions[mouse_index].append(pred.cpu().numpy())
66 | self.targets[mouse_index].append(target.cpu().numpy())
67 |
68 | def compute(self):
69 | mice_corr = dict()
70 | for mouse_index in self.predictions:
71 | targets = np.concatenate(self.targets[mouse_index], axis=0)
72 | predictions = np.concatenate(self.predictions[mouse_index], axis=0)
73 | mice_corr[mouse_index] = corr(predictions, targets, axis=0).mean()
74 | return mice_corr
75 |
76 | def epoch_complete(self, state):
77 | with torch.no_grad():
78 | mice_corr = self.compute()
79 | name_prefix = f"{state.phase}_" if state.phase else ''
80 | for mouse_index, mouse_corr in mice_corr.items():
81 | state.metrics[name_prefix + self.name + f"_mouse_{mouse_index}"] = mouse_corr
82 | state.metrics[name_prefix + self.name] = np.mean(list(mice_corr.values()))
83 |
--------------------------------------------------------------------------------
/src/mixers.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import torch
4 |
5 | import numpy as np
6 |
7 | SampleType = tuple[torch.Tensor, torch.Tensor]
8 |
9 |
10 | class Mixer(metaclass=abc.ABCMeta):
11 | def __init__(self, prob: float):
12 | self.prob = prob
13 |
14 | def use(self):
15 | return np.random.random() < self.prob
16 |
17 | @abc.abstractmethod
18 | def __call__(self, sample1: SampleType, sample2: SampleType) -> SampleType:
19 | pass
20 |
21 |
22 | class Mixup(Mixer):
23 | def __init__(self, alpha: float = 0.4, prob: float = 1.0):
24 | super().__init__(prob)
25 | self.alpha = alpha
26 |
27 | def __call__(self, sample1: SampleType, sample2: SampleType) -> SampleType:
28 | inputs1, target1 = sample1
29 | inputs2, target2 = sample2
30 | lam = np.random.beta(self.alpha, self.alpha)
31 | inputs = (1 - lam) * inputs1 + lam * inputs2
32 | target = (1 - lam) * target1 + lam * target2
33 | return inputs, target
34 |
35 |
36 | def rand_bbox(height: int, width: int, lam: float):
37 | cut_rat = np.sqrt(lam)
38 | cut_w = (width * cut_rat).astype(int)
39 | cut_h = (height * cut_rat).astype(int)
40 |
41 | cx = np.random.randint(width)
42 | cy = np.random.randint(height)
43 |
44 | bbx1 = np.clip(cx - cut_w // 2, 0, width)
45 | bby1 = np.clip(cy - cut_h // 2, 0, height)
46 | bbx2 = np.clip(cx + cut_w // 2, 0, width)
47 | bby2 = np.clip(cy + cut_h // 2, 0, height)
48 |
49 | return bbx1, bby1, bbx2, bby2
50 |
51 |
52 | class CutMix(Mixer):
53 | def __init__(self, alpha: float = 1.0, prob: float = 1.0):
54 | super().__init__(prob)
55 | self.alpha = alpha
56 |
57 | def __call__(self, sample1: SampleType, sample2: SampleType) -> SampleType:
58 | inputs1, target1 = sample1
59 | inputs2, target2 = sample2
60 | inputs = inputs1.clone().detach()
61 | lam = np.random.beta(self.alpha, self.alpha)
62 | h, w = inputs1.shape[-2:]
63 | bbx1, bby1, bbx2, bby2 = rand_bbox(h, w, lam)
64 | inputs[..., bbx1: bbx2, bby1: bby2] = inputs2[..., bbx1: bbx2, bby1: bby2]
65 | lam = (bbx2 - bbx1) * (bby2 - bby1) / (h * w)
66 | target = (1 - lam) * target1 + lam * target2
67 | return inputs, target
68 |
69 |
70 | class RandomChoiceMixer(Mixer):
71 | def __init__(self, mixers: list[Mixer], choice_probs: list[float], prob: float = 1.0):
72 | super().__init__(prob)
73 | self.mixers = mixers
74 | self.choice_probs = choice_probs
75 |
76 | def __call__(self, sample1: SampleType, sample2: SampleType) -> SampleType:
77 | mixer_index = np.random.choice(range(len(self.mixers)), p=self.choice_probs)
78 | mixer = self.mixers[mixer_index]
79 | return mixer(sample1, sample2)
80 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lRomul/sensorium/6849050e74e2843fbd70b33af110d71638aa37bb/src/models/__init__.py
--------------------------------------------------------------------------------
/src/models/dwiseneuro.py:
--------------------------------------------------------------------------------
1 | import math
2 | import functools
3 | from typing import Callable
4 |
5 | import torch
6 | from torch import nn
7 |
8 |
9 | class BatchNormAct(nn.Module):
10 | def __init__(self,
11 | num_features: int,
12 | bn_layer: Callable = nn.BatchNorm3d,
13 | act_layer: Callable = nn.ReLU,
14 | apply_act: bool = True):
15 | super().__init__()
16 | self.bn = bn_layer(num_features)
17 | self.act = act_layer() if apply_act else nn.Identity()
18 |
19 | def forward(self, x):
20 | x = self.bn(x)
21 | x = self.act(x)
22 | return x
23 |
24 |
25 | class SqueezeExcite3d(nn.Module):
26 | def __init__(self,
27 | in_features: int,
28 | reduce_ratio: int = 16,
29 | act_layer: Callable = nn.ReLU,
30 | gate_layer: Callable = nn.Sigmoid):
31 | super().__init__()
32 | rd_channels = in_features // reduce_ratio
33 | self.conv_reduce = nn.Conv3d(in_features, rd_channels, (1, 1, 1), bias=True)
34 | self.act1 = act_layer()
35 | self.conv_expand = nn.Conv3d(rd_channels, in_features, (1, 1, 1), bias=True)
36 | self.gate = gate_layer()
37 |
38 | def forward(self, x):
39 | x_se = x.mean((2, 3, 4), keepdim=True)
40 | x_se = self.conv_reduce(x_se)
41 | x_se = self.act1(x_se)
42 | x_se = self.conv_expand(x_se)
43 | return x * self.gate(x_se)
44 |
45 |
46 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
47 | if drop_prob == 0. or not training:
48 | return x
49 | keep_prob = 1 - drop_prob
50 | shape = (x.shape[0],) + (1,) * (x.ndim - 1)
51 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
52 | if keep_prob > 0. and scale_by_keep:
53 | random_tensor.div_(keep_prob)
54 | return x * random_tensor
55 |
56 |
57 | class DropPath(nn.Module):
58 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
59 | super(DropPath, self).__init__()
60 | self.drop_prob = drop_prob
61 | self.scale_by_keep = scale_by_keep
62 |
63 | def forward(self, x):
64 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
65 |
66 | def extra_repr(self):
67 | return f"drop_prob={round(self.drop_prob,3):0.3f}"
68 |
69 |
70 | class InvertedResidual3d(nn.Module):
71 | def __init__(self,
72 | in_features: int,
73 | out_features: int,
74 | spatial_kernel: int = 3,
75 | temporal_kernel: int = 3,
76 | spatial_stride: int = 1,
77 | expansion_ratio: int = 3,
78 | se_reduce_ratio: int = 16,
79 | act_layer: Callable = nn.ReLU,
80 | bn_layer: Callable = nn.BatchNorm3d,
81 | drop_path_rate: float = 0.,
82 | bias: bool = False):
83 | super().__init__()
84 | self.spatial_stride = spatial_stride
85 | self.out_features = out_features
86 | mid_features = in_features * expansion_ratio
87 | stride = (1, spatial_stride, spatial_stride)
88 |
89 | # Point-wise expansion
90 | self.conv_pw = nn.Sequential(
91 | nn.Conv3d(in_features, mid_features, (1, 1, 1), bias=bias),
92 | BatchNormAct(mid_features, bn_layer=bn_layer, act_layer=act_layer),
93 | )
94 |
95 | # Spatial depth-wise convolution
96 | spatial_padding = spatial_kernel // 2
97 | self.spat_covn_dw = nn.Sequential(
98 | nn.Conv3d(mid_features, mid_features, (1, spatial_kernel, spatial_kernel),
99 | stride=stride, padding=(0, spatial_padding, spatial_padding),
100 | groups=mid_features, bias=bias),
101 | BatchNormAct(mid_features, bn_layer=bn_layer, act_layer=act_layer),
102 | )
103 |
104 | # Temporal depth-wise convolution
105 | temporal_padding = temporal_kernel // 2
106 | self.temp_covn_dw = nn.Sequential(
107 | nn.Conv3d(mid_features, mid_features, (temporal_kernel, 1, 1),
108 | stride=(1, 1, 1), padding=(temporal_padding, 0, 0),
109 | groups=mid_features, bias=bias),
110 | BatchNormAct(mid_features, bn_layer=bn_layer, act_layer=act_layer),
111 | )
112 |
113 | # Squeeze-and-excitation
114 | self.se = SqueezeExcite3d(mid_features, act_layer=act_layer, reduce_ratio=se_reduce_ratio)
115 |
116 | # Point-wise linear projection
117 | self.conv_pwl = nn.Sequential(
118 | nn.Conv3d(mid_features, out_features, (1, 1, 1), bias=bias),
119 | BatchNormAct(out_features, bn_layer=bn_layer, apply_act=False),
120 | )
121 |
122 | self.drop_path = DropPath(drop_prob=drop_path_rate)
123 | self.bn_sc = BatchNormAct(out_features, bn_layer=bn_layer, apply_act=False)
124 |
125 | def interpolate_shortcut(self, shortcut):
126 | _, c, t, h, w = shortcut.shape
127 | if self.spatial_stride > 1:
128 | size = (t, math.ceil(h / self.spatial_stride), math.ceil(w / self.spatial_stride))
129 | shortcut = nn.functional.interpolate(shortcut, size=size, mode="nearest")
130 | if c != self.out_features:
131 | tile_dims = (1, math.ceil(self.out_features / c), 1, 1, 1)
132 | shortcut = torch.tile(shortcut, tile_dims)[:, :self.out_features]
133 | shortcut = self.bn_sc(shortcut)
134 | return shortcut
135 |
136 | def forward(self, x):
137 | shortcut = x
138 | x = self.conv_pw(x)
139 | x = self.spat_covn_dw(x)
140 | x = self.temp_covn_dw(x)
141 | x = self.se(x)
142 | x = self.conv_pwl(x)
143 | x = self.drop_path(x) + self.interpolate_shortcut(shortcut)
144 | return x
145 |
146 |
147 | class PositionalEncoding3d(nn.Module):
148 | def __init__(self, channels: int):
149 | super().__init__()
150 | self.orig_channels = channels
151 | channels = math.ceil(channels / 6) * 2
152 | if channels % 2:
153 | channels += 1
154 | self.channels = channels
155 | inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
156 | self.register_buffer("inv_freq", inv_freq)
157 | self.register_buffer("cached_encoding", None, persistent=False)
158 |
159 | def get_emb(self, sin_inp):
160 | emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=0)
161 | return torch.flatten(emb, 0, 1)
162 |
163 | def create_cached_encoding(self, tensor):
164 | _, orig_ch, x, y, z = tensor.shape
165 | assert orig_ch == self.orig_channels
166 | self.cached_encoding = None
167 | pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
168 | pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
169 | pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type())
170 | sin_inp_x = torch.einsum("i,j->ij", self.inv_freq, pos_x)
171 | sin_inp_y = torch.einsum("i,j->ij", self.inv_freq, pos_y)
172 | sin_inp_z = torch.einsum("i,j->ij", self.inv_freq, pos_z)
173 | emb_x = self.get_emb(sin_inp_x).unsqueeze(-1).unsqueeze(-1)
174 | emb_y = self.get_emb(sin_inp_y).unsqueeze(1).unsqueeze(-1)
175 | emb_z = self.get_emb(sin_inp_z).unsqueeze(1).unsqueeze(1)
176 | emb = torch.zeros((self.channels * 3, x, y, z), dtype=tensor.dtype, device=tensor.device)
177 | emb[:self.channels] = emb_x
178 | emb[self.channels: 2 * self.channels] = emb_y
179 | emb[2 * self.channels:] = emb_z
180 | emb = emb[None, :self.orig_channels].contiguous()
181 | self.cached_encoding = emb
182 | return emb
183 |
184 | def forward(self, x):
185 | if len(x.shape) != 5:
186 | raise RuntimeError("The input tensor has to be 5D")
187 |
188 | cached_encoding = self.cached_encoding
189 | if cached_encoding is None or cached_encoding.shape[1:] != x.shape[1:]:
190 | cached_encoding = self.create_cached_encoding(x)
191 |
192 | return x + cached_encoding.expand_as(x)
193 |
194 |
195 | class ShuffleLayer(nn.Module):
196 | def __init__(self,
197 | in_features: int,
198 | out_features: int,
199 | groups: int = 1,
200 | act_layer: Callable = nn.ReLU,
201 | bn_layer: Callable = nn.BatchNorm1d,
202 | drop_path_rate: float = 0.):
203 | super().__init__()
204 | self.in_features = in_features
205 | self.out_features = out_features
206 | self.groups = groups
207 | self.conv = nn.Conv1d(in_features, out_features, (1,), groups=groups, bias=False)
208 | self.bn = BatchNormAct(out_features, bn_layer=bn_layer, act_layer=act_layer)
209 | self.drop_path = DropPath(drop_prob=drop_path_rate)
210 | self.bn_sc = BatchNormAct(out_features, bn_layer=bn_layer, apply_act=False)
211 |
212 | def shuffle_channels(self, x):
213 | if self.groups > 1:
214 | # Shuffle channels between groups
215 | b, c, t = x.shape
216 | x = x.view(b, self.groups, -1, t)
217 | x = torch.transpose(x, 1, 2)
218 | x = x.reshape(b, -1, t)
219 | return x
220 |
221 | def tile_shortcut(self, shortcut):
222 | if self.in_features != self.out_features:
223 | tile_dims = (1, math.ceil(self.out_features / self.in_features), 1)
224 | shortcut = torch.tile(shortcut, tile_dims)[:, :self.out_features]
225 | shortcut = self.bn_sc(shortcut)
226 | return shortcut
227 |
228 | def forward(self, x):
229 | shortcut = x
230 | x = self.conv(x)
231 | x = self.bn(x)
232 | x = self.shuffle_channels(x)
233 | x = self.drop_path(x) + self.tile_shortcut(shortcut)
234 | return x
235 |
236 |
237 | class Cortex(nn.Module):
238 | def __init__(self,
239 | in_features: int,
240 | features: tuple[int, ...],
241 | groups: int = 1,
242 | act_layer: Callable = nn.ReLU,
243 | bn_layer: Callable = nn.BatchNorm1d,
244 | drop_path_rate: float = 0.):
245 | super().__init__()
246 | self.layers = nn.Sequential()
247 | prev_num_features = in_features
248 | for num_features in features:
249 | self.layers.append(
250 | ShuffleLayer(
251 | in_features=prev_num_features,
252 | out_features=num_features,
253 | groups=groups,
254 | act_layer=act_layer,
255 | bn_layer=bn_layer,
256 | drop_path_rate=drop_path_rate,
257 | )
258 | )
259 | prev_num_features = num_features
260 |
261 | def forward(self, x):
262 | x = self.layers(x)
263 | return x
264 |
265 |
266 | class Readout(nn.Module):
267 | def __init__(self,
268 | in_features: int,
269 | out_features: int,
270 | groups: int = 1,
271 | softplus_beta: float = 1.0,
272 | drop_rate: float = 0.):
273 | super().__init__()
274 | self.out_features = out_features
275 | self.layer = nn.Sequential(
276 | nn.Dropout1d(p=drop_rate),
277 | nn.Conv1d(in_features,
278 | math.ceil(out_features / groups) * groups, (1,),
279 | groups=groups, bias=True),
280 | )
281 | self.gate = nn.Softplus(beta=softplus_beta) # type: ignore
282 |
283 | def forward(self, x):
284 | x = self.layer(x)
285 | x = x[:, :self.out_features]
286 | x = self.gate(x)
287 | return x
288 |
289 |
290 | class DepthwiseCore(nn.Module):
291 | def __init__(self,
292 | in_channels: int = 1,
293 | features: tuple[int, ...] = (64, 128, 256, 512),
294 | spatial_strides: tuple[int, ...] = (2, 2, 2, 2),
295 | spatial_kernel: int = 3,
296 | temporal_kernel: int = 3,
297 | expansion_ratio: int = 3,
298 | se_reduce_ratio: int = 16,
299 | act_layer: Callable = nn.ReLU,
300 | bn_layer: Callable = nn.BatchNorm3d,
301 | drop_path_rate: float = 0.):
302 | super().__init__()
303 | num_blocks = len(features)
304 | assert num_blocks and num_blocks == len(spatial_strides)
305 | next_num_features = features[0]
306 | self.stem = nn.Sequential(
307 | nn.Conv3d(in_channels, next_num_features, (1, 1, 1), bias=False),
308 | BatchNormAct(next_num_features, bn_layer=bn_layer, apply_act=False),
309 | )
310 |
311 | blocks = []
312 | for block_index in range(num_blocks):
313 | num_features = features[block_index]
314 | spatial_stride = spatial_strides[block_index]
315 | if block_index < num_blocks - 1:
316 | next_num_features = features[block_index + 1]
317 | block_drop_path_rate = drop_path_rate * block_index / num_blocks
318 |
319 | blocks += [
320 | PositionalEncoding3d(num_features),
321 | InvertedResidual3d(
322 | num_features,
323 | next_num_features,
324 | spatial_kernel=spatial_kernel,
325 | temporal_kernel=temporal_kernel,
326 | spatial_stride=spatial_stride,
327 | expansion_ratio=expansion_ratio,
328 | se_reduce_ratio=se_reduce_ratio,
329 | act_layer=act_layer,
330 | bn_layer=bn_layer,
331 | drop_path_rate=block_drop_path_rate,
332 | bias=False,
333 | )
334 | ]
335 | self.blocks = nn.Sequential(*blocks)
336 |
337 | def forward(self, x):
338 | x = self.stem(x)
339 | x = self.blocks(x)
340 | return x
341 |
342 |
343 | class DwiseNeuro(nn.Module):
344 | def __init__(self,
345 | readout_outputs: tuple[int, ...],
346 | in_channels: int = 5,
347 | core_features: tuple[int, ...] = (64, 64, 64, 64, 128, 128, 128, 256, 256),
348 | spatial_strides: tuple[int, ...] = (2, 1, 1, 1, 2, 1, 1, 2, 1),
349 | spatial_kernel: int = 3,
350 | temporal_kernel: int = 5,
351 | expansion_ratio: int = 6,
352 | se_reduce_ratio: int = 32,
353 | cortex_features: tuple[int, ...] = (1024, 2048, 4096),
354 | groups: int = 2,
355 | softplus_beta: float = 0.07,
356 | drop_rate: float = 0.4,
357 | drop_path_rate: float = 0.1):
358 | super().__init__()
359 | act_layer = functools.partial(nn.SiLU, inplace=True)
360 |
361 | self.core = DepthwiseCore(
362 | in_channels=in_channels,
363 | features=core_features,
364 | spatial_strides=spatial_strides,
365 | spatial_kernel=spatial_kernel,
366 | temporal_kernel=temporal_kernel,
367 | expansion_ratio=expansion_ratio,
368 | se_reduce_ratio=se_reduce_ratio,
369 | act_layer=act_layer,
370 | bn_layer=nn.BatchNorm3d,
371 | drop_path_rate=drop_path_rate,
372 | )
373 |
374 | self.pool = nn.AdaptiveAvgPool3d((None, 1, 1))
375 |
376 | self.cortex = Cortex(
377 | in_features=core_features[-1],
378 | features=cortex_features,
379 | groups=groups,
380 | act_layer=act_layer,
381 | bn_layer=nn.BatchNorm1d,
382 | drop_path_rate=drop_path_rate,
383 | )
384 |
385 | self.readouts = nn.ModuleList()
386 | for readout_output in readout_outputs:
387 | self.readouts.append(
388 | Readout(
389 | in_features=cortex_features[-1],
390 | out_features=readout_output,
391 | groups=groups,
392 | softplus_beta=softplus_beta,
393 | drop_rate=drop_rate,
394 | )
395 | )
396 |
397 | def forward(self, x: torch.Tensor, index: int | None = None) -> list[torch.Tensor] | torch.Tensor:
398 | # Input shape: (batch, channel, time, height, width), e.g. (32, 5, 16, 64, 64)
399 | x = self.core(x) # (32, 256, 16, 8, 8)
400 | x = self.pool(x).squeeze(-1).squeeze(-1) # (32, 256, 16)
401 | x = self.cortex(x) # (32, 4096, 16)
402 | if index is None:
403 | return [readout(x) for readout in self.readouts]
404 | else:
405 | return self.readouts[index](x) # (32, neurons, 16)
406 |
--------------------------------------------------------------------------------
/src/phash.py:
--------------------------------------------------------------------------------
1 | import imagehash
2 | import numpy as np
3 | from PIL import Image
4 |
5 | from src.utils import get_length_without_nan
6 |
7 |
8 | def binary_array_to_int(arr: np.ndarray) -> int:
9 | bit_string = ''.join(str(b) for b in 1 * arr.flatten())
10 | return int(bit_string, 2)
11 |
12 |
13 | def calculate_frame_phash(frame: np.ndarray) -> int:
14 | frame = Image.fromarray(frame.astype(np.uint8), 'L')
15 | phash = imagehash.phash(frame).hash
16 | return binary_array_to_int(phash.ravel())
17 |
18 |
19 | def calculate_video_phash(video: np.ndarray, num_hash_frames: int = 5) -> int:
20 | length = get_length_without_nan(video[0, 0])
21 | assert length >= num_hash_frames
22 | step = length // num_hash_frames
23 | video_hash: int = 0
24 | for frame_index in range(step // 2, length, step)[:num_hash_frames]:
25 | video_hash ^= calculate_frame_phash(video[..., frame_index])
26 | return video_hash
27 |
--------------------------------------------------------------------------------
/src/predictors.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import argus
7 |
8 | from src.indexes import IndexesGenerator
9 | from src.inputs import get_inputs_processor
10 | from src.argus_models import MouseModel
11 | from src import constants
12 |
13 |
14 | def get_blend_weights(name: str, size: int):
15 | if name == "ones":
16 | return np.ones(size, dtype=np.float32)
17 | elif name == "linear":
18 | return np.linspace(0, 1, num=size)
19 | else:
20 | raise ValueError(f"Blend weights '{name}' is not supported")
21 |
22 |
23 | class Predictor:
24 | def __init__(self, model_path: Path | str, device: str = "cuda:0", blend_weights="ones"):
25 | self.model: MouseModel = argus.load_model(model_path, device=device, optimizer=None, loss=None)
26 | self.model.eval()
27 | self.inputs_processor = get_inputs_processor(*self.model.params["inputs_processor"])
28 | self.frame_stack_size = self.model.params["frame_stack"]["size"]
29 | self.frame_stack_step = self.model.params["frame_stack"]["step"]
30 | assert self.model.params["frame_stack"]["position"] == "last"
31 | assert self.model.params["responses_processor"][0] == "identity"
32 | self.indexes_generator = IndexesGenerator(self.frame_stack_size,
33 | self.frame_stack_step)
34 | self.blend_weights = get_blend_weights(blend_weights, self.frame_stack_size)
35 |
36 | @torch.no_grad()
37 | def predict_trial(self,
38 | video: np.ndarray,
39 | behavior: np.ndarray,
40 | pupil_center: np.ndarray,
41 | mouse_index: int) -> np.ndarray:
42 | inputs = self.inputs_processor(video, behavior, pupil_center).to(self.model.device)
43 | length = video.shape[-1]
44 | responses = np.zeros((constants.num_neurons[mouse_index], length), dtype=np.float32)
45 | blend_weights = np.zeros(length, np.float32)
46 | for index in range(
47 | self.indexes_generator.behind,
48 | length - self.indexes_generator.ahead
49 | ):
50 | indexes = self.indexes_generator.make_indexes(index)
51 | prediction = self.model.predict(inputs[:, indexes].unsqueeze(0), mouse_index)[0]
52 | responses[..., indexes] += prediction.cpu().numpy()
53 | blend_weights[indexes] += self.blend_weights
54 | responses /= np.clip(blend_weights, 1.0, None)
55 | return responses
56 |
--------------------------------------------------------------------------------
/src/responses.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Type
3 |
4 | import numpy as np
5 |
6 | import torch
7 |
8 | from src import constants
9 |
10 |
11 | class ResponseNormalizer:
12 | def __init__(self, mouse: str):
13 | std = np.load(
14 | str(constants.sensorium_dir / mouse / "meta" / "statistics" / "responses" / "all" / "std.npy")
15 | )
16 | threshold = 0.01 * np.nanmean(std)
17 | idx = std > threshold
18 | self._response_precision = np.ones_like(std) / threshold
19 | self._response_precision[idx] = 1 / std[idx]
20 |
21 | def __call__(self, responses):
22 | return responses * self._response_precision[..., :responses.shape[-1]]
23 |
24 |
25 | def responses_to_tensor(responses: np.ndarray) -> torch.Tensor:
26 | responses = responses.astype(np.float32)
27 | responses_tensor = torch.from_numpy(responses)
28 | responses_tensor = torch.relu(responses_tensor)
29 | return responses_tensor
30 |
31 |
32 | class ResponsesProcessor(metaclass=abc.ABCMeta):
33 | @abc.abstractmethod
34 | def __call__(self, responses: np.ndarray) -> torch.Tensor:
35 | pass
36 |
37 |
38 | class IdentityResponsesProcessor(ResponsesProcessor):
39 | def __call__(self, responses: np.ndarray) -> torch.Tensor:
40 | return responses_to_tensor(responses)
41 |
42 |
43 | class IndexingResponsesProcessor(ResponsesProcessor):
44 | def __init__(self, index: int | list[int]):
45 | self.index = index
46 |
47 | def __call__(self, responses: np.ndarray) -> torch.Tensor:
48 | responses = responses[..., self.index]
49 | responses_tensor = responses_to_tensor(responses)
50 | return responses_tensor
51 |
52 |
53 | class SelectLastResponsesProcessor(IndexingResponsesProcessor):
54 | def __init__(self):
55 | super().__init__(index=-1)
56 |
57 |
58 | _RESPONSES_PROCESSOR_REGISTRY: dict[str, Type[ResponsesProcessor]] = dict(
59 | identity=IdentityResponsesProcessor,
60 | indexing=IndexingResponsesProcessor,
61 | last=SelectLastResponsesProcessor,
62 | )
63 |
64 |
65 | def get_responses_processor(name: str, processor_params: dict) -> ResponsesProcessor:
66 | assert name in _RESPONSES_PROCESSOR_REGISTRY
67 | return _RESPONSES_PROCESSOR_REGISTRY[name](**processor_params)
68 |
--------------------------------------------------------------------------------
/src/submission.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from src.responses import ResponseNormalizer
7 | from src.data import get_mouse_data
8 | from src.metrics import corr
9 | from src import constants
10 |
11 |
12 | def cut_responses_for_submission(prediction: np.ndarray):
13 | prediction = prediction[..., :constants.submission_limit_length]
14 | prediction = prediction[..., constants.submission_skip_first:]
15 | if constants.submission_skip_last:
16 | prediction = prediction[..., :-constants.submission_skip_last]
17 | return prediction
18 |
19 |
20 | def evaluate_folds_predictions(experiment: str, dataset: str):
21 | prediction_dir = constants.predictions_dir / experiment / "out-of-fold"
22 | correlations = dict()
23 | for mouse in constants.dataset2mice[dataset]:
24 | mouse_data = get_mouse_data(mouse=mouse, splits=constants.folds_splits)
25 | mouse_prediction_dir = prediction_dir / mouse
26 | predictions = []
27 | targets = []
28 | for trial_data in mouse_data["trials"]:
29 | trial_id = trial_data['trial_id']
30 | prediction = np.load(str(mouse_prediction_dir / f"{trial_id}.npy"))
31 | target = np.load(trial_data["response_path"])[..., :trial_data["length"]]
32 | prediction = cut_responses_for_submission(prediction)
33 | target = cut_responses_for_submission(target)
34 | predictions.append(prediction)
35 | targets.append(target)
36 | correlation = float(corr(
37 | np.concatenate(predictions, axis=1),
38 | np.concatenate(targets, axis=1),
39 | axis=1
40 | ).mean())
41 | print(f"Mouse {mouse} correlation: {correlation}")
42 | correlations[mouse] = correlation
43 | mean_correlation = float(np.mean(list(correlations.values())))
44 | print("Mean correlation:", mean_correlation)
45 |
46 | evaluate_result = {"correlations": correlations, "mean_correlation": mean_correlation}
47 | with open(prediction_dir / f"evaluate_{dataset}.json", "w") as outfile:
48 | json.dump(evaluate_result, outfile, indent=4)
49 |
50 |
51 | def make_submission(experiment: str, split: str):
52 | prediction_dir = constants.predictions_dir / experiment / split
53 | data = []
54 | for mouse in constants.new_mice:
55 | normalizer = ResponseNormalizer(mouse)
56 | mouse_data = get_mouse_data(mouse=mouse, splits=[split])
57 | neuron_ids = mouse_data["neuron_ids"].tolist()
58 | mouse_prediction_dir = prediction_dir / mouse
59 | for trial_data in mouse_data["trials"]:
60 | trial_id = trial_data['trial_id']
61 | prediction = np.load(str(mouse_prediction_dir / f"{trial_id}.npy"))
62 | prediction = normalizer(prediction)
63 | prediction = cut_responses_for_submission(prediction)
64 | data.append((mouse, trial_id, prediction.tolist(), neuron_ids))
65 | submission_df = pd.DataFrame.from_records(
66 | data,
67 | columns=['mouse', 'trial_indices', 'prediction', 'neuron_ids']
68 | )
69 | del data
70 | split = split.replace('_test_', '_').replace('bonus', 'test_bonus_ood')
71 | submission_path = prediction_dir / f"predictions_{split}.parquet.brotli"
72 | submission_df.to_parquet(submission_path, compression='brotli', engine='pyarrow', index=False)
73 | print(f"Submission saved to '{submission_path}'")
74 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import math
3 | import time
4 | import random
5 | from pathlib import Path
6 |
7 | import numpy as np
8 |
9 | from torch import nn
10 |
11 |
12 | def set_random_seed(index: int):
13 | seed = int(time.time() * 1000.0) + index
14 | random.seed(seed)
15 | np.random.seed(seed % (2 ** 32 - 1))
16 |
17 |
18 | def get_lr(base_lr: float, batch_size: int, base_batch_size: int = 4) -> float:
19 | return base_lr * (batch_size / base_batch_size)
20 |
21 |
22 | def get_best_model_path(dir_path, return_score=False, more_better=True):
23 | dir_path = Path(dir_path)
24 | model_scores = []
25 | for model_path in dir_path.glob('*.pth'):
26 | score = re.search(r'-(\d+(?:\.\d+)?).pth', str(model_path))
27 | if score is not None:
28 | score = float(score.group(0)[1:-4])
29 | model_scores.append((model_path, score))
30 |
31 | if not model_scores:
32 | if return_score:
33 | return None, -np.inf if more_better else np.inf
34 | else:
35 | return None
36 |
37 | model_score = sorted(model_scores, key=lambda x: x[1], reverse=more_better)
38 | best_model_path = model_score[0][0]
39 | if return_score:
40 | best_score = model_score[0][1]
41 | return best_model_path, best_score
42 | else:
43 | return best_model_path
44 |
45 |
46 | def init_weights(module: nn.Module):
47 | for m in module.modules():
48 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
49 | fan_out = math.prod(m.kernel_size) * m.out_channels
50 | fan_out //= m.groups
51 | nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out))
52 | if m.bias is not None:
53 | nn.init.zeros_(m.bias)
54 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
55 | nn.init.ones_(m.weight)
56 | nn.init.zeros_(m.bias)
57 | elif isinstance(m, nn.Linear):
58 | fan_out = m.weight.size(0)
59 | fan_in = 0
60 | init_range = 1.0 / math.sqrt(fan_in + fan_out)
61 | nn.init.uniform_(m.weight, -init_range, init_range)
62 | if m.bias is not None:
63 | nn.init.zeros_(m.bias)
64 |
65 |
66 | def get_length_without_nan(array: np.ndarray):
67 | nan_indexes = np.argwhere(np.isnan(array)).ravel()
68 | if nan_indexes.shape[0]:
69 | return nan_indexes[0]
70 | else:
71 | return array.shape[0]
72 |
--------------------------------------------------------------------------------