├── .github
└── workflows
│ └── main.yml
├── .gitignore
├── LICENSE
├── README.md
├── assets
├── github_image.png
├── main_example.gif
├── mainquant_example.gif
├── train_example.gif
└── vprtempo_example.gif
├── main.py
├── orc_list.txt
├── pixi.lock
├── pixi.toml
├── requirements.txt
├── setup.py
├── tutorials
├── 1_BasicDemo.ipynb
├── 2_Training.ipynb
├── 3_Modules.ipynb
└── mats
│ └── README.txt
└── vprtempo
├── VPRTempo.py
├── VPRTempoQuant.py
├── VPRTempoQuantTrain.py
├── VPRTempoTrain.py
├── __init__.py
├── dataset
├── nordland-fall.csv
├── nordland-spring.csv
├── nordland-summer.csv
├── orc-dusk.csv
├── orc-rain.csv
└── orc-sun.csv
├── models
├── .gitkeep
└── README.txt
└── src
├── __init__.py
├── blitnet.py
├── create_data_csv.py
├── dataset.py
├── download.py
├── loggers.py
├── metrics.py
├── nordland.py
└── process_orc.py
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: Publish to PyPi
2 |
3 | on:
4 | release:
5 | types: [created]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - name: Check out code
12 | uses: actions/checkout@v2
13 |
14 | - name: Set up Python
15 | uses: actions/setup-python@v2
16 | with:
17 | python-version: '3.x'
18 |
19 | - name: Install dependencies
20 | run: |
21 | pip install setuptools wheel twine
22 |
23 | - name: Build and publish
24 | env:
25 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
26 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
27 | run: |
28 | python setup.py sdist bdist_wheel
29 | twine upload dist/*
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .pyest_cache/
3 | vprtempo/__pycache__/
4 | vprtempo/dataset/fall/
5 | vprtempo/dataset/spring/
6 | vprtempo/dataset/summer/
7 | vprtempo/dataset/winter/
8 | vprtempo/dataset/event.csv
9 | vprtempo/output/
10 | vprtempo/src/__pycache__/
11 | *.pth
12 | tutorials/mats/1_BasicDemo/
13 | .vscode/
14 | *.DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Adam Hines
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VPRTempo - A Temporally Encoded Spiking Neural Network for Visual Place Recognition
2 | 
3 | [](https://creativecommons.org/licenses/by-nc-sa/4.0/)
4 | [](https://qcr.ai)
5 | [](https://github.com/QVPR/VPRTempo/stargazers)
6 | [](https://pepy.tech/project/vprtempo)
7 | [](https://pixi.sh)
8 | [](https://anaconda.org/conda-forge/vprtempo)
9 | [](./README.md)
10 |
11 |
12 | This repository contains code for [VPRTempo](https://vprtempo.github.io), a spiking neural network that uses temporally encoding to perform visual place recognition tasks. The network is based off of [BLiTNet](https://arxiv.org/pdf/2208.01204.pdf) and adapted to the [VPRSNN](https://github.com/QVPR/VPRSNN) framework.
13 |
14 |
15 |
16 |
17 |
18 | VPRTempo is built on a [torch.nn](https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html) framework and employs custom learning rules based on the temporal codes of spikes in order to train layer weights.
19 |
20 | In this repository, we provide two networks:
21 | - `VPRTempo`: Our base network architecture to perform visual place recognition (fp32)
22 | - `VPRTempoQuant`: A modified base network with [Quantization Aware Training (QAT)](https://pytorch.org/docs/stable/quantization.html) enabled (int8)
23 |
24 | To use VPRTempo, please follow the instructions below for installation and usage.
25 |
26 | ## :star: Update v1.1.10: What's new?
27 | - Adding [pixi](https://prefix.dev/) for easier setup and reproducibility :rocket:
28 | - Fixed repo size issue for a more compact download :chart_with_downwards_trend:
29 |
30 | ## Quick start
31 | For simplicity and reproducibility, VPRTempo uses [pixi](https://prefix.dev/) to install and manage dependencies. If you do not already have pixi installed, run the following in your command terminal:
32 |
33 | ```console
34 | curl -fsSL https://pixi.sh/install.sh | bash
35 | ```
36 | For more information, please refer to the [pixi documentation](https://prefix.dev/docs/prefix/overview).
37 |
38 | ### Get the repository
39 | Get the latest VPRTempo code and navigate to the project directory by running the following in your command terminal:
40 | ```console
41 | git clone https://github.com/QVPR/VPRTempo.git
42 | cd VPRTempo
43 | ```
44 |
45 | ### Run the demo
46 | To quickly evaluate VPRTempo, we provide a pre-trained network trained on 500 places from the [Nordland](https://nrkbeta.no/2013/01/15/nordlandsbanen-minute-by-minute-season-by-season/) dataset. Run the following in your command terminal to run the demo:
47 | ```console
48 | pixi run demo
49 | ```
50 | _Note: this will start a download of the models and datasets (~600MB), please ensure you have enough disk space before proceeding._
51 |
52 | ### Train and evaluate a new model
53 | Training and evaluating a new model is quick and easy, simply run the following in your command terminal to re-train and evaluate the demo model:
54 |
55 | ```console
56 | pixi run train
57 | pixi run eval
58 | ```
59 | _Note: You will be prompted if you want to retrain the pre-existing network._
60 |
61 | ### Use the quantized models
62 | For training and evaluation of the 8-bit quantized model, run the following in your command terminal:
63 |
64 | ```console
65 | pixi run train_quant
66 | pixi run eval_quant
67 | ```
68 |
69 | ## Datasets
70 | VPRTempo was developed to be simple to train and test a variety of datasets. Please see the information below about recreating our results for the Nordland and Oxford RobotCar datasets and setting up custom datasets.
71 |
72 | ### Nordland
73 | VPRTempo was developed and tested using the [Nordland](https://nrkbeta.no/2013/01/15/nordlandsbanen-minute-by-minute-season-by-season/) dataset. To download the full dataset, please visit [this repository](https://huggingface.co/datasets/Somayeh-h/Nordland?row=0). Once downloaded, place dataset folders into the VPRTempo directory as follows:
74 |
75 | ```
76 | |__./vprtempo
77 | |___dataset
78 | |__summer
79 | |__spring
80 | |__fall
81 | |__winter
82 | ```
83 |
84 | To replicate the results in our paper, run the following in your command terminal:
85 |
86 | ```console
87 | pixi run nordland_train
88 | pixi run nordland_eval
89 | ```
90 |
91 | Alternatively, specify the data directory using the following argument:
92 |
93 | ```console
94 | pixi run nordland_train --data_dir
95 | pixi run nordland_eval --data_dir
96 | ```
97 |
98 | ### Oxford RobotCar
99 | In order to train and test on Oxford RobotCar, you will need to [register an account](https://mrgdatashare.robots.ox.ac.uk/register/) to get access to download the dataset and process the images before proceeding. For more information, please refer to the [documentation]().
100 |
101 | Once fully processed, to replicate the results in our paper run the following in your command terminal:
102 |
103 | ```console
104 | pixi run orc_train
105 | pixi run orc_eval
106 | ```
107 |
108 | ### Custom Datasets
109 | To define your own custom dataset to use with VPRTempo, simply follow the same dataset structure defined above for Nordland. A `.csv` file of the image names will be required for the dataloader.
110 |
111 | We have included a convenient script `./vprtempo/src/create_data_csv.py` which will generate the necessary file. Simply modify the `dataset_name` variable to the folder containing your images.
112 |
113 | To train a new model with a custom dataset, you can do the following.
114 |
115 | ```console
116 | pixi run train --dataset --database_dirs
117 | pixi run eval --database_dirs --dataset --query_dir
118 | ```
119 |
120 | ## License & Citation
121 | This repository is licensed under the [MIT License](./LICENSE). If you use our code, please cite our IEEE ICRA [paper](https://ieeexplore.ieee.org/document/10610918):
122 | ```
123 | @inproceedings{hines2024vprtempo,
124 | title={VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition},
125 | author={Adam D. Hines and Peter G. Stratton and Michael Milford and Tobias Fischer},
126 | year={2024},
127 | pages={10200-10207},
128 | booktitle={2024 IEEE International Conference on Robotics and Automation (ICRA)}
129 | }
130 | ```
131 |
132 | ## Documentation
133 |
134 | For more detailed information on usage, please visit the [documentation]().
135 |
136 | ## Tutorials
137 | We provide a series of Jupyter Notebook [tutorials](https://github.com/QVPR/VPRTempo/tree/main/tutorials) that go through the basic operations and logic for VPRTempo and VPRTempoQuant.
138 |
139 | ## Issues, bugs, and feature requests
140 | If you encounter problems whilst running the code or if you have a suggestion for a feature or improvement, please report it as an [issue](https://github.com/QVPR/VPRTempo/issues).
141 |
--------------------------------------------------------------------------------
/assets/github_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QVPR/VPRTempo/52fb2ecbbbe924430021b54e309fe2a5c8d18672/assets/github_image.png
--------------------------------------------------------------------------------
/assets/main_example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QVPR/VPRTempo/52fb2ecbbbe924430021b54e309fe2a5c8d18672/assets/main_example.gif
--------------------------------------------------------------------------------
/assets/mainquant_example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QVPR/VPRTempo/52fb2ecbbbe924430021b54e309fe2a5c8d18672/assets/mainquant_example.gif
--------------------------------------------------------------------------------
/assets/train_example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QVPR/VPRTempo/52fb2ecbbbe924430021b54e309fe2a5c8d18672/assets/train_example.gif
--------------------------------------------------------------------------------
/assets/vprtempo_example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QVPR/VPRTempo/52fb2ecbbbe924430021b54e309fe2a5c8d18672/assets/vprtempo_example.gif
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #MIT License
2 |
3 | #Copyright (c) 2023 Adam Hines, Peter G Stratton, Michael Milford, Tobias Fischer
4 |
5 | #Permission is hereby granted, free of charge, to any person obtaining a copy
6 | #of this software and associated documentation files (the "Software"), to deal
7 | #in the Software without restriction, including without limitation the rights
8 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | #copies of the Software, and to permit persons to whom the Software is
10 | #furnished to do so, subject to the following conditions:
11 |
12 | #The above copyright notice and this permission notice shall be included in all
13 | #copies or substantial portions of the Software.
14 |
15 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | #SOFTWARE.
22 |
23 | '''
24 | Imports
25 | '''
26 | import os
27 | import sys
28 | import torch
29 | import argparse
30 |
31 | import torch.quantization as quantization
32 |
33 | from tqdm import tqdm
34 | from vprtempo.VPRTempo import VPRTempo, run_inference
35 | from vprtempo.VPRTempoTrain import VPRTempoTrain, train_new_model
36 | from vprtempo.src.loggers import model_logger, model_logger_quant
37 | from vprtempo.VPRTempoQuant import VPRTempoQuant, run_inference_quant
38 | from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, train_new_model_quant
39 |
40 | def generate_model_name(model,quant=False):
41 | """
42 | Generate the model name based on its parameters.
43 | """
44 | if quant:
45 | model_name = (''.join(model.database_dirs)+"_"+
46 | "VPRTempoQuant_" +
47 | "IN"+str(model.input)+"_" +
48 | "FN"+str(model.feature)+"_" +
49 | "DB"+str(model.database_places) +
50 | ".pth")
51 | else:
52 | model_name = (''.join(model.database_dirs)+"_"+
53 | "VPRTempo_" +
54 | "IN"+str(model.input)+"_" +
55 | "FN"+str(model.feature)+"_" +
56 | "DB"+str(model.database_places) +
57 | ".pth")
58 | return model_name
59 |
60 | def check_pretrained_model(model_name):
61 | """
62 | Check if a pre-trained model exists and prompt the user to retrain if desired.
63 | """
64 | if os.path.exists(os.path.join('./vprtempo/models', model_name)):
65 | prompt = "A network with these parameters exists, re-train network? (y/n):\n"
66 | retrain = input(prompt).strip().lower()
67 | if retrain == 'y':
68 | return True
69 | elif retrain == 'n':
70 | print('Training new model cancelled')
71 | sys.exit()
72 |
73 | def initialize_and_run_model(args,dims):
74 | """
75 | Run the VPRTempo/VPRTempoQuant training or inference models.
76 |
77 | :param args: Arguments set for the network
78 | :param dims: Dimensions of the network
79 | """
80 | # Determine number of modules to generate based on user input
81 | places = args.database_places # Copy out number of database places
82 |
83 | # Caclulate number of modules
84 | num_modules = 1
85 | while places > args.max_module:
86 | places -= args.max_module
87 | num_modules += 1
88 |
89 | # If the final module has less than max_module, reduce the dim of the output layer
90 | remainder = args.database_places % args.max_module
91 | if remainder != 0: # There are remainders, adjust output neuron count in final module
92 | out_dim = int((args.database_places - remainder) / (num_modules - 1))
93 | final_out_dim = remainder
94 | else: # No remainders, all modules are even
95 | out_dim = int(args.database_places / num_modules)
96 | final_out_dim = out_dim
97 |
98 | # If user wants to train a new network
99 | if args.train_new_model:
100 | # If using quantization aware training
101 | if args.quantize:
102 | models = []
103 | logger = model_logger_quant() # Initialize the logger
104 | qconfig = quantization.get_default_qat_qconfig('fbgemm')
105 | # Create the modules
106 | final_out = None
107 | for mod in tqdm(range(num_modules), desc="Initializing modules"):
108 | model = VPRTempoQuantTrain(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model
109 | model.train()
110 | model.qconfig = qconfig
111 | quantization.prepare_qat(model, inplace=True)
112 | models.append(model) # Create module list
113 | if mod == num_modules - 2:
114 | final_out = final_out_dim
115 | # Generate the model name
116 | model_name = generate_model_name(model,args.quantize)
117 | # Check if the model has been trained before
118 | check_pretrained_model(model_name)
119 | # Get the quantization config
120 | qconfig = quantization.get_default_qat_qconfig('fbgemm')
121 | # Train the model
122 | train_new_model_quant(models, model_name)
123 |
124 | # Base model
125 | else:
126 | models = []
127 | logger = model_logger() # Initialize the logger
128 |
129 | # Create the modules
130 | final_out = None
131 | for mod in tqdm(range(num_modules), desc="Initializing modules"):
132 | model = VPRTempoTrain(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model
133 | model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models)
134 | models.append(model) # Create module list
135 | if mod == num_modules - 2:
136 | final_out = final_out_dim
137 |
138 | # Generate the model name
139 | model_name = generate_model_name(model)
140 | print(f"Model name: {model_name}")
141 | # Check if the model has been trained before
142 | check_pretrained_model(model_name)
143 | # Train the model
144 | train_new_model(models, model_name)
145 |
146 | # Run the inference network
147 | else:
148 | # Set the quantization configuration
149 | if args.quantize:
150 | models = []
151 | logger, output_folder = model_logger_quant()
152 | qconfig = quantization.get_default_qat_qconfig('fbgemm')
153 | final_out = None
154 | for _ in tqdm(range(num_modules), desc="Initializing modules"):
155 | # Initialize the model
156 | model = VPRTempoQuant(
157 | args,
158 | dims,
159 | logger,
160 | num_modules,
161 | output_folder,
162 | out_dim,
163 | out_dim_remainder=final_out
164 | )
165 | model.eval()
166 | model.qconfig = qconfig
167 | quantization.prepare(model, inplace=True)
168 | quantization.convert(model, inplace=True)
169 | models.append(model)
170 | # Generate the model name
171 | model_name = generate_model_name(model, args.quantize)
172 | # Run the quantized inference model
173 | run_inference_quant(models, model_name)
174 | else:
175 | models = []
176 | logger, output_folder = model_logger() # Initialize the logger
177 | places = args.database_places # Copy out number of database places
178 |
179 | # Create the modules
180 | final_out = None
181 | for mod in tqdm(range(num_modules), desc="Initializing modules"):
182 | model = VPRTempo(
183 | args,
184 | dims,
185 | logger,
186 | num_modules,
187 | output_folder,
188 | out_dim,
189 | out_dim_remainder=final_out
190 | )
191 | model.eval()
192 | model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models)
193 | models.append(model) # Create module list
194 | if mod == num_modules - 2:
195 | final_out = final_out_dim
196 | # Generate the model name
197 | model_name = generate_model_name(model)
198 | print(f"Model name: {model_name}")
199 | # Run the inference model
200 | run_inference(models, model_name)
201 |
202 | def parse_network():
203 | '''
204 | Define the base parameter parser (configurable by the user)
205 | '''
206 | parser = argparse.ArgumentParser(description="Args for base configuration file")
207 |
208 | # Define the dataset arguments
209 | parser.add_argument('--dataset', type=str, default='nordland',
210 | help="Dataset to use for training and/or inferencing")
211 | parser.add_argument('--data_dir', type=str, default='./vprtempo/dataset/',
212 | help="Directory where dataset files are stored")
213 | parser.add_argument('--database_places', type=int, default=500,
214 | help="Number of places to use for training")
215 | parser.add_argument('--query_places', type=int, default=500,
216 | help="Number of places to use for inferencing")
217 | parser.add_argument('--max_module', type=int, default=500,
218 | help="Maximum number of images per module")
219 | parser.add_argument('--database_dirs', type=str, default='spring,fall',
220 | help="Directories to use for training")
221 | parser.add_argument('--query_dir', type=str, default='summer',
222 | help="Directories to use for testing")
223 | parser.add_argument('--GT_tolerance', type=int, default=0,
224 | help="Ground truth tolerance for matching")
225 | parser.add_argument('--skip', type=int, default=0,
226 | help="Images to skip for training and/or inferencing")
227 |
228 | # Define training parameters
229 | parser.add_argument('--filter', type=int, default=8,
230 | help="Images to skip for training and/or inferencing")
231 | parser.add_argument('--epoch', type=int, default=4,
232 | help="Number of epochs to train the model")
233 |
234 | # Define image transformation parameters
235 | parser.add_argument('--patches', type=int, default=15,
236 | help="Number of patches to generate for patch normalization image into")
237 | parser.add_argument('--dims', type=str, default="56,56",
238 | help="Dimensions to resize the image to")
239 |
240 | # Define the network functionality
241 | parser.add_argument('--train_new_model', action='store_true',
242 | help="Flag to run the training or inferencing model")
243 | parser.add_argument('--quantize', action='store_true',
244 | help="Enable/disable quantization for the model")
245 |
246 | # Define metrics functionality
247 | parser.add_argument('--PR_curve', action='store_true',
248 | help="Flag to generate a Precision-Recall curve")
249 | parser.add_argument('--sim_mat', action='store_true',
250 | help="Flag to plot the similarity matrix, GT, and GTsoft")
251 |
252 | # Output base configuration
253 | args = parser.parse_args()
254 | dims = [int(x) for x in args.dims.split(",")]
255 |
256 | # Run the network with the desired settings
257 | initialize_and_run_model(args,dims)
258 |
259 | if __name__ == "__main__":
260 | parse_network()
--------------------------------------------------------------------------------
/orc_list.txt:
--------------------------------------------------------------------------------
1 | 2015-08-12-15-04-18
2 | 2014-11-21-16-07-03
3 | 2015-10-29-12-18-17
--------------------------------------------------------------------------------
/pixi.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "vprtempo"
3 | version = "1.1.10"
4 | description = "Temporally encoded spiking neural network for fast visual place recognition"
5 | authors = ["Adam D Hines ", "Peter G Stratton ", "Michael Milford "]
6 | channels = ["conda-forge", "pytorch"]
7 | platforms = ["linux-64", "osx-arm64", "win-64"]
8 | preview = ["pixi-build"]
9 |
10 | [feature.cuda]
11 | platforms = ["linux-64"]
12 | system-requirements = {cuda = "12"}
13 |
14 | [dependencies]
15 | python = ">=3.6,<3.13"
16 | pytorch = ">=2.4.0"
17 | torchvision = ">=0.19.0"
18 | numpy = ">=1.26.2,<2"
19 | pandas = ">=2.2.2"
20 | tqdm = ">=4.66.5"
21 | prettytable = ">=3.11.0"
22 | matplotlib-base = ">=3.9.2"
23 | requests = ">=2.32.3"
24 |
25 | [feature.cuda.dependencies]
26 | cuda-version = "12.*"
27 | pytorch-gpu = "*"
28 | cuda-cudart-dev = "*"
29 | cuda-crt = "*"
30 | cudnn = "*"
31 | libcusparse-dev = "*"
32 | cuda-driver-dev = "*"
33 | cuda-nvcc = "*"
34 | cuda-nvrtc-dev = "*"
35 | cuda-nvtx-dev = "*"
36 | cuda-nvml-dev = "*"
37 | cuda-profiler-api = "*"
38 | cusparselt = "*"
39 | libcublas-dev = "*"
40 | libcudss-dev = "*"
41 | libcufile-dev = "*"
42 | libcufft-dev = "*"
43 | libcurand-dev = "*"
44 | libcusolver-dev = "*"
45 |
46 | [environments]
47 | cuda = ["cuda"]
48 |
49 | [tasks]
50 | # run demo - downloads pretrained models and Nordland dataset
51 | demo = {cmd = "pixi run --frozen python main.py --PR_curve --sim_mat"}
52 |
53 | # run the evaluation networks
54 | eval = {cmd = "pixi run python main.py"}
55 | eval_quant = {cmd = "pixi run python main.py --quantize"}
56 |
57 | # train network
58 | train = {cmd = "pixi run python main.py --train_new_model"}
59 | train_quant = {cmd = "pixi run python main.py --train_new_model --quantize"}
60 |
61 | # replicate conference proceedings results
62 | nordland_train = {cmd = "pixi run python main.py --train_new_model --database_places 3300 --database_dirs spring,fall --skip 0 --max_module 1100 --dataset nordland --dims 28,28 --patches 7 --filter 8"}
63 | nordland_eval = {cmd = "pixi run python main.py --database_places 3300 --database_dirs spring,fall --skip 4800 --dataset nordland --dims 28,28 --patches 7 --filter 8 --query_dir summer --query_places 2700 --sim_mat --max_module 1100"}
64 | oxford_train = {cmd = "pixi run python main.py --train_new_model --dataset orc --database_places 450 --database_dirs sun,rain --skip 0 --max_module 450 --dataset orc --dims 28,28 --patches 7 --filter 7"}
65 | oxford_eval = {cmd = "python main.py --dataset orc --database_places 450 --database_dirs sun,rain --skip 630 --dataset orc --dims 28,28 --patches 7 --filter 7 --query_dir dusk --query_places 360 --sim_mat --max_module 450"}
66 |
67 | # helper functions for datasets
68 | scrape_oxford = {cmd = "python scrape_mrgdatashare.py --choice_sensors stereo_left --choice_runs_file orc_list.txt --downloads_dir ~/VPRTempo/vprtempo/dataset --datasets_file datasets.csv"}
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | numpy
4 | pandas
5 | tqdm
6 | prettytable
7 | matplotlib
8 | requests
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | from setuptools import setup, find_packages
4 |
5 | here = os.path.abspath(os.path.dirname(__file__))
6 | # Get the long description from the README file
7 | with open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
8 | long_description = f.read()
9 |
10 | # define the base requires needed for the repo
11 | requirements = [
12 | 'torch',
13 | 'torchvision',
14 | 'numpy',
15 | 'pandas',
16 | 'tqdm',
17 | 'prettytable',
18 | 'matplotlib',
19 | 'requests'
20 | ]
21 |
22 | # define the setup
23 | setup(
24 | name="VPRTempo",
25 | version="1.1.9",
26 | description='VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition',
27 | long_description=long_description,
28 | long_description_content_type='text/markdown',
29 | author='Adam D Hines, Peter G Stratton, Michael Milford and Tobias Fischer',
30 | author_email='adam.hines@qut.edu.au',
31 | url='https://github.com/QVPR/VPRTempo',
32 | license='MIT',
33 | install_requires=requirements,
34 | python_requires='>=3.6, <3.13',
35 | classifiers=[
36 | # 3 - Alpha
37 | # 4 - Beta
38 | # 5 - Production/Stable
39 | 'Development Status :: 4 - Beta',
40 |
41 | # Indicate who your project is intended for
42 | 'Intended Audience :: Developers',
43 | # Pick your license as you wish (should match "license" above)
44 | 'License :: OSI Approved :: MIT License',
45 |
46 | # Specify the Python versions you support here. In particular, ensure
47 | # that you indicate whether you support Python 2, Python 3 or both.
48 | 'Programming Language :: Python :: 3.6',
49 | 'Programming Language :: Python :: 3.7',
50 | 'Programming Language :: Python :: 3.8',
51 | 'Programming Language :: Python :: 3.9',
52 | 'Programming Language :: Python :: 3.10',
53 | 'Programming Language :: Python :: 3.11',
54 | 'Programming Language :: Python :: 3.12',
55 | ],
56 | packages=find_packages(),
57 | keywords=['python', 'place recognition', 'spiking neural networks',
58 | 'computer vision', 'robotics'],
59 | scripts=['main.py'],
60 | )
61 |
--------------------------------------------------------------------------------
/tutorials/1_BasicDemo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "b34c7b8a-e7bb-47f4-b558-be1bde9a7b37",
6 | "metadata": {},
7 | "source": [
8 | "## VPRTempo & VPRTempoQuant - Basic Demo\n",
9 | "\n",
10 | "### By Adam D Hines (https://research.qut.edu.au/qcr/people/adam-hines/)\n",
11 | "\n",
12 | "VPRTempo is based on the following paper, if you use or find this code helpful for your research please consider citing the source:\n",
13 | " \n",
14 | "[Adam D Hines, Peter G Stratton, Michael Milford, & Tobias Fischer. \"VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition. arXiv September 2023](https://arxiv.org/abs/2309.10225)\n",
15 | "\n",
16 | "### Introduction\n",
17 | "\n",
18 | "This is a basic, extremely simplified version of VPRTempo that highlights how images are transformed, spikes and weights are used, and the readout for performance using a model trained using our base system and the Quantized Aware Training (QAT) version. This is a basic, extremely simplified version of VPRTempo that highlights how images are transformed, spikes and weights are used, and the readout for performance. Although the proper implementation is in [PyTorch](https://pytorch.org/), we present a simple NumPy example to get started. As in the paper, we will present a simple example using the [Nordland](https://webdiis.unizar.es/~jmfacil/pr-nordland/#download-dataset) dataset with pre-trained set of weights.\n",
19 | "\n",
20 | "Before starting, make sure the following packages are installed and imported:"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "id": "c879cd02-82db-441d-9476-fff1925bf494",
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "# Imprt opencv-python, NumPy, and matplotlib.pyplot\n",
31 | "try:\n",
32 | " import cv2\n",
33 | " import numpy as np\n",
34 | " import matplotlib.pyplot as plt\n",
35 | "except:\n",
36 | " !pip3 install numpy, opencv-python, matplotlib # pip install if modules not present\n",
37 | " import cv2\n",
38 | " import numpy as np\n",
39 | " import matplotlib.pyplot as plt"
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "id": "66b11853-6e17-4884-ac92-d35d814add42",
45 | "metadata": {},
46 | "source": [
47 | "Next, we will need to get the pretrained weights for the model. To get them and the other materials, please download them from [here](https://www.dropbox.com/scl/fi/bxbzm47kxl24x979q5r5s/1_BasicDemo.zip?rlkey=0umij016whwgm11frzlk63v5k&st=hncbx0ld&dl=0). Unzip the files into the `./tutorials/mats/` folder."
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "id": "bb45df38-e333-46b2-9161-80e6ac367532",
53 | "metadata": {},
54 | "source": [
55 | "### Image processing\n",
56 | "\n",
57 | "Let's have a look at how we process our images to run through VPRTempo. We utilize a technique called *patch normalization* to resize input images and normalize the pixel intensities. To start, let's see what the original image looks like before patch normalization."
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "id": "67f129b5-9a7a-4b50-9d94-b9bf512f8b70",
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "# Load the input image\n",
68 | "raw_img = cv2.imread('./mats/1_BasicDemo/summer.png')\n",
69 | "rgb_img = cv2.cvtColor(raw_img, cv2.COLOR_BGR2RGB) # Convert to RGB\n",
70 | "\n",
71 | "# Plot the image\n",
72 | "plt.imshow(rgb_img)\n",
73 | "plt.title('Nordland Summer')\n",
74 | "plt.show()"
75 | ]
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "id": "b68cf25e-35ae-4885-9cf1-c1b09ce4ad42",
80 | "metadata": {},
81 | "source": [
82 | "What we have here is a 360x640 RGB image, which for processing through neural networks is too big (230,400 total pixels). So instead, we'll use patch normalization to reduce the image size down to a grayscale 56x56 image to just 3136 pixels in total."
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": null,
88 | "id": "6f67656a-3ba4-4374-b780-4e8bac4ec2d2",
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "# Load the patch normalized image\n",
93 | "patch_img = np.load('./mats/1_BasicDemo/summer_patchnorm.npy', allow_pickle=True)\n",
94 | "\n",
95 | "# Plot the image\n",
96 | "plt.matshow(patch_img)\n",
97 | "plt.title('Nordland Summer Patch Normalized')\n",
98 | "plt.colorbar(shrink=0.75, label=\"Pixel intensity\")\n",
99 | "plt.show()"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "id": "d404dfd2-10bd-4981-96ef-6092a9866fc6",
105 | "metadata": {},
106 | "source": [
107 | "The reduced image dimensions with patch normalization allows for a decent representation of the full scene, despite the smaller size.\n",
108 | "\n",
109 | "### Convert images to spikes\n",
110 | "\n",
111 | "'Spikes' in the context of VPRTempo are a little different than conventional spiking neural networks. Typically, spikes from image datasets are converted into Poisson spike trains where the pixel intensity determines the number of spikes to propagate throughout a network. VPRTempo only considers each pixel as a single spike, but considers the *amplitude* of the spike to determine the timing within a single timestep - where large amplitudes (high pixel intensity) spike early in a timestep, and vice versa for small amplitudes. \n",
112 | "\n",
113 | "Let's flatten the patch normalized image into a 1D-array so we can apply our network weights."
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "id": "2bd6ae95-2a79-4b45-8a60-503079339739",
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "# Convert 2D image to a 1D-array\n",
124 | "patch_1d = np.reshape(patch_img, (3136,))"
125 | ]
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "id": "9d9a5eaf-1de3-461f-b138-3ac820da8bae",
130 | "metadata": {},
131 | "source": [
132 | "### Load the pre-trained network weights\n",
133 | "\n",
134 | "Our network consists of the following architecture:\n",
135 | "\n",
136 | " - An input layer sparsely connected to a feature layer, 3136 input neurons to 6272 feature neurons\n",
137 | " - The feature layer fully connected to a one-hot-encoded output layer, 6272 feature neurons to 500 output neurons\n",
138 | "\n",
139 | "Each layer connection is trained separately and stored in different weight matrices for excitatory (positive) and inhibitory (negative) connections. "
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": null,
145 | "id": "6d98749a-8f28-477b-871c-93626e96786c",
146 | "metadata": {},
147 | "outputs": [],
148 | "source": [
149 | "# Load the input to feature excitatory and inhibitory network weights\n",
150 | "featureW = np.load('./mats/1_BasicDemo/featureW.npy')\n",
151 | "\n",
152 | "# Plot the weights\n",
153 | "plt.matshow(featureW.T)\n",
154 | "plt.title('Input > Feature Weights')\n",
155 | "plt.colorbar(shrink=0.8, label=\"Weight strength\")\n",
156 | "\n",
157 | "# Display the plots\n",
158 | "plt.show()"
159 | ]
160 | },
161 | {
162 | "cell_type": "markdown",
163 | "id": "826213d7-7721-440c-b1a8-47fb613339eb",
164 | "metadata": {},
165 | "source": [
166 | "Whilst it might be a little difficult to see, our excitatory connection amplitudes are on average a little higher than our inhibitiory. However, we overall have more inhibitiory connections that positive to balance the system.\n",
167 | "\n",
168 | "This is because when we set up our connections we use a probability of connections for both excitation and inbhition. In this case, we have a 10% connection probability for excitatory weights and a 50% probability for inhibitiory. This means as well there will be a high number of neurons without connections."
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": null,
174 | "id": "94acb2f4",
175 | "metadata": {},
176 | "outputs": [],
177 | "source": [
178 | "# In this function, we will plot and visualize the distribution of weights and connections.\n",
179 | "def count_and_plot(array):\n",
180 | " # Flatten the 2D array and count positive, negative, and zero values\n",
181 | " flattened_array = array.flatten()\n",
182 | " positive_count = np.sum(flattened_array > 0)\n",
183 | " negative_count = np.sum(flattened_array < 0)\n",
184 | " zero_count = np.sum(flattened_array == 0)\n",
185 | " \n",
186 | " # Calculate percentages\n",
187 | " total_count = flattened_array.size\n",
188 | " positive_percentage = (positive_count / total_count) * 100\n",
189 | " negative_percentage = (negative_count / total_count) * 100\n",
190 | " zero_percentage = (zero_count / total_count) * 100\n",
191 | "\n",
192 | " # Print the results\n",
193 | " print(f\"Excitatory Connections: {positive_count} ({positive_percentage:.2f}%)\")\n",
194 | " print(f\"Inhibitory Conncetions: {negative_count} ({negative_percentage:.2f}%)\")\n",
195 | " print(f\"Zero Connections: {zero_count} ({zero_percentage:.2f}%)\")\n",
196 | "\n",
197 | " # Create a bar plot of the percentages\n",
198 | " categories = ['Excitatory', 'Inhibitory', 'Zero']\n",
199 | " percentages = [positive_percentage, negative_percentage, zero_percentage]\n",
200 | "\n",
201 | " plt.bar(categories, percentages)\n",
202 | " plt.xlabel('Category')\n",
203 | " plt.ylabel('Percentage')\n",
204 | " plt.title('Percentage of Excitatory, Inhibitiory, and Zero Connections')\n",
205 | " plt.ylim(0, 60) # Set the y-axis limit to 0-60%\n",
206 | " plt.show()\n",
207 | "\n",
208 | "if __name__ == \"__main__\":\n",
209 | "\n",
210 | " # Call the function to count and plot\n",
211 | " count_and_plot(featureW)"
212 | ]
213 | },
214 | {
215 | "cell_type": "markdown",
216 | "id": "f180220e",
217 | "metadata": {},
218 | "source": [
219 | "Now let's have a look at the feature to the output weights, and see how the distribution of excitiatory and inhibitory connections differs."
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": null,
225 | "id": "7609eae5-f584-4c98-9eb6-3c3cf1e2aa04",
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "# Load the input to feature excitatory and inhibitory network weights\n",
230 | "outputW = np.load('./mats/1_BasicDemo/outputW.npy')\n",
231 | "\n",
232 | "# Plot the weights\n",
233 | "plt.matshow(outputW)\n",
234 | "plt.title('Feature > Output Weights')\n",
235 | "plt.colorbar(shrink=0.8, label=\"Weight strength\")\n",
236 | "\n",
237 | "# Display the plots\n",
238 | "plt.show()\n",
239 | "\n",
240 | "# Plot the distributions\n",
241 | "count_and_plot(outputW)"
242 | ]
243 | },
244 | {
245 | "cell_type": "markdown",
246 | "id": "d591969a-e72e-43b2-8c89-16a13bb29fe6",
247 | "metadata": {},
248 | "source": [
249 | "### Propagate network spikes\n",
250 | "\n",
251 | "Now we'll propagate the input spikes across the layers to get the output. All we have to do is multiply the input spikes by the Input > Feature weights for both excitatory and inhibitory matrices and add them, then take the feature spikes and multiply them by the Feature > Output weights and do the smae thing. We'll also clamp spikes in the range of [0, 0.9] to prevent negative spikes and spike explosions.\n",
252 | "\n",
253 | "Let's do that and visualize the spikes as they're going through, we'll start with the Input to Feature layer."
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": null,
259 | "id": "f6c84239-c176-48c3-8954-25da5f989d61",
260 | "metadata": {},
261 | "outputs": [],
262 | "source": [
263 | "# Calculate feature spikes (positive and negative weights)\n",
264 | "feature_spikes = np.matmul(featureW,patch_1d)\n",
265 | "feature_spikes = np.clip(feature_spikes, 0, 0.9)\n",
266 | "\n",
267 | "# Now create the line plot\n",
268 | "plt.plot(np.arange(len(feature_spikes)), feature_spikes)\n",
269 | "\n",
270 | "# Add title and labels if you wish\n",
271 | "plt.title('Feature Layer Spikes')\n",
272 | "plt.xlabel('Neuron ID')\n",
273 | "plt.ylabel('Spike Amplitude')\n",
274 | "\n",
275 | "# Show the plot\n",
276 | "plt.show()"
277 | ]
278 | },
279 | {
280 | "cell_type": "markdown",
281 | "id": "4ea0b0a3-66fc-4202-963c-cbd05114d283",
282 | "metadata": {},
283 | "source": [
284 | "This looks a little homogenous, but this is the feature representation of our input image. \n",
285 | "\n",
286 | "Now let's propagate the feature layer spikes through to the output layer to get our corresponding place match."
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": null,
292 | "id": "f5d4dc99-c7b9-4e9b-ba7c-58f6e30631cb",
293 | "metadata": {},
294 | "outputs": [],
295 | "source": [
296 | "# Calculate output spikes (positive and negative weights)\n",
297 | "output_spikes = np.matmul(outputW,feature_spikes)\n",
298 | "output_spikes = np.clip(output_spikes, 0, 0.9)\n",
299 | "\n",
300 | "# Now create the line plot\n",
301 | "plt.plot(np.arange(len(output_spikes)), output_spikes)\n",
302 | "\n",
303 | "# Add title and labels if you wish\n",
304 | "plt.title('Output Layer Spikes')\n",
305 | "plt.xlabel('Neuron ID')\n",
306 | "plt.ylabel('Spike Amplitude')\n",
307 | "\n",
308 | "# Show the plot\n",
309 | "plt.show()"
310 | ]
311 | },
312 | {
313 | "cell_type": "markdown",
314 | "id": "54b4f5f6-017b-4d7d-812a-c96baf9cb39f",
315 | "metadata": {},
316 | "source": [
317 | "Success! We have propagated our input spikes across the layers to reach this output. Clearly, one of the output spikes has the highest amplitude. Our network weights were trained on 500 locations from a Fall and Spring traversal of Nordland. For this example, we passed the first location from the Summer traversal through the network to achieve this output - which clearly looks to have spikes Neuron ID '0' the highest!\n",
318 | "\n",
319 | "Let's prove that."
320 | ]
321 | },
322 | {
323 | "cell_type": "code",
324 | "execution_count": null,
325 | "id": "780371ca-9dfe-4dd7-857d-e35be73ffd23",
326 | "metadata": {},
327 | "outputs": [],
328 | "source": [
329 | "# Output the argmax from the output spikes\n",
330 | "prediction = np.argmax(output_spikes)\n",
331 | "print(f\"Neuron ID with the highest output is {prediction}\")"
332 | ]
333 | },
334 | {
335 | "cell_type": "markdown",
336 | "id": "0c8c82d7",
337 | "metadata": {},
338 | "source": [
339 | "## Quantized model example\n",
340 | "\n",
341 | "Now that we have seen how our base model works, let's look at how our int8 quantized model performs by comparison. Working in the in8 space has a few benefits, like faster inferencing time and smaller model sizes. There are a couple differences however when feeding spikes throughout the system that PyTorch performs in the backend.\n",
342 | "\n",
343 | "Let's start by converting our input image into int8 spikes."
344 | ]
345 | },
346 | {
347 | "cell_type": "code",
348 | "execution_count": null,
349 | "id": "5c893e76",
350 | "metadata": {},
351 | "outputs": [],
352 | "source": [
353 | "# Converting fp32 spikes to int8 uses a learned scale factor during quantization aware training\n",
354 | "spike_scale = 133\n",
355 | "patch_img_int = patch_img*spike_scale\n",
356 | "\n",
357 | "# Plot the converted int8 image\n",
358 | "plt.matshow(patch_img_int)\n",
359 | "plt.title('Nordland Summer Patch Normalized Int8')\n",
360 | "plt.colorbar(shrink=0.75, label=\"Pixel intensity\")\n",
361 | "plt.show()\n",
362 | "\n",
363 | "# Convert 2D image to a 1D-array\n",
364 | "patch_1d_int = np.reshape(patch_img_int, (3136,))"
365 | ]
366 | },
367 | {
368 | "cell_type": "markdown",
369 | "id": "9f68d3dc",
370 | "metadata": {},
371 | "source": [
372 | "Now we'll load in and plot our integer based weights, as well as some scale factors which will be important to reduce the size of our spikes after multiplying them with our weights."
373 | ]
374 | },
375 | {
376 | "cell_type": "code",
377 | "execution_count": null,
378 | "id": "70c99f47",
379 | "metadata": {},
380 | "outputs": [],
381 | "source": [
382 | "# Load the scales for the feature and output spikes\n",
383 | "feature_scales = np.load('./mats/1_BasicDemo/featureScales.npy',allow_pickle=True)\n",
384 | "output_scales = np.load('./mats/1_BasicDemo/outputScales.npy',allow_pickle=True)\n",
385 | "\n",
386 | "# Load the int8 weights and plot them\n",
387 | "featureQuantW = np.load('./mats/1_BasicDemo/featureQuantW.npy')\n",
388 | "outputQuantW = np.load('./mats/1_BasicDemo/outputQuantW.npy')\n",
389 | "\n",
390 | "# Plot the feature weights\n",
391 | "plt.matshow(featureQuantW.T)\n",
392 | "plt.title('Input > Feature Weights')\n",
393 | "plt.colorbar(shrink=0.8, label=\"Weight strength\")\n",
394 | "\n",
395 | "# Display the plots\n",
396 | "plt.show()\n",
397 | "\n",
398 | "# Plot the output weights\n",
399 | "plt.matshow(outputQuantW)\n",
400 | "plt.title('Feature > Output Weights')\n",
401 | "plt.colorbar(shrink=0.8, label=\"Weight strength\")\n",
402 | "\n",
403 | "# Display the plots\n",
404 | "plt.show()"
405 | ]
406 | },
407 | {
408 | "cell_type": "markdown",
409 | "id": "f185e3ff",
410 | "metadata": {},
411 | "source": [
412 | "Now as above, let's propagate the input spikes throughout the network."
413 | ]
414 | },
415 | {
416 | "cell_type": "code",
417 | "execution_count": null,
418 | "id": "ddb1ef65",
419 | "metadata": {},
420 | "outputs": [],
421 | "source": [
422 | "# Get the feature spikes\n",
423 | "feature_spikes_int = np.matmul(featureQuantW,patch_1d_int)\n",
424 | "\n",
425 | "# Now create the line plot\n",
426 | "plt.plot(np.arange(len(feature_spikes_int)), feature_spikes_int)\n",
427 | "\n",
428 | "# Add title and labels if you wish\n",
429 | "plt.title('Output Layer Spikes')\n",
430 | "plt.xlabel('Neuron ID')\n",
431 | "plt.ylabel('Spike Amplitude')\n",
432 | "\n",
433 | "# Show the plot\n",
434 | "plt.show()"
435 | ]
436 | },
437 | {
438 | "cell_type": "markdown",
439 | "id": "6e9189aa",
440 | "metadata": {},
441 | "source": [
442 | "Those are some big spikes! We're going to have to scale these spikes back down before we forward them to the output layer, otherwise we'll have some huge activations. Let's take those scales we loaded in earlier and apply them to the feature spikes.\n",
443 | "\n",
444 | "We have three things to consider here:\n",
445 | " - A slice scale factor (per neuronal connection scale)\n",
446 | " - A zero point (a factor to change where 'zero' is)\n",
447 | " \n",
448 | " Let's print out these three factors and see how they scale our spikes."
449 | ]
450 | },
451 | {
452 | "cell_type": "code",
453 | "execution_count": null,
454 | "id": "95032ac7",
455 | "metadata": {},
456 | "outputs": [],
457 | "source": [
458 | "# Print out the individual scales\n",
459 | "print(f\"The slice scale factor is {feature_scales[1]}\")\n",
460 | "print(f\"The zero point is {feature_scales[2]}\")"
461 | ]
462 | },
463 | {
464 | "cell_type": "markdown",
465 | "id": "9f62b909",
466 | "metadata": {},
467 | "source": [
468 | "Now we'll modify and scale our spikes to then pass them on to the feature layer."
469 | ]
470 | },
471 | {
472 | "cell_type": "code",
473 | "execution_count": null,
474 | "id": "9a7dd83d",
475 | "metadata": {},
476 | "outputs": [],
477 | "source": [
478 | "# Scale the feature spikes\n",
479 | "scaled_feature_spikes = (feature_spikes_int//(feature_scales[1]))+feature_scales[2]\n",
480 | "scaled_feature_spikes = np.clip(scaled_feature_spikes,0,255)\n",
481 | "\n",
482 | "# Plot the scaled feature spikes\n",
483 | "plt.plot(np.arange(len(scaled_feature_spikes)), scaled_feature_spikes)\n",
484 | "\n",
485 | "# Add title and labels if you wish\n",
486 | "plt.title('Output Layer Spikes')\n",
487 | "plt.xlabel('Neuron ID')\n",
488 | "plt.ylabel('Spike Amplitude')\n",
489 | "\n",
490 | "# Show the plot\n",
491 | "plt.show()"
492 | ]
493 | },
494 | {
495 | "cell_type": "markdown",
496 | "id": "2a767c92",
497 | "metadata": {},
498 | "source": [
499 | "Now that we've scaled our feature spikes, let's pass them through to the output layer and get our match!"
500 | ]
501 | },
502 | {
503 | "cell_type": "code",
504 | "execution_count": null,
505 | "id": "7519c07f",
506 | "metadata": {},
507 | "outputs": [],
508 | "source": [
509 | "# Get the output spikes\n",
510 | "output_spikes_int = np.matmul(outputQuantW,scaled_feature_spikes)\n",
511 | "\n",
512 | "# Scale the output spikes\n",
513 | "scaled_output_spikes = output_spikes_int//(output_scales[1]) + output_scales[2]\n",
514 | "\n",
515 | "# Plot the scaled feature spikes\n",
516 | "plt.plot(np.arange(len(scaled_output_spikes)), scaled_output_spikes)\n",
517 | "\n",
518 | "# Add title and labels if you wish\n",
519 | "plt.title('Output Layer Spikes')\n",
520 | "plt.xlabel('Neuron ID')\n",
521 | "plt.ylabel('Spike Amplitude')\n",
522 | "\n",
523 | "# Show the plot\n",
524 | "plt.show()"
525 | ]
526 | },
527 | {
528 | "cell_type": "markdown",
529 | "id": "3cc4588a",
530 | "metadata": {},
531 | "source": [
532 | "And once again, as in the base model, we can see that output neuron 0 is the highest respondant.\n",
533 | "\n",
534 | "Let's prove it!"
535 | ]
536 | },
537 | {
538 | "cell_type": "code",
539 | "execution_count": null,
540 | "id": "67e457e4",
541 | "metadata": {},
542 | "outputs": [],
543 | "source": [
544 | "# Output the argmax from the output spikes\n",
545 | "prediction = np.argmax(scaled_output_spikes)\n",
546 | "print(f\"Neuron ID with the highest output is {prediction}\")"
547 | ]
548 | },
549 | {
550 | "cell_type": "markdown",
551 | "id": "7bc8a7fb-66b4-455b-922e-b0fdc38b53c5",
552 | "metadata": {},
553 | "source": [
554 | "### Conclusions\n",
555 | "\n",
556 | "We have gone through a very basic demo of how VPRTempo takes input images, patch normalizes them, and propagates the spikes throughout the weights to achieve the desired matching output. Although this demonstration was performed using NumPy, the torch implementation is virtually the same except we use tensors with or without quantization. \n",
557 | "\n",
558 | "We also went through how the quantization version of the network handled weights and spikes in the integer domain.\n",
559 | "\n",
560 | "If you would like to go more in-depth with training and inferencing, checkout some of the [other tutorials](https://github.com/AdamDHines/VPRTempo-quant/tree/main/tutorials) which show you how to train your own model and goes through the more sophisticated implementation of VPRTempo."
561 | ]
562 | }
563 | ],
564 | "metadata": {
565 | "kernelspec": {
566 | "display_name": "Python 3 (ipykernel)",
567 | "language": "python",
568 | "name": "python3"
569 | },
570 | "language_info": {
571 | "codemirror_mode": {
572 | "name": "ipython",
573 | "version": 3
574 | },
575 | "file_extension": ".py",
576 | "mimetype": "text/x-python",
577 | "name": "python",
578 | "nbconvert_exporter": "python",
579 | "pygments_lexer": "ipython3",
580 | "version": "3.11.4"
581 | }
582 | },
583 | "nbformat": 4,
584 | "nbformat_minor": 5
585 | }
586 |
--------------------------------------------------------------------------------
/tutorials/2_Training.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "2a09c9e8-68f6-4edd-a8c1-1cc7f6b8bb00",
6 | "metadata": {},
7 | "source": [
8 | "## Training a new VPRTempo & VPRTempoQuant network\n",
9 | "\n",
10 | "### By Adam D Hines (https://research.qut.edu.au/qcr/people/adam-hines/)\n",
11 | "\n",
12 | "VPRTempo is based on the following paper, if you use or find this code helpful for your research please consider citing the source:\n",
13 | " \n",
14 | "[Adam D Hines, Peter G Stratton, Michael Milford, & Tobias Fischer. \"VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition. arXiv September 2023](https://arxiv.org/abs/2309.10225)\n",
15 | "\n",
16 | "### Introduction\n",
17 | "\n",
18 | "In this tutorial, we will go through how to train your own model for both the base and quantized version of VPRTempo. \n",
19 | "\n",
20 | "Before starting, make sure you have [installed the dependencies](https://github.com/AdamDHines/VPRTempo-quant#installation-and-setup) and/or activated the conda environment. You will also need the [Nordland](https://github.com/AdamDHines/VPRTempo-quant#nordland) dataset before proceeding, as this tutorial will cover training the network using this as an example.\n",
21 | "\n",
22 | "### Training new models for VPRTempo and VPRTempoQuant\n",
23 | "\n",
24 | "Let's start by training the base model with the default settings (if you have pre-trained a model, it will get removed for the purpose of the tutorial)."
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "id": "a757316e-4aa8-41aa-b03f-1ce4489d3705",
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "# Change the working directory to the main folder from tutorials\n",
35 | "import os\n",
36 | "os.chdir('../')"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "id": "eff21590-9d00-4d8f-b860-ca1c0c849187",
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "# Train the base model with the default settings\n",
47 | "# If the pre-trained model already exists, we will remove it for the tutorial\n",
48 | "file_path = './models/VPRTempo313662725001.pth'\n",
49 | "\n",
50 | "if os.path.exists(file_path):\n",
51 | " os.remove(file_path)\n",
52 | " print(\"The file has been deleted.\")\n",
53 | "\n",
54 | "# Run the training paradigm\n",
55 | "!python main.py --train_new_model"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "id": "285ca349-6295-4f04-bbe1-360d43111f82",
61 | "metadata": {},
62 | "source": [
63 | "Now we'll run the inferencing model to check and make sure our model trained ok."
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "id": "1037e4ee-0361-4aa1-bfba-f979ffa51b88",
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "# Run the base inferencing network\n",
74 | "!python main.py"
75 | ]
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "id": "cbf8cd37-1af8-420f-ad76-481157a2920e",
80 | "metadata": {},
81 | "source": [
82 | "Great! Now let's have a look at changing a few of the default settings and training different kinds of networks. The default settings train 500 places, so if we want to only look at a smaller number of places we can parse the `--num_places` argument and specify how many places to learn."
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": null,
88 | "id": "81daf9c8-ea06-4e69-82a5-9179a7c38ea1",
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "# Train a new model with 250 places\n",
93 | "!python main.py --num_places 250 --train_new_model"
94 | ]
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "id": "a46cde0d-6343-48a4-87f8-aa5c9f56b231",
99 | "metadata": {},
100 | "source": [
101 | "And we can now inference using this smaller model."
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": null,
107 | "id": "559e2391-d0bc-4d58-bb35-8cf69ae08b5e",
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "# Run the inference for a model with 250 places\n",
112 | "!python main.py --num_places 250"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "id": "aa835469-90cc-4f09-b7b4-688465321b18",
118 | "metadata": {},
119 | "source": [
120 | "Arguments for the base network work the same for VPRTempoQuant, we just need to also parse the `--quantize` argument. Let's now train another 250 place network, but also change a couple of other parameters. The default VPRTempo settings is a little slow to train on CPU, so let's reduce the image size from 56x56 to 28x28 and change the number of patches for patch normalization."
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "id": "eee89998-d334-4cdf-ac5b-88aed367ceab",
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "# Train a 250 place network with VPRTempoQuant\n",
131 | "!python main.py --quantize --num_places 250 --patches 7 --dims 28,28 --train_new_model"
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "id": "42d5f3ce-acbb-435d-b197-f7e5841dcd70",
137 | "metadata": {},
138 | "source": [
139 | "And now we can inference this model!"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": null,
145 | "id": "c6a9e948-5349-4714-b7d0-03ce694409b8",
146 | "metadata": {},
147 | "outputs": [],
148 | "source": [
149 | "# Run inference on newly trained VPRTempoQuant model\n",
150 | "!python main.py --quantize --num_places 250 --patches 7 --dims 28,28"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "id": "2648e38e-a0d7-481a-9009-3c4cf3e3fff0",
156 | "metadata": {},
157 | "source": [
158 | "### List of arguments you can parse\n",
159 | "\n",
160 | "The full list of arguments that can parsed to VPRTempo can be found in the `parse_network` function of `main.py`. Hyperparameters for VPRTempo are hardcoded into the layers and are not recommended to be changed since they generalize fairly well across multiple different datasets. \n",
161 | "\n",
162 | "### Conclusions\n",
163 | "\n",
164 | "This tutorial provided a simple overview of how you can train your own models for both VPRTempo and VPRTempoQuant, and changing a few of the network parameters.\n",
165 | "\n",
166 | "If you would like to go more in-depth, checkout some of the [other tutorials](https://github.com/AdamDHines/VPRTempo-quant/tree/main/tutorials) where we cover how to define your own custom dataset and work with expert modules."
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "id": "a2bdc6d1-4690-4b48-a75f-f9d716cb7154",
173 | "metadata": {},
174 | "outputs": [],
175 | "source": []
176 | }
177 | ],
178 | "metadata": {
179 | "kernelspec": {
180 | "display_name": "Python 3 (ipykernel)",
181 | "language": "python",
182 | "name": "python3"
183 | },
184 | "language_info": {
185 | "codemirror_mode": {
186 | "name": "ipython",
187 | "version": 3
188 | },
189 | "file_extension": ".py",
190 | "mimetype": "text/x-python",
191 | "name": "python",
192 | "nbconvert_exporter": "python",
193 | "pygments_lexer": "ipython3",
194 | "version": "3.11.4"
195 | }
196 | },
197 | "nbformat": 4,
198 | "nbformat_minor": 5
199 | }
200 |
--------------------------------------------------------------------------------
/tutorials/3_Modules.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "7b7b7556-f8df-4e50-801b-7e215e46a415",
6 | "metadata": {},
7 | "source": [
8 | "## Using modules with VPRTempo and VPRTempoQuant\n",
9 | "\n",
10 | "### By Adam D Hines (https://research.qut.edu.au/qcr/people/adam-hines/)\n",
11 | "\n",
12 | "VPRTempo is based on the following paper, if you use or find this code helpful for your research please consider citing the source:\n",
13 | " \n",
14 | "[Adam D Hines, Peter G Stratton, Michael Milford, & Tobias Fischer. \"VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition. arXiv September 2023](https://arxiv.org/abs/2309.10225)\n",
15 | "\n",
16 | "### Introduction\n",
17 | "\n",
18 | "In this tutorial, we will go through how to use modules with VPRTempo. Modules break up the training data into multiple networks, which has been shown to [improve the overall performance](https://towardsdatascience.com/machine-learning-with-expert-models-a-primer-6c74585f223f) and accuracy of larger models.\n",
19 | "\n",
20 | "Before starting, make sure you have [installed the dependencies](https://github.com/AdamDHines/VPRTempo-quant#installation-and-setup) and/or activated the conda environment. You will also need the [Nordland](https://github.com/AdamDHines/VPRTempo-quant#nordland) dataset before proceeding, as this tutorial will cover training the network using this as an example.\n",
21 | "\n",
22 | "### Comparing results using expert modules for VPRTempo\n",
23 | "\n",
24 | "Let's start by training the base model with 1000 places, which is 500 more than the default settings. We will need to parse the `--train_new_model`, `--num_places`, as well as another argument we haven't seen yet `--max_module`. \n",
25 | "\n",
26 | "`--max_module` tells the network how many places each expert module should learn, which by default is set to `500`. So if we're training a new, singular network with 1000 places we need to increase `max_module` to 1000."
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "id": "9e0d1a7b-8788-436a-8dd5-27f60077c524",
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "# Change the working directory to the main folder from tutorials\n",
37 | "import os\n",
38 | "os.chdir('../')"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "id": "c0981001-fa1c-49e0-bbe3-e5a628098eba",
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "# Train a single network with 1000 places\n",
49 | "!python main.py --num_places 1000 --max_module 1000 --train_new_model"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "id": "e7d8287b-b47b-4f7e-9d1c-4b309746b284",
55 | "metadata": {},
56 | "source": [
57 | "Now let's see how this performs."
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "id": "32435d47-e583-4768-86ed-2998ec78fdd6",
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "# Run the inferencing network on the singular 1000 place trained model\n",
68 | "!python main.py --num_places 1000 --max_module 1000"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "id": "845764ee-1492-449c-9816-a45ccfe043a1",
74 | "metadata": {},
75 | "source": [
76 | "Performance here is still pretty good, but let's see if we can improve it by splitting up the network into modules!\n",
77 | "\n",
78 | "Now that splitting up our 1000 place network into 2 networks, we can remove the `--max_module` argument because the default is set to 500. Instead what we will parse is `--num_modules` to tell the network to split things up into two models."
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "id": "d7cf368b-2f7f-4345-b882-7bf51622c0d6",
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "# Train a new 1000 place model with 2 modules\n",
89 | "!python main.py --num_places 1000 --num_modules 2 --train_new_model"
90 | ]
91 | },
92 | {
93 | "cell_type": "markdown",
94 | "id": "84ce9e04-ff62-45ce-83b3-89e4623f0fd1",
95 | "metadata": {},
96 | "source": [
97 | "Now let's see how it compares with the singular model."
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "id": "c6ec3ec7-cca7-45a0-906b-cc3e6d14b5f8",
104 | "metadata": {},
105 | "outputs": [],
106 | "source": [
107 | "# Run the network with 2 modules\n",
108 | "!python main.py --num_places 1000 --num_modules 2"
109 | ]
110 | },
111 | {
112 | "cell_type": "markdown",
113 | "id": "8da07161-81e9-4773-90f1-e9d675e8c071",
114 | "metadata": {},
115 | "source": [
116 | "A modest boost to performance, however you have to imagine how this scales to much larger networks - especially when considering training times. Because the output layer is one-hot encoded, you need to increase the number of output neurons with each place you want to learn. Splitting up networks has a key benefit for VPRTempo to reduce overall training times with little impact on inference speeds. \n",
117 | "\n",
118 | "(Optional) Run a single network for 2500 places or 5 expert modules for 500 places each (reduced dimensionality to speed things up)."
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "id": "ffd18263-190c-47f4-a37f-130a040d7a67",
125 | "metadata": {},
126 | "outputs": [],
127 | "source": [
128 | "# Optional: run a 2500 place comparison for singular vs modular networks\n",
129 | "# Train networks\n",
130 | "!python main.py --num_places 2500 --max_module 2500 --dims 28,28 --patches 7 --train_new_model\n",
131 | "!python main.py --num_places 2500 --num_modules 5 --dims 28,28 --patches 7 --train_new_model\n",
132 | "# Run inference\n",
133 | "!python main.py --num_places 2500 --max_module 2500 --dims 28,28 --patches 7\n",
134 | "!python main.py --num_places 2500 --num_modules 5 --dims 28,28 --patches 7"
135 | ]
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "id": "02edb3da-9d34-4709-950d-d202c5e902e7",
140 | "metadata": {},
141 | "source": [
142 | "As in the other tutorials, parsing the `--quantize` argument will run exactly the same but for VPRTempoQuant. Let's do a quick comparison."
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": null,
148 | "id": "2e520f0a-7b6f-4546-8169-14f20f2f9d18",
149 | "metadata": {},
150 | "outputs": [],
151 | "source": [
152 | "# Train networks\n",
153 | "#!python main.py --num_places 1500 --max_module 1500 --dims 28,28 --patches 7 --train_new_model --quantize\n",
154 | "#!python main.py --num_places 1500 --num_modules 3 --dims 28,28 --patches 7 --train_new_model --quantize\n",
155 | "# Run inference\n",
156 | "!python main.py --num_places 1500 --max_module 1500 --dims 28,28 --patches 7 --quantize\n",
157 | "!python main.py --num_places 1500 --num_modules 3 --dims 28,28 --patches 7 --quantize"
158 | ]
159 | },
160 | {
161 | "cell_type": "markdown",
162 | "id": "58581b5c-70a8-4708-bda1-1ccca1edc70b",
163 | "metadata": {},
164 | "source": [
165 | "Once again, we can see that whilst there's a modest boost to the accuracy result the clear improve is the training speed. Because each network is smaller, the opeations on CPU are a lot less computationally heavy when splitting the networks up.\n",
166 | "\n",
167 | "### Conclusions\n",
168 | "\n",
169 | "This tutorial provided a simple overview of how you can train network models using expert modules. \n",
170 | "\n",
171 | "If you would like to go more in-depth, checkout some of the [other tutorials](https://github.com/AdamDHines/VPRTempo-quant/tree/main/tutorials)."
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "id": "0309b80b-05c2-4577-8ea0-2e45bf2d6aef",
178 | "metadata": {},
179 | "outputs": [],
180 | "source": []
181 | }
182 | ],
183 | "metadata": {
184 | "kernelspec": {
185 | "display_name": "Python 3 (ipykernel)",
186 | "language": "python",
187 | "name": "python3"
188 | },
189 | "language_info": {
190 | "codemirror_mode": {
191 | "name": "ipython",
192 | "version": 3
193 | },
194 | "file_extension": ".py",
195 | "mimetype": "text/x-python",
196 | "name": "python",
197 | "nbconvert_exporter": "python",
198 | "pygments_lexer": "ipython3",
199 | "version": "3.11.4"
200 | }
201 | },
202 | "nbformat": 4,
203 | "nbformat_minor": 5
204 | }
205 |
--------------------------------------------------------------------------------
/tutorials/mats/README.txt:
--------------------------------------------------------------------------------
1 | To download materials for BasicDemo, please visit -> https://www.dropbox.com/scl/fi/bxbzm47kxl24x979q5r5s/1_BasicDemo.zip?rlkey=0umij016whwgm11frzlk63v5k&st=52otrem5&dl=0
2 |
3 | Extract into the `./tutorials/mats` folder.
4 |
--------------------------------------------------------------------------------
/vprtempo/VPRTempo.py:
--------------------------------------------------------------------------------
1 | #MIT License
2 |
3 | #Copyright (c) 2023 Adam Hines, Peter G Stratton, Michael Milford, Tobias Fischer
4 |
5 | #Permission is hereby granted, free of charge, to any person obtaining a copy
6 | #of this software and associated documentation files (the "Software"), to deal
7 | #in the Software without restriction, including without limitation the rights
8 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | #copies of the Software, and to permit persons to whom the Software is
10 | #furnished to do so, subject to the following conditions:
11 |
12 | #The above copyright notice and this permission notice shall be included in all
13 | #copies or substantial portions of the Software.
14 |
15 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | #SOFTWARE.
22 |
23 | '''
24 | Imports
25 | '''
26 |
27 | import os
28 | import json
29 | import torch
30 |
31 | import numpy as np
32 | import torch.nn as nn
33 | import matplotlib.pyplot as plt
34 | import vprtempo.src.blitnet as bn
35 |
36 | from tqdm import tqdm
37 | from prettytable import PrettyTable
38 | from torch.utils.data import DataLoader
39 | from vprtempo.src.download import get_data_model
40 | from vprtempo.src.metrics import recallAtK, createPR
41 | from vprtempo.src.dataset import CustomImageDataset, ProcessImage
42 |
43 | class VPRTempo(nn.Module):
44 | def __init__(self, args, dims, logger, num_modules, output_folder, out_dim, out_dim_remainder=None):
45 | super(VPRTempo, self).__init__()
46 |
47 | # Set the args
48 | if args is not None:
49 | self.args = args
50 | for arg in vars(args):
51 | setattr(self, arg, getattr(args, arg))
52 | setattr(self, 'dims', dims)
53 |
54 | # Set the device
55 | if torch.cuda.is_available():
56 | self.device = "cuda:0"
57 | elif torch.backends.mps.is_available():
58 | self.device = "mps"
59 | else:
60 | self.device = "cpu"
61 |
62 | # Set input args
63 | self.logger = logger
64 | self.num_modules = num_modules
65 | self.output_folder = output_folder
66 |
67 | # Set the dataset file
68 | self.dataset_file = os.path.join('./vprtempo/dataset', f'{self.dataset}-{self.query_dir}' + '.csv')
69 | self.query_dir = [dir.strip() for dir in self.query_dir.split(',')]
70 |
71 | # Layer dict to keep track of layer names and their order
72 | self.layer_dict = {}
73 | self.layer_counter = 0
74 | self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')]
75 |
76 | # Define layer architecture
77 | self.input = int(self.dims[0]*self.dims[1])
78 | self.feature = int(self.input * 2)
79 |
80 | # Output dimension changes for final module if not an even distribution of places
81 | if not out_dim_remainder is None:
82 | self.output = out_dim_remainder
83 | else:
84 | self.output = out_dim
85 |
86 | # set model name for default demo
87 | self.demo = './vprtempo/models/springfall_VPRTempo_IN3136_FN6272_DB500.pth'
88 |
89 | """
90 | Define trainable layers here
91 | """
92 | self.add_layer(
93 | 'feature_layer',
94 | dims=[self.input, self.feature],
95 | device=self.device,
96 | inference=True
97 | )
98 | self.add_layer(
99 | 'output_layer',
100 | dims=[self.feature, self.output],
101 | device=self.device,
102 | inference=True
103 | )
104 |
105 | def add_layer(self, name, **kwargs):
106 | """
107 | Dynamically add a layer with given name and keyword arguments.
108 |
109 | :param name: Name of the layer to be added
110 | :type name: str
111 | :param kwargs: Hyperparameters for the layer
112 | """
113 | # Check for layer name duplicates
114 | if name in self.layer_dict:
115 | raise ValueError(f"Layer with name {name} already exists.")
116 |
117 | # Add a new SNNLayer with provided kwargs
118 | setattr(self, name, bn.SNNLayer(**kwargs))
119 |
120 | # Add layer name and index to the layer_dict
121 | self.layer_dict[name] = self.layer_counter
122 | self.layer_counter += 1
123 |
124 | def evaluate(self, models, test_loader):
125 | """
126 | Run the inferencing model and calculate the accuracy.
127 |
128 | :param models: Models to run inference on, each model is a VPRTempo module
129 | :param test_loader: Testing data loader
130 | """
131 | # Initialize the tqdm progress bar
132 | pbar = tqdm(total=self.query_places,
133 | desc="Running the test network",
134 | position=0)
135 | self.inferences = []
136 | for model in models:
137 | self.inferences.append(nn.Sequential(
138 | model.feature_layer.w,
139 | model.output_layer.w,
140 | ))
141 | self.inferences[-1].to(torch.device(self.device))
142 | # Initiliaze the output spikes variable
143 | out = []
144 | labels = []
145 |
146 | # Run inference for the specified number of timesteps
147 | for spikes, label in test_loader:
148 | # Set device
149 | spikes = spikes.to(self.device)
150 | labels.append(label.detach().cpu().item())
151 | # Forward pass
152 | spikes = self.forward(spikes)
153 | # Add output spikes to list
154 | out.append(spikes.detach().cpu())
155 | pbar.update(1)
156 |
157 | # Close the tqdm progress bar
158 | pbar.close()
159 | # Rehsape output spikes into a similarity matrix
160 | out = torch.stack(out, dim=2)
161 | out = out.squeeze(0).numpy()
162 |
163 | if self.skip != 0:
164 | GT = np.zeros((model.database_places, model.query_places))
165 | skip = model.skip // model.filter
166 | # Create an array of indices for the query dimension
167 | query_indices = np.arange(model.query_places)
168 |
169 | # Set the ones on the diagonal starting at row `skip`
170 | GT[skip + query_indices, query_indices] = 1
171 | else:
172 | GT = np.eye(model.database_places, model.query_places)
173 |
174 | # Apply GT tolerance
175 | if self.GT_tolerance > 0:
176 | # Get the number of rows and columns
177 | num_rows, num_cols = GT.shape
178 |
179 | # Iterate over each column
180 | for col in range(num_cols):
181 | # Find the indices of rows where GT has a 1 in the current column
182 | ones_indices = np.where(GT[:, col] == 1)[0]
183 |
184 | # For each index with a 1, set 1s in GTtol within the specified vertical distance
185 | for row in ones_indices:
186 | # Determine the start and end rows, ensuring they are within bounds
187 | start_row = max(row - self.GT_tolerance, 0)
188 | end_row = min(row + self.GT_tolerance + 1, num_rows) # +1 because upper bound is exclusive
189 |
190 | # Set the range in GTtol to 1
191 | GT[start_row:end_row, col] = 1
192 |
193 | # If user specified, generate a PR curve
194 | if model.PR_curve:
195 | # Create PR curve
196 | P, R = createPR(out, GT, matching='single', n_thresh=100)
197 | # Combine P and R into a list of lists
198 | PR_data = {
199 | "Precision": P,
200 | "Recall": R
201 | }
202 | output_file = "PR_curve_data.json"
203 | # Construct the full path
204 | full_path = f"{model.output_folder}/{output_file}"
205 | # Write the data to a JSON file
206 | with open(full_path, 'w') as file:
207 | json.dump(PR_data, file)
208 | # Plot PR curve
209 | plt.plot(R,P)
210 | plt.xlabel('Recall')
211 | plt.ylabel('Precision')
212 | plt.title('Precision-Recall Curve')
213 | plt.show()
214 |
215 | plt.close()
216 |
217 | if model.sim_mat:
218 | # Create a figure and a set of subplots
219 | fig, axs = plt.subplots(1, 2, figsize=(15, 5))
220 |
221 | # Plot each matrix using matshow
222 | cax1 = axs[0].matshow(out, cmap='viridis')
223 | fig.colorbar(cax1, ax=axs[0], shrink=0.8)
224 | axs[0].set_title('Similarity matrix')
225 |
226 | cax2 = axs[1].matshow(GT, cmap='plasma')
227 | fig.colorbar(cax2, ax=axs[1], shrink=0.8)
228 | axs[1].set_title('GT')
229 |
230 | # Adjust layout
231 | plt.tight_layout()
232 | plt.show()
233 |
234 | plt.close()
235 |
236 | # Recall@N
237 | N = [1,5,10,15,20,25] # N values to calculate
238 | R = [] # Recall@N values
239 | # Calculate Recall@N
240 | for n in N:
241 | R.append(round(recallAtK(out, GT, K=n),2))
242 | # Print the results
243 | table = PrettyTable()
244 | table.field_names = ["N", "1", "5", "10", "15", "20", "25"]
245 | table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]])
246 | self.logger.info(table)
247 |
248 | def forward(self, spikes):
249 | """
250 | Compute the forward pass of the model.
251 |
252 | Parameters:
253 | - spikes (Tensor): Input spikes.
254 |
255 | Returns:
256 | - Tensor: Output after processing.
257 | """
258 | # Initialize the output spikes tensor
259 | in_spikes = spikes.detach().clone()
260 | outputs = [] # List to collect output tensors
261 |
262 | # Run inferencing across modules
263 | for inference in self.inferences:
264 | out_spikes = inference(in_spikes)
265 | outputs.append(out_spikes) # Append the output tensor to the list
266 |
267 | # Concatenate along the desired dimension
268 | concatenated_output = torch.cat(outputs, dim=1)
269 |
270 | return concatenated_output
271 |
272 | def load_model(self, models, model_path):
273 | """
274 | Load pre-trained model and set the state dictionary keys.
275 | """
276 | # check if model exists, download the demo data and model if defaults are set
277 | if not os.path.exists(model_path):
278 | if model_path == self.demo:
279 | get_data_model()
280 | else:
281 | raise ValueError(f"Model path {model_path} does not exist.")
282 |
283 | combined_state_dict = torch.load(model_path, map_location=self.device, weights_only=True)
284 |
285 | for i, model in enumerate(models): # models_classes is a list of model classes
286 | model.load_state_dict(combined_state_dict[f'model_{i}'])
287 | model.eval() # Set the model to inference mode
288 |
289 | def run_inference(models, model_name):
290 | """
291 | Run inference on a pre-trained model.
292 |
293 | :param models: Models to run inference on, each model is a VPRTempo module
294 | :param model_name: Name of the model to load
295 | """
296 | # Set first index model as the main model for parameters
297 | model = models[0]
298 | # Initialize the image transforms
299 | image_transform = ProcessImage(model.dims, model.patches)
300 |
301 | # Initialize the test dataset
302 | test_dataset = CustomImageDataset(annotations_file=model.dataset_file,
303 | base_dir=model.data_dir,
304 | img_dirs=model.query_dir,
305 | transform=image_transform,
306 | max_samples=model.query_places,
307 | filter=model.filter,
308 | skip=model.skip)
309 |
310 | # Initialize the data loader
311 | test_loader = DataLoader(test_dataset,
312 | batch_size=1,
313 | num_workers=4,
314 | persistent_workers=True)
315 |
316 | # Load the model
317 | model.load_model(models, os.path.join('./vprtempo/models', model_name))
318 |
319 | # Use evaluate method for inference accuracy
320 | with torch.no_grad():
321 | model.evaluate(models, test_loader)
--------------------------------------------------------------------------------
/vprtempo/VPRTempoQuant.py:
--------------------------------------------------------------------------------
1 | #MIT License
2 |
3 | #Copyright (c) 2023 Adam Hines, Peter G Stratton, Michael Milford, Tobias Fischer
4 |
5 | #Permission is hereby granted, free of charge, to any person obtaining a copy
6 | #of this software and associated documentation files (the "Software"), to deal
7 | #in the Software without restriction, including without limitation the rights
8 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | #copies of the Software, and to permit persons to whom the Software is
10 | #furnished to do so, subject to the following conditions:
11 |
12 | #The above copyright notice and this permission notice shall be included in all
13 | #copies or substantial portions of the Software.
14 |
15 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | #SOFTWARE.
22 |
23 | '''
24 | Imports
25 | '''
26 |
27 | import os
28 | import json
29 | import torch
30 | import random
31 |
32 | import numpy as np
33 | import torch.nn as nn
34 | import matplotlib.pyplot as plt
35 | import vprtempo.src.blitnet as bn
36 |
37 | from tqdm import tqdm
38 | from prettytable import PrettyTable
39 | from torch.utils.data import DataLoader
40 | from vprtempo.src.download import get_data_model
41 | from vprtempo.src.metrics import recallAtK, createPR
42 | from torch.ao.quantization import QuantStub, DeQuantStub
43 | from vprtempo.src.dataset import CustomImageDataset, ProcessImage
44 |
45 | #from main import parse_network
46 |
47 | class VPRTempoQuant(nn.Module):
48 | def __init__(self, args, dims, logger, num_modules, output_folder, out_dim, out_dim_remainder=None):
49 | super(VPRTempoQuant, self).__init__()
50 |
51 | # Set the args
52 | if args is not None:
53 | self.args = args
54 | for arg in vars(args):
55 | setattr(self, arg, getattr(args, arg))
56 | setattr(self, 'dims', dims)
57 |
58 | # Set the device
59 | self.device = "cpu"
60 |
61 | # Set input args
62 | self.logger = logger
63 | self.num_modules = num_modules
64 | self.output_folder = output_folder
65 |
66 | self.quant = QuantStub()
67 | self.dequant = DeQuantStub()
68 |
69 | # Set the dataset file
70 | self.dataset_file = os.path.join('./vprtempo/dataset', f'{self.dataset}-{self.query_dir}' + '.csv')
71 | self.query_dir = [dir.strip() for dir in self.query_dir.split(',')]
72 |
73 | # Layer dict to keep track of layer names and their order
74 | self.layer_dict = {}
75 | self.layer_counter = 0
76 | self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')]
77 |
78 | # Define layer architecture
79 | self.input = int(self.dims[0]*self.dims[1])
80 | self.feature = int(self.input * 2)
81 |
82 | # Output dimension changes for final module if not an even distribution of places
83 | if not out_dim_remainder is None:
84 | self.output = out_dim_remainder
85 | else:
86 | self.output = out_dim
87 |
88 | # set model name for default demo
89 | self.demo = './vprtempo/models/springfall_VPRTempoQuant_IN3136_FN6272_DB500.pth'
90 |
91 | """
92 | Define trainable layers here
93 | """
94 | self.add_layer(
95 | 'feature_layer',
96 | dims=[self.input, self.feature],
97 | device=self.device,
98 | inference=True
99 | )
100 | self.add_layer(
101 | 'output_layer',
102 | dims=[self.feature, self.output],
103 | device=self.device,
104 | inference=True
105 | )
106 |
107 | def add_layer(self, name, **kwargs):
108 | """
109 | Dynamically add a layer with given name and keyword arguments.
110 |
111 | :param name: Name of the layer to be added
112 | :type name: str
113 | :param kwargs: Hyperparameters for the layer
114 | """
115 | # Check for layer name duplicates
116 | if name in self.layer_dict:
117 | raise ValueError(f"Layer with name {name} already exists.")
118 |
119 | # Add a new SNNLayer with provided kwargs
120 | setattr(self, name, bn.SNNLayer(**kwargs))
121 |
122 | # Add layer name and index to the layer_dict
123 | self.layer_dict[name] = self.layer_counter
124 | self.layer_counter += 1
125 |
126 | def evaluate(self, models, test_loader, layers=None):
127 | """
128 | Run the inferencing model and calculate the accuracy.
129 |
130 | :param test_loader: Testing data loader
131 | :param layers: Layers to pass data through
132 | """
133 | # Determine the Hardtahn max value
134 | maxSpike = (1//models[0].quant.scale).item()
135 | # Define the sequential inference model
136 | self.inferences = []
137 | for model in models:
138 | self.inferences.append(nn.Sequential(
139 | model.feature_layer.w,
140 | model.output_layer.w,
141 | ))
142 | # Initialize the tqdm progress bar
143 | pbar = tqdm(total=self.query_places,
144 | desc="Running the test network",
145 | position=0)
146 | # Initiliaze the output spikes variable
147 | out = []
148 | labels = []
149 | # Run inference for the specified number of timesteps
150 | for spikes, label in test_loader:
151 | # Set device
152 | spikes = spikes.to(self.device)
153 | labels.append(label.detach().item())
154 | # Pass through previous layers if they exist
155 | spikes = self.forward(spikes)
156 | # Add output spikes to list
157 | out.append(spikes.detach().cpu())
158 | pbar.update(1)
159 |
160 | # Close the tqdm progress bar
161 | pbar.close()
162 | # Rehsape output spikes into a similarity matrix
163 | out = torch.stack(out, dim=2)
164 | out = out.squeeze(0).numpy()
165 |
166 | if self.skip != 0:
167 | GT = np.zeros((model.database_places, model.query_places))
168 | skip = model.skip // model.filter
169 | # Create an array of indices for the query dimension
170 | query_indices = np.arange(model.query_places)
171 |
172 | # Set the ones on the diagonal starting at row `skip`
173 | GT[skip + query_indices, query_indices] = 1
174 | else:
175 | GT = np.eye(model.database_places, model.query_places)
176 |
177 | # Apply GT tolerance
178 | if self.GT_tolerance > 0:
179 | # Get the number of rows and columns
180 | num_rows, num_cols = GT.shape
181 |
182 | # Iterate over each column
183 | for col in range(num_cols):
184 | # Find the indices of rows where GT has a 1 in the current column
185 | ones_indices = np.where(GT[:, col] == 1)[0]
186 |
187 | # For each index with a 1, set 1s in GTtol within the specified vertical distance
188 | for row in ones_indices:
189 | # Determine the start and end rows, ensuring they are within bounds
190 | start_row = max(row - self.GT_tolerance, 0)
191 | end_row = min(row + self.GT_tolerance + 1, num_rows) # +1 because upper bound is exclusive
192 |
193 | # Set the range in GTtol to 1
194 | GT[start_row:end_row, col] = 1
195 |
196 | # If user specified, generate a PR curve
197 | if model.PR_curve:
198 | # Create PR curve
199 | P, R = createPR(out, GT, matching='single', n_thresh=100)
200 | # Combine P and R into a list of lists
201 | PR_data = {
202 | "Precision": P,
203 | "Recall": R
204 | }
205 | output_file = "PR_curve_data.json"
206 | # Construct the full path
207 | full_path = f"{model.output_folder}/{output_file}"
208 | # Write the data to a JSON file
209 | with open(full_path, 'w') as file:
210 | json.dump(PR_data, file)
211 | # Plot PR curve
212 | plt.plot(R,P)
213 | plt.xlabel('Recall')
214 | plt.ylabel('Precision')
215 | plt.title('Precision-Recall Curve')
216 | plt.show()
217 |
218 | if model.sim_mat:
219 | # Create a figure and a set of subplots
220 | fig, axs = plt.subplots(1, 2, figsize=(15, 5))
221 |
222 | # Plot each matrix using matshow
223 | cax1 = axs[0].matshow(out, cmap='viridis')
224 | fig.colorbar(cax1, ax=axs[0], shrink=0.8)
225 | axs[0].set_title('Similarity matrix')
226 |
227 | cax2 = axs[1].matshow(GT, cmap='plasma')
228 | fig.colorbar(cax2, ax=axs[1], shrink=0.8)
229 | axs[1].set_title('GT')
230 |
231 | # Adjust layout
232 | plt.tight_layout()
233 | plt.show()
234 |
235 | # Recall@N
236 | N = [1,5,10,15,20,25] # N values to calculate
237 | R = [] # Recall@N values
238 | # Calculate Recall@N
239 | for n in N:
240 | R.append(round(recallAtK(out,GT,K=n),2))
241 | # Print the results
242 | table = PrettyTable()
243 | table.field_names = ["N", "1", "5", "10", "15", "20", "25"]
244 | table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]])
245 | self.logger.info(table)
246 |
247 | def forward(self, spikes):
248 | """
249 | Compute the forward pass of the model.
250 |
251 | Parameters:
252 | - spikes (Tensor): Input spikes.
253 |
254 | Returns:
255 | - Tensor: Output after processing.
256 | """
257 | spikes = self.quant(spikes)
258 | in_spikes = spikes.detach().clone()
259 | outputs = [] # List to collect output tensors
260 |
261 | for inference in self.inferences:
262 | out_spikes = inference(in_spikes)
263 | outputs.append(out_spikes) # Append the output tensor to the list
264 |
265 | # Concatenate along the desired dimension
266 | concatenated_output = torch.cat(outputs, dim=1)
267 | spikes = self.dequant(concatenated_output)
268 |
269 | return spikes
270 |
271 | def load_model(self, models, model_path):
272 | """
273 | Load pre-trained model and set the state dictionary keys.
274 | """
275 | # check if model exists, download the demo data and model if defaults are set
276 | if not os.path.exists(model_path):
277 | if model_path == self.demo:
278 | get_data_model()
279 | else:
280 | raise ValueError(f"Model path {model_path} does not exist.")
281 | combined_state_dict = torch.load(model_path, map_location=self.device, weights_only=True)
282 |
283 | for i, model in enumerate(models): # models_classes is a list of model classes
284 |
285 | model.load_state_dict(combined_state_dict[f'model_{i}'])
286 | model.eval() # Set the model to inference mode
287 |
288 | def run_inference_quant(models, model_name):
289 | """
290 | Run inference on a pre-trained model.
291 |
292 | :param models: Models to run inference on, each model is a VPRTempo module
293 | :param model_name: Name of the model to load
294 | """
295 | # Set first index model as the main model for parameters
296 | model = models[0]
297 | # Initialize the image transforms
298 | image_transform = ProcessImage(model.dims, model.patches)
299 |
300 | # Initialize the test dataset
301 | test_dataset = CustomImageDataset(annotations_file=model.dataset_file,
302 | base_dir=model.data_dir,
303 | img_dirs=model.query_dir,
304 | transform=image_transform,
305 | max_samples=model.query_places,
306 | filter=model.filter,
307 | skip=model.skip)
308 |
309 | # Initialize the data loader
310 | test_loader = DataLoader(test_dataset,
311 | batch_size=1,
312 | num_workers=8,
313 | persistent_workers=True)
314 |
315 | # Load the model
316 | model.load_model(models, os.path.join('./vprtempo/models', model_name))
317 |
318 | # Use evaluate method for inference accuracy
319 | with torch.no_grad():
320 | model.evaluate(models, test_loader)
--------------------------------------------------------------------------------
/vprtempo/VPRTempoQuantTrain.py:
--------------------------------------------------------------------------------
1 | #MIT License
2 |
3 | #Copyright (c) 2023 Adam Hines, Peter G Stratton, Michael Milford, Tobias Fischer
4 |
5 | #Permission is hereby granted, free of charge, to any person obtaining a copy
6 | #of this software and associated documentation files (the "Software"), to deal
7 | #in the Software without restriction, including without limitation the rights
8 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | #copies of the Software, and to permit persons to whom the Software is
10 | #furnished to do so, subject to the following conditions:
11 |
12 | #The above copyright notice and this permission notice shall be included in all
13 | #copies or substantial portions of the Software.
14 |
15 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | #SOFTWARE.
22 |
23 | '''
24 | Imports
25 | '''
26 |
27 | import os
28 | import torch
29 |
30 | import numpy as np
31 | import torch.nn as nn
32 | import vprtempo.src.blitnet as bn
33 | import torch.quantization as quantization
34 | import torchvision.transforms as transforms
35 |
36 | from tqdm import tqdm
37 | from torch.utils.data import DataLoader
38 | from torch.ao.quantization import QuantStub, DeQuantStub
39 | from vprtempo.src.dataset import CustomImageDataset, ProcessImage
40 |
41 | class VPRTempoQuantTrain(nn.Module):
42 | def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=None):
43 | super(VPRTempoQuantTrain, self).__init__()
44 |
45 | # Set the arguments
46 | self.args = args
47 | for arg in vars(args):
48 | setattr(self, arg, getattr(args, arg))
49 | setattr(self, 'dims', dims)
50 |
51 | self.device = "cpu"
52 | self.logger = logger
53 | self.num_modules = num_modules
54 | self.quant = QuantStub()
55 | self.dequant = DeQuantStub()
56 |
57 | # Set the dataset file
58 | fields = self.database_dirs.split(',')
59 | if len(fields) > 1:
60 | self.dataset_file = []
61 | for field in fields:
62 | self.dataset_file.append(os.path.join('./vprtempo/dataset', f'{self.dataset}-{field}' + '.csv'))
63 | else:
64 | self.dataset_file = os.path.join('./vprtempo/dataset', f'{self.dataset}-{self.database_dirs}' + '.csv')
65 |
66 | # Layer dict to keep track of layer names and their order
67 | self.layer_dict = {}
68 | self.layer_counter = 0
69 |
70 | # Define layer architecture
71 | self.input = int(dims[0]*dims[1])
72 | self.feature = int(self.input * 2)
73 | if not out_dim_remainder is None:
74 | self.output = out_dim_remainder
75 | else:
76 | self.output = out_dim
77 |
78 | # Set the total timestep count
79 | self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')]
80 | self.location_repeat = len(self.database_dirs) # Number of times to repeat the locations
81 | if not out_dim_remainder is None:
82 | self.T = int(out_dim_remainder * self.location_repeat * self.epoch)
83 | else:
84 | self.T = int(self.max_module * self.location_repeat * self.epoch)
85 |
86 | """
87 | Define trainable layers here
88 | """
89 | self.add_layer(
90 | 'feature_layer',
91 | dims=[self.input, self.feature],
92 | thr_range=[0, 0.5],
93 | fire_rate=[0.2, 0.9],
94 | ip_rate=0.15,
95 | stdp_rate=0.005,
96 | p=[0.1, 0.5],
97 | device=self.device
98 | )
99 | self.add_layer(
100 | 'output_layer',
101 | dims=[self.feature, self.output],
102 | ip_rate=0.15,
103 | stdp_rate=0.005,
104 | spk_force=True,
105 | p=[1.0, 1.0],
106 | device=self.device
107 | )
108 |
109 | def add_layer(self, name, **kwargs):
110 | """
111 | Dynamically add a layer with given name and keyword arguments.
112 |
113 | :param name: Name of the layer to be added
114 | :type name: str
115 | :param kwargs: Hyperparameters for the layer
116 | """
117 | # Check for layer name duplicates
118 | if name in self.layer_dict:
119 | raise ValueError(f"Layer with name {name} already exists.")
120 |
121 | # Add a new SNNLayer with provided kwargs
122 | setattr(self, name, bn.SNNLayer(**kwargs))
123 |
124 | # Add layer name and index to the layer_dict
125 | self.layer_dict[name] = self.layer_counter
126 | self.layer_counter += 1
127 |
128 | def _anneal_learning_rate(self, layer, mod, itp, stdp):
129 | """
130 | Anneal the learning rate for the current layer.
131 | """
132 | if np.mod(mod, 100) == 0: # Modify learning rate every 100 timesteps
133 | pt = pow(float(self.T - mod) / self.T, 2)
134 | layer.eta_ip = torch.mul(itp, pt) # Anneal intrinsic threshold plasticity learning rate
135 | layer.eta_stdp = torch.mul(stdp, pt) # Anneal STDP learning rate
136 |
137 | return layer
138 |
139 | def train_model(self, train_loader, layer, model, model_num, prev_layers=None):
140 | """
141 | Train a layer of the network model.
142 |
143 | :param train_loader: Training data loader
144 | :param layer: Layer to train
145 | :param prev_layers: Previous layers to pass data through
146 | """
147 |
148 | # Initialize the tqdm progress bar
149 | pbar = tqdm(total=int(self.T),
150 | desc="Training ",
151 | position=0)
152 |
153 | # Initialize the learning rates for each layer (used for annealment)
154 | init_itp = layer.eta_ip.detach()
155 | init_stdp = layer.eta_stdp.detach()
156 | mod = 0 # Used to determine the learning rate annealment, resets at each epoch
157 | # idx scale factor for different modules
158 | idx_scale = (self.max_module*self.filter)*model_num
159 | # Run training for the specified number of epochs
160 | for _ in range(self.epoch):
161 | # Run training for the specified number of timesteps
162 | for spikes, labels in train_loader:
163 | spikes, labels = spikes.to(self.device), labels.to(self.device)
164 | idx = torch.round((labels - idx_scale) / self.filter) # Set output index for spike forcing
165 | # Pass through previous layers if they exist
166 | if prev_layers:
167 | with torch.no_grad():
168 | for prev_layer_name in prev_layers:
169 | prev_layer = getattr(self, prev_layer_name) # Get the previous layer object
170 | spikes = self.forward(spikes, prev_layer) # Pass spikes through the previous layer
171 | spikes = bn.clamp_spikes(spikes, prev_layer) # Clamp spikes [0, 0.9]
172 | else:
173 | prev_layer = None
174 | # Get the output spikes from the current layer
175 | pre_spike = spikes.detach() # Previous layer spikes for STDP
176 | spikes = self.forward(spikes, layer) # Current layer spikes
177 | spikes_noclp = spikes.detach() # Used for inhibitory homeostasis
178 | spikes = bn.clamp_spikes(spikes, layer) # Clamp spikes [0, 0.9]
179 | # Calculate STDP
180 | layer = bn.calc_stdp(pre_spike,spikes,spikes_noclp,layer, idx, prev_layer=prev_layer)
181 | # Adjust learning rates
182 | layer = self._anneal_learning_rate(layer, mod, init_itp, init_stdp)
183 | # Update the annealing mod & progress bar
184 | mod += 1
185 | pbar.update(1)
186 |
187 | # Close the tqdm progress bar
188 | pbar.close()
189 |
190 | def forward(self, spikes, layer):
191 | """
192 | Compute the forward pass of the model.
193 |
194 | Parameters:
195 | - spikes (Tensor): Input spikes.
196 |
197 | Returns:
198 | - Tensor: Output after processing.
199 | """
200 |
201 | spikes = self.quant(spikes)
202 | spikes = layer.w(spikes)
203 | spikes = self.dequant(spikes)
204 |
205 | return spikes
206 |
207 | def save_model(self, models, model_out):
208 | """
209 | Save the trained model to models output folder.
210 | """
211 | state_dicts = {}
212 | for i, model in enumerate(models): # Assuming models_list is your list of models
213 | state_dicts[f'model_{i}'] = model.state_dict()
214 |
215 | torch.save(state_dicts, model_out)
216 |
217 | def generate_model_name_quant(model):
218 | """
219 | Generate the model name based on its parameters.
220 | """
221 | return ("VPRTempoQuant" +
222 | str(model.input) +
223 | str(model.feature) +
224 | str(model.output) +
225 | str(model.num_modules) +
226 | '.pth')
227 |
228 | def check_pretrained_model(model_name):
229 | """
230 | Check if a pre-trained model exists and prompt the user to retrain if desired.
231 | """
232 | if os.path.exists(os.path.join('./models', model_name)):
233 | prompt = "A network with these parameters exists, re-train network? (y/n):\n"
234 | retrain = input(prompt).strip().lower()
235 | return retrain == 'n'
236 | return False
237 |
238 | def train_new_model_quant(models, model_name):
239 | """
240 | Train a new model.
241 |
242 | :param model: Model to train
243 | :param model_name: Name of the model to save after training
244 | """
245 | # Set first index model as the main model for parameters
246 | model = models[0]
247 | # Initialize the image transforms and datasets
248 | image_transform = transforms.Compose([
249 | ProcessImage(model.dims, model.patches)
250 | ])
251 | # Automatically generate user_input_ranges
252 | user_input_ranges = []
253 | start_idx = 0
254 | # Generate the image ranges for each module
255 | for _ in range(model.num_modules):
256 | range_temp = [start_idx, start_idx+((model.max_module-1)*model.filter)]
257 | user_input_ranges.append(range_temp)
258 | start_idx = range_temp[1] + model.filter
259 |
260 | # Keep track of trained layers to pass data through them
261 | trained_layers = []
262 |
263 | # Training each layer
264 | for layer_name, _ in sorted(model.layer_dict.items(), key=lambda item: item[1]):
265 | print(f"Training layer: {layer_name}")
266 | # Retrieve the layer object
267 | for i, model in enumerate(models):
268 | model.train()
269 | model.to(torch.device(model.device))
270 | layer = (getattr(model, layer_name))
271 | # Determine the maximum samples for the DataLoader
272 | if model.database_places < model.max_module:
273 | max_samples = model.database_places
274 | elif model.output < model.max_module:
275 | max_samples = model.output
276 | else:
277 | max_samples = model.max_module
278 | # Initialize new dataset with unique range for each module
279 | img_range=user_input_ranges[i]
280 | train_dataset = CustomImageDataset(annotations_file=model.dataset_file,
281 | base_dir=model.data_dir,
282 | img_dirs=model.database_dirs,
283 | transform=image_transform,
284 | filter=models[0].filter,
285 | skip=models[0].skip,
286 | test=False,
287 | img_range=img_range,
288 | max_samples=max_samples)
289 | # Initialize the data loader
290 | train_loader = DataLoader(train_dataset,
291 | batch_size=1,
292 | shuffle=True,
293 | num_workers=8,
294 | persistent_workers=True)
295 | # Train the layers
296 | model.train_model(train_loader, layer, model, i, prev_layers=trained_layers)
297 | trained_layers.append(layer_name)
298 |
299 | # Convert the model to evaluation mode
300 | for model in models:
301 | quantization.convert(model, inplace=True)
302 |
303 | # Save the model
304 | model.save_model(models,os.path.join('./vprtempo/models', model_name))
--------------------------------------------------------------------------------
/vprtempo/VPRTempoTrain.py:
--------------------------------------------------------------------------------
1 | #MIT License
2 |
3 | #Copyright (c) 2023 Adam Hines, Peter G Stratton, Michael Milford, Tobias Fischer
4 |
5 | #Permission is hereby granted, free of charge, to any person obtaining a copy
6 | #of this software and associated documentation files (the "Software"), to deal
7 | #in the Software without restriction, including without limitation the rights
8 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | #copies of the Software, and to permit persons to whom the Software is
10 | #furnished to do so, subject to the following conditions:
11 |
12 | #The above copyright notice and this permission notice shall be included in all
13 | #copies or substantial portions of the Software.
14 |
15 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | #SOFTWARE.
22 |
23 | '''
24 | Imports
25 | '''
26 |
27 | import os
28 | import gc
29 | import torch
30 | import sys
31 |
32 | import numpy as np
33 | import torch.nn as nn
34 | import vprtempo.src.blitnet as bn
35 | import torchvision.transforms as transforms
36 |
37 | from tqdm import tqdm
38 | from torch.utils.data import DataLoader
39 | from vprtempo.src.loggers import model_logger
40 | from vprtempo.src.dataset import CustomImageDataset, ProcessImage
41 |
42 | class VPRTempoTrain(nn.Module):
43 | def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=None):
44 | super(VPRTempoTrain, self).__init__()
45 |
46 | # Set the arguments
47 | self.args = args
48 | for arg in vars(args):
49 | setattr(self, arg, getattr(args, arg))
50 | setattr(self, 'dims', dims)
51 | # Set the device
52 | if torch.cuda.is_available():
53 | self.device = "cuda:0"
54 | elif torch.backends.mps.is_available():
55 | self.device = "mps"
56 | else:
57 | self.device = "cpu"
58 | self.logger = logger
59 | self.num_modules = num_modules
60 |
61 | # Set the dataset file
62 | fields = self.database_dirs.split(',')
63 | if len(fields) > 1:
64 | self.dataset_file = []
65 | for field in fields:
66 | self.dataset_file.append(os.path.join('./vprtempo/dataset', f'{self.dataset}-{field}' + '.csv'))
67 | else:
68 | self.dataset_file = os.path.join('./vprtempo/dataset', f'{self.dataset}-{self.database_dirs}' + '.csv')
69 |
70 | # Layer dict to keep track of layer names and their order
71 | self.layer_dict = {}
72 | self.layer_counter = 0
73 |
74 | # Define layer architecture
75 | self.input = int(dims[0]*dims[1])
76 | self.feature = int(self.input * 2)
77 | if not out_dim_remainder is None:
78 | self.output = out_dim_remainder
79 | else:
80 | self.output = out_dim
81 |
82 | # Set the total timestep count
83 | self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')]
84 | self.location_repeat = len(self.database_dirs) # Number of times to repeat the locations
85 | if not out_dim_remainder is None:
86 | self.T = int(out_dim_remainder * self.location_repeat * self.epoch)
87 | else:
88 | self.T = int(self.max_module * self.location_repeat * self.epoch)
89 |
90 | """
91 | Define trainable layers here
92 | """
93 | self.add_layer(
94 | 'feature_layer',
95 | dims=[self.input, self.feature],
96 | thr_range=[0, 0.5],
97 | fire_rate=[0.2, 0.9],
98 | ip_rate=0.15,
99 | stdp_rate=0.005,
100 | p=[0.1, 0.5],
101 | device=self.device
102 | )
103 | self.add_layer(
104 | 'output_layer',
105 | dims=[self.feature, self.output],
106 | ip_rate=0.15,
107 | stdp_rate=0.005,
108 | p=[1.0, 1.0],
109 | spk_force=True,
110 | device=self.device
111 | )
112 |
113 | def add_layer(self, name, **kwargs):
114 | """
115 | Dynamically add a layer with given name and keyword arguments.
116 |
117 | :param name: Name of the layer to be added
118 | :type name: str
119 | :param kwargs: Hyperparameters for the layer
120 | """
121 | # Check for layer name duplicates
122 | if name in self.layer_dict:
123 | raise ValueError(f"Layer with name {name} already exists.")
124 |
125 | # Add a new SNNLayer with provided kwargs
126 | setattr(self, name, bn.SNNLayer(**kwargs))
127 |
128 | # Add layer name and index to the layer_dict
129 | self.layer_dict[name] = self.layer_counter
130 | self.layer_counter += 1
131 |
132 | def model_logger(self):
133 | """
134 | Log the model configuration to the console.
135 | """
136 | model_logger(self)
137 |
138 | def _anneal_learning_rate(self, layer, mod, itp, stdp):
139 | """
140 | Anneal the learning rate for the current layer.
141 | """
142 | if np.mod(mod, 100) == 0: # Modify learning rate every 100 timesteps
143 | pt = pow(float(self.T - mod) / self.T, 2)
144 | layer.eta_ip = torch.mul(itp, pt) # Anneal intrinsic threshold plasticity learning rate
145 | layer.eta_stdp = torch.mul(stdp, pt) # Anneal STDP learning rate
146 |
147 | return layer
148 |
149 | def train_model(self, train_loader, layer, model, model_num, prev_layers=None):
150 | """
151 | Train a layer of the network model.
152 |
153 | :param train_loader: Training data loader
154 | :param layer: Layer to train
155 | :param prev_layers: Previous layers to pass data through
156 | """
157 |
158 | # Initialize the tqdm progress bar
159 | pbar = tqdm(total=self.T,
160 | desc=f"Module {model_num+1}",
161 | position=0)
162 |
163 | # Initialize the learning rates for each layer (used for annealment)
164 | init_itp = layer.eta_ip.detach()
165 | init_stdp = layer.eta_stdp.detach()
166 | mod = 0 # Used to determine the learning rate annealment, resets at each epoch
167 |
168 | # idx scale factor for different modules
169 | idx_scale = (self.max_module*self.filter)*model_num
170 |
171 | # Run training for the specified number of epochs
172 | for _ in range(self.epoch):
173 | # Run training for the specified number of timesteps
174 | for spikes, labels in train_loader:
175 | spikes, labels = spikes.to(self.device), labels.to(self.device)
176 | idx = torch.round((labels - idx_scale) / self.filter) # Set output index for spike forcing
177 | # Pass through previous layers if they exist
178 | if prev_layers:
179 | with torch.no_grad():
180 | for prev_layer_name in prev_layers:
181 | prev_layer = getattr(model, prev_layer_name) # Get the previous layer object
182 | spikes = self.forward(spikes, prev_layer) # Pass spikes through the previous layer
183 | spikes = bn.clamp_spikes(spikes, prev_layer) # Clamp spikes [0, 0.9]
184 | else:
185 | prev_layer = None
186 | # Get the output spikes from the current layer
187 | pre_spike = spikes.detach() # Previous layer spikes for STDP
188 | spikes = self.forward(spikes, layer) # Current layer spikes
189 | spikes_noclp = spikes.detach() # Used for inhibitory homeostasis
190 | spikes = bn.clamp_spikes(spikes, layer) # Clamp spikes [0, 0.9]
191 | # Calculate STDP
192 | layer = bn.calc_stdp(pre_spike,spikes,spikes_noclp,layer, idx, prev_layer=prev_layer)
193 | # Adjust learning rates
194 | layer = self._anneal_learning_rate(layer, mod, init_itp, init_stdp)
195 | # Update the annealing mod & progress bar
196 | mod += 1
197 | pbar.update(1)
198 |
199 | # Close the tqdm progress bar
200 | pbar.close()
201 |
202 | # Free up memory
203 | if self.device == "cuda:0":
204 | torch.cuda.empty_cache()
205 | gc.collect()
206 |
207 | def forward(self, spikes, layer):
208 | """
209 | Compute the forward pass of the model.
210 |
211 | Parameters:
212 | - spikes (Tensor): Input spikes.
213 |
214 | Returns:
215 | - Tensor: Output after processing.
216 | """
217 |
218 | spikes = layer.w(spikes)
219 |
220 | return spikes
221 |
222 | def save_model(self,models, model_out):
223 | """
224 | Save the trained model to models output folder.
225 | """
226 | state_dicts = {}
227 | for i, model in enumerate(models): # Assuming models_list is your list of models
228 | state_dicts[f'model_{i}'] = model.state_dict()
229 |
230 | torch.save(state_dicts, model_out)
231 |
232 |
233 | def check_pretrained_model(model_name):
234 | """
235 | Check if a pre-trained model exists and prompt the user to retrain if desired.
236 | """
237 | if os.path.exists(os.path.join('./vprtempo/models', model_name)):
238 | prompt = "A network with these parameters exists, re-train network? (y/n):\n"
239 | retrain = input(prompt).strip().lower()
240 | if retrain == 'y':
241 | return True
242 | elif retrain == 'n':
243 | print('Training new model cancelled')
244 | sys.exit()
245 |
246 | def train_new_model(models, model_name):
247 | """
248 | Train a new model.
249 |
250 | :param model: Model to train
251 | :param model_name: Name of the model to save after training
252 | :param qconfig: Quantization configuration
253 | """
254 | # Initialize the image transforms and datasets
255 | image_transform = transforms.Compose([
256 | ProcessImage(models[0].dims, models[0].patches)
257 | ])
258 | # Automatically generate user_input_ranges
259 | user_input_ranges = []
260 | start_idx = 0
261 | # Generate the image ranges for each module
262 | for _ in range(models[0].num_modules):
263 | range_temp = [start_idx, start_idx+((models[0].max_module-1)*models[0].filter)]
264 | user_input_ranges.append(range_temp)
265 | start_idx = range_temp[1] + models[0].filter
266 |
267 | # Keep track of trained layers to pass data through them
268 | trained_layers = []
269 | # Training each layer
270 | for layer_name, _ in sorted(models[0].layer_dict.items(), key=lambda item: item[1]):
271 | print(f"Training layer: {layer_name}")
272 | # Retrieve the layer object
273 | for i, model in enumerate(models):
274 | model.train()
275 | model.to(torch.device(model.device))
276 | layer = (getattr(model, layer_name))
277 | # Determine the maximum samples for the DataLoader
278 | if model.database_places < model.max_module:
279 | max_samples = model.database_places
280 | elif model.output < model.max_module:
281 | max_samples = model.output
282 | else:
283 | max_samples = model.max_module
284 | # Initialize new dataset with unique range for each module
285 | img_range=user_input_ranges[i]
286 | train_dataset = CustomImageDataset(annotations_file=models[0].dataset_file,
287 | base_dir=models[0].data_dir,
288 | img_dirs=models[0].database_dirs,
289 | transform=image_transform,
290 | filter=models[0].filter,
291 | skip=models[0].skip,
292 | test=False,
293 | img_range=img_range,
294 | max_samples=max_samples)
295 | # Initialize the data loader
296 | train_loader = DataLoader(train_dataset,
297 | batch_size=1,
298 | shuffle=True,
299 | num_workers=8,
300 | persistent_workers=True)
301 | # Train the layers
302 | model.train_model(train_loader, layer, model, i, prev_layers=trained_layers)
303 | model.to(torch.device("cpu"))
304 | # After training the current layer, add it to the list of trained layers
305 | trained_layers.append(layer_name)
306 | # Convert the model to evaluation mode
307 | for model in models:
308 | model.eval()
309 | # Save the model
310 | model.save_model(models,os.path.join('./vprtempo/models', model_name))
--------------------------------------------------------------------------------
/vprtempo/__init__.py:
--------------------------------------------------------------------------------
1 | _version__ = '1.1.9'
2 |
--------------------------------------------------------------------------------
/vprtempo/models/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QVPR/VPRTempo/52fb2ecbbbe924430021b54e309fe2a5c8d18672/vprtempo/models/.gitkeep
--------------------------------------------------------------------------------
/vprtempo/models/README.txt:
--------------------------------------------------------------------------------
1 | To download pretrained models, please visit -> https://www.dropbox.com/scl/fi/ysfz7t7ek6h0pslwq9hd4/VPRTempo_pretrained_models.zip?rlkey=thg0rhn0hjsyov6zov63ni11o&st=nvimet71&dl=0
2 |
--------------------------------------------------------------------------------
/vprtempo/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QVPR/VPRTempo/52fb2ecbbbe924430021b54e309fe2a5c8d18672/vprtempo/src/__init__.py
--------------------------------------------------------------------------------
/vprtempo/src/blitnet.py:
--------------------------------------------------------------------------------
1 | #MIT License
2 |
3 | #Copyright (c) 2023 Adam Hines, Peter Stratton, Michael Milford, Tobias Fischer
4 |
5 | #Permission is hereby granted, free of charge, to any person obtaining a copy
6 | #of this software and associated documentation files (the "Software"), to deal
7 | #in the Software without restriction, including without limitation the rights
8 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | #copies of the Software, and to permit persons to whom the Software is
10 | #furnished to do so, subject to the following conditions:
11 |
12 | #The above copyright notice and this permission notice shall be included in all
13 | #copies or substantial portions of the Software.
14 |
15 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | #SOFTWARE.
22 |
23 | '''
24 | Imports
25 | '''
26 | import torch
27 |
28 | import torch.nn as nn
29 | import numpy as np
30 |
31 |
32 | class SNNLayer(nn.Module):
33 | def __init__(self, dims=[0,0],thr_range=[0,0],fire_rate=[0,0],ip_rate=0,
34 | stdp_rate=0,const_inp=[0,0],p=[1,1],spk_force=False,device=None,inference=False,args=None):
35 | super(SNNLayer, self).__init__()
36 | """
37 | dims: [input, output] dimensions of the layer
38 | thr_range: [min, max] range of thresholds
39 | fire_rate: [min, max] range of firing rates
40 | ip_rate: learning rate for input threshold plasticity
41 | stdp_rate: learning rate for stdp
42 | const_inp: [min, max] range of constant input
43 | p: [min, max] range of connection probabilities
44 | spk_force: boolean to force spikes
45 | """
46 |
47 | # Device
48 | self.device = device
49 | # Add different parameters depending if trainnig or running inference model
50 | if inference: # If running inference model
51 | self.w = nn.Linear(dims[0], dims[1], bias=False) # Combined weight tensors
52 | self.w.to(device)
53 | self.thr = nn.Parameter(torch.zeros([1, dims[-1]],
54 | device=self.device).uniform_(thr_range[0],
55 | thr_range[1]))
56 | else: # If training new model
57 | # Check constraints etc
58 | if np.isscalar(thr_range): thr_range = [thr_range, thr_range]
59 | if np.isscalar(fire_rate): fire_rate = [fire_rate, fire_rate]
60 | if np.isscalar(const_inp): const_inp = [const_inp, const_inp]
61 |
62 | # Initialize Tensors
63 | self.x = torch.zeros([1, dims[-1]], device=self.device)
64 | self.eta_ip = torch.tensor(ip_rate, device=self.device)
65 | self.eta_stdp = torch.tensor(stdp_rate, device=self.device)
66 |
67 | # Initialize Parameters
68 | self.thr = nn.Parameter(torch.zeros([1, dims[-1]],
69 | device=self.device).uniform_(thr_range[0],
70 | thr_range[1]))
71 | self.fire_rate = torch.zeros([1,dims[-1]], device=self.device).uniform_(fire_rate[0], fire_rate[1])
72 |
73 | # Sequentially set the feature firing rates (if any)
74 | if not torch.all(self.fire_rate==0).item():
75 | fstep = (fire_rate[1]-fire_rate[0])/dims[-1]
76 |
77 | for i in range(dims[-1]):
78 | self.fire_rate[:,i] = fire_rate[0]+fstep*(i+1)
79 |
80 | self.have_rate = torch.any(self.fire_rate[:,0] > 0.0).to(self.device)
81 | self.const_inp = torch.zeros([1, dims[-1]], device=self.device).uniform_(const_inp[0], const_inp[1])
82 | self.p = p
83 | self.dims = dims
84 |
85 | # Additional State Variables
86 | self.set_spks = []
87 | self.sspk_idx = 0
88 | self.spikes = torch.empty([], dtype=torch.float64)
89 | self.spk_force = spk_force
90 |
91 | # Create the excitatory weights
92 | self.exc = nn.Linear(dims[0], dims[1], bias=False)
93 | self.exc.weight = self.addWeights(dims=dims,
94 | W_range=[0,1],
95 | p=p[0],
96 | device=device)
97 |
98 | # Create the inhibitory weights
99 | self.inh = nn.Linear(dims[0], dims[1], bias=False)
100 | self.inh.weight = self.addWeights(dims=dims,
101 | W_range=[-1,0],
102 | p=p[-1],
103 | device=device)
104 |
105 | # Output boolean reference of which neurons have connection weights
106 | self.havconnExc = self.exc.weight > 0
107 | self.havconnInh = self.inh.weight < 0
108 |
109 | # Combine weights into a single tensor
110 | self.w = nn.Linear(dims[0], dims[1], bias=False)
111 | self.w.weight = nn.Parameter(torch.add(self.exc.weight, self.inh.weight))
112 |
113 | self.havconnCombinedExc = self.w.weight > 0
114 | self.havconnCombinedInh = self.w.weight < 0
115 |
116 | del self.exc, self.inh
117 |
118 | def addWeights(self,W_range=[0,0],p=[0,0],dims=[0,0],device=None):
119 |
120 | # Get torch device
121 | device = device
122 |
123 | # Check constraints etc
124 | if np.isscalar(W_range): W_range = [W_range,W_range]
125 |
126 | # Determine dimensions of the weight matrices
127 | nrow = dims[1]
128 | ncol = dims[0]
129 |
130 | # Calculate mean and std for normal distributions
131 | Wmn = (W_range[0]+W_range[1])/2.0
132 | Wsd = (W_range[1]-W_range[0])/6.0
133 |
134 | # Initialize weights as empty tensors
135 | W = torch.empty((0, nrow, ncol), device=device)
136 |
137 | # Normally disribute random weights
138 | W = torch.empty(nrow, ncol, device=device).normal_(mean=Wmn, std=Wsd)
139 |
140 | # Remove inappropriate weights based on sign from W_range
141 | if W_range[-1] != 0:
142 | # For excitatory weights
143 | W[W < 0] = 0.0
144 | else:
145 | # For inhibitory weights
146 | W[W > 0] = 0.0
147 |
148 | # Remove weights based on connection probability
149 | setzero = np.random.rand(nrow,ncol) > p
150 | if setzero.any():
151 | W[setzero] = 0.0
152 |
153 | # Normalise the weights
154 | nrm = torch.linalg.norm(W[len(W)-1],ord=1,axis=0)
155 | nrm[nrm==0.0] = 1.0
156 | W = nn.Parameter(W/nrm)
157 |
158 | return W
159 |
160 | def add_input(spikes, layer):
161 |
162 | # Add the constant input
163 | spikes += layer.const_inp
164 |
165 | return spikes
166 |
167 | def clamp_spikes(spikes, layer):
168 | # Clamp outputs between 0 and 0.9 after subtracting thresholds from input
169 | spikes = torch.clamp(torch.sub(spikes, layer.thr), min=0.0, max=0.9)
170 |
171 | return spikes
172 |
173 | def calc_stdp(prespike, spikes, noclp, layer, idx, prev_layer=None):
174 | # Spike Forcing has special rules to make calculated and forced spikes match
175 | if layer.spk_force:
176 |
177 | # Get layer dimensions
178 | shape = layer.w.weight.data.shape
179 |
180 | # Get the output neuron index
181 | idx_sel = torch.arange(int(idx[0]), int(idx[0]) + 1,
182 | device=layer.device,
183 | dtype=int)
184 |
185 | # Difference between forced and calculated spikes
186 | layer.x = torch.full_like(layer.x, 0)
187 | xdiff = layer.x.index_fill_(-1, idx_sel, 0.5) - spikes
188 | xdiff.clamp(min=0.0, max=0.9)
189 |
190 | # Pre and Post spikes tiled across and down for all synapses
191 | if prev_layer.fire_rate == None:
192 | mpre = prespike
193 | else:
194 | # Modulate learning rate by firing rate (low firing rate = high learning rate)
195 | mpre = prespike/prev_layer.fire_rate
196 |
197 | # Tile out pre- and post- spikes for STDP weight updates
198 | pre = torch.tile(torch.reshape(mpre, (shape[1], 1)), (1, shape[0]))
199 | post = torch.tile(xdiff, (shape[1], 1))
200 |
201 | # Apply the weight changes
202 | layer.w.weight.data += ((pre * post * layer.havconnCombinedExc.T) *
203 | layer.eta_stdp).T
204 | layer.w.weight.data += ((-pre * post * layer.havconnCombinedInh.T) *
205 | (layer.eta_stdp * -1)).T
206 |
207 | # Normal STDP
208 | else:
209 |
210 | # Get layer dimensions
211 | shape = layer.w.weight.data.shape
212 |
213 | # Tile out pre- and post-spikes
214 | pre = torch.tile(torch.reshape(prespike, (shape[1], 1)), (1, shape[0]))
215 | post = torch.tile(spikes, (shape[1], 1))
216 |
217 | # Apply positive and negative weight changes
218 | layer.w.weight.data += (((0.5 - post) * (pre > 0) * (post > 0) *
219 | layer.havconnCombinedExc.T) * layer.eta_stdp).T
220 | layer.w.weight.data += (((0.5 - post) * (pre > 0) *
221 | (post > 0) * layer.havconnCombinedInh.T) * (layer.eta_stdp * -1)).T
222 |
223 | # Remove negative weights for excW and positive for inhW
224 | layer.w.weight.data[layer.havconnCombinedExc] = layer.w.weight.data[layer.havconnCombinedExc].clamp(min=1e-06, max=10)
225 | layer.w.weight.data[layer.havconnCombinedInh] = layer.w.weight.data[layer.havconnCombinedInh].clamp(min=-10, max=-1e-06)
226 |
227 | # Check if layer has target firing rate and an ITP learning rate
228 | if layer.have_rate and layer.eta_ip > 0.0:
229 |
230 | # Replace the original layer.thr with the updated one
231 | layer.thr.data += layer.eta_ip * (layer.x - layer.fire_rate)
232 | layer.thr.data[layer.thr.data < 0] = 0
233 |
234 | # Check if layer has inhibitory weights and an stdp learning rate
235 | if torch.any(layer.w.weight.data).item() and layer.eta_stdp != 0:
236 |
237 | # Normalize the inhibitory weights using homeostasis
238 | inhW = layer.w.weight.data.T.clone()
239 | inhW[inhW>0] = 0
240 | layer.w.weight.data += (torch.mul(noclp,inhW) * layer.eta_stdp*50).T
241 | #layer.w.weight.data[layer.w.weight.data > 0.0] = -1e-06
242 |
243 | return layer
--------------------------------------------------------------------------------
/vprtempo/src/create_data_csv.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 |
4 | def create_csv_from_images(folder_path, csv_file_path):
5 | files = os.listdir(folder_path)
6 | png_files = sorted([f for f in files if f.endswith('.png')])
7 |
8 | with open(csv_file_path, 'w', newline='') as file:
9 | writer = csv.writer(file)
10 | writer.writerow(['Image_name', 'index'])
11 |
12 | for index, image_name in enumerate(png_files):
13 | writer.writerow([image_name, index])
14 |
15 | # Name of the dataset to create .csv for
16 | dataset_name = 'nordland-fall'
17 |
18 | # Generate paths
19 | folder_path = ''
20 | csv_file_path = os.path.join('./VPRTempo/vprtempo/dataset', f'{dataset_name}.csv')
21 |
22 | # Create .csv file
23 | create_csv_from_images(folder_path, csv_file_path)
--------------------------------------------------------------------------------
/vprtempo/src/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 |
5 | import pandas as pd
6 | import numpy as np
7 | import torch.nn.functional as F
8 |
9 | from torchvision.io import read_image
10 | from torch.utils.data import Dataset
11 |
12 | class GetPatches2D:
13 | def __init__(self, patch_size, image_pad):
14 | self.patch_size = patch_size
15 | self.image_pad = image_pad
16 |
17 | def __call__(self, img):
18 |
19 | # Assuming image_pad is already a PyTorch tensor. If not, you can convert it:
20 | # image_pad = torch.tensor(image_pad).to(torch.float64)
21 |
22 | # Using unfold to get 2D sliding windows.
23 | unfolded = self.image_pad.unfold(0, self.patch_size[0], 1).unfold(1, self.patch_size[1], 1)
24 | # The size of unfolded will be [nrows, ncols, patch_size[0], patch_size[1]]
25 |
26 | # Reshaping the tensor to the desired shape
27 | patches = unfolded.permute(2, 3, 0, 1).contiguous().view(self.patch_size[0]*self.patch_size[1], -1)
28 |
29 | return patches
30 |
31 |
32 | class PatchNormalisePad:
33 | def __init__(self, patches):
34 | self.patches = patches
35 |
36 |
37 | def nanstd(self,input_tensor, dim=None, unbiased=True):
38 | if dim is not None:
39 | valid_count = torch.sum(~torch.isnan(input_tensor), dim=dim, dtype=torch.float)
40 | mean = torch.nansum(input_tensor, dim=dim) / valid_count
41 | diff = input_tensor - mean.unsqueeze(dim)
42 | variance = torch.nansum(diff * diff, dim=dim) / valid_count
43 |
44 | # Bessel's correction for unbiased estimation
45 | if unbiased:
46 | variance = variance * (valid_count / (valid_count - 1))
47 | else:
48 | valid_count = torch.sum(~torch.isnan(input_tensor), dtype=torch.float)
49 | mean = torch.nansum(input_tensor) / valid_count
50 | diff = input_tensor - mean
51 | variance = torch.nansum(diff * diff) / valid_count
52 |
53 | # Bessel's correction for unbiased estimation
54 | if unbiased:
55 | variance = variance * (valid_count / (valid_count - 1))
56 |
57 | return torch.sqrt(variance)
58 |
59 | def __call__(self, img):
60 | img = torch.squeeze(img,0)
61 | patch_size = (self.patches, self.patches)
62 | patch_half_size = [int((p-1)/2) for p in patch_size ]
63 |
64 | # Compute the padding. If patch_half_size is a scalar, the same value will be used for all sides.
65 | if isinstance(patch_half_size, int):
66 | pad = (patch_half_size, patch_half_size, patch_half_size, patch_half_size) # left, right, top, bottom
67 | else:
68 | # If patch_half_size is a tuple, then we'll assume it's in the format (height, width)
69 | pad = (patch_half_size[1], patch_half_size[1], patch_half_size[0], patch_half_size[0]) # left, right, top, bottom
70 |
71 | # Apply padding
72 | image_pad = F.pad(img, pad, mode='constant', value=float('nan'))
73 |
74 | nrows = img.shape[0]
75 | ncols = img.shape[1]
76 | patcher = GetPatches2D(patch_size,image_pad)
77 | patches = patcher(img)
78 | mus = torch.nanmean(patches, dim=0)
79 | stds = self.nanstd(patches, dim=0)
80 | with np.errstate(divide='ignore', invalid='ignore'):
81 | im_norm = (img - mus.reshape(nrows, ncols)) / stds.reshape(nrows, ncols)
82 |
83 | im_norm[torch.isnan(im_norm)] = 0.0
84 | im_norm[im_norm < -1.0] = -1.0
85 | im_norm[im_norm > 1.0] = 1.0
86 |
87 | return im_norm
88 |
89 | class SetImageAsSpikes:
90 | def __init__(self, intensity=255, test=True):
91 | self.intensity = intensity
92 |
93 | # Setup QAT FakeQuantize for the activations (your spikes)
94 | self.fake_quantize = torch.quantization.FakeQuantize(
95 | observer=torch.quantization.MovingAverageMinMaxObserver,
96 | quant_min=0,
97 | quant_max=255,
98 | dtype=torch.quint8,
99 | qscheme=torch.per_tensor_affine,
100 | reduce_range=False
101 | )
102 |
103 | def train(self):
104 | self.fake_quantize.train()
105 |
106 | def eval(self):
107 | self.fake_quantize.eval()
108 |
109 | def __call__(self, img_tensor):
110 | N, W, H = img_tensor.shape
111 | reshaped_batch = img_tensor.view(N, 1, -1)
112 |
113 | # Divide all pixel values by 255
114 | normalized_batch = reshaped_batch / self.intensity
115 | normalized_batch = torch.squeeze(normalized_batch, 0)
116 |
117 | # Apply FakeQuantize
118 | spikes = self.fake_quantize(normalized_batch)
119 |
120 | if not self.fake_quantize.training:
121 | scale, zero_point = self.fake_quantize.calculate_qparams()
122 | spikes = torch.quantize_per_tensor(spikes, float(scale), int(zero_point), dtype=torch.quint8)
123 |
124 | return spikes
125 |
126 | class ProcessImage:
127 | def __init__(self, dims, patches):
128 | self.dims = dims
129 | self.patches = patches
130 |
131 | def __call__(self, img):
132 | # Convert the image to grayscale using the standard weights for RGB channels
133 | if img.shape[0] == 3:
134 | img = 0.299 * img[0] + 0.587 * img[1] + 0.114 * img[2]
135 | # Add a channel dimension to the resulting grayscale image
136 | img= img.unsqueeze(0)
137 |
138 | # gamma correction
139 | mid = 0.5
140 | mean = torch.mean(img.float())
141 | gamma = math.log(mid * 255) / math.log(mean)
142 | img = torch.pow(img, gamma).clip(0, 255)
143 |
144 | # resize and patch normalize
145 | if len(img.shape) == 3:
146 | img = img.unsqueeze(0)
147 | img = F.interpolate(img, size=self.dims, mode='bilinear', align_corners=False)
148 | img = img.squeeze(0)
149 | patch_normaliser = PatchNormalisePad(self.patches)
150 | im_norm = patch_normaliser(img)
151 | img = (255.0 * (1 + im_norm) / 2.0).to(dtype=torch.uint8)
152 | img = torch.unsqueeze(img,0)
153 | spike_maker = SetImageAsSpikes()
154 | img = spike_maker(img)
155 | img = torch.squeeze(img,0)
156 |
157 | return img
158 |
159 | class CustomImageDataset(Dataset):
160 | def __init__(self, annotations_file, base_dir, img_dirs, transform=None, target_transform=None,
161 | filter=1, skip=0, max_samples=None, test=True, img_range=None):
162 | self.transform = transform
163 | self.target_transform = target_transform
164 | self.filter = filter
165 | self.img_range = img_range
166 | self.skip = skip
167 |
168 | # Load image labels from each directory, apply the skip and max_samples, and concatenate
169 | self.img_labels = []
170 | # if annotations_file is a single file, convert it to a list
171 | if not isinstance(annotations_file, list):
172 | annotations_file = [annotations_file]
173 | for idx, annotation in enumerate(annotations_file):
174 | img_labels = pd.read_csv(annotation)
175 | img_labels['file_path'] = img_labels.apply(lambda row: os.path.join(base_dir,img_dirs[idx], row.iloc[0]), axis=1)
176 | if self.img_range is not None:
177 | img_labels = img_labels.iloc[self.img_range[0]:self.img_range[1]+1]
178 | # Apply skip: start after the first 'skip' number of rows
179 | if self.skip > 0:
180 | img_labels = img_labels.iloc[self.skip:]
181 | # Select specific rows based on the skip parameter
182 | img_labels = img_labels.iloc[::filter]
183 | # Limit the number of samples to max_samples if specified
184 | if max_samples is not None:
185 | img_labels = img_labels.iloc[:max_samples]
186 | # Determine if the images being fed are training or testing
187 | if test:
188 | self.img_labels = img_labels
189 | else:
190 | self.img_labels.append(img_labels)
191 |
192 | if isinstance(self.img_labels,list):
193 | # Concatenate all the DataFrames
194 | self.img_labels = pd.concat(self.img_labels, ignore_index=True)
195 |
196 | def __len__(self):
197 | return len(self.img_labels)
198 |
199 | def __getitem__(self, idx):
200 | img_path = self.img_labels.iloc[idx]['file_path']
201 | if not os.path.exists(img_path):
202 | raise FileNotFoundError(f"No file found for index {idx} at {img_path}.")
203 |
204 | image = read_image(img_path)
205 | label = self.img_labels.iloc[idx, 1]
206 |
207 | if self.transform:
208 | image = self.transform(image)
209 | if self.target_transform:
210 | label = self.target_transform(label)
211 |
212 | return image, label
--------------------------------------------------------------------------------
/vprtempo/src/download.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import zipfile
3 | import os
4 | from tqdm import tqdm
5 |
6 | def get_data_model():
7 | print('==== Downloading pre-trained models & Nordland images ====')
8 | # Download the pre-trained models
9 | dropbox_urls = [
10 | "https://www.dropbox.com/scl/fi/vb8ljhp5rm1cx4tbjfxx3/VPRTempo_pretrained_models.zip?rlkey=felsqy3qbapeeztgkfcdd1zix&st=xncoy7rg&dl=0",
11 | "https://www.dropbox.com/scl/fi/445psdi7srbhuqa807lyn/spring.zip?rlkey=5ciswaz0ygv107e6pzvxnm6ga&st=tdaslyuc&dl=0",
12 | "https://www.dropbox.com/scl/fi/l2dmccham4ifj0xf9p9jw/fall.zip?rlkey=gvmt5jvzdfw8p7008yfoxeb4s&st=z14ngqyx&dl=0",
13 | "https://www.dropbox.com/scl/fi/8ff3ozh6kujbg1vbnrasw/summer.zip?rlkey=563t03cd2vwfr32i9llg945m8&st=t96te8py&dl=0"
14 | ]
15 |
16 | folders = [
17 | "./vprtempo/models/",
18 | "./vprtempo/dataset/",
19 | "./vprtempo/dataset/",
20 | "./vprtempo/dataset/"
21 | ]
22 |
23 | names = [
24 | "VPRTempo_pretrained_models.zip",
25 | "spring.zip",
26 | "fall.zip",
27 | "summer.zip"
28 | ]
29 |
30 | for idx, url in enumerate(dropbox_urls):
31 | download_extract(url, folders[idx], names[idx])
32 |
33 | print('==== Downloading pre-trained models & Nordland images completed ====')
34 |
35 | def download_extract(url, folder, name):
36 | # Ensure the destination folder exists
37 | os.makedirs(folder, exist_ok=True)
38 |
39 | # Modify the URL for direct download
40 | direct_download_url = url.replace("dl=0", "dl=1")
41 |
42 | # Send a HEAD request to get the total file size
43 | with requests.head(direct_download_url, allow_redirects=True) as head:
44 | if head.status_code != 200:
45 | print(f"Failed to retrieve header for {name}. Status code: {head.status_code}")
46 | return
47 | total_size = int(head.headers.get('content-length', 0))
48 |
49 | # Initialize the progress bar for downloading
50 | with requests.get(direct_download_url, stream=True) as response, \
51 | open(os.path.join(folder, name), "wb") as file, \
52 | tqdm(total=total_size, unit='B', unit_scale=True, desc=f"Downloading {name}", ncols=80) as progress_bar:
53 |
54 | if response.status_code != 200:
55 | print(f"Failed to download {name}. Status code: {response.status_code}")
56 | return
57 |
58 | for chunk in response.iter_content(chunk_size=8192):
59 | if chunk:
60 | file.write(chunk)
61 | progress_bar.update(len(chunk))
62 |
63 | # Determine extraction path
64 | if name == "VPRTempo_pretrained_models.zip":
65 | extract_path = folder
66 | else:
67 | extract_path = os.path.join(folder, name.replace('.zip', ''))
68 |
69 | # Open the zip file
70 | with zipfile.ZipFile(os.path.join(folder, name), 'r') as zip_ref:
71 | # Get list of files in the zip
72 | members = zip_ref.namelist()
73 | # Initialize the progress bar for extraction
74 | with tqdm(total=len(members), desc=f"Extracting {name}", unit='file', ncols=80) as extract_bar:
75 | for member in members:
76 | zip_ref.extract(member, path=extract_path)
77 | extract_bar.update(1)
78 |
79 | # Remove the zip file after extraction
80 | os.remove(os.path.join(folder, name))
81 | print(f"Completed {name}")
--------------------------------------------------------------------------------
/vprtempo/src/loggers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import logging
4 |
5 | from datetime import datetime
6 |
7 | def model_logger():
8 | """
9 | Configure the model logger
10 | """
11 | now = datetime.now()
12 | output_base_folder = './vprtempo/output/'
13 | output_folder = output_base_folder + now.strftime("%d%m%y-%H-%M-%S")
14 |
15 | # Create the base output folder if it does not exist
16 | os.makedirs(output_base_folder, exist_ok=True)
17 |
18 | # Create the specific output folder
19 | os.mkdir(output_folder)
20 | # Create the logger
21 | logger = logging.getLogger("VPRTempo")
22 | if (logger.hasHandlers()):
23 | logger.handlers.clear()
24 | # Set the logger level
25 | logger.setLevel(logging.DEBUG)
26 | logging.basicConfig(filename=output_folder + "/logfile.log",
27 | filemode="a+",
28 | format="%(asctime)-15s %(levelname)-8s %(message)s")
29 | # Add the logger to the console (if specified)
30 | logger.addHandler(logging.StreamHandler())
31 |
32 | logger.info('')
33 | logger.info('██╗ ██╗██████╗ ██████╗ ████████╗███████╗███╗ ███╗██████╗ ██████╗')
34 | logger.info('██║ ██║██╔══██╗██╔══██╗╚══██╔══╝██╔════╝████╗ ████║██╔══██╗██╔═══██╗')
35 | logger.info('██║ ██║██████╔╝██████╔╝ ██║ █████╗ ██╔████╔██║██████╔╝██║ ██║')
36 | logger.info('╚██╗ ██╔╝██╔═══╝ ██╔══██╗ ██║ ██╔══╝ ██║╚██╔╝██║██╔═══╝ ██║ ██║')
37 | logger.info(' ╚████╔╝ ██║ ██║ ██║ ██║ ███████╗██║ ╚═╝ ██║██║ ╚██████╔╝')
38 | logger.info(' ╚═══╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚═════╝ ')
39 | logger.info('-----------------------------------------------------------------------')
40 | logger.info('Temporally Encoded Spiking Neural Network for Visual Place Recognition v1.1.9')
41 | logger.info('Queensland University of Technology, Centre for Robotics')
42 | logger.info('')
43 | logger.info('© 2023 Adam D Hines, Peter G Stratton, Michael Milford, Tobias Fischer')
44 | logger.info('MIT license - https://github.com/QVPR/VPRTempo')
45 | logger.info('\\\\\\\\\\\\\\\\\\\\\\\\')
46 | logger.info('')
47 | if torch.cuda.is_available():
48 | logger.info('CUDA available: ' + str(torch.cuda.is_available()))
49 | current_device = torch.cuda.current_device()
50 | logger.info('Current device is: ' + str(torch.cuda.get_device_name(current_device)))
51 | elif torch.backends.mps.is_available():
52 | logger.info('MPS available: ' + str(torch.backends.mps.is_available()))
53 | logger.info('Current device is: MPS')
54 | else:
55 | logger.info('CUDA available: ' + str(torch.cuda.is_available()))
56 | logger.info('Current device is: CPU')
57 | logger.info('')
58 |
59 | return logger, output_folder
60 |
61 | def model_logger_quant():
62 | """
63 | Configure the logger
64 | """
65 |
66 | now = datetime.now()
67 | output_base_folder = './vprtempo/output/'
68 | output_folder = output_base_folder + now.strftime("%d%m%y-%H-%M-%S")
69 |
70 | # Create the base output folder if it does not exist
71 | os.makedirs(output_base_folder, exist_ok=True)
72 |
73 | # Create the specific output folder
74 | os.mkdir(output_folder)
75 | # Create the logger
76 | logger = logging.getLogger("VPRTempo")
77 | if (logger.hasHandlers()):
78 | logger.handlers.clear()
79 | # Set the logger level
80 | logger.setLevel(logging.DEBUG)
81 | logging.basicConfig(filename=output_folder + "/logfile.log",
82 | filemode="a+",
83 | format="%(asctime)-15s %(levelname)-8s %(message)s")
84 | # Add the logger to the console (if specified)
85 | logger.addHandler(logging.StreamHandler())
86 |
87 | logger.info('')
88 |
89 | logger.info('██╗ ██╗██████╗ ██████╗ ████████╗███████╗███╗ ███╗██████╗ ██████╗ ██████╗ ██╗ ██╗ █████╗ ███╗ ██╗████████╗')
90 | logger.info('██║ ██║██╔══██╗██╔══██╗╚══██╔══╝██╔════╝████╗ ████║██╔══██╗██╔═══██╗ ██╔═══██╗██║ ██║██╔══██╗████╗ ██║╚══██╔══╝')
91 | logger.info('██║ ██║██████╔╝██████╔╝ ██║ █████╗ ██╔████╔██║██████╔╝██║ ██║█████╗██║ ██║██║ ██║███████║██╔██╗ ██║ ██║')
92 | logger.info('╚██╗ ██╔╝██╔═══╝ ██╔══██╗ ██║ ██╔══╝ ██║╚██╔╝██║██╔═══╝ ██║ ██║╚════╝██║▄▄ ██║██║ ██║██╔══██║██║╚██╗██║ ██║')
93 | logger.info(' ╚████╔╝ ██║ ██║ ██║ ██║ ███████╗██║ ╚═╝ ██║██║ ╚██████╔╝ ╚██████╔╝╚██████╔╝██║ ██║██║ ╚████║ ██║')
94 | logger.info(' ╚═══╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚═════╝ ╚══▀▀═╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═══╝ ╚═╝')
95 | logger.info('-----------------------------------------------------------------------')
96 | logger.info('Temporally Encoded Spiking Neural Network for Visual Place Recognition v1.1.8')
97 | logger.info('Queensland University of Technology, Centre for Robotics')
98 | logger.info('')
99 | logger.info('© 2023 Adam D Hines, Peter G Stratton, Michael Milford, Tobias Fischer')
100 | logger.info('MIT license - https://github.com/QVPR/VPRTempo')
101 | logger.info('\\\\\\\\\\\\\\\\\\\\\\\\')
102 | logger.info('')
103 | logger.info('Quantization enabled')
104 | logger.info('Current device is: CPU')
105 | logger.info('')
106 |
107 | return logger, output_folder
108 |
--------------------------------------------------------------------------------
/vprtempo/src/metrics.py:
--------------------------------------------------------------------------------
1 | # =====================================================================
2 | # Copyright (C) 2023 Stefan Schubert, stefan.schubert@etit.tu-chemnitz.de
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see .
16 | # =====================================================================
17 | #
18 | import numpy as np
19 |
20 |
21 | def createPR(S_in, GThard, GTsoft=None, matching='multi', n_thresh=100):
22 | """
23 | Calculates the precision and recall at n_thresh equally spaced threshold values
24 | for a given similarity matrix S_in and ground truth matrices GThard and GTsoft for
25 | single-best-match VPR or multi-match VPR.
26 |
27 | The matrices S_in, GThard and GTsoft are two-dimensional and should all have the
28 | same shape.
29 | The matrices GThard and GTsoft should be binary matrices, where the entries are
30 | only zeros or ones.
31 | The matrix S_in should have continuous values between -Inf and Inf. Higher values
32 | indicate higher similarity.
33 | The string matching should be set to either "single" or "multi" for single-best-
34 | match VPR or multi-match VPR.
35 | The integer n_tresh controls the number of threshold values and should be >1.
36 | """
37 |
38 | assert (S_in.shape == GThard.shape),"S_in, GThard and GTsoft must have the same shape"
39 | assert (S_in.ndim == 2),"S_in, GThard and GTsoft must be two-dimensional"
40 | assert (matching in ['single', 'multi']),"matching should contain one of the following strings: [single, multi]"
41 | assert (n_thresh > 1),"n_thresh must be >1"
42 |
43 | if GTsoft is not None and matching == 'single':
44 | raise ValueError(
45 | "GTSoft with single matching is not supported. "
46 | "Please use dilated hard ground truth directly. "
47 | "For more details, visit: https://github.com/stschubert/VPR_Tutorial"
48 | )
49 |
50 | # ensure logical datatype in GT and GTsoft
51 | GT = GThard.astype('bool')
52 | if GTsoft is not None:
53 | GTsoft = GTsoft.astype('bool')
54 | GThard_orig = GThard.copy()
55 |
56 | # copy S and set elements that are only true in GTsoft to min(S) to ignore them during evaluation
57 | S = S_in.copy()
58 |
59 | if GTsoft is not None:
60 | S[GTsoft & ~GT] = S.min()
61 |
62 | if matching == 'single':
63 | # GT-values for best match per query (i.e., per column)
64 | GT = GT[np.argmax(S, axis=0), np.arange(GT.shape[1])]
65 | # similarities for best match per query (i.e., per column)
66 | S = np.max(S, axis=0)
67 |
68 | # init precision and recall vectors
69 | R = [0, ]
70 | P = [1, ]
71 |
72 | # select start and end treshold
73 | startV = S.max() # start-value for treshold
74 | endV = S.min() # end-value for treshold
75 | thresholds = np.linspace(startV, endV, n_thresh)
76 |
77 | # Iterate over different thresholds with enumeration to track the last iteration
78 | for i in thresholds:
79 | B = S >= i # Apply threshold
80 |
81 | TP = np.count_nonzero(GT & B) # True Positives
82 | FP = np.count_nonzero((~GT) & B) # False Positives
83 | FN = np.count_nonzero(GT & (~B)) # False Negatives
84 |
85 | # Handle division by zero for precision
86 | precision = TP / (TP + FP)
87 | recall = TP / (TP + FN)
88 |
89 | P.append(precision) # Precision
90 | R.append(recall) # Recall
91 |
92 | return P, R
93 |
94 |
95 | def recallAt100precision(S_in, GThard, GTsoft=None, matching='multi', n_thresh=100):
96 | """
97 | Calculates the maximum recall at 100% precision for a given similarity matrix S_in
98 | and ground truth matrices GThard and GTsoft for single-best-match VPR or multi-match
99 | VPR.
100 |
101 | The matrices S_in, GThard and GTsoft are two-dimensional and should all have the
102 | same shape.
103 | The matrices GThard and GTsoft should be binary matrices, where the entries are
104 | only zeros or ones.
105 | The matrix S_in should have continuous values between -Inf and Inf. Higher values
106 | indicate higher similarity.
107 | The string matching should be set to either "single" or "multi" for single-best-
108 | match VPR or multi-match VPR.
109 | The integer n_tresh controls the number of threshold values during the creation of
110 | the precision-recall curve and should be >1.
111 | """
112 |
113 | assert (S_in.shape == GThard.shape),"S_in and GThard must have the same shape"
114 | if GTsoft is not None:
115 | assert (S_in.shape == GTsoft.shape),"S_in and GTsoft must have the same shape"
116 | assert (S_in.ndim == 2),"S_in, GThard and GTsoft must be two-dimensional"
117 | assert (matching in ['single', 'multi']),"matching should contain one of the following strings: [single, multi]"
118 | assert (n_thresh > 1),"n_thresh must be >1"
119 |
120 | # get precision-recall curve
121 | P, R = createPR(S_in, GThard, GTsoft, matching=matching, n_thresh=n_thresh)
122 | P = np.array(P)
123 | R = np.array(R)
124 |
125 | # recall values at 100% precision
126 | R = R[P==1]
127 |
128 | # maximum recall at 100% precision
129 | R = R.max()
130 |
131 | return R
132 |
133 |
134 | def recallAtK(S, GT, K=1):
135 | """
136 | Calculates the recall@K for a given similarity matrix S and ground truth matrix GT.
137 | Note that this method does not support GTsoft - instead, please directly provide
138 | the dilated ground truth matrix as GT.
139 |
140 | The matrices S and GT are two-dimensional and should all have the same shape.
141 | The matrix GT should be binary, where the entries are only zeros or ones.
142 | The matrix S should have continuous values between -Inf and Inf. Higher values
143 | indicate higher similarity.
144 | The integer K>=1 defines the number of matching candidates that are selected and
145 | that must contain an actually matching image pair.
146 | """
147 |
148 | assert (S.shape == GT.shape),"S and GT must have the same shape"
149 | assert (S.ndim == 2),"S and GT must be two-dimensional"
150 | assert (K >= 1),"K must be >=1"
151 |
152 | # ensure logical datatype in GT
153 | GT = GT.astype('bool')
154 |
155 | # discard all query images without an actually matching database image
156 | j = GT.sum(0) > 0 # columns with matches
157 | S = S[:,j] # select columns with a match
158 | GT = GT[:,j] # select columns with a match
159 |
160 | # select K highest similarities
161 | i = S.argsort(0)[-K:,:]
162 | j = np.tile(np.arange(i.shape[1]), [K, 1])
163 | GT = GT[i, j]
164 |
165 | # recall@K
166 | RatK = np.sum(GT.sum(0) > 0) / GT.shape[1]
167 |
168 | return RatK
--------------------------------------------------------------------------------
/vprtempo/src/nordland.py:
--------------------------------------------------------------------------------
1 | '''
2 | Imports
3 | '''
4 | import os
5 | import re
6 | import shutil
7 | import zipfile
8 | import sys
9 | sys.path.append('..//dataset')
10 |
11 | from os import walk
12 |
13 | def nord_sort():
14 | # load and sort the file names in order, not by how OS indexes them
15 | def atoi(text):
16 | return int(text) if text.isdigit() else text
17 |
18 | def natural_keys(text):
19 | return [ atoi(c) for c in re.split(r'(\d+)', text) ]
20 |
21 | # set the base path to the location of the downloaded Nordland datasets
22 | basePath = '../dataset/'
23 | assert(os.path.isdir(basePath)),"Please set the basePath to the location of the downloaded Nordland datasets"
24 |
25 | # define the subfolders of the Nordland datasets
26 | subPath = ["spring_images_train/section1/","spring_images_train/section2/",
27 | "fall_images_train/section1/","fall_images_train/section2/",
28 | "winter_images_train/section1/","winter_images_train/section2/",
29 | "summer_images_train/section1/","summer_images_train/section2/"]
30 |
31 | # set the desired output folder for unzipping and organization
32 | outDir = '../dataset/'
33 | assert(os.path.isdir(outDir)),"Please set the outDir to the desired output location for unzipping the Nordland datasets"
34 |
35 | # define output paths for the data
36 | outPath = [os.path.join(outDir,"spring/"),os.path.join(outDir,"fall/"),
37 | os.path.join(outDir,"winter/"),os.path.join(outDir,"summer/")]
38 |
39 | # check for existence of the zip folders, throw exception if missing
40 | zipNames = ["spring_images_train.zip","fall_images_train.zip",
41 | "winter_images_train.zip","summer_images_train.zip"]
42 | for n in zipNames:
43 | if not os.path.exists(basePath+n):
44 | raise Exception('Please ensure dataset .zip folders have been downloaded')
45 |
46 | # check if nordland data folders have already been unzipped
47 | zip_flag = []
48 | for n, ndx in enumerate(range(0,len(subPath),2)):
49 | print('Unzipping '+zipNames[n])
50 | if os.path.exists(basePath+subPath[ndx]):
51 | # check if the folder contains any files
52 | file_lst = os.listdir(basePath+subPath[ndx])
53 | # remove folder if it is empty and unzip the data folder
54 | if len(file_lst) == 0:
55 | shutil.rmtree(basePath+subPath[ndx].replace('section1/',''))
56 | with zipfile.ZipFile(basePath+zipNames[n],"r") as zip_ref:
57 | zip_ref.extractall(basePath)
58 | else:
59 | with zipfile.ZipFile(basePath+zipNames[n],"r") as zip_ref:
60 | zip_ref.extractall(basePath)
61 |
62 | # load image paths
63 | tempPaths = []
64 | imgPaths = []
65 | for n in range(0,len(subPath)):
66 | tempPaths = []
67 | for (path, dir_names, file_names) in walk(basePath+subPath[n]):
68 | tempPaths.extend(file_names)
69 | # sort image names
70 | tempPaths.sort(key=natural_keys)
71 | tempPaths = [basePath+subPath[n]+s for s in tempPaths]
72 | imgPaths = imgPaths + tempPaths
73 |
74 | # if output folders already exist, delete them
75 | for n in outPath:
76 | if os.path.exists(n):
77 | shutil.rmtree(n)
78 | print('Removed pre-existing output folder')
79 |
80 | # rename and move the training data to match the nordland_imageNames.txt file
81 | for n in outPath:
82 | os.mkdir(n)
83 | for n, filename in enumerate(imgPaths):
84 | base = os.path.basename(filename)
85 | split_base = os.path.splitext(base)
86 | if int(split_base[0]) < 10:
87 | my_dest = "images-0000"+split_base[0] + ".png"
88 | elif int(split_base[0]) < 100:
89 | my_dest = "images-000"+split_base[0] + ".png"
90 | elif int(split_base[0]) < 1000:
91 | my_dest = "images-00"+split_base[0] + ".png"
92 | elif int(split_base[0]) < 10000:
93 | my_dest = "images-0"+split_base[0] + ".png"
94 | else:
95 | my_dest = "images-"+split_base[0] + ".png"
96 | if "spring" in filename:
97 | out = outPath[0]
98 | elif "fall" in filename:
99 | out = outPath[1]
100 | elif "winter" in filename:
101 | out = outPath[2]
102 | else:
103 | out = outPath[-1]
104 |
105 | fileDest = out + my_dest
106 | os.rename(filename, fileDest)
107 |
108 | # remove the empty folders
109 | for n, ndx in enumerate(subPath):
110 | if n%2 == 0:
111 | shutil.rmtree(basePath+ndx.replace('section1/',''))
112 | else:
113 | continue
114 |
115 | print('Finished unzipping and organizing Nordland dataset')
--------------------------------------------------------------------------------
/vprtempo/src/process_orc.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from tqdm import tqdm
4 | import image
5 | from camera_model import CameraModel
6 |
7 | # Set this to the directory containing the images
8 | base_path = '/media/adam/vprdatasets/data/orc'
9 | # Modify if evaluating different ORC datasets
10 | datasets = [
11 | '2015-08-12-15-04-18',
12 | '2014-11-21-16-07-03',
13 | '2015-10-29-12-18-17'
14 | ]
15 | # Set up output path names
16 | processed_path = 'demosaiced'
17 |
18 | for dataset in datasets:
19 | # Define the dataset path
20 | dataset_path = os.path.join(base_path, dataset)
21 |
22 | # Create the folders
23 | os.makedirs(os.path.join(dataset_path, processed_path), exist_ok=True)
24 |
25 | # file path for the robotcar-dataset-sdk models folder
26 | model_dir = '/home/adam/repo/robotcar-dataset-sdk/models/'
27 | # file path for the left stereo images
28 | images_path = os.path.join(dataset_path, 'stereo/left')
29 | # Create a camera model object
30 | model = CameraModel(model_dir,images_path)
31 | # Get sorted list of PNG images
32 | images = sorted([os.path.join(images_path, image) for image in os.listdir(images_path) if image.endswith('.png')])
33 |
34 | # Process each image with a progress bar
35 | for img in tqdm(images, desc="Processing images", unit="image"):
36 | # Load the image
37 | processed_img = image.load_image(img, model=model)
38 | # Create the output file path
39 | output_img = os.path.join(dataset_path, processed_path, os.path.basename(img))
40 | processed_img = Image.fromarray(processed_img, mode="RGB")
41 | # Save the processed image as PNG
42 | processed_img.save(output_img, "PNG")
43 |
44 | print("Images processed and saved successfully.")
--------------------------------------------------------------------------------