├── .github └── workflows │ └── main.yml ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── github_image.png ├── main_example.gif ├── mainquant_example.gif ├── train_example.gif └── vprtempo_example.gif ├── main.py ├── orc_list.txt ├── pixi.lock ├── pixi.toml ├── requirements.txt ├── setup.py ├── tutorials ├── 1_BasicDemo.ipynb ├── 2_Training.ipynb ├── 3_Modules.ipynb └── mats │ └── README.txt └── vprtempo ├── VPRTempo.py ├── VPRTempoQuant.py ├── VPRTempoQuantTrain.py ├── VPRTempoTrain.py ├── __init__.py ├── dataset ├── nordland-fall.csv ├── nordland-spring.csv ├── nordland-summer.csv ├── orc-dusk.csv ├── orc-rain.csv └── orc-sun.csv ├── models ├── .gitkeep └── README.txt └── src ├── __init__.py ├── blitnet.py ├── create_data_csv.py ├── dataset.py ├── download.py ├── loggers.py ├── metrics.py ├── nordland.py └── process_orc.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPi 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Check out code 12 | uses: actions/checkout@v2 13 | 14 | - name: Set up Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: '3.x' 18 | 19 | - name: Install dependencies 20 | run: | 21 | pip install setuptools wheel twine 22 | 23 | - name: Build and publish 24 | env: 25 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 26 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 27 | run: | 28 | python setup.py sdist bdist_wheel 29 | twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .pyest_cache/ 3 | vprtempo/__pycache__/ 4 | vprtempo/dataset/fall/ 5 | vprtempo/dataset/spring/ 6 | vprtempo/dataset/summer/ 7 | vprtempo/dataset/winter/ 8 | vprtempo/dataset/event.csv 9 | vprtempo/output/ 10 | vprtempo/src/__pycache__/ 11 | *.pth 12 | tutorials/mats/1_BasicDemo/ 13 | .vscode/ 14 | *.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Adam Hines 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VPRTempo - A Temporally Encoded Spiking Neural Network for Visual Place Recognition 2 | ![PyTorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=for-the-badge&logo=PyTorch&logoColor=white) 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg?style=flat-square)](https://creativecommons.org/licenses/by-nc-sa/4.0/) 4 | [![QUT Centre for Robotics](https://img.shields.io/badge/collection-QUT%20Robotics-%23043d71?style=flat-square)](https://qcr.ai) 5 | [![stars](https://img.shields.io/github/stars/QVPR/VPRTempo.svg?style=flat-square)](https://github.com/QVPR/VPRTempo/stargazers) 6 | [![Downloads](https://static.pepy.tech/badge/vprtempo)](https://pepy.tech/project/vprtempo) 7 | [![Pixi Badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/prefix-dev/pixi/main/assets/badge/v0.json)](https://pixi.sh) 8 | [![Conda Version](https://img.shields.io/conda/vn/conda-forge/vprtempo.svg)](https://anaconda.org/conda-forge/vprtempo) 9 | [![GitHub repo size](https://img.shields.io/github/repo-size/QVPR/VPRTempo.svg?style=flat-square)](./README.md) 10 | 11 | 12 | This repository contains code for [VPRTempo](https://vprtempo.github.io), a spiking neural network that uses temporally encoding to perform visual place recognition tasks. The network is based off of [BLiTNet](https://arxiv.org/pdf/2208.01204.pdf) and adapted to the [VPRSNN](https://github.com/QVPR/VPRSNN) framework. 13 | 14 |

15 | VPRTempo method diagram 16 |

17 | 18 | VPRTempo is built on a [torch.nn](https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html) framework and employs custom learning rules based on the temporal codes of spikes in order to train layer weights. 19 | 20 | In this repository, we provide two networks: 21 | - `VPRTempo`: Our base network architecture to perform visual place recognition (fp32) 22 | - `VPRTempoQuant`: A modified base network with [Quantization Aware Training (QAT)](https://pytorch.org/docs/stable/quantization.html) enabled (int8) 23 | 24 | To use VPRTempo, please follow the instructions below for installation and usage. 25 | 26 | ## :star: Update v1.1.10: What's new? 27 | - Adding [pixi](https://prefix.dev/) for easier setup and reproducibility :rocket: 28 | - Fixed repo size issue for a more compact download :chart_with_downwards_trend: 29 | 30 | ## Quick start 31 | For simplicity and reproducibility, VPRTempo uses [pixi](https://prefix.dev/) to install and manage dependencies. If you do not already have pixi installed, run the following in your command terminal: 32 | 33 | ```console 34 | curl -fsSL https://pixi.sh/install.sh | bash 35 | ``` 36 | For more information, please refer to the [pixi documentation](https://prefix.dev/docs/prefix/overview). 37 | 38 | ### Get the repository 39 | Get the latest VPRTempo code and navigate to the project directory by running the following in your command terminal: 40 | ```console 41 | git clone https://github.com/QVPR/VPRTempo.git 42 | cd VPRTempo 43 | ``` 44 | 45 | ### Run the demo 46 | To quickly evaluate VPRTempo, we provide a pre-trained network trained on 500 places from the [Nordland](https://nrkbeta.no/2013/01/15/nordlandsbanen-minute-by-minute-season-by-season/) dataset. Run the following in your command terminal to run the demo: 47 | ```console 48 | pixi run demo 49 | ``` 50 | _Note: this will start a download of the models and datasets (~600MB), please ensure you have enough disk space before proceeding._ 51 | 52 | ### Train and evaluate a new model 53 | Training and evaluating a new model is quick and easy, simply run the following in your command terminal to re-train and evaluate the demo model: 54 | 55 | ```console 56 | pixi run train 57 | pixi run eval 58 | ``` 59 | _Note: You will be prompted if you want to retrain the pre-existing network._ 60 | 61 | ### Use the quantized models 62 | For training and evaluation of the 8-bit quantized model, run the following in your command terminal: 63 | 64 | ```console 65 | pixi run train_quant 66 | pixi run eval_quant 67 | ``` 68 | 69 | ## Datasets 70 | VPRTempo was developed to be simple to train and test a variety of datasets. Please see the information below about recreating our results for the Nordland and Oxford RobotCar datasets and setting up custom datasets. 71 | 72 | ### Nordland 73 | VPRTempo was developed and tested using the [Nordland](https://nrkbeta.no/2013/01/15/nordlandsbanen-minute-by-minute-season-by-season/) dataset. To download the full dataset, please visit [this repository](https://huggingface.co/datasets/Somayeh-h/Nordland?row=0). Once downloaded, place dataset folders into the VPRTempo directory as follows: 74 | 75 | ``` 76 | |__./vprtempo 77 | |___dataset 78 | |__summer 79 | |__spring 80 | |__fall 81 | |__winter 82 | ``` 83 | 84 | To replicate the results in our paper, run the following in your command terminal: 85 | 86 | ```console 87 | pixi run nordland_train 88 | pixi run nordland_eval 89 | ``` 90 | 91 | Alternatively, specify the data directory using the following argument: 92 | 93 | ```console 94 | pixi run nordland_train --data_dir 95 | pixi run nordland_eval --data_dir 96 | ``` 97 | 98 | ### Oxford RobotCar 99 | In order to train and test on Oxford RobotCar, you will need to [register an account](https://mrgdatashare.robots.ox.ac.uk/register/) to get access to download the dataset and process the images before proceeding. For more information, please refer to the [documentation](). 100 | 101 | Once fully processed, to replicate the results in our paper run the following in your command terminal: 102 | 103 | ```console 104 | pixi run orc_train 105 | pixi run orc_eval 106 | ``` 107 | 108 | ### Custom Datasets 109 | To define your own custom dataset to use with VPRTempo, simply follow the same dataset structure defined above for Nordland. A `.csv` file of the image names will be required for the dataloader. 110 | 111 | We have included a convenient script `./vprtempo/src/create_data_csv.py` which will generate the necessary file. Simply modify the `dataset_name` variable to the folder containing your images. 112 | 113 | To train a new model with a custom dataset, you can do the following. 114 | 115 | ```console 116 | pixi run train --dataset --database_dirs 117 | pixi run eval --database_dirs --dataset --query_dir 118 | ``` 119 | 120 | ## License & Citation 121 | This repository is licensed under the [MIT License](./LICENSE). If you use our code, please cite our IEEE ICRA [paper](https://ieeexplore.ieee.org/document/10610918): 122 | ``` 123 | @inproceedings{hines2024vprtempo, 124 | title={VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition}, 125 | author={Adam D. Hines and Peter G. Stratton and Michael Milford and Tobias Fischer}, 126 | year={2024}, 127 | pages={10200-10207}, 128 | booktitle={2024 IEEE International Conference on Robotics and Automation (ICRA)} 129 | } 130 | ``` 131 | 132 | ## Documentation 133 | 134 | For more detailed information on usage, please visit the [documentation](). 135 | 136 | ## Tutorials 137 | We provide a series of Jupyter Notebook [tutorials](https://github.com/QVPR/VPRTempo/tree/main/tutorials) that go through the basic operations and logic for VPRTempo and VPRTempoQuant. 138 | 139 | ## Issues, bugs, and feature requests 140 | If you encounter problems whilst running the code or if you have a suggestion for a feature or improvement, please report it as an [issue](https://github.com/QVPR/VPRTempo/issues). 141 | -------------------------------------------------------------------------------- /assets/github_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QVPR/VPRTempo/52fb2ecbbbe924430021b54e309fe2a5c8d18672/assets/github_image.png -------------------------------------------------------------------------------- /assets/main_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QVPR/VPRTempo/52fb2ecbbbe924430021b54e309fe2a5c8d18672/assets/main_example.gif -------------------------------------------------------------------------------- /assets/mainquant_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QVPR/VPRTempo/52fb2ecbbbe924430021b54e309fe2a5c8d18672/assets/mainquant_example.gif -------------------------------------------------------------------------------- /assets/train_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QVPR/VPRTempo/52fb2ecbbbe924430021b54e309fe2a5c8d18672/assets/train_example.gif -------------------------------------------------------------------------------- /assets/vprtempo_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QVPR/VPRTempo/52fb2ecbbbe924430021b54e309fe2a5c8d18672/assets/vprtempo_example.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #MIT License 2 | 3 | #Copyright (c) 2023 Adam Hines, Peter G Stratton, Michael Milford, Tobias Fischer 4 | 5 | #Permission is hereby granted, free of charge, to any person obtaining a copy 6 | #of this software and associated documentation files (the "Software"), to deal 7 | #in the Software without restriction, including without limitation the rights 8 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | #copies of the Software, and to permit persons to whom the Software is 10 | #furnished to do so, subject to the following conditions: 11 | 12 | #The above copyright notice and this permission notice shall be included in all 13 | #copies or substantial portions of the Software. 14 | 15 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | #SOFTWARE. 22 | 23 | ''' 24 | Imports 25 | ''' 26 | import os 27 | import sys 28 | import torch 29 | import argparse 30 | 31 | import torch.quantization as quantization 32 | 33 | from tqdm import tqdm 34 | from vprtempo.VPRTempo import VPRTempo, run_inference 35 | from vprtempo.VPRTempoTrain import VPRTempoTrain, train_new_model 36 | from vprtempo.src.loggers import model_logger, model_logger_quant 37 | from vprtempo.VPRTempoQuant import VPRTempoQuant, run_inference_quant 38 | from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, train_new_model_quant 39 | 40 | def generate_model_name(model,quant=False): 41 | """ 42 | Generate the model name based on its parameters. 43 | """ 44 | if quant: 45 | model_name = (''.join(model.database_dirs)+"_"+ 46 | "VPRTempoQuant_" + 47 | "IN"+str(model.input)+"_" + 48 | "FN"+str(model.feature)+"_" + 49 | "DB"+str(model.database_places) + 50 | ".pth") 51 | else: 52 | model_name = (''.join(model.database_dirs)+"_"+ 53 | "VPRTempo_" + 54 | "IN"+str(model.input)+"_" + 55 | "FN"+str(model.feature)+"_" + 56 | "DB"+str(model.database_places) + 57 | ".pth") 58 | return model_name 59 | 60 | def check_pretrained_model(model_name): 61 | """ 62 | Check if a pre-trained model exists and prompt the user to retrain if desired. 63 | """ 64 | if os.path.exists(os.path.join('./vprtempo/models', model_name)): 65 | prompt = "A network with these parameters exists, re-train network? (y/n):\n" 66 | retrain = input(prompt).strip().lower() 67 | if retrain == 'y': 68 | return True 69 | elif retrain == 'n': 70 | print('Training new model cancelled') 71 | sys.exit() 72 | 73 | def initialize_and_run_model(args,dims): 74 | """ 75 | Run the VPRTempo/VPRTempoQuant training or inference models. 76 | 77 | :param args: Arguments set for the network 78 | :param dims: Dimensions of the network 79 | """ 80 | # Determine number of modules to generate based on user input 81 | places = args.database_places # Copy out number of database places 82 | 83 | # Caclulate number of modules 84 | num_modules = 1 85 | while places > args.max_module: 86 | places -= args.max_module 87 | num_modules += 1 88 | 89 | # If the final module has less than max_module, reduce the dim of the output layer 90 | remainder = args.database_places % args.max_module 91 | if remainder != 0: # There are remainders, adjust output neuron count in final module 92 | out_dim = int((args.database_places - remainder) / (num_modules - 1)) 93 | final_out_dim = remainder 94 | else: # No remainders, all modules are even 95 | out_dim = int(args.database_places / num_modules) 96 | final_out_dim = out_dim 97 | 98 | # If user wants to train a new network 99 | if args.train_new_model: 100 | # If using quantization aware training 101 | if args.quantize: 102 | models = [] 103 | logger = model_logger_quant() # Initialize the logger 104 | qconfig = quantization.get_default_qat_qconfig('fbgemm') 105 | # Create the modules 106 | final_out = None 107 | for mod in tqdm(range(num_modules), desc="Initializing modules"): 108 | model = VPRTempoQuantTrain(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model 109 | model.train() 110 | model.qconfig = qconfig 111 | quantization.prepare_qat(model, inplace=True) 112 | models.append(model) # Create module list 113 | if mod == num_modules - 2: 114 | final_out = final_out_dim 115 | # Generate the model name 116 | model_name = generate_model_name(model,args.quantize) 117 | # Check if the model has been trained before 118 | check_pretrained_model(model_name) 119 | # Get the quantization config 120 | qconfig = quantization.get_default_qat_qconfig('fbgemm') 121 | # Train the model 122 | train_new_model_quant(models, model_name) 123 | 124 | # Base model 125 | else: 126 | models = [] 127 | logger = model_logger() # Initialize the logger 128 | 129 | # Create the modules 130 | final_out = None 131 | for mod in tqdm(range(num_modules), desc="Initializing modules"): 132 | model = VPRTempoTrain(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model 133 | model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models) 134 | models.append(model) # Create module list 135 | if mod == num_modules - 2: 136 | final_out = final_out_dim 137 | 138 | # Generate the model name 139 | model_name = generate_model_name(model) 140 | print(f"Model name: {model_name}") 141 | # Check if the model has been trained before 142 | check_pretrained_model(model_name) 143 | # Train the model 144 | train_new_model(models, model_name) 145 | 146 | # Run the inference network 147 | else: 148 | # Set the quantization configuration 149 | if args.quantize: 150 | models = [] 151 | logger, output_folder = model_logger_quant() 152 | qconfig = quantization.get_default_qat_qconfig('fbgemm') 153 | final_out = None 154 | for _ in tqdm(range(num_modules), desc="Initializing modules"): 155 | # Initialize the model 156 | model = VPRTempoQuant( 157 | args, 158 | dims, 159 | logger, 160 | num_modules, 161 | output_folder, 162 | out_dim, 163 | out_dim_remainder=final_out 164 | ) 165 | model.eval() 166 | model.qconfig = qconfig 167 | quantization.prepare(model, inplace=True) 168 | quantization.convert(model, inplace=True) 169 | models.append(model) 170 | # Generate the model name 171 | model_name = generate_model_name(model, args.quantize) 172 | # Run the quantized inference model 173 | run_inference_quant(models, model_name) 174 | else: 175 | models = [] 176 | logger, output_folder = model_logger() # Initialize the logger 177 | places = args.database_places # Copy out number of database places 178 | 179 | # Create the modules 180 | final_out = None 181 | for mod in tqdm(range(num_modules), desc="Initializing modules"): 182 | model = VPRTempo( 183 | args, 184 | dims, 185 | logger, 186 | num_modules, 187 | output_folder, 188 | out_dim, 189 | out_dim_remainder=final_out 190 | ) 191 | model.eval() 192 | model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models) 193 | models.append(model) # Create module list 194 | if mod == num_modules - 2: 195 | final_out = final_out_dim 196 | # Generate the model name 197 | model_name = generate_model_name(model) 198 | print(f"Model name: {model_name}") 199 | # Run the inference model 200 | run_inference(models, model_name) 201 | 202 | def parse_network(): 203 | ''' 204 | Define the base parameter parser (configurable by the user) 205 | ''' 206 | parser = argparse.ArgumentParser(description="Args for base configuration file") 207 | 208 | # Define the dataset arguments 209 | parser.add_argument('--dataset', type=str, default='nordland', 210 | help="Dataset to use for training and/or inferencing") 211 | parser.add_argument('--data_dir', type=str, default='./vprtempo/dataset/', 212 | help="Directory where dataset files are stored") 213 | parser.add_argument('--database_places', type=int, default=500, 214 | help="Number of places to use for training") 215 | parser.add_argument('--query_places', type=int, default=500, 216 | help="Number of places to use for inferencing") 217 | parser.add_argument('--max_module', type=int, default=500, 218 | help="Maximum number of images per module") 219 | parser.add_argument('--database_dirs', type=str, default='spring,fall', 220 | help="Directories to use for training") 221 | parser.add_argument('--query_dir', type=str, default='summer', 222 | help="Directories to use for testing") 223 | parser.add_argument('--GT_tolerance', type=int, default=0, 224 | help="Ground truth tolerance for matching") 225 | parser.add_argument('--skip', type=int, default=0, 226 | help="Images to skip for training and/or inferencing") 227 | 228 | # Define training parameters 229 | parser.add_argument('--filter', type=int, default=8, 230 | help="Images to skip for training and/or inferencing") 231 | parser.add_argument('--epoch', type=int, default=4, 232 | help="Number of epochs to train the model") 233 | 234 | # Define image transformation parameters 235 | parser.add_argument('--patches', type=int, default=15, 236 | help="Number of patches to generate for patch normalization image into") 237 | parser.add_argument('--dims', type=str, default="56,56", 238 | help="Dimensions to resize the image to") 239 | 240 | # Define the network functionality 241 | parser.add_argument('--train_new_model', action='store_true', 242 | help="Flag to run the training or inferencing model") 243 | parser.add_argument('--quantize', action='store_true', 244 | help="Enable/disable quantization for the model") 245 | 246 | # Define metrics functionality 247 | parser.add_argument('--PR_curve', action='store_true', 248 | help="Flag to generate a Precision-Recall curve") 249 | parser.add_argument('--sim_mat', action='store_true', 250 | help="Flag to plot the similarity matrix, GT, and GTsoft") 251 | 252 | # Output base configuration 253 | args = parser.parse_args() 254 | dims = [int(x) for x in args.dims.split(",")] 255 | 256 | # Run the network with the desired settings 257 | initialize_and_run_model(args,dims) 258 | 259 | if __name__ == "__main__": 260 | parse_network() -------------------------------------------------------------------------------- /orc_list.txt: -------------------------------------------------------------------------------- 1 | 2015-08-12-15-04-18 2 | 2014-11-21-16-07-03 3 | 2015-10-29-12-18-17 -------------------------------------------------------------------------------- /pixi.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "vprtempo" 3 | version = "1.1.10" 4 | description = "Temporally encoded spiking neural network for fast visual place recognition" 5 | authors = ["Adam D Hines ", "Peter G Stratton ", "Michael Milford "] 6 | channels = ["conda-forge", "pytorch"] 7 | platforms = ["linux-64", "osx-arm64", "win-64"] 8 | preview = ["pixi-build"] 9 | 10 | [feature.cuda] 11 | platforms = ["linux-64"] 12 | system-requirements = {cuda = "12"} 13 | 14 | [dependencies] 15 | python = ">=3.6,<3.13" 16 | pytorch = ">=2.4.0" 17 | torchvision = ">=0.19.0" 18 | numpy = ">=1.26.2,<2" 19 | pandas = ">=2.2.2" 20 | tqdm = ">=4.66.5" 21 | prettytable = ">=3.11.0" 22 | matplotlib-base = ">=3.9.2" 23 | requests = ">=2.32.3" 24 | 25 | [feature.cuda.dependencies] 26 | cuda-version = "12.*" 27 | pytorch-gpu = "*" 28 | cuda-cudart-dev = "*" 29 | cuda-crt = "*" 30 | cudnn = "*" 31 | libcusparse-dev = "*" 32 | cuda-driver-dev = "*" 33 | cuda-nvcc = "*" 34 | cuda-nvrtc-dev = "*" 35 | cuda-nvtx-dev = "*" 36 | cuda-nvml-dev = "*" 37 | cuda-profiler-api = "*" 38 | cusparselt = "*" 39 | libcublas-dev = "*" 40 | libcudss-dev = "*" 41 | libcufile-dev = "*" 42 | libcufft-dev = "*" 43 | libcurand-dev = "*" 44 | libcusolver-dev = "*" 45 | 46 | [environments] 47 | cuda = ["cuda"] 48 | 49 | [tasks] 50 | # run demo - downloads pretrained models and Nordland dataset 51 | demo = {cmd = "pixi run --frozen python main.py --PR_curve --sim_mat"} 52 | 53 | # run the evaluation networks 54 | eval = {cmd = "pixi run python main.py"} 55 | eval_quant = {cmd = "pixi run python main.py --quantize"} 56 | 57 | # train network 58 | train = {cmd = "pixi run python main.py --train_new_model"} 59 | train_quant = {cmd = "pixi run python main.py --train_new_model --quantize"} 60 | 61 | # replicate conference proceedings results 62 | nordland_train = {cmd = "pixi run python main.py --train_new_model --database_places 3300 --database_dirs spring,fall --skip 0 --max_module 1100 --dataset nordland --dims 28,28 --patches 7 --filter 8"} 63 | nordland_eval = {cmd = "pixi run python main.py --database_places 3300 --database_dirs spring,fall --skip 4800 --dataset nordland --dims 28,28 --patches 7 --filter 8 --query_dir summer --query_places 2700 --sim_mat --max_module 1100"} 64 | oxford_train = {cmd = "pixi run python main.py --train_new_model --dataset orc --database_places 450 --database_dirs sun,rain --skip 0 --max_module 450 --dataset orc --dims 28,28 --patches 7 --filter 7"} 65 | oxford_eval = {cmd = "python main.py --dataset orc --database_places 450 --database_dirs sun,rain --skip 630 --dataset orc --dims 28,28 --patches 7 --filter 7 --query_dir dusk --query_places 360 --sim_mat --max_module 450"} 66 | 67 | # helper functions for datasets 68 | scrape_oxford = {cmd = "python scrape_mrgdatashare.py --choice_sensors stereo_left --choice_runs_file orc_list.txt --downloads_dir ~/VPRTempo/vprtempo/dataset --datasets_file datasets.csv"} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | pandas 5 | tqdm 6 | prettytable 7 | matplotlib 8 | requests -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | from setuptools import setup, find_packages 4 | 5 | here = os.path.abspath(os.path.dirname(__file__)) 6 | # Get the long description from the README file 7 | with open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 8 | long_description = f.read() 9 | 10 | # define the base requires needed for the repo 11 | requirements = [ 12 | 'torch', 13 | 'torchvision', 14 | 'numpy', 15 | 'pandas', 16 | 'tqdm', 17 | 'prettytable', 18 | 'matplotlib', 19 | 'requests' 20 | ] 21 | 22 | # define the setup 23 | setup( 24 | name="VPRTempo", 25 | version="1.1.9", 26 | description='VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition', 27 | long_description=long_description, 28 | long_description_content_type='text/markdown', 29 | author='Adam D Hines, Peter G Stratton, Michael Milford and Tobias Fischer', 30 | author_email='adam.hines@qut.edu.au', 31 | url='https://github.com/QVPR/VPRTempo', 32 | license='MIT', 33 | install_requires=requirements, 34 | python_requires='>=3.6, <3.13', 35 | classifiers=[ 36 | # 3 - Alpha 37 | # 4 - Beta 38 | # 5 - Production/Stable 39 | 'Development Status :: 4 - Beta', 40 | 41 | # Indicate who your project is intended for 42 | 'Intended Audience :: Developers', 43 | # Pick your license as you wish (should match "license" above) 44 | 'License :: OSI Approved :: MIT License', 45 | 46 | # Specify the Python versions you support here. In particular, ensure 47 | # that you indicate whether you support Python 2, Python 3 or both. 48 | 'Programming Language :: Python :: 3.6', 49 | 'Programming Language :: Python :: 3.7', 50 | 'Programming Language :: Python :: 3.8', 51 | 'Programming Language :: Python :: 3.9', 52 | 'Programming Language :: Python :: 3.10', 53 | 'Programming Language :: Python :: 3.11', 54 | 'Programming Language :: Python :: 3.12', 55 | ], 56 | packages=find_packages(), 57 | keywords=['python', 'place recognition', 'spiking neural networks', 58 | 'computer vision', 'robotics'], 59 | scripts=['main.py'], 60 | ) 61 | -------------------------------------------------------------------------------- /tutorials/1_BasicDemo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b34c7b8a-e7bb-47f4-b558-be1bde9a7b37", 6 | "metadata": {}, 7 | "source": [ 8 | "## VPRTempo & VPRTempoQuant - Basic Demo\n", 9 | "\n", 10 | "### By Adam D Hines (https://research.qut.edu.au/qcr/people/adam-hines/)\n", 11 | "\n", 12 | "VPRTempo is based on the following paper, if you use or find this code helpful for your research please consider citing the source:\n", 13 | " \n", 14 | "[Adam D Hines, Peter G Stratton, Michael Milford, & Tobias Fischer. \"VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition. arXiv September 2023](https://arxiv.org/abs/2309.10225)\n", 15 | "\n", 16 | "### Introduction\n", 17 | "\n", 18 | "This is a basic, extremely simplified version of VPRTempo that highlights how images are transformed, spikes and weights are used, and the readout for performance using a model trained using our base system and the Quantized Aware Training (QAT) version. This is a basic, extremely simplified version of VPRTempo that highlights how images are transformed, spikes and weights are used, and the readout for performance. Although the proper implementation is in [PyTorch](https://pytorch.org/), we present a simple NumPy example to get started. As in the paper, we will present a simple example using the [Nordland](https://webdiis.unizar.es/~jmfacil/pr-nordland/#download-dataset) dataset with pre-trained set of weights.\n", 19 | "\n", 20 | "Before starting, make sure the following packages are installed and imported:" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "c879cd02-82db-441d-9476-fff1925bf494", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# Imprt opencv-python, NumPy, and matplotlib.pyplot\n", 31 | "try:\n", 32 | " import cv2\n", 33 | " import numpy as np\n", 34 | " import matplotlib.pyplot as plt\n", 35 | "except:\n", 36 | " !pip3 install numpy, opencv-python, matplotlib # pip install if modules not present\n", 37 | " import cv2\n", 38 | " import numpy as np\n", 39 | " import matplotlib.pyplot as plt" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "id": "66b11853-6e17-4884-ac92-d35d814add42", 45 | "metadata": {}, 46 | "source": [ 47 | "Next, we will need to get the pretrained weights for the model. To get them and the other materials, please download them from [here](https://www.dropbox.com/scl/fi/bxbzm47kxl24x979q5r5s/1_BasicDemo.zip?rlkey=0umij016whwgm11frzlk63v5k&st=hncbx0ld&dl=0). Unzip the files into the `./tutorials/mats/` folder." 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "id": "bb45df38-e333-46b2-9161-80e6ac367532", 53 | "metadata": {}, 54 | "source": [ 55 | "### Image processing\n", 56 | "\n", 57 | "Let's have a look at how we process our images to run through VPRTempo. We utilize a technique called *patch normalization* to resize input images and normalize the pixel intensities. To start, let's see what the original image looks like before patch normalization." 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "id": "67f129b5-9a7a-4b50-9d94-b9bf512f8b70", 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# Load the input image\n", 68 | "raw_img = cv2.imread('./mats/1_BasicDemo/summer.png')\n", 69 | "rgb_img = cv2.cvtColor(raw_img, cv2.COLOR_BGR2RGB) # Convert to RGB\n", 70 | "\n", 71 | "# Plot the image\n", 72 | "plt.imshow(rgb_img)\n", 73 | "plt.title('Nordland Summer')\n", 74 | "plt.show()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "id": "b68cf25e-35ae-4885-9cf1-c1b09ce4ad42", 80 | "metadata": {}, 81 | "source": [ 82 | "What we have here is a 360x640 RGB image, which for processing through neural networks is too big (230,400 total pixels). So instead, we'll use patch normalization to reduce the image size down to a grayscale 56x56 image to just 3136 pixels in total." 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "id": "6f67656a-3ba4-4374-b780-4e8bac4ec2d2", 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "# Load the patch normalized image\n", 93 | "patch_img = np.load('./mats/1_BasicDemo/summer_patchnorm.npy', allow_pickle=True)\n", 94 | "\n", 95 | "# Plot the image\n", 96 | "plt.matshow(patch_img)\n", 97 | "plt.title('Nordland Summer Patch Normalized')\n", 98 | "plt.colorbar(shrink=0.75, label=\"Pixel intensity\")\n", 99 | "plt.show()" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "id": "d404dfd2-10bd-4981-96ef-6092a9866fc6", 105 | "metadata": {}, 106 | "source": [ 107 | "The reduced image dimensions with patch normalization allows for a decent representation of the full scene, despite the smaller size.\n", 108 | "\n", 109 | "### Convert images to spikes\n", 110 | "\n", 111 | "'Spikes' in the context of VPRTempo are a little different than conventional spiking neural networks. Typically, spikes from image datasets are converted into Poisson spike trains where the pixel intensity determines the number of spikes to propagate throughout a network. VPRTempo only considers each pixel as a single spike, but considers the *amplitude* of the spike to determine the timing within a single timestep - where large amplitudes (high pixel intensity) spike early in a timestep, and vice versa for small amplitudes. \n", 112 | "\n", 113 | "Let's flatten the patch normalized image into a 1D-array so we can apply our network weights." 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "2bd6ae95-2a79-4b45-8a60-503079339739", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "# Convert 2D image to a 1D-array\n", 124 | "patch_1d = np.reshape(patch_img, (3136,))" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "id": "9d9a5eaf-1de3-461f-b138-3ac820da8bae", 130 | "metadata": {}, 131 | "source": [ 132 | "### Load the pre-trained network weights\n", 133 | "\n", 134 | "Our network consists of the following architecture:\n", 135 | "\n", 136 | " - An input layer sparsely connected to a feature layer, 3136 input neurons to 6272 feature neurons\n", 137 | " - The feature layer fully connected to a one-hot-encoded output layer, 6272 feature neurons to 500 output neurons\n", 138 | "\n", 139 | "Each layer connection is trained separately and stored in different weight matrices for excitatory (positive) and inhibitory (negative) connections. " 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "id": "6d98749a-8f28-477b-871c-93626e96786c", 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "# Load the input to feature excitatory and inhibitory network weights\n", 150 | "featureW = np.load('./mats/1_BasicDemo/featureW.npy')\n", 151 | "\n", 152 | "# Plot the weights\n", 153 | "plt.matshow(featureW.T)\n", 154 | "plt.title('Input > Feature Weights')\n", 155 | "plt.colorbar(shrink=0.8, label=\"Weight strength\")\n", 156 | "\n", 157 | "# Display the plots\n", 158 | "plt.show()" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "id": "826213d7-7721-440c-b1a8-47fb613339eb", 164 | "metadata": {}, 165 | "source": [ 166 | "Whilst it might be a little difficult to see, our excitatory connection amplitudes are on average a little higher than our inhibitiory. However, we overall have more inhibitiory connections that positive to balance the system.\n", 167 | "\n", 168 | "This is because when we set up our connections we use a probability of connections for both excitation and inbhition. In this case, we have a 10% connection probability for excitatory weights and a 50% probability for inhibitiory. This means as well there will be a high number of neurons without connections." 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "id": "94acb2f4", 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "# In this function, we will plot and visualize the distribution of weights and connections.\n", 179 | "def count_and_plot(array):\n", 180 | " # Flatten the 2D array and count positive, negative, and zero values\n", 181 | " flattened_array = array.flatten()\n", 182 | " positive_count = np.sum(flattened_array > 0)\n", 183 | " negative_count = np.sum(flattened_array < 0)\n", 184 | " zero_count = np.sum(flattened_array == 0)\n", 185 | " \n", 186 | " # Calculate percentages\n", 187 | " total_count = flattened_array.size\n", 188 | " positive_percentage = (positive_count / total_count) * 100\n", 189 | " negative_percentage = (negative_count / total_count) * 100\n", 190 | " zero_percentage = (zero_count / total_count) * 100\n", 191 | "\n", 192 | " # Print the results\n", 193 | " print(f\"Excitatory Connections: {positive_count} ({positive_percentage:.2f}%)\")\n", 194 | " print(f\"Inhibitory Conncetions: {negative_count} ({negative_percentage:.2f}%)\")\n", 195 | " print(f\"Zero Connections: {zero_count} ({zero_percentage:.2f}%)\")\n", 196 | "\n", 197 | " # Create a bar plot of the percentages\n", 198 | " categories = ['Excitatory', 'Inhibitory', 'Zero']\n", 199 | " percentages = [positive_percentage, negative_percentage, zero_percentage]\n", 200 | "\n", 201 | " plt.bar(categories, percentages)\n", 202 | " plt.xlabel('Category')\n", 203 | " plt.ylabel('Percentage')\n", 204 | " plt.title('Percentage of Excitatory, Inhibitiory, and Zero Connections')\n", 205 | " plt.ylim(0, 60) # Set the y-axis limit to 0-60%\n", 206 | " plt.show()\n", 207 | "\n", 208 | "if __name__ == \"__main__\":\n", 209 | "\n", 210 | " # Call the function to count and plot\n", 211 | " count_and_plot(featureW)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "id": "f180220e", 217 | "metadata": {}, 218 | "source": [ 219 | "Now let's have a look at the feature to the output weights, and see how the distribution of excitiatory and inhibitory connections differs." 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "7609eae5-f584-4c98-9eb6-3c3cf1e2aa04", 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "# Load the input to feature excitatory and inhibitory network weights\n", 230 | "outputW = np.load('./mats/1_BasicDemo/outputW.npy')\n", 231 | "\n", 232 | "# Plot the weights\n", 233 | "plt.matshow(outputW)\n", 234 | "plt.title('Feature > Output Weights')\n", 235 | "plt.colorbar(shrink=0.8, label=\"Weight strength\")\n", 236 | "\n", 237 | "# Display the plots\n", 238 | "plt.show()\n", 239 | "\n", 240 | "# Plot the distributions\n", 241 | "count_and_plot(outputW)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "id": "d591969a-e72e-43b2-8c89-16a13bb29fe6", 247 | "metadata": {}, 248 | "source": [ 249 | "### Propagate network spikes\n", 250 | "\n", 251 | "Now we'll propagate the input spikes across the layers to get the output. All we have to do is multiply the input spikes by the Input > Feature weights for both excitatory and inhibitory matrices and add them, then take the feature spikes and multiply them by the Feature > Output weights and do the smae thing. We'll also clamp spikes in the range of [0, 0.9] to prevent negative spikes and spike explosions.\n", 252 | "\n", 253 | "Let's do that and visualize the spikes as they're going through, we'll start with the Input to Feature layer." 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "id": "f6c84239-c176-48c3-8954-25da5f989d61", 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "# Calculate feature spikes (positive and negative weights)\n", 264 | "feature_spikes = np.matmul(featureW,patch_1d)\n", 265 | "feature_spikes = np.clip(feature_spikes, 0, 0.9)\n", 266 | "\n", 267 | "# Now create the line plot\n", 268 | "plt.plot(np.arange(len(feature_spikes)), feature_spikes)\n", 269 | "\n", 270 | "# Add title and labels if you wish\n", 271 | "plt.title('Feature Layer Spikes')\n", 272 | "plt.xlabel('Neuron ID')\n", 273 | "plt.ylabel('Spike Amplitude')\n", 274 | "\n", 275 | "# Show the plot\n", 276 | "plt.show()" 277 | ] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "id": "4ea0b0a3-66fc-4202-963c-cbd05114d283", 282 | "metadata": {}, 283 | "source": [ 284 | "This looks a little homogenous, but this is the feature representation of our input image. \n", 285 | "\n", 286 | "Now let's propagate the feature layer spikes through to the output layer to get our corresponding place match." 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "id": "f5d4dc99-c7b9-4e9b-ba7c-58f6e30631cb", 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "# Calculate output spikes (positive and negative weights)\n", 297 | "output_spikes = np.matmul(outputW,feature_spikes)\n", 298 | "output_spikes = np.clip(output_spikes, 0, 0.9)\n", 299 | "\n", 300 | "# Now create the line plot\n", 301 | "plt.plot(np.arange(len(output_spikes)), output_spikes)\n", 302 | "\n", 303 | "# Add title and labels if you wish\n", 304 | "plt.title('Output Layer Spikes')\n", 305 | "plt.xlabel('Neuron ID')\n", 306 | "plt.ylabel('Spike Amplitude')\n", 307 | "\n", 308 | "# Show the plot\n", 309 | "plt.show()" 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "id": "54b4f5f6-017b-4d7d-812a-c96baf9cb39f", 315 | "metadata": {}, 316 | "source": [ 317 | "Success! We have propagated our input spikes across the layers to reach this output. Clearly, one of the output spikes has the highest amplitude. Our network weights were trained on 500 locations from a Fall and Spring traversal of Nordland. For this example, we passed the first location from the Summer traversal through the network to achieve this output - which clearly looks to have spikes Neuron ID '0' the highest!\n", 318 | "\n", 319 | "Let's prove that." 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "id": "780371ca-9dfe-4dd7-857d-e35be73ffd23", 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "# Output the argmax from the output spikes\n", 330 | "prediction = np.argmax(output_spikes)\n", 331 | "print(f\"Neuron ID with the highest output is {prediction}\")" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "id": "0c8c82d7", 337 | "metadata": {}, 338 | "source": [ 339 | "## Quantized model example\n", 340 | "\n", 341 | "Now that we have seen how our base model works, let's look at how our int8 quantized model performs by comparison. Working in the in8 space has a few benefits, like faster inferencing time and smaller model sizes. There are a couple differences however when feeding spikes throughout the system that PyTorch performs in the backend.\n", 342 | "\n", 343 | "Let's start by converting our input image into int8 spikes." 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": null, 349 | "id": "5c893e76", 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "# Converting fp32 spikes to int8 uses a learned scale factor during quantization aware training\n", 354 | "spike_scale = 133\n", 355 | "patch_img_int = patch_img*spike_scale\n", 356 | "\n", 357 | "# Plot the converted int8 image\n", 358 | "plt.matshow(patch_img_int)\n", 359 | "plt.title('Nordland Summer Patch Normalized Int8')\n", 360 | "plt.colorbar(shrink=0.75, label=\"Pixel intensity\")\n", 361 | "plt.show()\n", 362 | "\n", 363 | "# Convert 2D image to a 1D-array\n", 364 | "patch_1d_int = np.reshape(patch_img_int, (3136,))" 365 | ] 366 | }, 367 | { 368 | "cell_type": "markdown", 369 | "id": "9f68d3dc", 370 | "metadata": {}, 371 | "source": [ 372 | "Now we'll load in and plot our integer based weights, as well as some scale factors which will be important to reduce the size of our spikes after multiplying them with our weights." 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "id": "70c99f47", 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "# Load the scales for the feature and output spikes\n", 383 | "feature_scales = np.load('./mats/1_BasicDemo/featureScales.npy',allow_pickle=True)\n", 384 | "output_scales = np.load('./mats/1_BasicDemo/outputScales.npy',allow_pickle=True)\n", 385 | "\n", 386 | "# Load the int8 weights and plot them\n", 387 | "featureQuantW = np.load('./mats/1_BasicDemo/featureQuantW.npy')\n", 388 | "outputQuantW = np.load('./mats/1_BasicDemo/outputQuantW.npy')\n", 389 | "\n", 390 | "# Plot the feature weights\n", 391 | "plt.matshow(featureQuantW.T)\n", 392 | "plt.title('Input > Feature Weights')\n", 393 | "plt.colorbar(shrink=0.8, label=\"Weight strength\")\n", 394 | "\n", 395 | "# Display the plots\n", 396 | "plt.show()\n", 397 | "\n", 398 | "# Plot the output weights\n", 399 | "plt.matshow(outputQuantW)\n", 400 | "plt.title('Feature > Output Weights')\n", 401 | "plt.colorbar(shrink=0.8, label=\"Weight strength\")\n", 402 | "\n", 403 | "# Display the plots\n", 404 | "plt.show()" 405 | ] 406 | }, 407 | { 408 | "cell_type": "markdown", 409 | "id": "f185e3ff", 410 | "metadata": {}, 411 | "source": [ 412 | "Now as above, let's propagate the input spikes throughout the network." 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "id": "ddb1ef65", 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [ 422 | "# Get the feature spikes\n", 423 | "feature_spikes_int = np.matmul(featureQuantW,patch_1d_int)\n", 424 | "\n", 425 | "# Now create the line plot\n", 426 | "plt.plot(np.arange(len(feature_spikes_int)), feature_spikes_int)\n", 427 | "\n", 428 | "# Add title and labels if you wish\n", 429 | "plt.title('Output Layer Spikes')\n", 430 | "plt.xlabel('Neuron ID')\n", 431 | "plt.ylabel('Spike Amplitude')\n", 432 | "\n", 433 | "# Show the plot\n", 434 | "plt.show()" 435 | ] 436 | }, 437 | { 438 | "cell_type": "markdown", 439 | "id": "6e9189aa", 440 | "metadata": {}, 441 | "source": [ 442 | "Those are some big spikes! We're going to have to scale these spikes back down before we forward them to the output layer, otherwise we'll have some huge activations. Let's take those scales we loaded in earlier and apply them to the feature spikes.\n", 443 | "\n", 444 | "We have three things to consider here:\n", 445 | " - A slice scale factor (per neuronal connection scale)\n", 446 | " - A zero point (a factor to change where 'zero' is)\n", 447 | " \n", 448 | " Let's print out these three factors and see how they scale our spikes." 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": null, 454 | "id": "95032ac7", 455 | "metadata": {}, 456 | "outputs": [], 457 | "source": [ 458 | "# Print out the individual scales\n", 459 | "print(f\"The slice scale factor is {feature_scales[1]}\")\n", 460 | "print(f\"The zero point is {feature_scales[2]}\")" 461 | ] 462 | }, 463 | { 464 | "cell_type": "markdown", 465 | "id": "9f62b909", 466 | "metadata": {}, 467 | "source": [ 468 | "Now we'll modify and scale our spikes to then pass them on to the feature layer." 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": null, 474 | "id": "9a7dd83d", 475 | "metadata": {}, 476 | "outputs": [], 477 | "source": [ 478 | "# Scale the feature spikes\n", 479 | "scaled_feature_spikes = (feature_spikes_int//(feature_scales[1]))+feature_scales[2]\n", 480 | "scaled_feature_spikes = np.clip(scaled_feature_spikes,0,255)\n", 481 | "\n", 482 | "# Plot the scaled feature spikes\n", 483 | "plt.plot(np.arange(len(scaled_feature_spikes)), scaled_feature_spikes)\n", 484 | "\n", 485 | "# Add title and labels if you wish\n", 486 | "plt.title('Output Layer Spikes')\n", 487 | "plt.xlabel('Neuron ID')\n", 488 | "plt.ylabel('Spike Amplitude')\n", 489 | "\n", 490 | "# Show the plot\n", 491 | "plt.show()" 492 | ] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "id": "2a767c92", 497 | "metadata": {}, 498 | "source": [ 499 | "Now that we've scaled our feature spikes, let's pass them through to the output layer and get our match!" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": null, 505 | "id": "7519c07f", 506 | "metadata": {}, 507 | "outputs": [], 508 | "source": [ 509 | "# Get the output spikes\n", 510 | "output_spikes_int = np.matmul(outputQuantW,scaled_feature_spikes)\n", 511 | "\n", 512 | "# Scale the output spikes\n", 513 | "scaled_output_spikes = output_spikes_int//(output_scales[1]) + output_scales[2]\n", 514 | "\n", 515 | "# Plot the scaled feature spikes\n", 516 | "plt.plot(np.arange(len(scaled_output_spikes)), scaled_output_spikes)\n", 517 | "\n", 518 | "# Add title and labels if you wish\n", 519 | "plt.title('Output Layer Spikes')\n", 520 | "plt.xlabel('Neuron ID')\n", 521 | "plt.ylabel('Spike Amplitude')\n", 522 | "\n", 523 | "# Show the plot\n", 524 | "plt.show()" 525 | ] 526 | }, 527 | { 528 | "cell_type": "markdown", 529 | "id": "3cc4588a", 530 | "metadata": {}, 531 | "source": [ 532 | "And once again, as in the base model, we can see that output neuron 0 is the highest respondant.\n", 533 | "\n", 534 | "Let's prove it!" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": null, 540 | "id": "67e457e4", 541 | "metadata": {}, 542 | "outputs": [], 543 | "source": [ 544 | "# Output the argmax from the output spikes\n", 545 | "prediction = np.argmax(scaled_output_spikes)\n", 546 | "print(f\"Neuron ID with the highest output is {prediction}\")" 547 | ] 548 | }, 549 | { 550 | "cell_type": "markdown", 551 | "id": "7bc8a7fb-66b4-455b-922e-b0fdc38b53c5", 552 | "metadata": {}, 553 | "source": [ 554 | "### Conclusions\n", 555 | "\n", 556 | "We have gone through a very basic demo of how VPRTempo takes input images, patch normalizes them, and propagates the spikes throughout the weights to achieve the desired matching output. Although this demonstration was performed using NumPy, the torch implementation is virtually the same except we use tensors with or without quantization. \n", 557 | "\n", 558 | "We also went through how the quantization version of the network handled weights and spikes in the integer domain.\n", 559 | "\n", 560 | "If you would like to go more in-depth with training and inferencing, checkout some of the [other tutorials](https://github.com/AdamDHines/VPRTempo-quant/tree/main/tutorials) which show you how to train your own model and goes through the more sophisticated implementation of VPRTempo." 561 | ] 562 | } 563 | ], 564 | "metadata": { 565 | "kernelspec": { 566 | "display_name": "Python 3 (ipykernel)", 567 | "language": "python", 568 | "name": "python3" 569 | }, 570 | "language_info": { 571 | "codemirror_mode": { 572 | "name": "ipython", 573 | "version": 3 574 | }, 575 | "file_extension": ".py", 576 | "mimetype": "text/x-python", 577 | "name": "python", 578 | "nbconvert_exporter": "python", 579 | "pygments_lexer": "ipython3", 580 | "version": "3.11.4" 581 | } 582 | }, 583 | "nbformat": 4, 584 | "nbformat_minor": 5 585 | } 586 | -------------------------------------------------------------------------------- /tutorials/2_Training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2a09c9e8-68f6-4edd-a8c1-1cc7f6b8bb00", 6 | "metadata": {}, 7 | "source": [ 8 | "## Training a new VPRTempo & VPRTempoQuant network\n", 9 | "\n", 10 | "### By Adam D Hines (https://research.qut.edu.au/qcr/people/adam-hines/)\n", 11 | "\n", 12 | "VPRTempo is based on the following paper, if you use or find this code helpful for your research please consider citing the source:\n", 13 | " \n", 14 | "[Adam D Hines, Peter G Stratton, Michael Milford, & Tobias Fischer. \"VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition. arXiv September 2023](https://arxiv.org/abs/2309.10225)\n", 15 | "\n", 16 | "### Introduction\n", 17 | "\n", 18 | "In this tutorial, we will go through how to train your own model for both the base and quantized version of VPRTempo. \n", 19 | "\n", 20 | "Before starting, make sure you have [installed the dependencies](https://github.com/AdamDHines/VPRTempo-quant#installation-and-setup) and/or activated the conda environment. You will also need the [Nordland](https://github.com/AdamDHines/VPRTempo-quant#nordland) dataset before proceeding, as this tutorial will cover training the network using this as an example.\n", 21 | "\n", 22 | "### Training new models for VPRTempo and VPRTempoQuant\n", 23 | "\n", 24 | "Let's start by training the base model with the default settings (if you have pre-trained a model, it will get removed for the purpose of the tutorial)." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "a757316e-4aa8-41aa-b03f-1ce4489d3705", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "# Change the working directory to the main folder from tutorials\n", 35 | "import os\n", 36 | "os.chdir('../')" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "id": "eff21590-9d00-4d8f-b860-ca1c0c849187", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "# Train the base model with the default settings\n", 47 | "# If the pre-trained model already exists, we will remove it for the tutorial\n", 48 | "file_path = './models/VPRTempo313662725001.pth'\n", 49 | "\n", 50 | "if os.path.exists(file_path):\n", 51 | " os.remove(file_path)\n", 52 | " print(\"The file has been deleted.\")\n", 53 | "\n", 54 | "# Run the training paradigm\n", 55 | "!python main.py --train_new_model" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "285ca349-6295-4f04-bbe1-360d43111f82", 61 | "metadata": {}, 62 | "source": [ 63 | "Now we'll run the inferencing model to check and make sure our model trained ok." 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "1037e4ee-0361-4aa1-bfba-f979ffa51b88", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# Run the base inferencing network\n", 74 | "!python main.py" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "id": "cbf8cd37-1af8-420f-ad76-481157a2920e", 80 | "metadata": {}, 81 | "source": [ 82 | "Great! Now let's have a look at changing a few of the default settings and training different kinds of networks. The default settings train 500 places, so if we want to only look at a smaller number of places we can parse the `--num_places` argument and specify how many places to learn." 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "id": "81daf9c8-ea06-4e69-82a5-9179a7c38ea1", 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "# Train a new model with 250 places\n", 93 | "!python main.py --num_places 250 --train_new_model" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "id": "a46cde0d-6343-48a4-87f8-aa5c9f56b231", 99 | "metadata": {}, 100 | "source": [ 101 | "And we can now inference using this smaller model." 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "id": "559e2391-d0bc-4d58-bb35-8cf69ae08b5e", 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "# Run the inference for a model with 250 places\n", 112 | "!python main.py --num_places 250" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "id": "aa835469-90cc-4f09-b7b4-688465321b18", 118 | "metadata": {}, 119 | "source": [ 120 | "Arguments for the base network work the same for VPRTempoQuant, we just need to also parse the `--quantize` argument. Let's now train another 250 place network, but also change a couple of other parameters. The default VPRTempo settings is a little slow to train on CPU, so let's reduce the image size from 56x56 to 28x28 and change the number of patches for patch normalization." 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "id": "eee89998-d334-4cdf-ac5b-88aed367ceab", 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "# Train a 250 place network with VPRTempoQuant\n", 131 | "!python main.py --quantize --num_places 250 --patches 7 --dims 28,28 --train_new_model" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "id": "42d5f3ce-acbb-435d-b197-f7e5841dcd70", 137 | "metadata": {}, 138 | "source": [ 139 | "And now we can inference this model!" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "id": "c6a9e948-5349-4714-b7d0-03ce694409b8", 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "# Run inference on newly trained VPRTempoQuant model\n", 150 | "!python main.py --quantize --num_places 250 --patches 7 --dims 28,28" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "id": "2648e38e-a0d7-481a-9009-3c4cf3e3fff0", 156 | "metadata": {}, 157 | "source": [ 158 | "### List of arguments you can parse\n", 159 | "\n", 160 | "The full list of arguments that can parsed to VPRTempo can be found in the `parse_network` function of `main.py`. Hyperparameters for VPRTempo are hardcoded into the layers and are not recommended to be changed since they generalize fairly well across multiple different datasets. \n", 161 | "\n", 162 | "### Conclusions\n", 163 | "\n", 164 | "This tutorial provided a simple overview of how you can train your own models for both VPRTempo and VPRTempoQuant, and changing a few of the network parameters.\n", 165 | "\n", 166 | "If you would like to go more in-depth, checkout some of the [other tutorials](https://github.com/AdamDHines/VPRTempo-quant/tree/main/tutorials) where we cover how to define your own custom dataset and work with expert modules." 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "a2bdc6d1-4690-4b48-a75f-f9d716cb7154", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [] 176 | } 177 | ], 178 | "metadata": { 179 | "kernelspec": { 180 | "display_name": "Python 3 (ipykernel)", 181 | "language": "python", 182 | "name": "python3" 183 | }, 184 | "language_info": { 185 | "codemirror_mode": { 186 | "name": "ipython", 187 | "version": 3 188 | }, 189 | "file_extension": ".py", 190 | "mimetype": "text/x-python", 191 | "name": "python", 192 | "nbconvert_exporter": "python", 193 | "pygments_lexer": "ipython3", 194 | "version": "3.11.4" 195 | } 196 | }, 197 | "nbformat": 4, 198 | "nbformat_minor": 5 199 | } 200 | -------------------------------------------------------------------------------- /tutorials/3_Modules.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7b7b7556-f8df-4e50-801b-7e215e46a415", 6 | "metadata": {}, 7 | "source": [ 8 | "## Using modules with VPRTempo and VPRTempoQuant\n", 9 | "\n", 10 | "### By Adam D Hines (https://research.qut.edu.au/qcr/people/adam-hines/)\n", 11 | "\n", 12 | "VPRTempo is based on the following paper, if you use or find this code helpful for your research please consider citing the source:\n", 13 | " \n", 14 | "[Adam D Hines, Peter G Stratton, Michael Milford, & Tobias Fischer. \"VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition. arXiv September 2023](https://arxiv.org/abs/2309.10225)\n", 15 | "\n", 16 | "### Introduction\n", 17 | "\n", 18 | "In this tutorial, we will go through how to use modules with VPRTempo. Modules break up the training data into multiple networks, which has been shown to [improve the overall performance](https://towardsdatascience.com/machine-learning-with-expert-models-a-primer-6c74585f223f) and accuracy of larger models.\n", 19 | "\n", 20 | "Before starting, make sure you have [installed the dependencies](https://github.com/AdamDHines/VPRTempo-quant#installation-and-setup) and/or activated the conda environment. You will also need the [Nordland](https://github.com/AdamDHines/VPRTempo-quant#nordland) dataset before proceeding, as this tutorial will cover training the network using this as an example.\n", 21 | "\n", 22 | "### Comparing results using expert modules for VPRTempo\n", 23 | "\n", 24 | "Let's start by training the base model with 1000 places, which is 500 more than the default settings. We will need to parse the `--train_new_model`, `--num_places`, as well as another argument we haven't seen yet `--max_module`. \n", 25 | "\n", 26 | "`--max_module` tells the network how many places each expert module should learn, which by default is set to `500`. So if we're training a new, singular network with 1000 places we need to increase `max_module` to 1000." 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "9e0d1a7b-8788-436a-8dd5-27f60077c524", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# Change the working directory to the main folder from tutorials\n", 37 | "import os\n", 38 | "os.chdir('../')" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "id": "c0981001-fa1c-49e0-bbe3-e5a628098eba", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# Train a single network with 1000 places\n", 49 | "!python main.py --num_places 1000 --max_module 1000 --train_new_model" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "id": "e7d8287b-b47b-4f7e-9d1c-4b309746b284", 55 | "metadata": {}, 56 | "source": [ 57 | "Now let's see how this performs." 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "id": "32435d47-e583-4768-86ed-2998ec78fdd6", 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# Run the inferencing network on the singular 1000 place trained model\n", 68 | "!python main.py --num_places 1000 --max_module 1000" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "id": "845764ee-1492-449c-9816-a45ccfe043a1", 74 | "metadata": {}, 75 | "source": [ 76 | "Performance here is still pretty good, but let's see if we can improve it by splitting up the network into modules!\n", 77 | "\n", 78 | "Now that splitting up our 1000 place network into 2 networks, we can remove the `--max_module` argument because the default is set to 500. Instead what we will parse is `--num_modules` to tell the network to split things up into two models." 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "d7cf368b-2f7f-4345-b882-7bf51622c0d6", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "# Train a new 1000 place model with 2 modules\n", 89 | "!python main.py --num_places 1000 --num_modules 2 --train_new_model" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "id": "84ce9e04-ff62-45ce-83b3-89e4623f0fd1", 95 | "metadata": {}, 96 | "source": [ 97 | "Now let's see how it compares with the singular model." 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "id": "c6ec3ec7-cca7-45a0-906b-cc3e6d14b5f8", 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "# Run the network with 2 modules\n", 108 | "!python main.py --num_places 1000 --num_modules 2" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "id": "8da07161-81e9-4773-90f1-e9d675e8c071", 114 | "metadata": {}, 115 | "source": [ 116 | "A modest boost to performance, however you have to imagine how this scales to much larger networks - especially when considering training times. Because the output layer is one-hot encoded, you need to increase the number of output neurons with each place you want to learn. Splitting up networks has a key benefit for VPRTempo to reduce overall training times with little impact on inference speeds. \n", 117 | "\n", 118 | "(Optional) Run a single network for 2500 places or 5 expert modules for 500 places each (reduced dimensionality to speed things up)." 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "ffd18263-190c-47f4-a37f-130a040d7a67", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "# Optional: run a 2500 place comparison for singular vs modular networks\n", 129 | "# Train networks\n", 130 | "!python main.py --num_places 2500 --max_module 2500 --dims 28,28 --patches 7 --train_new_model\n", 131 | "!python main.py --num_places 2500 --num_modules 5 --dims 28,28 --patches 7 --train_new_model\n", 132 | "# Run inference\n", 133 | "!python main.py --num_places 2500 --max_module 2500 --dims 28,28 --patches 7\n", 134 | "!python main.py --num_places 2500 --num_modules 5 --dims 28,28 --patches 7" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "id": "02edb3da-9d34-4709-950d-d202c5e902e7", 140 | "metadata": {}, 141 | "source": [ 142 | "As in the other tutorials, parsing the `--quantize` argument will run exactly the same but for VPRTempoQuant. Let's do a quick comparison." 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "2e520f0a-7b6f-4546-8169-14f20f2f9d18", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "# Train networks\n", 153 | "#!python main.py --num_places 1500 --max_module 1500 --dims 28,28 --patches 7 --train_new_model --quantize\n", 154 | "#!python main.py --num_places 1500 --num_modules 3 --dims 28,28 --patches 7 --train_new_model --quantize\n", 155 | "# Run inference\n", 156 | "!python main.py --num_places 1500 --max_module 1500 --dims 28,28 --patches 7 --quantize\n", 157 | "!python main.py --num_places 1500 --num_modules 3 --dims 28,28 --patches 7 --quantize" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "id": "58581b5c-70a8-4708-bda1-1ccca1edc70b", 163 | "metadata": {}, 164 | "source": [ 165 | "Once again, we can see that whilst there's a modest boost to the accuracy result the clear improve is the training speed. Because each network is smaller, the opeations on CPU are a lot less computationally heavy when splitting the networks up.\n", 166 | "\n", 167 | "### Conclusions\n", 168 | "\n", 169 | "This tutorial provided a simple overview of how you can train network models using expert modules. \n", 170 | "\n", 171 | "If you would like to go more in-depth, checkout some of the [other tutorials](https://github.com/AdamDHines/VPRTempo-quant/tree/main/tutorials)." 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "id": "0309b80b-05c2-4577-8ea0-2e45bf2d6aef", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "Python 3 (ipykernel)", 186 | "language": "python", 187 | "name": "python3" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.11.4" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 5 204 | } 205 | -------------------------------------------------------------------------------- /tutorials/mats/README.txt: -------------------------------------------------------------------------------- 1 | To download materials for BasicDemo, please visit -> https://www.dropbox.com/scl/fi/bxbzm47kxl24x979q5r5s/1_BasicDemo.zip?rlkey=0umij016whwgm11frzlk63v5k&st=52otrem5&dl=0 2 | 3 | Extract into the `./tutorials/mats` folder. 4 | -------------------------------------------------------------------------------- /vprtempo/VPRTempo.py: -------------------------------------------------------------------------------- 1 | #MIT License 2 | 3 | #Copyright (c) 2023 Adam Hines, Peter G Stratton, Michael Milford, Tobias Fischer 4 | 5 | #Permission is hereby granted, free of charge, to any person obtaining a copy 6 | #of this software and associated documentation files (the "Software"), to deal 7 | #in the Software without restriction, including without limitation the rights 8 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | #copies of the Software, and to permit persons to whom the Software is 10 | #furnished to do so, subject to the following conditions: 11 | 12 | #The above copyright notice and this permission notice shall be included in all 13 | #copies or substantial portions of the Software. 14 | 15 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | #SOFTWARE. 22 | 23 | ''' 24 | Imports 25 | ''' 26 | 27 | import os 28 | import json 29 | import torch 30 | 31 | import numpy as np 32 | import torch.nn as nn 33 | import matplotlib.pyplot as plt 34 | import vprtempo.src.blitnet as bn 35 | 36 | from tqdm import tqdm 37 | from prettytable import PrettyTable 38 | from torch.utils.data import DataLoader 39 | from vprtempo.src.download import get_data_model 40 | from vprtempo.src.metrics import recallAtK, createPR 41 | from vprtempo.src.dataset import CustomImageDataset, ProcessImage 42 | 43 | class VPRTempo(nn.Module): 44 | def __init__(self, args, dims, logger, num_modules, output_folder, out_dim, out_dim_remainder=None): 45 | super(VPRTempo, self).__init__() 46 | 47 | # Set the args 48 | if args is not None: 49 | self.args = args 50 | for arg in vars(args): 51 | setattr(self, arg, getattr(args, arg)) 52 | setattr(self, 'dims', dims) 53 | 54 | # Set the device 55 | if torch.cuda.is_available(): 56 | self.device = "cuda:0" 57 | elif torch.backends.mps.is_available(): 58 | self.device = "mps" 59 | else: 60 | self.device = "cpu" 61 | 62 | # Set input args 63 | self.logger = logger 64 | self.num_modules = num_modules 65 | self.output_folder = output_folder 66 | 67 | # Set the dataset file 68 | self.dataset_file = os.path.join('./vprtempo/dataset', f'{self.dataset}-{self.query_dir}' + '.csv') 69 | self.query_dir = [dir.strip() for dir in self.query_dir.split(',')] 70 | 71 | # Layer dict to keep track of layer names and their order 72 | self.layer_dict = {} 73 | self.layer_counter = 0 74 | self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')] 75 | 76 | # Define layer architecture 77 | self.input = int(self.dims[0]*self.dims[1]) 78 | self.feature = int(self.input * 2) 79 | 80 | # Output dimension changes for final module if not an even distribution of places 81 | if not out_dim_remainder is None: 82 | self.output = out_dim_remainder 83 | else: 84 | self.output = out_dim 85 | 86 | # set model name for default demo 87 | self.demo = './vprtempo/models/springfall_VPRTempo_IN3136_FN6272_DB500.pth' 88 | 89 | """ 90 | Define trainable layers here 91 | """ 92 | self.add_layer( 93 | 'feature_layer', 94 | dims=[self.input, self.feature], 95 | device=self.device, 96 | inference=True 97 | ) 98 | self.add_layer( 99 | 'output_layer', 100 | dims=[self.feature, self.output], 101 | device=self.device, 102 | inference=True 103 | ) 104 | 105 | def add_layer(self, name, **kwargs): 106 | """ 107 | Dynamically add a layer with given name and keyword arguments. 108 | 109 | :param name: Name of the layer to be added 110 | :type name: str 111 | :param kwargs: Hyperparameters for the layer 112 | """ 113 | # Check for layer name duplicates 114 | if name in self.layer_dict: 115 | raise ValueError(f"Layer with name {name} already exists.") 116 | 117 | # Add a new SNNLayer with provided kwargs 118 | setattr(self, name, bn.SNNLayer(**kwargs)) 119 | 120 | # Add layer name and index to the layer_dict 121 | self.layer_dict[name] = self.layer_counter 122 | self.layer_counter += 1 123 | 124 | def evaluate(self, models, test_loader): 125 | """ 126 | Run the inferencing model and calculate the accuracy. 127 | 128 | :param models: Models to run inference on, each model is a VPRTempo module 129 | :param test_loader: Testing data loader 130 | """ 131 | # Initialize the tqdm progress bar 132 | pbar = tqdm(total=self.query_places, 133 | desc="Running the test network", 134 | position=0) 135 | self.inferences = [] 136 | for model in models: 137 | self.inferences.append(nn.Sequential( 138 | model.feature_layer.w, 139 | model.output_layer.w, 140 | )) 141 | self.inferences[-1].to(torch.device(self.device)) 142 | # Initiliaze the output spikes variable 143 | out = [] 144 | labels = [] 145 | 146 | # Run inference for the specified number of timesteps 147 | for spikes, label in test_loader: 148 | # Set device 149 | spikes = spikes.to(self.device) 150 | labels.append(label.detach().cpu().item()) 151 | # Forward pass 152 | spikes = self.forward(spikes) 153 | # Add output spikes to list 154 | out.append(spikes.detach().cpu()) 155 | pbar.update(1) 156 | 157 | # Close the tqdm progress bar 158 | pbar.close() 159 | # Rehsape output spikes into a similarity matrix 160 | out = torch.stack(out, dim=2) 161 | out = out.squeeze(0).numpy() 162 | 163 | if self.skip != 0: 164 | GT = np.zeros((model.database_places, model.query_places)) 165 | skip = model.skip // model.filter 166 | # Create an array of indices for the query dimension 167 | query_indices = np.arange(model.query_places) 168 | 169 | # Set the ones on the diagonal starting at row `skip` 170 | GT[skip + query_indices, query_indices] = 1 171 | else: 172 | GT = np.eye(model.database_places, model.query_places) 173 | 174 | # Apply GT tolerance 175 | if self.GT_tolerance > 0: 176 | # Get the number of rows and columns 177 | num_rows, num_cols = GT.shape 178 | 179 | # Iterate over each column 180 | for col in range(num_cols): 181 | # Find the indices of rows where GT has a 1 in the current column 182 | ones_indices = np.where(GT[:, col] == 1)[0] 183 | 184 | # For each index with a 1, set 1s in GTtol within the specified vertical distance 185 | for row in ones_indices: 186 | # Determine the start and end rows, ensuring they are within bounds 187 | start_row = max(row - self.GT_tolerance, 0) 188 | end_row = min(row + self.GT_tolerance + 1, num_rows) # +1 because upper bound is exclusive 189 | 190 | # Set the range in GTtol to 1 191 | GT[start_row:end_row, col] = 1 192 | 193 | # If user specified, generate a PR curve 194 | if model.PR_curve: 195 | # Create PR curve 196 | P, R = createPR(out, GT, matching='single', n_thresh=100) 197 | # Combine P and R into a list of lists 198 | PR_data = { 199 | "Precision": P, 200 | "Recall": R 201 | } 202 | output_file = "PR_curve_data.json" 203 | # Construct the full path 204 | full_path = f"{model.output_folder}/{output_file}" 205 | # Write the data to a JSON file 206 | with open(full_path, 'w') as file: 207 | json.dump(PR_data, file) 208 | # Plot PR curve 209 | plt.plot(R,P) 210 | plt.xlabel('Recall') 211 | plt.ylabel('Precision') 212 | plt.title('Precision-Recall Curve') 213 | plt.show() 214 | 215 | plt.close() 216 | 217 | if model.sim_mat: 218 | # Create a figure and a set of subplots 219 | fig, axs = plt.subplots(1, 2, figsize=(15, 5)) 220 | 221 | # Plot each matrix using matshow 222 | cax1 = axs[0].matshow(out, cmap='viridis') 223 | fig.colorbar(cax1, ax=axs[0], shrink=0.8) 224 | axs[0].set_title('Similarity matrix') 225 | 226 | cax2 = axs[1].matshow(GT, cmap='plasma') 227 | fig.colorbar(cax2, ax=axs[1], shrink=0.8) 228 | axs[1].set_title('GT') 229 | 230 | # Adjust layout 231 | plt.tight_layout() 232 | plt.show() 233 | 234 | plt.close() 235 | 236 | # Recall@N 237 | N = [1,5,10,15,20,25] # N values to calculate 238 | R = [] # Recall@N values 239 | # Calculate Recall@N 240 | for n in N: 241 | R.append(round(recallAtK(out, GT, K=n),2)) 242 | # Print the results 243 | table = PrettyTable() 244 | table.field_names = ["N", "1", "5", "10", "15", "20", "25"] 245 | table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]]) 246 | self.logger.info(table) 247 | 248 | def forward(self, spikes): 249 | """ 250 | Compute the forward pass of the model. 251 | 252 | Parameters: 253 | - spikes (Tensor): Input spikes. 254 | 255 | Returns: 256 | - Tensor: Output after processing. 257 | """ 258 | # Initialize the output spikes tensor 259 | in_spikes = spikes.detach().clone() 260 | outputs = [] # List to collect output tensors 261 | 262 | # Run inferencing across modules 263 | for inference in self.inferences: 264 | out_spikes = inference(in_spikes) 265 | outputs.append(out_spikes) # Append the output tensor to the list 266 | 267 | # Concatenate along the desired dimension 268 | concatenated_output = torch.cat(outputs, dim=1) 269 | 270 | return concatenated_output 271 | 272 | def load_model(self, models, model_path): 273 | """ 274 | Load pre-trained model and set the state dictionary keys. 275 | """ 276 | # check if model exists, download the demo data and model if defaults are set 277 | if not os.path.exists(model_path): 278 | if model_path == self.demo: 279 | get_data_model() 280 | else: 281 | raise ValueError(f"Model path {model_path} does not exist.") 282 | 283 | combined_state_dict = torch.load(model_path, map_location=self.device, weights_only=True) 284 | 285 | for i, model in enumerate(models): # models_classes is a list of model classes 286 | model.load_state_dict(combined_state_dict[f'model_{i}']) 287 | model.eval() # Set the model to inference mode 288 | 289 | def run_inference(models, model_name): 290 | """ 291 | Run inference on a pre-trained model. 292 | 293 | :param models: Models to run inference on, each model is a VPRTempo module 294 | :param model_name: Name of the model to load 295 | """ 296 | # Set first index model as the main model for parameters 297 | model = models[0] 298 | # Initialize the image transforms 299 | image_transform = ProcessImage(model.dims, model.patches) 300 | 301 | # Initialize the test dataset 302 | test_dataset = CustomImageDataset(annotations_file=model.dataset_file, 303 | base_dir=model.data_dir, 304 | img_dirs=model.query_dir, 305 | transform=image_transform, 306 | max_samples=model.query_places, 307 | filter=model.filter, 308 | skip=model.skip) 309 | 310 | # Initialize the data loader 311 | test_loader = DataLoader(test_dataset, 312 | batch_size=1, 313 | num_workers=4, 314 | persistent_workers=True) 315 | 316 | # Load the model 317 | model.load_model(models, os.path.join('./vprtempo/models', model_name)) 318 | 319 | # Use evaluate method for inference accuracy 320 | with torch.no_grad(): 321 | model.evaluate(models, test_loader) -------------------------------------------------------------------------------- /vprtempo/VPRTempoQuant.py: -------------------------------------------------------------------------------- 1 | #MIT License 2 | 3 | #Copyright (c) 2023 Adam Hines, Peter G Stratton, Michael Milford, Tobias Fischer 4 | 5 | #Permission is hereby granted, free of charge, to any person obtaining a copy 6 | #of this software and associated documentation files (the "Software"), to deal 7 | #in the Software without restriction, including without limitation the rights 8 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | #copies of the Software, and to permit persons to whom the Software is 10 | #furnished to do so, subject to the following conditions: 11 | 12 | #The above copyright notice and this permission notice shall be included in all 13 | #copies or substantial portions of the Software. 14 | 15 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | #SOFTWARE. 22 | 23 | ''' 24 | Imports 25 | ''' 26 | 27 | import os 28 | import json 29 | import torch 30 | import random 31 | 32 | import numpy as np 33 | import torch.nn as nn 34 | import matplotlib.pyplot as plt 35 | import vprtempo.src.blitnet as bn 36 | 37 | from tqdm import tqdm 38 | from prettytable import PrettyTable 39 | from torch.utils.data import DataLoader 40 | from vprtempo.src.download import get_data_model 41 | from vprtempo.src.metrics import recallAtK, createPR 42 | from torch.ao.quantization import QuantStub, DeQuantStub 43 | from vprtempo.src.dataset import CustomImageDataset, ProcessImage 44 | 45 | #from main import parse_network 46 | 47 | class VPRTempoQuant(nn.Module): 48 | def __init__(self, args, dims, logger, num_modules, output_folder, out_dim, out_dim_remainder=None): 49 | super(VPRTempoQuant, self).__init__() 50 | 51 | # Set the args 52 | if args is not None: 53 | self.args = args 54 | for arg in vars(args): 55 | setattr(self, arg, getattr(args, arg)) 56 | setattr(self, 'dims', dims) 57 | 58 | # Set the device 59 | self.device = "cpu" 60 | 61 | # Set input args 62 | self.logger = logger 63 | self.num_modules = num_modules 64 | self.output_folder = output_folder 65 | 66 | self.quant = QuantStub() 67 | self.dequant = DeQuantStub() 68 | 69 | # Set the dataset file 70 | self.dataset_file = os.path.join('./vprtempo/dataset', f'{self.dataset}-{self.query_dir}' + '.csv') 71 | self.query_dir = [dir.strip() for dir in self.query_dir.split(',')] 72 | 73 | # Layer dict to keep track of layer names and their order 74 | self.layer_dict = {} 75 | self.layer_counter = 0 76 | self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')] 77 | 78 | # Define layer architecture 79 | self.input = int(self.dims[0]*self.dims[1]) 80 | self.feature = int(self.input * 2) 81 | 82 | # Output dimension changes for final module if not an even distribution of places 83 | if not out_dim_remainder is None: 84 | self.output = out_dim_remainder 85 | else: 86 | self.output = out_dim 87 | 88 | # set model name for default demo 89 | self.demo = './vprtempo/models/springfall_VPRTempoQuant_IN3136_FN6272_DB500.pth' 90 | 91 | """ 92 | Define trainable layers here 93 | """ 94 | self.add_layer( 95 | 'feature_layer', 96 | dims=[self.input, self.feature], 97 | device=self.device, 98 | inference=True 99 | ) 100 | self.add_layer( 101 | 'output_layer', 102 | dims=[self.feature, self.output], 103 | device=self.device, 104 | inference=True 105 | ) 106 | 107 | def add_layer(self, name, **kwargs): 108 | """ 109 | Dynamically add a layer with given name and keyword arguments. 110 | 111 | :param name: Name of the layer to be added 112 | :type name: str 113 | :param kwargs: Hyperparameters for the layer 114 | """ 115 | # Check for layer name duplicates 116 | if name in self.layer_dict: 117 | raise ValueError(f"Layer with name {name} already exists.") 118 | 119 | # Add a new SNNLayer with provided kwargs 120 | setattr(self, name, bn.SNNLayer(**kwargs)) 121 | 122 | # Add layer name and index to the layer_dict 123 | self.layer_dict[name] = self.layer_counter 124 | self.layer_counter += 1 125 | 126 | def evaluate(self, models, test_loader, layers=None): 127 | """ 128 | Run the inferencing model and calculate the accuracy. 129 | 130 | :param test_loader: Testing data loader 131 | :param layers: Layers to pass data through 132 | """ 133 | # Determine the Hardtahn max value 134 | maxSpike = (1//models[0].quant.scale).item() 135 | # Define the sequential inference model 136 | self.inferences = [] 137 | for model in models: 138 | self.inferences.append(nn.Sequential( 139 | model.feature_layer.w, 140 | model.output_layer.w, 141 | )) 142 | # Initialize the tqdm progress bar 143 | pbar = tqdm(total=self.query_places, 144 | desc="Running the test network", 145 | position=0) 146 | # Initiliaze the output spikes variable 147 | out = [] 148 | labels = [] 149 | # Run inference for the specified number of timesteps 150 | for spikes, label in test_loader: 151 | # Set device 152 | spikes = spikes.to(self.device) 153 | labels.append(label.detach().item()) 154 | # Pass through previous layers if they exist 155 | spikes = self.forward(spikes) 156 | # Add output spikes to list 157 | out.append(spikes.detach().cpu()) 158 | pbar.update(1) 159 | 160 | # Close the tqdm progress bar 161 | pbar.close() 162 | # Rehsape output spikes into a similarity matrix 163 | out = torch.stack(out, dim=2) 164 | out = out.squeeze(0).numpy() 165 | 166 | if self.skip != 0: 167 | GT = np.zeros((model.database_places, model.query_places)) 168 | skip = model.skip // model.filter 169 | # Create an array of indices for the query dimension 170 | query_indices = np.arange(model.query_places) 171 | 172 | # Set the ones on the diagonal starting at row `skip` 173 | GT[skip + query_indices, query_indices] = 1 174 | else: 175 | GT = np.eye(model.database_places, model.query_places) 176 | 177 | # Apply GT tolerance 178 | if self.GT_tolerance > 0: 179 | # Get the number of rows and columns 180 | num_rows, num_cols = GT.shape 181 | 182 | # Iterate over each column 183 | for col in range(num_cols): 184 | # Find the indices of rows where GT has a 1 in the current column 185 | ones_indices = np.where(GT[:, col] == 1)[0] 186 | 187 | # For each index with a 1, set 1s in GTtol within the specified vertical distance 188 | for row in ones_indices: 189 | # Determine the start and end rows, ensuring they are within bounds 190 | start_row = max(row - self.GT_tolerance, 0) 191 | end_row = min(row + self.GT_tolerance + 1, num_rows) # +1 because upper bound is exclusive 192 | 193 | # Set the range in GTtol to 1 194 | GT[start_row:end_row, col] = 1 195 | 196 | # If user specified, generate a PR curve 197 | if model.PR_curve: 198 | # Create PR curve 199 | P, R = createPR(out, GT, matching='single', n_thresh=100) 200 | # Combine P and R into a list of lists 201 | PR_data = { 202 | "Precision": P, 203 | "Recall": R 204 | } 205 | output_file = "PR_curve_data.json" 206 | # Construct the full path 207 | full_path = f"{model.output_folder}/{output_file}" 208 | # Write the data to a JSON file 209 | with open(full_path, 'w') as file: 210 | json.dump(PR_data, file) 211 | # Plot PR curve 212 | plt.plot(R,P) 213 | plt.xlabel('Recall') 214 | plt.ylabel('Precision') 215 | plt.title('Precision-Recall Curve') 216 | plt.show() 217 | 218 | if model.sim_mat: 219 | # Create a figure and a set of subplots 220 | fig, axs = plt.subplots(1, 2, figsize=(15, 5)) 221 | 222 | # Plot each matrix using matshow 223 | cax1 = axs[0].matshow(out, cmap='viridis') 224 | fig.colorbar(cax1, ax=axs[0], shrink=0.8) 225 | axs[0].set_title('Similarity matrix') 226 | 227 | cax2 = axs[1].matshow(GT, cmap='plasma') 228 | fig.colorbar(cax2, ax=axs[1], shrink=0.8) 229 | axs[1].set_title('GT') 230 | 231 | # Adjust layout 232 | plt.tight_layout() 233 | plt.show() 234 | 235 | # Recall@N 236 | N = [1,5,10,15,20,25] # N values to calculate 237 | R = [] # Recall@N values 238 | # Calculate Recall@N 239 | for n in N: 240 | R.append(round(recallAtK(out,GT,K=n),2)) 241 | # Print the results 242 | table = PrettyTable() 243 | table.field_names = ["N", "1", "5", "10", "15", "20", "25"] 244 | table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]]) 245 | self.logger.info(table) 246 | 247 | def forward(self, spikes): 248 | """ 249 | Compute the forward pass of the model. 250 | 251 | Parameters: 252 | - spikes (Tensor): Input spikes. 253 | 254 | Returns: 255 | - Tensor: Output after processing. 256 | """ 257 | spikes = self.quant(spikes) 258 | in_spikes = spikes.detach().clone() 259 | outputs = [] # List to collect output tensors 260 | 261 | for inference in self.inferences: 262 | out_spikes = inference(in_spikes) 263 | outputs.append(out_spikes) # Append the output tensor to the list 264 | 265 | # Concatenate along the desired dimension 266 | concatenated_output = torch.cat(outputs, dim=1) 267 | spikes = self.dequant(concatenated_output) 268 | 269 | return spikes 270 | 271 | def load_model(self, models, model_path): 272 | """ 273 | Load pre-trained model and set the state dictionary keys. 274 | """ 275 | # check if model exists, download the demo data and model if defaults are set 276 | if not os.path.exists(model_path): 277 | if model_path == self.demo: 278 | get_data_model() 279 | else: 280 | raise ValueError(f"Model path {model_path} does not exist.") 281 | combined_state_dict = torch.load(model_path, map_location=self.device, weights_only=True) 282 | 283 | for i, model in enumerate(models): # models_classes is a list of model classes 284 | 285 | model.load_state_dict(combined_state_dict[f'model_{i}']) 286 | model.eval() # Set the model to inference mode 287 | 288 | def run_inference_quant(models, model_name): 289 | """ 290 | Run inference on a pre-trained model. 291 | 292 | :param models: Models to run inference on, each model is a VPRTempo module 293 | :param model_name: Name of the model to load 294 | """ 295 | # Set first index model as the main model for parameters 296 | model = models[0] 297 | # Initialize the image transforms 298 | image_transform = ProcessImage(model.dims, model.patches) 299 | 300 | # Initialize the test dataset 301 | test_dataset = CustomImageDataset(annotations_file=model.dataset_file, 302 | base_dir=model.data_dir, 303 | img_dirs=model.query_dir, 304 | transform=image_transform, 305 | max_samples=model.query_places, 306 | filter=model.filter, 307 | skip=model.skip) 308 | 309 | # Initialize the data loader 310 | test_loader = DataLoader(test_dataset, 311 | batch_size=1, 312 | num_workers=8, 313 | persistent_workers=True) 314 | 315 | # Load the model 316 | model.load_model(models, os.path.join('./vprtempo/models', model_name)) 317 | 318 | # Use evaluate method for inference accuracy 319 | with torch.no_grad(): 320 | model.evaluate(models, test_loader) -------------------------------------------------------------------------------- /vprtempo/VPRTempoQuantTrain.py: -------------------------------------------------------------------------------- 1 | #MIT License 2 | 3 | #Copyright (c) 2023 Adam Hines, Peter G Stratton, Michael Milford, Tobias Fischer 4 | 5 | #Permission is hereby granted, free of charge, to any person obtaining a copy 6 | #of this software and associated documentation files (the "Software"), to deal 7 | #in the Software without restriction, including without limitation the rights 8 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | #copies of the Software, and to permit persons to whom the Software is 10 | #furnished to do so, subject to the following conditions: 11 | 12 | #The above copyright notice and this permission notice shall be included in all 13 | #copies or substantial portions of the Software. 14 | 15 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | #SOFTWARE. 22 | 23 | ''' 24 | Imports 25 | ''' 26 | 27 | import os 28 | import torch 29 | 30 | import numpy as np 31 | import torch.nn as nn 32 | import vprtempo.src.blitnet as bn 33 | import torch.quantization as quantization 34 | import torchvision.transforms as transforms 35 | 36 | from tqdm import tqdm 37 | from torch.utils.data import DataLoader 38 | from torch.ao.quantization import QuantStub, DeQuantStub 39 | from vprtempo.src.dataset import CustomImageDataset, ProcessImage 40 | 41 | class VPRTempoQuantTrain(nn.Module): 42 | def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=None): 43 | super(VPRTempoQuantTrain, self).__init__() 44 | 45 | # Set the arguments 46 | self.args = args 47 | for arg in vars(args): 48 | setattr(self, arg, getattr(args, arg)) 49 | setattr(self, 'dims', dims) 50 | 51 | self.device = "cpu" 52 | self.logger = logger 53 | self.num_modules = num_modules 54 | self.quant = QuantStub() 55 | self.dequant = DeQuantStub() 56 | 57 | # Set the dataset file 58 | fields = self.database_dirs.split(',') 59 | if len(fields) > 1: 60 | self.dataset_file = [] 61 | for field in fields: 62 | self.dataset_file.append(os.path.join('./vprtempo/dataset', f'{self.dataset}-{field}' + '.csv')) 63 | else: 64 | self.dataset_file = os.path.join('./vprtempo/dataset', f'{self.dataset}-{self.database_dirs}' + '.csv') 65 | 66 | # Layer dict to keep track of layer names and their order 67 | self.layer_dict = {} 68 | self.layer_counter = 0 69 | 70 | # Define layer architecture 71 | self.input = int(dims[0]*dims[1]) 72 | self.feature = int(self.input * 2) 73 | if not out_dim_remainder is None: 74 | self.output = out_dim_remainder 75 | else: 76 | self.output = out_dim 77 | 78 | # Set the total timestep count 79 | self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')] 80 | self.location_repeat = len(self.database_dirs) # Number of times to repeat the locations 81 | if not out_dim_remainder is None: 82 | self.T = int(out_dim_remainder * self.location_repeat * self.epoch) 83 | else: 84 | self.T = int(self.max_module * self.location_repeat * self.epoch) 85 | 86 | """ 87 | Define trainable layers here 88 | """ 89 | self.add_layer( 90 | 'feature_layer', 91 | dims=[self.input, self.feature], 92 | thr_range=[0, 0.5], 93 | fire_rate=[0.2, 0.9], 94 | ip_rate=0.15, 95 | stdp_rate=0.005, 96 | p=[0.1, 0.5], 97 | device=self.device 98 | ) 99 | self.add_layer( 100 | 'output_layer', 101 | dims=[self.feature, self.output], 102 | ip_rate=0.15, 103 | stdp_rate=0.005, 104 | spk_force=True, 105 | p=[1.0, 1.0], 106 | device=self.device 107 | ) 108 | 109 | def add_layer(self, name, **kwargs): 110 | """ 111 | Dynamically add a layer with given name and keyword arguments. 112 | 113 | :param name: Name of the layer to be added 114 | :type name: str 115 | :param kwargs: Hyperparameters for the layer 116 | """ 117 | # Check for layer name duplicates 118 | if name in self.layer_dict: 119 | raise ValueError(f"Layer with name {name} already exists.") 120 | 121 | # Add a new SNNLayer with provided kwargs 122 | setattr(self, name, bn.SNNLayer(**kwargs)) 123 | 124 | # Add layer name and index to the layer_dict 125 | self.layer_dict[name] = self.layer_counter 126 | self.layer_counter += 1 127 | 128 | def _anneal_learning_rate(self, layer, mod, itp, stdp): 129 | """ 130 | Anneal the learning rate for the current layer. 131 | """ 132 | if np.mod(mod, 100) == 0: # Modify learning rate every 100 timesteps 133 | pt = pow(float(self.T - mod) / self.T, 2) 134 | layer.eta_ip = torch.mul(itp, pt) # Anneal intrinsic threshold plasticity learning rate 135 | layer.eta_stdp = torch.mul(stdp, pt) # Anneal STDP learning rate 136 | 137 | return layer 138 | 139 | def train_model(self, train_loader, layer, model, model_num, prev_layers=None): 140 | """ 141 | Train a layer of the network model. 142 | 143 | :param train_loader: Training data loader 144 | :param layer: Layer to train 145 | :param prev_layers: Previous layers to pass data through 146 | """ 147 | 148 | # Initialize the tqdm progress bar 149 | pbar = tqdm(total=int(self.T), 150 | desc="Training ", 151 | position=0) 152 | 153 | # Initialize the learning rates for each layer (used for annealment) 154 | init_itp = layer.eta_ip.detach() 155 | init_stdp = layer.eta_stdp.detach() 156 | mod = 0 # Used to determine the learning rate annealment, resets at each epoch 157 | # idx scale factor for different modules 158 | idx_scale = (self.max_module*self.filter)*model_num 159 | # Run training for the specified number of epochs 160 | for _ in range(self.epoch): 161 | # Run training for the specified number of timesteps 162 | for spikes, labels in train_loader: 163 | spikes, labels = spikes.to(self.device), labels.to(self.device) 164 | idx = torch.round((labels - idx_scale) / self.filter) # Set output index for spike forcing 165 | # Pass through previous layers if they exist 166 | if prev_layers: 167 | with torch.no_grad(): 168 | for prev_layer_name in prev_layers: 169 | prev_layer = getattr(self, prev_layer_name) # Get the previous layer object 170 | spikes = self.forward(spikes, prev_layer) # Pass spikes through the previous layer 171 | spikes = bn.clamp_spikes(spikes, prev_layer) # Clamp spikes [0, 0.9] 172 | else: 173 | prev_layer = None 174 | # Get the output spikes from the current layer 175 | pre_spike = spikes.detach() # Previous layer spikes for STDP 176 | spikes = self.forward(spikes, layer) # Current layer spikes 177 | spikes_noclp = spikes.detach() # Used for inhibitory homeostasis 178 | spikes = bn.clamp_spikes(spikes, layer) # Clamp spikes [0, 0.9] 179 | # Calculate STDP 180 | layer = bn.calc_stdp(pre_spike,spikes,spikes_noclp,layer, idx, prev_layer=prev_layer) 181 | # Adjust learning rates 182 | layer = self._anneal_learning_rate(layer, mod, init_itp, init_stdp) 183 | # Update the annealing mod & progress bar 184 | mod += 1 185 | pbar.update(1) 186 | 187 | # Close the tqdm progress bar 188 | pbar.close() 189 | 190 | def forward(self, spikes, layer): 191 | """ 192 | Compute the forward pass of the model. 193 | 194 | Parameters: 195 | - spikes (Tensor): Input spikes. 196 | 197 | Returns: 198 | - Tensor: Output after processing. 199 | """ 200 | 201 | spikes = self.quant(spikes) 202 | spikes = layer.w(spikes) 203 | spikes = self.dequant(spikes) 204 | 205 | return spikes 206 | 207 | def save_model(self, models, model_out): 208 | """ 209 | Save the trained model to models output folder. 210 | """ 211 | state_dicts = {} 212 | for i, model in enumerate(models): # Assuming models_list is your list of models 213 | state_dicts[f'model_{i}'] = model.state_dict() 214 | 215 | torch.save(state_dicts, model_out) 216 | 217 | def generate_model_name_quant(model): 218 | """ 219 | Generate the model name based on its parameters. 220 | """ 221 | return ("VPRTempoQuant" + 222 | str(model.input) + 223 | str(model.feature) + 224 | str(model.output) + 225 | str(model.num_modules) + 226 | '.pth') 227 | 228 | def check_pretrained_model(model_name): 229 | """ 230 | Check if a pre-trained model exists and prompt the user to retrain if desired. 231 | """ 232 | if os.path.exists(os.path.join('./models', model_name)): 233 | prompt = "A network with these parameters exists, re-train network? (y/n):\n" 234 | retrain = input(prompt).strip().lower() 235 | return retrain == 'n' 236 | return False 237 | 238 | def train_new_model_quant(models, model_name): 239 | """ 240 | Train a new model. 241 | 242 | :param model: Model to train 243 | :param model_name: Name of the model to save after training 244 | """ 245 | # Set first index model as the main model for parameters 246 | model = models[0] 247 | # Initialize the image transforms and datasets 248 | image_transform = transforms.Compose([ 249 | ProcessImage(model.dims, model.patches) 250 | ]) 251 | # Automatically generate user_input_ranges 252 | user_input_ranges = [] 253 | start_idx = 0 254 | # Generate the image ranges for each module 255 | for _ in range(model.num_modules): 256 | range_temp = [start_idx, start_idx+((model.max_module-1)*model.filter)] 257 | user_input_ranges.append(range_temp) 258 | start_idx = range_temp[1] + model.filter 259 | 260 | # Keep track of trained layers to pass data through them 261 | trained_layers = [] 262 | 263 | # Training each layer 264 | for layer_name, _ in sorted(model.layer_dict.items(), key=lambda item: item[1]): 265 | print(f"Training layer: {layer_name}") 266 | # Retrieve the layer object 267 | for i, model in enumerate(models): 268 | model.train() 269 | model.to(torch.device(model.device)) 270 | layer = (getattr(model, layer_name)) 271 | # Determine the maximum samples for the DataLoader 272 | if model.database_places < model.max_module: 273 | max_samples = model.database_places 274 | elif model.output < model.max_module: 275 | max_samples = model.output 276 | else: 277 | max_samples = model.max_module 278 | # Initialize new dataset with unique range for each module 279 | img_range=user_input_ranges[i] 280 | train_dataset = CustomImageDataset(annotations_file=model.dataset_file, 281 | base_dir=model.data_dir, 282 | img_dirs=model.database_dirs, 283 | transform=image_transform, 284 | filter=models[0].filter, 285 | skip=models[0].skip, 286 | test=False, 287 | img_range=img_range, 288 | max_samples=max_samples) 289 | # Initialize the data loader 290 | train_loader = DataLoader(train_dataset, 291 | batch_size=1, 292 | shuffle=True, 293 | num_workers=8, 294 | persistent_workers=True) 295 | # Train the layers 296 | model.train_model(train_loader, layer, model, i, prev_layers=trained_layers) 297 | trained_layers.append(layer_name) 298 | 299 | # Convert the model to evaluation mode 300 | for model in models: 301 | quantization.convert(model, inplace=True) 302 | 303 | # Save the model 304 | model.save_model(models,os.path.join('./vprtempo/models', model_name)) -------------------------------------------------------------------------------- /vprtempo/VPRTempoTrain.py: -------------------------------------------------------------------------------- 1 | #MIT License 2 | 3 | #Copyright (c) 2023 Adam Hines, Peter G Stratton, Michael Milford, Tobias Fischer 4 | 5 | #Permission is hereby granted, free of charge, to any person obtaining a copy 6 | #of this software and associated documentation files (the "Software"), to deal 7 | #in the Software without restriction, including without limitation the rights 8 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | #copies of the Software, and to permit persons to whom the Software is 10 | #furnished to do so, subject to the following conditions: 11 | 12 | #The above copyright notice and this permission notice shall be included in all 13 | #copies or substantial portions of the Software. 14 | 15 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | #SOFTWARE. 22 | 23 | ''' 24 | Imports 25 | ''' 26 | 27 | import os 28 | import gc 29 | import torch 30 | import sys 31 | 32 | import numpy as np 33 | import torch.nn as nn 34 | import vprtempo.src.blitnet as bn 35 | import torchvision.transforms as transforms 36 | 37 | from tqdm import tqdm 38 | from torch.utils.data import DataLoader 39 | from vprtempo.src.loggers import model_logger 40 | from vprtempo.src.dataset import CustomImageDataset, ProcessImage 41 | 42 | class VPRTempoTrain(nn.Module): 43 | def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=None): 44 | super(VPRTempoTrain, self).__init__() 45 | 46 | # Set the arguments 47 | self.args = args 48 | for arg in vars(args): 49 | setattr(self, arg, getattr(args, arg)) 50 | setattr(self, 'dims', dims) 51 | # Set the device 52 | if torch.cuda.is_available(): 53 | self.device = "cuda:0" 54 | elif torch.backends.mps.is_available(): 55 | self.device = "mps" 56 | else: 57 | self.device = "cpu" 58 | self.logger = logger 59 | self.num_modules = num_modules 60 | 61 | # Set the dataset file 62 | fields = self.database_dirs.split(',') 63 | if len(fields) > 1: 64 | self.dataset_file = [] 65 | for field in fields: 66 | self.dataset_file.append(os.path.join('./vprtempo/dataset', f'{self.dataset}-{field}' + '.csv')) 67 | else: 68 | self.dataset_file = os.path.join('./vprtempo/dataset', f'{self.dataset}-{self.database_dirs}' + '.csv') 69 | 70 | # Layer dict to keep track of layer names and their order 71 | self.layer_dict = {} 72 | self.layer_counter = 0 73 | 74 | # Define layer architecture 75 | self.input = int(dims[0]*dims[1]) 76 | self.feature = int(self.input * 2) 77 | if not out_dim_remainder is None: 78 | self.output = out_dim_remainder 79 | else: 80 | self.output = out_dim 81 | 82 | # Set the total timestep count 83 | self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')] 84 | self.location_repeat = len(self.database_dirs) # Number of times to repeat the locations 85 | if not out_dim_remainder is None: 86 | self.T = int(out_dim_remainder * self.location_repeat * self.epoch) 87 | else: 88 | self.T = int(self.max_module * self.location_repeat * self.epoch) 89 | 90 | """ 91 | Define trainable layers here 92 | """ 93 | self.add_layer( 94 | 'feature_layer', 95 | dims=[self.input, self.feature], 96 | thr_range=[0, 0.5], 97 | fire_rate=[0.2, 0.9], 98 | ip_rate=0.15, 99 | stdp_rate=0.005, 100 | p=[0.1, 0.5], 101 | device=self.device 102 | ) 103 | self.add_layer( 104 | 'output_layer', 105 | dims=[self.feature, self.output], 106 | ip_rate=0.15, 107 | stdp_rate=0.005, 108 | p=[1.0, 1.0], 109 | spk_force=True, 110 | device=self.device 111 | ) 112 | 113 | def add_layer(self, name, **kwargs): 114 | """ 115 | Dynamically add a layer with given name and keyword arguments. 116 | 117 | :param name: Name of the layer to be added 118 | :type name: str 119 | :param kwargs: Hyperparameters for the layer 120 | """ 121 | # Check for layer name duplicates 122 | if name in self.layer_dict: 123 | raise ValueError(f"Layer with name {name} already exists.") 124 | 125 | # Add a new SNNLayer with provided kwargs 126 | setattr(self, name, bn.SNNLayer(**kwargs)) 127 | 128 | # Add layer name and index to the layer_dict 129 | self.layer_dict[name] = self.layer_counter 130 | self.layer_counter += 1 131 | 132 | def model_logger(self): 133 | """ 134 | Log the model configuration to the console. 135 | """ 136 | model_logger(self) 137 | 138 | def _anneal_learning_rate(self, layer, mod, itp, stdp): 139 | """ 140 | Anneal the learning rate for the current layer. 141 | """ 142 | if np.mod(mod, 100) == 0: # Modify learning rate every 100 timesteps 143 | pt = pow(float(self.T - mod) / self.T, 2) 144 | layer.eta_ip = torch.mul(itp, pt) # Anneal intrinsic threshold plasticity learning rate 145 | layer.eta_stdp = torch.mul(stdp, pt) # Anneal STDP learning rate 146 | 147 | return layer 148 | 149 | def train_model(self, train_loader, layer, model, model_num, prev_layers=None): 150 | """ 151 | Train a layer of the network model. 152 | 153 | :param train_loader: Training data loader 154 | :param layer: Layer to train 155 | :param prev_layers: Previous layers to pass data through 156 | """ 157 | 158 | # Initialize the tqdm progress bar 159 | pbar = tqdm(total=self.T, 160 | desc=f"Module {model_num+1}", 161 | position=0) 162 | 163 | # Initialize the learning rates for each layer (used for annealment) 164 | init_itp = layer.eta_ip.detach() 165 | init_stdp = layer.eta_stdp.detach() 166 | mod = 0 # Used to determine the learning rate annealment, resets at each epoch 167 | 168 | # idx scale factor for different modules 169 | idx_scale = (self.max_module*self.filter)*model_num 170 | 171 | # Run training for the specified number of epochs 172 | for _ in range(self.epoch): 173 | # Run training for the specified number of timesteps 174 | for spikes, labels in train_loader: 175 | spikes, labels = spikes.to(self.device), labels.to(self.device) 176 | idx = torch.round((labels - idx_scale) / self.filter) # Set output index for spike forcing 177 | # Pass through previous layers if they exist 178 | if prev_layers: 179 | with torch.no_grad(): 180 | for prev_layer_name in prev_layers: 181 | prev_layer = getattr(model, prev_layer_name) # Get the previous layer object 182 | spikes = self.forward(spikes, prev_layer) # Pass spikes through the previous layer 183 | spikes = bn.clamp_spikes(spikes, prev_layer) # Clamp spikes [0, 0.9] 184 | else: 185 | prev_layer = None 186 | # Get the output spikes from the current layer 187 | pre_spike = spikes.detach() # Previous layer spikes for STDP 188 | spikes = self.forward(spikes, layer) # Current layer spikes 189 | spikes_noclp = spikes.detach() # Used for inhibitory homeostasis 190 | spikes = bn.clamp_spikes(spikes, layer) # Clamp spikes [0, 0.9] 191 | # Calculate STDP 192 | layer = bn.calc_stdp(pre_spike,spikes,spikes_noclp,layer, idx, prev_layer=prev_layer) 193 | # Adjust learning rates 194 | layer = self._anneal_learning_rate(layer, mod, init_itp, init_stdp) 195 | # Update the annealing mod & progress bar 196 | mod += 1 197 | pbar.update(1) 198 | 199 | # Close the tqdm progress bar 200 | pbar.close() 201 | 202 | # Free up memory 203 | if self.device == "cuda:0": 204 | torch.cuda.empty_cache() 205 | gc.collect() 206 | 207 | def forward(self, spikes, layer): 208 | """ 209 | Compute the forward pass of the model. 210 | 211 | Parameters: 212 | - spikes (Tensor): Input spikes. 213 | 214 | Returns: 215 | - Tensor: Output after processing. 216 | """ 217 | 218 | spikes = layer.w(spikes) 219 | 220 | return spikes 221 | 222 | def save_model(self,models, model_out): 223 | """ 224 | Save the trained model to models output folder. 225 | """ 226 | state_dicts = {} 227 | for i, model in enumerate(models): # Assuming models_list is your list of models 228 | state_dicts[f'model_{i}'] = model.state_dict() 229 | 230 | torch.save(state_dicts, model_out) 231 | 232 | 233 | def check_pretrained_model(model_name): 234 | """ 235 | Check if a pre-trained model exists and prompt the user to retrain if desired. 236 | """ 237 | if os.path.exists(os.path.join('./vprtempo/models', model_name)): 238 | prompt = "A network with these parameters exists, re-train network? (y/n):\n" 239 | retrain = input(prompt).strip().lower() 240 | if retrain == 'y': 241 | return True 242 | elif retrain == 'n': 243 | print('Training new model cancelled') 244 | sys.exit() 245 | 246 | def train_new_model(models, model_name): 247 | """ 248 | Train a new model. 249 | 250 | :param model: Model to train 251 | :param model_name: Name of the model to save after training 252 | :param qconfig: Quantization configuration 253 | """ 254 | # Initialize the image transforms and datasets 255 | image_transform = transforms.Compose([ 256 | ProcessImage(models[0].dims, models[0].patches) 257 | ]) 258 | # Automatically generate user_input_ranges 259 | user_input_ranges = [] 260 | start_idx = 0 261 | # Generate the image ranges for each module 262 | for _ in range(models[0].num_modules): 263 | range_temp = [start_idx, start_idx+((models[0].max_module-1)*models[0].filter)] 264 | user_input_ranges.append(range_temp) 265 | start_idx = range_temp[1] + models[0].filter 266 | 267 | # Keep track of trained layers to pass data through them 268 | trained_layers = [] 269 | # Training each layer 270 | for layer_name, _ in sorted(models[0].layer_dict.items(), key=lambda item: item[1]): 271 | print(f"Training layer: {layer_name}") 272 | # Retrieve the layer object 273 | for i, model in enumerate(models): 274 | model.train() 275 | model.to(torch.device(model.device)) 276 | layer = (getattr(model, layer_name)) 277 | # Determine the maximum samples for the DataLoader 278 | if model.database_places < model.max_module: 279 | max_samples = model.database_places 280 | elif model.output < model.max_module: 281 | max_samples = model.output 282 | else: 283 | max_samples = model.max_module 284 | # Initialize new dataset with unique range for each module 285 | img_range=user_input_ranges[i] 286 | train_dataset = CustomImageDataset(annotations_file=models[0].dataset_file, 287 | base_dir=models[0].data_dir, 288 | img_dirs=models[0].database_dirs, 289 | transform=image_transform, 290 | filter=models[0].filter, 291 | skip=models[0].skip, 292 | test=False, 293 | img_range=img_range, 294 | max_samples=max_samples) 295 | # Initialize the data loader 296 | train_loader = DataLoader(train_dataset, 297 | batch_size=1, 298 | shuffle=True, 299 | num_workers=8, 300 | persistent_workers=True) 301 | # Train the layers 302 | model.train_model(train_loader, layer, model, i, prev_layers=trained_layers) 303 | model.to(torch.device("cpu")) 304 | # After training the current layer, add it to the list of trained layers 305 | trained_layers.append(layer_name) 306 | # Convert the model to evaluation mode 307 | for model in models: 308 | model.eval() 309 | # Save the model 310 | model.save_model(models,os.path.join('./vprtempo/models', model_name)) -------------------------------------------------------------------------------- /vprtempo/__init__.py: -------------------------------------------------------------------------------- 1 | _version__ = '1.1.9' 2 | -------------------------------------------------------------------------------- /vprtempo/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QVPR/VPRTempo/52fb2ecbbbe924430021b54e309fe2a5c8d18672/vprtempo/models/.gitkeep -------------------------------------------------------------------------------- /vprtempo/models/README.txt: -------------------------------------------------------------------------------- 1 | To download pretrained models, please visit -> https://www.dropbox.com/scl/fi/ysfz7t7ek6h0pslwq9hd4/VPRTempo_pretrained_models.zip?rlkey=thg0rhn0hjsyov6zov63ni11o&st=nvimet71&dl=0 2 | -------------------------------------------------------------------------------- /vprtempo/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QVPR/VPRTempo/52fb2ecbbbe924430021b54e309fe2a5c8d18672/vprtempo/src/__init__.py -------------------------------------------------------------------------------- /vprtempo/src/blitnet.py: -------------------------------------------------------------------------------- 1 | #MIT License 2 | 3 | #Copyright (c) 2023 Adam Hines, Peter Stratton, Michael Milford, Tobias Fischer 4 | 5 | #Permission is hereby granted, free of charge, to any person obtaining a copy 6 | #of this software and associated documentation files (the "Software"), to deal 7 | #in the Software without restriction, including without limitation the rights 8 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | #copies of the Software, and to permit persons to whom the Software is 10 | #furnished to do so, subject to the following conditions: 11 | 12 | #The above copyright notice and this permission notice shall be included in all 13 | #copies or substantial portions of the Software. 14 | 15 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | #SOFTWARE. 22 | 23 | ''' 24 | Imports 25 | ''' 26 | import torch 27 | 28 | import torch.nn as nn 29 | import numpy as np 30 | 31 | 32 | class SNNLayer(nn.Module): 33 | def __init__(self, dims=[0,0],thr_range=[0,0],fire_rate=[0,0],ip_rate=0, 34 | stdp_rate=0,const_inp=[0,0],p=[1,1],spk_force=False,device=None,inference=False,args=None): 35 | super(SNNLayer, self).__init__() 36 | """ 37 | dims: [input, output] dimensions of the layer 38 | thr_range: [min, max] range of thresholds 39 | fire_rate: [min, max] range of firing rates 40 | ip_rate: learning rate for input threshold plasticity 41 | stdp_rate: learning rate for stdp 42 | const_inp: [min, max] range of constant input 43 | p: [min, max] range of connection probabilities 44 | spk_force: boolean to force spikes 45 | """ 46 | 47 | # Device 48 | self.device = device 49 | # Add different parameters depending if trainnig or running inference model 50 | if inference: # If running inference model 51 | self.w = nn.Linear(dims[0], dims[1], bias=False) # Combined weight tensors 52 | self.w.to(device) 53 | self.thr = nn.Parameter(torch.zeros([1, dims[-1]], 54 | device=self.device).uniform_(thr_range[0], 55 | thr_range[1])) 56 | else: # If training new model 57 | # Check constraints etc 58 | if np.isscalar(thr_range): thr_range = [thr_range, thr_range] 59 | if np.isscalar(fire_rate): fire_rate = [fire_rate, fire_rate] 60 | if np.isscalar(const_inp): const_inp = [const_inp, const_inp] 61 | 62 | # Initialize Tensors 63 | self.x = torch.zeros([1, dims[-1]], device=self.device) 64 | self.eta_ip = torch.tensor(ip_rate, device=self.device) 65 | self.eta_stdp = torch.tensor(stdp_rate, device=self.device) 66 | 67 | # Initialize Parameters 68 | self.thr = nn.Parameter(torch.zeros([1, dims[-1]], 69 | device=self.device).uniform_(thr_range[0], 70 | thr_range[1])) 71 | self.fire_rate = torch.zeros([1,dims[-1]], device=self.device).uniform_(fire_rate[0], fire_rate[1]) 72 | 73 | # Sequentially set the feature firing rates (if any) 74 | if not torch.all(self.fire_rate==0).item(): 75 | fstep = (fire_rate[1]-fire_rate[0])/dims[-1] 76 | 77 | for i in range(dims[-1]): 78 | self.fire_rate[:,i] = fire_rate[0]+fstep*(i+1) 79 | 80 | self.have_rate = torch.any(self.fire_rate[:,0] > 0.0).to(self.device) 81 | self.const_inp = torch.zeros([1, dims[-1]], device=self.device).uniform_(const_inp[0], const_inp[1]) 82 | self.p = p 83 | self.dims = dims 84 | 85 | # Additional State Variables 86 | self.set_spks = [] 87 | self.sspk_idx = 0 88 | self.spikes = torch.empty([], dtype=torch.float64) 89 | self.spk_force = spk_force 90 | 91 | # Create the excitatory weights 92 | self.exc = nn.Linear(dims[0], dims[1], bias=False) 93 | self.exc.weight = self.addWeights(dims=dims, 94 | W_range=[0,1], 95 | p=p[0], 96 | device=device) 97 | 98 | # Create the inhibitory weights 99 | self.inh = nn.Linear(dims[0], dims[1], bias=False) 100 | self.inh.weight = self.addWeights(dims=dims, 101 | W_range=[-1,0], 102 | p=p[-1], 103 | device=device) 104 | 105 | # Output boolean reference of which neurons have connection weights 106 | self.havconnExc = self.exc.weight > 0 107 | self.havconnInh = self.inh.weight < 0 108 | 109 | # Combine weights into a single tensor 110 | self.w = nn.Linear(dims[0], dims[1], bias=False) 111 | self.w.weight = nn.Parameter(torch.add(self.exc.weight, self.inh.weight)) 112 | 113 | self.havconnCombinedExc = self.w.weight > 0 114 | self.havconnCombinedInh = self.w.weight < 0 115 | 116 | del self.exc, self.inh 117 | 118 | def addWeights(self,W_range=[0,0],p=[0,0],dims=[0,0],device=None): 119 | 120 | # Get torch device 121 | device = device 122 | 123 | # Check constraints etc 124 | if np.isscalar(W_range): W_range = [W_range,W_range] 125 | 126 | # Determine dimensions of the weight matrices 127 | nrow = dims[1] 128 | ncol = dims[0] 129 | 130 | # Calculate mean and std for normal distributions 131 | Wmn = (W_range[0]+W_range[1])/2.0 132 | Wsd = (W_range[1]-W_range[0])/6.0 133 | 134 | # Initialize weights as empty tensors 135 | W = torch.empty((0, nrow, ncol), device=device) 136 | 137 | # Normally disribute random weights 138 | W = torch.empty(nrow, ncol, device=device).normal_(mean=Wmn, std=Wsd) 139 | 140 | # Remove inappropriate weights based on sign from W_range 141 | if W_range[-1] != 0: 142 | # For excitatory weights 143 | W[W < 0] = 0.0 144 | else: 145 | # For inhibitory weights 146 | W[W > 0] = 0.0 147 | 148 | # Remove weights based on connection probability 149 | setzero = np.random.rand(nrow,ncol) > p 150 | if setzero.any(): 151 | W[setzero] = 0.0 152 | 153 | # Normalise the weights 154 | nrm = torch.linalg.norm(W[len(W)-1],ord=1,axis=0) 155 | nrm[nrm==0.0] = 1.0 156 | W = nn.Parameter(W/nrm) 157 | 158 | return W 159 | 160 | def add_input(spikes, layer): 161 | 162 | # Add the constant input 163 | spikes += layer.const_inp 164 | 165 | return spikes 166 | 167 | def clamp_spikes(spikes, layer): 168 | # Clamp outputs between 0 and 0.9 after subtracting thresholds from input 169 | spikes = torch.clamp(torch.sub(spikes, layer.thr), min=0.0, max=0.9) 170 | 171 | return spikes 172 | 173 | def calc_stdp(prespike, spikes, noclp, layer, idx, prev_layer=None): 174 | # Spike Forcing has special rules to make calculated and forced spikes match 175 | if layer.spk_force: 176 | 177 | # Get layer dimensions 178 | shape = layer.w.weight.data.shape 179 | 180 | # Get the output neuron index 181 | idx_sel = torch.arange(int(idx[0]), int(idx[0]) + 1, 182 | device=layer.device, 183 | dtype=int) 184 | 185 | # Difference between forced and calculated spikes 186 | layer.x = torch.full_like(layer.x, 0) 187 | xdiff = layer.x.index_fill_(-1, idx_sel, 0.5) - spikes 188 | xdiff.clamp(min=0.0, max=0.9) 189 | 190 | # Pre and Post spikes tiled across and down for all synapses 191 | if prev_layer.fire_rate == None: 192 | mpre = prespike 193 | else: 194 | # Modulate learning rate by firing rate (low firing rate = high learning rate) 195 | mpre = prespike/prev_layer.fire_rate 196 | 197 | # Tile out pre- and post- spikes for STDP weight updates 198 | pre = torch.tile(torch.reshape(mpre, (shape[1], 1)), (1, shape[0])) 199 | post = torch.tile(xdiff, (shape[1], 1)) 200 | 201 | # Apply the weight changes 202 | layer.w.weight.data += ((pre * post * layer.havconnCombinedExc.T) * 203 | layer.eta_stdp).T 204 | layer.w.weight.data += ((-pre * post * layer.havconnCombinedInh.T) * 205 | (layer.eta_stdp * -1)).T 206 | 207 | # Normal STDP 208 | else: 209 | 210 | # Get layer dimensions 211 | shape = layer.w.weight.data.shape 212 | 213 | # Tile out pre- and post-spikes 214 | pre = torch.tile(torch.reshape(prespike, (shape[1], 1)), (1, shape[0])) 215 | post = torch.tile(spikes, (shape[1], 1)) 216 | 217 | # Apply positive and negative weight changes 218 | layer.w.weight.data += (((0.5 - post) * (pre > 0) * (post > 0) * 219 | layer.havconnCombinedExc.T) * layer.eta_stdp).T 220 | layer.w.weight.data += (((0.5 - post) * (pre > 0) * 221 | (post > 0) * layer.havconnCombinedInh.T) * (layer.eta_stdp * -1)).T 222 | 223 | # Remove negative weights for excW and positive for inhW 224 | layer.w.weight.data[layer.havconnCombinedExc] = layer.w.weight.data[layer.havconnCombinedExc].clamp(min=1e-06, max=10) 225 | layer.w.weight.data[layer.havconnCombinedInh] = layer.w.weight.data[layer.havconnCombinedInh].clamp(min=-10, max=-1e-06) 226 | 227 | # Check if layer has target firing rate and an ITP learning rate 228 | if layer.have_rate and layer.eta_ip > 0.0: 229 | 230 | # Replace the original layer.thr with the updated one 231 | layer.thr.data += layer.eta_ip * (layer.x - layer.fire_rate) 232 | layer.thr.data[layer.thr.data < 0] = 0 233 | 234 | # Check if layer has inhibitory weights and an stdp learning rate 235 | if torch.any(layer.w.weight.data).item() and layer.eta_stdp != 0: 236 | 237 | # Normalize the inhibitory weights using homeostasis 238 | inhW = layer.w.weight.data.T.clone() 239 | inhW[inhW>0] = 0 240 | layer.w.weight.data += (torch.mul(noclp,inhW) * layer.eta_stdp*50).T 241 | #layer.w.weight.data[layer.w.weight.data > 0.0] = -1e-06 242 | 243 | return layer -------------------------------------------------------------------------------- /vprtempo/src/create_data_csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | 4 | def create_csv_from_images(folder_path, csv_file_path): 5 | files = os.listdir(folder_path) 6 | png_files = sorted([f for f in files if f.endswith('.png')]) 7 | 8 | with open(csv_file_path, 'w', newline='') as file: 9 | writer = csv.writer(file) 10 | writer.writerow(['Image_name', 'index']) 11 | 12 | for index, image_name in enumerate(png_files): 13 | writer.writerow([image_name, index]) 14 | 15 | # Name of the dataset to create .csv for 16 | dataset_name = 'nordland-fall' 17 | 18 | # Generate paths 19 | folder_path = '' 20 | csv_file_path = os.path.join('./VPRTempo/vprtempo/dataset', f'{dataset_name}.csv') 21 | 22 | # Create .csv file 23 | create_csv_from_images(folder_path, csv_file_path) -------------------------------------------------------------------------------- /vprtempo/src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | 5 | import pandas as pd 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | from torchvision.io import read_image 10 | from torch.utils.data import Dataset 11 | 12 | class GetPatches2D: 13 | def __init__(self, patch_size, image_pad): 14 | self.patch_size = patch_size 15 | self.image_pad = image_pad 16 | 17 | def __call__(self, img): 18 | 19 | # Assuming image_pad is already a PyTorch tensor. If not, you can convert it: 20 | # image_pad = torch.tensor(image_pad).to(torch.float64) 21 | 22 | # Using unfold to get 2D sliding windows. 23 | unfolded = self.image_pad.unfold(0, self.patch_size[0], 1).unfold(1, self.patch_size[1], 1) 24 | # The size of unfolded will be [nrows, ncols, patch_size[0], patch_size[1]] 25 | 26 | # Reshaping the tensor to the desired shape 27 | patches = unfolded.permute(2, 3, 0, 1).contiguous().view(self.patch_size[0]*self.patch_size[1], -1) 28 | 29 | return patches 30 | 31 | 32 | class PatchNormalisePad: 33 | def __init__(self, patches): 34 | self.patches = patches 35 | 36 | 37 | def nanstd(self,input_tensor, dim=None, unbiased=True): 38 | if dim is not None: 39 | valid_count = torch.sum(~torch.isnan(input_tensor), dim=dim, dtype=torch.float) 40 | mean = torch.nansum(input_tensor, dim=dim) / valid_count 41 | diff = input_tensor - mean.unsqueeze(dim) 42 | variance = torch.nansum(diff * diff, dim=dim) / valid_count 43 | 44 | # Bessel's correction for unbiased estimation 45 | if unbiased: 46 | variance = variance * (valid_count / (valid_count - 1)) 47 | else: 48 | valid_count = torch.sum(~torch.isnan(input_tensor), dtype=torch.float) 49 | mean = torch.nansum(input_tensor) / valid_count 50 | diff = input_tensor - mean 51 | variance = torch.nansum(diff * diff) / valid_count 52 | 53 | # Bessel's correction for unbiased estimation 54 | if unbiased: 55 | variance = variance * (valid_count / (valid_count - 1)) 56 | 57 | return torch.sqrt(variance) 58 | 59 | def __call__(self, img): 60 | img = torch.squeeze(img,0) 61 | patch_size = (self.patches, self.patches) 62 | patch_half_size = [int((p-1)/2) for p in patch_size ] 63 | 64 | # Compute the padding. If patch_half_size is a scalar, the same value will be used for all sides. 65 | if isinstance(patch_half_size, int): 66 | pad = (patch_half_size, patch_half_size, patch_half_size, patch_half_size) # left, right, top, bottom 67 | else: 68 | # If patch_half_size is a tuple, then we'll assume it's in the format (height, width) 69 | pad = (patch_half_size[1], patch_half_size[1], patch_half_size[0], patch_half_size[0]) # left, right, top, bottom 70 | 71 | # Apply padding 72 | image_pad = F.pad(img, pad, mode='constant', value=float('nan')) 73 | 74 | nrows = img.shape[0] 75 | ncols = img.shape[1] 76 | patcher = GetPatches2D(patch_size,image_pad) 77 | patches = patcher(img) 78 | mus = torch.nanmean(patches, dim=0) 79 | stds = self.nanstd(patches, dim=0) 80 | with np.errstate(divide='ignore', invalid='ignore'): 81 | im_norm = (img - mus.reshape(nrows, ncols)) / stds.reshape(nrows, ncols) 82 | 83 | im_norm[torch.isnan(im_norm)] = 0.0 84 | im_norm[im_norm < -1.0] = -1.0 85 | im_norm[im_norm > 1.0] = 1.0 86 | 87 | return im_norm 88 | 89 | class SetImageAsSpikes: 90 | def __init__(self, intensity=255, test=True): 91 | self.intensity = intensity 92 | 93 | # Setup QAT FakeQuantize for the activations (your spikes) 94 | self.fake_quantize = torch.quantization.FakeQuantize( 95 | observer=torch.quantization.MovingAverageMinMaxObserver, 96 | quant_min=0, 97 | quant_max=255, 98 | dtype=torch.quint8, 99 | qscheme=torch.per_tensor_affine, 100 | reduce_range=False 101 | ) 102 | 103 | def train(self): 104 | self.fake_quantize.train() 105 | 106 | def eval(self): 107 | self.fake_quantize.eval() 108 | 109 | def __call__(self, img_tensor): 110 | N, W, H = img_tensor.shape 111 | reshaped_batch = img_tensor.view(N, 1, -1) 112 | 113 | # Divide all pixel values by 255 114 | normalized_batch = reshaped_batch / self.intensity 115 | normalized_batch = torch.squeeze(normalized_batch, 0) 116 | 117 | # Apply FakeQuantize 118 | spikes = self.fake_quantize(normalized_batch) 119 | 120 | if not self.fake_quantize.training: 121 | scale, zero_point = self.fake_quantize.calculate_qparams() 122 | spikes = torch.quantize_per_tensor(spikes, float(scale), int(zero_point), dtype=torch.quint8) 123 | 124 | return spikes 125 | 126 | class ProcessImage: 127 | def __init__(self, dims, patches): 128 | self.dims = dims 129 | self.patches = patches 130 | 131 | def __call__(self, img): 132 | # Convert the image to grayscale using the standard weights for RGB channels 133 | if img.shape[0] == 3: 134 | img = 0.299 * img[0] + 0.587 * img[1] + 0.114 * img[2] 135 | # Add a channel dimension to the resulting grayscale image 136 | img= img.unsqueeze(0) 137 | 138 | # gamma correction 139 | mid = 0.5 140 | mean = torch.mean(img.float()) 141 | gamma = math.log(mid * 255) / math.log(mean) 142 | img = torch.pow(img, gamma).clip(0, 255) 143 | 144 | # resize and patch normalize 145 | if len(img.shape) == 3: 146 | img = img.unsqueeze(0) 147 | img = F.interpolate(img, size=self.dims, mode='bilinear', align_corners=False) 148 | img = img.squeeze(0) 149 | patch_normaliser = PatchNormalisePad(self.patches) 150 | im_norm = patch_normaliser(img) 151 | img = (255.0 * (1 + im_norm) / 2.0).to(dtype=torch.uint8) 152 | img = torch.unsqueeze(img,0) 153 | spike_maker = SetImageAsSpikes() 154 | img = spike_maker(img) 155 | img = torch.squeeze(img,0) 156 | 157 | return img 158 | 159 | class CustomImageDataset(Dataset): 160 | def __init__(self, annotations_file, base_dir, img_dirs, transform=None, target_transform=None, 161 | filter=1, skip=0, max_samples=None, test=True, img_range=None): 162 | self.transform = transform 163 | self.target_transform = target_transform 164 | self.filter = filter 165 | self.img_range = img_range 166 | self.skip = skip 167 | 168 | # Load image labels from each directory, apply the skip and max_samples, and concatenate 169 | self.img_labels = [] 170 | # if annotations_file is a single file, convert it to a list 171 | if not isinstance(annotations_file, list): 172 | annotations_file = [annotations_file] 173 | for idx, annotation in enumerate(annotations_file): 174 | img_labels = pd.read_csv(annotation) 175 | img_labels['file_path'] = img_labels.apply(lambda row: os.path.join(base_dir,img_dirs[idx], row.iloc[0]), axis=1) 176 | if self.img_range is not None: 177 | img_labels = img_labels.iloc[self.img_range[0]:self.img_range[1]+1] 178 | # Apply skip: start after the first 'skip' number of rows 179 | if self.skip > 0: 180 | img_labels = img_labels.iloc[self.skip:] 181 | # Select specific rows based on the skip parameter 182 | img_labels = img_labels.iloc[::filter] 183 | # Limit the number of samples to max_samples if specified 184 | if max_samples is not None: 185 | img_labels = img_labels.iloc[:max_samples] 186 | # Determine if the images being fed are training or testing 187 | if test: 188 | self.img_labels = img_labels 189 | else: 190 | self.img_labels.append(img_labels) 191 | 192 | if isinstance(self.img_labels,list): 193 | # Concatenate all the DataFrames 194 | self.img_labels = pd.concat(self.img_labels, ignore_index=True) 195 | 196 | def __len__(self): 197 | return len(self.img_labels) 198 | 199 | def __getitem__(self, idx): 200 | img_path = self.img_labels.iloc[idx]['file_path'] 201 | if not os.path.exists(img_path): 202 | raise FileNotFoundError(f"No file found for index {idx} at {img_path}.") 203 | 204 | image = read_image(img_path) 205 | label = self.img_labels.iloc[idx, 1] 206 | 207 | if self.transform: 208 | image = self.transform(image) 209 | if self.target_transform: 210 | label = self.target_transform(label) 211 | 212 | return image, label -------------------------------------------------------------------------------- /vprtempo/src/download.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import zipfile 3 | import os 4 | from tqdm import tqdm 5 | 6 | def get_data_model(): 7 | print('==== Downloading pre-trained models & Nordland images ====') 8 | # Download the pre-trained models 9 | dropbox_urls = [ 10 | "https://www.dropbox.com/scl/fi/vb8ljhp5rm1cx4tbjfxx3/VPRTempo_pretrained_models.zip?rlkey=felsqy3qbapeeztgkfcdd1zix&st=xncoy7rg&dl=0", 11 | "https://www.dropbox.com/scl/fi/445psdi7srbhuqa807lyn/spring.zip?rlkey=5ciswaz0ygv107e6pzvxnm6ga&st=tdaslyuc&dl=0", 12 | "https://www.dropbox.com/scl/fi/l2dmccham4ifj0xf9p9jw/fall.zip?rlkey=gvmt5jvzdfw8p7008yfoxeb4s&st=z14ngqyx&dl=0", 13 | "https://www.dropbox.com/scl/fi/8ff3ozh6kujbg1vbnrasw/summer.zip?rlkey=563t03cd2vwfr32i9llg945m8&st=t96te8py&dl=0" 14 | ] 15 | 16 | folders = [ 17 | "./vprtempo/models/", 18 | "./vprtempo/dataset/", 19 | "./vprtempo/dataset/", 20 | "./vprtempo/dataset/" 21 | ] 22 | 23 | names = [ 24 | "VPRTempo_pretrained_models.zip", 25 | "spring.zip", 26 | "fall.zip", 27 | "summer.zip" 28 | ] 29 | 30 | for idx, url in enumerate(dropbox_urls): 31 | download_extract(url, folders[idx], names[idx]) 32 | 33 | print('==== Downloading pre-trained models & Nordland images completed ====') 34 | 35 | def download_extract(url, folder, name): 36 | # Ensure the destination folder exists 37 | os.makedirs(folder, exist_ok=True) 38 | 39 | # Modify the URL for direct download 40 | direct_download_url = url.replace("dl=0", "dl=1") 41 | 42 | # Send a HEAD request to get the total file size 43 | with requests.head(direct_download_url, allow_redirects=True) as head: 44 | if head.status_code != 200: 45 | print(f"Failed to retrieve header for {name}. Status code: {head.status_code}") 46 | return 47 | total_size = int(head.headers.get('content-length', 0)) 48 | 49 | # Initialize the progress bar for downloading 50 | with requests.get(direct_download_url, stream=True) as response, \ 51 | open(os.path.join(folder, name), "wb") as file, \ 52 | tqdm(total=total_size, unit='B', unit_scale=True, desc=f"Downloading {name}", ncols=80) as progress_bar: 53 | 54 | if response.status_code != 200: 55 | print(f"Failed to download {name}. Status code: {response.status_code}") 56 | return 57 | 58 | for chunk in response.iter_content(chunk_size=8192): 59 | if chunk: 60 | file.write(chunk) 61 | progress_bar.update(len(chunk)) 62 | 63 | # Determine extraction path 64 | if name == "VPRTempo_pretrained_models.zip": 65 | extract_path = folder 66 | else: 67 | extract_path = os.path.join(folder, name.replace('.zip', '')) 68 | 69 | # Open the zip file 70 | with zipfile.ZipFile(os.path.join(folder, name), 'r') as zip_ref: 71 | # Get list of files in the zip 72 | members = zip_ref.namelist() 73 | # Initialize the progress bar for extraction 74 | with tqdm(total=len(members), desc=f"Extracting {name}", unit='file', ncols=80) as extract_bar: 75 | for member in members: 76 | zip_ref.extract(member, path=extract_path) 77 | extract_bar.update(1) 78 | 79 | # Remove the zip file after extraction 80 | os.remove(os.path.join(folder, name)) 81 | print(f"Completed {name}") -------------------------------------------------------------------------------- /vprtempo/src/loggers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | 5 | from datetime import datetime 6 | 7 | def model_logger(): 8 | """ 9 | Configure the model logger 10 | """ 11 | now = datetime.now() 12 | output_base_folder = './vprtempo/output/' 13 | output_folder = output_base_folder + now.strftime("%d%m%y-%H-%M-%S") 14 | 15 | # Create the base output folder if it does not exist 16 | os.makedirs(output_base_folder, exist_ok=True) 17 | 18 | # Create the specific output folder 19 | os.mkdir(output_folder) 20 | # Create the logger 21 | logger = logging.getLogger("VPRTempo") 22 | if (logger.hasHandlers()): 23 | logger.handlers.clear() 24 | # Set the logger level 25 | logger.setLevel(logging.DEBUG) 26 | logging.basicConfig(filename=output_folder + "/logfile.log", 27 | filemode="a+", 28 | format="%(asctime)-15s %(levelname)-8s %(message)s") 29 | # Add the logger to the console (if specified) 30 | logger.addHandler(logging.StreamHandler()) 31 | 32 | logger.info('') 33 | logger.info('██╗ ██╗██████╗ ██████╗ ████████╗███████╗███╗ ███╗██████╗ ██████╗') 34 | logger.info('██║ ██║██╔══██╗██╔══██╗╚══██╔══╝██╔════╝████╗ ████║██╔══██╗██╔═══██╗') 35 | logger.info('██║ ██║██████╔╝██████╔╝ ██║ █████╗ ██╔████╔██║██████╔╝██║ ██║') 36 | logger.info('╚██╗ ██╔╝██╔═══╝ ██╔══██╗ ██║ ██╔══╝ ██║╚██╔╝██║██╔═══╝ ██║ ██║') 37 | logger.info(' ╚████╔╝ ██║ ██║ ██║ ██║ ███████╗██║ ╚═╝ ██║██║ ╚██████╔╝') 38 | logger.info(' ╚═══╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚═════╝ ') 39 | logger.info('-----------------------------------------------------------------------') 40 | logger.info('Temporally Encoded Spiking Neural Network for Visual Place Recognition v1.1.9') 41 | logger.info('Queensland University of Technology, Centre for Robotics') 42 | logger.info('') 43 | logger.info('© 2023 Adam D Hines, Peter G Stratton, Michael Milford, Tobias Fischer') 44 | logger.info('MIT license - https://github.com/QVPR/VPRTempo') 45 | logger.info('\\\\\\\\\\\\\\\\\\\\\\\\') 46 | logger.info('') 47 | if torch.cuda.is_available(): 48 | logger.info('CUDA available: ' + str(torch.cuda.is_available())) 49 | current_device = torch.cuda.current_device() 50 | logger.info('Current device is: ' + str(torch.cuda.get_device_name(current_device))) 51 | elif torch.backends.mps.is_available(): 52 | logger.info('MPS available: ' + str(torch.backends.mps.is_available())) 53 | logger.info('Current device is: MPS') 54 | else: 55 | logger.info('CUDA available: ' + str(torch.cuda.is_available())) 56 | logger.info('Current device is: CPU') 57 | logger.info('') 58 | 59 | return logger, output_folder 60 | 61 | def model_logger_quant(): 62 | """ 63 | Configure the logger 64 | """ 65 | 66 | now = datetime.now() 67 | output_base_folder = './vprtempo/output/' 68 | output_folder = output_base_folder + now.strftime("%d%m%y-%H-%M-%S") 69 | 70 | # Create the base output folder if it does not exist 71 | os.makedirs(output_base_folder, exist_ok=True) 72 | 73 | # Create the specific output folder 74 | os.mkdir(output_folder) 75 | # Create the logger 76 | logger = logging.getLogger("VPRTempo") 77 | if (logger.hasHandlers()): 78 | logger.handlers.clear() 79 | # Set the logger level 80 | logger.setLevel(logging.DEBUG) 81 | logging.basicConfig(filename=output_folder + "/logfile.log", 82 | filemode="a+", 83 | format="%(asctime)-15s %(levelname)-8s %(message)s") 84 | # Add the logger to the console (if specified) 85 | logger.addHandler(logging.StreamHandler()) 86 | 87 | logger.info('') 88 | 89 | logger.info('██╗ ██╗██████╗ ██████╗ ████████╗███████╗███╗ ███╗██████╗ ██████╗ ██████╗ ██╗ ██╗ █████╗ ███╗ ██╗████████╗') 90 | logger.info('██║ ██║██╔══██╗██╔══██╗╚══██╔══╝██╔════╝████╗ ████║██╔══██╗██╔═══██╗ ██╔═══██╗██║ ██║██╔══██╗████╗ ██║╚══██╔══╝') 91 | logger.info('██║ ██║██████╔╝██████╔╝ ██║ █████╗ ██╔████╔██║██████╔╝██║ ██║█████╗██║ ██║██║ ██║███████║██╔██╗ ██║ ██║') 92 | logger.info('╚██╗ ██╔╝██╔═══╝ ██╔══██╗ ██║ ██╔══╝ ██║╚██╔╝██║██╔═══╝ ██║ ██║╚════╝██║▄▄ ██║██║ ██║██╔══██║██║╚██╗██║ ██║') 93 | logger.info(' ╚████╔╝ ██║ ██║ ██║ ██║ ███████╗██║ ╚═╝ ██║██║ ╚██████╔╝ ╚██████╔╝╚██████╔╝██║ ██║██║ ╚████║ ██║') 94 | logger.info(' ╚═══╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚═════╝ ╚══▀▀═╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═══╝ ╚═╝') 95 | logger.info('-----------------------------------------------------------------------') 96 | logger.info('Temporally Encoded Spiking Neural Network for Visual Place Recognition v1.1.8') 97 | logger.info('Queensland University of Technology, Centre for Robotics') 98 | logger.info('') 99 | logger.info('© 2023 Adam D Hines, Peter G Stratton, Michael Milford, Tobias Fischer') 100 | logger.info('MIT license - https://github.com/QVPR/VPRTempo') 101 | logger.info('\\\\\\\\\\\\\\\\\\\\\\\\') 102 | logger.info('') 103 | logger.info('Quantization enabled') 104 | logger.info('Current device is: CPU') 105 | logger.info('') 106 | 107 | return logger, output_folder 108 | -------------------------------------------------------------------------------- /vprtempo/src/metrics.py: -------------------------------------------------------------------------------- 1 | # ===================================================================== 2 | # Copyright (C) 2023 Stefan Schubert, stefan.schubert@etit.tu-chemnitz.de 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | # ===================================================================== 17 | # 18 | import numpy as np 19 | 20 | 21 | def createPR(S_in, GThard, GTsoft=None, matching='multi', n_thresh=100): 22 | """ 23 | Calculates the precision and recall at n_thresh equally spaced threshold values 24 | for a given similarity matrix S_in and ground truth matrices GThard and GTsoft for 25 | single-best-match VPR or multi-match VPR. 26 | 27 | The matrices S_in, GThard and GTsoft are two-dimensional and should all have the 28 | same shape. 29 | The matrices GThard and GTsoft should be binary matrices, where the entries are 30 | only zeros or ones. 31 | The matrix S_in should have continuous values between -Inf and Inf. Higher values 32 | indicate higher similarity. 33 | The string matching should be set to either "single" or "multi" for single-best- 34 | match VPR or multi-match VPR. 35 | The integer n_tresh controls the number of threshold values and should be >1. 36 | """ 37 | 38 | assert (S_in.shape == GThard.shape),"S_in, GThard and GTsoft must have the same shape" 39 | assert (S_in.ndim == 2),"S_in, GThard and GTsoft must be two-dimensional" 40 | assert (matching in ['single', 'multi']),"matching should contain one of the following strings: [single, multi]" 41 | assert (n_thresh > 1),"n_thresh must be >1" 42 | 43 | if GTsoft is not None and matching == 'single': 44 | raise ValueError( 45 | "GTSoft with single matching is not supported. " 46 | "Please use dilated hard ground truth directly. " 47 | "For more details, visit: https://github.com/stschubert/VPR_Tutorial" 48 | ) 49 | 50 | # ensure logical datatype in GT and GTsoft 51 | GT = GThard.astype('bool') 52 | if GTsoft is not None: 53 | GTsoft = GTsoft.astype('bool') 54 | GThard_orig = GThard.copy() 55 | 56 | # copy S and set elements that are only true in GTsoft to min(S) to ignore them during evaluation 57 | S = S_in.copy() 58 | 59 | if GTsoft is not None: 60 | S[GTsoft & ~GT] = S.min() 61 | 62 | if matching == 'single': 63 | # GT-values for best match per query (i.e., per column) 64 | GT = GT[np.argmax(S, axis=0), np.arange(GT.shape[1])] 65 | # similarities for best match per query (i.e., per column) 66 | S = np.max(S, axis=0) 67 | 68 | # init precision and recall vectors 69 | R = [0, ] 70 | P = [1, ] 71 | 72 | # select start and end treshold 73 | startV = S.max() # start-value for treshold 74 | endV = S.min() # end-value for treshold 75 | thresholds = np.linspace(startV, endV, n_thresh) 76 | 77 | # Iterate over different thresholds with enumeration to track the last iteration 78 | for i in thresholds: 79 | B = S >= i # Apply threshold 80 | 81 | TP = np.count_nonzero(GT & B) # True Positives 82 | FP = np.count_nonzero((~GT) & B) # False Positives 83 | FN = np.count_nonzero(GT & (~B)) # False Negatives 84 | 85 | # Handle division by zero for precision 86 | precision = TP / (TP + FP) 87 | recall = TP / (TP + FN) 88 | 89 | P.append(precision) # Precision 90 | R.append(recall) # Recall 91 | 92 | return P, R 93 | 94 | 95 | def recallAt100precision(S_in, GThard, GTsoft=None, matching='multi', n_thresh=100): 96 | """ 97 | Calculates the maximum recall at 100% precision for a given similarity matrix S_in 98 | and ground truth matrices GThard and GTsoft for single-best-match VPR or multi-match 99 | VPR. 100 | 101 | The matrices S_in, GThard and GTsoft are two-dimensional and should all have the 102 | same shape. 103 | The matrices GThard and GTsoft should be binary matrices, where the entries are 104 | only zeros or ones. 105 | The matrix S_in should have continuous values between -Inf and Inf. Higher values 106 | indicate higher similarity. 107 | The string matching should be set to either "single" or "multi" for single-best- 108 | match VPR or multi-match VPR. 109 | The integer n_tresh controls the number of threshold values during the creation of 110 | the precision-recall curve and should be >1. 111 | """ 112 | 113 | assert (S_in.shape == GThard.shape),"S_in and GThard must have the same shape" 114 | if GTsoft is not None: 115 | assert (S_in.shape == GTsoft.shape),"S_in and GTsoft must have the same shape" 116 | assert (S_in.ndim == 2),"S_in, GThard and GTsoft must be two-dimensional" 117 | assert (matching in ['single', 'multi']),"matching should contain one of the following strings: [single, multi]" 118 | assert (n_thresh > 1),"n_thresh must be >1" 119 | 120 | # get precision-recall curve 121 | P, R = createPR(S_in, GThard, GTsoft, matching=matching, n_thresh=n_thresh) 122 | P = np.array(P) 123 | R = np.array(R) 124 | 125 | # recall values at 100% precision 126 | R = R[P==1] 127 | 128 | # maximum recall at 100% precision 129 | R = R.max() 130 | 131 | return R 132 | 133 | 134 | def recallAtK(S, GT, K=1): 135 | """ 136 | Calculates the recall@K for a given similarity matrix S and ground truth matrix GT. 137 | Note that this method does not support GTsoft - instead, please directly provide 138 | the dilated ground truth matrix as GT. 139 | 140 | The matrices S and GT are two-dimensional and should all have the same shape. 141 | The matrix GT should be binary, where the entries are only zeros or ones. 142 | The matrix S should have continuous values between -Inf and Inf. Higher values 143 | indicate higher similarity. 144 | The integer K>=1 defines the number of matching candidates that are selected and 145 | that must contain an actually matching image pair. 146 | """ 147 | 148 | assert (S.shape == GT.shape),"S and GT must have the same shape" 149 | assert (S.ndim == 2),"S and GT must be two-dimensional" 150 | assert (K >= 1),"K must be >=1" 151 | 152 | # ensure logical datatype in GT 153 | GT = GT.astype('bool') 154 | 155 | # discard all query images without an actually matching database image 156 | j = GT.sum(0) > 0 # columns with matches 157 | S = S[:,j] # select columns with a match 158 | GT = GT[:,j] # select columns with a match 159 | 160 | # select K highest similarities 161 | i = S.argsort(0)[-K:,:] 162 | j = np.tile(np.arange(i.shape[1]), [K, 1]) 163 | GT = GT[i, j] 164 | 165 | # recall@K 166 | RatK = np.sum(GT.sum(0) > 0) / GT.shape[1] 167 | 168 | return RatK -------------------------------------------------------------------------------- /vprtempo/src/nordland.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Imports 3 | ''' 4 | import os 5 | import re 6 | import shutil 7 | import zipfile 8 | import sys 9 | sys.path.append('..//dataset') 10 | 11 | from os import walk 12 | 13 | def nord_sort(): 14 | # load and sort the file names in order, not by how OS indexes them 15 | def atoi(text): 16 | return int(text) if text.isdigit() else text 17 | 18 | def natural_keys(text): 19 | return [ atoi(c) for c in re.split(r'(\d+)', text) ] 20 | 21 | # set the base path to the location of the downloaded Nordland datasets 22 | basePath = '../dataset/' 23 | assert(os.path.isdir(basePath)),"Please set the basePath to the location of the downloaded Nordland datasets" 24 | 25 | # define the subfolders of the Nordland datasets 26 | subPath = ["spring_images_train/section1/","spring_images_train/section2/", 27 | "fall_images_train/section1/","fall_images_train/section2/", 28 | "winter_images_train/section1/","winter_images_train/section2/", 29 | "summer_images_train/section1/","summer_images_train/section2/"] 30 | 31 | # set the desired output folder for unzipping and organization 32 | outDir = '../dataset/' 33 | assert(os.path.isdir(outDir)),"Please set the outDir to the desired output location for unzipping the Nordland datasets" 34 | 35 | # define output paths for the data 36 | outPath = [os.path.join(outDir,"spring/"),os.path.join(outDir,"fall/"), 37 | os.path.join(outDir,"winter/"),os.path.join(outDir,"summer/")] 38 | 39 | # check for existence of the zip folders, throw exception if missing 40 | zipNames = ["spring_images_train.zip","fall_images_train.zip", 41 | "winter_images_train.zip","summer_images_train.zip"] 42 | for n in zipNames: 43 | if not os.path.exists(basePath+n): 44 | raise Exception('Please ensure dataset .zip folders have been downloaded') 45 | 46 | # check if nordland data folders have already been unzipped 47 | zip_flag = [] 48 | for n, ndx in enumerate(range(0,len(subPath),2)): 49 | print('Unzipping '+zipNames[n]) 50 | if os.path.exists(basePath+subPath[ndx]): 51 | # check if the folder contains any files 52 | file_lst = os.listdir(basePath+subPath[ndx]) 53 | # remove folder if it is empty and unzip the data folder 54 | if len(file_lst) == 0: 55 | shutil.rmtree(basePath+subPath[ndx].replace('section1/','')) 56 | with zipfile.ZipFile(basePath+zipNames[n],"r") as zip_ref: 57 | zip_ref.extractall(basePath) 58 | else: 59 | with zipfile.ZipFile(basePath+zipNames[n],"r") as zip_ref: 60 | zip_ref.extractall(basePath) 61 | 62 | # load image paths 63 | tempPaths = [] 64 | imgPaths = [] 65 | for n in range(0,len(subPath)): 66 | tempPaths = [] 67 | for (path, dir_names, file_names) in walk(basePath+subPath[n]): 68 | tempPaths.extend(file_names) 69 | # sort image names 70 | tempPaths.sort(key=natural_keys) 71 | tempPaths = [basePath+subPath[n]+s for s in tempPaths] 72 | imgPaths = imgPaths + tempPaths 73 | 74 | # if output folders already exist, delete them 75 | for n in outPath: 76 | if os.path.exists(n): 77 | shutil.rmtree(n) 78 | print('Removed pre-existing output folder') 79 | 80 | # rename and move the training data to match the nordland_imageNames.txt file 81 | for n in outPath: 82 | os.mkdir(n) 83 | for n, filename in enumerate(imgPaths): 84 | base = os.path.basename(filename) 85 | split_base = os.path.splitext(base) 86 | if int(split_base[0]) < 10: 87 | my_dest = "images-0000"+split_base[0] + ".png" 88 | elif int(split_base[0]) < 100: 89 | my_dest = "images-000"+split_base[0] + ".png" 90 | elif int(split_base[0]) < 1000: 91 | my_dest = "images-00"+split_base[0] + ".png" 92 | elif int(split_base[0]) < 10000: 93 | my_dest = "images-0"+split_base[0] + ".png" 94 | else: 95 | my_dest = "images-"+split_base[0] + ".png" 96 | if "spring" in filename: 97 | out = outPath[0] 98 | elif "fall" in filename: 99 | out = outPath[1] 100 | elif "winter" in filename: 101 | out = outPath[2] 102 | else: 103 | out = outPath[-1] 104 | 105 | fileDest = out + my_dest 106 | os.rename(filename, fileDest) 107 | 108 | # remove the empty folders 109 | for n, ndx in enumerate(subPath): 110 | if n%2 == 0: 111 | shutil.rmtree(basePath+ndx.replace('section1/','')) 112 | else: 113 | continue 114 | 115 | print('Finished unzipping and organizing Nordland dataset') -------------------------------------------------------------------------------- /vprtempo/src/process_orc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from tqdm import tqdm 4 | import image 5 | from camera_model import CameraModel 6 | 7 | # Set this to the directory containing the images 8 | base_path = '/media/adam/vprdatasets/data/orc' 9 | # Modify if evaluating different ORC datasets 10 | datasets = [ 11 | '2015-08-12-15-04-18', 12 | '2014-11-21-16-07-03', 13 | '2015-10-29-12-18-17' 14 | ] 15 | # Set up output path names 16 | processed_path = 'demosaiced' 17 | 18 | for dataset in datasets: 19 | # Define the dataset path 20 | dataset_path = os.path.join(base_path, dataset) 21 | 22 | # Create the folders 23 | os.makedirs(os.path.join(dataset_path, processed_path), exist_ok=True) 24 | 25 | # file path for the robotcar-dataset-sdk models folder 26 | model_dir = '/home/adam/repo/robotcar-dataset-sdk/models/' 27 | # file path for the left stereo images 28 | images_path = os.path.join(dataset_path, 'stereo/left') 29 | # Create a camera model object 30 | model = CameraModel(model_dir,images_path) 31 | # Get sorted list of PNG images 32 | images = sorted([os.path.join(images_path, image) for image in os.listdir(images_path) if image.endswith('.png')]) 33 | 34 | # Process each image with a progress bar 35 | for img in tqdm(images, desc="Processing images", unit="image"): 36 | # Load the image 37 | processed_img = image.load_image(img, model=model) 38 | # Create the output file path 39 | output_img = os.path.join(dataset_path, processed_path, os.path.basename(img)) 40 | processed_img = Image.fromarray(processed_img, mode="RGB") 41 | # Save the processed image as PNG 42 | processed_img.save(output_img, "PNG") 43 | 44 | print("Images processed and saved successfully.") --------------------------------------------------------------------------------