├── .gitignore
├── README.md
├── ckpts
├── 110.pickle
└── 78.pickle
├── knot_theory.ipynb
├── requirements.txt
└── test_ckpt.py
/.gitignore:
--------------------------------------------------------------------------------
1 | knot_theory_invariants.csv
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CompactNet
2 | This repo is a slight modification on DeepMind's _Advancing mathematics by
3 | guiding human intuition with AI_
4 | - [Original Notebook](https://colab.research.google.com/github/deepmind/mathematics_conjectures/blob/main/knot_theory.ipynb)
5 | - [Original Repo](https://github.com/google-deepmind/mathematics_conjectures)
6 | - [Nature Paper](https://www.nature.com/articles/s41586-021-04086-x)
7 |
8 | We noticed some issues when the [KAN](https://arxiv.org/abs/2404.19756) paper
9 | cited this work and found that the comparisons had some errors.
10 | We found that we could match KAN's 81.6% accuracy on this dataset with as few as
11 | 122 parameters.
12 | We did not make any major modifications to the DeepMind code.
13 | To achieve this result we only decreased the network size, used a random seed,
14 | and increased the training time.
15 | Keeping the same seed and keeping the same training cutoff we could get a
16 | matching result with a network with 204 parameters.
17 |
18 | The table below depicts some of our results.
19 | There are some variances so numbers may change slightly during your runs.
20 | Running several times you should be quite similar to ours.
21 | These results maintain the same random seed and the same training limit.
22 |
23 | | Network | Number of Hidden Neurons| Number of Parameters | Accuracy Pre Salient | Accuracy Post Salient |
24 | |:--------:|:---------------------:|:---------------------:|:----------------------:|:-----------------------:|
25 | | [300, 300, 300] | 900 | 190,214 | 81.38% | 80.14% |
26 | | [100, 100, 100] | 300 | 23,414 | 82.79% | 82.04% |
27 | | [50, 50, 50] | 150 | 6,714 | 85.13% | 81.65% |
28 | | [10, 10, 10] | 30 | 554 | 84.45% | 82.30% |
29 | | [5, 5, 5] | 15 | 234 | 83.06% | 80.42% |
30 | | [4, 4, 4] | 12 | 182 | 76.73% | 65.19% |
31 | | [3, 3, 3] | 9 | 134 | 66.33% | 74.93% |
32 | | [50, 50] | 100 | 4,164 | 87.15% | 82.65% |
33 | | [10, 10] | 20 | 444 | 83.02% | 81.50% |
34 | | [5, 5] | 10 | 204 | 82.19% | 81.33% |
35 | | [4, 4] | 8 | 162 | 81.89% | 81.03% |
36 | | [3, 3] | 6 | 122 | 77.72% | 76.24% |
37 | | Baseline (direct calculate) | 0 | 0 | 73.82% | 73.82% |
38 | | DeepMind's 4 layer reported | 900 | 190,214 | 78% | 78% |
39 | | KAN | N/A | 200 | 81.6% | 78.2% |
40 |
41 | Here are some results where we have changed the random seed and training length.
42 | We set `num_training_steps` to 50k for an arbitrarially long run and report how many steps before the network early stopped (`Steps`)
43 | | Network | Number of Hidden Neurons|Number of Parameters | seed | Accuarcy Pre Salient | Steps | Accuracy Post Salient | Steps |
44 | |:--------:|:----:|:---------------------:|:----:|:---------------------:|:-----:|:----------------------:|:-----:|
45 | | [3,3] | 6 | 122 | 552 | 81.60% | 20700 | 81.69% | 22100 |
46 | | [2,2] | 4 | 84 | 8110 | 81.33% | 22700 | 80.44% | 23300 |
47 |
48 | We also have advanced methods to train a two layer extremely small MLP in 1k steps that can achieve average performance (averaged over 10 times) of more thann 80%. You may check the models in the folder `ckpt` for details. You may run the following code for inference.
49 | ```python
50 | python test_ckpt.py -hn 3
51 | python test_ckpt.py -hn 2
52 | ```
53 |
54 | | Network | Number of Neurons|Number of Parameters | Accuarcy Pre Salient | Steps | Accuracy Post Salient | Steps |
55 | |:--------:|:---------------------:|:----:|:---------------------:|:-----:|:----------------------:|:-----:|
56 | | [3] | 3 | 110 | 82.42% | ~1000 | 80.33% | ~1000 |
57 | | [2] | 2 | 78 | 80.85% | ~1000 | 80.23% | ~1000|
58 |
59 |
60 | ## Running
61 | ```
62 | pip install -r requirements.txt
63 | ```
64 | You will also need to install the dataset which requires having installed
65 | [gsutil](https://cloud.google.com/storage/docs/gsutil_install).
66 | If you install this, we will automatically download the dataset for you.
67 | Make sure `gsutil` is in your path before opening the notebook or it may not be
68 | able to download it for you.
69 |
70 | Line 1 of the file contains the network definition where you should define how
71 | many hidden neurons you want per hidden layer.
72 | The length of the list determines the number of hidden layers.
73 | For example `[2,2]` means two hidden layers with 2 neurons each.
74 |
--------------------------------------------------------------------------------
/ckpts/110.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/CompactNet/cba12b447a654b154051eeb63caff6088c022a15/ckpts/110.pickle
--------------------------------------------------------------------------------
/ckpts/78.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/CompactNet/cba12b447a654b154051eeb63caff6088c022a15/ckpts/78.pickle
--------------------------------------------------------------------------------
/knot_theory.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "n-vSyzB-qmir"
7 | },
8 | "source": [
9 | "\n",
10 | "Copyright 2021 DeepMind Technologies Limited\n",
11 | "\n",
12 | "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
13 | "\n",
14 | "https://www.apache.org/licenses/LICENSE-2.0\n",
15 | "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {},
21 | "source": [
22 | "# Reimplimitation/Validation of DeepMind Colab\n",
23 | "- [Original Notebook](https://colab.research.google.com/github/deepmind/mathematics_conjectures/blob/main/knot_theory.ipynb)\n",
24 | "- [Original Repo](https://github.com/google-deepmind/mathematics_conjectures)\n",
25 | "- [Nature Paper](https://www.nature.com/articles/s41586-021-04086-x)\n",
26 | "\n",
27 | "The code is nearly identical to the original notebook which you can compare with the link above. We've only made some minor changes (the rest is unmodified)\n",
28 | "- We [rewrote the network](#Network-Definition-\\(modified\\)) to be defined with a loop so you can use the array on the next line to define the architecture\n",
29 | "- We noticed an off-by-one error on the last layer. Original has an output of 13 and it should be 14 (even). In most cases this seems not to make a big difference.\n",
30 | "\n",
31 | "We noticed that we could shrink DeepMind's Network and still get >80% accuracy. We find that you can reliably get netowrks with <200 parameters to classify with >80% accuracy. We've included some of our results [here](#Our-Results). If you find better results, let us know! :)\n",
32 | " \n",
33 | "# How To Use?\n",
34 | "1) First, make sure that you install the requirements (if you haven't, create a new cell and run `!pip install -r requirements.txt`)\n",
35 | "2) Download [gsutils](https://cloud.google.com/storage/docs/gsutil_install) (This is called [in this cell](Run me first to get data)). Make sure that `gsutils` is within your path (the installation should handle this. You may need to close jupyter and reopen in a new terminal). We will download the file for you if we don't find it.\n",
36 | "\n",
37 | "For quick validation you can just hit run and you will get an output from a 4 layer network (2 hidden layers) where each hidden layer has 5 neurons. Your result should match the shown outputs of this notebook (within small variance)\n",
38 | "\n",
39 | "To define a different MLP structure edit `hidden_layers` on the cell below. The numbers represent the hidden neurons per hidden layer. This is arbitrary so `[4,5,6]` would define a 5 layer network (3 hidden layers) where there are 4 neurons in the first hidden layer, 5 in the second hidden layer, and 6 in the third hidden layer. The input layer is always 17 neurons and the output is 14 (when using `use_fix`). You can see the summary of your network in the output just above [this cell](#Network-Setup)\n",
40 | "\n",
41 | "When using smaller networks you may want to edit some of the user defined parameters, specifically `num_training_steps` (suggest use the commented `50_000`). This defines the *maximum* number of training steps that will be used. Network will early stop if validation loss increases. \n",
42 | "\n",
43 | "We also suggest trying different random seeds. There is some variance to the results so this may be good to check. You can find the user defined variables [here](#User-Defined-Parameters)\n",
44 | "\n",
45 | "### What We Don't Do\n",
46 | "We didn't want to stray from DeepMind's code much so we don't implement any different training techniques or anything else. It is likely that results could be improved with changes such as different optimizers, learning rates (learning schedulers), or any number of modern DL training techniques. "
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 1,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "# Length of array = number of hidden layers\n",
56 | "# Array entries = number of hidden neurons per hidden layer\n",
57 | "# e.g. [5,5] = 4 layer network with [17, 5, 5, 14] (17 input neurons, 2 hidden layers w/ 5 hidden neurons, 14 output neurons)\n",
58 | "hidden_layers = [5,5]\n",
59 | "# Set True to use the off-by-one error fix\n",
60 | "use_fix = True"
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "metadata": {},
66 | "source": [
67 | "# Our Results\n",
68 | "Numbers may vary. With smaller networks the network will likely stop training at the max number of training steps (10k in original)\n",
69 | "All results in this table are with identical settings except changing the hidden layer.\n",
70 | "(Your numbers may vary slightly, especially in smaller networks. See tips below)\n",
71 | "| Network | Number of Parameters | Accuracy Pre Salient | Accuracy Post Salient |\n",
72 | "|:--------|---------------------:|---------------------:|----------------------:|\n",
73 | "|[300,300,300] | 190,214 | 0.8137779 | 0.8013719 |\n",
74 | "|[100,100,100] | 23,414 | 0.82789063 | 0.8204076 |\n",
75 | "|[50,50,50] | 6,714 | 0.8512915 | 0.8164692 |\n",
76 | "|[10,10,10] | 554 | 0.8444977 | 0.82296765 |\n",
77 | "|[5,5,5] | 234 | 0.83061475 | 0.80422723 |\n",
78 | "|[4,4,4] | 182 | 0.7673045 | 0.6519413 |\n",
79 | "|[3,3,3] | 134 | 0.66332996 | 0.74925333 | \n",
80 | "| - | - | - | - |\n",
81 | "|[50,50] | 4,164 | 0.87154156 | 0.8265122 |\n",
82 | "|[10,10] | 444 | 0.8302209 | 0.8149923 |\n",
83 | "|[5,5] | 204 | 0.82191736 | 0.8132528 |\n",
84 | "|[4,4] | 162 | 0.81886506 | 0.810299 |\n",
85 | "|[3,3] | 122 | 0.77718335 | 0.76238143 |\n",
86 | "\n",
87 | "Here are some results where we have changed the random seed and training length. We set `num_training_steps` to 50k for an arbitrarially long run and report how many steps before the network early stopped (`Steps`)\n",
88 | "| Network | Number of Parameters | seed | Accuarcy Pre Salient | Steps | Accuracy Post Salient | Steps |\n",
89 | "|:--------|---------------------:|:----:|---------------------:|:-----:|----------------------:|:-----:|\n",
90 | "| [3,3] | 122 | 552 | 0.8160097 | 20700 | 0.81686306 | 22100 |\n",
91 | "| [2,2] | 84 | 8110 | 0.81328565 | 22700 | 0.8043914 | 23300 |\n",
92 | " "
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "metadata": {},
98 | "source": [
99 | "# User Defined Parameters"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "metadata": {},
105 | "source": [
106 | "We've moved user defined parameters here to make this easier to run\n",
107 | "#### Tips\n",
108 | "- Change random seed\n",
109 | "- Increase `num_training_steps`\n",
110 | "\n",
111 | "Most runs smaller than `[50,50,50]` will likely run the full `num_training_steps` so it is a good idea to change this value to see that actual potential of the network. Larger networks will likely early stop. This can also result in a small network getting a very different result when the random seed is different.\n",
112 | "\n",
113 | "Defaults\n",
114 | "| Variable | Value | Description |\n",
115 | "|:---------|-------:|-----------:|\n",
116 | "| `random_seed` | 2 | Random seed for notebook |\n",
117 | "| `num_training_steps` | 10_000 | Maximum number of training steps (will early stop) |\n",
118 | "| `network_weight_rng` | 1 | [Seed for generating network's initial random weights](#Network-Setup) |\n",
119 | "| `batch_size` | 64 | Training Batch size |\n",
120 | "| `learning_rate` | 0.001 | Training Learning Rate |\n",
121 | "| `validation_interval` | 100 | Frequency to check against validation split |\n",
122 | "\n",
123 | "`network_weight_rng` is used [here](#Network-Setup)\n"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 2,
129 | "metadata": {},
130 | "outputs": [
131 | {
132 | "name": "stdout",
133 | "output_type": "stream",
134 | "text": [
135 | "This run was performed with random_seed=2\n"
136 | ]
137 | }
138 | ],
139 | "source": [
140 | "# These are default values provided by the original colab\n",
141 | "import random\n",
142 | "#random_seed = random.randint(0,10000)\n",
143 | "random_seed = 2 # @param {type: \"integer\"}\n",
144 | "print(f\"This run was performed with {random_seed=}\")\n",
145 | "#\n",
146 | "# *maximum* number of training steps (will early stop!)\n",
147 | "num_training_steps = 10_000 # @param {type: \"integer\"}\n",
148 | "### Suggest using this, especially if smaller network\n",
149 | "#num_training_steps = 50_000 # @param {type: \"integer\"}\n",
150 | "#\n",
151 | "# Random number for initializing the network weights\n",
152 | "network_weight_rng = 1 # @param {type: \"integer\"}\n",
153 | "#\n",
154 | "#\n",
155 | "batch_size = 64 # @param {type: \"integer\"}\n",
156 | "learning_rate = 0.001 # @param {type: \"float\"}\n",
157 | "validation_interval = 100 # @param {type: \"integer\"}\n",
158 | "\n"
159 | ]
160 | },
161 | {
162 | "cell_type": "markdown",
163 | "metadata": {},
164 | "source": [
165 | "# Run me first to get data\n",
166 | "\n",
167 | "You may need to download [gsutil](https://cloud.google.com/storage/docs/gsutil_install)\n",
168 | "\n",
169 | "After doing so, close jupyter and open it back up in a new terminal. Making sure you have the binary in your path"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": 3,
175 | "metadata": {},
176 | "outputs": [],
177 | "source": [
178 | "# @title Download data\n",
179 | "import pandas as pd\n",
180 | "\n",
181 | "#_, input_filename = tempfile.mkstemp()\n",
182 | "input_filename = \"./knot_theory_invariants.csv\"\n",
183 | "# USE YOUR OWN gsutil cp COMMAND HERE\n",
184 | "!(if [[ ! -a {input_filename} ]]; then \\\n",
185 | " gsutil cp \"gs://maths_conjectures/knot_theory/knot_theory_invariants.csv\" {input_filename}; \\\n",
186 | "fi)\n",
187 | "\n",
188 | "full_df = pd.read_csv(input_filename)"
189 | ]
190 | },
191 | {
192 | "cell_type": "markdown",
193 | "metadata": {},
194 | "source": [
195 | "# Run (Mostly Unmodified Deep Mind Code)"
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {},
201 | "source": [
202 | "### Pip installs\n",
203 | "Uncomment me if first time!"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": 4,
209 | "metadata": {
210 | "cellView": "form",
211 | "id": "xQnp1V1cvdxy"
212 | },
213 | "outputs": [],
214 | "source": [
215 | "# @title Install required modules\n",
216 | "from IPython.display import clear_output\n",
217 | "\n",
218 | "#!pip install dm-haiku\n",
219 | "#!pip install optax\n",
220 | "clear_output()"
221 | ]
222 | },
223 | {
224 | "cell_type": "markdown",
225 | "metadata": {},
226 | "source": [
227 | "### Imports"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 5,
233 | "metadata": {
234 | "cellView": "form",
235 | "id": "3NFhPsG4u1L1"
236 | },
237 | "outputs": [],
238 | "source": [
239 | "# @title Imports\n",
240 | "import tempfile\n",
241 | "\n",
242 | "import haiku as hk\n",
243 | "import jax\n",
244 | "import jax.numpy as jnp\n",
245 | "import matplotlib.pyplot as plt\n",
246 | "import numpy as np\n",
247 | "import optax\n",
248 | "import seaborn as sns\n",
249 | "from sklearn.model_selection import train_test_split"
250 | ]
251 | },
252 | {
253 | "cell_type": "markdown",
254 | "metadata": {},
255 | "source": [
256 | "### Load and Preprocess Data"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": 6,
262 | "metadata": {
263 | "cellView": "form",
264 | "id": "C-kZ_rW-h8Nk"
265 | },
266 | "outputs": [],
267 | "source": [
268 | "# @title Load and preprocess data\n",
269 | "\n",
270 | "#@markdown The columns of the dataset which will make up the inputs to the network.\n",
271 | "#@markdown In other words, for a knot k, X(k) will be a vector consisting of these quantities. In this case, these are the geometric invariants of each knot.\n",
272 | "#@markdown For descriptions of these invariants see https://knotinfo.math.indiana.edu/\n",
273 | "display_name_from_short_name = {\n",
274 | " 'chern_simons': 'Chern-Simons',\n",
275 | " 'cusp_volume': 'Cusp volume',\n",
276 | " 'hyperbolic_adjoint_torsion_degree': 'Adjoint Torsion Degree',\n",
277 | " 'hyperbolic_torsion_degree': 'Torsion Degree',\n",
278 | " 'injectivity_radius': 'Injectivity radius',\n",
279 | " 'longitudinal_translation': 'Longitudinal translation',\n",
280 | " 'meridinal_translation_imag': 'Re(Meridional translation)',\n",
281 | " 'meridinal_translation_real': 'Im(Meridional translation)',\n",
282 | " 'short_geodesic_imag_part': 'Im(Short geodesic)',\n",
283 | " 'short_geodesic_real_part': 'Re(Short geodesic)',\n",
284 | " 'Symmetry_0': 'Symmetry: $0$',\n",
285 | " 'Symmetry_D3': 'Symmetry: $D_3$',\n",
286 | " 'Symmetry_D4': 'Symmetry: $D_4$',\n",
287 | " 'Symmetry_D6': 'Symmetry: $D_6$',\n",
288 | " 'Symmetry_D8': 'Symmetry: $D_8$',\n",
289 | " 'Symmetry_Z/2 + Z/2': 'Symmetry: $\\\\frac{Z}{2} + \\\\frac{Z}{2}$',\n",
290 | " 'volume': 'Volume',\n",
291 | "}\n",
292 | "column_names = list(display_name_from_short_name)\n",
293 | "target = 'signature'\n",
294 | "\n",
295 | "#@markdown Split the data into a training, a validation and a holdout test set. To check\n",
296 | "#@markdown the robustness of the model and any proposed relationship, the training\n",
297 | "#@markdown process can be repeated with multiple different train/validation/test splits.\n",
298 | "\n",
299 | "random_state = np.random.RandomState(random_seed)\n",
300 | "train_df, validation_and_test_df = train_test_split(\n",
301 | " full_df, random_state=random_state)\n",
302 | "validation_df, test_df = train_test_split(\n",
303 | " validation_and_test_df, test_size=.5, random_state=random_state)\n",
304 | "\n",
305 | "# Find bounds for the signature in the training dataset.\n",
306 | "max_signature = train_df[target].max()\n",
307 | "min_signature = train_df[target].min()"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "execution_count": 7,
313 | "metadata": {},
314 | "outputs": [
315 | {
316 | "data": {
317 | "text/html": [
318 | ""
319 | ],
320 | "text/plain": [
321 | ""
322 | ]
323 | },
324 | "metadata": {},
325 | "output_type": "display_data"
326 | }
327 | ],
328 | "source": [
329 | "# I'm here because Haiku's display is wide. I'll add horizontal scrolling support to output cells\n",
330 | "from IPython.display import display, HTML\n",
331 | "display(HTML(\"\"))"
332 | ]
333 | },
334 | {
335 | "cell_type": "markdown",
336 | "metadata": {},
337 | "source": [
338 | "# Network Definition (modified)\n",
339 | "Modification in `net_forward` to allow user to modify based on `hidden_layers` from above"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": 8,
345 | "metadata": {
346 | "cellView": "form",
347 | "id": "jmBkFuYGu_j5"
348 | },
349 | "outputs": [
350 | {
351 | "name": "stdout",
352 | "output_type": "stream",
353 | "text": [
354 | "+----------------------------+------------------------------------------------------------------+-----------------+-----------+-----------+---------------+---------------+\n",
355 | "| Module | Config | Module params | Input | Output | Param count | Param bytes |\n",
356 | "+============================+==================================================================+=================+===========+===========+===============+===============+\n",
357 | "| sequential (Sequential) | Sequential( | | f32[1,17] | f32[1,14] | 204 | 816.00 B |\n",
358 | "| | layers=[Linear(output_size=5), | | | | | |\n",
359 | "| | >, | | | | | |\n",
360 | "| | Linear(output_size=5), | | | | | |\n",
361 | "| | >, | | | | | |\n",
362 | "| | Linear(output_size=14)], | | | | | |\n",
363 | "| | ) | | | | | |\n",
364 | "+----------------------------+------------------------------------------------------------------+-----------------+-----------+-----------+---------------+---------------+\n",
365 | "| linear (Linear) | Linear(output_size=5) | w: f32[17,5] | f32[1,17] | f32[1,5] | 90 | 360.00 B |\n",
366 | "| └ sequential (Sequential) | | b: f32[5] | | | | |\n",
367 | "+----------------------------+------------------------------------------------------------------+-----------------+-----------+-----------+---------------+---------------+\n",
368 | "| linear_1 (Linear) | Linear(output_size=5) | w: f32[5,5] | f32[1,5] | f32[1,5] | 30 | 120.00 B |\n",
369 | "| └ sequential (Sequential) | | b: f32[5] | | | | |\n",
370 | "+----------------------------+------------------------------------------------------------------+-----------------+-----------+-----------+---------------+---------------+\n",
371 | "| linear_2 (Linear) | Linear(output_size=14) | w: f32[5,14] | f32[1,5] | f32[1,14] | 84 | 336.00 B |\n",
372 | "| └ sequential (Sequential) | | b: f32[14] | | | | |\n",
373 | "+----------------------------+------------------------------------------------------------------+-----------------+-----------+-----------+---------------+---------------+\n"
374 | ]
375 | }
376 | ],
377 | "source": [
378 | "# @title Network Definition\n",
379 | "\n",
380 | "#@markdown Create a simple feedforward network, using the DM Haiku library\n",
381 | "#@markdown (https://github.com/deepmind/dm-haiku).\n",
382 | "\n",
383 | "#@markdown The output of the network is a predicted categorical distribution, represented\n",
384 | "#@markdown by a vector q, where softmax(q)[i] is the predicted probability that the\n",
385 | "#@markdown signature of the knot is equal to 2*i + min_signature. Note that signature is\n",
386 | "#@markdown always an even integer.\n",
387 | "\n",
388 | "#@markdown We take the cross entropy between this distribution and the true distribution\n",
389 | "#@markdown (i.e. 1 at the true value of the signature, 0 everywhere else) as the loss\n",
390 | "#@markdown function.\n",
391 | "def net_forward(inp, use_fix=True):\n",
392 | " net = []\n",
393 | " for layer in hidden_layers:\n",
394 | " net.append(hk.Linear(layer))\n",
395 | " net.append(jax.nn.sigmoid)\n",
396 | " net.append(hk.Linear(int((max_signature - min_signature) / 2) + (1 if use_fix else 0)))\n",
397 | " return hk.Sequential(net)(inp)\n",
398 | "\n",
399 | "# Print out network and parameter info\n",
400 | "x = jnp.zeros((1, 17))\n",
401 | "print(hk.experimental.tabulate(net_forward)(x))\n",
402 | "# Uncomment if you want line by line\n",
403 | "#for method_invocation in hk.experimental.eval_summary(net_forward)(x):\n",
404 | "# print(method_invocation)\n",
405 | "\n",
406 | "def softmax_cross_entropy(logits, labels):\n",
407 | " # Labels are the true values of the signature\n",
408 | " one_hot = jax.nn.one_hot((labels - min_signature) / 2, logits.shape[-1])\n",
409 | " return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)\n",
410 | "\n",
411 | "\n",
412 | "# The cross-entropy loss is composed with the network predictions, to define\n",
413 | "# `loss_fn` as a function on X and y.\n",
414 | "def loss_fn(inputs, labels):\n",
415 | " return jnp.mean(softmax_cross_entropy(net_forward(inputs), labels))\n",
416 | "\n",
417 | "\n",
418 | "# Haiku network transformation steps.\n",
419 | "loss_fn_t = hk.without_apply_rng(hk.transform(loss_fn))\n",
420 | "net_forward_t = hk.without_apply_rng(hk.transform(net_forward))\n",
421 | "\n",
422 | "\n",
423 | "@jax.jit\n",
424 | "def predict(params, data_X):\n",
425 | " return (np.argmax(net_forward_t.apply(params, data_X), axis=1) * 2 +\n",
426 | " min_signature)\n",
427 | "\n",
428 | "\n",
429 | "#@markdown Calculate the mean and standard deviation over each column in the training\n",
430 | "#@markdown dataset. We use this to normalize each feature, this is best practice for\n",
431 | "#@markdown inputting features into a network, but is also very important in this case\n",
432 | "#@markdown to ensure the gradients used for saliency are meaningfully comparable.\n",
433 | "def normalize_features(df, cols, add_target=True):\n",
434 | " features = df[cols]\n",
435 | " sigma = features.std()\n",
436 | " if any(sigma == 0):\n",
437 | " print(sigma)\n",
438 | " raise RuntimeError(\n",
439 | " \"A poor data stratification has led to no variation in one of the data \"\n",
440 | " \"splits for at least one feature (ie std=0). Restratify and try again.\")\n",
441 | " mu = features.mean()\n",
442 | " normed_df = (features - mu) / sigma\n",
443 | " if add_target:\n",
444 | " normed_df[target] = df[target]\n",
445 | " return normed_df\n",
446 | "\n",
447 | "\n",
448 | "def get_batch(df, cols, size=None):\n",
449 | " batch_df = df if size is None else df.sample(size)\n",
450 | " X = batch_df[cols].to_numpy()\n",
451 | " y = batch_df[target].to_numpy()\n",
452 | " return X, y\n",
453 | "\n",
454 | "\n",
455 | "normed_train_df = normalize_features(train_df, column_names)\n",
456 | "# Sometimes this line will cause a RuntimeError when running smaller networks. Just restart the kernel and try again.\n",
457 | "normed_validation_df = normalize_features(validation_df, column_names)\n",
458 | "normed_test_df = normalize_features(test_df, column_names)"
459 | ]
460 | },
461 | {
462 | "cell_type": "markdown",
463 | "metadata": {},
464 | "source": [
465 | "# Network Setup"
466 | ]
467 | },
468 | {
469 | "cell_type": "code",
470 | "execution_count": 9,
471 | "metadata": {
472 | "cellView": "form",
473 | "id": "_ZyvOA1g_08D"
474 | },
475 | "outputs": [
476 | {
477 | "name": "stderr",
478 | "output_type": "stream",
479 | "text": [
480 | "/opt/homebrew/anaconda3/lib/python3.11/site-packages/haiku/_src/initializers.py:127: UserWarning: Explicitly requested dtype float64 is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n",
481 | " unscaled = jax.random.truncated_normal(\n",
482 | "/opt/homebrew/anaconda3/lib/python3.11/site-packages/haiku/_src/base.py:658: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n",
483 | " param = init(shape, dtype)\n"
484 | ]
485 | }
486 | ],
487 | "source": [
488 | "# @title Network Setup (re-run before re-training)\n",
489 | "\n",
490 | "train_X, train_y = get_batch(normed_train_df, column_names, batch_size)\n",
491 | "\n",
492 | "# Pick a random seed for the network weights. To check the robustness of the\n",
493 | "# model, the training process can be repeated with multiple different random\n",
494 | "# seeds.\n",
495 | "rng = jax.random.PRNGKey(network_weight_rng)\n",
496 | "init_params = loss_fn_t.init(rng, train_X, train_y)\n",
497 | "\n",
498 | "opt_init, opt_update = optax.adam(learning_rate)\n",
499 | "opt_state = opt_init(init_params)"
500 | ]
501 | },
502 | {
503 | "cell_type": "markdown",
504 | "metadata": {},
505 | "source": [
506 | "### Network Training"
507 | ]
508 | },
509 | {
510 | "cell_type": "code",
511 | "execution_count": 10,
512 | "metadata": {
513 | "cellView": "form",
514 | "id": "sofhcPK4vZWl"
515 | },
516 | "outputs": [
517 | {
518 | "name": "stdout",
519 | "output_type": "stream",
520 | "text": [
521 | "Step count: 0\n",
522 | "Train loss: 2.8116812705993652\n",
523 | "Validation loss: 2.7941524982452393\n",
524 | "Step count: 100\n",
525 | "Train loss: 2.384979724884033\n",
526 | "Validation loss: 2.407904863357544\n",
527 | "Step count: 200\n",
528 | "Train loss: 2.1264162063598633\n",
529 | "Validation loss: 2.1435587406158447\n",
530 | "Step count: 300\n",
531 | "Train loss: 2.0061676502227783\n",
532 | "Validation loss: 1.9773170948028564\n",
533 | "Step count: 400\n",
534 | "Train loss: 1.8468494415283203\n",
535 | "Validation loss: 1.8776098489761353\n",
536 | "Step count: 500\n",
537 | "Train loss: 1.8516018390655518\n",
538 | "Validation loss: 1.811793327331543\n",
539 | "Step count: 600\n",
540 | "Train loss: 1.715728759765625\n",
541 | "Validation loss: 1.760715365409851\n",
542 | "Step count: 700\n",
543 | "Train loss: 1.6884984970092773\n",
544 | "Validation loss: 1.717073678970337\n",
545 | "Step count: 800\n",
546 | "Train loss: 1.6176589727401733\n",
547 | "Validation loss: 1.6789257526397705\n",
548 | "Step count: 900\n",
549 | "Train loss: 1.4370570182800293\n",
550 | "Validation loss: 1.6448426246643066\n",
551 | "Step count: 1000\n",
552 | "Train loss: 1.5930187702178955\n",
553 | "Validation loss: 1.6133735179901123\n",
554 | "Step count: 1100\n",
555 | "Train loss: 1.6224822998046875\n",
556 | "Validation loss: 1.584986925125122\n",
557 | "Step count: 1200\n",
558 | "Train loss: 1.622300148010254\n",
559 | "Validation loss: 1.5556803941726685\n",
560 | "Step count: 1300\n",
561 | "Train loss: 1.444936990737915\n",
562 | "Validation loss: 1.5245068073272705\n",
563 | "Step count: 1400\n",
564 | "Train loss: 1.627079963684082\n",
565 | "Validation loss: 1.491458773612976\n",
566 | "Step count: 1500\n",
567 | "Train loss: 1.406960129737854\n",
568 | "Validation loss: 1.4536099433898926\n",
569 | "Step count: 1600\n",
570 | "Train loss: 1.3565831184387207\n",
571 | "Validation loss: 1.4108202457427979\n",
572 | "Step count: 1700\n",
573 | "Train loss: 1.4146857261657715\n",
574 | "Validation loss: 1.3642162084579468\n",
575 | "Step count: 1800\n",
576 | "Train loss: 1.1749764680862427\n",
577 | "Validation loss: 1.3149152994155884\n",
578 | "Step count: 1900\n",
579 | "Train loss: 1.188778042793274\n",
580 | "Validation loss: 1.264180302619934\n",
581 | "Step count: 2000\n",
582 | "Train loss: 1.3288445472717285\n",
583 | "Validation loss: 1.2146958112716675\n",
584 | "Step count: 2100\n",
585 | "Train loss: 1.3042057752609253\n",
586 | "Validation loss: 1.1659657955169678\n",
587 | "Step count: 2200\n",
588 | "Train loss: 1.248848557472229\n",
589 | "Validation loss: 1.1203685998916626\n",
590 | "Step count: 2300\n",
591 | "Train loss: 1.1984355449676514\n",
592 | "Validation loss: 1.078598141670227\n",
593 | "Step count: 2400\n",
594 | "Train loss: 1.2042495012283325\n",
595 | "Validation loss: 1.0411516427993774\n",
596 | "Step count: 2500\n",
597 | "Train loss: 1.1562697887420654\n",
598 | "Validation loss: 1.0058705806732178\n",
599 | "Step count: 2600\n",
600 | "Train loss: 0.8232065439224243\n",
601 | "Validation loss: 0.9747156500816345\n",
602 | "Step count: 2700\n",
603 | "Train loss: 0.9761694669723511\n",
604 | "Validation loss: 0.9461800456047058\n",
605 | "Step count: 2800\n",
606 | "Train loss: 0.7596954107284546\n",
607 | "Validation loss: 0.9209511876106262\n",
608 | "Step count: 2900\n",
609 | "Train loss: 0.9237159490585327\n",
610 | "Validation loss: 0.8975797891616821\n",
611 | "Step count: 3000\n",
612 | "Train loss: 0.8497984409332275\n",
613 | "Validation loss: 0.8760898113250732\n",
614 | "Step count: 3100\n",
615 | "Train loss: 0.7388738393783569\n",
616 | "Validation loss: 0.8562992215156555\n",
617 | "Step count: 3200\n",
618 | "Train loss: 0.9391216039657593\n",
619 | "Validation loss: 0.8379098773002625\n",
620 | "Step count: 3300\n",
621 | "Train loss: 0.8053942918777466\n",
622 | "Validation loss: 0.820518970489502\n",
623 | "Step count: 3400\n",
624 | "Train loss: 1.1734966039657593\n",
625 | "Validation loss: 0.8035964369773865\n",
626 | "Step count: 3500\n",
627 | "Train loss: 0.897958517074585\n",
628 | "Validation loss: 0.7881501317024231\n",
629 | "Step count: 3600\n",
630 | "Train loss: 0.7107032537460327\n",
631 | "Validation loss: 0.7737591862678528\n",
632 | "Step count: 3700\n",
633 | "Train loss: 0.7133588790893555\n",
634 | "Validation loss: 0.7591407895088196\n",
635 | "Step count: 3800\n",
636 | "Train loss: 0.8879278302192688\n",
637 | "Validation loss: 0.745370090007782\n",
638 | "Step count: 3900\n",
639 | "Train loss: 0.7890059947967529\n",
640 | "Validation loss: 0.7329174876213074\n",
641 | "Step count: 4000\n",
642 | "Train loss: 0.5831533670425415\n",
643 | "Validation loss: 0.721225917339325\n",
644 | "Step count: 4100\n",
645 | "Train loss: 0.7699029445648193\n",
646 | "Validation loss: 0.7102410197257996\n",
647 | "Step count: 4200\n",
648 | "Train loss: 0.6830381751060486\n",
649 | "Validation loss: 0.6998180150985718\n",
650 | "Step count: 4300\n",
651 | "Train loss: 0.6156393885612488\n",
652 | "Validation loss: 0.6894598603248596\n",
653 | "Step count: 4400\n",
654 | "Train loss: 0.8743906021118164\n",
655 | "Validation loss: 0.6800310015678406\n",
656 | "Step count: 4500\n",
657 | "Train loss: 0.6473797559738159\n",
658 | "Validation loss: 0.6715149283409119\n",
659 | "Step count: 4600\n",
660 | "Train loss: 0.7979230880737305\n",
661 | "Validation loss: 0.6632071733474731\n",
662 | "Step count: 4700\n",
663 | "Train loss: 0.6013624668121338\n",
664 | "Validation loss: 0.6549491286277771\n",
665 | "Step count: 4800\n",
666 | "Train loss: 0.6401097774505615\n",
667 | "Validation loss: 0.6469190120697021\n",
668 | "Step count: 4900\n",
669 | "Train loss: 0.7246673107147217\n",
670 | "Validation loss: 0.6397194266319275\n",
671 | "Step count: 5000\n",
672 | "Train loss: 0.5493844747543335\n",
673 | "Validation loss: 0.6325175762176514\n",
674 | "Step count: 5100\n",
675 | "Train loss: 0.6013330817222595\n",
676 | "Validation loss: 0.6253847479820251\n",
677 | "Step count: 5200\n",
678 | "Train loss: 0.5418999195098877\n",
679 | "Validation loss: 0.6195732951164246\n",
680 | "Step count: 5300\n",
681 | "Train loss: 0.5872135162353516\n",
682 | "Validation loss: 0.6133323311805725\n",
683 | "Step count: 5400\n",
684 | "Train loss: 0.6665270328521729\n",
685 | "Validation loss: 0.6069185137748718\n",
686 | "Step count: 5500\n",
687 | "Train loss: 0.544267475605011\n",
688 | "Validation loss: 0.6018220782279968\n",
689 | "Step count: 5600\n",
690 | "Train loss: 0.5228955745697021\n",
691 | "Validation loss: 0.5958737730979919\n",
692 | "Step count: 5700\n",
693 | "Train loss: 0.5577840805053711\n",
694 | "Validation loss: 0.5900106430053711\n",
695 | "Step count: 5800\n",
696 | "Train loss: 0.5805840492248535\n",
697 | "Validation loss: 0.5850721597671509\n",
698 | "Step count: 5900\n",
699 | "Train loss: 0.8313626050949097\n",
700 | "Validation loss: 0.5806062817573547\n",
701 | "Step count: 6000\n",
702 | "Train loss: 0.5863834619522095\n",
703 | "Validation loss: 0.5754520893096924\n",
704 | "Step count: 6100\n",
705 | "Train loss: 0.46375811100006104\n",
706 | "Validation loss: 0.5714226365089417\n",
707 | "Step count: 6200\n",
708 | "Train loss: 0.5163490772247314\n",
709 | "Validation loss: 0.5665392279624939\n",
710 | "Step count: 6300\n",
711 | "Train loss: 0.5626010894775391\n",
712 | "Validation loss: 0.5617250204086304\n",
713 | "Step count: 6400\n",
714 | "Train loss: 0.484112024307251\n",
715 | "Validation loss: 0.5571351647377014\n",
716 | "Step count: 6500\n",
717 | "Train loss: 0.5184279680252075\n",
718 | "Validation loss: 0.5530323386192322\n",
719 | "Step count: 6600\n",
720 | "Train loss: 0.5501329898834229\n",
721 | "Validation loss: 0.548738956451416\n",
722 | "Step count: 6700\n",
723 | "Train loss: 0.5456504821777344\n",
724 | "Validation loss: 0.5449445843696594\n",
725 | "Step count: 6800\n",
726 | "Train loss: 0.454240083694458\n",
727 | "Validation loss: 0.5415924787521362\n",
728 | "Step count: 6900\n",
729 | "Train loss: 0.5784645080566406\n",
730 | "Validation loss: 0.5375749468803406\n",
731 | "Step count: 7000\n",
732 | "Train loss: 0.492913156747818\n",
733 | "Validation loss: 0.5337445735931396\n",
734 | "Step count: 7100\n",
735 | "Train loss: 0.6137053966522217\n",
736 | "Validation loss: 0.5305295586585999\n",
737 | "Step count: 7200\n",
738 | "Train loss: 0.6540536880493164\n",
739 | "Validation loss: 0.5265972018241882\n",
740 | "Step count: 7300\n",
741 | "Train loss: 0.5247725248336792\n",
742 | "Validation loss: 0.5231927633285522\n",
743 | "Step count: 7400\n",
744 | "Train loss: 0.6034173965454102\n",
745 | "Validation loss: 0.520044207572937\n",
746 | "Step count: 7500\n",
747 | "Train loss: 0.5097901821136475\n",
748 | "Validation loss: 0.5166345834732056\n",
749 | "Step count: 7600\n",
750 | "Train loss: 0.5273083448410034\n",
751 | "Validation loss: 0.5132884979248047\n",
752 | "Step count: 7700\n",
753 | "Train loss: 0.6435624361038208\n",
754 | "Validation loss: 0.5110858082771301\n",
755 | "Step count: 7800\n",
756 | "Train loss: 0.4839707911014557\n",
757 | "Validation loss: 0.5074858665466309\n",
758 | "Step count: 7900\n",
759 | "Train loss: 0.5335568189620972\n",
760 | "Validation loss: 0.5046110153198242\n",
761 | "Step count: 8000\n",
762 | "Train loss: 0.44616270065307617\n",
763 | "Validation loss: 0.5017329454421997\n",
764 | "Step count: 8100\n",
765 | "Train loss: 0.5774402618408203\n",
766 | "Validation loss: 0.4991530478000641\n",
767 | "Step count: 8200\n",
768 | "Train loss: 0.36768782138824463\n",
769 | "Validation loss: 0.4969697594642639\n",
770 | "Step count: 8300\n",
771 | "Train loss: 0.44016486406326294\n",
772 | "Validation loss: 0.4940197467803955\n",
773 | "Step count: 8400\n",
774 | "Train loss: 0.5363975167274475\n",
775 | "Validation loss: 0.4918510913848877\n",
776 | "Step count: 8500\n",
777 | "Train loss: 0.424763023853302\n",
778 | "Validation loss: 0.48892942070961\n",
779 | "Step count: 8600\n",
780 | "Train loss: 0.4969513416290283\n",
781 | "Validation loss: 0.48722168803215027\n",
782 | "Step count: 8700\n",
783 | "Train loss: 0.37056827545166016\n",
784 | "Validation loss: 0.4848025143146515\n",
785 | "Step count: 8800\n",
786 | "Train loss: 0.391757071018219\n",
787 | "Validation loss: 0.482174813747406\n",
788 | "Step count: 8900\n",
789 | "Train loss: 0.4947183132171631\n",
790 | "Validation loss: 0.48007288575172424\n",
791 | "Step count: 9000\n",
792 | "Train loss: 0.45929282903671265\n",
793 | "Validation loss: 0.4781815707683563\n",
794 | "Step count: 9100\n",
795 | "Train loss: 0.47612375020980835\n",
796 | "Validation loss: 0.4760727882385254\n",
797 | "Step count: 9200\n",
798 | "Train loss: 0.44101282954216003\n",
799 | "Validation loss: 0.4746829867362976\n",
800 | "Step count: 9300\n",
801 | "Train loss: 0.37658315896987915\n",
802 | "Validation loss: 0.47280576825141907\n",
803 | "Step count: 9400\n",
804 | "Train loss: 0.3759828209877014\n",
805 | "Validation loss: 0.47096705436706543\n",
806 | "Step count: 9500\n",
807 | "Train loss: 0.5561549067497253\n",
808 | "Validation loss: 0.4691046476364136\n",
809 | "Step count: 9600\n",
810 | "Train loss: 0.5512958765029907\n",
811 | "Validation loss: 0.4677446782588959\n",
812 | "Step count: 9700\n",
813 | "Train loss: 0.4368234872817993\n",
814 | "Validation loss: 0.46550050377845764\n",
815 | "Step count: 9800\n",
816 | "Train loss: 0.4498073160648346\n",
817 | "Validation loss: 0.4638981521129608\n",
818 | "Step count: 9900\n",
819 | "Train loss: 0.45500627160072327\n",
820 | "Validation loss: 0.4625166654586792\n",
821 | "(5, 5)\n",
822 | "Test Accuracy: 0.819423\n"
823 | ]
824 | }
825 | ],
826 | "source": [
827 | "# @title Network Training\n",
828 | "\n",
829 | "# We train until the validation loss stops decreasing, checking every steps,\n",
830 | "# up to a maximum of 10k steps.\n",
831 | "\n",
832 | "\n",
833 | "@jax.jit\n",
834 | "def update(params, opt_state, batch_X, batch_y):\n",
835 | " grads = jax.grad(loss_fn_t.apply)(params, batch_X, batch_y)\n",
836 | " upds, new_opt_state = opt_update(grads, opt_state)\n",
837 | " new_params = optax.apply_updates(params, upds)\n",
838 | " return new_params, new_opt_state\n",
839 | "\n",
840 | "\n",
841 | "def train(columns_to_train_on, params, opt_state, update_fn):\n",
842 | " best_validation_loss = np.inf\n",
843 | " for i in range(num_training_steps):\n",
844 | " train_X, train_y = get_batch(normed_train_df, columns_to_train_on,\n",
845 | " batch_size)\n",
846 | " params, opt_state = update_fn(params, opt_state, train_X, train_y)\n",
847 | "\n",
848 | " if i % validation_interval == 0:\n",
849 | " # Run validation on the full validation dataset.\n",
850 | " validation_X, validation_y = get_batch(normed_validation_df,\n",
851 | " columns_to_train_on)\n",
852 | " train_loss = loss_fn_t.apply(params, train_X, train_y)\n",
853 | " validation_loss = loss_fn_t.apply(params, validation_X, validation_y)\n",
854 | " print(f\"Step count: {i}\")\n",
855 | " print(f\"Train loss: {train_loss}\")\n",
856 | " print(f\"Validation loss: {validation_loss}\")\n",
857 | "\n",
858 | " if validation_loss > best_validation_loss:\n",
859 | " print(\"Validation loss increased. Stopping!\")\n",
860 | " return params\n",
861 | " else:\n",
862 | " best_validation_loss = validation_loss\n",
863 | " return params\n",
864 | "\n",
865 | "\n",
866 | "trained_params = train(column_names, init_params, opt_state, update)\n",
867 | "# Print the test accuracy, i.e. the proportion of the knots for which the\n",
868 | "# network predicts the correct signature.\n",
869 | "test_X, test_y = get_batch(normed_test_df, column_names)\n",
870 | "print(trained_params[\"linear_1\"][\"w\"].shape)\n",
871 | "# The final below accuracy should be in the low 80%s.\n",
872 | "print(\"Test Accuracy: \",\n",
873 | " np.mean((predict(trained_params, test_X) - test_y) == 0))"
874 | ]
875 | },
876 | {
877 | "cell_type": "markdown",
878 | "metadata": {
879 | "id": "qj0TjRTnAmGW"
880 | },
881 | "source": [
882 | "The below cell replicates Figure 2a from the paper, though for simplicity omitting the error bars.\n",
883 | "\n",
884 | "To compute the saliency, we take the gradient of the loss function, with respect to each of the the components of the X input to the network (i.e. the geometric invariants of each knot), and average over the dataset.\n",
885 | "\n",
886 | "We plot the feature saliencies in decreasing order. The plot (should!) show that the overall loss is influenced far more by the longitudinal translation and the real and imaginary parts of the meridional translation than any of the other invariants."
887 | ]
888 | },
889 | {
890 | "cell_type": "markdown",
891 | "metadata": {},
892 | "source": [
893 | "### Saliency Analysis"
894 | ]
895 | },
896 | {
897 | "cell_type": "code",
898 | "execution_count": 11,
899 | "metadata": {
900 | "cellView": "form",
901 | "id": "j1CEbkUgxtGZ",
902 | "scrolled": true
903 | },
904 | "outputs": [
905 | {
906 | "data": {
907 | "image/png": "",
908 | "text/plain": [
909 | ""
910 | ]
911 | },
912 | "metadata": {},
913 | "output_type": "display_data"
914 | }
915 | ],
916 | "source": [
917 | "# @title Saliency Analysis\n",
918 | "train_X = normalize_features(train_df, column_names, add_target=False).to_numpy()\n",
919 | "train_y = train_df[target].to_numpy()\n",
920 | "\n",
921 | "\n",
922 | "saliencies = np.mean(\n",
923 | " np.abs(jax.grad(loss_fn_t.apply, 1)(trained_params, train_X, train_y)), axis=0)\n",
924 | "\n",
925 | "\n",
926 | "decreasing_saliency = reversed(sorted(zip(saliencies, display_name_from_short_name.values())))\n",
927 | "sorted_saliencies, sorted_columns = zip(*decreasing_saliency)\n",
928 | "\n",
929 | "fig, ax = plt.subplots(figsize=(6,8))\n",
930 | "sns.barplot(y=np.array(sorted_columns),\n",
931 | " x=np.array(sorted_saliencies) / max(sorted_saliencies),\n",
932 | " color=\"#0077c6\");\n",
933 | "\n",
934 | "ax.tick_params(labelsize=15);\n",
935 | "ax.set_ylabel('Geometric invariants X(z)', fontsize=20);\n",
936 | "plt.xlabel('Normalized attribution score', fontsize=20);"
937 | ]
938 | },
939 | {
940 | "cell_type": "markdown",
941 | "metadata": {},
942 | "source": [
943 | "### Confirming The Feature Saliency"
944 | ]
945 | },
946 | {
947 | "cell_type": "code",
948 | "execution_count": 12,
949 | "metadata": {
950 | "cellView": "form",
951 | "id": "EsDbwItnlNUY"
952 | },
953 | "outputs": [],
954 | "source": [
955 | "# @title Confirming the Feature Saliency\n",
956 | "\n",
957 | "#@markdown To confirm the results of the saliency analysis, we re-train the network with\n",
958 | "#@markdown only these three features as input to the network.\n",
959 | "salient_column_names = ['longitudinal_translation',\n",
960 | " 'meridinal_translation_imag',\n",
961 | " 'meridinal_translation_real']\n",
962 | "target = 'signature'"
963 | ]
964 | },
965 | {
966 | "cell_type": "markdown",
967 | "metadata": {},
968 | "source": [
969 | "### Confirming the Feature Saliency: Network Setup"
970 | ]
971 | },
972 | {
973 | "cell_type": "code",
974 | "execution_count": 13,
975 | "metadata": {
976 | "cellView": "form",
977 | "id": "SKY75X_sEsy2"
978 | },
979 | "outputs": [
980 | {
981 | "name": "stderr",
982 | "output_type": "stream",
983 | "text": [
984 | "/opt/homebrew/anaconda3/lib/python3.11/site-packages/haiku/_src/initializers.py:127: UserWarning: Explicitly requested dtype float64 is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n",
985 | " unscaled = jax.random.truncated_normal(\n",
986 | "/opt/homebrew/anaconda3/lib/python3.11/site-packages/haiku/_src/base.py:658: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n",
987 | " param = init(shape, dtype)\n"
988 | ]
989 | }
990 | ],
991 | "source": [
992 | "# @title Confirming the Feature Saliency: Network Setup (re-run before re-training)\n",
993 | "\n",
994 | "train_X, train_y = get_batch(normed_train_df, salient_column_names, batch_size)\n",
995 | "\n",
996 | "init_params_salient = loss_fn_t.init(rng, train_X, train_y)\n",
997 | "\n",
998 | "opt_init_salient, opt_update_salient = optax.adam(learning_rate)\n",
999 | "opt_state_salient = opt_init(init_params_salient)"
1000 | ]
1001 | },
1002 | {
1003 | "cell_type": "markdown",
1004 | "metadata": {},
1005 | "source": [
1006 | "### Confirming the Feature Saliency: Network Training"
1007 | ]
1008 | },
1009 | {
1010 | "cell_type": "code",
1011 | "execution_count": 14,
1012 | "metadata": {
1013 | "cellView": "form",
1014 | "id": "lGOfm424lXtR"
1015 | },
1016 | "outputs": [
1017 | {
1018 | "name": "stdout",
1019 | "output_type": "stream",
1020 | "text": [
1021 | "Step count: 0\n",
1022 | "Train loss: 2.7614235877990723\n",
1023 | "Validation loss: 2.805657386779785\n",
1024 | "Step count: 100\n",
1025 | "Train loss: 2.388213872909546\n",
1026 | "Validation loss: 2.419311285018921\n",
1027 | "Step count: 200\n",
1028 | "Train loss: 2.1496756076812744\n",
1029 | "Validation loss: 2.161698579788208\n",
1030 | "Step count: 300\n",
1031 | "Train loss: 1.9615678787231445\n",
1032 | "Validation loss: 1.9971750974655151\n",
1033 | "Step count: 400\n",
1034 | "Train loss: 1.9326486587524414\n",
1035 | "Validation loss: 1.896746039390564\n",
1036 | "Step count: 500\n",
1037 | "Train loss: 1.7300117015838623\n",
1038 | "Validation loss: 1.826549768447876\n",
1039 | "Step count: 600\n",
1040 | "Train loss: 1.8089721202850342\n",
1041 | "Validation loss: 1.7697826623916626\n",
1042 | "Step count: 700\n",
1043 | "Train loss: 1.8630174398422241\n",
1044 | "Validation loss: 1.7183831930160522\n",
1045 | "Step count: 800\n",
1046 | "Train loss: 1.7237074375152588\n",
1047 | "Validation loss: 1.6680281162261963\n",
1048 | "Step count: 900\n",
1049 | "Train loss: 1.5447618961334229\n",
1050 | "Validation loss: 1.618795394897461\n",
1051 | "Step count: 1000\n",
1052 | "Train loss: 1.5312609672546387\n",
1053 | "Validation loss: 1.5699970722198486\n",
1054 | "Step count: 1100\n",
1055 | "Train loss: 1.4775203466415405\n",
1056 | "Validation loss: 1.5210233926773071\n",
1057 | "Step count: 1200\n",
1058 | "Train loss: 1.4869463443756104\n",
1059 | "Validation loss: 1.4717947244644165\n",
1060 | "Step count: 1300\n",
1061 | "Train loss: 1.3452095985412598\n",
1062 | "Validation loss: 1.4213618040084839\n",
1063 | "Step count: 1400\n",
1064 | "Train loss: 1.3345348834991455\n",
1065 | "Validation loss: 1.3705226182937622\n",
1066 | "Step count: 1500\n",
1067 | "Train loss: 1.1913044452667236\n",
1068 | "Validation loss: 1.3196864128112793\n",
1069 | "Step count: 1600\n",
1070 | "Train loss: 1.2670183181762695\n",
1071 | "Validation loss: 1.2692861557006836\n",
1072 | "Step count: 1700\n",
1073 | "Train loss: 1.1062930822372437\n",
1074 | "Validation loss: 1.2203614711761475\n",
1075 | "Step count: 1800\n",
1076 | "Train loss: 1.1311137676239014\n",
1077 | "Validation loss: 1.1733242273330688\n",
1078 | "Step count: 1900\n",
1079 | "Train loss: 1.1015247106552124\n",
1080 | "Validation loss: 1.1291357278823853\n",
1081 | "Step count: 2000\n",
1082 | "Train loss: 1.2747454643249512\n",
1083 | "Validation loss: 1.0881578922271729\n",
1084 | "Step count: 2100\n",
1085 | "Train loss: 1.1164498329162598\n",
1086 | "Validation loss: 1.0503908395767212\n",
1087 | "Step count: 2200\n",
1088 | "Train loss: 0.9105275273323059\n",
1089 | "Validation loss: 1.0164942741394043\n",
1090 | "Step count: 2300\n",
1091 | "Train loss: 0.8813854455947876\n",
1092 | "Validation loss: 0.9849182963371277\n",
1093 | "Step count: 2400\n",
1094 | "Train loss: 0.9708192348480225\n",
1095 | "Validation loss: 0.956310510635376\n",
1096 | "Step count: 2500\n",
1097 | "Train loss: 0.8696013689041138\n",
1098 | "Validation loss: 0.9300026297569275\n",
1099 | "Step count: 2600\n",
1100 | "Train loss: 1.024273157119751\n",
1101 | "Validation loss: 0.9060561060905457\n",
1102 | "Step count: 2700\n",
1103 | "Train loss: 0.8595995903015137\n",
1104 | "Validation loss: 0.8840814232826233\n",
1105 | "Step count: 2800\n",
1106 | "Train loss: 0.7574213743209839\n",
1107 | "Validation loss: 0.8634595274925232\n",
1108 | "Step count: 2900\n",
1109 | "Train loss: 0.8112626075744629\n",
1110 | "Validation loss: 0.8452709317207336\n",
1111 | "Step count: 3000\n",
1112 | "Train loss: 0.8701050281524658\n",
1113 | "Validation loss: 0.8277906179428101\n",
1114 | "Step count: 3100\n",
1115 | "Train loss: 0.7930698394775391\n",
1116 | "Validation loss: 0.8116493225097656\n",
1117 | "Step count: 3200\n",
1118 | "Train loss: 0.8960763216018677\n",
1119 | "Validation loss: 0.795965850353241\n",
1120 | "Step count: 3300\n",
1121 | "Train loss: 0.7834410071372986\n",
1122 | "Validation loss: 0.7817390561103821\n",
1123 | "Step count: 3400\n",
1124 | "Train loss: 0.6940096616744995\n",
1125 | "Validation loss: 0.7686640620231628\n",
1126 | "Step count: 3500\n",
1127 | "Train loss: 0.7357538938522339\n",
1128 | "Validation loss: 0.7559365034103394\n",
1129 | "Step count: 3600\n",
1130 | "Train loss: 0.7318865060806274\n",
1131 | "Validation loss: 0.7439627647399902\n",
1132 | "Step count: 3700\n",
1133 | "Train loss: 0.7224804162979126\n",
1134 | "Validation loss: 0.7326063513755798\n",
1135 | "Step count: 3800\n",
1136 | "Train loss: 0.7348959445953369\n",
1137 | "Validation loss: 0.7221681475639343\n",
1138 | "Step count: 3900\n",
1139 | "Train loss: 0.6049644947052002\n",
1140 | "Validation loss: 0.7122893929481506\n",
1141 | "Step count: 4000\n",
1142 | "Train loss: 0.7843313813209534\n",
1143 | "Validation loss: 0.7031309008598328\n",
1144 | "Step count: 4100\n",
1145 | "Train loss: 0.7371485233306885\n",
1146 | "Validation loss: 0.6940381526947021\n",
1147 | "Step count: 4200\n",
1148 | "Train loss: 0.8364322185516357\n",
1149 | "Validation loss: 0.6851128935813904\n",
1150 | "Step count: 4300\n",
1151 | "Train loss: 0.7326945066452026\n",
1152 | "Validation loss: 0.6773727536201477\n",
1153 | "Step count: 4400\n",
1154 | "Train loss: 0.6783174276351929\n",
1155 | "Validation loss: 0.6696005463600159\n",
1156 | "Step count: 4500\n",
1157 | "Train loss: 0.614048957824707\n",
1158 | "Validation loss: 0.6618369221687317\n",
1159 | "Step count: 4600\n",
1160 | "Train loss: 0.7186390161514282\n",
1161 | "Validation loss: 0.6548063158988953\n",
1162 | "Step count: 4700\n",
1163 | "Train loss: 0.6813069581985474\n",
1164 | "Validation loss: 0.648128867149353\n",
1165 | "Step count: 4800\n",
1166 | "Train loss: 0.6567203998565674\n",
1167 | "Validation loss: 0.6416744589805603\n",
1168 | "Step count: 4900\n",
1169 | "Train loss: 0.6926349401473999\n",
1170 | "Validation loss: 0.6355624198913574\n",
1171 | "Step count: 5000\n",
1172 | "Train loss: 0.6608371138572693\n",
1173 | "Validation loss: 0.6299481987953186\n",
1174 | "Step count: 5100\n",
1175 | "Train loss: 0.8321611285209656\n",
1176 | "Validation loss: 0.6240999102592468\n",
1177 | "Step count: 5200\n",
1178 | "Train loss: 0.6036645174026489\n",
1179 | "Validation loss: 0.6185488700866699\n",
1180 | "Step count: 5300\n",
1181 | "Train loss: 0.6060733795166016\n",
1182 | "Validation loss: 0.6134203672409058\n",
1183 | "Step count: 5400\n",
1184 | "Train loss: 0.6677653193473816\n",
1185 | "Validation loss: 0.6078833937644958\n",
1186 | "Step count: 5500\n",
1187 | "Train loss: 0.4742782413959503\n",
1188 | "Validation loss: 0.6031250953674316\n",
1189 | "Step count: 5600\n",
1190 | "Train loss: 0.49872714281082153\n",
1191 | "Validation loss: 0.5985592603683472\n",
1192 | "Step count: 5700\n",
1193 | "Train loss: 0.6913710832595825\n",
1194 | "Validation loss: 0.593771755695343\n",
1195 | "Step count: 5800\n",
1196 | "Train loss: 0.6784710884094238\n",
1197 | "Validation loss: 0.5893182158470154\n",
1198 | "Step count: 5900\n",
1199 | "Train loss: 0.6741566061973572\n",
1200 | "Validation loss: 0.5848686695098877\n",
1201 | "Step count: 6000\n",
1202 | "Train loss: 0.6259003281593323\n",
1203 | "Validation loss: 0.5804926156997681\n",
1204 | "Step count: 6100\n",
1205 | "Train loss: 0.6865454316139221\n",
1206 | "Validation loss: 0.5764020085334778\n",
1207 | "Step count: 6200\n",
1208 | "Train loss: 0.5772806406021118\n",
1209 | "Validation loss: 0.5719350576400757\n",
1210 | "Step count: 6300\n",
1211 | "Train loss: 0.6780864596366882\n",
1212 | "Validation loss: 0.5680485367774963\n",
1213 | "Step count: 6400\n",
1214 | "Train loss: 0.4631125032901764\n",
1215 | "Validation loss: 0.5639679431915283\n",
1216 | "Step count: 6500\n",
1217 | "Train loss: 0.5941550731658936\n",
1218 | "Validation loss: 0.5604357719421387\n",
1219 | "Step count: 6600\n",
1220 | "Train loss: 0.6307265758514404\n",
1221 | "Validation loss: 0.5566143989562988\n",
1222 | "Step count: 6700\n",
1223 | "Train loss: 0.7405992746353149\n",
1224 | "Validation loss: 0.5531129837036133\n",
1225 | "Step count: 6800\n",
1226 | "Train loss: 0.526777982711792\n",
1227 | "Validation loss: 0.5495466589927673\n",
1228 | "Step count: 6900\n",
1229 | "Train loss: 0.6490035057067871\n",
1230 | "Validation loss: 0.5466834306716919\n",
1231 | "Step count: 7000\n",
1232 | "Train loss: 0.5304993987083435\n",
1233 | "Validation loss: 0.5427576303482056\n",
1234 | "Step count: 7100\n",
1235 | "Train loss: 0.53404301404953\n",
1236 | "Validation loss: 0.5399003028869629\n",
1237 | "Step count: 7200\n",
1238 | "Train loss: 0.487496554851532\n",
1239 | "Validation loss: 0.5371031165122986\n",
1240 | "Step count: 7300\n",
1241 | "Train loss: 0.45802122354507446\n",
1242 | "Validation loss: 0.5341663360595703\n",
1243 | "Step count: 7400\n",
1244 | "Train loss: 0.44041725993156433\n",
1245 | "Validation loss: 0.5314267873764038\n",
1246 | "Step count: 7500\n",
1247 | "Train loss: 0.531814455986023\n",
1248 | "Validation loss: 0.5284692645072937\n",
1249 | "Step count: 7600\n",
1250 | "Train loss: 0.47992217540740967\n",
1251 | "Validation loss: 0.526004433631897\n",
1252 | "Step count: 7700\n",
1253 | "Train loss: 0.5088521242141724\n",
1254 | "Validation loss: 0.5230643153190613\n",
1255 | "Step count: 7800\n",
1256 | "Train loss: 0.5828921794891357\n",
1257 | "Validation loss: 0.520179808139801\n",
1258 | "Step count: 7900\n",
1259 | "Train loss: 0.4943484663963318\n",
1260 | "Validation loss: 0.5175743103027344\n",
1261 | "Step count: 8000\n",
1262 | "Train loss: 0.552828311920166\n",
1263 | "Validation loss: 0.5151614546775818\n",
1264 | "Step count: 8100\n",
1265 | "Train loss: 0.6291115880012512\n",
1266 | "Validation loss: 0.5133667588233948\n",
1267 | "Step count: 8200\n",
1268 | "Train loss: 0.4879392683506012\n",
1269 | "Validation loss: 0.510779082775116\n",
1270 | "Step count: 8300\n",
1271 | "Train loss: 0.4349018633365631\n",
1272 | "Validation loss: 0.5087208151817322\n",
1273 | "Step count: 8400\n",
1274 | "Train loss: 0.41613534092903137\n",
1275 | "Validation loss: 0.5061413049697876\n",
1276 | "Step count: 8500\n",
1277 | "Train loss: 0.46150243282318115\n",
1278 | "Validation loss: 0.504405677318573\n",
1279 | "Step count: 8600\n",
1280 | "Train loss: 0.5156277418136597\n",
1281 | "Validation loss: 0.5025164484977722\n",
1282 | "Step count: 8700\n",
1283 | "Train loss: 0.6169106364250183\n",
1284 | "Validation loss: 0.5001534819602966\n",
1285 | "Step count: 8800\n",
1286 | "Train loss: 0.35082823038101196\n",
1287 | "Validation loss: 0.49816426634788513\n",
1288 | "Step count: 8900\n",
1289 | "Train loss: 0.5226665139198303\n",
1290 | "Validation loss: 0.4966347813606262\n",
1291 | "Step count: 9000\n",
1292 | "Train loss: 0.6887170672416687\n",
1293 | "Validation loss: 0.4947109520435333\n",
1294 | "Step count: 9100\n",
1295 | "Train loss: 0.6093366146087646\n",
1296 | "Validation loss: 0.49301907420158386\n",
1297 | "Step count: 9200\n",
1298 | "Train loss: 0.6447827219963074\n",
1299 | "Validation loss: 0.491354763507843\n",
1300 | "Step count: 9300\n",
1301 | "Train loss: 0.4595358967781067\n",
1302 | "Validation loss: 0.48962050676345825\n",
1303 | "Step count: 9400\n",
1304 | "Train loss: 0.564563512802124\n",
1305 | "Validation loss: 0.4886232018470764\n",
1306 | "Step count: 9500\n",
1307 | "Train loss: 0.5287905335426331\n",
1308 | "Validation loss: 0.48654934763908386\n",
1309 | "Step count: 9600\n",
1310 | "Train loss: 0.41884922981262207\n",
1311 | "Validation loss: 0.4851319193840027\n",
1312 | "Step count: 9700\n",
1313 | "Train loss: 0.3431329131126404\n",
1314 | "Validation loss: 0.48383355140686035\n",
1315 | "Step count: 9800\n",
1316 | "Train loss: 0.47992512583732605\n",
1317 | "Validation loss: 0.4823845624923706\n",
1318 | "Step count: 9900\n",
1319 | "Train loss: 0.39113670587539673\n",
1320 | "Validation loss: 0.4812586009502411\n",
1321 | "Test Accuracy: 0.81358105\n"
1322 | ]
1323 | }
1324 | ],
1325 | "source": [
1326 | "# @title Confirming the Feature Saliency: Network Training\n",
1327 | "\n",
1328 | "\n",
1329 | "#@markdown Re-train the network using only the most salient features.\n",
1330 | "@jax.jit\n",
1331 | "def update_salient(params, opt_state, batch_X, batch_y):\n",
1332 | " grads = jax.grad(loss_fn_t.apply)(params, batch_X, batch_y)\n",
1333 | " upds, new_opt_state = opt_update_salient(grads, opt_state)\n",
1334 | " new_params = optax.apply_updates(params, upds)\n",
1335 | " return new_params, new_opt_state\n",
1336 | "\n",
1337 | "\n",
1338 | "trained_params_salient = train(salient_column_names, init_params_salient,\n",
1339 | " opt_state_salient, update_salient)\n",
1340 | "\n",
1341 | "#@markdown Print the test accuracy. This should be very similar to the test accuracy in\n",
1342 | "#@markdown the case that all columns / invariants are included, demonstrating that most\n",
1343 | "#@markdown of the predictve information about the signature is contained in the three\n",
1344 | "#@markdown selected invariants.\n",
1345 | "test_X, test_y = get_batch(normed_test_df, salient_column_names)\n",
1346 | "\n",
1347 | "#@markdown The final below accuracy should be in the low 80%s, probably 0.8 -> 0.85\n",
1348 | "print(\"Test Accuracy: \",\n",
1349 | " np.mean((predict(trained_params_salient, test_X) - test_y) == 0))"
1350 | ]
1351 | },
1352 | {
1353 | "cell_type": "markdown",
1354 | "metadata": {},
1355 | "source": [
1356 | "### Slope vs. Signature: Proposed Linear Relationship"
1357 | ]
1358 | },
1359 | {
1360 | "cell_type": "code",
1361 | "execution_count": 15,
1362 | "metadata": {
1363 | "cellView": "form",
1364 | "id": "ZIqvjO_O8Zti"
1365 | },
1366 | "outputs": [],
1367 | "source": [
1368 | "# @title Slope vs. Signature: Proposed Linear Relationship\n",
1369 | "\n",
1370 | "#@markdown The quantity we proposed was the \"natural slope\", given by\n",
1371 | "#@markdown real(longitudinal_translation / meridinal_translation). We show that\n",
1372 | "#@markdown this is approximately twice the signature (up to a correction term based on\n",
1373 | "#@markdown other hyperbolic invariants) which we can check by comparing the predictions\n",
1374 | "#@markdown made by this rule to those made by the previously trained models.\n",
1375 | "\n",
1376 | "\n",
1377 | "def predict_signature_from_slope(data_X, min_signature, max_signature):\n",
1378 | " meridinal_translation = (\n",
1379 | " data_X['meridinal_translation_real'] +\n",
1380 | " 1j * data_X['meridinal_translation_imag'])\n",
1381 | " slope = data_X['longitudinal_translation'] / meridinal_translation\n",
1382 | " return slope.real / 2"
1383 | ]
1384 | },
1385 | {
1386 | "cell_type": "markdown",
1387 | "metadata": {},
1388 | "source": [
1389 | "### Proposed Linear Relationship: Scatter plot"
1390 | ]
1391 | },
1392 | {
1393 | "cell_type": "code",
1394 | "execution_count": 16,
1395 | "metadata": {
1396 | "cellView": "form",
1397 | "id": "KAwLtNFxTwCc"
1398 | },
1399 | "outputs": [
1400 | {
1401 | "data": {
1402 | "text/plain": [
1403 | "Text(0, 0.5, 'Signature')"
1404 | ]
1405 | },
1406 | "execution_count": 16,
1407 | "metadata": {},
1408 | "output_type": "execute_result"
1409 | },
1410 | {
1411 | "data": {
1412 | "image/png": "",
1413 | "text/plain": [
1414 | ""
1415 | ]
1416 | },
1417 | "metadata": {},
1418 | "output_type": "display_data"
1419 | }
1420 | ],
1421 | "source": [
1422 | "# @title Proposed Linear Relationship: Scatter plot\n",
1423 | "\n",
1424 | "#@markdown Scatter plot of the slope against predicted signature.\n",
1425 | "predictions = [\n",
1426 | " predict_signature_from_slope(x, min_signature, max_signature)\n",
1427 | " for _, x in test_df.iterrows()\n",
1428 | "]\n",
1429 | "\n",
1430 | "fig, ax = plt.subplots(figsize=(8, 8))\n",
1431 | "sns.scatterplot(\n",
1432 | " x=predictions, y=test_df[target], alpha=0.2)\n",
1433 | "ax.set_xlabel('Predicted Signature')\n",
1434 | "ax.set_ylabel('Signature')"
1435 | ]
1436 | },
1437 | {
1438 | "cell_type": "markdown",
1439 | "metadata": {},
1440 | "source": [
1441 | "### Proposed Linear Relationship: Test Accuracy\n"
1442 | ]
1443 | },
1444 | {
1445 | "cell_type": "code",
1446 | "execution_count": 17,
1447 | "metadata": {
1448 | "cellView": "form",
1449 | "id": "HkoLDdYV_Idv"
1450 | },
1451 | "outputs": [],
1452 | "source": [
1453 | "# @title Proposed Linear Relationship: Test Accuracy\n",
1454 | "\n",
1455 | "\n",
1456 | "#@markdown In order to compute the \"test accuracy\" in the same way as before, we quantize\n",
1457 | "#@markdown the predicted signature values to even integers, between min_signature and\n",
1458 | "#@markdown max_signature.\n",
1459 | "def quantize(x, min_signature, max_signature):\n",
1460 | " return min(max(2 * round(x / 2), min_signature), max_signature)\n",
1461 | "\n",
1462 | "\n",
1463 | "quantized_predictions = [\n",
1464 | " quantize(x, min_signature, max_signature) for x in predictions\n",
1465 | "]"
1466 | ]
1467 | },
1468 | {
1469 | "cell_type": "markdown",
1470 | "metadata": {
1471 | "id": "ZhJAExt6739Y"
1472 | },
1473 | "source": [
1474 | "Now we can compute the \"test accuracy\" of this prediction.\n",
1475 | "\n",
1476 | "The value is slightly lower than for the trained network predictors (although still far higher than chance), but this is not unexpected.\n",
1477 | "\n",
1478 | "Indeed, the proposed rule gives a provable bound on the signature over all knots, instead of maximizing prediction performance over a given finite dataset, as the networks are doing. Although we do use separate training, validation and test datasets for the networks, these are all drawn from approximately the same distribution, whereas the proposed rule could be considered (in some imprecise sense) to have been \"trained\" on the set of all knots, a very different distribution.\n",
1479 | "\n",
1480 | "The correction term may also have some bias, for example tending to be positive more often than it is negative, information which the network predictors would be able to use to increase their prediction performance relative to that of the proposed rule."
1481 | ]
1482 | },
1483 | {
1484 | "cell_type": "markdown",
1485 | "metadata": {},
1486 | "source": [
1487 | "### Network output here is fixed (unmodified)\n",
1488 | "We did not modify this part but the result is always `0.738192917391447` if you are using a seed of 2"
1489 | ]
1490 | },
1491 | {
1492 | "cell_type": "code",
1493 | "execution_count": 18,
1494 | "metadata": {
1495 | "id": "2m-RE-PQ74GC"
1496 | },
1497 | "outputs": [
1498 | {
1499 | "name": "stdout",
1500 | "output_type": "stream",
1501 | "text": [
1502 | "Test Accuracy: 0.738192917391447\n"
1503 | ]
1504 | }
1505 | ],
1506 | "source": [
1507 | "#@markdown The below accuracy will probably be lower than the previous ~80%, but not by much, likely still >70%\n",
1508 | "print(\"Test Accuracy: \", np.mean(test_df[target] - quantized_predictions == 0))"
1509 | ]
1510 | }
1511 | ],
1512 | "metadata": {
1513 | "colab": {
1514 | "collapsed_sections": [],
1515 | "name": "Knot_theory",
1516 | "private_outputs": true,
1517 | "provenance": [
1518 | {
1519 | "file_id": "1tKuhjAHpuTUr5fs683eWFNPAuhmLt0uP",
1520 | "timestamp": 1632418370828
1521 | }
1522 | ]
1523 | },
1524 | "kernelspec": {
1525 | "display_name": "Python 3 (ipykernel)",
1526 | "language": "python",
1527 | "name": "python3"
1528 | },
1529 | "language_info": {
1530 | "codemirror_mode": {
1531 | "name": "ipython",
1532 | "version": 3
1533 | },
1534 | "file_extension": ".py",
1535 | "mimetype": "text/x-python",
1536 | "name": "python",
1537 | "nbconvert_exporter": "python",
1538 | "pygments_lexer": "ipython3",
1539 | "version": "3.11.7"
1540 | }
1541 | },
1542 | "nbformat": 4,
1543 | "nbformat_minor": 4
1544 | }
1545 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | jax
2 | dm-haiku
3 | optax
4 | matplotlib
5 | numpy
6 | pandas
7 | seaborn
8 | sklearn
9 |
--------------------------------------------------------------------------------
/test_ckpt.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 |
3 | import haiku as hk
4 | import jax
5 | import jax.numpy as jnp
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import optax
9 | import seaborn as sns
10 | from sklearn.model_selection import train_test_split
11 |
12 | # These are default values provided by the original colab
13 | import random
14 | # random_seed = random.randint(0,10000)
15 | # print(f"{random_seed=}")
16 | random_seed = 2 # @param {type: "integer"}
17 | batch_size = 64 # @param {type: "integer"}
18 | learning_rate = 0.001 # @param {type: "float"}
19 | num_training_steps = 50_000 # @param {type: "integer"}
20 | validation_interval = 100 # @param {type: "integer"}
21 |
22 | prngkey = 1 # @param {type: "integer"}
23 |
24 | # @title Download data
25 | import pandas as pd
26 | import os
27 | #_, input_filename = tempfile.mkstemp()
28 | input_filename = "./knot_theory_invariants.csv"
29 | # USE YOUR OWN gsutil cp COMMAND HERE
30 | if not os.path.exists(input_filename):
31 | os.system(f"~/software/google-cloud-sdk/bin/gsutil cp gs://maths_conjectures/knot_theory/knot_theory_invariants.csv {input_filename}")
32 |
33 | full_df = pd.read_csv(input_filename)
34 |
35 | # @title Load and preprocess data
36 |
37 | #@markdown The columns of the dataset which will make up the inputs to the network.
38 | #@markdown In other words, for a knot k, X(k) will be a vector consisting of these quantities. In this case, these are the geometric invariants of each knot.
39 | #@markdown For descriptions of these invariants see https://knotinfo.math.indiana.edu/
40 | display_name_from_short_name = {
41 | 'chern_simons': 'Chern-Simons',
42 | 'cusp_volume': 'Cusp volume',
43 | 'hyperbolic_adjoint_torsion_degree': 'Adjoint Torsion Degree',
44 | 'hyperbolic_torsion_degree': 'Torsion Degree',
45 | 'injectivity_radius': 'Injectivity radius',
46 | 'longitudinal_translation': 'Longitudinal translation',
47 | 'meridinal_translation_imag': 'Re(Meridional translation)',
48 | 'meridinal_translation_real': 'Im(Meridional translation)',
49 | 'short_geodesic_imag_part': 'Im(Short geodesic)',
50 | 'short_geodesic_real_part': 'Re(Short geodesic)',
51 | 'Symmetry_0': 'Symmetry: $0$',
52 | 'Symmetry_D3': 'Symmetry: $D_3$',
53 | 'Symmetry_D4': 'Symmetry: $D_4$',
54 | 'Symmetry_D6': 'Symmetry: $D_6$',
55 | 'Symmetry_D8': 'Symmetry: $D_8$',
56 | 'Symmetry_Z/2 + Z/2': 'Symmetry: $\\frac{Z}{2} + \\frac{Z}{2}$',
57 | 'volume': 'Volume',
58 | }
59 | column_names = list(display_name_from_short_name)
60 | target = 'signature'
61 |
62 | #@markdown Split the data into a training, a validation and a holdout test set. To check
63 | #@markdown the robustness of the model and any proposed relationship, the training
64 | #@markdown process can be repeated with multiple different train/validation/test splits.
65 |
66 | #@markdown Calculate the mean and standard deviation over each column in the training
67 | #@markdown dataset. We use this to normalize each feature, this is best practice for
68 | #@markdown inputting features into a network, but is also very important in this case
69 | #@markdown to ensure the gradients used for saliency are meaningfully comparable.
70 | def normalize_features(df, cols, add_target=True):
71 | features = df[cols]
72 | sigma = features.std()
73 | if any(sigma == 0):
74 | print(sigma)
75 | raise RuntimeError(
76 | "A poor data stratification has led to no variation in one of the data "
77 | "splits for at least one feature (ie std=0). Restratify and try again.")
78 | mu = features.mean()
79 | normed_df = (features - mu) / sigma
80 | if add_target:
81 | normed_df[target] = df[target]
82 | return normed_df
83 |
84 |
85 | def get_batch(df, cols, size=None):
86 | batch_df = df if size is None else df.sample(size)
87 | X = batch_df[cols].to_numpy()
88 | y = batch_df[target].to_numpy()
89 | return X, y
90 |
91 |
92 |
93 | random_state = np.random.RandomState(random_seed)
94 | train_df, validation_and_test_df = train_test_split(
95 | full_df, random_state=random_state)
96 | validation_df, test_df = train_test_split(
97 | validation_and_test_df, test_size=.5, random_state=random_state)
98 |
99 |
100 | # normed_train_df = normalize_features(train_df, column_names)
101 | # Sometimes this line will cause a RuntimeError when running smaller networks. Just restart the kernel and try again.
102 | # normed_validation_df = normalize_features(validation_df, column_names)
103 | normed_test_df = normalize_features(test_df, column_names)
104 |
105 | test_X, test_y = get_batch(normed_test_df, column_names)
106 |
107 |
108 |
109 | # Find bounds for the signature in the training dataset.
110 | max_signature = train_df[target].max()
111 | min_signature = train_df[target].min()
112 |
113 | import argparse
114 | parser = argparse.ArgumentParser()
115 | parser.add_argument("-hn", type=int, default=3)
116 | args = parser.parse_args()
117 | hid_dim = args.hn
118 |
119 | def compacted_net_forward(inp):
120 | return hk.Sequential([
121 | hk.Linear(hid_dim),
122 | jax.nn.sigmoid,
123 | hk.Linear(int((max_signature - min_signature) / 2)+1),
124 | ])(inp)
125 | compacted_net_forward_t = hk.without_apply_rng(hk.transform(compacted_net_forward))
126 |
127 | @jax.jit
128 | def compacted_predict(params, data_X):
129 | return (np.argmax(compacted_net_forward_t.apply(params, data_X), axis=1) * 2 +
130 | min_signature)
131 |
132 | if hid_dim == 2:
133 | num_ckpt = 78
134 | elif hid_dim == 3:
135 | num_ckpt = 110
136 |
137 | import pickle
138 | with open(f'ckpts/{num_ckpt}.pickle', 'rb') as f:
139 | hk_params = pickle.load(f)
140 | print(np.mean((compacted_predict(hk_params, test_X) - test_y) == 0))
--------------------------------------------------------------------------------