├── .gitignore ├── LICENSE ├── README.md ├── bin └── reformat_STEAD.py ├── data └── ridgecrest.hdf5 ├── fastmap ├── __init__.py ├── _version.py ├── core.py └── test.py ├── resources ├── confusion_matrix.png ├── readme_figure.png ├── supervised_clustering.png └── unsupervised_clustering.png ├── setup.py └── tutorial ├── fastmapsvm.joblib └── tutorial.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | *.egg-info 3 | *.ipynb_checkpoints 4 | .spyproject 5 | __pycache__ 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Malcolm White 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastMapSVM: An Algorithm for Classifying Complex Objects 2 | This is the official repository for the FastMapSVM algorithm associated with **"Classifying seismograms using the FastMap algorithm and support-vector machines"** (White et al., 2023). If you make use of this code, please cite White et al. (2023) where appropriate (full reference at bottom of page). 3 | 4 | ![Perspicuous Visualization](resources/readme_figure.png) 5 | 6 | # Installation 7 | ```bash 8 | >$ pip install . 9 | ``` 10 | 11 | # Tutorial 12 | This tutorial is available as a IPython notebook in `tutorials/tutorial.ipynb`. 13 | ```python 14 | import fastmap 15 | import h5py 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | import sklearn.metrics 19 | import sklearn.pipeline 20 | import sklearn.preprocessing 21 | import sklearn.svm 22 | ``` 23 | ## 1. Introduction: $X \rightarrow$ FastMap $\rightarrow$ SVM $\rightarrow \widehat{y}$ 24 | This tutorial demonstrates how to train and deploy the FastMapSVM classification model. As the name implies, FastMapSVM comprises two critical components: (1) the FastMap algorithm, and (2) an SVM classifier. The FastMap algorithm is implemented as a data transformation compatible with the `sklearn.pipeline.Pipeline` API. This allows the FastMapSVM model to be implemented as a simple `Pipeline` with optional intermediate transformations, such as data re-scaling. 25 | 26 | ## 2. Setup 27 | ### 2.1 Define the distance function 28 | To deploy the FastMapSVM algorithm, the user must define the distance function that quantifies the dissimilarity between any pair of objects in the train/test data. The distance function must adhere to NumPy's [broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html): Given input arrays `a` and `b` with shapes $(M, 1, ...)$ and $(1, N, ...)$, respectively, it should return the $M \times N$ distance matrix. 29 | ```python 30 | def correlation_distance(a, b, axis=-1): 31 | ''' 32 | Compute the pair-wise correlation distance matrix. 33 | ''' 34 | xcorr = correlate(a, b, axis=axis) 35 | xcorr = np.abs(xcorr) 36 | xcorr = np.nanmean(xcorr, axis=-2) 37 | xcorr = np.max(xcorr, axis=-1) 38 | xcorr = np.clip(xcorr, 0, 1) 39 | 40 | return 1 - xcorr 41 | 42 | 43 | def correlate(a, b, axis=-1): 44 | A = np.fft.rfft(a, axis=axis) 45 | B = np.fft.rfft(b, axis=axis) 46 | X = A * np.conj(B) 47 | x = np.fft.irfft(X) 48 | x = np.fft.fftshift(x, axes=axis) 49 | norm = np.sqrt( 50 | a.shape[-1] * np.var(a, axis=axis) 51 | * 52 | b.shape[-1] * np.var(b, axis=axis) 53 | ) 54 | norm = norm[..., np.newaxis] 55 | 56 | return np.nan_to_num(x / norm, neginf=0, posinf=0) 57 | ``` 58 | 59 | **Note**: If your distance cannot be easily vectorized, the code below implements a generic loop that applies the necessary broadcasting rules and calls the distance function in singleton fashion (i.e., on individual pairs of objects). 60 | 61 | ```python 62 | def generic_distance(a, b, axis=-1): 63 | # Build the output array with broadcasting rules applied. 64 | shape = np.broadcast_shapes(a.shape, b.shape) 65 | axis = axis if axis > -1 else len(shape) + axis 66 | shape = shape[:axis] + shape[axis+1:] 67 | output = np.empty(shape) 68 | n_dim = output.ndim 69 | 70 | # Loop over elements and compute distances serially. 71 | for ijk in np.ndindex(*output.shape): 72 | ijk_a = tuple([ijk[i] if a.shape[i] != 1 else 0 for i in range(len(ijk))]) 73 | ijk_b = tuple([ijk[i] if b.shape[i] != 1 else 0 for i in range(len(ijk))]) 74 | output[ijk] = dist(a[ijk_a], b[ijk_b]) 75 | 76 | return output 77 | 78 | def dist(a, b, axis=-1): 79 | ''' 80 | Return the distance between objects a and b. 81 | ''' 82 | return np.linalg.norm(a - b, axis=axis) 83 | ``` 84 | 85 | ### 2.2 Implement concrete FastMap class. 86 | The `fastmap` module provides an abstract base class `fastmap.FastMapABC` that is not intended to be used directly. The user should define a child class that adds a `_distance_function` attribute to the abstract base class. Implementing the model this way supports models persistence. 87 | ```python 88 | class FastMap(fastmap.FastMapABC): 89 | def __init__(self, *args, **kwargs): 90 | super().__init__(*args, **kwargs) 91 | self._distance_func = correlation_distance 92 | ``` 93 | ## 3. Model training 94 | ### 3.1 Load the train and test data. 95 | ```python 96 | with h5py.File('data/ridgecrest.hdf5', mode='r') as in_file: 97 | X_train = in_file['/X/train'][:] 98 | y_train = in_file['/y/train'][:] 99 | 100 | X_test = in_file['/X/test'][:] 101 | y_test = in_file['/y/test'][:] 102 | ``` 103 | 104 | ### 3.2 Build a `sklearn.pipeline.Pipeline` 105 | The FastMapSVM model benefits from rescaling the extracted features before SVM classification. 106 | ```python 107 | n_dim = 2 # The number of dimensions for the Euclidean embedding. 108 | fastmapsvm = sklearn.pipeline.Pipeline([ 109 | ('fastmap', FastMap(n_dim)), 110 | ('scaler', sklearn.preprocessing.StandardScaler()), 111 | ('svc', sklearn.svm.SVC()) 112 | ]) 113 | ``` 114 | 115 | ### 3.3 Train and score the model using the train data 116 | ```python 117 | fastmapsvm.fit(X_train, y_train); 118 | fastmapsvm.score(X_train, y_train) 119 | ``` 120 | 121 | ### 3.4 View the embedding of the train data 122 | Sub-components of the Pipeline can be extracted via indexing. 123 | ```python 124 | # Compute the embedding of the train data 125 | W = fastmapsvm[0].transform(X_train) 126 | 127 | plt.close('all') 128 | fig, ax = plt.subplots() 129 | for label in range(2): 130 | idxs = np.argwhere(y_train == label).flatten() 131 | ax.scatter(W[idxs, 0], W[idxs, 1]) 132 | ax.set_aspect(1) 133 | ax.set_xlabel('$w_0$') 134 | ax.set_ylabel('$w_1$') 135 | ``` 136 | ![Supervised clustering](resources/supervised_clustering.png) 137 | 138 | ## 4. Model testing 139 | ### 4.1 Score the model on the test data. 140 | ```python 141 | # For bigger data sets, it is helpful to have a progress bar 142 | fastmapsvm['fastmap'].show_progress = True 143 | 144 | fastmapsvm.score(X_test, y_test) 145 | ``` 146 | 147 | ### 4.2 Plot the confusion matrix for the test data 148 | ```python 149 | y_hat = fastmapsvm.predict(X_test) 150 | sklearn.metrics.ConfusionMatrixDisplay.from_predictions( 151 | y_test, 152 | y_hat 153 | ) 154 | ``` 155 | ![Confusion matrix](resources/confusion_matrix.png) 156 | 157 | ## 5. Model persistence 158 | ### 5.1 Store the trained model on disk 159 | ```python 160 | import joblib 161 | joblib.dump(fastmapsvm, 'fastmapsvm.joblib') 162 | 163 | del(fastmapsvm) 164 | ``` 165 | ### 5.2 Load a pre-trained model from disk 166 | **Note:** The distance function must be defined within the scope in which you load the model. So, if you train the model in one script, and then try to reload it in another script in which the distance function is not defined, it will not work. For instance, if you restart your Python kernel and immediately run the cell below, before running anything else, you will get `AttributeError: Can't get attribute 'FastMap' on `. There is, to my knowledge, no clean way of storing user-defined code and its dependencies alongside the model. The abstract base class paradigm is used to force users to write code in a way that will make it easier for them to reload the model later. 167 | 168 | If you restart your Python kernel, you need to run the code in sections 2.1 and 2.2 (along with the necessary imports) before running the code in this section. 169 | ```python 170 | import joblib 171 | fastmapsvm = joblib.load('fastmapsvm.joblib') 172 | ``` 173 | ## 6. Unsupervised clustering 174 | FastMap was originally designed for unsupervised cluster analysis, and can be trained in unsupervised mode by omitting the labels on training. 175 | ```python 176 | fm = FastMap(2) 177 | fm.fit(X_train) 178 | W = fm.transform(X_train) 179 | 180 | plt.close('all') 181 | fig, ax = plt.subplots() 182 | ax.scatter(W[:, 0], W[:, 1]) 183 | ``` 184 | ![Unsupervised clustering](resources/unsupervised_clustering.png) 185 | 186 | # References 187 | 188 | 1. Faloutsos, C. & Lin, K.-I. FastMap: A Fast Algorithm for Indexing, Data-Mining and Visualization of Traditional and Multimedia Datasets. _Sigmod Rec_ **24**, 163–174 (1995). 189 | 2. White, M. C. A., Sharma, K., Li, A., Kumar, T. K. S. & Nakata, N. Classifying seismograms using the FastMap algorithm and support-vector machines. _Commun. Eng._ **2**, 46 (2023). 190 | -------------------------------------------------------------------------------- /bin/reformat_STEAD.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pathlib 4 | import scipy.signal 5 | import tqdm 6 | 7 | nnoise_train = 8192 8 | neq_train = 8192 9 | sos = scipy.signal.butter(2, [1, 20], btype="bandpass", output="sos", fs=100) 10 | 11 | dataf0 = pd.concat( 12 | [pd.read_csv(f"/home/malcolmw/proj/fastmapsvm/data/stead/chunk{i}.csv") for i in range(1, 7)], 13 | ignore_index=True 14 | ) 15 | 16 | dataf_noise = dataf0[dataf0["chunk"] == 1] 17 | dataf_noise_train = dataf_noise.sample(n=nnoise_train) 18 | dataf_noise_test = dataf_noise[~dataf_noise.index.isin(dataf_noise_train.index)] 19 | dataf_noise_train = dataf_noise_train.reset_index(drop=True) 20 | dataf_noise_test = dataf_noise_test.reset_index(drop=True) 21 | 22 | dataf_eq = dataf0[dataf0["chunk"] > 1] 23 | dataf_eq_train = dataf_eq.sample(n=neq_train) 24 | dataf_eq_test = dataf_eq[~dataf_eq.index.isin(dataf_eq_train.index)] 25 | dataf_eq_train = dataf_eq_train.reset_index(drop=True) 26 | dataf_eq_test = dataf_eq_test.reset_index(drop=True) 27 | 28 | n_train = len(dataf_noise_train) + len(dataf_eq_train) 29 | n_test = len(dataf_noise_test) + len(dataf_eq_test) 30 | 31 | with h5py.File("/home/malcolmw/proj/fastmapsvm/data/stead/train.hdf5", mode="w") as f5out: 32 | X = f5out.create_dataset("X", shape=(n_train, 3, 6000), dtype=np.float32) 33 | y = f5out.create_dataset("y", shape=(n_train,), dtype=np.uint8) 34 | i = 0 35 | for label, dataf in enumerate((dataf_noise_train, dataf_eq_train)): 36 | for j, row in tqdm.tqdm(dataf.iterrows(), total=len(dataf)): 37 | ichunk = row["chunk"] 38 | handle = row["trace_name"] 39 | with h5py.File(f"/home/malcolmw/proj/fastmapsvm/data/stead/chunk{ichunk}.hdf5", mode="r") as f5in: 40 | x = f5in[f"/data/{handle}"][:] 41 | x = scipy.signal.sosfiltfilt(sos, x, axis=0) 42 | try: 43 | x = (x - np.mean(x, axis=0)) / np.std(x, axis=0) 44 | except: 45 | x = x - np.mean(x, axis=0) 46 | X[i] = x.T 47 | y[i] = label 48 | i += 1 49 | 50 | with h5py.File("/home/malcolmw/proj/fastmapsvm/data/stead/test.hdf5", mode="w") as f5out: 51 | X = f5out.create_dataset("X", shape=(n_test, 3, 6000), dtype=np.float32) 52 | y = f5out.create_dataset("y", shape=(n_test,), dtype=np.uint8) 53 | i = 0 54 | for label, dataf in enumerate((dataf_noise_test, dataf_eq_test)): 55 | for j, row in tqdm.tqdm(dataf.iterrows(), total=len(dataf)): 56 | ichunk = row["chunk"] 57 | handle = row["trace_name"] 58 | with h5py.File(f"/home/malcolmw/proj/fastmapsvm/data/stead/chunk{ichunk}.hdf5", mode="r") as f5in: 59 | x = f5in[f"/data/{handle}"][:] 60 | x = scipy.signal.sosfiltfilt(sos, x, axis=0) 61 | try: 62 | x = (x - np.mean(x, axis=0)) / np.std(x, axis=0) 63 | except: 64 | x = x - np.mean(x, axis=0) 65 | X[i] = x.T 66 | y[i] = label 67 | i += 1 68 | -------------------------------------------------------------------------------- /data/ridgecrest.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/malcolmw/FastMapSVM/90ad89db0c6154e23f3afbf46f12a84b9db42caa/data/ridgecrest.hdf5 -------------------------------------------------------------------------------- /fastmap/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import FastMapABC -------------------------------------------------------------------------------- /fastmap/_version.py: -------------------------------------------------------------------------------- 1 | # Don't forget to update version in pyproject.toml. 2 | __version__ = '0.1.0beta' 3 | -------------------------------------------------------------------------------- /fastmap/core.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Feb 16 14:24:10 2023 5 | 6 | @author: malcolmw 7 | """ 8 | 9 | print('Importing fastmap.core') 10 | 11 | import numpy as np 12 | import tqdm 13 | 14 | DEFAULT_BATCH_SIZE = 1024 15 | EPSILON = 1e-9 16 | 17 | class FastMapABC: 18 | 19 | def __init__( 20 | self, 21 | n_dim, 22 | show_progress=False, 23 | batch_size=DEFAULT_BATCH_SIZE, 24 | cupy=False 25 | ): 26 | ''' 27 | Implements the FastMap algorithm. 28 | 29 | Parameters 30 | ---------- 31 | n_dim : int 32 | The number of Euclidean dimensions. 33 | model_path : str, pathlib.Path 34 | Path to store model. 35 | show_progress: bool, optional 36 | Show TQDM progress bar. The default is False. 37 | cupy: bool, optional 38 | Use cupy backend. The default is False. 39 | 40 | Returns 41 | ------- 42 | None. 43 | 44 | ''' 45 | self._ihyprpln = 0 46 | self._n_dim = n_dim 47 | self._batch_size = batch_size 48 | self.show_progress = show_progress 49 | if cupy is False: 50 | self.numpy() 51 | else: 52 | self.cupy() 53 | 54 | 55 | @property 56 | def batch_size(self): 57 | ''' 58 | Returns 59 | ------- 60 | int 61 | Batch size. 62 | 63 | ''' 64 | return self._batch_size 65 | 66 | @batch_size.setter 67 | def batch_size(self, value): 68 | if not isinstance(value, int): 69 | raise TypeError('batch_size must be an int.') 70 | self._batch_size = value 71 | 72 | 73 | @property 74 | def n_dim(self): 75 | ''' 76 | Returns 77 | ------- 78 | int 79 | Dimensionality of embedding. 80 | 81 | ''' 82 | return self._n_dim 83 | 84 | 85 | @property 86 | def n_obj(self): 87 | ''' 88 | Returns 89 | ------- 90 | int 91 | The number of objects in the train set. 92 | 93 | ''' 94 | return len(self.X) 95 | 96 | @property 97 | def pivot_ids(self): 98 | ''' 99 | Returns 100 | ------- 101 | h5py.DataSet 102 | Indices of pivot objects. 103 | 104 | ''' 105 | if not hasattr(self, '_pivot_ids'): 106 | self._pivot_ids = np.full( 107 | (self.n_dim, 2), 108 | np.nan, 109 | dtype=np.uint16, 110 | ) 111 | return self._pivot_ids 112 | 113 | 114 | @property 115 | def show_progress(self): 116 | return self._show_progress 117 | 118 | @show_progress.setter 119 | def show_progress(self, value): 120 | if value not in (True, False): 121 | raise(ValueError('show_progress must be either True or False.')) 122 | self._show_progress = value 123 | 124 | @property 125 | def supervised(self): 126 | ''' 127 | Returns 128 | ------- 129 | bool 130 | Whether the embedding is supervised. 131 | 132 | ''' 133 | return self._supervised 134 | 135 | 136 | @property 137 | def W_piv(self): 138 | if not hasattr(self, '_W_piv'): 139 | self._W_piv = np.full( 140 | (self.n_dim, 2, self.n_dim), 141 | np.nan, 142 | dtype=np.float32 143 | ) 144 | return self._W_piv 145 | 146 | 147 | @property 148 | def X(self): 149 | ''' 150 | Returns 151 | ------- 152 | numpy.array or cupy.array 153 | Embedded objects in original data domain. 154 | 155 | ''' 156 | return self._X 157 | 158 | @X.setter 159 | def X(self, value): 160 | self._X = value 161 | 162 | 163 | @property 164 | def X_piv(self): 165 | if not hasattr(self, '_X_piv'): 166 | self._X_piv = np.full( 167 | (self.n_dim, 2, *self.X.shape[1:]), 168 | np.nan, 169 | dtype=self.X.dtype 170 | ) 171 | return self._X_piv 172 | 173 | 174 | @property 175 | def y(self): 176 | ''' 177 | Returns 178 | ------- 179 | numpy.array 180 | Class labels of training data if run in supervised mode. 181 | 182 | ''' 183 | return self._y 184 | 185 | @y.setter 186 | def y(self, value): 187 | if value is not None: 188 | self._y = np.array(value) 189 | self._supervised = True 190 | else: 191 | self._supervised = False 192 | 193 | 194 | def _choose_pivots(self, n_proc=None): 195 | ''' 196 | A heuristic algorithm to choose distant pivot objects adapted 197 | from Faloutsos and Lin (1995). 198 | 199 | Parameters 200 | ---------- 201 | n_proc : int, optional 202 | The number of processors to use if running in multiprocessing mode. 203 | The default is None. 204 | 205 | Returns 206 | ------- 207 | i_obj : int 208 | The index of pivot object #1. 209 | j_obj : int 210 | The index of pivot object #2. 211 | 212 | ''' 213 | 214 | forbidden = self.pivot_ids[:self._ihyprpln].flatten() 215 | 216 | while True: 217 | if self.supervised is True: 218 | idxs = np.argwhere(self.y == 1).flatten() 219 | else: 220 | idxs = np.arange(self.n_obj) 221 | j_obj = np.random.choice(idxs) 222 | if j_obj not in forbidden: 223 | break 224 | 225 | furthest = self.furthest( 226 | j_obj, 227 | label=0 if self.supervised else None, 228 | n_proc=n_proc 229 | ) 230 | for i_obj in furthest: 231 | if i_obj not in forbidden: 232 | break 233 | 234 | furthest = self.furthest( 235 | i_obj, 236 | label=1 if self.supervised else None, 237 | n_proc=n_proc 238 | ) 239 | for j_obj in furthest: 240 | if j_obj not in forbidden: 241 | break 242 | 243 | return i_obj, j_obj 244 | 245 | 246 | 247 | def numpy(self): 248 | self.xp = np 249 | self.get_array = lambda array: array 250 | 251 | def cupy(self): 252 | import cupy as xp 253 | self.xp = xp 254 | self.get_array = lambda array: array.get() if isinstance(array, xp.ndarray) else array 255 | 256 | def distance_matrix( 257 | self, 258 | i_objs, 259 | j_objs, 260 | X_1=None, 261 | X_2=None, 262 | W_1=None, 263 | W_2=None 264 | ): 265 | # """ 266 | # Return the distance between objects at indices i_objs and kernel object at 267 | # index ikernel on the ihyprpln^th hyperplane. 268 | 269 | # Arguments: 270 | # - iobj: int 271 | # Index of first object to consider. 272 | # - jobj: int 273 | # Index of second object to consider. 274 | 275 | # Keyword arguments: 276 | # - ihyprpln: int=0 277 | # Index of hyperplane on which to compute distance. 278 | # """ 279 | 280 | if X_1 is None: 281 | X_1 = self.X 282 | if X_2 is None: 283 | X_2 = self.X 284 | if W_1 is None: 285 | W_1 = self.W 286 | if W_2 is None: 287 | W_2 = self.W 288 | 289 | X_j = self.xp.array(X_2[j_objs]) 290 | dW = self.xp.square(self.xp.array(W_1[i_objs]) - self.xp.array(W_2[j_objs])) 291 | 292 | dist = [ 293 | self._distance_func( 294 | self.xp.array(X_1[i_objs[i: i+self.batch_size]]), 295 | self.xp.array(X_j) 296 | ) 297 | for i in range(0, len(i_objs), self.batch_size) 298 | ] 299 | dist = self.xp.concatenate(dist) if len(dist) > 1 else dist[0] 300 | 301 | for i in range(self._ihyprpln): 302 | dist = self.xp.sqrt(self.xp.clip(dist**2 - dW[:, i], 0, self.xp.inf)) 303 | 304 | return dist 305 | 306 | 307 | def furthest(self, i_obj, label=None, n_proc=None): 308 | """ 309 | Return the index of the object furthest from object with index 310 | *i_obj*. 311 | """ 312 | 313 | if label is None: 314 | idxs = np.arange(self.n_obj) 315 | else: 316 | idxs = np.argwhere(self.y == label).flatten() 317 | 318 | dW = self.xp.square(self.xp.array(self.W[idxs] - self.W[[i_obj]])) 319 | dist = self._distance_func( 320 | self.xp.array(self.X[idxs]), 321 | self.xp.array(self.X[[i_obj]]) 322 | ) 323 | for i in range(self._ihyprpln): 324 | dist = self.xp.sqrt(self.xp.clip(dist**2 - dW[:, i], 0, self.xp.inf)) 325 | 326 | idxs = idxs[self.get_array(self.xp.argsort(dist))] 327 | return idxs[-1::-1] 328 | 329 | 330 | def fit( 331 | self, 332 | X, 333 | y=None, 334 | n_proc=None 335 | ): 336 | ''' 337 | Train the FastMap embedding using the input X, y data. 338 | 339 | Parameters 340 | ---------- 341 | X : numpy.array or cupy.array 342 | Objects to embed. These objects must be represented as an 343 | n-D array. 344 | y : array-like, optional 345 | Binary Class labels for supervised mode. The default is None. 346 | n_proc : int, optional 347 | Number of processes to use if running multiprocessing mode. 348 | The default is None. 349 | 350 | Returns 351 | ------- 352 | None. 353 | 354 | ''' 355 | self.X = X 356 | self.y = y 357 | 358 | self.W = np.full( 359 | (self.n_obj, self.n_dim), 360 | np.nan, 361 | dtype=np.float32 362 | ) 363 | 364 | wrapper = tqdm.tqdm if self.show_progress is True else lambda x: x 365 | for self._ihyprpln in wrapper(range(self.n_dim)): 366 | i_piv, j_piv = self._choose_pivots(n_proc=n_proc) 367 | self.pivot_ids[self._ihyprpln] = [i_piv, j_piv] 368 | self.X_piv[self._ihyprpln, 0] = self.get_array(self.X[i_piv]) 369 | self.X_piv[self._ihyprpln, 1] = self.get_array(self.X[j_piv]) 370 | 371 | d_ij = self.distance_matrix([i_piv], [j_piv]) 372 | d = self.xp.square(self.distance_matrix(np.arange(self.n_obj), i_piv)) 373 | d -= self.xp.square(self.distance_matrix(np.arange(self.n_obj), j_piv)) 374 | # d = d.get() 375 | d += d_ij ** 2 376 | ####### Avoid divide by zero. 377 | d /= (2 * d_ij + EPSILON) 378 | #### Hack for negative distances. 379 | d = self.xp.clip(d, 0, self.xp.inf) 380 | #### 381 | self.W[:, self._ihyprpln] = self.get_array(d) 382 | 383 | for i_dim, (i_piv, j_piv) in enumerate(self.pivot_ids): 384 | self.W_piv[i_dim, 0] = self.W[i_piv] 385 | self.W_piv[i_dim, 1] = self.W[j_piv] 386 | 387 | del(self._pivot_ids, self._X, self.W) 388 | if hasattr(self, '_y'): 389 | del(self._y) 390 | 391 | self._ihyprpln = 0 392 | return self 393 | 394 | 395 | def transform(self, X): 396 | """ 397 | Return the embedding (images) of the given objects, `X`. 398 | """ 399 | 400 | n_obj = len(X) 401 | 402 | W = self.xp.zeros((n_obj, self.n_dim), dtype=(self.xp).float32) 403 | X_piv = self.xp.array(self.X_piv[:]) 404 | W_piv = self.xp.array(self.W_piv[:]) 405 | wrapper = tqdm.tqdm if self.show_progress is True else lambda x: x 406 | for i_batch, i_start in enumerate(wrapper(range( 407 | 0, 408 | n_obj, 409 | self.batch_size 410 | ))): 411 | X_batch = self.xp.array(X[i_start: i_start+self.batch_size]) 412 | W_batch = W[i_start: i_start+self.batch_size] 413 | d_ij0 = self._distance_func(X_piv[:, [0]], X_piv[:, [1]]) 414 | d_k0 = self._distance_func( 415 | X_batch[:, self.xp.newaxis, self.xp.newaxis], 416 | X_piv[self.xp.newaxis] 417 | ) 418 | for self._ihyprpln in range(self.n_dim): 419 | dW_ij = self.xp.square(W_piv[self._ihyprpln, [0]] - W_piv[self._ihyprpln, 1]) 420 | dW_ik = self.xp.square(W_batch - W_piv[self._ihyprpln, 0]) 421 | dW_jk = self.xp.square(W_batch - W_piv[self._ihyprpln, 1]) 422 | d_ij = d_ij0[self._ihyprpln].copy() 423 | d_ik = d_k0[:, self._ihyprpln, 0].copy() 424 | d_jk = d_k0[:, self._ihyprpln, 1].copy() 425 | for i in range(self._ihyprpln): 426 | d_ij = self.xp.sqrt(self.xp.clip(d_ij**2 - dW_ij[:, i], 0, self.xp.inf)) 427 | d_ik = self.xp.sqrt(self.xp.clip(d_ik**2 - dW_ik[:, i], 0, self.xp.inf)) 428 | d_jk = self.xp.sqrt(self.xp.clip(d_jk**2 - dW_jk[:, i], 0, self.xp.inf)) 429 | W_batch[:, self._ihyprpln] = self.xp.square(d_ik) 430 | W_batch[:, self._ihyprpln] += self.xp.square(d_ij) 431 | W_batch[:, self._ihyprpln] -= self.xp.square(d_jk) 432 | W_batch[:, self._ihyprpln] /= (d_ij * 2 + EPSILON) 433 | W[i_start: i_start+self.batch_size, self._ihyprpln] = W_batch[:, self._ihyprpln] 434 | 435 | return self.get_array(W) 436 | 437 | 438 | if __name__ == '__main__': 439 | pass 440 | -------------------------------------------------------------------------------- /fastmap/test.py: -------------------------------------------------------------------------------- 1 | import core 2 | 3 | import h5py 4 | import numpy as np 5 | 6 | get_array_module = lambda array: np 7 | 8 | def correlation_distance(a, b, axis=-1): 9 | ''' 10 | Compute the pair-wise correlation distance matrix. 11 | ''' 12 | xp = get_array_module(a) 13 | return 1 - xp.clip( 14 | xp.max( 15 | xp.nanmean( 16 | xp.abs( 17 | correlate(a, b, axis=axis) 18 | ), 19 | axis=-2 20 | ), 21 | axis=-1 22 | ), 23 | 0, 1 24 | ) 25 | 26 | 27 | def correlate(a, b, axis=-1): 28 | xp = get_array_module(a) 29 | 30 | z = xp.fft.fftshift( 31 | xp.fft.irfft( 32 | xp.fft.rfft(a, axis=axis) 33 | * 34 | xp.conj( 35 | xp.fft.rfft(b, axis=axis) 36 | ) 37 | ), 38 | axes=axis 39 | ) 40 | norm = xp.sqrt( 41 | a.shape[-1] * xp.var(a, axis=axis) 42 | * 43 | b.shape[-1] * xp.var(b, axis=axis) 44 | ) 45 | norm = norm[..., xp.newaxis] 46 | 47 | return xp.nan_to_num(z / norm, neginf=0, posinf=0) 48 | 49 | 50 | class FastMap(core.FastMapABC): 51 | def __init__(self, *args, **kwargs): 52 | super().__init__(*args, **kwargs) 53 | self._distance_func = correlation_distance 54 | 55 | 56 | def test(): 57 | 58 | import sklearn.pipeline 59 | import sklearn.preprocessing 60 | import sklearn.svm 61 | 62 | data_path = '../data/ridgecrest.hdf5' 63 | with h5py.File(data_path, mode='r') as in_file: 64 | X_train = in_file['/X/train'][:] 65 | y_train = in_file['/y/train'][:] 66 | 67 | X_test = in_file['/X/test'][:] 68 | y_test = in_file['/y/test'][:] 69 | 70 | 71 | pipe = sklearn.pipeline.Pipeline([ 72 | ('fastmap', FastMap(2)), 73 | ('scaler', sklearn.preprocessing.StandardScaler()), 74 | ('svc', sklearn.svm.SVC()) 75 | ]) 76 | pipe.fit(X_train, y_train) 77 | print(pipe.score(X_test, y_test)) 78 | 79 | W = pipe['fastmap'].transform(X_train) 80 | W_test = pipe['fastmap'].transform(X_test) 81 | 82 | import matplotlib.pyplot as plt 83 | fig, ax = plt.subplots() 84 | for i in range(2): 85 | idxs = np.argwhere(y_train == i).flatten() 86 | ax.scatter( 87 | W[idxs, 0], 88 | W[idxs, 1] 89 | ) 90 | plt.show() 91 | 92 | 93 | fig, ax = plt.subplots() 94 | for i in range(2): 95 | idxs = np.argwhere(y_test == i).flatten() 96 | ax.scatter( 97 | W_test[idxs, 0], 98 | W_test[idxs, 1] 99 | ) 100 | plt.show() 101 | 102 | return pipe 103 | 104 | 105 | 106 | if __name__ == '__main__': 107 | pipe = test() 108 | -------------------------------------------------------------------------------- /resources/confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/malcolmw/FastMapSVM/90ad89db0c6154e23f3afbf46f12a84b9db42caa/resources/confusion_matrix.png -------------------------------------------------------------------------------- /resources/readme_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/malcolmw/FastMapSVM/90ad89db0c6154e23f3afbf46f12a84b9db42caa/resources/readme_figure.png -------------------------------------------------------------------------------- /resources/supervised_clustering.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/malcolmw/FastMapSVM/90ad89db0c6154e23f3afbf46f12a84b9db42caa/resources/supervised_clustering.png -------------------------------------------------------------------------------- /resources/unsupervised_clustering.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/malcolmw/FastMapSVM/90ad89db0c6154e23f3afbf46f12a84b9db42caa/resources/unsupervised_clustering.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | import setuptools 3 | 4 | version_file = 'fastmap/_version.py' 5 | version_line = open(version_file, 'r').read() 6 | version_re = r'^__version__ = ["\']([^"\']*)["\']' 7 | mo = re.search(version_re, version_line, re.M) 8 | if mo: 9 | version = mo.group(1) 10 | else: 11 | raise RuntimeError(f'Unable to find version string in {version_file}.') 12 | 13 | def configure(): 14 | # Initialize the setup kwargs 15 | kwargs = { 16 | 'name': 'FastMap', 17 | 'version': version, 18 | 'author': 'Malcolm C. A. White', 19 | 'author_email': 'malcolmw@mit.edu', 20 | 'maintainer': 'Malcolm C. A. White', 21 | 'maintainer_email': 'malcolmw@mit.edu', 22 | 'url': 'http://malcolmw.github.io/FastMapSVM', 23 | 'description': 'Official implementation of FastMapSVM algorithm ' 24 | 'for classifying complex objects (White et al., 2023).', 25 | 'download_url': 'https://github.com/malcolmw/FastMapSVM.git', 26 | 'platforms': ['linux'], 27 | 'install_requires': ['numpy', 'tqdm'], 28 | 'packages': ['fastmap'] 29 | } 30 | return kwargs 31 | 32 | if __name__ == '__main__': 33 | kwargs = configure() 34 | setuptools.setup(**kwargs) 35 | -------------------------------------------------------------------------------- /tutorial/fastmapsvm.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/malcolmw/FastMapSVM/90ad89db0c6154e23f3afbf46f12a84b9db42caa/tutorial/fastmapsvm.joblib -------------------------------------------------------------------------------- /tutorial/tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "81779b31-b85b-47c9-8384-d39b38896dbd", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%matplotlib inline" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "585d6b98-c21f-48fd-83df-91d7e02dbc00", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import fastmap\n", 21 | "import h5py\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "import numpy as np\n", 24 | "import sklearn.metrics\n", 25 | "import sklearn.pipeline\n", 26 | "import sklearn.preprocessing\n", 27 | "import sklearn.svm" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "e62ad6de-2bc3-4cf9-982c-ffaa238543f8", 33 | "metadata": {}, 34 | "source": [ 35 | "# 1. Introduction: $X \\rightarrow$ FastMap $\\rightarrow$ SVM $\\rightarrow \\widehat{y}$\n", 36 | "This tutorial demonstrates how to train and deploy the FastMapSVM classification model. As the name implies, FastMapSVM comprises two critical components: (1) the FastMap algorithm, and (2) an SVM classifier. The FastMap algorithm is implemented as a data transformation compatible with the `sklearn.pipeline.Pipeline` API. This allows the FastMapSVM model to be implemented as a simple `Pipeline` with optional intermediate transformations, such as data re-scaling." 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "id": "a80d8831-6dfc-4e8a-93f1-b16b514feebb", 42 | "metadata": {}, 43 | "source": [ 44 | "# 2. Setup\n", 45 | "## 2.1 Define the distance function\n", 46 | "To deploy the FastMapSVM algorithm, the user must define the distance function that quantifies the dissimilarity between any pair of objects in the train/test data. The distance function must adhere to NumPy's [broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html): Given input arrays `a` and `b` with shapes $(M, 1, ...)$ and $(1, N, ...)$, respectively, it should return the $M \\times N$ distance matrix." 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "id": "03885713-4034-4da1-8767-dd6c1db127f2", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "def correlation_distance(a, b, axis=-1):\n", 57 | " '''\n", 58 | " Compute the pair-wise correlation distance matrix.\n", 59 | " '''\n", 60 | " xcorr = correlate(a, b, axis=axis)\n", 61 | " xcorr = np.abs(xcorr)\n", 62 | " xcorr = np.nanmean(xcorr, axis=-2)\n", 63 | " xcorr = np.max(xcorr, axis=-1)\n", 64 | " xcorr = np.clip(xcorr, 0, 1)\n", 65 | " \n", 66 | " return 1 - xcorr\n", 67 | " \n", 68 | "\n", 69 | "def correlate(a, b, axis=-1):\n", 70 | " A = np.fft.rfft(a, axis=axis)\n", 71 | " B = np.fft.rfft(b, axis=axis)\n", 72 | " X = A * np.conj(B)\n", 73 | " x = np.fft.irfft(X)\n", 74 | " x = np.fft.fftshift(x, axes=axis)\n", 75 | " norm = np.sqrt(\n", 76 | " a.shape[-1] * np.var(a, axis=axis)\n", 77 | " *\n", 78 | " b.shape[-1] * np.var(b, axis=axis)\n", 79 | " )\n", 80 | " norm = norm[..., np.newaxis]\n", 81 | "\n", 82 | " return np.nan_to_num(x / norm, neginf=0, posinf=0)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "id": "a1f42654-2de9-47ac-b093-1f1e8f570a1b", 88 | "metadata": {}, 89 | "source": [ 90 | "**Note**: If your distance cannot be easily vectorized, the code below implements a generic loop that applies the necessary broadcasting rules and calls the distance function in singleton fashion (i.e., on individual pairs of objects)." 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "8a0d4f1c-b946-4642-920d-bdb714d0f684", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "def generic_distance(A, B, axis=-1):\n", 101 | " '''\n", 102 | " Return the (broadcasted) distance matrix between multidimensional\n", 103 | " arrays of objects A and B.\n", 104 | " '''\n", 105 | " # Build the output array with broadcasting rules applied.\n", 106 | " shape = np.broadcast_shapes(A.shape, B.shape)\n", 107 | " axis = axis if axis > -1 else len(shape) + axis\n", 108 | " shape = shape[:axis] + shape[axis+1:]\n", 109 | " output = np.empty(shape)\n", 110 | " n_dim = output.ndim\n", 111 | "\n", 112 | " # Loop over elements and compute distances.\n", 113 | " for ijk in np.ndindex(*output.shape):\n", 114 | " ijk_A = tuple([ijk[i] if A.shape[i] != 1 else 0 for i in range(len(ijk))])\n", 115 | " ijk_B = tuple([ijk[i] if B.shape[i] != 1 else 0 for i in range(len(ijk))])\n", 116 | " output[ijk] = _distance_singleton(A[ijk_A], B[ijk_B])\n", 117 | " \n", 118 | " return output\n", 119 | " \n", 120 | "def _distance_singleton(a, b, axis=-1):\n", 121 | " '''\n", 122 | " Return the distance between a pair of objects a and b.\n", 123 | " '''\n", 124 | " return np.linalg.norm(a - b, axis=axis)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "id": "6ec9a53f-03d7-4541-a6be-63d8cc68a59f", 130 | "metadata": {}, 131 | "source": [ 132 | "## 2.2 Implement concrete FastMap class.\n", 133 | "The `fastmap` module provides an abstract base class `fastmap.FastMapABC` that is not intended to be used directly. The user should define a child class that adds a `_distance_function` attribute to the abstract base class. Implementing the model this way supports models persistence." 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "id": "abf45e35-bc06-4c7e-bfe0-277c633f399c", 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "class FastMap(fastmap.FastMapABC):\n", 144 | " def __init__(self, *args, **kwargs):\n", 145 | " super().__init__(*args, **kwargs)\n", 146 | " self._distance_func = correlation_distance" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "id": "f4a9ef04-98a9-45b3-9bd8-8de42c8f362e", 152 | "metadata": {}, 153 | "source": [ 154 | "# 3. Model training\n", 155 | "## 3.1 Load the train and test data." 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "2e796a34-45ae-45aa-b3f6-c041acf8dbc8", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "with h5py.File('../data/ridgecrest.hdf5', mode='r') as in_file:\n", 166 | " X_train = in_file['/X/train'][:]\n", 167 | " y_train = in_file['/y/train'][:]\n", 168 | "\n", 169 | " X_test = in_file['/X/test'][:]\n", 170 | " y_test = in_file['/y/test'][:]" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "id": "ed814540-85eb-4d98-a16d-0771a97a4498", 176 | "metadata": {}, 177 | "source": [ 178 | "## 3.2 Build a `sklearn.pipeline.Pipeline`\n", 179 | "The FastMapSVM model benefits from rescaling the extracted features before SVM classification." 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "id": "87b2702e-02de-491d-9e3f-7776de3c03ea", 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "n_dim = 2 # The number of dimensions for the Euclidean embedding.\n", 190 | "fastmapsvm = sklearn.pipeline.Pipeline([\n", 191 | " ('fastmap', FastMap(n_dim)),\n", 192 | " ('scaler', sklearn.preprocessing.StandardScaler()),\n", 193 | " ('svc', sklearn.svm.SVC())\n", 194 | "])\n", 195 | "fastmapsvm" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "id": "136e2fb4-786c-44d9-b919-8389a66fe635", 201 | "metadata": {}, 202 | "source": [ 203 | "## 3.3 Train and score the model using the train data" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "id": "ea3fcc98-3fb7-4cb6-8a8a-4cf9631f76c4", 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "fastmapsvm.fit(X_train, y_train);\n", 214 | "fastmapsvm.score(X_train, y_train)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "id": "19ffabdb-afcd-4cad-9883-adaeed37f970", 220 | "metadata": {}, 221 | "source": [ 222 | "## 3.4 View the embedding of the train data\n", 223 | "Sub-components of the Pipeline can be extracted via indexing." 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "id": "0daeeca9-e323-4fc4-b3a0-f592bebef9e1", 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "fastmapsvm[:2]" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "id": "dd33b06c-ff7a-47b6-b3d7-a432566d03f9", 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "# Compute the embedding of the train data\n", 244 | "W = fastmapsvm[0].transform(X_train)\n", 245 | "\n", 246 | "plt.close('all')\n", 247 | "fig, ax = plt.subplots()\n", 248 | "for label in range(2):\n", 249 | " idxs = np.argwhere(y_train == label).flatten()\n", 250 | " ax.scatter(W[idxs, 0], W[idxs, 1])\n", 251 | "ax.set_aspect(1)\n", 252 | "ax.set_xlabel('$w_0$')\n", 253 | "ax.set_ylabel('$w_1$')" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "id": "cf9e48db-6141-469e-8e41-aa94324d9d18", 259 | "metadata": {}, 260 | "source": [ 261 | "# 4. Model testing\n", 262 | "## 4.1 Score the model on the test data." 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "id": "5332934c-cf0c-4dee-b588-bae61f4b7b27", 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "# For bigger data sets, it is helpful to have a progress bar\n", 273 | "fastmapsvm['fastmap'].show_progress = True\n", 274 | "\n", 275 | "fastmapsvm.score(X_test, y_test)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "id": "fa766269-fb09-408f-ab6b-86662347da16", 281 | "metadata": {}, 282 | "source": [ 283 | "## 4.2 Plot the confusion matrix for the test data" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "id": "affb0826-acb2-43c1-8963-006f7b56b10f", 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "y_hat = fastmapsvm.predict(X_test)\n", 294 | "sklearn.metrics.ConfusionMatrixDisplay.from_predictions(\n", 295 | " y_test, \n", 296 | " y_hat\n", 297 | ");" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "id": "4b63536b-8e84-48f3-aaad-35329565bafe", 303 | "metadata": {}, 304 | "source": [ 305 | "# 5. Model persistence\n", 306 | "## 5.1 Store the trained model on disk" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "id": "d10f7d42-4654-4871-9bd1-07af6ef4a319", 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "import joblib\n", 317 | "joblib.dump(fastmapsvm, 'fastmapsvm.joblib')\n", 318 | "\n", 319 | "del(fastmapsvm)" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "id": "7a2dabba-9c7e-41cb-9361-8b4c9f9be0eb", 325 | "metadata": {}, 326 | "source": [ 327 | "## 5.2 Load a pre-trained model from disk\n", 328 | "**Note:** The distance function must be defined within the scope in which you load the model. So, if you train the model in one script, and then try to reload it in another script in which the distance function is not defined, it will not work. For instance, if you restart your Python kernel and immediately run the cell below, before running anything else, you will get `AttributeError: Can't get attribute 'FastMap' on `. There is, to my knowledge, no clean way of storing user-defined code and its dependencies alongside the model. The abstract base class paradigm is used to force users to write code in a way that will make it easier for them to reload the model later.\n", 329 | "\n", 330 | "If you restart your Python kernel, you need to run the code in sections 2.1 and 2.2 (along with the necessary imports) before running the code in this section." 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "id": "8bb92e07-9419-40f3-963d-31b5567ced2e", 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [ 340 | "import joblib\n", 341 | "fastmapsvm = joblib.load('fastmapsvm.joblib')" 342 | ] 343 | }, 344 | { 345 | "cell_type": "markdown", 346 | "id": "8abd77b5-61b6-4465-80c6-b4800b2989bb", 347 | "metadata": {}, 348 | "source": [ 349 | "# 6. Unsupervised clustering\n", 350 | "FastMap was originally designed for unsupervised cluster analysis, and can be trained in unsupervised mode by omitting the labels on training." 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "id": "51258d0d-66a2-4aa6-b9eb-87b8dc14b2dc", 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "fm = FastMap(2)\n", 361 | "fm.fit(X_train)\n", 362 | "W = fm.transform(X_train)\n", 363 | "\n", 364 | "plt.close('all')\n", 365 | "fig, ax = plt.subplots()\n", 366 | "ax.scatter(W[:, 0], W[:, 1])" 367 | ] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "id": "66e90e9e-f8b1-45ae-87f4-1797eaa7f569", 372 | "metadata": {}, 373 | "source": [ 374 | "# 7. Minkowski" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "id": "d584a21c-0a91-469c-9ec1-dbe2e5012204", 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [ 384 | "a = np.random.rand(3, 3, 3)" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": null, 390 | "id": "b853433a-3225-4293-9cdf-2844148ad5eb", 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "np.sum(a, axis=(-1, -2))" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": null, 400 | "id": "31cbb245-4cfa-4f17-b058-425d9427be3a", 401 | "metadata": {}, 402 | "outputs": [], 403 | "source": [ 404 | "def minkowski_distance(a, b):\n", 405 | " p = 1\n", 406 | " d = np.power(np.sum(np.power(np.abs(a-b), p), axis=-1), 1/p)\n", 407 | " return np.sqrt(np.sum(np.square(d), axis=-1))\n", 408 | "\n", 409 | "class FastMap(fastmap.FastMapABC):\n", 410 | " def __init__(self, *args, **kwargs):\n", 411 | " super().__init__(*args, **kwargs)\n", 412 | " self._distance_func = minkowski_distance" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "id": "c0c70f7d-7d0a-47bd-9638-86bcb96c3ffa", 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [ 422 | "fm = FastMap(2)\n", 423 | "fm.fit(X_train, y=y_train)\n", 424 | "W = fm.transform(X_train)\n", 425 | "\n", 426 | "plt.close('all')\n", 427 | "fig, ax = plt.subplots()\n", 428 | "for i in range(2):\n", 429 | " idxs = np.argwhere(y_train == i).flatten()\n", 430 | " ax.scatter(W[idxs, 0], W[idxs, 1])" 431 | ] 432 | } 433 | ], 434 | "metadata": { 435 | "kernelspec": { 436 | "display_name": "Python [conda env:py310]", 437 | "language": "python", 438 | "name": "conda-env-py310-py" 439 | }, 440 | "language_info": { 441 | "codemirror_mode": { 442 | "name": "ipython", 443 | "version": 3 444 | }, 445 | "file_extension": ".py", 446 | "mimetype": "text/x-python", 447 | "name": "python", 448 | "nbconvert_exporter": "python", 449 | "pygments_lexer": "ipython3", 450 | "version": "3.10.8" 451 | } 452 | }, 453 | "nbformat": 4, 454 | "nbformat_minor": 5 455 | } 456 | --------------------------------------------------------------------------------