├── .gitignore
├── README.md
├── analysis
├── latex
│ ├── vary_noise_cifar10
│ └── vary_noise_imagenet
├── markdown
│ ├── vary_noise_cifar10
│ └── vary_noise_imagenet
└── plots
│ ├── high_prob.pdf
│ ├── high_prob.png
│ ├── vary_noise_cifar10.pdf
│ ├── vary_noise_cifar10.png
│ ├── vary_noise_imagenet.pdf
│ ├── vary_noise_imagenet.png
│ ├── vary_train_noise_cifar_050.pdf
│ ├── vary_train_noise_cifar_050.png
│ ├── vary_train_noise_imagenet_050.pdf
│ └── vary_train_noise_imagenet_050.png
├── code
├── __pycache__
│ └── bounds.cpython-36.pyc
├── analyze.py
├── analyze_predict.py
├── architectures.py
├── archs
│ └── cifar_resnet.py
├── certify.py
├── core.py
├── datasets.py
├── predict.py
├── train.py
├── train_utils.py
└── visualize.py
├── data
├── certify
│ ├── cifar10
│ │ └── resnet110
│ │ │ ├── noise_0.12
│ │ │ └── test
│ │ │ │ ├── sigma_0.12
│ │ │ │ └── sigma_0.50
│ │ │ ├── noise_0.25
│ │ │ └── test
│ │ │ │ ├── sigma_0.25
│ │ │ │ └── sigma_0.50
│ │ │ ├── noise_0.50
│ │ │ └── test
│ │ │ │ └── sigma_0.50
│ │ │ └── noise_1.00
│ │ │ └── test
│ │ │ ├── sigma_0.50
│ │ │ └── sigma_1.00
│ └── imagenet
│ │ └── resnet50
│ │ ├── noise_0.25
│ │ └── test
│ │ │ ├── sigma_0.25
│ │ │ └── sigma_0.50
│ │ ├── noise_0.50
│ │ ├── test
│ │ │ └── sigma_0.50
│ │ └── train
│ │ │ └── sigma_0.50
│ │ └── noise_1.00
│ │ └── test
│ │ ├── sigma_0.50
│ │ └── sigma_1.00
└── predict
│ └── imagenet
│ └── resnet50
│ └── noise_0.25
│ └── test
│ ├── N_100
│ ├── N_1000
│ ├── N_10000
│ └── N_100000
├── experiments.MD
└── figures
├── compare_bounds.pdf
├── example_images
├── cifar10
│ ├── 10_0.png
│ ├── 10_100.png
│ ├── 10_25.png
│ ├── 10_50.png
│ ├── 110_0.png
│ ├── 110_100.png
│ ├── 110_25.png
│ ├── 110_50.png
│ ├── 20_0.png
│ ├── 20_100.png
│ ├── 20_25.png
│ ├── 20_50.png
│ ├── 70_0.png
│ ├── 70_100.png
│ ├── 70_25.png
│ └── 70_50.png
└── imagenet
│ ├── 100_0.png
│ ├── 100_100.png
│ ├── 100_25.png
│ ├── 100_50.png
│ ├── 19411_0.png
│ ├── 19411_100.png
│ ├── 19411_25.png
│ ├── 19411_50.png
│ ├── 3300_0.png
│ ├── 3300_100.png
│ ├── 3300_25.png
│ ├── 3300_50.png
│ ├── 5400_0.png
│ ├── 5400_100.png
│ ├── 5400_25.png
│ ├── 5400_50.png
│ ├── 9067_0.png
│ ├── 9067_100.png
│ ├── 9067_25.png
│ └── 9067_50.png
├── panda_0.25.gif
├── panda_0.50.gif
├── panda_1.00.gif
├── panda_577.png
└── radiusslow.pdf
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | *~
3 | models
4 | .DS_Store
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Certified Adversarial Robustness via Randomized Smoothing
2 |
3 | This repository contains code and trained models for the paper [Certified Adversarial Robustness via Randomized Smoothing](https://arxiv.org/abs/1902.02918) by [Jeremy Cohen](http://cs.cmu.edu/~jeremiac), Elan Rosenfeld, and [Zico Kolter](http://zicokolter.com).
4 |
5 | Randomized smoothing is a **provable** adversarial defense in L2 norm which **scales to ImageNet.**
6 | It's also SOTA on the smaller datasets like CIFAR-10 and SVHN where other provable L2-robust classifiers are viable.
7 |
8 | ## How does it work?
9 |
10 | First, you train a neural network _f_ with Gaussian data augmentation at variance σ2.
11 | Then you leverage _f_ to create a new, "smoothed" classifier _g_, defined as follows:
12 | _g(x)_ returns the class which _f_ is most likely to return when _x_
13 | is corrupted by isotropic Gaussian noise with variance σ2.
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | For example, let _x_ be the image above on the left.
22 | Suppose that when _f_ classifies _x_ corrupted by Gaussian noise (the GIF on the right), _f_ returns "panda"
23 | 98\% of the time and "gibbon" 2% of the time.
24 | Then the prediction of _g_ at _x_ is defined to be "panda."
25 |
26 |
27 | Interestingly, _g_ is **provably** robust within an L2 norm ball around _x_, in the sense that for any perturbation
28 | δ with sufficiently small L2 norm, _g(x+δ)_ is guaranteed to be "panda."
29 | In this particular example, _g_ will be robust around _x_ within an L2 radius of σ Φ-1(0.98) ≈ 2.05 σ,
30 | where Φ-1 is the inverse CDF of the standard normal distribution.
31 |
32 | In general, suppose that when _f_ classifies noisy corruptions of _x_, the class "panda" is returned with probability _p_ (with _p_ > 0.5).
33 | Then _g_ is guaranteed to classify "panda" within an L2 ball around _x_ of radius σ Φ-1(_p_).
34 |
35 | ### What's the intuition behind this bound?
36 |
37 | We know that _f_ classifies noisy corruptions of _x_ as "panda" with probability 0.98.
38 | An equivalent way of phrasing this that the Gaussian distribution N(x, σ2I) puts measure 0.98 on
39 | the decision region of class "panda," defined as the set {x': f(x') = "panda"}.
40 | You can prove that no matter how the decision regions of _f_ are "shaped", for any δ with
41 | ||δ||2 < σ Φ-1(0.98), the translated Gaussian N(x+δ, σ2I) is guaranteed to put measure > 0.5 on the decision region of
42 | class "panda," implying that _g(x+δ)_ = "panda."
43 |
44 | ### Wait a minute...
45 | There's one catch: it's not possible to actually evaluate the smoothed classifer _g_.
46 | This is because it's not possible to exactly compute the probability distribution over the classes when _f_'s input is corrupted by Gaussian noise.
47 | For the same reason, it's not possible to exactly compute the radius in which _g_ is provably robust.
48 |
49 | Instead, we give Monte Carlo algorithms for both
50 | 1. **prediction**: evaluating _g_(x)
51 | 2. **certification**: computing the L2 radius in which _g_ is robust around _x_
52 |
53 | which are guaranteed to return a correct answer with arbitrarily high probability.
54 |
55 | The prediction algorithm does this by abstaining from making any prediction when it's a "close call," e.g. if
56 | 510 noisy corruptions of _x_ were classified as "panda" and 490 were classified as "gibbon."
57 | Prediction is pretty cheap, since you don't need to use very many samples.
58 | For example, with our ImageNet classifier, making a prediction using 1000 samples took 1.5 seconds, and our classifier abstained 3\% of the time.
59 |
60 | On the other hand, certification is pretty slow, since you need _a lot_ of samples to say with high
61 | probability that the measure under N(x, σ2I) of the "panda" decision region is close to 1.
62 | In our experiments we used 100,000 samples, so making each certification took 150 seconds.
63 |
64 | ### Related work
65 |
66 | Randomized smoothing was first proposed in [Certified Robustness to Adversarial Examples with Differential Privacy](https://arxiv.org/abs/1802.03471)
67 | and later improved upon in [Second-Order Adversarial Attack and Certified Robustness](https://arxiv.org/abs/1809.03113).
68 | We simply tightened the analysis and showed that it outperforms the other provably L2-robust classifiers that have been proposed in the literature.
69 |
70 | ## ImageNet results
71 |
72 | We constructed three randomized smoothing classifiers for ImageNet, with the hyperparameter
73 | σ set to 0.25, 0.50, and 1.00.
74 | Here's what the panda image looks like under these three noise levels:
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 | The plot below shows the certified top-1 accuracy at various radii of these three classifiers.
84 | The "certified accuracy" of a classifier _g_ at radius _r_ is defined as test set accuracy that _g_ will
85 | provably attain under any possible adversarial attack with L2 norm less than _r_.
86 | As you can see, the hyperparameter σ controls a robustness/accuracy tradeoff: when
87 | σ is high, the standard accuracy is lower, but the classifier's correct predictions are
88 | robust within larger radii.
89 |
90 |
91 |
92 | To put these numbers in context: on ImageNet, random guessing would achieve a top-1 accuracy of 0.001.
93 | A perturbation with L2 norm of 1.0 could change one pixel by 255, ten pixels by 80, 100 pixels by 25, or 1000 pixels by 8.
94 |
95 |
96 | Here's the same data in tabular form.
97 | The best σ for each radius is denoted with an asterisk.
98 |
99 | | | r = 0.0 |r = 0.5 |r = 1.0 |r = 1.5 |r = 2.0 |r = 2.5 |r = 3.0 |
100 | | --- | --- | --- | --- | --- | --- | --- | --- |
101 | σ = 0.25 | 0.67* |0.49* |0.00 |0.00 |0.00 |0.00 |0.00 |
102 | σ = 0.50 | 0.57 |0.46 |0.38* |0.28* |0.00 |0.00 |0.00 |
103 | σ = 1.00 | 0.44 |0.38 |0.33 |0.26 |0.19* |0.15* |0.12* |
104 |
105 |
106 |
118 |
119 | ## This repository
120 |
121 | ### Outline
122 |
123 | The contents of this repository are as follows:
124 |
125 | * [code/](code) contains the code for our experiments.
126 | * [data/](data) contains the raw data from our experiments.
127 | * [analysis/](analysis) contains the plots and tables, based on the contents of [data](/data), that are shown in our paper.
128 |
129 | If you'd like to run our code, you need to download our models from [here](https://drive.google.com/file/d/1h_TpbXm5haY5f-l4--IKylmdz6tvPoR4/view?usp=sharing)
130 | and then move the directory `models` into the root directory of this repo.
131 |
132 | ### Smoothed classifiers
133 |
134 | Randomized smoothing is implemented in the `Smooth` class in [core.py](code/core.py).
135 |
136 | * To instantiate a smoothed clasifier _g_, use the constructor:
137 |
138 | ```def __init__(self, base_classifier: torch.nn.Module, num_classes: int, sigma: float):```
139 |
140 | where `base_classifier` is a PyTorch module that implements _f_, `num_classes` is the number of classes in the output
141 | space, and `sigma` is the noise hyperparameter σ
142 |
143 | * To make a prediction at an input `x`, call:
144 |
145 | ``` def predict(self, x: torch.tensor, n: int, alpha: float, batch_size: int) -> int:```
146 |
147 | where `n` is the number of Monte Carlo samples and `alpha` is the confidence level.
148 | This function will either (1) return `-1` to abstain or (2) return a class which equals _g(x)_
149 | with probability at least `1 - alpha`.
150 |
151 | * To compute a radius in which _g_ is robust around an input `x`, call:
152 |
153 | ```def certify(self, x: torch.tensor, n0: int, n: int, alpha: float, batch_size: int) -> (int, float):```
154 |
155 | where `n0` is the number of Monte Carlo samples to use for selection (see the paper), `n` is the number of Monte Carlo
156 | samples to use for estimation, and `alpha` is the confidence level.
157 | This function will either return the pair `(-1, 0.0)` to abstain, or return a pair
158 | `(prediction, radius)`. The probability that `certify()` will return a class not equal to _g(x)_ is no greater than `alpha`. Another way to say this is that with probability at least `1 - alpha`, `certify()` will either abstain or return _g(x)_.
159 |
160 | ### Scripts
161 |
162 | * The program [train.py](code/train.py) trains a base classifier with Gaussian data augmentation:
163 |
164 | ```python code/train.py imagenet resnet50 model_output_dir --batch 400 --noise 0.50 ```
165 |
166 | will train a ResNet-50 on ImageNet under Gaussian data augmentation with σ=0.50.
167 |
168 | * The program [predict.py](code/predict.py) makes predictions using _g_ on a bunch of inputs. For example,
169 |
170 | ```python code/predict.py imagenet model_output_dir/checkpoint.pth.tar 0.50 prediction_outupt --alpha 0.001 --N 1000 --skip 100 --batch 400```
171 |
172 | will load the base classifier saved at `model_output_dir/checkpoint.pth.tar`, smooth it using noise level σ=0.50,
173 | and classify every 100-th image from the ImageNet test set with parameters `N=1000`
174 | and `alpha=0.001`.
175 |
176 | * The program [certify.py](code/certify.py) certifies the robustness of _g_ on bunch of inputs. For example,
177 |
178 | ```python code/certify.py imagenet model_output_dir/checkpoint.pth.tar 0.50 certification_output --alpha 0.001 --N0 100 --N 100000 --skip 100 --batch 400```
179 |
180 | will load the base classifier saved at `model_output_dir/checkpoint.pth.tar`, smooth it using noise level σ=0.50,
181 | and certify every 100-th image from the ImageNet test set with parameters `N0=100`, `N=100000`
182 | and `alpha=0.001`.
183 |
184 | * The program [visualize.py](code/visualize.py) outputs pictures of noisy examples. For example,
185 |
186 | ```python code/visualize.py imagenet visualize_output 100 0.0 0.25 0.5 1.0```
187 |
188 | will visualize noisy corruptions of the 100-th image from the ImageNet test set with noise levels
189 | σ=0.0, σ=0.25, σ=0.50, and σ=1.00.
190 |
191 | * The program [analyze.py](code/analyze.py) generates all of certified accuracy plots and tables that appeared in the
192 | paper.
193 |
194 | Finally, we note that [this file](experiments.MD) describes exactly how to reproduce
195 | our experiments from the paper.
196 |
197 |
198 | We're not officially releasing code for the experiments where we compared randomized smoothing against the baselines,
199 | since that code involved a number of hacks, but feel free to get in touch if you'd like to see that code.
200 |
201 | ## Getting started
202 |
203 | 1. Clone this repository: `git clone git@github.com:locuslab/smoothing.git`
204 |
205 | 2. Install the dependencies:
206 | ```
207 | conda create -n smoothing
208 | conda activate smoothing
209 | # below is for linux, with CUDA 10; see https://pytorch.org/ for the correct command for your system
210 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
211 | conda install scipy pandas statsmodels matplotlib seaborn
212 | pip install setGPU
213 | ```
214 | 3. Download our trained models from [here](https://drive.google.com/file/d/1h_TpbXm5haY5f-l4--IKylmdz6tvPoR4/view?usp=sharing).
215 | 4. If you want to run ImageNet experiments, obtain a copy of ImageNet and preprocess the `val` directory to look
216 | like the `train` directory by running [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh).
217 | Finally, set the environment variable `IMAGENET_DIR` to the directory where ImageNet is located.
218 |
219 | 5. To get the hang of things, try running this command, which will certify the robustness of one of our pretrained CIFAR-10 models
220 | on the CIFAR test set.
221 | ```
222 | model="models/cifar10/resnet110/noise_0.25/checkpoint.pth.tar"
223 | output="???"
224 | python code/certify.py cifar10 $model 0.25 $output --skip 20 --batch 400
225 | ```
226 | where `???` is your desired output file.
227 |
--------------------------------------------------------------------------------
/analysis/latex/vary_noise_cifar10:
--------------------------------------------------------------------------------
1 | & $r = 0.25$& $r = 0.5$& $r = 0.75$& $r = 1.0$& $r = 1.25$& $r = 1.5$\\
2 | \midrule
3 | $\sigma = 0.12$ & 0.59 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00\\
4 | $\sigma = 0.25$ & \textbf{0.60} & \textbf{0.43} & 0.27 & 0.00 & 0.00 & 0.00\\
5 | $\sigma = 0.50$ & 0.55 & 0.41 & \textbf{0.32} & \textbf{0.23} & 0.15 & 0.09\\
6 | $\sigma = 1.00$ & 0.39 & 0.34 & 0.28 & 0.22 & \textbf{0.17} & \textbf{0.14}\\
7 |
--------------------------------------------------------------------------------
/analysis/latex/vary_noise_imagenet:
--------------------------------------------------------------------------------
1 | & $r = 0.5$& $r = 1.0$& $r = 1.5$& $r = 2.0$& $r = 2.5$& $r = 3.0$\\
2 | \midrule
3 | $\sigma = 0.25$ & \textbf{0.49} & 0.00 & 0.00 & 0.00 & 0.00 & 0.00\\
4 | $\sigma = 0.50$ & 0.46 & \textbf{0.37} & \textbf{0.29} & 0.00 & 0.00 & 0.00\\
5 | $\sigma = 1.00$ & 0.38 & 0.33 & 0.26 & \textbf{0.19} & \textbf{0.15} & \textbf{0.12}\\
6 |
--------------------------------------------------------------------------------
/analysis/markdown/vary_noise_cifar10:
--------------------------------------------------------------------------------
1 | | | r = 0.25 |r = 0.5 |r = 0.75 |r = 1.0 |r = 1.25 |r = 1.5 |
2 | | --- | --- | --- | --- | --- | --- | --- |
3 | σ = 0.12 | 0.59 |0.00 |0.00 |0.00 |0.00 |0.00 |
4 | σ = 0.25 | 0.60* |0.43* |0.27 |0.00 |0.00 |0.00 |
5 | σ = 0.50 | 0.55 |0.41 |0.32* |0.23* |0.15 |0.09 |
6 | σ = 1.00 | 0.39 |0.34 |0.28 |0.22 |0.17* |0.14* |
7 |
--------------------------------------------------------------------------------
/analysis/markdown/vary_noise_imagenet:
--------------------------------------------------------------------------------
1 | | | r = 0.5 |r = 1.0 |r = 1.5 |r = 2.0 |r = 2.5 |r = 3.0 |
2 | | --- | --- | --- | --- | --- | --- | --- |
3 | σ = 0.25 | 0.49* |0.00 |0.00 |0.00 |0.00 |0.00 |
4 | σ = 0.50 | 0.46 |0.37* |0.29* |0.00 |0.00 |0.00 |
5 | σ = 1.00 | 0.38 |0.33 |0.26 |0.19* |0.15* |0.12* |
6 |
--------------------------------------------------------------------------------
/analysis/plots/high_prob.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/high_prob.pdf
--------------------------------------------------------------------------------
/analysis/plots/high_prob.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/high_prob.png
--------------------------------------------------------------------------------
/analysis/plots/vary_noise_cifar10.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_noise_cifar10.pdf
--------------------------------------------------------------------------------
/analysis/plots/vary_noise_cifar10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_noise_cifar10.png
--------------------------------------------------------------------------------
/analysis/plots/vary_noise_imagenet.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_noise_imagenet.pdf
--------------------------------------------------------------------------------
/analysis/plots/vary_noise_imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_noise_imagenet.png
--------------------------------------------------------------------------------
/analysis/plots/vary_train_noise_cifar_050.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_train_noise_cifar_050.pdf
--------------------------------------------------------------------------------
/analysis/plots/vary_train_noise_cifar_050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_train_noise_cifar_050.png
--------------------------------------------------------------------------------
/analysis/plots/vary_train_noise_imagenet_050.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_train_noise_imagenet_050.pdf
--------------------------------------------------------------------------------
/analysis/plots/vary_train_noise_imagenet_050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/analysis/plots/vary_train_noise_imagenet_050.png
--------------------------------------------------------------------------------
/code/__pycache__/bounds.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/code/__pycache__/bounds.cpython-36.pyc
--------------------------------------------------------------------------------
/code/analyze.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 |
4 | matplotlib.use("TkAgg")
5 | import matplotlib.pyplot as plt
6 | from typing import *
7 | import pandas as pd
8 | import seaborn as sns
9 | import math
10 |
11 | sns.set()
12 |
13 |
14 | class Accuracy(object):
15 | def at_radii(self, radii: np.ndarray):
16 | raise NotImplementedError()
17 |
18 |
19 | class ApproximateAccuracy(Accuracy):
20 | def __init__(self, data_file_path: str):
21 | self.data_file_path = data_file_path
22 |
23 | def at_radii(self, radii: np.ndarray) -> np.ndarray:
24 | df = pd.read_csv(self.data_file_path, delimiter="\t")
25 | return np.array([self.at_radius(df, radius) for radius in radii])
26 |
27 | def at_radius(self, df: pd.DataFrame, radius: float):
28 | return (df["correct"] & (df["radius"] >= radius)).mean()
29 |
30 |
31 | class HighProbAccuracy(Accuracy):
32 | def __init__(self, data_file_path: str, alpha: float, rho: float):
33 | self.data_file_path = data_file_path
34 | self.alpha = alpha
35 | self.rho = rho
36 |
37 | def at_radii(self, radii: np.ndarray) -> np.ndarray:
38 | df = pd.read_csv(self.data_file_path, delimiter="\t")
39 | return np.array([self.at_radius(df, radius) for radius in radii])
40 |
41 | def at_radius(self, df: pd.DataFrame, radius: float):
42 | mean = (df["correct"] & (df["radius"] >= radius)).mean()
43 | num_examples = len(df)
44 | return (mean - self.alpha - math.sqrt(self.alpha * (1 - self.alpha) * math.log(1 / self.rho) / num_examples)
45 | - math.log(1 / self.rho) / (3 * num_examples))
46 |
47 |
48 | class Line(object):
49 | def __init__(self, quantity: Accuracy, legend: str, plot_fmt: str = "", scale_x: float = 1):
50 | self.quantity = quantity
51 | self.legend = legend
52 | self.plot_fmt = plot_fmt
53 | self.scale_x = scale_x
54 |
55 |
56 | def plot_certified_accuracy(outfile: str, title: str, max_radius: float,
57 | lines: List[Line], radius_step: float = 0.01) -> None:
58 | radii = np.arange(0, max_radius + radius_step, radius_step)
59 | plt.figure()
60 | for line in lines:
61 | plt.plot(radii * line.scale_x, line.quantity.at_radii(radii), line.plot_fmt)
62 |
63 | plt.ylim((0, 1))
64 | plt.xlim((0, max_radius))
65 | plt.tick_params(labelsize=14)
66 | plt.xlabel("radius", fontsize=16)
67 | plt.ylabel("certified accuracy", fontsize=16)
68 | plt.legend([method.legend for method in lines], loc='upper right', fontsize=16)
69 | plt.savefig(outfile + ".pdf")
70 | plt.tight_layout()
71 | plt.title(title, fontsize=20)
72 | plt.tight_layout()
73 | plt.savefig(outfile + ".png", dpi=300)
74 | plt.close()
75 |
76 |
77 | def smallplot_certified_accuracy(outfile: str, title: str, max_radius: float,
78 | methods: List[Line], radius_step: float = 0.01, xticks=0.5) -> None:
79 | radii = np.arange(0, max_radius + radius_step, radius_step)
80 | plt.figure()
81 | for method in methods:
82 | plt.plot(radii, method.quantity.at_radii(radii), method.plot_fmt)
83 |
84 | plt.ylim((0, 1))
85 | plt.xlim((0, max_radius))
86 | plt.xlabel("radius", fontsize=22)
87 | plt.ylabel("certified accuracy", fontsize=22)
88 | plt.tick_params(labelsize=20)
89 | plt.gca().xaxis.set_major_locator(plt.MultipleLocator(xticks))
90 | plt.legend([method.legend for method in methods], loc='upper right', fontsize=20)
91 | plt.tight_layout()
92 | plt.savefig(outfile + ".pdf")
93 | plt.close()
94 |
95 |
96 | def latex_table_certified_accuracy(outfile: str, radius_start: float, radius_stop: float, radius_step: float,
97 | methods: List[Line]):
98 | radii = np.arange(radius_start, radius_stop + radius_step, radius_step)
99 | accuracies = np.zeros((len(methods), len(radii)))
100 | for i, method in enumerate(methods):
101 | accuracies[i, :] = method.quantity.at_radii(radii)
102 |
103 | f = open(outfile, 'w')
104 |
105 | for radius in radii:
106 | f.write("& $r = {:.3}$".format(radius))
107 | f.write("\\\\\n")
108 |
109 | f.write("\midrule\n")
110 |
111 | for i, method in enumerate(methods):
112 | f.write(method.legend)
113 | for j, radius in enumerate(radii):
114 | if i == accuracies[:, j].argmax():
115 | txt = r" & \textbf{" + "{:.2f}".format(accuracies[i, j]) + "}"
116 | else:
117 | txt = " & {:.2f}".format(accuracies[i, j])
118 | f.write(txt)
119 | f.write("\\\\\n")
120 | f.close()
121 |
122 |
123 | def markdown_table_certified_accuracy(outfile: str, radius_start: float, radius_stop: float, radius_step: float,
124 | methods: List[Line]):
125 | radii = np.arange(radius_start, radius_stop + radius_step, radius_step)
126 | accuracies = np.zeros((len(methods), len(radii)))
127 | for i, method in enumerate(methods):
128 | accuracies[i, :] = method.quantity.at_radii(radii)
129 |
130 | f = open(outfile, 'w')
131 | f.write("| | ")
132 | for radius in radii:
133 | f.write("r = {:.3} |".format(radius))
134 | f.write("\n")
135 |
136 | f.write("| --- | ")
137 | for i in range(len(radii)):
138 | f.write(" --- |")
139 | f.write("\n")
140 |
141 | for i, method in enumerate(methods):
142 | f.write(" {} | ".format(method.legend))
143 | for j, radius in enumerate(radii):
144 | if i == accuracies[:, j].argmax():
145 | txt = "{:.2f}* |".format(accuracies[i, j])
146 | else:
147 | txt = "{:.2f} |".format(accuracies[i, j])
148 | f.write(txt)
149 | f.write("\n")
150 | f.close()
151 |
152 |
153 | if __name__ == "__main__":
154 | latex_table_certified_accuracy(
155 | "analysis/latex/vary_noise_cifar10", 0.25, 1.5, 0.25, [
156 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), "$\sigma = 0.12$"),
157 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"),
158 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"),
159 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"),
160 | ])
161 | markdown_table_certified_accuracy(
162 | "analysis/markdown/vary_noise_cifar10", 0.25, 1.5, 0.25, [
163 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), "σ = 0.12"),
164 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), "σ = 0.25"),
165 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "σ = 0.50"),
166 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), "σ = 1.00"),
167 | ])
168 | latex_table_certified_accuracy(
169 | "analysis/latex/vary_noise_imagenet", 0.5, 3.0, 0.5, [
170 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"),
171 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"),
172 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"),
173 | ])
174 | markdown_table_certified_accuracy(
175 | "analysis/markdown/vary_noise_imagenet", 0.5, 3.0, 0.5, [
176 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), "σ = 0.25"),
177 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "σ = 0.50"),
178 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), "σ = 1.00"),
179 | ])
180 | plot_certified_accuracy(
181 | "analysis/plots/vary_noise_cifar10", "CIFAR-10, vary $\sigma$", 1.5, [
182 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), "$\sigma = 0.12$"),
183 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"),
184 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"),
185 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"),
186 | ])
187 | plot_certified_accuracy(
188 | "analysis/plots/vary_train_noise_cifar_050", "CIFAR-10, vary train noise, $\sigma=0.5$", 1.5, [
189 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.50"), "train $\sigma = 0.25$"),
190 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "train $\sigma = 0.50$"),
191 | Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_0.50"), "train $\sigma = 1.00$"),
192 | ])
193 | plot_certified_accuracy(
194 | "analysis/plots/vary_train_noise_imagenet_050", "ImageNet, vary train noise, $\sigma=0.5$", 1.5, [
195 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.50"), "train $\sigma = 0.25$"),
196 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "train $\sigma = 0.50$"),
197 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_0.50"), "train $\sigma = 1.00$"),
198 | ])
199 | plot_certified_accuracy(
200 | "analysis/plots/vary_noise_imagenet", "ImageNet, vary $\sigma$", 4, [
201 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"),
202 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"),
203 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"),
204 | ])
205 | plot_certified_accuracy(
206 | "analysis/plots/high_prob", "Approximate vs. High-Probability", 2.0, [
207 | Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "Approximate"),
208 | Line(HighProbAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50", 0.001, 0.001), "High-Prob"),
209 | ])
210 |
--------------------------------------------------------------------------------
/code/analyze_predict.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | if __name__ == "__main__":
4 | output = pd.DataFrame(columns=["n", "correct, accurate", "correct, inaccuraet",
5 | "incorrect, accurate", "incorrect, inaccurate", "abstain"])
6 |
7 | results = []
8 |
9 | gold_standard = pd.read_csv("data/predict/imagenet/resnet50/noise_0.25/test/N_100000", delimiter="\t")[:450]
10 | for N in [100, 1000, 10000]:
11 | df = pd.read_csv("data/predict/imagenet/resnet50/noise_0.25/test/N_{}".format(N), delimiter="\t")[:450]
12 | accurate = df["predict"] == gold_standard["predict"]
13 | abstain = df["predict"] == -1
14 | frac_abstain = abstain.mean()
15 | frac_correct_accurate = (df["correct"] & accurate & ~abstain).mean()
16 | frac_correct_inaccurate = (df["correct"] & ~accurate & ~abstain).mean()
17 | frac_incorrect_acccurate = (~df["correct"] & accurate & ~abstain).mean()
18 | frac_incorrect_inacccurate = (~df["correct"] & ~accurate & ~abstain).mean()
19 | results.append((N, frac_correct_accurate, frac_correct_inaccurate,
20 | frac_incorrect_acccurate, frac_incorrect_inacccurate, frac_abstain))
21 |
22 | df = pd.DataFrame.from_records(results, "n", columns=["n", "correct, accurate", "correct, inaccurate",
23 | "incorrect, accurate", "incorrect, inaccurate", "abstain"])
24 | print(df.to_latex(float_format=lambda f:"{:.2f}".format(f)))
25 |
--------------------------------------------------------------------------------
/code/architectures.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.models.resnet import resnet50
3 | import torch.backends.cudnn as cudnn
4 | from archs.cifar_resnet import resnet as resnet_cifar
5 | from datasets import get_normalize_layer
6 | from torch.nn.functional import interpolate
7 |
8 | # resnet50 - the classic ResNet-50, sized for ImageNet
9 | # cifar_resnet20 - a 20-layer residual network sized for CIFAR
10 | # cifar_resnet110 - a 110-layer residual network sized for CIFAR
11 | ARCHITECTURES = ["resnet50", "cifar_resnet20", "cifar_resnet110"]
12 |
13 | def get_architecture(arch: str, dataset: str) -> torch.nn.Module:
14 | """ Return a neural network (with random weights)
15 |
16 | :param arch: the architecture - should be in the ARCHITECTURES list above
17 | :param dataset: the dataset - should be in the datasets.DATASETS list
18 | :return: a Pytorch module
19 | """
20 | if arch == "resnet50" and dataset == "imagenet":
21 | model = torch.nn.DataParallel(resnet50(pretrained=False)).cuda()
22 | cudnn.benchmark = True
23 | elif arch == "cifar_resnet20":
24 | model = resnet_cifar(depth=20, num_classes=10).cuda()
25 | elif arch == "cifar_resnet110":
26 | model = resnet_cifar(depth=110, num_classes=10).cuda()
27 | normalize_layer = get_normalize_layer(dataset)
28 | return torch.nn.Sequential(normalize_layer, model)
29 |
--------------------------------------------------------------------------------
/code/archs/cifar_resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''
4 | This file is from: https://raw.githubusercontent.com/bearpaw/pytorch-classification/master/models/cifar/resnet.py
5 | by Wei Yang
6 | '''
7 | import torch.nn as nn
8 | import math
9 |
10 |
11 | __all__ = ['resnet']
12 |
13 | def conv3x3(in_planes, out_planes, stride=1):
14 | "3x3 convolution with padding"
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
16 | padding=1, bias=False)
17 |
18 |
19 | class BasicBlock(nn.Module):
20 | expansion = 1
21 |
22 | def __init__(self, inplanes, planes, stride=1, downsample=None):
23 | super(BasicBlock, self).__init__()
24 | self.conv1 = conv3x3(inplanes, planes, stride)
25 | self.bn1 = nn.BatchNorm2d(planes)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.conv2 = conv3x3(planes, planes)
28 | self.bn2 = nn.BatchNorm2d(planes)
29 | self.downsample = downsample
30 | self.stride = stride
31 |
32 | def forward(self, x):
33 | residual = x
34 |
35 | out = self.conv1(x)
36 | out = self.bn1(out)
37 | out = self.relu(out)
38 |
39 | out = self.conv2(out)
40 | out = self.bn2(out)
41 |
42 | if self.downsample is not None:
43 | residual = self.downsample(x)
44 |
45 | out += residual
46 | out = self.relu(out)
47 |
48 | return out
49 |
50 |
51 | class Bottleneck(nn.Module):
52 | expansion = 4
53 |
54 | def __init__(self, inplanes, planes, stride=1, downsample=None):
55 | super(Bottleneck, self).__init__()
56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
57 | self.bn1 = nn.BatchNorm2d(planes)
58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
59 | padding=1, bias=False)
60 | self.bn2 = nn.BatchNorm2d(planes)
61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
62 | self.bn3 = nn.BatchNorm2d(planes * 4)
63 | self.relu = nn.ReLU(inplace=True)
64 | self.downsample = downsample
65 | self.stride = stride
66 |
67 | def forward(self, x):
68 | residual = x
69 |
70 | out = self.conv1(x)
71 | out = self.bn1(out)
72 | out = self.relu(out)
73 |
74 | out = self.conv2(out)
75 | out = self.bn2(out)
76 | out = self.relu(out)
77 |
78 | out = self.conv3(out)
79 | out = self.bn3(out)
80 |
81 | if self.downsample is not None:
82 | residual = self.downsample(x)
83 |
84 | out += residual
85 | out = self.relu(out)
86 |
87 | return out
88 |
89 |
90 | class ResNet(nn.Module):
91 |
92 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'):
93 | super(ResNet, self).__init__()
94 | # Model type specifies number of layers for CIFAR-10 model
95 | if block_name.lower() == 'basicblock':
96 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
97 | n = (depth - 2) // 6
98 | block = BasicBlock
99 | elif block_name.lower() == 'bottleneck':
100 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
101 | n = (depth - 2) // 9
102 | block = Bottleneck
103 | else:
104 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
105 |
106 |
107 | self.inplanes = 16
108 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
109 | bias=False)
110 | self.bn1 = nn.BatchNorm2d(16)
111 | self.relu = nn.ReLU(inplace=True)
112 | self.layer1 = self._make_layer(block, 16, n)
113 | self.layer2 = self._make_layer(block, 32, n, stride=2)
114 | self.layer3 = self._make_layer(block, 64, n, stride=2)
115 | self.avgpool = nn.AvgPool2d(8)
116 | self.fc = nn.Linear(64 * block.expansion, num_classes)
117 |
118 | for m in self.modules():
119 | if isinstance(m, nn.Conv2d):
120 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
121 | m.weight.data.normal_(0, math.sqrt(2. / n))
122 | elif isinstance(m, nn.BatchNorm2d):
123 | m.weight.data.fill_(1)
124 | m.bias.data.zero_()
125 |
126 | def _make_layer(self, block, planes, blocks, stride=1):
127 | downsample = None
128 | if stride != 1 or self.inplanes != planes * block.expansion:
129 | downsample = nn.Sequential(
130 | nn.Conv2d(self.inplanes, planes * block.expansion,
131 | kernel_size=1, stride=stride, bias=False),
132 | nn.BatchNorm2d(planes * block.expansion),
133 | )
134 |
135 | layers = []
136 | layers.append(block(self.inplanes, planes, stride, downsample))
137 | self.inplanes = planes * block.expansion
138 | for i in range(1, blocks):
139 | layers.append(block(self.inplanes, planes))
140 |
141 | return nn.Sequential(*layers)
142 |
143 | def forward(self, x):
144 | x = self.conv1(x)
145 | x = self.bn1(x)
146 | x = self.relu(x) # 32x32
147 |
148 | x = self.layer1(x) # 32x32
149 | x = self.layer2(x) # 16x16
150 | x = self.layer3(x) # 8x8
151 |
152 | x = self.avgpool(x)
153 | x = x.view(x.size(0), -1)
154 | x = self.fc(x)
155 |
156 | return x
157 |
158 |
159 | def resnet(**kwargs):
160 | """
161 | Constructs a ResNet model.
162 | """
163 | return ResNet(**kwargs)
--------------------------------------------------------------------------------
/code/certify.py:
--------------------------------------------------------------------------------
1 | # evaluate a smoothed classifier on a dataset
2 | import argparse
3 | import os
4 | import setGPU
5 | from datasets import get_dataset, DATASETS, get_num_classes
6 | from core import Smooth
7 | from time import time
8 | import torch
9 | import datetime
10 | from architectures import get_architecture
11 |
12 | parser = argparse.ArgumentParser(description='Certify many examples')
13 | parser.add_argument("dataset", choices=DATASETS, help="which dataset")
14 | parser.add_argument("base_classifier", type=str, help="path to saved pytorch model of base classifier")
15 | parser.add_argument("sigma", type=float, help="noise hyperparameter")
16 | parser.add_argument("outfile", type=str, help="output file")
17 | parser.add_argument("--batch", type=int, default=1000, help="batch size")
18 | parser.add_argument("--skip", type=int, default=1, help="how many examples to skip")
19 | parser.add_argument("--max", type=int, default=-1, help="stop after this many examples")
20 | parser.add_argument("--split", choices=["train", "test"], default="test", help="train or test set")
21 | parser.add_argument("--N0", type=int, default=100)
22 | parser.add_argument("--N", type=int, default=100000, help="number of samples to use")
23 | parser.add_argument("--alpha", type=float, default=0.001, help="failure probability")
24 | args = parser.parse_args()
25 |
26 | if __name__ == "__main__":
27 | # load the base classifier
28 | checkpoint = torch.load(args.base_classifier)
29 | base_classifier = get_architecture(checkpoint["arch"], args.dataset)
30 | base_classifier.load_state_dict(checkpoint['state_dict'])
31 |
32 | # create the smooothed classifier g
33 | smoothed_classifier = Smooth(base_classifier, get_num_classes(args.dataset), args.sigma)
34 |
35 | # prepare output file
36 | f = open(args.outfile, 'w')
37 | print("idx\tlabel\tpredict\tradius\tcorrect\ttime", file=f, flush=True)
38 |
39 | # iterate through the dataset
40 | dataset = get_dataset(args.dataset, args.split)
41 | for i in range(len(dataset)):
42 |
43 | # only certify every args.skip examples, and stop after args.max examples
44 | if i % args.skip != 0:
45 | continue
46 | if i == args.max:
47 | break
48 |
49 | (x, label) = dataset[i]
50 |
51 | before_time = time()
52 | # certify the prediction of g around x
53 | x = x.cuda()
54 | prediction, radius = smoothed_classifier.certify(x, args.N0, args.N, args.alpha, args.batch)
55 | after_time = time()
56 | correct = int(prediction == label)
57 |
58 | time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time)))
59 | print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(
60 | i, label, prediction, radius, correct, time_elapsed), file=f, flush=True)
61 |
62 | f.close()
63 |
--------------------------------------------------------------------------------
/code/core.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from scipy.stats import norm, binom_test
3 | import numpy as np
4 | from math import ceil
5 | from statsmodels.stats.proportion import proportion_confint
6 |
7 |
8 | class Smooth(object):
9 | """A smoothed classifier g """
10 |
11 | # to abstain, Smooth returns this int
12 | ABSTAIN = -1
13 |
14 | def __init__(self, base_classifier: torch.nn.Module, num_classes: int, sigma: float):
15 | """
16 | :param base_classifier: maps from [batch x channel x height x width] to [batch x num_classes]
17 | :param num_classes:
18 | :param sigma: the noise level hyperparameter
19 | """
20 | self.base_classifier = base_classifier
21 | self.num_classes = num_classes
22 | self.sigma = sigma
23 |
24 | def certify(self, x: torch.tensor, n0: int, n: int, alpha: float, batch_size: int) -> (int, float):
25 | """ Monte Carlo algorithm for certifying that g's prediction around x is constant within some L2 radius.
26 | With probability at least 1 - alpha, the class returned by this method will equal g(x), and g's prediction will
27 | robust within a L2 ball of radius R around x.
28 |
29 | :param x: the input [channel x height x width]
30 | :param n0: the number of Monte Carlo samples to use for selection
31 | :param n: the number of Monte Carlo samples to use for estimation
32 | :param alpha: the failure probability
33 | :param batch_size: batch size to use when evaluating the base classifier
34 | :return: (predicted class, certified radius)
35 | in the case of abstention, the class will be ABSTAIN and the radius 0.
36 | """
37 | self.base_classifier.eval()
38 | # draw samples of f(x+ epsilon)
39 | counts_selection = self._sample_noise(x, n0, batch_size)
40 | # use these samples to take a guess at the top class
41 | cAHat = counts_selection.argmax().item()
42 | # draw more samples of f(x + epsilon)
43 | counts_estimation = self._sample_noise(x, n, batch_size)
44 | # use these samples to estimate a lower bound on pA
45 | nA = counts_estimation[cAHat].item()
46 | pABar = self._lower_confidence_bound(nA, n, alpha)
47 | if pABar < 0.5:
48 | return Smooth.ABSTAIN, 0.0
49 | else:
50 | radius = self.sigma * norm.ppf(pABar)
51 | return cAHat, radius
52 |
53 | def predict(self, x: torch.tensor, n: int, alpha: float, batch_size: int) -> int:
54 | """ Monte Carlo algorithm for evaluating the prediction of g at x. With probability at least 1 - alpha, the
55 | class returned by this method will equal g(x).
56 |
57 | This function uses the hypothesis test described in https://arxiv.org/abs/1610.03944
58 | for identifying the top category of a multinomial distribution.
59 |
60 | :param x: the input [channel x height x width]
61 | :param n: the number of Monte Carlo samples to use
62 | :param alpha: the failure probability
63 | :param batch_size: batch size to use when evaluating the base classifier
64 | :return: the predicted class, or ABSTAIN
65 | """
66 | self.base_classifier.eval()
67 | counts = self._sample_noise(x, n, batch_size)
68 | top2 = counts.argsort()[::-1][:2]
69 | count1 = counts[top2[0]]
70 | count2 = counts[top2[1]]
71 | if binom_test(count1, count1 + count2, p=0.5) > alpha:
72 | return Smooth.ABSTAIN
73 | else:
74 | return top2[0]
75 |
76 | def _sample_noise(self, x: torch.tensor, num: int, batch_size) -> np.ndarray:
77 | """ Sample the base classifier's prediction under noisy corruptions of the input x.
78 |
79 | :param x: the input [channel x width x height]
80 | :param num: number of samples to collect
81 | :param batch_size:
82 | :return: an ndarray[int] of length num_classes containing the per-class counts
83 | """
84 | with torch.no_grad():
85 | counts = np.zeros(self.num_classes, dtype=int)
86 | for _ in range(ceil(num / batch_size)):
87 | this_batch_size = min(batch_size, num)
88 | num -= this_batch_size
89 |
90 | batch = x.repeat((this_batch_size, 1, 1, 1))
91 | noise = torch.randn_like(batch, device='cuda') * self.sigma
92 | predictions = self.base_classifier(batch + noise).argmax(1)
93 | counts += self._count_arr(predictions.cpu().numpy(), self.num_classes)
94 | return counts
95 |
96 | def _count_arr(self, arr: np.ndarray, length: int) -> np.ndarray:
97 | counts = np.zeros(length, dtype=int)
98 | for idx in arr:
99 | counts[idx] += 1
100 | return counts
101 |
102 | def _lower_confidence_bound(self, NA: int, N: int, alpha: float) -> float:
103 | """ Returns a (1 - alpha) lower confidence bound on a bernoulli proportion.
104 |
105 | This function uses the Clopper-Pearson method.
106 |
107 | :param NA: the number of "successes"
108 | :param N: the number of total draws
109 | :param alpha: the confidence level
110 | :return: a lower bound on the binomial proportion which holds true w.p at least (1 - alpha) over the samples
111 | """
112 | return proportion_confint(NA, N, alpha=2 * alpha, method="beta")[0]
113 |
--------------------------------------------------------------------------------
/code/datasets.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms, datasets
2 | from typing import *
3 | import torch
4 | import os
5 | from torch.utils.data import Dataset
6 |
7 | # set this environment variable to the location of your imagenet directory if you want to read ImageNet data.
8 | # make sure your val directory is preprocessed to look like the train directory, e.g. by running this script
9 | # https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh
10 | IMAGENET_LOC_ENV = "IMAGENET_DIR"
11 |
12 | # list of all datasets
13 | DATASETS = ["imagenet", "cifar10"]
14 |
15 |
16 | def get_dataset(dataset: str, split: str) -> Dataset:
17 | """Return the dataset as a PyTorch Dataset object"""
18 | if dataset == "imagenet":
19 | return _imagenet(split)
20 | elif dataset == "cifar10":
21 | return _cifar10(split)
22 |
23 |
24 | def get_num_classes(dataset: str):
25 | """Return the number of classes in the dataset. """
26 | if dataset == "imagenet":
27 | return 1000
28 | elif dataset == "cifar10":
29 | return 10
30 |
31 |
32 | def get_normalize_layer(dataset: str) -> torch.nn.Module:
33 | """Return the dataset's normalization layer"""
34 | if dataset == "imagenet":
35 | return NormalizeLayer(_IMAGENET_MEAN, _IMAGENET_STDDEV)
36 | elif dataset == "cifar10":
37 | return NormalizeLayer(_CIFAR10_MEAN, _CIFAR10_STDDEV)
38 |
39 |
40 | _IMAGENET_MEAN = [0.485, 0.456, 0.406]
41 | _IMAGENET_STDDEV = [0.229, 0.224, 0.225]
42 |
43 | _CIFAR10_MEAN = [0.4914, 0.4822, 0.4465]
44 | _CIFAR10_STDDEV = [0.2023, 0.1994, 0.2010]
45 |
46 |
47 | def _cifar10(split: str) -> Dataset:
48 | if split == "train":
49 | return datasets.CIFAR10("./dataset_cache", train=True, download=True, transform=transforms.Compose([
50 | transforms.RandomCrop(32, padding=4),
51 | transforms.RandomHorizontalFlip(),
52 | transforms.ToTensor()
53 | ]))
54 | elif split == "test":
55 | return datasets.CIFAR10("./dataset_cache", train=False, download=True, transform=transforms.ToTensor())
56 |
57 |
58 | def _imagenet(split: str) -> Dataset:
59 | if not IMAGENET_LOC_ENV in os.environ:
60 | raise RuntimeError("environment variable for ImageNet directory not set")
61 |
62 | dir = os.environ[IMAGENET_LOC_ENV]
63 | if split == "train":
64 | subdir = os.path.join(dir, "train")
65 | transform = transforms.Compose([
66 | transforms.RandomSizedCrop(224),
67 | transforms.RandomHorizontalFlip(),
68 | transforms.ToTensor()
69 | ])
70 | elif split == "test":
71 | subdir = os.path.join(dir, "val")
72 | transform = transforms.Compose([
73 | transforms.Scale(256),
74 | transforms.CenterCrop(224),
75 | transforms.ToTensor()
76 | ])
77 | return datasets.ImageFolder(subdir, transform)
78 |
79 |
80 | class NormalizeLayer(torch.nn.Module):
81 | """Standardize the channels of a batch of images by subtracting the dataset mean
82 | and dividing by the dataset standard deviation.
83 |
84 | In order to certify radii in original coordinates rather than standardized coordinates, we
85 | add the Gaussian noise _before_ standardizing, which is why we have standardization be the first
86 | layer of the classifier rather than as a part of preprocessing as is typical.
87 | """
88 |
89 | def __init__(self, means: List[float], sds: List[float]):
90 | """
91 | :param means: the channel means
92 | :param sds: the channel standard deviations
93 | """
94 | super(NormalizeLayer, self).__init__()
95 | self.means = torch.tensor(means).cuda()
96 | self.sds = torch.tensor(sds).cuda()
97 |
98 | def forward(self, input: torch.tensor):
99 | (batch_size, num_channels, height, width) = input.shape
100 | means = self.means.repeat((batch_size, height, width, 1)).permute(0, 3, 1, 2)
101 | sds = self.sds.repeat((batch_size, height, width, 1)).permute(0, 3, 1, 2)
102 | return (input - means) / sds
103 |
--------------------------------------------------------------------------------
/code/predict.py:
--------------------------------------------------------------------------------
1 | """ This script loads a base classifier and then runs PREDICT on many examples from a dataset.
2 | """
3 | import argparse
4 | import setGPU
5 | from datasets import get_dataset, DATASETS, get_num_classes
6 | from core import Smooth
7 | from time import time
8 | import torch
9 | from architectures import get_architecture
10 | import datetime
11 |
12 | parser = argparse.ArgumentParser(description='Predict on many examples')
13 | parser.add_argument("dataset", choices=DATASETS, help="which dataset")
14 | parser.add_argument("base_classifier", type=str, help="path to saved pytorch model of base classifier")
15 | parser.add_argument("sigma", type=float, help="noise hyperparameter")
16 | parser.add_argument("outfile", type=str, help="output file")
17 | parser.add_argument("--batch", type=int, default=1000, help="batch size")
18 | parser.add_argument("--skip", type=int, default=1, help="how many examples to skip")
19 | parser.add_argument("--max", type=int, default=-1, help="stop after this many examples")
20 | parser.add_argument("--split", choices=["train", "test"], default="test", help="train or test set")
21 | parser.add_argument("--N", type=int, default=100000, help="number of samples to use")
22 | parser.add_argument("--alpha", type=float, default=0.001, help="failure probability")
23 | args = parser.parse_args()
24 |
25 | if __name__ == "__main__":
26 | # load the base classifier
27 | checkpoint = torch.load(args.base_classifier)
28 | base_classifier = get_architecture(checkpoint["arch"], args.dataset)
29 | base_classifier.load_state_dict(checkpoint['state_dict'])
30 |
31 | # create the smoothed classifier g
32 | smoothed_classifier = Smooth(base_classifier, get_num_classes(args.dataset), args.sigma)
33 |
34 | # prepare output file
35 | f = open(args.outfile, 'w')
36 | print("idx\tlabel\tpredict\tcorrect\ttime", file=f, flush=True)
37 |
38 | # iterate through the dataset
39 | dataset = get_dataset(args.dataset, args.split)
40 | for i in range(len(dataset)):
41 |
42 | # only certify every args.skip examples, and stop after args.max examples
43 | if i % args.skip != 0:
44 | continue
45 | if i == args.max:
46 | break
47 |
48 | (x, label) = dataset[i]
49 | x = x.cuda()
50 | before_time = time()
51 |
52 | # make the prediction
53 | prediction = smoothed_classifier.predict(x, args.N, args.alpha, args.batch)
54 |
55 | after_time = time()
56 | correct = int(prediction == label)
57 |
58 | time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time)))
59 |
60 | # log the prediction and whether it was correct
61 | print("{}\t{}\t{}\t{}\t{}".format(i, label, prediction, correct, time_elapsed), file=f, flush=True)
62 |
63 | f.close()
64 |
--------------------------------------------------------------------------------
/code/train.py:
--------------------------------------------------------------------------------
1 | # this file is based on code publicly available at
2 | # https://github.com/bearpaw/pytorch-classification
3 | # written by Wei Yang.
4 |
5 | import argparse
6 | import os
7 | import torch
8 | from torch.nn import CrossEntropyLoss
9 | from torch.utils.data import DataLoader
10 | from datasets import get_dataset, DATASETS
11 | from architectures import ARCHITECTURES, get_architecture
12 | from torch.optim import SGD, Optimizer
13 | from torch.optim.lr_scheduler import StepLR
14 | import time
15 | import datetime
16 | from train_utils import AverageMeter, accuracy, init_logfile, log
17 |
18 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
19 | parser.add_argument('dataset', type=str, choices=DATASETS)
20 | parser.add_argument('arch', type=str, choices=ARCHITECTURES)
21 | parser.add_argument('outdir', type=str, help='folder to save model and training log)')
22 | parser.add_argument('--workers', default=4, type=int, metavar='N',
23 | help='number of data loading workers (default: 4)')
24 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
25 | help='number of total epochs to run')
26 | parser.add_argument('--batch', default=256, type=int, metavar='N',
27 | help='batchsize (default: 256)')
28 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
29 | help='initial learning rate', dest='lr')
30 | parser.add_argument('--lr_step_size', type=int, default=30,
31 | help='How often to decrease learning by gamma.')
32 | parser.add_argument('--gamma', type=float, default=0.1,
33 | help='LR is multiplied by gamma on schedule.')
34 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
35 | help='momentum')
36 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
37 | metavar='W', help='weight decay (default: 1e-4)')
38 | parser.add_argument('--noise_sd', default=0.0, type=float,
39 | help="standard deviation of Gaussian noise for data augmentation")
40 | parser.add_argument('--gpu', default=None, type=str,
41 | help='id(s) for CUDA_VISIBLE_DEVICES')
42 | parser.add_argument('--print-freq', default=10, type=int,
43 | metavar='N', help='print frequency (default: 10)')
44 | args = parser.parse_args()
45 |
46 |
47 | def main():
48 | if args.gpu:
49 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
50 |
51 | if not os.path.exists(args.outdir):
52 | os.mkdir(args.outdir)
53 |
54 | train_dataset = get_dataset(args.dataset, 'train')
55 | test_dataset = get_dataset(args.dataset, 'test')
56 | pin_memory = (args.dataset == "imagenet")
57 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch,
58 | num_workers=args.workers, pin_memory=pin_memory)
59 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch,
60 | num_workers=args.workers, pin_memory=pin_memory)
61 |
62 | model = get_architecture(args.arch, args.dataset)
63 |
64 | logfilename = os.path.join(args.outdir, 'log.txt')
65 | init_logfile(logfilename, "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")
66 |
67 | criterion = CrossEntropyLoss().cuda()
68 | optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
69 | scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma)
70 |
71 | for epoch in range(args.epochs):
72 | scheduler.step(epoch)
73 | before = time.time()
74 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, args.noise_sd)
75 | test_loss, test_acc = test(test_loader, model, criterion, args.noise_sd)
76 | after = time.time()
77 |
78 | log(logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
79 | epoch, str(datetime.timedelta(seconds=(after - before))),
80 | scheduler.get_lr()[0], train_loss, train_acc, test_loss, test_acc))
81 |
82 | torch.save({
83 | 'epoch': epoch + 1,
84 | 'arch': args.arch,
85 | 'state_dict': model.state_dict(),
86 | 'optimizer': optimizer.state_dict(),
87 | }, os.path.join(args.outdir, 'checkpoint.pth.tar'))
88 |
89 |
90 | def train(loader: DataLoader, model: torch.nn.Module, criterion, optimizer: Optimizer, epoch: int, noise_sd: float):
91 | batch_time = AverageMeter()
92 | data_time = AverageMeter()
93 | losses = AverageMeter()
94 | top1 = AverageMeter()
95 | top5 = AverageMeter()
96 | end = time.time()
97 |
98 | # switch to train mode
99 | model.train()
100 |
101 | for i, (inputs, targets) in enumerate(loader):
102 | # measure data loading time
103 | data_time.update(time.time() - end)
104 |
105 | inputs = inputs.cuda()
106 | targets = targets.cuda()
107 |
108 | # augment inputs with noise
109 | inputs = inputs + torch.randn_like(inputs, device='cuda') * noise_sd
110 |
111 | # compute output
112 | outputs = model(inputs)
113 | loss = criterion(outputs, targets)
114 |
115 | # measure accuracy and record loss
116 | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
117 | losses.update(loss.item(), inputs.size(0))
118 | top1.update(acc1.item(), inputs.size(0))
119 | top5.update(acc5.item(), inputs.size(0))
120 |
121 | # compute gradient and do SGD step
122 | optimizer.zero_grad()
123 | loss.backward()
124 | optimizer.step()
125 |
126 | # measure elapsed time
127 | batch_time.update(time.time() - end)
128 | end = time.time()
129 |
130 | if i % args.print_freq == 0:
131 | print('Epoch: [{0}][{1}/{2}]\t'
132 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
133 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
134 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
135 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
136 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
137 | epoch, i, len(loader), batch_time=batch_time,
138 | data_time=data_time, loss=losses, top1=top1, top5=top5))
139 |
140 | return (losses.avg, top1.avg)
141 |
142 |
143 | def test(loader: DataLoader, model: torch.nn.Module, criterion, noise_sd: float):
144 | batch_time = AverageMeter()
145 | data_time = AverageMeter()
146 | losses = AverageMeter()
147 | top1 = AverageMeter()
148 | top5 = AverageMeter()
149 | end = time.time()
150 |
151 | # switch to eval mode
152 | model.eval()
153 |
154 | with torch.no_grad():
155 | for i, (inputs, targets) in enumerate(loader):
156 | # measure data loading time
157 | data_time.update(time.time() - end)
158 |
159 | inputs = inputs.cuda()
160 | targets = targets.cuda()
161 |
162 | # augment inputs with noise
163 | inputs = inputs + torch.randn_like(inputs, device='cuda') * noise_sd
164 |
165 | # compute output
166 | outputs = model(inputs)
167 | loss = criterion(outputs, targets)
168 |
169 | # measure accuracy and record loss
170 | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
171 | losses.update(loss.item(), inputs.size(0))
172 | top1.update(acc1.item(), inputs.size(0))
173 | top5.update(acc5.item(), inputs.size(0))
174 |
175 | # measure elapsed time
176 | batch_time.update(time.time() - end)
177 | end = time.time()
178 |
179 | if i % args.print_freq == 0:
180 | print('Test: [{0}/{1}]\t'
181 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
182 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
183 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
184 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
185 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
186 | i, len(loader), batch_time=batch_time,
187 | data_time=data_time, loss=losses, top1=top1, top5=top5))
188 |
189 | return (losses.avg, top1.avg)
190 |
191 |
192 | if __name__ == "__main__":
193 | main()
194 |
--------------------------------------------------------------------------------
/code/train_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class AverageMeter(object):
4 | """Computes and stores the average and current value"""
5 | def __init__(self):
6 | self.reset()
7 |
8 | def reset(self):
9 | self.val = 0
10 | self.avg = 0
11 | self.sum = 0
12 | self.count = 0
13 |
14 | def update(self, val, n=1):
15 | self.val = val
16 | self.sum += val * n
17 | self.count += n
18 | self.avg = self.sum / self.count
19 |
20 |
21 | def accuracy(output, target, topk=(1,)):
22 | """Computes the accuracy over the k top predictions for the specified values of k"""
23 | with torch.no_grad():
24 | maxk = max(topk)
25 | batch_size = target.size(0)
26 |
27 | _, pred = output.topk(maxk, 1, True, True)
28 | pred = pred.t()
29 | correct = pred.eq(target.view(1, -1).expand_as(pred))
30 |
31 | res = []
32 | for k in topk:
33 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
34 | res.append(correct_k.mul_(100.0 / batch_size))
35 | return res
36 |
37 | def init_logfile(filename: str, text: str):
38 | f = open(filename, 'w')
39 | f.write(text+"\n")
40 | f.close()
41 |
42 | def log(filename: str, text: str):
43 | f = open(filename, 'a')
44 | f.write(text+"\n")
45 | f.close()
--------------------------------------------------------------------------------
/code/visualize.py:
--------------------------------------------------------------------------------
1 | # visualize noisy images
2 | import argparse
3 | from datasets import get_dataset, DATASETS
4 | import torch
5 | from torchvision.transforms import ToPILImage
6 |
7 | parser = argparse.ArgumentParser(description='visualize noisy images')
8 | parser.add_argument("dataset", type=str, choices=DATASETS)
9 | parser.add_argument("outdir", type=str, help="output directory")
10 | parser.add_argument("idx", type=int)
11 | parser.add_argument("noise_sds", nargs='+', type=float)
12 | parser.add_argument("--split", choices=["train", "test"], default="test")
13 | args = parser.parse_args()
14 |
15 | toPilImage = ToPILImage()
16 | dataset = get_dataset(args.dataset, args.split)
17 | image, _ = dataset[args.idx]
18 | noise = torch.randn_like(image)
19 | for noise_sd in args.noise_sds:
20 | noisy_image = torch.clamp(image + noise * noise_sd, min=0, max=1)
21 | pil = toPilImage(noisy_image)
22 | pil.save("{}/{}_{}.png".format(args.outdir, args.idx, int(noise_sd * 100)))
23 |
--------------------------------------------------------------------------------
/data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25:
--------------------------------------------------------------------------------
1 | idx label predict radius correct time
2 | 0 3 3 0.378 1 15.4
3 | 20 7 -1 0.0 0 15.6
4 | 40 4 4 0.341 1 15.7
5 | 60 7 7 0.941 1 15.8
6 | 80 8 8 0.755 1 15.9
7 | 100 4 -1 0.0 0 15.9
8 | 120 8 8 0.163 1 16.0
9 | 140 6 6 0.203 1 16.3
10 | 160 2 2 0.548 1 16.4
11 | 180 0 0 0.557 1 16.4
12 | 200 5 5 0.313 1 16.6
13 | 220 7 7 0.134 1 16.7
14 | 240 1 1 0.978 1 16.8
15 | 260 8 8 0.563 1 16.8
16 | 280 9 9 0.283 1 16.8
17 | 300 6 6 0.955 1 16.9
18 | 320 3 3 0.224 1 16.9
19 | 340 2 4 0.152 0 16.9
20 | 360 9 9 0.611 1 16.9
21 | 380 6 6 0.444 1 17.0
22 | 400 9 9 0.753 1 17.0
23 | 420 4 4 0.699 1 17.0
24 | 440 1 1 0.708 1 17.0
25 | 460 5 5 0.616 1 17.0
26 | 480 8 9 0.211 0 17.0
27 | 500 4 -1 0.0 0 17.0
28 | 520 5 5 0.63 1 17.0
29 | 540 1 1 0.725 1 17.0
30 | 560 0 0 0.978 1 17.0
31 | 580 4 4 0.314 1 17.0
32 | 600 8 8 0.46 1 17.0
33 | 620 7 7 0.849 1 16.9
34 | 640 5 3 0.698 0 17.0
35 | 660 9 9 0.411 1 17.0
36 | 680 9 9 0.441 1 16.9
37 | 700 7 7 0.955 1 16.9
38 | 720 4 4 0.232 1 16.9
39 | 740 2 5 0.508 0 16.9
40 | 760 3 3 0.237 1 16.9
41 | 780 9 9 0.929 1 16.9
42 | 800 7 7 0.978 1 16.9
43 | 820 6 6 0.254 1 16.9
44 | 840 7 7 0.978 1 16.9
45 | 860 4 4 0.0831 1 16.9
46 | 880 8 8 0.894 1 16.9
47 | 900 2 6 0.0322 0 16.9
48 | 920 6 6 0.417 1 16.9
49 | 940 9 9 0.759 1 16.9
50 | 960 9 9 0.245 1 16.9
51 | 980 2 2 0.0549 1 16.9
52 | 1000 5 5 0.978 1 16.9
53 | 1020 1 1 0.906 1 16.9
54 | 1040 7 5 0.17 0 16.9
55 | 1060 9 9 0.978 1 16.9
56 | 1080 6 6 0.411 1 16.9
57 | 1100 7 0 0.0632 0 16.9
58 | 1120 5 2 0.039 0 16.9
59 | 1140 6 6 0.257 1 16.9
60 | 1160 8 8 0.617 1 16.9
61 | 1180 3 3 0.558 1 16.9
62 | 1200 8 8 0.518 1 16.9
63 | 1220 9 6 0.0591 0 16.9
64 | 1240 2 6 0.195 0 16.9
65 | 1260 1 -1 0.0 0 16.9
66 | 1280 3 3 0.303 1 16.9
67 | 1300 4 3 0.347 0 16.9
68 | 1320 7 7 0.515 1 16.9
69 | 1340 1 1 0.648 1 16.9
70 | 1360 7 7 0.321 1 16.9
71 | 1380 4 4 0.0991 1 16.9
72 | 1400 5 3 0.01 0 16.9
73 | 1420 4 4 0.164 1 16.9
74 | 1440 0 0 0.323 1 16.9
75 | 1460 6 6 0.831 1 16.9
76 | 1480 1 1 0.26 1 16.9
77 | 1500 1 1 0.724 1 16.9
78 | 1520 7 7 0.522 1 16.9
79 | 1540 8 8 0.889 1 16.9
80 | 1560 7 7 0.978 1 16.9
81 | 1580 6 -1 0.0 0 16.9
82 | 1600 8 -1 0.0 0 16.9
83 | 1620 5 -1 0.0 0 16.9
84 | 1640 7 -1 0.0 0 16.9
85 | 1660 3 3 0.368 1 16.9
86 | 1680 6 6 0.474 1 16.9
87 | 1700 5 5 0.0619 1 16.9
88 | 1720 4 4 0.578 1 16.9
89 | 1740 1 1 0.477 1 16.9
90 | 1760 0 0 0.683 1 16.9
91 | 1780 7 7 0.662 1 16.9
92 | 1800 4 4 0.441 1 16.9
93 | 1820 8 8 0.781 1 16.9
94 | 1840 8 -1 0.0 0 16.9
95 | 1860 8 8 0.978 1 16.9
96 | 1880 1 1 0.978 1 16.9
97 | 1900 8 8 0.978 1 16.9
98 | 1920 2 2 0.978 1 16.9
99 | 1940 5 5 0.265 1 16.9
100 | 1960 2 -1 0.0 0 16.9
101 | 1980 9 9 0.729 1 16.9
102 | 2000 1 1 0.731 1 16.9
103 | 2020 9 9 0.978 1 16.9
104 | 2040 0 0 0.929 1 16.9
105 | 2060 5 3 0.538 0 16.9
106 | 2080 1 9 0.136 0 16.9
107 | 2100 2 2 0.483 1 16.9
108 | 2120 4 8 0.092 0 16.9
109 | 2140 9 9 0.978 1 16.9
110 | 2160 0 0 0.397 1 16.9
111 | 2180 4 4 0.558 1 16.9
112 | 2200 0 0 0.228 1 16.9
113 | 2220 1 1 0.978 1 16.9
114 | 2240 9 9 0.268 1 16.9
115 | 2260 4 4 0.661 1 16.9
116 | 2280 4 4 0.144 1 16.9
117 | 2300 3 3 0.018 1 16.9
118 | 2320 5 2 0.653 0 16.9
119 | 2340 7 7 0.955 1 16.9
120 | 2360 7 7 0.019 1 16.9
121 | 2380 8 8 0.725 1 16.9
122 | 2400 0 0 0.978 1 16.9
123 | 2420 8 8 0.055 1 16.9
124 | 2440 7 7 0.978 1 16.9
125 | 2460 9 9 0.108 1 16.9
126 | 2480 8 8 0.817 1 16.9
127 | 2500 4 4 0.108 1 16.9
128 | 2520 1 1 0.147 1 16.9
129 | 2540 2 6 0.119 0 16.9
130 | 2560 7 7 0.312 1 16.9
131 | 2580 6 5 0.0511 0 16.9
132 | 2600 8 8 0.88 1 16.9
133 | 2620 1 1 0.637 1 16.9
134 | 2640 7 7 0.978 1 16.9
135 | 2660 3 -1 0.0 0 16.9
136 | 2680 0 0 0.978 1 16.9
137 | 2700 9 -1 0.0 0 16.9
138 | 2720 3 3 0.484 1 16.9
139 | 2740 1 1 0.164 1 16.9
140 | 2760 2 2 0.42 1 16.9
141 | 2780 7 7 0.889 1 16.9
142 | 2800 4 4 0.372 1 16.9
143 | 2820 6 6 0.978 1 16.9
144 | 2840 3 3 0.978 1 16.9
145 | 2860 5 4 0.279 0 16.9
146 | 2880 4 4 0.781 1 16.9
147 | 2900 3 7 0.127 0 16.9
148 | 2920 2 3 0.0245 0 16.9
149 | 2940 3 0 0.811 0 16.9
150 | 2960 9 0 0.155 0 16.9
151 | 2980 8 8 0.456 1 16.9
152 | 3000 5 5 0.806 1 16.9
153 | 3020 1 1 0.567 1 16.9
154 | 3040 7 7 0.978 1 16.9
155 | 3060 7 7 0.821 1 16.9
156 | 3080 1 1 0.0759 1 16.9
157 | 3100 0 0 0.62 1 16.9
158 | 3120 4 4 0.137 1 16.9
159 | 3140 8 8 0.589 1 16.9
160 | 3160 6 2 0.114 0 16.9
161 | 3180 3 3 0.0918 1 16.9
162 | 3200 5 5 0.577 1 16.9
163 | 3220 3 9 0.124 0 16.9
164 | 3240 4 -1 0.0 0 16.9
165 | 3260 7 7 0.522 1 16.9
166 | 3280 3 5 0.0818 0 16.9
167 | 3300 4 4 0.285 1 16.9
168 | 3320 2 7 0.335 0 16.9
169 | 3340 6 6 0.525 1 16.9
170 | 3360 4 4 0.442 1 16.9
171 | 3380 8 8 0.978 1 16.9
172 | 3400 6 4 0.49 0 16.9
173 | 3420 5 -1 0.0 0 16.9
174 | 3440 2 3 0.643 0 16.9
175 | 3460 1 1 0.978 1 16.9
176 | 3480 2 2 0.978 1 16.9
177 | 3500 1 1 0.118 1 16.9
178 | 3520 1 1 0.978 1 16.9
179 | 3540 6 6 0.736 1 16.9
180 | 3560 1 1 0.0284 1 16.9
181 | 3580 4 4 0.228 1 16.9
182 | 3600 4 2 0.467 0 16.9
183 | 3620 0 0 0.304 1 16.9
184 | 3640 6 6 0.0873 1 16.9
185 | 3660 0 0 0.506 1 16.9
186 | 3680 0 0 0.0749 1 16.9
187 | 3700 3 3 0.759 1 16.9
188 | 3720 2 2 0.978 1 16.9
189 | 3740 0 0 0.955 1 16.9
190 | 3760 7 7 0.978 1 16.9
191 | 3780 4 4 0.349 1 16.9
192 | 3800 9 9 0.708 1 16.9
193 | 3820 0 0 0.133 1 16.9
194 | 3840 6 3 0.00299 0 16.9
195 | 3860 7 7 0.978 1 16.9
196 | 3880 6 6 0.288 1 16.9
197 | 3900 3 3 0.086 1 16.9
198 | 3920 7 7 0.27 1 16.9
199 | 3940 6 -1 0.0 0 16.9
200 | 3960 2 2 0.784 1 16.9
201 | 3980 9 9 0.163 1 16.9
202 | 4000 8 -1 0.0 0 16.9
203 | 4020 8 8 0.488 1 16.9
204 | 4040 0 0 0.92 1 16.9
205 | 4060 6 6 0.559 1 16.9
206 | 4080 1 1 0.978 1 16.9
207 | 4100 7 7 0.978 1 16.9
208 | 4120 4 4 0.929 1 16.9
209 | 4140 5 5 0.682 1 16.9
210 | 4160 5 -1 0.0 0 16.9
211 | 4180 0 8 0.155 0 16.9
212 | 4200 4 4 0.181 1 16.9
213 | 4220 4 0 0.113 0 16.9
214 | 4240 7 7 0.0474 1 16.9
215 | 4260 8 8 0.92 1 16.9
216 | 4280 8 0 0.394 0 16.9
217 | 4300 8 8 0.621 1 16.9
218 | 4320 1 -1 0.0 0 16.9
219 | 4340 0 0 0.859 1 16.9
220 | 4360 6 6 0.929 1 16.9
221 | 4380 9 9 0.642 1 16.9
222 | 4400 3 -1 0.0 0 16.9
223 | 4420 5 5 0.469 1 16.9
224 | 4440 2 2 0.208 1 16.9
225 | 4460 9 9 0.674 1 16.9
226 | 4480 9 9 0.978 1 16.9
227 | 4500 3 5 0.0371 0 16.9
228 | 4520 3 3 0.174 1 16.9
229 | 4540 9 9 0.784 1 16.9
230 | 4560 1 1 0.955 1 16.9
231 | 4580 6 6 0.412 1 16.9
232 | 4600 4 4 0.512 1 16.9
233 | 4620 7 7 0.125 1 16.9
234 | 4640 2 2 0.34 1 16.9
235 | 4660 7 -1 0.0 0 16.9
236 | 4680 9 9 0.617 1 16.9
237 | 4700 6 5 0.179 0 16.9
238 | 4720 8 8 0.299 1 16.9
239 | 4740 5 3 0.145 0 16.9
240 | 4760 3 2 0.436 0 16.9
241 | 4780 0 0 0.585 1 16.9
242 | 4800 9 9 0.978 1 16.9
243 | 4820 3 3 0.471 1 16.9
244 | 4840 0 0 0.978 1 16.9
245 | 4860 5 5 0.978 1 16.9
246 | 4880 0 8 0.178 0 16.9
247 | 4900 3 3 0.652 1 16.9
248 | 4920 7 7 0.978 1 16.9
249 | 4940 6 6 0.378 1 16.9
250 | 4960 4 4 0.318 1 16.9
251 | 4980 1 9 0.478 0 16.9
252 | 5000 7 7 0.978 1 16.9
253 | 5020 8 8 0.754 1 16.9
254 | 5040 3 3 0.336 1 16.9
255 | 5060 6 6 0.256 1 16.9
256 | 5080 7 7 0.731 1 16.9
257 | 5100 3 3 0.731 1 16.9
258 | 5120 9 9 0.32 1 16.9
259 | 5140 8 8 0.772 1 16.9
260 | 5160 0 0 0.686 1 16.9
261 | 5180 9 9 0.0113 1 16.9
262 | 5200 3 3 0.663 1 16.9
263 | 5220 0 0 0.429 1 16.9
264 | 5240 1 1 0.423 1 16.9
265 | 5260 0 0 0.978 1 16.9
266 | 5280 0 -1 0.0 0 16.9
267 | 5300 9 9 0.743 1 16.9
268 | 5320 7 7 0.449 1 16.9
269 | 5340 3 4 0.135 0 16.9
270 | 5360 9 9 0.802 1 16.9
271 | 5380 1 9 0.267 0 16.9
272 | 5400 9 9 0.781 1 16.9
273 | 5420 6 6 0.844 1 16.9
274 | 5440 9 9 0.955 1 16.9
275 | 5460 0 0 0.978 1 16.9
276 | 5480 1 1 0.641 1 16.9
277 | 5500 8 8 0.336 1 16.9
278 | 5520 4 4 0.279 1 16.9
279 | 5540 5 5 0.104 1 16.9
280 | 5560 3 5 0.293 0 16.9
281 | 5580 5 -1 0.0 0 16.9
282 | 5600 6 6 0.187 1 16.9
283 | 5620 3 -1 0.0 0 16.9
284 | 5640 2 5 0.0519 0 16.9
285 | 5660 9 9 0.165 1 16.9
286 | 5680 2 3 0.155 0 16.9
287 | 5700 3 3 0.655 1 16.9
288 | 5720 9 9 0.978 1 16.9
289 | 5740 6 6 0.322 1 16.9
290 | 5760 2 2 0.26 1 16.9
291 | 5780 7 7 0.978 1 16.9
292 | 5800 2 2 0.517 1 16.9
293 | 5820 5 5 0.433 1 16.9
294 | 5840 2 2 0.322 1 16.9
295 | 5860 0 0 0.516 1 16.9
296 | 5880 0 0 0.889 1 16.9
297 | 5900 4 3 0.0461 0 16.9
298 | 5920 8 8 0.978 1 16.9
299 | 5940 7 7 0.0779 1 16.9
300 | 5960 2 6 0.224 0 16.9
301 | 5980 9 9 0.955 1 16.9
302 | 6000 8 8 0.478 1 16.9
303 | 6020 6 6 0.632 1 16.9
304 | 6040 2 -1 0.0 0 16.9
305 | 6060 3 4 0.0698 0 16.9
306 | 6080 1 1 0.585 1 16.9
307 | 6100 1 1 0.978 1 16.9
308 | 6120 5 5 0.539 1 16.9
309 | 6140 9 9 0.615 1 16.9
310 | 6160 2 -1 0.0 0 16.9
311 | 6180 0 0 0.389 1 16.9
312 | 6200 3 3 0.0459 1 16.9
313 | 6220 2 -1 0.0 0 16.9
314 | 6240 2 2 0.41 1 16.9
315 | 6260 2 2 0.176 1 16.9
316 | 6280 8 -1 0.0 0 16.9
317 | 6300 1 1 0.782 1 16.9
318 | 6320 0 0 0.37 1 16.9
319 | 6340 3 3 0.173 1 16.9
320 | 6360 2 2 0.247 1 16.9
321 | 6380 3 -1 0.0 0 16.9
322 | 6400 0 0 0.543 1 16.9
323 | 6420 7 7 0.978 1 16.9
324 | 6440 3 3 0.101 1 16.9
325 | 6460 3 3 0.724 1 16.9
326 | 6480 0 0 0.565 1 16.9
327 | 6500 7 -1 0.0 0 16.9
328 | 6520 6 6 0.186 1 16.9
329 | 6540 8 8 0.45 1 16.9
330 | 6560 6 6 0.0883 1 16.9
331 | 6580 1 1 0.117 1 16.9
332 | 6600 7 7 0.955 1 16.9
333 | 6620 7 7 0.978 1 16.9
334 | 6640 5 3 0.496 0 16.9
335 | 6660 0 0 0.733 1 16.9
336 | 6680 3 5 0.27 0 16.9
337 | 6700 6 6 0.866 1 16.9
338 | 6720 2 2 0.304 1 16.9
339 | 6740 2 2 0.0894 1 16.9
340 | 6760 5 5 0.955 1 16.9
341 | 6780 7 7 0.978 1 16.9
342 | 6800 6 -1 0.0 0 16.9
343 | 6820 1 1 0.873 1 16.9
344 | 6840 9 9 0.752 1 16.9
345 | 6860 0 0 0.617 1 16.9
346 | 6880 2 2 0.572 1 16.9
347 | 6900 3 3 0.612 1 16.9
348 | 6920 5 4 0.199 0 16.9
349 | 6940 1 1 0.978 1 16.9
350 | 6960 9 9 0.403 1 16.9
351 | 6980 0 0 0.381 1 16.9
352 | 7000 2 0 0.0768 0 16.9
353 | 7020 7 7 0.978 1 16.9
354 | 7040 7 -1 0.0 0 16.9
355 | 7060 8 8 0.894 1 16.9
356 | 7080 5 4 0.175 0 16.9
357 | 7100 9 3 0.349 0 16.9
358 | 7120 7 7 0.316 1 16.9
359 | 7140 4 4 0.147 1 16.9
360 | 7160 5 5 0.688 1 16.9
361 | 7180 6 6 0.445 1 16.9
362 | 7200 4 4 0.105 1 16.9
363 | 7220 9 9 0.669 1 16.9
364 | 7240 0 0 0.414 1 16.9
365 | 7260 1 1 0.765 1 16.9
366 | 7280 1 1 0.231 1 16.9
367 | 7300 3 5 0.0342 0 16.9
368 | 7320 8 0 0.105 0 16.9
369 | 7340 2 2 0.236 1 16.9
370 | 7360 2 2 0.328 1 16.9
371 | 7380 7 7 0.844 1 16.9
372 | 7400 3 5 0.37 0 16.9
373 | 7420 3 4 0.187 0 16.9
374 | 7440 3 3 0.693 1 16.9
375 | 7460 5 5 0.978 1 16.9
376 | 7480 1 1 0.466 1 16.9
377 | 7500 6 6 0.978 1 16.9
378 | 7520 4 4 0.366 1 16.9
379 | 7540 5 5 0.978 1 16.9
380 | 7560 0 0 0.369 1 16.9
381 | 7580 8 8 0.978 1 16.9
382 | 7600 8 2 0.0625 0 16.9
383 | 7620 5 3 0.103 0 16.9
384 | 7640 9 9 0.114 1 16.9
385 | 7660 7 7 0.424 1 16.9
386 | 7680 3 5 0.0632 0 16.9
387 | 7700 6 6 0.92 1 16.9
388 | 7720 5 2 0.00494 0 16.9
389 | 7740 4 4 0.301 1 16.9
390 | 7760 2 0 0.143 0 16.9
391 | 7780 3 -1 0.0 0 16.9
392 | 7800 0 0 0.597 1 16.9
393 | 7820 5 5 0.18 1 16.9
394 | 7840 1 1 0.978 1 16.9
395 | 7860 8 8 0.749 1 16.9
396 | 7880 0 0 0.647 1 16.9
397 | 7900 0 -1 0.0 0 16.9
398 | 7920 9 9 0.849 1 16.9
399 | 7940 2 2 0.0572 1 16.9
400 | 7960 0 0 0.138 1 16.9
401 | 7980 3 3 0.321 1 16.9
402 | 8000 9 9 0.762 1 16.9
403 | 8020 6 3 0.02 0 16.9
404 | 8040 2 2 0.729 1 16.9
405 | 8060 6 6 0.768 1 16.9
406 | 8080 4 4 0.231 1 16.9
407 | 8100 6 2 0.017 0 16.9
408 | 8120 8 8 0.00256 1 16.9
409 | 8140 5 7 0.0382 0 16.9
410 | 8160 6 6 0.978 1 16.9
411 | 8180 4 4 0.239 1 16.9
412 | 8200 3 -1 0.0 0 16.9
413 | 8220 6 5 0.0177 0 16.9
414 | 8240 7 7 0.572 1 16.9
415 | 8260 1 1 0.955 1 16.9
416 | 8280 4 4 0.863 1 16.9
417 | 8300 5 -1 0.0 0 16.9
418 | 8320 7 7 0.27 1 16.9
419 | 8340 5 5 0.104 1 16.9
420 | 8360 7 7 0.978 1 16.9
421 | 8380 9 9 0.978 1 16.9
422 | 8400 0 0 0.168 1 16.9
423 | 8420 3 6 0.573 0 16.9
424 | 8440 5 5 0.141 1 16.9
425 | 8460 8 8 0.771 1 16.9
426 | 8480 2 2 0.363 1 16.9
427 | 8500 4 4 0.64 1 16.9
428 | 8520 8 8 0.41 1 16.9
429 | 8540 4 -1 0.0 0 16.9
430 | 8560 7 7 0.941 1 16.9
431 | 8580 3 5 0.244 0 16.9
432 | 8600 3 -1 0.0 0 16.9
433 | 8620 3 -1 0.0 0 16.9
434 | 8640 1 1 0.579 1 16.9
435 | 8660 7 7 0.686 1 16.9
436 | 8680 3 3 0.00316 1 16.9
437 | 8700 3 3 0.552 1 16.9
438 | 8720 2 0 0.88 0 16.9
439 | 8740 2 2 0.253 1 16.9
440 | 8760 8 8 0.955 1 16.9
441 | 8780 4 4 0.255 1 16.9
442 | 8800 0 0 0.978 1 16.9
443 | 8820 2 2 0.978 1 16.9
444 | 8840 2 2 0.131 1 16.9
445 | 8860 2 2 0.941 1 16.9
446 | 8880 1 1 0.793 1 16.9
447 | 8900 2 4 0.19 0 16.9
448 | 8920 6 6 0.566 1 16.9
449 | 8940 6 3 0.393 0 16.9
450 | 8960 0 0 0.0635 1 16.9
451 | 8980 9 9 0.955 1 16.9
452 | 9000 8 8 0.978 1 16.9
453 | 9020 7 7 0.824 1 16.9
454 | 9040 2 5 0.725 0 16.9
455 | 9060 9 9 0.466 1 16.9
456 | 9080 3 3 0.182 1 16.9
457 | 9100 9 1 0.233 0 16.9
458 | 9120 3 -1 0.0 0 16.9
459 | 9140 7 7 0.455 1 16.9
460 | 9160 5 3 0.19 0 16.9
461 | 9180 5 5 0.978 1 16.9
462 | 9200 8 8 0.802 1 16.9
463 | 9220 8 8 0.821 1 16.9
464 | 9240 8 8 0.467 1 16.9
465 | 9260 5 5 0.319 1 16.9
466 | 9280 9 8 0.0567 0 16.9
467 | 9300 5 -1 0.0 0 16.9
468 | 9320 9 9 0.894 1 16.9
469 | 9340 1 1 0.978 1 16.9
470 | 9360 5 0 0.119 0 16.9
471 | 9380 5 3 0.0574 0 16.9
472 | 9400 6 6 0.978 1 16.9
473 | 9420 5 5 0.787 1 16.9
474 | 9440 8 8 0.572 1 16.9
475 | 9460 3 3 0.14 1 16.9
476 | 9480 2 2 0.178 1 16.9
477 | 9500 9 9 0.644 1 16.9
478 | 9520 1 1 0.721 1 16.9
479 | 9540 5 5 0.659 1 16.9
480 | 9560 9 7 0.123 0 16.9
481 | 9580 1 1 0.978 1 16.9
482 | 9600 8 8 0.322 1 16.9
483 | 9620 4 4 0.352 1 16.9
484 | 9640 5 5 0.358 1 16.9
485 | 9660 6 6 0.906 1 16.9
486 | 9680 8 8 0.844 1 16.9
487 | 9700 0 0 0.709 1 16.9
488 | 9720 8 -1 0.0 0 16.9
489 | 9740 3 6 0.22 0 16.9
490 | 9760 9 -1 0.0 0 16.9
491 | 9780 4 4 0.192 1 16.9
492 | 9800 1 1 0.978 1 16.8
493 | 9820 0 8 0.0396 0 16.9
494 | 9840 4 7 0.563 0 16.9
495 | 9860 0 -1 0.0 0 16.8
496 | 9880 7 7 0.077 1 16.9
497 | 9900 8 8 0.929 1 16.9
498 | 9920 6 6 0.978 1 16.9
499 | 9940 4 4 0.0449 1 16.9
500 | 9960 2 0 0.941 0 16.9
501 | 9980 0 0 0.978 1 16.9
502 |
--------------------------------------------------------------------------------
/data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50:
--------------------------------------------------------------------------------
1 | idx label predict radius correct time
2 | 0 3 3 0.486 1 16.6
3 | 20 7 7 0.413 1 17.2
4 | 40 4 4 0.312 1 17.5
5 | 60 7 7 1.22 1 17.6
6 | 80 8 8 1.2 1 17.6
7 | 100 4 -1 0.0 0 17.6
8 | 120 8 8 0.226 1 17.6
9 | 140 6 6 0.366 1 17.6
10 | 160 2 2 0.276 1 17.6
11 | 180 0 0 0.935 1 17.6
12 | 200 5 5 0.28 1 17.6
13 | 220 7 7 1.37 1 17.6
14 | 240 1 1 1.88 1 17.6
15 | 260 8 8 0.00902 1 17.6
16 | 280 9 9 1.04 1 17.6
17 | 300 6 6 0.684 1 17.6
18 | 320 3 3 0.592 1 17.5
19 | 340 2 4 0.172 0 17.5
20 | 360 9 9 0.301 1 17.6
21 | 380 6 6 0.454 1 17.5
22 | 400 9 9 1.17 1 17.5
23 | 420 4 4 0.803 1 17.5
24 | 440 1 1 0.441 1 17.5
25 | 460 5 5 1.13 1 17.5
26 | 480 8 9 0.556 0 17.5
27 | 500 4 6 0.373 0 17.5
28 | 520 5 5 0.324 1 17.5
29 | 540 1 1 0.685 1 17.5
30 | 560 0 0 1.36 1 17.5
31 | 580 4 4 0.507 1 17.5
32 | 600 8 8 1.15 1 17.5
33 | 620 7 7 0.479 1 17.5
34 | 640 5 3 0.917 0 17.5
35 | 660 9 9 0.84 1 17.5
36 | 680 9 9 0.112 1 17.5
37 | 700 7 7 0.792 1 17.5
38 | 720 4 -1 0.0 0 17.5
39 | 740 2 3 0.0547 0 17.5
40 | 760 3 3 0.187 1 17.5
41 | 780 9 9 0.36 1 17.5
42 | 800 7 7 1.91 1 17.5
43 | 820 6 -1 0.0 0 17.5
44 | 840 7 7 1.77 1 17.5
45 | 860 4 4 0.398 1 17.5
46 | 880 8 8 1.28 1 17.5
47 | 900 2 2 0.0857 1 17.5
48 | 920 6 6 0.344 1 17.5
49 | 940 9 9 1.25 1 17.5
50 | 960 9 9 1.09 1 17.5
51 | 980 2 6 0.141 0 17.5
52 | 1000 5 5 1.34 1 17.5
53 | 1020 1 1 1.45 1 17.5
54 | 1040 7 5 0.775 0 17.5
55 | 1060 9 9 1.45 1 17.5
56 | 1080 6 6 0.546 1 17.5
57 | 1100 7 -1 0.0 0 17.5
58 | 1120 5 5 0.0396 1 17.5
59 | 1140 6 6 0.373 1 17.5
60 | 1160 8 8 0.096 1 17.5
61 | 1180 3 3 0.978 1 17.5
62 | 1200 8 8 0.442 1 17.5
63 | 1220 9 -1 0.0 0 17.5
64 | 1240 2 6 0.868 0 17.5
65 | 1260 1 9 0.423 0 17.5
66 | 1280 3 3 0.0999 1 17.5
67 | 1300 4 3 0.62 0 17.5
68 | 1320 7 -1 0.0 0 17.5
69 | 1340 1 1 1.33 1 17.5
70 | 1360 7 7 0.938 1 17.5
71 | 1380 4 5 0.315 0 17.5
72 | 1400 5 3 0.348 0 17.5
73 | 1420 4 7 0.0896 0 17.5
74 | 1440 0 0 0.137 1 17.5
75 | 1460 6 6 0.963 1 17.5
76 | 1480 1 1 0.336 1 17.5
77 | 1500 1 1 0.934 1 17.5
78 | 1520 7 7 0.876 1 17.5
79 | 1540 8 8 1.35 1 17.5
80 | 1560 7 7 1.96 1 17.5
81 | 1580 6 1 0.103 0 17.5
82 | 1600 8 9 0.116 0 17.5
83 | 1620 5 -1 0.0 0 17.5
84 | 1640 7 3 0.423 0 17.5
85 | 1660 3 3 0.0679 1 17.5
86 | 1680 6 6 0.867 1 17.5
87 | 1700 5 5 0.0387 1 17.5
88 | 1720 4 4 0.0682 1 17.5
89 | 1740 1 1 0.487 1 17.5
90 | 1760 0 0 0.566 1 17.5
91 | 1780 7 7 0.688 1 17.5
92 | 1800 4 -1 0.0 0 17.5
93 | 1820 8 8 0.452 1 17.5
94 | 1840 8 -1 0.0 0 17.5
95 | 1860 8 8 1.12 1 17.5
96 | 1880 1 1 1.01 1 17.5
97 | 1900 8 8 1.7 1 17.5
98 | 1920 2 2 1.02 1 17.5
99 | 1940 5 -1 0.0 0 17.5
100 | 1960 2 5 0.334 0 17.5
101 | 1980 9 9 1.6 1 17.5
102 | 2000 1 1 0.105 1 17.5
103 | 2020 9 9 1.42 1 17.5
104 | 2040 0 0 0.663 1 17.5
105 | 2060 5 3 0.606 0 17.5
106 | 2080 1 1 0.274 1 17.5
107 | 2100 2 2 0.606 1 17.5
108 | 2120 4 8 0.39 0 17.5
109 | 2140 9 9 1.23 1 17.5
110 | 2160 0 0 0.329 1 17.5
111 | 2180 4 4 0.785 1 17.5
112 | 2200 0 -1 0.0 0 17.5
113 | 2220 1 1 1.63 1 17.5
114 | 2240 9 9 0.0244 1 17.5
115 | 2260 4 4 0.823 1 17.5
116 | 2280 4 -1 0.0 0 17.5
117 | 2300 3 -1 0.0 0 17.5
118 | 2320 5 2 0.682 0 17.5
119 | 2340 7 7 1.74 1 17.5
120 | 2360 7 7 0.149 1 17.5
121 | 2380 8 8 1.0 1 17.5
122 | 2400 0 0 0.826 1 17.5
123 | 2420 8 -1 0.0 0 17.5
124 | 2440 7 7 0.627 1 17.5
125 | 2460 9 -1 0.0 0 17.5
126 | 2480 8 8 1.2 1 17.5
127 | 2500 4 4 0.0665 1 17.5
128 | 2520 1 1 0.605 1 17.5
129 | 2540 2 -1 0.0 0 17.5
130 | 2560 7 7 0.364 1 17.5
131 | 2580 6 6 0.0011 1 17.5
132 | 2600 8 8 1.44 1 17.5
133 | 2620 1 1 0.438 1 17.5
134 | 2640 7 7 1.17 1 17.5
135 | 2660 3 5 0.0535 0 17.5
136 | 2680 0 0 1.86 1 17.5
137 | 2700 9 9 0.441 1 17.5
138 | 2720 3 3 0.57 1 17.5
139 | 2740 1 1 0.857 1 17.5
140 | 2760 2 2 0.643 1 17.5
141 | 2780 7 7 1.62 1 17.5
142 | 2800 4 4 0.11 1 17.5
143 | 2820 6 6 1.22 1 17.5
144 | 2840 3 3 1.96 1 17.5
145 | 2860 5 -1 0.0 0 17.5
146 | 2880 4 4 1.4 1 17.5
147 | 2900 3 5 0.0473 0 17.5
148 | 2920 2 3 0.458 0 17.5
149 | 2940 3 0 0.795 0 17.5
150 | 2960 9 9 0.476 1 17.5
151 | 2980 8 8 0.563 1 17.5
152 | 3000 5 5 1.42 1 17.5
153 | 3020 1 1 0.955 1 17.5
154 | 3040 7 7 1.96 1 17.5
155 | 3060 7 7 0.883 1 17.5
156 | 3080 1 1 0.361 1 17.5
157 | 3100 0 0 0.995 1 17.5
158 | 3120 4 -1 0.0 0 17.5
159 | 3140 8 8 0.722 1 17.5
160 | 3160 6 2 0.215 0 17.5
161 | 3180 3 3 0.26 1 17.5
162 | 3200 5 7 0.0822 0 17.5
163 | 3220 3 -1 0.0 0 17.5
164 | 3240 4 -1 0.0 0 17.5
165 | 3260 7 7 1.03 1 17.5
166 | 3280 3 5 0.405 0 17.5
167 | 3300 4 4 0.304 1 17.5
168 | 3320 2 7 1.11 0 17.5
169 | 3340 6 6 0.478 1 17.5
170 | 3360 4 4 0.15 1 17.5
171 | 3380 8 8 1.19 1 17.5
172 | 3400 6 4 0.498 0 17.5
173 | 3420 5 3 0.15 0 17.5
174 | 3440 2 3 0.445 0 17.5
175 | 3460 1 1 0.745 1 17.5
176 | 3480 2 2 1.81 1 17.5
177 | 3500 1 1 0.414 1 17.5
178 | 3520 1 1 1.47 1 17.5
179 | 3540 6 6 0.83 1 17.5
180 | 3560 1 -1 0.0 0 17.5
181 | 3580 4 4 0.0958 1 17.5
182 | 3600 4 2 0.95 0 17.5
183 | 3620 0 -1 0.0 0 17.5
184 | 3640 6 2 0.477 0 17.5
185 | 3660 0 0 0.727 1 17.5
186 | 3680 0 -1 0.0 0 17.5
187 | 3700 3 3 0.648 1 17.5
188 | 3720 2 2 1.36 1 17.5
189 | 3740 0 0 0.924 1 17.5
190 | 3760 7 7 1.27 1 17.5
191 | 3780 4 4 0.215 1 17.5
192 | 3800 9 9 0.585 1 17.5
193 | 3820 0 0 0.00592 1 17.5
194 | 3840 6 3 0.436 0 17.5
195 | 3860 7 7 1.96 1 17.5
196 | 3880 6 -1 0.0 0 17.5
197 | 3900 3 6 0.535 0 17.5
198 | 3920 7 -1 0.0 0 17.5
199 | 3940 6 -1 0.0 0 17.5
200 | 3960 2 2 1.11 1 17.5
201 | 3980 9 9 0.254 1 17.5
202 | 4000 8 0 0.00403 0 17.5
203 | 4020 8 8 0.396 1 17.5
204 | 4040 0 0 1.88 1 17.5
205 | 4060 6 6 0.228 1 17.5
206 | 4080 1 1 0.821 1 17.5
207 | 4100 7 7 1.96 1 17.5
208 | 4120 4 4 0.783 1 17.5
209 | 4140 5 5 0.788 1 17.5
210 | 4160 5 -1 0.0 0 17.5
211 | 4180 0 0 0.0756 1 17.5
212 | 4200 4 2 0.0203 0 17.5
213 | 4220 4 0 1.09 0 17.5
214 | 4240 7 0 0.383 0 17.5
215 | 4260 8 8 1.21 1 17.5
216 | 4280 8 -1 0.0 0 17.5
217 | 4300 8 8 1.03 1 17.5
218 | 4320 1 -1 0.0 0 17.5
219 | 4340 0 0 0.912 1 17.5
220 | 4360 6 -1 0.0 0 17.5
221 | 4380 9 9 0.513 1 17.5
222 | 4400 3 -1 0.0 0 17.5
223 | 4420 5 5 0.29 1 17.5
224 | 4440 2 2 0.358 1 17.5
225 | 4460 9 9 0.362 1 17.5
226 | 4480 9 9 1.41 1 17.5
227 | 4500 3 5 0.4 0 17.5
228 | 4520 3 6 0.25 0 17.5
229 | 4540 9 9 1.48 1 17.5
230 | 4560 1 1 1.1 1 17.5
231 | 4580 6 6 0.0476 1 17.5
232 | 4600 4 7 0.112 0 17.5
233 | 4620 7 3 0.216 0 17.5
234 | 4640 2 2 0.299 1 17.5
235 | 4660 7 7 0.131 1 17.5
236 | 4680 9 9 0.0216 1 17.5
237 | 4700 6 -1 0.0 0 17.5
238 | 4720 8 8 0.173 1 17.5
239 | 4740 5 3 0.618 0 17.5
240 | 4760 3 -1 0.0 0 17.5
241 | 4780 0 0 0.787 1 17.5
242 | 4800 9 9 1.96 1 17.5
243 | 4820 3 3 0.958 1 17.5
244 | 4840 0 0 1.69 1 17.5
245 | 4860 5 5 0.841 1 17.5
246 | 4880 0 -1 0.0 0 17.5
247 | 4900 3 -1 0.0 0 17.5
248 | 4920 7 7 1.42 1 17.5
249 | 4940 6 6 0.169 1 17.5
250 | 4960 4 3 0.089 0 17.5
251 | 4980 1 9 0.555 0 17.5
252 | 5000 7 7 1.96 1 17.5
253 | 5020 8 8 1.34 1 17.5
254 | 5040 3 3 0.0419 1 17.5
255 | 5060 6 6 0.204 1 17.5
256 | 5080 7 7 1.36 1 17.5
257 | 5100 3 3 1.33 1 17.5
258 | 5120 9 9 0.478 1 17.5
259 | 5140 8 8 0.877 1 17.5
260 | 5160 0 0 0.56 1 17.5
261 | 5180 9 -1 0.0 0 17.5
262 | 5200 3 3 1.33 1 17.5
263 | 5220 0 0 0.667 1 17.5
264 | 5240 1 -1 0.0 0 17.5
265 | 5260 0 0 1.73 1 17.5
266 | 5280 0 -1 0.0 0 17.5
267 | 5300 9 9 0.601 1 17.5
268 | 5320 7 7 0.94 1 17.5
269 | 5340 3 4 1.01 0 17.5
270 | 5360 9 9 1.13 1 17.5
271 | 5380 1 9 0.458 0 17.5
272 | 5400 9 9 0.956 1 17.5
273 | 5420 6 6 0.677 1 17.5
274 | 5440 9 9 1.91 1 17.5
275 | 5460 0 0 1.77 1 17.5
276 | 5480 1 1 0.593 1 17.5
277 | 5500 8 8 0.47 1 17.5
278 | 5520 4 -1 0.0 0 17.5
279 | 5540 5 -1 0.0 0 17.5
280 | 5560 3 -1 0.0 0 17.5
281 | 5580 5 3 0.121 0 17.5
282 | 5600 6 -1 0.0 0 17.5
283 | 5620 3 4 0.102 0 17.5
284 | 5640 2 2 0.491 1 17.5
285 | 5660 9 9 0.346 1 17.5
286 | 5680 2 3 0.417 0 17.5
287 | 5700 3 -1 0.0 0 17.5
288 | 5720 9 9 1.51 1 17.5
289 | 5740 6 6 0.379 1 17.5
290 | 5760 2 2 0.242 1 17.5
291 | 5780 7 7 1.84 1 17.5
292 | 5800 2 2 0.182 1 17.5
293 | 5820 5 5 0.268 1 17.5
294 | 5840 2 2 0.461 1 17.5
295 | 5860 0 0 0.146 1 17.5
296 | 5880 0 0 1.22 1 17.5
297 | 5900 4 -1 0.0 0 17.5
298 | 5920 8 8 1.96 1 17.5
299 | 5940 7 4 0.151 0 17.5
300 | 5960 2 6 0.152 0 17.5
301 | 5980 9 9 1.7 1 17.5
302 | 6000 8 8 0.115 1 17.5
303 | 6020 6 6 0.731 1 17.5
304 | 6040 2 5 0.00539 0 17.5
305 | 6060 3 3 0.0204 1 17.5
306 | 6080 1 1 0.625 1 17.5
307 | 6100 1 1 1.15 1 17.5
308 | 6120 5 5 0.741 1 17.5
309 | 6140 9 9 1.05 1 17.5
310 | 6160 2 2 0.258 1 17.5
311 | 6180 0 0 0.626 1 17.5
312 | 6200 3 3 0.49 1 17.5
313 | 6220 2 5 0.178 0 17.5
314 | 6240 2 2 0.0802 1 17.5
315 | 6260 2 -1 0.0 0 17.5
316 | 6280 8 -1 0.0 0 17.5
317 | 6300 1 1 0.493 1 17.5
318 | 6320 0 0 0.588 1 17.5
319 | 6340 3 3 0.827 1 17.5
320 | 6360 2 2 0.196 1 17.5
321 | 6380 3 -1 0.0 0 17.5
322 | 6400 0 -1 0.0 0 17.5
323 | 6420 7 7 1.68 1 17.5
324 | 6440 3 -1 0.0 0 17.5
325 | 6460 3 3 0.289 1 17.5
326 | 6480 0 0 0.66 1 17.5
327 | 6500 7 -1 0.0 0 17.5
328 | 6520 6 -1 0.0 0 17.5
329 | 6540 8 8 0.853 1 17.5
330 | 6560 6 6 0.468 1 17.5
331 | 6580 1 1 0.127 1 17.5
332 | 6600 7 7 1.69 1 17.5
333 | 6620 7 7 1.2 1 17.5
334 | 6640 5 3 0.822 0 17.5
335 | 6660 0 0 1.06 1 17.5
336 | 6680 3 3 0.33 1 17.5
337 | 6700 6 6 0.22 1 17.5
338 | 6720 2 2 0.44 1 17.5
339 | 6740 2 2 0.35 1 17.5
340 | 6760 5 5 1.07 1 17.5
341 | 6780 7 7 1.54 1 17.5
342 | 6800 6 -1 0.0 0 17.5
343 | 6820 1 1 1.7 1 17.5
344 | 6840 9 9 0.669 1 17.5
345 | 6860 0 0 0.335 1 17.5
346 | 6880 2 2 0.562 1 17.5
347 | 6900 3 -1 0.0 0 17.5
348 | 6920 5 -1 0.0 0 17.5
349 | 6940 1 1 1.91 1 17.5
350 | 6960 9 9 0.065 1 17.5
351 | 6980 0 -1 0.0 0 17.5
352 | 7000 2 0 0.122 0 17.5
353 | 7020 7 7 1.47 1 17.5
354 | 7040 7 -1 0.0 0 17.5
355 | 7060 8 8 0.651 1 17.5
356 | 7080 5 4 0.403 0 17.5
357 | 7100 9 -1 0.0 0 17.5
358 | 7120 7 5 0.235 0 17.5
359 | 7140 4 -1 0.0 0 17.5
360 | 7160 5 5 0.97 1 17.5
361 | 7180 6 6 0.74 1 17.5
362 | 7200 4 -1 0.0 0 17.5
363 | 7220 9 9 1.04 1 17.5
364 | 7240 0 0 0.0983 1 17.5
365 | 7260 1 1 0.477 1 17.5
366 | 7280 1 9 0.391 0 17.5
367 | 7300 3 4 0.451 0 17.5
368 | 7320 8 8 0.423 1 17.5
369 | 7340 2 4 0.149 0 17.5
370 | 7360 2 -1 0.0 0 17.5
371 | 7380 7 7 0.911 1 17.5
372 | 7400 3 3 0.668 1 17.5
373 | 7420 3 4 0.118 0 17.5
374 | 7440 3 3 0.582 1 17.5
375 | 7460 5 5 1.15 1 17.5
376 | 7480 1 8 0.612 0 17.5
377 | 7500 6 6 1.88 1 17.5
378 | 7520 4 4 0.199 1 17.5
379 | 7540 5 5 1.4 1 17.5
380 | 7560 0 0 0.485 1 17.5
381 | 7580 8 8 1.12 1 17.5
382 | 7600 8 9 0.3 0 17.5
383 | 7620 5 3 0.206 0 17.5
384 | 7640 9 9 0.0615 1 17.5
385 | 7660 7 -1 0.0 0 17.5
386 | 7680 3 -1 0.0 0 17.5
387 | 7700 6 6 1.22 1 17.5
388 | 7720 5 -1 0.0 0 17.5
389 | 7740 4 4 0.244 1 17.5
390 | 7760 2 9 0.242 0 17.5
391 | 7780 3 -1 0.0 0 17.5
392 | 7800 0 0 0.957 1 17.5
393 | 7820 5 5 0.29 1 17.5
394 | 7840 1 1 1.52 1 17.5
395 | 7860 8 8 0.378 1 17.5
396 | 7880 0 0 0.232 1 17.5
397 | 7900 0 8 0.0991 0 17.5
398 | 7920 9 9 0.38 1 17.5
399 | 7940 2 -1 0.0 0 17.5
400 | 7960 0 0 0.453 1 17.4
401 | 7980 3 3 0.339 1 17.3
402 | 8000 9 9 0.973 1 17.2
403 | 8020 6 3 0.618 0 17.1
404 | 8040 2 2 1.08 1 17.1
405 | 8060 6 6 0.736 1 17.0
406 | 8080 4 -1 0.0 0 17.0
407 | 8100 6 -1 0.0 0 17.0
408 | 8120 8 9 0.668 0 16.9
409 | 8140 5 -1 0.0 0 16.9
410 | 8160 6 6 1.08 1 16.9
411 | 8180 4 4 0.178 1 16.9
412 | 8200 3 4 0.355 0 16.9
413 | 8220 6 -1 0.0 0 16.9
414 | 8240 7 7 1.61 1 16.9
415 | 8260 1 1 1.91 1 16.9
416 | 8280 4 5 0.115 0 16.9
417 | 8300 5 -1 0.0 0 16.9
418 | 8320 7 -1 0.0 0 16.9
419 | 8340 5 3 0.45 0 16.7
420 | 8360 7 7 1.53 1 16.7
421 | 8380 9 9 1.15 1 16.7
422 | 8400 0 -1 0.0 0 16.7
423 | 8420 3 6 0.541 0 16.7
424 | 8440 5 5 0.0846 1 16.7
425 | 8460 8 8 1.28 1 16.7
426 | 8480 2 2 0.479 1 16.7
427 | 8500 4 4 0.725 1 16.7
428 | 8520 8 8 0.299 1 16.7
429 | 8540 4 -1 0.0 0 16.7
430 | 8560 7 7 1.05 1 16.7
431 | 8580 3 -1 0.0 0 16.7
432 | 8600 3 3 0.374 1 16.7
433 | 8620 3 -1 0.0 0 16.7
434 | 8640 1 1 1.01 1 16.7
435 | 8660 7 7 1.16 1 16.7
436 | 8680 3 -1 0.0 0 16.7
437 | 8700 3 3 0.522 1 16.7
438 | 8720 2 0 1.25 0 16.7
439 | 8740 2 2 0.308 1 16.7
440 | 8760 8 8 1.49 1 16.7
441 | 8780 4 4 0.358 1 16.7
442 | 8800 0 0 1.36 1 16.7
443 | 8820 2 2 1.81 1 16.7
444 | 8840 2 -1 0.0 0 16.7
445 | 8860 2 2 0.972 1 16.7
446 | 8880 1 1 0.935 1 16.7
447 | 8900 2 4 0.239 0 16.7
448 | 8920 6 6 0.906 1 16.7
449 | 8940 6 3 0.66 0 16.7
450 | 8960 0 0 0.921 1 16.7
451 | 8980 9 9 1.8 1 16.7
452 | 9000 8 8 1.96 1 16.7
453 | 9020 7 7 0.774 1 16.7
454 | 9040 2 5 0.219 0 16.7
455 | 9060 9 9 0.842 1 16.7
456 | 9080 3 3 0.178 1 16.7
457 | 9100 9 1 0.127 0 16.7
458 | 9120 3 -1 0.0 0 16.7
459 | 9140 7 7 0.0382 1 16.7
460 | 9160 5 3 0.457 0 16.7
461 | 9180 5 5 1.51 1 16.8
462 | 9200 8 8 1.39 1 16.7
463 | 9220 8 8 1.67 1 16.7
464 | 9240 8 8 0.747 1 16.7
465 | 9260 5 5 0.603 1 16.7
466 | 9280 9 -1 0.0 0 16.7
467 | 9300 5 -1 0.0 0 16.7
468 | 9320 9 9 1.13 1 16.7
469 | 9340 1 1 1.88 1 16.7
470 | 9360 5 0 0.00429 0 16.7
471 | 9380 5 3 0.613 0 16.7
472 | 9400 6 6 0.713 1 16.7
473 | 9420 5 5 1.12 1 16.7
474 | 9440 8 8 1.52 1 16.7
475 | 9460 3 7 0.516 0 16.7
476 | 9480 2 2 0.0308 1 16.7
477 | 9500 9 9 0.661 1 16.7
478 | 9520 1 1 0.875 1 16.7
479 | 9540 5 5 0.206 1 16.7
480 | 9560 9 7 0.354 0 16.7
481 | 9580 1 1 0.695 1 16.7
482 | 9600 8 8 0.61 1 16.7
483 | 9620 4 4 0.436 1 16.7
484 | 9640 5 -1 0.0 0 16.7
485 | 9660 6 6 1.8 1 16.7
486 | 9680 8 8 1.52 1 16.7
487 | 9700 0 0 1.07 1 16.7
488 | 9720 8 1 0.0197 0 16.7
489 | 9740 3 6 0.317 0 16.7
490 | 9760 9 -1 0.0 0 16.7
491 | 9780 4 6 0.0379 0 16.7
492 | 9800 1 1 1.91 1 16.7
493 | 9820 0 8 0.336 0 16.7
494 | 9840 4 -1 0.0 0 16.7
495 | 9860 0 5 0.169 0 16.7
496 | 9880 7 7 0.207 1 16.7
497 | 9900 8 8 0.569 1 16.7
498 | 9920 6 6 0.807 1 16.7
499 | 9940 4 5 0.254 0 16.7
500 | 9960 2 0 1.34 0 16.7
501 | 9980 0 0 1.69 1 16.7
502 |
--------------------------------------------------------------------------------
/data/certify/cifar10/resnet110/noise_1.00/test/sigma_0.50:
--------------------------------------------------------------------------------
1 | idx label predict radius correct time
2 | 0 3 2 0.502 0 0:00:14.260753
3 | 20 7 7 0.365 1 0:00:14.348896
4 | 40 4 8 0.625 0 0:00:14.554244
5 | 60 7 4 0.269 0 0:00:14.513019
6 | 80 8 8 1.91 1 0:00:14.514495
7 | 100 4 4 0.541 1 0:00:14.518535
8 | 120 8 8 1.75 1 0:00:14.489042
9 | 140 6 2 0.614 0 0:00:14.502152
10 | 160 2 2 1.34 1 0:00:14.536093
11 | 180 0 0 0.377 1 0:00:14.470639
12 | 200 5 3 1.03 0 0:00:14.535109
13 | 220 7 7 1.55 1 0:00:14.465855
14 | 240 1 1 0.908 1 0:00:14.472940
15 | 260 8 3 0.273 0 0:00:14.486534
16 | 280 9 9 1.1 1 0:00:14.460034
17 | 300 6 4 0.187 0 0:00:14.503257
18 | 320 3 2 0.183 0 0:00:14.505910
19 | 340 2 4 0.648 0 0:00:14.490895
20 | 360 9 6 0.0018 0 0:00:14.429836
21 | 380 6 4 1.2 0 0:00:14.490548
22 | 400 9 9 0.729 1 0:00:14.469219
23 | 420 4 4 0.501 1 0:00:14.469759
24 | 440 1 1 0.44 1 0:00:14.501306
25 | 460 5 5 0.532 1 0:00:14.481773
26 | 480 8 8 0.701 1 0:00:14.485138
27 | 500 4 -1 0.0 0 0:00:14.445102
28 | 520 5 5 0.596 1 0:00:14.509751
29 | 540 1 8 1.39 0 0:00:14.504498
30 | 560 0 0 0.38 1 0:00:14.485221
31 | 580 4 4 1.91 1 0:00:14.442113
32 | 600 8 8 1.39 1 0:00:14.469329
33 | 620 7 4 0.556 0 0:00:14.470549
34 | 640 5 3 1.91 0 0:00:14.468914
35 | 660 9 9 1.19 1 0:00:14.499982
36 | 680 9 1 0.0681 0 0:00:14.482205
37 | 700 7 -1 0.0 0 0:00:14.514781
38 | 720 4 7 0.0293 0 0:00:14.452885
39 | 740 2 -1 0.0 0 0:00:14.456233
40 | 760 3 -1 0.0 0 0:00:14.514683
41 | 780 9 -1 0.0 0 0:00:14.449405
42 | 800 7 7 1.91 1 0:00:14.505088
43 | 820 6 -1 0.0 0 0:00:14.494100
44 | 840 7 7 1.91 1 0:00:14.526783
45 | 860 4 4 1.91 1 0:00:14.434383
46 | 880 8 8 1.91 1 0:00:14.472785
47 | 900 2 4 0.307 0 0:00:14.430315
48 | 920 6 2 0.518 0 0:00:14.539771
49 | 940 9 9 1.1 1 0:00:14.464529
50 | 960 9 9 0.799 1 0:00:14.456657
51 | 980 2 -1 0.0 0 0:00:14.484656
52 | 1000 5 5 1.87 1 0:00:14.493587
53 | 1020 1 8 0.806 0 0:00:14.385534
54 | 1040 7 5 0.438 0 0:00:14.523761
55 | 1060 9 -1 0.0 0 0:00:14.499828
56 | 1080 6 6 1.67 1 0:00:14.479753
57 | 1100 7 8 1.3 0 0:00:14.456918
58 | 1120 5 2 0.924 0 0:00:14.480105
59 | 1140 6 4 0.877 0 0:00:14.462265
60 | 1160 8 8 0.174 1 0:00:14.456139
61 | 1180 3 3 0.448 1 0:00:14.387904
62 | 1200 8 0 0.284 0 0:00:14.468246
63 | 1220 9 6 0.282 0 0:00:14.451054
64 | 1240 2 6 1.43 0 0:00:14.450309
65 | 1260 1 9 0.899 0 0:00:14.517980
66 | 1280 3 7 0.362 0 0:00:14.458498
67 | 1300 4 3 0.248 0 0:00:14.405770
68 | 1320 7 2 1.22 0 0:00:14.483056
69 | 1340 1 1 1.16 1 0:00:14.443302
70 | 1360 7 7 0.226 1 0:00:14.495132
71 | 1380 4 5 0.174 0 0:00:14.432075
72 | 1400 5 2 0.497 0 0:00:14.466615
73 | 1420 4 4 0.413 1 0:00:14.452952
74 | 1440 0 0 0.249 1 0:00:14.467673
75 | 1460 6 4 0.117 0 0:00:14.476620
76 | 1480 1 5 0.277 0 0:00:14.441596
77 | 1500 1 1 0.82 1 0:00:14.505503
78 | 1520 7 4 0.0535 0 0:00:14.376559
79 | 1540 8 8 1.64 1 0:00:14.537273
80 | 1560 7 7 1.91 1 0:00:14.505560
81 | 1580 6 1 0.113 0 0:00:14.487358
82 | 1600 8 -1 0.0 0 0:00:14.492057
83 | 1620 5 8 0.315 0 0:00:14.464475
84 | 1640 7 3 0.28 0 0:00:14.491488
85 | 1660 3 3 0.35 1 0:00:14.431566
86 | 1680 6 6 1.91 1 0:00:14.479111
87 | 1700 5 2 1.87 0 0:00:14.417356
88 | 1720 4 2 0.821 0 0:00:14.442744
89 | 1740 1 1 0.189 1 0:00:14.430615
90 | 1760 0 8 0.709 0 0:00:14.403716
91 | 1780 7 -1 0.0 0 0:00:14.389048
92 | 1800 4 8 1.62 0 0:00:14.392776
93 | 1820 8 8 1.48 1 0:00:14.406297
94 | 1840 8 4 1.1 0 0:00:14.410980
95 | 1860 8 8 1.91 1 0:00:14.386441
96 | 1880 1 8 0.849 0 0:00:14.381203
97 | 1900 8 8 1.91 1 0:00:14.365559
98 | 1920 2 2 1.84 1 0:00:14.292651
99 | 1940 5 3 1.16 0 0:00:14.433078
100 | 1960 2 5 1.28 0 0:00:14.404337
101 | 1980 9 9 0.452 1 0:00:14.329074
102 | 2000 1 6 0.448 0 0:00:14.355776
103 | 2020 9 9 1.54 1 0:00:14.358637
104 | 2040 0 8 0.674 0 0:00:14.362852
105 | 2060 5 5 0.292 1 0:00:14.347701
106 | 2080 1 -1 0.0 0 0:00:14.358643
107 | 2100 2 4 1.21 0 0:00:14.328557
108 | 2120 4 8 1.84 0 0:00:14.351626
109 | 2140 9 9 0.6 1 0:00:14.322432
110 | 2160 0 8 1.77 0 0:00:14.340261
111 | 2180 4 4 1.77 1 0:00:14.317538
112 | 2200 0 2 0.0457 0 0:00:14.316711
113 | 2220 1 1 0.649 1 0:00:14.373689
114 | 2240 9 9 0.811 1 0:00:14.342507
115 | 2260 4 4 1.11 1 0:00:14.340267
116 | 2280 4 8 1.37 0 0:00:14.336520
117 | 2300 3 -1 0.0 0 0:00:14.311955
118 | 2320 5 -1 0.0 0 0:00:14.324267
119 | 2340 7 7 1.67 1 0:00:14.365062
120 | 2360 7 7 0.403 1 0:00:14.331631
121 | 2380 8 8 1.29 1 0:00:14.321607
122 | 2400 0 0 1.26 1 0:00:14.322513
123 | 2420 8 8 0.0765 1 0:00:14.381662
124 | 2440 7 7 0.833 1 0:00:14.278740
125 | 2460 9 3 0.191 0 0:00:14.368503
126 | 2480 8 8 1.91 1 0:00:14.331156
127 | 2500 4 2 0.599 0 0:00:14.336555
128 | 2520 1 8 0.178 0 0:00:14.301184
129 | 2540 2 3 0.5 0 0:00:14.321578
130 | 2560 7 4 0.648 0 0:00:14.333872
131 | 2580 6 4 0.906 0 0:00:14.320160
132 | 2600 8 8 1.91 1 0:00:14.321662
133 | 2620 1 4 0.701 0 0:00:14.360818
134 | 2640 7 7 0.14 1 0:00:14.314903
135 | 2660 3 2 0.986 0 0:00:14.345839
136 | 2680 0 0 1.87 1 0:00:14.299886
137 | 2700 9 0 0.277 0 0:00:14.322745
138 | 2720 3 2 0.798 0 0:00:14.336164
139 | 2740 1 1 0.699 1 0:00:14.326773
140 | 2760 2 2 0.0434 1 0:00:14.276177
141 | 2780 7 7 1.47 1 0:00:14.379085
142 | 2800 4 4 0.576 1 0:00:14.301599
143 | 2820 6 6 0.701 1 0:00:14.323634
144 | 2840 3 3 1.91 1 0:00:14.282580
145 | 2860 5 6 0.533 0 0:00:14.347618
146 | 2880 4 4 1.91 1 0:00:14.312035
147 | 2900 3 2 0.156 0 0:00:14.324175
148 | 2920 2 2 0.411 1 0:00:14.322795
149 | 2940 3 0 1.46 0 0:00:14.330476
150 | 2960 9 9 0.232 1 0:00:14.315882
151 | 2980 8 8 1.35 1 0:00:14.342838
152 | 3000 5 5 1.91 1 0:00:14.315494
153 | 3020 1 8 0.92 0 0:00:14.314904
154 | 3040 7 7 1.91 1 0:00:14.352726
155 | 3060 7 7 1.52 1 0:00:14.321048
156 | 3080 1 4 0.0304 0 0:00:14.325140
157 | 3100 0 0 1.24 1 0:00:14.301325
158 | 3120 4 2 0.609 0 0:00:14.329913
159 | 3140 8 8 1.91 1 0:00:14.308403
160 | 3160 6 4 0.668 0 0:00:14.326945
161 | 3180 3 -1 0.0 0 0:00:14.317472
162 | 3200 5 7 1.18 0 0:00:14.326241
163 | 3220 3 3 0.702 1 0:00:14.328259
164 | 3240 4 -1 0.0 0 0:00:14.368996
165 | 3260 7 -1 0.0 0 0:00:14.326500
166 | 3280 3 0 0.254 0 0:00:14.310939
167 | 3300 4 4 1.49 1 0:00:14.349883
168 | 3320 2 7 0.254 0 0:00:14.333468
169 | 3340 6 4 0.208 0 0:00:14.339950
170 | 3360 4 4 0.384 1 0:00:14.316216
171 | 3380 8 8 1.91 1 0:00:14.329068
172 | 3400 6 4 1.87 0 0:00:14.332860
173 | 3420 5 4 0.347 0 0:00:14.340887
174 | 3440 2 3 0.45 0 0:00:14.356870
175 | 3460 1 1 0.545 1 0:00:14.332382
176 | 3480 2 2 1.91 1 0:00:14.341356
177 | 3500 1 1 0.689 1 0:00:14.350852
178 | 3520 1 1 1.91 1 0:00:14.368938
179 | 3540 6 6 0.538 1 0:00:14.325798
180 | 3560 1 -1 0.0 0 0:00:14.347779
181 | 3580 4 4 0.604 1 0:00:14.338479
182 | 3600 4 2 1.67 0 0:00:14.367562
183 | 3620 0 6 0.279 0 0:00:14.351847
184 | 3640 6 2 1.78 0 0:00:14.304122
185 | 3660 0 8 1.69 0 0:00:14.373169
186 | 3680 0 2 0.821 0 0:00:14.339939
187 | 3700 3 -1 0.0 0 0:00:14.349725
188 | 3720 2 2 1.91 1 0:00:14.330755
189 | 3740 0 0 0.748 1 0:00:14.348924
190 | 3760 7 7 1.91 1 0:00:14.318395
191 | 3780 4 2 1.5 0 0:00:14.333459
192 | 3800 9 -1 0.0 0 0:00:14.349926
193 | 3820 0 0 0.684 1 0:00:14.338120
194 | 3840 6 3 0.817 0 0:00:14.363207
195 | 3860 7 7 1.91 1 0:00:14.307930
196 | 3880 6 4 1.18 0 0:00:14.353190
197 | 3900 3 6 0.0366 0 0:00:14.302730
198 | 3920 7 8 0.961 0 0:00:14.356417
199 | 3940 6 4 0.969 0 0:00:14.319504
200 | 3960 2 2 0.722 1 0:00:14.362688
201 | 3980 9 7 1.21 0 0:00:14.360320
202 | 4000 8 0 0.428 0 0:00:14.336081
203 | 4020 8 8 1.15 1 0:00:14.336947
204 | 4040 0 0 1.75 1 0:00:14.310399
205 | 4060 6 3 0.741 0 0:00:14.320710
206 | 4080 1 8 1.81 0 0:00:14.333818
207 | 4100 7 7 1.91 1 0:00:14.349823
208 | 4120 4 0 0.0273 0 0:00:14.382308
209 | 4140 5 3 1.58 0 0:00:14.298037
210 | 4160 5 2 1.27 0 0:00:14.329030
211 | 4180 0 8 1.23 0 0:00:14.365413
212 | 4200 4 2 0.51 0 0:00:14.322507
213 | 4220 4 0 1.91 0 0:00:14.341161
214 | 4240 7 0 0.777 0 0:00:14.345427
215 | 4260 8 8 1.3 1 0:00:14.329375
216 | 4280 8 8 1.17 1 0:00:14.338165
217 | 4300 8 8 1.91 1 0:00:14.329269
218 | 4320 1 8 1.22 0 0:00:14.351560
219 | 4340 0 0 0.125 1 0:00:14.292989
220 | 4360 6 6 0.697 1 0:00:14.384689
221 | 4380 9 8 0.715 0 0:00:14.319510
222 | 4400 3 4 0.415 0 0:00:14.321643
223 | 4420 5 8 0.802 0 0:00:14.342366
224 | 4440 2 2 1.42 1 0:00:14.354761
225 | 4460 9 8 1.13 0 0:00:14.340989
226 | 4480 9 -1 0.0 0 0:00:14.326163
227 | 4500 3 6 0.332 0 0:00:14.321603
228 | 4520 3 6 1.16 0 0:00:14.337091
229 | 4540 9 9 1.54 1 0:00:14.326691
230 | 4560 1 -1 0.0 0 0:00:14.384115
231 | 4580 6 4 1.3 0 0:00:14.301269
232 | 4600 4 4 0.729 1 0:00:14.361005
233 | 4620 7 3 1.56 0 0:00:14.350336
234 | 4640 2 4 0.616 0 0:00:14.306101
235 | 4660 7 -1 0.0 0 0:00:14.353105
236 | 4680 9 1 0.71 0 0:00:14.322191
237 | 4700 6 4 0.81 0 0:00:14.386648
238 | 4720 8 0 0.419 0 0:00:14.333860
239 | 4740 5 3 0.379 0 0:00:14.361800
240 | 4760 3 3 0.00686 1 0:00:14.356988
241 | 4780 0 0 0.395 1 0:00:14.363134
242 | 4800 9 9 1.91 1 0:00:14.315060
243 | 4820 3 5 0.162 0 0:00:14.337430
244 | 4840 0 0 1.84 1 0:00:14.336288
245 | 4860 5 5 0.407 1 0:00:14.371648
246 | 4880 0 2 1.76 0 0:00:14.316225
247 | 4900 3 6 1.57 0 0:00:14.342275
248 | 4920 7 7 0.745 1 0:00:14.341017
249 | 4940 6 2 0.135 0 0:00:14.337153
250 | 4960 4 4 0.0881 1 0:00:14.338384
251 | 4980 1 9 0.184 0 0:00:14.318410
252 | 5000 7 7 1.84 1 0:00:14.383744
253 | 5020 8 8 1.59 1 0:00:14.338115
254 | 5040 3 4 0.0256 0 0:00:14.348839
255 | 5060 6 4 0.595 0 0:00:14.328453
256 | 5080 7 7 1.91 1 0:00:14.357483
257 | 5100 3 3 1.09 1 0:00:14.318982
258 | 5120 9 8 0.861 0 0:00:14.355429
259 | 5140 8 8 1.91 1 0:00:14.317059
260 | 5160 0 8 0.0582 0 0:00:14.346611
261 | 5180 9 4 0.613 0 0:00:14.328148
262 | 5200 3 3 1.91 1 0:00:14.334650
263 | 5220 0 0 1.01 1 0:00:14.337687
264 | 5240 1 2 0.296 0 0:00:14.352919
265 | 5260 0 0 1.91 1 0:00:14.321813
266 | 5280 0 8 1.87 0 0:00:14.334321
267 | 5300 9 9 1.07 1 0:00:14.369558
268 | 5320 7 7 0.974 1 0:00:14.362770
269 | 5340 3 4 0.599 0 0:00:14.341419
270 | 5360 9 9 1.04 1 0:00:14.339264
271 | 5380 1 9 0.928 0 0:00:14.331964
272 | 5400 9 8 0.495 0 0:00:14.324177
273 | 5420 6 6 1.59 1 0:00:14.327261
274 | 5440 9 9 1.44 1 0:00:14.355649
275 | 5460 0 7 0.598 0 0:00:14.368535
276 | 5480 1 8 0.666 0 0:00:14.335977
277 | 5500 8 8 1.46 1 0:00:14.335099
278 | 5520 4 3 0.469 0 0:00:14.362765
279 | 5540 5 2 0.484 0 0:00:14.320348
280 | 5560 3 6 0.197 0 0:00:14.324019
281 | 5580 5 6 1.28 0 0:00:14.369295
282 | 5600 6 0 0.278 0 0:00:14.300689
283 | 5620 3 2 0.875 0 0:00:14.361083
284 | 5640 2 3 0.101 0 0:00:14.344099
285 | 5660 9 8 0.924 0 0:00:14.362576
286 | 5680 2 8 0.551 0 0:00:14.323498
287 | 5700 3 -1 0.0 0 0:00:14.362411
288 | 5720 9 9 1.56 1 0:00:14.343616
289 | 5740 6 4 1.75 0 0:00:14.373607
290 | 5760 2 2 1.2 1 0:00:14.312141
291 | 5780 7 7 0.855 1 0:00:14.373281
292 | 5800 2 4 0.0133 0 0:00:14.327827
293 | 5820 5 -1 0.0 0 0:00:14.371184
294 | 5840 2 -1 0.0 0 0:00:14.343114
295 | 5860 0 8 0.668 0 0:00:14.339405
296 | 5880 0 0 0.728 1 0:00:14.347477
297 | 5900 4 4 1.03 1 0:00:14.363754
298 | 5920 8 8 1.66 1 0:00:14.339641
299 | 5940 7 4 1.75 0 0:00:14.334054
300 | 5960 2 4 0.857 0 0:00:14.357543
301 | 5980 9 9 0.914 1 0:00:14.333889
302 | 6000 8 8 1.87 1 0:00:14.356345
303 | 6020 6 6 1.61 1 0:00:14.327467
304 | 6040 2 7 1.31 0 0:00:14.354960
305 | 6060 3 4 1.1 0 0:00:14.345023
306 | 6080 1 -1 0.0 0 0:00:14.262996
307 | 6100 1 8 0.64 0 0:00:14.397099
308 | 6120 5 5 1.12 1 0:00:14.358169
309 | 6140 9 9 0.819 1 0:00:14.319568
310 | 6160 2 5 1.35 0 0:00:14.346238
311 | 6180 0 -1 0.0 0 0:00:14.335939
312 | 6200 3 3 0.598 1 0:00:14.356852
313 | 6220 2 -1 0.0 0 0:00:14.314770
314 | 6240 2 4 0.402 0 0:00:14.368407
315 | 6260 2 4 0.398 0 0:00:14.316968
316 | 6280 8 4 0.127 0 0:00:14.379302
317 | 6300 1 4 0.272 0 0:00:14.336131
318 | 6320 0 0 1.91 1 0:00:14.365493
319 | 6340 3 3 1.31 1 0:00:14.328599
320 | 6360 2 2 1.48 1 0:00:14.334644
321 | 6380 3 6 0.616 0 0:00:14.351504
322 | 6400 0 8 0.176 0 0:00:14.331140
323 | 6420 7 7 0.062 1 0:00:14.382440
324 | 6440 3 8 0.689 0 0:00:14.325770
325 | 6460 3 4 0.0612 0 0:00:14.333105
326 | 6480 0 0 0.646 1 0:00:14.355450
327 | 6500 7 8 0.0157 0 0:00:14.338653
328 | 6520 6 4 1.02 0 0:00:14.324024
329 | 6540 8 8 1.91 1 0:00:14.344356
330 | 6560 6 -1 0.0 0 0:00:14.348085
331 | 6580 1 2 0.23 0 0:00:14.362120
332 | 6600 7 7 1.91 1 0:00:14.341597
333 | 6620 7 7 1.29 1 0:00:14.330906
334 | 6640 5 3 0.0195 0 0:00:14.333315
335 | 6660 0 0 0.483 1 0:00:14.349654
336 | 6680 3 3 0.127 1 0:00:14.345540
337 | 6700 6 -1 0.0 0 0:00:14.326991
338 | 6720 2 2 1.45 1 0:00:14.320856
339 | 6740 2 4 1.07 0 0:00:14.320175
340 | 6760 5 3 0.102 0 0:00:14.373966
341 | 6780 7 7 1.77 1 0:00:14.330199
342 | 6800 6 4 0.302 0 0:00:14.319868
343 | 6820 1 1 1.33 1 0:00:14.345896
344 | 6840 9 9 1.17 1 0:00:14.355399
345 | 6860 0 0 0.186 1 0:00:14.329840
346 | 6880 2 -1 0.0 0 0:00:14.349433
347 | 6900 3 4 0.59 0 0:00:14.356242
348 | 6920 5 2 1.74 0 0:00:14.330287
349 | 6940 1 1 1.87 1 0:00:14.334181
350 | 6960 9 8 1.83 0 0:00:14.340834
351 | 6980 0 0 0.00823 1 0:00:14.300156
352 | 7000 2 8 0.736 0 0:00:14.351959
353 | 7020 7 7 1.13 1 0:00:14.353994
354 | 7040 7 8 1.91 0 0:00:14.363911
355 | 7060 8 8 1.05 1 0:00:14.333409
356 | 7080 5 4 0.832 0 0:00:14.334554
357 | 7100 9 8 0.26 0 0:00:14.336618
358 | 7120 7 7 0.101 1 0:00:14.341166
359 | 7140 4 4 1.68 1 0:00:14.362059
360 | 7160 5 4 1.15 0 0:00:14.344642
361 | 7180 6 6 1.84 1 0:00:14.328773
362 | 7200 4 2 0.524 0 0:00:14.325600
363 | 7220 9 8 0.351 0 0:00:14.343401
364 | 7240 0 8 1.91 0 0:00:14.312107
365 | 7260 1 8 0.462 0 0:00:14.358999
366 | 7280 1 6 0.399 0 0:00:14.305262
367 | 7300 3 4 0.56 0 0:00:14.331922
368 | 7320 8 8 0.999 1 0:00:14.351168
369 | 7340 2 4 0.931 0 0:00:14.327765
370 | 7360 2 4 1.83 0 0:00:14.334838
371 | 7380 7 7 1.46 1 0:00:14.339803
372 | 7400 3 3 1.29 1 0:00:14.346749
373 | 7420 3 0 0.0586 0 0:00:14.298488
374 | 7440 3 2 0.0534 0 0:00:14.372183
375 | 7460 5 5 1.91 1 0:00:14.344006
376 | 7480 1 8 0.563 0 0:00:14.355524
377 | 7500 6 6 1.91 1 0:00:14.338224
378 | 7520 4 4 1.56 1 0:00:14.324864
379 | 7540 5 5 1.91 1 0:00:14.354052
380 | 7560 0 -1 0.0 0 0:00:14.355483
381 | 7580 8 8 1.91 1 0:00:14.349985
382 | 7600 8 8 0.738 1 0:00:14.339299
383 | 7620 5 3 0.603 0 0:00:14.370959
384 | 7640 9 -1 0.0 0 0:00:14.277472
385 | 7660 7 4 1.37 0 0:00:14.379815
386 | 7680 3 3 1.38 1 0:00:14.355321
387 | 7700 6 6 0.809 1 0:00:14.360164
388 | 7720 5 -1 0.0 0 0:00:14.303431
389 | 7740 4 2 0.386 0 0:00:14.359398
390 | 7760 2 -1 0.0 0 0:00:14.368498
391 | 7780 3 8 0.344 0 0:00:14.322312
392 | 7800 0 0 0.665 1 0:00:14.327515
393 | 7820 5 8 0.881 0 0:00:14.336834
394 | 7840 1 1 1.87 1 0:00:14.372243
395 | 7860 8 4 0.346 0 0:00:14.353720
396 | 7880 0 0 1.09 1 0:00:14.330323
397 | 7900 0 8 0.477 0 0:00:14.353978
398 | 7920 9 8 0.818 0 0:00:14.339945
399 | 7940 2 -1 0.0 0 0:00:14.332996
400 | 7960 0 0 0.242 1 0:00:14.328747
401 | 7980 3 4 0.69 0 0:00:14.346883
402 | 8000 9 9 0.693 1 0:00:14.336344
403 | 8020 6 3 0.436 0 0:00:14.297441
404 | 8040 2 2 0.419 1 0:00:14.342679
405 | 8060 6 -1 0.0 0 0:00:14.333190
406 | 8080 4 4 0.144 1 0:00:14.317881
407 | 8100 6 -1 0.0 0 0:00:14.312716
408 | 8120 8 8 0.715 1 0:00:14.359991
409 | 8140 5 4 1.84 0 0:00:14.328249
410 | 8160 6 6 1.52 1 0:00:14.355975
411 | 8180 4 4 0.372 1 0:00:14.319064
412 | 8200 3 -1 0.0 0 0:00:14.334832
413 | 8220 6 -1 0.0 0 0:00:14.314748
414 | 8240 7 7 1.5 1 0:00:14.361883
415 | 8260 1 1 0.17 1 0:00:14.346026
416 | 8280 4 -1 0.0 0 0:00:14.347736
417 | 8300 5 2 1.27 0 0:00:14.307997
418 | 8320 7 4 1.56 0 0:00:14.341300
419 | 8340 5 3 0.63 0 0:00:14.349405
420 | 8360 7 4 0.154 0 0:00:14.319071
421 | 8380 9 8 0.409 0 0:00:14.337004
422 | 8400 0 8 0.0526 0 0:00:14.339651
423 | 8420 3 4 0.233 0 0:00:14.340102
424 | 8440 5 3 0.0606 0 0:00:14.347626
425 | 8460 8 8 1.91 1 0:00:14.288629
426 | 8480 2 4 0.0412 0 0:00:14.316075
427 | 8500 4 4 0.225 1 0:00:14.347934
428 | 8520 8 4 0.776 0 0:00:14.337139
429 | 8540 4 2 0.325 0 0:00:14.312202
430 | 8560 7 7 0.194 1 0:00:14.322416
431 | 8580 3 2 0.037 0 0:00:14.377486
432 | 8600 3 4 0.159 0 0:00:14.339806
433 | 8620 3 4 0.606 0 0:00:14.314949
434 | 8640 1 1 1.38 1 0:00:14.371932
435 | 8660 7 7 0.665 1 0:00:14.338834
436 | 8680 3 4 0.0978 0 0:00:14.363549
437 | 8700 3 3 1.71 1 0:00:14.339277
438 | 8720 2 0 1.3 0 0:00:14.349474
439 | 8740 2 4 1.11 0 0:00:14.319043
440 | 8760 8 8 1.14 1 0:00:14.391077
441 | 8780 4 4 0.987 1 0:00:14.283131
442 | 8800 0 0 1.91 1 0:00:14.352883
443 | 8820 2 2 0.78 1 0:00:14.347805
444 | 8840 2 -1 0.0 0 0:00:14.362284
445 | 8860 2 2 0.552 1 0:00:14.330706
446 | 8880 1 1 0.415 1 0:00:14.352176
447 | 8900 2 4 1.5 0 0:00:14.350649
448 | 8920 6 6 1.13 1 0:00:14.360373
449 | 8940 6 6 0.115 1 0:00:14.365397
450 | 8960 0 0 1.91 1 0:00:14.340553
451 | 8980 9 9 0.951 1 0:00:14.368227
452 | 9000 8 8 1.91 1 0:00:14.250843
453 | 9020 7 -1 0.0 0 0:00:14.430918
454 | 9040 2 5 0.825 0 0:00:14.361758
455 | 9060 9 5 0.728 0 0:00:14.319256
456 | 9080 3 3 0.504 1 0:00:14.357976
457 | 9100 9 6 0.453 0 0:00:14.367532
458 | 9120 3 -1 0.0 0 0:00:14.344448
459 | 9140 7 4 0.00065 0 0:00:14.291768
460 | 9160 5 5 0.545 1 0:00:14.383770
461 | 9180 5 5 1.91 1 0:00:14.345354
462 | 9200 8 8 1.91 1 0:00:14.334551
463 | 9220 8 8 1.91 1 0:00:14.365463
464 | 9240 8 8 1.77 1 0:00:14.340728
465 | 9260 5 5 1.2 1 0:00:14.337726
466 | 9280 9 8 0.272 0 0:00:14.350229
467 | 9300 5 2 0.451 0 0:00:14.318248
468 | 9320 9 8 0.16 0 0:00:14.353187
469 | 9340 1 1 1.91 1 0:00:14.365501
470 | 9360 5 -1 0.0 0 0:00:14.350924
471 | 9380 5 3 1.36 0 0:00:14.393360
472 | 9400 6 4 0.158 0 0:00:14.331827
473 | 9420 5 6 0.0157 0 0:00:14.326165
474 | 9440 8 8 1.91 1 0:00:14.375585
475 | 9460 3 6 0.267 0 0:00:14.351903
476 | 9480 2 4 1.39 0 0:00:14.323893
477 | 9500 9 -1 0.0 0 0:00:14.348245
478 | 9520 1 1 1.3 1 0:00:14.373777
479 | 9540 5 -1 0.0 0 0:00:14.321848
480 | 9560 9 7 0.843 0 0:00:14.350507
481 | 9580 1 4 0.841 0 0:00:14.364640
482 | 9600 8 8 1.91 1 0:00:14.331119
483 | 9620 4 4 0.28 1 0:00:14.336773
484 | 9640 5 6 0.0726 0 0:00:14.346887
485 | 9660 6 6 1.91 1 0:00:14.355296
486 | 9680 8 8 1.91 1 0:00:14.372817
487 | 9700 0 0 0.877 1 0:00:14.320988
488 | 9720 8 8 0.0438 1 0:00:14.341066
489 | 9740 3 4 0.859 0 0:00:14.335354
490 | 9760 9 7 0.249 0 0:00:14.370011
491 | 9780 4 4 0.536 1 0:00:14.319711
492 | 9800 1 1 1.91 1 0:00:14.356350
493 | 9820 0 8 0.899 0 0:00:14.352618
494 | 9840 4 5 0.548 0 0:00:14.345452
495 | 9860 0 8 0.865 0 0:00:14.327323
496 | 9880 7 4 0.644 0 0:00:14.355347
497 | 9900 8 8 0.8 1 0:00:14.371057
498 | 9920 6 2 0.424 0 0:00:14.328738
499 | 9940 4 -1 0.0 0 0:00:14.363681
500 | 9960 2 0 1.91 0 0:00:14.337633
501 | 9980 0 0 1.87 1 0:00:14.363302
502 |
--------------------------------------------------------------------------------
/data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00:
--------------------------------------------------------------------------------
1 | idx label predict radius correct time
2 | 0 3 -1 0.0 0 15.3
3 | 20 7 -1 0.0 0 15.4
4 | 40 4 -1 0.0 0 15.5
5 | 60 7 7 0.689 1 15.6
6 | 80 8 8 1.16 1 15.7
7 | 100 4 -1 0.0 0 15.8
8 | 120 8 8 0.938 1 15.8
9 | 140 6 -1 0.0 0 15.9
10 | 160 2 0 0.105 0 15.9
11 | 180 0 0 0.777 1 16.1
12 | 200 5 3 0.257 0 16.1
13 | 220 7 7 1.79 1 16.2
14 | 240 1 1 2.1 1 16.2
15 | 260 8 -1 0.0 0 16.2
16 | 280 9 9 2.46 1 16.2
17 | 300 6 6 0.181 1 16.1
18 | 320 3 -1 0.0 0 16.1
19 | 340 2 4 0.117 0 16.1
20 | 360 9 -1 0.0 0 16.2
21 | 380 6 6 0.0784 1 16.2
22 | 400 9 9 1.0 1 16.2
23 | 420 4 4 0.391 1 16.2
24 | 440 1 1 0.928 1 16.2
25 | 460 5 5 0.721 1 16.2
26 | 480 8 9 0.509 0 16.2
27 | 500 4 6 0.258 0 16.2
28 | 520 5 5 0.206 1 16.2
29 | 540 1 -1 0.0 0 16.2
30 | 560 0 0 1.47 1 16.2
31 | 580 4 4 0.458 1 16.2
32 | 600 8 8 0.475 1 16.2
33 | 620 7 -1 0.0 0 16.2
34 | 640 5 3 1.92 0 16.2
35 | 660 9 9 2.37 1 16.2
36 | 680 9 -1 0.0 0 16.2
37 | 700 7 7 0.326 1 16.2
38 | 720 4 9 0.165 0 16.2
39 | 740 2 5 0.222 0 16.2
40 | 760 3 9 0.0399 0 16.2
41 | 780 9 -1 0.0 0 16.2
42 | 800 7 7 2.0 1 16.2
43 | 820 6 -1 0.0 0 16.2
44 | 840 7 7 2.9 1 16.2
45 | 860 4 4 0.978 1 16.2
46 | 880 8 8 2.77 1 16.2
47 | 900 2 -1 0.0 0 16.2
48 | 920 6 6 0.125 1 16.2
49 | 940 9 9 1.07 1 16.2
50 | 960 9 9 1.86 1 16.2
51 | 980 2 6 0.842 0 16.2
52 | 1000 5 5 2.25 1 16.2
53 | 1020 1 1 0.734 1 16.2
54 | 1040 7 5 1.05 0 16.2
55 | 1060 9 9 1.44 1 16.2
56 | 1080 6 6 1.1 1 16.2
57 | 1100 7 8 0.482 0 16.2
58 | 1120 5 -1 0.0 0 16.2
59 | 1140 6 6 0.113 1 16.2
60 | 1160 8 8 0.106 1 16.2
61 | 1180 3 3 0.671 1 16.2
62 | 1200 8 8 0.376 1 16.2
63 | 1220 9 -1 0.0 0 16.2
64 | 1240 2 6 0.897 0 16.2
65 | 1260 1 9 0.994 0 16.2
66 | 1280 3 7 0.515 0 16.2
67 | 1300 4 3 0.608 0 16.2
68 | 1320 7 2 0.0536 0 16.2
69 | 1340 1 1 1.05 1 16.2
70 | 1360 7 7 0.912 1 16.2
71 | 1380 4 5 0.868 0 16.2
72 | 1400 5 -1 0.0 0 16.2
73 | 1420 4 7 0.00959 0 16.2
74 | 1440 0 0 0.0527 1 16.2
75 | 1460 6 6 0.53 1 16.2
76 | 1480 1 -1 0.0 0 16.2
77 | 1500 1 1 0.34 1 16.2
78 | 1520 7 7 0.618 1 16.2
79 | 1540 8 8 0.829 1 16.2
80 | 1560 7 7 3.05 1 16.2
81 | 1580 6 1 0.244 0 16.2
82 | 1600 8 -1 0.0 0 16.2
83 | 1620 5 -1 0.0 0 16.2
84 | 1640 7 5 0.218 0 16.2
85 | 1660 3 -1 0.0 0 16.2
86 | 1680 6 6 2.04 1 16.2
87 | 1700 5 2 0.454 0 16.2
88 | 1720 4 -1 0.0 0 16.2
89 | 1740 1 -1 0.0 0 16.2
90 | 1760 0 -1 0.0 0 16.2
91 | 1780 7 -1 0.0 0 16.2
92 | 1800 4 8 0.362 0 16.2
93 | 1820 8 8 0.29 1 16.2
94 | 1840 8 -1 0.0 0 16.2
95 | 1860 8 8 1.71 1 16.2
96 | 1880 1 1 0.685 1 16.2
97 | 1900 8 8 2.34 1 16.2
98 | 1920 2 2 0.53 1 16.2
99 | 1940 5 3 1.21 0 16.2
100 | 1960 2 5 1.34 0 16.2
101 | 1980 9 9 1.45 1 16.2
102 | 2000 1 -1 0.0 0 16.2
103 | 2020 9 9 2.07 1 16.2
104 | 2040 0 8 0.536 0 16.2
105 | 2060 5 5 0.733 1 16.2
106 | 2080 1 1 0.0422 1 16.2
107 | 2100 2 4 0.462 0 16.2
108 | 2120 4 8 0.99 0 16.2
109 | 2140 9 9 1.23 1 16.2
110 | 2160 0 8 1.15 0 16.2
111 | 2180 4 4 0.581 1 16.2
112 | 2200 0 0 0.504 1 16.2
113 | 2220 1 1 1.32 1 16.2
114 | 2240 9 9 1.13 1 16.2
115 | 2260 4 -1 0.0 0 16.2
116 | 2280 4 8 0.0794 0 16.2
117 | 2300 3 -1 0.0 0 16.2
118 | 2320 5 -1 0.0 0 16.2
119 | 2340 7 7 1.32 1 16.2
120 | 2360 7 7 0.423 1 16.2
121 | 2380 8 8 0.817 1 16.2
122 | 2400 0 0 1.23 1 16.2
123 | 2420 8 -1 0.0 0 16.2
124 | 2440 7 7 0.705 1 16.2
125 | 2460 9 9 0.233 1 16.2
126 | 2480 8 8 1.21 1 16.2
127 | 2500 4 -1 0.0 0 16.2
128 | 2520 1 1 0.171 1 16.2
129 | 2540 2 3 0.718 0 16.2
130 | 2560 7 7 0.106 1 16.2
131 | 2580 6 -1 0.0 0 16.2
132 | 2600 8 8 2.39 1 16.2
133 | 2620 1 1 0.0818 1 16.2
134 | 2640 7 7 0.914 1 16.2
135 | 2660 3 -1 0.0 0 16.2
136 | 2680 0 0 2.5 1 16.2
137 | 2700 9 -1 0.0 0 16.2
138 | 2720 3 -1 0.0 0 16.2
139 | 2740 1 1 0.741 1 16.2
140 | 2760 2 2 0.484 1 16.2
141 | 2780 7 7 1.76 1 16.2
142 | 2800 4 6 0.207 0 16.2
143 | 2820 6 6 1.9 1 16.2
144 | 2840 3 3 2.8 1 16.2
145 | 2860 5 -1 0.0 0 16.2
146 | 2880 4 4 2.16 1 16.2
147 | 2900 3 2 0.00313 0 16.2
148 | 2920 2 -1 0.0 0 16.2
149 | 2940 3 0 1.32 0 16.2
150 | 2960 9 9 0.811 1 16.2
151 | 2980 8 -1 0.0 0 16.2
152 | 3000 5 5 2.45 1 16.2
153 | 3020 1 1 0.913 1 16.2
154 | 3040 7 7 3.22 1 16.2
155 | 3060 7 7 1.47 1 16.2
156 | 3080 1 -1 0.0 0 16.2
157 | 3100 0 0 1.92 1 16.2
158 | 3120 4 6 0.659 0 16.2
159 | 3140 8 8 0.955 1 16.2
160 | 3160 6 4 0.319 0 16.2
161 | 3180 3 -1 0.0 0 16.2
162 | 3200 5 7 0.806 0 16.2
163 | 3220 3 -1 0.0 0 16.2
164 | 3240 4 -1 0.0 0 16.2
165 | 3260 7 -1 0.0 0 16.2
166 | 3280 3 5 0.347 0 16.2
167 | 3300 4 -1 0.0 0 16.2
168 | 3320 2 7 1.15 0 16.2
169 | 3340 6 6 0.255 1 16.2
170 | 3360 4 6 0.245 0 16.2
171 | 3380 8 8 0.329 1 16.2
172 | 3400 6 4 1.24 0 16.2
173 | 3420 5 -1 0.0 0 16.2
174 | 3440 2 -1 0.0 0 16.2
175 | 3460 1 1 1.75 1 16.2
176 | 3480 2 2 2.88 1 16.2
177 | 3500 1 1 1.14 1 16.2
178 | 3520 1 1 2.06 1 16.2
179 | 3540 6 6 0.575 1 16.2
180 | 3560 1 9 0.376 0 16.2
181 | 3580 4 -1 0.0 0 16.2
182 | 3600 4 2 0.824 0 16.2
183 | 3620 0 -1 0.0 0 16.2
184 | 3640 6 2 0.987 0 16.2
185 | 3660 0 8 0.101 0 16.2
186 | 3680 0 -1 0.0 0 16.2
187 | 3700 3 -1 0.0 0 16.2
188 | 3720 2 2 1.46 1 16.2
189 | 3740 0 0 1.27 1 16.2
190 | 3760 7 7 3.04 1 16.2
191 | 3780 4 -1 0.0 0 16.2
192 | 3800 9 9 0.55 1 16.2
193 | 3820 0 0 0.5 1 16.2
194 | 3840 6 3 0.31 0 16.2
195 | 3860 7 7 2.25 1 16.2
196 | 3880 6 4 0.19 0 16.2
197 | 3900 3 -1 0.0 0 16.2
198 | 3920 7 -1 0.0 0 16.2
199 | 3940 6 -1 0.0 0 16.2
200 | 3960 2 2 0.648 1 16.2
201 | 3980 9 7 0.666 0 16.2
202 | 4000 8 -1 0.0 0 16.2
203 | 4020 8 -1 0.0 0 16.2
204 | 4040 0 0 1.64 1 16.2
205 | 4060 6 3 0.65 0 16.2
206 | 4080 1 -1 0.0 0 16.2
207 | 4100 7 7 3.04 1 16.2
208 | 4120 4 0 1.17 0 16.2
209 | 4140 5 3 1.3 0 16.2
210 | 4160 5 2 0.0858 0 16.2
211 | 4180 0 0 0.406 1 16.2
212 | 4200 4 -1 0.0 0 16.2
213 | 4220 4 0 2.33 0 16.2
214 | 4240 7 0 0.775 0 16.2
215 | 4260 8 8 0.589 1 16.2
216 | 4280 8 8 0.0455 1 16.2
217 | 4300 8 8 1.29 1 16.2
218 | 4320 1 8 0.888 0 16.2
219 | 4340 0 0 0.723 1 16.2
220 | 4360 6 -1 0.0 0 16.2
221 | 4380 9 9 0.703 1 16.2
222 | 4400 3 -1 0.0 0 16.2
223 | 4420 5 -1 0.0 0 16.2
224 | 4440 2 2 0.21 1 16.2
225 | 4460 9 -1 0.0 0 16.2
226 | 4480 9 9 0.35 1 16.2
227 | 4500 3 3 0.129 1 16.2
228 | 4520 3 6 0.487 0 16.2
229 | 4540 9 9 1.98 1 16.2
230 | 4560 1 1 1.05 1 16.2
231 | 4580 6 -1 0.0 0 16.2
232 | 4600 4 7 0.597 0 16.2
233 | 4620 7 3 0.944 0 16.2
234 | 4640 2 -1 0.0 0 16.2
235 | 4660 7 -1 0.0 0 16.2
236 | 4680 9 1 0.68 0 16.2
237 | 4700 6 -1 0.0 0 16.2
238 | 4720 8 -1 0.0 0 16.2
239 | 4740 5 3 0.253 0 16.2
240 | 4760 3 3 0.146 1 16.2
241 | 4780 0 0 1.03 1 16.2
242 | 4800 9 9 2.91 1 16.2
243 | 4820 3 -1 0.0 0 16.2
244 | 4840 0 0 2.04 1 16.2
245 | 4860 5 5 0.274 1 16.2
246 | 4880 0 2 1.08 0 16.2
247 | 4900 3 6 0.748 0 16.2
248 | 4920 7 7 1.18 1 16.2
249 | 4940 6 6 0.00465 1 16.2
250 | 4960 4 -1 0.0 0 16.2
251 | 4980 1 9 0.828 0 16.2
252 | 5000 7 7 2.4 1 16.2
253 | 5020 8 8 0.421 1 16.2
254 | 5040 3 3 0.1 1 16.2
255 | 5060 6 -1 0.0 0 16.2
256 | 5080 7 7 2.45 1 16.2
257 | 5100 3 3 1.0 1 16.2
258 | 5120 9 9 0.812 1 16.2
259 | 5140 8 8 1.22 1 16.2
260 | 5160 0 0 0.357 1 16.2
261 | 5180 9 -1 0.0 0 16.2
262 | 5200 3 3 2.66 1 16.2
263 | 5220 0 0 1.31 1 16.2
264 | 5240 1 -1 0.0 0 16.2
265 | 5260 0 0 2.81 1 16.2
266 | 5280 0 8 0.979 0 16.2
267 | 5300 9 9 0.786 1 16.2
268 | 5320 7 7 0.801 1 16.2
269 | 5340 3 4 0.81 0 16.2
270 | 5360 9 9 1.49 1 16.2
271 | 5380 1 9 0.993 0 16.2
272 | 5400 9 9 0.265 1 16.2
273 | 5420 6 6 0.781 1 16.2
274 | 5440 9 9 2.32 1 16.2
275 | 5460 0 7 0.44 0 16.2
276 | 5480 1 1 0.451 1 16.2
277 | 5500 8 -1 0.0 0 16.2
278 | 5520 4 -1 0.0 0 16.2
279 | 5540 5 -1 0.0 0 16.2
280 | 5560 3 -1 0.0 0 16.2
281 | 5580 5 6 0.491 0 16.2
282 | 5600 6 -1 0.0 0 16.2
283 | 5620 3 2 0.545 0 16.2
284 | 5640 2 -1 0.0 0 16.2
285 | 5660 9 9 0.166 1 16.2
286 | 5680 2 -1 0.0 0 16.2
287 | 5700 3 -1 0.0 0 16.2
288 | 5720 9 9 1.97 1 16.2
289 | 5740 6 4 0.244 0 16.2
290 | 5760 2 2 0.143 1 16.2
291 | 5780 7 7 1.56 1 16.2
292 | 5800 2 -1 0.0 0 16.2
293 | 5820 5 -1 0.0 0 16.2
294 | 5840 2 -1 0.0 0 16.2
295 | 5860 0 -1 0.0 0 16.2
296 | 5880 0 0 0.259 1 16.2
297 | 5900 4 -1 0.0 0 16.2
298 | 5920 8 8 2.45 1 16.2
299 | 5940 7 4 0.529 0 16.2
300 | 5960 2 -1 0.0 0 16.2
301 | 5980 9 9 1.72 1 16.2
302 | 6000 8 -1 0.0 0 16.2
303 | 6020 6 6 0.335 1 16.2
304 | 6040 2 7 1.06 0 16.2
305 | 6060 3 -1 0.0 0 16.2
306 | 6080 1 -1 0.0 0 16.2
307 | 6100 1 1 0.83 1 16.2
308 | 6120 5 5 1.05 1 16.2
309 | 6140 9 9 1.53 1 16.2
310 | 6160 2 5 1.05 0 16.2
311 | 6180 0 -1 0.0 0 16.2
312 | 6200 3 -1 0.0 0 16.2
313 | 6220 2 6 0.215 0 16.2
314 | 6240 2 -1 0.0 0 16.2
315 | 6260 2 -1 0.0 0 16.2
316 | 6280 8 -1 0.0 0 16.2
317 | 6300 1 -1 0.0 0 16.2
318 | 6320 0 0 2.29 1 16.2
319 | 6340 3 3 0.665 1 16.2
320 | 6360 2 2 0.166 1 16.2
321 | 6380 3 -1 0.0 0 16.2
322 | 6400 0 1 0.423 0 16.2
323 | 6420 7 7 0.522 1 16.2
324 | 6440 3 -1 0.0 0 16.2
325 | 6460 3 -1 0.0 0 16.2
326 | 6480 0 0 0.485 1 16.2
327 | 6500 7 -1 0.0 0 16.2
328 | 6520 6 4 0.309 0 16.2
329 | 6540 8 8 1.16 1 16.2
330 | 6560 6 6 0.6 1 16.2
331 | 6580 1 -1 0.0 0 16.2
332 | 6600 7 7 3.56 1 16.2
333 | 6620 7 7 1.42 1 16.2
334 | 6640 5 3 0.413 0 16.2
335 | 6660 0 0 0.983 1 16.2
336 | 6680 3 5 0.988 0 16.2
337 | 6700 6 6 0.164 1 16.2
338 | 6720 2 2 0.528 1 16.2
339 | 6740 2 -1 0.0 0 16.2
340 | 6760 5 5 0.0267 1 16.2
341 | 6780 7 7 2.17 1 16.2
342 | 6800 6 6 0.053 1 16.2
343 | 6820 1 1 2.02 1 16.2
344 | 6840 9 9 0.907 1 16.2
345 | 6860 0 0 0.985 1 16.2
346 | 6880 2 7 0.0815 0 16.2
347 | 6900 3 -1 0.0 0 16.2
348 | 6920 5 2 0.978 0 16.2
349 | 6940 1 1 2.84 1 16.2
350 | 6960 9 8 1.29 0 16.2
351 | 6980 0 9 0.00741 0 16.2
352 | 7000 2 -1 0.0 0 16.2
353 | 7020 7 7 0.983 1 16.2
354 | 7040 7 8 0.289 0 16.2
355 | 7060 8 8 0.456 1 16.2
356 | 7080 5 4 0.294 0 16.2
357 | 7100 9 3 0.239 0 16.2
358 | 7120 7 7 0.0409 1 16.2
359 | 7140 4 -1 0.0 0 16.2
360 | 7160 5 -1 0.0 0 16.2
361 | 7180 6 6 1.64 1 16.2
362 | 7200 4 6 0.209 0 16.2
363 | 7220 9 9 0.884 1 16.2
364 | 7240 0 8 0.393 0 16.2
365 | 7260 1 -1 0.0 0 16.2
366 | 7280 1 1 0.041 1 16.2
367 | 7300 3 -1 0.0 0 16.2
368 | 7320 8 8 0.0225 1 16.2
369 | 7340 2 4 0.358 0 16.2
370 | 7360 2 4 0.573 0 16.2
371 | 7380 7 7 0.799 1 16.2
372 | 7400 3 3 0.642 1 16.2
373 | 7420 3 0 0.398 0 16.2
374 | 7440 3 3 0.196 1 16.2
375 | 7460 5 5 2.81 1 16.2
376 | 7480 1 8 0.601 0 16.2
377 | 7500 6 6 3.08 1 16.2
378 | 7520 4 -1 0.0 0 16.2
379 | 7540 5 5 2.78 1 16.2
380 | 7560 0 -1 0.0 0 16.2
381 | 7580 8 8 0.896 1 16.2
382 | 7600 8 9 0.755 0 16.2
383 | 7620 5 3 0.129 0 16.2
384 | 7640 9 9 0.49 1 16.2
385 | 7660 7 -1 0.0 0 16.2
386 | 7680 3 -1 0.0 0 16.2
387 | 7700 6 6 1.56 1 16.2
388 | 7720 5 5 0.385 1 16.2
389 | 7740 4 -1 0.0 0 16.2
390 | 7760 2 -1 0.0 0 16.2
391 | 7780 3 -1 0.0 0 16.2
392 | 7800 0 0 1.43 1 16.2
393 | 7820 5 8 0.148 0 16.2
394 | 7840 1 1 2.08 1 16.2
395 | 7860 8 -1 0.0 0 16.2
396 | 7880 0 0 0.967 1 16.2
397 | 7900 0 8 0.0134 0 16.2
398 | 7920 9 -1 0.0 0 16.2
399 | 7940 2 6 0.077 0 16.2
400 | 7960 0 0 0.774 1 16.2
401 | 7980 3 -1 0.0 0 16.2
402 | 8000 9 9 1.18 1 16.2
403 | 8020 6 3 0.342 0 16.2
404 | 8040 2 2 0.554 1 16.2
405 | 8060 6 6 0.644 1 16.2
406 | 8080 4 -1 0.0 0 16.2
407 | 8100 6 -1 0.0 0 16.2
408 | 8120 8 9 0.875 0 16.2
409 | 8140 5 4 0.269 0 16.2
410 | 8160 6 6 0.175 1 16.2
411 | 8180 4 4 0.6 1 16.2
412 | 8200 3 -1 0.0 0 16.2
413 | 8220 6 6 0.041 1 16.2
414 | 8240 7 7 2.49 1 16.2
415 | 8260 1 1 0.683 1 16.2
416 | 8280 4 -1 0.0 0 16.2
417 | 8300 5 2 0.151 0 16.2
418 | 8320 7 -1 0.0 0 16.2
419 | 8340 5 3 1.03 0 16.2
420 | 8360 7 7 0.765 1 16.2
421 | 8380 9 9 0.762 1 16.2
422 | 8400 0 0 0.049 1 16.2
423 | 8420 3 6 0.197 0 16.2
424 | 8440 5 -1 0.0 0 16.2
425 | 8460 8 8 1.73 1 16.2
426 | 8480 2 -1 0.0 0 16.2
427 | 8500 4 4 0.114 1 16.2
428 | 8520 8 1 0.00909 0 16.2
429 | 8540 4 -1 0.0 0 16.2
430 | 8560 7 7 0.841 1 16.2
431 | 8580 3 -1 0.0 0 16.2
432 | 8600 3 -1 0.0 0 16.2
433 | 8620 3 -1 0.0 0 16.2
434 | 8640 1 1 2.53 1 16.2
435 | 8660 7 7 0.893 1 16.2
436 | 8680 3 9 0.221 0 16.2
437 | 8700 3 3 1.35 1 16.2
438 | 8720 2 0 2.39 0 16.2
439 | 8740 2 -1 0.0 0 16.2
440 | 8760 8 8 1.27 1 16.2
441 | 8780 4 4 0.0863 1 16.2
442 | 8800 0 0 3.19 1 16.2
443 | 8820 2 2 0.818 1 16.2
444 | 8840 2 1 0.302 0 16.2
445 | 8860 2 2 0.121 1 16.2
446 | 8880 1 1 0.362 1 16.2
447 | 8900 2 -1 0.0 0 16.2
448 | 8920 6 6 1.72 1 16.2
449 | 8940 6 3 0.226 0 16.2
450 | 8960 0 0 1.93 1 16.2
451 | 8980 9 9 1.69 1 16.1
452 | 9000 8 8 3.12 1 16.2
453 | 9020 7 9 0.268 0 16.2
454 | 9040 2 5 1.2 0 16.2
455 | 9060 9 5 0.162 0 16.2
456 | 9080 3 3 0.143 1 16.2
457 | 9100 9 -1 0.0 0 16.2
458 | 9120 3 -1 0.0 0 16.2
459 | 9140 7 -1 0.0 0 16.2
460 | 9160 5 5 0.198 1 16.2
461 | 9180 5 5 3.27 1 16.2
462 | 9200 8 8 2.57 1 16.2
463 | 9220 8 8 2.43 1 16.2
464 | 9240 8 8 0.796 1 16.2
465 | 9260 5 5 0.637 1 16.2
466 | 9280 9 -1 0.0 0 16.2
467 | 9300 5 -1 0.0 0 16.2
468 | 9320 9 9 1.14 1 16.2
469 | 9340 1 1 2.12 1 16.2
470 | 9360 5 -1 0.0 0 16.2
471 | 9380 5 3 0.0606 0 16.2
472 | 9400 6 6 0.0884 1 16.2
473 | 9420 5 5 0.529 1 16.2
474 | 9440 8 8 2.55 1 16.2
475 | 9460 3 7 0.0941 0 16.2
476 | 9480 2 -1 0.0 0 16.2
477 | 9500 9 -1 0.0 0 16.2
478 | 9520 1 1 1.03 1 16.2
479 | 9540 5 5 0.0202 1 16.2
480 | 9560 9 7 1.52 0 16.2
481 | 9580 1 -1 0.0 0 16.2
482 | 9600 8 8 1.46 1 16.2
483 | 9620 4 -1 0.0 0 16.2
484 | 9640 5 -1 0.0 0 16.2
485 | 9660 6 6 1.73 1 16.2
486 | 9680 8 8 1.4 1 16.2
487 | 9700 0 0 1.05 1 16.2
488 | 9720 8 -1 0.0 0 16.2
489 | 9740 3 6 0.356 0 16.2
490 | 9760 9 9 0.344 1 16.2
491 | 9780 4 6 0.813 0 16.2
492 | 9800 1 1 3.18 1 16.2
493 | 9820 0 8 0.0948 0 16.2
494 | 9840 4 5 0.56 0 16.2
495 | 9860 0 -1 0.0 0 16.2
496 | 9880 7 7 0.0843 1 16.2
497 | 9900 8 0 0.179 0 16.2
498 | 9920 6 6 0.104 1 16.2
499 | 9940 4 -1 0.0 0 16.2
500 | 9960 2 0 3.3 0 16.2
501 | 9980 0 0 2.12 1 16.2
502 |
--------------------------------------------------------------------------------
/data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25:
--------------------------------------------------------------------------------
1 | idx label predict radius correct time
2 | 0 0 0 0.165 1 0:02:31.238689
3 | 100 2 2 0.978 1 0:02:28.472100
4 | 200 4 395 0.0504 0 0:02:29.734734
5 | 300 6 6 0.978 1 0:02:29.500387
6 | 400 8 8 0.978 1 0:02:30.348481
7 | 500 10 10 0.849 1 0:02:29.899296
8 | 600 12 12 0.0511 1 0:02:29.689873
9 | 700 14 14 0.978 1 0:02:29.582000
10 | 800 16 16 0.978 1 0:02:29.844638
11 | 900 18 18 0.117 1 0:02:30.467950
12 | 1000 20 20 0.978 1 0:02:30.492715
13 | 1100 22 21 0.271 0 0:02:29.971339
14 | 1200 24 24 0.978 1 0:02:29.882296
15 | 1300 26 26 0.464 1 0:02:30.184608
16 | 1400 28 -1 0.0 0 0:02:29.938915
17 | 1500 30 30 0.287 1 0:02:29.372000
18 | 1600 32 -1 0.0 0 0:02:30.198865
19 | 1700 34 327 0.0542 0 0:02:30.576400
20 | 1800 36 35 0.609 0 0:02:29.825960
21 | 1900 38 38 0.226 1 0:02:29.747150
22 | 2000 40 31 0.0497 0 0:02:29.953930
23 | 2100 42 48 0.839 0 0:02:30.242376
24 | 2200 44 44 0.681 1 0:02:30.185557
25 | 2300 46 46 0.537 1 0:02:30.888912
26 | 2400 48 48 0.17 1 0:02:30.444757
27 | 2500 50 50 0.682 1 0:02:30.507672
28 | 2600 52 52 0.405 1 0:02:29.848361
29 | 2700 54 57 0.978 0 0:02:30.142747
30 | 2800 56 56 0.48 1 0:02:30.734643
31 | 2900 58 -1 0.0 0 0:02:30.443910
32 | 3000 60 62 0.37 0 0:02:30.121695
33 | 3100 62 62 0.449 1 0:02:30.041115
34 | 3200 64 64 0.634 1 0:02:29.907086
35 | 3300 66 63 0.133 0 0:02:30.320275
36 | 3400 68 -1 0.0 0 0:02:30.168557
37 | 3500 70 70 0.978 1 0:02:29.595999
38 | 3600 72 72 0.978 1 0:02:29.768712
39 | 3700 74 74 0.499 1 0:02:30.160280
40 | 3800 76 76 0.641 1 0:02:30.447175
41 | 3900 78 78 0.631 1 0:02:30.357008
42 | 4000 80 80 0.978 1 0:02:30.246374
43 | 4100 82 82 0.978 1 0:02:30.204693
44 | 4200 84 84 0.405 1 0:02:30.568214
45 | 4300 86 86 0.571 1 0:02:30.145539
46 | 4400 88 88 0.128 1 0:02:29.993359
47 | 4500 90 90 0.74 1 0:02:30.215480
48 | 4600 92 92 0.503 1 0:02:29.838759
49 | 4700 94 94 0.978 1 0:02:29.745489
50 | 4800 96 96 0.978 1 0:02:30.194775
51 | 4900 98 98 0.722 1 0:02:29.985465
52 | 5000 100 100 0.978 1 0:02:30.678385
53 | 5100 102 102 0.978 1 0:02:30.334541
54 | 5200 104 342 0.216 0 0:02:30.607674
55 | 5300 106 106 0.561 1 0:02:30.052348
56 | 5400 108 470 0.256 0 0:02:30.179028
57 | 5500 110 110 0.978 1 0:02:29.893491
58 | 5600 112 112 0.906 1 0:02:30.182807
59 | 5700 114 113 0.391 0 0:02:30.503177
60 | 5800 116 116 0.978 1 0:02:30.496567
61 | 5900 118 118 0.662 1 0:02:30.373342
62 | 6000 120 119 0.0777 0 0:02:30.121950
63 | 6100 122 122 0.978 1 0:02:30.427012
64 | 6200 124 123 0.113 0 0:02:30.054732
65 | 6300 126 126 0.528 1 0:02:30.416313
66 | 6400 128 128 0.978 1 0:02:30.562088
67 | 6500 130 130 0.978 1 0:02:30.231175
68 | 6600 132 132 0.978 1 0:02:30.159542
69 | 6700 134 132 0.208 0 0:02:29.646543
70 | 6800 136 136 0.978 1 0:02:30.380730
71 | 6900 138 138 0.978 1 0:02:30.572920
72 | 7000 140 140 0.978 1 0:02:30.700227
73 | 7100 142 142 0.814 1 0:02:30.641635
74 | 7200 144 144 0.978 1 0:02:30.563427
75 | 7300 146 146 0.0569 1 0:02:30.143445
76 | 7400 148 148 0.978 1 0:02:30.023675
77 | 7500 150 360 0.498 0 0:02:30.621709
78 | 7600 152 152 0.793 1 0:02:30.327601
79 | 7700 154 154 0.913 1 0:02:30.553908
80 | 7800 156 156 0.978 1 0:02:30.480134
81 | 7900 158 158 0.685 1 0:02:30.272967
82 | 8000 160 160 0.978 1 0:02:30.440335
83 | 8100 162 162 0.692 1 0:02:29.968419
84 | 8200 164 164 0.25 1 0:02:31.014798
85 | 8300 166 166 0.978 1 0:02:30.850247
86 | 8400 168 211 0.0559 0 0:02:30.812654
87 | 8500 170 172 0.345 0 0:02:30.567950
88 | 8600 172 172 0.623 1 0:02:31.028928
89 | 8700 174 174 0.687 1 0:02:30.706730
90 | 8800 176 176 0.388 1 0:02:30.524550
91 | 8900 178 178 0.978 1 0:02:30.185986
92 | 9000 180 243 0.285 0 0:02:30.189492
93 | 9100 182 182 0.0441 1 0:02:29.980940
94 | 9200 184 -1 0.0 0 0:02:30.320398
95 | 9300 186 192 0.596 0 0:02:30.886985
96 | 9400 188 188 0.165 1 0:02:30.411439
97 | 9500 190 190 0.978 1 0:02:29.880804
98 | 9600 192 199 0.672 0 0:02:30.719151
99 | 9700 194 -1 0.0 0 0:02:30.716428
100 | 9800 196 199 0.978 0 0:02:30.519637
101 | 9900 198 196 0.0359 0 0:02:30.775083
102 | 10000 200 200 0.817 1 0:02:30.284424
103 | 10100 202 202 0.955 1 0:02:29.999943
104 | 10200 204 204 0.247 1 0:02:30.015042
105 | 10300 206 206 0.978 1 0:02:29.928960
106 | 10400 208 208 0.395 1 0:02:30.028211
107 | 10500 210 210 0.51 1 0:02:30.097306
108 | 10600 212 212 0.219 1 0:02:30.861089
109 | 10700 214 214 0.577 1 0:02:30.404881
110 | 10800 216 216 0.846 1 0:02:30.240213
111 | 10900 218 218 0.978 1 0:02:30.392252
112 | 11000 220 220 0.535 1 0:02:30.464677
113 | 11100 222 222 0.0577 1 0:02:30.620631
114 | 11200 224 233 0.117 0 0:02:30.720985
115 | 11300 226 226 0.69 1 0:02:30.625103
116 | 11400 228 228 0.978 1 0:02:30.223536
117 | 11500 230 230 0.188 1 0:02:30.951324
118 | 11600 232 232 0.941 1 0:02:30.144850
119 | 11700 234 234 0.437 1 0:02:31.048280
120 | 11800 236 234 0.0155 0 0:02:30.477529
121 | 11900 238 238 0.712 1 0:02:30.344499
122 | 12000 240 238 0.691 0 0:02:30.270159
123 | 12100 242 242 0.786 1 0:02:30.272286
124 | 12200 244 244 0.978 1 0:02:30.038246
125 | 12300 246 251 0.61 0 0:02:29.815927
126 | 12400 248 248 0.0843 1 0:02:29.774322
127 | 12500 250 169 0.238 0 0:02:30.566967
128 | 12600 252 252 0.726 1 0:02:29.879580
129 | 12700 254 254 0.235 1 0:02:29.971380
130 | 12800 256 256 0.978 1 0:02:29.825671
131 | 12900 258 258 0.517 1 0:02:30.061793
132 | 13000 260 260 0.537 1 0:02:30.049401
133 | 13100 262 262 0.978 1 0:02:30.337835
134 | 13200 264 264 0.307 1 0:02:30.565490
135 | 13300 266 266 0.175 1 0:02:29.914956
136 | 13400 268 268 0.978 1 0:02:29.795518
137 | 13500 270 270 0.978 1 0:02:30.268059
138 | 13600 272 274 0.306 0 0:02:30.197188
139 | 13700 274 274 0.978 1 0:02:30.016229
140 | 13800 276 276 0.978 1 0:02:29.347983
141 | 13900 278 278 0.707 1 0:02:30.143017
142 | 14000 280 280 0.623 1 0:02:30.471658
143 | 14100 282 282 0.929 1 0:02:29.568346
144 | 14200 284 284 0.978 1 0:02:29.389408
145 | 14300 286 286 0.978 1 0:02:29.676501
146 | 14400 288 288 0.716 1 0:02:30.132765
147 | 14500 290 290 0.978 1 0:02:29.294689
148 | 14600 292 290 0.609 0 0:02:30.056545
149 | 14700 294 294 0.649 1 0:02:30.130922
150 | 14800 296 296 0.638 1 0:02:29.760217
151 | 14900 298 -1 0.0 0 0:02:30.193345
152 | 15000 300 300 0.978 1 0:02:29.807024
153 | 15100 302 302 0.0493 1 0:02:30.157539
154 | 15200 304 301 0.978 0 0:02:30.279750
155 | 15300 306 306 0.978 1 0:02:29.992867
156 | 15400 308 308 0.844 1 0:02:30.478700
157 | 15500 310 310 0.978 1 0:02:29.860750
158 | 15600 312 -1 0.0 0 0:02:29.935535
159 | 15700 314 314 0.735 1 0:02:30.038218
160 | 15800 316 316 0.978 1 0:02:30.223525
161 | 15900 318 318 0.46 1 0:02:30.016658
162 | 16000 320 320 0.978 1 0:02:30.401388
163 | 16100 322 322 0.754 1 0:02:30.413642
164 | 16200 324 324 0.978 1 0:02:30.292821
165 | 16300 326 326 0.955 1 0:02:30.410080
166 | 16400 328 328 0.978 1 0:02:30.249984
167 | 16500 330 330 0.405 1 0:02:29.675360
168 | 16600 332 -1 0.0 0 0:02:29.048663
169 | 16700 334 334 0.978 1 0:02:30.071187
170 | 16800 336 336 0.844 1 0:02:29.878231
171 | 16900 338 617 0.464 0 0:02:30.147896
172 | 17000 340 340 0.978 1 0:02:30.650256
173 | 17100 342 287 0.16 0 0:02:30.627570
174 | 17200 344 344 0.941 1 0:02:30.470581
175 | 17300 346 344 0.0442 0 0:02:29.950092
176 | 17400 348 348 0.797 1 0:02:30.853736
177 | 17500 350 350 0.978 1 0:02:30.505603
178 | 17600 352 352 0.857 1 0:02:30.227803
179 | 17700 354 354 0.978 1 0:02:30.097232
180 | 17800 356 -1 0.0 0 0:02:29.830953
181 | 17900 358 359 0.181 0 0:02:29.647446
182 | 18000 360 360 0.863 1 0:02:30.357880
183 | 18100 362 362 0.978 1 0:02:30.611909
184 | 18200 364 364 0.718 1 0:02:30.395142
185 | 18300 366 366 0.407 1 0:02:30.572165
186 | 18400 368 368 0.955 1 0:02:30.181935
187 | 18500 370 370 0.978 1 0:02:30.173064
188 | 18600 372 372 0.529 1 0:02:30.302947
189 | 18700 374 -1 0.0 0 0:02:30.235269
190 | 18800 376 376 0.772 1 0:02:30.473644
191 | 18900 378 378 0.235 1 0:02:30.372144
192 | 19000 380 380 0.837 1 0:02:30.160681
193 | 19100 382 382 0.978 1 0:02:30.570127
194 | 19200 384 -1 0.0 0 0:02:30.820962
195 | 19300 386 101 0.402 0 0:02:29.788515
196 | 19400 388 388 0.541 1 0:02:29.797744
197 | 19500 390 390 0.775 1 0:02:30.239534
198 | 19600 392 397 0.562 0 0:02:30.206156
199 | 19700 394 -1 0.0 0 0:02:30.558219
200 | 19800 396 396 0.978 1 0:02:30.153110
201 | 19900 398 398 0.978 1 0:02:30.476912
202 | 20000 400 400 0.186 1 0:02:30.401398
203 | 20100 402 402 0.245 1 0:02:30.091169
204 | 20200 404 404 0.978 1 0:02:30.319886
205 | 20300 406 406 0.607 1 0:02:29.894624
206 | 20400 408 847 0.57 0 0:02:30.479393
207 | 20500 410 410 0.0798 1 0:02:30.255137
208 | 20600 412 412 0.331 1 0:02:30.360466
209 | 20700 414 414 0.39 1 0:02:29.998817
210 | 20800 416 416 0.88 1 0:02:30.508972
211 | 20900 418 563 0.978 0 0:02:30.698495
212 | 21000 420 420 0.242 1 0:02:30.715746
213 | 21100 422 422 0.261 1 0:02:30.667740
214 | 21200 424 -1 0.0 0 0:02:30.284119
215 | 21300 426 426 0.978 1 0:02:30.628197
216 | 21400 428 428 0.978 1 0:02:30.471429
217 | 21500 430 430 0.51 1 0:02:29.978379
218 | 21600 432 432 0.978 1 0:02:29.980470
219 | 21700 434 -1 0.0 0 0:02:30.307042
220 | 21800 436 436 0.978 1 0:02:30.765092
221 | 21900 438 438 0.262 1 0:02:30.637522
222 | 22000 440 737 0.0713 0 0:02:30.458820
223 | 22100 442 442 0.929 1 0:02:30.332701
224 | 22200 444 -1 0.0 0 0:02:30.527043
225 | 22300 446 446 0.929 1 0:02:29.991854
226 | 22400 448 448 0.543 1 0:02:30.389721
227 | 22500 450 407 0.154 0 0:02:29.747789
228 | 22600 452 452 0.913 1 0:02:30.578689
229 | 22700 454 454 0.9 1 0:02:30.596173
230 | 22800 456 -1 0.0 0 0:02:29.819309
231 | 22900 458 458 0.9 1 0:02:29.640807
232 | 23000 460 978 0.12 0 0:02:29.570771
233 | 23100 462 462 0.978 1 0:02:29.782221
234 | 23200 464 439 0.522 0 0:02:30.267272
235 | 23300 466 466 0.978 1 0:02:30.125070
236 | 23400 468 468 0.978 1 0:02:29.601588
237 | 23500 470 624 0.0437 0 0:02:29.220088
238 | 23600 472 472 0.978 1 0:02:30.266124
239 | 23700 474 841 0.14 0 0:02:30.019980
240 | 23800 476 476 0.929 1 0:02:29.595922
241 | 23900 478 478 0.39 1 0:02:29.594913
242 | 24000 480 509 0.509 0 0:02:30.068407
243 | 24100 482 481 0.238 0 0:02:30.223905
244 | 24200 484 871 0.842 0 0:02:30.230005
245 | 24300 486 486 0.104 1 0:02:30.039842
246 | 24400 488 488 0.468 1 0:02:30.693576
247 | 24500 490 490 0.629 1 0:02:30.275327
248 | 24600 492 492 0.826 1 0:02:30.023680
249 | 24700 494 398 0.941 0 0:02:30.105409
250 | 24800 496 496 0.131 1 0:02:30.163059
251 | 24900 498 498 0.978 1 0:02:30.074041
252 | 25000 500 500 0.978 1 0:02:29.717768
253 | 25100 502 502 0.873 1 0:02:29.803116
254 | 25200 504 504 0.0495 1 0:02:30.441310
255 | 25300 506 506 0.653 1 0:02:30.144799
256 | 25400 508 508 0.978 1 0:02:29.834378
257 | 25500 510 510 0.978 1 0:02:30.429432
258 | 25600 512 740 0.254 0 0:02:29.822186
259 | 25700 514 514 0.978 1 0:02:29.765429
260 | 25800 516 431 0.0175 0 0:02:29.624502
261 | 25900 518 518 0.978 1 0:02:29.319035
262 | 26000 520 520 0.978 1 0:02:29.907698
263 | 26100 522 522 0.978 1 0:02:30.069973
264 | 26200 524 461 0.664 0 0:02:29.917895
265 | 26300 526 526 0.0395 1 0:02:30.180328
266 | 26400 528 -1 0.0 0 0:02:30.060835
267 | 26500 530 531 0.978 0 0:02:30.186140
268 | 26600 532 532 0.142 1 0:02:30.598862
269 | 26700 534 -1 0.0 0 0:02:30.541164
270 | 26800 536 403 0.755 0 0:02:30.070715
271 | 26900 538 538 0.978 1 0:02:30.112502
272 | 27000 540 540 0.978 1 0:02:30.576435
273 | 27100 542 -1 0.0 0 0:02:30.259104
274 | 27200 544 926 0.371 0 0:02:30.055312
275 | 27300 546 546 0.718 1 0:02:30.321178
276 | 27400 548 548 0.9 1 0:02:30.384792
277 | 27500 550 505 0.251 0 0:02:30.494111
278 | 27600 552 552 0.978 1 0:02:30.445498
279 | 27700 554 554 0.802 1 0:02:30.662148
280 | 27800 556 421 0.0586 0 0:02:30.567910
281 | 27900 558 251 0.228 0 0:02:30.566867
282 | 28000 560 768 0.254 0 0:02:30.054083
283 | 28100 562 562 0.978 1 0:02:30.009979
284 | 28200 564 564 0.189 1 0:02:30.454839
285 | 28300 566 566 0.831 1 0:02:30.514242
286 | 28400 568 399 0.163 0 0:02:30.355165
287 | 28500 570 -1 0.0 0 0:02:30.256466
288 | 28600 572 418 0.123 0 0:02:29.945449
289 | 28700 574 574 0.978 1 0:02:29.825920
290 | 28800 576 -1 0.0 0 0:02:30.626047
291 | 28900 578 578 0.457 1 0:02:29.899034
292 | 29000 580 738 0.978 0 0:02:30.362939
293 | 29100 582 -1 0.0 0 0:02:30.624694
294 | 29200 584 754 0.102 0 0:02:30.366516
295 | 29300 586 586 0.955 1 0:02:30.131823
296 | 29400 588 588 0.681 1 0:02:30.329224
297 | 29500 590 590 0.71 1 0:02:29.902602
298 | 29600 592 592 0.978 1 0:02:30.362917
299 | 29700 594 843 0.378 0 0:02:30.272107
300 | 29800 596 56 0.0281 0 0:02:29.961047
301 | 29900 598 664 0.603 0 0:02:29.538639
302 | 30000 600 517 0.205 0 0:02:29.747167
303 | 30100 602 602 0.978 1 0:02:29.913053
304 | 30200 604 604 0.366 1 0:02:29.841973
305 | 30300 606 606 0.978 1 0:02:30.086977
306 | 30400 608 608 0.441 1 0:02:30.778591
307 | 30500 610 800 0.138 0 0:02:30.530804
308 | 30600 612 612 0.978 1 0:02:29.529376
309 | 30700 614 614 0.805 1 0:02:29.651345
310 | 30800 616 534 0.121 0 0:02:30.312252
311 | 30900 618 -1 0.0 0 0:02:30.016361
312 | 31000 620 681 0.116 0 0:02:30.335139
313 | 31100 622 622 0.978 1 0:02:29.938906
314 | 31200 624 453 0.851 0 0:02:29.743702
315 | 31300 626 542 0.127 0 0:02:30.080677
316 | 31400 628 -1 0.0 0 0:02:29.638628
317 | 31500 630 703 0.672 0 0:02:30.238523
318 | 31600 632 632 0.866 1 0:02:30.075643
319 | 31700 634 489 0.978 0 0:02:30.278016
320 | 31800 636 636 0.616 1 0:02:30.159042
321 | 31900 638 638 0.92 1 0:02:30.447832
322 | 32000 640 640 0.978 1 0:02:30.080730
323 | 32100 642 642 0.978 1 0:02:29.710235
324 | 32200 644 644 0.978 1 0:02:30.165180
325 | 32300 646 646 0.978 1 0:02:29.225623
326 | 32400 648 -1 0.0 0 0:02:30.410912
327 | 32500 650 546 0.0652 0 0:02:30.158323
328 | 32600 652 652 0.553 1 0:02:29.836980
329 | 32700 654 436 0.103 0 0:02:29.863310
330 | 32800 656 -1 0.0 0 0:02:30.554843
331 | 32900 658 658 0.978 1 0:02:30.161191
332 | 33000 660 660 0.437 1 0:02:29.909711
333 | 33100 662 662 0.0149 1 0:02:29.551033
334 | 33200 664 851 0.117 0 0:02:29.760597
335 | 33300 666 659 0.259 0 0:02:29.734256
336 | 33400 668 -1 0.0 0 0:02:30.239919
337 | 33500 670 670 0.978 1 0:02:30.034608
338 | 33600 672 672 0.353 1 0:02:30.049591
339 | 33700 674 674 0.426 1 0:02:30.018736
340 | 33800 676 560 0.0283 0 0:02:30.365235
341 | 33900 678 678 0.978 1 0:02:30.710799
342 | 34000 680 -1 0.0 0 0:02:29.993227
343 | 34100 682 682 0.299 1 0:02:29.944957
344 | 34200 684 684 0.929 1 0:02:29.973918
345 | 34300 686 -1 0.0 0 0:02:29.893355
346 | 34400 688 688 0.978 1 0:02:29.698450
347 | 34500 690 -1 0.0 0 0:02:29.653637
348 | 34600 692 -1 0.0 0 0:02:29.897104
349 | 34700 694 694 0.978 1 0:02:29.589318
350 | 34800 696 -1 0.0 0 0:02:29.911197
351 | 34900 698 698 0.0859 1 0:02:29.422916
352 | 35000 700 700 0.459 1 0:02:29.795737
353 | 35100 702 422 0.846 0 0:02:30.083227
354 | 35200 704 704 0.731 1 0:02:30.325363
355 | 35300 706 536 0.217 0 0:02:30.080998
356 | 35400 708 682 0.794 0 0:02:30.274108
357 | 35500 710 710 0.366 1 0:02:30.420304
358 | 35600 712 -1 0.0 0 0:02:30.628211
359 | 35700 714 714 0.597 1 0:02:30.303162
360 | 35800 716 716 0.955 1 0:02:29.972523
361 | 35900 718 449 0.212 0 0:02:29.445828
362 | 36000 720 720 0.372 1 0:02:30.154242
363 | 36100 722 722 0.978 1 0:02:30.139431
364 | 36200 724 780 0.978 0 0:02:30.145026
365 | 36300 726 726 0.978 1 0:02:30.133129
366 | 36400 728 728 0.0609 1 0:02:30.022333
367 | 36500 730 856 0.859 0 0:02:30.329501
368 | 36600 732 732 0.978 1 0:02:30.794844
369 | 36700 734 734 0.978 1 0:02:30.110396
370 | 36800 736 736 0.978 1 0:02:30.063162
371 | 36900 738 738 0.617 1 0:02:30.438124
372 | 37000 740 -1 0.0 0 0:02:30.778301
373 | 37100 742 -1 0.0 0 0:02:30.334797
374 | 37200 744 657 0.913 0 0:02:30.138139
375 | 37300 746 746 0.978 1 0:02:29.461292
376 | 37400 748 748 0.978 1 0:02:29.794650
377 | 37500 750 172 0.016 0 0:02:30.684972
378 | 37600 752 -1 0.0 0 0:02:30.102833
379 | 37700 754 848 0.331 0 0:02:29.873319
380 | 37800 756 756 0.978 1 0:02:29.837130
381 | 37900 758 758 0.527 1 0:02:30.398277
382 | 38000 760 651 0.246 0 0:02:30.244316
383 | 38100 762 762 0.455 1 0:02:30.474714
384 | 38200 764 593 0.0208 0 0:02:30.163622
385 | 38300 766 766 0.873 1 0:02:30.341681
386 | 38400 768 768 0.56 1 0:02:30.791192
387 | 38500 770 770 0.208 1 0:02:30.626651
388 | 38600 772 772 0.773 1 0:02:30.237842
389 | 38700 774 774 0.906 1 0:02:30.539333
390 | 38800 776 683 0.206 0 0:02:30.403928
391 | 38900 778 707 0.273 0 0:02:30.200670
392 | 39000 780 780 0.103 1 0:02:30.389691
393 | 39100 782 664 0.566 0 0:02:30.260540
394 | 39200 784 505 0.0573 0 0:02:30.160525
395 | 39300 786 -1 0.0 0 0:02:30.366705
396 | 39400 788 743 0.101 0 0:02:30.461723
397 | 39500 790 791 0.271 0 0:02:30.898356
398 | 39600 792 795 0.195 0 0:02:30.090714
399 | 39700 794 794 0.831 1 0:02:30.063199
400 | 39800 796 796 0.978 1 0:02:30.189355
401 | 39900 798 798 0.978 1 0:02:30.551748
402 | 40000 800 800 0.978 1 0:02:30.192058
403 | 40100 802 802 0.685 1 0:02:29.962413
404 | 40200 804 804 0.213 1 0:02:29.795746
405 | 40300 806 806 0.978 1 0:02:30.063303
406 | 40400 808 808 0.955 1 0:02:29.804598
407 | 40500 810 878 0.33 0 0:02:29.777533
408 | 40600 812 812 0.978 1 0:02:29.472982
409 | 40700 814 814 0.172 1 0:02:29.290164
410 | 40800 816 816 0.978 1 0:02:30.354582
411 | 40900 818 745 0.523 0 0:02:29.823832
412 | 41000 820 820 0.262 1 0:02:29.811310
413 | 41100 822 822 0.978 1 0:02:30.542749
414 | 41200 824 824 0.696 1 0:02:30.533368
415 | 41300 826 826 0.47 1 0:02:30.509891
416 | 41400 828 618 0.0484 0 0:02:30.638973
417 | 41500 830 830 0.209 1 0:02:29.854785
418 | 41600 832 832 0.443 1 0:02:30.028128
419 | 41700 834 417 0.337 0 0:02:30.216051
420 | 41800 836 444 0.423 0 0:02:30.478095
421 | 41900 838 631 0.92 0 0:02:30.660583
422 | 42000 840 840 0.978 1 0:02:29.944967
423 | 42100 842 842 0.0816 1 0:02:29.631471
424 | 42200 844 844 0.978 1 0:02:30.017588
425 | 42300 846 583 0.192 0 0:02:29.849304
426 | 42400 848 848 0.455 1 0:02:30.021253
427 | 42500 850 850 0.423 1 0:02:30.437045
428 | 42600 852 852 0.978 1 0:02:30.529743
429 |
--------------------------------------------------------------------------------
/data/predict/imagenet/resnet50/noise_0.25/test/N_100:
--------------------------------------------------------------------------------
1 | idx label predict correct time
2 | 0 0 0 1 0:00:01.737771
3 | 100 2 2 1 0:00:00.158496
4 | 200 4 395 0 0:00:00.149257
5 | 300 6 6 1 0:00:00.149002
6 | 400 8 8 1 0:00:00.149901
7 | 500 10 10 1 0:00:00.148993
8 | 600 12 12 1 0:00:00.150077
9 | 700 14 14 1 0:00:00.149649
10 | 800 16 16 1 0:00:00.148875
11 | 900 18 18 1 0:00:00.148872
12 | 1000 20 20 1 0:00:00.148681
13 | 1100 22 21 0 0:00:00.148832
14 | 1200 24 24 1 0:00:00.148842
15 | 1300 26 26 1 0:00:00.149388
16 | 1400 28 28 1 0:00:00.149335
17 | 1500 30 30 1 0:00:00.149273
18 | 1600 32 -1 0 0:00:00.150587
19 | 1700 34 327 0 0:00:00.149775
20 | 1800 36 35 0 0:00:00.149858
21 | 1900 38 38 1 0:00:00.150131
22 | 2000 40 -1 0 0:00:00.149192
23 | 2100 42 48 0 0:00:00.149575
24 | 2200 44 44 1 0:00:00.150103
25 | 2300 46 46 1 0:00:00.150445
26 | 2400 48 48 1 0:00:00.148776
27 | 2500 50 50 1 0:00:00.148915
28 | 2600 52 52 1 0:00:00.148730
29 | 2700 54 57 0 0:00:00.148792
30 | 2800 56 56 1 0:00:00.149241
31 | 2900 58 -1 0 0:00:00.149086
32 | 3000 60 62 0 0:00:00.149118
33 | 3100 62 62 1 0:00:00.149509
34 | 3200 64 64 1 0:00:00.150243
35 | 3300 66 63 0 0:00:00.149570
36 | 3400 68 -1 0 0:00:00.149943
37 | 3500 70 70 1 0:00:00.149646
38 | 3600 72 72 1 0:00:00.150097
39 | 3700 74 74 1 0:00:00.150264
40 | 3800 76 76 1 0:00:00.149668
41 | 3900 78 78 1 0:00:00.148759
42 | 4000 80 80 1 0:00:00.148977
43 | 4100 82 82 1 0:00:00.149063
44 | 4200 84 84 1 0:00:00.149432
45 | 4300 86 86 1 0:00:00.149689
46 | 4400 88 88 1 0:00:00.149581
47 | 4500 90 90 1 0:00:00.149732
48 | 4600 92 92 1 0:00:00.150014
49 | 4700 94 94 1 0:00:00.149836
50 | 4800 96 96 1 0:00:00.149859
51 | 4900 98 98 1 0:00:00.149908
52 | 5000 100 100 1 0:00:00.150221
53 | 5100 102 102 1 0:00:00.150292
54 | 5200 104 342 0 0:00:00.149715
55 | 5300 106 106 1 0:00:00.149746
56 | 5400 108 470 0 0:00:00.149747
57 | 5500 110 110 1 0:00:00.149787
58 | 5600 112 112 1 0:00:00.150302
59 | 5700 114 113 0 0:00:00.149334
60 | 5800 116 116 1 0:00:00.149520
61 | 5900 118 118 1 0:00:00.149666
62 | 6000 120 119 0 0:00:00.149083
63 | 6100 122 122 1 0:00:00.149483
64 | 6200 124 123 0 0:00:00.150204
65 | 6300 126 126 1 0:00:00.149876
66 | 6400 128 128 1 0:00:00.150150
67 | 6500 130 130 1 0:00:00.149888
68 | 6600 132 132 1 0:00:00.148795
69 | 6700 134 132 0 0:00:00.149421
70 | 6800 136 136 1 0:00:00.149319
71 | 6900 138 138 1 0:00:00.149000
72 | 7000 140 140 1 0:00:00.149684
73 | 7100 142 142 1 0:00:00.149306
74 | 7200 144 144 1 0:00:00.149648
75 | 7300 146 146 1 0:00:00.149898
76 | 7400 148 148 1 0:00:00.150061
77 | 7500 150 360 0 0:00:00.150225
78 | 7600 152 152 1 0:00:00.150093
79 | 7700 154 154 1 0:00:00.150203
80 | 7800 156 156 1 0:00:00.150502
81 | 7900 158 158 1 0:00:00.150271
82 | 8000 160 160 1 0:00:00.150902
83 | 8100 162 162 1 0:00:00.150535
84 | 8200 164 164 1 0:00:00.150811
85 | 8300 166 166 1 0:00:00.150555
86 | 8400 168 211 0 0:00:00.149288
87 | 8500 170 172 0 0:00:00.149561
88 | 8600 172 172 1 0:00:00.149710
89 | 8700 174 174 1 0:00:00.149799
90 | 8800 176 176 1 0:00:00.149869
91 | 8900 178 178 1 0:00:00.149695
92 | 9000 180 243 0 0:00:00.150275
93 | 9100 182 182 1 0:00:00.150397
94 | 9200 184 -1 0 0:00:00.150378
95 | 9300 186 192 0 0:00:00.150749
96 | 9400 188 188 1 0:00:00.150607
97 | 9500 190 190 1 0:00:00.150677
98 | 9600 192 199 0 0:00:00.150700
99 | 9700 194 202 0 0:00:00.150573
100 | 9800 196 199 0 0:00:00.150710
101 | 9900 198 -1 0 0:00:00.150579
102 | 10000 200 200 1 0:00:00.149348
103 | 10100 202 202 1 0:00:00.150144
104 | 10200 204 204 1 0:00:00.149713
105 | 10300 206 206 1 0:00:00.149793
106 | 10400 208 208 1 0:00:00.149739
107 | 10500 210 210 1 0:00:00.150280
108 | 10600 212 212 1 0:00:00.150422
109 | 10700 214 214 1 0:00:00.150694
110 | 10800 216 216 1 0:00:00.150631
111 | 10900 218 218 1 0:00:00.150435
112 | 11000 220 220 1 0:00:00.151002
113 | 11100 222 -1 0 0:00:00.150712
114 | 11200 224 233 0 0:00:00.151255
115 | 11300 226 226 1 0:00:00.150797
116 | 11400 228 228 1 0:00:00.150893
117 | 11500 230 230 1 0:00:00.150637
118 | 11600 232 232 1 0:00:00.151438
119 | 11700 234 234 1 0:00:00.150964
120 | 11800 236 234 0 0:00:00.149482
121 | 11900 238 238 1 0:00:00.150126
122 | 12000 240 238 0 0:00:00.149885
123 | 12100 242 242 1 0:00:00.149924
124 | 12200 244 244 1 0:00:00.149931
125 | 12300 246 251 0 0:00:00.150398
126 | 12400 248 -1 0 0:00:00.150118
127 | 12500 250 169 0 0:00:00.150278
128 | 12600 252 252 1 0:00:00.151481
129 | 12700 254 254 1 0:00:00.150805
130 | 12800 256 256 1 0:00:00.150324
131 | 12900 258 258 1 0:00:00.150588
132 | 13000 260 260 1 0:00:00.151438
133 | 13100 262 262 1 0:00:00.150877
134 | 13200 264 264 1 0:00:00.151444
135 | 13300 266 266 1 0:00:00.150706
136 | 13400 268 268 1 0:00:00.150965
137 | 13500 270 270 1 0:00:00.150461
138 | 13600 272 274 0 0:00:00.149865
139 | 13700 274 274 1 0:00:00.150135
140 | 13800 276 276 1 0:00:00.150259
141 | 13900 278 278 1 0:00:00.149917
142 | 14000 280 280 1 0:00:00.150193
143 | 14100 282 282 1 0:00:00.151167
144 | 14200 284 284 1 0:00:00.151545
145 | 14300 286 286 1 0:00:00.151211
146 | 14400 288 288 1 0:00:00.151570
147 | 14500 290 290 1 0:00:00.151209
148 | 14600 292 290 0 0:00:00.150754
149 | 14700 294 294 1 0:00:00.150950
150 | 14800 296 296 1 0:00:00.151226
151 | 14900 298 298 1 0:00:00.150931
152 | 15000 300 300 1 0:00:00.151120
153 | 15100 302 -1 0 0:00:00.151035
154 | 15200 304 301 0 0:00:00.149800
155 | 15300 306 306 1 0:00:00.150287
156 | 15400 308 308 1 0:00:00.150084
157 | 15500 310 310 1 0:00:00.150211
158 | 15600 312 -1 0 0:00:00.150407
159 | 15700 314 314 1 0:00:00.150743
160 | 15800 316 316 1 0:00:00.151449
161 | 15900 318 318 1 0:00:00.151332
162 | 16000 320 320 1 0:00:00.151387
163 | 16100 322 322 1 0:00:00.151472
164 | 16200 324 324 1 0:00:00.150631
165 | 16300 326 326 1 0:00:00.151067
166 | 16400 328 328 1 0:00:00.150776
167 | 16500 330 330 1 0:00:00.150908
168 | 16600 332 -1 0 0:00:00.150953
169 | 16700 334 334 1 0:00:00.150952
170 | 16800 336 336 1 0:00:00.151640
171 | 16900 338 617 0 0:00:00.151710
172 | 17000 340 340 1 0:00:00.149796
173 | 17100 342 287 0 0:00:00.150234
174 | 17200 344 344 1 0:00:00.150386
175 | 17300 346 344 0 0:00:00.150476
176 | 17400 348 348 1 0:00:00.150613
177 | 17500 350 350 1 0:00:00.150906
178 | 17600 352 352 1 0:00:00.150985
179 | 17700 354 354 1 0:00:00.151042
180 | 17800 356 -1 0 0:00:00.151376
181 | 17900 358 359 0 0:00:00.151409
182 | 18000 360 360 1 0:00:00.151763
183 | 18100 362 362 1 0:00:00.151103
184 | 18200 364 364 1 0:00:00.151692
185 | 18300 366 366 1 0:00:00.151635
186 | 18400 368 368 1 0:00:00.151221
187 | 18500 370 370 1 0:00:00.152012
188 | 18600 372 372 1 0:00:00.151617
189 | 18700 374 -1 0 0:00:00.151739
190 | 18800 376 376 1 0:00:00.150095
191 | 18900 378 378 1 0:00:00.150199
192 | 19000 380 380 1 0:00:00.150291
193 | 19100 382 382 1 0:00:00.150382
194 | 19200 384 -1 0 0:00:00.150969
195 | 19300 386 101 0 0:00:00.151030
196 | 19400 388 388 1 0:00:00.151301
197 | 19500 390 390 1 0:00:00.151015
198 | 19600 392 397 0 0:00:00.151571
199 | 19700 394 -1 0 0:00:00.151451
200 | 19800 396 396 1 0:00:00.151190
201 | 19900 398 398 1 0:00:00.151696
202 | 20000 400 400 1 0:00:00.151555
203 | 20100 402 402 1 0:00:00.151069
204 | 20200 404 404 1 0:00:00.152052
205 | 20300 406 406 1 0:00:00.151742
206 | 20400 408 847 0 0:00:00.151591
207 | 20500 410 410 1 0:00:00.151356
208 | 20600 412 412 1 0:00:00.151769
209 | 20700 414 414 1 0:00:00.151783
210 | 20800 416 416 1 0:00:00.151651
211 | 20900 418 563 0 0:00:00.151843
212 | 21000 420 420 1 0:00:00.151830
213 | 21100 422 422 1 0:00:00.151571
214 | 21200 424 -1 0 0:00:00.151089
215 | 21300 426 426 1 0:00:00.151694
216 | 21400 428 428 1 0:00:00.151853
217 | 21500 430 430 1 0:00:00.151803
218 | 21600 432 432 1 0:00:00.150823
219 | 21700 434 -1 0 0:00:00.150586
220 | 21800 436 436 1 0:00:00.150602
221 | 21900 438 438 1 0:00:00.150879
222 | 22000 440 737 0 0:00:00.151455
223 | 22100 442 442 1 0:00:00.151263
224 | 22200 444 870 0 0:00:00.151276
225 | 22300 446 446 1 0:00:00.151532
226 | 22400 448 448 1 0:00:00.151578
227 | 22500 450 407 0 0:00:00.151591
228 | 22600 452 452 1 0:00:00.151208
229 | 22700 454 454 1 0:00:00.152079
230 | 22800 456 -1 0 0:00:00.151654
231 | 22900 458 458 1 0:00:00.151926
232 | 23000 460 978 0 0:00:00.151490
233 | 23100 462 462 1 0:00:00.152122
234 | 23200 464 439 0 0:00:00.151895
235 | 23300 466 466 1 0:00:00.151898
236 | 23400 468 468 1 0:00:00.150604
237 | 23500 470 624 0 0:00:00.150587
238 | 23600 472 472 1 0:00:00.150816
239 | 23700 474 841 0 0:00:00.150733
240 | 23800 476 476 1 0:00:00.151181
241 | 23900 478 478 1 0:00:00.151568
242 | 24000 480 509 0 0:00:00.151719
243 | 24100 482 481 0 0:00:00.151719
244 | 24200 484 871 0 0:00:00.152006
245 | 24300 486 486 1 0:00:00.151621
246 | 24400 488 488 1 0:00:00.152247
247 | 24500 490 490 1 0:00:00.152318
248 | 24600 492 492 1 0:00:00.151755
249 | 24700 494 398 0 0:00:00.152137
250 | 24800 496 496 1 0:00:00.152608
251 | 24900 498 498 1 0:00:00.151824
252 | 25000 500 500 1 0:00:00.151747
253 | 25100 502 502 1 0:00:00.152223
254 | 25200 504 -1 0 0:00:00.152050
255 | 25300 506 506 1 0:00:00.151769
256 | 25400 508 508 1 0:00:00.152360
257 | 25500 510 510 1 0:00:00.150740
258 | 25600 512 740 0 0:00:00.150476
259 | 25700 514 514 1 0:00:00.151328
260 | 25800 516 431 0 0:00:00.150751
261 | 25900 518 518 1 0:00:00.150894
262 | 26000 520 520 1 0:00:00.151777
263 | 26100 522 522 1 0:00:00.151428
264 | 26200 524 461 0 0:00:00.151602
265 | 26300 526 -1 0 0:00:00.151999
266 | 26400 528 -1 0 0:00:00.152343
267 | 26500 530 531 0 0:00:00.151699
268 | 26600 532 532 1 0:00:00.152420
269 | 26700 534 534 1 0:00:00.152150
270 | 26800 536 403 0 0:00:00.152110
271 | 26900 538 538 1 0:00:00.151867
272 | 27000 540 540 1 0:00:00.151880
273 | 27100 542 -1 0 0:00:00.151522
274 | 27200 544 926 0 0:00:00.151950
275 | 27300 546 546 1 0:00:00.152192
276 | 27400 548 548 1 0:00:00.151759
277 | 27500 550 505 0 0:00:00.151389
278 | 27600 552 552 1 0:00:00.151038
279 | 27700 554 554 1 0:00:00.151238
280 | 27800 556 421 0 0:00:00.152079
281 | 27900 558 251 0 0:00:00.151511
282 | 28000 560 768 0 0:00:00.151679
283 | 28100 562 562 1 0:00:00.151944
284 | 28200 564 564 1 0:00:00.151967
285 | 28300 566 566 1 0:00:00.152348
286 | 28400 568 399 0 0:00:00.152331
287 | 28500 570 -1 0 0:00:00.152074
288 | 28600 572 418 0 0:00:00.152075
289 | 28700 574 574 1 0:00:00.152514
290 | 28800 576 -1 0 0:00:00.152667
291 | 28900 578 578 1 0:00:00.151901
292 | 29000 580 738 0 0:00:00.152458
293 | 29100 582 -1 0 0:00:00.151789
294 | 29200 584 754 0 0:00:00.152026
295 | 29300 586 586 1 0:00:00.152489
296 | 29400 588 588 1 0:00:00.152318
297 | 29500 590 590 1 0:00:00.151197
298 | 29600 592 592 1 0:00:00.151341
299 | 29700 594 843 0 0:00:00.151311
300 | 29800 596 56 0 0:00:00.151415
301 | 29900 598 664 0 0:00:00.151570
302 | 30000 600 517 0 0:00:00.151897
303 | 30100 602 602 1 0:00:00.152012
304 | 30200 604 604 1 0:00:00.152546
305 | 30300 606 606 1 0:00:00.152328
306 | 30400 608 608 1 0:00:00.152009
307 | 30500 610 800 0 0:00:00.152771
308 | 30600 612 612 1 0:00:00.152460
309 | 30700 614 614 1 0:00:00.152576
310 | 30800 616 534 0 0:00:00.152400
311 | 30900 618 -1 0 0:00:00.152495
312 | 31000 620 681 0 0:00:00.152324
313 | 31100 622 622 1 0:00:00.151970
314 | 31200 624 453 0 0:00:00.152316
315 | 31300 626 -1 0 0:00:00.151186
316 | 31400 628 624 0 0:00:00.151331
317 | 31500 630 703 0 0:00:00.151542
318 | 31600 632 632 1 0:00:00.151527
319 | 31700 634 489 0 0:00:00.151511
320 | 31800 636 636 1 0:00:00.151963
321 | 31900 638 638 1 0:00:00.151848
322 | 32000 640 640 1 0:00:00.152098
323 | 32100 642 642 1 0:00:00.152818
324 | 32200 644 644 1 0:00:00.152130
325 | 32300 646 646 1 0:00:00.152788
326 | 32400 648 -1 0 0:00:00.152037
327 | 32500 650 546 0 0:00:00.152443
328 | 32600 652 652 1 0:00:00.151977
329 | 32700 654 436 0 0:00:00.152806
330 | 32800 656 -1 0 0:00:00.152032
331 | 32900 658 658 1 0:00:00.152533
332 | 33000 660 660 1 0:00:00.152316
333 | 33100 662 662 1 0:00:00.151066
334 | 33200 664 851 0 0:00:00.151540
335 | 33300 666 659 0 0:00:00.151820
336 | 33400 668 -1 0 0:00:00.151236
337 | 33500 670 670 1 0:00:00.151784
338 | 33600 672 672 1 0:00:00.151690
339 | 33700 674 674 1 0:00:00.152147
340 | 33800 676 560 0 0:00:00.152476
341 | 33900 678 678 1 0:00:00.152597
342 | 34000 680 -1 0 0:00:00.152415
343 | 34100 682 682 1 0:00:00.152823
344 | 34200 684 684 1 0:00:00.153091
345 | 34300 686 -1 0 0:00:00.152630
346 | 34400 688 688 1 0:00:00.152025
347 | 34500 690 -1 0 0:00:00.152708
348 | 34600 692 -1 0 0:00:00.153099
349 | 34700 694 694 1 0:00:00.152318
350 | 34800 696 -1 0 0:00:00.152753
351 | 34900 698 698 1 0:00:00.152175
352 | 35000 700 700 1 0:00:00.152845
353 | 35100 702 422 0 0:00:00.153542
354 | 35200 704 704 1 0:00:00.151337
355 | 35300 706 536 0 0:00:00.151342
356 | 35400 708 682 0 0:00:00.151542
357 | 35500 710 710 1 0:00:00.151536
358 | 35600 712 -1 0 0:00:00.152408
359 | 35700 714 714 1 0:00:00.151919
360 | 35800 716 716 1 0:00:00.152362
361 | 35900 718 449 0 0:00:00.152079
362 | 36000 720 720 1 0:00:00.152772
363 | 36100 722 722 1 0:00:00.152386
364 | 36200 724 780 0 0:00:00.152625
365 | 36300 726 726 1 0:00:00.152840
366 | 36400 728 728 1 0:00:00.152847
367 | 36500 730 856 0 0:00:00.152109
368 | 36600 732 732 1 0:00:00.152683
369 | 36700 734 734 1 0:00:00.152587
370 | 36800 736 736 1 0:00:00.153082
371 | 36900 738 738 1 0:00:00.152080
372 | 37000 740 -1 0 0:00:00.152445
373 | 37100 742 466 0 0:00:00.153498
374 | 37200 744 657 0 0:00:00.152382
375 | 37300 746 746 1 0:00:00.152394
376 | 37400 748 748 1 0:00:00.151623
377 | 37500 750 -1 0 0:00:00.151844
378 | 37600 752 -1 0 0:00:00.151828
379 | 37700 754 848 0 0:00:00.151490
380 | 37800 756 756 1 0:00:00.151816
381 | 37900 758 758 1 0:00:00.152165
382 | 38000 760 651 0 0:00:00.152320
383 | 38100 762 762 1 0:00:00.152745
384 | 38200 764 -1 0 0:00:00.152301
385 | 38300 766 766 1 0:00:00.152556
386 | 38400 768 768 1 0:00:00.152708
387 | 38500 770 770 1 0:00:00.152466
388 | 38600 772 772 1 0:00:00.153146
389 | 38700 774 774 1 0:00:00.152727
390 | 38800 776 683 0 0:00:00.153644
391 | 38900 778 707 0 0:00:00.152719
392 | 39000 780 780 1 0:00:00.152416
393 | 39100 782 664 0 0:00:00.152261
394 | 39200 784 505 0 0:00:00.151717
395 | 39300 786 -1 0 0:00:00.153320
396 | 39400 788 743 0 0:00:00.151936
397 | 39500 790 791 0 0:00:00.151433
398 | 39600 792 795 0 0:00:00.151931
399 | 39700 794 794 1 0:00:00.152070
400 | 39800 796 796 1 0:00:00.151815
401 | 39900 798 798 1 0:00:00.152158
402 | 40000 800 800 1 0:00:00.152520
403 | 40100 802 802 1 0:00:00.152387
404 | 40200 804 804 1 0:00:00.152754
405 | 40300 806 806 1 0:00:00.152673
406 | 40400 808 808 1 0:00:00.152710
407 | 40500 810 878 0 0:00:00.152405
408 | 40600 812 812 1 0:00:00.152728
409 | 40700 814 814 1 0:00:00.152647
410 | 40800 816 816 1 0:00:00.153116
411 | 40900 818 745 0 0:00:00.153242
412 | 41000 820 820 1 0:00:00.153136
413 | 41100 822 822 1 0:00:00.152745
414 | 41200 824 824 1 0:00:00.153423
415 | 41300 826 826 1 0:00:00.151895
416 | 41400 828 618 0 0:00:00.151875
417 | 41500 830 830 1 0:00:00.152070
418 | 41600 832 832 1 0:00:00.151996
419 | 41700 834 417 0 0:00:00.151796
420 | 41800 836 444 0 0:00:00.152159
421 | 41900 838 631 0 0:00:00.152218
422 | 42000 840 840 1 0:00:00.152527
423 | 42100 842 842 1 0:00:00.152467
424 | 42200 844 844 1 0:00:00.152677
425 | 42300 846 583 0 0:00:00.152786
426 | 42400 848 848 1 0:00:00.152760
427 | 42500 850 850 1 0:00:00.152573
428 | 42600 852 852 1 0:00:00.152428
429 | 42700 854 854 1 0:00:00.152375
430 | 42800 856 856 1 0:00:00.153162
431 | 42900 858 832 0 0:00:00.153394
432 | 43000 860 892 0 0:00:00.152900
433 | 43100 862 862 1 0:00:00.152719
434 | 43200 864 864 1 0:00:00.153751
435 | 43300 866 866 1 0:00:00.152864
436 | 43400 868 968 0 0:00:00.153205
437 | 43500 870 870 1 0:00:00.153201
438 | 43600 872 840 0 0:00:00.153276
439 | 43700 874 874 1 0:00:00.151293
440 | 43800 876 648 0 0:00:00.151649
441 | 43900 878 878 1 0:00:00.151755
442 | 44000 880 870 0 0:00:00.152305
443 | 44100 882 -1 0 0:00:00.152210
444 | 44200 884 884 1 0:00:00.152426
445 | 44300 886 886 1 0:00:00.152342
446 | 44400 888 718 0 0:00:00.152582
447 | 44500 890 890 1 0:00:00.153204
448 | 44600 892 442 0 0:00:00.153057
449 | 44700 894 894 1 0:00:00.152637
450 | 44800 896 896 1 0:00:00.152978
451 | 44900 898 508 0 0:00:00.152954
452 | 45000 900 900 1 0:00:00.153266
453 | 45100 902 902 1 0:00:00.152574
454 | 45200 904 905 0 0:00:00.153147
455 | 45300 906 906 1 0:00:00.152289
456 | 45400 908 895 0 0:00:00.153968
457 | 45500 910 -1 0 0:00:00.153294
458 | 45600 912 912 1 0:00:00.153271
459 | 45700 914 914 1 0:00:00.153157
460 | 45800 916 916 1 0:00:00.153229
461 | 45900 918 918 1 0:00:00.151889
462 | 46000 920 672 0 0:00:00.151853
463 | 46100 922 922 1 0:00:00.151801
464 | 46200 924 924 1 0:00:00.151895
465 | 46300 926 926 1 0:00:00.152917
466 | 46400 928 927 0 0:00:00.152268
467 | 46500 930 -1 0 0:00:00.152658
468 | 46600 932 932 1 0:00:00.152935
469 | 46700 934 934 1 0:00:00.152805
470 | 46800 936 936 1 0:00:00.153057
471 | 46900 938 938 1 0:00:00.153356
472 | 47000 940 940 1 0:00:00.152951
473 | 47100 942 942 1 0:00:00.152692
474 | 47200 944 946 0 0:00:00.152845
475 | 47300 946 946 1 0:00:00.152305
476 | 47400 948 948 1 0:00:00.153133
477 | 47500 950 950 1 0:00:00.152508
478 | 47600 952 90 0 0:00:00.153198
479 | 47700 954 -1 0 0:00:00.153289
480 | 47800 956 956 1 0:00:00.153317
481 | 47900 958 483 0 0:00:00.153305
482 | 48000 960 960 1 0:00:00.151798
483 | 48100 962 962 1 0:00:00.153157
484 | 48200 964 961 0 0:00:00.152369
485 | 48300 966 630 0 0:00:00.151834
486 | 48400 968 968 1 0:00:00.152382
487 | 48500 970 795 0 0:00:00.152050
488 | 48600 972 979 0 0:00:00.151806
489 | 48700 974 974 1 0:00:00.152487
490 | 48800 976 -1 0 0:00:00.152602
491 | 48900 978 701 0 0:00:00.152845
492 | 49000 980 980 1 0:00:00.153392
493 | 49100 982 982 1 0:00:00.153532
494 | 49200 984 984 1 0:00:00.153388
495 | 49300 986 986 1 0:00:00.153524
496 | 49400 988 988 1 0:00:00.153043
497 | 49500 990 990 1 0:00:00.153023
498 | 49600 992 992 1 0:00:00.153144
499 | 49700 994 994 1 0:00:00.153317
500 | 49800 996 996 1 0:00:00.153331
501 | 49900 998 998 1 0:00:00.153559
502 |
--------------------------------------------------------------------------------
/data/predict/imagenet/resnet50/noise_0.25/test/N_1000:
--------------------------------------------------------------------------------
1 | idx label predict correct time
2 | 0 0 0 1 0:00:06.513843
3 | 100 2 2 1 0:00:01.455222
4 | 200 4 395 0 0:00:01.448558
5 | 300 6 6 1 0:00:01.448352
6 | 400 8 8 1 0:00:01.450362
7 | 500 10 10 1 0:00:01.449924
8 | 600 12 12 1 0:00:01.452259
9 | 700 14 14 1 0:00:01.453799
10 | 800 16 16 1 0:00:01.455621
11 | 900 18 18 1 0:00:01.454016
12 | 1000 20 20 1 0:00:01.457207
13 | 1100 22 21 0 0:00:01.454456
14 | 1200 24 24 1 0:00:01.457761
15 | 1300 26 26 1 0:00:01.459710
16 | 1400 28 28 1 0:00:01.457183
17 | 1500 30 30 1 0:00:01.458956
18 | 1600 32 103 0 0:00:01.459355
19 | 1700 34 327 0 0:00:01.462396
20 | 1800 36 35 0 0:00:01.460968
21 | 1900 38 38 1 0:00:01.462716
22 | 2000 40 31 0 0:00:01.464007
23 | 2100 42 48 0 0:00:01.463212
24 | 2200 44 44 1 0:00:01.464057
25 | 2300 46 46 1 0:00:01.463487
26 | 2400 48 48 1 0:00:01.466431
27 | 2500 50 50 1 0:00:01.464084
28 | 2600 52 52 1 0:00:01.465358
29 | 2700 54 57 0 0:00:01.465160
30 | 2800 56 56 1 0:00:01.471741
31 | 2900 58 58 1 0:00:01.469234
32 | 3000 60 62 0 0:00:01.469421
33 | 3100 62 62 1 0:00:01.467561
34 | 3200 64 64 1 0:00:01.468968
35 | 3300 66 63 0 0:00:01.469123
36 | 3400 68 -1 0 0:00:01.470740
37 | 3500 70 70 1 0:00:01.473675
38 | 3600 72 72 1 0:00:01.472703
39 | 3700 74 74 1 0:00:01.473779
40 | 3800 76 76 1 0:00:01.469706
41 | 3900 78 78 1 0:00:01.471689
42 | 4000 80 80 1 0:00:01.474229
43 | 4100 82 82 1 0:00:01.475134
44 | 4200 84 84 1 0:00:01.475369
45 | 4300 86 86 1 0:00:01.476813
46 | 4400 88 88 1 0:00:01.477425
47 | 4500 90 90 1 0:00:01.479196
48 | 4600 92 92 1 0:00:01.496022
49 | 4700 94 94 1 0:00:01.477135
50 | 4800 96 96 1 0:00:01.480331
51 | 4900 98 98 1 0:00:01.495252
52 | 5000 100 100 1 0:00:01.479198
53 | 5100 102 102 1 0:00:01.491304
54 | 5200 104 342 0 0:00:01.555276
55 | 5300 106 106 1 0:00:01.495784
56 | 5400 108 470 0 0:00:01.479194
57 | 5500 110 110 1 0:00:01.476201
58 | 5600 112 112 1 0:00:01.483646
59 | 5700 114 113 0 0:00:01.485140
60 | 5800 116 116 1 0:00:01.539842
61 | 5900 118 118 1 0:00:01.550041
62 | 6000 120 119 0 0:00:01.498474
63 | 6100 122 122 1 0:00:01.477521
64 | 6200 124 123 0 0:00:01.484329
65 | 6300 126 126 1 0:00:01.484162
66 | 6400 128 128 1 0:00:01.481536
67 | 6500 130 130 1 0:00:01.483169
68 | 6600 132 132 1 0:00:01.503742
69 | 6700 134 132 0 0:00:01.518725
70 | 6800 136 136 1 0:00:01.539357
71 | 6900 138 138 1 0:00:01.498044
72 | 7000 140 140 1 0:00:01.480387
73 | 7100 142 142 1 0:00:01.483995
74 | 7200 144 144 1 0:00:01.486113
75 | 7300 146 146 1 0:00:01.488695
76 | 7400 148 148 1 0:00:01.489084
77 | 7500 150 360 0 0:00:01.483568
78 | 7600 152 152 1 0:00:01.479528
79 | 7700 154 154 1 0:00:01.478619
80 | 7800 156 156 1 0:00:01.537528
81 | 7900 158 158 1 0:00:01.496571
82 | 8000 160 160 1 0:00:01.498672
83 | 8100 162 162 1 0:00:01.481299
84 | 8200 164 164 1 0:00:01.496998
85 | 8300 166 166 1 0:00:01.490325
86 | 8400 168 211 0 0:00:01.503927
87 | 8500 170 172 0 0:00:01.536041
88 | 8600 172 172 1 0:00:01.500019
89 | 8700 174 174 1 0:00:01.520921
90 | 8800 176 176 1 0:00:01.487111
91 | 8900 178 178 1 0:00:01.488966
92 | 9000 180 243 0 0:00:01.550351
93 | 9100 182 182 1 0:00:01.526359
94 | 9200 184 -1 0 0:00:01.557193
95 | 9300 186 192 0 0:00:01.499761
96 | 9400 188 188 1 0:00:01.524055
97 | 9500 190 190 1 0:00:01.483822
98 | 9600 192 199 0 0:00:01.479605
99 | 9700 194 202 0 0:00:01.554249
100 | 9800 196 199 0 0:00:01.496503
101 | 9900 198 196 0 0:00:01.492259
102 | 10000 200 200 1 0:00:01.479200
103 | 10100 202 202 1 0:00:01.482847
104 | 10200 204 204 1 0:00:01.479599
105 | 10300 206 206 1 0:00:01.520266
106 | 10400 208 208 1 0:00:01.481036
107 | 10500 210 210 1 0:00:01.489356
108 | 10600 212 212 1 0:00:01.555681
109 | 10700 214 214 1 0:00:01.491593
110 | 10800 216 216 1 0:00:01.484073
111 | 10900 218 218 1 0:00:01.485102
112 | 11000 220 220 1 0:00:01.479046
113 | 11100 222 222 1 0:00:01.498316
114 | 11200 224 233 0 0:00:01.481419
115 | 11300 226 226 1 0:00:01.489648
116 | 11400 228 228 1 0:00:01.475325
117 | 11500 230 230 1 0:00:01.479492
118 | 11600 232 232 1 0:00:01.485672
119 | 11700 234 234 1 0:00:01.478974
120 | 11800 236 234 0 0:00:01.532255
121 | 11900 238 238 1 0:00:01.501435
122 | 12000 240 238 0 0:00:01.478093
123 | 12100 242 242 1 0:00:01.487937
124 | 12200 244 244 1 0:00:01.555950
125 | 12300 246 251 0 0:00:01.497940
126 | 12400 248 248 1 0:00:01.484743
127 | 12500 250 169 0 0:00:01.480882
128 | 12600 252 252 1 0:00:01.485409
129 | 12700 254 254 1 0:00:01.482774
130 | 12800 256 256 1 0:00:01.554899
131 | 12900 258 258 1 0:00:01.504591
132 | 13000 260 260 1 0:00:01.565463
133 | 13100 262 262 1 0:00:01.493973
134 | 13200 264 264 1 0:00:01.525658
135 | 13300 266 266 1 0:00:01.486248
136 | 13400 268 268 1 0:00:01.481760
137 | 13500 270 270 1 0:00:01.481478
138 | 13600 272 274 0 0:00:01.479039
139 | 13700 274 274 1 0:00:01.475843
140 | 13800 276 276 1 0:00:01.475403
141 | 13900 278 278 1 0:00:01.478996
142 | 14000 280 280 1 0:00:01.481032
143 | 14100 282 282 1 0:00:01.483646
144 | 14200 284 284 1 0:00:01.479185
145 | 14300 286 286 1 0:00:01.492340
146 | 14400 288 288 1 0:00:01.487615
147 | 14500 290 290 1 0:00:01.489121
148 | 14600 292 290 0 0:00:01.480814
149 | 14700 294 294 1 0:00:01.490083
150 | 14800 296 296 1 0:00:01.488762
151 | 14900 298 298 1 0:00:01.481069
152 | 15000 300 300 1 0:00:01.520467
153 | 15100 302 302 1 0:00:01.482847
154 | 15200 304 301 0 0:00:01.492529
155 | 15300 306 306 1 0:00:01.558305
156 | 15400 308 308 1 0:00:01.488705
157 | 15500 310 310 1 0:00:01.480214
158 | 15600 312 315 0 0:00:01.496566
159 | 15700 314 314 1 0:00:01.540013
160 | 15800 316 316 1 0:00:01.516623
161 | 15900 318 318 1 0:00:01.477651
162 | 16000 320 320 1 0:00:01.477822
163 | 16100 322 322 1 0:00:01.483212
164 | 16200 324 324 1 0:00:01.489052
165 | 16300 326 326 1 0:00:01.480788
166 | 16400 328 328 1 0:00:01.482190
167 | 16500 330 330 1 0:00:01.491863
168 | 16600 332 153 0 0:00:01.478761
169 | 16700 334 334 1 0:00:01.515207
170 | 16800 336 336 1 0:00:01.554691
171 | 16900 338 617 0 0:00:01.541147
172 | 17000 340 340 1 0:00:01.489892
173 | 17100 342 287 0 0:00:01.481903
174 | 17200 344 344 1 0:00:01.502365
175 | 17300 346 344 0 0:00:01.534523
176 | 17400 348 348 1 0:00:01.505174
177 | 17500 350 350 1 0:00:01.502758
178 | 17600 352 352 1 0:00:01.481927
179 | 17700 354 354 1 0:00:01.484131
180 | 17800 356 -1 0 0:00:01.481266
181 | 17900 358 359 0 0:00:01.480267
182 | 18000 360 360 1 0:00:01.478824
183 | 18100 362 362 1 0:00:01.518305
184 | 18200 364 364 1 0:00:01.560162
185 | 18300 366 366 1 0:00:01.496499
186 | 18400 368 368 1 0:00:01.482130
187 | 18500 370 370 1 0:00:01.477426
188 | 18600 372 372 1 0:00:01.481619
189 | 18700 374 374 1 0:00:01.488400
190 | 18800 376 376 1 0:00:01.496681
191 | 18900 378 378 1 0:00:01.485233
192 | 19000 380 380 1 0:00:01.531544
193 | 19100 382 382 1 0:00:01.499931
194 | 19200 384 283 0 0:00:01.480609
195 | 19300 386 101 0 0:00:01.477893
196 | 19400 388 388 1 0:00:01.481483
197 | 19500 390 390 1 0:00:01.485835
198 | 19600 392 397 0 0:00:01.494839
199 | 19700 394 467 0 0:00:01.479918
200 | 19800 396 396 1 0:00:01.486316
201 | 19900 398 398 1 0:00:01.479316
202 | 20000 400 400 1 0:00:01.482861
203 | 20100 402 402 1 0:00:01.481416
204 | 20200 404 404 1 0:00:01.500082
205 | 20300 406 406 1 0:00:01.481807
206 | 20400 408 847 0 0:00:01.477773
207 | 20500 410 410 1 0:00:01.550446
208 | 20600 412 412 1 0:00:01.495554
209 | 20700 414 414 1 0:00:01.490323
210 | 20800 416 416 1 0:00:01.479424
211 | 20900 418 563 0 0:00:01.479409
212 | 21000 420 420 1 0:00:01.480183
213 | 21100 422 422 1 0:00:01.522102
214 | 21200 424 454 0 0:00:01.480354
215 | 21300 426 426 1 0:00:01.484423
216 | 21400 428 428 1 0:00:01.481469
217 | 21500 430 430 1 0:00:01.498621
218 | 21600 432 432 1 0:00:01.483243
219 | 21700 434 434 1 0:00:01.485271
220 | 21800 436 436 1 0:00:01.498695
221 | 21900 438 438 1 0:00:01.554712
222 | 22000 440 737 0 0:00:01.493426
223 | 22100 442 442 1 0:00:01.476355
224 | 22200 444 870 0 0:00:01.484731
225 | 22300 446 446 1 0:00:01.481968
226 | 22400 448 448 1 0:00:01.480012
227 | 22500 450 407 0 0:00:01.479037
228 | 22600 452 452 1 0:00:01.491563
229 | 22700 454 454 1 0:00:01.554456
230 | 22800 456 645 0 0:00:01.526294
231 | 22900 458 458 1 0:00:01.496040
232 | 23000 460 978 0 0:00:01.482221
233 | 23100 462 462 1 0:00:01.502633
234 | 23200 464 439 0 0:00:01.496977
235 | 23300 466 466 1 0:00:01.474824
236 | 23400 468 468 1 0:00:01.479537
237 | 23500 470 624 0 0:00:01.482006
238 | 23600 472 472 1 0:00:01.491713
239 | 23700 474 841 0 0:00:01.481372
240 | 23800 476 476 1 0:00:01.483987
241 | 23900 478 478 1 0:00:01.484918
242 | 24000 480 509 0 0:00:01.480896
243 | 24100 482 481 0 0:00:01.481517
244 | 24200 484 871 0 0:00:01.480409
245 | 24300 486 486 1 0:00:01.551933
246 | 24400 488 488 1 0:00:01.501731
247 | 24500 490 490 1 0:00:01.552063
248 | 24600 492 492 1 0:00:01.505758
249 | 24700 494 398 0 0:00:01.498369
250 | 24800 496 496 1 0:00:01.481360
251 | 24900 498 498 1 0:00:01.477659
252 | 25000 500 500 1 0:00:01.479922
253 | 25100 502 502 1 0:00:01.477720
254 | 25200 504 504 1 0:00:01.481018
255 | 25300 506 506 1 0:00:01.480943
256 | 25400 508 508 1 0:00:01.478332
257 | 25500 510 510 1 0:00:01.496781
258 | 25600 512 740 0 0:00:01.484283
259 | 25700 514 514 1 0:00:01.492787
260 | 25800 516 431 0 0:00:01.491374
261 | 25900 518 518 1 0:00:01.478206
262 | 26000 520 520 1 0:00:01.491399
263 | 26100 522 522 1 0:00:01.481239
264 | 26200 524 461 0 0:00:01.516815
265 | 26300 526 526 1 0:00:01.478231
266 | 26400 528 -1 0 0:00:01.489215
267 | 26500 530 531 0 0:00:01.529016
268 | 26600 532 532 1 0:00:01.488698
269 | 26700 534 534 1 0:00:01.483696
270 | 26800 536 403 0 0:00:01.500773
271 | 26900 538 538 1 0:00:01.479615
272 | 27000 540 540 1 0:00:01.483645
273 | 27100 542 -1 0 0:00:01.483379
274 | 27200 544 926 0 0:00:01.480354
275 | 27300 546 546 1 0:00:01.483093
276 | 27400 548 548 1 0:00:01.480865
277 | 27500 550 505 0 0:00:01.479713
278 | 27600 552 552 1 0:00:01.484718
279 | 27700 554 554 1 0:00:01.486498
280 | 27800 556 421 0 0:00:01.481885
281 | 27900 558 251 0 0:00:01.481816
282 | 28000 560 768 0 0:00:01.479113
283 | 28100 562 562 1 0:00:01.487936
284 | 28200 564 564 1 0:00:01.482777
285 | 28300 566 566 1 0:00:01.512465
286 | 28400 568 399 0 0:00:01.488383
287 | 28500 570 -1 0 0:00:01.482008
288 | 28600 572 418 0 0:00:01.514439
289 | 28700 574 574 1 0:00:01.477766
290 | 28800 576 -1 0 0:00:01.496104
291 | 28900 578 578 1 0:00:01.551244
292 | 29000 580 738 0 0:00:01.521960
293 | 29100 582 -1 0 0:00:01.507269
294 | 29200 584 754 0 0:00:01.478857
295 | 29300 586 586 1 0:00:01.477593
296 | 29400 588 588 1 0:00:01.493497
297 | 29500 590 590 1 0:00:01.477899
298 | 29600 592 592 1 0:00:01.479885
299 | 29700 594 843 0 0:00:01.481439
300 | 29800 596 56 0 0:00:01.541896
301 | 29900 598 664 0 0:00:01.526365
302 | 30000 600 517 0 0:00:01.531583
303 | 30100 602 602 1 0:00:01.484584
304 | 30200 604 604 1 0:00:01.480353
305 | 30300 606 606 1 0:00:01.523729
306 | 30400 608 608 1 0:00:01.481934
307 | 30500 610 800 0 0:00:01.478992
308 | 30600 612 612 1 0:00:01.548418
309 | 30700 614 614 1 0:00:01.494345
310 | 30800 616 534 0 0:00:01.500657
311 | 30900 618 828 0 0:00:01.487541
312 | 31000 620 681 0 0:00:01.490442
313 | 31100 622 622 1 0:00:01.480001
314 | 31200 624 453 0 0:00:01.520598
315 | 31300 626 542 0 0:00:01.492061
316 | 31400 628 624 0 0:00:01.485516
317 | 31500 630 703 0 0:00:01.493293
318 | 31600 632 632 1 0:00:01.499207
319 | 31700 634 489 0 0:00:01.482136
320 | 31800 636 636 1 0:00:01.522583
321 | 31900 638 638 1 0:00:01.480404
322 | 32000 640 640 1 0:00:01.478007
323 | 32100 642 642 1 0:00:01.490945
324 | 32200 644 644 1 0:00:01.483194
325 | 32300 646 646 1 0:00:01.477241
326 | 32400 648 -1 0 0:00:01.480447
327 | 32500 650 546 0 0:00:01.557726
328 | 32600 652 652 1 0:00:01.505422
329 | 32700 654 436 0 0:00:01.477040
330 | 32800 656 -1 0 0:00:01.509807
331 | 32900 658 658 1 0:00:01.490112
332 | 33000 660 660 1 0:00:01.517296
333 | 33100 662 662 1 0:00:01.485711
334 | 33200 664 851 0 0:00:01.490185
335 | 33300 666 659 0 0:00:01.522107
336 | 33400 668 668 1 0:00:01.478890
337 | 33500 670 670 1 0:00:01.487738
338 | 33600 672 672 1 0:00:01.479811
339 | 33700 674 674 1 0:00:01.478646
340 | 33800 676 560 0 0:00:01.485127
341 | 33900 678 678 1 0:00:01.490292
342 | 34000 680 793 0 0:00:01.484289
343 | 34100 682 682 1 0:00:01.478803
344 | 34200 684 684 1 0:00:01.478796
345 | 34300 686 -1 0 0:00:01.479543
346 | 34400 688 688 1 0:00:01.537382
347 | 34500 690 690 1 0:00:01.555078
348 | 34600 692 481 0 0:00:01.524606
349 | 34700 694 694 1 0:00:01.484595
350 | 34800 696 -1 0 0:00:01.480659
351 | 34900 698 698 1 0:00:01.495338
352 | 35000 700 700 1 0:00:01.491358
353 | 35100 702 422 0 0:00:01.479972
354 | 35200 704 704 1 0:00:01.559804
355 | 35300 706 536 0 0:00:01.515814
356 | 35400 708 682 0 0:00:01.494998
357 | 35500 710 710 1 0:00:01.483101
358 | 35600 712 622 0 0:00:01.489805
359 | 35700 714 714 1 0:00:01.485512
360 | 35800 716 716 1 0:00:01.480383
361 | 35900 718 449 0 0:00:01.481020
362 | 36000 720 720 1 0:00:01.480507
363 | 36100 722 722 1 0:00:01.484626
364 | 36200 724 780 0 0:00:01.492352
365 | 36300 726 726 1 0:00:01.494925
366 | 36400 728 728 1 0:00:01.488405
367 | 36500 730 856 0 0:00:01.491160
368 | 36600 732 732 1 0:00:01.488379
369 | 36700 734 734 1 0:00:01.480097
370 | 36800 736 736 1 0:00:01.483129
371 | 36900 738 738 1 0:00:01.524685
372 | 37000 740 -1 0 0:00:01.483893
373 | 37100 742 466 0 0:00:01.481901
374 | 37200 744 657 0 0:00:01.477067
375 | 37300 746 746 1 0:00:01.481244
376 | 37400 748 748 1 0:00:01.481573
377 | 37500 750 172 0 0:00:01.551862
378 | 37600 752 830 0 0:00:01.493996
379 | 37700 754 848 0 0:00:01.478429
380 | 37800 756 756 1 0:00:01.479473
381 | 37900 758 758 1 0:00:01.480465
382 | 38000 760 651 0 0:00:01.498277
383 | 38100 762 762 1 0:00:01.558768
384 | 38200 764 593 0 0:00:01.497985
385 | 38300 766 766 1 0:00:01.568286
386 | 38400 768 768 1 0:00:01.531938
387 | 38500 770 770 1 0:00:01.481052
388 | 38600 772 772 1 0:00:01.479920
389 | 38700 774 774 1 0:00:01.492791
390 | 38800 776 683 0 0:00:01.536620
391 | 38900 778 707 0 0:00:01.480408
392 | 39000 780 780 1 0:00:01.478122
393 | 39100 782 664 0 0:00:01.551872
394 | 39200 784 505 0 0:00:01.496936
395 | 39300 786 513 0 0:00:01.480556
396 | 39400 788 743 0 0:00:01.480324
397 | 39500 790 791 0 0:00:01.475573
398 | 39600 792 795 0 0:00:01.483009
399 | 39700 794 794 1 0:00:01.480900
400 | 39800 796 796 1 0:00:01.482380
401 | 39900 798 798 1 0:00:01.480042
402 | 40000 800 800 1 0:00:01.481433
403 | 40100 802 802 1 0:00:01.481994
404 | 40200 804 804 1 0:00:01.482339
405 | 40300 806 806 1 0:00:01.484827
406 | 40400 808 808 1 0:00:01.480378
407 | 40500 810 878 0 0:00:01.550252
408 | 40600 812 812 1 0:00:01.497617
409 | 40700 814 814 1 0:00:01.543693
410 | 40800 816 816 1 0:00:01.528976
411 | 40900 818 745 0 0:00:01.549425
412 | 41000 820 820 1 0:00:01.488208
413 | 41100 822 822 1 0:00:01.479282
414 | 41200 824 824 1 0:00:01.494460
415 | 41300 826 826 1 0:00:01.482595
416 | 41400 828 618 0 0:00:01.493755
417 | 41500 830 830 1 0:00:01.560996
418 | 41600 832 832 1 0:00:01.494581
419 | 41700 834 417 0 0:00:01.516018
420 | 41800 836 444 0 0:00:01.479565
421 | 41900 838 631 0 0:00:01.486431
422 | 42000 840 840 1 0:00:01.481183
423 | 42100 842 842 1 0:00:01.477463
424 | 42200 844 844 1 0:00:01.486127
425 | 42300 846 583 0 0:00:01.483045
426 | 42400 848 848 1 0:00:01.525781
427 | 42500 850 850 1 0:00:01.483320
428 | 42600 852 852 1 0:00:01.478466
429 | 42700 854 854 1 0:00:01.513655
430 | 42800 856 856 1 0:00:01.486674
431 | 42900 858 832 0 0:00:01.482063
432 | 43000 860 892 0 0:00:01.480499
433 | 43100 862 862 1 0:00:01.480461
434 | 43200 864 864 1 0:00:01.480838
435 | 43300 866 866 1 0:00:01.479955
436 | 43400 868 968 0 0:00:01.479614
437 | 43500 870 870 1 0:00:01.492825
438 | 43600 872 840 0 0:00:01.494615
439 | 43700 874 874 1 0:00:01.480262
440 | 43800 876 648 0 0:00:01.530230
441 | 43900 878 878 1 0:00:01.477895
442 | 44000 880 870 0 0:00:01.481497
443 | 44100 882 -1 0 0:00:01.523555
444 | 44200 884 884 1 0:00:01.546123
445 | 44300 886 886 1 0:00:01.534611
446 | 44400 888 718 0 0:00:01.528654
447 | 44500 890 890 1 0:00:01.478625
448 | 44600 892 442 0 0:00:01.489870
449 | 44700 894 894 1 0:00:01.485551
450 | 44800 896 896 1 0:00:01.486424
451 | 44900 898 508 0 0:00:01.481154
452 | 45000 900 900 1 0:00:01.479834
453 | 45100 902 902 1 0:00:01.494956
454 | 45200 904 905 0 0:00:01.534988
455 | 45300 906 906 1 0:00:01.480348
456 | 45400 908 895 0 0:00:01.480904
457 | 45500 910 -1 0 0:00:01.549871
458 | 45600 912 912 1 0:00:01.518482
459 | 45700 914 914 1 0:00:01.496823
460 | 45800 916 916 1 0:00:01.484683
461 | 45900 918 918 1 0:00:01.482222
462 | 46000 920 672 0 0:00:01.481179
463 | 46100 922 922 1 0:00:01.480635
464 | 46200 924 924 1 0:00:01.480506
465 | 46300 926 926 1 0:00:01.481995
466 | 46400 928 927 0 0:00:01.482548
467 | 46500 930 964 0 0:00:01.478524
468 | 46600 932 932 1 0:00:01.481828
469 | 46700 934 934 1 0:00:01.492927
470 | 46800 936 936 1 0:00:01.485990
471 | 46900 938 938 1 0:00:01.477883
472 | 47000 940 940 1 0:00:01.495854
473 | 47100 942 942 1 0:00:01.483602
474 | 47200 944 946 0 0:00:01.485325
475 | 47300 946 946 1 0:00:01.479533
476 | 47400 948 948 1 0:00:01.482796
477 | 47500 950 950 1 0:00:01.481151
478 | 47600 952 90 0 0:00:01.479927
479 | 47700 954 -1 0 0:00:01.484572
480 | 47800 956 956 1 0:00:01.490380
481 | 47900 958 483 0 0:00:01.493005
482 | 48000 960 960 1 0:00:01.479488
483 | 48100 962 962 1 0:00:01.508343
484 | 48200 964 961 0 0:00:01.556295
485 | 48300 966 630 0 0:00:01.503184
486 | 48400 968 968 1 0:00:01.483091
487 | 48500 970 795 0 0:00:01.480517
488 | 48600 972 979 0 0:00:01.479620
489 | 48700 974 974 1 0:00:01.479991
490 | 48800 976 724 0 0:00:01.489057
491 | 48900 978 701 0 0:00:01.489683
492 | 49000 980 980 1 0:00:01.482108
493 | 49100 982 982 1 0:00:01.483436
494 | 49200 984 984 1 0:00:01.503881
495 | 49300 986 986 1 0:00:01.484524
496 | 49400 988 988 1 0:00:01.484175
497 | 49500 990 990 1 0:00:01.543809
498 | 49600 992 992 1 0:00:01.552771
499 | 49700 994 994 1 0:00:01.492512
500 | 49800 996 996 1 0:00:01.485028
501 | 49900 998 998 1 0:00:01.477876
502 |
--------------------------------------------------------------------------------
/data/predict/imagenet/resnet50/noise_0.25/test/N_10000:
--------------------------------------------------------------------------------
1 | idx label predict correct time
2 | 0 0 0 1 0:00:17.938670
3 | 100 2 2 1 0:00:14.529888
4 | 200 4 395 0 0:00:14.602542
5 | 300 6 6 1 0:00:14.658964
6 | 400 8 8 1 0:00:14.701330
7 | 500 10 10 1 0:00:14.710515
8 | 600 12 12 1 0:00:14.762375
9 | 700 14 14 1 0:00:14.895730
10 | 800 16 16 1 0:00:15.036877
11 | 900 18 18 1 0:00:14.915746
12 | 1000 20 20 1 0:00:14.775350
13 | 1100 22 21 0 0:00:14.954231
14 | 1200 24 24 1 0:00:14.953908
15 | 1300 26 26 1 0:00:14.813852
16 | 1400 28 28 1 0:00:14.843823
17 | 1500 30 30 1 0:00:14.890527
18 | 1600 32 103 0 0:00:14.992512
19 | 1700 34 327 0 0:00:14.821474
20 | 1800 36 35 0 0:00:14.935570
21 | 1900 38 38 1 0:00:14.801161
22 | 2000 40 31 0 0:00:14.919142
23 | 2100 42 48 0 0:00:14.847678
24 | 2200 44 44 1 0:00:15.038948
25 | 2300 46 46 1 0:00:14.826487
26 | 2400 48 48 1 0:00:14.995221
27 | 2500 50 50 1 0:00:15.031976
28 | 2600 52 52 1 0:00:14.921636
29 | 2700 54 57 0 0:00:14.942702
30 | 2800 56 56 1 0:00:14.990683
31 | 2900 58 58 1 0:00:14.999382
32 | 3000 60 62 0 0:00:14.885931
33 | 3100 62 62 1 0:00:14.838300
34 | 3200 64 64 1 0:00:14.967360
35 | 3300 66 63 0 0:00:14.891294
36 | 3400 68 -1 0 0:00:14.835415
37 | 3500 70 70 1 0:00:14.876681
38 | 3600 72 72 1 0:00:15.101473
39 | 3700 74 74 1 0:00:14.888600
40 | 3800 76 76 1 0:00:14.864935
41 | 3900 78 78 1 0:00:14.870341
42 | 4000 80 80 1 0:00:14.896723
43 | 4100 82 82 1 0:00:14.920049
44 | 4200 84 84 1 0:00:14.901886
45 | 4300 86 86 1 0:00:14.940372
46 | 4400 88 88 1 0:00:14.945328
47 | 4500 90 90 1 0:00:14.846179
48 | 4600 92 92 1 0:00:15.061382
49 | 4700 94 94 1 0:00:14.843592
50 | 4800 96 96 1 0:00:14.833731
51 | 4900 98 98 1 0:00:14.789376
52 | 5000 100 100 1 0:00:14.819640
53 | 5100 102 102 1 0:00:14.878288
54 | 5200 104 342 0 0:00:14.806604
55 | 5300 106 106 1 0:00:14.844456
56 | 5400 108 470 0 0:00:15.005162
57 | 5500 110 110 1 0:00:14.846514
58 | 5600 112 112 1 0:00:14.857095
59 | 5700 114 113 0 0:00:14.923831
60 | 5800 116 116 1 0:00:14.970319
61 | 5900 118 118 1 0:00:14.890198
62 | 6000 120 119 0 0:00:14.816716
63 | 6100 122 122 1 0:00:15.101482
64 | 6200 124 123 0 0:00:14.836912
65 | 6300 126 126 1 0:00:15.015560
66 | 6400 128 128 1 0:00:14.807324
67 | 6500 130 130 1 0:00:14.792057
68 | 6600 132 132 1 0:00:15.131752
69 | 6700 134 132 0 0:00:14.876376
70 | 6800 136 136 1 0:00:14.973896
71 | 6900 138 138 1 0:00:14.950263
72 | 7000 140 140 1 0:00:14.883024
73 | 7100 142 142 1 0:00:14.867703
74 | 7200 144 144 1 0:00:14.824093
75 | 7300 146 146 1 0:00:14.969388
76 | 7400 148 148 1 0:00:15.022812
77 | 7500 150 360 0 0:00:14.857029
78 | 7600 152 152 1 0:00:15.055737
79 | 7700 154 154 1 0:00:14.930024
80 | 7800 156 156 1 0:00:14.998052
81 | 7900 158 158 1 0:00:14.854780
82 | 8000 160 160 1 0:00:15.022626
83 | 8100 162 162 1 0:00:14.812995
84 | 8200 164 164 1 0:00:14.955143
85 | 8300 166 166 1 0:00:15.046064
86 | 8400 168 211 0 0:00:14.830755
87 | 8500 170 172 0 0:00:14.842812
88 | 8600 172 172 1 0:00:14.830592
89 | 8700 174 174 1 0:00:15.064394
90 | 8800 176 176 1 0:00:14.872763
91 | 8900 178 178 1 0:00:14.958781
92 | 9000 180 243 0 0:00:14.998049
93 | 9100 182 182 1 0:00:14.956381
94 | 9200 184 188 0 0:00:14.997293
95 | 9300 186 192 0 0:00:14.842927
96 | 9400 188 188 1 0:00:14.969200
97 | 9500 190 190 1 0:00:14.819773
98 | 9600 192 199 0 0:00:14.841370
99 | 9700 194 202 0 0:00:14.891114
100 | 9800 196 199 0 0:00:15.028255
101 | 9900 198 196 0 0:00:14.930591
102 | 10000 200 200 1 0:00:15.043234
103 | 10100 202 202 1 0:00:14.797675
104 | 10200 204 204 1 0:00:15.117067
105 | 10300 206 206 1 0:00:14.826504
106 | 10400 208 208 1 0:00:14.879246
107 | 10500 210 210 1 0:00:14.878064
108 | 10600 212 212 1 0:00:14.987449
109 | 10700 214 214 1 0:00:14.888630
110 | 10800 216 216 1 0:00:14.790781
111 | 10900 218 218 1 0:00:14.845216
112 | 11000 220 220 1 0:00:15.162716
113 | 11100 222 222 1 0:00:14.845916
114 | 11200 224 233 0 0:00:14.967866
115 | 11300 226 226 1 0:00:15.006416
116 | 11400 228 228 1 0:00:14.967813
117 | 11500 230 230 1 0:00:14.985305
118 | 11600 232 232 1 0:00:14.927306
119 | 11700 234 234 1 0:00:14.946466
120 | 11800 236 234 0 0:00:15.021441
121 | 11900 238 238 1 0:00:14.916477
122 | 12000 240 238 0 0:00:14.907923
123 | 12100 242 242 1 0:00:14.984931
124 | 12200 244 244 1 0:00:14.801400
125 | 12300 246 251 0 0:00:14.846193
126 | 12400 248 248 1 0:00:15.127333
127 | 12500 250 169 0 0:00:14.818779
128 | 12600 252 252 1 0:00:14.976161
129 | 12700 254 254 1 0:00:14.848175
130 | 12800 256 256 1 0:00:14.817288
131 | 12900 258 258 1 0:00:14.941558
132 | 13000 260 260 1 0:00:14.882449
133 | 13100 262 262 1 0:00:14.845883
134 | 13200 264 264 1 0:00:15.053814
135 | 13300 266 266 1 0:00:14.963253
136 | 13400 268 268 1 0:00:14.906438
137 | 13500 270 270 1 0:00:14.821655
138 | 13600 272 274 0 0:00:14.912062
139 | 13700 274 274 1 0:00:15.016080
140 | 13800 276 276 1 0:00:14.981570
141 | 13900 278 278 1 0:00:14.983741
142 | 14000 280 280 1 0:00:14.931658
143 | 14100 282 282 1 0:00:14.885947
144 | 14200 284 284 1 0:00:14.956033
145 | 14300 286 286 1 0:00:14.804141
146 | 14400 288 288 1 0:00:14.923306
147 | 14500 290 290 1 0:00:14.974497
148 | 14600 292 290 0 0:00:15.046874
149 | 14700 294 294 1 0:00:14.817052
150 | 14800 296 296 1 0:00:15.010772
151 | 14900 298 298 1 0:00:14.779613
152 | 15000 300 300 1 0:00:15.087663
153 | 15100 302 302 1 0:00:14.960548
154 | 15200 304 301 0 0:00:14.948146
155 | 15300 306 306 1 0:00:14.997912
156 | 15400 308 308 1 0:00:14.958719
157 | 15500 310 310 1 0:00:14.940277
158 | 15600 312 315 0 0:00:14.941935
159 | 15700 314 314 1 0:00:14.972534
160 | 15800 316 316 1 0:00:15.045166
161 | 15900 318 318 1 0:00:14.816161
162 | 16000 320 320 1 0:00:14.905789
163 | 16100 322 322 1 0:00:14.972884
164 | 16200 324 324 1 0:00:14.849005
165 | 16300 326 326 1 0:00:15.113016
166 | 16400 328 328 1 0:00:14.819477
167 | 16500 330 330 1 0:00:14.900786
168 | 16600 332 153 0 0:00:14.899500
169 | 16700 334 334 1 0:00:15.079960
170 | 16800 336 336 1 0:00:14.817020
171 | 16900 338 617 0 0:00:14.896904
172 | 17000 340 340 1 0:00:15.101224
173 | 17100 342 287 0 0:00:14.789878
174 | 17200 344 344 1 0:00:15.015266
175 | 17300 346 344 0 0:00:14.911635
176 | 17400 348 348 1 0:00:15.007726
177 | 17500 350 350 1 0:00:15.036709
178 | 17600 352 352 1 0:00:15.037419
179 | 17700 354 354 1 0:00:14.989679
180 | 17800 356 -1 0 0:00:14.995699
181 | 17900 358 359 0 0:00:14.847023
182 | 18000 360 360 1 0:00:14.942973
183 | 18100 362 362 1 0:00:14.966301
184 | 18200 364 364 1 0:00:14.917976
185 | 18300 366 366 1 0:00:14.879833
186 | 18400 368 368 1 0:00:15.137949
187 | 18500 370 370 1 0:00:14.916142
188 | 18600 372 372 1 0:00:15.003802
189 | 18700 374 374 1 0:00:14.929094
190 | 18800 376 376 1 0:00:14.956331
191 | 18900 378 378 1 0:00:14.873452
192 | 19000 380 380 1 0:00:14.860228
193 | 19100 382 382 1 0:00:14.906570
194 | 19200 384 283 0 0:00:14.863993
195 | 19300 386 101 0 0:00:15.087887
196 | 19400 388 388 1 0:00:14.998850
197 | 19500 390 390 1 0:00:14.789777
198 | 19600 392 397 0 0:00:15.002732
199 | 19700 394 467 0 0:00:14.836903
200 | 19800 396 396 1 0:00:14.804411
201 | 19900 398 398 1 0:00:15.011312
202 | 20000 400 400 1 0:00:15.001071
203 | 20100 402 402 1 0:00:14.913904
204 | 20200 404 404 1 0:00:14.936100
205 | 20300 406 406 1 0:00:14.876220
206 | 20400 408 847 0 0:00:14.978644
207 | 20500 410 410 1 0:00:14.988457
208 | 20600 412 412 1 0:00:14.980928
209 | 20700 414 414 1 0:00:14.982497
210 | 20800 416 416 1 0:00:14.813848
211 | 20900 418 563 0 0:00:15.099722
212 | 21000 420 420 1 0:00:14.814135
213 | 21100 422 422 1 0:00:14.884715
214 | 21200 424 454 0 0:00:14.864878
215 | 21300 426 426 1 0:00:15.047416
216 | 21400 428 428 1 0:00:14.912828
217 | 21500 430 430 1 0:00:14.957142
218 | 21600 432 432 1 0:00:14.899586
219 | 21700 434 434 1 0:00:14.843033
220 | 21800 436 436 1 0:00:14.993052
221 | 21900 438 438 1 0:00:15.054873
222 | 22000 440 737 0 0:00:14.859920
223 | 22100 442 442 1 0:00:14.820955
224 | 22200 444 870 0 0:00:15.013117
225 | 22300 446 446 1 0:00:15.022072
226 | 22400 448 448 1 0:00:14.867987
227 | 22500 450 407 0 0:00:14.980548
228 | 22600 452 452 1 0:00:14.831118
229 | 22700 454 454 1 0:00:15.062666
230 | 22800 456 645 0 0:00:15.022581
231 | 22900 458 458 1 0:00:15.048858
232 | 23000 460 978 0 0:00:14.951904
233 | 23100 462 462 1 0:00:15.014614
234 | 23200 464 439 0 0:00:15.039416
235 | 23300 466 466 1 0:00:14.933714
236 | 23400 468 468 1 0:00:15.138598
237 | 23500 470 624 0 0:00:15.030802
238 | 23600 472 472 1 0:00:15.028487
239 | 23700 474 841 0 0:00:14.842982
240 | 23800 476 476 1 0:00:15.076778
241 | 23900 478 478 1 0:00:14.985259
242 | 24000 480 509 0 0:00:14.983824
243 | 24100 482 481 0 0:00:14.831785
244 | 24200 484 871 0 0:00:15.156861
245 | 24300 486 486 1 0:00:15.030223
246 | 24400 488 488 1 0:00:14.944442
247 | 24500 490 490 1 0:00:15.050789
248 | 24600 492 492 1 0:00:15.071643
249 | 24700 494 398 0 0:00:14.878071
250 | 24800 496 496 1 0:00:15.084099
251 | 24900 498 498 1 0:00:14.986719
252 | 25000 500 500 1 0:00:15.008059
253 | 25100 502 502 1 0:00:14.938144
254 | 25200 504 504 1 0:00:14.811468
255 | 25300 506 506 1 0:00:15.105104
256 | 25400 508 508 1 0:00:14.966163
257 | 25500 510 510 1 0:00:15.080637
258 | 25600 512 740 0 0:00:14.948434
259 | 25700 514 514 1 0:00:14.932904
260 | 25800 516 431 0 0:00:14.829576
261 | 25900 518 518 1 0:00:14.968691
262 | 26000 520 520 1 0:00:14.801010
263 | 26100 522 522 1 0:00:14.894444
264 | 26200 524 461 0 0:00:15.070504
265 | 26300 526 526 1 0:00:14.958834
266 | 26400 528 478 0 0:00:14.938254
267 | 26500 530 531 0 0:00:14.946788
268 | 26600 532 532 1 0:00:15.041339
269 | 26700 534 534 1 0:00:14.967383
270 | 26800 536 403 0 0:00:14.814197
271 | 26900 538 538 1 0:00:14.946787
272 | 27000 540 540 1 0:00:14.909997
273 | 27100 542 477 0 0:00:15.057261
274 | 27200 544 926 0 0:00:14.985470
275 | 27300 546 546 1 0:00:15.035207
276 | 27400 548 548 1 0:00:14.893974
277 | 27500 550 505 0 0:00:14.990738
278 | 27600 552 552 1 0:00:15.071906
279 | 27700 554 554 1 0:00:14.799304
280 | 27800 556 421 0 0:00:15.033077
281 | 27900 558 251 0 0:00:15.025178
282 | 28000 560 768 0 0:00:14.976117
283 | 28100 562 562 1 0:00:14.799901
284 | 28200 564 564 1 0:00:14.838296
285 | 28300 566 566 1 0:00:15.089116
286 | 28400 568 399 0 0:00:14.900311
287 | 28500 570 -1 0 0:00:15.044947
288 | 28600 572 418 0 0:00:14.849125
289 | 28700 574 574 1 0:00:14.926119
290 | 28800 576 -1 0 0:00:14.965969
291 | 28900 578 578 1 0:00:14.859858
292 | 29000 580 738 0 0:00:15.022806
293 | 29100 582 582 1 0:00:14.958454
294 | 29200 584 754 0 0:00:14.904257
295 | 29300 586 586 1 0:00:15.074695
296 | 29400 588 588 1 0:00:14.977583
297 | 29500 590 590 1 0:00:14.941264
298 | 29600 592 592 1 0:00:14.933073
299 | 29700 594 843 0 0:00:15.028679
300 | 29800 596 56 0 0:00:15.067989
301 | 29900 598 664 0 0:00:14.945151
302 | 30000 600 517 0 0:00:14.936215
303 | 30100 602 602 1 0:00:14.965253
304 | 30200 604 604 1 0:00:14.856193
305 | 30300 606 606 1 0:00:15.016406
306 | 30400 608 608 1 0:00:15.112201
307 | 30500 610 800 0 0:00:14.789717
308 | 30600 612 612 1 0:00:15.035326
309 | 30700 614 614 1 0:00:15.098388
310 | 30800 616 534 0 0:00:14.907665
311 | 30900 618 828 0 0:00:14.832339
312 | 31000 620 681 0 0:00:15.147113
313 | 31100 622 622 1 0:00:14.945833
314 | 31200 624 453 0 0:00:14.915479
315 | 31300 626 542 0 0:00:14.887706
316 | 31400 628 624 0 0:00:15.180945
317 | 31500 630 703 0 0:00:14.891687
318 | 31600 632 632 1 0:00:15.057439
319 | 31700 634 489 0 0:00:15.016377
320 | 31800 636 636 1 0:00:14.963135
321 | 31900 638 638 1 0:00:14.799041
322 | 32000 640 640 1 0:00:14.927938
323 | 32100 642 642 1 0:00:15.120875
324 | 32200 644 644 1 0:00:14.877543
325 | 32300 646 646 1 0:00:14.940699
326 | 32400 648 729 0 0:00:14.874189
327 | 32500 650 546 0 0:00:15.097237
328 | 32600 652 652 1 0:00:14.962884
329 | 32700 654 436 0 0:00:14.872689
330 | 32800 656 656 1 0:00:15.005286
331 | 32900 658 658 1 0:00:14.917542
332 | 33000 660 660 1 0:00:14.839606
333 | 33100 662 662 1 0:00:14.892756
334 | 33200 664 851 0 0:00:14.911993
335 | 33300 666 659 0 0:00:14.796421
336 | 33400 668 668 1 0:00:14.824733
337 | 33500 670 670 1 0:00:15.041363
338 | 33600 672 672 1 0:00:15.056527
339 | 33700 674 674 1 0:00:14.825686
340 | 33800 676 560 0 0:00:14.931751
341 | 33900 678 678 1 0:00:14.853488
342 | 34000 680 793 0 0:00:14.958912
343 | 34100 682 682 1 0:00:15.030425
344 | 34200 684 684 1 0:00:14.845501
345 | 34300 686 550 0 0:00:15.033074
346 | 34400 688 688 1 0:00:14.988053
347 | 34500 690 690 1 0:00:14.943735
348 | 34600 692 481 0 0:00:14.928473
349 | 34700 694 694 1 0:00:14.829778
350 | 34800 696 -1 0 0:00:15.074387
351 | 34900 698 698 1 0:00:15.019027
352 | 35000 700 700 1 0:00:15.032050
353 | 35100 702 422 0 0:00:15.006292
354 | 35200 704 704 1 0:00:14.928123
355 | 35300 706 536 0 0:00:14.818038
356 | 35400 708 682 0 0:00:15.055249
357 | 35500 710 710 1 0:00:14.948738
358 | 35600 712 622 0 0:00:15.094625
359 | 35700 714 714 1 0:00:14.915263
360 | 35800 716 716 1 0:00:14.976116
361 | 35900 718 449 0 0:00:14.864276
362 | 36000 720 720 1 0:00:15.096854
363 | 36100 722 722 1 0:00:14.921822
364 | 36200 724 780 0 0:00:14.832085
365 | 36300 726 726 1 0:00:15.146349
366 | 36400 728 728 1 0:00:14.939208
367 | 36500 730 856 0 0:00:14.883153
368 | 36600 732 732 1 0:00:15.047479
369 | 36700 734 734 1 0:00:14.925721
370 | 36800 736 736 1 0:00:14.927071
371 | 36900 738 738 1 0:00:14.864711
372 | 37000 740 796 0 0:00:15.075991
373 | 37100 742 466 0 0:00:14.987401
374 | 37200 744 657 0 0:00:14.830023
375 | 37300 746 746 1 0:00:14.953343
376 | 37400 748 748 1 0:00:14.797044
377 | 37500 750 172 0 0:00:14.995230
378 | 37600 752 830 0 0:00:14.965986
379 | 37700 754 848 0 0:00:14.849808
380 | 37800 756 756 1 0:00:14.833296
381 | 37900 758 758 1 0:00:14.914641
382 | 38000 760 651 0 0:00:14.921158
383 | 38100 762 762 1 0:00:15.031013
384 | 38200 764 593 0 0:00:14.855857
385 | 38300 766 766 1 0:00:14.836674
386 | 38400 768 768 1 0:00:14.937352
387 | 38500 770 770 1 0:00:14.985036
388 | 38600 772 772 1 0:00:15.019809
389 | 38700 774 774 1 0:00:14.958676
390 | 38800 776 683 0 0:00:14.881106
391 | 38900 778 707 0 0:00:15.088828
392 | 39000 780 780 1 0:00:14.941449
393 | 39100 782 664 0 0:00:14.912092
394 | 39200 784 505 0 0:00:15.034298
395 | 39300 786 513 0 0:00:14.912165
396 | 39400 788 743 0 0:00:14.926766
397 | 39500 790 791 0 0:00:14.947411
398 | 39600 792 795 0 0:00:14.923461
399 | 39700 794 794 1 0:00:15.040500
400 | 39800 796 796 1 0:00:14.828018
401 | 39900 798 798 1 0:00:15.105424
402 | 40000 800 800 1 0:00:15.035649
403 | 40100 802 802 1 0:00:14.857424
404 | 40200 804 804 1 0:00:14.936726
405 | 40300 806 806 1 0:00:14.901794
406 | 40400 808 808 1 0:00:14.920217
407 | 40500 810 878 0 0:00:14.899759
408 | 40600 812 812 1 0:00:14.947197
409 | 40700 814 814 1 0:00:14.900511
410 | 40800 816 816 1 0:00:15.095273
411 | 40900 818 745 0 0:00:15.008224
412 | 41000 820 820 1 0:00:15.002912
413 | 41100 822 822 1 0:00:14.927240
414 | 41200 824 824 1 0:00:14.942341
415 | 41300 826 826 1 0:00:14.871397
416 | 41400 828 618 0 0:00:14.840422
417 | 41500 830 830 1 0:00:14.942988
418 | 41600 832 832 1 0:00:15.013025
419 | 41700 834 417 0 0:00:15.061307
420 | 41800 836 444 0 0:00:14.924838
421 | 41900 838 631 0 0:00:14.804935
422 | 42000 840 840 1 0:00:14.906883
423 | 42100 842 842 1 0:00:14.838626
424 | 42200 844 844 1 0:00:15.125461
425 | 42300 846 583 0 0:00:14.892340
426 | 42400 848 848 1 0:00:14.995593
427 | 42500 850 850 1 0:00:15.004535
428 | 42600 852 852 1 0:00:15.016846
429 | 42700 854 854 1 0:00:14.969559
430 | 42800 856 856 1 0:00:14.853972
431 | 42900 858 832 0 0:00:15.011433
432 | 43000 860 892 0 0:00:14.879105
433 | 43100 862 862 1 0:00:14.824035
434 | 43200 864 864 1 0:00:15.063193
435 | 43300 866 866 1 0:00:15.115438
436 | 43400 868 968 0 0:00:15.104931
437 | 43500 870 870 1 0:00:14.852551
438 | 43600 872 840 0 0:00:15.085425
439 | 43700 874 874 1 0:00:15.004121
440 | 43800 876 648 0 0:00:14.948926
441 | 43900 878 878 1 0:00:14.924355
442 | 44000 880 870 0 0:00:15.053208
443 | 44100 882 882 1 0:00:14.870412
444 | 44200 884 884 1 0:00:15.067243
445 | 44300 886 886 1 0:00:15.063395
446 | 44400 888 718 0 0:00:14.863382
447 | 44500 890 890 1 0:00:14.956937
448 | 44600 892 442 0 0:00:15.065577
449 | 44700 894 894 1 0:00:14.913900
450 | 44800 896 896 1 0:00:14.845646
451 | 44900 898 508 0 0:00:15.041700
452 | 45000 900 900 1 0:00:15.036083
453 | 45100 902 902 1 0:00:14.869172
454 | 45200 904 905 0 0:00:14.926249
455 | 45300 906 906 1 0:00:14.980194
456 | 45400 908 895 0 0:00:14.923142
457 | 45500 910 910 1 0:00:15.059851
458 | 45600 912 912 1 0:00:14.834315
459 | 45700 914 914 1 0:00:14.801381
460 | 45800 916 916 1 0:00:15.072167
461 | 45900 918 918 1 0:00:14.886890
462 | 46000 920 672 0 0:00:14.905368
463 | 46100 922 922 1 0:00:15.158672
464 | 46200 924 924 1 0:00:14.959707
465 | 46300 926 926 1 0:00:15.009355
466 | 46400 928 927 0 0:00:14.835334
467 | 46500 930 964 0 0:00:14.963948
468 | 46600 932 932 1 0:00:14.963856
469 | 46700 934 934 1 0:00:15.037371
470 | 46800 936 936 1 0:00:14.889781
471 | 46900 938 938 1 0:00:15.001535
472 | 47000 940 940 1 0:00:15.078541
473 | 47100 942 942 1 0:00:14.961972
474 | 47200 944 946 0 0:00:14.824271
475 | 47300 946 946 1 0:00:15.006133
476 | 47400 948 948 1 0:00:14.943808
477 | 47500 950 950 1 0:00:14.915371
478 | 47600 952 90 0 0:00:14.829384
479 | 47700 954 954 1 0:00:14.983240
480 | 47800 956 956 1 0:00:15.007075
481 | 47900 958 483 0 0:00:14.995762
482 | 48000 960 960 1 0:00:14.903471
483 | 48100 962 962 1 0:00:14.842755
484 | 48200 964 961 0 0:00:15.000199
485 | 48300 966 630 0 0:00:14.848870
486 | 48400 968 968 1 0:00:15.047162
487 | 48500 970 795 0 0:00:14.873642
488 | 48600 972 979 0 0:00:14.915262
489 | 48700 974 974 1 0:00:15.077598
490 | 48800 976 724 0 0:00:14.963078
491 | 48900 978 701 0 0:00:14.855375
492 | 49000 980 980 1 0:00:14.798299
493 | 49100 982 982 1 0:00:15.025065
494 | 49200 984 984 1 0:00:14.829287
495 | 49300 986 986 1 0:00:14.985315
496 | 49400 988 988 1 0:00:14.922611
497 | 49500 990 990 1 0:00:15.019418
498 | 49600 992 992 1 0:00:14.923599
499 | 49700 994 994 1 0:00:14.888789
500 | 49800 996 996 1 0:00:14.961667
501 | 49900 998 998 1 0:00:14.941916
502 |
--------------------------------------------------------------------------------
/data/predict/imagenet/resnet50/noise_0.25/test/N_100000:
--------------------------------------------------------------------------------
1 | idx label predict correct time
2 | 0 0 0 1 0:02:33.474448
3 | 100 2 2 1 0:02:31.267941
4 | 200 4 395 0 0:02:30.693998
5 | 300 6 6 1 0:02:31.397363
6 | 400 8 8 1 0:02:31.735987
7 | 500 10 10 1 0:02:31.113407
8 | 600 12 12 1 0:02:30.951247
9 | 700 14 14 1 0:02:31.248816
10 | 800 16 16 1 0:02:31.478386
11 | 900 18 18 1 0:02:31.303808
12 | 1000 20 20 1 0:02:31.340990
13 | 1100 22 21 0 0:02:31.461657
14 | 1200 24 24 1 0:02:31.166599
15 | 1300 26 26 1 0:02:31.435601
16 | 1400 28 28 1 0:02:31.746241
17 | 1500 30 30 1 0:02:31.415279
18 | 1600 32 103 0 0:02:31.048611
19 | 1700 34 327 0 0:02:31.615164
20 | 1800 36 35 0 0:02:31.247375
21 | 1900 38 38 1 0:02:31.175008
22 | 2000 40 31 0 0:02:31.079274
23 | 2100 42 48 0 0:02:31.412859
24 | 2200 44 44 1 0:02:31.535267
25 | 2300 46 46 1 0:02:31.560232
26 | 2400 48 48 1 0:02:31.527727
27 | 2500 50 50 1 0:02:31.395422
28 | 2600 52 52 1 0:02:31.524686
29 | 2700 54 57 0 0:02:31.278730
30 | 2800 56 56 1 0:02:31.561788
31 | 2900 58 58 1 0:02:31.469573
32 | 3000 60 62 0 0:02:31.659518
33 | 3100 62 62 1 0:02:31.540611
34 | 3200 64 64 1 0:02:31.528662
35 | 3300 66 63 0 0:02:31.605767
36 | 3400 68 111 0 0:02:31.528680
37 | 3500 70 70 1 0:02:31.388872
38 | 3600 72 72 1 0:02:31.339453
39 | 3700 74 74 1 0:02:31.459056
40 | 3800 76 76 1 0:02:31.597488
41 | 3900 78 78 1 0:02:31.864876
42 | 4000 80 80 1 0:02:31.824401
43 | 4100 82 82 1 0:02:31.586974
44 | 4200 84 84 1 0:02:31.467104
45 | 4300 86 86 1 0:02:31.449934
46 | 4400 88 88 1 0:02:31.694547
47 | 4500 90 90 1 0:02:31.714680
48 | 4600 92 92 1 0:02:31.601055
49 | 4700 94 94 1 0:02:31.759537
50 | 4800 96 96 1 0:02:31.570559
51 | 4900 98 98 1 0:02:30.771919
52 | 5000 100 100 1 0:02:29.871229
53 | 5100 102 102 1 0:02:29.855997
54 | 5200 104 342 0 0:02:30.007782
55 | 5300 106 106 1 0:02:29.889595
56 | 5400 108 470 0 0:02:29.999557
57 | 5500 110 110 1 0:02:30.123106
58 | 5600 112 112 1 0:02:30.096990
59 | 5700 114 113 0 0:02:29.985586
60 | 5800 116 116 1 0:02:30.067943
61 | 5900 118 118 1 0:02:30.032640
62 | 6000 120 119 0 0:02:29.894344
63 | 6100 122 122 1 0:02:30.045305
64 | 6200 124 123 0 0:02:29.928218
65 | 6300 126 126 1 0:02:29.826290
66 | 6400 128 128 1 0:02:29.844152
67 | 6500 130 130 1 0:02:29.743992
68 | 6600 132 132 1 0:02:29.878203
69 | 6700 134 132 0 0:02:29.878030
70 | 6800 136 136 1 0:02:29.977701
71 | 6900 138 138 1 0:02:29.961860
72 | 7000 140 140 1 0:02:29.987991
73 | 7100 142 142 1 0:02:29.983496
74 | 7200 144 144 1 0:02:29.934271
75 | 7300 146 146 1 0:02:30.003505
76 | 7400 148 148 1 0:02:30.013247
77 | 7500 150 360 0 0:02:30.092299
78 | 7600 152 152 1 0:02:30.050198
79 | 7700 154 154 1 0:02:29.898821
80 | 7800 156 156 1 0:02:29.955264
81 | 7900 158 158 1 0:02:30.089475
82 | 8000 160 160 1 0:02:30.007120
83 | 8100 162 162 1 0:02:29.897341
84 | 8200 164 164 1 0:02:29.929127
85 | 8300 166 166 1 0:02:29.948314
86 | 8400 168 211 0 0:02:29.853677
87 | 8500 170 172 0 0:02:29.910112
88 | 8600 172 172 1 0:02:29.825488
89 | 8700 174 174 1 0:02:29.876307
90 | 8800 176 176 1 0:02:29.820867
91 | 8900 178 178 1 0:02:29.995679
92 | 9000 180 243 0 0:02:29.832195
93 | 9100 182 182 1 0:02:29.864980
94 | 9200 184 188 0 0:02:29.972162
95 | 9300 186 192 0 0:02:29.873754
96 | 9400 188 188 1 0:02:29.842876
97 | 9500 190 190 1 0:02:29.861335
98 | 9600 192 199 0 0:02:30.022910
99 | 9700 194 202 0 0:02:30.002872
100 | 9800 196 199 0 0:02:30.027588
101 | 9900 198 196 0 0:02:29.992388
102 | 10000 200 200 1 0:02:29.942592
103 | 10100 202 202 1 0:02:29.897604
104 | 10200 204 204 1 0:02:29.856708
105 | 10300 206 206 1 0:02:29.787749
106 | 10400 208 208 1 0:02:29.881501
107 | 10500 210 210 1 0:02:29.846165
108 | 10600 212 212 1 0:02:29.949987
109 | 10700 214 214 1 0:02:29.946611
110 | 10800 216 216 1 0:02:29.918387
111 | 10900 218 218 1 0:02:30.054871
112 | 11000 220 220 1 0:02:30.114067
113 | 11100 222 222 1 0:02:29.941162
114 | 11200 224 233 0 0:02:30.051554
115 | 11300 226 226 1 0:02:30.096485
116 | 11400 228 228 1 0:02:29.820301
117 | 11500 230 230 1 0:02:29.818161
118 | 11600 232 232 1 0:02:30.023631
119 | 11700 234 234 1 0:02:29.966839
120 | 11800 236 234 0 0:02:30.053770
121 | 11900 238 238 1 0:02:29.941641
122 | 12000 240 238 0 0:02:29.944469
123 | 12100 242 242 1 0:02:29.903479
124 | 12200 244 244 1 0:02:30.013111
125 | 12300 246 251 0 0:02:29.945586
126 | 12400 248 248 1 0:02:30.064157
127 | 12500 250 169 0 0:02:29.971363
128 | 12600 252 252 1 0:02:30.167051
129 | 12700 254 254 1 0:02:30.059607
130 | 12800 256 256 1 0:02:30.098028
131 | 12900 258 258 1 0:02:30.070731
132 | 13000 260 260 1 0:02:30.081086
133 | 13100 262 262 1 0:02:30.182859
134 | 13200 264 264 1 0:02:30.012478
135 | 13300 266 266 1 0:02:29.938711
136 | 13400 268 268 1 0:02:30.146840
137 | 13500 270 270 1 0:02:30.072387
138 | 13600 272 274 0 0:02:30.042817
139 | 13700 274 274 1 0:02:30.089741
140 | 13800 276 276 1 0:02:30.056051
141 | 13900 278 278 1 0:02:30.044791
142 | 14000 280 280 1 0:02:30.148042
143 | 14100 282 282 1 0:02:30.111140
144 | 14200 284 284 1 0:02:30.040564
145 | 14300 286 286 1 0:02:30.077097
146 | 14400 288 288 1 0:02:30.130719
147 | 14500 290 290 1 0:02:30.088434
148 | 14600 292 290 0 0:02:30.055848
149 | 14700 294 294 1 0:02:29.988939
150 | 14800 296 296 1 0:02:29.772256
151 | 14900 298 298 1 0:02:29.949858
152 | 15000 300 300 1 0:02:30.033037
153 | 15100 302 302 1 0:02:30.070987
154 | 15200 304 301 0 0:02:30.089978
155 | 15300 306 306 1 0:02:30.067981
156 | 15400 308 308 1 0:02:30.042348
157 | 15500 310 310 1 0:02:30.006768
158 | 15600 312 315 0 0:02:30.003502
159 | 15700 314 314 1 0:02:29.950488
160 | 15800 316 316 1 0:02:30.047839
161 | 15900 318 318 1 0:02:30.017167
162 | 16000 320 320 1 0:02:29.991655
163 | 16100 322 322 1 0:02:30.104580
164 | 16200 324 324 1 0:02:29.992422
165 | 16300 326 326 1 0:02:29.870930
166 | 16400 328 328 1 0:02:29.891489
167 | 16500 330 330 1 0:02:29.890625
168 | 16600 332 153 0 0:02:29.929097
169 | 16700 334 334 1 0:02:29.997736
170 | 16800 336 336 1 0:02:30.078620
171 | 16900 338 617 0 0:02:29.958100
172 | 17000 340 340 1 0:02:29.913225
173 | 17100 342 287 0 0:02:29.886427
174 | 17200 344 344 1 0:02:30.053031
175 | 17300 346 344 0 0:02:29.932767
176 | 17400 348 348 1 0:02:30.012097
177 | 17500 350 350 1 0:02:30.171809
178 | 17600 352 352 1 0:02:30.087436
179 | 17700 354 354 1 0:02:29.947427
180 | 17800 356 357 0 0:02:30.191751
181 | 17900 358 359 0 0:02:30.051043
182 | 18000 360 360 1 0:02:30.036187
183 | 18100 362 362 1 0:02:29.914627
184 | 18200 364 364 1 0:02:29.950982
185 | 18300 366 366 1 0:02:29.918943
186 | 18400 368 368 1 0:02:29.854911
187 | 18500 370 370 1 0:02:29.836591
188 | 18600 372 372 1 0:02:29.889367
189 | 18700 374 374 1 0:02:29.889299
190 | 18800 376 376 1 0:02:29.975704
191 | 18900 378 378 1 0:02:30.087036
192 | 19000 380 380 1 0:02:30.009402
193 | 19100 382 382 1 0:02:29.955765
194 | 19200 384 283 0 0:02:30.055064
195 | 19300 386 101 0 0:02:29.988809
196 | 19400 388 388 1 0:02:29.895886
197 | 19500 390 390 1 0:02:29.903945
198 | 19600 392 397 0 0:02:30.082440
199 | 19700 394 467 0 0:02:29.983311
200 | 19800 396 396 1 0:02:29.908767
201 | 19900 398 398 1 0:02:30.002472
202 | 20000 400 400 1 0:02:30.054005
203 | 20100 402 402 1 0:02:29.959074
204 | 20200 404 404 1 0:02:29.876369
205 | 20300 406 406 1 0:02:30.077703
206 | 20400 408 847 0 0:02:30.047797
207 | 20500 410 410 1 0:02:30.044914
208 | 20600 412 412 1 0:02:30.040994
209 | 20700 414 414 1 0:02:29.985378
210 | 20800 416 416 1 0:02:29.940233
211 | 20900 418 563 0 0:02:29.924270
212 | 21000 420 420 1 0:02:29.946433
213 | 21100 422 422 1 0:02:29.951842
214 | 21200 424 454 0 0:02:29.952231
215 | 21300 426 426 1 0:02:29.728515
216 | 21400 428 428 1 0:02:29.755238
217 | 21500 430 430 1 0:02:29.863588
218 | 21600 432 432 1 0:02:29.885156
219 | 21700 434 434 1 0:02:29.853207
220 | 21800 436 436 1 0:02:29.946874
221 | 21900 438 438 1 0:02:29.944616
222 | 22000 440 737 0 0:02:29.930953
223 | 22100 442 442 1 0:02:29.986371
224 | 22200 444 870 0 0:02:29.998457
225 | 22300 446 446 1 0:02:29.934899
226 | 22400 448 448 1 0:02:29.868617
227 | 22500 450 407 0 0:02:29.756166
228 | 22600 452 452 1 0:02:29.863274
229 | 22700 454 454 1 0:02:29.971981
230 | 22800 456 645 0 0:02:30.003376
231 | 22900 458 458 1 0:02:29.914280
232 | 23000 460 978 0 0:02:30.226463
233 | 23100 462 462 1 0:02:29.935791
234 | 23200 464 439 0 0:02:29.818337
235 | 23300 466 466 1 0:02:29.930532
236 | 23400 468 468 1 0:02:29.931208
237 | 23500 470 624 0 0:02:30.035671
238 | 23600 472 472 1 0:02:29.943951
239 | 23700 474 841 0 0:02:29.934437
240 | 23800 476 476 1 0:02:30.122416
241 | 23900 478 478 1 0:02:29.828464
242 | 24000 480 509 0 0:02:30.027221
243 | 24100 482 481 0 0:02:29.997922
244 | 24200 484 871 0 0:02:29.806696
245 | 24300 486 486 1 0:02:29.874212
246 | 24400 488 488 1 0:02:29.793973
247 | 24500 490 490 1 0:02:29.869747
248 | 24600 492 492 1 0:02:29.841316
249 | 24700 494 398 0 0:02:29.811148
250 | 24800 496 496 1 0:02:29.910695
251 | 24900 498 498 1 0:02:29.966648
252 | 25000 500 500 1 0:02:29.991739
253 | 25100 502 502 1 0:02:30.041048
254 | 25200 504 504 1 0:02:30.038781
255 | 25300 506 506 1 0:02:29.979415
256 | 25400 508 508 1 0:02:30.084711
257 | 25500 510 510 1 0:02:30.006354
258 | 25600 512 740 0 0:02:30.084266
259 | 25700 514 514 1 0:02:30.032617
260 | 25800 516 431 0 0:02:29.789236
261 | 25900 518 518 1 0:02:29.920027
262 | 26000 520 520 1 0:02:29.996022
263 | 26100 522 522 1 0:02:30.083467
264 | 26200 524 461 0 0:02:30.074240
265 | 26300 526 526 1 0:02:29.917449
266 | 26400 528 478 0 0:02:29.928078
267 | 26500 530 531 0 0:02:29.899624
268 | 26600 532 532 1 0:02:30.009423
269 | 26700 534 534 1 0:02:30.048200
270 | 26800 536 403 0 0:02:30.077787
271 | 26900 538 538 1 0:02:30.006213
272 | 27000 540 540 1 0:02:30.003153
273 | 27100 542 477 0 0:02:30.079738
274 | 27200 544 926 0 0:02:29.935786
275 | 27300 546 546 1 0:02:30.040292
276 | 27400 548 548 1 0:02:29.942731
277 | 27500 550 505 0 0:02:30.162599
278 | 27600 552 552 1 0:02:30.197462
279 | 27700 554 554 1 0:02:29.881878
280 | 27800 556 421 0 0:02:30.003082
281 | 27900 558 251 0 0:02:29.969123
282 | 28000 560 768 0 0:02:30.061316
283 | 28100 562 562 1 0:02:29.891487
284 | 28200 564 564 1 0:02:29.852345
285 | 28300 566 566 1 0:02:29.913517
286 | 28400 568 399 0 0:02:29.937221
287 | 28500 570 -1 0 0:02:29.939591
288 | 28600 572 418 0 0:02:30.063215
289 | 28700 574 574 1 0:02:29.918890
290 | 28800 576 576 1 0:02:30.023425
291 | 28900 578 578 1 0:02:29.880543
292 | 29000 580 738 0 0:02:29.832065
293 | 29100 582 582 1 0:02:30.020106
294 | 29200 584 754 0 0:02:30.010541
295 | 29300 586 586 1 0:02:29.950467
296 | 29400 588 588 1 0:02:29.812419
297 | 29500 590 590 1 0:02:29.732527
298 | 29600 592 592 1 0:02:29.842051
299 | 29700 594 843 0 0:02:29.837906
300 | 29800 596 56 0 0:02:29.837575
301 | 29900 598 664 0 0:02:29.940911
302 | 30000 600 517 0 0:02:29.977010
303 | 30100 602 602 1 0:02:29.953506
304 | 30200 604 604 1 0:02:30.107783
305 | 30300 606 606 1 0:02:30.022856
306 | 30400 608 608 1 0:02:30.136491
307 | 30500 610 800 0 0:02:29.965505
308 | 30600 612 612 1 0:02:30.079032
309 | 30700 614 614 1 0:02:30.212737
310 | 30800 616 534 0 0:02:30.179169
311 | 30900 618 828 0 0:02:29.963279
312 | 31000 620 681 0 0:02:30.197027
313 | 31100 622 622 1 0:02:29.996470
314 | 31200 624 453 0 0:02:30.116919
315 | 31300 626 542 0 0:02:30.134142
316 | 31400 628 624 0 0:02:30.071232
317 | 31500 630 703 0 0:02:30.077622
318 | 31600 632 632 1 0:02:29.860154
319 | 31700 634 489 0 0:02:30.070970
320 | 31800 636 636 1 0:02:30.025989
321 | 31900 638 638 1 0:02:30.038827
322 | 32000 640 640 1 0:02:29.923118
323 | 32100 642 642 1 0:02:29.869806
324 | 32200 644 644 1 0:02:29.827098
325 | 32300 646 646 1 0:02:29.858985
326 | 32400 648 729 0 0:02:29.997431
327 | 32500 650 546 0 0:02:30.006125
328 | 32600 652 652 1 0:02:29.955058
329 | 32700 654 436 0 0:02:29.910161
330 | 32800 656 656 1 0:02:29.853518
331 | 32900 658 658 1 0:02:29.789470
332 | 33000 660 660 1 0:02:29.846462
333 | 33100 662 662 1 0:02:29.912663
334 | 33200 664 851 0 0:02:29.911196
335 | 33300 666 659 0 0:02:29.956095
336 | 33400 668 668 1 0:02:29.996228
337 | 33500 670 670 1 0:02:29.986417
338 | 33600 672 672 1 0:02:29.827720
339 | 33700 674 674 1 0:02:29.967539
340 | 33800 676 560 0 0:02:30.012731
341 | 33900 678 678 1 0:02:29.908836
342 | 34000 680 793 0 0:02:29.788307
343 | 34100 682 682 1 0:02:29.894981
344 | 34200 684 684 1 0:02:29.966279
345 | 34300 686 550 0 0:02:29.915007
346 | 34400 688 688 1 0:02:29.914888
347 | 34500 690 690 1 0:02:30.002952
348 | 34600 692 481 0 0:02:29.898020
349 | 34700 694 694 1 0:02:29.990453
350 | 34800 696 -1 0 0:02:29.979518
351 | 34900 698 698 1 0:02:29.978442
352 | 35000 700 700 1 0:02:29.829307
353 | 35100 702 422 0 0:02:29.926219
354 | 35200 704 704 1 0:02:30.086647
355 | 35300 706 536 0 0:02:30.101645
356 | 35400 708 682 0 0:02:29.961313
357 | 35500 710 710 1 0:02:30.063833
358 | 35600 712 622 0 0:02:29.946512
359 | 35700 714 714 1 0:02:29.815917
360 | 35800 716 716 1 0:02:29.909783
361 | 35900 718 449 0 0:02:29.789578
362 | 36000 720 720 1 0:02:29.964153
363 | 36100 722 722 1 0:02:29.866087
364 | 36200 724 780 0 0:02:29.745715
365 | 36300 726 726 1 0:02:29.968047
366 | 36400 728 728 1 0:02:29.892893
367 | 36500 730 856 0 0:02:29.996930
368 | 36600 732 732 1 0:02:30.164158
369 | 36700 734 734 1 0:02:30.062919
370 | 36800 736 736 1 0:02:29.980641
371 | 36900 738 738 1 0:02:30.107528
372 | 37000 740 796 0 0:02:30.099491
373 | 37100 742 466 0 0:02:29.944660
374 | 37200 744 657 0 0:02:29.855458
375 | 37300 746 746 1 0:02:29.868380
376 | 37400 748 748 1 0:02:29.954246
377 | 37500 750 172 0 0:02:29.896917
378 | 37600 752 830 0 0:02:29.989432
379 | 37700 754 848 0 0:02:29.823816
380 | 37800 756 756 1 0:02:29.879869
381 | 37900 758 758 1 0:02:29.909168
382 | 38000 760 651 0 0:02:29.897002
383 | 38100 762 762 1 0:02:29.806436
384 | 38200 764 593 0 0:02:29.910020
385 | 38300 766 766 1 0:02:29.998693
386 | 38400 768 768 1 0:02:30.028741
387 | 38500 770 770 1 0:02:29.972508
388 | 38600 772 772 1 0:02:29.958522
389 | 38700 774 774 1 0:02:29.937070
390 | 38800 776 683 0 0:02:29.718288
391 | 38900 778 707 0 0:02:29.885120
392 | 39000 780 780 1 0:02:29.704233
393 | 39100 782 664 0 0:02:29.688010
394 | 39200 784 505 0 0:02:29.858668
395 | 39300 786 513 0 0:02:29.847906
396 | 39400 788 743 0 0:02:29.873123
397 | 39500 790 791 0 0:02:29.793333
398 | 39600 792 795 0 0:02:30.019487
399 | 39700 794 794 1 0:02:30.052389
400 | 39800 796 796 1 0:02:29.873432
401 | 39900 798 798 1 0:02:29.815152
402 | 40000 800 800 1 0:02:29.987385
403 | 40100 802 802 1 0:02:29.892680
404 | 40200 804 804 1 0:02:29.874876
405 | 40300 806 806 1 0:02:29.806446
406 | 40400 808 808 1 0:02:29.953785
407 | 40500 810 878 0 0:02:29.862812
408 | 40600 812 812 1 0:02:30.005352
409 | 40700 814 814 1 0:02:29.910957
410 | 40800 816 816 1 0:02:29.965954
411 | 40900 818 745 0 0:02:30.010599
412 | 41000 820 820 1 0:02:29.833271
413 | 41100 822 822 1 0:02:29.792163
414 | 41200 824 824 1 0:02:29.873111
415 | 41300 826 826 1 0:02:29.984101
416 | 41400 828 618 0 0:02:29.830313
417 | 41500 830 830 1 0:02:29.967694
418 | 41600 832 832 1 0:02:29.843198
419 | 41700 834 417 0 0:02:29.856908
420 | 41800 836 444 0 0:02:29.943907
421 | 41900 838 631 0 0:02:29.875778
422 | 42000 840 840 1 0:02:29.994829
423 | 42100 842 842 1 0:02:30.153473
424 | 42200 844 844 1 0:02:29.808288
425 | 42300 846 583 0 0:02:30.014881
426 | 42400 848 848 1 0:02:30.071110
427 | 42500 850 850 1 0:02:30.020923
428 | 42600 852 852 1 0:02:29.896731
429 | 42700 854 854 1 0:02:29.905206
430 | 42800 856 856 1 0:02:29.816986
431 | 42900 858 832 0 0:02:29.966872
432 | 43000 860 892 0 0:02:30.035207
433 | 43100 862 862 1 0:02:29.931289
434 | 43200 864 864 1 0:02:29.939961
435 | 43300 866 866 1 0:02:29.974598
436 | 43400 868 968 0 0:02:29.877698
437 | 43500 870 870 1 0:02:29.989906
438 | 43600 872 840 0 0:02:29.967932
439 | 43700 874 874 1 0:02:29.878560
440 | 43800 876 648 0 0:02:29.746216
441 | 43900 878 878 1 0:02:29.868720
442 | 44000 880 870 0 0:02:29.887129
443 | 44100 882 882 1 0:02:29.831625
444 | 44200 884 884 1 0:02:29.907423
445 | 44300 886 886 1 0:02:29.848836
446 | 44400 888 718 0 0:02:29.699474
447 | 44500 890 890 1 0:02:30.080145
448 | 44600 892 442 0 0:02:29.891949
449 | 44700 894 894 1 0:02:29.743512
450 | 44800 896 896 1 0:02:29.778889
451 | 44900 898 508 0 0:02:29.825377
452 | 45000 900 900 1 0:02:29.843363
453 | 45100 902 902 1 0:02:29.768915
454 | 45200 904 905 0 0:02:29.945003
455 | 45300 906 906 1 0:02:29.963234
456 | 45400 908 895 0 0:02:29.794284
457 | 45500 910 910 1 0:02:29.881790
458 | 45600 912 912 1 0:02:29.877916
459 | 45700 914 914 1 0:02:29.756244
460 | 45800 916 916 1 0:02:29.751879
461 | 45900 918 918 1 0:02:29.787862
462 | 46000 920 672 0 0:02:29.753602
463 | 46100 922 922 1 0:02:29.704010
464 | 46200 924 924 1 0:02:29.787180
465 | 46300 926 926 1 0:02:29.900919
466 | 46400 928 927 0 0:02:29.837027
467 | 46500 930 964 0 0:02:29.734897
468 | 46600 932 932 1 0:02:29.875139
469 | 46700 934 934 1 0:02:29.888191
470 | 46800 936 936 1 0:02:29.858023
471 | 46900 938 938 1 0:02:29.856562
472 | 47000 940 940 1 0:02:29.896224
473 | 47100 942 942 1 0:02:29.828978
474 | 47200 944 946 0 0:02:29.884413
475 | 47300 946 946 1 0:02:29.885259
476 | 47400 948 948 1 0:02:29.742021
477 | 47500 950 950 1 0:02:30.007190
478 | 47600 952 90 0 0:02:29.852872
479 | 47700 954 954 1 0:02:29.902331
480 | 47800 956 956 1 0:02:29.907287
481 | 47900 958 483 0 0:02:29.967186
482 | 48000 960 960 1 0:02:30.014844
483 | 48100 962 962 1 0:02:30.087343
484 | 48200 964 961 0 0:02:29.911554
485 | 48300 966 630 0 0:02:30.058535
486 | 48400 968 968 1 0:02:29.935223
487 | 48500 970 795 0 0:02:29.779368
488 | 48600 972 979 0 0:02:29.745436
489 | 48700 974 974 1 0:02:29.752444
490 | 48800 976 724 0 0:02:29.804450
491 | 48900 978 701 0 0:02:30.000623
492 | 49000 980 980 1 0:02:29.829837
493 | 49100 982 982 1 0:02:29.936440
494 | 49200 984 984 1 0:02:29.820193
495 | 49300 986 986 1 0:02:30.066937
496 | 49400 988 988 1 0:02:29.911113
497 | 49500 990 990 1 0:02:30.062292
498 | 49600 992 992 1 0:02:29.921232
499 | 49700 994 994 1 0:02:30.089058
500 | 49800 996 996 1 0:02:30.022316
501 | 49900 998 998 1 0:02:29.815372
502 |
--------------------------------------------------------------------------------
/experiments.MD:
--------------------------------------------------------------------------------
1 | # Experiments
2 |
3 | This document describes how to replicate our results.
4 |
5 | First, train the models on ImageNet and CIFAR-10:
6 |
7 | ```
8 | python train.py imagenet resnet50 models/imagenet/resnet50/noise_0.00 --batch 400 --noise 0.0
9 | python train.py imagenet resnet50 models/imagenet/resnet50/noise_0.25 --batch 400 --noise 0.25
10 | python train.py imagenet resnet50 models/imagenet/resnet50/noise_0.50 --batch 400 --noise 0.5
11 | python train.py imagenet resnet50 models/imagenet/resnet50/noise_1.00 --batch 400 --noise 1.0
12 |
13 | python train.py cifar10 cifar_resnet110 models/cifar10/resnet110/noise_0.00 --batch 400 --noise 0.00 --gpu [num]
14 | python train.py cifar10 cifar_resnet110 models/cifar10/resnet110/noise_0.12 --batch 400 --noise 0.12 --gpu [num]
15 | python train.py cifar10 cifar_resnet110 models/cifar10/resnet110/noise_0.25 --batch 400 --noise 0.25 --gpu [num]
16 | python train.py cifar10 cifar_resnet110 models/cifar10/resnet110/noise_0.50 --batch 400 --noise 0.50 --gpu [num]
17 | python train.py cifar10 cifar_resnet110 models/cifar10/resnet110/noise_1.00 --batch 400 --noise 1.00 --gpu [num]
18 | ```
19 | On ImageNet, `train.py` uses all available GPUs in synchronous SGD; on CIFAR-10 it just uses one GPU.
20 |
21 | Then, certify a subsample of the test set on ImageNet and CIFAR-10:
22 |
23 | ```
24 | python code/certify.py imagenet models/imagenet/resnet50/noise_0.25/checkpoint.pth.tar 0.25 data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25 --skip 100 --batch 400
25 | python code/certify.py imagenet models/imagenet/resnet50/noise_0.50/checkpoint.pth.tar 0.50 data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50 --skip 100 --batch 400
26 | python code/certify.py imagenet models/imagenet/resnet50/noise_1.00/checkpoint.pth.tar 1.00 data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00 --skip 100 --batch 400
27 |
28 | python code/certify.py cifar10 models/cifar10/resnet110/noise_0.12/checkpoint.pth.tar 0.25 data/predict/cifar10/resnet110/noise_0.12/test/sigma_0.12 --skip 20 --batch 400
29 | python code/certify.py cifar10 models/cifar10/resnet110/noise_0.25/checkpoint.pth.tar 0.25 data/predict/cifar10/resnet110/noise_0.25/test/sigma_0.25 --skip 20 --batch 400
30 | python code/certify.py cifar10 models/cifar10/resnet110/noise_0.50/checkpoint.pth.tar 0.50 data/predict/cifar10/resnet110/noise_0.50/test/sigma_0.50 --skip 20 --batch 400
31 | python code/certify.py cifar10 models/cifar10/resnet110/noise_1.00/checkpoint.pth.tar 1.00 data/predict/cifar10/resnet110/noise_1.00/test/sigma_1.00 --skip 20 --batch 400
32 | ```
33 |
34 | Then try to certify when the training and testing noise is mismatched:
35 | ```
36 | python code/certify.py imagenet models/imagenet/resnet50/noise_0.25/checkpoint.pth.tar 0.50 data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.50 --skip 100 --batch 400
37 | python code/certify.py imagenet models/imagenet/resnet50/noise_1.00/checkpoint.pth.tar 0.50 data/certify/imagenet/resnet50/noise_1.00/test/sigma_0.50 --skip 100 --batch 400
38 |
39 | python code/certify.py cifar10 models/cifar10/resnet110/noise_0.00/checkpoint.pth.tar 0.50 data/predict/cifar10/resnet110/noise_0.00/test/sigma_0.50 --skip 20 --batch 400
40 | python code/certify.py cifar10 models/cifar10/resnet110/noise_0.12/checkpoint.pth.tar 0.50 data/predict/cifar10/resnet110/noise_0.12/test/sigma_0.50 --skip 20 --batch 400
41 | python code/certify.py cifar10 models/cifar10/resnet110/noise_0.25/checkpoint.pth.tar 0.50 data/predict/cifar10/resnet110/noise_0.25/test/sigma_0.50 --skip 20 --batch 400
42 | python code/certify.py cifar10 models/cifar10/resnet110/noise_1.00/checkpoint.pth.tar 0.50 data/predict/cifar10/resnet110/noise_1.00/test/sigma_0.50 --skip 20 --batch 400
43 | ```
44 |
45 | Prediction experiments on ImageNet:
46 | ```
47 | python code/predict.py imagenet models/imagenet/resnet50/noise_0.25/checkpoint.pth.tar 0.25 data/predict/imagenet/resnet50/noise_0.25/test/N_100 --N 100 --skip 100 --batch 400
48 | python code/predict.py imagenet models/imagenet/resnet50/noise_0.25/checkpoint.pth.tar 0.25 data/predict/imagenet/resnet50/noise_0.25/test/N_1000 --N 1000 --skip 100 --batch 400
49 | python code/predict.py imagenet models/imagenet/resnet50/noise_0.25/checkpoint.pth.tar 0.25 data/predict/imagenet/resnet50/noise_0.25/test/N_10000 --N 10000 --skip 100 --batch 400
50 | python code/predict.py imagenet models/imagenet/resnet50/noise_0.25/checkpoint.pth.tar 0.25 data/predict/imagenet/resnet50/noise_0.25/test/N_100000 --N 100000 --skip 100 --batch 400
51 | ```
52 |
53 | Finally, to visualize noisy images:
54 | ```
55 | python code/visualize.py imagenet figures/example_images/imagenet 100 0.0 0.25 0.5 1.0
56 | python code/visualize.py imagenet figures/example_images/imagenet 5400 0.0 0.25 0.5 1.0
57 | python code/visualize.py imagenet figures/example_images/imagenet 9025 0.0 0.25 0.5 1.0
58 | python code/visualize.py imagenet figures/example_images/imagenet 19411 0.0 0.25 0.5 1.0
59 |
60 | python code/visualize.py cifar10 figures/example_images/cifar10 10 0.0 0.25 0.5 1.0
61 | python code/visualize.py cifar10 figures/example_images/cifar10 20 0.0 0.25 0.5 1.0
62 | python code/visualize.py cifar10 figures/example_images/cifar10 70 0.0 0.25 0.5 1.0
63 | python code/visualize.py cifar10 figures/example_images/cifar10 110 0.0 0.25 0.5 1.0
64 | ```
--------------------------------------------------------------------------------
/figures/compare_bounds.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/compare_bounds.pdf
--------------------------------------------------------------------------------
/figures/example_images/cifar10/10_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/10_0.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/10_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/10_100.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/10_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/10_25.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/10_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/10_50.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/110_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/110_0.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/110_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/110_100.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/110_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/110_25.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/110_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/110_50.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/20_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/20_0.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/20_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/20_100.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/20_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/20_25.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/20_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/20_50.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/70_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/70_0.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/70_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/70_100.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/70_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/70_25.png
--------------------------------------------------------------------------------
/figures/example_images/cifar10/70_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/cifar10/70_50.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/100_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/100_0.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/100_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/100_100.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/100_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/100_25.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/100_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/100_50.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/19411_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/19411_0.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/19411_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/19411_100.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/19411_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/19411_25.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/19411_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/19411_50.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/3300_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/3300_0.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/3300_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/3300_100.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/3300_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/3300_25.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/3300_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/3300_50.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/5400_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/5400_0.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/5400_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/5400_100.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/5400_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/5400_25.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/5400_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/5400_50.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/9067_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/9067_0.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/9067_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/9067_100.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/9067_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/9067_25.png
--------------------------------------------------------------------------------
/figures/example_images/imagenet/9067_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/example_images/imagenet/9067_50.png
--------------------------------------------------------------------------------
/figures/panda_0.25.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/panda_0.25.gif
--------------------------------------------------------------------------------
/figures/panda_0.50.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/panda_0.50.gif
--------------------------------------------------------------------------------
/figures/panda_1.00.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/panda_1.00.gif
--------------------------------------------------------------------------------
/figures/panda_577.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/panda_577.png
--------------------------------------------------------------------------------
/figures/radiusslow.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/smoothing/78a4d949e4f627d000a78908e001f8ca66c92943/figures/radiusslow.pdf
--------------------------------------------------------------------------------