├── .gitignore ├── README.md ├── analysis ├── latex │ ├── vary_noise_cifar10 │ └── vary_noise_imagenet ├── markdown │ ├── vary_noise_cifar10 │ └── vary_noise_imagenet └── plots │ ├── high_prob.pdf │ ├── high_prob.png │ ├── vary_noise_cifar10.pdf │ ├── vary_noise_cifar10.png │ ├── vary_noise_imagenet.pdf │ ├── vary_noise_imagenet.png │ ├── vary_train_noise_cifar_050.pdf │ ├── vary_train_noise_cifar_050.png │ ├── vary_train_noise_imagenet_050.pdf │ └── vary_train_noise_imagenet_050.png ├── code ├── __pycache__ │ └── bounds.cpython-36.pyc ├── analyze.py ├── analyze_predict.py ├── architectures.py ├── archs │ └── cifar_resnet.py ├── certify.py ├── core.py ├── datasets.py ├── predict.py ├── train.py ├── train_utils.py └── visualize.py ├── data ├── certify │ ├── cifar10 │ │ └── resnet110 │ │ │ ├── noise_0.12 │ │ │ └── test │ │ │ │ ├── sigma_0.12 │ │ │ │ └── sigma_0.50 │ │ │ ├── noise_0.25 │ │ │ └── test │ │ │ │ ├── sigma_0.25 │ │ │ │ └── sigma_0.50 │ │ │ ├── noise_0.50 │ │ │ └── test │ │ │ │ └── sigma_0.50 │ │ │ └── noise_1.00 │ │ │ └── test │ │ │ ├── sigma_0.50 │ │ │ └── sigma_1.00 │ └── imagenet │ │ └── resnet50 │ │ ├── noise_0.25 │ │ └── test │ │ │ ├── sigma_0.25 │ │ │ └── sigma_0.50 │ │ ├── noise_0.50 │ │ ├── test │ │ │ └── sigma_0.50 │ │ └── train │ │ │ └── sigma_0.50 │ │ └── noise_1.00 │ │ └── test │ │ ├── sigma_0.50 │ │ └── sigma_1.00 └── predict │ └── imagenet │ └── resnet50 │ └── noise_0.25 │ └── test │ ├── N_100 │ ├── N_1000 │ ├── N_10000 │ └── N_100000 ├── experiments.MD └── figures ├── compare_bounds.pdf ├── example_images ├── cifar10 │ ├── 10_0.png │ ├── 10_100.png │ ├── 10_25.png │ ├── 10_50.png │ ├── 110_0.png │ ├── 110_100.png │ ├── 110_25.png │ ├── 110_50.png │ ├── 20_0.png │ ├── 20_100.png │ ├── 20_25.png │ ├── 20_50.png │ ├── 70_0.png │ ├── 70_100.png │ ├── 70_25.png │ └── 70_50.png └── imagenet │ ├── 100_0.png │ ├── 100_100.png │ ├── 100_25.png │ ├── 100_50.png │ ├── 19411_0.png │ ├── 19411_100.png │ ├── 19411_25.png │ ├── 19411_50.png │ ├── 3300_0.png │ ├── 3300_100.png │ ├── 3300_25.png │ ├── 3300_50.png │ ├── 5400_0.png │ ├── 5400_100.png │ ├── 5400_25.png │ ├── 5400_50.png │ ├── 9067_0.png │ ├── 9067_100.png │ ├── 9067_25.png │ └── 9067_50.png ├── panda_0.25.gif ├── panda_0.50.gif ├── panda_1.00.gif ├── panda_577.png └── radiusslow.pdf /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *~ 3 | models 4 | .DS_Store 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Certified Adversarial Robustness via Randomized Smoothing 2 | 3 | This repository contains code and trained models for the paper [Certified Adversarial Robustness via Randomized Smoothing](https://arxiv.org/abs/1902.02918) by [Jeremy Cohen](http://cs.cmu.edu/~jeremiac), Elan Rosenfeld, and [Zico Kolter](http://zicokolter.com). 4 | 5 | Randomized smoothing is a **provable** adversarial defense in L2 norm which **scales to ImageNet.** 6 | It's also SOTA on the smaller datasets like CIFAR-10 and SVHN where other provable L2-robust classifiers are viable. 7 | 8 | ## How does it work? 9 | 10 | First, you train a neural network _f_ with Gaussian data augmentation at variance σ2. 11 | Then you leverage _f_ to create a new, "smoothed" classifier _g_, defined as follows: 12 | _g(x)_ returns the class which _f_ is most likely to return when _x_ 13 | is corrupted by isotropic Gaussian noise with variance σ2. 14 | 15 |

16 | 17 | 18 |

19 | 20 | 21 | For example, let _x_ be the image above on the left. 22 | Suppose that when _f_ classifies _x_ corrupted by Gaussian noise (the GIF on the right), _f_ returns "panda" 23 | 98\% of the time and "gibbon" 2% of the time. 24 | Then the prediction of _g_ at _x_ is defined to be "panda." 25 | 26 | 27 | Interestingly, _g_ is **provably** robust within an L2 norm ball around _x_, in the sense that for any perturbation 28 | δ with sufficiently small L2 norm, _g(x+δ)_ is guaranteed to be "panda." 29 | In this particular example, _g_ will be robust around _x_ within an L2 radius of σ Φ-1(0.98) ≈ 2.05 σ, 30 | where Φ-1 is the inverse CDF of the standard normal distribution. 31 | 32 | In general, suppose that when _f_ classifies noisy corruptions of _x_, the class "panda" is returned with probability _p_ (with _p_ > 0.5). 33 | Then _g_ is guaranteed to classify "panda" within an L2 ball around _x_ of radius σ Φ-1(_p_). 34 | 35 | ### What's the intuition behind this bound? 36 | 37 | We know that _f_ classifies noisy corruptions of _x_ as "panda" with probability 0.98. 38 | An equivalent way of phrasing this that the Gaussian distribution N(x, σ2I) puts measure 0.98 on 39 | the decision region of class "panda," defined as the set {x': f(x') = "panda"}. 40 | You can prove that no matter how the decision regions of _f_ are "shaped", for any δ with 41 | ||δ||2 < σ Φ-1(0.98), the translated Gaussian N(x+δ, σ2I) is guaranteed to put measure > 0.5 on the decision region of 42 | class "panda," implying that _g(x+δ)_ = "panda." 43 | 44 | ### Wait a minute... 45 | There's one catch: it's not possible to actually evaluate the smoothed classifer _g_. 46 | This is because it's not possible to exactly compute the probability distribution over the classes when _f_'s input is corrupted by Gaussian noise. 47 | For the same reason, it's not possible to exactly compute the radius in which _g_ is provably robust. 48 | 49 | Instead, we give Monte Carlo algorithms for both 50 | 1. **prediction**: evaluating _g_(x) 51 | 2. **certification**: computing the L2 radius in which _g_ is robust around _x_ 52 | 53 | which are guaranteed to return a correct answer with arbitrarily high probability. 54 | 55 | The prediction algorithm does this by abstaining from making any prediction when it's a "close call," e.g. if 56 | 510 noisy corruptions of _x_ were classified as "panda" and 490 were classified as "gibbon." 57 | Prediction is pretty cheap, since you don't need to use very many samples. 58 | For example, with our ImageNet classifier, making a prediction using 1000 samples took 1.5 seconds, and our classifier abstained 3\% of the time. 59 | 60 | On the other hand, certification is pretty slow, since you need _a lot_ of samples to say with high 61 | probability that the measure under N(x, σ2I) of the "panda" decision region is close to 1. 62 | In our experiments we used 100,000 samples, so making each certification took 150 seconds. 63 | 64 | ### Related work 65 | 66 | Randomized smoothing was first proposed in [Certified Robustness to Adversarial Examples with Differential Privacy](https://arxiv.org/abs/1802.03471) 67 | and later improved upon in [Second-Order Adversarial Attack and Certified Robustness](https://arxiv.org/abs/1809.03113). 68 | We simply tightened the analysis and showed that it outperforms the other provably L2-robust classifiers that have been proposed in the literature. 69 | 70 | ## ImageNet results 71 | 72 | We constructed three randomized smoothing classifiers for ImageNet, with the hyperparameter 73 | σ set to 0.25, 0.50, and 1.00. 74 | Here's what the panda image looks like under these three noise levels: 75 | 76 |

77 | 78 | 79 | 80 |

81 | 82 | 83 | The plot below shows the certified top-1 accuracy at various radii of these three classifiers. 84 | The "certified accuracy" of a classifier _g_ at radius _r_ is defined as test set accuracy that _g_ will 85 | provably attain under any possible adversarial attack with L2 norm less than _r_. 86 | As you can see, the hyperparameter σ controls a robustness/accuracy tradeoff: when 87 | σ is high, the standard accuracy is lower, but the classifier's correct predictions are 88 | robust within larger radii. 89 | 90 | 91 | 92 | To put these numbers in context: on ImageNet, random guessing would achieve a top-1 accuracy of 0.001. 93 | A perturbation with L2 norm of 1.0 could change one pixel by 255, ten pixels by 80, 100 pixels by 25, or 1000 pixels by 8. 94 | 95 | 96 | Here's the same data in tabular form. 97 | The best σ for each radius is denoted with an asterisk. 98 | 99 | | | r = 0.0 |r = 0.5 |r = 1.0 |r = 1.5 |r = 2.0 |r = 2.5 |r = 3.0 | 100 | | --- | --- | --- | --- | --- | --- | --- | --- | 101 | σ = 0.25 | 0.67* |0.49* |0.00 |0.00 |0.00 |0.00 |0.00 | 102 | σ = 0.50 | 0.57 |0.46 |0.38* |0.28* |0.00 |0.00 |0.00 | 103 | σ = 1.00 | 0.44 |0.38 |0.33 |0.26 |0.19* |0.15* |0.12* | 104 | 105 | 106 | 118 | 119 | ## This repository 120 | 121 | ### Outline 122 | 123 | The contents of this repository are as follows: 124 | 125 | * [code/](code) contains the code for our experiments. 126 | * [data/](data) contains the raw data from our experiments. 127 | * [analysis/](analysis) contains the plots and tables, based on the contents of [data](/data), that are shown in our paper. 128 | 129 | If you'd like to run our code, you need to download our models from [here](https://drive.google.com/file/d/1h_TpbXm5haY5f-l4--IKylmdz6tvPoR4/view?usp=sharing) 130 | and then move the directory `models` into the root directory of this repo. 131 | 132 | ### Smoothed classifiers 133 | 134 | Randomized smoothing is implemented in the `Smooth` class in [core.py](code/core.py). 135 | 136 | * To instantiate a smoothed clasifier _g_, use the constructor: 137 | 138 | ```def __init__(self, base_classifier: torch.nn.Module, num_classes: int, sigma: float):``` 139 | 140 | where `base_classifier` is a PyTorch module that implements _f_, `num_classes` is the number of classes in the output 141 | space, and `sigma` is the noise hyperparameter σ 142 | 143 | * To make a prediction at an input `x`, call: 144 | 145 | ``` def predict(self, x: torch.tensor, n: int, alpha: float, batch_size: int) -> int:``` 146 | 147 | where `n` is the number of Monte Carlo samples and `alpha` is the confidence level. 148 | This function will either (1) return `-1` to abstain or (2) return a class which equals _g(x)_ 149 | with probability at least `1 - alpha`. 150 | 151 | * To compute a radius in which _g_ is robust around an input `x`, call: 152 | 153 | ```def certify(self, x: torch.tensor, n0: int, n: int, alpha: float, batch_size: int) -> (int, float):``` 154 | 155 | where `n0` is the number of Monte Carlo samples to use for selection (see the paper), `n` is the number of Monte Carlo 156 | samples to use for estimation, and `alpha` is the confidence level. 157 | This function will either return the pair `(-1, 0.0)` to abstain, or return a pair 158 | `(prediction, radius)`. The probability that `certify()` will return a class not equal to _g(x)_ is no greater than `alpha`. Another way to say this is that with probability at least `1 - alpha`, `certify()` will either abstain or return _g(x)_. 159 | 160 | ### Scripts 161 | 162 | * The program [train.py](code/train.py) trains a base classifier with Gaussian data augmentation: 163 | 164 | ```python code/train.py imagenet resnet50 model_output_dir --batch 400 --noise 0.50 ``` 165 | 166 | will train a ResNet-50 on ImageNet under Gaussian data augmentation with σ=0.50. 167 | 168 | * The program [predict.py](code/predict.py) makes predictions using _g_ on a bunch of inputs. For example, 169 | 170 | ```python code/predict.py imagenet model_output_dir/checkpoint.pth.tar 0.50 prediction_outupt --alpha 0.001 --N 1000 --skip 100 --batch 400``` 171 | 172 | will load the base classifier saved at `model_output_dir/checkpoint.pth.tar`, smooth it using noise level σ=0.50, 173 | and classify every 100-th image from the ImageNet test set with parameters `N=1000` 174 | and `alpha=0.001`. 175 | 176 | * The program [certify.py](code/certify.py) certifies the robustness of _g_ on bunch of inputs. For example, 177 | 178 | ```python code/certify.py imagenet model_output_dir/checkpoint.pth.tar 0.50 certification_output --alpha 0.001 --N0 100 --N 100000 --skip 100 --batch 400``` 179 | 180 | will load the base classifier saved at `model_output_dir/checkpoint.pth.tar`, smooth it using noise level σ=0.50, 181 | and certify every 100-th image from the ImageNet test set with parameters `N0=100`, `N=100000` 182 | and `alpha=0.001`. 183 | 184 | * The program [visualize.py](code/visualize.py) outputs pictures of noisy examples. For example, 185 | 186 | ```python code/visualize.py imagenet visualize_output 100 0.0 0.25 0.5 1.0``` 187 | 188 | will visualize noisy corruptions of the 100-th image from the ImageNet test set with noise levels 189 | σ=0.0, σ=0.25, σ=0.50, and σ=1.00. 190 | 191 | * The program [analyze.py](code/analyze.py) generates all of certified accuracy plots and tables that appeared in the 192 | paper. 193 | 194 | Finally, we note that [this file](experiments.MD) describes exactly how to reproduce 195 | our experiments from the paper. 196 | 197 | 198 | We're not officially releasing code for the experiments where we compared randomized smoothing against the baselines, 199 | since that code involved a number of hacks, but feel free to get in touch if you'd like to see that code. 200 | 201 | ## Getting started 202 | 203 | 1. Clone this repository: `git clone git@github.com:locuslab/smoothing.git` 204 | 205 | 2. Install the dependencies: 206 | ``` 207 | conda create -n smoothing 208 | conda activate smoothing 209 | # below is for linux, with CUDA 10; see https://pytorch.org/ for the correct command for your system 210 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch 211 | conda install scipy pandas statsmodels matplotlib seaborn 212 | pip install setGPU 213 | ``` 214 | 3. Download our trained models from [here](https://drive.google.com/file/d/1h_TpbXm5haY5f-l4--IKylmdz6tvPoR4/view?usp=sharing). 215 | 4. If you want to run ImageNet experiments, obtain a copy of ImageNet and preprocess the `val` directory to look 216 | like the `train` directory by running [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh). 217 | Finally, set the environment variable `IMAGENET_DIR` to the directory where ImageNet is located. 218 | 219 | 5. To get the hang of things, try running this command, which will certify the robustness of one of our pretrained CIFAR-10 models 220 | on the CIFAR test set. 221 | ``` 222 | model="models/cifar10/resnet110/noise_0.25/checkpoint.pth.tar" 223 | output="???" 224 | python code/certify.py cifar10 $model 0.25 $output --skip 20 --batch 400 225 | ``` 226 | where `???` is your desired output file. 227 | -------------------------------------------------------------------------------- /analysis/latex/vary_noise_cifar10: -------------------------------------------------------------------------------- 1 | & $r = 0.25$& $r = 0.5$& $r = 0.75$& $r = 1.0$& $r = 1.25$& $r = 1.5$\\ 2 | \midrule 3 | $\sigma = 0.12$ & 0.59 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00\\ 4 | $\sigma = 0.25$ & \textbf{0.60} & \textbf{0.43} & 0.27 & 0.00 & 0.00 & 0.00\\ 5 | $\sigma = 0.50$ & 0.55 & 0.41 & \textbf{0.32} & \textbf{0.23} & 0.15 & 0.09\\ 6 | $\sigma = 1.00$ & 0.39 & 0.34 & 0.28 & 0.22 & \textbf{0.17} & \textbf{0.14}\\ 7 | -------------------------------------------------------------------------------- /analysis/latex/vary_noise_imagenet: -------------------------------------------------------------------------------- 1 | & $r = 0.5$& $r = 1.0$& $r = 1.5$& $r = 2.0$& $r = 2.5$& $r = 3.0$\\ 2 | \midrule 3 | $\sigma = 0.25$ & \textbf{0.49} & 0.00 & 0.00 & 0.00 & 0.00 & 0.00\\ 4 | $\sigma = 0.50$ & 0.46 & \textbf{0.37} & \textbf{0.29} & 0.00 & 0.00 & 0.00\\ 5 | $\sigma = 1.00$ & 0.38 & 0.33 & 0.26 & \textbf{0.19} & \textbf{0.15} & \textbf{0.12}\\ 6 | -------------------------------------------------------------------------------- /analysis/markdown/vary_noise_cifar10: -------------------------------------------------------------------------------- 1 | | | r = 0.25 |r = 0.5 |r = 0.75 |r = 1.0 |r = 1.25 |r = 1.5 | 2 | | --- | --- | --- | --- | --- | --- | --- | 3 | σ = 0.12 | 0.59 |0.00 |0.00 |0.00 |0.00 |0.00 | 4 | σ = 0.25 | 0.60* |0.43* |0.27 |0.00 |0.00 |0.00 | 5 | σ = 0.50 | 0.55 |0.41 |0.32* |0.23* |0.15 |0.09 | 6 | σ = 1.00 | 0.39 |0.34 |0.28 |0.22 |0.17* |0.14* | 7 | -------------------------------------------------------------------------------- /analysis/markdown/vary_noise_imagenet: -------------------------------------------------------------------------------- 1 | | | r = 0.5 |r = 1.0 |r = 1.5 |r = 2.0 |r = 2.5 |r = 3.0 | 2 | | --- | --- | --- | --- | --- | --- | --- | 3 | σ = 0.25 | 0.49* |0.00 |0.00 |0.00 |0.00 |0.00 | 4 | σ = 0.50 | 0.46 |0.37* |0.29* |0.00 |0.00 |0.00 | 5 | σ = 1.00 | 0.38 |0.33 |0.26 |0.19* |0.15* |0.12* | 6 | -------------------------------------------------------------------------------- /analysis/plots/high_prob.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/high_prob.pdf -------------------------------------------------------------------------------- /analysis/plots/high_prob.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/high_prob.png -------------------------------------------------------------------------------- /analysis/plots/vary_noise_cifar10.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_noise_cifar10.pdf -------------------------------------------------------------------------------- /analysis/plots/vary_noise_cifar10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_noise_cifar10.png -------------------------------------------------------------------------------- /analysis/plots/vary_noise_imagenet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_noise_imagenet.pdf -------------------------------------------------------------------------------- /analysis/plots/vary_noise_imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_noise_imagenet.png -------------------------------------------------------------------------------- /analysis/plots/vary_train_noise_cifar_050.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_train_noise_cifar_050.pdf -------------------------------------------------------------------------------- /analysis/plots/vary_train_noise_cifar_050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_train_noise_cifar_050.png -------------------------------------------------------------------------------- /analysis/plots/vary_train_noise_imagenet_050.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_train_noise_imagenet_050.pdf -------------------------------------------------------------------------------- /analysis/plots/vary_train_noise_imagenet_050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_train_noise_imagenet_050.png -------------------------------------------------------------------------------- /code/__pycache__/bounds.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/code/__pycache__/bounds.cpython-36.pyc -------------------------------------------------------------------------------- /code/analyze.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | 4 | matplotlib.use("TkAgg") 5 | import matplotlib.pyplot as plt 6 | from typing import * 7 | import pandas as pd 8 | import seaborn as sns 9 | import math 10 | 11 | sns.set() 12 | 13 | 14 | class Accuracy(object): 15 | def at_radii(self, radii: np.ndarray): 16 | raise NotImplementedError() 17 | 18 | 19 | class ApproximateAccuracy(Accuracy): 20 | def __init__(self, data_file_path: str): 21 | self.data_file_path = data_file_path 22 | 23 | def at_radii(self, radii: np.ndarray) -> np.ndarray: 24 | df = pd.read_csv(self.data_file_path, delimiter="\t") 25 | return np.array([self.at_radius(df, radius) for radius in radii]) 26 | 27 | def at_radius(self, df: pd.DataFrame, radius: float): 28 | return (df["correct"] & (df["radius"] >= radius)).mean() 29 | 30 | 31 | class HighProbAccuracy(Accuracy): 32 | def __init__(self, data_file_path: str, alpha: float, rho: float): 33 | self.data_file_path = data_file_path 34 | self.alpha = alpha 35 | self.rho = rho 36 | 37 | def at_radii(self, radii: np.ndarray) -> np.ndarray: 38 | df = pd.read_csv(self.data_file_path, delimiter="\t") 39 | return np.array([self.at_radius(df, radius) for radius in radii]) 40 | 41 | def at_radius(self, df: pd.DataFrame, radius: float): 42 | mean = (df["correct"] & (df["radius"] >= radius)).mean() 43 | num_examples = len(df) 44 | return (mean - self.alpha - math.sqrt(self.alpha * (1 - self.alpha) * math.log(1 / self.rho) / num_examples) 45 | - math.log(1 / self.rho) / (3 * num_examples)) 46 | 47 | 48 | class Line(object): 49 | def __init__(self, quantity: Accuracy, legend: str, plot_fmt: str = "", scale_x: float = 1): 50 | self.quantity = quantity 51 | self.legend = legend 52 | self.plot_fmt = plot_fmt 53 | self.scale_x = scale_x 54 | 55 | 56 | def plot_certified_accuracy(outfile: str, title: str, max_radius: float, 57 | lines: List[Line], radius_step: float = 0.01) -> None: 58 | radii = np.arange(0, max_radius + radius_step, radius_step) 59 | plt.figure() 60 | for line in lines: 61 | plt.plot(radii * line.scale_x, line.quantity.at_radii(radii), line.plot_fmt) 62 | 63 | plt.ylim((0, 1)) 64 | plt.xlim((0, max_radius)) 65 | plt.tick_params(labelsize=14) 66 | plt.xlabel("radius", fontsize=16) 67 | plt.ylabel("certified accuracy", fontsize=16) 68 | plt.legend([method.legend for method in lines], loc='upper right', fontsize=16) 69 | plt.savefig(outfile + ".pdf") 70 | plt.tight_layout() 71 | plt.title(title, fontsize=20) 72 | plt.tight_layout() 73 | plt.savefig(outfile + ".png", dpi=300) 74 | plt.close() 75 | 76 | 77 | def smallplot_certified_accuracy(outfile: str, title: str, max_radius: float, 78 | methods: List[Line], radius_step: float = 0.01, xticks=0.5) -> None: 79 | radii = np.arange(0, max_radius + radius_step, radius_step) 80 | plt.figure() 81 | for method in methods: 82 | plt.plot(radii, method.quantity.at_radii(radii), method.plot_fmt) 83 | 84 | plt.ylim((0, 1)) 85 | plt.xlim((0, max_radius)) 86 | plt.xlabel("radius", fontsize=22) 87 | plt.ylabel("certified accuracy", fontsize=22) 88 | plt.tick_params(labelsize=20) 89 | plt.gca().xaxis.set_major_locator(plt.MultipleLocator(xticks)) 90 | plt.legend([method.legend for method in methods], loc='upper right', fontsize=20) 91 | plt.tight_layout() 92 | plt.savefig(outfile + ".pdf") 93 | plt.close() 94 | 95 | 96 | def latex_table_certified_accuracy(outfile: str, radius_start: float, radius_stop: float, radius_step: float, 97 | methods: List[Line]): 98 | radii = np.arange(radius_start, radius_stop + radius_step, radius_step) 99 | accuracies = np.zeros((len(methods), len(radii))) 100 | for i, method in enumerate(methods): 101 | accuracies[i, :] = method.quantity.at_radii(radii) 102 | 103 | f = open(outfile, 'w') 104 | 105 | for radius in radii: 106 | f.write("& $r = {:.3}$".format(radius)) 107 | f.write("\\\\\n") 108 | 109 | f.write("\midrule\n") 110 | 111 | for i, method in enumerate(methods): 112 | f.write(method.legend) 113 | for j, radius in enumerate(radii): 114 | if i == accuracies[:, j].argmax(): 115 | txt = r" & \textbf{" + "{:.2f}".format(accuracies[i, j]) + "}" 116 | else: 117 | txt = " & {:.2f}".format(accuracies[i, j]) 118 | f.write(txt) 119 | f.write("\\\\\n") 120 | f.close() 121 | 122 | 123 | def markdown_table_certified_accuracy(outfile: str, radius_start: float, radius_stop: float, radius_step: float, 124 | methods: List[Line]): 125 | radii = np.arange(radius_start, radius_stop + radius_step, radius_step) 126 | accuracies = np.zeros((len(methods), len(radii))) 127 | for i, method in enumerate(methods): 128 | accuracies[i, :] = method.quantity.at_radii(radii) 129 | 130 | f = open(outfile, 'w') 131 | f.write("| | ") 132 | for radius in radii: 133 | f.write("r = {:.3} |".format(radius)) 134 | f.write("\n") 135 | 136 | f.write("| --- | ") 137 | for i in range(len(radii)): 138 | f.write(" --- |") 139 | f.write("\n") 140 | 141 | for i, method in enumerate(methods): 142 | f.write(" {} | ".format(method.legend)) 143 | for j, radius in enumerate(radii): 144 | if i == accuracies[:, j].argmax(): 145 | txt = "{:.2f}* |".format(accuracies[i, j]) 146 | else: 147 | txt = "{:.2f} |".format(accuracies[i, j]) 148 | f.write(txt) 149 | f.write("\n") 150 | f.close() 151 | 152 | 153 | if __name__ == "__main__": 154 | latex_table_certified_accuracy( 155 | "analysis/latex/vary_noise_cifar10", 0.25, 1.5, 0.25, [ 156 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), "$\sigma = 0.12$"), 157 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"), 158 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"), 159 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"), 160 | ]) 161 | markdown_table_certified_accuracy( 162 | "analysis/markdown/vary_noise_cifar10", 0.25, 1.5, 0.25, [ 163 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), "σ = 0.12"), 164 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), "σ = 0.25"), 165 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "σ = 0.50"), 166 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), "σ = 1.00"), 167 | ]) 168 | latex_table_certified_accuracy( 169 | "analysis/latex/vary_noise_imagenet", 0.5, 3.0, 0.5, [ 170 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"), 171 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"), 172 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"), 173 | ]) 174 | markdown_table_certified_accuracy( 175 | "analysis/markdown/vary_noise_imagenet", 0.5, 3.0, 0.5, [ 176 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), "σ = 0.25"), 177 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "σ = 0.50"), 178 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), "σ = 1.00"), 179 | ]) 180 | plot_certified_accuracy( 181 | "analysis/plots/vary_noise_cifar10", "CIFAR-10, vary $\sigma$", 1.5, [ 182 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), "$\sigma = 0.12$"), 183 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"), 184 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"), 185 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"), 186 | ]) 187 | plot_certified_accuracy( 188 | "analysis/plots/vary_train_noise_cifar_050", "CIFAR-10, vary train noise, $\sigma=0.5$", 1.5, [ 189 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.50"), "train $\sigma = 0.25$"), 190 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "train $\sigma = 0.50$"), 191 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_0.50"), "train $\sigma = 1.00$"), 192 | ]) 193 | plot_certified_accuracy( 194 | "analysis/plots/vary_train_noise_imagenet_050", "ImageNet, vary train noise, $\sigma=0.5$", 1.5, [ 195 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.50"), "train $\sigma = 0.25$"), 196 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "train $\sigma = 0.50$"), 197 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_0.50"), "train $\sigma = 1.00$"), 198 | ]) 199 | plot_certified_accuracy( 200 | "analysis/plots/vary_noise_imagenet", "ImageNet, vary $\sigma$", 4, [ 201 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"), 202 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"), 203 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"), 204 | ]) 205 | plot_certified_accuracy( 206 | "analysis/plots/high_prob", "Approximate vs. High-Probability", 2.0, [ 207 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "Approximate"), 208 | Line(HighProbAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50", 0.001, 0.001), "High-Prob"), 209 | ]) 210 | -------------------------------------------------------------------------------- /code/analyze_predict.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | if __name__ == "__main__": 4 | output = pd.DataFrame(columns=["n", "correct, accurate", "correct, inaccuraet", 5 | "incorrect, accurate", "incorrect, inaccurate", "abstain"]) 6 | 7 | results = [] 8 | 9 | gold_standard = pd.read_csv("data/predict/imagenet/resnet50/noise_0.25/test/N_100000", delimiter="\t")[:450] 10 | for N in [100, 1000, 10000]: 11 | df = pd.read_csv("data/predict/imagenet/resnet50/noise_0.25/test/N_{}".format(N), delimiter="\t")[:450] 12 | accurate = df["predict"] == gold_standard["predict"] 13 | abstain = df["predict"] == -1 14 | frac_abstain = abstain.mean() 15 | frac_correct_accurate = (df["correct"] & accurate & ~abstain).mean() 16 | frac_correct_inaccurate = (df["correct"] & ~accurate & ~abstain).mean() 17 | frac_incorrect_acccurate = (~df["correct"] & accurate & ~abstain).mean() 18 | frac_incorrect_inacccurate = (~df["correct"] & ~accurate & ~abstain).mean() 19 | results.append((N, frac_correct_accurate, frac_correct_inaccurate, 20 | frac_incorrect_acccurate, frac_incorrect_inacccurate, frac_abstain)) 21 | 22 | df = pd.DataFrame.from_records(results, "n", columns=["n", "correct, accurate", "correct, inaccurate", 23 | "incorrect, accurate", "incorrect, inaccurate", "abstain"]) 24 | print(df.to_latex(float_format=lambda f:"{:.2f}".format(f))) 25 | -------------------------------------------------------------------------------- /code/architectures.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models.resnet import resnet50 3 | import torch.backends.cudnn as cudnn 4 | from archs.cifar_resnet import resnet as resnet_cifar 5 | from datasets import get_normalize_layer 6 | from torch.nn.functional import interpolate 7 | 8 | # resnet50 - the classic ResNet-50, sized for ImageNet 9 | # cifar_resnet20 - a 20-layer residual network sized for CIFAR 10 | # cifar_resnet110 - a 110-layer residual network sized for CIFAR 11 | ARCHITECTURES = ["resnet50", "cifar_resnet20", "cifar_resnet110"] 12 | 13 | def get_architecture(arch: str, dataset: str) -> torch.nn.Module: 14 | """ Return a neural network (with random weights) 15 | 16 | :param arch: the architecture - should be in the ARCHITECTURES list above 17 | :param dataset: the dataset - should be in the datasets.DATASETS list 18 | :return: a Pytorch module 19 | """ 20 | if arch == "resnet50" and dataset == "imagenet": 21 | model = torch.nn.DataParallel(resnet50(pretrained=False)).cuda() 22 | cudnn.benchmark = True 23 | elif arch == "cifar_resnet20": 24 | model = resnet_cifar(depth=20, num_classes=10).cuda() 25 | elif arch == "cifar_resnet110": 26 | model = resnet_cifar(depth=110, num_classes=10).cuda() 27 | normalize_layer = get_normalize_layer(dataset) 28 | return torch.nn.Sequential(normalize_layer, model) 29 | -------------------------------------------------------------------------------- /code/archs/cifar_resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | ''' 4 | This file is from: https://raw.githubusercontent.com/bearpaw/pytorch-classification/master/models/cifar/resnet.py 5 | by Wei Yang 6 | ''' 7 | import torch.nn as nn 8 | import math 9 | 10 | 11 | __all__ = ['resnet'] 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | "3x3 convolution with padding" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 59 | padding=1, bias=False) 60 | self.bn2 = nn.BatchNorm2d(planes) 61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(planes * 4) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv3(out) 79 | out = self.bn3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class ResNet(nn.Module): 91 | 92 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'): 93 | super(ResNet, self).__init__() 94 | # Model type specifies number of layers for CIFAR-10 model 95 | if block_name.lower() == 'basicblock': 96 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 97 | n = (depth - 2) // 6 98 | block = BasicBlock 99 | elif block_name.lower() == 'bottleneck': 100 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 101 | n = (depth - 2) // 9 102 | block = Bottleneck 103 | else: 104 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 105 | 106 | 107 | self.inplanes = 16 108 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 109 | bias=False) 110 | self.bn1 = nn.BatchNorm2d(16) 111 | self.relu = nn.ReLU(inplace=True) 112 | self.layer1 = self._make_layer(block, 16, n) 113 | self.layer2 = self._make_layer(block, 32, n, stride=2) 114 | self.layer3 = self._make_layer(block, 64, n, stride=2) 115 | self.avgpool = nn.AvgPool2d(8) 116 | self.fc = nn.Linear(64 * block.expansion, num_classes) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 121 | m.weight.data.normal_(0, math.sqrt(2. / n)) 122 | elif isinstance(m, nn.BatchNorm2d): 123 | m.weight.data.fill_(1) 124 | m.bias.data.zero_() 125 | 126 | def _make_layer(self, block, planes, blocks, stride=1): 127 | downsample = None 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | nn.Conv2d(self.inplanes, planes * block.expansion, 131 | kernel_size=1, stride=stride, bias=False), 132 | nn.BatchNorm2d(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append(block(self.inplanes, planes, stride, downsample)) 137 | self.inplanes = planes * block.expansion 138 | for i in range(1, blocks): 139 | layers.append(block(self.inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | x = self.conv1(x) 145 | x = self.bn1(x) 146 | x = self.relu(x) # 32x32 147 | 148 | x = self.layer1(x) # 32x32 149 | x = self.layer2(x) # 16x16 150 | x = self.layer3(x) # 8x8 151 | 152 | x = self.avgpool(x) 153 | x = x.view(x.size(0), -1) 154 | x = self.fc(x) 155 | 156 | return x 157 | 158 | 159 | def resnet(**kwargs): 160 | """ 161 | Constructs a ResNet model. 162 | """ 163 | return ResNet(**kwargs) -------------------------------------------------------------------------------- /code/certify.py: -------------------------------------------------------------------------------- 1 | # evaluate a smoothed classifier on a dataset 2 | import argparse 3 | import os 4 | import setGPU 5 | from datasets import get_dataset, DATASETS, get_num_classes 6 | from core import Smooth 7 | from time import time 8 | import torch 9 | import datetime 10 | from architectures import get_architecture 11 | 12 | parser = argparse.ArgumentParser(description='Certify many examples') 13 | parser.add_argument("dataset", choices=DATASETS, help="which dataset") 14 | parser.add_argument("base_classifier", type=str, help="path to saved pytorch model of base classifier") 15 | parser.add_argument("sigma", type=float, help="noise hyperparameter") 16 | parser.add_argument("outfile", type=str, help="output file") 17 | parser.add_argument("--batch", type=int, default=1000, help="batch size") 18 | parser.add_argument("--skip", type=int, default=1, help="how many examples to skip") 19 | parser.add_argument("--max", type=int, default=-1, help="stop after this many examples") 20 | parser.add_argument("--split", choices=["train", "test"], default="test", help="train or test set") 21 | parser.add_argument("--N0", type=int, default=100) 22 | parser.add_argument("--N", type=int, default=100000, help="number of samples to use") 23 | parser.add_argument("--alpha", type=float, default=0.001, help="failure probability") 24 | args = parser.parse_args() 25 | 26 | if __name__ == "__main__": 27 | # load the base classifier 28 | checkpoint = torch.load(args.base_classifier) 29 | base_classifier = get_architecture(checkpoint["arch"], args.dataset) 30 | base_classifier.load_state_dict(checkpoint['state_dict']) 31 | 32 | # create the smooothed classifier g 33 | smoothed_classifier = Smooth(base_classifier, get_num_classes(args.dataset), args.sigma) 34 | 35 | # prepare output file 36 | f = open(args.outfile, 'w') 37 | print("idx\tlabel\tpredict\tradius\tcorrect\ttime", file=f, flush=True) 38 | 39 | # iterate through the dataset 40 | dataset = get_dataset(args.dataset, args.split) 41 | for i in range(len(dataset)): 42 | 43 | # only certify every args.skip examples, and stop after args.max examples 44 | if i % args.skip != 0: 45 | continue 46 | if i == args.max: 47 | break 48 | 49 | (x, label) = dataset[i] 50 | 51 | before_time = time() 52 | # certify the prediction of g around x 53 | x = x.cuda() 54 | prediction, radius = smoothed_classifier.certify(x, args.N0, args.N, args.alpha, args.batch) 55 | after_time = time() 56 | correct = int(prediction == label) 57 | 58 | time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time))) 59 | print("{}\t{}\t{}\t{:.3}\t{}\t{}".format( 60 | i, label, prediction, radius, correct, time_elapsed), file=f, flush=True) 61 | 62 | f.close() 63 | -------------------------------------------------------------------------------- /code/core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.stats import norm, binom_test 3 | import numpy as np 4 | from math import ceil 5 | from statsmodels.stats.proportion import proportion_confint 6 | 7 | 8 | class Smooth(object): 9 | """A smoothed classifier g """ 10 | 11 | # to abstain, Smooth returns this int 12 | ABSTAIN = -1 13 | 14 | def __init__(self, base_classifier: torch.nn.Module, num_classes: int, sigma: float): 15 | """ 16 | :param base_classifier: maps from [batch x channel x height x width] to [batch x num_classes] 17 | :param num_classes: 18 | :param sigma: the noise level hyperparameter 19 | """ 20 | self.base_classifier = base_classifier 21 | self.num_classes = num_classes 22 | self.sigma = sigma 23 | 24 | def certify(self, x: torch.tensor, n0: int, n: int, alpha: float, batch_size: int) -> (int, float): 25 | """ Monte Carlo algorithm for certifying that g's prediction around x is constant within some L2 radius. 26 | With probability at least 1 - alpha, the class returned by this method will equal g(x), and g's prediction will 27 | robust within a L2 ball of radius R around x. 28 | 29 | :param x: the input [channel x height x width] 30 | :param n0: the number of Monte Carlo samples to use for selection 31 | :param n: the number of Monte Carlo samples to use for estimation 32 | :param alpha: the failure probability 33 | :param batch_size: batch size to use when evaluating the base classifier 34 | :return: (predicted class, certified radius) 35 | in the case of abstention, the class will be ABSTAIN and the radius 0. 36 | """ 37 | self.base_classifier.eval() 38 | # draw samples of f(x+ epsilon) 39 | counts_selection = self._sample_noise(x, n0, batch_size) 40 | # use these samples to take a guess at the top class 41 | cAHat = counts_selection.argmax().item() 42 | # draw more samples of f(x + epsilon) 43 | counts_estimation = self._sample_noise(x, n, batch_size) 44 | # use these samples to estimate a lower bound on pA 45 | nA = counts_estimation[cAHat].item() 46 | pABar = self._lower_confidence_bound(nA, n, alpha) 47 | if pABar < 0.5: 48 | return Smooth.ABSTAIN, 0.0 49 | else: 50 | radius = self.sigma * norm.ppf(pABar) 51 | return cAHat, radius 52 | 53 | def predict(self, x: torch.tensor, n: int, alpha: float, batch_size: int) -> int: 54 | """ Monte Carlo algorithm for evaluating the prediction of g at x. With probability at least 1 - alpha, the 55 | class returned by this method will equal g(x). 56 | 57 | This function uses the hypothesis test described in https://arxiv.org/abs/1610.03944 58 | for identifying the top category of a multinomial distribution. 59 | 60 | :param x: the input [channel x height x width] 61 | :param n: the number of Monte Carlo samples to use 62 | :param alpha: the failure probability 63 | :param batch_size: batch size to use when evaluating the base classifier 64 | :return: the predicted class, or ABSTAIN 65 | """ 66 | self.base_classifier.eval() 67 | counts = self._sample_noise(x, n, batch_size) 68 | top2 = counts.argsort()[::-1][:2] 69 | count1 = counts[top2[0]] 70 | count2 = counts[top2[1]] 71 | if binom_test(count1, count1 + count2, p=0.5) > alpha: 72 | return Smooth.ABSTAIN 73 | else: 74 | return top2[0] 75 | 76 | def _sample_noise(self, x: torch.tensor, num: int, batch_size) -> np.ndarray: 77 | """ Sample the base classifier's prediction under noisy corruptions of the input x. 78 | 79 | :param x: the input [channel x width x height] 80 | :param num: number of samples to collect 81 | :param batch_size: 82 | :return: an ndarray[int] of length num_classes containing the per-class counts 83 | """ 84 | with torch.no_grad(): 85 | counts = np.zeros(self.num_classes, dtype=int) 86 | for _ in range(ceil(num / batch_size)): 87 | this_batch_size = min(batch_size, num) 88 | num -= this_batch_size 89 | 90 | batch = x.repeat((this_batch_size, 1, 1, 1)) 91 | noise = torch.randn_like(batch, device='cuda') * self.sigma 92 | predictions = self.base_classifier(batch + noise).argmax(1) 93 | counts += self._count_arr(predictions.cpu().numpy(), self.num_classes) 94 | return counts 95 | 96 | def _count_arr(self, arr: np.ndarray, length: int) -> np.ndarray: 97 | counts = np.zeros(length, dtype=int) 98 | for idx in arr: 99 | counts[idx] += 1 100 | return counts 101 | 102 | def _lower_confidence_bound(self, NA: int, N: int, alpha: float) -> float: 103 | """ Returns a (1 - alpha) lower confidence bound on a bernoulli proportion. 104 | 105 | This function uses the Clopper-Pearson method. 106 | 107 | :param NA: the number of "successes" 108 | :param N: the number of total draws 109 | :param alpha: the confidence level 110 | :return: a lower bound on the binomial proportion which holds true w.p at least (1 - alpha) over the samples 111 | """ 112 | return proportion_confint(NA, N, alpha=2 * alpha, method="beta")[0] 113 | -------------------------------------------------------------------------------- /code/datasets.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms, datasets 2 | from typing import * 3 | import torch 4 | import os 5 | from torch.utils.data import Dataset 6 | 7 | # set this environment variable to the location of your imagenet directory if you want to read ImageNet data. 8 | # make sure your val directory is preprocessed to look like the train directory, e.g. by running this script 9 | # https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh 10 | IMAGENET_LOC_ENV = "IMAGENET_DIR" 11 | 12 | # list of all datasets 13 | DATASETS = ["imagenet", "cifar10"] 14 | 15 | 16 | def get_dataset(dataset: str, split: str) -> Dataset: 17 | """Return the dataset as a PyTorch Dataset object""" 18 | if dataset == "imagenet": 19 | return _imagenet(split) 20 | elif dataset == "cifar10": 21 | return _cifar10(split) 22 | 23 | 24 | def get_num_classes(dataset: str): 25 | """Return the number of classes in the dataset. """ 26 | if dataset == "imagenet": 27 | return 1000 28 | elif dataset == "cifar10": 29 | return 10 30 | 31 | 32 | def get_normalize_layer(dataset: str) -> torch.nn.Module: 33 | """Return the dataset's normalization layer""" 34 | if dataset == "imagenet": 35 | return NormalizeLayer(_IMAGENET_MEAN, _IMAGENET_STDDEV) 36 | elif dataset == "cifar10": 37 | return NormalizeLayer(_CIFAR10_MEAN, _CIFAR10_STDDEV) 38 | 39 | 40 | _IMAGENET_MEAN = [0.485, 0.456, 0.406] 41 | _IMAGENET_STDDEV = [0.229, 0.224, 0.225] 42 | 43 | _CIFAR10_MEAN = [0.4914, 0.4822, 0.4465] 44 | _CIFAR10_STDDEV = [0.2023, 0.1994, 0.2010] 45 | 46 | 47 | def _cifar10(split: str) -> Dataset: 48 | if split == "train": 49 | return datasets.CIFAR10("./dataset_cache", train=True, download=True, transform=transforms.Compose([ 50 | transforms.RandomCrop(32, padding=4), 51 | transforms.RandomHorizontalFlip(), 52 | transforms.ToTensor() 53 | ])) 54 | elif split == "test": 55 | return datasets.CIFAR10("./dataset_cache", train=False, download=True, transform=transforms.ToTensor()) 56 | 57 | 58 | def _imagenet(split: str) -> Dataset: 59 | if not IMAGENET_LOC_ENV in os.environ: 60 | raise RuntimeError("environment variable for ImageNet directory not set") 61 | 62 | dir = os.environ[IMAGENET_LOC_ENV] 63 | if split == "train": 64 | subdir = os.path.join(dir, "train") 65 | transform = transforms.Compose([ 66 | transforms.RandomSizedCrop(224), 67 | transforms.RandomHorizontalFlip(), 68 | transforms.ToTensor() 69 | ]) 70 | elif split == "test": 71 | subdir = os.path.join(dir, "val") 72 | transform = transforms.Compose([ 73 | transforms.Scale(256), 74 | transforms.CenterCrop(224), 75 | transforms.ToTensor() 76 | ]) 77 | return datasets.ImageFolder(subdir, transform) 78 | 79 | 80 | class NormalizeLayer(torch.nn.Module): 81 | """Standardize the channels of a batch of images by subtracting the dataset mean 82 | and dividing by the dataset standard deviation. 83 | 84 | In order to certify radii in original coordinates rather than standardized coordinates, we 85 | add the Gaussian noise _before_ standardizing, which is why we have standardization be the first 86 | layer of the classifier rather than as a part of preprocessing as is typical. 87 | """ 88 | 89 | def __init__(self, means: List[float], sds: List[float]): 90 | """ 91 | :param means: the channel means 92 | :param sds: the channel standard deviations 93 | """ 94 | super(NormalizeLayer, self).__init__() 95 | self.means = torch.tensor(means).cuda() 96 | self.sds = torch.tensor(sds).cuda() 97 | 98 | def forward(self, input: torch.tensor): 99 | (batch_size, num_channels, height, width) = input.shape 100 | means = self.means.repeat((batch_size, height, width, 1)).permute(0, 3, 1, 2) 101 | sds = self.sds.repeat((batch_size, height, width, 1)).permute(0, 3, 1, 2) 102 | return (input - means) / sds 103 | -------------------------------------------------------------------------------- /code/predict.py: -------------------------------------------------------------------------------- 1 | """ This script loads a base classifier and then runs PREDICT on many examples from a dataset. 2 | """ 3 | import argparse 4 | import setGPU 5 | from datasets import get_dataset, DATASETS, get_num_classes 6 | from core import Smooth 7 | from time import time 8 | import torch 9 | from architectures import get_architecture 10 | import datetime 11 | 12 | parser = argparse.ArgumentParser(description='Predict on many examples') 13 | parser.add_argument("dataset", choices=DATASETS, help="which dataset") 14 | parser.add_argument("base_classifier", type=str, help="path to saved pytorch model of base classifier") 15 | parser.add_argument("sigma", type=float, help="noise hyperparameter") 16 | parser.add_argument("outfile", type=str, help="output file") 17 | parser.add_argument("--batch", type=int, default=1000, help="batch size") 18 | parser.add_argument("--skip", type=int, default=1, help="how many examples to skip") 19 | parser.add_argument("--max", type=int, default=-1, help="stop after this many examples") 20 | parser.add_argument("--split", choices=["train", "test"], default="test", help="train or test set") 21 | parser.add_argument("--N", type=int, default=100000, help="number of samples to use") 22 | parser.add_argument("--alpha", type=float, default=0.001, help="failure probability") 23 | args = parser.parse_args() 24 | 25 | if __name__ == "__main__": 26 | # load the base classifier 27 | checkpoint = torch.load(args.base_classifier) 28 | base_classifier = get_architecture(checkpoint["arch"], args.dataset) 29 | base_classifier.load_state_dict(checkpoint['state_dict']) 30 | 31 | # create the smoothed classifier g 32 | smoothed_classifier = Smooth(base_classifier, get_num_classes(args.dataset), args.sigma) 33 | 34 | # prepare output file 35 | f = open(args.outfile, 'w') 36 | print("idx\tlabel\tpredict\tcorrect\ttime", file=f, flush=True) 37 | 38 | # iterate through the dataset 39 | dataset = get_dataset(args.dataset, args.split) 40 | for i in range(len(dataset)): 41 | 42 | # only certify every args.skip examples, and stop after args.max examples 43 | if i % args.skip != 0: 44 | continue 45 | if i == args.max: 46 | break 47 | 48 | (x, label) = dataset[i] 49 | x = x.cuda() 50 | before_time = time() 51 | 52 | # make the prediction 53 | prediction = smoothed_classifier.predict(x, args.N, args.alpha, args.batch) 54 | 55 | after_time = time() 56 | correct = int(prediction == label) 57 | 58 | time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time))) 59 | 60 | # log the prediction and whether it was correct 61 | print("{}\t{}\t{}\t{}\t{}".format(i, label, prediction, correct, time_elapsed), file=f, flush=True) 62 | 63 | f.close() 64 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | # this file is based on code publicly available at 2 | # https://github.com/bearpaw/pytorch-classification 3 | # written by Wei Yang. 4 | 5 | import argparse 6 | import os 7 | import torch 8 | from torch.nn import CrossEntropyLoss 9 | from torch.utils.data import DataLoader 10 | from datasets import get_dataset, DATASETS 11 | from architectures import ARCHITECTURES, get_architecture 12 | from torch.optim import SGD, Optimizer 13 | from torch.optim.lr_scheduler import StepLR 14 | import time 15 | import datetime 16 | from train_utils import AverageMeter, accuracy, init_logfile, log 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 19 | parser.add_argument('dataset', type=str, choices=DATASETS) 20 | parser.add_argument('arch', type=str, choices=ARCHITECTURES) 21 | parser.add_argument('outdir', type=str, help='folder to save model and training log)') 22 | parser.add_argument('--workers', default=4, type=int, metavar='N', 23 | help='number of data loading workers (default: 4)') 24 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 25 | help='number of total epochs to run') 26 | parser.add_argument('--batch', default=256, type=int, metavar='N', 27 | help='batchsize (default: 256)') 28 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 29 | help='initial learning rate', dest='lr') 30 | parser.add_argument('--lr_step_size', type=int, default=30, 31 | help='How often to decrease learning by gamma.') 32 | parser.add_argument('--gamma', type=float, default=0.1, 33 | help='LR is multiplied by gamma on schedule.') 34 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 35 | help='momentum') 36 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 37 | metavar='W', help='weight decay (default: 1e-4)') 38 | parser.add_argument('--noise_sd', default=0.0, type=float, 39 | help="standard deviation of Gaussian noise for data augmentation") 40 | parser.add_argument('--gpu', default=None, type=str, 41 | help='id(s) for CUDA_VISIBLE_DEVICES') 42 | parser.add_argument('--print-freq', default=10, type=int, 43 | metavar='N', help='print frequency (default: 10)') 44 | args = parser.parse_args() 45 | 46 | 47 | def main(): 48 | if args.gpu: 49 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 50 | 51 | if not os.path.exists(args.outdir): 52 | os.mkdir(args.outdir) 53 | 54 | train_dataset = get_dataset(args.dataset, 'train') 55 | test_dataset = get_dataset(args.dataset, 'test') 56 | pin_memory = (args.dataset == "imagenet") 57 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch, 58 | num_workers=args.workers, pin_memory=pin_memory) 59 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch, 60 | num_workers=args.workers, pin_memory=pin_memory) 61 | 62 | model = get_architecture(args.arch, args.dataset) 63 | 64 | logfilename = os.path.join(args.outdir, 'log.txt') 65 | init_logfile(logfilename, "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc") 66 | 67 | criterion = CrossEntropyLoss().cuda() 68 | optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 69 | scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma) 70 | 71 | for epoch in range(args.epochs): 72 | scheduler.step(epoch) 73 | before = time.time() 74 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, args.noise_sd) 75 | test_loss, test_acc = test(test_loader, model, criterion, args.noise_sd) 76 | after = time.time() 77 | 78 | log(logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format( 79 | epoch, str(datetime.timedelta(seconds=(after - before))), 80 | scheduler.get_lr()[0], train_loss, train_acc, test_loss, test_acc)) 81 | 82 | torch.save({ 83 | 'epoch': epoch + 1, 84 | 'arch': args.arch, 85 | 'state_dict': model.state_dict(), 86 | 'optimizer': optimizer.state_dict(), 87 | }, os.path.join(args.outdir, 'checkpoint.pth.tar')) 88 | 89 | 90 | def train(loader: DataLoader, model: torch.nn.Module, criterion, optimizer: Optimizer, epoch: int, noise_sd: float): 91 | batch_time = AverageMeter() 92 | data_time = AverageMeter() 93 | losses = AverageMeter() 94 | top1 = AverageMeter() 95 | top5 = AverageMeter() 96 | end = time.time() 97 | 98 | # switch to train mode 99 | model.train() 100 | 101 | for i, (inputs, targets) in enumerate(loader): 102 | # measure data loading time 103 | data_time.update(time.time() - end) 104 | 105 | inputs = inputs.cuda() 106 | targets = targets.cuda() 107 | 108 | # augment inputs with noise 109 | inputs = inputs + torch.randn_like(inputs, device='cuda') * noise_sd 110 | 111 | # compute output 112 | outputs = model(inputs) 113 | loss = criterion(outputs, targets) 114 | 115 | # measure accuracy and record loss 116 | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) 117 | losses.update(loss.item(), inputs.size(0)) 118 | top1.update(acc1.item(), inputs.size(0)) 119 | top5.update(acc5.item(), inputs.size(0)) 120 | 121 | # compute gradient and do SGD step 122 | optimizer.zero_grad() 123 | loss.backward() 124 | optimizer.step() 125 | 126 | # measure elapsed time 127 | batch_time.update(time.time() - end) 128 | end = time.time() 129 | 130 | if i % args.print_freq == 0: 131 | print('Epoch: [{0}][{1}/{2}]\t' 132 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 133 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 134 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 135 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 136 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 137 | epoch, i, len(loader), batch_time=batch_time, 138 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 139 | 140 | return (losses.avg, top1.avg) 141 | 142 | 143 | def test(loader: DataLoader, model: torch.nn.Module, criterion, noise_sd: float): 144 | batch_time = AverageMeter() 145 | data_time = AverageMeter() 146 | losses = AverageMeter() 147 | top1 = AverageMeter() 148 | top5 = AverageMeter() 149 | end = time.time() 150 | 151 | # switch to eval mode 152 | model.eval() 153 | 154 | with torch.no_grad(): 155 | for i, (inputs, targets) in enumerate(loader): 156 | # measure data loading time 157 | data_time.update(time.time() - end) 158 | 159 | inputs = inputs.cuda() 160 | targets = targets.cuda() 161 | 162 | # augment inputs with noise 163 | inputs = inputs + torch.randn_like(inputs, device='cuda') * noise_sd 164 | 165 | # compute output 166 | outputs = model(inputs) 167 | loss = criterion(outputs, targets) 168 | 169 | # measure accuracy and record loss 170 | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) 171 | losses.update(loss.item(), inputs.size(0)) 172 | top1.update(acc1.item(), inputs.size(0)) 173 | top5.update(acc5.item(), inputs.size(0)) 174 | 175 | # measure elapsed time 176 | batch_time.update(time.time() - end) 177 | end = time.time() 178 | 179 | if i % args.print_freq == 0: 180 | print('Test: [{0}/{1}]\t' 181 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 182 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 183 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 184 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 185 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 186 | i, len(loader), batch_time=batch_time, 187 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 188 | 189 | return (losses.avg, top1.avg) 190 | 191 | 192 | if __name__ == "__main__": 193 | main() 194 | -------------------------------------------------------------------------------- /code/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class AverageMeter(object): 4 | """Computes and stores the average and current value""" 5 | def __init__(self): 6 | self.reset() 7 | 8 | def reset(self): 9 | self.val = 0 10 | self.avg = 0 11 | self.sum = 0 12 | self.count = 0 13 | 14 | def update(self, val, n=1): 15 | self.val = val 16 | self.sum += val * n 17 | self.count += n 18 | self.avg = self.sum / self.count 19 | 20 | 21 | def accuracy(output, target, topk=(1,)): 22 | """Computes the accuracy over the k top predictions for the specified values of k""" 23 | with torch.no_grad(): 24 | maxk = max(topk) 25 | batch_size = target.size(0) 26 | 27 | _, pred = output.topk(maxk, 1, True, True) 28 | pred = pred.t() 29 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 30 | 31 | res = [] 32 | for k in topk: 33 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 34 | res.append(correct_k.mul_(100.0 / batch_size)) 35 | return res 36 | 37 | def init_logfile(filename: str, text: str): 38 | f = open(filename, 'w') 39 | f.write(text+"\n") 40 | f.close() 41 | 42 | def log(filename: str, text: str): 43 | f = open(filename, 'a') 44 | f.write(text+"\n") 45 | f.close() -------------------------------------------------------------------------------- /code/visualize.py: -------------------------------------------------------------------------------- 1 | # visualize noisy images 2 | import argparse 3 | from datasets import get_dataset, DATASETS 4 | import torch 5 | from torchvision.transforms import ToPILImage 6 | 7 | parser = argparse.ArgumentParser(description='visualize noisy images') 8 | parser.add_argument("dataset", type=str, choices=DATASETS) 9 | parser.add_argument("outdir", type=str, help="output directory") 10 | parser.add_argument("idx", type=int) 11 | parser.add_argument("noise_sds", nargs='+', type=float) 12 | parser.add_argument("--split", choices=["train", "test"], default="test") 13 | args = parser.parse_args() 14 | 15 | toPilImage = ToPILImage() 16 | dataset = get_dataset(args.dataset, args.split) 17 | image, _ = dataset[args.idx] 18 | noise = torch.randn_like(image) 19 | for noise_sd in args.noise_sds: 20 | noisy_image = torch.clamp(image + noise * noise_sd, min=0, max=1) 21 | pil = toPilImage(noisy_image) 22 | pil.save("{}/{}_{}.png".format(args.outdir, args.idx, int(noise_sd * 100))) 23 | -------------------------------------------------------------------------------- /data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25: -------------------------------------------------------------------------------- 1 | idx label predict radius correct time 2 | 0 3 3 0.378 1 15.4 3 | 20 7 -1 0.0 0 15.6 4 | 40 4 4 0.341 1 15.7 5 | 60 7 7 0.941 1 15.8 6 | 80 8 8 0.755 1 15.9 7 | 100 4 -1 0.0 0 15.9 8 | 120 8 8 0.163 1 16.0 9 | 140 6 6 0.203 1 16.3 10 | 160 2 2 0.548 1 16.4 11 | 180 0 0 0.557 1 16.4 12 | 200 5 5 0.313 1 16.6 13 | 220 7 7 0.134 1 16.7 14 | 240 1 1 0.978 1 16.8 15 | 260 8 8 0.563 1 16.8 16 | 280 9 9 0.283 1 16.8 17 | 300 6 6 0.955 1 16.9 18 | 320 3 3 0.224 1 16.9 19 | 340 2 4 0.152 0 16.9 20 | 360 9 9 0.611 1 16.9 21 | 380 6 6 0.444 1 17.0 22 | 400 9 9 0.753 1 17.0 23 | 420 4 4 0.699 1 17.0 24 | 440 1 1 0.708 1 17.0 25 | 460 5 5 0.616 1 17.0 26 | 480 8 9 0.211 0 17.0 27 | 500 4 -1 0.0 0 17.0 28 | 520 5 5 0.63 1 17.0 29 | 540 1 1 0.725 1 17.0 30 | 560 0 0 0.978 1 17.0 31 | 580 4 4 0.314 1 17.0 32 | 600 8 8 0.46 1 17.0 33 | 620 7 7 0.849 1 16.9 34 | 640 5 3 0.698 0 17.0 35 | 660 9 9 0.411 1 17.0 36 | 680 9 9 0.441 1 16.9 37 | 700 7 7 0.955 1 16.9 38 | 720 4 4 0.232 1 16.9 39 | 740 2 5 0.508 0 16.9 40 | 760 3 3 0.237 1 16.9 41 | 780 9 9 0.929 1 16.9 42 | 800 7 7 0.978 1 16.9 43 | 820 6 6 0.254 1 16.9 44 | 840 7 7 0.978 1 16.9 45 | 860 4 4 0.0831 1 16.9 46 | 880 8 8 0.894 1 16.9 47 | 900 2 6 0.0322 0 16.9 48 | 920 6 6 0.417 1 16.9 49 | 940 9 9 0.759 1 16.9 50 | 960 9 9 0.245 1 16.9 51 | 980 2 2 0.0549 1 16.9 52 | 1000 5 5 0.978 1 16.9 53 | 1020 1 1 0.906 1 16.9 54 | 1040 7 5 0.17 0 16.9 55 | 1060 9 9 0.978 1 16.9 56 | 1080 6 6 0.411 1 16.9 57 | 1100 7 0 0.0632 0 16.9 58 | 1120 5 2 0.039 0 16.9 59 | 1140 6 6 0.257 1 16.9 60 | 1160 8 8 0.617 1 16.9 61 | 1180 3 3 0.558 1 16.9 62 | 1200 8 8 0.518 1 16.9 63 | 1220 9 6 0.0591 0 16.9 64 | 1240 2 6 0.195 0 16.9 65 | 1260 1 -1 0.0 0 16.9 66 | 1280 3 3 0.303 1 16.9 67 | 1300 4 3 0.347 0 16.9 68 | 1320 7 7 0.515 1 16.9 69 | 1340 1 1 0.648 1 16.9 70 | 1360 7 7 0.321 1 16.9 71 | 1380 4 4 0.0991 1 16.9 72 | 1400 5 3 0.01 0 16.9 73 | 1420 4 4 0.164 1 16.9 74 | 1440 0 0 0.323 1 16.9 75 | 1460 6 6 0.831 1 16.9 76 | 1480 1 1 0.26 1 16.9 77 | 1500 1 1 0.724 1 16.9 78 | 1520 7 7 0.522 1 16.9 79 | 1540 8 8 0.889 1 16.9 80 | 1560 7 7 0.978 1 16.9 81 | 1580 6 -1 0.0 0 16.9 82 | 1600 8 -1 0.0 0 16.9 83 | 1620 5 -1 0.0 0 16.9 84 | 1640 7 -1 0.0 0 16.9 85 | 1660 3 3 0.368 1 16.9 86 | 1680 6 6 0.474 1 16.9 87 | 1700 5 5 0.0619 1 16.9 88 | 1720 4 4 0.578 1 16.9 89 | 1740 1 1 0.477 1 16.9 90 | 1760 0 0 0.683 1 16.9 91 | 1780 7 7 0.662 1 16.9 92 | 1800 4 4 0.441 1 16.9 93 | 1820 8 8 0.781 1 16.9 94 | 1840 8 -1 0.0 0 16.9 95 | 1860 8 8 0.978 1 16.9 96 | 1880 1 1 0.978 1 16.9 97 | 1900 8 8 0.978 1 16.9 98 | 1920 2 2 0.978 1 16.9 99 | 1940 5 5 0.265 1 16.9 100 | 1960 2 -1 0.0 0 16.9 101 | 1980 9 9 0.729 1 16.9 102 | 2000 1 1 0.731 1 16.9 103 | 2020 9 9 0.978 1 16.9 104 | 2040 0 0 0.929 1 16.9 105 | 2060 5 3 0.538 0 16.9 106 | 2080 1 9 0.136 0 16.9 107 | 2100 2 2 0.483 1 16.9 108 | 2120 4 8 0.092 0 16.9 109 | 2140 9 9 0.978 1 16.9 110 | 2160 0 0 0.397 1 16.9 111 | 2180 4 4 0.558 1 16.9 112 | 2200 0 0 0.228 1 16.9 113 | 2220 1 1 0.978 1 16.9 114 | 2240 9 9 0.268 1 16.9 115 | 2260 4 4 0.661 1 16.9 116 | 2280 4 4 0.144 1 16.9 117 | 2300 3 3 0.018 1 16.9 118 | 2320 5 2 0.653 0 16.9 119 | 2340 7 7 0.955 1 16.9 120 | 2360 7 7 0.019 1 16.9 121 | 2380 8 8 0.725 1 16.9 122 | 2400 0 0 0.978 1 16.9 123 | 2420 8 8 0.055 1 16.9 124 | 2440 7 7 0.978 1 16.9 125 | 2460 9 9 0.108 1 16.9 126 | 2480 8 8 0.817 1 16.9 127 | 2500 4 4 0.108 1 16.9 128 | 2520 1 1 0.147 1 16.9 129 | 2540 2 6 0.119 0 16.9 130 | 2560 7 7 0.312 1 16.9 131 | 2580 6 5 0.0511 0 16.9 132 | 2600 8 8 0.88 1 16.9 133 | 2620 1 1 0.637 1 16.9 134 | 2640 7 7 0.978 1 16.9 135 | 2660 3 -1 0.0 0 16.9 136 | 2680 0 0 0.978 1 16.9 137 | 2700 9 -1 0.0 0 16.9 138 | 2720 3 3 0.484 1 16.9 139 | 2740 1 1 0.164 1 16.9 140 | 2760 2 2 0.42 1 16.9 141 | 2780 7 7 0.889 1 16.9 142 | 2800 4 4 0.372 1 16.9 143 | 2820 6 6 0.978 1 16.9 144 | 2840 3 3 0.978 1 16.9 145 | 2860 5 4 0.279 0 16.9 146 | 2880 4 4 0.781 1 16.9 147 | 2900 3 7 0.127 0 16.9 148 | 2920 2 3 0.0245 0 16.9 149 | 2940 3 0 0.811 0 16.9 150 | 2960 9 0 0.155 0 16.9 151 | 2980 8 8 0.456 1 16.9 152 | 3000 5 5 0.806 1 16.9 153 | 3020 1 1 0.567 1 16.9 154 | 3040 7 7 0.978 1 16.9 155 | 3060 7 7 0.821 1 16.9 156 | 3080 1 1 0.0759 1 16.9 157 | 3100 0 0 0.62 1 16.9 158 | 3120 4 4 0.137 1 16.9 159 | 3140 8 8 0.589 1 16.9 160 | 3160 6 2 0.114 0 16.9 161 | 3180 3 3 0.0918 1 16.9 162 | 3200 5 5 0.577 1 16.9 163 | 3220 3 9 0.124 0 16.9 164 | 3240 4 -1 0.0 0 16.9 165 | 3260 7 7 0.522 1 16.9 166 | 3280 3 5 0.0818 0 16.9 167 | 3300 4 4 0.285 1 16.9 168 | 3320 2 7 0.335 0 16.9 169 | 3340 6 6 0.525 1 16.9 170 | 3360 4 4 0.442 1 16.9 171 | 3380 8 8 0.978 1 16.9 172 | 3400 6 4 0.49 0 16.9 173 | 3420 5 -1 0.0 0 16.9 174 | 3440 2 3 0.643 0 16.9 175 | 3460 1 1 0.978 1 16.9 176 | 3480 2 2 0.978 1 16.9 177 | 3500 1 1 0.118 1 16.9 178 | 3520 1 1 0.978 1 16.9 179 | 3540 6 6 0.736 1 16.9 180 | 3560 1 1 0.0284 1 16.9 181 | 3580 4 4 0.228 1 16.9 182 | 3600 4 2 0.467 0 16.9 183 | 3620 0 0 0.304 1 16.9 184 | 3640 6 6 0.0873 1 16.9 185 | 3660 0 0 0.506 1 16.9 186 | 3680 0 0 0.0749 1 16.9 187 | 3700 3 3 0.759 1 16.9 188 | 3720 2 2 0.978 1 16.9 189 | 3740 0 0 0.955 1 16.9 190 | 3760 7 7 0.978 1 16.9 191 | 3780 4 4 0.349 1 16.9 192 | 3800 9 9 0.708 1 16.9 193 | 3820 0 0 0.133 1 16.9 194 | 3840 6 3 0.00299 0 16.9 195 | 3860 7 7 0.978 1 16.9 196 | 3880 6 6 0.288 1 16.9 197 | 3900 3 3 0.086 1 16.9 198 | 3920 7 7 0.27 1 16.9 199 | 3940 6 -1 0.0 0 16.9 200 | 3960 2 2 0.784 1 16.9 201 | 3980 9 9 0.163 1 16.9 202 | 4000 8 -1 0.0 0 16.9 203 | 4020 8 8 0.488 1 16.9 204 | 4040 0 0 0.92 1 16.9 205 | 4060 6 6 0.559 1 16.9 206 | 4080 1 1 0.978 1 16.9 207 | 4100 7 7 0.978 1 16.9 208 | 4120 4 4 0.929 1 16.9 209 | 4140 5 5 0.682 1 16.9 210 | 4160 5 -1 0.0 0 16.9 211 | 4180 0 8 0.155 0 16.9 212 | 4200 4 4 0.181 1 16.9 213 | 4220 4 0 0.113 0 16.9 214 | 4240 7 7 0.0474 1 16.9 215 | 4260 8 8 0.92 1 16.9 216 | 4280 8 0 0.394 0 16.9 217 | 4300 8 8 0.621 1 16.9 218 | 4320 1 -1 0.0 0 16.9 219 | 4340 0 0 0.859 1 16.9 220 | 4360 6 6 0.929 1 16.9 221 | 4380 9 9 0.642 1 16.9 222 | 4400 3 -1 0.0 0 16.9 223 | 4420 5 5 0.469 1 16.9 224 | 4440 2 2 0.208 1 16.9 225 | 4460 9 9 0.674 1 16.9 226 | 4480 9 9 0.978 1 16.9 227 | 4500 3 5 0.0371 0 16.9 228 | 4520 3 3 0.174 1 16.9 229 | 4540 9 9 0.784 1 16.9 230 | 4560 1 1 0.955 1 16.9 231 | 4580 6 6 0.412 1 16.9 232 | 4600 4 4 0.512 1 16.9 233 | 4620 7 7 0.125 1 16.9 234 | 4640 2 2 0.34 1 16.9 235 | 4660 7 -1 0.0 0 16.9 236 | 4680 9 9 0.617 1 16.9 237 | 4700 6 5 0.179 0 16.9 238 | 4720 8 8 0.299 1 16.9 239 | 4740 5 3 0.145 0 16.9 240 | 4760 3 2 0.436 0 16.9 241 | 4780 0 0 0.585 1 16.9 242 | 4800 9 9 0.978 1 16.9 243 | 4820 3 3 0.471 1 16.9 244 | 4840 0 0 0.978 1 16.9 245 | 4860 5 5 0.978 1 16.9 246 | 4880 0 8 0.178 0 16.9 247 | 4900 3 3 0.652 1 16.9 248 | 4920 7 7 0.978 1 16.9 249 | 4940 6 6 0.378 1 16.9 250 | 4960 4 4 0.318 1 16.9 251 | 4980 1 9 0.478 0 16.9 252 | 5000 7 7 0.978 1 16.9 253 | 5020 8 8 0.754 1 16.9 254 | 5040 3 3 0.336 1 16.9 255 | 5060 6 6 0.256 1 16.9 256 | 5080 7 7 0.731 1 16.9 257 | 5100 3 3 0.731 1 16.9 258 | 5120 9 9 0.32 1 16.9 259 | 5140 8 8 0.772 1 16.9 260 | 5160 0 0 0.686 1 16.9 261 | 5180 9 9 0.0113 1 16.9 262 | 5200 3 3 0.663 1 16.9 263 | 5220 0 0 0.429 1 16.9 264 | 5240 1 1 0.423 1 16.9 265 | 5260 0 0 0.978 1 16.9 266 | 5280 0 -1 0.0 0 16.9 267 | 5300 9 9 0.743 1 16.9 268 | 5320 7 7 0.449 1 16.9 269 | 5340 3 4 0.135 0 16.9 270 | 5360 9 9 0.802 1 16.9 271 | 5380 1 9 0.267 0 16.9 272 | 5400 9 9 0.781 1 16.9 273 | 5420 6 6 0.844 1 16.9 274 | 5440 9 9 0.955 1 16.9 275 | 5460 0 0 0.978 1 16.9 276 | 5480 1 1 0.641 1 16.9 277 | 5500 8 8 0.336 1 16.9 278 | 5520 4 4 0.279 1 16.9 279 | 5540 5 5 0.104 1 16.9 280 | 5560 3 5 0.293 0 16.9 281 | 5580 5 -1 0.0 0 16.9 282 | 5600 6 6 0.187 1 16.9 283 | 5620 3 -1 0.0 0 16.9 284 | 5640 2 5 0.0519 0 16.9 285 | 5660 9 9 0.165 1 16.9 286 | 5680 2 3 0.155 0 16.9 287 | 5700 3 3 0.655 1 16.9 288 | 5720 9 9 0.978 1 16.9 289 | 5740 6 6 0.322 1 16.9 290 | 5760 2 2 0.26 1 16.9 291 | 5780 7 7 0.978 1 16.9 292 | 5800 2 2 0.517 1 16.9 293 | 5820 5 5 0.433 1 16.9 294 | 5840 2 2 0.322 1 16.9 295 | 5860 0 0 0.516 1 16.9 296 | 5880 0 0 0.889 1 16.9 297 | 5900 4 3 0.0461 0 16.9 298 | 5920 8 8 0.978 1 16.9 299 | 5940 7 7 0.0779 1 16.9 300 | 5960 2 6 0.224 0 16.9 301 | 5980 9 9 0.955 1 16.9 302 | 6000 8 8 0.478 1 16.9 303 | 6020 6 6 0.632 1 16.9 304 | 6040 2 -1 0.0 0 16.9 305 | 6060 3 4 0.0698 0 16.9 306 | 6080 1 1 0.585 1 16.9 307 | 6100 1 1 0.978 1 16.9 308 | 6120 5 5 0.539 1 16.9 309 | 6140 9 9 0.615 1 16.9 310 | 6160 2 -1 0.0 0 16.9 311 | 6180 0 0 0.389 1 16.9 312 | 6200 3 3 0.0459 1 16.9 313 | 6220 2 -1 0.0 0 16.9 314 | 6240 2 2 0.41 1 16.9 315 | 6260 2 2 0.176 1 16.9 316 | 6280 8 -1 0.0 0 16.9 317 | 6300 1 1 0.782 1 16.9 318 | 6320 0 0 0.37 1 16.9 319 | 6340 3 3 0.173 1 16.9 320 | 6360 2 2 0.247 1 16.9 321 | 6380 3 -1 0.0 0 16.9 322 | 6400 0 0 0.543 1 16.9 323 | 6420 7 7 0.978 1 16.9 324 | 6440 3 3 0.101 1 16.9 325 | 6460 3 3 0.724 1 16.9 326 | 6480 0 0 0.565 1 16.9 327 | 6500 7 -1 0.0 0 16.9 328 | 6520 6 6 0.186 1 16.9 329 | 6540 8 8 0.45 1 16.9 330 | 6560 6 6 0.0883 1 16.9 331 | 6580 1 1 0.117 1 16.9 332 | 6600 7 7 0.955 1 16.9 333 | 6620 7 7 0.978 1 16.9 334 | 6640 5 3 0.496 0 16.9 335 | 6660 0 0 0.733 1 16.9 336 | 6680 3 5 0.27 0 16.9 337 | 6700 6 6 0.866 1 16.9 338 | 6720 2 2 0.304 1 16.9 339 | 6740 2 2 0.0894 1 16.9 340 | 6760 5 5 0.955 1 16.9 341 | 6780 7 7 0.978 1 16.9 342 | 6800 6 -1 0.0 0 16.9 343 | 6820 1 1 0.873 1 16.9 344 | 6840 9 9 0.752 1 16.9 345 | 6860 0 0 0.617 1 16.9 346 | 6880 2 2 0.572 1 16.9 347 | 6900 3 3 0.612 1 16.9 348 | 6920 5 4 0.199 0 16.9 349 | 6940 1 1 0.978 1 16.9 350 | 6960 9 9 0.403 1 16.9 351 | 6980 0 0 0.381 1 16.9 352 | 7000 2 0 0.0768 0 16.9 353 | 7020 7 7 0.978 1 16.9 354 | 7040 7 -1 0.0 0 16.9 355 | 7060 8 8 0.894 1 16.9 356 | 7080 5 4 0.175 0 16.9 357 | 7100 9 3 0.349 0 16.9 358 | 7120 7 7 0.316 1 16.9 359 | 7140 4 4 0.147 1 16.9 360 | 7160 5 5 0.688 1 16.9 361 | 7180 6 6 0.445 1 16.9 362 | 7200 4 4 0.105 1 16.9 363 | 7220 9 9 0.669 1 16.9 364 | 7240 0 0 0.414 1 16.9 365 | 7260 1 1 0.765 1 16.9 366 | 7280 1 1 0.231 1 16.9 367 | 7300 3 5 0.0342 0 16.9 368 | 7320 8 0 0.105 0 16.9 369 | 7340 2 2 0.236 1 16.9 370 | 7360 2 2 0.328 1 16.9 371 | 7380 7 7 0.844 1 16.9 372 | 7400 3 5 0.37 0 16.9 373 | 7420 3 4 0.187 0 16.9 374 | 7440 3 3 0.693 1 16.9 375 | 7460 5 5 0.978 1 16.9 376 | 7480 1 1 0.466 1 16.9 377 | 7500 6 6 0.978 1 16.9 378 | 7520 4 4 0.366 1 16.9 379 | 7540 5 5 0.978 1 16.9 380 | 7560 0 0 0.369 1 16.9 381 | 7580 8 8 0.978 1 16.9 382 | 7600 8 2 0.0625 0 16.9 383 | 7620 5 3 0.103 0 16.9 384 | 7640 9 9 0.114 1 16.9 385 | 7660 7 7 0.424 1 16.9 386 | 7680 3 5 0.0632 0 16.9 387 | 7700 6 6 0.92 1 16.9 388 | 7720 5 2 0.00494 0 16.9 389 | 7740 4 4 0.301 1 16.9 390 | 7760 2 0 0.143 0 16.9 391 | 7780 3 -1 0.0 0 16.9 392 | 7800 0 0 0.597 1 16.9 393 | 7820 5 5 0.18 1 16.9 394 | 7840 1 1 0.978 1 16.9 395 | 7860 8 8 0.749 1 16.9 396 | 7880 0 0 0.647 1 16.9 397 | 7900 0 -1 0.0 0 16.9 398 | 7920 9 9 0.849 1 16.9 399 | 7940 2 2 0.0572 1 16.9 400 | 7960 0 0 0.138 1 16.9 401 | 7980 3 3 0.321 1 16.9 402 | 8000 9 9 0.762 1 16.9 403 | 8020 6 3 0.02 0 16.9 404 | 8040 2 2 0.729 1 16.9 405 | 8060 6 6 0.768 1 16.9 406 | 8080 4 4 0.231 1 16.9 407 | 8100 6 2 0.017 0 16.9 408 | 8120 8 8 0.00256 1 16.9 409 | 8140 5 7 0.0382 0 16.9 410 | 8160 6 6 0.978 1 16.9 411 | 8180 4 4 0.239 1 16.9 412 | 8200 3 -1 0.0 0 16.9 413 | 8220 6 5 0.0177 0 16.9 414 | 8240 7 7 0.572 1 16.9 415 | 8260 1 1 0.955 1 16.9 416 | 8280 4 4 0.863 1 16.9 417 | 8300 5 -1 0.0 0 16.9 418 | 8320 7 7 0.27 1 16.9 419 | 8340 5 5 0.104 1 16.9 420 | 8360 7 7 0.978 1 16.9 421 | 8380 9 9 0.978 1 16.9 422 | 8400 0 0 0.168 1 16.9 423 | 8420 3 6 0.573 0 16.9 424 | 8440 5 5 0.141 1 16.9 425 | 8460 8 8 0.771 1 16.9 426 | 8480 2 2 0.363 1 16.9 427 | 8500 4 4 0.64 1 16.9 428 | 8520 8 8 0.41 1 16.9 429 | 8540 4 -1 0.0 0 16.9 430 | 8560 7 7 0.941 1 16.9 431 | 8580 3 5 0.244 0 16.9 432 | 8600 3 -1 0.0 0 16.9 433 | 8620 3 -1 0.0 0 16.9 434 | 8640 1 1 0.579 1 16.9 435 | 8660 7 7 0.686 1 16.9 436 | 8680 3 3 0.00316 1 16.9 437 | 8700 3 3 0.552 1 16.9 438 | 8720 2 0 0.88 0 16.9 439 | 8740 2 2 0.253 1 16.9 440 | 8760 8 8 0.955 1 16.9 441 | 8780 4 4 0.255 1 16.9 442 | 8800 0 0 0.978 1 16.9 443 | 8820 2 2 0.978 1 16.9 444 | 8840 2 2 0.131 1 16.9 445 | 8860 2 2 0.941 1 16.9 446 | 8880 1 1 0.793 1 16.9 447 | 8900 2 4 0.19 0 16.9 448 | 8920 6 6 0.566 1 16.9 449 | 8940 6 3 0.393 0 16.9 450 | 8960 0 0 0.0635 1 16.9 451 | 8980 9 9 0.955 1 16.9 452 | 9000 8 8 0.978 1 16.9 453 | 9020 7 7 0.824 1 16.9 454 | 9040 2 5 0.725 0 16.9 455 | 9060 9 9 0.466 1 16.9 456 | 9080 3 3 0.182 1 16.9 457 | 9100 9 1 0.233 0 16.9 458 | 9120 3 -1 0.0 0 16.9 459 | 9140 7 7 0.455 1 16.9 460 | 9160 5 3 0.19 0 16.9 461 | 9180 5 5 0.978 1 16.9 462 | 9200 8 8 0.802 1 16.9 463 | 9220 8 8 0.821 1 16.9 464 | 9240 8 8 0.467 1 16.9 465 | 9260 5 5 0.319 1 16.9 466 | 9280 9 8 0.0567 0 16.9 467 | 9300 5 -1 0.0 0 16.9 468 | 9320 9 9 0.894 1 16.9 469 | 9340 1 1 0.978 1 16.9 470 | 9360 5 0 0.119 0 16.9 471 | 9380 5 3 0.0574 0 16.9 472 | 9400 6 6 0.978 1 16.9 473 | 9420 5 5 0.787 1 16.9 474 | 9440 8 8 0.572 1 16.9 475 | 9460 3 3 0.14 1 16.9 476 | 9480 2 2 0.178 1 16.9 477 | 9500 9 9 0.644 1 16.9 478 | 9520 1 1 0.721 1 16.9 479 | 9540 5 5 0.659 1 16.9 480 | 9560 9 7 0.123 0 16.9 481 | 9580 1 1 0.978 1 16.9 482 | 9600 8 8 0.322 1 16.9 483 | 9620 4 4 0.352 1 16.9 484 | 9640 5 5 0.358 1 16.9 485 | 9660 6 6 0.906 1 16.9 486 | 9680 8 8 0.844 1 16.9 487 | 9700 0 0 0.709 1 16.9 488 | 9720 8 -1 0.0 0 16.9 489 | 9740 3 6 0.22 0 16.9 490 | 9760 9 -1 0.0 0 16.9 491 | 9780 4 4 0.192 1 16.9 492 | 9800 1 1 0.978 1 16.8 493 | 9820 0 8 0.0396 0 16.9 494 | 9840 4 7 0.563 0 16.9 495 | 9860 0 -1 0.0 0 16.8 496 | 9880 7 7 0.077 1 16.9 497 | 9900 8 8 0.929 1 16.9 498 | 9920 6 6 0.978 1 16.9 499 | 9940 4 4 0.0449 1 16.9 500 | 9960 2 0 0.941 0 16.9 501 | 9980 0 0 0.978 1 16.9 502 | -------------------------------------------------------------------------------- /data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50: -------------------------------------------------------------------------------- 1 | idx label predict radius correct time 2 | 0 3 3 0.486 1 16.6 3 | 20 7 7 0.413 1 17.2 4 | 40 4 4 0.312 1 17.5 5 | 60 7 7 1.22 1 17.6 6 | 80 8 8 1.2 1 17.6 7 | 100 4 -1 0.0 0 17.6 8 | 120 8 8 0.226 1 17.6 9 | 140 6 6 0.366 1 17.6 10 | 160 2 2 0.276 1 17.6 11 | 180 0 0 0.935 1 17.6 12 | 200 5 5 0.28 1 17.6 13 | 220 7 7 1.37 1 17.6 14 | 240 1 1 1.88 1 17.6 15 | 260 8 8 0.00902 1 17.6 16 | 280 9 9 1.04 1 17.6 17 | 300 6 6 0.684 1 17.6 18 | 320 3 3 0.592 1 17.5 19 | 340 2 4 0.172 0 17.5 20 | 360 9 9 0.301 1 17.6 21 | 380 6 6 0.454 1 17.5 22 | 400 9 9 1.17 1 17.5 23 | 420 4 4 0.803 1 17.5 24 | 440 1 1 0.441 1 17.5 25 | 460 5 5 1.13 1 17.5 26 | 480 8 9 0.556 0 17.5 27 | 500 4 6 0.373 0 17.5 28 | 520 5 5 0.324 1 17.5 29 | 540 1 1 0.685 1 17.5 30 | 560 0 0 1.36 1 17.5 31 | 580 4 4 0.507 1 17.5 32 | 600 8 8 1.15 1 17.5 33 | 620 7 7 0.479 1 17.5 34 | 640 5 3 0.917 0 17.5 35 | 660 9 9 0.84 1 17.5 36 | 680 9 9 0.112 1 17.5 37 | 700 7 7 0.792 1 17.5 38 | 720 4 -1 0.0 0 17.5 39 | 740 2 3 0.0547 0 17.5 40 | 760 3 3 0.187 1 17.5 41 | 780 9 9 0.36 1 17.5 42 | 800 7 7 1.91 1 17.5 43 | 820 6 -1 0.0 0 17.5 44 | 840 7 7 1.77 1 17.5 45 | 860 4 4 0.398 1 17.5 46 | 880 8 8 1.28 1 17.5 47 | 900 2 2 0.0857 1 17.5 48 | 920 6 6 0.344 1 17.5 49 | 940 9 9 1.25 1 17.5 50 | 960 9 9 1.09 1 17.5 51 | 980 2 6 0.141 0 17.5 52 | 1000 5 5 1.34 1 17.5 53 | 1020 1 1 1.45 1 17.5 54 | 1040 7 5 0.775 0 17.5 55 | 1060 9 9 1.45 1 17.5 56 | 1080 6 6 0.546 1 17.5 57 | 1100 7 -1 0.0 0 17.5 58 | 1120 5 5 0.0396 1 17.5 59 | 1140 6 6 0.373 1 17.5 60 | 1160 8 8 0.096 1 17.5 61 | 1180 3 3 0.978 1 17.5 62 | 1200 8 8 0.442 1 17.5 63 | 1220 9 -1 0.0 0 17.5 64 | 1240 2 6 0.868 0 17.5 65 | 1260 1 9 0.423 0 17.5 66 | 1280 3 3 0.0999 1 17.5 67 | 1300 4 3 0.62 0 17.5 68 | 1320 7 -1 0.0 0 17.5 69 | 1340 1 1 1.33 1 17.5 70 | 1360 7 7 0.938 1 17.5 71 | 1380 4 5 0.315 0 17.5 72 | 1400 5 3 0.348 0 17.5 73 | 1420 4 7 0.0896 0 17.5 74 | 1440 0 0 0.137 1 17.5 75 | 1460 6 6 0.963 1 17.5 76 | 1480 1 1 0.336 1 17.5 77 | 1500 1 1 0.934 1 17.5 78 | 1520 7 7 0.876 1 17.5 79 | 1540 8 8 1.35 1 17.5 80 | 1560 7 7 1.96 1 17.5 81 | 1580 6 1 0.103 0 17.5 82 | 1600 8 9 0.116 0 17.5 83 | 1620 5 -1 0.0 0 17.5 84 | 1640 7 3 0.423 0 17.5 85 | 1660 3 3 0.0679 1 17.5 86 | 1680 6 6 0.867 1 17.5 87 | 1700 5 5 0.0387 1 17.5 88 | 1720 4 4 0.0682 1 17.5 89 | 1740 1 1 0.487 1 17.5 90 | 1760 0 0 0.566 1 17.5 91 | 1780 7 7 0.688 1 17.5 92 | 1800 4 -1 0.0 0 17.5 93 | 1820 8 8 0.452 1 17.5 94 | 1840 8 -1 0.0 0 17.5 95 | 1860 8 8 1.12 1 17.5 96 | 1880 1 1 1.01 1 17.5 97 | 1900 8 8 1.7 1 17.5 98 | 1920 2 2 1.02 1 17.5 99 | 1940 5 -1 0.0 0 17.5 100 | 1960 2 5 0.334 0 17.5 101 | 1980 9 9 1.6 1 17.5 102 | 2000 1 1 0.105 1 17.5 103 | 2020 9 9 1.42 1 17.5 104 | 2040 0 0 0.663 1 17.5 105 | 2060 5 3 0.606 0 17.5 106 | 2080 1 1 0.274 1 17.5 107 | 2100 2 2 0.606 1 17.5 108 | 2120 4 8 0.39 0 17.5 109 | 2140 9 9 1.23 1 17.5 110 | 2160 0 0 0.329 1 17.5 111 | 2180 4 4 0.785 1 17.5 112 | 2200 0 -1 0.0 0 17.5 113 | 2220 1 1 1.63 1 17.5 114 | 2240 9 9 0.0244 1 17.5 115 | 2260 4 4 0.823 1 17.5 116 | 2280 4 -1 0.0 0 17.5 117 | 2300 3 -1 0.0 0 17.5 118 | 2320 5 2 0.682 0 17.5 119 | 2340 7 7 1.74 1 17.5 120 | 2360 7 7 0.149 1 17.5 121 | 2380 8 8 1.0 1 17.5 122 | 2400 0 0 0.826 1 17.5 123 | 2420 8 -1 0.0 0 17.5 124 | 2440 7 7 0.627 1 17.5 125 | 2460 9 -1 0.0 0 17.5 126 | 2480 8 8 1.2 1 17.5 127 | 2500 4 4 0.0665 1 17.5 128 | 2520 1 1 0.605 1 17.5 129 | 2540 2 -1 0.0 0 17.5 130 | 2560 7 7 0.364 1 17.5 131 | 2580 6 6 0.0011 1 17.5 132 | 2600 8 8 1.44 1 17.5 133 | 2620 1 1 0.438 1 17.5 134 | 2640 7 7 1.17 1 17.5 135 | 2660 3 5 0.0535 0 17.5 136 | 2680 0 0 1.86 1 17.5 137 | 2700 9 9 0.441 1 17.5 138 | 2720 3 3 0.57 1 17.5 139 | 2740 1 1 0.857 1 17.5 140 | 2760 2 2 0.643 1 17.5 141 | 2780 7 7 1.62 1 17.5 142 | 2800 4 4 0.11 1 17.5 143 | 2820 6 6 1.22 1 17.5 144 | 2840 3 3 1.96 1 17.5 145 | 2860 5 -1 0.0 0 17.5 146 | 2880 4 4 1.4 1 17.5 147 | 2900 3 5 0.0473 0 17.5 148 | 2920 2 3 0.458 0 17.5 149 | 2940 3 0 0.795 0 17.5 150 | 2960 9 9 0.476 1 17.5 151 | 2980 8 8 0.563 1 17.5 152 | 3000 5 5 1.42 1 17.5 153 | 3020 1 1 0.955 1 17.5 154 | 3040 7 7 1.96 1 17.5 155 | 3060 7 7 0.883 1 17.5 156 | 3080 1 1 0.361 1 17.5 157 | 3100 0 0 0.995 1 17.5 158 | 3120 4 -1 0.0 0 17.5 159 | 3140 8 8 0.722 1 17.5 160 | 3160 6 2 0.215 0 17.5 161 | 3180 3 3 0.26 1 17.5 162 | 3200 5 7 0.0822 0 17.5 163 | 3220 3 -1 0.0 0 17.5 164 | 3240 4 -1 0.0 0 17.5 165 | 3260 7 7 1.03 1 17.5 166 | 3280 3 5 0.405 0 17.5 167 | 3300 4 4 0.304 1 17.5 168 | 3320 2 7 1.11 0 17.5 169 | 3340 6 6 0.478 1 17.5 170 | 3360 4 4 0.15 1 17.5 171 | 3380 8 8 1.19 1 17.5 172 | 3400 6 4 0.498 0 17.5 173 | 3420 5 3 0.15 0 17.5 174 | 3440 2 3 0.445 0 17.5 175 | 3460 1 1 0.745 1 17.5 176 | 3480 2 2 1.81 1 17.5 177 | 3500 1 1 0.414 1 17.5 178 | 3520 1 1 1.47 1 17.5 179 | 3540 6 6 0.83 1 17.5 180 | 3560 1 -1 0.0 0 17.5 181 | 3580 4 4 0.0958 1 17.5 182 | 3600 4 2 0.95 0 17.5 183 | 3620 0 -1 0.0 0 17.5 184 | 3640 6 2 0.477 0 17.5 185 | 3660 0 0 0.727 1 17.5 186 | 3680 0 -1 0.0 0 17.5 187 | 3700 3 3 0.648 1 17.5 188 | 3720 2 2 1.36 1 17.5 189 | 3740 0 0 0.924 1 17.5 190 | 3760 7 7 1.27 1 17.5 191 | 3780 4 4 0.215 1 17.5 192 | 3800 9 9 0.585 1 17.5 193 | 3820 0 0 0.00592 1 17.5 194 | 3840 6 3 0.436 0 17.5 195 | 3860 7 7 1.96 1 17.5 196 | 3880 6 -1 0.0 0 17.5 197 | 3900 3 6 0.535 0 17.5 198 | 3920 7 -1 0.0 0 17.5 199 | 3940 6 -1 0.0 0 17.5 200 | 3960 2 2 1.11 1 17.5 201 | 3980 9 9 0.254 1 17.5 202 | 4000 8 0 0.00403 0 17.5 203 | 4020 8 8 0.396 1 17.5 204 | 4040 0 0 1.88 1 17.5 205 | 4060 6 6 0.228 1 17.5 206 | 4080 1 1 0.821 1 17.5 207 | 4100 7 7 1.96 1 17.5 208 | 4120 4 4 0.783 1 17.5 209 | 4140 5 5 0.788 1 17.5 210 | 4160 5 -1 0.0 0 17.5 211 | 4180 0 0 0.0756 1 17.5 212 | 4200 4 2 0.0203 0 17.5 213 | 4220 4 0 1.09 0 17.5 214 | 4240 7 0 0.383 0 17.5 215 | 4260 8 8 1.21 1 17.5 216 | 4280 8 -1 0.0 0 17.5 217 | 4300 8 8 1.03 1 17.5 218 | 4320 1 -1 0.0 0 17.5 219 | 4340 0 0 0.912 1 17.5 220 | 4360 6 -1 0.0 0 17.5 221 | 4380 9 9 0.513 1 17.5 222 | 4400 3 -1 0.0 0 17.5 223 | 4420 5 5 0.29 1 17.5 224 | 4440 2 2 0.358 1 17.5 225 | 4460 9 9 0.362 1 17.5 226 | 4480 9 9 1.41 1 17.5 227 | 4500 3 5 0.4 0 17.5 228 | 4520 3 6 0.25 0 17.5 229 | 4540 9 9 1.48 1 17.5 230 | 4560 1 1 1.1 1 17.5 231 | 4580 6 6 0.0476 1 17.5 232 | 4600 4 7 0.112 0 17.5 233 | 4620 7 3 0.216 0 17.5 234 | 4640 2 2 0.299 1 17.5 235 | 4660 7 7 0.131 1 17.5 236 | 4680 9 9 0.0216 1 17.5 237 | 4700 6 -1 0.0 0 17.5 238 | 4720 8 8 0.173 1 17.5 239 | 4740 5 3 0.618 0 17.5 240 | 4760 3 -1 0.0 0 17.5 241 | 4780 0 0 0.787 1 17.5 242 | 4800 9 9 1.96 1 17.5 243 | 4820 3 3 0.958 1 17.5 244 | 4840 0 0 1.69 1 17.5 245 | 4860 5 5 0.841 1 17.5 246 | 4880 0 -1 0.0 0 17.5 247 | 4900 3 -1 0.0 0 17.5 248 | 4920 7 7 1.42 1 17.5 249 | 4940 6 6 0.169 1 17.5 250 | 4960 4 3 0.089 0 17.5 251 | 4980 1 9 0.555 0 17.5 252 | 5000 7 7 1.96 1 17.5 253 | 5020 8 8 1.34 1 17.5 254 | 5040 3 3 0.0419 1 17.5 255 | 5060 6 6 0.204 1 17.5 256 | 5080 7 7 1.36 1 17.5 257 | 5100 3 3 1.33 1 17.5 258 | 5120 9 9 0.478 1 17.5 259 | 5140 8 8 0.877 1 17.5 260 | 5160 0 0 0.56 1 17.5 261 | 5180 9 -1 0.0 0 17.5 262 | 5200 3 3 1.33 1 17.5 263 | 5220 0 0 0.667 1 17.5 264 | 5240 1 -1 0.0 0 17.5 265 | 5260 0 0 1.73 1 17.5 266 | 5280 0 -1 0.0 0 17.5 267 | 5300 9 9 0.601 1 17.5 268 | 5320 7 7 0.94 1 17.5 269 | 5340 3 4 1.01 0 17.5 270 | 5360 9 9 1.13 1 17.5 271 | 5380 1 9 0.458 0 17.5 272 | 5400 9 9 0.956 1 17.5 273 | 5420 6 6 0.677 1 17.5 274 | 5440 9 9 1.91 1 17.5 275 | 5460 0 0 1.77 1 17.5 276 | 5480 1 1 0.593 1 17.5 277 | 5500 8 8 0.47 1 17.5 278 | 5520 4 -1 0.0 0 17.5 279 | 5540 5 -1 0.0 0 17.5 280 | 5560 3 -1 0.0 0 17.5 281 | 5580 5 3 0.121 0 17.5 282 | 5600 6 -1 0.0 0 17.5 283 | 5620 3 4 0.102 0 17.5 284 | 5640 2 2 0.491 1 17.5 285 | 5660 9 9 0.346 1 17.5 286 | 5680 2 3 0.417 0 17.5 287 | 5700 3 -1 0.0 0 17.5 288 | 5720 9 9 1.51 1 17.5 289 | 5740 6 6 0.379 1 17.5 290 | 5760 2 2 0.242 1 17.5 291 | 5780 7 7 1.84 1 17.5 292 | 5800 2 2 0.182 1 17.5 293 | 5820 5 5 0.268 1 17.5 294 | 5840 2 2 0.461 1 17.5 295 | 5860 0 0 0.146 1 17.5 296 | 5880 0 0 1.22 1 17.5 297 | 5900 4 -1 0.0 0 17.5 298 | 5920 8 8 1.96 1 17.5 299 | 5940 7 4 0.151 0 17.5 300 | 5960 2 6 0.152 0 17.5 301 | 5980 9 9 1.7 1 17.5 302 | 6000 8 8 0.115 1 17.5 303 | 6020 6 6 0.731 1 17.5 304 | 6040 2 5 0.00539 0 17.5 305 | 6060 3 3 0.0204 1 17.5 306 | 6080 1 1 0.625 1 17.5 307 | 6100 1 1 1.15 1 17.5 308 | 6120 5 5 0.741 1 17.5 309 | 6140 9 9 1.05 1 17.5 310 | 6160 2 2 0.258 1 17.5 311 | 6180 0 0 0.626 1 17.5 312 | 6200 3 3 0.49 1 17.5 313 | 6220 2 5 0.178 0 17.5 314 | 6240 2 2 0.0802 1 17.5 315 | 6260 2 -1 0.0 0 17.5 316 | 6280 8 -1 0.0 0 17.5 317 | 6300 1 1 0.493 1 17.5 318 | 6320 0 0 0.588 1 17.5 319 | 6340 3 3 0.827 1 17.5 320 | 6360 2 2 0.196 1 17.5 321 | 6380 3 -1 0.0 0 17.5 322 | 6400 0 -1 0.0 0 17.5 323 | 6420 7 7 1.68 1 17.5 324 | 6440 3 -1 0.0 0 17.5 325 | 6460 3 3 0.289 1 17.5 326 | 6480 0 0 0.66 1 17.5 327 | 6500 7 -1 0.0 0 17.5 328 | 6520 6 -1 0.0 0 17.5 329 | 6540 8 8 0.853 1 17.5 330 | 6560 6 6 0.468 1 17.5 331 | 6580 1 1 0.127 1 17.5 332 | 6600 7 7 1.69 1 17.5 333 | 6620 7 7 1.2 1 17.5 334 | 6640 5 3 0.822 0 17.5 335 | 6660 0 0 1.06 1 17.5 336 | 6680 3 3 0.33 1 17.5 337 | 6700 6 6 0.22 1 17.5 338 | 6720 2 2 0.44 1 17.5 339 | 6740 2 2 0.35 1 17.5 340 | 6760 5 5 1.07 1 17.5 341 | 6780 7 7 1.54 1 17.5 342 | 6800 6 -1 0.0 0 17.5 343 | 6820 1 1 1.7 1 17.5 344 | 6840 9 9 0.669 1 17.5 345 | 6860 0 0 0.335 1 17.5 346 | 6880 2 2 0.562 1 17.5 347 | 6900 3 -1 0.0 0 17.5 348 | 6920 5 -1 0.0 0 17.5 349 | 6940 1 1 1.91 1 17.5 350 | 6960 9 9 0.065 1 17.5 351 | 6980 0 -1 0.0 0 17.5 352 | 7000 2 0 0.122 0 17.5 353 | 7020 7 7 1.47 1 17.5 354 | 7040 7 -1 0.0 0 17.5 355 | 7060 8 8 0.651 1 17.5 356 | 7080 5 4 0.403 0 17.5 357 | 7100 9 -1 0.0 0 17.5 358 | 7120 7 5 0.235 0 17.5 359 | 7140 4 -1 0.0 0 17.5 360 | 7160 5 5 0.97 1 17.5 361 | 7180 6 6 0.74 1 17.5 362 | 7200 4 -1 0.0 0 17.5 363 | 7220 9 9 1.04 1 17.5 364 | 7240 0 0 0.0983 1 17.5 365 | 7260 1 1 0.477 1 17.5 366 | 7280 1 9 0.391 0 17.5 367 | 7300 3 4 0.451 0 17.5 368 | 7320 8 8 0.423 1 17.5 369 | 7340 2 4 0.149 0 17.5 370 | 7360 2 -1 0.0 0 17.5 371 | 7380 7 7 0.911 1 17.5 372 | 7400 3 3 0.668 1 17.5 373 | 7420 3 4 0.118 0 17.5 374 | 7440 3 3 0.582 1 17.5 375 | 7460 5 5 1.15 1 17.5 376 | 7480 1 8 0.612 0 17.5 377 | 7500 6 6 1.88 1 17.5 378 | 7520 4 4 0.199 1 17.5 379 | 7540 5 5 1.4 1 17.5 380 | 7560 0 0 0.485 1 17.5 381 | 7580 8 8 1.12 1 17.5 382 | 7600 8 9 0.3 0 17.5 383 | 7620 5 3 0.206 0 17.5 384 | 7640 9 9 0.0615 1 17.5 385 | 7660 7 -1 0.0 0 17.5 386 | 7680 3 -1 0.0 0 17.5 387 | 7700 6 6 1.22 1 17.5 388 | 7720 5 -1 0.0 0 17.5 389 | 7740 4 4 0.244 1 17.5 390 | 7760 2 9 0.242 0 17.5 391 | 7780 3 -1 0.0 0 17.5 392 | 7800 0 0 0.957 1 17.5 393 | 7820 5 5 0.29 1 17.5 394 | 7840 1 1 1.52 1 17.5 395 | 7860 8 8 0.378 1 17.5 396 | 7880 0 0 0.232 1 17.5 397 | 7900 0 8 0.0991 0 17.5 398 | 7920 9 9 0.38 1 17.5 399 | 7940 2 -1 0.0 0 17.5 400 | 7960 0 0 0.453 1 17.4 401 | 7980 3 3 0.339 1 17.3 402 | 8000 9 9 0.973 1 17.2 403 | 8020 6 3 0.618 0 17.1 404 | 8040 2 2 1.08 1 17.1 405 | 8060 6 6 0.736 1 17.0 406 | 8080 4 -1 0.0 0 17.0 407 | 8100 6 -1 0.0 0 17.0 408 | 8120 8 9 0.668 0 16.9 409 | 8140 5 -1 0.0 0 16.9 410 | 8160 6 6 1.08 1 16.9 411 | 8180 4 4 0.178 1 16.9 412 | 8200 3 4 0.355 0 16.9 413 | 8220 6 -1 0.0 0 16.9 414 | 8240 7 7 1.61 1 16.9 415 | 8260 1 1 1.91 1 16.9 416 | 8280 4 5 0.115 0 16.9 417 | 8300 5 -1 0.0 0 16.9 418 | 8320 7 -1 0.0 0 16.9 419 | 8340 5 3 0.45 0 16.7 420 | 8360 7 7 1.53 1 16.7 421 | 8380 9 9 1.15 1 16.7 422 | 8400 0 -1 0.0 0 16.7 423 | 8420 3 6 0.541 0 16.7 424 | 8440 5 5 0.0846 1 16.7 425 | 8460 8 8 1.28 1 16.7 426 | 8480 2 2 0.479 1 16.7 427 | 8500 4 4 0.725 1 16.7 428 | 8520 8 8 0.299 1 16.7 429 | 8540 4 -1 0.0 0 16.7 430 | 8560 7 7 1.05 1 16.7 431 | 8580 3 -1 0.0 0 16.7 432 | 8600 3 3 0.374 1 16.7 433 | 8620 3 -1 0.0 0 16.7 434 | 8640 1 1 1.01 1 16.7 435 | 8660 7 7 1.16 1 16.7 436 | 8680 3 -1 0.0 0 16.7 437 | 8700 3 3 0.522 1 16.7 438 | 8720 2 0 1.25 0 16.7 439 | 8740 2 2 0.308 1 16.7 440 | 8760 8 8 1.49 1 16.7 441 | 8780 4 4 0.358 1 16.7 442 | 8800 0 0 1.36 1 16.7 443 | 8820 2 2 1.81 1 16.7 444 | 8840 2 -1 0.0 0 16.7 445 | 8860 2 2 0.972 1 16.7 446 | 8880 1 1 0.935 1 16.7 447 | 8900 2 4 0.239 0 16.7 448 | 8920 6 6 0.906 1 16.7 449 | 8940 6 3 0.66 0 16.7 450 | 8960 0 0 0.921 1 16.7 451 | 8980 9 9 1.8 1 16.7 452 | 9000 8 8 1.96 1 16.7 453 | 9020 7 7 0.774 1 16.7 454 | 9040 2 5 0.219 0 16.7 455 | 9060 9 9 0.842 1 16.7 456 | 9080 3 3 0.178 1 16.7 457 | 9100 9 1 0.127 0 16.7 458 | 9120 3 -1 0.0 0 16.7 459 | 9140 7 7 0.0382 1 16.7 460 | 9160 5 3 0.457 0 16.7 461 | 9180 5 5 1.51 1 16.8 462 | 9200 8 8 1.39 1 16.7 463 | 9220 8 8 1.67 1 16.7 464 | 9240 8 8 0.747 1 16.7 465 | 9260 5 5 0.603 1 16.7 466 | 9280 9 -1 0.0 0 16.7 467 | 9300 5 -1 0.0 0 16.7 468 | 9320 9 9 1.13 1 16.7 469 | 9340 1 1 1.88 1 16.7 470 | 9360 5 0 0.00429 0 16.7 471 | 9380 5 3 0.613 0 16.7 472 | 9400 6 6 0.713 1 16.7 473 | 9420 5 5 1.12 1 16.7 474 | 9440 8 8 1.52 1 16.7 475 | 9460 3 7 0.516 0 16.7 476 | 9480 2 2 0.0308 1 16.7 477 | 9500 9 9 0.661 1 16.7 478 | 9520 1 1 0.875 1 16.7 479 | 9540 5 5 0.206 1 16.7 480 | 9560 9 7 0.354 0 16.7 481 | 9580 1 1 0.695 1 16.7 482 | 9600 8 8 0.61 1 16.7 483 | 9620 4 4 0.436 1 16.7 484 | 9640 5 -1 0.0 0 16.7 485 | 9660 6 6 1.8 1 16.7 486 | 9680 8 8 1.52 1 16.7 487 | 9700 0 0 1.07 1 16.7 488 | 9720 8 1 0.0197 0 16.7 489 | 9740 3 6 0.317 0 16.7 490 | 9760 9 -1 0.0 0 16.7 491 | 9780 4 6 0.0379 0 16.7 492 | 9800 1 1 1.91 1 16.7 493 | 9820 0 8 0.336 0 16.7 494 | 9840 4 -1 0.0 0 16.7 495 | 9860 0 5 0.169 0 16.7 496 | 9880 7 7 0.207 1 16.7 497 | 9900 8 8 0.569 1 16.7 498 | 9920 6 6 0.807 1 16.7 499 | 9940 4 5 0.254 0 16.7 500 | 9960 2 0 1.34 0 16.7 501 | 9980 0 0 1.69 1 16.7 502 | -------------------------------------------------------------------------------- /data/certify/cifar10/resnet110/noise_1.00/test/sigma_0.50: -------------------------------------------------------------------------------- 1 | idx label predict radius correct time 2 | 0 3 2 0.502 0 0:00:14.260753 3 | 20 7 7 0.365 1 0:00:14.348896 4 | 40 4 8 0.625 0 0:00:14.554244 5 | 60 7 4 0.269 0 0:00:14.513019 6 | 80 8 8 1.91 1 0:00:14.514495 7 | 100 4 4 0.541 1 0:00:14.518535 8 | 120 8 8 1.75 1 0:00:14.489042 9 | 140 6 2 0.614 0 0:00:14.502152 10 | 160 2 2 1.34 1 0:00:14.536093 11 | 180 0 0 0.377 1 0:00:14.470639 12 | 200 5 3 1.03 0 0:00:14.535109 13 | 220 7 7 1.55 1 0:00:14.465855 14 | 240 1 1 0.908 1 0:00:14.472940 15 | 260 8 3 0.273 0 0:00:14.486534 16 | 280 9 9 1.1 1 0:00:14.460034 17 | 300 6 4 0.187 0 0:00:14.503257 18 | 320 3 2 0.183 0 0:00:14.505910 19 | 340 2 4 0.648 0 0:00:14.490895 20 | 360 9 6 0.0018 0 0:00:14.429836 21 | 380 6 4 1.2 0 0:00:14.490548 22 | 400 9 9 0.729 1 0:00:14.469219 23 | 420 4 4 0.501 1 0:00:14.469759 24 | 440 1 1 0.44 1 0:00:14.501306 25 | 460 5 5 0.532 1 0:00:14.481773 26 | 480 8 8 0.701 1 0:00:14.485138 27 | 500 4 -1 0.0 0 0:00:14.445102 28 | 520 5 5 0.596 1 0:00:14.509751 29 | 540 1 8 1.39 0 0:00:14.504498 30 | 560 0 0 0.38 1 0:00:14.485221 31 | 580 4 4 1.91 1 0:00:14.442113 32 | 600 8 8 1.39 1 0:00:14.469329 33 | 620 7 4 0.556 0 0:00:14.470549 34 | 640 5 3 1.91 0 0:00:14.468914 35 | 660 9 9 1.19 1 0:00:14.499982 36 | 680 9 1 0.0681 0 0:00:14.482205 37 | 700 7 -1 0.0 0 0:00:14.514781 38 | 720 4 7 0.0293 0 0:00:14.452885 39 | 740 2 -1 0.0 0 0:00:14.456233 40 | 760 3 -1 0.0 0 0:00:14.514683 41 | 780 9 -1 0.0 0 0:00:14.449405 42 | 800 7 7 1.91 1 0:00:14.505088 43 | 820 6 -1 0.0 0 0:00:14.494100 44 | 840 7 7 1.91 1 0:00:14.526783 45 | 860 4 4 1.91 1 0:00:14.434383 46 | 880 8 8 1.91 1 0:00:14.472785 47 | 900 2 4 0.307 0 0:00:14.430315 48 | 920 6 2 0.518 0 0:00:14.539771 49 | 940 9 9 1.1 1 0:00:14.464529 50 | 960 9 9 0.799 1 0:00:14.456657 51 | 980 2 -1 0.0 0 0:00:14.484656 52 | 1000 5 5 1.87 1 0:00:14.493587 53 | 1020 1 8 0.806 0 0:00:14.385534 54 | 1040 7 5 0.438 0 0:00:14.523761 55 | 1060 9 -1 0.0 0 0:00:14.499828 56 | 1080 6 6 1.67 1 0:00:14.479753 57 | 1100 7 8 1.3 0 0:00:14.456918 58 | 1120 5 2 0.924 0 0:00:14.480105 59 | 1140 6 4 0.877 0 0:00:14.462265 60 | 1160 8 8 0.174 1 0:00:14.456139 61 | 1180 3 3 0.448 1 0:00:14.387904 62 | 1200 8 0 0.284 0 0:00:14.468246 63 | 1220 9 6 0.282 0 0:00:14.451054 64 | 1240 2 6 1.43 0 0:00:14.450309 65 | 1260 1 9 0.899 0 0:00:14.517980 66 | 1280 3 7 0.362 0 0:00:14.458498 67 | 1300 4 3 0.248 0 0:00:14.405770 68 | 1320 7 2 1.22 0 0:00:14.483056 69 | 1340 1 1 1.16 1 0:00:14.443302 70 | 1360 7 7 0.226 1 0:00:14.495132 71 | 1380 4 5 0.174 0 0:00:14.432075 72 | 1400 5 2 0.497 0 0:00:14.466615 73 | 1420 4 4 0.413 1 0:00:14.452952 74 | 1440 0 0 0.249 1 0:00:14.467673 75 | 1460 6 4 0.117 0 0:00:14.476620 76 | 1480 1 5 0.277 0 0:00:14.441596 77 | 1500 1 1 0.82 1 0:00:14.505503 78 | 1520 7 4 0.0535 0 0:00:14.376559 79 | 1540 8 8 1.64 1 0:00:14.537273 80 | 1560 7 7 1.91 1 0:00:14.505560 81 | 1580 6 1 0.113 0 0:00:14.487358 82 | 1600 8 -1 0.0 0 0:00:14.492057 83 | 1620 5 8 0.315 0 0:00:14.464475 84 | 1640 7 3 0.28 0 0:00:14.491488 85 | 1660 3 3 0.35 1 0:00:14.431566 86 | 1680 6 6 1.91 1 0:00:14.479111 87 | 1700 5 2 1.87 0 0:00:14.417356 88 | 1720 4 2 0.821 0 0:00:14.442744 89 | 1740 1 1 0.189 1 0:00:14.430615 90 | 1760 0 8 0.709 0 0:00:14.403716 91 | 1780 7 -1 0.0 0 0:00:14.389048 92 | 1800 4 8 1.62 0 0:00:14.392776 93 | 1820 8 8 1.48 1 0:00:14.406297 94 | 1840 8 4 1.1 0 0:00:14.410980 95 | 1860 8 8 1.91 1 0:00:14.386441 96 | 1880 1 8 0.849 0 0:00:14.381203 97 | 1900 8 8 1.91 1 0:00:14.365559 98 | 1920 2 2 1.84 1 0:00:14.292651 99 | 1940 5 3 1.16 0 0:00:14.433078 100 | 1960 2 5 1.28 0 0:00:14.404337 101 | 1980 9 9 0.452 1 0:00:14.329074 102 | 2000 1 6 0.448 0 0:00:14.355776 103 | 2020 9 9 1.54 1 0:00:14.358637 104 | 2040 0 8 0.674 0 0:00:14.362852 105 | 2060 5 5 0.292 1 0:00:14.347701 106 | 2080 1 -1 0.0 0 0:00:14.358643 107 | 2100 2 4 1.21 0 0:00:14.328557 108 | 2120 4 8 1.84 0 0:00:14.351626 109 | 2140 9 9 0.6 1 0:00:14.322432 110 | 2160 0 8 1.77 0 0:00:14.340261 111 | 2180 4 4 1.77 1 0:00:14.317538 112 | 2200 0 2 0.0457 0 0:00:14.316711 113 | 2220 1 1 0.649 1 0:00:14.373689 114 | 2240 9 9 0.811 1 0:00:14.342507 115 | 2260 4 4 1.11 1 0:00:14.340267 116 | 2280 4 8 1.37 0 0:00:14.336520 117 | 2300 3 -1 0.0 0 0:00:14.311955 118 | 2320 5 -1 0.0 0 0:00:14.324267 119 | 2340 7 7 1.67 1 0:00:14.365062 120 | 2360 7 7 0.403 1 0:00:14.331631 121 | 2380 8 8 1.29 1 0:00:14.321607 122 | 2400 0 0 1.26 1 0:00:14.322513 123 | 2420 8 8 0.0765 1 0:00:14.381662 124 | 2440 7 7 0.833 1 0:00:14.278740 125 | 2460 9 3 0.191 0 0:00:14.368503 126 | 2480 8 8 1.91 1 0:00:14.331156 127 | 2500 4 2 0.599 0 0:00:14.336555 128 | 2520 1 8 0.178 0 0:00:14.301184 129 | 2540 2 3 0.5 0 0:00:14.321578 130 | 2560 7 4 0.648 0 0:00:14.333872 131 | 2580 6 4 0.906 0 0:00:14.320160 132 | 2600 8 8 1.91 1 0:00:14.321662 133 | 2620 1 4 0.701 0 0:00:14.360818 134 | 2640 7 7 0.14 1 0:00:14.314903 135 | 2660 3 2 0.986 0 0:00:14.345839 136 | 2680 0 0 1.87 1 0:00:14.299886 137 | 2700 9 0 0.277 0 0:00:14.322745 138 | 2720 3 2 0.798 0 0:00:14.336164 139 | 2740 1 1 0.699 1 0:00:14.326773 140 | 2760 2 2 0.0434 1 0:00:14.276177 141 | 2780 7 7 1.47 1 0:00:14.379085 142 | 2800 4 4 0.576 1 0:00:14.301599 143 | 2820 6 6 0.701 1 0:00:14.323634 144 | 2840 3 3 1.91 1 0:00:14.282580 145 | 2860 5 6 0.533 0 0:00:14.347618 146 | 2880 4 4 1.91 1 0:00:14.312035 147 | 2900 3 2 0.156 0 0:00:14.324175 148 | 2920 2 2 0.411 1 0:00:14.322795 149 | 2940 3 0 1.46 0 0:00:14.330476 150 | 2960 9 9 0.232 1 0:00:14.315882 151 | 2980 8 8 1.35 1 0:00:14.342838 152 | 3000 5 5 1.91 1 0:00:14.315494 153 | 3020 1 8 0.92 0 0:00:14.314904 154 | 3040 7 7 1.91 1 0:00:14.352726 155 | 3060 7 7 1.52 1 0:00:14.321048 156 | 3080 1 4 0.0304 0 0:00:14.325140 157 | 3100 0 0 1.24 1 0:00:14.301325 158 | 3120 4 2 0.609 0 0:00:14.329913 159 | 3140 8 8 1.91 1 0:00:14.308403 160 | 3160 6 4 0.668 0 0:00:14.326945 161 | 3180 3 -1 0.0 0 0:00:14.317472 162 | 3200 5 7 1.18 0 0:00:14.326241 163 | 3220 3 3 0.702 1 0:00:14.328259 164 | 3240 4 -1 0.0 0 0:00:14.368996 165 | 3260 7 -1 0.0 0 0:00:14.326500 166 | 3280 3 0 0.254 0 0:00:14.310939 167 | 3300 4 4 1.49 1 0:00:14.349883 168 | 3320 2 7 0.254 0 0:00:14.333468 169 | 3340 6 4 0.208 0 0:00:14.339950 170 | 3360 4 4 0.384 1 0:00:14.316216 171 | 3380 8 8 1.91 1 0:00:14.329068 172 | 3400 6 4 1.87 0 0:00:14.332860 173 | 3420 5 4 0.347 0 0:00:14.340887 174 | 3440 2 3 0.45 0 0:00:14.356870 175 | 3460 1 1 0.545 1 0:00:14.332382 176 | 3480 2 2 1.91 1 0:00:14.341356 177 | 3500 1 1 0.689 1 0:00:14.350852 178 | 3520 1 1 1.91 1 0:00:14.368938 179 | 3540 6 6 0.538 1 0:00:14.325798 180 | 3560 1 -1 0.0 0 0:00:14.347779 181 | 3580 4 4 0.604 1 0:00:14.338479 182 | 3600 4 2 1.67 0 0:00:14.367562 183 | 3620 0 6 0.279 0 0:00:14.351847 184 | 3640 6 2 1.78 0 0:00:14.304122 185 | 3660 0 8 1.69 0 0:00:14.373169 186 | 3680 0 2 0.821 0 0:00:14.339939 187 | 3700 3 -1 0.0 0 0:00:14.349725 188 | 3720 2 2 1.91 1 0:00:14.330755 189 | 3740 0 0 0.748 1 0:00:14.348924 190 | 3760 7 7 1.91 1 0:00:14.318395 191 | 3780 4 2 1.5 0 0:00:14.333459 192 | 3800 9 -1 0.0 0 0:00:14.349926 193 | 3820 0 0 0.684 1 0:00:14.338120 194 | 3840 6 3 0.817 0 0:00:14.363207 195 | 3860 7 7 1.91 1 0:00:14.307930 196 | 3880 6 4 1.18 0 0:00:14.353190 197 | 3900 3 6 0.0366 0 0:00:14.302730 198 | 3920 7 8 0.961 0 0:00:14.356417 199 | 3940 6 4 0.969 0 0:00:14.319504 200 | 3960 2 2 0.722 1 0:00:14.362688 201 | 3980 9 7 1.21 0 0:00:14.360320 202 | 4000 8 0 0.428 0 0:00:14.336081 203 | 4020 8 8 1.15 1 0:00:14.336947 204 | 4040 0 0 1.75 1 0:00:14.310399 205 | 4060 6 3 0.741 0 0:00:14.320710 206 | 4080 1 8 1.81 0 0:00:14.333818 207 | 4100 7 7 1.91 1 0:00:14.349823 208 | 4120 4 0 0.0273 0 0:00:14.382308 209 | 4140 5 3 1.58 0 0:00:14.298037 210 | 4160 5 2 1.27 0 0:00:14.329030 211 | 4180 0 8 1.23 0 0:00:14.365413 212 | 4200 4 2 0.51 0 0:00:14.322507 213 | 4220 4 0 1.91 0 0:00:14.341161 214 | 4240 7 0 0.777 0 0:00:14.345427 215 | 4260 8 8 1.3 1 0:00:14.329375 216 | 4280 8 8 1.17 1 0:00:14.338165 217 | 4300 8 8 1.91 1 0:00:14.329269 218 | 4320 1 8 1.22 0 0:00:14.351560 219 | 4340 0 0 0.125 1 0:00:14.292989 220 | 4360 6 6 0.697 1 0:00:14.384689 221 | 4380 9 8 0.715 0 0:00:14.319510 222 | 4400 3 4 0.415 0 0:00:14.321643 223 | 4420 5 8 0.802 0 0:00:14.342366 224 | 4440 2 2 1.42 1 0:00:14.354761 225 | 4460 9 8 1.13 0 0:00:14.340989 226 | 4480 9 -1 0.0 0 0:00:14.326163 227 | 4500 3 6 0.332 0 0:00:14.321603 228 | 4520 3 6 1.16 0 0:00:14.337091 229 | 4540 9 9 1.54 1 0:00:14.326691 230 | 4560 1 -1 0.0 0 0:00:14.384115 231 | 4580 6 4 1.3 0 0:00:14.301269 232 | 4600 4 4 0.729 1 0:00:14.361005 233 | 4620 7 3 1.56 0 0:00:14.350336 234 | 4640 2 4 0.616 0 0:00:14.306101 235 | 4660 7 -1 0.0 0 0:00:14.353105 236 | 4680 9 1 0.71 0 0:00:14.322191 237 | 4700 6 4 0.81 0 0:00:14.386648 238 | 4720 8 0 0.419 0 0:00:14.333860 239 | 4740 5 3 0.379 0 0:00:14.361800 240 | 4760 3 3 0.00686 1 0:00:14.356988 241 | 4780 0 0 0.395 1 0:00:14.363134 242 | 4800 9 9 1.91 1 0:00:14.315060 243 | 4820 3 5 0.162 0 0:00:14.337430 244 | 4840 0 0 1.84 1 0:00:14.336288 245 | 4860 5 5 0.407 1 0:00:14.371648 246 | 4880 0 2 1.76 0 0:00:14.316225 247 | 4900 3 6 1.57 0 0:00:14.342275 248 | 4920 7 7 0.745 1 0:00:14.341017 249 | 4940 6 2 0.135 0 0:00:14.337153 250 | 4960 4 4 0.0881 1 0:00:14.338384 251 | 4980 1 9 0.184 0 0:00:14.318410 252 | 5000 7 7 1.84 1 0:00:14.383744 253 | 5020 8 8 1.59 1 0:00:14.338115 254 | 5040 3 4 0.0256 0 0:00:14.348839 255 | 5060 6 4 0.595 0 0:00:14.328453 256 | 5080 7 7 1.91 1 0:00:14.357483 257 | 5100 3 3 1.09 1 0:00:14.318982 258 | 5120 9 8 0.861 0 0:00:14.355429 259 | 5140 8 8 1.91 1 0:00:14.317059 260 | 5160 0 8 0.0582 0 0:00:14.346611 261 | 5180 9 4 0.613 0 0:00:14.328148 262 | 5200 3 3 1.91 1 0:00:14.334650 263 | 5220 0 0 1.01 1 0:00:14.337687 264 | 5240 1 2 0.296 0 0:00:14.352919 265 | 5260 0 0 1.91 1 0:00:14.321813 266 | 5280 0 8 1.87 0 0:00:14.334321 267 | 5300 9 9 1.07 1 0:00:14.369558 268 | 5320 7 7 0.974 1 0:00:14.362770 269 | 5340 3 4 0.599 0 0:00:14.341419 270 | 5360 9 9 1.04 1 0:00:14.339264 271 | 5380 1 9 0.928 0 0:00:14.331964 272 | 5400 9 8 0.495 0 0:00:14.324177 273 | 5420 6 6 1.59 1 0:00:14.327261 274 | 5440 9 9 1.44 1 0:00:14.355649 275 | 5460 0 7 0.598 0 0:00:14.368535 276 | 5480 1 8 0.666 0 0:00:14.335977 277 | 5500 8 8 1.46 1 0:00:14.335099 278 | 5520 4 3 0.469 0 0:00:14.362765 279 | 5540 5 2 0.484 0 0:00:14.320348 280 | 5560 3 6 0.197 0 0:00:14.324019 281 | 5580 5 6 1.28 0 0:00:14.369295 282 | 5600 6 0 0.278 0 0:00:14.300689 283 | 5620 3 2 0.875 0 0:00:14.361083 284 | 5640 2 3 0.101 0 0:00:14.344099 285 | 5660 9 8 0.924 0 0:00:14.362576 286 | 5680 2 8 0.551 0 0:00:14.323498 287 | 5700 3 -1 0.0 0 0:00:14.362411 288 | 5720 9 9 1.56 1 0:00:14.343616 289 | 5740 6 4 1.75 0 0:00:14.373607 290 | 5760 2 2 1.2 1 0:00:14.312141 291 | 5780 7 7 0.855 1 0:00:14.373281 292 | 5800 2 4 0.0133 0 0:00:14.327827 293 | 5820 5 -1 0.0 0 0:00:14.371184 294 | 5840 2 -1 0.0 0 0:00:14.343114 295 | 5860 0 8 0.668 0 0:00:14.339405 296 | 5880 0 0 0.728 1 0:00:14.347477 297 | 5900 4 4 1.03 1 0:00:14.363754 298 | 5920 8 8 1.66 1 0:00:14.339641 299 | 5940 7 4 1.75 0 0:00:14.334054 300 | 5960 2 4 0.857 0 0:00:14.357543 301 | 5980 9 9 0.914 1 0:00:14.333889 302 | 6000 8 8 1.87 1 0:00:14.356345 303 | 6020 6 6 1.61 1 0:00:14.327467 304 | 6040 2 7 1.31 0 0:00:14.354960 305 | 6060 3 4 1.1 0 0:00:14.345023 306 | 6080 1 -1 0.0 0 0:00:14.262996 307 | 6100 1 8 0.64 0 0:00:14.397099 308 | 6120 5 5 1.12 1 0:00:14.358169 309 | 6140 9 9 0.819 1 0:00:14.319568 310 | 6160 2 5 1.35 0 0:00:14.346238 311 | 6180 0 -1 0.0 0 0:00:14.335939 312 | 6200 3 3 0.598 1 0:00:14.356852 313 | 6220 2 -1 0.0 0 0:00:14.314770 314 | 6240 2 4 0.402 0 0:00:14.368407 315 | 6260 2 4 0.398 0 0:00:14.316968 316 | 6280 8 4 0.127 0 0:00:14.379302 317 | 6300 1 4 0.272 0 0:00:14.336131 318 | 6320 0 0 1.91 1 0:00:14.365493 319 | 6340 3 3 1.31 1 0:00:14.328599 320 | 6360 2 2 1.48 1 0:00:14.334644 321 | 6380 3 6 0.616 0 0:00:14.351504 322 | 6400 0 8 0.176 0 0:00:14.331140 323 | 6420 7 7 0.062 1 0:00:14.382440 324 | 6440 3 8 0.689 0 0:00:14.325770 325 | 6460 3 4 0.0612 0 0:00:14.333105 326 | 6480 0 0 0.646 1 0:00:14.355450 327 | 6500 7 8 0.0157 0 0:00:14.338653 328 | 6520 6 4 1.02 0 0:00:14.324024 329 | 6540 8 8 1.91 1 0:00:14.344356 330 | 6560 6 -1 0.0 0 0:00:14.348085 331 | 6580 1 2 0.23 0 0:00:14.362120 332 | 6600 7 7 1.91 1 0:00:14.341597 333 | 6620 7 7 1.29 1 0:00:14.330906 334 | 6640 5 3 0.0195 0 0:00:14.333315 335 | 6660 0 0 0.483 1 0:00:14.349654 336 | 6680 3 3 0.127 1 0:00:14.345540 337 | 6700 6 -1 0.0 0 0:00:14.326991 338 | 6720 2 2 1.45 1 0:00:14.320856 339 | 6740 2 4 1.07 0 0:00:14.320175 340 | 6760 5 3 0.102 0 0:00:14.373966 341 | 6780 7 7 1.77 1 0:00:14.330199 342 | 6800 6 4 0.302 0 0:00:14.319868 343 | 6820 1 1 1.33 1 0:00:14.345896 344 | 6840 9 9 1.17 1 0:00:14.355399 345 | 6860 0 0 0.186 1 0:00:14.329840 346 | 6880 2 -1 0.0 0 0:00:14.349433 347 | 6900 3 4 0.59 0 0:00:14.356242 348 | 6920 5 2 1.74 0 0:00:14.330287 349 | 6940 1 1 1.87 1 0:00:14.334181 350 | 6960 9 8 1.83 0 0:00:14.340834 351 | 6980 0 0 0.00823 1 0:00:14.300156 352 | 7000 2 8 0.736 0 0:00:14.351959 353 | 7020 7 7 1.13 1 0:00:14.353994 354 | 7040 7 8 1.91 0 0:00:14.363911 355 | 7060 8 8 1.05 1 0:00:14.333409 356 | 7080 5 4 0.832 0 0:00:14.334554 357 | 7100 9 8 0.26 0 0:00:14.336618 358 | 7120 7 7 0.101 1 0:00:14.341166 359 | 7140 4 4 1.68 1 0:00:14.362059 360 | 7160 5 4 1.15 0 0:00:14.344642 361 | 7180 6 6 1.84 1 0:00:14.328773 362 | 7200 4 2 0.524 0 0:00:14.325600 363 | 7220 9 8 0.351 0 0:00:14.343401 364 | 7240 0 8 1.91 0 0:00:14.312107 365 | 7260 1 8 0.462 0 0:00:14.358999 366 | 7280 1 6 0.399 0 0:00:14.305262 367 | 7300 3 4 0.56 0 0:00:14.331922 368 | 7320 8 8 0.999 1 0:00:14.351168 369 | 7340 2 4 0.931 0 0:00:14.327765 370 | 7360 2 4 1.83 0 0:00:14.334838 371 | 7380 7 7 1.46 1 0:00:14.339803 372 | 7400 3 3 1.29 1 0:00:14.346749 373 | 7420 3 0 0.0586 0 0:00:14.298488 374 | 7440 3 2 0.0534 0 0:00:14.372183 375 | 7460 5 5 1.91 1 0:00:14.344006 376 | 7480 1 8 0.563 0 0:00:14.355524 377 | 7500 6 6 1.91 1 0:00:14.338224 378 | 7520 4 4 1.56 1 0:00:14.324864 379 | 7540 5 5 1.91 1 0:00:14.354052 380 | 7560 0 -1 0.0 0 0:00:14.355483 381 | 7580 8 8 1.91 1 0:00:14.349985 382 | 7600 8 8 0.738 1 0:00:14.339299 383 | 7620 5 3 0.603 0 0:00:14.370959 384 | 7640 9 -1 0.0 0 0:00:14.277472 385 | 7660 7 4 1.37 0 0:00:14.379815 386 | 7680 3 3 1.38 1 0:00:14.355321 387 | 7700 6 6 0.809 1 0:00:14.360164 388 | 7720 5 -1 0.0 0 0:00:14.303431 389 | 7740 4 2 0.386 0 0:00:14.359398 390 | 7760 2 -1 0.0 0 0:00:14.368498 391 | 7780 3 8 0.344 0 0:00:14.322312 392 | 7800 0 0 0.665 1 0:00:14.327515 393 | 7820 5 8 0.881 0 0:00:14.336834 394 | 7840 1 1 1.87 1 0:00:14.372243 395 | 7860 8 4 0.346 0 0:00:14.353720 396 | 7880 0 0 1.09 1 0:00:14.330323 397 | 7900 0 8 0.477 0 0:00:14.353978 398 | 7920 9 8 0.818 0 0:00:14.339945 399 | 7940 2 -1 0.0 0 0:00:14.332996 400 | 7960 0 0 0.242 1 0:00:14.328747 401 | 7980 3 4 0.69 0 0:00:14.346883 402 | 8000 9 9 0.693 1 0:00:14.336344 403 | 8020 6 3 0.436 0 0:00:14.297441 404 | 8040 2 2 0.419 1 0:00:14.342679 405 | 8060 6 -1 0.0 0 0:00:14.333190 406 | 8080 4 4 0.144 1 0:00:14.317881 407 | 8100 6 -1 0.0 0 0:00:14.312716 408 | 8120 8 8 0.715 1 0:00:14.359991 409 | 8140 5 4 1.84 0 0:00:14.328249 410 | 8160 6 6 1.52 1 0:00:14.355975 411 | 8180 4 4 0.372 1 0:00:14.319064 412 | 8200 3 -1 0.0 0 0:00:14.334832 413 | 8220 6 -1 0.0 0 0:00:14.314748 414 | 8240 7 7 1.5 1 0:00:14.361883 415 | 8260 1 1 0.17 1 0:00:14.346026 416 | 8280 4 -1 0.0 0 0:00:14.347736 417 | 8300 5 2 1.27 0 0:00:14.307997 418 | 8320 7 4 1.56 0 0:00:14.341300 419 | 8340 5 3 0.63 0 0:00:14.349405 420 | 8360 7 4 0.154 0 0:00:14.319071 421 | 8380 9 8 0.409 0 0:00:14.337004 422 | 8400 0 8 0.0526 0 0:00:14.339651 423 | 8420 3 4 0.233 0 0:00:14.340102 424 | 8440 5 3 0.0606 0 0:00:14.347626 425 | 8460 8 8 1.91 1 0:00:14.288629 426 | 8480 2 4 0.0412 0 0:00:14.316075 427 | 8500 4 4 0.225 1 0:00:14.347934 428 | 8520 8 4 0.776 0 0:00:14.337139 429 | 8540 4 2 0.325 0 0:00:14.312202 430 | 8560 7 7 0.194 1 0:00:14.322416 431 | 8580 3 2 0.037 0 0:00:14.377486 432 | 8600 3 4 0.159 0 0:00:14.339806 433 | 8620 3 4 0.606 0 0:00:14.314949 434 | 8640 1 1 1.38 1 0:00:14.371932 435 | 8660 7 7 0.665 1 0:00:14.338834 436 | 8680 3 4 0.0978 0 0:00:14.363549 437 | 8700 3 3 1.71 1 0:00:14.339277 438 | 8720 2 0 1.3 0 0:00:14.349474 439 | 8740 2 4 1.11 0 0:00:14.319043 440 | 8760 8 8 1.14 1 0:00:14.391077 441 | 8780 4 4 0.987 1 0:00:14.283131 442 | 8800 0 0 1.91 1 0:00:14.352883 443 | 8820 2 2 0.78 1 0:00:14.347805 444 | 8840 2 -1 0.0 0 0:00:14.362284 445 | 8860 2 2 0.552 1 0:00:14.330706 446 | 8880 1 1 0.415 1 0:00:14.352176 447 | 8900 2 4 1.5 0 0:00:14.350649 448 | 8920 6 6 1.13 1 0:00:14.360373 449 | 8940 6 6 0.115 1 0:00:14.365397 450 | 8960 0 0 1.91 1 0:00:14.340553 451 | 8980 9 9 0.951 1 0:00:14.368227 452 | 9000 8 8 1.91 1 0:00:14.250843 453 | 9020 7 -1 0.0 0 0:00:14.430918 454 | 9040 2 5 0.825 0 0:00:14.361758 455 | 9060 9 5 0.728 0 0:00:14.319256 456 | 9080 3 3 0.504 1 0:00:14.357976 457 | 9100 9 6 0.453 0 0:00:14.367532 458 | 9120 3 -1 0.0 0 0:00:14.344448 459 | 9140 7 4 0.00065 0 0:00:14.291768 460 | 9160 5 5 0.545 1 0:00:14.383770 461 | 9180 5 5 1.91 1 0:00:14.345354 462 | 9200 8 8 1.91 1 0:00:14.334551 463 | 9220 8 8 1.91 1 0:00:14.365463 464 | 9240 8 8 1.77 1 0:00:14.340728 465 | 9260 5 5 1.2 1 0:00:14.337726 466 | 9280 9 8 0.272 0 0:00:14.350229 467 | 9300 5 2 0.451 0 0:00:14.318248 468 | 9320 9 8 0.16 0 0:00:14.353187 469 | 9340 1 1 1.91 1 0:00:14.365501 470 | 9360 5 -1 0.0 0 0:00:14.350924 471 | 9380 5 3 1.36 0 0:00:14.393360 472 | 9400 6 4 0.158 0 0:00:14.331827 473 | 9420 5 6 0.0157 0 0:00:14.326165 474 | 9440 8 8 1.91 1 0:00:14.375585 475 | 9460 3 6 0.267 0 0:00:14.351903 476 | 9480 2 4 1.39 0 0:00:14.323893 477 | 9500 9 -1 0.0 0 0:00:14.348245 478 | 9520 1 1 1.3 1 0:00:14.373777 479 | 9540 5 -1 0.0 0 0:00:14.321848 480 | 9560 9 7 0.843 0 0:00:14.350507 481 | 9580 1 4 0.841 0 0:00:14.364640 482 | 9600 8 8 1.91 1 0:00:14.331119 483 | 9620 4 4 0.28 1 0:00:14.336773 484 | 9640 5 6 0.0726 0 0:00:14.346887 485 | 9660 6 6 1.91 1 0:00:14.355296 486 | 9680 8 8 1.91 1 0:00:14.372817 487 | 9700 0 0 0.877 1 0:00:14.320988 488 | 9720 8 8 0.0438 1 0:00:14.341066 489 | 9740 3 4 0.859 0 0:00:14.335354 490 | 9760 9 7 0.249 0 0:00:14.370011 491 | 9780 4 4 0.536 1 0:00:14.319711 492 | 9800 1 1 1.91 1 0:00:14.356350 493 | 9820 0 8 0.899 0 0:00:14.352618 494 | 9840 4 5 0.548 0 0:00:14.345452 495 | 9860 0 8 0.865 0 0:00:14.327323 496 | 9880 7 4 0.644 0 0:00:14.355347 497 | 9900 8 8 0.8 1 0:00:14.371057 498 | 9920 6 2 0.424 0 0:00:14.328738 499 | 9940 4 -1 0.0 0 0:00:14.363681 500 | 9960 2 0 1.91 0 0:00:14.337633 501 | 9980 0 0 1.87 1 0:00:14.363302 502 | -------------------------------------------------------------------------------- /data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00: -------------------------------------------------------------------------------- 1 | idx label predict radius correct time 2 | 0 3 -1 0.0 0 15.3 3 | 20 7 -1 0.0 0 15.4 4 | 40 4 -1 0.0 0 15.5 5 | 60 7 7 0.689 1 15.6 6 | 80 8 8 1.16 1 15.7 7 | 100 4 -1 0.0 0 15.8 8 | 120 8 8 0.938 1 15.8 9 | 140 6 -1 0.0 0 15.9 10 | 160 2 0 0.105 0 15.9 11 | 180 0 0 0.777 1 16.1 12 | 200 5 3 0.257 0 16.1 13 | 220 7 7 1.79 1 16.2 14 | 240 1 1 2.1 1 16.2 15 | 260 8 -1 0.0 0 16.2 16 | 280 9 9 2.46 1 16.2 17 | 300 6 6 0.181 1 16.1 18 | 320 3 -1 0.0 0 16.1 19 | 340 2 4 0.117 0 16.1 20 | 360 9 -1 0.0 0 16.2 21 | 380 6 6 0.0784 1 16.2 22 | 400 9 9 1.0 1 16.2 23 | 420 4 4 0.391 1 16.2 24 | 440 1 1 0.928 1 16.2 25 | 460 5 5 0.721 1 16.2 26 | 480 8 9 0.509 0 16.2 27 | 500 4 6 0.258 0 16.2 28 | 520 5 5 0.206 1 16.2 29 | 540 1 -1 0.0 0 16.2 30 | 560 0 0 1.47 1 16.2 31 | 580 4 4 0.458 1 16.2 32 | 600 8 8 0.475 1 16.2 33 | 620 7 -1 0.0 0 16.2 34 | 640 5 3 1.92 0 16.2 35 | 660 9 9 2.37 1 16.2 36 | 680 9 -1 0.0 0 16.2 37 | 700 7 7 0.326 1 16.2 38 | 720 4 9 0.165 0 16.2 39 | 740 2 5 0.222 0 16.2 40 | 760 3 9 0.0399 0 16.2 41 | 780 9 -1 0.0 0 16.2 42 | 800 7 7 2.0 1 16.2 43 | 820 6 -1 0.0 0 16.2 44 | 840 7 7 2.9 1 16.2 45 | 860 4 4 0.978 1 16.2 46 | 880 8 8 2.77 1 16.2 47 | 900 2 -1 0.0 0 16.2 48 | 920 6 6 0.125 1 16.2 49 | 940 9 9 1.07 1 16.2 50 | 960 9 9 1.86 1 16.2 51 | 980 2 6 0.842 0 16.2 52 | 1000 5 5 2.25 1 16.2 53 | 1020 1 1 0.734 1 16.2 54 | 1040 7 5 1.05 0 16.2 55 | 1060 9 9 1.44 1 16.2 56 | 1080 6 6 1.1 1 16.2 57 | 1100 7 8 0.482 0 16.2 58 | 1120 5 -1 0.0 0 16.2 59 | 1140 6 6 0.113 1 16.2 60 | 1160 8 8 0.106 1 16.2 61 | 1180 3 3 0.671 1 16.2 62 | 1200 8 8 0.376 1 16.2 63 | 1220 9 -1 0.0 0 16.2 64 | 1240 2 6 0.897 0 16.2 65 | 1260 1 9 0.994 0 16.2 66 | 1280 3 7 0.515 0 16.2 67 | 1300 4 3 0.608 0 16.2 68 | 1320 7 2 0.0536 0 16.2 69 | 1340 1 1 1.05 1 16.2 70 | 1360 7 7 0.912 1 16.2 71 | 1380 4 5 0.868 0 16.2 72 | 1400 5 -1 0.0 0 16.2 73 | 1420 4 7 0.00959 0 16.2 74 | 1440 0 0 0.0527 1 16.2 75 | 1460 6 6 0.53 1 16.2 76 | 1480 1 -1 0.0 0 16.2 77 | 1500 1 1 0.34 1 16.2 78 | 1520 7 7 0.618 1 16.2 79 | 1540 8 8 0.829 1 16.2 80 | 1560 7 7 3.05 1 16.2 81 | 1580 6 1 0.244 0 16.2 82 | 1600 8 -1 0.0 0 16.2 83 | 1620 5 -1 0.0 0 16.2 84 | 1640 7 5 0.218 0 16.2 85 | 1660 3 -1 0.0 0 16.2 86 | 1680 6 6 2.04 1 16.2 87 | 1700 5 2 0.454 0 16.2 88 | 1720 4 -1 0.0 0 16.2 89 | 1740 1 -1 0.0 0 16.2 90 | 1760 0 -1 0.0 0 16.2 91 | 1780 7 -1 0.0 0 16.2 92 | 1800 4 8 0.362 0 16.2 93 | 1820 8 8 0.29 1 16.2 94 | 1840 8 -1 0.0 0 16.2 95 | 1860 8 8 1.71 1 16.2 96 | 1880 1 1 0.685 1 16.2 97 | 1900 8 8 2.34 1 16.2 98 | 1920 2 2 0.53 1 16.2 99 | 1940 5 3 1.21 0 16.2 100 | 1960 2 5 1.34 0 16.2 101 | 1980 9 9 1.45 1 16.2 102 | 2000 1 -1 0.0 0 16.2 103 | 2020 9 9 2.07 1 16.2 104 | 2040 0 8 0.536 0 16.2 105 | 2060 5 5 0.733 1 16.2 106 | 2080 1 1 0.0422 1 16.2 107 | 2100 2 4 0.462 0 16.2 108 | 2120 4 8 0.99 0 16.2 109 | 2140 9 9 1.23 1 16.2 110 | 2160 0 8 1.15 0 16.2 111 | 2180 4 4 0.581 1 16.2 112 | 2200 0 0 0.504 1 16.2 113 | 2220 1 1 1.32 1 16.2 114 | 2240 9 9 1.13 1 16.2 115 | 2260 4 -1 0.0 0 16.2 116 | 2280 4 8 0.0794 0 16.2 117 | 2300 3 -1 0.0 0 16.2 118 | 2320 5 -1 0.0 0 16.2 119 | 2340 7 7 1.32 1 16.2 120 | 2360 7 7 0.423 1 16.2 121 | 2380 8 8 0.817 1 16.2 122 | 2400 0 0 1.23 1 16.2 123 | 2420 8 -1 0.0 0 16.2 124 | 2440 7 7 0.705 1 16.2 125 | 2460 9 9 0.233 1 16.2 126 | 2480 8 8 1.21 1 16.2 127 | 2500 4 -1 0.0 0 16.2 128 | 2520 1 1 0.171 1 16.2 129 | 2540 2 3 0.718 0 16.2 130 | 2560 7 7 0.106 1 16.2 131 | 2580 6 -1 0.0 0 16.2 132 | 2600 8 8 2.39 1 16.2 133 | 2620 1 1 0.0818 1 16.2 134 | 2640 7 7 0.914 1 16.2 135 | 2660 3 -1 0.0 0 16.2 136 | 2680 0 0 2.5 1 16.2 137 | 2700 9 -1 0.0 0 16.2 138 | 2720 3 -1 0.0 0 16.2 139 | 2740 1 1 0.741 1 16.2 140 | 2760 2 2 0.484 1 16.2 141 | 2780 7 7 1.76 1 16.2 142 | 2800 4 6 0.207 0 16.2 143 | 2820 6 6 1.9 1 16.2 144 | 2840 3 3 2.8 1 16.2 145 | 2860 5 -1 0.0 0 16.2 146 | 2880 4 4 2.16 1 16.2 147 | 2900 3 2 0.00313 0 16.2 148 | 2920 2 -1 0.0 0 16.2 149 | 2940 3 0 1.32 0 16.2 150 | 2960 9 9 0.811 1 16.2 151 | 2980 8 -1 0.0 0 16.2 152 | 3000 5 5 2.45 1 16.2 153 | 3020 1 1 0.913 1 16.2 154 | 3040 7 7 3.22 1 16.2 155 | 3060 7 7 1.47 1 16.2 156 | 3080 1 -1 0.0 0 16.2 157 | 3100 0 0 1.92 1 16.2 158 | 3120 4 6 0.659 0 16.2 159 | 3140 8 8 0.955 1 16.2 160 | 3160 6 4 0.319 0 16.2 161 | 3180 3 -1 0.0 0 16.2 162 | 3200 5 7 0.806 0 16.2 163 | 3220 3 -1 0.0 0 16.2 164 | 3240 4 -1 0.0 0 16.2 165 | 3260 7 -1 0.0 0 16.2 166 | 3280 3 5 0.347 0 16.2 167 | 3300 4 -1 0.0 0 16.2 168 | 3320 2 7 1.15 0 16.2 169 | 3340 6 6 0.255 1 16.2 170 | 3360 4 6 0.245 0 16.2 171 | 3380 8 8 0.329 1 16.2 172 | 3400 6 4 1.24 0 16.2 173 | 3420 5 -1 0.0 0 16.2 174 | 3440 2 -1 0.0 0 16.2 175 | 3460 1 1 1.75 1 16.2 176 | 3480 2 2 2.88 1 16.2 177 | 3500 1 1 1.14 1 16.2 178 | 3520 1 1 2.06 1 16.2 179 | 3540 6 6 0.575 1 16.2 180 | 3560 1 9 0.376 0 16.2 181 | 3580 4 -1 0.0 0 16.2 182 | 3600 4 2 0.824 0 16.2 183 | 3620 0 -1 0.0 0 16.2 184 | 3640 6 2 0.987 0 16.2 185 | 3660 0 8 0.101 0 16.2 186 | 3680 0 -1 0.0 0 16.2 187 | 3700 3 -1 0.0 0 16.2 188 | 3720 2 2 1.46 1 16.2 189 | 3740 0 0 1.27 1 16.2 190 | 3760 7 7 3.04 1 16.2 191 | 3780 4 -1 0.0 0 16.2 192 | 3800 9 9 0.55 1 16.2 193 | 3820 0 0 0.5 1 16.2 194 | 3840 6 3 0.31 0 16.2 195 | 3860 7 7 2.25 1 16.2 196 | 3880 6 4 0.19 0 16.2 197 | 3900 3 -1 0.0 0 16.2 198 | 3920 7 -1 0.0 0 16.2 199 | 3940 6 -1 0.0 0 16.2 200 | 3960 2 2 0.648 1 16.2 201 | 3980 9 7 0.666 0 16.2 202 | 4000 8 -1 0.0 0 16.2 203 | 4020 8 -1 0.0 0 16.2 204 | 4040 0 0 1.64 1 16.2 205 | 4060 6 3 0.65 0 16.2 206 | 4080 1 -1 0.0 0 16.2 207 | 4100 7 7 3.04 1 16.2 208 | 4120 4 0 1.17 0 16.2 209 | 4140 5 3 1.3 0 16.2 210 | 4160 5 2 0.0858 0 16.2 211 | 4180 0 0 0.406 1 16.2 212 | 4200 4 -1 0.0 0 16.2 213 | 4220 4 0 2.33 0 16.2 214 | 4240 7 0 0.775 0 16.2 215 | 4260 8 8 0.589 1 16.2 216 | 4280 8 8 0.0455 1 16.2 217 | 4300 8 8 1.29 1 16.2 218 | 4320 1 8 0.888 0 16.2 219 | 4340 0 0 0.723 1 16.2 220 | 4360 6 -1 0.0 0 16.2 221 | 4380 9 9 0.703 1 16.2 222 | 4400 3 -1 0.0 0 16.2 223 | 4420 5 -1 0.0 0 16.2 224 | 4440 2 2 0.21 1 16.2 225 | 4460 9 -1 0.0 0 16.2 226 | 4480 9 9 0.35 1 16.2 227 | 4500 3 3 0.129 1 16.2 228 | 4520 3 6 0.487 0 16.2 229 | 4540 9 9 1.98 1 16.2 230 | 4560 1 1 1.05 1 16.2 231 | 4580 6 -1 0.0 0 16.2 232 | 4600 4 7 0.597 0 16.2 233 | 4620 7 3 0.944 0 16.2 234 | 4640 2 -1 0.0 0 16.2 235 | 4660 7 -1 0.0 0 16.2 236 | 4680 9 1 0.68 0 16.2 237 | 4700 6 -1 0.0 0 16.2 238 | 4720 8 -1 0.0 0 16.2 239 | 4740 5 3 0.253 0 16.2 240 | 4760 3 3 0.146 1 16.2 241 | 4780 0 0 1.03 1 16.2 242 | 4800 9 9 2.91 1 16.2 243 | 4820 3 -1 0.0 0 16.2 244 | 4840 0 0 2.04 1 16.2 245 | 4860 5 5 0.274 1 16.2 246 | 4880 0 2 1.08 0 16.2 247 | 4900 3 6 0.748 0 16.2 248 | 4920 7 7 1.18 1 16.2 249 | 4940 6 6 0.00465 1 16.2 250 | 4960 4 -1 0.0 0 16.2 251 | 4980 1 9 0.828 0 16.2 252 | 5000 7 7 2.4 1 16.2 253 | 5020 8 8 0.421 1 16.2 254 | 5040 3 3 0.1 1 16.2 255 | 5060 6 -1 0.0 0 16.2 256 | 5080 7 7 2.45 1 16.2 257 | 5100 3 3 1.0 1 16.2 258 | 5120 9 9 0.812 1 16.2 259 | 5140 8 8 1.22 1 16.2 260 | 5160 0 0 0.357 1 16.2 261 | 5180 9 -1 0.0 0 16.2 262 | 5200 3 3 2.66 1 16.2 263 | 5220 0 0 1.31 1 16.2 264 | 5240 1 -1 0.0 0 16.2 265 | 5260 0 0 2.81 1 16.2 266 | 5280 0 8 0.979 0 16.2 267 | 5300 9 9 0.786 1 16.2 268 | 5320 7 7 0.801 1 16.2 269 | 5340 3 4 0.81 0 16.2 270 | 5360 9 9 1.49 1 16.2 271 | 5380 1 9 0.993 0 16.2 272 | 5400 9 9 0.265 1 16.2 273 | 5420 6 6 0.781 1 16.2 274 | 5440 9 9 2.32 1 16.2 275 | 5460 0 7 0.44 0 16.2 276 | 5480 1 1 0.451 1 16.2 277 | 5500 8 -1 0.0 0 16.2 278 | 5520 4 -1 0.0 0 16.2 279 | 5540 5 -1 0.0 0 16.2 280 | 5560 3 -1 0.0 0 16.2 281 | 5580 5 6 0.491 0 16.2 282 | 5600 6 -1 0.0 0 16.2 283 | 5620 3 2 0.545 0 16.2 284 | 5640 2 -1 0.0 0 16.2 285 | 5660 9 9 0.166 1 16.2 286 | 5680 2 -1 0.0 0 16.2 287 | 5700 3 -1 0.0 0 16.2 288 | 5720 9 9 1.97 1 16.2 289 | 5740 6 4 0.244 0 16.2 290 | 5760 2 2 0.143 1 16.2 291 | 5780 7 7 1.56 1 16.2 292 | 5800 2 -1 0.0 0 16.2 293 | 5820 5 -1 0.0 0 16.2 294 | 5840 2 -1 0.0 0 16.2 295 | 5860 0 -1 0.0 0 16.2 296 | 5880 0 0 0.259 1 16.2 297 | 5900 4 -1 0.0 0 16.2 298 | 5920 8 8 2.45 1 16.2 299 | 5940 7 4 0.529 0 16.2 300 | 5960 2 -1 0.0 0 16.2 301 | 5980 9 9 1.72 1 16.2 302 | 6000 8 -1 0.0 0 16.2 303 | 6020 6 6 0.335 1 16.2 304 | 6040 2 7 1.06 0 16.2 305 | 6060 3 -1 0.0 0 16.2 306 | 6080 1 -1 0.0 0 16.2 307 | 6100 1 1 0.83 1 16.2 308 | 6120 5 5 1.05 1 16.2 309 | 6140 9 9 1.53 1 16.2 310 | 6160 2 5 1.05 0 16.2 311 | 6180 0 -1 0.0 0 16.2 312 | 6200 3 -1 0.0 0 16.2 313 | 6220 2 6 0.215 0 16.2 314 | 6240 2 -1 0.0 0 16.2 315 | 6260 2 -1 0.0 0 16.2 316 | 6280 8 -1 0.0 0 16.2 317 | 6300 1 -1 0.0 0 16.2 318 | 6320 0 0 2.29 1 16.2 319 | 6340 3 3 0.665 1 16.2 320 | 6360 2 2 0.166 1 16.2 321 | 6380 3 -1 0.0 0 16.2 322 | 6400 0 1 0.423 0 16.2 323 | 6420 7 7 0.522 1 16.2 324 | 6440 3 -1 0.0 0 16.2 325 | 6460 3 -1 0.0 0 16.2 326 | 6480 0 0 0.485 1 16.2 327 | 6500 7 -1 0.0 0 16.2 328 | 6520 6 4 0.309 0 16.2 329 | 6540 8 8 1.16 1 16.2 330 | 6560 6 6 0.6 1 16.2 331 | 6580 1 -1 0.0 0 16.2 332 | 6600 7 7 3.56 1 16.2 333 | 6620 7 7 1.42 1 16.2 334 | 6640 5 3 0.413 0 16.2 335 | 6660 0 0 0.983 1 16.2 336 | 6680 3 5 0.988 0 16.2 337 | 6700 6 6 0.164 1 16.2 338 | 6720 2 2 0.528 1 16.2 339 | 6740 2 -1 0.0 0 16.2 340 | 6760 5 5 0.0267 1 16.2 341 | 6780 7 7 2.17 1 16.2 342 | 6800 6 6 0.053 1 16.2 343 | 6820 1 1 2.02 1 16.2 344 | 6840 9 9 0.907 1 16.2 345 | 6860 0 0 0.985 1 16.2 346 | 6880 2 7 0.0815 0 16.2 347 | 6900 3 -1 0.0 0 16.2 348 | 6920 5 2 0.978 0 16.2 349 | 6940 1 1 2.84 1 16.2 350 | 6960 9 8 1.29 0 16.2 351 | 6980 0 9 0.00741 0 16.2 352 | 7000 2 -1 0.0 0 16.2 353 | 7020 7 7 0.983 1 16.2 354 | 7040 7 8 0.289 0 16.2 355 | 7060 8 8 0.456 1 16.2 356 | 7080 5 4 0.294 0 16.2 357 | 7100 9 3 0.239 0 16.2 358 | 7120 7 7 0.0409 1 16.2 359 | 7140 4 -1 0.0 0 16.2 360 | 7160 5 -1 0.0 0 16.2 361 | 7180 6 6 1.64 1 16.2 362 | 7200 4 6 0.209 0 16.2 363 | 7220 9 9 0.884 1 16.2 364 | 7240 0 8 0.393 0 16.2 365 | 7260 1 -1 0.0 0 16.2 366 | 7280 1 1 0.041 1 16.2 367 | 7300 3 -1 0.0 0 16.2 368 | 7320 8 8 0.0225 1 16.2 369 | 7340 2 4 0.358 0 16.2 370 | 7360 2 4 0.573 0 16.2 371 | 7380 7 7 0.799 1 16.2 372 | 7400 3 3 0.642 1 16.2 373 | 7420 3 0 0.398 0 16.2 374 | 7440 3 3 0.196 1 16.2 375 | 7460 5 5 2.81 1 16.2 376 | 7480 1 8 0.601 0 16.2 377 | 7500 6 6 3.08 1 16.2 378 | 7520 4 -1 0.0 0 16.2 379 | 7540 5 5 2.78 1 16.2 380 | 7560 0 -1 0.0 0 16.2 381 | 7580 8 8 0.896 1 16.2 382 | 7600 8 9 0.755 0 16.2 383 | 7620 5 3 0.129 0 16.2 384 | 7640 9 9 0.49 1 16.2 385 | 7660 7 -1 0.0 0 16.2 386 | 7680 3 -1 0.0 0 16.2 387 | 7700 6 6 1.56 1 16.2 388 | 7720 5 5 0.385 1 16.2 389 | 7740 4 -1 0.0 0 16.2 390 | 7760 2 -1 0.0 0 16.2 391 | 7780 3 -1 0.0 0 16.2 392 | 7800 0 0 1.43 1 16.2 393 | 7820 5 8 0.148 0 16.2 394 | 7840 1 1 2.08 1 16.2 395 | 7860 8 -1 0.0 0 16.2 396 | 7880 0 0 0.967 1 16.2 397 | 7900 0 8 0.0134 0 16.2 398 | 7920 9 -1 0.0 0 16.2 399 | 7940 2 6 0.077 0 16.2 400 | 7960 0 0 0.774 1 16.2 401 | 7980 3 -1 0.0 0 16.2 402 | 8000 9 9 1.18 1 16.2 403 | 8020 6 3 0.342 0 16.2 404 | 8040 2 2 0.554 1 16.2 405 | 8060 6 6 0.644 1 16.2 406 | 8080 4 -1 0.0 0 16.2 407 | 8100 6 -1 0.0 0 16.2 408 | 8120 8 9 0.875 0 16.2 409 | 8140 5 4 0.269 0 16.2 410 | 8160 6 6 0.175 1 16.2 411 | 8180 4 4 0.6 1 16.2 412 | 8200 3 -1 0.0 0 16.2 413 | 8220 6 6 0.041 1 16.2 414 | 8240 7 7 2.49 1 16.2 415 | 8260 1 1 0.683 1 16.2 416 | 8280 4 -1 0.0 0 16.2 417 | 8300 5 2 0.151 0 16.2 418 | 8320 7 -1 0.0 0 16.2 419 | 8340 5 3 1.03 0 16.2 420 | 8360 7 7 0.765 1 16.2 421 | 8380 9 9 0.762 1 16.2 422 | 8400 0 0 0.049 1 16.2 423 | 8420 3 6 0.197 0 16.2 424 | 8440 5 -1 0.0 0 16.2 425 | 8460 8 8 1.73 1 16.2 426 | 8480 2 -1 0.0 0 16.2 427 | 8500 4 4 0.114 1 16.2 428 | 8520 8 1 0.00909 0 16.2 429 | 8540 4 -1 0.0 0 16.2 430 | 8560 7 7 0.841 1 16.2 431 | 8580 3 -1 0.0 0 16.2 432 | 8600 3 -1 0.0 0 16.2 433 | 8620 3 -1 0.0 0 16.2 434 | 8640 1 1 2.53 1 16.2 435 | 8660 7 7 0.893 1 16.2 436 | 8680 3 9 0.221 0 16.2 437 | 8700 3 3 1.35 1 16.2 438 | 8720 2 0 2.39 0 16.2 439 | 8740 2 -1 0.0 0 16.2 440 | 8760 8 8 1.27 1 16.2 441 | 8780 4 4 0.0863 1 16.2 442 | 8800 0 0 3.19 1 16.2 443 | 8820 2 2 0.818 1 16.2 444 | 8840 2 1 0.302 0 16.2 445 | 8860 2 2 0.121 1 16.2 446 | 8880 1 1 0.362 1 16.2 447 | 8900 2 -1 0.0 0 16.2 448 | 8920 6 6 1.72 1 16.2 449 | 8940 6 3 0.226 0 16.2 450 | 8960 0 0 1.93 1 16.2 451 | 8980 9 9 1.69 1 16.1 452 | 9000 8 8 3.12 1 16.2 453 | 9020 7 9 0.268 0 16.2 454 | 9040 2 5 1.2 0 16.2 455 | 9060 9 5 0.162 0 16.2 456 | 9080 3 3 0.143 1 16.2 457 | 9100 9 -1 0.0 0 16.2 458 | 9120 3 -1 0.0 0 16.2 459 | 9140 7 -1 0.0 0 16.2 460 | 9160 5 5 0.198 1 16.2 461 | 9180 5 5 3.27 1 16.2 462 | 9200 8 8 2.57 1 16.2 463 | 9220 8 8 2.43 1 16.2 464 | 9240 8 8 0.796 1 16.2 465 | 9260 5 5 0.637 1 16.2 466 | 9280 9 -1 0.0 0 16.2 467 | 9300 5 -1 0.0 0 16.2 468 | 9320 9 9 1.14 1 16.2 469 | 9340 1 1 2.12 1 16.2 470 | 9360 5 -1 0.0 0 16.2 471 | 9380 5 3 0.0606 0 16.2 472 | 9400 6 6 0.0884 1 16.2 473 | 9420 5 5 0.529 1 16.2 474 | 9440 8 8 2.55 1 16.2 475 | 9460 3 7 0.0941 0 16.2 476 | 9480 2 -1 0.0 0 16.2 477 | 9500 9 -1 0.0 0 16.2 478 | 9520 1 1 1.03 1 16.2 479 | 9540 5 5 0.0202 1 16.2 480 | 9560 9 7 1.52 0 16.2 481 | 9580 1 -1 0.0 0 16.2 482 | 9600 8 8 1.46 1 16.2 483 | 9620 4 -1 0.0 0 16.2 484 | 9640 5 -1 0.0 0 16.2 485 | 9660 6 6 1.73 1 16.2 486 | 9680 8 8 1.4 1 16.2 487 | 9700 0 0 1.05 1 16.2 488 | 9720 8 -1 0.0 0 16.2 489 | 9740 3 6 0.356 0 16.2 490 | 9760 9 9 0.344 1 16.2 491 | 9780 4 6 0.813 0 16.2 492 | 9800 1 1 3.18 1 16.2 493 | 9820 0 8 0.0948 0 16.2 494 | 9840 4 5 0.56 0 16.2 495 | 9860 0 -1 0.0 0 16.2 496 | 9880 7 7 0.0843 1 16.2 497 | 9900 8 0 0.179 0 16.2 498 | 9920 6 6 0.104 1 16.2 499 | 9940 4 -1 0.0 0 16.2 500 | 9960 2 0 3.3 0 16.2 501 | 9980 0 0 2.12 1 16.2 502 | -------------------------------------------------------------------------------- /data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25: -------------------------------------------------------------------------------- 1 | idx label predict radius correct time 2 | 0 0 0 0.165 1 0:02:31.238689 3 | 100 2 2 0.978 1 0:02:28.472100 4 | 200 4 395 0.0504 0 0:02:29.734734 5 | 300 6 6 0.978 1 0:02:29.500387 6 | 400 8 8 0.978 1 0:02:30.348481 7 | 500 10 10 0.849 1 0:02:29.899296 8 | 600 12 12 0.0511 1 0:02:29.689873 9 | 700 14 14 0.978 1 0:02:29.582000 10 | 800 16 16 0.978 1 0:02:29.844638 11 | 900 18 18 0.117 1 0:02:30.467950 12 | 1000 20 20 0.978 1 0:02:30.492715 13 | 1100 22 21 0.271 0 0:02:29.971339 14 | 1200 24 24 0.978 1 0:02:29.882296 15 | 1300 26 26 0.464 1 0:02:30.184608 16 | 1400 28 -1 0.0 0 0:02:29.938915 17 | 1500 30 30 0.287 1 0:02:29.372000 18 | 1600 32 -1 0.0 0 0:02:30.198865 19 | 1700 34 327 0.0542 0 0:02:30.576400 20 | 1800 36 35 0.609 0 0:02:29.825960 21 | 1900 38 38 0.226 1 0:02:29.747150 22 | 2000 40 31 0.0497 0 0:02:29.953930 23 | 2100 42 48 0.839 0 0:02:30.242376 24 | 2200 44 44 0.681 1 0:02:30.185557 25 | 2300 46 46 0.537 1 0:02:30.888912 26 | 2400 48 48 0.17 1 0:02:30.444757 27 | 2500 50 50 0.682 1 0:02:30.507672 28 | 2600 52 52 0.405 1 0:02:29.848361 29 | 2700 54 57 0.978 0 0:02:30.142747 30 | 2800 56 56 0.48 1 0:02:30.734643 31 | 2900 58 -1 0.0 0 0:02:30.443910 32 | 3000 60 62 0.37 0 0:02:30.121695 33 | 3100 62 62 0.449 1 0:02:30.041115 34 | 3200 64 64 0.634 1 0:02:29.907086 35 | 3300 66 63 0.133 0 0:02:30.320275 36 | 3400 68 -1 0.0 0 0:02:30.168557 37 | 3500 70 70 0.978 1 0:02:29.595999 38 | 3600 72 72 0.978 1 0:02:29.768712 39 | 3700 74 74 0.499 1 0:02:30.160280 40 | 3800 76 76 0.641 1 0:02:30.447175 41 | 3900 78 78 0.631 1 0:02:30.357008 42 | 4000 80 80 0.978 1 0:02:30.246374 43 | 4100 82 82 0.978 1 0:02:30.204693 44 | 4200 84 84 0.405 1 0:02:30.568214 45 | 4300 86 86 0.571 1 0:02:30.145539 46 | 4400 88 88 0.128 1 0:02:29.993359 47 | 4500 90 90 0.74 1 0:02:30.215480 48 | 4600 92 92 0.503 1 0:02:29.838759 49 | 4700 94 94 0.978 1 0:02:29.745489 50 | 4800 96 96 0.978 1 0:02:30.194775 51 | 4900 98 98 0.722 1 0:02:29.985465 52 | 5000 100 100 0.978 1 0:02:30.678385 53 | 5100 102 102 0.978 1 0:02:30.334541 54 | 5200 104 342 0.216 0 0:02:30.607674 55 | 5300 106 106 0.561 1 0:02:30.052348 56 | 5400 108 470 0.256 0 0:02:30.179028 57 | 5500 110 110 0.978 1 0:02:29.893491 58 | 5600 112 112 0.906 1 0:02:30.182807 59 | 5700 114 113 0.391 0 0:02:30.503177 60 | 5800 116 116 0.978 1 0:02:30.496567 61 | 5900 118 118 0.662 1 0:02:30.373342 62 | 6000 120 119 0.0777 0 0:02:30.121950 63 | 6100 122 122 0.978 1 0:02:30.427012 64 | 6200 124 123 0.113 0 0:02:30.054732 65 | 6300 126 126 0.528 1 0:02:30.416313 66 | 6400 128 128 0.978 1 0:02:30.562088 67 | 6500 130 130 0.978 1 0:02:30.231175 68 | 6600 132 132 0.978 1 0:02:30.159542 69 | 6700 134 132 0.208 0 0:02:29.646543 70 | 6800 136 136 0.978 1 0:02:30.380730 71 | 6900 138 138 0.978 1 0:02:30.572920 72 | 7000 140 140 0.978 1 0:02:30.700227 73 | 7100 142 142 0.814 1 0:02:30.641635 74 | 7200 144 144 0.978 1 0:02:30.563427 75 | 7300 146 146 0.0569 1 0:02:30.143445 76 | 7400 148 148 0.978 1 0:02:30.023675 77 | 7500 150 360 0.498 0 0:02:30.621709 78 | 7600 152 152 0.793 1 0:02:30.327601 79 | 7700 154 154 0.913 1 0:02:30.553908 80 | 7800 156 156 0.978 1 0:02:30.480134 81 | 7900 158 158 0.685 1 0:02:30.272967 82 | 8000 160 160 0.978 1 0:02:30.440335 83 | 8100 162 162 0.692 1 0:02:29.968419 84 | 8200 164 164 0.25 1 0:02:31.014798 85 | 8300 166 166 0.978 1 0:02:30.850247 86 | 8400 168 211 0.0559 0 0:02:30.812654 87 | 8500 170 172 0.345 0 0:02:30.567950 88 | 8600 172 172 0.623 1 0:02:31.028928 89 | 8700 174 174 0.687 1 0:02:30.706730 90 | 8800 176 176 0.388 1 0:02:30.524550 91 | 8900 178 178 0.978 1 0:02:30.185986 92 | 9000 180 243 0.285 0 0:02:30.189492 93 | 9100 182 182 0.0441 1 0:02:29.980940 94 | 9200 184 -1 0.0 0 0:02:30.320398 95 | 9300 186 192 0.596 0 0:02:30.886985 96 | 9400 188 188 0.165 1 0:02:30.411439 97 | 9500 190 190 0.978 1 0:02:29.880804 98 | 9600 192 199 0.672 0 0:02:30.719151 99 | 9700 194 -1 0.0 0 0:02:30.716428 100 | 9800 196 199 0.978 0 0:02:30.519637 101 | 9900 198 196 0.0359 0 0:02:30.775083 102 | 10000 200 200 0.817 1 0:02:30.284424 103 | 10100 202 202 0.955 1 0:02:29.999943 104 | 10200 204 204 0.247 1 0:02:30.015042 105 | 10300 206 206 0.978 1 0:02:29.928960 106 | 10400 208 208 0.395 1 0:02:30.028211 107 | 10500 210 210 0.51 1 0:02:30.097306 108 | 10600 212 212 0.219 1 0:02:30.861089 109 | 10700 214 214 0.577 1 0:02:30.404881 110 | 10800 216 216 0.846 1 0:02:30.240213 111 | 10900 218 218 0.978 1 0:02:30.392252 112 | 11000 220 220 0.535 1 0:02:30.464677 113 | 11100 222 222 0.0577 1 0:02:30.620631 114 | 11200 224 233 0.117 0 0:02:30.720985 115 | 11300 226 226 0.69 1 0:02:30.625103 116 | 11400 228 228 0.978 1 0:02:30.223536 117 | 11500 230 230 0.188 1 0:02:30.951324 118 | 11600 232 232 0.941 1 0:02:30.144850 119 | 11700 234 234 0.437 1 0:02:31.048280 120 | 11800 236 234 0.0155 0 0:02:30.477529 121 | 11900 238 238 0.712 1 0:02:30.344499 122 | 12000 240 238 0.691 0 0:02:30.270159 123 | 12100 242 242 0.786 1 0:02:30.272286 124 | 12200 244 244 0.978 1 0:02:30.038246 125 | 12300 246 251 0.61 0 0:02:29.815927 126 | 12400 248 248 0.0843 1 0:02:29.774322 127 | 12500 250 169 0.238 0 0:02:30.566967 128 | 12600 252 252 0.726 1 0:02:29.879580 129 | 12700 254 254 0.235 1 0:02:29.971380 130 | 12800 256 256 0.978 1 0:02:29.825671 131 | 12900 258 258 0.517 1 0:02:30.061793 132 | 13000 260 260 0.537 1 0:02:30.049401 133 | 13100 262 262 0.978 1 0:02:30.337835 134 | 13200 264 264 0.307 1 0:02:30.565490 135 | 13300 266 266 0.175 1 0:02:29.914956 136 | 13400 268 268 0.978 1 0:02:29.795518 137 | 13500 270 270 0.978 1 0:02:30.268059 138 | 13600 272 274 0.306 0 0:02:30.197188 139 | 13700 274 274 0.978 1 0:02:30.016229 140 | 13800 276 276 0.978 1 0:02:29.347983 141 | 13900 278 278 0.707 1 0:02:30.143017 142 | 14000 280 280 0.623 1 0:02:30.471658 143 | 14100 282 282 0.929 1 0:02:29.568346 144 | 14200 284 284 0.978 1 0:02:29.389408 145 | 14300 286 286 0.978 1 0:02:29.676501 146 | 14400 288 288 0.716 1 0:02:30.132765 147 | 14500 290 290 0.978 1 0:02:29.294689 148 | 14600 292 290 0.609 0 0:02:30.056545 149 | 14700 294 294 0.649 1 0:02:30.130922 150 | 14800 296 296 0.638 1 0:02:29.760217 151 | 14900 298 -1 0.0 0 0:02:30.193345 152 | 15000 300 300 0.978 1 0:02:29.807024 153 | 15100 302 302 0.0493 1 0:02:30.157539 154 | 15200 304 301 0.978 0 0:02:30.279750 155 | 15300 306 306 0.978 1 0:02:29.992867 156 | 15400 308 308 0.844 1 0:02:30.478700 157 | 15500 310 310 0.978 1 0:02:29.860750 158 | 15600 312 -1 0.0 0 0:02:29.935535 159 | 15700 314 314 0.735 1 0:02:30.038218 160 | 15800 316 316 0.978 1 0:02:30.223525 161 | 15900 318 318 0.46 1 0:02:30.016658 162 | 16000 320 320 0.978 1 0:02:30.401388 163 | 16100 322 322 0.754 1 0:02:30.413642 164 | 16200 324 324 0.978 1 0:02:30.292821 165 | 16300 326 326 0.955 1 0:02:30.410080 166 | 16400 328 328 0.978 1 0:02:30.249984 167 | 16500 330 330 0.405 1 0:02:29.675360 168 | 16600 332 -1 0.0 0 0:02:29.048663 169 | 16700 334 334 0.978 1 0:02:30.071187 170 | 16800 336 336 0.844 1 0:02:29.878231 171 | 16900 338 617 0.464 0 0:02:30.147896 172 | 17000 340 340 0.978 1 0:02:30.650256 173 | 17100 342 287 0.16 0 0:02:30.627570 174 | 17200 344 344 0.941 1 0:02:30.470581 175 | 17300 346 344 0.0442 0 0:02:29.950092 176 | 17400 348 348 0.797 1 0:02:30.853736 177 | 17500 350 350 0.978 1 0:02:30.505603 178 | 17600 352 352 0.857 1 0:02:30.227803 179 | 17700 354 354 0.978 1 0:02:30.097232 180 | 17800 356 -1 0.0 0 0:02:29.830953 181 | 17900 358 359 0.181 0 0:02:29.647446 182 | 18000 360 360 0.863 1 0:02:30.357880 183 | 18100 362 362 0.978 1 0:02:30.611909 184 | 18200 364 364 0.718 1 0:02:30.395142 185 | 18300 366 366 0.407 1 0:02:30.572165 186 | 18400 368 368 0.955 1 0:02:30.181935 187 | 18500 370 370 0.978 1 0:02:30.173064 188 | 18600 372 372 0.529 1 0:02:30.302947 189 | 18700 374 -1 0.0 0 0:02:30.235269 190 | 18800 376 376 0.772 1 0:02:30.473644 191 | 18900 378 378 0.235 1 0:02:30.372144 192 | 19000 380 380 0.837 1 0:02:30.160681 193 | 19100 382 382 0.978 1 0:02:30.570127 194 | 19200 384 -1 0.0 0 0:02:30.820962 195 | 19300 386 101 0.402 0 0:02:29.788515 196 | 19400 388 388 0.541 1 0:02:29.797744 197 | 19500 390 390 0.775 1 0:02:30.239534 198 | 19600 392 397 0.562 0 0:02:30.206156 199 | 19700 394 -1 0.0 0 0:02:30.558219 200 | 19800 396 396 0.978 1 0:02:30.153110 201 | 19900 398 398 0.978 1 0:02:30.476912 202 | 20000 400 400 0.186 1 0:02:30.401398 203 | 20100 402 402 0.245 1 0:02:30.091169 204 | 20200 404 404 0.978 1 0:02:30.319886 205 | 20300 406 406 0.607 1 0:02:29.894624 206 | 20400 408 847 0.57 0 0:02:30.479393 207 | 20500 410 410 0.0798 1 0:02:30.255137 208 | 20600 412 412 0.331 1 0:02:30.360466 209 | 20700 414 414 0.39 1 0:02:29.998817 210 | 20800 416 416 0.88 1 0:02:30.508972 211 | 20900 418 563 0.978 0 0:02:30.698495 212 | 21000 420 420 0.242 1 0:02:30.715746 213 | 21100 422 422 0.261 1 0:02:30.667740 214 | 21200 424 -1 0.0 0 0:02:30.284119 215 | 21300 426 426 0.978 1 0:02:30.628197 216 | 21400 428 428 0.978 1 0:02:30.471429 217 | 21500 430 430 0.51 1 0:02:29.978379 218 | 21600 432 432 0.978 1 0:02:29.980470 219 | 21700 434 -1 0.0 0 0:02:30.307042 220 | 21800 436 436 0.978 1 0:02:30.765092 221 | 21900 438 438 0.262 1 0:02:30.637522 222 | 22000 440 737 0.0713 0 0:02:30.458820 223 | 22100 442 442 0.929 1 0:02:30.332701 224 | 22200 444 -1 0.0 0 0:02:30.527043 225 | 22300 446 446 0.929 1 0:02:29.991854 226 | 22400 448 448 0.543 1 0:02:30.389721 227 | 22500 450 407 0.154 0 0:02:29.747789 228 | 22600 452 452 0.913 1 0:02:30.578689 229 | 22700 454 454 0.9 1 0:02:30.596173 230 | 22800 456 -1 0.0 0 0:02:29.819309 231 | 22900 458 458 0.9 1 0:02:29.640807 232 | 23000 460 978 0.12 0 0:02:29.570771 233 | 23100 462 462 0.978 1 0:02:29.782221 234 | 23200 464 439 0.522 0 0:02:30.267272 235 | 23300 466 466 0.978 1 0:02:30.125070 236 | 23400 468 468 0.978 1 0:02:29.601588 237 | 23500 470 624 0.0437 0 0:02:29.220088 238 | 23600 472 472 0.978 1 0:02:30.266124 239 | 23700 474 841 0.14 0 0:02:30.019980 240 | 23800 476 476 0.929 1 0:02:29.595922 241 | 23900 478 478 0.39 1 0:02:29.594913 242 | 24000 480 509 0.509 0 0:02:30.068407 243 | 24100 482 481 0.238 0 0:02:30.223905 244 | 24200 484 871 0.842 0 0:02:30.230005 245 | 24300 486 486 0.104 1 0:02:30.039842 246 | 24400 488 488 0.468 1 0:02:30.693576 247 | 24500 490 490 0.629 1 0:02:30.275327 248 | 24600 492 492 0.826 1 0:02:30.023680 249 | 24700 494 398 0.941 0 0:02:30.105409 250 | 24800 496 496 0.131 1 0:02:30.163059 251 | 24900 498 498 0.978 1 0:02:30.074041 252 | 25000 500 500 0.978 1 0:02:29.717768 253 | 25100 502 502 0.873 1 0:02:29.803116 254 | 25200 504 504 0.0495 1 0:02:30.441310 255 | 25300 506 506 0.653 1 0:02:30.144799 256 | 25400 508 508 0.978 1 0:02:29.834378 257 | 25500 510 510 0.978 1 0:02:30.429432 258 | 25600 512 740 0.254 0 0:02:29.822186 259 | 25700 514 514 0.978 1 0:02:29.765429 260 | 25800 516 431 0.0175 0 0:02:29.624502 261 | 25900 518 518 0.978 1 0:02:29.319035 262 | 26000 520 520 0.978 1 0:02:29.907698 263 | 26100 522 522 0.978 1 0:02:30.069973 264 | 26200 524 461 0.664 0 0:02:29.917895 265 | 26300 526 526 0.0395 1 0:02:30.180328 266 | 26400 528 -1 0.0 0 0:02:30.060835 267 | 26500 530 531 0.978 0 0:02:30.186140 268 | 26600 532 532 0.142 1 0:02:30.598862 269 | 26700 534 -1 0.0 0 0:02:30.541164 270 | 26800 536 403 0.755 0 0:02:30.070715 271 | 26900 538 538 0.978 1 0:02:30.112502 272 | 27000 540 540 0.978 1 0:02:30.576435 273 | 27100 542 -1 0.0 0 0:02:30.259104 274 | 27200 544 926 0.371 0 0:02:30.055312 275 | 27300 546 546 0.718 1 0:02:30.321178 276 | 27400 548 548 0.9 1 0:02:30.384792 277 | 27500 550 505 0.251 0 0:02:30.494111 278 | 27600 552 552 0.978 1 0:02:30.445498 279 | 27700 554 554 0.802 1 0:02:30.662148 280 | 27800 556 421 0.0586 0 0:02:30.567910 281 | 27900 558 251 0.228 0 0:02:30.566867 282 | 28000 560 768 0.254 0 0:02:30.054083 283 | 28100 562 562 0.978 1 0:02:30.009979 284 | 28200 564 564 0.189 1 0:02:30.454839 285 | 28300 566 566 0.831 1 0:02:30.514242 286 | 28400 568 399 0.163 0 0:02:30.355165 287 | 28500 570 -1 0.0 0 0:02:30.256466 288 | 28600 572 418 0.123 0 0:02:29.945449 289 | 28700 574 574 0.978 1 0:02:29.825920 290 | 28800 576 -1 0.0 0 0:02:30.626047 291 | 28900 578 578 0.457 1 0:02:29.899034 292 | 29000 580 738 0.978 0 0:02:30.362939 293 | 29100 582 -1 0.0 0 0:02:30.624694 294 | 29200 584 754 0.102 0 0:02:30.366516 295 | 29300 586 586 0.955 1 0:02:30.131823 296 | 29400 588 588 0.681 1 0:02:30.329224 297 | 29500 590 590 0.71 1 0:02:29.902602 298 | 29600 592 592 0.978 1 0:02:30.362917 299 | 29700 594 843 0.378 0 0:02:30.272107 300 | 29800 596 56 0.0281 0 0:02:29.961047 301 | 29900 598 664 0.603 0 0:02:29.538639 302 | 30000 600 517 0.205 0 0:02:29.747167 303 | 30100 602 602 0.978 1 0:02:29.913053 304 | 30200 604 604 0.366 1 0:02:29.841973 305 | 30300 606 606 0.978 1 0:02:30.086977 306 | 30400 608 608 0.441 1 0:02:30.778591 307 | 30500 610 800 0.138 0 0:02:30.530804 308 | 30600 612 612 0.978 1 0:02:29.529376 309 | 30700 614 614 0.805 1 0:02:29.651345 310 | 30800 616 534 0.121 0 0:02:30.312252 311 | 30900 618 -1 0.0 0 0:02:30.016361 312 | 31000 620 681 0.116 0 0:02:30.335139 313 | 31100 622 622 0.978 1 0:02:29.938906 314 | 31200 624 453 0.851 0 0:02:29.743702 315 | 31300 626 542 0.127 0 0:02:30.080677 316 | 31400 628 -1 0.0 0 0:02:29.638628 317 | 31500 630 703 0.672 0 0:02:30.238523 318 | 31600 632 632 0.866 1 0:02:30.075643 319 | 31700 634 489 0.978 0 0:02:30.278016 320 | 31800 636 636 0.616 1 0:02:30.159042 321 | 31900 638 638 0.92 1 0:02:30.447832 322 | 32000 640 640 0.978 1 0:02:30.080730 323 | 32100 642 642 0.978 1 0:02:29.710235 324 | 32200 644 644 0.978 1 0:02:30.165180 325 | 32300 646 646 0.978 1 0:02:29.225623 326 | 32400 648 -1 0.0 0 0:02:30.410912 327 | 32500 650 546 0.0652 0 0:02:30.158323 328 | 32600 652 652 0.553 1 0:02:29.836980 329 | 32700 654 436 0.103 0 0:02:29.863310 330 | 32800 656 -1 0.0 0 0:02:30.554843 331 | 32900 658 658 0.978 1 0:02:30.161191 332 | 33000 660 660 0.437 1 0:02:29.909711 333 | 33100 662 662 0.0149 1 0:02:29.551033 334 | 33200 664 851 0.117 0 0:02:29.760597 335 | 33300 666 659 0.259 0 0:02:29.734256 336 | 33400 668 -1 0.0 0 0:02:30.239919 337 | 33500 670 670 0.978 1 0:02:30.034608 338 | 33600 672 672 0.353 1 0:02:30.049591 339 | 33700 674 674 0.426 1 0:02:30.018736 340 | 33800 676 560 0.0283 0 0:02:30.365235 341 | 33900 678 678 0.978 1 0:02:30.710799 342 | 34000 680 -1 0.0 0 0:02:29.993227 343 | 34100 682 682 0.299 1 0:02:29.944957 344 | 34200 684 684 0.929 1 0:02:29.973918 345 | 34300 686 -1 0.0 0 0:02:29.893355 346 | 34400 688 688 0.978 1 0:02:29.698450 347 | 34500 690 -1 0.0 0 0:02:29.653637 348 | 34600 692 -1 0.0 0 0:02:29.897104 349 | 34700 694 694 0.978 1 0:02:29.589318 350 | 34800 696 -1 0.0 0 0:02:29.911197 351 | 34900 698 698 0.0859 1 0:02:29.422916 352 | 35000 700 700 0.459 1 0:02:29.795737 353 | 35100 702 422 0.846 0 0:02:30.083227 354 | 35200 704 704 0.731 1 0:02:30.325363 355 | 35300 706 536 0.217 0 0:02:30.080998 356 | 35400 708 682 0.794 0 0:02:30.274108 357 | 35500 710 710 0.366 1 0:02:30.420304 358 | 35600 712 -1 0.0 0 0:02:30.628211 359 | 35700 714 714 0.597 1 0:02:30.303162 360 | 35800 716 716 0.955 1 0:02:29.972523 361 | 35900 718 449 0.212 0 0:02:29.445828 362 | 36000 720 720 0.372 1 0:02:30.154242 363 | 36100 722 722 0.978 1 0:02:30.139431 364 | 36200 724 780 0.978 0 0:02:30.145026 365 | 36300 726 726 0.978 1 0:02:30.133129 366 | 36400 728 728 0.0609 1 0:02:30.022333 367 | 36500 730 856 0.859 0 0:02:30.329501 368 | 36600 732 732 0.978 1 0:02:30.794844 369 | 36700 734 734 0.978 1 0:02:30.110396 370 | 36800 736 736 0.978 1 0:02:30.063162 371 | 36900 738 738 0.617 1 0:02:30.438124 372 | 37000 740 -1 0.0 0 0:02:30.778301 373 | 37100 742 -1 0.0 0 0:02:30.334797 374 | 37200 744 657 0.913 0 0:02:30.138139 375 | 37300 746 746 0.978 1 0:02:29.461292 376 | 37400 748 748 0.978 1 0:02:29.794650 377 | 37500 750 172 0.016 0 0:02:30.684972 378 | 37600 752 -1 0.0 0 0:02:30.102833 379 | 37700 754 848 0.331 0 0:02:29.873319 380 | 37800 756 756 0.978 1 0:02:29.837130 381 | 37900 758 758 0.527 1 0:02:30.398277 382 | 38000 760 651 0.246 0 0:02:30.244316 383 | 38100 762 762 0.455 1 0:02:30.474714 384 | 38200 764 593 0.0208 0 0:02:30.163622 385 | 38300 766 766 0.873 1 0:02:30.341681 386 | 38400 768 768 0.56 1 0:02:30.791192 387 | 38500 770 770 0.208 1 0:02:30.626651 388 | 38600 772 772 0.773 1 0:02:30.237842 389 | 38700 774 774 0.906 1 0:02:30.539333 390 | 38800 776 683 0.206 0 0:02:30.403928 391 | 38900 778 707 0.273 0 0:02:30.200670 392 | 39000 780 780 0.103 1 0:02:30.389691 393 | 39100 782 664 0.566 0 0:02:30.260540 394 | 39200 784 505 0.0573 0 0:02:30.160525 395 | 39300 786 -1 0.0 0 0:02:30.366705 396 | 39400 788 743 0.101 0 0:02:30.461723 397 | 39500 790 791 0.271 0 0:02:30.898356 398 | 39600 792 795 0.195 0 0:02:30.090714 399 | 39700 794 794 0.831 1 0:02:30.063199 400 | 39800 796 796 0.978 1 0:02:30.189355 401 | 39900 798 798 0.978 1 0:02:30.551748 402 | 40000 800 800 0.978 1 0:02:30.192058 403 | 40100 802 802 0.685 1 0:02:29.962413 404 | 40200 804 804 0.213 1 0:02:29.795746 405 | 40300 806 806 0.978 1 0:02:30.063303 406 | 40400 808 808 0.955 1 0:02:29.804598 407 | 40500 810 878 0.33 0 0:02:29.777533 408 | 40600 812 812 0.978 1 0:02:29.472982 409 | 40700 814 814 0.172 1 0:02:29.290164 410 | 40800 816 816 0.978 1 0:02:30.354582 411 | 40900 818 745 0.523 0 0:02:29.823832 412 | 41000 820 820 0.262 1 0:02:29.811310 413 | 41100 822 822 0.978 1 0:02:30.542749 414 | 41200 824 824 0.696 1 0:02:30.533368 415 | 41300 826 826 0.47 1 0:02:30.509891 416 | 41400 828 618 0.0484 0 0:02:30.638973 417 | 41500 830 830 0.209 1 0:02:29.854785 418 | 41600 832 832 0.443 1 0:02:30.028128 419 | 41700 834 417 0.337 0 0:02:30.216051 420 | 41800 836 444 0.423 0 0:02:30.478095 421 | 41900 838 631 0.92 0 0:02:30.660583 422 | 42000 840 840 0.978 1 0:02:29.944967 423 | 42100 842 842 0.0816 1 0:02:29.631471 424 | 42200 844 844 0.978 1 0:02:30.017588 425 | 42300 846 583 0.192 0 0:02:29.849304 426 | 42400 848 848 0.455 1 0:02:30.021253 427 | 42500 850 850 0.423 1 0:02:30.437045 428 | 42600 852 852 0.978 1 0:02:30.529743 429 | -------------------------------------------------------------------------------- /data/predict/imagenet/resnet50/noise_0.25/test/N_100: -------------------------------------------------------------------------------- 1 | idx label predict correct time 2 | 0 0 0 1 0:00:01.737771 3 | 100 2 2 1 0:00:00.158496 4 | 200 4 395 0 0:00:00.149257 5 | 300 6 6 1 0:00:00.149002 6 | 400 8 8 1 0:00:00.149901 7 | 500 10 10 1 0:00:00.148993 8 | 600 12 12 1 0:00:00.150077 9 | 700 14 14 1 0:00:00.149649 10 | 800 16 16 1 0:00:00.148875 11 | 900 18 18 1 0:00:00.148872 12 | 1000 20 20 1 0:00:00.148681 13 | 1100 22 21 0 0:00:00.148832 14 | 1200 24 24 1 0:00:00.148842 15 | 1300 26 26 1 0:00:00.149388 16 | 1400 28 28 1 0:00:00.149335 17 | 1500 30 30 1 0:00:00.149273 18 | 1600 32 -1 0 0:00:00.150587 19 | 1700 34 327 0 0:00:00.149775 20 | 1800 36 35 0 0:00:00.149858 21 | 1900 38 38 1 0:00:00.150131 22 | 2000 40 -1 0 0:00:00.149192 23 | 2100 42 48 0 0:00:00.149575 24 | 2200 44 44 1 0:00:00.150103 25 | 2300 46 46 1 0:00:00.150445 26 | 2400 48 48 1 0:00:00.148776 27 | 2500 50 50 1 0:00:00.148915 28 | 2600 52 52 1 0:00:00.148730 29 | 2700 54 57 0 0:00:00.148792 30 | 2800 56 56 1 0:00:00.149241 31 | 2900 58 -1 0 0:00:00.149086 32 | 3000 60 62 0 0:00:00.149118 33 | 3100 62 62 1 0:00:00.149509 34 | 3200 64 64 1 0:00:00.150243 35 | 3300 66 63 0 0:00:00.149570 36 | 3400 68 -1 0 0:00:00.149943 37 | 3500 70 70 1 0:00:00.149646 38 | 3600 72 72 1 0:00:00.150097 39 | 3700 74 74 1 0:00:00.150264 40 | 3800 76 76 1 0:00:00.149668 41 | 3900 78 78 1 0:00:00.148759 42 | 4000 80 80 1 0:00:00.148977 43 | 4100 82 82 1 0:00:00.149063 44 | 4200 84 84 1 0:00:00.149432 45 | 4300 86 86 1 0:00:00.149689 46 | 4400 88 88 1 0:00:00.149581 47 | 4500 90 90 1 0:00:00.149732 48 | 4600 92 92 1 0:00:00.150014 49 | 4700 94 94 1 0:00:00.149836 50 | 4800 96 96 1 0:00:00.149859 51 | 4900 98 98 1 0:00:00.149908 52 | 5000 100 100 1 0:00:00.150221 53 | 5100 102 102 1 0:00:00.150292 54 | 5200 104 342 0 0:00:00.149715 55 | 5300 106 106 1 0:00:00.149746 56 | 5400 108 470 0 0:00:00.149747 57 | 5500 110 110 1 0:00:00.149787 58 | 5600 112 112 1 0:00:00.150302 59 | 5700 114 113 0 0:00:00.149334 60 | 5800 116 116 1 0:00:00.149520 61 | 5900 118 118 1 0:00:00.149666 62 | 6000 120 119 0 0:00:00.149083 63 | 6100 122 122 1 0:00:00.149483 64 | 6200 124 123 0 0:00:00.150204 65 | 6300 126 126 1 0:00:00.149876 66 | 6400 128 128 1 0:00:00.150150 67 | 6500 130 130 1 0:00:00.149888 68 | 6600 132 132 1 0:00:00.148795 69 | 6700 134 132 0 0:00:00.149421 70 | 6800 136 136 1 0:00:00.149319 71 | 6900 138 138 1 0:00:00.149000 72 | 7000 140 140 1 0:00:00.149684 73 | 7100 142 142 1 0:00:00.149306 74 | 7200 144 144 1 0:00:00.149648 75 | 7300 146 146 1 0:00:00.149898 76 | 7400 148 148 1 0:00:00.150061 77 | 7500 150 360 0 0:00:00.150225 78 | 7600 152 152 1 0:00:00.150093 79 | 7700 154 154 1 0:00:00.150203 80 | 7800 156 156 1 0:00:00.150502 81 | 7900 158 158 1 0:00:00.150271 82 | 8000 160 160 1 0:00:00.150902 83 | 8100 162 162 1 0:00:00.150535 84 | 8200 164 164 1 0:00:00.150811 85 | 8300 166 166 1 0:00:00.150555 86 | 8400 168 211 0 0:00:00.149288 87 | 8500 170 172 0 0:00:00.149561 88 | 8600 172 172 1 0:00:00.149710 89 | 8700 174 174 1 0:00:00.149799 90 | 8800 176 176 1 0:00:00.149869 91 | 8900 178 178 1 0:00:00.149695 92 | 9000 180 243 0 0:00:00.150275 93 | 9100 182 182 1 0:00:00.150397 94 | 9200 184 -1 0 0:00:00.150378 95 | 9300 186 192 0 0:00:00.150749 96 | 9400 188 188 1 0:00:00.150607 97 | 9500 190 190 1 0:00:00.150677 98 | 9600 192 199 0 0:00:00.150700 99 | 9700 194 202 0 0:00:00.150573 100 | 9800 196 199 0 0:00:00.150710 101 | 9900 198 -1 0 0:00:00.150579 102 | 10000 200 200 1 0:00:00.149348 103 | 10100 202 202 1 0:00:00.150144 104 | 10200 204 204 1 0:00:00.149713 105 | 10300 206 206 1 0:00:00.149793 106 | 10400 208 208 1 0:00:00.149739 107 | 10500 210 210 1 0:00:00.150280 108 | 10600 212 212 1 0:00:00.150422 109 | 10700 214 214 1 0:00:00.150694 110 | 10800 216 216 1 0:00:00.150631 111 | 10900 218 218 1 0:00:00.150435 112 | 11000 220 220 1 0:00:00.151002 113 | 11100 222 -1 0 0:00:00.150712 114 | 11200 224 233 0 0:00:00.151255 115 | 11300 226 226 1 0:00:00.150797 116 | 11400 228 228 1 0:00:00.150893 117 | 11500 230 230 1 0:00:00.150637 118 | 11600 232 232 1 0:00:00.151438 119 | 11700 234 234 1 0:00:00.150964 120 | 11800 236 234 0 0:00:00.149482 121 | 11900 238 238 1 0:00:00.150126 122 | 12000 240 238 0 0:00:00.149885 123 | 12100 242 242 1 0:00:00.149924 124 | 12200 244 244 1 0:00:00.149931 125 | 12300 246 251 0 0:00:00.150398 126 | 12400 248 -1 0 0:00:00.150118 127 | 12500 250 169 0 0:00:00.150278 128 | 12600 252 252 1 0:00:00.151481 129 | 12700 254 254 1 0:00:00.150805 130 | 12800 256 256 1 0:00:00.150324 131 | 12900 258 258 1 0:00:00.150588 132 | 13000 260 260 1 0:00:00.151438 133 | 13100 262 262 1 0:00:00.150877 134 | 13200 264 264 1 0:00:00.151444 135 | 13300 266 266 1 0:00:00.150706 136 | 13400 268 268 1 0:00:00.150965 137 | 13500 270 270 1 0:00:00.150461 138 | 13600 272 274 0 0:00:00.149865 139 | 13700 274 274 1 0:00:00.150135 140 | 13800 276 276 1 0:00:00.150259 141 | 13900 278 278 1 0:00:00.149917 142 | 14000 280 280 1 0:00:00.150193 143 | 14100 282 282 1 0:00:00.151167 144 | 14200 284 284 1 0:00:00.151545 145 | 14300 286 286 1 0:00:00.151211 146 | 14400 288 288 1 0:00:00.151570 147 | 14500 290 290 1 0:00:00.151209 148 | 14600 292 290 0 0:00:00.150754 149 | 14700 294 294 1 0:00:00.150950 150 | 14800 296 296 1 0:00:00.151226 151 | 14900 298 298 1 0:00:00.150931 152 | 15000 300 300 1 0:00:00.151120 153 | 15100 302 -1 0 0:00:00.151035 154 | 15200 304 301 0 0:00:00.149800 155 | 15300 306 306 1 0:00:00.150287 156 | 15400 308 308 1 0:00:00.150084 157 | 15500 310 310 1 0:00:00.150211 158 | 15600 312 -1 0 0:00:00.150407 159 | 15700 314 314 1 0:00:00.150743 160 | 15800 316 316 1 0:00:00.151449 161 | 15900 318 318 1 0:00:00.151332 162 | 16000 320 320 1 0:00:00.151387 163 | 16100 322 322 1 0:00:00.151472 164 | 16200 324 324 1 0:00:00.150631 165 | 16300 326 326 1 0:00:00.151067 166 | 16400 328 328 1 0:00:00.150776 167 | 16500 330 330 1 0:00:00.150908 168 | 16600 332 -1 0 0:00:00.150953 169 | 16700 334 334 1 0:00:00.150952 170 | 16800 336 336 1 0:00:00.151640 171 | 16900 338 617 0 0:00:00.151710 172 | 17000 340 340 1 0:00:00.149796 173 | 17100 342 287 0 0:00:00.150234 174 | 17200 344 344 1 0:00:00.150386 175 | 17300 346 344 0 0:00:00.150476 176 | 17400 348 348 1 0:00:00.150613 177 | 17500 350 350 1 0:00:00.150906 178 | 17600 352 352 1 0:00:00.150985 179 | 17700 354 354 1 0:00:00.151042 180 | 17800 356 -1 0 0:00:00.151376 181 | 17900 358 359 0 0:00:00.151409 182 | 18000 360 360 1 0:00:00.151763 183 | 18100 362 362 1 0:00:00.151103 184 | 18200 364 364 1 0:00:00.151692 185 | 18300 366 366 1 0:00:00.151635 186 | 18400 368 368 1 0:00:00.151221 187 | 18500 370 370 1 0:00:00.152012 188 | 18600 372 372 1 0:00:00.151617 189 | 18700 374 -1 0 0:00:00.151739 190 | 18800 376 376 1 0:00:00.150095 191 | 18900 378 378 1 0:00:00.150199 192 | 19000 380 380 1 0:00:00.150291 193 | 19100 382 382 1 0:00:00.150382 194 | 19200 384 -1 0 0:00:00.150969 195 | 19300 386 101 0 0:00:00.151030 196 | 19400 388 388 1 0:00:00.151301 197 | 19500 390 390 1 0:00:00.151015 198 | 19600 392 397 0 0:00:00.151571 199 | 19700 394 -1 0 0:00:00.151451 200 | 19800 396 396 1 0:00:00.151190 201 | 19900 398 398 1 0:00:00.151696 202 | 20000 400 400 1 0:00:00.151555 203 | 20100 402 402 1 0:00:00.151069 204 | 20200 404 404 1 0:00:00.152052 205 | 20300 406 406 1 0:00:00.151742 206 | 20400 408 847 0 0:00:00.151591 207 | 20500 410 410 1 0:00:00.151356 208 | 20600 412 412 1 0:00:00.151769 209 | 20700 414 414 1 0:00:00.151783 210 | 20800 416 416 1 0:00:00.151651 211 | 20900 418 563 0 0:00:00.151843 212 | 21000 420 420 1 0:00:00.151830 213 | 21100 422 422 1 0:00:00.151571 214 | 21200 424 -1 0 0:00:00.151089 215 | 21300 426 426 1 0:00:00.151694 216 | 21400 428 428 1 0:00:00.151853 217 | 21500 430 430 1 0:00:00.151803 218 | 21600 432 432 1 0:00:00.150823 219 | 21700 434 -1 0 0:00:00.150586 220 | 21800 436 436 1 0:00:00.150602 221 | 21900 438 438 1 0:00:00.150879 222 | 22000 440 737 0 0:00:00.151455 223 | 22100 442 442 1 0:00:00.151263 224 | 22200 444 870 0 0:00:00.151276 225 | 22300 446 446 1 0:00:00.151532 226 | 22400 448 448 1 0:00:00.151578 227 | 22500 450 407 0 0:00:00.151591 228 | 22600 452 452 1 0:00:00.151208 229 | 22700 454 454 1 0:00:00.152079 230 | 22800 456 -1 0 0:00:00.151654 231 | 22900 458 458 1 0:00:00.151926 232 | 23000 460 978 0 0:00:00.151490 233 | 23100 462 462 1 0:00:00.152122 234 | 23200 464 439 0 0:00:00.151895 235 | 23300 466 466 1 0:00:00.151898 236 | 23400 468 468 1 0:00:00.150604 237 | 23500 470 624 0 0:00:00.150587 238 | 23600 472 472 1 0:00:00.150816 239 | 23700 474 841 0 0:00:00.150733 240 | 23800 476 476 1 0:00:00.151181 241 | 23900 478 478 1 0:00:00.151568 242 | 24000 480 509 0 0:00:00.151719 243 | 24100 482 481 0 0:00:00.151719 244 | 24200 484 871 0 0:00:00.152006 245 | 24300 486 486 1 0:00:00.151621 246 | 24400 488 488 1 0:00:00.152247 247 | 24500 490 490 1 0:00:00.152318 248 | 24600 492 492 1 0:00:00.151755 249 | 24700 494 398 0 0:00:00.152137 250 | 24800 496 496 1 0:00:00.152608 251 | 24900 498 498 1 0:00:00.151824 252 | 25000 500 500 1 0:00:00.151747 253 | 25100 502 502 1 0:00:00.152223 254 | 25200 504 -1 0 0:00:00.152050 255 | 25300 506 506 1 0:00:00.151769 256 | 25400 508 508 1 0:00:00.152360 257 | 25500 510 510 1 0:00:00.150740 258 | 25600 512 740 0 0:00:00.150476 259 | 25700 514 514 1 0:00:00.151328 260 | 25800 516 431 0 0:00:00.150751 261 | 25900 518 518 1 0:00:00.150894 262 | 26000 520 520 1 0:00:00.151777 263 | 26100 522 522 1 0:00:00.151428 264 | 26200 524 461 0 0:00:00.151602 265 | 26300 526 -1 0 0:00:00.151999 266 | 26400 528 -1 0 0:00:00.152343 267 | 26500 530 531 0 0:00:00.151699 268 | 26600 532 532 1 0:00:00.152420 269 | 26700 534 534 1 0:00:00.152150 270 | 26800 536 403 0 0:00:00.152110 271 | 26900 538 538 1 0:00:00.151867 272 | 27000 540 540 1 0:00:00.151880 273 | 27100 542 -1 0 0:00:00.151522 274 | 27200 544 926 0 0:00:00.151950 275 | 27300 546 546 1 0:00:00.152192 276 | 27400 548 548 1 0:00:00.151759 277 | 27500 550 505 0 0:00:00.151389 278 | 27600 552 552 1 0:00:00.151038 279 | 27700 554 554 1 0:00:00.151238 280 | 27800 556 421 0 0:00:00.152079 281 | 27900 558 251 0 0:00:00.151511 282 | 28000 560 768 0 0:00:00.151679 283 | 28100 562 562 1 0:00:00.151944 284 | 28200 564 564 1 0:00:00.151967 285 | 28300 566 566 1 0:00:00.152348 286 | 28400 568 399 0 0:00:00.152331 287 | 28500 570 -1 0 0:00:00.152074 288 | 28600 572 418 0 0:00:00.152075 289 | 28700 574 574 1 0:00:00.152514 290 | 28800 576 -1 0 0:00:00.152667 291 | 28900 578 578 1 0:00:00.151901 292 | 29000 580 738 0 0:00:00.152458 293 | 29100 582 -1 0 0:00:00.151789 294 | 29200 584 754 0 0:00:00.152026 295 | 29300 586 586 1 0:00:00.152489 296 | 29400 588 588 1 0:00:00.152318 297 | 29500 590 590 1 0:00:00.151197 298 | 29600 592 592 1 0:00:00.151341 299 | 29700 594 843 0 0:00:00.151311 300 | 29800 596 56 0 0:00:00.151415 301 | 29900 598 664 0 0:00:00.151570 302 | 30000 600 517 0 0:00:00.151897 303 | 30100 602 602 1 0:00:00.152012 304 | 30200 604 604 1 0:00:00.152546 305 | 30300 606 606 1 0:00:00.152328 306 | 30400 608 608 1 0:00:00.152009 307 | 30500 610 800 0 0:00:00.152771 308 | 30600 612 612 1 0:00:00.152460 309 | 30700 614 614 1 0:00:00.152576 310 | 30800 616 534 0 0:00:00.152400 311 | 30900 618 -1 0 0:00:00.152495 312 | 31000 620 681 0 0:00:00.152324 313 | 31100 622 622 1 0:00:00.151970 314 | 31200 624 453 0 0:00:00.152316 315 | 31300 626 -1 0 0:00:00.151186 316 | 31400 628 624 0 0:00:00.151331 317 | 31500 630 703 0 0:00:00.151542 318 | 31600 632 632 1 0:00:00.151527 319 | 31700 634 489 0 0:00:00.151511 320 | 31800 636 636 1 0:00:00.151963 321 | 31900 638 638 1 0:00:00.151848 322 | 32000 640 640 1 0:00:00.152098 323 | 32100 642 642 1 0:00:00.152818 324 | 32200 644 644 1 0:00:00.152130 325 | 32300 646 646 1 0:00:00.152788 326 | 32400 648 -1 0 0:00:00.152037 327 | 32500 650 546 0 0:00:00.152443 328 | 32600 652 652 1 0:00:00.151977 329 | 32700 654 436 0 0:00:00.152806 330 | 32800 656 -1 0 0:00:00.152032 331 | 32900 658 658 1 0:00:00.152533 332 | 33000 660 660 1 0:00:00.152316 333 | 33100 662 662 1 0:00:00.151066 334 | 33200 664 851 0 0:00:00.151540 335 | 33300 666 659 0 0:00:00.151820 336 | 33400 668 -1 0 0:00:00.151236 337 | 33500 670 670 1 0:00:00.151784 338 | 33600 672 672 1 0:00:00.151690 339 | 33700 674 674 1 0:00:00.152147 340 | 33800 676 560 0 0:00:00.152476 341 | 33900 678 678 1 0:00:00.152597 342 | 34000 680 -1 0 0:00:00.152415 343 | 34100 682 682 1 0:00:00.152823 344 | 34200 684 684 1 0:00:00.153091 345 | 34300 686 -1 0 0:00:00.152630 346 | 34400 688 688 1 0:00:00.152025 347 | 34500 690 -1 0 0:00:00.152708 348 | 34600 692 -1 0 0:00:00.153099 349 | 34700 694 694 1 0:00:00.152318 350 | 34800 696 -1 0 0:00:00.152753 351 | 34900 698 698 1 0:00:00.152175 352 | 35000 700 700 1 0:00:00.152845 353 | 35100 702 422 0 0:00:00.153542 354 | 35200 704 704 1 0:00:00.151337 355 | 35300 706 536 0 0:00:00.151342 356 | 35400 708 682 0 0:00:00.151542 357 | 35500 710 710 1 0:00:00.151536 358 | 35600 712 -1 0 0:00:00.152408 359 | 35700 714 714 1 0:00:00.151919 360 | 35800 716 716 1 0:00:00.152362 361 | 35900 718 449 0 0:00:00.152079 362 | 36000 720 720 1 0:00:00.152772 363 | 36100 722 722 1 0:00:00.152386 364 | 36200 724 780 0 0:00:00.152625 365 | 36300 726 726 1 0:00:00.152840 366 | 36400 728 728 1 0:00:00.152847 367 | 36500 730 856 0 0:00:00.152109 368 | 36600 732 732 1 0:00:00.152683 369 | 36700 734 734 1 0:00:00.152587 370 | 36800 736 736 1 0:00:00.153082 371 | 36900 738 738 1 0:00:00.152080 372 | 37000 740 -1 0 0:00:00.152445 373 | 37100 742 466 0 0:00:00.153498 374 | 37200 744 657 0 0:00:00.152382 375 | 37300 746 746 1 0:00:00.152394 376 | 37400 748 748 1 0:00:00.151623 377 | 37500 750 -1 0 0:00:00.151844 378 | 37600 752 -1 0 0:00:00.151828 379 | 37700 754 848 0 0:00:00.151490 380 | 37800 756 756 1 0:00:00.151816 381 | 37900 758 758 1 0:00:00.152165 382 | 38000 760 651 0 0:00:00.152320 383 | 38100 762 762 1 0:00:00.152745 384 | 38200 764 -1 0 0:00:00.152301 385 | 38300 766 766 1 0:00:00.152556 386 | 38400 768 768 1 0:00:00.152708 387 | 38500 770 770 1 0:00:00.152466 388 | 38600 772 772 1 0:00:00.153146 389 | 38700 774 774 1 0:00:00.152727 390 | 38800 776 683 0 0:00:00.153644 391 | 38900 778 707 0 0:00:00.152719 392 | 39000 780 780 1 0:00:00.152416 393 | 39100 782 664 0 0:00:00.152261 394 | 39200 784 505 0 0:00:00.151717 395 | 39300 786 -1 0 0:00:00.153320 396 | 39400 788 743 0 0:00:00.151936 397 | 39500 790 791 0 0:00:00.151433 398 | 39600 792 795 0 0:00:00.151931 399 | 39700 794 794 1 0:00:00.152070 400 | 39800 796 796 1 0:00:00.151815 401 | 39900 798 798 1 0:00:00.152158 402 | 40000 800 800 1 0:00:00.152520 403 | 40100 802 802 1 0:00:00.152387 404 | 40200 804 804 1 0:00:00.152754 405 | 40300 806 806 1 0:00:00.152673 406 | 40400 808 808 1 0:00:00.152710 407 | 40500 810 878 0 0:00:00.152405 408 | 40600 812 812 1 0:00:00.152728 409 | 40700 814 814 1 0:00:00.152647 410 | 40800 816 816 1 0:00:00.153116 411 | 40900 818 745 0 0:00:00.153242 412 | 41000 820 820 1 0:00:00.153136 413 | 41100 822 822 1 0:00:00.152745 414 | 41200 824 824 1 0:00:00.153423 415 | 41300 826 826 1 0:00:00.151895 416 | 41400 828 618 0 0:00:00.151875 417 | 41500 830 830 1 0:00:00.152070 418 | 41600 832 832 1 0:00:00.151996 419 | 41700 834 417 0 0:00:00.151796 420 | 41800 836 444 0 0:00:00.152159 421 | 41900 838 631 0 0:00:00.152218 422 | 42000 840 840 1 0:00:00.152527 423 | 42100 842 842 1 0:00:00.152467 424 | 42200 844 844 1 0:00:00.152677 425 | 42300 846 583 0 0:00:00.152786 426 | 42400 848 848 1 0:00:00.152760 427 | 42500 850 850 1 0:00:00.152573 428 | 42600 852 852 1 0:00:00.152428 429 | 42700 854 854 1 0:00:00.152375 430 | 42800 856 856 1 0:00:00.153162 431 | 42900 858 832 0 0:00:00.153394 432 | 43000 860 892 0 0:00:00.152900 433 | 43100 862 862 1 0:00:00.152719 434 | 43200 864 864 1 0:00:00.153751 435 | 43300 866 866 1 0:00:00.152864 436 | 43400 868 968 0 0:00:00.153205 437 | 43500 870 870 1 0:00:00.153201 438 | 43600 872 840 0 0:00:00.153276 439 | 43700 874 874 1 0:00:00.151293 440 | 43800 876 648 0 0:00:00.151649 441 | 43900 878 878 1 0:00:00.151755 442 | 44000 880 870 0 0:00:00.152305 443 | 44100 882 -1 0 0:00:00.152210 444 | 44200 884 884 1 0:00:00.152426 445 | 44300 886 886 1 0:00:00.152342 446 | 44400 888 718 0 0:00:00.152582 447 | 44500 890 890 1 0:00:00.153204 448 | 44600 892 442 0 0:00:00.153057 449 | 44700 894 894 1 0:00:00.152637 450 | 44800 896 896 1 0:00:00.152978 451 | 44900 898 508 0 0:00:00.152954 452 | 45000 900 900 1 0:00:00.153266 453 | 45100 902 902 1 0:00:00.152574 454 | 45200 904 905 0 0:00:00.153147 455 | 45300 906 906 1 0:00:00.152289 456 | 45400 908 895 0 0:00:00.153968 457 | 45500 910 -1 0 0:00:00.153294 458 | 45600 912 912 1 0:00:00.153271 459 | 45700 914 914 1 0:00:00.153157 460 | 45800 916 916 1 0:00:00.153229 461 | 45900 918 918 1 0:00:00.151889 462 | 46000 920 672 0 0:00:00.151853 463 | 46100 922 922 1 0:00:00.151801 464 | 46200 924 924 1 0:00:00.151895 465 | 46300 926 926 1 0:00:00.152917 466 | 46400 928 927 0 0:00:00.152268 467 | 46500 930 -1 0 0:00:00.152658 468 | 46600 932 932 1 0:00:00.152935 469 | 46700 934 934 1 0:00:00.152805 470 | 46800 936 936 1 0:00:00.153057 471 | 46900 938 938 1 0:00:00.153356 472 | 47000 940 940 1 0:00:00.152951 473 | 47100 942 942 1 0:00:00.152692 474 | 47200 944 946 0 0:00:00.152845 475 | 47300 946 946 1 0:00:00.152305 476 | 47400 948 948 1 0:00:00.153133 477 | 47500 950 950 1 0:00:00.152508 478 | 47600 952 90 0 0:00:00.153198 479 | 47700 954 -1 0 0:00:00.153289 480 | 47800 956 956 1 0:00:00.153317 481 | 47900 958 483 0 0:00:00.153305 482 | 48000 960 960 1 0:00:00.151798 483 | 48100 962 962 1 0:00:00.153157 484 | 48200 964 961 0 0:00:00.152369 485 | 48300 966 630 0 0:00:00.151834 486 | 48400 968 968 1 0:00:00.152382 487 | 48500 970 795 0 0:00:00.152050 488 | 48600 972 979 0 0:00:00.151806 489 | 48700 974 974 1 0:00:00.152487 490 | 48800 976 -1 0 0:00:00.152602 491 | 48900 978 701 0 0:00:00.152845 492 | 49000 980 980 1 0:00:00.153392 493 | 49100 982 982 1 0:00:00.153532 494 | 49200 984 984 1 0:00:00.153388 495 | 49300 986 986 1 0:00:00.153524 496 | 49400 988 988 1 0:00:00.153043 497 | 49500 990 990 1 0:00:00.153023 498 | 49600 992 992 1 0:00:00.153144 499 | 49700 994 994 1 0:00:00.153317 500 | 49800 996 996 1 0:00:00.153331 501 | 49900 998 998 1 0:00:00.153559 502 | -------------------------------------------------------------------------------- /data/predict/imagenet/resnet50/noise_0.25/test/N_1000: -------------------------------------------------------------------------------- 1 | idx label predict correct time 2 | 0 0 0 1 0:00:06.513843 3 | 100 2 2 1 0:00:01.455222 4 | 200 4 395 0 0:00:01.448558 5 | 300 6 6 1 0:00:01.448352 6 | 400 8 8 1 0:00:01.450362 7 | 500 10 10 1 0:00:01.449924 8 | 600 12 12 1 0:00:01.452259 9 | 700 14 14 1 0:00:01.453799 10 | 800 16 16 1 0:00:01.455621 11 | 900 18 18 1 0:00:01.454016 12 | 1000 20 20 1 0:00:01.457207 13 | 1100 22 21 0 0:00:01.454456 14 | 1200 24 24 1 0:00:01.457761 15 | 1300 26 26 1 0:00:01.459710 16 | 1400 28 28 1 0:00:01.457183 17 | 1500 30 30 1 0:00:01.458956 18 | 1600 32 103 0 0:00:01.459355 19 | 1700 34 327 0 0:00:01.462396 20 | 1800 36 35 0 0:00:01.460968 21 | 1900 38 38 1 0:00:01.462716 22 | 2000 40 31 0 0:00:01.464007 23 | 2100 42 48 0 0:00:01.463212 24 | 2200 44 44 1 0:00:01.464057 25 | 2300 46 46 1 0:00:01.463487 26 | 2400 48 48 1 0:00:01.466431 27 | 2500 50 50 1 0:00:01.464084 28 | 2600 52 52 1 0:00:01.465358 29 | 2700 54 57 0 0:00:01.465160 30 | 2800 56 56 1 0:00:01.471741 31 | 2900 58 58 1 0:00:01.469234 32 | 3000 60 62 0 0:00:01.469421 33 | 3100 62 62 1 0:00:01.467561 34 | 3200 64 64 1 0:00:01.468968 35 | 3300 66 63 0 0:00:01.469123 36 | 3400 68 -1 0 0:00:01.470740 37 | 3500 70 70 1 0:00:01.473675 38 | 3600 72 72 1 0:00:01.472703 39 | 3700 74 74 1 0:00:01.473779 40 | 3800 76 76 1 0:00:01.469706 41 | 3900 78 78 1 0:00:01.471689 42 | 4000 80 80 1 0:00:01.474229 43 | 4100 82 82 1 0:00:01.475134 44 | 4200 84 84 1 0:00:01.475369 45 | 4300 86 86 1 0:00:01.476813 46 | 4400 88 88 1 0:00:01.477425 47 | 4500 90 90 1 0:00:01.479196 48 | 4600 92 92 1 0:00:01.496022 49 | 4700 94 94 1 0:00:01.477135 50 | 4800 96 96 1 0:00:01.480331 51 | 4900 98 98 1 0:00:01.495252 52 | 5000 100 100 1 0:00:01.479198 53 | 5100 102 102 1 0:00:01.491304 54 | 5200 104 342 0 0:00:01.555276 55 | 5300 106 106 1 0:00:01.495784 56 | 5400 108 470 0 0:00:01.479194 57 | 5500 110 110 1 0:00:01.476201 58 | 5600 112 112 1 0:00:01.483646 59 | 5700 114 113 0 0:00:01.485140 60 | 5800 116 116 1 0:00:01.539842 61 | 5900 118 118 1 0:00:01.550041 62 | 6000 120 119 0 0:00:01.498474 63 | 6100 122 122 1 0:00:01.477521 64 | 6200 124 123 0 0:00:01.484329 65 | 6300 126 126 1 0:00:01.484162 66 | 6400 128 128 1 0:00:01.481536 67 | 6500 130 130 1 0:00:01.483169 68 | 6600 132 132 1 0:00:01.503742 69 | 6700 134 132 0 0:00:01.518725 70 | 6800 136 136 1 0:00:01.539357 71 | 6900 138 138 1 0:00:01.498044 72 | 7000 140 140 1 0:00:01.480387 73 | 7100 142 142 1 0:00:01.483995 74 | 7200 144 144 1 0:00:01.486113 75 | 7300 146 146 1 0:00:01.488695 76 | 7400 148 148 1 0:00:01.489084 77 | 7500 150 360 0 0:00:01.483568 78 | 7600 152 152 1 0:00:01.479528 79 | 7700 154 154 1 0:00:01.478619 80 | 7800 156 156 1 0:00:01.537528 81 | 7900 158 158 1 0:00:01.496571 82 | 8000 160 160 1 0:00:01.498672 83 | 8100 162 162 1 0:00:01.481299 84 | 8200 164 164 1 0:00:01.496998 85 | 8300 166 166 1 0:00:01.490325 86 | 8400 168 211 0 0:00:01.503927 87 | 8500 170 172 0 0:00:01.536041 88 | 8600 172 172 1 0:00:01.500019 89 | 8700 174 174 1 0:00:01.520921 90 | 8800 176 176 1 0:00:01.487111 91 | 8900 178 178 1 0:00:01.488966 92 | 9000 180 243 0 0:00:01.550351 93 | 9100 182 182 1 0:00:01.526359 94 | 9200 184 -1 0 0:00:01.557193 95 | 9300 186 192 0 0:00:01.499761 96 | 9400 188 188 1 0:00:01.524055 97 | 9500 190 190 1 0:00:01.483822 98 | 9600 192 199 0 0:00:01.479605 99 | 9700 194 202 0 0:00:01.554249 100 | 9800 196 199 0 0:00:01.496503 101 | 9900 198 196 0 0:00:01.492259 102 | 10000 200 200 1 0:00:01.479200 103 | 10100 202 202 1 0:00:01.482847 104 | 10200 204 204 1 0:00:01.479599 105 | 10300 206 206 1 0:00:01.520266 106 | 10400 208 208 1 0:00:01.481036 107 | 10500 210 210 1 0:00:01.489356 108 | 10600 212 212 1 0:00:01.555681 109 | 10700 214 214 1 0:00:01.491593 110 | 10800 216 216 1 0:00:01.484073 111 | 10900 218 218 1 0:00:01.485102 112 | 11000 220 220 1 0:00:01.479046 113 | 11100 222 222 1 0:00:01.498316 114 | 11200 224 233 0 0:00:01.481419 115 | 11300 226 226 1 0:00:01.489648 116 | 11400 228 228 1 0:00:01.475325 117 | 11500 230 230 1 0:00:01.479492 118 | 11600 232 232 1 0:00:01.485672 119 | 11700 234 234 1 0:00:01.478974 120 | 11800 236 234 0 0:00:01.532255 121 | 11900 238 238 1 0:00:01.501435 122 | 12000 240 238 0 0:00:01.478093 123 | 12100 242 242 1 0:00:01.487937 124 | 12200 244 244 1 0:00:01.555950 125 | 12300 246 251 0 0:00:01.497940 126 | 12400 248 248 1 0:00:01.484743 127 | 12500 250 169 0 0:00:01.480882 128 | 12600 252 252 1 0:00:01.485409 129 | 12700 254 254 1 0:00:01.482774 130 | 12800 256 256 1 0:00:01.554899 131 | 12900 258 258 1 0:00:01.504591 132 | 13000 260 260 1 0:00:01.565463 133 | 13100 262 262 1 0:00:01.493973 134 | 13200 264 264 1 0:00:01.525658 135 | 13300 266 266 1 0:00:01.486248 136 | 13400 268 268 1 0:00:01.481760 137 | 13500 270 270 1 0:00:01.481478 138 | 13600 272 274 0 0:00:01.479039 139 | 13700 274 274 1 0:00:01.475843 140 | 13800 276 276 1 0:00:01.475403 141 | 13900 278 278 1 0:00:01.478996 142 | 14000 280 280 1 0:00:01.481032 143 | 14100 282 282 1 0:00:01.483646 144 | 14200 284 284 1 0:00:01.479185 145 | 14300 286 286 1 0:00:01.492340 146 | 14400 288 288 1 0:00:01.487615 147 | 14500 290 290 1 0:00:01.489121 148 | 14600 292 290 0 0:00:01.480814 149 | 14700 294 294 1 0:00:01.490083 150 | 14800 296 296 1 0:00:01.488762 151 | 14900 298 298 1 0:00:01.481069 152 | 15000 300 300 1 0:00:01.520467 153 | 15100 302 302 1 0:00:01.482847 154 | 15200 304 301 0 0:00:01.492529 155 | 15300 306 306 1 0:00:01.558305 156 | 15400 308 308 1 0:00:01.488705 157 | 15500 310 310 1 0:00:01.480214 158 | 15600 312 315 0 0:00:01.496566 159 | 15700 314 314 1 0:00:01.540013 160 | 15800 316 316 1 0:00:01.516623 161 | 15900 318 318 1 0:00:01.477651 162 | 16000 320 320 1 0:00:01.477822 163 | 16100 322 322 1 0:00:01.483212 164 | 16200 324 324 1 0:00:01.489052 165 | 16300 326 326 1 0:00:01.480788 166 | 16400 328 328 1 0:00:01.482190 167 | 16500 330 330 1 0:00:01.491863 168 | 16600 332 153 0 0:00:01.478761 169 | 16700 334 334 1 0:00:01.515207 170 | 16800 336 336 1 0:00:01.554691 171 | 16900 338 617 0 0:00:01.541147 172 | 17000 340 340 1 0:00:01.489892 173 | 17100 342 287 0 0:00:01.481903 174 | 17200 344 344 1 0:00:01.502365 175 | 17300 346 344 0 0:00:01.534523 176 | 17400 348 348 1 0:00:01.505174 177 | 17500 350 350 1 0:00:01.502758 178 | 17600 352 352 1 0:00:01.481927 179 | 17700 354 354 1 0:00:01.484131 180 | 17800 356 -1 0 0:00:01.481266 181 | 17900 358 359 0 0:00:01.480267 182 | 18000 360 360 1 0:00:01.478824 183 | 18100 362 362 1 0:00:01.518305 184 | 18200 364 364 1 0:00:01.560162 185 | 18300 366 366 1 0:00:01.496499 186 | 18400 368 368 1 0:00:01.482130 187 | 18500 370 370 1 0:00:01.477426 188 | 18600 372 372 1 0:00:01.481619 189 | 18700 374 374 1 0:00:01.488400 190 | 18800 376 376 1 0:00:01.496681 191 | 18900 378 378 1 0:00:01.485233 192 | 19000 380 380 1 0:00:01.531544 193 | 19100 382 382 1 0:00:01.499931 194 | 19200 384 283 0 0:00:01.480609 195 | 19300 386 101 0 0:00:01.477893 196 | 19400 388 388 1 0:00:01.481483 197 | 19500 390 390 1 0:00:01.485835 198 | 19600 392 397 0 0:00:01.494839 199 | 19700 394 467 0 0:00:01.479918 200 | 19800 396 396 1 0:00:01.486316 201 | 19900 398 398 1 0:00:01.479316 202 | 20000 400 400 1 0:00:01.482861 203 | 20100 402 402 1 0:00:01.481416 204 | 20200 404 404 1 0:00:01.500082 205 | 20300 406 406 1 0:00:01.481807 206 | 20400 408 847 0 0:00:01.477773 207 | 20500 410 410 1 0:00:01.550446 208 | 20600 412 412 1 0:00:01.495554 209 | 20700 414 414 1 0:00:01.490323 210 | 20800 416 416 1 0:00:01.479424 211 | 20900 418 563 0 0:00:01.479409 212 | 21000 420 420 1 0:00:01.480183 213 | 21100 422 422 1 0:00:01.522102 214 | 21200 424 454 0 0:00:01.480354 215 | 21300 426 426 1 0:00:01.484423 216 | 21400 428 428 1 0:00:01.481469 217 | 21500 430 430 1 0:00:01.498621 218 | 21600 432 432 1 0:00:01.483243 219 | 21700 434 434 1 0:00:01.485271 220 | 21800 436 436 1 0:00:01.498695 221 | 21900 438 438 1 0:00:01.554712 222 | 22000 440 737 0 0:00:01.493426 223 | 22100 442 442 1 0:00:01.476355 224 | 22200 444 870 0 0:00:01.484731 225 | 22300 446 446 1 0:00:01.481968 226 | 22400 448 448 1 0:00:01.480012 227 | 22500 450 407 0 0:00:01.479037 228 | 22600 452 452 1 0:00:01.491563 229 | 22700 454 454 1 0:00:01.554456 230 | 22800 456 645 0 0:00:01.526294 231 | 22900 458 458 1 0:00:01.496040 232 | 23000 460 978 0 0:00:01.482221 233 | 23100 462 462 1 0:00:01.502633 234 | 23200 464 439 0 0:00:01.496977 235 | 23300 466 466 1 0:00:01.474824 236 | 23400 468 468 1 0:00:01.479537 237 | 23500 470 624 0 0:00:01.482006 238 | 23600 472 472 1 0:00:01.491713 239 | 23700 474 841 0 0:00:01.481372 240 | 23800 476 476 1 0:00:01.483987 241 | 23900 478 478 1 0:00:01.484918 242 | 24000 480 509 0 0:00:01.480896 243 | 24100 482 481 0 0:00:01.481517 244 | 24200 484 871 0 0:00:01.480409 245 | 24300 486 486 1 0:00:01.551933 246 | 24400 488 488 1 0:00:01.501731 247 | 24500 490 490 1 0:00:01.552063 248 | 24600 492 492 1 0:00:01.505758 249 | 24700 494 398 0 0:00:01.498369 250 | 24800 496 496 1 0:00:01.481360 251 | 24900 498 498 1 0:00:01.477659 252 | 25000 500 500 1 0:00:01.479922 253 | 25100 502 502 1 0:00:01.477720 254 | 25200 504 504 1 0:00:01.481018 255 | 25300 506 506 1 0:00:01.480943 256 | 25400 508 508 1 0:00:01.478332 257 | 25500 510 510 1 0:00:01.496781 258 | 25600 512 740 0 0:00:01.484283 259 | 25700 514 514 1 0:00:01.492787 260 | 25800 516 431 0 0:00:01.491374 261 | 25900 518 518 1 0:00:01.478206 262 | 26000 520 520 1 0:00:01.491399 263 | 26100 522 522 1 0:00:01.481239 264 | 26200 524 461 0 0:00:01.516815 265 | 26300 526 526 1 0:00:01.478231 266 | 26400 528 -1 0 0:00:01.489215 267 | 26500 530 531 0 0:00:01.529016 268 | 26600 532 532 1 0:00:01.488698 269 | 26700 534 534 1 0:00:01.483696 270 | 26800 536 403 0 0:00:01.500773 271 | 26900 538 538 1 0:00:01.479615 272 | 27000 540 540 1 0:00:01.483645 273 | 27100 542 -1 0 0:00:01.483379 274 | 27200 544 926 0 0:00:01.480354 275 | 27300 546 546 1 0:00:01.483093 276 | 27400 548 548 1 0:00:01.480865 277 | 27500 550 505 0 0:00:01.479713 278 | 27600 552 552 1 0:00:01.484718 279 | 27700 554 554 1 0:00:01.486498 280 | 27800 556 421 0 0:00:01.481885 281 | 27900 558 251 0 0:00:01.481816 282 | 28000 560 768 0 0:00:01.479113 283 | 28100 562 562 1 0:00:01.487936 284 | 28200 564 564 1 0:00:01.482777 285 | 28300 566 566 1 0:00:01.512465 286 | 28400 568 399 0 0:00:01.488383 287 | 28500 570 -1 0 0:00:01.482008 288 | 28600 572 418 0 0:00:01.514439 289 | 28700 574 574 1 0:00:01.477766 290 | 28800 576 -1 0 0:00:01.496104 291 | 28900 578 578 1 0:00:01.551244 292 | 29000 580 738 0 0:00:01.521960 293 | 29100 582 -1 0 0:00:01.507269 294 | 29200 584 754 0 0:00:01.478857 295 | 29300 586 586 1 0:00:01.477593 296 | 29400 588 588 1 0:00:01.493497 297 | 29500 590 590 1 0:00:01.477899 298 | 29600 592 592 1 0:00:01.479885 299 | 29700 594 843 0 0:00:01.481439 300 | 29800 596 56 0 0:00:01.541896 301 | 29900 598 664 0 0:00:01.526365 302 | 30000 600 517 0 0:00:01.531583 303 | 30100 602 602 1 0:00:01.484584 304 | 30200 604 604 1 0:00:01.480353 305 | 30300 606 606 1 0:00:01.523729 306 | 30400 608 608 1 0:00:01.481934 307 | 30500 610 800 0 0:00:01.478992 308 | 30600 612 612 1 0:00:01.548418 309 | 30700 614 614 1 0:00:01.494345 310 | 30800 616 534 0 0:00:01.500657 311 | 30900 618 828 0 0:00:01.487541 312 | 31000 620 681 0 0:00:01.490442 313 | 31100 622 622 1 0:00:01.480001 314 | 31200 624 453 0 0:00:01.520598 315 | 31300 626 542 0 0:00:01.492061 316 | 31400 628 624 0 0:00:01.485516 317 | 31500 630 703 0 0:00:01.493293 318 | 31600 632 632 1 0:00:01.499207 319 | 31700 634 489 0 0:00:01.482136 320 | 31800 636 636 1 0:00:01.522583 321 | 31900 638 638 1 0:00:01.480404 322 | 32000 640 640 1 0:00:01.478007 323 | 32100 642 642 1 0:00:01.490945 324 | 32200 644 644 1 0:00:01.483194 325 | 32300 646 646 1 0:00:01.477241 326 | 32400 648 -1 0 0:00:01.480447 327 | 32500 650 546 0 0:00:01.557726 328 | 32600 652 652 1 0:00:01.505422 329 | 32700 654 436 0 0:00:01.477040 330 | 32800 656 -1 0 0:00:01.509807 331 | 32900 658 658 1 0:00:01.490112 332 | 33000 660 660 1 0:00:01.517296 333 | 33100 662 662 1 0:00:01.485711 334 | 33200 664 851 0 0:00:01.490185 335 | 33300 666 659 0 0:00:01.522107 336 | 33400 668 668 1 0:00:01.478890 337 | 33500 670 670 1 0:00:01.487738 338 | 33600 672 672 1 0:00:01.479811 339 | 33700 674 674 1 0:00:01.478646 340 | 33800 676 560 0 0:00:01.485127 341 | 33900 678 678 1 0:00:01.490292 342 | 34000 680 793 0 0:00:01.484289 343 | 34100 682 682 1 0:00:01.478803 344 | 34200 684 684 1 0:00:01.478796 345 | 34300 686 -1 0 0:00:01.479543 346 | 34400 688 688 1 0:00:01.537382 347 | 34500 690 690 1 0:00:01.555078 348 | 34600 692 481 0 0:00:01.524606 349 | 34700 694 694 1 0:00:01.484595 350 | 34800 696 -1 0 0:00:01.480659 351 | 34900 698 698 1 0:00:01.495338 352 | 35000 700 700 1 0:00:01.491358 353 | 35100 702 422 0 0:00:01.479972 354 | 35200 704 704 1 0:00:01.559804 355 | 35300 706 536 0 0:00:01.515814 356 | 35400 708 682 0 0:00:01.494998 357 | 35500 710 710 1 0:00:01.483101 358 | 35600 712 622 0 0:00:01.489805 359 | 35700 714 714 1 0:00:01.485512 360 | 35800 716 716 1 0:00:01.480383 361 | 35900 718 449 0 0:00:01.481020 362 | 36000 720 720 1 0:00:01.480507 363 | 36100 722 722 1 0:00:01.484626 364 | 36200 724 780 0 0:00:01.492352 365 | 36300 726 726 1 0:00:01.494925 366 | 36400 728 728 1 0:00:01.488405 367 | 36500 730 856 0 0:00:01.491160 368 | 36600 732 732 1 0:00:01.488379 369 | 36700 734 734 1 0:00:01.480097 370 | 36800 736 736 1 0:00:01.483129 371 | 36900 738 738 1 0:00:01.524685 372 | 37000 740 -1 0 0:00:01.483893 373 | 37100 742 466 0 0:00:01.481901 374 | 37200 744 657 0 0:00:01.477067 375 | 37300 746 746 1 0:00:01.481244 376 | 37400 748 748 1 0:00:01.481573 377 | 37500 750 172 0 0:00:01.551862 378 | 37600 752 830 0 0:00:01.493996 379 | 37700 754 848 0 0:00:01.478429 380 | 37800 756 756 1 0:00:01.479473 381 | 37900 758 758 1 0:00:01.480465 382 | 38000 760 651 0 0:00:01.498277 383 | 38100 762 762 1 0:00:01.558768 384 | 38200 764 593 0 0:00:01.497985 385 | 38300 766 766 1 0:00:01.568286 386 | 38400 768 768 1 0:00:01.531938 387 | 38500 770 770 1 0:00:01.481052 388 | 38600 772 772 1 0:00:01.479920 389 | 38700 774 774 1 0:00:01.492791 390 | 38800 776 683 0 0:00:01.536620 391 | 38900 778 707 0 0:00:01.480408 392 | 39000 780 780 1 0:00:01.478122 393 | 39100 782 664 0 0:00:01.551872 394 | 39200 784 505 0 0:00:01.496936 395 | 39300 786 513 0 0:00:01.480556 396 | 39400 788 743 0 0:00:01.480324 397 | 39500 790 791 0 0:00:01.475573 398 | 39600 792 795 0 0:00:01.483009 399 | 39700 794 794 1 0:00:01.480900 400 | 39800 796 796 1 0:00:01.482380 401 | 39900 798 798 1 0:00:01.480042 402 | 40000 800 800 1 0:00:01.481433 403 | 40100 802 802 1 0:00:01.481994 404 | 40200 804 804 1 0:00:01.482339 405 | 40300 806 806 1 0:00:01.484827 406 | 40400 808 808 1 0:00:01.480378 407 | 40500 810 878 0 0:00:01.550252 408 | 40600 812 812 1 0:00:01.497617 409 | 40700 814 814 1 0:00:01.543693 410 | 40800 816 816 1 0:00:01.528976 411 | 40900 818 745 0 0:00:01.549425 412 | 41000 820 820 1 0:00:01.488208 413 | 41100 822 822 1 0:00:01.479282 414 | 41200 824 824 1 0:00:01.494460 415 | 41300 826 826 1 0:00:01.482595 416 | 41400 828 618 0 0:00:01.493755 417 | 41500 830 830 1 0:00:01.560996 418 | 41600 832 832 1 0:00:01.494581 419 | 41700 834 417 0 0:00:01.516018 420 | 41800 836 444 0 0:00:01.479565 421 | 41900 838 631 0 0:00:01.486431 422 | 42000 840 840 1 0:00:01.481183 423 | 42100 842 842 1 0:00:01.477463 424 | 42200 844 844 1 0:00:01.486127 425 | 42300 846 583 0 0:00:01.483045 426 | 42400 848 848 1 0:00:01.525781 427 | 42500 850 850 1 0:00:01.483320 428 | 42600 852 852 1 0:00:01.478466 429 | 42700 854 854 1 0:00:01.513655 430 | 42800 856 856 1 0:00:01.486674 431 | 42900 858 832 0 0:00:01.482063 432 | 43000 860 892 0 0:00:01.480499 433 | 43100 862 862 1 0:00:01.480461 434 | 43200 864 864 1 0:00:01.480838 435 | 43300 866 866 1 0:00:01.479955 436 | 43400 868 968 0 0:00:01.479614 437 | 43500 870 870 1 0:00:01.492825 438 | 43600 872 840 0 0:00:01.494615 439 | 43700 874 874 1 0:00:01.480262 440 | 43800 876 648 0 0:00:01.530230 441 | 43900 878 878 1 0:00:01.477895 442 | 44000 880 870 0 0:00:01.481497 443 | 44100 882 -1 0 0:00:01.523555 444 | 44200 884 884 1 0:00:01.546123 445 | 44300 886 886 1 0:00:01.534611 446 | 44400 888 718 0 0:00:01.528654 447 | 44500 890 890 1 0:00:01.478625 448 | 44600 892 442 0 0:00:01.489870 449 | 44700 894 894 1 0:00:01.485551 450 | 44800 896 896 1 0:00:01.486424 451 | 44900 898 508 0 0:00:01.481154 452 | 45000 900 900 1 0:00:01.479834 453 | 45100 902 902 1 0:00:01.494956 454 | 45200 904 905 0 0:00:01.534988 455 | 45300 906 906 1 0:00:01.480348 456 | 45400 908 895 0 0:00:01.480904 457 | 45500 910 -1 0 0:00:01.549871 458 | 45600 912 912 1 0:00:01.518482 459 | 45700 914 914 1 0:00:01.496823 460 | 45800 916 916 1 0:00:01.484683 461 | 45900 918 918 1 0:00:01.482222 462 | 46000 920 672 0 0:00:01.481179 463 | 46100 922 922 1 0:00:01.480635 464 | 46200 924 924 1 0:00:01.480506 465 | 46300 926 926 1 0:00:01.481995 466 | 46400 928 927 0 0:00:01.482548 467 | 46500 930 964 0 0:00:01.478524 468 | 46600 932 932 1 0:00:01.481828 469 | 46700 934 934 1 0:00:01.492927 470 | 46800 936 936 1 0:00:01.485990 471 | 46900 938 938 1 0:00:01.477883 472 | 47000 940 940 1 0:00:01.495854 473 | 47100 942 942 1 0:00:01.483602 474 | 47200 944 946 0 0:00:01.485325 475 | 47300 946 946 1 0:00:01.479533 476 | 47400 948 948 1 0:00:01.482796 477 | 47500 950 950 1 0:00:01.481151 478 | 47600 952 90 0 0:00:01.479927 479 | 47700 954 -1 0 0:00:01.484572 480 | 47800 956 956 1 0:00:01.490380 481 | 47900 958 483 0 0:00:01.493005 482 | 48000 960 960 1 0:00:01.479488 483 | 48100 962 962 1 0:00:01.508343 484 | 48200 964 961 0 0:00:01.556295 485 | 48300 966 630 0 0:00:01.503184 486 | 48400 968 968 1 0:00:01.483091 487 | 48500 970 795 0 0:00:01.480517 488 | 48600 972 979 0 0:00:01.479620 489 | 48700 974 974 1 0:00:01.479991 490 | 48800 976 724 0 0:00:01.489057 491 | 48900 978 701 0 0:00:01.489683 492 | 49000 980 980 1 0:00:01.482108 493 | 49100 982 982 1 0:00:01.483436 494 | 49200 984 984 1 0:00:01.503881 495 | 49300 986 986 1 0:00:01.484524 496 | 49400 988 988 1 0:00:01.484175 497 | 49500 990 990 1 0:00:01.543809 498 | 49600 992 992 1 0:00:01.552771 499 | 49700 994 994 1 0:00:01.492512 500 | 49800 996 996 1 0:00:01.485028 501 | 49900 998 998 1 0:00:01.477876 502 | -------------------------------------------------------------------------------- /data/predict/imagenet/resnet50/noise_0.25/test/N_10000: -------------------------------------------------------------------------------- 1 | idx label predict correct time 2 | 0 0 0 1 0:00:17.938670 3 | 100 2 2 1 0:00:14.529888 4 | 200 4 395 0 0:00:14.602542 5 | 300 6 6 1 0:00:14.658964 6 | 400 8 8 1 0:00:14.701330 7 | 500 10 10 1 0:00:14.710515 8 | 600 12 12 1 0:00:14.762375 9 | 700 14 14 1 0:00:14.895730 10 | 800 16 16 1 0:00:15.036877 11 | 900 18 18 1 0:00:14.915746 12 | 1000 20 20 1 0:00:14.775350 13 | 1100 22 21 0 0:00:14.954231 14 | 1200 24 24 1 0:00:14.953908 15 | 1300 26 26 1 0:00:14.813852 16 | 1400 28 28 1 0:00:14.843823 17 | 1500 30 30 1 0:00:14.890527 18 | 1600 32 103 0 0:00:14.992512 19 | 1700 34 327 0 0:00:14.821474 20 | 1800 36 35 0 0:00:14.935570 21 | 1900 38 38 1 0:00:14.801161 22 | 2000 40 31 0 0:00:14.919142 23 | 2100 42 48 0 0:00:14.847678 24 | 2200 44 44 1 0:00:15.038948 25 | 2300 46 46 1 0:00:14.826487 26 | 2400 48 48 1 0:00:14.995221 27 | 2500 50 50 1 0:00:15.031976 28 | 2600 52 52 1 0:00:14.921636 29 | 2700 54 57 0 0:00:14.942702 30 | 2800 56 56 1 0:00:14.990683 31 | 2900 58 58 1 0:00:14.999382 32 | 3000 60 62 0 0:00:14.885931 33 | 3100 62 62 1 0:00:14.838300 34 | 3200 64 64 1 0:00:14.967360 35 | 3300 66 63 0 0:00:14.891294 36 | 3400 68 -1 0 0:00:14.835415 37 | 3500 70 70 1 0:00:14.876681 38 | 3600 72 72 1 0:00:15.101473 39 | 3700 74 74 1 0:00:14.888600 40 | 3800 76 76 1 0:00:14.864935 41 | 3900 78 78 1 0:00:14.870341 42 | 4000 80 80 1 0:00:14.896723 43 | 4100 82 82 1 0:00:14.920049 44 | 4200 84 84 1 0:00:14.901886 45 | 4300 86 86 1 0:00:14.940372 46 | 4400 88 88 1 0:00:14.945328 47 | 4500 90 90 1 0:00:14.846179 48 | 4600 92 92 1 0:00:15.061382 49 | 4700 94 94 1 0:00:14.843592 50 | 4800 96 96 1 0:00:14.833731 51 | 4900 98 98 1 0:00:14.789376 52 | 5000 100 100 1 0:00:14.819640 53 | 5100 102 102 1 0:00:14.878288 54 | 5200 104 342 0 0:00:14.806604 55 | 5300 106 106 1 0:00:14.844456 56 | 5400 108 470 0 0:00:15.005162 57 | 5500 110 110 1 0:00:14.846514 58 | 5600 112 112 1 0:00:14.857095 59 | 5700 114 113 0 0:00:14.923831 60 | 5800 116 116 1 0:00:14.970319 61 | 5900 118 118 1 0:00:14.890198 62 | 6000 120 119 0 0:00:14.816716 63 | 6100 122 122 1 0:00:15.101482 64 | 6200 124 123 0 0:00:14.836912 65 | 6300 126 126 1 0:00:15.015560 66 | 6400 128 128 1 0:00:14.807324 67 | 6500 130 130 1 0:00:14.792057 68 | 6600 132 132 1 0:00:15.131752 69 | 6700 134 132 0 0:00:14.876376 70 | 6800 136 136 1 0:00:14.973896 71 | 6900 138 138 1 0:00:14.950263 72 | 7000 140 140 1 0:00:14.883024 73 | 7100 142 142 1 0:00:14.867703 74 | 7200 144 144 1 0:00:14.824093 75 | 7300 146 146 1 0:00:14.969388 76 | 7400 148 148 1 0:00:15.022812 77 | 7500 150 360 0 0:00:14.857029 78 | 7600 152 152 1 0:00:15.055737 79 | 7700 154 154 1 0:00:14.930024 80 | 7800 156 156 1 0:00:14.998052 81 | 7900 158 158 1 0:00:14.854780 82 | 8000 160 160 1 0:00:15.022626 83 | 8100 162 162 1 0:00:14.812995 84 | 8200 164 164 1 0:00:14.955143 85 | 8300 166 166 1 0:00:15.046064 86 | 8400 168 211 0 0:00:14.830755 87 | 8500 170 172 0 0:00:14.842812 88 | 8600 172 172 1 0:00:14.830592 89 | 8700 174 174 1 0:00:15.064394 90 | 8800 176 176 1 0:00:14.872763 91 | 8900 178 178 1 0:00:14.958781 92 | 9000 180 243 0 0:00:14.998049 93 | 9100 182 182 1 0:00:14.956381 94 | 9200 184 188 0 0:00:14.997293 95 | 9300 186 192 0 0:00:14.842927 96 | 9400 188 188 1 0:00:14.969200 97 | 9500 190 190 1 0:00:14.819773 98 | 9600 192 199 0 0:00:14.841370 99 | 9700 194 202 0 0:00:14.891114 100 | 9800 196 199 0 0:00:15.028255 101 | 9900 198 196 0 0:00:14.930591 102 | 10000 200 200 1 0:00:15.043234 103 | 10100 202 202 1 0:00:14.797675 104 | 10200 204 204 1 0:00:15.117067 105 | 10300 206 206 1 0:00:14.826504 106 | 10400 208 208 1 0:00:14.879246 107 | 10500 210 210 1 0:00:14.878064 108 | 10600 212 212 1 0:00:14.987449 109 | 10700 214 214 1 0:00:14.888630 110 | 10800 216 216 1 0:00:14.790781 111 | 10900 218 218 1 0:00:14.845216 112 | 11000 220 220 1 0:00:15.162716 113 | 11100 222 222 1 0:00:14.845916 114 | 11200 224 233 0 0:00:14.967866 115 | 11300 226 226 1 0:00:15.006416 116 | 11400 228 228 1 0:00:14.967813 117 | 11500 230 230 1 0:00:14.985305 118 | 11600 232 232 1 0:00:14.927306 119 | 11700 234 234 1 0:00:14.946466 120 | 11800 236 234 0 0:00:15.021441 121 | 11900 238 238 1 0:00:14.916477 122 | 12000 240 238 0 0:00:14.907923 123 | 12100 242 242 1 0:00:14.984931 124 | 12200 244 244 1 0:00:14.801400 125 | 12300 246 251 0 0:00:14.846193 126 | 12400 248 248 1 0:00:15.127333 127 | 12500 250 169 0 0:00:14.818779 128 | 12600 252 252 1 0:00:14.976161 129 | 12700 254 254 1 0:00:14.848175 130 | 12800 256 256 1 0:00:14.817288 131 | 12900 258 258 1 0:00:14.941558 132 | 13000 260 260 1 0:00:14.882449 133 | 13100 262 262 1 0:00:14.845883 134 | 13200 264 264 1 0:00:15.053814 135 | 13300 266 266 1 0:00:14.963253 136 | 13400 268 268 1 0:00:14.906438 137 | 13500 270 270 1 0:00:14.821655 138 | 13600 272 274 0 0:00:14.912062 139 | 13700 274 274 1 0:00:15.016080 140 | 13800 276 276 1 0:00:14.981570 141 | 13900 278 278 1 0:00:14.983741 142 | 14000 280 280 1 0:00:14.931658 143 | 14100 282 282 1 0:00:14.885947 144 | 14200 284 284 1 0:00:14.956033 145 | 14300 286 286 1 0:00:14.804141 146 | 14400 288 288 1 0:00:14.923306 147 | 14500 290 290 1 0:00:14.974497 148 | 14600 292 290 0 0:00:15.046874 149 | 14700 294 294 1 0:00:14.817052 150 | 14800 296 296 1 0:00:15.010772 151 | 14900 298 298 1 0:00:14.779613 152 | 15000 300 300 1 0:00:15.087663 153 | 15100 302 302 1 0:00:14.960548 154 | 15200 304 301 0 0:00:14.948146 155 | 15300 306 306 1 0:00:14.997912 156 | 15400 308 308 1 0:00:14.958719 157 | 15500 310 310 1 0:00:14.940277 158 | 15600 312 315 0 0:00:14.941935 159 | 15700 314 314 1 0:00:14.972534 160 | 15800 316 316 1 0:00:15.045166 161 | 15900 318 318 1 0:00:14.816161 162 | 16000 320 320 1 0:00:14.905789 163 | 16100 322 322 1 0:00:14.972884 164 | 16200 324 324 1 0:00:14.849005 165 | 16300 326 326 1 0:00:15.113016 166 | 16400 328 328 1 0:00:14.819477 167 | 16500 330 330 1 0:00:14.900786 168 | 16600 332 153 0 0:00:14.899500 169 | 16700 334 334 1 0:00:15.079960 170 | 16800 336 336 1 0:00:14.817020 171 | 16900 338 617 0 0:00:14.896904 172 | 17000 340 340 1 0:00:15.101224 173 | 17100 342 287 0 0:00:14.789878 174 | 17200 344 344 1 0:00:15.015266 175 | 17300 346 344 0 0:00:14.911635 176 | 17400 348 348 1 0:00:15.007726 177 | 17500 350 350 1 0:00:15.036709 178 | 17600 352 352 1 0:00:15.037419 179 | 17700 354 354 1 0:00:14.989679 180 | 17800 356 -1 0 0:00:14.995699 181 | 17900 358 359 0 0:00:14.847023 182 | 18000 360 360 1 0:00:14.942973 183 | 18100 362 362 1 0:00:14.966301 184 | 18200 364 364 1 0:00:14.917976 185 | 18300 366 366 1 0:00:14.879833 186 | 18400 368 368 1 0:00:15.137949 187 | 18500 370 370 1 0:00:14.916142 188 | 18600 372 372 1 0:00:15.003802 189 | 18700 374 374 1 0:00:14.929094 190 | 18800 376 376 1 0:00:14.956331 191 | 18900 378 378 1 0:00:14.873452 192 | 19000 380 380 1 0:00:14.860228 193 | 19100 382 382 1 0:00:14.906570 194 | 19200 384 283 0 0:00:14.863993 195 | 19300 386 101 0 0:00:15.087887 196 | 19400 388 388 1 0:00:14.998850 197 | 19500 390 390 1 0:00:14.789777 198 | 19600 392 397 0 0:00:15.002732 199 | 19700 394 467 0 0:00:14.836903 200 | 19800 396 396 1 0:00:14.804411 201 | 19900 398 398 1 0:00:15.011312 202 | 20000 400 400 1 0:00:15.001071 203 | 20100 402 402 1 0:00:14.913904 204 | 20200 404 404 1 0:00:14.936100 205 | 20300 406 406 1 0:00:14.876220 206 | 20400 408 847 0 0:00:14.978644 207 | 20500 410 410 1 0:00:14.988457 208 | 20600 412 412 1 0:00:14.980928 209 | 20700 414 414 1 0:00:14.982497 210 | 20800 416 416 1 0:00:14.813848 211 | 20900 418 563 0 0:00:15.099722 212 | 21000 420 420 1 0:00:14.814135 213 | 21100 422 422 1 0:00:14.884715 214 | 21200 424 454 0 0:00:14.864878 215 | 21300 426 426 1 0:00:15.047416 216 | 21400 428 428 1 0:00:14.912828 217 | 21500 430 430 1 0:00:14.957142 218 | 21600 432 432 1 0:00:14.899586 219 | 21700 434 434 1 0:00:14.843033 220 | 21800 436 436 1 0:00:14.993052 221 | 21900 438 438 1 0:00:15.054873 222 | 22000 440 737 0 0:00:14.859920 223 | 22100 442 442 1 0:00:14.820955 224 | 22200 444 870 0 0:00:15.013117 225 | 22300 446 446 1 0:00:15.022072 226 | 22400 448 448 1 0:00:14.867987 227 | 22500 450 407 0 0:00:14.980548 228 | 22600 452 452 1 0:00:14.831118 229 | 22700 454 454 1 0:00:15.062666 230 | 22800 456 645 0 0:00:15.022581 231 | 22900 458 458 1 0:00:15.048858 232 | 23000 460 978 0 0:00:14.951904 233 | 23100 462 462 1 0:00:15.014614 234 | 23200 464 439 0 0:00:15.039416 235 | 23300 466 466 1 0:00:14.933714 236 | 23400 468 468 1 0:00:15.138598 237 | 23500 470 624 0 0:00:15.030802 238 | 23600 472 472 1 0:00:15.028487 239 | 23700 474 841 0 0:00:14.842982 240 | 23800 476 476 1 0:00:15.076778 241 | 23900 478 478 1 0:00:14.985259 242 | 24000 480 509 0 0:00:14.983824 243 | 24100 482 481 0 0:00:14.831785 244 | 24200 484 871 0 0:00:15.156861 245 | 24300 486 486 1 0:00:15.030223 246 | 24400 488 488 1 0:00:14.944442 247 | 24500 490 490 1 0:00:15.050789 248 | 24600 492 492 1 0:00:15.071643 249 | 24700 494 398 0 0:00:14.878071 250 | 24800 496 496 1 0:00:15.084099 251 | 24900 498 498 1 0:00:14.986719 252 | 25000 500 500 1 0:00:15.008059 253 | 25100 502 502 1 0:00:14.938144 254 | 25200 504 504 1 0:00:14.811468 255 | 25300 506 506 1 0:00:15.105104 256 | 25400 508 508 1 0:00:14.966163 257 | 25500 510 510 1 0:00:15.080637 258 | 25600 512 740 0 0:00:14.948434 259 | 25700 514 514 1 0:00:14.932904 260 | 25800 516 431 0 0:00:14.829576 261 | 25900 518 518 1 0:00:14.968691 262 | 26000 520 520 1 0:00:14.801010 263 | 26100 522 522 1 0:00:14.894444 264 | 26200 524 461 0 0:00:15.070504 265 | 26300 526 526 1 0:00:14.958834 266 | 26400 528 478 0 0:00:14.938254 267 | 26500 530 531 0 0:00:14.946788 268 | 26600 532 532 1 0:00:15.041339 269 | 26700 534 534 1 0:00:14.967383 270 | 26800 536 403 0 0:00:14.814197 271 | 26900 538 538 1 0:00:14.946787 272 | 27000 540 540 1 0:00:14.909997 273 | 27100 542 477 0 0:00:15.057261 274 | 27200 544 926 0 0:00:14.985470 275 | 27300 546 546 1 0:00:15.035207 276 | 27400 548 548 1 0:00:14.893974 277 | 27500 550 505 0 0:00:14.990738 278 | 27600 552 552 1 0:00:15.071906 279 | 27700 554 554 1 0:00:14.799304 280 | 27800 556 421 0 0:00:15.033077 281 | 27900 558 251 0 0:00:15.025178 282 | 28000 560 768 0 0:00:14.976117 283 | 28100 562 562 1 0:00:14.799901 284 | 28200 564 564 1 0:00:14.838296 285 | 28300 566 566 1 0:00:15.089116 286 | 28400 568 399 0 0:00:14.900311 287 | 28500 570 -1 0 0:00:15.044947 288 | 28600 572 418 0 0:00:14.849125 289 | 28700 574 574 1 0:00:14.926119 290 | 28800 576 -1 0 0:00:14.965969 291 | 28900 578 578 1 0:00:14.859858 292 | 29000 580 738 0 0:00:15.022806 293 | 29100 582 582 1 0:00:14.958454 294 | 29200 584 754 0 0:00:14.904257 295 | 29300 586 586 1 0:00:15.074695 296 | 29400 588 588 1 0:00:14.977583 297 | 29500 590 590 1 0:00:14.941264 298 | 29600 592 592 1 0:00:14.933073 299 | 29700 594 843 0 0:00:15.028679 300 | 29800 596 56 0 0:00:15.067989 301 | 29900 598 664 0 0:00:14.945151 302 | 30000 600 517 0 0:00:14.936215 303 | 30100 602 602 1 0:00:14.965253 304 | 30200 604 604 1 0:00:14.856193 305 | 30300 606 606 1 0:00:15.016406 306 | 30400 608 608 1 0:00:15.112201 307 | 30500 610 800 0 0:00:14.789717 308 | 30600 612 612 1 0:00:15.035326 309 | 30700 614 614 1 0:00:15.098388 310 | 30800 616 534 0 0:00:14.907665 311 | 30900 618 828 0 0:00:14.832339 312 | 31000 620 681 0 0:00:15.147113 313 | 31100 622 622 1 0:00:14.945833 314 | 31200 624 453 0 0:00:14.915479 315 | 31300 626 542 0 0:00:14.887706 316 | 31400 628 624 0 0:00:15.180945 317 | 31500 630 703 0 0:00:14.891687 318 | 31600 632 632 1 0:00:15.057439 319 | 31700 634 489 0 0:00:15.016377 320 | 31800 636 636 1 0:00:14.963135 321 | 31900 638 638 1 0:00:14.799041 322 | 32000 640 640 1 0:00:14.927938 323 | 32100 642 642 1 0:00:15.120875 324 | 32200 644 644 1 0:00:14.877543 325 | 32300 646 646 1 0:00:14.940699 326 | 32400 648 729 0 0:00:14.874189 327 | 32500 650 546 0 0:00:15.097237 328 | 32600 652 652 1 0:00:14.962884 329 | 32700 654 436 0 0:00:14.872689 330 | 32800 656 656 1 0:00:15.005286 331 | 32900 658 658 1 0:00:14.917542 332 | 33000 660 660 1 0:00:14.839606 333 | 33100 662 662 1 0:00:14.892756 334 | 33200 664 851 0 0:00:14.911993 335 | 33300 666 659 0 0:00:14.796421 336 | 33400 668 668 1 0:00:14.824733 337 | 33500 670 670 1 0:00:15.041363 338 | 33600 672 672 1 0:00:15.056527 339 | 33700 674 674 1 0:00:14.825686 340 | 33800 676 560 0 0:00:14.931751 341 | 33900 678 678 1 0:00:14.853488 342 | 34000 680 793 0 0:00:14.958912 343 | 34100 682 682 1 0:00:15.030425 344 | 34200 684 684 1 0:00:14.845501 345 | 34300 686 550 0 0:00:15.033074 346 | 34400 688 688 1 0:00:14.988053 347 | 34500 690 690 1 0:00:14.943735 348 | 34600 692 481 0 0:00:14.928473 349 | 34700 694 694 1 0:00:14.829778 350 | 34800 696 -1 0 0:00:15.074387 351 | 34900 698 698 1 0:00:15.019027 352 | 35000 700 700 1 0:00:15.032050 353 | 35100 702 422 0 0:00:15.006292 354 | 35200 704 704 1 0:00:14.928123 355 | 35300 706 536 0 0:00:14.818038 356 | 35400 708 682 0 0:00:15.055249 357 | 35500 710 710 1 0:00:14.948738 358 | 35600 712 622 0 0:00:15.094625 359 | 35700 714 714 1 0:00:14.915263 360 | 35800 716 716 1 0:00:14.976116 361 | 35900 718 449 0 0:00:14.864276 362 | 36000 720 720 1 0:00:15.096854 363 | 36100 722 722 1 0:00:14.921822 364 | 36200 724 780 0 0:00:14.832085 365 | 36300 726 726 1 0:00:15.146349 366 | 36400 728 728 1 0:00:14.939208 367 | 36500 730 856 0 0:00:14.883153 368 | 36600 732 732 1 0:00:15.047479 369 | 36700 734 734 1 0:00:14.925721 370 | 36800 736 736 1 0:00:14.927071 371 | 36900 738 738 1 0:00:14.864711 372 | 37000 740 796 0 0:00:15.075991 373 | 37100 742 466 0 0:00:14.987401 374 | 37200 744 657 0 0:00:14.830023 375 | 37300 746 746 1 0:00:14.953343 376 | 37400 748 748 1 0:00:14.797044 377 | 37500 750 172 0 0:00:14.995230 378 | 37600 752 830 0 0:00:14.965986 379 | 37700 754 848 0 0:00:14.849808 380 | 37800 756 756 1 0:00:14.833296 381 | 37900 758 758 1 0:00:14.914641 382 | 38000 760 651 0 0:00:14.921158 383 | 38100 762 762 1 0:00:15.031013 384 | 38200 764 593 0 0:00:14.855857 385 | 38300 766 766 1 0:00:14.836674 386 | 38400 768 768 1 0:00:14.937352 387 | 38500 770 770 1 0:00:14.985036 388 | 38600 772 772 1 0:00:15.019809 389 | 38700 774 774 1 0:00:14.958676 390 | 38800 776 683 0 0:00:14.881106 391 | 38900 778 707 0 0:00:15.088828 392 | 39000 780 780 1 0:00:14.941449 393 | 39100 782 664 0 0:00:14.912092 394 | 39200 784 505 0 0:00:15.034298 395 | 39300 786 513 0 0:00:14.912165 396 | 39400 788 743 0 0:00:14.926766 397 | 39500 790 791 0 0:00:14.947411 398 | 39600 792 795 0 0:00:14.923461 399 | 39700 794 794 1 0:00:15.040500 400 | 39800 796 796 1 0:00:14.828018 401 | 39900 798 798 1 0:00:15.105424 402 | 40000 800 800 1 0:00:15.035649 403 | 40100 802 802 1 0:00:14.857424 404 | 40200 804 804 1 0:00:14.936726 405 | 40300 806 806 1 0:00:14.901794 406 | 40400 808 808 1 0:00:14.920217 407 | 40500 810 878 0 0:00:14.899759 408 | 40600 812 812 1 0:00:14.947197 409 | 40700 814 814 1 0:00:14.900511 410 | 40800 816 816 1 0:00:15.095273 411 | 40900 818 745 0 0:00:15.008224 412 | 41000 820 820 1 0:00:15.002912 413 | 41100 822 822 1 0:00:14.927240 414 | 41200 824 824 1 0:00:14.942341 415 | 41300 826 826 1 0:00:14.871397 416 | 41400 828 618 0 0:00:14.840422 417 | 41500 830 830 1 0:00:14.942988 418 | 41600 832 832 1 0:00:15.013025 419 | 41700 834 417 0 0:00:15.061307 420 | 41800 836 444 0 0:00:14.924838 421 | 41900 838 631 0 0:00:14.804935 422 | 42000 840 840 1 0:00:14.906883 423 | 42100 842 842 1 0:00:14.838626 424 | 42200 844 844 1 0:00:15.125461 425 | 42300 846 583 0 0:00:14.892340 426 | 42400 848 848 1 0:00:14.995593 427 | 42500 850 850 1 0:00:15.004535 428 | 42600 852 852 1 0:00:15.016846 429 | 42700 854 854 1 0:00:14.969559 430 | 42800 856 856 1 0:00:14.853972 431 | 42900 858 832 0 0:00:15.011433 432 | 43000 860 892 0 0:00:14.879105 433 | 43100 862 862 1 0:00:14.824035 434 | 43200 864 864 1 0:00:15.063193 435 | 43300 866 866 1 0:00:15.115438 436 | 43400 868 968 0 0:00:15.104931 437 | 43500 870 870 1 0:00:14.852551 438 | 43600 872 840 0 0:00:15.085425 439 | 43700 874 874 1 0:00:15.004121 440 | 43800 876 648 0 0:00:14.948926 441 | 43900 878 878 1 0:00:14.924355 442 | 44000 880 870 0 0:00:15.053208 443 | 44100 882 882 1 0:00:14.870412 444 | 44200 884 884 1 0:00:15.067243 445 | 44300 886 886 1 0:00:15.063395 446 | 44400 888 718 0 0:00:14.863382 447 | 44500 890 890 1 0:00:14.956937 448 | 44600 892 442 0 0:00:15.065577 449 | 44700 894 894 1 0:00:14.913900 450 | 44800 896 896 1 0:00:14.845646 451 | 44900 898 508 0 0:00:15.041700 452 | 45000 900 900 1 0:00:15.036083 453 | 45100 902 902 1 0:00:14.869172 454 | 45200 904 905 0 0:00:14.926249 455 | 45300 906 906 1 0:00:14.980194 456 | 45400 908 895 0 0:00:14.923142 457 | 45500 910 910 1 0:00:15.059851 458 | 45600 912 912 1 0:00:14.834315 459 | 45700 914 914 1 0:00:14.801381 460 | 45800 916 916 1 0:00:15.072167 461 | 45900 918 918 1 0:00:14.886890 462 | 46000 920 672 0 0:00:14.905368 463 | 46100 922 922 1 0:00:15.158672 464 | 46200 924 924 1 0:00:14.959707 465 | 46300 926 926 1 0:00:15.009355 466 | 46400 928 927 0 0:00:14.835334 467 | 46500 930 964 0 0:00:14.963948 468 | 46600 932 932 1 0:00:14.963856 469 | 46700 934 934 1 0:00:15.037371 470 | 46800 936 936 1 0:00:14.889781 471 | 46900 938 938 1 0:00:15.001535 472 | 47000 940 940 1 0:00:15.078541 473 | 47100 942 942 1 0:00:14.961972 474 | 47200 944 946 0 0:00:14.824271 475 | 47300 946 946 1 0:00:15.006133 476 | 47400 948 948 1 0:00:14.943808 477 | 47500 950 950 1 0:00:14.915371 478 | 47600 952 90 0 0:00:14.829384 479 | 47700 954 954 1 0:00:14.983240 480 | 47800 956 956 1 0:00:15.007075 481 | 47900 958 483 0 0:00:14.995762 482 | 48000 960 960 1 0:00:14.903471 483 | 48100 962 962 1 0:00:14.842755 484 | 48200 964 961 0 0:00:15.000199 485 | 48300 966 630 0 0:00:14.848870 486 | 48400 968 968 1 0:00:15.047162 487 | 48500 970 795 0 0:00:14.873642 488 | 48600 972 979 0 0:00:14.915262 489 | 48700 974 974 1 0:00:15.077598 490 | 48800 976 724 0 0:00:14.963078 491 | 48900 978 701 0 0:00:14.855375 492 | 49000 980 980 1 0:00:14.798299 493 | 49100 982 982 1 0:00:15.025065 494 | 49200 984 984 1 0:00:14.829287 495 | 49300 986 986 1 0:00:14.985315 496 | 49400 988 988 1 0:00:14.922611 497 | 49500 990 990 1 0:00:15.019418 498 | 49600 992 992 1 0:00:14.923599 499 | 49700 994 994 1 0:00:14.888789 500 | 49800 996 996 1 0:00:14.961667 501 | 49900 998 998 1 0:00:14.941916 502 | -------------------------------------------------------------------------------- /data/predict/imagenet/resnet50/noise_0.25/test/N_100000: -------------------------------------------------------------------------------- 1 | idx label predict correct time 2 | 0 0 0 1 0:02:33.474448 3 | 100 2 2 1 0:02:31.267941 4 | 200 4 395 0 0:02:30.693998 5 | 300 6 6 1 0:02:31.397363 6 | 400 8 8 1 0:02:31.735987 7 | 500 10 10 1 0:02:31.113407 8 | 600 12 12 1 0:02:30.951247 9 | 700 14 14 1 0:02:31.248816 10 | 800 16 16 1 0:02:31.478386 11 | 900 18 18 1 0:02:31.303808 12 | 1000 20 20 1 0:02:31.340990 13 | 1100 22 21 0 0:02:31.461657 14 | 1200 24 24 1 0:02:31.166599 15 | 1300 26 26 1 0:02:31.435601 16 | 1400 28 28 1 0:02:31.746241 17 | 1500 30 30 1 0:02:31.415279 18 | 1600 32 103 0 0:02:31.048611 19 | 1700 34 327 0 0:02:31.615164 20 | 1800 36 35 0 0:02:31.247375 21 | 1900 38 38 1 0:02:31.175008 22 | 2000 40 31 0 0:02:31.079274 23 | 2100 42 48 0 0:02:31.412859 24 | 2200 44 44 1 0:02:31.535267 25 | 2300 46 46 1 0:02:31.560232 26 | 2400 48 48 1 0:02:31.527727 27 | 2500 50 50 1 0:02:31.395422 28 | 2600 52 52 1 0:02:31.524686 29 | 2700 54 57 0 0:02:31.278730 30 | 2800 56 56 1 0:02:31.561788 31 | 2900 58 58 1 0:02:31.469573 32 | 3000 60 62 0 0:02:31.659518 33 | 3100 62 62 1 0:02:31.540611 34 | 3200 64 64 1 0:02:31.528662 35 | 3300 66 63 0 0:02:31.605767 36 | 3400 68 111 0 0:02:31.528680 37 | 3500 70 70 1 0:02:31.388872 38 | 3600 72 72 1 0:02:31.339453 39 | 3700 74 74 1 0:02:31.459056 40 | 3800 76 76 1 0:02:31.597488 41 | 3900 78 78 1 0:02:31.864876 42 | 4000 80 80 1 0:02:31.824401 43 | 4100 82 82 1 0:02:31.586974 44 | 4200 84 84 1 0:02:31.467104 45 | 4300 86 86 1 0:02:31.449934 46 | 4400 88 88 1 0:02:31.694547 47 | 4500 90 90 1 0:02:31.714680 48 | 4600 92 92 1 0:02:31.601055 49 | 4700 94 94 1 0:02:31.759537 50 | 4800 96 96 1 0:02:31.570559 51 | 4900 98 98 1 0:02:30.771919 52 | 5000 100 100 1 0:02:29.871229 53 | 5100 102 102 1 0:02:29.855997 54 | 5200 104 342 0 0:02:30.007782 55 | 5300 106 106 1 0:02:29.889595 56 | 5400 108 470 0 0:02:29.999557 57 | 5500 110 110 1 0:02:30.123106 58 | 5600 112 112 1 0:02:30.096990 59 | 5700 114 113 0 0:02:29.985586 60 | 5800 116 116 1 0:02:30.067943 61 | 5900 118 118 1 0:02:30.032640 62 | 6000 120 119 0 0:02:29.894344 63 | 6100 122 122 1 0:02:30.045305 64 | 6200 124 123 0 0:02:29.928218 65 | 6300 126 126 1 0:02:29.826290 66 | 6400 128 128 1 0:02:29.844152 67 | 6500 130 130 1 0:02:29.743992 68 | 6600 132 132 1 0:02:29.878203 69 | 6700 134 132 0 0:02:29.878030 70 | 6800 136 136 1 0:02:29.977701 71 | 6900 138 138 1 0:02:29.961860 72 | 7000 140 140 1 0:02:29.987991 73 | 7100 142 142 1 0:02:29.983496 74 | 7200 144 144 1 0:02:29.934271 75 | 7300 146 146 1 0:02:30.003505 76 | 7400 148 148 1 0:02:30.013247 77 | 7500 150 360 0 0:02:30.092299 78 | 7600 152 152 1 0:02:30.050198 79 | 7700 154 154 1 0:02:29.898821 80 | 7800 156 156 1 0:02:29.955264 81 | 7900 158 158 1 0:02:30.089475 82 | 8000 160 160 1 0:02:30.007120 83 | 8100 162 162 1 0:02:29.897341 84 | 8200 164 164 1 0:02:29.929127 85 | 8300 166 166 1 0:02:29.948314 86 | 8400 168 211 0 0:02:29.853677 87 | 8500 170 172 0 0:02:29.910112 88 | 8600 172 172 1 0:02:29.825488 89 | 8700 174 174 1 0:02:29.876307 90 | 8800 176 176 1 0:02:29.820867 91 | 8900 178 178 1 0:02:29.995679 92 | 9000 180 243 0 0:02:29.832195 93 | 9100 182 182 1 0:02:29.864980 94 | 9200 184 188 0 0:02:29.972162 95 | 9300 186 192 0 0:02:29.873754 96 | 9400 188 188 1 0:02:29.842876 97 | 9500 190 190 1 0:02:29.861335 98 | 9600 192 199 0 0:02:30.022910 99 | 9700 194 202 0 0:02:30.002872 100 | 9800 196 199 0 0:02:30.027588 101 | 9900 198 196 0 0:02:29.992388 102 | 10000 200 200 1 0:02:29.942592 103 | 10100 202 202 1 0:02:29.897604 104 | 10200 204 204 1 0:02:29.856708 105 | 10300 206 206 1 0:02:29.787749 106 | 10400 208 208 1 0:02:29.881501 107 | 10500 210 210 1 0:02:29.846165 108 | 10600 212 212 1 0:02:29.949987 109 | 10700 214 214 1 0:02:29.946611 110 | 10800 216 216 1 0:02:29.918387 111 | 10900 218 218 1 0:02:30.054871 112 | 11000 220 220 1 0:02:30.114067 113 | 11100 222 222 1 0:02:29.941162 114 | 11200 224 233 0 0:02:30.051554 115 | 11300 226 226 1 0:02:30.096485 116 | 11400 228 228 1 0:02:29.820301 117 | 11500 230 230 1 0:02:29.818161 118 | 11600 232 232 1 0:02:30.023631 119 | 11700 234 234 1 0:02:29.966839 120 | 11800 236 234 0 0:02:30.053770 121 | 11900 238 238 1 0:02:29.941641 122 | 12000 240 238 0 0:02:29.944469 123 | 12100 242 242 1 0:02:29.903479 124 | 12200 244 244 1 0:02:30.013111 125 | 12300 246 251 0 0:02:29.945586 126 | 12400 248 248 1 0:02:30.064157 127 | 12500 250 169 0 0:02:29.971363 128 | 12600 252 252 1 0:02:30.167051 129 | 12700 254 254 1 0:02:30.059607 130 | 12800 256 256 1 0:02:30.098028 131 | 12900 258 258 1 0:02:30.070731 132 | 13000 260 260 1 0:02:30.081086 133 | 13100 262 262 1 0:02:30.182859 134 | 13200 264 264 1 0:02:30.012478 135 | 13300 266 266 1 0:02:29.938711 136 | 13400 268 268 1 0:02:30.146840 137 | 13500 270 270 1 0:02:30.072387 138 | 13600 272 274 0 0:02:30.042817 139 | 13700 274 274 1 0:02:30.089741 140 | 13800 276 276 1 0:02:30.056051 141 | 13900 278 278 1 0:02:30.044791 142 | 14000 280 280 1 0:02:30.148042 143 | 14100 282 282 1 0:02:30.111140 144 | 14200 284 284 1 0:02:30.040564 145 | 14300 286 286 1 0:02:30.077097 146 | 14400 288 288 1 0:02:30.130719 147 | 14500 290 290 1 0:02:30.088434 148 | 14600 292 290 0 0:02:30.055848 149 | 14700 294 294 1 0:02:29.988939 150 | 14800 296 296 1 0:02:29.772256 151 | 14900 298 298 1 0:02:29.949858 152 | 15000 300 300 1 0:02:30.033037 153 | 15100 302 302 1 0:02:30.070987 154 | 15200 304 301 0 0:02:30.089978 155 | 15300 306 306 1 0:02:30.067981 156 | 15400 308 308 1 0:02:30.042348 157 | 15500 310 310 1 0:02:30.006768 158 | 15600 312 315 0 0:02:30.003502 159 | 15700 314 314 1 0:02:29.950488 160 | 15800 316 316 1 0:02:30.047839 161 | 15900 318 318 1 0:02:30.017167 162 | 16000 320 320 1 0:02:29.991655 163 | 16100 322 322 1 0:02:30.104580 164 | 16200 324 324 1 0:02:29.992422 165 | 16300 326 326 1 0:02:29.870930 166 | 16400 328 328 1 0:02:29.891489 167 | 16500 330 330 1 0:02:29.890625 168 | 16600 332 153 0 0:02:29.929097 169 | 16700 334 334 1 0:02:29.997736 170 | 16800 336 336 1 0:02:30.078620 171 | 16900 338 617 0 0:02:29.958100 172 | 17000 340 340 1 0:02:29.913225 173 | 17100 342 287 0 0:02:29.886427 174 | 17200 344 344 1 0:02:30.053031 175 | 17300 346 344 0 0:02:29.932767 176 | 17400 348 348 1 0:02:30.012097 177 | 17500 350 350 1 0:02:30.171809 178 | 17600 352 352 1 0:02:30.087436 179 | 17700 354 354 1 0:02:29.947427 180 | 17800 356 357 0 0:02:30.191751 181 | 17900 358 359 0 0:02:30.051043 182 | 18000 360 360 1 0:02:30.036187 183 | 18100 362 362 1 0:02:29.914627 184 | 18200 364 364 1 0:02:29.950982 185 | 18300 366 366 1 0:02:29.918943 186 | 18400 368 368 1 0:02:29.854911 187 | 18500 370 370 1 0:02:29.836591 188 | 18600 372 372 1 0:02:29.889367 189 | 18700 374 374 1 0:02:29.889299 190 | 18800 376 376 1 0:02:29.975704 191 | 18900 378 378 1 0:02:30.087036 192 | 19000 380 380 1 0:02:30.009402 193 | 19100 382 382 1 0:02:29.955765 194 | 19200 384 283 0 0:02:30.055064 195 | 19300 386 101 0 0:02:29.988809 196 | 19400 388 388 1 0:02:29.895886 197 | 19500 390 390 1 0:02:29.903945 198 | 19600 392 397 0 0:02:30.082440 199 | 19700 394 467 0 0:02:29.983311 200 | 19800 396 396 1 0:02:29.908767 201 | 19900 398 398 1 0:02:30.002472 202 | 20000 400 400 1 0:02:30.054005 203 | 20100 402 402 1 0:02:29.959074 204 | 20200 404 404 1 0:02:29.876369 205 | 20300 406 406 1 0:02:30.077703 206 | 20400 408 847 0 0:02:30.047797 207 | 20500 410 410 1 0:02:30.044914 208 | 20600 412 412 1 0:02:30.040994 209 | 20700 414 414 1 0:02:29.985378 210 | 20800 416 416 1 0:02:29.940233 211 | 20900 418 563 0 0:02:29.924270 212 | 21000 420 420 1 0:02:29.946433 213 | 21100 422 422 1 0:02:29.951842 214 | 21200 424 454 0 0:02:29.952231 215 | 21300 426 426 1 0:02:29.728515 216 | 21400 428 428 1 0:02:29.755238 217 | 21500 430 430 1 0:02:29.863588 218 | 21600 432 432 1 0:02:29.885156 219 | 21700 434 434 1 0:02:29.853207 220 | 21800 436 436 1 0:02:29.946874 221 | 21900 438 438 1 0:02:29.944616 222 | 22000 440 737 0 0:02:29.930953 223 | 22100 442 442 1 0:02:29.986371 224 | 22200 444 870 0 0:02:29.998457 225 | 22300 446 446 1 0:02:29.934899 226 | 22400 448 448 1 0:02:29.868617 227 | 22500 450 407 0 0:02:29.756166 228 | 22600 452 452 1 0:02:29.863274 229 | 22700 454 454 1 0:02:29.971981 230 | 22800 456 645 0 0:02:30.003376 231 | 22900 458 458 1 0:02:29.914280 232 | 23000 460 978 0 0:02:30.226463 233 | 23100 462 462 1 0:02:29.935791 234 | 23200 464 439 0 0:02:29.818337 235 | 23300 466 466 1 0:02:29.930532 236 | 23400 468 468 1 0:02:29.931208 237 | 23500 470 624 0 0:02:30.035671 238 | 23600 472 472 1 0:02:29.943951 239 | 23700 474 841 0 0:02:29.934437 240 | 23800 476 476 1 0:02:30.122416 241 | 23900 478 478 1 0:02:29.828464 242 | 24000 480 509 0 0:02:30.027221 243 | 24100 482 481 0 0:02:29.997922 244 | 24200 484 871 0 0:02:29.806696 245 | 24300 486 486 1 0:02:29.874212 246 | 24400 488 488 1 0:02:29.793973 247 | 24500 490 490 1 0:02:29.869747 248 | 24600 492 492 1 0:02:29.841316 249 | 24700 494 398 0 0:02:29.811148 250 | 24800 496 496 1 0:02:29.910695 251 | 24900 498 498 1 0:02:29.966648 252 | 25000 500 500 1 0:02:29.991739 253 | 25100 502 502 1 0:02:30.041048 254 | 25200 504 504 1 0:02:30.038781 255 | 25300 506 506 1 0:02:29.979415 256 | 25400 508 508 1 0:02:30.084711 257 | 25500 510 510 1 0:02:30.006354 258 | 25600 512 740 0 0:02:30.084266 259 | 25700 514 514 1 0:02:30.032617 260 | 25800 516 431 0 0:02:29.789236 261 | 25900 518 518 1 0:02:29.920027 262 | 26000 520 520 1 0:02:29.996022 263 | 26100 522 522 1 0:02:30.083467 264 | 26200 524 461 0 0:02:30.074240 265 | 26300 526 526 1 0:02:29.917449 266 | 26400 528 478 0 0:02:29.928078 267 | 26500 530 531 0 0:02:29.899624 268 | 26600 532 532 1 0:02:30.009423 269 | 26700 534 534 1 0:02:30.048200 270 | 26800 536 403 0 0:02:30.077787 271 | 26900 538 538 1 0:02:30.006213 272 | 27000 540 540 1 0:02:30.003153 273 | 27100 542 477 0 0:02:30.079738 274 | 27200 544 926 0 0:02:29.935786 275 | 27300 546 546 1 0:02:30.040292 276 | 27400 548 548 1 0:02:29.942731 277 | 27500 550 505 0 0:02:30.162599 278 | 27600 552 552 1 0:02:30.197462 279 | 27700 554 554 1 0:02:29.881878 280 | 27800 556 421 0 0:02:30.003082 281 | 27900 558 251 0 0:02:29.969123 282 | 28000 560 768 0 0:02:30.061316 283 | 28100 562 562 1 0:02:29.891487 284 | 28200 564 564 1 0:02:29.852345 285 | 28300 566 566 1 0:02:29.913517 286 | 28400 568 399 0 0:02:29.937221 287 | 28500 570 -1 0 0:02:29.939591 288 | 28600 572 418 0 0:02:30.063215 289 | 28700 574 574 1 0:02:29.918890 290 | 28800 576 576 1 0:02:30.023425 291 | 28900 578 578 1 0:02:29.880543 292 | 29000 580 738 0 0:02:29.832065 293 | 29100 582 582 1 0:02:30.020106 294 | 29200 584 754 0 0:02:30.010541 295 | 29300 586 586 1 0:02:29.950467 296 | 29400 588 588 1 0:02:29.812419 297 | 29500 590 590 1 0:02:29.732527 298 | 29600 592 592 1 0:02:29.842051 299 | 29700 594 843 0 0:02:29.837906 300 | 29800 596 56 0 0:02:29.837575 301 | 29900 598 664 0 0:02:29.940911 302 | 30000 600 517 0 0:02:29.977010 303 | 30100 602 602 1 0:02:29.953506 304 | 30200 604 604 1 0:02:30.107783 305 | 30300 606 606 1 0:02:30.022856 306 | 30400 608 608 1 0:02:30.136491 307 | 30500 610 800 0 0:02:29.965505 308 | 30600 612 612 1 0:02:30.079032 309 | 30700 614 614 1 0:02:30.212737 310 | 30800 616 534 0 0:02:30.179169 311 | 30900 618 828 0 0:02:29.963279 312 | 31000 620 681 0 0:02:30.197027 313 | 31100 622 622 1 0:02:29.996470 314 | 31200 624 453 0 0:02:30.116919 315 | 31300 626 542 0 0:02:30.134142 316 | 31400 628 624 0 0:02:30.071232 317 | 31500 630 703 0 0:02:30.077622 318 | 31600 632 632 1 0:02:29.860154 319 | 31700 634 489 0 0:02:30.070970 320 | 31800 636 636 1 0:02:30.025989 321 | 31900 638 638 1 0:02:30.038827 322 | 32000 640 640 1 0:02:29.923118 323 | 32100 642 642 1 0:02:29.869806 324 | 32200 644 644 1 0:02:29.827098 325 | 32300 646 646 1 0:02:29.858985 326 | 32400 648 729 0 0:02:29.997431 327 | 32500 650 546 0 0:02:30.006125 328 | 32600 652 652 1 0:02:29.955058 329 | 32700 654 436 0 0:02:29.910161 330 | 32800 656 656 1 0:02:29.853518 331 | 32900 658 658 1 0:02:29.789470 332 | 33000 660 660 1 0:02:29.846462 333 | 33100 662 662 1 0:02:29.912663 334 | 33200 664 851 0 0:02:29.911196 335 | 33300 666 659 0 0:02:29.956095 336 | 33400 668 668 1 0:02:29.996228 337 | 33500 670 670 1 0:02:29.986417 338 | 33600 672 672 1 0:02:29.827720 339 | 33700 674 674 1 0:02:29.967539 340 | 33800 676 560 0 0:02:30.012731 341 | 33900 678 678 1 0:02:29.908836 342 | 34000 680 793 0 0:02:29.788307 343 | 34100 682 682 1 0:02:29.894981 344 | 34200 684 684 1 0:02:29.966279 345 | 34300 686 550 0 0:02:29.915007 346 | 34400 688 688 1 0:02:29.914888 347 | 34500 690 690 1 0:02:30.002952 348 | 34600 692 481 0 0:02:29.898020 349 | 34700 694 694 1 0:02:29.990453 350 | 34800 696 -1 0 0:02:29.979518 351 | 34900 698 698 1 0:02:29.978442 352 | 35000 700 700 1 0:02:29.829307 353 | 35100 702 422 0 0:02:29.926219 354 | 35200 704 704 1 0:02:30.086647 355 | 35300 706 536 0 0:02:30.101645 356 | 35400 708 682 0 0:02:29.961313 357 | 35500 710 710 1 0:02:30.063833 358 | 35600 712 622 0 0:02:29.946512 359 | 35700 714 714 1 0:02:29.815917 360 | 35800 716 716 1 0:02:29.909783 361 | 35900 718 449 0 0:02:29.789578 362 | 36000 720 720 1 0:02:29.964153 363 | 36100 722 722 1 0:02:29.866087 364 | 36200 724 780 0 0:02:29.745715 365 | 36300 726 726 1 0:02:29.968047 366 | 36400 728 728 1 0:02:29.892893 367 | 36500 730 856 0 0:02:29.996930 368 | 36600 732 732 1 0:02:30.164158 369 | 36700 734 734 1 0:02:30.062919 370 | 36800 736 736 1 0:02:29.980641 371 | 36900 738 738 1 0:02:30.107528 372 | 37000 740 796 0 0:02:30.099491 373 | 37100 742 466 0 0:02:29.944660 374 | 37200 744 657 0 0:02:29.855458 375 | 37300 746 746 1 0:02:29.868380 376 | 37400 748 748 1 0:02:29.954246 377 | 37500 750 172 0 0:02:29.896917 378 | 37600 752 830 0 0:02:29.989432 379 | 37700 754 848 0 0:02:29.823816 380 | 37800 756 756 1 0:02:29.879869 381 | 37900 758 758 1 0:02:29.909168 382 | 38000 760 651 0 0:02:29.897002 383 | 38100 762 762 1 0:02:29.806436 384 | 38200 764 593 0 0:02:29.910020 385 | 38300 766 766 1 0:02:29.998693 386 | 38400 768 768 1 0:02:30.028741 387 | 38500 770 770 1 0:02:29.972508 388 | 38600 772 772 1 0:02:29.958522 389 | 38700 774 774 1 0:02:29.937070 390 | 38800 776 683 0 0:02:29.718288 391 | 38900 778 707 0 0:02:29.885120 392 | 39000 780 780 1 0:02:29.704233 393 | 39100 782 664 0 0:02:29.688010 394 | 39200 784 505 0 0:02:29.858668 395 | 39300 786 513 0 0:02:29.847906 396 | 39400 788 743 0 0:02:29.873123 397 | 39500 790 791 0 0:02:29.793333 398 | 39600 792 795 0 0:02:30.019487 399 | 39700 794 794 1 0:02:30.052389 400 | 39800 796 796 1 0:02:29.873432 401 | 39900 798 798 1 0:02:29.815152 402 | 40000 800 800 1 0:02:29.987385 403 | 40100 802 802 1 0:02:29.892680 404 | 40200 804 804 1 0:02:29.874876 405 | 40300 806 806 1 0:02:29.806446 406 | 40400 808 808 1 0:02:29.953785 407 | 40500 810 878 0 0:02:29.862812 408 | 40600 812 812 1 0:02:30.005352 409 | 40700 814 814 1 0:02:29.910957 410 | 40800 816 816 1 0:02:29.965954 411 | 40900 818 745 0 0:02:30.010599 412 | 41000 820 820 1 0:02:29.833271 413 | 41100 822 822 1 0:02:29.792163 414 | 41200 824 824 1 0:02:29.873111 415 | 41300 826 826 1 0:02:29.984101 416 | 41400 828 618 0 0:02:29.830313 417 | 41500 830 830 1 0:02:29.967694 418 | 41600 832 832 1 0:02:29.843198 419 | 41700 834 417 0 0:02:29.856908 420 | 41800 836 444 0 0:02:29.943907 421 | 41900 838 631 0 0:02:29.875778 422 | 42000 840 840 1 0:02:29.994829 423 | 42100 842 842 1 0:02:30.153473 424 | 42200 844 844 1 0:02:29.808288 425 | 42300 846 583 0 0:02:30.014881 426 | 42400 848 848 1 0:02:30.071110 427 | 42500 850 850 1 0:02:30.020923 428 | 42600 852 852 1 0:02:29.896731 429 | 42700 854 854 1 0:02:29.905206 430 | 42800 856 856 1 0:02:29.816986 431 | 42900 858 832 0 0:02:29.966872 432 | 43000 860 892 0 0:02:30.035207 433 | 43100 862 862 1 0:02:29.931289 434 | 43200 864 864 1 0:02:29.939961 435 | 43300 866 866 1 0:02:29.974598 436 | 43400 868 968 0 0:02:29.877698 437 | 43500 870 870 1 0:02:29.989906 438 | 43600 872 840 0 0:02:29.967932 439 | 43700 874 874 1 0:02:29.878560 440 | 43800 876 648 0 0:02:29.746216 441 | 43900 878 878 1 0:02:29.868720 442 | 44000 880 870 0 0:02:29.887129 443 | 44100 882 882 1 0:02:29.831625 444 | 44200 884 884 1 0:02:29.907423 445 | 44300 886 886 1 0:02:29.848836 446 | 44400 888 718 0 0:02:29.699474 447 | 44500 890 890 1 0:02:30.080145 448 | 44600 892 442 0 0:02:29.891949 449 | 44700 894 894 1 0:02:29.743512 450 | 44800 896 896 1 0:02:29.778889 451 | 44900 898 508 0 0:02:29.825377 452 | 45000 900 900 1 0:02:29.843363 453 | 45100 902 902 1 0:02:29.768915 454 | 45200 904 905 0 0:02:29.945003 455 | 45300 906 906 1 0:02:29.963234 456 | 45400 908 895 0 0:02:29.794284 457 | 45500 910 910 1 0:02:29.881790 458 | 45600 912 912 1 0:02:29.877916 459 | 45700 914 914 1 0:02:29.756244 460 | 45800 916 916 1 0:02:29.751879 461 | 45900 918 918 1 0:02:29.787862 462 | 46000 920 672 0 0:02:29.753602 463 | 46100 922 922 1 0:02:29.704010 464 | 46200 924 924 1 0:02:29.787180 465 | 46300 926 926 1 0:02:29.900919 466 | 46400 928 927 0 0:02:29.837027 467 | 46500 930 964 0 0:02:29.734897 468 | 46600 932 932 1 0:02:29.875139 469 | 46700 934 934 1 0:02:29.888191 470 | 46800 936 936 1 0:02:29.858023 471 | 46900 938 938 1 0:02:29.856562 472 | 47000 940 940 1 0:02:29.896224 473 | 47100 942 942 1 0:02:29.828978 474 | 47200 944 946 0 0:02:29.884413 475 | 47300 946 946 1 0:02:29.885259 476 | 47400 948 948 1 0:02:29.742021 477 | 47500 950 950 1 0:02:30.007190 478 | 47600 952 90 0 0:02:29.852872 479 | 47700 954 954 1 0:02:29.902331 480 | 47800 956 956 1 0:02:29.907287 481 | 47900 958 483 0 0:02:29.967186 482 | 48000 960 960 1 0:02:30.014844 483 | 48100 962 962 1 0:02:30.087343 484 | 48200 964 961 0 0:02:29.911554 485 | 48300 966 630 0 0:02:30.058535 486 | 48400 968 968 1 0:02:29.935223 487 | 48500 970 795 0 0:02:29.779368 488 | 48600 972 979 0 0:02:29.745436 489 | 48700 974 974 1 0:02:29.752444 490 | 48800 976 724 0 0:02:29.804450 491 | 48900 978 701 0 0:02:30.000623 492 | 49000 980 980 1 0:02:29.829837 493 | 49100 982 982 1 0:02:29.936440 494 | 49200 984 984 1 0:02:29.820193 495 | 49300 986 986 1 0:02:30.066937 496 | 49400 988 988 1 0:02:29.911113 497 | 49500 990 990 1 0:02:30.062292 498 | 49600 992 992 1 0:02:29.921232 499 | 49700 994 994 1 0:02:30.089058 500 | 49800 996 996 1 0:02:30.022316 501 | 49900 998 998 1 0:02:29.815372 502 | -------------------------------------------------------------------------------- /experiments.MD: -------------------------------------------------------------------------------- 1 | # Experiments 2 | 3 | This document describes how to replicate our results. 4 | 5 | First, train the models on ImageNet and CIFAR-10: 6 | 7 | ``` 8 | python train.py imagenet resnet50 models/imagenet/resnet50/noise_0.00 --batch 400 --noise 0.0 9 | python train.py imagenet resnet50 models/imagenet/resnet50/noise_0.25 --batch 400 --noise 0.25 10 | python train.py imagenet resnet50 models/imagenet/resnet50/noise_0.50 --batch 400 --noise 0.5 11 | python train.py imagenet resnet50 models/imagenet/resnet50/noise_1.00 --batch 400 --noise 1.0 12 | 13 | python train.py cifar10 cifar_resnet110 models/cifar10/resnet110/noise_0.00 --batch 400 --noise 0.00 --gpu [num] 14 | python train.py cifar10 cifar_resnet110 models/cifar10/resnet110/noise_0.12 --batch 400 --noise 0.12 --gpu [num] 15 | python train.py cifar10 cifar_resnet110 models/cifar10/resnet110/noise_0.25 --batch 400 --noise 0.25 --gpu [num] 16 | python train.py cifar10 cifar_resnet110 models/cifar10/resnet110/noise_0.50 --batch 400 --noise 0.50 --gpu [num] 17 | python train.py cifar10 cifar_resnet110 models/cifar10/resnet110/noise_1.00 --batch 400 --noise 1.00 --gpu [num] 18 | ``` 19 | On ImageNet, `train.py` uses all available GPUs in synchronous SGD; on CIFAR-10 it just uses one GPU. 20 | 21 | Then, certify a subsample of the test set on ImageNet and CIFAR-10: 22 | 23 | ``` 24 | python code/certify.py imagenet models/imagenet/resnet50/noise_0.25/checkpoint.pth.tar 0.25 data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25 --skip 100 --batch 400 25 | python code/certify.py imagenet models/imagenet/resnet50/noise_0.50/checkpoint.pth.tar 0.50 data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50 --skip 100 --batch 400 26 | python code/certify.py imagenet models/imagenet/resnet50/noise_1.00/checkpoint.pth.tar 1.00 data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00 --skip 100 --batch 400 27 | 28 | python code/certify.py cifar10 models/cifar10/resnet110/noise_0.12/checkpoint.pth.tar 0.25 data/predict/cifar10/resnet110/noise_0.12/test/sigma_0.12 --skip 20 --batch 400 29 | python code/certify.py cifar10 models/cifar10/resnet110/noise_0.25/checkpoint.pth.tar 0.25 data/predict/cifar10/resnet110/noise_0.25/test/sigma_0.25 --skip 20 --batch 400 30 | python code/certify.py cifar10 models/cifar10/resnet110/noise_0.50/checkpoint.pth.tar 0.50 data/predict/cifar10/resnet110/noise_0.50/test/sigma_0.50 --skip 20 --batch 400 31 | python code/certify.py cifar10 models/cifar10/resnet110/noise_1.00/checkpoint.pth.tar 1.00 data/predict/cifar10/resnet110/noise_1.00/test/sigma_1.00 --skip 20 --batch 400 32 | ``` 33 | 34 | Then try to certify when the training and testing noise is mismatched: 35 | ``` 36 | python code/certify.py imagenet models/imagenet/resnet50/noise_0.25/checkpoint.pth.tar 0.50 data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.50 --skip 100 --batch 400 37 | python code/certify.py imagenet models/imagenet/resnet50/noise_1.00/checkpoint.pth.tar 0.50 data/certify/imagenet/resnet50/noise_1.00/test/sigma_0.50 --skip 100 --batch 400 38 | 39 | python code/certify.py cifar10 models/cifar10/resnet110/noise_0.00/checkpoint.pth.tar 0.50 data/predict/cifar10/resnet110/noise_0.00/test/sigma_0.50 --skip 20 --batch 400 40 | python code/certify.py cifar10 models/cifar10/resnet110/noise_0.12/checkpoint.pth.tar 0.50 data/predict/cifar10/resnet110/noise_0.12/test/sigma_0.50 --skip 20 --batch 400 41 | python code/certify.py cifar10 models/cifar10/resnet110/noise_0.25/checkpoint.pth.tar 0.50 data/predict/cifar10/resnet110/noise_0.25/test/sigma_0.50 --skip 20 --batch 400 42 | python code/certify.py cifar10 models/cifar10/resnet110/noise_1.00/checkpoint.pth.tar 0.50 data/predict/cifar10/resnet110/noise_1.00/test/sigma_0.50 --skip 20 --batch 400 43 | ``` 44 | 45 | Prediction experiments on ImageNet: 46 | ``` 47 | python code/predict.py imagenet models/imagenet/resnet50/noise_0.25/checkpoint.pth.tar 0.25 data/predict/imagenet/resnet50/noise_0.25/test/N_100 --N 100 --skip 100 --batch 400 48 | python code/predict.py imagenet models/imagenet/resnet50/noise_0.25/checkpoint.pth.tar 0.25 data/predict/imagenet/resnet50/noise_0.25/test/N_1000 --N 1000 --skip 100 --batch 400 49 | python code/predict.py imagenet models/imagenet/resnet50/noise_0.25/checkpoint.pth.tar 0.25 data/predict/imagenet/resnet50/noise_0.25/test/N_10000 --N 10000 --skip 100 --batch 400 50 | python code/predict.py imagenet models/imagenet/resnet50/noise_0.25/checkpoint.pth.tar 0.25 data/predict/imagenet/resnet50/noise_0.25/test/N_100000 --N 100000 --skip 100 --batch 400 51 | ``` 52 | 53 | Finally, to visualize noisy images: 54 | ``` 55 | python code/visualize.py imagenet figures/example_images/imagenet 100 0.0 0.25 0.5 1.0 56 | python code/visualize.py imagenet figures/example_images/imagenet 5400 0.0 0.25 0.5 1.0 57 | python code/visualize.py imagenet figures/example_images/imagenet 9025 0.0 0.25 0.5 1.0 58 | python code/visualize.py imagenet figures/example_images/imagenet 19411 0.0 0.25 0.5 1.0 59 | 60 | python code/visualize.py cifar10 figures/example_images/cifar10 10 0.0 0.25 0.5 1.0 61 | python code/visualize.py cifar10 figures/example_images/cifar10 20 0.0 0.25 0.5 1.0 62 | python code/visualize.py cifar10 figures/example_images/cifar10 70 0.0 0.25 0.5 1.0 63 | python code/visualize.py cifar10 figures/example_images/cifar10 110 0.0 0.25 0.5 1.0 64 | ``` -------------------------------------------------------------------------------- /figures/compare_bounds.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/compare_bounds.pdf -------------------------------------------------------------------------------- /figures/example_images/cifar10/10_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/10_0.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/10_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/10_100.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/10_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/10_25.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/10_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/10_50.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/110_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/110_0.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/110_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/110_100.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/110_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/110_25.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/110_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/110_50.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/20_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/20_0.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/20_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/20_100.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/20_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/20_25.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/20_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/20_50.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/70_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/70_0.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/70_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/70_100.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/70_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/70_25.png -------------------------------------------------------------------------------- /figures/example_images/cifar10/70_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/70_50.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/100_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/100_0.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/100_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/100_100.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/100_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/100_25.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/100_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/100_50.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/19411_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/19411_0.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/19411_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/19411_100.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/19411_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/19411_25.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/19411_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/19411_50.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/3300_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/3300_0.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/3300_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/3300_100.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/3300_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/3300_25.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/3300_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/3300_50.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/5400_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/5400_0.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/5400_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/5400_100.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/5400_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/5400_25.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/5400_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/5400_50.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/9067_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/9067_0.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/9067_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/9067_100.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/9067_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/9067_25.png -------------------------------------------------------------------------------- /figures/example_images/imagenet/9067_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/9067_50.png -------------------------------------------------------------------------------- /figures/panda_0.25.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/panda_0.25.gif -------------------------------------------------------------------------------- /figures/panda_0.50.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/panda_0.50.gif -------------------------------------------------------------------------------- /figures/panda_1.00.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/panda_1.00.gif -------------------------------------------------------------------------------- /figures/panda_577.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/panda_577.png -------------------------------------------------------------------------------- /figures/radiusslow.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/radiusslow.pdf --------------------------------------------------------------------------------