├── .DS_Store ├── LICENSE ├── README.md ├── checkpoint └── preact_resnet18.pth ├── data └── .DS_Store ├── otdd ├── .ipynb_checkpoints │ ├── __init__-checkpoint.py │ ├── plotting-checkpoint.py │ └── utils-checkpoint.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── plotting.cpython-38.pyc │ ├── plotting.cpython-39.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-36.pyc │ ├── utils.cpython-38.pyc │ └── utils.cpython-39.pyc ├── plotting.py ├── pytorch │ ├── .ipynb_checkpoints │ │ ├── Untitled-checkpoint.ipynb │ │ ├── __init__-checkpoint.py │ │ ├── datasets-checkpoint.py │ │ ├── distance-checkpoint.py │ │ ├── distance_double-checkpoint.py │ │ ├── distance_fast-checkpoint.py │ │ ├── flows-checkpoint.py │ │ ├── functionals-checkpoint.py │ │ ├── moments-checkpoint.py │ │ ├── nets-checkpoint.py │ │ ├── old_distance-checkpoint.py │ │ ├── sqrtm-checkpoint.py │ │ ├── utils-checkpoint.py │ │ └── wasserstein-checkpoint.py │ ├── Untitled.ipynb │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── datasets.cpython-310.pyc │ │ ├── datasets.cpython-36.pyc │ │ ├── datasets.cpython-38.pyc │ │ ├── datasets.cpython-39.pyc │ │ ├── datasets_2.cpython-38.pyc │ │ ├── distance.cpython-38.pyc │ │ ├── distance.cpython-39.pyc │ │ ├── distance_2.cpython-38.pyc │ │ ├── distance_double.cpython-38.pyc │ │ ├── distance_fast.cpython-38.pyc │ │ ├── distance_fast.cpython-39.pyc │ │ ├── moments.cpython-38.pyc │ │ ├── moments.cpython-39.pyc │ │ ├── nets.cpython-38.pyc │ │ ├── nets.cpython-39.pyc │ │ ├── sqrtm.cpython-38.pyc │ │ ├── sqrtm.cpython-39.pyc │ │ ├── utils.cpython-310.pyc │ │ ├── utils.cpython-38.pyc │ │ ├── utils.cpython-39.pyc │ │ ├── utils_2.cpython-38.pyc │ │ ├── wasserstein.cpython-38.pyc │ │ └── wasserstein.cpython-39.pyc │ ├── datasets.py │ ├── datasets3.py │ ├── datasets_2.py │ ├── distance.py │ ├── distance_2.py │ ├── distance_double.py │ ├── distance_fast.py │ ├── flows.py │ ├── functionals.py │ ├── moments.py │ ├── nets.py │ ├── old_distance.py │ ├── sqrtm.py │ ├── utils.py │ ├── utils_2.py │ └── wasserstein.py └── utils.py ├── pipeline_projektor.png ├── prep_train_data.py └── projektor_cifar10_example.ipynb /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 REDS LAB 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Projektor-NeurIPS 2023] Performance Scaling via Optimal Transport: Enabling Data Selection from Partially Revealed Sources 2 | ![Python 3.8.10](https://img.shields.io/badge/python-3.8.10-DodgerBlue.svg?style=plastic) 3 | 4 | This repository is the official implementation of the "[Performance Scaling via Optimal Transport: Enabling Data Selection from Partially Revealed Sources](https://arxiv.org/abs/2307.02460)" (NeurIPS 2023). 5 | 6 | 7 | ![projektor](pipeline_projektor.png) 8 | 9 | 10 | We propose a performance estimator for a model trained on any data source composition given limited sample information. 11 | We further develop a novel optimal transport based scaling law to predict performance on larger scales, which effectively finds the optimal composition of data sources for 12 | any target data size. 13 | 14 | 15 | 16 | 17 | 18 | # Getting Started 19 | 20 | 21 | ## Examples 22 | 23 | For better understanding of applying **projektor** to data source selection and performance scaling, we have provided a tutorial Jupyter notebook `projektor_cifar10_example.ipynb`. 24 | 25 | ## Data 26 | 27 | The datasets should be placed in the folder ['data'](data). 28 | Please download the necessary datasets, e.g. CIFAR10. 29 | 30 | 31 | ## Acknoledgment 32 | 33 | RJ and the ReDS lab acknowledge support through grants from the Amazon-Virginia Tech Initiative 34 | for Efficient and Robust Machine Learning, the National Science Foundation under Grant No. 35 | IIS-2312794, NSF IIS-2313130, NSF OAC-2239622, and the Commonwealth Cyber Initiative 36 | 37 | -------------------------------------------------------------------------------- /checkpoint/preact_resnet18.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/checkpoint/preact_resnet18.pth -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/data/.DS_Store -------------------------------------------------------------------------------- /otdd/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import dirname, abspath 3 | import logging 4 | # Defaults 5 | ROOT_DIR = dirname(dirname(abspath(__file__))) # Project Root 6 | HOME_DIR = os.getenv("HOME") # User home dir 7 | DATA_DIR = os.path.join(ROOT_DIR, 'data') 8 | OUTPUT_DIR = os.path.join(ROOT_DIR, 'out') 9 | MODELS_DIR = os.path.join(ROOT_DIR, 'models') 10 | from .utils import launch_logger 11 | -------------------------------------------------------------------------------- /otdd/.ipynb_checkpoints/plotting-checkpoint.py: -------------------------------------------------------------------------------- 1 | """Plotting tools for Optimal Transport Dataset Distance. 2 | 3 | 4 | """ 5 | 6 | import logging 7 | import matplotlib as mpl 8 | import matplotlib.pyplot as plt 9 | from matplotlib import cm 10 | 11 | import numpy as np 12 | import seaborn as sns 13 | import torch 14 | 15 | import scipy.stats 16 | from scipy.stats import pearsonr, spearmanr 17 | 18 | from mpl_toolkits.axes_grid1 import make_axes_locatable 19 | 20 | from adjustText import adjust_text 21 | 22 | import pdb 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def as_si(x, ndp): 27 | """ Convert humber to latex-style x10 scientific notation string""" 28 | s = '{x:0.{ndp:d}e}'.format(x=x, ndp=ndp) 29 | m, e = s.split('e') 30 | return r'{m:s}\times 10^{{{e:d}}}'.format(m=m, e=int(e)) 31 | 32 | 33 | def get_plot_ranges(X): 34 | x, y = X[:,0], X[:,1] 35 | dx = (x.max() - x.min())/10 36 | dy = (y.max() - y.min())/10 37 | xmin = x.min() - dx 38 | xmax = x.max() + dx 39 | ymin = y.min() - dy 40 | ymax = y.max() + dy 41 | return (xmin,xmax,ymin,ymax) 42 | 43 | def gaussian_density_plot(P=None, X=None, method = 'exact', nsamples = 1000, 44 | color='blue', label_means=True, cmap='coolwarm',ax=None,eps=1e-4): 45 | if X is None and P is not None: 46 | X = P.sample(sample_shape=torch.Size([nsamples])).numpy() 47 | 48 | if ax is None: 49 | fig = plt.figure(figsize=(8,8)) 50 | ax = fig.gca() 51 | xmin, xmax, ymin, ymax = get_plot_ranges(X) 52 | logger.info(xmin, xmax, ymin, ymax) 53 | ax.set_xlim(xmin, xmax) 54 | ax.set_ylim(ymin, ymax) 55 | else: 56 | xmin,xmax = ax.get_xlim() 57 | ymin,ymax = ax.get_ylim() 58 | 59 | XY = np.mgrid[xmin:xmax:100j, ymin:ymax:100j] 60 | xx,yy = XY[0,:,:],XY[1,:,:] 61 | 62 | 63 | if method == 'samples': 64 | positions = np.vstack([xx.ravel(), yy.ravel()]) 65 | kernel = scipy.stats.gaussian_kde(X.T) 66 | f = np.reshape(kernel(positions).T, xx.shape) 67 | elif method == 'exact': 68 | μ,Σ = P.loc.numpy(), P.covariance_matrix.numpy() 69 | f = scipy.stats.multivariate_normal.pdf(XY.transpose(1,2,0),μ,Σ) 70 | 71 | step = 0.01 72 | levels = np.arange(0, np.amax(f), step) + step 73 | 74 | if len(levels) < 2: 75 | levels = [step/2, levels[0]] 76 | 77 | cfset = ax.contourf(xx, yy, f, levels, cmap=cmap, alpha=0.5) 78 | 79 | cset = ax.contour(xx, yy, f, levels, colors='k', alpha=0.5) 80 | ax.clabel(cset, inline=1, fontsize=10) 81 | ax.set_xlabel('X') 82 | ax.set_ylabel('Y') 83 | if method == 'samples': 84 | ax.scatter(X[:,0], X[:,1], color=cmap(0.8)) 85 | ax.set_title('2D Gaussian Kernel density estimation') 86 | elif method == 'exact': 87 | ax.scatter(μ[0],μ[1], s=5, c= 'black') 88 | if label_means: 89 | ax.text(μ[0]+eps,μ[1]+eps,'μ=({:.2},{:.2})'.format(μ[0],μ[1]), fontsize=12) 90 | ax.set_title('Exact Gaussian Density') 91 | 92 | 93 | def heatmap(data, row_labels, col_labels, ax=None, cbar=True, 94 | cbar_kw={}, cbarlabel="", **kwargs): 95 | """ Create a heatmap from a numpy array and two lists of labels. 96 | 97 | Args: 98 | data: A 2D numpy array of shape (N, M). 99 | row_labels: A list or array of length N with the labels for the rows. 100 | col_labels: A list or array of length M with the labels for the columns. 101 | ax: A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If 102 | not provided, use current axes or create a new one. Optional. 103 | cbar: A boolear value, whether to display colorbar or not 104 | cbar_kw: A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. 105 | cbarlabel: The label for the colorbar. Optional. 106 | **kwargs: All other arguments are forwarded to `imshow`. 107 | """ 108 | 109 | if not ax: 110 | ax = plt.gca() 111 | 112 | im = ax.imshow(data, **kwargs) 113 | 114 | 115 | if cbar: 116 | if 'alpha' in kwargs: 117 | cbar_kw['alpha'] = kwargs.get('alpha') 118 | cbar = ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04, **cbar_kw) 119 | cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") 120 | 121 | ax.set_xticks(np.arange(data.shape[1])) 122 | ax.set_yticks(np.arange(data.shape[0])) 123 | ax.set_xticklabels(col_labels) 124 | ax.set_yticklabels(row_labels) 125 | 126 | ax.tick_params(top=False, bottom=True, 127 | labeltop=False, labelbottom=True) 128 | 129 | plt.setp(ax.get_xticklabels(), rotation=0, ha="right", rotation_mode="anchor") 130 | 131 | for edge, spine in ax.spines.items(): 132 | spine.set_visible(False) 133 | 134 | ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) 135 | ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) 136 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 137 | ax.tick_params(which="minor", bottom=False, left=False) 138 | 139 | return im, cbar 140 | 141 | 142 | def annotate_heatmap(im, data=None, valfmt="{x:.2f}", 143 | textcolors=["black", "white"], 144 | threshold=None, **textkw): 145 | """ A function to annotate a heatmap. 146 | 147 | Args: 148 | im: The AxesImage to be labeled. 149 | data: Data used to annotate. If None, the image's data is used. Optional. 150 | valfmt: The format of the annotations inside the heatmap. This should either 151 | use the string format method, e.g. "$ {x:.2f}", or be a 152 | `matplotlib.ticker.Formatter`. Optional. 153 | textcolors: A list or array of two color specifications. The first is used for 154 | values below a threshold, the second for those above. Optional. 155 | threshold: Value in data units according to which the colors from textcolors are 156 | applied. If None (the default) uses the middle of the colormap as 157 | separation. Optional. 158 | **kwargs: All other arguments are forwarded to each call to `text` used to create 159 | the text labels. 160 | """ 161 | 162 | if not isinstance(data, (list, np.ndarray)): 163 | data = im.get_array() 164 | 165 | if threshold is not None: 166 | threshold = im.norm(threshold) 167 | else: 168 | threshold = im.norm(data.max())/2. 169 | 170 | kw = dict(horizontalalignment="center", 171 | verticalalignment="center") 172 | kw.update(textkw) 173 | 174 | if isinstance(valfmt, str): 175 | valfmt = mpl.ticker.StrMethodFormatter(valfmt) 176 | 177 | texts = [] 178 | for i in range(data.shape[0]): 179 | for j in range(data.shape[1]): 180 | kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) 181 | text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) 182 | texts.append(text) 183 | 184 | return texts 185 | 186 | 187 | def distance_scatter(d, topk=10, show=True, save_path =None): 188 | """ Distance vs adaptation scatter plots as used in the OTDD paper. 189 | Args: 190 | d (dict): dictionary of task pair (string), distance (float) 191 | topk (int): number k of top/bottom distances that will be annotated 192 | """ 193 | sorted_d = sorted(d.items(), key=lambda kv: kv[1]) 194 | keys, dists = zip(*sorted_d) 195 | if type(keys[0]) is tuple and len(keys[0]) == 2: 196 | labels = ['{}<->{}'.format(p,q) for (p,q) in keys] 197 | else: 198 | labels = ['{}'.format(p) for p in keys] 199 | x_coord = np.linspace(0,1,len(keys)) 200 | 201 | fig, ax = plt.subplots(figsize=(10,10)) 202 | ax.scatter(x_coord, dists, s = min(100/len(keys), 1)) 203 | texts=[] 204 | for i, (x, y, name) in enumerate(zip(x_coord,dists,keys)): 205 | if i < topk or i >= len(keys) - topk: 206 | label = '{}<->{}'.format(*name) if type(name) is tuple else str(name) 207 | texts.append(ax.text(x, y, label)) 208 | adjust_text(texts, force_text=0.05, arrowprops=dict(arrowstyle="-|>", 209 | color='r', alpha=0.5)) 210 | 211 | ax.set_title('Pairwise Distance Between MNIST Binary Classification Tasks') 212 | ax.set_ylabel('Dataset Distance') 213 | if save_path: 214 | plt.savefig(save_path, format='pdf', dpi=300) #bbox_inches='tight', 215 | if show: plt.show() 216 | 217 | def dist_adapt_joinplot(df, yvar='delta', show=True, type='joinplot', save_path = None): 218 | j = sns.jointplot(x='dist', y=yvar, data=df, kind="reg", height=7) 219 | j.annotate(scipy.stats.pearsonr) 220 | y_label = 'Acc. Improvement w/ Adapt'#.format(direction[yvar]) 221 | j.set_axis_labels('OT Task Distance', y_label) 222 | if save_path: 223 | plt.savefig(save_path, format='pdf', dpi=300) #bbox_inches='tight', 224 | if show: 225 | plt.show() 226 | 227 | 228 | def dist_adapt_regplot(df, yvar, xvar='dist', xerrvar=None, yerrvar=None, 229 | figsize=(6,5), title=None, 230 | show_correlation=True, corrtype='pearson', sci_pval=True, 231 | annotate=True, annotation_arrows=True, annotation_fontsize=12, 232 | force_text=0.5, 233 | legend_fontsize=12, 234 | title_fontsize=12, 235 | marker_size=10, 236 | arrowcolor='gray', 237 | barcolor='gray', 238 | xlabel = 'OT Dataset Distance', 239 | ylabel = r'Relative Drop in Test Error ($\%$)', 240 | color='#1f77b4', 241 | lw=1, 242 | ax=None, 243 | show=True, 244 | save_path=None): 245 | 246 | if ax is None: 247 | fig, ax = plt.subplots(figsize=figsize) 248 | else: 249 | show = False 250 | 251 | 252 | #### Compute Correlation 253 | if show_correlation: 254 | if corrtype == 'spearman': 255 | corr, p = spearmanr(df[xvar], df[yvar]) 256 | corrsymbol = '\\rho' 257 | elif corrtype == 'pearson': 258 | corr, p = pearsonr(df[xvar], df[yvar]) 259 | corrsymbol = 'r' 260 | else: 261 | raise ValueError('Unrecognized correlation type') 262 | if p < 0.01 and sci_pval: 263 | legend_label = r"${}: {:2.2f}$".format(corrsymbol,corr) + "\n" + r"p-value: ${:s}$".format(as_si(p,1)) 264 | else: 265 | legend_label = r"${}: {:2.2f}$".format(corrsymbol,corr) + "\n" + r"p-value: ${:2.2f}$".format(p) 266 | else: 267 | legend_label = None 268 | 269 | 270 | ### Actual Plots - First does scatter only, second does line 271 | sns.regplot(x=xvar, y=yvar, data=df, ax = ax, color=color, label=legend_label, 272 | scatter_kws={'s':marker_size}, 273 | line_kws={'lw': 1} 274 | ) 275 | 276 | ### Add Error Bars 277 | if xerrvar or yerrvar: 278 | xerr = df[xerrvar] if xerrvar else None 279 | yerr = df[yerrvar] if yerrvar else None 280 | ax.errorbar(df[xvar], df[yvar], xerr=xerr, yerr=yerr, fmt='none', ecolor='#d6d4d4', alpha=0.75,elinewidth=0.75) 281 | 282 | ### Annotate Points 283 | if annotate: 284 | texts = [] 285 | for i,a in df.iterrows(): 286 | lab = r'{}$\rightarrow${}'.format(a.src,a.tgt) if a.tgt is not None else r'{}'.format(a.src) 287 | texts.append(ax.text(a[xvar], a[yvar], lab,fontsize=annotation_fontsize)) 288 | if annotation_arrows: 289 | adjust_text(texts, force_text=force_text, arrowprops=dict(arrowstyle="-", color=arrowcolor, alpha=0.5, lw=0.5)) 290 | else: 291 | adjust_text(texts, force_text=force_text) 292 | 293 | ### Fix Legend for Correlation (otherwise don't show) 294 | if show_correlation: 295 | plt.rc('legend',fontsize=legend_fontsize)#,borderpad=0.2,handletextpad=0, handlelength=0) # using a size in points 296 | ax.legend([ax.get_lines()[0]], ax.get_legend_handles_labels()[-1],handlelength=1.0,loc='best')#, handletextpad=0.0) 297 | 298 | 299 | ### Add title and labels 300 | ax.set_xlabel(xlabel, fontsize=title_fontsize) 301 | ax.set_ylabel(ylabel, fontsize=title_fontsize) 302 | ax.set_title(r'Distance vs Adaptation' + (': {}'.format(title) if title else ''), fontsize=title_fontsize) 303 | 304 | if save_path: 305 | plt.savefig(save_path+'.pdf', dpi=300, bbox_inches = "tight") 306 | plt.savefig(save_path+'.png', dpi=300, bbox_inches = "tight") 307 | 308 | if show: plt.show() 309 | 310 | return ax 311 | 312 | 313 | def plot2D_samples_mat(xs, xt, G, thr=1e-8, ax=None, **kwargs): 314 | """ (ADAPTED FROM PYTHON OT LIBRARY). 315 | Plot matrix M in 2D with lines using alpha values 316 | Plot lines between source and target 2D samples with a color 317 | proportional to the value of the matrix G between samples. 318 | Parameters 319 | ---------- 320 | xs : ndarray, shape (ns,2) 321 | Source samples positions 322 | b : ndarray, shape (nt,2) 323 | Target samples positions 324 | G : ndarray, shape (na,nb) 325 | OT matrix 326 | thr : float, optional 327 | threshold above which the line is drawn 328 | **kwargs : dict 329 | paameters given to the plot functions (default color is black if 330 | nothing given) 331 | """ 332 | if ('color' not in kwargs) and ('c' not in kwargs): 333 | kwargs['color'] = 'gray' 334 | mx = G.max() 335 | if not ax: 336 | fig,ax = plt.subplots() 337 | for i in range(xs.shape[0]): 338 | for j in range(xt.shape[0]): 339 | if G[i, j] / mx > thr: 340 | ax.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], 341 | alpha=G[i, j] / mx, **kwargs) 342 | 343 | return ax 344 | 345 | 346 | def annotate_group(name, span, ax=None, orient='h', side=None): 347 | """Annotates a span of the x-axis (or y-axis if orient ='v')""" 348 | if not side: 349 | side = 'left' if orient == 'v' else 'bottom' 350 | def annotate(ax, name, left, right, y, pad): 351 | xy = (left, y) if orient == 'h' else (y, left) 352 | xytext=(right, y+pad) if orient =='h' else (y+pad, right) 353 | valign = 'top' if orient =='h' else 'center' 354 | halign = 'center' if orient == 'h' else 'center' 355 | rot = 0 if orient == 'h' else 0 356 | if orient == 'h': 357 | connectionstyle='angle,angleB=90,angleA=0,rad=5' 358 | else: 359 | connectionstyle='angle,angleB=0,angleA=-90,rad=5' 360 | 361 | arrow = ax.annotate(name, 362 | xy=xy, xycoords='data', 363 | xytext=xytext, textcoords='data', 364 | annotation_clip=False, verticalalignment=valign, 365 | horizontalalignment=halign, linespacing=2.0, 366 | arrowprops=dict(arrowstyle='-', shrinkA=0, shrinkB=0, 367 | connectionstyle=connectionstyle), 368 | fontsize=8, rotation=rot 369 | ) 370 | return arrow 371 | if ax is None: 372 | ax = plt.gca() 373 | lims = ax.get_ylim() if orient=='h' else ax.get_xlim() 374 | range = np.abs(lims[1] - lims[0]) 375 | lim = lims[0] if side == 'bottom' or side == 'left' else lims[1] 376 | 377 | if side == 'bottom': 378 | arrow_coord = lim + 0.01*range# if orient == 'h' else lim - 0.02*range 379 | text_pad = 0.02*range 380 | elif side == 'right': 381 | arrow_coord = lim + 0.01*range# if orient == 'h' else lim - 0.02*range 382 | text_pad = 0.02*range 383 | elif side == 'top': 384 | arrow_coord = lim - 0.01*range# if orient == 'h' else lim - 0.02*range 385 | text_pad = -0.05*range 386 | else: # left 387 | arrow_coord = lim - 0.01*range 388 | text_pad = -0.02*range 389 | 390 | 391 | 392 | center = np.mean(span) 393 | left_arrow = annotate(ax, name, span[0], center, arrow_coord, text_pad) 394 | right_arrow = annotate(ax, name, span[1], center, arrow_coord, text_pad) 395 | return left_arrow, right_arrow 396 | 397 | 398 | def imshow_group_boundaries(ax, gU, gV, group_names, side = 'both', alpha=0.2, lw=0.5): 399 | """Imshow must be sorted according to order in groups""" 400 | if side in ['source','both']: 401 | xmin,xmax = ax.get_xlim() 402 | ax.hlines(np.cumsum(gU[:-1]) - 0.5,xmin=xmin,xmax=xmax,lw=lw, linestyles='dashed', alpha = alpha) 403 | if side in ['target','both']: 404 | ymin,ymax = ax.get_ylim() 405 | ax.vlines(np.cumsum(gV[:-1]) - 0.5,ymin=ymin,ymax=ymax,lw=lw,linestyles='dashed', alpha=alpha) 406 | 407 | if group_names: 408 | offset = -0.5 409 | posx = np.cumsum(gU)# + offset 410 | posy = np.cumsum(gV)# + offset 411 | posx = np.insert(posx, 0, offset) 412 | posy = np.insert(posy, 0, offset) 413 | for i,y in enumerate(posy[:-1]): 414 | annotate_group(group_names[1][i], (posy[i], posy[i+1]), ax, orient='h', side = 'top') 415 | for i,x in enumerate(posx[:-1]): 416 | annotate_group(group_names[0][i], (posx[i], posx[i+1]), ax, orient='v', side = 'right') 417 | 418 | 419 | def method_comparison_plot(df, hue_var = 'method', style_var = 'method', 420 | figsize = (15,4), ax = None, save_path=None): 421 | """ Produce plots comparing OTDD variants in terms of runtime and distance """ 422 | if ax is None: 423 | fig, ax = plt.subplots(1, 2, figsize=figsize) 424 | 425 | lplot_args = { 426 | 'hue': hue_var, 427 | 'style': style_var, 428 | 'data': df, 429 | 'x': 'n', 430 | 'markers': True 431 | } 432 | 433 | sns.lineplot(y='dist', ax= ax[0], **lplot_args) 434 | ax[0].legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=16) 435 | ax[0].set_ylabel('Dataset Distance') 436 | ax[0].set_xlabel('Dataset Size') 437 | ax[0].set_xscale("log") 438 | ax[0].grid(True,which="both",ls="--",c='gray') 439 | 440 | sns.lineplot(y='time', ax= ax[1], **lplot_args) 441 | ax[1].legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=16) 442 | ax[1].set_ylabel('Runtime (s)') 443 | ax[1].set_xlabel('Dataset Size') 444 | ax[1].set_xscale("log") 445 | ax[1].set_yscale("log") 446 | ax[1].grid(True,which="both",ls="--",c='gray') 447 | 448 | handles, labels = ax[1].get_legend_handles_labels() 449 | ax[1].get_legend().remove() 450 | 451 | plt.tight_layout() 452 | if save_path: 453 | plt.savefig(save_path + '.pdf', dpi=300) 454 | plt.savefig(save_path + '.png', dpi=300) 455 | plt.show() 456 | 457 | return ax 458 | -------------------------------------------------------------------------------- /otdd/.ipynb_checkpoints/utils-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle as pkl 4 | import pdb 5 | import shutil 6 | import logging 7 | import tempfile 8 | 9 | def launch_logger(console_level='warning'): 10 | ############################### Logging Config ################################# 11 | ## Remove all handlers of root logger object -> needed to override basicConfig above 12 | for handler in logging.root.handlers[:]: 13 | logging.root.removeHandler(handler) 14 | 15 | _logger = logging.getLogger() 16 | _logger.setLevel(logging.INFO) # Has to be min of all the others 17 | 18 | ## create file handler which logs even debug messages, use random logfile name 19 | logfile = tempfile.NamedTemporaryFile(prefix="otddlog_", dir='/tmp').name 20 | fh = logging.FileHandler(logfile) 21 | fh.setLevel(logging.INFO) 22 | 23 | ## create console handler with a higher log level 24 | ch = logging.StreamHandler(stream=sys.stdout) 25 | if console_level == 'warning': 26 | ch.setLevel(logging.WARNING) 27 | elif console_level == 'info': 28 | ch.setLevel(logging.INFO) 29 | else: 30 | raise ValueError() 31 | ## create formatter and add it to the handlers 32 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(levelname)s: %(message)s', 33 | datefmt='%Y-%m-%d %H:%M:%S') 34 | fh.setFormatter(formatter) 35 | ch.setFormatter(formatter) 36 | _logger.addHandler(fh) 37 | _logger.addHandler(ch) 38 | ################################################################################ 39 | return _logger 40 | 41 | def safedump(d,f): 42 | try: 43 | pkl.dump(d, open(f, 'wb')) 44 | except: 45 | pdb.set_trace() 46 | 47 | def append_to_file(fname, l): 48 | with open(fname, "a") as f: 49 | f.write('\t'.join(l) + '\n') 50 | 51 | def delete_if_exists(path, typ='f'): 52 | if typ == 'f' and os.path.exists(path): 53 | os.remove(path) 54 | elif typ == 'd' and os.path.isdir(path): 55 | shutil.rmtree(path) 56 | else: 57 | raise ValueError("Unrecognized path type") 58 | -------------------------------------------------------------------------------- /otdd/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import dirname, abspath 3 | import logging 4 | # Defaults 5 | ROOT_DIR = dirname(dirname(abspath(__file__))) # Project Root 6 | HOME_DIR = os.getenv("HOME") # User home dir 7 | DATA_DIR = os.path.join(ROOT_DIR, 'data') 8 | OUTPUT_DIR = os.path.join(ROOT_DIR, 'out') 9 | MODELS_DIR = os.path.join(ROOT_DIR, 'models') 10 | from .utils import launch_logger 11 | -------------------------------------------------------------------------------- /otdd/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /otdd/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /otdd/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /otdd/__pycache__/plotting.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/__pycache__/plotting.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/__pycache__/plotting.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/__pycache__/plotting.cpython-39.pyc -------------------------------------------------------------------------------- /otdd/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /otdd/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /otdd/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /otdd/plotting.py: -------------------------------------------------------------------------------- 1 | """Plotting tools for Optimal Transport Dataset Distance. 2 | 3 | 4 | """ 5 | 6 | import logging 7 | import matplotlib as mpl 8 | import matplotlib.pyplot as plt 9 | from matplotlib import cm 10 | 11 | import numpy as np 12 | import seaborn as sns 13 | import torch 14 | 15 | import scipy.stats 16 | from scipy.stats import pearsonr, spearmanr 17 | 18 | from mpl_toolkits.axes_grid1 import make_axes_locatable 19 | 20 | from adjustText import adjust_text 21 | 22 | import pdb 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def as_si(x, ndp): 27 | """ Convert humber to latex-style x10 scientific notation string""" 28 | s = '{x:0.{ndp:d}e}'.format(x=x, ndp=ndp) 29 | m, e = s.split('e') 30 | return r'{m:s}\times 10^{{{e:d}}}'.format(m=m, e=int(e)) 31 | 32 | 33 | def get_plot_ranges(X): 34 | x, y = X[:,0], X[:,1] 35 | dx = (x.max() - x.min())/10 36 | dy = (y.max() - y.min())/10 37 | xmin = x.min() - dx 38 | xmax = x.max() + dx 39 | ymin = y.min() - dy 40 | ymax = y.max() + dy 41 | return (xmin,xmax,ymin,ymax) 42 | 43 | def gaussian_density_plot(P=None, X=None, method = 'exact', nsamples = 1000, 44 | color='blue', label_means=True, cmap='coolwarm',ax=None,eps=1e-4): 45 | if X is None and P is not None: 46 | X = P.sample(sample_shape=torch.Size([nsamples])).numpy() 47 | 48 | if ax is None: 49 | fig = plt.figure(figsize=(8,8)) 50 | ax = fig.gca() 51 | xmin, xmax, ymin, ymax = get_plot_ranges(X) 52 | logger.info(xmin, xmax, ymin, ymax) 53 | ax.set_xlim(xmin, xmax) 54 | ax.set_ylim(ymin, ymax) 55 | else: 56 | xmin,xmax = ax.get_xlim() 57 | ymin,ymax = ax.get_ylim() 58 | 59 | XY = np.mgrid[xmin:xmax:100j, ymin:ymax:100j] 60 | xx,yy = XY[0,:,:],XY[1,:,:] 61 | 62 | 63 | if method == 'samples': 64 | positions = np.vstack([xx.ravel(), yy.ravel()]) 65 | kernel = scipy.stats.gaussian_kde(X.T) 66 | f = np.reshape(kernel(positions).T, xx.shape) 67 | elif method == 'exact': 68 | μ,Σ = P.loc.numpy(), P.covariance_matrix.numpy() 69 | f = scipy.stats.multivariate_normal.pdf(XY.transpose(1,2,0),μ,Σ) 70 | 71 | step = 0.01 72 | levels = np.arange(0, np.amax(f), step) + step 73 | 74 | if len(levels) < 2: 75 | levels = [step/2, levels[0]] 76 | 77 | cfset = ax.contourf(xx, yy, f, levels, cmap=cmap, alpha=0.5) 78 | 79 | cset = ax.contour(xx, yy, f, levels, colors='k', alpha=0.5) 80 | ax.clabel(cset, inline=1, fontsize=10) 81 | ax.set_xlabel('X') 82 | ax.set_ylabel('Y') 83 | if method == 'samples': 84 | ax.scatter(X[:,0], X[:,1], color=cmap(0.8)) 85 | ax.set_title('2D Gaussian Kernel density estimation') 86 | elif method == 'exact': 87 | ax.scatter(μ[0],μ[1], s=5, c= 'black') 88 | if label_means: 89 | ax.text(μ[0]+eps,μ[1]+eps,'μ=({:.2},{:.2})'.format(μ[0],μ[1]), fontsize=12) 90 | ax.set_title('Exact Gaussian Density') 91 | 92 | 93 | def heatmap(data, row_labels, col_labels, ax=None, cbar=True, 94 | cbar_kw={}, cbarlabel="", **kwargs): 95 | """ Create a heatmap from a numpy array and two lists of labels. 96 | 97 | Args: 98 | data: A 2D numpy array of shape (N, M). 99 | row_labels: A list or array of length N with the labels for the rows. 100 | col_labels: A list or array of length M with the labels for the columns. 101 | ax: A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If 102 | not provided, use current axes or create a new one. Optional. 103 | cbar: A boolear value, whether to display colorbar or not 104 | cbar_kw: A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. 105 | cbarlabel: The label for the colorbar. Optional. 106 | **kwargs: All other arguments are forwarded to `imshow`. 107 | """ 108 | 109 | if not ax: 110 | ax = plt.gca() 111 | 112 | im = ax.imshow(data, **kwargs) 113 | 114 | 115 | if cbar: 116 | if 'alpha' in kwargs: 117 | cbar_kw['alpha'] = kwargs.get('alpha') 118 | cbar = ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04, **cbar_kw) 119 | cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") 120 | 121 | ax.set_xticks(np.arange(data.shape[1])) 122 | ax.set_yticks(np.arange(data.shape[0])) 123 | ax.set_xticklabels(col_labels) 124 | ax.set_yticklabels(row_labels) 125 | 126 | ax.tick_params(top=False, bottom=True, 127 | labeltop=False, labelbottom=True) 128 | 129 | plt.setp(ax.get_xticklabels(), rotation=0, ha="right", rotation_mode="anchor") 130 | 131 | for edge, spine in ax.spines.items(): 132 | spine.set_visible(False) 133 | 134 | ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) 135 | ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) 136 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 137 | ax.tick_params(which="minor", bottom=False, left=False) 138 | 139 | return im, cbar 140 | 141 | 142 | def annotate_heatmap(im, data=None, valfmt="{x:.2f}", 143 | textcolors=["black", "white"], 144 | threshold=None, **textkw): 145 | """ A function to annotate a heatmap. 146 | 147 | Args: 148 | im: The AxesImage to be labeled. 149 | data: Data used to annotate. If None, the image's data is used. Optional. 150 | valfmt: The format of the annotations inside the heatmap. This should either 151 | use the string format method, e.g. "$ {x:.2f}", or be a 152 | `matplotlib.ticker.Formatter`. Optional. 153 | textcolors: A list or array of two color specifications. The first is used for 154 | values below a threshold, the second for those above. Optional. 155 | threshold: Value in data units according to which the colors from textcolors are 156 | applied. If None (the default) uses the middle of the colormap as 157 | separation. Optional. 158 | **kwargs: All other arguments are forwarded to each call to `text` used to create 159 | the text labels. 160 | """ 161 | 162 | if not isinstance(data, (list, np.ndarray)): 163 | data = im.get_array() 164 | 165 | if threshold is not None: 166 | threshold = im.norm(threshold) 167 | else: 168 | threshold = im.norm(data.max())/2. 169 | 170 | kw = dict(horizontalalignment="center", 171 | verticalalignment="center") 172 | kw.update(textkw) 173 | 174 | if isinstance(valfmt, str): 175 | valfmt = mpl.ticker.StrMethodFormatter(valfmt) 176 | 177 | texts = [] 178 | for i in range(data.shape[0]): 179 | for j in range(data.shape[1]): 180 | kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) 181 | text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) 182 | texts.append(text) 183 | 184 | return texts 185 | 186 | 187 | def distance_scatter(d, topk=10, show=True, save_path =None): 188 | """ Distance vs adaptation scatter plots as used in the OTDD paper. 189 | Args: 190 | d (dict): dictionary of task pair (string), distance (float) 191 | topk (int): number k of top/bottom distances that will be annotated 192 | """ 193 | sorted_d = sorted(d.items(), key=lambda kv: kv[1]) 194 | keys, dists = zip(*sorted_d) 195 | if type(keys[0]) is tuple and len(keys[0]) == 2: 196 | labels = ['{}<->{}'.format(p,q) for (p,q) in keys] 197 | else: 198 | labels = ['{}'.format(p) for p in keys] 199 | x_coord = np.linspace(0,1,len(keys)) 200 | 201 | fig, ax = plt.subplots(figsize=(10,10)) 202 | ax.scatter(x_coord, dists, s = min(100/len(keys), 1)) 203 | texts=[] 204 | for i, (x, y, name) in enumerate(zip(x_coord,dists,keys)): 205 | if i < topk or i >= len(keys) - topk: 206 | label = '{}<->{}'.format(*name) if type(name) is tuple else str(name) 207 | texts.append(ax.text(x, y, label)) 208 | adjust_text(texts, force_text=0.05, arrowprops=dict(arrowstyle="-|>", 209 | color='r', alpha=0.5)) 210 | 211 | ax.set_title('Pairwise Distance Between MNIST Binary Classification Tasks') 212 | ax.set_ylabel('Dataset Distance') 213 | if save_path: 214 | plt.savefig(save_path, format='pdf', dpi=300) #bbox_inches='tight', 215 | if show: plt.show() 216 | 217 | def dist_adapt_joinplot(df, yvar='delta', show=True, type='joinplot', save_path = None): 218 | j = sns.jointplot(x='dist', y=yvar, data=df, kind="reg", height=7) 219 | j.annotate(scipy.stats.pearsonr) 220 | y_label = 'Acc. Improvement w/ Adapt'#.format(direction[yvar]) 221 | j.set_axis_labels('OT Task Distance', y_label) 222 | if save_path: 223 | plt.savefig(save_path, format='pdf', dpi=300) #bbox_inches='tight', 224 | if show: 225 | plt.show() 226 | 227 | 228 | def dist_adapt_regplot(df, yvar, xvar='dist', xerrvar=None, yerrvar=None, 229 | figsize=(6,5), title=None, 230 | show_correlation=True, corrtype='pearson', sci_pval=True, 231 | annotate=True, annotation_arrows=True, annotation_fontsize=12, 232 | force_text=0.5, 233 | legend_fontsize=12, 234 | title_fontsize=12, 235 | marker_size=10, 236 | arrowcolor='gray', 237 | barcolor='gray', 238 | xlabel = 'OT Dataset Distance', 239 | ylabel = r'Relative Drop in Test Error ($\%$)', 240 | color='#1f77b4', 241 | lw=1, 242 | ax=None, 243 | show=True, 244 | save_path=None): 245 | 246 | if ax is None: 247 | fig, ax = plt.subplots(figsize=figsize) 248 | else: 249 | show = False 250 | 251 | 252 | #### Compute Correlation 253 | if show_correlation: 254 | if corrtype == 'spearman': 255 | corr, p = spearmanr(df[xvar], df[yvar]) 256 | corrsymbol = '\\rho' 257 | elif corrtype == 'pearson': 258 | corr, p = pearsonr(df[xvar], df[yvar]) 259 | corrsymbol = 'r' 260 | else: 261 | raise ValueError('Unrecognized correlation type') 262 | if p < 0.01 and sci_pval: 263 | legend_label = r"${}: {:2.2f}$".format(corrsymbol,corr) + "\n" + r"p-value: ${:s}$".format(as_si(p,1)) 264 | else: 265 | legend_label = r"${}: {:2.2f}$".format(corrsymbol,corr) + "\n" + r"p-value: ${:2.2f}$".format(p) 266 | else: 267 | legend_label = None 268 | 269 | 270 | ### Actual Plots - First does scatter only, second does line 271 | sns.regplot(x=xvar, y=yvar, data=df, ax = ax, color=color, label=legend_label, 272 | scatter_kws={'s':marker_size}, 273 | line_kws={'lw': 1} 274 | ) 275 | 276 | ### Add Error Bars 277 | if xerrvar or yerrvar: 278 | xerr = df[xerrvar] if xerrvar else None 279 | yerr = df[yerrvar] if yerrvar else None 280 | ax.errorbar(df[xvar], df[yvar], xerr=xerr, yerr=yerr, fmt='none', ecolor='#d6d4d4', alpha=0.75,elinewidth=0.75) 281 | 282 | ### Annotate Points 283 | if annotate: 284 | texts = [] 285 | for i,a in df.iterrows(): 286 | lab = r'{}$\rightarrow${}'.format(a.src,a.tgt) if a.tgt is not None else r'{}'.format(a.src) 287 | texts.append(ax.text(a[xvar], a[yvar], lab,fontsize=annotation_fontsize)) 288 | if annotation_arrows: 289 | adjust_text(texts, force_text=force_text, arrowprops=dict(arrowstyle="-", color=arrowcolor, alpha=0.5, lw=0.5)) 290 | else: 291 | adjust_text(texts, force_text=force_text) 292 | 293 | ### Fix Legend for Correlation (otherwise don't show) 294 | if show_correlation: 295 | plt.rc('legend',fontsize=legend_fontsize)#,borderpad=0.2,handletextpad=0, handlelength=0) # using a size in points 296 | ax.legend([ax.get_lines()[0]], ax.get_legend_handles_labels()[-1],handlelength=1.0,loc='best')#, handletextpad=0.0) 297 | 298 | 299 | ### Add title and labels 300 | ax.set_xlabel(xlabel, fontsize=title_fontsize) 301 | ax.set_ylabel(ylabel, fontsize=title_fontsize) 302 | ax.set_title(r'Distance vs Adaptation' + (': {}'.format(title) if title else ''), fontsize=title_fontsize) 303 | 304 | if save_path: 305 | plt.savefig(save_path+'.pdf', dpi=300, bbox_inches = "tight") 306 | plt.savefig(save_path+'.png', dpi=300, bbox_inches = "tight") 307 | 308 | if show: plt.show() 309 | 310 | return ax 311 | 312 | 313 | def plot2D_samples_mat(xs, xt, G, thr=1e-8, ax=None, **kwargs): 314 | """ (ADAPTED FROM PYTHON OT LIBRARY). 315 | Plot matrix M in 2D with lines using alpha values 316 | Plot lines between source and target 2D samples with a color 317 | proportional to the value of the matrix G between samples. 318 | Parameters 319 | ---------- 320 | xs : ndarray, shape (ns,2) 321 | Source samples positions 322 | b : ndarray, shape (nt,2) 323 | Target samples positions 324 | G : ndarray, shape (na,nb) 325 | OT matrix 326 | thr : float, optional 327 | threshold above which the line is drawn 328 | **kwargs : dict 329 | paameters given to the plot functions (default color is black if 330 | nothing given) 331 | """ 332 | if ('color' not in kwargs) and ('c' not in kwargs): 333 | kwargs['color'] = 'gray' 334 | mx = G.max() 335 | if not ax: 336 | fig,ax = plt.subplots() 337 | for i in range(xs.shape[0]): 338 | for j in range(xt.shape[0]): 339 | if G[i, j] / mx > thr: 340 | ax.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], 341 | alpha=G[i, j] / mx, **kwargs) 342 | 343 | return ax 344 | 345 | 346 | def annotate_group(name, span, ax=None, orient='h', side=None): 347 | """Annotates a span of the x-axis (or y-axis if orient ='v')""" 348 | if not side: 349 | side = 'left' if orient == 'v' else 'bottom' 350 | def annotate(ax, name, left, right, y, pad): 351 | xy = (left, y) if orient == 'h' else (y, left) 352 | xytext=(right, y+pad) if orient =='h' else (y+pad, right) 353 | valign = 'top' if orient =='h' else 'center' 354 | halign = 'center' if orient == 'h' else 'center' 355 | rot = 0 if orient == 'h' else 0 356 | if orient == 'h': 357 | connectionstyle='angle,angleB=90,angleA=0,rad=5' 358 | else: 359 | connectionstyle='angle,angleB=0,angleA=-90,rad=5' 360 | 361 | arrow = ax.annotate(name, 362 | xy=xy, xycoords='data', 363 | xytext=xytext, textcoords='data', 364 | annotation_clip=False, verticalalignment=valign, 365 | horizontalalignment=halign, linespacing=2.0, 366 | arrowprops=dict(arrowstyle='-', shrinkA=0, shrinkB=0, 367 | connectionstyle=connectionstyle), 368 | fontsize=8, rotation=rot 369 | ) 370 | return arrow 371 | if ax is None: 372 | ax = plt.gca() 373 | lims = ax.get_ylim() if orient=='h' else ax.get_xlim() 374 | range = np.abs(lims[1] - lims[0]) 375 | lim = lims[0] if side == 'bottom' or side == 'left' else lims[1] 376 | 377 | if side == 'bottom': 378 | arrow_coord = lim + 0.01*range# if orient == 'h' else lim - 0.02*range 379 | text_pad = 0.02*range 380 | elif side == 'right': 381 | arrow_coord = lim + 0.01*range# if orient == 'h' else lim - 0.02*range 382 | text_pad = 0.02*range 383 | elif side == 'top': 384 | arrow_coord = lim - 0.01*range# if orient == 'h' else lim - 0.02*range 385 | text_pad = -0.05*range 386 | else: # left 387 | arrow_coord = lim - 0.01*range 388 | text_pad = -0.02*range 389 | 390 | 391 | 392 | center = np.mean(span) 393 | left_arrow = annotate(ax, name, span[0], center, arrow_coord, text_pad) 394 | right_arrow = annotate(ax, name, span[1], center, arrow_coord, text_pad) 395 | return left_arrow, right_arrow 396 | 397 | 398 | def imshow_group_boundaries(ax, gU, gV, group_names, side = 'both', alpha=0.2, lw=0.5): 399 | """Imshow must be sorted according to order in groups""" 400 | if side in ['source','both']: 401 | xmin,xmax = ax.get_xlim() 402 | ax.hlines(np.cumsum(gU[:-1]) - 0.5,xmin=xmin,xmax=xmax,lw=lw, linestyles='dashed', alpha = alpha) 403 | if side in ['target','both']: 404 | ymin,ymax = ax.get_ylim() 405 | ax.vlines(np.cumsum(gV[:-1]) - 0.5,ymin=ymin,ymax=ymax,lw=lw,linestyles='dashed', alpha=alpha) 406 | 407 | if group_names: 408 | offset = -0.5 409 | posx = np.cumsum(gU)# + offset 410 | posy = np.cumsum(gV)# + offset 411 | posx = np.insert(posx, 0, offset) 412 | posy = np.insert(posy, 0, offset) 413 | for i,y in enumerate(posy[:-1]): 414 | annotate_group(group_names[1][i], (posy[i], posy[i+1]), ax, orient='h', side = 'top') 415 | for i,x in enumerate(posx[:-1]): 416 | annotate_group(group_names[0][i], (posx[i], posx[i+1]), ax, orient='v', side = 'right') 417 | 418 | 419 | def method_comparison_plot(df, hue_var = 'method', style_var = 'method', 420 | figsize = (15,4), ax = None, save_path=None): 421 | """ Produce plots comparing OTDD variants in terms of runtime and distance """ 422 | if ax is None: 423 | fig, ax = plt.subplots(1, 2, figsize=figsize) 424 | 425 | lplot_args = { 426 | 'hue': hue_var, 427 | 'style': style_var, 428 | 'data': df, 429 | 'x': 'n', 430 | 'markers': True 431 | } 432 | 433 | sns.lineplot(y='dist', ax= ax[0], **lplot_args) 434 | ax[0].legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=16) 435 | ax[0].set_ylabel('Dataset Distance') 436 | ax[0].set_xlabel('Dataset Size') 437 | ax[0].set_xscale("log") 438 | ax[0].grid(True,which="both",ls="--",c='gray') 439 | 440 | sns.lineplot(y='time', ax= ax[1], **lplot_args) 441 | ax[1].legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=16) 442 | ax[1].set_ylabel('Runtime (s)') 443 | ax[1].set_xlabel('Dataset Size') 444 | ax[1].set_xscale("log") 445 | ax[1].set_yscale("log") 446 | ax[1].grid(True,which="both",ls="--",c='gray') 447 | 448 | handles, labels = ax[1].get_legend_handles_labels() 449 | ax[1].get_legend().remove() 450 | 451 | plt.tight_layout() 452 | if save_path: 453 | plt.savefig(save_path + '.pdf', dpi=300) 454 | plt.savefig(save_path + '.png', dpi=300) 455 | plt.show() 456 | 457 | return ax 458 | -------------------------------------------------------------------------------- /otdd/pytorch/.ipynb_checkpoints/Untitled-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 5 6 | } 7 | -------------------------------------------------------------------------------- /otdd/pytorch/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/.ipynb_checkpoints/__init__-checkpoint.py -------------------------------------------------------------------------------- /otdd/pytorch/.ipynb_checkpoints/functionals-checkpoint.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | ############### COLLECTION OF FUNCTIONALS ON DATASETS ########################## 3 | ################################################################################ 4 | import numpy as np 5 | import torch 6 | 7 | class Functional(): 8 | """ 9 | Defines a JKO functional over measures implicitly by defining it over 10 | individual particles (points). 11 | 12 | The input should be a full dataset: points X (n x d) with labels Y (n x 1). 13 | Optionally, the means/variances associated with each class can be passed. 14 | 15 | (extra space do to repeating) 16 | 17 | """ 18 | def __init__(self, V=None, W=None, f=None, weights=None): 19 | self.V = V # The functional on Z space in potential energy 𝒱() = V 20 | self.W = W # The bi-linear form on ZxZ spaces in interaction energy 𝒲 21 | self.f = f # The scalar-valued function in the niternal energy term ℱ 22 | 23 | def __call__(x, y, μ=None, Σ=None): 24 | sum = 0 25 | if self.F is not None: 26 | sum += self.F(x,y,μ,Σ) 27 | if self.V is not None: 28 | sum += self.V(x,y,μ,Σ) 29 | if self.W is not None: 30 | sum += self.W(x,y,μ,Σ) 31 | return sum 32 | 33 | ################################################################################ 34 | ####### Potential energy functionals (denoted by V in the paper) ######### 35 | ################################################################################ 36 | 37 | def affine_feature_norm(X,Y=None,A=None, b=None, threshold=None, weight=1.0): 38 | """ A simple (feature-only) potential energy based on affine transform + norm: 39 | 40 | v(x,y) = || Ax - b ||, so that V(ρ) = ∫|| Ax - b || dρ(x,y) 41 | 42 | where the integral is approximated by empirical expectation (mean). 43 | """ 44 | if A is None and b is None: 45 | norm = X.norm(dim=1) 46 | elif A is None and not b is None: 47 | norm = (X - b).norm(dim=1) 48 | elif not A is None and b is None: 49 | norm = (X - b).norm(dim=1) 50 | else: 51 | norm = (X@A - b).norm(dim=1) 52 | if threshold: 53 | norm = torch.nn.functional.threshold(norm, threshold, 0) 54 | return weight*norm.mean() 55 | 56 | def binary_hyperplane_margin(X, Y, w, b, weight=1.0): 57 | """ A potential function based on margin separation according to a (given 58 | and fixed) hyperplane: 59 | 60 | v(x,y) = max(0, 1 - y(x'w - b) ), so that V(ρ) = ∫ max(0, y(x'w - b) ) dρ(x,y) 61 | 62 | Returns 0 if all points are at least 1 away from margin. 63 | 64 | Note that y is expected to be {0,1} 65 | 66 | Needs separation hyperplane be determined by (w, b) parameters. 67 | """ 68 | Y_hat = 2*Y-1 # To map Y to {-1, 1}, required by the SVM-type margin obj we use 69 | margin = torch.relu(1-Y_hat*(torch.matmul(X, w) - b)) 70 | return weight*margin.mean() 71 | 72 | def dimension_collapse(X, Y, dim=1, v=None, weight=1.0): 73 | """ Potential function to induce a dimension collapse """ 74 | if v is None: 75 | v = 0 76 | deviation = (X[:,dim] - v)**2 77 | return weight*deviation.mean() 78 | 79 | 80 | 81 | def cluster_repulsion(X, Y): 82 | pdb.set_trace() 83 | 84 | ################################################################################ 85 | ######## Interaction energy functionals (denoted by W in the paper) ######### 86 | ################################################################################ 87 | 88 | def interaction_fun(X, Y, weight=1.0): 89 | """ 90 | 91 | """ 92 | Z = torch.cat((X, Y.float().unsqueeze(1)), -1) 93 | 94 | n,d = Z.shape 95 | Diffs = Z.repeat(n,1,1).transpose(0,1) - Z.repeat(n,1,1) 96 | 97 | def _f(δz): # Enforces cluster repulsion: 98 | δx, δy = torch.split(δz,[δz.shape[-1]-1,1], dim=-1) 99 | δy = torch.abs(δy/δy.max()).ceil() # Hacky way to get 0/1 loss for δy 100 | return -(δx*δy).norm(dim=-1).mean(dim=-1) 101 | 102 | val = _f(Diffs).mean() 103 | 104 | return val*weight 105 | 106 | 107 | def binary_cluster_margin(X, Y, μ=None, weight=1.0): 108 | """ Similar to binary_hyperplane_margin but does to require a separating 109 | hyperplane be provided in advance. Instead, computes one based on current 110 | datapoints as the hyperplane through the midpoint of their means. 111 | 112 | Also, ensures that ..., so it requires point-to-point comparison (interaction) 113 | 114 | """ 115 | 116 | μ_0 = X[Y==0].mean(0) 117 | μ_1 = X[Y==1].mean(0) 118 | 119 | n,d = X.shape 120 | diffs_x = X.repeat(n,1,1).transpose(0,1) - X.repeat(n,1,1) 121 | diffs_x = torch.nn.functional.normalize(diffs_x, dim=2, p=2) 122 | 123 | μ = torch.zeros(n,d) 124 | μ[Y==0,:] = μ_0 125 | μ[Y==1,:] = μ_1 126 | 127 | diffs_μ = μ.repeat(n,1,1).transpose(0,1) - μ.repeat(n,1,1) 128 | diffs_μ = torch.nn.functional.normalize(diffs_μ, dim=2, p=2) 129 | 130 | 131 | inner_prod = torch.einsum("ijk,ijl->ij", diffs_x, diffs_μ) 132 | 133 | print(inner_prod.min(), inner_prod.max()) 134 | 135 | out = torch.relu(-inner_prod + 1) 136 | 137 | print(out.shape) 138 | 139 | margin = torch.exp(out) 140 | return weight*margin.mean() 141 | -------------------------------------------------------------------------------- /otdd/pytorch/.ipynb_checkpoints/moments-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for moment (mean/cov) computation needed by OTTD and other routines. 3 | """ 4 | 5 | import logging 6 | import pdb 7 | 8 | import torch 9 | import torch.utils.data.dataloader as dataloader 10 | from torch.utils.data.sampler import SubsetRandomSampler 11 | 12 | from .utils import process_device_arg, extract_data_targets 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def cov(m, mean=None, rowvar=True, inplace=False): 18 | """ Estimate a covariance matrix given data. 19 | 20 | Covariance indicates the level to which two variables vary together. 21 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 22 | then the covariance matrix element `C_{ij}` is the covariance of 23 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 24 | 25 | Arguments: 26 | m (tensor): A 1-D or 2-D array containing multiple variables and observations. 27 | Each row of `m` represents a variable, and each column a single 28 | observation of all those variables. 29 | rowvar (bool): If `rowvar` is True, then each row represents a 30 | variable, with observations in the columns. Otherwise, the 31 | relationship is transposed: each column represents a variable, 32 | while the rows contain observations. 33 | 34 | Returns: 35 | The covariance matrix of the variables. 36 | """ 37 | if m.dim() > 2: 38 | raise ValueError('m has more than 2 dimensions') 39 | if m.dim() < 2: 40 | m = m.view(1, -1) 41 | if not rowvar and m.size(0) != 1: 42 | m = m.t() 43 | fact = 1.0 / (m.size(1) - 1) 44 | if mean is None: 45 | mean = torch.mean(m, dim=1, keepdim=True) 46 | else: 47 | mean = mean.unsqueeze(1) # For broadcasting 48 | if inplace: 49 | m -= mean 50 | else: 51 | m = m - mean 52 | mt = m.t() # if complex: mt = m.t().conj() 53 | return fact * m.matmul(mt).squeeze() 54 | 55 | class OnlineStatsRecorder: 56 | """ Online batch estimation of multivariate sample mean and covariance matrix. 57 | 58 | Alleviates numerical instability due to catastrophic cancellation that 59 | the naive estimation suffers from. 60 | 61 | Two pass approach first computes population mean, and then uses stable 62 | one pass algorithm on residuals x' = (x - μ). Uses the fact that Cov is 63 | translation invariant, and less cancellation happens if E[XX'] and 64 | E[X]E[X]' are far apart, which is the case for centered data. 65 | 66 | Ideas from: 67 | - https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance 68 | - https://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html 69 | """ 70 | def __init__(self, data=None, twopass=True, centered_cov=False, 71 | diagonal_cov=False, embedding=None, 72 | device='cpu', dtype=torch.FloatTensor): 73 | """ 74 | Arguments: 75 | data (torch tensor): batch of data of shape (nobservations, ndimensions) 76 | twopass (bool): whether two use the two-pass approach (recommended) 77 | centered_cov (bool): whether covariance matrix is centered throughout 78 | the iterations. If false, centering happens once, 79 | at the end. 80 | diagonal_cov (bool): whether covariance matrix should be diagonal 81 | (i.e. ignore cross-correlation terms). In this 82 | case only diagonal (1xdim) tensor retrieved. 83 | embedding (callable): if provided, will map features using this 84 | device (str): device for storage of computed statistics 85 | dtype (torch data type): data type for computed statistics 86 | 87 | """ 88 | self.device = device 89 | self.centered_cov = centered_cov 90 | self.diagonal_cov = diagonal_cov 91 | self.twopass = twopass 92 | self.dtype = dtype 93 | self.embedding = embedding 94 | 95 | self._init_values() 96 | 97 | def _init_values(self): 98 | self.μ = None 99 | self.Σ = None 100 | self.n = 0 101 | 102 | def compute_from_loader(self, dataloader): 103 | """ Compute statistics from dataloader """ 104 | device = process_device_arg(self.device) 105 | for x, _ in dataloader: 106 | x = x.type(self.dtype).to(device) 107 | x = self.embedding(x).detach() if self.embedding is not None else x 108 | self.update(x.view(x.shape[0], -1)) 109 | μ, Σ = self.retrieve() 110 | if self.twopass: 111 | self._init_values() 112 | self.centered_cov = False 113 | for x, _ in dataloader: 114 | x = x.type(self.dtype).to(device) 115 | x = self.embedding(x).detach() if self.embedding is not None else x 116 | self.update(x.view(x.shape[0],-1)-μ) # We compute cov on residuals 117 | _, Σ = self.retrieve() 118 | return μ, Σ 119 | 120 | def update(self, batch): 121 | """ Update statistics using batch of data. 122 | 123 | Arguments: 124 | data (tensor): tensor of shape (nobservations, ndimensions) 125 | """ 126 | if self.n == 0: 127 | self.n,self.d = batch.shape 128 | self.μ = batch.mean(axis=0) 129 | if self.diagonal_cov and self.centered_cov: 130 | self.Σ = torch.var(batch, axis=0, unbiased=True) 131 | ## unbiased is default in pytorch, shown here just to be explicit 132 | elif self.diagonal_cov and not self.centered_cov: 133 | self.Σ = batch.pow(2).sum(axis=0)/(1.0*self.n-1) 134 | elif self.centered_cov: 135 | self.Σ = ((batch-self.μ).T).matmul(batch-self.μ)/(1.0*self.n-1) 136 | else: 137 | self.Σ = (batch.T).matmul(batch)/(1.0*self.n-1) 138 | ## note that this not really covariance yet (not centered) 139 | else: 140 | if batch.shape[1] != self.d: 141 | raise ValueError("Data dims don't match prev observations.") 142 | 143 | ### Dimensions 144 | m = self.n * 1.0 145 | n = batch.shape[0] *1.0 146 | 147 | ### Mean Update 148 | self.μ = self.μ + (batch-self.μ).sum(axis=0)/(m+n) # Stable Algo 149 | 150 | ### Cov Update 151 | if self.diagonal_cov and self.centered_cov: 152 | self.Σ = ((m-1)*self.Σ + ((m-1)/(m+n-1))*((batch-self.μ).pow(2).sum(axis=0)))/(m+n-1) 153 | elif self.diagonal_cov and not self.centered_cov: 154 | self.Σ = (m-1)/(m+n-1)*self.Σ + 1/(m+n-1)*(batch.pow(2).sum(axis=0)) 155 | elif self.centered_cov: 156 | self.Σ = ((m-1)*self.Σ + ((m-1)/(m+n-1))*((batch-self.μ).T).matmul(batch-self.μ))/(m+n-1) 157 | else: 158 | self.Σ = (m-1)/(m+n-1)*self.Σ + 1/(m+n-1)*(batch.T).matmul(batch) 159 | 160 | ### Update total number of examples seen 161 | self.n += n 162 | 163 | def retrieve(self, verbose=False): 164 | """ Retrieve current statistics """ 165 | if verbose: print('Mean and Covariance computed on {} samples'.format(int(self.n))) 166 | if self.centered_cov: 167 | return self.μ, self.Σ 168 | elif self.diagonal_cov: 169 | Σ = self.Σ - self.μ.pow(2)*self.n/(self.n-1) 170 | Σ = torch.nn.functional.relu(Σ) # To avoid negative variances due to rounding 171 | return self.μ, Σ 172 | else: 173 | return self.μ, self.Σ - torch.ger(self.μ.T,self.μ)*self.n/(self.n-1) 174 | 175 | 176 | def _single_label_stats(data, i, c, label_indices, M=None, S=None, batch_size=256, 177 | embedding=None, online=True, diagonal_cov=False, 178 | dtype=None, device=None): 179 | """ Computes mean/covariance of examples that have a given label. Note that 180 | classname c is only needed for vanity printing. Device info needed here since 181 | dataloaders are used inside. 182 | 183 | Arguments: 184 | data (pytorch Dataset or Dataloader): data to compute stats on 185 | i (int): index of label (a.k.a class) to filter 186 | c (int/str): value of label (a.k.a class) to filter 187 | 188 | Returns: 189 | μ (torch tensor): empirical mean of samples with given label 190 | Σ (torch tensor): empirical covariance of samples with given label 191 | n (int): number of samples with giben label 192 | 193 | """ 194 | device = process_device_arg(device) 195 | if len(label_indices) < 2: 196 | logger.warning(" -- Class '{:10}' has too few examples ({})." \ 197 | " Ignoring it.".format(c, len(label_indices))) 198 | if M is None: 199 | return None,None,len(label_indices) 200 | else: 201 | if type(data) == dataloader.DataLoader: 202 | ## We'll reuse the provided dataloader, just setting indices. 203 | ## If loader had indices before, we restore them when we're done 204 | filtered_loader = data 205 | if hasattr(data.sampler,'indices'): 206 | _orig_indices = data.sampler.indices 207 | else: 208 | _orig_indices = None 209 | filtered_loader.sampler.indices = label_indices 210 | 211 | else: 212 | ## Create our own loader 213 | filtered_loader = dataloader.DataLoader(data, batch_size=batch_size, 214 | sampler=SubsetRandomSampler(label_indices)) 215 | _orig_indices = None 216 | 217 | if online: 218 | ## Will compute online (i.e. without loading all the data at once) 219 | stats_rec = OnlineStatsRecorder(centered_cov=True, twopass=True, 220 | diagonal_cov=diagonal_cov, device=device, 221 | embedding=embedding, 222 | dtype=dtype) 223 | μ, Σ = stats_rec.compute_from_loader(filtered_loader) 224 | 225 | n = int(stats_rec.n) 226 | else: 227 | X = torch.cat([d[0].to(device) for d in filtered_loader]).squeeze() 228 | X = embedding(X) if embedding is not None else X 229 | μ = torch.mean(X, dim = 0).flatten() 230 | if diagonal_cov: 231 | Σ = torch.var(X, dim=0).flatten() 232 | else: 233 | Σ = cov(X.view(X.shape[0], -1).t()) 234 | n = X.shape[0] 235 | logger.info(' -> class {:10} (id {:2}): {} examples'.format(c, i, n)) 236 | 237 | if diagonal_cov: 238 | try: 239 | assert Σ.min() >= 0 240 | except: 241 | pdb.set_trace() 242 | 243 | ## Reinstante original indices in sampler 244 | if _orig_indices is not None: data.sampler.indices = _orig_indices 245 | 246 | if M is not None: 247 | M[i],S[i] = μ.cpu(),Σ.cpu() # To avoid GPU parallelism problems 248 | else: 249 | return μ,Σ,n 250 | 251 | 252 | def compute_label_stats(data, targets=None,indices=None,classnames=None, 253 | online=True, batch_size=100, to_tensor=True, 254 | eigen_correction=False, 255 | eigen_correction_scale=1.0, 256 | nworkers=0, diagonal_cov = False, 257 | embedding=None, 258 | device=None, dtype = torch.FloatTensor): 259 | """ 260 | Computes mean/covariance of examples grouped by label. Data can be passed as 261 | a pytorch dataset or a dataloader. Uses dataloader to avoid loading all 262 | classes at once. 263 | 264 | Arguments: 265 | data (pytorch Dataset or Dataloader): data to compute stats on 266 | targets (Tensor, optional): If provided, will use this target array to 267 | avoid re-extracting targets. 268 | indices (array-like, optional): If provided, filtering is based on these 269 | indices (useful if e.g. dataloader has subsampler) 270 | eigen_correction (bool, optional): If ``True``, will shift the covariance 271 | matrix's diagonal by :attr:`eigen_correction_scale` to ensure PSD'ness. 272 | eigen_correction_scale (numeric, optional): Magnitude of eigenvalue 273 | correction (used only if :attr:`eigen_correction` is True) 274 | 275 | Returns: 276 | M (dict): Dictionary with sample means (Tensors) indexed by target class 277 | S (dict): Dictionary with sample covariances (Tensors) indexed by target class 278 | """ 279 | 280 | device = process_device_arg(device) 281 | M = {} # Means 282 | S = {} # Covariances 283 | 284 | ## We need to get all targets in advance, in order to filter. 285 | ## Here we assume targets is the full dataset targets (ignoring subsets, etc) 286 | ## so we need to find effective targets. 287 | if targets is None: 288 | targets, classnames, indices = extract_data_targets(data) 289 | else: 290 | assert (indices is not None), "If targets are provided, so must be indices" 291 | if classnames is None: 292 | classnames = sorted([a.item() for a in torch.unique(targets)]) 293 | 294 | effective_targets = targets[indices] 295 | 296 | if nworkers > 1: 297 | import torch.multiprocessing as mp # Ugly, sure. But useful. 298 | mp.set_start_method('spawn',force=True) 299 | M = mp.Manager().dict() # Alternatively, M = {}; M.share_memory 300 | S = mp.Manager().dict() 301 | processes = [] 302 | for i,c in enumerate(classnames): # No. of processes 303 | label_indices = indices[effective_targets == i] 304 | p = mp.Process(target=_single_label_stats, 305 | args=(data, i,c,label_indices,M,S), 306 | kwargs={'device': device, 'online':online}) 307 | p.start() 308 | processes.append(p) 309 | for p in processes: p.join() 310 | else: 311 | for i,c in enumerate(classnames): 312 | label_indices = indices[effective_targets == i] 313 | μ,Σ,n = _single_label_stats(data, i,c,label_indices, device=device, 314 | dtype=dtype, embedding=embedding, 315 | online=online, diagonal_cov=diagonal_cov) 316 | M[i],S[i] = μ, Σ 317 | 318 | if to_tensor: 319 | ## Warning: this assumes classes are *exactly* {0,...,n}, might break things 320 | ## downstream if data is missing some classes 321 | M = torch.stack([μ.to(device) for i,μ in sorted(M.items()) if μ is not None], dim=0) 322 | S = torch.stack([Σ.to(device) for i,Σ in sorted(S.items()) if Σ is not None], dim=0) 323 | 324 | ### Shift the Covariance matrix's diagonal to ensure PSD'ness 325 | if eigen_correction: 326 | logger.warning('Applying eigenvalue correction to Covariance Matrix') 327 | λ = eigen_correction_scale 328 | for i in range(S.shape[0]): 329 | if eigen_correction == 'constant': 330 | S[i] += torch.diag(λ*torch.ones(S.shape[1], device = device)) 331 | elif eigen_correction == 'jitter': 332 | S[i] += torch.diag(λ*torch.ones(S.shape[1], device=device).uniform_(0.99, 1.01)) 333 | elif eigen_correction == 'exact': 334 | s,v = torch.symeig(S[i]) 335 | print(s.min()) 336 | s,v = torch.lobpcg(S[i], largest=False) 337 | print(s.min()) 338 | s = torch.eig(S[i], eigenvectors=False).eigenvalues 339 | print(s.min()) 340 | pdb.set_trace() 341 | s_min = s.min() 342 | if s_min <= 1e-10: 343 | S[i] += torch.diag(λ*torch.abs(s_min)*torch.ones(S.shape[1], device=device)) 344 | raise NotImplemented() 345 | return M,S 346 | 347 | 348 | def dimreduce_means_covs(Means, Covs, redtype='diagonal'): 349 | """ Methods to reduce the dimensionality of the Feature-Mean/Covariance 350 | representation of Labels. 351 | 352 | Arguments: 353 | Means (tensor or list of tensors): original mean vectors 354 | Covs (tensor or list of tensors): original covariances matrices 355 | redtype (str): dimensionality reduction methods, one of 'diagonal', 'mds' 356 | or 'distance_embedding'. 357 | 358 | Returns: 359 | Means (tensor or list of tensors): dimensionality-reduced mean vectors 360 | Covs (tensor or list of tensors): dimensionality-reduced covariance matrices 361 | 362 | """ 363 | n1, d1 = Means[0].shape 364 | n2, d2 = Means[1].shape 365 | k = d1 366 | 367 | print(n1, d1, n2, d2) 368 | if redtype == 'diagonal': 369 | ## Leave Means As Is, Keep Only Diag of Covariance Matrices, Independent DR for Each Task 370 | Covs[0] = torch.stack([torch.diag(C) for C in Covs[0]]) 371 | Covs[1] = torch.stack([torch.diag(C) for C in Covs[1]]) 372 | elif redtype == 'mds': 373 | ## Leave Means As Is, Use MDS to DimRed Covariance Matrices, Independent DR for Each Task 374 | Covs[0] = mds(Covs[0].view(Covs[0].shape[0], -1), output_dim=k) 375 | Covs[1] = mds(Covs[1].view(Covs[1].shape[0], -1), output_dim=k) 376 | elif redtype == 'distance_embedding': 377 | ## Leaves Means As Is, Use Bipartitie MSE Embedding, Which Embeds the Pairwise Distance Matrix, Rather than the Cov Matrices Directly 378 | print('Will reduce dimension of Σs by embedding pairwise distance matrix...') 379 | D = torch.zeros(n1, n2) 380 | print('... computing pairwise bures distances ...') 381 | for (i, j) in tqdm(itertools.product(range(n1), range(n2))): 382 | D[i, j] = bures_distance(Covs[0][i], Covs[1][j]) 383 | print('... embedding distance matrix ...') 384 | U, V = bipartite_mse_embedding(D, k=k) 385 | Covs = [U, V] 386 | print("Done! Σ's Dimensions: {} (Task 1) and {} (Task 2)".format( 387 | list(U.shape), list(V.shape))) 388 | else: 389 | raise ValueError('Reduction type not recognized') 390 | return Means, Covs 391 | 392 | 393 | def pairwise_distance_mse(U, V, D, reg=1): 394 | d_uv = torch.cdist(U, V) 395 | l = torch.norm(D - d_uv)**2 / D.numel() + reg * (torch.norm(U) ** 396 | 2 / U.numel() + torch.norm(V)**2 / V.numel()) # MSE per entry 397 | return l 398 | 399 | 400 | def bipartite_mse_embedding(D, k=100, niters=10000): 401 | n, m = D.shape 402 | U = torch.randn(n, k, requires_grad=True) 403 | V = torch.randn(m, k, requires_grad=True) 404 | optim = torch.optim.SGD([U, V], lr=1e-1) 405 | for i in range(niters): 406 | optim.zero_grad() 407 | loss = pairwise_distance_mse(U, V, D) 408 | loss.backward() 409 | if i % 100 == 0: 410 | print(i, loss.item()) 411 | optim.step() 412 | loss = pairwise_distance_mse(U, V, D, reg=0) 413 | print( 414 | "Final distortion: ||D - D'||\u00b2/|D| = {:4.2f}".format(loss.item())) 415 | return U.detach(), V.detach() 416 | -------------------------------------------------------------------------------- /otdd/pytorch/.ipynb_checkpoints/nets-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collection of basic neural net models used in the OTDD experiments 3 | """ 4 | 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import pdb 10 | 11 | from .. import ROOT_DIR, HOME_DIR 12 | 13 | MODELS_DIR = os.path.join(ROOT_DIR, 'models') 14 | 15 | MNIST_FLAT_DIM = 28 * 28 16 | 17 | def reset_parameters(m): 18 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 19 | m.reset_parameters() 20 | 21 | class LeNet(nn.Module): 22 | def __init__(self, pretrained=False, num_classes = 10, input_size=28, **kwargs): 23 | super(LeNet, self).__init__() 24 | suffix = f'dim{input_size}_nc{num_classes}' 25 | self.model_path = os.path.join(MODELS_DIR, f'lenet_mnist_{suffix}.pt') 26 | assert input_size in [28,32], "Can only do LeNet on 28x28 or 32x32 for now." 27 | 28 | feat_dim = 16*5*5 if input_size == 32 else 16*4*4 29 | self.feat_dim = feat_dim 30 | self.num_classes = num_classes 31 | if input_size == 32: 32 | self.conv1 = nn.Conv2d(1, 6, 3) 33 | self.conv2 = nn.Conv2d(6, 16, 3) 34 | elif input_size == 28: 35 | self.conv1 = nn.Conv2d(1, 6, 5) 36 | self.conv2 = nn.Conv2d(6, 16, 5) 37 | else: 38 | raise ValueError() 39 | 40 | self._init_classifier() 41 | 42 | if pretrained: 43 | state_dict = torch.load(self.model_path) 44 | self.load_state_dict(state_dict) 45 | 46 | def _init_classifier(self, num_classes=None): 47 | """ Useful for fine-tuning """ 48 | num_classes = self.num_classes if num_classes is None else num_classes 49 | self.classifier = nn.Sequential( 50 | nn.Linear(self.feat_dim, 120), # 6*6 from image dimension 51 | nn.ReLU(), 52 | nn.Dropout(), 53 | nn.Linear(120, 84), 54 | nn.ReLU(), 55 | nn.Dropout(), 56 | nn.Linear(84, num_classes) 57 | ) 58 | 59 | def forward(self, x): 60 | x = F.max_pool2d(F.relu(self.conv1(x)), 2) 61 | x = F.max_pool2d(F.relu(self.conv2(x)), 2) 62 | x = x.view(-1, self.num_flat_features(x)) 63 | return self.classifier(x) 64 | 65 | def num_flat_features(self, x): 66 | size = x.size()[1:] # all dimensions except the batch dimension 67 | num_features = 1 68 | for s in size: 69 | num_features *= s 70 | return num_features 71 | 72 | def save(self): 73 | state_dict = self.state_dict() 74 | torch.save(state_dict, self.model_path) 75 | 76 | class MNIST_MLP(nn.Module): 77 | def __init__( 78 | self, 79 | input_dim=MNIST_FLAT_DIM, 80 | hidden_dim=98, 81 | output_dim=10, 82 | dropout=0.5, 83 | ): 84 | super(ClassifierModule, self).__init__() 85 | self.dropout = nn.Dropout(dropout) 86 | self.hidden = nn.Linear(input_dim, hidden_dim) 87 | self.output = nn.Linear(hidden_dim, output_dim) 88 | 89 | def forward(self, X, **kwargs): 90 | X = X.reshape(-1, self.hidden.in_features) 91 | X = F.relu(self.hidden(X)) 92 | X = self.dropout(X) 93 | X = F.softmax(self.output(X), dim=-1) 94 | return X 95 | 96 | class MNIST_CNN(nn.Module): 97 | def __init__(self, input_size=28, dropout=0.3, nclasses=10, pretrained=False): 98 | super(MNIST_CNN, self).__init__() 99 | self.nclasses = nclasses 100 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 101 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 102 | self.conv2_drop = nn.Dropout2d(p=dropout) 103 | self.fc1 = nn.Linear(1600, 100) # 1600 = number channels * width * height 104 | self.logit = nn.Linear(100, self.nclasses) 105 | self.fc1_drop = nn.Dropout(p=dropout) 106 | suffix = f'dim{input_size}_nc{nclasses}' 107 | self.model_path = os.path.join(MODELS_DIR, f'cnn_mnist_{suffix}.pt') 108 | if pretrained: 109 | state_dict = torch.load(self.model_path) 110 | self.load_state_dict(state_dict) 111 | 112 | def forward(self, x): 113 | x = torch.relu(F.max_pool2d(self.conv1(x), 2)) 114 | x = torch.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 115 | x = x.view(-1, x.size(1) * x.size(2) * x.size(3)) 116 | x = torch.relu(self.fc1_drop(self.fc1(x))) 117 | x = self.logit(x) 118 | x = F.log_softmax(x, dim=-1) 119 | return x 120 | 121 | def save(self): 122 | state_dict = self.state_dict() 123 | torch.save(state_dict, self.model_path) 124 | 125 | 126 | class MLPClassifier(nn.Module): 127 | def __init__( 128 | self, 129 | input_size=None, 130 | hidden_size=400, 131 | num_classes=2, 132 | dropout=0.2, 133 | pretrained=False, 134 | ): 135 | super(MLPClassifier, self).__init__() 136 | self.num_classes = num_classes 137 | self.hidden_sizes = [hidden_size, int(hidden_size/2), int(hidden_size/4)] 138 | 139 | self.dropout = nn.Dropout(dropout) 140 | self.fc1 = nn.Linear(input_size, self.hidden_sizes[0]) 141 | self.fc2 = nn.Linear(self.hidden_sizes[0], self.hidden_sizes[1]) 142 | self.fc3 = nn.Linear(self.hidden_sizes[1], self.hidden_sizes[2]) 143 | 144 | self._init_classifier() 145 | 146 | def _init_classifier(self, num_classes=None): 147 | num_classes = self.num_classes if num_classes is None else num_classes 148 | self.classifier = nn.Sequential( 149 | nn.Linear(self.hidden_sizes[-1], 20), 150 | nn.ReLU(), 151 | nn.Linear(20, num_classes) 152 | ) 153 | 154 | def forward(self, x, **kwargs): 155 | x = self.dropout(F.relu(self.fc1(x))) 156 | x = self.dropout(F.relu(self.fc2(x))) 157 | x = self.dropout(F.relu(self.fc3(x))) 158 | x = self.classifier(x) 159 | return x 160 | 161 | class BoWSentenceEmbedding(): 162 | def __init__(self, vocab_size, embedding_dim, pretrained_vec, padding_idx=None, method = 'naive'): 163 | self.method = method 164 | if method == 'bag': 165 | self.emb = nn.EmbeddingBag.from_pretrained(pretrained_vec, padding_idx=padding_idx) 166 | else: 167 | self.emb = nn.Embedding.from_pretrained(pretrained_vec) 168 | 169 | def __call__(self, x): 170 | if self.method == 'bag': 171 | return self.emb(x) 172 | else: 173 | return self.emb(x).mean(dim=1) 174 | 175 | class MLPPushforward(nn.Module): 176 | def __init__(self, input_size=2, nlayers = 3, **kwargs): 177 | super(MLPPushforward, self).__init__() 178 | d = input_size 179 | 180 | _layers = [] 181 | _d = d 182 | for i in range(nlayers): 183 | _layers.append(nn.Linear(_d, 2*_d)) 184 | _layers.append(nn.ReLU()) 185 | _layers.append(nn.Dropout(0.0)) 186 | _d = 2*_d 187 | for i in range(nlayers): 188 | _layers.append(nn.Linear(_d,int(0.5*_d))) 189 | if i < nlayers - 1: _layers.append(nn.ReLU()) 190 | _layers.append(nn.Dropout(0.0)) 191 | _d = int(0.5*_d) 192 | 193 | self.mapping = nn.Sequential(*_layers) 194 | 195 | def forward(self, x): 196 | return self.mapping(x) 197 | 198 | def reset_parameters(self): 199 | self.mapping.apply(reset_parameters) 200 | 201 | 202 | class ConvPushforward(nn.Module): 203 | def __init__(self, input_size=28, channels = 1, nlayers_conv = 2, nlayers_mlp = 3, **kwargs): 204 | super(ConvPushforward, self).__init__() 205 | self.input_size = input_size 206 | self.channels = channels 207 | if input_size == 32: 208 | self.upconv1 = nn.Conv2d(1, 6, 3) 209 | self.upconv2 = nn.Conv2d(6, 16, 3) 210 | feat_dim = 16*5*5 211 | ## decoder layers ## 212 | self.dnconv1 = nn.ConvTranspose2d(4, 16, 2, stride=2) 213 | self.dnconv2 = nn.ConvTranspose2d(16, 1, 2, stride=2) 214 | elif input_size == 28: 215 | self.upconv1 = nn.Conv2d(1, 6, 5) 216 | self.upconv2 = nn.Conv2d(6, 16, 5) 217 | feat_dim = 16*4*4 218 | self.dnconv1 = nn.ConvTranspose2d(16, 6, 5) 219 | self.dnconv2 = nn.ConvTranspose2d(6, 1, 5) 220 | else: 221 | raise NotImplemented("Can only do LeNet on 28x28 or 32x32 for now.") 222 | self.feat_dim = feat_dim 223 | 224 | self.mlp = MLPPushforward(input_size = feat_dim, layers = nlayers_mlp) 225 | 226 | def forward(self, x): 227 | _orig_shape = x.shape 228 | x = x.reshape(-1, self.channels, self.input_size, self.input_size) 229 | x, idx1 = F.max_pool2d(F.relu(self.upconv1(x)), 2, return_indices=True) 230 | x, idx2 = F.max_pool2d(F.relu(self.upconv2(x)), 2, return_indices=True) 231 | _nonflat_shape = x.shape 232 | x = x.view(-1, self.num_flat_features(x)) 233 | x = self.mlp(x).reshape(_nonflat_shape) 234 | x = F.relu(self.dnconv1(F.max_unpool2d(x, idx2, kernel_size=2))) 235 | x = torch.tanh(self.dnconv2(F.max_unpool2d(x, idx1, kernel_size=2))) 236 | return x.reshape(_orig_shape) 237 | 238 | def num_flat_features(self, x): 239 | size = x.size()[1:] # all dimensions except the batch dimension 240 | num_features = 1 241 | for s in size: 242 | num_features *= s 243 | return num_features 244 | 245 | def reset_parameters(self): 246 | for name, module in self.named_children(): 247 | module.reset_parameters() 248 | 249 | 250 | class ConvPushforward2(nn.Module): 251 | def __init__(self, input_size=28, channels = 1, nlayers_conv = 2, nlayers_mlp = 3, **kwargs): 252 | super(ConvPushforward2, self).__init__() 253 | self.input_size = input_size 254 | self.channels = channels 255 | if input_size == 32: 256 | self.upconv1 = nn.Conv2d(1, 6, 3) 257 | self.upconv2 = nn.Conv2d(6, 16, 3) 258 | feat_dim = 16*5*5 259 | ## decoder layers ## 260 | self.dnconv1 = nn.ConvTranspose2d(4, 16, 2, stride=2) 261 | self.dnconv2 = nn.ConvTranspose2d(16, 1, 2, stride=2) 262 | elif input_size == 28: 263 | self.upconv1 = nn.Conv2d(1, 16, 3, stride=3, padding=1) # b, 16, 10, 10 264 | self.upconv2 = nn.Conv2d(16, 8, 3, stride=2, padding=1) # b, 8, 3, 3 265 | feat_dim = 8*2*2 266 | self.dnconv1 = nn.ConvTranspose2d(8, 16, 3, stride=2) # b, 16, 5, 5 267 | self.dnconv2 = nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1) # b, 8, 15, 15 268 | self.dnconv3 = nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1) # b, 1, 28, 28 269 | else: 270 | raise NotImplemented("Can only do LeNet on 28x28 or 32x32 for now.") 271 | self.feat_dim = feat_dim 272 | 273 | self.mlp = MLPPushforward(input_size = feat_dim, layers = nlayers_mlp) 274 | 275 | def forward(self, x): 276 | x = x.reshape(-1, self.channels, self.input_size, self.input_size) 277 | x = F.max_pool2d(F.relu(self.upconv1(x)), 2, stride=2) 278 | x = F.max_pool2d(F.relu(self.upconv2(x)), 2, stride=1) 279 | _nonflat_shape = x.shape 280 | x = x.view(-1, self.num_flat_features(x)) 281 | x = self.mlp(x).reshape(_nonflat_shape) 282 | x = F.relu(self.dnconv1(x)) 283 | x = F.relu(self.dnconv2(x)) 284 | x = torch.tanh(self.dnconv3(x)) 285 | return x 286 | 287 | def num_flat_features(self, x): 288 | size = x.size()[1:] # all dimensions except the batch dimension 289 | num_features = 1 290 | for s in size: 291 | num_features *= s 292 | return num_features 293 | 294 | def reset_parameters(self): 295 | for name, module in T.named_children(): 296 | print('resetting ', name) 297 | module.reset_parameters() 298 | 299 | 300 | class ConvPushforward3(nn.Module): 301 | def __init__(self, input_size=28, channels = 1, nlayers_conv = 2, nlayers_mlp = 3, **kwargs): 302 | super(ConvPushforward3, self).__init__() 303 | self.input_size = input_size 304 | self.channels = channels 305 | 306 | self.upconv1 = nn.Conv2d(1, 128, 3, 1, 2, dilation=2) 307 | self.upconv2 = nn.Conv2d(128, 128, 3, 1, 2) 308 | self.upconv3 = nn.Conv2d(128, 256, 3, 1, 2) 309 | self.upconv4 = nn.Conv2d(256, 256, 3, 1, 2) 310 | self.upconv5 = nn.Conv2d(128, 128, 3, 1, 2) 311 | self.upconv6 = nn.Conv2d(128, 128, 3, 1, 2) 312 | self.upconv7 = nn.Conv2d(128, 128, 3, 1, 2) 313 | self.upconv8 = nn.Conv2d(128, 128, 3, 1, 2) 314 | 315 | self.dnconv4 = nn.ConvTranspose2d(256, 256, 3, 1, 2) 316 | self.dnconv3 = nn.ConvTranspose2d(256, 128, 3, 1, 2) 317 | self.dnconv2 = nn.ConvTranspose2d(128, 128, 3, 1, 2) 318 | self.dnconv1 = nn.ConvTranspose2d(128, 1, 3, 1, 2, dilation=2) 319 | 320 | self.maxpool1 = nn.MaxPool2d(2, return_indices=True) 321 | self.maxpool2 = nn.MaxPool2d(2, return_indices=True) 322 | self.maxpool3 = nn.MaxPool2d(2, return_indices=True) 323 | self.maxunpool1 = nn.MaxUnpool2d(2) 324 | self.maxunpool2 = nn.MaxUnpool2d(2) 325 | 326 | self.relu1 = nn.ReLU() 327 | self.relu2 = nn.ReLU() 328 | self.relu3 = nn.ReLU() 329 | self.relu4 = nn.ReLU() 330 | self.relu5 = nn.ReLU() 331 | self.relu6 = nn.ReLU() 332 | self.relu7 = nn.ReLU() 333 | self.relu8 = nn.ReLU() 334 | self.derelu1 = nn.ReLU() 335 | self.derelu2 = nn.ReLU() 336 | self.derelu3 = nn.ReLU() 337 | self.derelu4 = nn.ReLU() 338 | self.derelu5 = nn.ReLU() 339 | self.derelu6 = nn.ReLU() 340 | self.derelu7 = nn.ReLU() 341 | self.bn1 = nn.BatchNorm2d(16) 342 | self.bn2 = nn.BatchNorm2d(32) 343 | self.bn3 = nn.BatchNorm2d(16) 344 | self.bn4 = nn.BatchNorm2d(1) 345 | 346 | 347 | def forward(self, x): 348 | x = self.upconv1(x) 349 | x = self.relu1(x) 350 | 351 | x = self.upconv2(x) 352 | x = self.relu2(x) 353 | 354 | x = self.upconv3(x) 355 | x = self.relu3(x) 356 | 357 | x = self.upconv4(x) 358 | x = self.relu4(x) 359 | 360 | x = self.derelu4(x) 361 | x = self.dnconv4(x) 362 | 363 | x = self.derelu3(x) 364 | x = self.dnconv3(x) 365 | 366 | x = self.derelu2(x) 367 | x = self.dnconv2(x) 368 | 369 | x = self.derelu1(x) 370 | x = self.dnconv1(x) 371 | 372 | return x 373 | 374 | def num_flat_features(self, x): 375 | size = x.size()[1:] # all dimensions except the batch dimension 376 | num_features = 1 377 | for s in size: 378 | num_features *= s 379 | return num_features 380 | 381 | def reset_parameters(self): 382 | for name, module in self.named_children(): 383 | try: 384 | module.reset_parameters() 385 | except: 386 | pass 387 | -------------------------------------------------------------------------------- /otdd/pytorch/.ipynb_checkpoints/sqrtm-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Routines for computing matrix square roots. 3 | 4 | With ideas from: 5 | 6 | https://github.com/steveli/pytorch-sqrtm/blob/master/sqrtm.py 7 | https://github.com/pytorch/pytorch/issues/25481 8 | """ 9 | 10 | import pdb 11 | import torch 12 | from torch.autograd import Function 13 | from functools import partial 14 | import numpy as np 15 | import scipy.linalg 16 | try: 17 | import cupy as cp 18 | except: 19 | import numpy as cp 20 | 21 | #### VIA SVD, version 1: from https://github.com/pytorch/pytorch/issues/25481 22 | def symsqrt_v1(A, func='symeig'): 23 | """Compute the square root of a symmetric positive definite matrix.""" 24 | ## https://github.com/pytorch/pytorch/issues/25481#issuecomment-576493693 25 | ## perform the decomposition 26 | ## Recall that for Sym Real matrices, SVD, EVD coincide, |λ_i| = σ_i, so 27 | ## for PSD matrices, these are equal and coincide, so we can use either. 28 | if func == 'symeig': 29 | s, v = A.symeig(eigenvectors=True) # This is faster in GPU than CPU, fails gradcheck. See https://github.com/pytorch/pytorch/issues/30578 30 | elif func == 'svd': 31 | _, s, v = A.svd() # But this passes torch.autograd.gradcheck() 32 | else: 33 | raise ValueError() 34 | 35 | ## truncate small components 36 | good = s > s.max(-1, True).values * s.size(-1) * torch.finfo(s.dtype).eps 37 | components = good.sum(-1) 38 | common = components.max() 39 | unbalanced = common != components.min() 40 | if common < s.size(-1): 41 | s = s[..., :common] 42 | v = v[..., :common] 43 | if unbalanced: 44 | good = good[..., :common] 45 | if unbalanced: 46 | s = s.where(good, torch.zeros((), device=s.device, dtype=s.dtype)) 47 | return (v * s.sqrt().unsqueeze(-2)) @ v.transpose(-2, -1) 48 | 49 | 50 | #### VIA SVD, version 2: from https://github.com/pytorch/pytorch/issues/25481 51 | def symsqrt_v2(A, func='symeig'): 52 | """Compute the square root of a symmetric positive definite matrix.""" 53 | if func == 'symeig': 54 | s, v = A.symeig(eigenvectors=True) # This is faster in GPU than CPU, fails gradcheck. See https://github.com/pytorch/pytorch/issues/30578 55 | elif func == 'svd': 56 | _, s, v = A.svd() # But this passes torch.autograd.gradcheck() 57 | else: 58 | raise ValueError() 59 | 60 | above_cutoff = s > s.max() * s.size(-1) * torch.finfo(s.dtype).eps 61 | 62 | ### This doesn't work for batched version 63 | 64 | ### This does but fails gradcheck because of inpalce 65 | 66 | ### This seems to be equivalent to above, work for batch, and pass inplace. CHECK!!!! 67 | s = torch.where(above_cutoff, s, torch.zeros_like(s)) 68 | 69 | sol =torch.matmul(torch.matmul(v,torch.diag_embed(s.sqrt(),dim1=-2,dim2=-1)),v.transpose(-2,-1)) 70 | 71 | return sol 72 | 73 | # 74 | # 75 | 76 | def special_sylvester(a, b): 77 | """Solves the eqation `A @ X + X @ A = B` for a positive definite `A`.""" 78 | s, v = a.symeig(eigenvectors=True) 79 | d = s.unsqueeze(-1) 80 | d = d + d.transpose(-2, -1) 81 | vt = v.transpose(-2, -1) 82 | c = vt @ b @ v 83 | return v @ (c / d) @ vt 84 | 85 | 86 | ##### Via Newton-Schulz: based on 87 | ## https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py, and 88 | ## https://github.com/BorisMuzellec/EllipticalEmbeddings/blob/master/utils.py 89 | def sqrtm_newton_schulz(A, numIters, reg=None, return_error=False, return_inverse=False): 90 | """ Matrix squareroot based on Newton-Schulz method """ 91 | if A.ndim <= 2: # Non-batched mode 92 | A = A.unsqueeze(0) 93 | batched = False 94 | else: 95 | batched = True 96 | batchSize = A.shape[0] 97 | dim = A.shape[1] 98 | normA = (A**2).sum((-2,-1)).sqrt() # Slightly faster than : A.mul(A).sum((-2,-1)).sqrt() 99 | 100 | if reg: 101 | ## Renormalize so that the each matrix has a norm lesser than 1/reg, 102 | ## but only normalize when necessary 103 | normA *= reg 104 | renorm = torch.ones_like(normA) 105 | renorm[torch.where(normA > 1.0)] = normA[cp.where(normA > 1.0)] 106 | else: 107 | renorm = normA 108 | 109 | Y = A.div(renorm.view(batchSize, 1, 1).expand_as(A)) 110 | I = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).to(A.device)#.type(dtype) 111 | Z = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).to(A.device)#.type(dtype) 112 | for i in range(numIters): 113 | T = 0.5*(3.0*I - Z.bmm(Y)) 114 | Y = Y.bmm(T) 115 | Z = T.bmm(Z) 116 | sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) 117 | sAinv = Z/torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) 118 | if not batched: 119 | sA = sA[0,:,:] 120 | sAinv = sAinv[0,:,:] 121 | 122 | if not return_inverse and not return_error: 123 | return sA 124 | elif not return_inverse and return_error: 125 | return sA, compute_error(A, sA) 126 | elif return_inverse and not return_error: 127 | return sA,sAinv 128 | else: 129 | return sA, sAinv, compute_error(A, sA) 130 | 131 | def create_symm_matrix(batchSize, dim, numPts=20, tau=1.0, dtype=torch.float32, 132 | verbose=False): 133 | """ Creates a random PSD matrix """ 134 | A = torch.zeros(batchSize, dim, dim).type(dtype) 135 | for i in range(batchSize): 136 | pts = np.random.randn(numPts, dim).astype(np.float32) 137 | sA = np.dot(pts.T, pts)/numPts + tau*np.eye(dim).astype(np.float32); 138 | A[i,:,:] = torch.from_numpy(sA); 139 | if verbose: print('Creating batch %d, dim %d, pts %d, tau %f, dtype %s' % (batchSize, dim, numPts, tau, dtype)) 140 | return A 141 | 142 | def compute_error(A, sA): 143 | """ Computes error in approximation """ 144 | normA = torch.sqrt(torch.sum(torch.sum(A * A, dim=1),dim=1)) 145 | error = A - torch.bmm(sA, sA) 146 | error = torch.sqrt((error * error).sum(dim=1).sum(dim=1)) / normA 147 | return torch.mean(error) 148 | 149 | ###========================== 150 | 151 | class MatrixSquareRoot(Function): 152 | """Square root of a positive definite matrix. 153 | 154 | NOTE: square root is not differentiable for matrices with zero eigenvalues. 155 | 156 | """ 157 | @staticmethod 158 | def forward(ctx, input, method = 'numpy'): 159 | _dev = input.device 160 | if method == 'numpy': 161 | m = input.cpu().detach().numpy().astype(np.float_) 162 | sqrtm = torch.from_numpy(scipy.linalg.sqrtm(m).real).type_as(input) 163 | elif method == 'pytorch': 164 | sqrtm = symsqrt(input) 165 | ctx.save_for_backward(sqrtm) 166 | return sqrtm 167 | 168 | @staticmethod 169 | def backward(ctx, grad_output, method = 'numpy'): 170 | grad_input = None 171 | if ctx.needs_input_grad[0]: 172 | sqrtm, = ctx.saved_tensors 173 | if method == 'numpy': 174 | sqrtm = sqrtm.data.numpy().astype(np.float_) 175 | gm = grad_output.data.numpy().astype(np.float_) 176 | grad_sqrtm = scipy.linalg.solve_sylvester(sqrtm, sqrtm, gm) 177 | grad_input = torch.from_numpy(grad_sqrtm).type_as(grad_output.data) 178 | elif method == 'pytorch': 179 | grad_input = special_sylvester(sqrtm, grad_output) 180 | return grad_input 181 | 182 | 183 | ## ========================================================================== ## 184 | ## NOTE: Must pick which version of matrix square root to use!!!! 185 | 186 | ## sqrtm = MatrixSquareRoot.apply 187 | sqrtm = symsqrt_v2 188 | ## sqrtm = symsqrt_v1 189 | ## sqrtm = symsqrt_diff 190 | ## ========================================================================== ## 191 | 192 | def main(): 193 | from torch.autograd import gradcheck 194 | 195 | k = torch.randn(5, 20, 20).double() 196 | M = k @ k.transpose(-1,-2) 197 | 198 | s1 = symsqrt_v1(M, func='symeig') 199 | test = torch.allclose(M, s1 @ s1.transpose(-1,-2)) 200 | print('Via symeig:', test) 201 | 202 | s2 = symsqrt_v1(M, func='svd') 203 | test = torch.allclose(M, s2 @ s2.transpose(-1,-2)) 204 | print('Via svd: ', test) 205 | 206 | print('Sqrtm with symeig and svd match:', torch.allclose(s1,s2)) 207 | 208 | M.requires_grad = True 209 | 210 | ## Check gradients for symsqrt 211 | _sqrt = partial(symsqrt, func='svd') 212 | test = gradcheck(_sqrt, (M,)) 213 | print('Grach Check for sqrtm/svd:', test) 214 | 215 | ## Check symeig itself 216 | S = torch.rand(5,20,20, requires_grad=True).double() 217 | def func(S): 218 | x = 0.5 * (S + S.transpose(-2, -1)) 219 | return torch.symeig(x, eigenvectors=True) 220 | print('Grad check for symeig', gradcheck(func, [S])) 221 | 222 | ## Check gradients for symsqrt with symeig 223 | _sqrt = partial(symsqrt, func='symeig') 224 | test = gradcheck(_sqrt, (M,)) 225 | print('Grach Check for sqrtm/symeig:', test) 226 | 227 | if __name__ == '__main__': 228 | main() 229 | -------------------------------------------------------------------------------- /otdd/pytorch/.ipynb_checkpoints/wasserstein-checkpoint.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import pdb 4 | import itertools 5 | import numpy as np 6 | import torch 7 | from tqdm.autonotebook import tqdm 8 | from joblib import Parallel, delayed 9 | import geomloss 10 | import ot 11 | 12 | from .sqrtm import sqrtm, sqrtm_newton_schulz 13 | from .utils import process_device_arg 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def bures_distance(Σ1, Σ2, sqrtΣ1, commute=False, squared=True): 19 | """ Bures distance between PDF matrices. Simple, non-batch version. 20 | Potentially deprecated. 21 | """ 22 | if not commute: 23 | sqrtΣ1 = sqrtΣ1 if sqrtΣ1 is not None else sqrtm(Σ1) 24 | bures = torch.trace( 25 | Σ1 + Σ2 - 2 * sqrtm(torch.mm(torch.mm(sqrtΣ1, Σ2), sqrtΣ1))) 26 | else: 27 | bures = ((sqrtm(Σ1) - sqrtm(Σ2))**2).sum() 28 | if not squared: 29 | bures = torch.sqrt(bures) 30 | return torch.relu(bures) # i.e., max(bures,0) 31 | 32 | 33 | def bbures_distance(Σ1, Σ2, sqrtΣ1=None, inv_sqrtΣ1=None, 34 | diagonal_cov=False, commute=False, squared=True, sqrt_method='spectral', 35 | sqrt_niters=20): 36 | """ Bures distance between PDF. Batched version. """ 37 | if sqrtΣ1 is None and not diagonal_cov: 38 | sqrtΣ1 = sqrtm(Σ1) if sqrt_method == 'spectral' else sqrtm_newton_schulz(Σ1, sqrt_niters) # , return_inverse=True) 39 | 40 | if diagonal_cov: 41 | bures = ((torch.sqrt(Σ1) - torch.sqrt(Σ2))**2).sum(-1) 42 | elif commute: 43 | sqrtΣ2 = sqrtm(Σ2) if sqrt_method == 'spectral' else sqrtm_newton_schulz(Σ2, sqrt_niters) 44 | bures = ((sqrtm(Σ1) - sqrtm(Σ2))**2).sum((-2, -1)) 45 | else: 46 | if sqrt_method == 'spectral': 47 | cross = sqrtm(torch.matmul(torch.matmul(sqrtΣ1, Σ2), sqrtΣ1)) 48 | else: 49 | cross = sqrtm_newton_schulz(torch.matmul(torch.matmul( 50 | sqrtΣ1, Σ2), sqrtΣ1), sqrt_niters) 51 | ## pytorch doesn't have batched trace yet! 52 | bures = (Σ1 + Σ2 - 2 * cross).diagonal(dim1=-2, dim2=-1).sum(-1) 53 | if not squared: 54 | bures = torch.sqrt(bures) 55 | return torch.relu(bures) 56 | 57 | 58 | def wasserstein_gauss_distance(μ_1, μ_2, Σ1, Σ2, sqrtΣ1=None, cost_function='euclidean', 59 | squared=False,**kwargs): 60 | """ 61 | Returns 2-Wasserstein Distance between Gaussians: 62 | 63 | W(α, β)^2 = || μ_α - μ_β ||^2 + Bures(Σ_α, Σ_β)^2 64 | 65 | 66 | Arguments: 67 | μ_1 (tensor): mean of first Gaussian 68 | kwargs (dict): additional arguments for bbures_distance. 69 | 70 | Returns: 71 | d (tensor): the Wasserstein distance 72 | 73 | """ 74 | if cost_function == 'euclidean': 75 | mean_diff = ((μ_1 - μ_2)**2).sum(axis=-1) # I think this is faster than torch.norm(μ_1-μ_2)**2 76 | else: 77 | mean_diff = cost_function(μ_1,μ_2) 78 | pdb.set_trace(header='TODO: what happens to bures distance for embedded cost function?') 79 | 80 | cova_diff = bbures_distance(Σ1, Σ2, sqrtΣ1=sqrtΣ1, squared=True, **kwargs) 81 | d = torch.relu(mean_diff + cova_diff) 82 | if not squared: 83 | d = torch.sqrt(d) 84 | return d 85 | 86 | 87 | def pwdist_gauss(M1, S1, M2, S2, symmetric=False, return_dmeans=False, nworkers=1, 88 | commute=False): 89 | """ POTENTIALLY DEPRECATED. 90 | Computes Wasserstein Distance between collections of Gaussians, 91 | represented in terms of their means (M1,M2) and Covariances (S1,S2). 92 | 93 | Arguments: 94 | parallel (bool): Whether to use multiprocessing via joblib 95 | 96 | 97 | """ 98 | n1, n2 = len(M1), len(M2) # Number of clusters in each side 99 | if symmetric: 100 | ## If tasks are symmetric (same data on both sides) only need combinations 101 | pairs = list(itertools.combinations(range(n1), 2)) 102 | else: 103 | ## If tasks are assymetric, need n1 x n2 comparisons 104 | pairs = list(itertools.product(range(n1), range(n2))) 105 | 106 | D = torch.zeros((n1, n2)).to(device) 107 | 108 | if nworkers > 1: 109 | results = Parallel(n_jobs=nworkers, verbose=1, backend="threading")( 110 | delayed(wasserstein_gauss_distance)(M1[i], M2[j], S1[i], S2[j], squared=True) for i, j in pairs) 111 | for (i, j), d in zip(pairs, results): 112 | D[i, j] = d 113 | if symmetric: 114 | D[j, i] = D[i, j] 115 | else: 116 | for i, j in tqdm(pairs, leave=False): 117 | D[i, j] = wasserstein_gauss_distance( 118 | M1[i], M2[j], S1[i], S2[j], squared=True, commute=commute) 119 | if symmetric: 120 | D[j, i] = D[i, j] 121 | 122 | if return_dmeans: 123 | D_means = torch.cdist(M1, M2) # For viz purposes only 124 | return D, D_means 125 | else: 126 | return D 127 | 128 | 129 | def efficient_pwdist_gauss(M1, S1, M2=None, S2=None, sqrtS1=None, sqrtS2=None, 130 | symmetric=False, diagonal_cov=False, commute=False, 131 | sqrt_method='spectral',sqrt_niters=20,sqrt_pref=0, 132 | device='cpu',nworkers=1, 133 | cost_function='euclidean', 134 | return_dmeans=False, return_sqrts=False): 135 | """ [Formerly known as efficient_pwassdist] Efficient computation of pairwise 136 | label-to-label Wasserstein distances between various distributions. Saves 137 | computation by precomputing and storing covariance square roots.""" 138 | if M2 is None: 139 | symmetric = True 140 | M2, S2 = M1, S1 141 | 142 | n1, n2 = len(M1), len(M2) # Number of clusters in each side 143 | if symmetric: 144 | ## If tasks are symmetric (same data on both sides) only need combinations 145 | pairs = list(itertools.combinations(range(n1), 2)) 146 | else: 147 | ## If tasks are assymetric, need n1 x n2 comparisons 148 | pairs = list(itertools.product(range(n1), range(n2))) 149 | 150 | D = torch.zeros((n1, n2), device = device, dtype=M1.dtype) 151 | 152 | sqrtS = [] 153 | ## Note that we need inverses of only one of two datasets. 154 | ## If sqrtS of S1 provided, use those. If S2 provided, flip roles of covs in Bures 155 | both_sqrt = (sqrtS1 is not None) and (sqrtS2 is not None) 156 | if (both_sqrt and sqrt_pref==0) or (sqrtS1 is not None): 157 | ## Either both were provided and S1 (idx=0) is prefered, or only S1 provided 158 | flip = False 159 | sqrtS = sqrtS1 160 | elif sqrtS2 is not None: 161 | ## S1 wasn't provided 162 | if sqrt_pref == 0: logger.warning('sqrt_pref=0 but S1 not provided!') 163 | flip = True 164 | sqrtS = sqrtS2 # S2 playes role of S1 165 | elif len(S1) <= len(S2): # No precomputed squareroots provided. Compute, but choose smaller of the two! 166 | flip = False 167 | S = S1 168 | else: 169 | flip = True 170 | S = S2 # S2 playes role of S1 171 | 172 | if not sqrtS: 173 | logger.info('Precomputing covariance matrix square roots...') 174 | for i, Σ in tqdm(enumerate(S), leave=False): 175 | if diagonal_cov: 176 | assert Σ.ndim == 1 177 | sqrtS.append(torch.sqrt(Σ)) # This is actually not needed. 178 | else: 179 | sqrtS.append(sqrtm(Σ) if sqrt_method == 180 | 'spectral' else sqrtm_newton_schulz(Σ, sqrt_niters)) 181 | 182 | logger.info('Computing gaussian-to-gaussian wasserstein distances...') 183 | pbar = tqdm(pairs, leave=False) 184 | pbar.set_description('Computing label-to-label distances') 185 | for i, j in pbar: 186 | if not flip: 187 | D[i, j] = wasserstein_gauss_distance(M1[i], M2[j], S1[i], S2[j], sqrtS[i], 188 | diagonal_cov=diagonal_cov, 189 | commute=commute, squared=True, 190 | cost_function=cost_function, 191 | sqrt_method=sqrt_method, 192 | sqrt_niters=sqrt_niters) 193 | else: 194 | D[i, j] = wasserstein_gauss_distance(M2[j], M1[i], S2[j], S1[i], sqrtS[j], 195 | diagonal_cov=diagonal_cov, 196 | commute=commute, squared=True, 197 | cost_function=cost_function, 198 | sqrt_method=sqrt_method, 199 | sqrt_niters=sqrt_niters) 200 | if symmetric: 201 | D[j, i] = D[i, j] 202 | 203 | if return_dmeans: 204 | D_means = torch.cdist(M1, M2) # For viz purposes only 205 | if return_sqrts: 206 | return D, D_means, sqrtS 207 | else: 208 | return D, D_means 209 | elif return_sqrts: 210 | return D, sqrtS 211 | else: 212 | return D 213 | 214 | def pwdist_means_only(M1, M2=None, symmetric=False, device=None): 215 | if M2 is None or symmetric: 216 | symmetric = True 217 | M2 = M1 218 | D = torch.cdist(M1, M2) 219 | if device: 220 | D = D.to(device) 221 | return D 222 | 223 | def pwdist_upperbound(M1, S1, M2=None, S2=None,symmetric=False, means_only=False, 224 | diagonal_cov=False, commute=False, device=None, 225 | return_dmeans=False): 226 | """ Computes upper bound of the Wasserstein distance between distributions 227 | with given mean and covariance. 228 | """ 229 | 230 | if M2 is None: 231 | symmetric = True 232 | M2, S2 = M1, S1 233 | 234 | n1, n2 = len(M1), len(M2) # Number of clusters in each side 235 | if symmetric: 236 | ## If tasks are symmetric (same data on both sides) only need combinations 237 | pairs = list(itertools.combinations(range(n1), 2)) 238 | else: 239 | ## If tasks are assymetric, need n1 x n2 comparisons 240 | pairs = list(itertools.product(range(n1), range(n2))) 241 | 242 | D = torch.zeros((n1, n2), device = device, dtype=M1.dtype) 243 | 244 | logger.info('Computing gaussian-to-gaussian wasserstein distances...') 245 | pbar = tqdm(pairs, leave=False) 246 | pbar.set_description('Computing label-to-label distances') 247 | 248 | if means_only or return_dmeans: 249 | D_means = torch.cdist(M1, M2) 250 | 251 | if not means_only: 252 | for i, j in pbar: 253 | if means_only: 254 | D[i,j] = ((M1[i]- M2[j])**2).sum(axis=-1) 255 | else: 256 | D[i,j] = ((M1[i]- M2[j])**2).sum(axis=-1) + (S1[i] + S2[j]).diagonal(dim1=-2, dim2=-1).sum(-1) 257 | if symmetric: 258 | D[j, i] = D[i, j] 259 | else: 260 | D = D_means 261 | 262 | if return_dmeans: 263 | D_means = torch.cdist(M1, M2) # For viz purposes only 264 | return D, D_means 265 | else: 266 | return D 267 | 268 | def pwdist_exact(X1, Y1, X2=None, Y2=None, symmetric=False, loss='sinkhorn', 269 | cost_function='euclidean', p=2, debias=True, entreg=1e-1, device='cpu'): 270 | 271 | """ Efficient computation of pairwise label-to-label Wasserstein distances 272 | between multiple distributions, without using Gaussian assumption. 273 | 274 | Args: 275 | X1,X2 (tensor): n x d matrix with features 276 | Y1,Y2 (tensor): labels corresponding to samples 277 | symmetric (bool): whether X1/Y1 and X2/Y2 are to be treated as the same dataset 278 | cost_function (callable/string): the 'ground metric' between features to 279 | be used in optimal transport problem. If callable, should take follow 280 | the convection of the cost argument in geomloss.SamplesLoss 281 | p (int): power of the cost (i.e. order of p-Wasserstein distance). Ignored 282 | if cost_function is a callable. 283 | debias (bool): Only relevant for Sinkhorn. If true, uses debiased sinkhorn 284 | divergence. 285 | 286 | 287 | """ 288 | device = process_device_arg(device) 289 | if X2 is None: 290 | symmetric = True 291 | X2, Y2 = X1, Y1 292 | 293 | c1 = torch.unique(Y1) 294 | c2 = torch.unique(Y2) 295 | n1, n2 = len(c1), len(c2) 296 | 297 | ## We account for the possibility that labels are shifted (c1[0]!=0), see below 298 | 299 | if symmetric: 300 | ## If tasks are symmetric (same data on both sides) only need combinations 301 | pairs = list(itertools.combinations(range(n1), 2)) 302 | else: 303 | ## If tasks are assymetric, need n1 x n2 comparisons 304 | pairs = list(itertools.product(range(n1), range(n2))) 305 | 306 | 307 | if cost_function == 'euclidean': 308 | if p == 1: 309 | cost_function = lambda x, y: geomloss.utils.distances(x, y) 310 | elif p == 2: 311 | cost_function = lambda x, y: geomloss.utils.squared_distances(x, y) 312 | else: 313 | raise ValueError() 314 | 315 | if loss == 'sinkhorn': 316 | distance = geomloss.SamplesLoss( 317 | loss=loss, p=p, 318 | cost=cost_function, 319 | debias=debias, 320 | blur=entreg**(1 / p), 321 | ) 322 | elif loss == 'wasserstein': 323 | def distance(Xa, Xb): 324 | C = cost_function(Xa, Xb).cpu() 325 | return torch.tensor(ot.emd2(ot.unif(Xa.shape[0]), ot.unif(Xb.shape[0]), C))#, verbose=True) 326 | else: 327 | raise ValueError('Wrong loss') 328 | 329 | 330 | logger.info('Computing label-to-label (exact) wasserstein distances...') 331 | pbar = tqdm(pairs, leave=False) 332 | pbar.set_description('Computing label-to-label distances') 333 | D = torch.zeros((n1, n2), device = device, dtype=X1.dtype) 334 | for i, j in pbar: 335 | try: 336 | D[i, j] = distance(X1[Y1==c1[i]].to(device), X2[Y2==c2[j]].to(device)).item() 337 | except: 338 | print("This is awkward. Distance computation failed. Geomloss is hard to debug" \ 339 | "But here's a few things that might be happening: "\ 340 | " 1. Too many samples with this label, causing memory issues" \ 341 | " 2. Datatype errors, e.g., if the two datasets have different type") 342 | sys.exit('Distance computation failed. Aborting.') 343 | if symmetric: 344 | D[j, i] = D[i, j] 345 | return D 346 | -------------------------------------------------------------------------------- /otdd/pytorch/Untitled.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "11a55cf1", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "import torch.optim as optim\n", 13 | "from torch.autograd import Variable\n", 14 | "import torch.nn.functional as F\n", 15 | "from torch.utils.data import Dataset, DataLoader, TensorDataset\n", 16 | "\n", 17 | "import torchvision.datasets as datasets\n", 18 | "import torchvision\n", 19 | "import torchvision.transforms as transforms\n", 20 | "\n", 21 | "# general\n", 22 | "import pandas as pd \n", 23 | "import numpy as np \n", 24 | "import itertools\n", 25 | "from scipy.special import comb\n", 26 | "from cvxopt import matrix, solvers, spdiag\n", 27 | "import copy\n", 28 | "import pickle\n", 29 | "import sys\n", 30 | "import time\n", 31 | "import os\n", 32 | "\n", 33 | "# scikit-learn\n", 34 | "import sklearn\n", 35 | "from sklearn.model_selection import train_test_split\n", 36 | "from sklearn.linear_model import LogisticRegression\n", 37 | "from sklearn.svm import SVC\n", 38 | "from sklearn.metrics import accuracy_score\n", 39 | "from sklearn.metrics import f1_score\n", 40 | "\n", 41 | "# Keras / Tensorflow Libraries\n", 42 | "import tensorflow.compat.v1 as tf\n", 43 | "from tensorflow import keras\n", 44 | "from tensorflow.keras.layers import BatchNormalization\n", 45 | "from tensorflow.keras.layers import Input, Flatten, Dense, LeakyReLU, Dropout, Activation\n", 46 | "from tensorflow.keras.models import Sequential\n", 47 | "from tensorflow.keras.layers import Conv2D, MaxPooling2D\n", 48 | "from tensorflow.keras import backend as K\n", 49 | "from tensorflow.keras import regularizers\n", 50 | "from tensorflow.keras import initializers\n", 51 | "import tensorflow.keras.utils as np_utils\n", 52 | "import tensorflow\n", 53 | "\n", 54 | "from utility_functions import *\n", 55 | "from pubfig83_utility import *\n", 56 | "from deepset import *\n", 57 | "\n", 58 | "import argparse\n", 59 | "\n", 60 | "\n", 61 | "def getSelectedAcc(rank, target_size, x_train, y_train, x_val, y_val, utilityFunc, interval=50):\n", 62 | " if rank is None or utilityFunc is None:\n", 63 | " return None\n", 64 | " ret = np.zeros(int(target_size/interval))\n", 65 | " for i in range(1, int(target_size/interval)+1):\n", 66 | " ret[i-1] = utilityFunc(x_train[rank[::-1][interval*(i):]], y_train[rank[::-1][interval*(i):]], x_val, y_val)\n", 67 | " return ret\n", 68 | "\n", 69 | "print(tf.config.list_physical_devices('GPU'))\n", 70 | "\n", 71 | "parser = argparse.ArgumentParser('')\n", 72 | "\n", 73 | "parser.add_argument('--attack_type', type=str)\n", 74 | "\n", 75 | "parser.add_argument('--lc', action='store_true')\n", 76 | "parser.add_argument('--sv', action='store_true')\n", 77 | "parser.add_argument('--loo', action='store_true')\n", 78 | "parser.add_argument('--tmc', action='store_true')\n", 79 | "parser.add_argument('--gshap', action='store_true')\n", 80 | "parser.add_argument('--inf', action='store_true')\n", 81 | "parser.add_argument('--tracin_self', action='store_true')\n", 82 | "parser.add_argument('--tracin_clean', action='store_true')\n", 83 | "parser.add_argument('--knn', action='store_true')\n", 84 | "parser.add_argument('--deepset', action='store_true')\n", 85 | "parser.add_argument('--random', action='store_true')\n", 86 | "\n", 87 | "args = parser.parse_args()\n", 88 | "\n", 89 | "rank_coll_dir = 'rank-collections/'\n", 90 | "\n", 91 | "\n", 92 | "post_fix = ''\n", 93 | "if args.lc:\n", 94 | " post_fix += 'LC_'\n", 95 | "if args.sv:\n", 96 | " post_fix += 'SV_'\n", 97 | "if args.loo:\n", 98 | " post_fix += 'LOO_'\n", 99 | "if args.tmc:\n", 100 | " post_fix += 'TMC_'\n", 101 | "if args.gshap:\n", 102 | " post_fix += 'GSHAP_'\n", 103 | "if args.inf:\n", 104 | " post_fix += 'INF_'\n", 105 | "if args.tracin_self:\n", 106 | " post_fix += 'TRACINSELF_'\n", 107 | "if args.tracin_clean:\n", 108 | " post_fix += 'TRACINCLEAN_'\n", 109 | "if args.knn:\n", 110 | " post_fix += 'KNN_'\n", 111 | "if args.random:\n", 112 | " post_fix += 'RANDOM_'\n", 113 | "if args.deepset:\n", 114 | " post_fix += 'DEEPSET'\n", 115 | "\n", 116 | "\n", 117 | "print('POST FIX:', post_fix)\n", 118 | "\n", 119 | "attack_type = args.attack_type\n", 120 | "\n", 121 | "\n", 122 | "\n", 123 | "if attack_type == 'CIFAR_Backdoor':\n", 124 | "\n", 125 | " data_file = 'backdoor_cifar_N50000_poison2500_val1000.data'\n", 126 | " x_train_few, y_train_few, x_val_poi, y_val_poi, x_val, y_val = pickle.load( \n", 127 | " open('low-quality-data/'+data_file, 'rb') )\n", 128 | "\n", 129 | " y_train_few = y_train_few.reshape(-1)\n", 130 | " y_val_poi = y_val_poi.reshape(-1)\n", 131 | " y_val = y_val.reshape(-1)\n", 132 | "\n", 133 | " n_data_deepset = 2000\n", 134 | " n_epoch = 16\n", 135 | " n_set = 64\n", 136 | " n_hext = 64\n", 137 | " n_hreg = 64\n", 138 | " LR = 1e-4\n", 139 | "\n", 140 | " deepset_dir = 'saved-deepset/Rebuttal_Backdoor_CIFAR_N{}_DataSeed100_Nepoch{}_Nset{}_Next{}_Nreg{}_LR{}.state_dict'.format(\n", 141 | " n_data_deepset, n_epoch, n_set, n_hext, n_hreg, LR)\n", 142 | "\n", 143 | " func_data_to_acc = None\n", 144 | " # func_data_to_atkacc = torch_cifar_smallCNN_data_to_acc_multiple\n", 145 | " func_data_to_atkacc = torch_cifar_vit_data_to_acc\n", 146 | "\n", 147 | "\n", 148 | "elif attack_type == '...'\n", 149 | "\n", 150 | "\n", 151 | "\n", 152 | "acc_coll_dir = rank_coll_dir+'ICMLRebuttal2_{}_{}.acccoll'.format(attack_type, post_fix)\n", 153 | "\n", 154 | "#x_train_few, y_train_few = x_train_few[:n_data_deepset], y_train_few[:n_data_deepset]\n", 155 | "\n", 156 | "n_data = x_train_few.shape[0]\n", 157 | "target_size = x_train_few.shape[0]\n", 158 | "\n", 159 | "\n", 160 | "if args.inf:\n", 161 | "\n", 162 | " [score] = pickle.load( open(\"saved-samples/ICML_{}_INF.sample\".format(attack_type), 'rb') )\n", 163 | " inf_rank = np.argsort(score)\n", 164 | "\n", 165 | "if args.tracin_self:\n", 166 | "\n", 167 | " [score] = pickle.load( open(\"saved-samples/ICML_{}_TRACINSELF.sample\".format(attack_type), 'rb') )\n", 168 | " tracinself_rank = np.argsort(score)[::-1]\n", 169 | "\n", 170 | "if args.tracin_clean:\n", 171 | "\n", 172 | " [score] = pickle.load( open(\"saved-samples/ICML_{}_TRACINCLEAN.sample\".format(attack_type), 'rb') )\n", 173 | " tracinclean_rank = np.argsort(score)\n", 174 | "\n", 175 | "if args.knn:\n", 176 | "\n", 177 | " [score] = pickle.load( open(\"saved-samples/ICML_{}_KNN.sample\".format(attack_type), 'rb') )\n", 178 | " knn_rank = np.argsort(score)\n", 179 | "\n", 180 | "\n", 181 | "\n", 182 | "\n", 183 | "if not args.lc:\n", 184 | " lc_rank = None\n", 185 | "if not args.sv:\n", 186 | " sv_rank = None\n", 187 | "if not args.loo:\n", 188 | " loo_rank = None\n", 189 | "if not args.tmc:\n", 190 | " tmc_rank = None\n", 191 | "if not args.gshap:\n", 192 | " gshap_rank = None\n", 193 | "if not args.inf:\n", 194 | " inf_rank = None\n", 195 | "if not args.tracin_self:\n", 196 | " tracinself_rank = None\n", 197 | "if not args.tracin_clean:\n", 198 | " tracinclean_rank = None\n", 199 | "if not args.knn:\n", 200 | " knn_rank = None\n", 201 | "if not args.deepset:\n", 202 | " deepset_rank = None\n", 203 | "if not args.random:\n", 204 | " random_rank = None\n", 205 | "\n", 206 | "\n", 207 | "if args.deepset:\n", 208 | "\n", 209 | " if attack_type=='CIFAR_Mislabel' or attack_type=='CIFAR_Poison' or attack_type=='CIFAR_Noisy':\n", 210 | " opt = tf.train.AdamOptimizer()\n", 211 | " extractor = CifarExtractor.build(numChannels=3, imgRows=32, imgCols=32, numClasses=10)\n", 212 | " extractor.compile(loss=\"categorical_crossentropy\", optimizer=opt, metrics=[\"accuracy\"])\n", 213 | " load_status = extractor.load_weights('cifar_featureExtractor.ckpt')\n", 214 | " lastLayerOp = K.function([extractor.layers[0].input], [extractor.layers[-5].output])\n", 215 | " x_train_few_cnnFeature = extractFeatures(lastLayerOp, x_train_few)\n", 216 | "\n", 217 | " print('FeatureExtractor', x_train_few_cnnFeature.shape)\n", 218 | "\n", 219 | " n_cls = 10\n", 220 | " eps = 1e-3\n", 221 | "\n", 222 | " if len(y_train_few.shape) < 2 or y_train_few.shape[1]!=10:\n", 223 | " y_train_few_hot = np_utils.to_categorical(y_train_few, 10)\n", 224 | " else:\n", 225 | " y_train_few_hot = y_train_few\n", 226 | "\n", 227 | " elif attack_type[:5]=='CIFAR' or attack_type[:6]=='DOGCAT':\n", 228 | "\n", 229 | " extractor_savename = 'saved-deepset/Rebuttal_Backdoor_CIFAR_Extractor.state_dict'\n", 230 | "\n", 231 | " extractor = SmallCNN_CIFAR().cuda()\n", 232 | " extractor.load_state_dict( torch.load( extractor_savename ) )\n", 233 | "\n", 234 | " if x_train_few.shape[1]==32:\n", 235 | " x_train_few_cf = np.moveaxis(x_train_few, 3, 1)\n", 236 | " tensor_x = torch.Tensor(x_train_few_cf).cuda()\n", 237 | " x_train_few_cnnFeature = extractor.getFeature(tensor_x)\n", 238 | " x_train_few_cnnFeature = x_train_few_cnnFeature.cpu().detach().numpy()\n", 239 | "\n", 240 | " print('FeatureExtractor', x_train_few_cnnFeature.shape)\n", 241 | "\n", 242 | " n_cls = 10\n", 243 | " eps = 1e-3\n", 244 | "\n", 245 | " if len(y_train_few.shape) < 2 or y_train_few.shape[1]!=10:\n", 246 | " y_train_few_hot = np_utils.to_categorical(y_train_few, 10)\n", 247 | " else:\n", 248 | " y_train_few_hot = y_train_few\n", 249 | "\n", 250 | "\n", 251 | " elif attack_type[:4]=='SPAM':\n", 252 | "\n", 253 | " from sklearn.feature_selection import SelectPercentile\n", 254 | " from sklearn.feature_selection import chi2\n", 255 | " selector = SelectPercentile(score_func=chi2, percentile=10)\n", 256 | " x_train_few_cnnFeature = selector.fit_transform(x_train_clean, y_train_clean)\n", 257 | " x_train_few_cnnFeature = x_train_few_cnnFeature[train_ind]\n", 258 | " n_cls = 2\n", 259 | " y_train_few_hot = y_train_few.reshape((len(y_train_few), 1))\n", 260 | "\n", 261 | " eps = 1e-3\n", 262 | "\n", 263 | "\n", 264 | " elif attack_type[:5]=='MNIST':\n", 265 | " opt = tf.train.AdamOptimizer()\n", 266 | " extractor = KerasLeNet.build(numChannels=1, imgRows=28, imgCols=28, numClasses=10)\n", 267 | " extractor.compile(loss=\"categorical_crossentropy\", optimizer=opt, metrics=[\"accuracy\"])\n", 268 | " load_status = extractor.load_weights('mnist_featureExtractor.h5')\n", 269 | " lastLayerOp = K.function([extractor.layers[0].input], [extractor.layers[6].output])\n", 270 | " x_train_few_cnnFeature = extractFeatures(lastLayerOp, x_train_few)\n", 271 | " n_cls = 10\n", 272 | " eps = 1e-3\n", 273 | "\n", 274 | " if len(y_train_few.shape) < 2 or y_train_few.shape[1]!=10:\n", 275 | " y_train_few_hot = np_utils.to_categorical(y_train_few, 10)\n", 276 | " else:\n", 277 | " y_train_few_hot = y_train_few\n", 278 | "\n", 279 | "\n", 280 | " elif attack_type[:6]=='PUBFIG':\n", 281 | " opt = tf.train.AdamOptimizer()\n", 282 | " extractor = CifarExtractor.build(numChannels=3, imgRows=32, imgCols=32, numClasses=83)\n", 283 | " extractor.compile(loss=\"categorical_crossentropy\", optimizer=opt, metrics=[\"accuracy\"])\n", 284 | " load_status = extractor.load_weights('pubfig_featureExtractor.h5')\n", 285 | " lastLayerOp = K.function([extractor.layers[0].input], [extractor.layers[-5].output])\n", 286 | " x_train_few_cnnFeature = extractFeatures(lastLayerOp, x_train_few)\n", 287 | "\n", 288 | " n_cls = 83\n", 289 | " y_train_few_hot = np_utils.to_categorical(y_train_few)\n", 290 | " eps = 1e-3\n", 291 | " \n", 292 | "\n", 293 | "\n", 294 | "\n", 295 | "\n", 296 | "deepset_rank_coll = np.zeros((10, n_data))\n", 297 | "random_rank_coll = np.zeros((10, n_data))\n", 298 | "acc_coll = {}\n", 299 | "atkacc_coll = {}\n", 300 | "\n", 301 | "acc_coll['deepset'] = []\n", 302 | "acc_coll['lc'] = []\n", 303 | "acc_coll['sv'] = []\n", 304 | "acc_coll['loo'] = []\n", 305 | "acc_coll['tmc'] = []\n", 306 | "acc_coll['gshap'] = []\n", 307 | "acc_coll['inf'] = []\n", 308 | "acc_coll['tracinself'] = []\n", 309 | "acc_coll['tracinclean'] = []\n", 310 | "acc_coll['knn'] = []\n", 311 | "acc_coll['random'] = []\n", 312 | "\n", 313 | "atkacc_coll['deepset'] = []\n", 314 | "atkacc_coll['lc'] = []\n", 315 | "atkacc_coll['sv'] = []\n", 316 | "atkacc_coll['loo'] = []\n", 317 | "atkacc_coll['tmc'] = []\n", 318 | "atkacc_coll['gshap'] = []\n", 319 | "atkacc_coll['inf'] = []\n", 320 | "atkacc_coll['tracinself'] = []\n", 321 | "atkacc_coll['tracinclean'] = []\n", 322 | "atkacc_coll['knn'] = []\n", 323 | "atkacc_coll['random'] = []\n", 324 | "\n", 325 | "\n", 326 | "\n", 327 | "\n", 328 | "# Load (x_bad, y_bad)\n", 329 | "# Load Data Utility Model\n", 330 | "\n", 331 | "\n", 332 | "\n", 333 | "\n", 334 | "\n", 335 | "for select_seed in range(10):\n", 336 | "#for select_seed in range(1):\n", 337 | "\n", 338 | " if args.deepset and attack_type[:6]!='PUBFIG':\n", 339 | "\n", 340 | " n_block = int(n_data/n_data_deepset)\n", 341 | " deepset_rank_matrix = np.zeros((n_block, n_data_deepset))\n", 342 | " random_perm = np.random.permutation(range(n_data))\n", 343 | "\n", 344 | " for i in range(n_block):\n", 345 | " \n", 346 | " random_ind = random_perm[n_data_deepset*(i):n_data_deepset*(i+1)]\n", 347 | "\n", 348 | " if attack_type == 'MNIST_Noisy':\n", 349 | "\n", 350 | " _, deepset_rank_small, _ = findMostValuableSample_deepset_stochasticgreedy_OLD(deepset_model.model, \n", 351 | " x_train_few_cnnFeature[random_ind], \n", 352 | " n_data_deepset, epsilon=eps, seed=select_seed)\n", 353 | " else:\n", 354 | " _, deepset_rank_small, _ , _ = findMostValuableSample_deepset_stochasticgreedy(deepset_model.model, \n", 355 | " x_train_few_cnnFeature[random_ind], \n", 356 | " y_train_few_hot[random_ind], \n", 357 | " n_data_deepset, epsilon=eps, seed=select_seed)\n", 358 | "\n", 359 | " deepset_rank_matrix[i] = random_ind[deepset_rank_small]\n", 360 | "\n", 361 | " deepset_rank = ((deepset_rank_matrix.T).reshape(-1)).astype(int)\n", 362 | " deepset_rank_coll[select_seed] = deepset_rank\n", 363 | "\n", 364 | " if args.random:\n", 365 | " random_rank = np.random.permutation(range(n_data))\n", 366 | " random_rank_coll[select_seed] = random_rank\n", 367 | "\n", 368 | " if args.deepset and attack_type[:6]=='PUBFIG':\n", 369 | " deepset_rank = np.random.permutation(range(n_data))\n", 370 | " deepset_rank_coll[select_seed] = deepset_rank\n", 371 | "\n", 372 | "\n", 373 | " half_size = int(target_size / 2)\n", 374 | " interval = int(half_size / 10)\n", 375 | " # interval = int(half_size / 5)\n", 376 | "\n", 377 | " print('Select Seed', select_seed)\n", 378 | "\n", 379 | " deepset_acc = getSelectedAcc(deepset_rank, half_size, x_train_few, y_train_few, x_val, y_val, func_data_to_acc, interval) \n", 380 | " \n", 381 | " lc_acc = getSelectedAcc(lc_rank, half_size, x_train_few, y_train_few, x_val, y_val, func_data_to_acc, interval)\n", 382 | " sv_acc = getSelectedAcc(sv_rank, half_size, x_train_few, y_train_few, x_val, y_val, func_data_to_acc, interval)\n", 383 | " loo_acc = getSelectedAcc(loo_rank, half_size, x_train_few, y_train_few, x_val, y_val, func_data_to_acc, interval)\n", 384 | " tmc_acc = getSelectedAcc(tmc_rank, half_size, x_train_few, y_train_few, x_val, y_val, func_data_to_acc, interval)\n", 385 | " gshap_acc = getSelectedAcc(gshap_rank, half_size, x_train_few, y_train_few, x_val, y_val, func_data_to_acc, interval)\n", 386 | " inf_acc = getSelectedAcc(inf_rank, half_size, x_train_few, y_train_few, x_val, y_val, func_data_to_acc, interval)\n", 387 | " tracinself_acc = getSelectedAcc(tracinself_rank, half_size, x_train_few, y_train_few, x_val, y_val, func_data_to_acc, interval)\n", 388 | " tracinclean_acc = getSelectedAcc(tracinclean_rank, half_size, x_train_few, y_train_few, x_val, y_val, func_data_to_acc, interval)\n", 389 | " \n", 390 | " knn_acc = getSelectedAcc(knn_rank, half_size, x_train_few, y_train_few, x_val, y_val, func_data_to_acc, interval)\n", 391 | " \n", 392 | " # random_acc = getSelectedAcc(random_rank, half_size, x_train_few, y_train_few, x_val, y_val, func_data_to_acc, interval)\n", 393 | "\n", 394 | " acc_coll['deepset'].append(deepset_acc)\n", 395 | " acc_coll['lc'].append(lc_acc)\n", 396 | " acc_coll['sv'].append(sv_acc)\n", 397 | " acc_coll['loo'].append(loo_acc)\n", 398 | " acc_coll['tmc'].append(tmc_acc)\n", 399 | " acc_coll['gshap'].append(gshap_acc)\n", 400 | " acc_coll['inf'].append(inf_acc)\n", 401 | " acc_coll['tracinself'].append(tracinself_acc)\n", 402 | " acc_coll['tracinclean'].append(tracinclean_acc)\n", 403 | " acc_coll['knn'].append(knn_acc)\n", 404 | " acc_coll['random'].append(random_acc)\n", 405 | "\n", 406 | " deepset_acc = getSelectedAcc(deepset_rank, half_size, x_train_few, y_train_few, x_val_poi, y_val_poi, func_data_to_atkacc, interval)\n", 407 | " lc_acc = getSelectedAcc(lc_rank, half_size, x_train_few, y_train_few, x_val_poi, y_val_poi, func_data_to_atkacc, interval)\n", 408 | " sv_acc = getSelectedAcc(sv_rank, half_size, x_train_few, y_train_few, x_val_poi, y_val_poi, func_data_to_atkacc, interval)\n", 409 | " loo_acc = getSelectedAcc(loo_rank, half_size, x_train_few, y_train_few, x_val_poi, y_val_poi, func_data_to_atkacc, interval)\n", 410 | " tmc_acc = getSelectedAcc(tmc_rank, half_size, x_train_few, y_train_few, x_val_poi, y_val_poi, func_data_to_atkacc, interval)\n", 411 | " gshap_acc = getSelectedAcc(gshap_rank, half_size, x_train_few, y_train_few, x_val_poi, y_val_poi, func_data_to_atkacc, interval)\n", 412 | " inf_acc = getSelectedAcc(inf_rank, half_size, x_train_few, y_train_few, x_val_poi, y_val_poi, func_data_to_atkacc, interval)\n", 413 | " tracinself_acc = getSelectedAcc(tracinself_rank, half_size, x_train_few, y_train_few, x_val_poi, y_val_poi, func_data_to_atkacc, interval)\n", 414 | " tracinclean_acc = getSelectedAcc(tracinclean_rank, half_size, x_train_few, y_train_few, x_val_poi, y_val_poi, func_data_to_atkacc, interval)\n", 415 | " knn_acc = getSelectedAcc(knn_rank, half_size, x_train_few, y_train_few, x_val_poi, y_val_poi, func_data_to_atkacc, interval)\n", 416 | " random_acc = getSelectedAcc(random_rank, half_size, x_train_few, y_train_few, x_val_poi, y_val_poi, func_data_to_atkacc, interval)\n", 417 | "\n", 418 | " atkacc_coll['deepset'].append(deepset_acc)\n", 419 | " atkacc_coll['lc'].append(lc_acc)\n", 420 | " atkacc_coll['sv'].append(sv_acc)\n", 421 | " atkacc_coll['loo'].append(loo_acc)\n", 422 | " atkacc_coll['tmc'].append(tmc_acc)\n", 423 | " atkacc_coll['gshap'].append(gshap_acc)\n", 424 | " atkacc_coll['inf'].append(inf_acc)\n", 425 | " atkacc_coll['tracinself'].append(tracinself_acc)\n", 426 | " atkacc_coll['tracinclean'].append(tracinclean_acc)\n", 427 | " atkacc_coll['knn'].append(knn_acc)\n", 428 | " atkacc_coll['random'].append(random_acc)\n", 429 | "\n", 430 | " pickle.dump([deepset_rank_coll, random_rank_coll, acc_coll, atkacc_coll], open(acc_coll_dir, 'wb') )\n", 431 | "\n", 432 | " print('save!!!')\n", 433 | "\n", 434 | "\n", 435 | "\n", 436 | "\n", 437 | "\n", 438 | "\n", 439 | "\n", 440 | "\n" 441 | ] 442 | } 443 | ], 444 | "metadata": { 445 | "kernelspec": { 446 | "display_name": "Python 3", 447 | "language": "python", 448 | "name": "python3" 449 | }, 450 | "language_info": { 451 | "codemirror_mode": { 452 | "name": "ipython", 453 | "version": 3 454 | }, 455 | "file_extension": ".py", 456 | "mimetype": "text/x-python", 457 | "name": "python", 458 | "nbconvert_exporter": "python", 459 | "pygments_lexer": "ipython3", 460 | "version": "3.8.10" 461 | } 462 | }, 463 | "nbformat": 4, 464 | "nbformat_minor": 5 465 | } 466 | -------------------------------------------------------------------------------- /otdd/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__init__.py -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/datasets.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/datasets.cpython-310.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/datasets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/datasets.cpython-36.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/datasets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/datasets.cpython-39.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/datasets_2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/datasets_2.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/distance.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/distance.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/distance.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/distance.cpython-39.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/distance_2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/distance_2.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/distance_double.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/distance_double.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/distance_fast.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/distance_fast.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/distance_fast.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/distance_fast.cpython-39.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/moments.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/moments.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/moments.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/moments.cpython-39.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/nets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/nets.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/nets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/nets.cpython-39.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/sqrtm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/sqrtm.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/sqrtm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/sqrtm.cpython-39.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/utils_2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/utils_2.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/wasserstein.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/wasserstein.cpython-38.pyc -------------------------------------------------------------------------------- /otdd/pytorch/__pycache__/wasserstein.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/otdd/pytorch/__pycache__/wasserstein.cpython-39.pyc -------------------------------------------------------------------------------- /otdd/pytorch/functionals.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | ############### COLLECTION OF FUNCTIONALS ON DATASETS ########################## 3 | ################################################################################ 4 | import numpy as np 5 | import torch 6 | 7 | class Functional(): 8 | """ 9 | Defines a JKO functional over measures implicitly by defining it over 10 | individual particles (points). 11 | 12 | The input should be a full dataset: points X (n x d) with labels Y (n x 1). 13 | Optionally, the means/variances associated with each class can be passed. 14 | 15 | (extra space do to repeating) 16 | 17 | """ 18 | def __init__(self, V=None, W=None, f=None, weights=None): 19 | self.V = V # The functional on Z space in potential energy 𝒱() = V 20 | self.W = W # The bi-linear form on ZxZ spaces in interaction energy 𝒲 21 | self.f = f # The scalar-valued function in the niternal energy term ℱ 22 | 23 | def __call__(x, y, μ=None, Σ=None): 24 | sum = 0 25 | if self.F is not None: 26 | sum += self.F(x,y,μ,Σ) 27 | if self.V is not None: 28 | sum += self.V(x,y,μ,Σ) 29 | if self.W is not None: 30 | sum += self.W(x,y,μ,Σ) 31 | return sum 32 | 33 | ################################################################################ 34 | ####### Potential energy functionals (denoted by V in the paper) ######### 35 | ################################################################################ 36 | 37 | def affine_feature_norm(X,Y=None,A=None, b=None, threshold=None, weight=1.0): 38 | """ A simple (feature-only) potential energy based on affine transform + norm: 39 | 40 | v(x,y) = || Ax - b ||, so that V(ρ) = ∫|| Ax - b || dρ(x,y) 41 | 42 | where the integral is approximated by empirical expectation (mean). 43 | """ 44 | if A is None and b is None: 45 | norm = X.norm(dim=1) 46 | elif A is None and not b is None: 47 | norm = (X - b).norm(dim=1) 48 | elif not A is None and b is None: 49 | norm = (X - b).norm(dim=1) 50 | else: 51 | norm = (X@A - b).norm(dim=1) 52 | if threshold: 53 | norm = torch.nn.functional.threshold(norm, threshold, 0) 54 | return weight*norm.mean() 55 | 56 | def binary_hyperplane_margin(X, Y, w, b, weight=1.0): 57 | """ A potential function based on margin separation according to a (given 58 | and fixed) hyperplane: 59 | 60 | v(x,y) = max(0, 1 - y(x'w - b) ), so that V(ρ) = ∫ max(0, y(x'w - b) ) dρ(x,y) 61 | 62 | Returns 0 if all points are at least 1 away from margin. 63 | 64 | Note that y is expected to be {0,1} 65 | 66 | Needs separation hyperplane be determined by (w, b) parameters. 67 | """ 68 | Y_hat = 2*Y-1 # To map Y to {-1, 1}, required by the SVM-type margin obj we use 69 | margin = torch.relu(1-Y_hat*(torch.matmul(X, w) - b)) 70 | return weight*margin.mean() 71 | 72 | def dimension_collapse(X, Y, dim=1, v=None, weight=1.0): 73 | """ Potential function to induce a dimension collapse """ 74 | if v is None: 75 | v = 0 76 | deviation = (X[:,dim] - v)**2 77 | return weight*deviation.mean() 78 | 79 | 80 | 81 | def cluster_repulsion(X, Y): 82 | pdb.set_trace() 83 | 84 | ################################################################################ 85 | ######## Interaction energy functionals (denoted by W in the paper) ######### 86 | ################################################################################ 87 | 88 | def interaction_fun(X, Y, weight=1.0): 89 | """ 90 | 91 | """ 92 | Z = torch.cat((X, Y.float().unsqueeze(1)), -1) 93 | 94 | n,d = Z.shape 95 | Diffs = Z.repeat(n,1,1).transpose(0,1) - Z.repeat(n,1,1) 96 | 97 | def _f(δz): # Enforces cluster repulsion: 98 | δx, δy = torch.split(δz,[δz.shape[-1]-1,1], dim=-1) 99 | δy = torch.abs(δy/δy.max()).ceil() # Hacky way to get 0/1 loss for δy 100 | return -(δx*δy).norm(dim=-1).mean(dim=-1) 101 | 102 | val = _f(Diffs).mean() 103 | 104 | return val*weight 105 | 106 | 107 | def binary_cluster_margin(X, Y, μ=None, weight=1.0): 108 | """ Similar to binary_hyperplane_margin but does to require a separating 109 | hyperplane be provided in advance. Instead, computes one based on current 110 | datapoints as the hyperplane through the midpoint of their means. 111 | 112 | Also, ensures that ..., so it requires point-to-point comparison (interaction) 113 | 114 | """ 115 | 116 | μ_0 = X[Y==0].mean(0) 117 | μ_1 = X[Y==1].mean(0) 118 | 119 | n,d = X.shape 120 | diffs_x = X.repeat(n,1,1).transpose(0,1) - X.repeat(n,1,1) 121 | diffs_x = torch.nn.functional.normalize(diffs_x, dim=2, p=2) 122 | 123 | μ = torch.zeros(n,d) 124 | μ[Y==0,:] = μ_0 125 | μ[Y==1,:] = μ_1 126 | 127 | diffs_μ = μ.repeat(n,1,1).transpose(0,1) - μ.repeat(n,1,1) 128 | diffs_μ = torch.nn.functional.normalize(diffs_μ, dim=2, p=2) 129 | 130 | 131 | inner_prod = torch.einsum("ijk,ijl->ij", diffs_x, diffs_μ) 132 | 133 | print(inner_prod.min(), inner_prod.max()) 134 | 135 | out = torch.relu(-inner_prod + 1) 136 | 137 | print(out.shape) 138 | 139 | margin = torch.exp(out) 140 | return weight*margin.mean() 141 | -------------------------------------------------------------------------------- /otdd/pytorch/moments.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for moment (mean/cov) computation needed by OTTD and other routines. 3 | """ 4 | 5 | import logging 6 | import pdb 7 | 8 | import torch 9 | import torch.utils.data.dataloader as dataloader 10 | from torch.utils.data.sampler import SubsetRandomSampler 11 | 12 | from .utils import process_device_arg, extract_data_targets 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def cov(m, mean=None, rowvar=True, inplace=False): 18 | """ Estimate a covariance matrix given data. 19 | 20 | Covariance indicates the level to which two variables vary together. 21 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 22 | then the covariance matrix element `C_{ij}` is the covariance of 23 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 24 | 25 | Arguments: 26 | m (tensor): A 1-D or 2-D array containing multiple variables and observations. 27 | Each row of `m` represents a variable, and each column a single 28 | observation of all those variables. 29 | rowvar (bool): If `rowvar` is True, then each row represents a 30 | variable, with observations in the columns. Otherwise, the 31 | relationship is transposed: each column represents a variable, 32 | while the rows contain observations. 33 | 34 | Returns: 35 | The covariance matrix of the variables. 36 | """ 37 | if m.dim() > 2: 38 | raise ValueError('m has more than 2 dimensions') 39 | if m.dim() < 2: 40 | m = m.view(1, -1) 41 | if not rowvar and m.size(0) != 1: 42 | m = m.t() 43 | fact = 1.0 / (m.size(1) - 1) 44 | if mean is None: 45 | mean = torch.mean(m, dim=1, keepdim=True) 46 | else: 47 | mean = mean.unsqueeze(1) # For broadcasting 48 | if inplace: 49 | m -= mean 50 | else: 51 | m = m - mean 52 | mt = m.t() # if complex: mt = m.t().conj() 53 | return fact * m.matmul(mt).squeeze() 54 | 55 | class OnlineStatsRecorder: 56 | """ Online batch estimation of multivariate sample mean and covariance matrix. 57 | 58 | Alleviates numerical instability due to catastrophic cancellation that 59 | the naive estimation suffers from. 60 | 61 | Two pass approach first computes population mean, and then uses stable 62 | one pass algorithm on residuals x' = (x - μ). Uses the fact that Cov is 63 | translation invariant, and less cancellation happens if E[XX'] and 64 | E[X]E[X]' are far apart, which is the case for centered data. 65 | 66 | Ideas from: 67 | - https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance 68 | - https://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html 69 | """ 70 | def __init__(self, data=None, twopass=True, centered_cov=False, 71 | diagonal_cov=False, embedding=None, 72 | device='cpu', dtype=torch.FloatTensor): 73 | """ 74 | Arguments: 75 | data (torch tensor): batch of data of shape (nobservations, ndimensions) 76 | twopass (bool): whether two use the two-pass approach (recommended) 77 | centered_cov (bool): whether covariance matrix is centered throughout 78 | the iterations. If false, centering happens once, 79 | at the end. 80 | diagonal_cov (bool): whether covariance matrix should be diagonal 81 | (i.e. ignore cross-correlation terms). In this 82 | case only diagonal (1xdim) tensor retrieved. 83 | embedding (callable): if provided, will map features using this 84 | device (str): device for storage of computed statistics 85 | dtype (torch data type): data type for computed statistics 86 | 87 | """ 88 | self.device = device 89 | self.centered_cov = centered_cov 90 | self.diagonal_cov = diagonal_cov 91 | self.twopass = twopass 92 | self.dtype = dtype 93 | self.embedding = embedding 94 | 95 | self._init_values() 96 | 97 | def _init_values(self): 98 | self.μ = None 99 | self.Σ = None 100 | self.n = 0 101 | 102 | def compute_from_loader(self, dataloader): 103 | """ Compute statistics from dataloader """ 104 | device = process_device_arg(self.device) 105 | for x, _ in dataloader: 106 | x = x.type(self.dtype).to(device) 107 | x = self.embedding(x).detach() if self.embedding is not None else x 108 | self.update(x.view(x.shape[0], -1)) 109 | μ, Σ = self.retrieve() 110 | if self.twopass: 111 | self._init_values() 112 | self.centered_cov = False 113 | for x, _ in dataloader: 114 | x = x.type(self.dtype).to(device) 115 | x = self.embedding(x).detach() if self.embedding is not None else x 116 | self.update(x.view(x.shape[0],-1)-μ) # We compute cov on residuals 117 | _, Σ = self.retrieve() 118 | return μ, Σ 119 | 120 | def update(self, batch): 121 | """ Update statistics using batch of data. 122 | 123 | Arguments: 124 | data (tensor): tensor of shape (nobservations, ndimensions) 125 | """ 126 | if self.n == 0: 127 | self.n,self.d = batch.shape 128 | self.μ = batch.mean(axis=0) 129 | if self.diagonal_cov and self.centered_cov: 130 | self.Σ = torch.var(batch, axis=0, unbiased=True) 131 | ## unbiased is default in pytorch, shown here just to be explicit 132 | elif self.diagonal_cov and not self.centered_cov: 133 | self.Σ = batch.pow(2).sum(axis=0)/(1.0*self.n-1) 134 | elif self.centered_cov: 135 | self.Σ = ((batch-self.μ).T).matmul(batch-self.μ)/(1.0*self.n-1) 136 | else: 137 | self.Σ = (batch.T).matmul(batch)/(1.0*self.n-1) 138 | ## note that this not really covariance yet (not centered) 139 | else: 140 | if batch.shape[1] != self.d: 141 | raise ValueError("Data dims don't match prev observations.") 142 | 143 | ### Dimensions 144 | m = self.n * 1.0 145 | n = batch.shape[0] *1.0 146 | 147 | ### Mean Update 148 | self.μ = self.μ + (batch-self.μ).sum(axis=0)/(m+n) # Stable Algo 149 | 150 | ### Cov Update 151 | if self.diagonal_cov and self.centered_cov: 152 | self.Σ = ((m-1)*self.Σ + ((m-1)/(m+n-1))*((batch-self.μ).pow(2).sum(axis=0)))/(m+n-1) 153 | elif self.diagonal_cov and not self.centered_cov: 154 | self.Σ = (m-1)/(m+n-1)*self.Σ + 1/(m+n-1)*(batch.pow(2).sum(axis=0)) 155 | elif self.centered_cov: 156 | self.Σ = ((m-1)*self.Σ + ((m-1)/(m+n-1))*((batch-self.μ).T).matmul(batch-self.μ))/(m+n-1) 157 | else: 158 | self.Σ = (m-1)/(m+n-1)*self.Σ + 1/(m+n-1)*(batch.T).matmul(batch) 159 | 160 | ### Update total number of examples seen 161 | self.n += n 162 | 163 | def retrieve(self, verbose=False): 164 | """ Retrieve current statistics """ 165 | if verbose: print('Mean and Covariance computed on {} samples'.format(int(self.n))) 166 | if self.centered_cov: 167 | return self.μ, self.Σ 168 | elif self.diagonal_cov: 169 | Σ = self.Σ - self.μ.pow(2)*self.n/(self.n-1) 170 | Σ = torch.nn.functional.relu(Σ) # To avoid negative variances due to rounding 171 | return self.μ, Σ 172 | else: 173 | return self.μ, self.Σ - torch.ger(self.μ.T,self.μ)*self.n/(self.n-1) 174 | 175 | 176 | def _single_label_stats(data, i, c, label_indices, M=None, S=None, batch_size=256, 177 | embedding=None, online=True, diagonal_cov=False, 178 | dtype=None, device=None): 179 | """ Computes mean/covariance of examples that have a given label. Note that 180 | classname c is only needed for vanity printing. Device info needed here since 181 | dataloaders are used inside. 182 | 183 | Arguments: 184 | data (pytorch Dataset or Dataloader): data to compute stats on 185 | i (int): index of label (a.k.a class) to filter 186 | c (int/str): value of label (a.k.a class) to filter 187 | 188 | Returns: 189 | μ (torch tensor): empirical mean of samples with given label 190 | Σ (torch tensor): empirical covariance of samples with given label 191 | n (int): number of samples with giben label 192 | 193 | """ 194 | device = process_device_arg(device) 195 | if len(label_indices) < 2: 196 | logger.warning(" -- Class '{:10}' has too few examples ({})." \ 197 | " Ignoring it.".format(c, len(label_indices))) 198 | if M is None: 199 | return None,None,len(label_indices) 200 | else: 201 | if type(data) == dataloader.DataLoader: 202 | ## We'll reuse the provided dataloader, just setting indices. 203 | ## If loader had indices before, we restore them when we're done 204 | filtered_loader = data 205 | if hasattr(data.sampler,'indices'): 206 | _orig_indices = data.sampler.indices 207 | else: 208 | _orig_indices = None 209 | filtered_loader.sampler.indices = label_indices 210 | 211 | else: 212 | ## Create our own loader 213 | filtered_loader = dataloader.DataLoader(data, batch_size=batch_size, 214 | sampler=SubsetRandomSampler(label_indices)) 215 | _orig_indices = None 216 | 217 | if online: 218 | ## Will compute online (i.e. without loading all the data at once) 219 | stats_rec = OnlineStatsRecorder(centered_cov=True, twopass=True, 220 | diagonal_cov=diagonal_cov, device=device, 221 | embedding=embedding, 222 | dtype=dtype) 223 | μ, Σ = stats_rec.compute_from_loader(filtered_loader) 224 | 225 | n = int(stats_rec.n) 226 | else: 227 | X = torch.cat([d[0].to(device) for d in filtered_loader]).squeeze() 228 | X = embedding(X) if embedding is not None else X 229 | μ = torch.mean(X, dim = 0).flatten() 230 | if diagonal_cov: 231 | Σ = torch.var(X, dim=0).flatten() 232 | else: 233 | Σ = cov(X.view(X.shape[0], -1).t()) 234 | n = X.shape[0] 235 | logger.info(' -> class {:10} (id {:2}): {} examples'.format(c, i, n)) 236 | 237 | if diagonal_cov: 238 | try: 239 | assert Σ.min() >= 0 240 | except: 241 | pdb.set_trace() 242 | 243 | ## Reinstante original indices in sampler 244 | if _orig_indices is not None: data.sampler.indices = _orig_indices 245 | 246 | if M is not None: 247 | M[i],S[i] = μ.cpu(),Σ.cpu() # To avoid GPU parallelism problems 248 | else: 249 | return μ,Σ,n 250 | 251 | 252 | def compute_label_stats(data, targets=None,indices=None,classnames=None, 253 | online=True, batch_size=100, to_tensor=True, 254 | eigen_correction=False, 255 | eigen_correction_scale=1.0, 256 | nworkers=0, diagonal_cov = False, 257 | embedding=None, 258 | device=None, dtype = torch.FloatTensor): 259 | """ 260 | Computes mean/covariance of examples grouped by label. Data can be passed as 261 | a pytorch dataset or a dataloader. Uses dataloader to avoid loading all 262 | classes at once. 263 | 264 | Arguments: 265 | data (pytorch Dataset or Dataloader): data to compute stats on 266 | targets (Tensor, optional): If provided, will use this target array to 267 | avoid re-extracting targets. 268 | indices (array-like, optional): If provided, filtering is based on these 269 | indices (useful if e.g. dataloader has subsampler) 270 | eigen_correction (bool, optional): If ``True``, will shift the covariance 271 | matrix's diagonal by :attr:`eigen_correction_scale` to ensure PSD'ness. 272 | eigen_correction_scale (numeric, optional): Magnitude of eigenvalue 273 | correction (used only if :attr:`eigen_correction` is True) 274 | 275 | Returns: 276 | M (dict): Dictionary with sample means (Tensors) indexed by target class 277 | S (dict): Dictionary with sample covariances (Tensors) indexed by target class 278 | """ 279 | print("in MOMENTS: ", eigen_correction) 280 | device = process_device_arg(device) 281 | M = {} # Means 282 | S = {} # Covariances 283 | 284 | ## We need to get all targets in advance, in order to filter. 285 | ## Here we assume targets is the full dataset targets (ignoring subsets, etc) 286 | ## so we need to find effective targets. 287 | if targets is None: 288 | targets, classnames, indices = extract_data_targets(data) 289 | else: 290 | assert (indices is not None), "If targets are provided, so must be indices" 291 | if classnames is None: 292 | classnames = sorted([a.item() for a in torch.unique(targets)]) 293 | 294 | effective_targets = targets[indices] 295 | 296 | if nworkers > 1: 297 | import torch.multiprocessing as mp # Ugly, sure. But useful. 298 | mp.set_start_method('spawn',force=True) 299 | M = mp.Manager().dict() # Alternatively, M = {}; M.share_memory 300 | S = mp.Manager().dict() 301 | processes = [] 302 | for i,c in enumerate(classnames): # No. of processes 303 | label_indices = indices[effective_targets == i] 304 | p = mp.Process(target=_single_label_stats, 305 | args=(data, i,c,label_indices,M,S), 306 | kwargs={'device': device, 'online':online}) 307 | p.start() 308 | processes.append(p) 309 | for p in processes: p.join() 310 | else: 311 | for i,c in enumerate(classnames): 312 | label_indices = indices[effective_targets == i] 313 | μ,Σ,n = _single_label_stats(data, i,c,label_indices, device=device, 314 | dtype=dtype, embedding=embedding, 315 | online=online, diagonal_cov=diagonal_cov) 316 | M[i],S[i] = μ, Σ 317 | 318 | if to_tensor: 319 | ## Warning: this assumes classes are *exactly* {0,...,n}, might break things 320 | ## downstream if data is missing some classes 321 | M = torch.stack([μ.to(device) for i,μ in sorted(M.items()) if μ is not None], dim=0) 322 | S = torch.stack([Σ.to(device) for i,Σ in sorted(S.items()) if Σ is not None], dim=0) 323 | 324 | ### Shift the Covariance matrix's diagonal to ensure PSD'ness 325 | if eigen_correction: 326 | logger.warning('Applying eigenvalue correction to Covariance Matrix') 327 | λ = eigen_correction_scale 328 | for i in range(S.shape[0]): 329 | if eigen_correction == 'constant': 330 | S[i] += torch.diag(λ*torch.ones(S.shape[1], device = device)) 331 | elif eigen_correction == 'jitter': 332 | S[i] += torch.diag(λ*torch.ones(S.shape[1], device=device).uniform_(0.99, 1.01)) 333 | elif eigen_correction == 'exact': 334 | s,v = torch.symeig(S[i]) 335 | print(s.min()) 336 | s,v = torch.lobpcg(S[i], largest=False) 337 | print(s.min()) 338 | s = torch.eig(S[i], eigenvectors=False).eigenvalues 339 | print(s.min()) 340 | pdb.set_trace() 341 | s_min = s.min() 342 | if s_min <= 1e-10: 343 | S[i] += torch.diag(λ*torch.abs(s_min)*torch.ones(S.shape[1], device=device)) 344 | raise NotImplemented() 345 | return M,S 346 | 347 | 348 | def dimreduce_means_covs(Means, Covs, redtype='diagonal'): 349 | """ Methods to reduce the dimensionality of the Feature-Mean/Covariance 350 | representation of Labels. 351 | 352 | Arguments: 353 | Means (tensor or list of tensors): original mean vectors 354 | Covs (tensor or list of tensors): original covariances matrices 355 | redtype (str): dimensionality reduction methods, one of 'diagonal', 'mds' 356 | or 'distance_embedding'. 357 | 358 | Returns: 359 | Means (tensor or list of tensors): dimensionality-reduced mean vectors 360 | Covs (tensor or list of tensors): dimensionality-reduced covariance matrices 361 | 362 | """ 363 | n1, d1 = Means[0].shape 364 | n2, d2 = Means[1].shape 365 | k = d1 366 | 367 | print(n1, d1, n2, d2) 368 | if redtype == 'diagonal': 369 | ## Leave Means As Is, Keep Only Diag of Covariance Matrices, Independent DR for Each Task 370 | Covs[0] = torch.stack([torch.diag(C) for C in Covs[0]]) 371 | Covs[1] = torch.stack([torch.diag(C) for C in Covs[1]]) 372 | elif redtype == 'mds': 373 | ## Leave Means As Is, Use MDS to DimRed Covariance Matrices, Independent DR for Each Task 374 | Covs[0] = mds(Covs[0].view(Covs[0].shape[0], -1), output_dim=k) 375 | Covs[1] = mds(Covs[1].view(Covs[1].shape[0], -1), output_dim=k) 376 | elif redtype == 'distance_embedding': 377 | ## Leaves Means As Is, Use Bipartitie MSE Embedding, Which Embeds the Pairwise Distance Matrix, Rather than the Cov Matrices Directly 378 | print('Will reduce dimension of Σs by embedding pairwise distance matrix...') 379 | D = torch.zeros(n1, n2) 380 | print('... computing pairwise bures distances ...') 381 | for (i, j) in tqdm(itertools.product(range(n1), range(n2))): 382 | D[i, j] = bures_distance(Covs[0][i], Covs[1][j]) 383 | print('... embedding distance matrix ...') 384 | U, V = bipartite_mse_embedding(D, k=k) 385 | Covs = [U, V] 386 | print("Done! Σ's Dimensions: {} (Task 1) and {} (Task 2)".format( 387 | list(U.shape), list(V.shape))) 388 | else: 389 | raise ValueError('Reduction type not recognized') 390 | return Means, Covs 391 | 392 | 393 | def pairwise_distance_mse(U, V, D, reg=1): 394 | d_uv = torch.cdist(U, V) 395 | l = torch.norm(D - d_uv)**2 / D.numel() + reg * (torch.norm(U) ** 396 | 2 / U.numel() + torch.norm(V)**2 / V.numel()) # MSE per entry 397 | return l 398 | 399 | 400 | def bipartite_mse_embedding(D, k=100, niters=10000): 401 | n, m = D.shape 402 | U = torch.randn(n, k, requires_grad=True) 403 | V = torch.randn(m, k, requires_grad=True) 404 | optim = torch.optim.SGD([U, V], lr=1e-1) 405 | for i in range(niters): 406 | optim.zero_grad() 407 | loss = pairwise_distance_mse(U, V, D) 408 | loss.backward() 409 | if i % 100 == 0: 410 | print(i, loss.item()) 411 | optim.step() 412 | loss = pairwise_distance_mse(U, V, D, reg=0) 413 | print( 414 | "Final distortion: ||D - D'||\u00b2/|D| = {:4.2f}".format(loss.item())) 415 | return U.detach(), V.detach() 416 | -------------------------------------------------------------------------------- /otdd/pytorch/nets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collection of basic neural net models used in the OTDD experiments 3 | """ 4 | 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import pdb 10 | 11 | from .. import ROOT_DIR, HOME_DIR 12 | 13 | MODELS_DIR = os.path.join(ROOT_DIR, 'models') 14 | 15 | MNIST_FLAT_DIM = 28 * 28 16 | 17 | def reset_parameters(m): 18 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 19 | m.reset_parameters() 20 | 21 | class LeNet(nn.Module): 22 | def __init__(self, pretrained=False, num_classes = 10, input_size=28, **kwargs): 23 | super(LeNet, self).__init__() 24 | suffix = f'dim{input_size}_nc{num_classes}' 25 | self.model_path = os.path.join(MODELS_DIR, f'lenet_mnist_{suffix}.pt') 26 | assert input_size in [28,32], "Can only do LeNet on 28x28 or 32x32 for now." 27 | 28 | feat_dim = 16*5*5 if input_size == 32 else 16*4*4 29 | self.feat_dim = feat_dim 30 | self.num_classes = num_classes 31 | if input_size == 32: 32 | self.conv1 = nn.Conv2d(1, 6, 3) 33 | self.conv2 = nn.Conv2d(6, 16, 3) 34 | elif input_size == 28: 35 | self.conv1 = nn.Conv2d(1, 6, 5) 36 | self.conv2 = nn.Conv2d(6, 16, 5) 37 | else: 38 | raise ValueError() 39 | 40 | self._init_classifier() 41 | 42 | if pretrained: 43 | state_dict = torch.load(self.model_path) 44 | self.load_state_dict(state_dict) 45 | 46 | def _init_classifier(self, num_classes=None): 47 | """ Useful for fine-tuning """ 48 | num_classes = self.num_classes if num_classes is None else num_classes 49 | self.classifier = nn.Sequential( 50 | nn.Linear(self.feat_dim, 120), # 6*6 from image dimension 51 | nn.ReLU(), 52 | nn.Dropout(), 53 | nn.Linear(120, 84), 54 | nn.ReLU(), 55 | nn.Dropout(), 56 | nn.Linear(84, num_classes) 57 | ) 58 | 59 | def forward(self, x): 60 | x = F.max_pool2d(F.relu(self.conv1(x)), 2) 61 | x = F.max_pool2d(F.relu(self.conv2(x)), 2) 62 | x = x.view(-1, self.num_flat_features(x)) 63 | return self.classifier(x) 64 | 65 | def num_flat_features(self, x): 66 | size = x.size()[1:] # all dimensions except the batch dimension 67 | num_features = 1 68 | for s in size: 69 | num_features *= s 70 | return num_features 71 | 72 | def save(self): 73 | state_dict = self.state_dict() 74 | torch.save(state_dict, self.model_path) 75 | 76 | class MNIST_MLP(nn.Module): 77 | def __init__( 78 | self, 79 | input_dim=MNIST_FLAT_DIM, 80 | hidden_dim=98, 81 | output_dim=10, 82 | dropout=0.5, 83 | ): 84 | super(ClassifierModule, self).__init__() 85 | self.dropout = nn.Dropout(dropout) 86 | self.hidden = nn.Linear(input_dim, hidden_dim) 87 | self.output = nn.Linear(hidden_dim, output_dim) 88 | 89 | def forward(self, X, **kwargs): 90 | X = X.reshape(-1, self.hidden.in_features) 91 | X = F.relu(self.hidden(X)) 92 | X = self.dropout(X) 93 | X = F.softmax(self.output(X), dim=-1) 94 | return X 95 | 96 | class MNIST_CNN(nn.Module): 97 | def __init__(self, input_size=28, dropout=0.3, nclasses=10, pretrained=False): 98 | super(MNIST_CNN, self).__init__() 99 | self.nclasses = nclasses 100 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 101 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 102 | self.conv2_drop = nn.Dropout2d(p=dropout) 103 | self.fc1 = nn.Linear(1600, 100) # 1600 = number channels * width * height 104 | self.logit = nn.Linear(100, self.nclasses) 105 | self.fc1_drop = nn.Dropout(p=dropout) 106 | suffix = f'dim{input_size}_nc{nclasses}' 107 | self.model_path = os.path.join(MODELS_DIR, f'cnn_mnist_{suffix}.pt') 108 | if pretrained: 109 | state_dict = torch.load(self.model_path) 110 | self.load_state_dict(state_dict) 111 | 112 | def forward(self, x): 113 | x = torch.relu(F.max_pool2d(self.conv1(x), 2)) 114 | x = torch.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 115 | x = x.view(-1, x.size(1) * x.size(2) * x.size(3)) 116 | x = torch.relu(self.fc1_drop(self.fc1(x))) 117 | x = self.logit(x) 118 | x = F.log_softmax(x, dim=-1) 119 | return x 120 | 121 | def save(self): 122 | state_dict = self.state_dict() 123 | torch.save(state_dict, self.model_path) 124 | 125 | 126 | class MLPClassifier(nn.Module): 127 | def __init__( 128 | self, 129 | input_size=None, 130 | hidden_size=400, 131 | num_classes=2, 132 | dropout=0.2, 133 | pretrained=False, 134 | ): 135 | super(MLPClassifier, self).__init__() 136 | self.num_classes = num_classes 137 | self.hidden_sizes = [hidden_size, int(hidden_size/2), int(hidden_size/4)] 138 | 139 | self.dropout = nn.Dropout(dropout) 140 | self.fc1 = nn.Linear(input_size, self.hidden_sizes[0]) 141 | self.fc2 = nn.Linear(self.hidden_sizes[0], self.hidden_sizes[1]) 142 | self.fc3 = nn.Linear(self.hidden_sizes[1], self.hidden_sizes[2]) 143 | 144 | self._init_classifier() 145 | 146 | def _init_classifier(self, num_classes=None): 147 | num_classes = self.num_classes if num_classes is None else num_classes 148 | self.classifier = nn.Sequential( 149 | nn.Linear(self.hidden_sizes[-1], 20), 150 | nn.ReLU(), 151 | nn.Linear(20, num_classes) 152 | ) 153 | 154 | def forward(self, x, **kwargs): 155 | x = self.dropout(F.relu(self.fc1(x))) 156 | x = self.dropout(F.relu(self.fc2(x))) 157 | x = self.dropout(F.relu(self.fc3(x))) 158 | x = self.classifier(x) 159 | return x 160 | 161 | class BoWSentenceEmbedding(): 162 | def __init__(self, vocab_size, embedding_dim, pretrained_vec, padding_idx=None, method = 'naive'): 163 | self.method = method 164 | if method == 'bag': 165 | self.emb = nn.EmbeddingBag.from_pretrained(pretrained_vec, padding_idx=padding_idx) 166 | else: 167 | self.emb = nn.Embedding.from_pretrained(pretrained_vec) 168 | 169 | def __call__(self, x): 170 | if self.method == 'bag': 171 | return self.emb(x) 172 | else: 173 | return self.emb(x).mean(dim=1) 174 | 175 | class MLPPushforward(nn.Module): 176 | def __init__(self, input_size=2, nlayers = 3, **kwargs): 177 | super(MLPPushforward, self).__init__() 178 | d = input_size 179 | 180 | _layers = [] 181 | _d = d 182 | for i in range(nlayers): 183 | _layers.append(nn.Linear(_d, 2*_d)) 184 | _layers.append(nn.ReLU()) 185 | _layers.append(nn.Dropout(0.0)) 186 | _d = 2*_d 187 | for i in range(nlayers): 188 | _layers.append(nn.Linear(_d,int(0.5*_d))) 189 | if i < nlayers - 1: _layers.append(nn.ReLU()) 190 | _layers.append(nn.Dropout(0.0)) 191 | _d = int(0.5*_d) 192 | 193 | self.mapping = nn.Sequential(*_layers) 194 | 195 | def forward(self, x): 196 | return self.mapping(x) 197 | 198 | def reset_parameters(self): 199 | self.mapping.apply(reset_parameters) 200 | 201 | 202 | class ConvPushforward(nn.Module): 203 | def __init__(self, input_size=28, channels = 1, nlayers_conv = 2, nlayers_mlp = 3, **kwargs): 204 | super(ConvPushforward, self).__init__() 205 | self.input_size = input_size 206 | self.channels = channels 207 | if input_size == 32: 208 | self.upconv1 = nn.Conv2d(1, 6, 3) 209 | self.upconv2 = nn.Conv2d(6, 16, 3) 210 | feat_dim = 16*5*5 211 | ## decoder layers ## 212 | self.dnconv1 = nn.ConvTranspose2d(4, 16, 2, stride=2) 213 | self.dnconv2 = nn.ConvTranspose2d(16, 1, 2, stride=2) 214 | elif input_size == 28: 215 | self.upconv1 = nn.Conv2d(1, 6, 5) 216 | self.upconv2 = nn.Conv2d(6, 16, 5) 217 | feat_dim = 16*4*4 218 | self.dnconv1 = nn.ConvTranspose2d(16, 6, 5) 219 | self.dnconv2 = nn.ConvTranspose2d(6, 1, 5) 220 | else: 221 | raise NotImplemented("Can only do LeNet on 28x28 or 32x32 for now.") 222 | self.feat_dim = feat_dim 223 | 224 | self.mlp = MLPPushforward(input_size = feat_dim, layers = nlayers_mlp) 225 | 226 | def forward(self, x): 227 | _orig_shape = x.shape 228 | x = x.reshape(-1, self.channels, self.input_size, self.input_size) 229 | x, idx1 = F.max_pool2d(F.relu(self.upconv1(x)), 2, return_indices=True) 230 | x, idx2 = F.max_pool2d(F.relu(self.upconv2(x)), 2, return_indices=True) 231 | _nonflat_shape = x.shape 232 | x = x.view(-1, self.num_flat_features(x)) 233 | x = self.mlp(x).reshape(_nonflat_shape) 234 | x = F.relu(self.dnconv1(F.max_unpool2d(x, idx2, kernel_size=2))) 235 | x = torch.tanh(self.dnconv2(F.max_unpool2d(x, idx1, kernel_size=2))) 236 | return x.reshape(_orig_shape) 237 | 238 | def num_flat_features(self, x): 239 | size = x.size()[1:] # all dimensions except the batch dimension 240 | num_features = 1 241 | for s in size: 242 | num_features *= s 243 | return num_features 244 | 245 | def reset_parameters(self): 246 | for name, module in self.named_children(): 247 | module.reset_parameters() 248 | 249 | 250 | class ConvPushforward2(nn.Module): 251 | def __init__(self, input_size=28, channels = 1, nlayers_conv = 2, nlayers_mlp = 3, **kwargs): 252 | super(ConvPushforward2, self).__init__() 253 | self.input_size = input_size 254 | self.channels = channels 255 | if input_size == 32: 256 | self.upconv1 = nn.Conv2d(1, 6, 3) 257 | self.upconv2 = nn.Conv2d(6, 16, 3) 258 | feat_dim = 16*5*5 259 | ## decoder layers ## 260 | self.dnconv1 = nn.ConvTranspose2d(4, 16, 2, stride=2) 261 | self.dnconv2 = nn.ConvTranspose2d(16, 1, 2, stride=2) 262 | elif input_size == 28: 263 | self.upconv1 = nn.Conv2d(1, 16, 3, stride=3, padding=1) # b, 16, 10, 10 264 | self.upconv2 = nn.Conv2d(16, 8, 3, stride=2, padding=1) # b, 8, 3, 3 265 | feat_dim = 8*2*2 266 | self.dnconv1 = nn.ConvTranspose2d(8, 16, 3, stride=2) # b, 16, 5, 5 267 | self.dnconv2 = nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1) # b, 8, 15, 15 268 | self.dnconv3 = nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1) # b, 1, 28, 28 269 | else: 270 | raise NotImplemented("Can only do LeNet on 28x28 or 32x32 for now.") 271 | self.feat_dim = feat_dim 272 | 273 | self.mlp = MLPPushforward(input_size = feat_dim, layers = nlayers_mlp) 274 | 275 | def forward(self, x): 276 | x = x.reshape(-1, self.channels, self.input_size, self.input_size) 277 | x = F.max_pool2d(F.relu(self.upconv1(x)), 2, stride=2) 278 | x = F.max_pool2d(F.relu(self.upconv2(x)), 2, stride=1) 279 | _nonflat_shape = x.shape 280 | x = x.view(-1, self.num_flat_features(x)) 281 | x = self.mlp(x).reshape(_nonflat_shape) 282 | x = F.relu(self.dnconv1(x)) 283 | x = F.relu(self.dnconv2(x)) 284 | x = torch.tanh(self.dnconv3(x)) 285 | return x 286 | 287 | def num_flat_features(self, x): 288 | size = x.size()[1:] # all dimensions except the batch dimension 289 | num_features = 1 290 | for s in size: 291 | num_features *= s 292 | return num_features 293 | 294 | def reset_parameters(self): 295 | for name, module in T.named_children(): 296 | print('resetting ', name) 297 | module.reset_parameters() 298 | 299 | 300 | class ConvPushforward3(nn.Module): 301 | def __init__(self, input_size=28, channels = 1, nlayers_conv = 2, nlayers_mlp = 3, **kwargs): 302 | super(ConvPushforward3, self).__init__() 303 | self.input_size = input_size 304 | self.channels = channels 305 | 306 | self.upconv1 = nn.Conv2d(1, 128, 3, 1, 2, dilation=2) 307 | self.upconv2 = nn.Conv2d(128, 128, 3, 1, 2) 308 | self.upconv3 = nn.Conv2d(128, 256, 3, 1, 2) 309 | self.upconv4 = nn.Conv2d(256, 256, 3, 1, 2) 310 | self.upconv5 = nn.Conv2d(128, 128, 3, 1, 2) 311 | self.upconv6 = nn.Conv2d(128, 128, 3, 1, 2) 312 | self.upconv7 = nn.Conv2d(128, 128, 3, 1, 2) 313 | self.upconv8 = nn.Conv2d(128, 128, 3, 1, 2) 314 | 315 | self.dnconv4 = nn.ConvTranspose2d(256, 256, 3, 1, 2) 316 | self.dnconv3 = nn.ConvTranspose2d(256, 128, 3, 1, 2) 317 | self.dnconv2 = nn.ConvTranspose2d(128, 128, 3, 1, 2) 318 | self.dnconv1 = nn.ConvTranspose2d(128, 1, 3, 1, 2, dilation=2) 319 | 320 | self.maxpool1 = nn.MaxPool2d(2, return_indices=True) 321 | self.maxpool2 = nn.MaxPool2d(2, return_indices=True) 322 | self.maxpool3 = nn.MaxPool2d(2, return_indices=True) 323 | self.maxunpool1 = nn.MaxUnpool2d(2) 324 | self.maxunpool2 = nn.MaxUnpool2d(2) 325 | 326 | self.relu1 = nn.ReLU() 327 | self.relu2 = nn.ReLU() 328 | self.relu3 = nn.ReLU() 329 | self.relu4 = nn.ReLU() 330 | self.relu5 = nn.ReLU() 331 | self.relu6 = nn.ReLU() 332 | self.relu7 = nn.ReLU() 333 | self.relu8 = nn.ReLU() 334 | self.derelu1 = nn.ReLU() 335 | self.derelu2 = nn.ReLU() 336 | self.derelu3 = nn.ReLU() 337 | self.derelu4 = nn.ReLU() 338 | self.derelu5 = nn.ReLU() 339 | self.derelu6 = nn.ReLU() 340 | self.derelu7 = nn.ReLU() 341 | self.bn1 = nn.BatchNorm2d(16) 342 | self.bn2 = nn.BatchNorm2d(32) 343 | self.bn3 = nn.BatchNorm2d(16) 344 | self.bn4 = nn.BatchNorm2d(1) 345 | 346 | 347 | def forward(self, x): 348 | x = self.upconv1(x) 349 | x = self.relu1(x) 350 | 351 | x = self.upconv2(x) 352 | x = self.relu2(x) 353 | 354 | x = self.upconv3(x) 355 | x = self.relu3(x) 356 | 357 | x = self.upconv4(x) 358 | x = self.relu4(x) 359 | 360 | x = self.derelu4(x) 361 | x = self.dnconv4(x) 362 | 363 | x = self.derelu3(x) 364 | x = self.dnconv3(x) 365 | 366 | x = self.derelu2(x) 367 | x = self.dnconv2(x) 368 | 369 | x = self.derelu1(x) 370 | x = self.dnconv1(x) 371 | 372 | return x 373 | 374 | def num_flat_features(self, x): 375 | size = x.size()[1:] # all dimensions except the batch dimension 376 | num_features = 1 377 | for s in size: 378 | num_features *= s 379 | return num_features 380 | 381 | def reset_parameters(self): 382 | for name, module in self.named_children(): 383 | try: 384 | module.reset_parameters() 385 | except: 386 | pass 387 | -------------------------------------------------------------------------------- /otdd/pytorch/sqrtm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Routines for computing matrix square roots. 3 | 4 | With ideas from: 5 | 6 | https://github.com/steveli/pytorch-sqrtm/blob/master/sqrtm.py 7 | https://github.com/pytorch/pytorch/issues/25481 8 | """ 9 | 10 | import pdb 11 | import torch 12 | from torch.autograd import Function 13 | from functools import partial 14 | import numpy as np 15 | import scipy.linalg 16 | try: 17 | import cupy as cp 18 | except: 19 | import numpy as cp 20 | 21 | #### VIA SVD, version 1: from https://github.com/pytorch/pytorch/issues/25481 22 | def symsqrt_v1(A, func='symeig'): 23 | """Compute the square root of a symmetric positive definite matrix.""" 24 | ## https://github.com/pytorch/pytorch/issues/25481#issuecomment-576493693 25 | ## perform the decomposition 26 | ## Recall that for Sym Real matrices, SVD, EVD coincide, |λ_i| = σ_i, so 27 | ## for PSD matrices, these are equal and coincide, so we can use either. 28 | if func == 'symeig': 29 | s, v = A.symeig(eigenvectors=True) # This is faster in GPU than CPU, fails gradcheck. See https://github.com/pytorch/pytorch/issues/30578 30 | elif func == 'svd': 31 | _, s, v = A.svd() # But this passes torch.autograd.gradcheck() 32 | else: 33 | raise ValueError() 34 | 35 | ## truncate small components 36 | good = s > s.max(-1, True).values * s.size(-1) * torch.finfo(s.dtype).eps 37 | components = good.sum(-1) 38 | common = components.max() 39 | unbalanced = common != components.min() 40 | if common < s.size(-1): 41 | s = s[..., :common] 42 | v = v[..., :common] 43 | if unbalanced: 44 | good = good[..., :common] 45 | if unbalanced: 46 | s = s.where(good, torch.zeros((), device=s.device, dtype=s.dtype)) 47 | return (v * s.sqrt().unsqueeze(-2)) @ v.transpose(-2, -1) 48 | 49 | 50 | #### VIA SVD, version 2: from https://github.com/pytorch/pytorch/issues/25481 51 | def symsqrt_v2(A, func='symeig'): 52 | """Compute the square root of a symmetric positive definite matrix.""" 53 | if func == 'symeig': 54 | s, v = A.symeig(eigenvectors=True) # This is faster in GPU than CPU, fails gradcheck. See https://github.com/pytorch/pytorch/issues/30578 55 | elif func == 'svd': 56 | _, s, v = A.svd() # But this passes torch.autograd.gradcheck() 57 | else: 58 | raise ValueError() 59 | 60 | above_cutoff = s > s.max() * s.size(-1) * torch.finfo(s.dtype).eps 61 | 62 | ### This doesn't work for batched version 63 | 64 | ### This does but fails gradcheck because of inpalce 65 | 66 | ### This seems to be equivalent to above, work for batch, and pass inplace. CHECK!!!! 67 | s = torch.where(above_cutoff, s, torch.zeros_like(s)) 68 | 69 | sol =torch.matmul(torch.matmul(v,torch.diag_embed(s.sqrt(),dim1=-2,dim2=-1)),v.transpose(-2,-1)) 70 | 71 | return sol 72 | 73 | # 74 | # 75 | 76 | def special_sylvester(a, b): 77 | """Solves the eqation `A @ X + X @ A = B` for a positive definite `A`.""" 78 | s, v = a.symeig(eigenvectors=True) 79 | d = s.unsqueeze(-1) 80 | d = d + d.transpose(-2, -1) 81 | vt = v.transpose(-2, -1) 82 | c = vt @ b @ v 83 | return v @ (c / d) @ vt 84 | 85 | 86 | ##### Via Newton-Schulz: based on 87 | ## https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py, and 88 | ## https://github.com/BorisMuzellec/EllipticalEmbeddings/blob/master/utils.py 89 | def sqrtm_newton_schulz(A, numIters, reg=None, return_error=False, return_inverse=False): 90 | """ Matrix squareroot based on Newton-Schulz method """ 91 | if A.ndim <= 2: # Non-batched mode 92 | A = A.unsqueeze(0) 93 | batched = False 94 | else: 95 | batched = True 96 | batchSize = A.shape[0] 97 | dim = A.shape[1] 98 | normA = (A**2).sum((-2,-1)).sqrt() # Slightly faster than : A.mul(A).sum((-2,-1)).sqrt() 99 | 100 | if reg: 101 | ## Renormalize so that the each matrix has a norm lesser than 1/reg, 102 | ## but only normalize when necessary 103 | normA *= reg 104 | renorm = torch.ones_like(normA) 105 | renorm[torch.where(normA > 1.0)] = normA[cp.where(normA > 1.0)] 106 | else: 107 | renorm = normA 108 | 109 | Y = A.div(renorm.view(batchSize, 1, 1).expand_as(A)) 110 | I = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).to(A.device)#.type(dtype) 111 | Z = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).to(A.device)#.type(dtype) 112 | for i in range(numIters): 113 | T = 0.5*(3.0*I - Z.bmm(Y)) 114 | Y = Y.bmm(T) 115 | Z = T.bmm(Z) 116 | sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) 117 | sAinv = Z/torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) 118 | if not batched: 119 | sA = sA[0,:,:] 120 | sAinv = sAinv[0,:,:] 121 | 122 | if not return_inverse and not return_error: 123 | return sA 124 | elif not return_inverse and return_error: 125 | return sA, compute_error(A, sA) 126 | elif return_inverse and not return_error: 127 | return sA,sAinv 128 | else: 129 | return sA, sAinv, compute_error(A, sA) 130 | 131 | def create_symm_matrix(batchSize, dim, numPts=20, tau=1.0, dtype=torch.float32, 132 | verbose=False): 133 | """ Creates a random PSD matrix """ 134 | A = torch.zeros(batchSize, dim, dim).type(dtype) 135 | for i in range(batchSize): 136 | pts = np.random.randn(numPts, dim).astype(np.float32) 137 | sA = np.dot(pts.T, pts)/numPts + tau*np.eye(dim).astype(np.float32); 138 | A[i,:,:] = torch.from_numpy(sA); 139 | if verbose: print('Creating batch %d, dim %d, pts %d, tau %f, dtype %s' % (batchSize, dim, numPts, tau, dtype)) 140 | return A 141 | 142 | def compute_error(A, sA): 143 | """ Computes error in approximation """ 144 | normA = torch.sqrt(torch.sum(torch.sum(A * A, dim=1),dim=1)) 145 | error = A - torch.bmm(sA, sA) 146 | error = torch.sqrt((error * error).sum(dim=1).sum(dim=1)) / normA 147 | return torch.mean(error) 148 | 149 | ###========================== 150 | 151 | class MatrixSquareRoot(Function): 152 | """Square root of a positive definite matrix. 153 | 154 | NOTE: square root is not differentiable for matrices with zero eigenvalues. 155 | 156 | """ 157 | @staticmethod 158 | def forward(ctx, input, method = 'numpy'): 159 | _dev = input.device 160 | if method == 'numpy': 161 | m = input.cpu().detach().numpy().astype(np.float_) 162 | sqrtm = torch.from_numpy(scipy.linalg.sqrtm(m).real).type_as(input) 163 | elif method == 'pytorch': 164 | sqrtm = symsqrt(input) 165 | ctx.save_for_backward(sqrtm) 166 | return sqrtm 167 | 168 | @staticmethod 169 | def backward(ctx, grad_output, method = 'numpy'): 170 | grad_input = None 171 | if ctx.needs_input_grad[0]: 172 | sqrtm, = ctx.saved_tensors 173 | if method == 'numpy': 174 | sqrtm = sqrtm.data.numpy().astype(np.float_) 175 | gm = grad_output.data.numpy().astype(np.float_) 176 | grad_sqrtm = scipy.linalg.solve_sylvester(sqrtm, sqrtm, gm) 177 | grad_input = torch.from_numpy(grad_sqrtm).type_as(grad_output.data) 178 | elif method == 'pytorch': 179 | grad_input = special_sylvester(sqrtm, grad_output) 180 | return grad_input 181 | 182 | 183 | ## ========================================================================== ## 184 | ## NOTE: Must pick which version of matrix square root to use!!!! 185 | 186 | ## sqrtm = MatrixSquareRoot.apply 187 | sqrtm = symsqrt_v2 188 | ## sqrtm = symsqrt_v1 189 | ## sqrtm = symsqrt_diff 190 | ## ========================================================================== ## 191 | 192 | def main(): 193 | from torch.autograd import gradcheck 194 | 195 | k = torch.randn(5, 20, 20).double() 196 | M = k @ k.transpose(-1,-2) 197 | 198 | s1 = symsqrt_v1(M, func='symeig') 199 | test = torch.allclose(M, s1 @ s1.transpose(-1,-2)) 200 | print('Via symeig:', test) 201 | 202 | s2 = symsqrt_v1(M, func='svd') 203 | test = torch.allclose(M, s2 @ s2.transpose(-1,-2)) 204 | print('Via svd: ', test) 205 | 206 | print('Sqrtm with symeig and svd match:', torch.allclose(s1,s2)) 207 | 208 | M.requires_grad = True 209 | 210 | ## Check gradients for symsqrt 211 | _sqrt = partial(symsqrt, func='svd') 212 | test = gradcheck(_sqrt, (M,)) 213 | print('Grach Check for sqrtm/svd:', test) 214 | 215 | ## Check symeig itself 216 | S = torch.rand(5,20,20, requires_grad=True).double() 217 | def func(S): 218 | x = 0.5 * (S + S.transpose(-2, -1)) 219 | return torch.symeig(x, eigenvectors=True) 220 | print('Grad check for symeig', gradcheck(func, [S])) 221 | 222 | ## Check gradients for symsqrt with symeig 223 | _sqrt = partial(symsqrt, func='symeig') 224 | test = gradcheck(_sqrt, (M,)) 225 | print('Grach Check for sqrtm/symeig:', test) 226 | 227 | if __name__ == '__main__': 228 | main() 229 | -------------------------------------------------------------------------------- /otdd/pytorch/wasserstein.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import pdb 4 | import itertools 5 | import numpy as np 6 | import torch 7 | from tqdm.autonotebook import tqdm 8 | from joblib import Parallel, delayed 9 | import geomloss 10 | # import ot 11 | 12 | from .sqrtm import sqrtm, sqrtm_newton_schulz 13 | from .utils import process_device_arg 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def bures_distance(Σ1, Σ2, sqrtΣ1, commute=False, squared=True): 19 | """ Bures distance between PDF matrices. Simple, non-batch version. 20 | Potentially deprecated. 21 | """ 22 | if not commute: 23 | sqrtΣ1 = sqrtΣ1 if sqrtΣ1 is not None else sqrtm(Σ1) 24 | bures = torch.trace( 25 | Σ1 + Σ2 - 2 * sqrtm(torch.mm(torch.mm(sqrtΣ1, Σ2), sqrtΣ1))) 26 | else: 27 | bures = ((sqrtm(Σ1) - sqrtm(Σ2))**2).sum() 28 | if not squared: 29 | bures = torch.sqrt(bures) 30 | return torch.relu(bures) # i.e., max(bures,0) 31 | 32 | 33 | def bbures_distance(Σ1, Σ2, sqrtΣ1=None, inv_sqrtΣ1=None, 34 | diagonal_cov=False, commute=False, squared=True, sqrt_method='spectral', 35 | sqrt_niters=20): 36 | """ Bures distance between PDF. Batched version. """ 37 | if sqrtΣ1 is None and not diagonal_cov: 38 | sqrtΣ1 = sqrtm(Σ1) if sqrt_method == 'spectral' else sqrtm_newton_schulz(Σ1, sqrt_niters) # , return_inverse=True) 39 | 40 | if diagonal_cov: 41 | bures = ((torch.sqrt(Σ1) - torch.sqrt(Σ2))**2).sum(-1) 42 | elif commute: 43 | sqrtΣ2 = sqrtm(Σ2) if sqrt_method == 'spectral' else sqrtm_newton_schulz(Σ2, sqrt_niters) 44 | bures = ((sqrtm(Σ1) - sqrtm(Σ2))**2).sum((-2, -1)) 45 | else: 46 | if sqrt_method == 'spectral': 47 | cross = sqrtm(torch.matmul(torch.matmul(sqrtΣ1, Σ2), sqrtΣ1)) 48 | else: 49 | cross = sqrtm_newton_schulz(torch.matmul(torch.matmul( 50 | sqrtΣ1, Σ2), sqrtΣ1), sqrt_niters) 51 | ## pytorch doesn't have batched trace yet! 52 | bures = (Σ1 + Σ2 - 2 * cross).diagonal(dim1=-2, dim2=-1).sum(-1) 53 | if not squared: 54 | bures = torch.sqrt(bures) 55 | return torch.relu(bures) 56 | 57 | 58 | def wasserstein_gauss_distance(μ_1, μ_2, Σ1, Σ2, sqrtΣ1=None, cost_function='euclidean', 59 | squared=False,**kwargs): 60 | """ 61 | Returns 2-Wasserstein Distance between Gaussians: 62 | 63 | W(α, β)^2 = || μ_α - μ_β ||^2 + Bures(Σ_α, Σ_β)^2 64 | 65 | 66 | Arguments: 67 | μ_1 (tensor): mean of first Gaussian 68 | kwargs (dict): additional arguments for bbures_distance. 69 | 70 | Returns: 71 | d (tensor): the Wasserstein distance 72 | 73 | """ 74 | if cost_function == 'euclidean': 75 | mean_diff = ((μ_1 - μ_2)**2).sum(axis=-1) # I think this is faster than torch.norm(μ_1-μ_2)**2 76 | else: 77 | mean_diff = cost_function(μ_1,μ_2) 78 | #pdb.set_trace(header='TODO: what happens to bures distance for embedded cost function?') 79 | 80 | cova_diff = bbures_distance(Σ1, Σ2, sqrtΣ1=sqrtΣ1, squared=True, **kwargs) 81 | d = torch.relu(mean_diff + cova_diff) 82 | if not squared: 83 | d = torch.sqrt(d) 84 | return d 85 | 86 | 87 | def pwdist_gauss(M1, S1, M2, S2, symmetric=False, return_dmeans=False, nworkers=1, 88 | commute=False): 89 | """ POTENTIALLY DEPRECATED. 90 | Computes Wasserstein Distance between collections of Gaussians, 91 | represented in terms of their means (M1,M2) and Covariances (S1,S2). 92 | 93 | Arguments: 94 | parallel (bool): Whether to use multiprocessing via joblib 95 | 96 | 97 | """ 98 | n1, n2 = len(M1), len(M2) # Number of clusters in each side 99 | if symmetric: 100 | ## If tasks are symmetric (same data on both sides) only need combinations 101 | pairs = list(itertools.combinations(range(n1), 2)) 102 | else: 103 | ## If tasks are assymetric, need n1 x n2 comparisons 104 | pairs = list(itertools.product(range(n1), range(n2))) 105 | 106 | D = torch.zeros((n1, n2)).to(device) 107 | 108 | if nworkers > 1: 109 | results = Parallel(n_jobs=nworkers, verbose=1, backend="threading")( 110 | delayed(wasserstein_gauss_distance)(M1[i], M2[j], S1[i], S2[j], squared=True) for i, j in pairs) 111 | for (i, j), d in zip(pairs, results): 112 | D[i, j] = d 113 | if symmetric: 114 | D[j, i] = D[i, j] 115 | else: 116 | for i, j in tqdm(pairs, leave=False): 117 | D[i, j] = wasserstein_gauss_distance( 118 | M1[i], M2[j], S1[i], S2[j], squared=True, commute=commute) 119 | if symmetric: 120 | D[j, i] = D[i, j] 121 | 122 | if return_dmeans: 123 | D_means = torch.cdist(M1, M2) # For viz purposes only 124 | return D, D_means 125 | else: 126 | return D 127 | 128 | 129 | def efficient_pwdist_gauss(M1, S1, M2=None, S2=None, sqrtS1=None, sqrtS2=None, 130 | symmetric=False, diagonal_cov=False, commute=False, 131 | sqrt_method='spectral',sqrt_niters=20,sqrt_pref=0, 132 | device='cpu',nworkers=1, 133 | cost_function='euclidean', 134 | return_dmeans=False, return_sqrts=False): 135 | """ [Formerly known as efficient_pwassdist] Efficient computation of pairwise 136 | label-to-label Wasserstein distances between various distributions. Saves 137 | computation by precomputing and storing covariance square roots.""" 138 | if M2 is None: 139 | symmetric = True 140 | M2, S2 = M1, S1 141 | 142 | n1, n2 = len(M1), len(M2) # Number of clusters in each side 143 | if symmetric: 144 | ## If tasks are symmetric (same data on both sides) only need combinations 145 | pairs = list(itertools.combinations(range(n1), 2)) 146 | else: 147 | ## If tasks are assymetric, need n1 x n2 comparisons 148 | pairs = list(itertools.product(range(n1), range(n2))) 149 | 150 | D = torch.zeros((n1, n2), device = device, dtype=M1.dtype) 151 | 152 | sqrtS = [] 153 | ## Note that we need inverses of only one of two datasets. 154 | ## If sqrtS of S1 provided, use those. If S2 provided, flip roles of covs in Bures 155 | both_sqrt = (sqrtS1 is not None) and (sqrtS2 is not None) 156 | if (both_sqrt and sqrt_pref==0) or (sqrtS1 is not None): 157 | ## Either both were provided and S1 (idx=0) is prefered, or only S1 provided 158 | flip = False 159 | sqrtS = sqrtS1 160 | elif sqrtS2 is not None: 161 | ## S1 wasn't provided 162 | if sqrt_pref == 0: logger.warning('sqrt_pref=0 but S1 not provided!') 163 | flip = True 164 | sqrtS = sqrtS2 # S2 playes role of S1 165 | elif len(S1) <= len(S2): # No precomputed squareroots provided. Compute, but choose smaller of the two! 166 | flip = False 167 | S = S1 168 | else: 169 | flip = True 170 | S = S2 # S2 playes role of S1 171 | 172 | if not sqrtS: 173 | logger.info('Precomputing covariance matrix square roots...') 174 | for i, Σ in tqdm(enumerate(S), leave=False): 175 | if diagonal_cov: 176 | assert Σ.ndim == 1 177 | sqrtS.append(torch.sqrt(Σ)) # This is actually not needed. 178 | else: 179 | sqrtS.append(sqrtm(Σ) if sqrt_method == 180 | 'spectral' else sqrtm_newton_schulz(Σ, sqrt_niters)) 181 | 182 | logger.info('Computing gaussian-to-gaussian wasserstein distances...') 183 | pbar = tqdm(pairs, leave=False) 184 | pbar.set_description('Computing label-to-label distances') 185 | for i, j in pbar: 186 | if not flip: 187 | D[i, j] = wasserstein_gauss_distance(M1[i], M2[j], S1[i], S2[j], sqrtS[i], 188 | diagonal_cov=diagonal_cov, 189 | commute=commute, squared=True, 190 | cost_function=cost_function, 191 | sqrt_method=sqrt_method, 192 | sqrt_niters=sqrt_niters) 193 | else: 194 | D[i, j] = wasserstein_gauss_distance(M2[j], M1[i], S2[j], S1[i], sqrtS[j], 195 | diagonal_cov=diagonal_cov, 196 | commute=commute, squared=True, 197 | cost_function=cost_function, 198 | sqrt_method=sqrt_method, 199 | sqrt_niters=sqrt_niters) 200 | if symmetric: 201 | D[j, i] = D[i, j] 202 | 203 | if return_dmeans: 204 | D_means = torch.cdist(M1, M2) # For viz purposes only 205 | if return_sqrts: 206 | return D, D_means, sqrtS 207 | else: 208 | return D, D_means 209 | elif return_sqrts: 210 | return D, sqrtS 211 | else: 212 | return D 213 | 214 | def pwdist_means_only(M1, M2=None, symmetric=False, device=None): 215 | if M2 is None or symmetric: 216 | symmetric = True 217 | M2 = M1 218 | D = torch.cdist(M1, M2) 219 | if device: 220 | D = D.to(device) 221 | return D 222 | 223 | def pwdist_upperbound(M1, S1, M2=None, S2=None,symmetric=False, means_only=False, 224 | diagonal_cov=False, commute=False, device=None, 225 | return_dmeans=False): 226 | """ Computes upper bound of the Wasserstein distance between distributions 227 | with given mean and covariance. 228 | """ 229 | 230 | if M2 is None: 231 | symmetric = True 232 | M2, S2 = M1, S1 233 | 234 | n1, n2 = len(M1), len(M2) # Number of clusters in each side 235 | if symmetric: 236 | ## If tasks are symmetric (same data on both sides) only need combinations 237 | pairs = list(itertools.combinations(range(n1), 2)) 238 | else: 239 | ## If tasks are assymetric, need n1 x n2 comparisons 240 | pairs = list(itertools.product(range(n1), range(n2))) 241 | 242 | D = torch.zeros((n1, n2), device = device, dtype=M1.dtype) 243 | 244 | logger.info('Computing gaussian-to-gaussian wasserstein distances...') 245 | pbar = tqdm(pairs, leave=False) 246 | pbar.set_description('Computing label-to-label distances') 247 | 248 | if means_only or return_dmeans: 249 | D_means = torch.cdist(M1, M2) 250 | 251 | if not means_only: 252 | for i, j in pbar: 253 | if means_only: 254 | D[i,j] = ((M1[i]- M2[j])**2).sum(axis=-1) 255 | else: 256 | D[i,j] = ((M1[i]- M2[j])**2).sum(axis=-1) + (S1[i] + S2[j]).diagonal(dim1=-2, dim2=-1).sum(-1) 257 | if symmetric: 258 | D[j, i] = D[i, j] 259 | else: 260 | D = D_means 261 | 262 | if return_dmeans: 263 | D_means = torch.cdist(M1, M2) # For viz purposes only 264 | return D, D_means 265 | else: 266 | return D 267 | 268 | def pwdist_exact(X1, Y1, X2=None, Y2=None, symmetric=False, loss='sinkhorn', 269 | cost_function='euclidean', p=2, debias=True, entreg=1e-1, device='cpu'): 270 | 271 | """ Efficient computation of pairwise label-to-label Wasserstein distances 272 | between multiple distributions, without using Gaussian assumption. 273 | 274 | Args: 275 | X1,X2 (tensor): n x d matrix with features 276 | Y1,Y2 (tensor): labels corresponding to samples 277 | symmetric (bool): whether X1/Y1 and X2/Y2 are to be treated as the same dataset 278 | cost_function (callable/string): the 'ground metric' between features to 279 | be used in optimal transport problem. If callable, should take follow 280 | the convection of the cost argument in geomloss.SamplesLoss 281 | p (int): power of the cost (i.e. order of p-Wasserstein distance). Ignored 282 | if cost_function is a callable. 283 | debias (bool): Only relevant for Sinkhorn. If true, uses debiased sinkhorn 284 | divergence. 285 | 286 | 287 | """ 288 | device = process_device_arg(device) 289 | if X2 is None: 290 | symmetric = True 291 | X2, Y2 = X1, Y1 292 | 293 | c1 = torch.unique(Y1) 294 | c2 = torch.unique(Y2) 295 | n1, n2 = len(c1), len(c2) 296 | 297 | ## We account for the possibility that labels are shifted (c1[0]!=0), see below 298 | 299 | if symmetric: 300 | ## If tasks are symmetric (same data on both sides) only need combinations 301 | pairs = list(itertools.combinations(range(n1), 2)) 302 | else: 303 | ## If tasks are assymetric, need n1 x n2 comparisons 304 | pairs = list(itertools.product(range(n1), range(n2))) 305 | 306 | 307 | if cost_function == 'euclidean': 308 | if p == 1: 309 | cost_function = lambda x, y: geomloss.utils.distances(x, y) 310 | elif p == 2: 311 | cost_function = lambda x, y: geomloss.utils.squared_distances(x, y) 312 | else: 313 | raise ValueError() 314 | 315 | if loss == 'sinkhorn': 316 | distance = geomloss.SamplesLoss( 317 | loss=loss, p=p, 318 | cost=cost_function, 319 | debias=debias, 320 | blur=entreg**(1 / p), 321 | ) 322 | # elif loss == 'wasserstein': 323 | # def distance(Xa, Xb): 324 | # C = cost_function(Xa, Xb).cpu() 325 | # return torch.tensor(ot.emd2(ot.unif(Xa.shape[0]), ot.unif(Xb.shape[0]), C))#, verbose=True) 326 | else: 327 | raise ValueError('Wrong loss') 328 | 329 | 330 | logger.info('Computing label-to-label (exact) wasserstein distances...') 331 | pbar = tqdm(pairs, leave=False) 332 | pbar.set_description('Computing label-to-label distances') 333 | D = torch.zeros((n1, n2), device = device, dtype=X1.dtype) 334 | for i, j in pbar: 335 | try: 336 | D[i, j] = distance(X1[Y1==c1[i]].to(device), X2[Y2==c2[j]].to(device)).item() 337 | except: 338 | print("This is awkward. Distance computation failed. Geomloss is hard to debug" \ 339 | "But here's a few things that might be happening: "\ 340 | " 1. Too many samples with this label, causing memory issues" \ 341 | " 2. Datatype errors, e.g., if the two datasets have different type") 342 | sys.exit('Distance computation failed. Aborting.') 343 | if symmetric: 344 | D[j, i] = D[i, j] 345 | return D 346 | -------------------------------------------------------------------------------- /otdd/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle as pkl 4 | import pdb 5 | import shutil 6 | import logging 7 | import tempfile 8 | 9 | def launch_logger(console_level='warning'): 10 | ############################### Logging Config ################################# 11 | ## Remove all handlers of root logger object -> needed to override basicConfig above 12 | for handler in logging.root.handlers[:]: 13 | logging.root.removeHandler(handler) 14 | 15 | _logger = logging.getLogger() 16 | _logger.setLevel(logging.INFO) # Has to be min of all the others 17 | 18 | ## create file handler which logs even debug messages, use random logfile name 19 | logfile = tempfile.NamedTemporaryFile(prefix="otddlog_", dir='/tmp').name 20 | fh = logging.FileHandler(logfile) 21 | fh.setLevel(logging.INFO) 22 | 23 | ## create console handler with a higher log level 24 | ch = logging.StreamHandler(stream=sys.stdout) 25 | if console_level == 'warning': 26 | ch.setLevel(logging.WARNING) 27 | elif console_level == 'info': 28 | ch.setLevel(logging.INFO) 29 | else: 30 | raise ValueError() 31 | ## create formatter and add it to the handlers 32 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(levelname)s: %(message)s', 33 | datefmt='%Y-%m-%d %H:%M:%S') 34 | fh.setFormatter(formatter) 35 | ch.setFormatter(formatter) 36 | _logger.addHandler(fh) 37 | _logger.addHandler(ch) 38 | ################################################################################ 39 | return _logger 40 | 41 | def safedump(d,f): 42 | try: 43 | pkl.dump(d, open(f, 'wb')) 44 | except: 45 | pdb.set_trace() 46 | 47 | def append_to_file(fname, l): 48 | with open(fname, "a") as f: 49 | f.write('\t'.join(l) + '\n') 50 | 51 | def delete_if_exists(path, typ='f'): 52 | if typ == 'f' and os.path.exists(path): 53 | os.remove(path) 54 | elif typ == 'd' and os.path.isdir(path): 55 | shutil.rmtree(path) 56 | else: 57 | raise ValueError("Unrecognized path type") 58 | -------------------------------------------------------------------------------- /pipeline_projektor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reds-lab/projektor/6da679ff9c43ef43819e3dcfd7189daebecfc431/pipeline_projektor.png -------------------------------------------------------------------------------- /prep_train_data.py: -------------------------------------------------------------------------------- 1 | import otdd 2 | from otdd.pytorch.datasets import load_imagenet, load_torchvision_data, load_torchvision_data_shuffle, load_torchvision_data_perturb, load_torchvision_data_keepclean 3 | from otdd.pytorch.distance import DatasetDistance, FeatureCost 4 | 5 | import torch 6 | import torchvision 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | import torchvision.models as models 12 | from torch.autograd import Variable 13 | 14 | import matplotlib.pyplot as plt 15 | from torch import tensor 16 | from torchvision import datasets, transforms 17 | import pandas as pd 18 | import numpy as np 19 | from copy import deepcopy as dpcp 20 | import pickle 21 | import time 22 | 23 | # import torchshow as ts 24 | 25 | from torchvision.utils import make_grid 26 | from torch.utils.data import random_split, Dataset, TensorDataset, DataLoader 27 | 28 | import argparse 29 | 30 | 31 | 32 | parser = argparse.ArgumentParser() 33 | # add_dataset_model_arguments(parser) 34 | 35 | parser.add_argument('--cnum', type=int, required=True, 36 | help='number of cuda in the server') 37 | 38 | parser.add_argument('--n', type=int, required=True, 39 | help='number of data') 40 | arg = parser.parse_args() # args conflict with other argument 41 | 42 | print(f"procs cnum {arg.cnum}") 43 | 44 | print(f"data cnum {arg.n}") 45 | 46 | print("end") 47 | 48 | 49 | cuda_num = arg.cnum 50 | import torch 51 | print(torch.__version__) 52 | import os 53 | os.environ["CUDA_VISIBLE_DEVICES"]=str(cuda_num) 54 | print(os.environ["CUDA_VISIBLE_DEVICES"]) 55 | torch.cuda.set_device(cuda_num) 56 | print("Cuda device: ", torch.cuda.current_device()) 57 | print("cude devices: ", torch.cuda.device_count()) 58 | device = torch.device('cuda:' + str(cuda_num) if torch.cuda.is_available() else 'cpu') 59 | 60 | 61 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 62 | 'dog', 'frog', 'horse', 'ship', 'truck') 63 | data_all = pickle.load( open('data/cifar10.data', 'rb') ) 64 | train_features, train_labels, test_features, test_labels = data_all 65 | 66 | 67 | # data_all = pickle.load(open('Baselines/datasets/clean_cifar.data', 'rb')) 68 | # train_features, train_labels, test_features, test_labels = data_all 69 | 70 | label_idx = [] 71 | for i in range(10): 72 | label_idx.append((train_labels==i).nonzero()[0]) 73 | 74 | test_label_idx = [] 75 | for i in range(10): 76 | test_label_idx.append((test_labels==i).nonzero()[0]) 77 | 78 | class PreActBlock(nn.Module): 79 | '''Pre-activation version of the BasicBlock.''' 80 | expansion = 1 81 | 82 | def __init__(self, in_planes, planes, stride=1): 83 | super(PreActBlock, self).__init__() 84 | self.bn1 = nn.BatchNorm2d(in_planes) 85 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 86 | self.bn2 = nn.BatchNorm2d(planes) 87 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 88 | 89 | if stride != 1 or in_planes != self.expansion*planes: 90 | self.shortcut = nn.Sequential( 91 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 92 | ) 93 | 94 | def forward(self, x): 95 | out = F.relu(self.bn1(x)) 96 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 97 | out = self.conv1(out) 98 | out = self.conv2(F.relu(self.bn2(out))) 99 | out += shortcut 100 | return out 101 | 102 | 103 | class PreActBottleneck(nn.Module): 104 | '''Pre-activation version of the original Bottleneck module.''' 105 | expansion = 4 106 | 107 | def __init__(self, in_planes, planes, stride=1): 108 | super(PreActBottleneck, self).__init__() 109 | self.bn1 = nn.BatchNorm2d(in_planes) 110 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 111 | self.bn2 = nn.BatchNorm2d(planes) 112 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 113 | self.bn3 = nn.BatchNorm2d(planes) 114 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 115 | 116 | if stride != 1 or in_planes != self.expansion*planes: 117 | self.shortcut = nn.Sequential( 118 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 119 | ) 120 | 121 | def forward(self, x): 122 | out = F.relu(self.bn1(x)) 123 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 124 | out = self.conv1(out) 125 | out = self.conv2(F.relu(self.bn2(out))) 126 | out = self.conv3(F.relu(self.bn3(out))) 127 | out += shortcut 128 | return out 129 | 130 | 131 | class PreActResNet(nn.Module): 132 | def __init__(self, block, num_blocks, num_classes=10): 133 | super(PreActResNet, self).__init__() 134 | self.in_planes = 64 135 | 136 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 137 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 138 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 139 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 140 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 141 | self.linear = nn.Linear(512*block.expansion, 10) 142 | # self.linear1 = nn.Linear(128, 10) 143 | 144 | def _make_layer(self, block, planes, num_blocks, stride): 145 | strides = [stride] + [1]*(num_blocks-1) 146 | layers = [] 147 | for stride in strides: 148 | layers.append(block(self.in_planes, planes, stride)) 149 | self.in_planes = planes * block.expansion 150 | return nn.Sequential(*layers) 151 | 152 | def forward(self, x): 153 | out = self.conv1(x) 154 | out = self.layer1(out) 155 | out = self.layer2(out) 156 | out = self.layer3(out) 157 | out = self.layer4(out) 158 | out = F.avg_pool2d(out, 4) 159 | out = out.view(out.size(0), -1) 160 | out = self.linear(out) 161 | # return out # only for embedder 162 | # out = self.linear1(out) 163 | return out 164 | 165 | 166 | def PreActResNet18(): 167 | return PreActResNet(PreActBlock, [2,2,2,2]) 168 | 169 | 170 | def get_model_log_err(train_loader, test_loader, epochs = 110): 171 | 172 | net = PreActResNet18() 173 | net = net.to(device) 174 | 175 | test_criterion = nn.CrossEntropyLoss() 176 | 177 | criterion = nn.CrossEntropyLoss() 178 | optimizer = optim.Adam(net.parameters(),lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) 179 | best_train_loss = 999999 180 | for epoch in range(epochs): 181 | # Training 182 | # print('Epoch {}/{}'.format(epoch + 1, 70)) 183 | # print('-' * 10) 184 | start_time = time.time() 185 | net.train() 186 | train_loss = 0 187 | correct = 0 188 | total = 0 189 | for batch_idx, (inputs, targets) in enumerate(train_loader): 190 | inputs, targets = inputs.to(device), targets.to(device) 191 | optimizer.zero_grad() 192 | outputs = net(inputs) 193 | loss = criterion(outputs, targets) 194 | loss.backward() 195 | optimizer.step() 196 | 197 | train_loss += loss.item() 198 | _, predicted = outputs.max(1) 199 | total += targets.size(0) 200 | correct += predicted.eq(targets).sum().item() 201 | end_time = time.time() 202 | if epoch % 10 == 0: 203 | print('%.1f . TrainLoss: %.3f | TrainAcc: %.3f%% (%d/%d) | Time Elapsed %.3f sec ' % (epoch, train_loss/(batch_idx+1), 100.*correct/total, correct, total, end_time-start_time)) 204 | best_train_loss = min(best_train_loss, train_loss/(batch_idx+1)) 205 | 206 | 207 | # net.eval() 208 | test_loss = 0 209 | correct = 0 210 | total = 0 211 | # acc = [0 for c in list_of_classes] 212 | 213 | with torch.no_grad(): 214 | for batch_idx, (inputs, targets) in enumerate(test_loader): 215 | inputs, targets = inputs.to(device), targets.to(device) 216 | outputs = net(inputs) 217 | loss = test_criterion(outputs, targets) 218 | 219 | test_loss += loss.item() 220 | _, predicted = outputs.max(1) 221 | total += targets.size(0) 222 | correct += predicted.eq(targets).sum().item() 223 | 224 | # class wise accuracy 225 | 226 | 227 | if epoch % 10 == 0: 228 | print('TestLoss: %.3f | TestAcc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 229 | 230 | 231 | test_loss /= (batch_idx+1) 232 | print(f"test loss {test_loss} train loss {best_train_loss}") 233 | 234 | return test_loss - best_train_loss, 100.*correct/total 235 | 236 | def get_ot_dist(train_loader, test_loader, n=5000): 237 | 238 | 239 | net_test = PreActResNet18() 240 | net_test = net_test.to(device) 241 | net_test.load_state_dict(torch.load('checkpoint/preact_resnet18.pth', map_location=str('cuda:'+str(cuda_num)))) 242 | net_test.eval() 243 | 244 | embedder = net_test.to(device) 245 | embedder.fc = torch.nn.Identity() 246 | for p in embedder.parameters(): 247 | p.requires_grad = False 248 | 249 | # Here we use same embedder for both datasets 250 | feature_cost = FeatureCost(src_embedding = embedder, 251 | src_dim = (3,32,32), 252 | tgt_embedding = embedder, 253 | tgt_dim = (3,32,32), 254 | p = 2, 255 | device='cuda') 256 | 257 | dist = DatasetDistance(train_loader, test_loader, 258 | inner_ot_method = 'exact', 259 | debiased_loss = True, 260 | feature_cost = feature_cost, 261 | λ_x=1.0, λ_y=1.0, 262 | sqrt_method = 'spectral', 263 | sqrt_niters=10, 264 | precision='single', 265 | p = 2, entreg = 1e-2, 266 | device='cuda') 267 | k = dist.distance(maxsamples = n, return_coupling = True) 268 | 269 | return k[0].item() 270 | 271 | def dataset_q(q1_amt, q2_amt, num, train_feats, train_labels): 272 | # two datasets, q=0 -> dataset2, q=1 -> dataset1 273 | # validation set: unbiased sample from MNIST validation set 274 | # dataset1: class 0-4: 99% (19.8% each class), class 5-9: 1% (0.2% each class) 275 | # dataset2: class 0-4: 2% (0.4% each class), class 5-9: 98% (19.6% each class) 276 | # near balance at q=0.5 277 | 278 | ds1_idx = [] 279 | ds2_idx = [] 280 | ds3_idx = [] 281 | ds1_labels = [] 282 | ds2_labels = [] 283 | ds3_labels = [] 284 | # ds1_features = [] 285 | # ds2_features = [] 286 | 287 | d1c1 = 0.2425 288 | d1c2 = 0.005 289 | d1c3 = 0.005 290 | 291 | d2c1 = 0.0057 292 | d2c2 = 0.32 293 | d2c3 = 0.0057 294 | 295 | d3c1 = 0.0014 296 | d3c2 = 0.0014 297 | d3c3 = 0.33 298 | 299 | 300 | 301 | # sample size 302 | n = num # size of dataset for training (use for construct) 303 | # ratio 304 | q1 = q1_amt # q * dataset 1 305 | q2 = q2_amt # q * dataset 1 306 | q3 = 1-q1-q2 # q * dataset 1 307 | 308 | for i in range(4): 309 | ds1_idx.append(label_idx[i][np.random.randint(len(label_idx[i]), size=int(np.rint(n*q1*d1c1)))]) 310 | ds2_idx.append(label_idx[i][np.random.randint(len(label_idx[i]), size=int(np.rint(n*q2*d2c1)))]) 311 | ds3_idx.append(label_idx[i][np.random.randint(len(label_idx[i]), size=int(np.rint(n*q3*d3c1)))]) 312 | ds1_labels.append(np.ones(int(np.rint(n*q1*d1c1)))*i) 313 | ds2_labels.append(np.ones(int(np.rint(n*q2*d2c1)))*i) 314 | ds3_labels.append(np.ones(int(np.rint(n*q3*d3c1)))*i) 315 | for i in range(4, 7): 316 | ds1_idx.append(label_idx[i][np.random.randint(len(label_idx[i]), size=int(np.rint(n*q1*d1c2)))]) 317 | ds2_idx.append(label_idx[i][np.random.randint(len(label_idx[i]), size=int(np.rint(n*q2*d2c2)))]) 318 | ds3_idx.append(label_idx[i][np.random.randint(len(label_idx[i]), size=int(np.rint(n*q3*d3c2)))]) 319 | ds1_labels.append(np.ones(int(np.rint(n*q1*d1c2)))*i) 320 | ds2_labels.append(np.ones(int(np.rint(n*q2*d2c2)))*i) 321 | ds3_labels.append(np.ones(int(np.rint(n*q3*d3c2)))*i) 322 | for i in range(7, 10): 323 | ds1_idx.append(label_idx[i][np.random.randint(len(label_idx[i]), size=int(np.rint(n*q1*d1c3)))]) 324 | ds2_idx.append(label_idx[i][np.random.randint(len(label_idx[i]), size=int(np.rint(n*q2*d2c3)))]) 325 | ds3_idx.append(label_idx[i][np.random.randint(len(label_idx[i]), size=int(np.rint(n*q3*d3c3)))]) 326 | ds1_labels.append(np.ones(int(np.rint(n*q1*d1c3)))*i) 327 | ds2_labels.append(np.ones(int(np.rint(n*q2*d2c3)))*i) 328 | ds3_labels.append(np.ones(int(np.rint(n*q3*d3c3)))*i) 329 | 330 | ds1_features_fl = train_feats[np.concatenate(ds1_idx)] 331 | ds2_features_fl = train_feats[np.concatenate(ds2_idx)] 332 | ds3_features_fl = train_feats[np.concatenate(ds3_idx)] 333 | ds1_features = train_feats[np.concatenate(ds1_idx)] 334 | ds2_features = train_feats[np.concatenate(ds2_idx)] 335 | ds3_features = train_feats[np.concatenate(ds3_idx)] 336 | train_x_2d = np.concatenate([ds1_features, ds2_features, ds3_features]) 337 | 338 | ds1_labels = np.concatenate(ds1_labels) 339 | ds2_labels = np.concatenate(ds2_labels) 340 | ds3_labels = np.concatenate(ds3_labels) 341 | 342 | train_x = np.concatenate([ds1_features_fl, ds2_features_fl, ds3_features_fl]) 343 | train_y = np.concatenate([ds1_labels, ds2_labels, ds3_labels]) 344 | 345 | 346 | return train_x, train_y 347 | 348 | 349 | n = arg.n 350 | q = 0.0 351 | batch_size = 256 352 | 353 | breaks = 10 354 | reps = 3 355 | 356 | # make test dataloader 357 | 358 | test_loader = torch.utils.data.DataLoader(dataset=TensorDataset(torch.Tensor(test_features).permute(0,3,1,2), torch.LongTensor(test_labels)), 359 | batch_size=batch_size, 360 | shuffle=False) 361 | 362 | qsreserrlog = [] 363 | qsotlog = [] 364 | qsaccs = [] 365 | for l in range(breaks+1): 366 | reserrlog = [] 367 | otlog = [] 368 | accs = [] 369 | 370 | for j in range(breaks+1): # going through q, from 0 to 1 - 20 points 371 | start_t = time.time() 372 | q1 = l/10 373 | q2 = j/10 374 | q3 = 1-q1-q2 375 | if q3<0: 376 | break 377 | 378 | cacheerr = [] 379 | cacheot = [] 380 | cacheacc = [] 381 | 382 | # create dataset 383 | train_x, train_y = dataset_q(q1, q2, n, train_features, train_labels) 384 | 385 | # make train dataloader 386 | train_loader = torch.utils.data.DataLoader(dataset=TensorDataset(torch.Tensor(train_x).permute(0,3,1,2), 387 | torch.LongTensor(train_y)), 388 | batch_size=batch_size, 389 | shuffle=True) 390 | for rep in range(reps): 391 | # get OT dist 392 | cacheot.append(get_ot_dist(train_loader, test_loader, n=n)) 393 | loss, acc = get_model_log_err(train_loader, test_loader) 394 | # get model error (test loss - train loss) 395 | cacheacc.append(acc) 396 | cacheerr.append(loss) 397 | print("cacheerr: ", cacheerr) 398 | print("cacheot: ", cacheot) 399 | print("cacheacc: ", cacheacc) 400 | # add median of vals 401 | reserrlog.append(np.median(cacheerr)) # median then loss + no need for log 402 | otlog.append(np.median(cacheot)) 403 | accs.append(np.median(cacheacc)) 404 | print("j: ", j, " it took: ", time.time() - start_t) 405 | 406 | qsreserrlog.append(reserrlog) 407 | qsotlog.append(otlog) 408 | qsaccs.append(accs) 409 | pickle.dump([qsreserrlog,qsotlog,qsaccs], open(f'projektor_data/cif10_3sources_unbalanced_{n}_br_{breaks}_rep_{reps}.res', 'wb' )) 410 | --------------------------------------------------------------------------------