├── kernel.jl ├── types.jl ├── LICENSE ├── adv_bfgs.jl ├── README.md ├── adv_sgd.jl ├── data-example ├── glass.test ├── glass.train └── glass.csv ├── util.jl ├── shared.jl ├── example.jl ├── example_primal.jl ├── example_kernel.jl ├── adv_cg.jl └── adv_kernel_cg.jl /kernel.jl: -------------------------------------------------------------------------------- 1 | function linear_kernel(x1::Vector, x2::Vector) 2 | return dot(x1, x2)::Float64 3 | end 4 | 5 | function polynomial_kernel(x1::Vector, x2::Vector, d::Integer=2) 6 | return ((1 + dot(x1, x2)) ^ d)::Float64 7 | end 8 | 9 | # function gaussian_kernel(x1::Vector, x2::Vector, sigma::Real=1.0) 10 | # # k = exp(-1/(2*sigma^2) ||x1 - x2||^2 ) 11 | # return exp( -norm(x1 - x2)^2 / (2*sigma^2) )::Float64 12 | # end 13 | 14 | function gaussian_kernel(x1::Vector, x2::Vector, gamma::Real=1.0) 15 | # k = exp(-gamma ||x1 - x2||^2 ) 16 | return exp( -gamma * norm(x1 - x2)^2 )::Float64 17 | end 18 | -------------------------------------------------------------------------------- /types.jl: -------------------------------------------------------------------------------- 1 | abstract ClassificationModel 2 | 3 | type MultiAdversarialModel <: ClassificationModel 4 | w::Vector{Float64} 5 | alpha::Vector{Float64} 6 | constraints::Vector{Tuple{Integer, Vector{Integer}}} 7 | n_class::Int 8 | game_value_01::Float64 9 | game_value_augmented::Float64 10 | train_adv_loss::Float64 11 | train_01_loss::Float64 12 | end 13 | 14 | type KernelMultiAdversarialModel <: ClassificationModel 15 | kernel::Symbol 16 | kernel_params::Vector{Float64} 17 | alpha::Vector{Float64} 18 | constraints::Vector{Tuple{Integer, Vector{Integer}}} 19 | n_class::Int 20 | game_value_01::Float64 21 | game_value_augmented::Float64 22 | train_adv_loss::Float64 23 | train_01_loss::Float64 24 | end 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Rizal Fathony 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /adv_bfgs.jl: -------------------------------------------------------------------------------- 1 | using Optim 2 | 3 | function objFunc(w::Vector, X1::Matrix, y::Vector, C::Real, n::Int64, n_c::Int64, idmi::Vector) 4 | halfnorm = 0.5 * dot(w, w) 5 | 6 | xi = 0 7 | for i = 1:n 8 | psis = psi_list(w, X1, y, i, n_c, idmi) 9 | psis_id, val = best_psis(psis) # most violated constraints 10 | xi += val 11 | end 12 | 13 | return halfnorm + C * xi 14 | end 15 | 16 | function gradFunc!(w::Vector, g::Vector, X1::Matrix, y::Vector, C::Real, n::Int64, n_c::Int64, idmi::Vector) 17 | m = length(g) 18 | 19 | dxi = zeros(m) 20 | for i = 1:n 21 | psis = psi_list(w, X1, y, i, n_c, idmi) 22 | psis_id, val = best_psis(psis) # most violated constraints 23 | dconst = calc_dconst((i,psis_id), X1, y, n_c, idmi) 24 | dxi += dconst 25 | end 26 | 27 | for i=1:m 28 | g[i] = w[i] + C * dxi[i] 29 | end 30 | end 31 | 32 | # train adversarial method using LBFGS 33 | function train_adv_bfgs(X::Matrix, y::Vector, C::Real=1.0; 34 | ftol::Real=1e-6, grtol::Real=1e-6, 35 | show_trace::Bool=true, max_iter::Int=1000) 36 | 37 | n = length(y) 38 | # add one 39 | X1 = [ones(n) X]' # transpose 40 | m = size(X1, 1) 41 | 42 | # number of class 43 | n_c = maximum(y) 44 | n_f = n_c * m # number of features 45 | 46 | # parameters. init with zero 47 | w = rand(n_f) - 0.5 48 | 49 | # prepare saved vars 50 | idmi = map(i -> idi(m, i), collect(1:n_c)) 51 | 52 | ## Lbfgs 53 | res = Optim.optimize( x -> objFunc(x, X1, y, C, n, n_c, idmi), 54 | (x, g) -> gradFunc!(x, g, X1, y, C, n, n_c, idmi), w, 55 | LBFGS(linesearch! = Optim.mt_linesearch!), 56 | OptimizationOptions(show_trace = show_trace, iterations = max_iter, 57 | f_tol = ftol, g_tol = grtol) 58 | ) 59 | 60 | w = res.minimum 61 | 62 | # finalizing losses 63 | gv_aug = zeros(n) 64 | gv_01 = zeros(n) 65 | l_adv = zeros(n) 66 | l_01 = zeros(n) 67 | for i=1:n 68 | psis = psi_list(w, X1, y, i, n_c, idmi) 69 | psis_id, val = best_psis(psis) # most violated constraints 70 | n_ps = length(psis_id) 71 | 72 | gv_aug[i] = val 73 | 74 | # compute probs 75 | p_hat = zeros(n_c) 76 | p_check = zeros(n_c) 77 | for j=1:n_c 78 | if j in psis_id 79 | p_hat[j] = ( (n_ps-1.0)*psis[j] - sum(psis[psis_id[psis_id .!= j]]) + 1.0 ) / n_ps 80 | p_check[j] = 1.0 / n_ps 81 | else 82 | p_hat[j] = 0.0 83 | p_check[j] = 0.0 84 | end 85 | end 86 | 87 | C01 = 1 - eye(n_c) # 01 loss matrix 88 | v = p_hat' * C01 * p_check # the result is vector size 1 not a number 89 | gv_01[i] = v[1] # 90 | 91 | ## training loss 92 | l_adv[i] = 1.0 - p_hat[y[i]] 93 | l_01[i] = 1.0 - round(Int, indmax(p_hat) == y[i]) 94 | end 95 | 96 | game_value_01 = mean(gv_01) 97 | game_value_augmented = mean(gv_aug) 98 | 99 | # create model 100 | adv_model = MultiAdversarialModel(w, zeros(0), Tuple{Integer, Vector}[], n_c, game_value_01, game_value_augmented, mean(l_adv), mean(l_01)) 101 | 102 | return adv_model::MultiAdversarialModel 103 | end 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Multiclass Classification: A Risk Minimization Perspective (NIPS 2016) 2 | This repository is a code example of a the paper: 3 | [Adversarial Multiclass Classification: A Risk Minimization Perspective](https://papers.nips.cc/paper/6088-adversarial-multiclass-classification-a-risk-minimization-perspective) 4 | 5 | Full paper: [https://www.cs.uic.edu/~rfathony/pdf/fathony2016adversarial.pdf](https://www.cs.uic.edu/~rfathony/pdf/fathony2016adversarial.pdf) 6 | 7 | ### Abstract 8 | 9 | Recently proposed adversarial classification methods have shown promising results for cost sensitive and multivariate losses. In contrast with empirical risk minimization (ERM) methods, which use convex surrogate losses to approximate the desired non-convex target loss function, adversarial methods minimize non-convex losses by treating the properties of the training data as being uncertain and worst case within a minimax game. Despite this difference in formulation, we recast adversarial classification under zero-one loss as an ERM method with a novel prescribed loss function. We demonstrate a number of theoretical and practical advantages over the very closely related hinge loss ERM methods. This establishes adversarial classification under the zero-one loss as a method that fills the long standing gap in multiclass hinge loss classification, simultaneously guaranteeing Fisher consistency and universal consistency, while also providing dual parameter sparsity and high accuracy predictions in practice. 10 | 11 | 12 | # Setup 13 | 14 | The source code is written in [Julia](http://julialang.org/) version 0.5.0. 15 | 16 | ### Dependency 17 | The code depends on the followong Julia Packages: 18 | 19 | 1. [Optim.jl](https://github.com/JuliaOpt/Optim.jl) 20 | 2. [Gurobi.jl](https://github.com/JuliaOpt/Gurobi.jl) 21 | 3. [Mosek.jl](https://github.com/JuliaOpt/Mosek.jl) 22 | 23 | [Optim.jl](https://github.com/JuliaOpt/Optim.jl) is used in the primal BFGS optimization. 24 | To run the dual constraint generation algorithm, a Quadratic Programming solver ([Gurobi.jl](https://github.com/JuliaOpt/Gurobi.jl) 25 | or [Mosek.jl](https://github.com/JuliaOpt/Mosek.jl)) is required. Please refer to each package's instruction for the installation. 26 | 27 | ### Example 28 | 29 | Three example files are provided: 30 | 31 | * `example.jl` : 32 | run dual constraint generation algorithm for training. 33 | 34 | * `example_kernel.jl` : 35 | run dual constraint generation algorithm for training with Gaussian kernel. 36 | 37 | * `example_primal.jl`: 38 | run primal optimization algorithm (BFGS or SGD) for training. 39 | 40 | In each file, the code will run training with k-fold cross validation for the example dataset (`glass`). 41 | After finding the best setting, it will run testing phase. 42 | 43 | To change the training settings, please directly edit the setting values in the given example. 44 | 45 | To run the code, execute (in terminal): 46 | ``` 47 | julia example.jl 48 | ``` 49 | 50 | # Citation (BibTeX) 51 | ``` 52 | @incollection{fathony2016adversarial, 53 | title = {Adversarial Multiclass Classification: A Risk Minimization Perspective}, 54 | author = {Fathony, Rizal and Liu, Anqi and Asif, Kaiser and Ziebart, Brian}, 55 | booktitle = {Advances in Neural Information Processing Systems 29}, 56 | pages = {559--567}, 57 | year = {2016}, 58 | } 59 | ``` 60 | # Acknowledgements 61 | This research was supported as part of the Future of Life Institute (futureoflife.org) FLI-RFP-AI1 program, grant\#2016-158710 and by NSF grant RI-\#1526379. 62 | -------------------------------------------------------------------------------- /adv_sgd.jl: -------------------------------------------------------------------------------- 1 | 2 | function fg_one!(w::Vector, g::Vector, i::Int64, X1::Matrix, y::Vector, C::Real, n::Int64, n_c::Int64, idmi::Vector) 3 | m = length(g) 4 | halfnorm = 0.5 * dot(w, w) 5 | 6 | psis = psi_list(w, X1, y, i, n_c, idmi) 7 | psis_id, xi = best_psis(psis) # most violated constraints 8 | 9 | dxi = calc_dconst((i,psis_id), X1, y, n_c, idmi) 10 | 11 | for i=1:m 12 | g[i] = w[i] + C * dxi[i] 13 | end 14 | 15 | return halfnorm + C * xi 16 | end 17 | 18 | # train adversarial method using SGD 19 | function train_adv_sgd(X::Matrix, y::Vector, C::Real=1.0; 20 | step::Float64=0.1, use_adagrad::Bool=false, 21 | ftol::Real=1e-6, grtol::Real=1e-6, 22 | show_trace::Bool=true, max_iter::Int=1000) 23 | 24 | n = length(y) 25 | # add one 26 | X1 = [ones(n) X]' # transpose 27 | m = size(X1, 1) 28 | 29 | # number of class 30 | n_c = maximum(y) 31 | n_f = n_c * m # number of features 32 | 33 | # parameters. init with zero 34 | w = rand(n_f) - 0.5 35 | grad = zeros(n_f) 36 | 37 | # prepare saved vars 38 | idmi = map(i -> idi(m, i), collect(1:n_c)) 39 | 40 | # adagrad 41 | # rate is something you need to set beforehand 42 | rate = step 43 | square_g = zeros(n_f) # for storing historical square of grads 44 | 45 | f_prev = Inf 46 | iter = 0 47 | 48 | while true 49 | iter = iter + 1 50 | 51 | ids = randperm(n) 52 | sum_f = 0.0 53 | 54 | for i in ids 55 | 56 | f = fg_one!(w, grad, i, X1, y, C, n, n_c, idmi) 57 | sum_f += f 58 | if show_trace 59 | println("iter : ", iter, ", sample : ", i, ", C : ", C, ", f : ", f, ", abs grad : ", mean(abs(grad))) 60 | end 61 | 62 | if use_adagrad 63 | # adagrad 64 | square_g += grad .^ 2 65 | w = w - (rate ./ sqrt(square_g)) .* grad 66 | else 67 | w = w - step * grad 68 | end 69 | 70 | end 71 | 72 | # discount step 73 | step = step * 0.95 74 | 75 | if iter >= max_iter 76 | if show_trace println("maximum iteration reached!!") end 77 | break 78 | end 79 | 80 | if mean(abs(grad)) < grtol 81 | if show_trace println("gradient breaks!!") end 82 | break 83 | end 84 | 85 | f = sum_f / n 86 | if abs(f_prev - f) < ftol 87 | if show_trace println("function breaks!!") end 88 | break 89 | end 90 | f_prev = f 91 | 92 | end 93 | 94 | # finalizing losses 95 | gv_aug = zeros(n) 96 | gv_01 = zeros(n) 97 | l_adv = zeros(n) 98 | l_01 = zeros(n) 99 | for i=1:n 100 | psis = psi_list(w, X1, y, i, n_c, idmi) 101 | psis_id, val = best_psis(psis) # most violated constraints 102 | n_ps = length(psis_id) 103 | 104 | gv_aug[i] = val 105 | 106 | # compute probs 107 | p_hat = zeros(n_c) 108 | p_check = zeros(n_c) 109 | for j=1:n_c 110 | if j in psis_id 111 | p_hat[j] = ( (n_ps-1.0)*psis[j] - sum(psis[psis_id[psis_id .!= j]]) + 1.0 ) / n_ps 112 | p_check[j] = 1.0 / n_ps 113 | else 114 | p_hat[j] = 0.0 115 | p_check[j] = 0.0 116 | end 117 | end 118 | 119 | C01 = 1 - eye(n_c) # 01 loss matrix 120 | v = p_hat' * C01 * p_check # the result is vector size 1 not a number 121 | gv_01[i] = v[1] # 122 | 123 | ## training loss 124 | l_adv[i] = 1.0 - p_hat[y[i]] 125 | l_01[i] = 1.0 - round(Int, indmax(p_hat) == y[i]) 126 | end 127 | 128 | game_value_01 = mean(gv_01) 129 | game_value_augmented = mean(gv_aug) 130 | 131 | # create model 132 | adv_model = MultiAdversarialModel(w, zeros(0), Tuple{Integer, Vector}[], n_c, game_value_01, game_value_augmented, mean(l_adv), mean(l_01)) 133 | 134 | return adv_model::MultiAdversarialModel 135 | end 136 | -------------------------------------------------------------------------------- /data-example/glass.test: -------------------------------------------------------------------------------- 1 | 144,114,53,212,113,121,167,203,198,63,169,153,60,85,123,200,12,140,204,126,183,8,83,71,106,175,149,103,179,67,182,94,185,137,210,159,194,15,111,4,22,96,127,151,208,101,86,17,152,186,95,166,37,201,147,74,44,192,130,163,188,45,180,66,97 2 | 135,162,191,114,64,193,99,107,143,38,57,68,206,115,149,186,40,2,131,201,14,21,73,76,35,138,151,10,172,211,36,174,106,12,127,90,184,171,141,82,22,205,152,136,26,27,81,8,185,25,137,100,70,157,180,69,15,47,207,178,164,163,56,37,62 3 | 105,96,145,54,116,44,128,65,166,50,202,76,22,111,158,130,163,106,107,173,100,127,68,149,156,211,14,59,46,8,43,19,144,57,182,150,51,125,213,197,87,64,60,188,162,148,161,21,30,26,122,137,20,189,81,152,195,58,112,176,80,79,78,23,86 4 | 36,83,67,101,143,154,185,17,131,8,103,203,132,78,175,64,16,209,172,135,93,84,128,37,210,58,142,79,214,105,208,133,130,179,65,2,141,55,98,10,151,38,29,106,54,204,108,198,11,102,48,42,192,202,18,61,174,201,68,109,81,45,19,72,168 5 | 168,123,82,151,102,108,66,176,49,30,128,13,182,32,170,157,143,107,3,208,70,159,42,43,203,129,100,149,55,189,8,172,79,152,94,111,26,113,2,10,138,186,131,15,29,160,45,177,201,193,198,194,185,187,24,97,21,65,14,78,67,85,110,207,16 6 | 160,41,103,46,111,52,33,54,142,35,98,28,180,40,214,78,139,86,20,106,22,199,179,166,42,7,55,208,135,10,148,161,65,8,17,56,72,45,164,176,122,165,80,116,102,51,13,32,30,134,49,124,101,75,53,212,133,153,88,15,93,169,138,188,73 7 | 126,177,75,120,154,136,185,106,162,152,143,64,209,68,195,180,42,137,4,127,94,60,79,78,200,48,130,125,24,121,59,62,76,188,171,74,165,29,27,72,8,71,176,164,173,145,77,167,175,26,66,2,189,63,84,196,14,112,214,144,156,100,211,22,107 8 | 169,57,7,186,159,122,107,180,185,16,172,28,53,168,29,31,78,17,106,34,101,137,77,199,133,153,74,120,18,75,196,43,127,54,39,100,69,158,59,150,89,142,86,112,183,65,187,104,198,195,174,166,145,202,190,124,46,192,92,109,19,22,189,149,76 9 | 52,206,123,202,167,194,4,26,90,204,5,213,196,188,12,166,76,171,41,71,62,173,93,182,156,184,25,11,138,154,153,40,130,203,80,77,146,155,113,32,147,59,72,20,108,140,144,37,148,18,201,64,158,88,186,193,22,137,84,28,170,63,57,211,214 10 | 26,93,200,104,107,109,134,74,29,185,16,164,197,19,47,183,201,5,168,39,2,33,59,38,144,174,166,37,179,203,136,17,101,138,15,103,178,189,187,72,7,127,182,204,210,69,186,91,167,146,77,206,58,131,64,78,88,194,24,121,196,28,170,161,45 11 | 128,197,204,211,116,214,206,76,112,143,78,207,150,173,210,17,88,200,169,212,63,90,67,186,35,54,104,125,95,147,45,81,93,53,180,23,131,75,109,19,71,167,187,16,34,56,106,129,87,11,64,121,2,31,111,72,183,166,144,103,49,86,57,70,68 12 | 179,31,41,152,87,81,197,214,181,83,90,50,114,14,40,8,175,191,192,89,189,122,117,97,48,58,79,5,163,26,133,74,3,64,71,155,157,55,59,149,136,109,161,104,128,187,23,186,150,44,195,32,169,15,210,130,173,9,185,141,108,153,199,105,103 13 | 197,80,68,15,170,132,157,117,109,173,180,200,86,38,127,94,188,119,187,105,213,141,4,129,95,8,22,41,177,97,33,189,56,107,59,184,148,72,30,172,161,205,207,193,143,49,26,60,124,167,128,98,138,118,134,11,104,111,162,146,20,202,135,21,53 14 | 211,117,152,199,102,45,91,154,24,175,190,39,161,179,133,71,11,15,63,48,143,1,32,106,18,6,10,76,108,159,178,49,158,180,181,73,122,201,170,128,126,81,206,55,177,36,153,205,23,5,203,84,171,138,83,92,52,173,142,109,43,110,165,156,148 15 | 61,68,40,204,59,25,108,78,151,143,49,27,92,72,199,158,132,10,198,122,79,54,205,176,214,171,17,71,22,137,195,102,30,124,134,14,38,88,201,111,118,83,162,20,67,207,196,33,37,178,7,60,43,153,8,123,89,16,168,113,47,24,177,28,75 16 | 91,170,97,173,128,143,90,66,120,32,3,122,152,179,191,92,65,79,208,130,134,83,204,67,206,181,61,46,115,101,125,167,71,45,104,22,211,180,33,141,75,58,86,178,51,112,160,201,149,147,88,182,105,27,163,136,119,186,197,202,118,106,31,70,78 17 | 133,25,176,87,43,204,82,211,111,77,42,141,19,173,102,73,156,188,110,50,126,124,31,209,62,143,201,181,105,171,79,112,183,122,195,98,78,116,68,38,163,28,75,137,58,177,64,35,60,119,101,161,134,14,149,121,44,92,90,71,55,16,12,135,192 18 | 98,212,194,35,87,135,97,74,178,143,57,174,214,123,16,144,56,68,115,170,133,33,14,10,101,100,51,2,96,166,137,146,7,208,205,20,32,55,204,108,37,13,210,63,71,27,150,85,151,17,76,132,154,88,189,173,168,19,125,145,152,43,175,140,169 19 | 65,210,112,126,120,167,127,101,154,41,183,8,86,16,186,109,151,179,134,153,166,64,135,38,21,29,20,73,132,211,145,204,40,187,141,182,15,31,30,173,139,56,27,198,103,4,156,5,137,121,106,207,67,194,54,212,68,168,50,143,171,70,100,71,95 20 | 22,164,103,210,79,51,63,28,78,172,47,140,17,204,176,64,154,160,65,83,197,119,111,16,159,39,24,189,29,186,117,4,74,155,123,90,112,150,55,56,138,46,101,135,15,73,94,49,69,143,21,169,80,12,165,136,201,133,115,142,196,202,198,23,141 21 | -------------------------------------------------------------------------------- /util.jl: -------------------------------------------------------------------------------- 1 | 2 | function standardize(data::Matrix) 3 | m,n = size(data) 4 | standardized = zeros(m,n) 5 | mean_vector = zeros(n) 6 | std_vector = zeros(n) 7 | for i = 1:n 8 | mean_vector[i] = mean(data[:,i]) 9 | std_vector[i] = std(data[:,i]) 10 | if std_vector[i] != 0 11 | standardized[:,i] = (data[:,i]-mean_vector[i])./std_vector[i] 12 | elseif mean_vector[i] < 1 && mean_vector[i] >=0 13 | standardized[:,i] = data[:,i] 14 | else 15 | standardized[:,i] = 1 16 | end 17 | end 18 | return standardized::Matrix{Float64}, mean_vector::Vector{Float64}, std_vector::Vector{Float64} 19 | end 20 | 21 | function standardize(data::Matrix, mean_vector::Vector{Float64}, std_vector::Vector{Float64}) 22 | m,n = size(data) 23 | standardized = zeros(m,n) 24 | for i = 1:n 25 | if std_vector[i] != 0 26 | standardized[:,i] = (data[:,i]-mean_vector[i]) ./ std_vector[i] 27 | elseif mean_vector[i] < 1 && mean_vector[i] >=0 28 | standardized[:,i] = data[:,i] 29 | else 30 | standardized[:,i] = 1 31 | end 32 | end 33 | return standardized 34 | end 35 | 36 | 37 | function normalize(X::Matrix) 38 | # normalize to 0 1 39 | r_nrm = 1.0 # range 40 | shift = 0.0 41 | X_max = maximum(X, 1) 42 | X_min = minimum(X, 1) 43 | X_nrm = (r_nrm * broadcast(-, X, X_min) ./ broadcast(-, X_max, X_min)) + shift 44 | 45 | return X_nrm 46 | end 47 | 48 | function k_fold(n::Int, k::Int) 49 | idx = randperm(n) 50 | 51 | # allocate folds 52 | folds = Vector[] 53 | n_f = round(Int, floor(n/k)) 54 | add_f = n % k 55 | j = 1 56 | for i=1:k 57 | if i <= add_f 58 | push!(folds, idx[j:j+n_f]) 59 | j += n_f+1 60 | else 61 | push!(folds, idx[j:j+n_f-1]) 62 | j += n_f 63 | end 64 | end 65 | 66 | return folds::Vector{Vector} 67 | end 68 | 69 | function jaakkola_heuristic(X::Matrix, y::Vector) 70 | n = size(X,1) 71 | Xt = X' # transpose is more efficient 72 | min_dist = zeros(n) 73 | for i=1:n 74 | ds = Inf 75 | for j=1:n 76 | if y[i] != y[j] 77 | d = norm(view(Xt,:,i) - view(Xt,:,j)) 78 | if d < ds 79 | ds = d 80 | end 81 | end 82 | end 83 | min_dist[i] = ds 84 | end 85 | 86 | sigma = median(min_dist) 87 | gamma = 0.5 / (sigma * sigma) 88 | 89 | return gamma 90 | end 91 | 92 | function count_constraints(model::MultiAdversarialModel, y_train::Vector) 93 | w = model.w 94 | alpha = model.alpha 95 | cs = model.constraints 96 | 97 | n_train = length(y_train) 98 | 99 | # via dec or full 100 | if cs[1][1] == 1 && cs[1][2] == [y_train[1]] 101 | optim = :dec 102 | cs_added = cs[n_train+1:end] 103 | else 104 | optim = :full 105 | cs_added = cs 106 | end 107 | 108 | n_cs_added = length(cs_added) 109 | const_num = zeros(n_train) 110 | for i=1:n_train 111 | const_num[i] = length(find(x::Tuple{Integer, Vector} -> x[1] == i, cs_added)) 112 | end 113 | max_cs_added = maximum(const_num) 114 | 115 | zero_thereshold = 1e-3 116 | active_const_num = zeros(n_train) 117 | for i=1:n_train 118 | idx = find(x::Tuple{Integer, Vector} -> x[1] == i, cs_added) 119 | active_const_num[i] = sum(alpha[idx] .> zero_thereshold) 120 | end 121 | 122 | n_cs_active = sum(active_const_num) 123 | max_cs_active = maximum(active_const_num) 124 | n_sv = sum(active_const_num .> 0) 125 | 126 | return n_cs_added, max_cs_added, n_cs_active, max_cs_active, n_sv 127 | end 128 | 129 | function count_constraints(model::KernelMultiAdversarialModel, y_train::Vector) 130 | alpha = model.alpha 131 | cs = model.constraints 132 | 133 | n_train = length(y_train) 134 | 135 | # via dec or full 136 | if cs[1][1] == 1 && cs[1][2] == [y_train[1]] 137 | optim = :dec 138 | cs_added = cs[n_train+1:end] 139 | else 140 | optim = :full 141 | cs_added = cs 142 | end 143 | 144 | n_cs_added = length(cs_added) 145 | const_num = zeros(n_train) 146 | for i=1:n_train 147 | const_num[i] = length(find(x::Tuple{Integer, Vector} -> x[1] == i, cs_added)) 148 | end 149 | max_cs_added = maximum(const_num) 150 | 151 | zero_thereshold = 1e-3 152 | active_const_num = zeros(n_train) 153 | for i=1:n_train 154 | idx = find(x::Tuple{Integer, Vector} -> x[1] == i, cs_added) 155 | active_const_num[i] = sum(alpha[idx] .> zero_thereshold) 156 | end 157 | 158 | n_cs_active = sum(active_const_num) 159 | max_cs_active = maximum(active_const_num) 160 | n_sv = sum(active_const_num .> 0) 161 | 162 | return n_cs_added, max_cs_added, n_cs_active, max_cs_active, n_sv 163 | end 164 | -------------------------------------------------------------------------------- /shared.jl: -------------------------------------------------------------------------------- 1 | 2 | function best_psis(psis::Vector) 3 | k = length(psis) 4 | idx = sortperm(psis, rev=true) # sort, get the indices 5 | 6 | max_id = 1 7 | max_val = psis[idx[1]] 8 | for j=2:k 9 | v = ( sum(psis[idx[1:j]]) + j - 1 ) / j 10 | if v >= max_val 11 | max_val = v 12 | max_id = j 13 | else 14 | break 15 | end 16 | end 17 | 18 | psi_id = idx[1:max_id] 19 | 20 | return psi_id::Vector{Int}, max_val::Float64 21 | end 22 | 23 | function calc_const(psi_list::Vector, psi_id::Vector) 24 | j = length(psi_id) 25 | ret = ( sum(psi_list[psi_id]) + j - 1 ) / j 26 | return ret::Float64 27 | end 28 | 29 | 30 | function calc_cconst(psi_id::Vector) 31 | n_psi = length(psi_id) 32 | ret = (n_psi - 1 ) / n_psi 33 | return ret::Float64 34 | end 35 | 36 | 37 | ## get featrure indeces for class i 38 | function idi(m::Integer, i::Integer) 39 | return ((i-1)*m+1 : i*m)::UnitRange{Int64} 40 | end 41 | 42 | function psi_list(w::Vector, X::Matrix, y::Vector, i::Integer, c::Integer, idmi::Vector) 43 | psis = zeros(c) 44 | yi = y[i] 45 | for j=1:c 46 | if j != yi 47 | v1 = dot(w[idmi[j]], view(X, :, i)) 48 | v2 = dot(w[idmi[yi]], -view(X, :, i)) 49 | psis[j] = v1 + v2 50 | end 51 | end 52 | 53 | return psis::Vector{Float64} 54 | end 55 | 56 | ## no y. for prediction 57 | function psi_list(w::Vector, X::Matrix, i::Integer, c::Integer, idmi::Vector) 58 | psis = zeros(c) 59 | for j=1:c 60 | val = dot(w[idmi[j]], view(X, :, i)) 61 | psis[j] = val 62 | end 63 | 64 | return psis::Vector{Float64} 65 | end 66 | 67 | # in terms of dual variable (useful for kernel methods) 68 | function psi_list_dual(alpha::Vector, LPsi::Matrix, i::Integer, c::Integer) 69 | LPsi_i = view(LPsi, (i-1)*c+1 : i*c, :) 70 | psis = -(LPsi_i * alpha) 71 | return psis::Vector{Float64} 72 | end 73 | 74 | # in terms of dual variable (useful for kernel methods) dec 75 | function psi_list_dual(sLa::Vector, i::Integer, c::Integer) 76 | psis = -(sLa[(i-1)*c+1 : i*c]) 77 | return psis::Vector{Float64} 78 | end 79 | 80 | function eta_list(w::Vector, X::Matrix, i::Integer, c::Integer, idmi::Vector) 81 | etas = zeros(c) 82 | for j=1:c 83 | etas[j] = dot(w[idmi[j]], view(X, :, i)) 84 | end 85 | 86 | return etas::Vector{Float64} 87 | end 88 | 89 | function calc_dot(key::Tuple{Integer,Vector,Integer,Vector}, K_ij::Float64, y::Vector) 90 | 91 | i = key[1] 92 | j = key[3] 93 | psi_i = key[2] 94 | psi_j = key[4] 95 | li = length(psi_i) 96 | lj = length(psi_j) 97 | yi = y[i] 98 | yj = y[j] 99 | 100 | mult = 0.0 101 | 102 | inii = yi in psi_i 103 | inij = yi in psi_j 104 | inji = yj in psi_i 105 | injj = yj in psi_j 106 | 107 | if yi == yj 108 | mult += (li - round(Int, inii)) * (lj - round(Int, injj)) 109 | else 110 | if inij 111 | mult -= (li - round(Int, inii)) 112 | end 113 | if inji 114 | mult -= (lj - round(Int, injj)) 115 | end 116 | end 117 | 118 | ii = 1 119 | ij = 1 120 | while ii <= li && ij <= lj 121 | if psi_i[ii] > psi_j[ij] 122 | ij += 1 123 | elseif psi_i[ii] < psi_j[ij] 124 | ii += 1 125 | else # equal 126 | if psi_i[ii] != yi && psi_j[ij] != yj && psi_i[ii] != yj && psi_j[ij] != yi 127 | mult += 1 128 | end 129 | ii += 1 130 | ij += 1 131 | end 132 | end 133 | 134 | d = (mult * K_ij) / (li * lj) 135 | 136 | return d::Float64 137 | end 138 | 139 | function calc_dot(key::Tuple{Integer,Vector,Integer,Vector}, K::Matrix, y::Vector) 140 | i = key[1] 141 | j = key[3] 142 | K_ij = K[i,j] 143 | 144 | d = calc_dot(key, K_ij, y) 145 | 146 | return d::Float64 147 | end 148 | 149 | function calc_dconst(key::Tuple{Integer,Vector}, X::Matrix, y::Vector, n_c::Integer, idmi::Vector) 150 | 151 | i = key[1] 152 | psi_i = key[2] 153 | li = length(psi_i) 154 | yi = y[i] 155 | 156 | xi = view(X, :, i) 157 | 158 | m = length(xi) 159 | dc = zeros(m * n_c) 160 | 161 | inii = yi in psi_i 162 | dc[idmi[yi]] = - ( (li - round(Int, inii)) * xi ) / li 163 | 164 | for ii = 1:li 165 | if psi_i[ii] != yi 166 | dc[idmi[psi_i[ii]]] = xi / li 167 | end 168 | end 169 | 170 | return dc::Vector{Float64} 171 | end 172 | 173 | # calc Lambda * Phi for kernel prediction 174 | function calc_dotlphi(key::Tuple{Integer,Vector,Integer}, K_ij::Float64, y::Vector, n_c::Integer) 175 | i = key[1] 176 | psi_i = key[2] 177 | li = length(psi_i) 178 | yi = y[i] 179 | 180 | mults = zeros(n_c) 181 | for ii in psi_i 182 | if ii == yi 183 | mults[ii] = -(li - 1.0) / li 184 | else 185 | mults[ii] = 1.0 / li 186 | end 187 | end 188 | 189 | if !(yi in psi_i) 190 | mults[yi] = -1.0 191 | end 192 | 193 | return (mults * K_ij)::Vector{Float64} 194 | end 195 | 196 | function calc_dotlphi(key::Tuple{Integer,Vector,Integer}, K::Matrix, y::Vector, n_c::Integer) 197 | i = key[1] 198 | j = key[3] 199 | K_ij = K[i,j] 200 | return calc_dotlphi(key, K_ij, y, n_c)::Vector{Float64} 201 | end 202 | -------------------------------------------------------------------------------- /example.jl: -------------------------------------------------------------------------------- 1 | include("types.jl") 2 | include("shared.jl") 3 | include("util.jl") 4 | include("adv_cg.jl") 5 | 6 | ## set 7 | 8 | solver = :mosek 9 | # solver = :gurobi 10 | 11 | log = 0 12 | psdtol = 1e-6 13 | perturb = 1e-12 14 | obj_reltol_cv = 0.0 15 | obj_reltol_test = 0.0 16 | 17 | verbose = false 18 | ### prepare data 19 | 20 | dname = "glass" 21 | D_all = readcsv("data-example/" * dname * ".csv") 22 | id_train = readcsv("data-example/" * dname * ".train") 23 | id_test = readcsv("data-example/" * dname * ".test") 24 | 25 | id_train = round(Int64, id_train) 26 | id_test = round(Int64, id_test) 27 | 28 | println(dname) 29 | 30 | ### Cross Validation, using first split 31 | ## First stage 32 | 33 | id_tr = vec(id_train[1,:]) 34 | id_ts = vec(id_test[1,:]) 35 | X_train = D_all[id_tr,1:end-1] 36 | y_train = round(Int, D_all[id_tr, end]) 37 | 38 | X_test = D_all[id_ts,1:end-1] 39 | y_test = round(Int, D_all[id_ts, end]) 40 | 41 | X_train, mean_vector, std_vector = standardize(X_train) 42 | X_test = standardize(X_test, mean_vector, std_vector) 43 | 44 | Cs = [2.0^i for i=0:3:12] 45 | ncs = length(Cs) 46 | 47 | # fold 48 | n_train = size(X_train, 1) 49 | n_test = size(X_test, 1) 50 | kf = 5 51 | 52 | # k folds 53 | folds = k_fold(n_train, kf) 54 | 55 | loss_list = zeros(ncs) 56 | loss01_list = zeros(ncs) 57 | 58 | # The first stage of CV 59 | idx = randperm(n_train) 60 | X_train = X_train[idx,:] 61 | y_train = y_train[idx] 62 | 63 | for i = 1:ncs 64 | 65 | println(i, " | Adversarial | C = ", string(Cs[i])) 66 | 67 | losses = zeros(n_train) 68 | losses1 = zeros(n_train) 69 | # k fold 70 | for j = 1:kf 71 | # prepare training and validation 72 | id_tr = vcat(folds[[1:j-1; j+1:end]]...) 73 | id_val = folds[j] 74 | 75 | X_tr = X_train[id_tr, :]; y_tr = y_train[id_tr] 76 | X_val = X_train[id_val, :]; y_val = y_train[id_val] 77 | 78 | print(" ",j, "-th fold : ") 79 | @time model = train_adv_cg(X_tr, y_tr, Cs[i], perturb=perturb, obj_reltol=obj_reltol_cv, solver=solver, log=log, psdtol=psdtol, verbose=verbose) 80 | 81 | _, ls, _, ls1, _, _ = test_adv(model, X_val, y_val) 82 | 83 | losses[id_val] = ls 84 | losses1[id_val] = ls1 85 | 86 | end 87 | 88 | loss_list[i] = mean(losses) 89 | loss01_list[i] = mean(losses1) 90 | # println("loss : ", string(mean(losses))) 91 | println(" => loss01 : ", string(mean(losses1))) 92 | println() 93 | 94 | end 95 | 96 | ind_max= indmin(loss01_list) 97 | C0 = Cs[ind_max] 98 | Cs = [C0*2.0^(i-3) for i=1:5] 99 | ncs = length(Cs) 100 | 101 | ## Second stage 102 | idx = randperm(n_train) 103 | X_train = X_train[idx,:] 104 | y_train = y_train[idx] 105 | 106 | for i = 1:ncs 107 | 108 | println(i, " | Adversarial | C = ", string(Cs[i])) 109 | 110 | losses = zeros(n_train) 111 | losses1 = zeros(n_train) 112 | # k fold 113 | for j = 1:kf 114 | # prepare training and validation 115 | id_tr = vcat(folds[[1:j-1; j+1:end]]...) 116 | id_val = folds[j] 117 | 118 | X_tr = X_train[id_tr, :]; y_tr = y_train[id_tr] 119 | X_val = X_train[id_val, :]; y_val = y_train[id_val] 120 | 121 | print(" ",j, "-th fold : ") 122 | @time model = train_adv_cg(X_tr, y_tr, Cs[i], perturb=perturb, obj_reltol=obj_reltol_cv, solver=solver, log=log, psdtol=psdtol, verbose=verbose) 123 | 124 | _, ls, _, ls1, _, _ = test_adv(model, X_val, y_val) 125 | losses[id_val] = ls 126 | losses1[id_val] = ls1 127 | #println("loss : ", string(ls)) 128 | #println("loss01 : ", string(ls)) 129 | 130 | end 131 | 132 | loss_list[i] = mean(losses) 133 | loss01_list[i] = mean(losses1) 134 | # println("loss : ", string(mean(losses))) 135 | println(" => loss01 : ", string(mean(losses1))) 136 | println() 137 | 138 | end 139 | 140 | ind_max= indmin(loss01_list) 141 | C_best = Cs[ind_max] 142 | 143 | 144 | ### Evaluation 145 | 146 | n_split = size(id_train, 1) 147 | 148 | v_model = Vector{ClassificationModel}() 149 | v_result = Vector{Tuple}() 150 | v_acc = zeros(n_split) 151 | v_cs_result = zeros(n_split, 5) 152 | 153 | for i = 1:n_split 154 | # standardize 155 | id_tr = vec(id_train[i,:]) 156 | id_ts = vec(id_test[i,:]) 157 | X_train = D_all[id_tr,1:end-1] 158 | y_train = round(Int, D_all[id_tr, end]) 159 | 160 | X_test = D_all[id_ts,1:end-1] 161 | y_test = round(Int, D_all[id_ts, end]) 162 | 163 | X_train, mean_vector, std_vector = standardize(X_train) 164 | X_test = standardize(X_test, mean_vector, std_vector) 165 | 166 | #train and test 167 | @time model = train_adv_cg(X_train, y_train, C_best, perturb=perturb, obj_reltol=obj_reltol_test, solver=solver, log=log, psdtol=psdtol, verbose=verbose) 168 | 169 | result = test_adv(model, X_test, y_test) 170 | loss01 = result[3] 171 | acc = 1.0 - loss01 172 | cs_result = count_constraints(model, y_train) 173 | 174 | println("accuracy : ", acc) 175 | 176 | push!(v_model, model) 177 | push!(v_result, result) 178 | v_acc[i] = acc 179 | v_cs_result[i, :] = collect(cs_result) 180 | end 181 | 182 | println(dname) 183 | println("mean accuracy : ", mean(v_acc)) 184 | println("std accuracy : ", std(v_acc)) 185 | -------------------------------------------------------------------------------- /example_primal.jl: -------------------------------------------------------------------------------- 1 | include("types.jl") 2 | include("shared.jl") 3 | include("util.jl") 4 | include("adv_cg.jl") 5 | include("adv_bfgs.jl") 6 | include("adv_sgd.jl") 7 | 8 | ## set 9 | 10 | # alg = :bfgs 11 | alg = :sgd 12 | 13 | # sgd setting 14 | step = 0.1 15 | use_adagrad = true 16 | 17 | verbose = false 18 | 19 | ### prepare data 20 | 21 | dname = "glass" 22 | D_all = readcsv("data-example/" * dname * ".csv") 23 | id_train = readcsv("data-example/" * dname * ".train") 24 | id_test = readcsv("data-example/" * dname * ".test") 25 | 26 | id_train = round(Int64, id_train) 27 | id_test = round(Int64, id_test) 28 | 29 | println(dname) 30 | 31 | ### Cross Validation, using first split 32 | ## First stage 33 | 34 | id_tr = vec(id_train[1,:]) 35 | id_ts = vec(id_test[1,:]) 36 | X_train = D_all[id_tr,1:end-1] 37 | y_train = round(Int, D_all[id_tr, end]) 38 | 39 | X_test = D_all[id_ts,1:end-1] 40 | y_test = round(Int, D_all[id_ts, end]) 41 | 42 | X_train, mean_vector, std_vector = standardize(X_train) 43 | X_test = standardize(X_test, mean_vector, std_vector) 44 | 45 | Cs = [2.0^i for i=0:3:12] 46 | ncs = length(Cs) 47 | 48 | # fold 49 | n_train = size(X_train, 1) 50 | n_test = size(X_test, 1) 51 | kf = 5 52 | 53 | # k folds 54 | folds = k_fold(n_train, kf) 55 | 56 | loss_list = zeros(ncs) 57 | loss01_list = zeros(ncs) 58 | 59 | # The first stage of CV 60 | idx = randperm(n_train) 61 | X_train = X_train[idx,:] 62 | y_train = y_train[idx] 63 | 64 | for i = 1:ncs 65 | 66 | println(i, " | Adversarial | C = ", string(Cs[i])) 67 | 68 | losses = zeros(n_train) 69 | losses1 = zeros(n_train) 70 | # k fold 71 | for j = 1:kf 72 | # prepare training and validation 73 | id_tr = vcat(folds[[1:j-1; j+1:end]]...) 74 | id_val = folds[j] 75 | 76 | X_tr = X_train[id_tr, :]; y_tr = y_train[id_tr] 77 | X_val = X_train[id_val, :]; y_val = y_train[id_val] 78 | 79 | print(" ",j, "-th fold : ") 80 | tol = Cs[i] * length(y_tr) * 1e-6 81 | if alg == :bfgs 82 | @time model = train_adv_bfgs(X_tr, y_tr, Cs[i], ftol = tol, grtol = tol, max_iter=1000, show_trace = verbose) 83 | elseif alg == :sgd 84 | @time model = train_adv_sgd(X_tr, y_tr, Cs[i], ftol = tol, step = step, use_adagrad = use_adagrad, 85 | grtol = tol, max_iter=10000, show_trace = verbose) 86 | end 87 | 88 | _, ls, _, ls1, _, _ = test_adv(model, X_val, y_val) 89 | 90 | losses[id_val] = ls 91 | losses1[id_val] = ls1 92 | 93 | end 94 | 95 | loss_list[i] = mean(losses) 96 | loss01_list[i] = mean(losses1) 97 | # println("loss : ", string(mean(losses))) 98 | println(" => loss01 : ", string(mean(losses1))) 99 | println() 100 | 101 | end 102 | 103 | ind_max= indmin(loss01_list) 104 | C0 = Cs[ind_max] 105 | Cs = [C0*2.0^(i-3) for i=1:5] 106 | ncs = length(Cs) 107 | 108 | ## Second stage 109 | idx = randperm(n_train) 110 | X_train = X_train[idx,:] 111 | y_train = y_train[idx] 112 | 113 | for i = 1:ncs 114 | 115 | println(i, " | Adversarial | C = ", string(Cs[i])) 116 | 117 | losses = zeros(n_train) 118 | losses1 = zeros(n_train) 119 | # k fold 120 | for j = 1:kf 121 | # prepare training and validation 122 | id_tr = vcat(folds[[1:j-1; j+1:end]]...) 123 | id_val = folds[j] 124 | 125 | X_tr = X_train[id_tr, :]; y_tr = y_train[id_tr] 126 | X_val = X_train[id_val, :]; y_val = y_train[id_val] 127 | 128 | print(" ",j, "-th fold : ") 129 | tol = Cs[i] * length(y_tr) * 1e-6 130 | if alg == :bfgs 131 | @time model = train_adv_bfgs(X_tr, y_tr, Cs[i], ftol = tol, grtol = tol, max_iter=1000, show_trace = verbose) 132 | elseif alg == :sgd 133 | @time model = train_adv_sgd(X_tr, y_tr, Cs[i], ftol = tol, step = step, use_adagrad = use_adagrad, 134 | grtol = tol, max_iter=10000, show_trace = verbose) 135 | end 136 | 137 | _, ls, _, ls1, _, _ = test_adv(model, X_val, y_val) 138 | losses[id_val] = ls 139 | losses1[id_val] = ls1 140 | #println("loss : ", string(ls)) 141 | #println("loss01 : ", string(ls)) 142 | 143 | end 144 | 145 | loss_list[i] = mean(losses) 146 | loss01_list[i] = mean(losses1) 147 | # println("loss : ", string(mean(losses))) 148 | println(" => loss01 : ", string(mean(losses1))) 149 | println() 150 | 151 | end 152 | 153 | ind_max= indmin(loss01_list) 154 | C_best = Cs[ind_max] 155 | 156 | 157 | ### Evaluation 158 | 159 | n_split = size(id_train, 1) 160 | 161 | v_model = Vector{ClassificationModel}() 162 | v_result = Vector{Tuple}() 163 | v_acc = zeros(n_split) 164 | 165 | for i = 1:n_split 166 | # standardize 167 | id_tr = vec(id_train[i,:]) 168 | id_ts = vec(id_test[i,:]) 169 | X_train = D_all[id_tr,1:end-1] 170 | y_train = round(Int, D_all[id_tr, end]) 171 | 172 | X_test = D_all[id_ts,1:end-1] 173 | y_test = round(Int, D_all[id_ts, end]) 174 | 175 | X_train, mean_vector, std_vector = standardize(X_train) 176 | X_test = standardize(X_test, mean_vector, std_vector) 177 | 178 | #train and test 179 | tol = C_best * length(y_train) * 1e-6 180 | if alg == :bfgs 181 | @time model = train_adv_bfgs(X_train, y_train, C_best, ftol = tol, grtol = tol, max_iter=1000, show_trace = verbose) 182 | elseif alg == :sgd 183 | @time model = train_adv_sgd(X_train, y_train, C_best, ftol = tol, step = step, use_adagrad = use_adagrad, 184 | grtol = tol, max_iter=10000, show_trace = verbose) 185 | end 186 | 187 | 188 | result = test_adv(model, X_test, y_test) 189 | loss01 = result[3] 190 | acc = 1.0 - loss01 191 | 192 | println("accuracy : ", acc) 193 | 194 | push!(v_model, model) 195 | push!(v_result, result) 196 | v_acc[i] = acc 197 | end 198 | 199 | println(dname) 200 | println("mean accuracy : ", mean(v_acc)) 201 | println("std accuracy : ", std(v_acc)) 202 | -------------------------------------------------------------------------------- /example_kernel.jl: -------------------------------------------------------------------------------- 1 | include("types.jl") 2 | include("shared.jl") 3 | include("util.jl") 4 | include("kernel.jl") 5 | include("adv_kernel_cg.jl") 6 | 7 | ## set 8 | solver = :mosek 9 | # solver = :gurobi 10 | 11 | log = 0 12 | psdtol = 1e-6 13 | perturb = 1e-12 14 | obj_reltol_cv = 0.0 15 | obj_reltol_test = 0.0 16 | verbose = false 17 | 18 | smaller_cv = false # 3 instead of 5 19 | 20 | ### prepare data 21 | 22 | dname = "glass" 23 | D_all = readcsv("data-example/" * dname * ".csv") 24 | id_train = readcsv("data-example/" * dname * ".train") 25 | id_test = readcsv("data-example/" * dname * ".test") 26 | 27 | id_train = round(Int64, id_train) 28 | id_test = round(Int64, id_test) 29 | 30 | println(dname) 31 | 32 | ### Cross Validation, using first split 33 | ## First stage 34 | 35 | id_tr = vec(id_train[1,:]) 36 | id_ts = vec(id_test[1,:]) 37 | X_train = D_all[id_tr,1:end-1] 38 | y_train = round(Int, D_all[id_tr, end]) 39 | 40 | X_test = D_all[id_ts,1:end-1] 41 | y_test = round(Int, D_all[id_ts, end]) 42 | 43 | X_train, mean_vector, std_vector = standardize(X_train) 44 | X_test = standardize(X_test, mean_vector, std_vector) 45 | 46 | if smaller_cv 47 | Cs = [2.0^i for i=0:4:8] 48 | Gs = [2.0^(-8+i) for i=0:4:8] 49 | else 50 | Cs = [2.0^i for i=0:3:12] 51 | Gs = [2.0^(-12+i) for i=0:3:12] 52 | end 53 | 54 | ncs = length(Cs) 55 | 56 | Pars = [ Tuple{Float64,Float64}((Cs[i], Gs[j])) for i=1:ncs, j=1:ncs ] 57 | Pars = vec(Pars) 58 | npar = length(Pars) 59 | 60 | # fold 61 | n_train = size(X_train, 1) 62 | n_test = size(X_test, 1) 63 | kf = 5 64 | 65 | # k folds 66 | folds = k_fold(n_train, kf) 67 | 68 | loss_list = zeros(npar) 69 | loss01_list = zeros(npar) 70 | 71 | # The first stage of CV 72 | idx = randperm(n_train) 73 | X_train = X_train[idx,:] 74 | y_train = y_train[idx] 75 | 76 | println("First CV") 77 | 78 | for i = 1:npar 79 | 80 | println(i, " | Adversarial | C = ", Pars[i][1], ", Gamma = ", Pars[i][2]) 81 | 82 | losses = zeros(n_train) 83 | losses1 = zeros(n_train) 84 | # k fold 85 | for j = 1:kf 86 | # prepare training and validation 87 | id_tr = vcat(folds[[1:j-1; j+1:end]]...) 88 | id_val = folds[j] 89 | 90 | X_tr = X_train[id_tr, :]; y_tr = y_train[id_tr] 91 | X_val = X_train[id_val, :]; y_val = y_train[id_val] 92 | 93 | C = Pars[i][1] 94 | gamma = Pars[i][2] 95 | 96 | print(" ",j, "-th fold : ") 97 | @time model = train_adv_kernel_cg(X_tr, y_tr, C, :gaussian, [gamma], perturb=perturb, obj_reltol=obj_reltol_cv, solver=solver, log=log, psdtol=psdtol, verbose=verbose) 98 | 99 | _, ls, _, ls1, _, _ = test_adv_kernel(model, X_val, y_val, X_tr, y_tr) 100 | 101 | losses[id_val] = ls 102 | losses1[id_val] = ls1 103 | 104 | end 105 | 106 | loss_list[i] = mean(losses) 107 | loss01_list[i] = mean(losses1) 108 | # println("loss : ", string(mean(losses))) 109 | println(" => loss01 : ", string(mean(losses1))) 110 | 111 | end 112 | 113 | ind_max= indmin(loss01_list) 114 | C0 = Pars[ind_max][1] 115 | G0 = Pars[ind_max][2] 116 | if smaller_cv 117 | Cs = [C0*2.0^(i-2) for i=1:3] 118 | Gs = [G0*2.0^(i-2) for i=1:3] 119 | else 120 | Cs = [C0*2.0^(i-3) for i=1:5] 121 | Gs = [G0*2.0^(i-3) for i=1:5] 122 | end 123 | ncs = length(Cs) 124 | 125 | Pars = [ Tuple{Float64,Float64}((Cs[i], Gs[j])) for i=1:ncs, j=1:ncs ] 126 | Pars = vec(Pars) 127 | npar = length(Pars) 128 | 129 | ## Second stage 130 | idx = randperm(n_train) 131 | X_train = X_train[idx,:] 132 | y_train = y_train[idx] 133 | 134 | println("Second CV") 135 | for i = 1:npar 136 | 137 | println(i, " | Adversarial | C = ", Pars[i][1], ", Gamma = ", Pars[i][2]) 138 | 139 | losses = zeros(n_train) 140 | losses1 = zeros(n_train) 141 | # k fold 142 | for j = 1:kf 143 | # prepare training and validation 144 | id_tr = vcat(folds[[1:j-1; j+1:end]]...) 145 | id_val = folds[j] 146 | 147 | X_tr = X_train[id_tr, :]; y_tr = y_train[id_tr] 148 | X_val = X_train[id_val, :]; y_val = y_train[id_val] 149 | 150 | C = Pars[i][1] 151 | gamma = Pars[i][2] 152 | 153 | print(" ",j, "-th fold : ") 154 | @time model = train_adv_kernel_cg(X_tr, y_tr, C, :gaussian, [gamma], perturb=perturb, obj_reltol=obj_reltol_cv, solver=solver, log=log, psdtol=psdtol, verbose=verbose) 155 | 156 | _, ls, _, ls1, _, _ = test_adv_kernel(model, X_val, y_val, X_tr, y_tr) 157 | 158 | losses[id_val] = ls 159 | losses1[id_val] = ls1 160 | #println("loss : ", string(ls)) 161 | #println("loss01 : ", string(ls)) 162 | 163 | end 164 | 165 | loss_list[i] = mean(losses) 166 | loss01_list[i] = mean(losses1) 167 | #println("loss : ", string(mean(losses))) 168 | println(" => loss01 : ", string(mean(losses1))) 169 | 170 | end 171 | 172 | ind_max= indmin(loss01_list) 173 | C_best = Pars[ind_max][1] 174 | G_best = Pars[ind_max][2] 175 | 176 | ### Evaluation 177 | 178 | n_split = size(id_train, 1) 179 | 180 | v_model = Vector{ClassificationModel}() 181 | v_result = Vector{Tuple}() 182 | v_acc = zeros(n_split) 183 | v_cs_result = zeros(n_split, 5) 184 | 185 | println("Evaluation") 186 | for i = 1:n_split 187 | # standardize 188 | id_tr = vec(id_train[i,:]) 189 | id_ts = vec(id_test[i,:]) 190 | X_train = D_all[id_tr,1:end-1] 191 | y_train = round(Int, D_all[id_tr, end]) 192 | 193 | X_test = D_all[id_ts,1:end-1] 194 | y_test = round(Int, D_all[id_ts, end]) 195 | 196 | X_train, mean_vector, std_vector = standardize(X_train) 197 | X_test = standardize(X_test, mean_vector, std_vector) 198 | 199 | #train and test 200 | @time model = train_adv_kernel_cg(X_train, y_train, C_best, :gaussian, [G_best], perturb=perturb, obj_reltol=obj_reltol_test, solver=solver, log=log, psdtol=psdtol, verbose=verbose) 201 | 202 | result = test_adv_kernel(model, X_test, y_test, X_train, y_train) 203 | loss01 = result[3] 204 | acc = 1.0 - loss01 205 | cs_result = count_constraints(model, y_train) 206 | 207 | println("accuracy : ", acc) 208 | 209 | push!(v_model, model) 210 | push!(v_result, result) 211 | v_acc[i] = acc 212 | v_cs_result[i, :] = collect(cs_result) 213 | end 214 | 215 | println(dname) 216 | println("mean accuracy : ", mean(v_acc)) 217 | println("std accuracy : ", std(v_acc)) 218 | -------------------------------------------------------------------------------- /data-example/glass.train: -------------------------------------------------------------------------------- 1 | 78,28,199,116,79,129,39,178,112,7,145,46,70,69,50,100,81,30,206,122,75,115,102,72,107,64,21,59,48,98,38,57,134,162,190,29,19,40,193,109,24,77,93,117,160,23,36,82,136,195,26,189,214,90,141,56,108,92,174,88,187,9,105,76,155,131,87,73,158,211,118,171,197,68,164,196,18,61,16,142,5,31,146,110,80,41,1,6,172,205,91,58,143,27,120,11,32,133,176,128,119,177,132,25,161,33,154,84,10,191,104,202,207,170,34,14,55,165,43,184,65,181,20,138,148,99,3,89,13,35,47,150,209,42,173,168,139,2,157,49,52,62,213,54,51,124,125,156,135 2 | 214,53,182,125,55,95,66,98,209,1,189,96,177,23,140,105,88,111,170,202,9,83,154,60,132,126,104,93,29,34,41,213,208,67,49,28,145,24,168,210,113,173,3,75,33,212,200,190,117,61,102,17,77,54,175,194,160,198,51,11,44,31,199,181,5,43,187,4,165,139,120,129,110,19,89,166,79,188,116,45,108,18,128,153,146,130,156,124,167,42,109,86,6,63,58,204,84,101,196,46,20,192,103,16,59,144,179,158,97,74,87,134,80,155,78,52,118,123,161,85,65,119,112,203,122,133,183,148,39,176,92,150,32,159,48,197,142,13,94,91,147,50,71,7,169,195,30,121,72 3 | 135,36,109,208,89,207,47,187,99,39,200,123,206,28,62,151,165,52,71,77,74,32,201,18,1,183,15,191,180,194,7,92,154,56,69,27,10,63,97,33,184,170,141,73,90,67,29,129,115,143,146,35,126,210,98,142,117,186,120,84,131,3,25,41,121,88,159,45,108,2,119,37,196,198,155,49,110,147,174,212,38,83,124,214,85,93,160,114,138,17,101,13,48,5,40,66,140,91,205,103,75,16,179,31,209,136,82,167,34,185,61,102,6,172,11,203,12,134,104,113,132,70,169,171,24,204,192,177,199,181,193,42,72,55,168,118,190,157,94,139,164,4,95,53,9,153,133,175,178 4 | 69,13,33,15,127,85,136,119,187,181,22,121,59,164,167,97,188,148,145,41,177,63,27,199,43,51,94,212,120,147,153,50,34,124,129,75,91,207,44,49,156,96,152,134,4,206,95,171,126,62,176,107,158,157,76,197,100,66,183,116,47,122,86,57,173,88,146,80,70,92,200,71,21,182,137,115,149,7,110,150,191,213,169,20,9,23,189,26,74,35,118,160,60,138,99,196,56,139,28,211,90,24,3,186,39,125,25,12,190,40,195,159,184,117,193,82,178,113,123,205,31,155,140,14,1,77,52,180,112,30,87,170,104,162,5,166,144,46,89,161,194,53,114,32,163,165,6,111,73 5 | 64,148,124,88,165,213,89,112,162,188,136,71,214,86,178,19,68,134,145,75,9,191,48,81,98,206,80,153,37,36,212,175,103,117,202,23,144,132,57,50,4,140,25,118,211,146,158,41,114,154,31,44,127,141,183,95,209,105,58,126,54,179,56,180,133,72,195,164,130,181,76,47,161,17,156,77,93,139,137,147,74,205,204,197,12,115,192,38,92,20,6,87,59,210,52,163,27,51,106,35,73,46,120,83,84,91,96,61,1,190,11,155,101,40,119,196,150,174,39,125,184,62,121,33,199,34,200,173,5,90,99,142,7,22,122,18,167,171,169,116,53,104,69,60,28,63,166,109,135 6 | 87,76,151,157,61,119,89,84,110,202,171,120,147,209,115,68,23,37,97,201,12,162,83,47,125,127,19,197,69,29,170,57,71,210,36,90,38,44,144,156,145,137,70,146,25,105,187,195,66,112,131,31,95,184,123,182,174,79,192,63,109,183,141,64,9,126,155,150,2,113,194,26,58,4,204,190,193,48,96,128,136,211,5,11,158,94,16,196,107,152,149,198,154,173,74,59,203,117,50,168,175,172,3,207,77,181,43,178,67,62,130,200,24,34,21,14,140,92,163,100,81,6,104,18,108,185,82,186,27,121,206,129,177,91,132,143,1,213,205,159,60,167,114,99,189,39,85,118,191 7 | 41,163,38,43,153,124,158,80,212,159,150,142,40,116,18,3,140,192,92,54,213,182,178,21,157,105,55,32,138,119,139,86,118,210,23,207,123,67,197,204,61,1,114,35,202,20,91,37,101,44,168,198,87,131,53,181,186,89,81,31,206,34,69,102,45,110,146,111,12,50,172,169,104,103,6,109,149,208,187,108,129,70,83,19,10,193,17,97,15,58,199,160,135,33,122,190,194,166,132,141,36,7,5,28,113,203,52,39,73,88,30,183,96,65,205,51,115,155,148,184,46,57,170,191,11,99,174,117,56,147,133,49,93,16,161,25,201,151,85,134,47,90,95,13,82,98,128,9,179 8 | 201,148,108,119,206,14,50,152,15,52,111,99,161,3,20,139,154,41,128,209,32,135,181,123,88,42,13,200,131,85,45,167,118,93,30,116,40,182,114,2,73,208,160,126,35,170,141,26,68,91,203,146,51,79,175,83,213,155,80,72,36,191,95,71,62,125,49,44,211,63,136,96,81,163,61,25,38,55,162,9,21,97,11,87,184,188,113,23,47,157,6,12,117,143,82,193,147,58,177,129,5,178,4,60,102,105,67,33,98,70,165,214,197,66,115,1,151,212,84,37,121,176,48,173,204,194,207,134,90,130,103,164,179,27,210,132,110,144,140,24,64,94,171,8,156,56,10,205,138 9 | 39,67,19,198,87,94,187,125,81,69,74,112,145,133,152,111,126,13,104,8,43,14,92,55,159,82,78,168,174,134,189,121,70,183,86,79,85,65,33,83,129,169,139,151,149,96,205,30,143,17,97,197,150,38,178,195,98,135,51,120,177,105,49,141,34,103,190,10,161,1,164,172,180,117,191,179,209,101,176,15,127,91,68,160,212,29,56,115,47,31,24,16,200,119,124,157,175,7,27,21,128,6,163,162,95,181,122,192,207,73,118,54,114,110,102,58,208,165,142,44,61,46,2,23,136,89,99,48,100,9,109,210,185,116,36,66,60,199,3,42,45,107,131,132,53,75,35,50,106 10 | 160,142,4,193,150,85,113,176,191,152,158,149,192,162,114,184,130,63,83,10,82,48,169,172,128,31,140,54,208,53,106,124,57,42,32,14,40,110,80,27,120,22,73,181,145,52,129,97,8,117,60,115,67,123,46,87,165,105,100,9,155,12,177,30,157,11,154,132,156,211,35,188,108,23,3,96,65,102,143,122,44,180,153,41,89,133,198,209,56,159,202,112,139,125,62,71,68,195,13,25,126,173,99,141,151,81,135,51,92,147,207,95,36,111,6,90,70,55,212,86,213,34,18,75,84,119,43,94,61,21,1,118,66,148,190,98,50,171,163,20,175,137,199,205,116,49,214,79,76 11 | 194,191,114,50,110,10,94,85,20,132,44,84,141,140,12,181,171,201,145,113,66,149,170,51,27,175,41,98,177,136,100,178,77,91,122,1,21,162,119,97,134,99,133,39,158,127,137,9,124,89,60,160,25,179,199,4,161,82,123,155,58,24,47,22,176,7,15,48,148,18,43,42,193,102,213,3,80,146,174,126,189,168,165,188,96,61,59,40,182,79,135,130,107,205,52,195,172,46,164,154,152,30,28,198,6,202,151,153,163,184,120,105,115,8,37,196,118,33,142,185,5,38,92,156,190,108,138,209,83,62,157,13,73,117,203,65,29,159,14,55,69,208,36,139,26,32,192,101,74 12 | 92,30,113,184,118,132,160,137,88,182,142,35,19,47,75,49,51,20,34,112,180,37,194,101,206,200,76,183,16,24,107,36,60,126,62,57,84,204,120,196,11,73,17,129,12,123,52,98,158,53,28,209,91,190,159,45,139,80,174,10,154,78,69,213,54,61,95,6,178,68,77,135,188,148,29,144,115,151,21,86,46,25,110,177,134,100,42,143,202,65,18,156,43,38,4,145,212,201,205,33,168,94,7,70,121,127,102,208,27,93,193,85,124,13,170,211,72,67,22,106,2,207,167,147,119,198,1,99,63,131,140,116,203,176,164,39,125,82,146,66,56,172,138,165,111,166,171,96,162 13 | 28,54,62,66,17,196,114,12,16,103,88,18,57,126,199,92,125,58,35,40,76,52,174,185,90,192,37,136,19,145,91,96,210,48,24,194,121,139,201,10,44,176,47,149,36,82,46,69,61,3,115,214,195,1,7,137,65,74,163,101,147,79,27,154,142,93,140,131,171,13,209,155,144,63,122,55,71,182,178,39,51,166,84,113,153,168,204,34,102,133,164,208,43,31,183,165,110,212,179,100,83,203,2,151,159,70,112,211,75,175,99,67,14,191,89,186,6,156,152,77,78,85,150,130,190,198,42,116,120,45,25,169,108,206,106,9,23,160,181,73,123,81,32,64,50,158,5,87,29 14 | 209,47,12,118,104,72,162,2,112,187,26,31,137,195,98,114,29,136,210,139,85,50,196,145,97,37,193,123,197,164,54,25,113,59,150,213,149,42,167,200,185,172,157,67,62,8,27,129,3,16,44,74,94,202,204,40,176,151,4,19,90,141,189,70,56,135,46,183,168,116,35,99,53,28,208,77,182,115,33,30,61,89,41,14,169,107,186,75,121,174,132,78,69,95,188,66,80,13,21,207,86,131,79,144,198,155,134,212,147,214,65,130,7,38,51,58,160,64,34,120,9,125,57,194,93,60,119,184,103,105,87,166,163,191,124,140,192,17,82,127,96,22,20,88,68,100,111,101,146 15 | 44,165,58,211,2,34,173,136,82,166,76,200,57,150,156,119,164,170,163,185,186,18,90,135,110,19,65,66,50,172,9,41,197,120,184,97,29,11,188,129,74,21,106,175,15,191,117,56,192,46,127,130,210,159,139,109,180,4,26,85,73,31,128,131,203,42,36,179,105,125,98,114,140,194,167,80,95,149,69,91,213,174,94,13,53,182,6,181,155,202,99,115,3,12,212,209,190,52,121,1,35,100,48,93,81,101,142,147,23,112,169,104,86,187,152,77,193,62,157,87,208,107,70,103,148,63,141,206,45,64,160,161,5,154,189,55,116,96,183,126,144,138,51,146,32,84,39,145,133 16 | 154,132,133,174,142,74,123,2,24,73,156,153,117,47,212,21,64,162,52,110,81,57,214,20,99,155,172,159,213,13,18,175,9,14,187,108,7,190,109,1,54,49,144,189,151,209,146,103,59,113,102,126,95,8,39,157,96,37,56,48,11,210,12,165,150,29,169,26,89,80,185,87,207,195,176,158,121,76,196,85,62,93,194,145,15,139,116,135,50,200,77,23,148,68,60,98,28,199,164,129,5,4,205,184,16,55,137,166,107,72,111,36,161,53,94,177,69,171,188,168,138,193,42,44,192,25,43,124,82,17,40,198,63,127,203,35,38,34,84,131,30,114,183,100,140,6,10,19,41 17 | 212,63,187,146,9,61,8,59,178,210,95,36,48,189,202,81,191,168,205,93,154,129,170,113,11,167,83,100,70,214,72,125,106,179,15,99,152,203,74,153,120,194,136,89,49,138,145,29,52,4,150,198,84,131,85,76,51,103,40,33,174,21,196,208,97,91,27,6,32,34,46,54,3,162,47,184,130,155,107,147,20,26,151,159,128,56,57,5,207,13,17,7,41,94,118,66,190,45,65,172,186,164,39,117,142,114,199,24,197,80,158,132,10,213,200,185,88,123,165,160,139,23,96,67,127,86,104,115,30,69,108,140,206,53,180,166,182,18,148,169,175,144,109,22,157,193,37,1,2 18 | 81,59,211,159,122,183,120,38,42,153,186,148,114,83,172,4,167,94,136,103,119,130,22,112,75,62,39,161,184,70,196,191,160,203,207,142,198,26,30,90,195,193,45,109,162,6,49,199,131,53,147,64,77,164,187,34,21,3,110,206,180,23,40,202,149,84,213,177,190,41,36,105,129,201,12,117,171,141,200,31,128,46,102,126,24,158,111,95,78,188,58,44,80,124,66,79,139,107,29,138,179,113,163,121,54,52,18,192,134,86,69,9,15,72,157,106,65,89,99,118,93,61,155,116,165,182,181,209,11,82,176,50,1,25,5,185,28,92,104,60,47,156,48,8,127,73,91,67,197 19 | 17,125,2,123,131,115,12,196,36,10,184,148,45,77,162,108,87,80,6,130,76,57,164,92,81,59,200,9,22,170,34,23,74,72,46,163,214,43,157,114,129,142,75,144,206,185,48,118,172,33,52,140,209,174,19,203,150,97,136,7,175,193,190,90,96,24,107,189,98,32,146,37,188,25,158,79,161,180,191,165,110,205,28,14,105,84,202,177,83,13,159,147,51,63,133,91,192,195,55,124,169,88,39,117,128,93,152,181,85,42,89,53,208,82,197,18,102,149,69,11,116,66,213,78,178,160,94,47,122,61,176,35,201,26,3,62,99,199,104,111,113,155,138,44,49,1,119,58,60 20 | 126,33,75,187,129,137,185,99,59,45,152,175,157,161,72,139,199,52,171,148,2,86,151,66,1,40,167,26,5,121,193,213,20,173,89,114,110,191,95,36,194,131,53,35,8,30,162,181,145,214,104,41,118,206,113,67,81,212,48,106,84,107,105,144,42,85,32,182,13,98,102,170,34,120,27,125,93,37,156,192,146,203,25,82,124,116,200,178,14,108,190,3,188,60,147,68,179,207,127,58,128,57,122,174,208,149,6,96,19,92,71,195,77,132,205,10,209,153,61,43,109,184,44,180,134,18,100,87,11,54,158,97,76,7,38,31,211,166,183,88,168,70,91,9,130,177,50,62,163 21 | -------------------------------------------------------------------------------- /data-example/glass.csv: -------------------------------------------------------------------------------- 1 | 1.52101,13.64,4.49,1.1,71.78,0.06,8.75,0,0,1 2 | 1.51761,13.89,3.6,1.36,72.73,0.48,7.83,0,0,1 3 | 1.51618,13.53,3.55,1.54,72.99,0.39,7.78,0,0,1 4 | 1.51766,13.21,3.69,1.29,72.61,0.57,8.22,0,0,1 5 | 1.51742,13.27,3.62,1.24,73.08,0.55,8.07,0,0,1 6 | 1.51596,12.79,3.61,1.62,72.97,0.64,8.07,0,0.26,1 7 | 1.51743,13.3,3.6,1.14,73.09,0.58,8.17,0,0,1 8 | 1.51756,13.15,3.61,1.05,73.24,0.57,8.24,0,0,1 9 | 1.51918,14.04,3.58,1.37,72.08,0.56,8.3,0,0,1 10 | 1.51755,13,3.6,1.36,72.99,0.57,8.4,0,0.11,1 11 | 1.51571,12.72,3.46,1.56,73.2,0.67,8.09,0,0.24,1 12 | 1.51763,12.8,3.66,1.27,73.01,0.6,8.56,0,0,1 13 | 1.51589,12.88,3.43,1.4,73.28,0.69,8.05,0,0.24,1 14 | 1.51748,12.86,3.56,1.27,73.21,0.54,8.38,0,0.17,1 15 | 1.51763,12.61,3.59,1.31,73.29,0.58,8.5,0,0,1 16 | 1.51761,12.81,3.54,1.23,73.24,0.58,8.39,0,0,1 17 | 1.51784,12.68,3.67,1.16,73.11,0.61,8.7,0,0,1 18 | 1.52196,14.36,3.85,0.89,71.36,0.15,9.15,0,0,1 19 | 1.51911,13.9,3.73,1.18,72.12,0.06,8.89,0,0,1 20 | 1.51735,13.02,3.54,1.69,72.73,0.54,8.44,0,0.07,1 21 | 1.5175,12.82,3.55,1.49,72.75,0.54,8.52,0,0.19,1 22 | 1.51966,14.77,3.75,0.29,72.02,0.03,9,0,0,1 23 | 1.51736,12.78,3.62,1.29,72.79,0.59,8.7,0,0,1 24 | 1.51751,12.81,3.57,1.35,73.02,0.62,8.59,0,0,1 25 | 1.5172,13.38,3.5,1.15,72.85,0.5,8.43,0,0,1 26 | 1.51764,12.98,3.54,1.21,73,0.65,8.53,0,0,1 27 | 1.51793,13.21,3.48,1.41,72.64,0.59,8.43,0,0,1 28 | 1.51721,12.87,3.48,1.33,73.04,0.56,8.43,0,0,1 29 | 1.51768,12.56,3.52,1.43,73.15,0.57,8.54,0,0,1 30 | 1.51784,13.08,3.49,1.28,72.86,0.6,8.49,0,0,1 31 | 1.51768,12.65,3.56,1.3,73.08,0.61,8.69,0,0.14,1 32 | 1.51747,12.84,3.5,1.14,73.27,0.56,8.55,0,0,1 33 | 1.51775,12.85,3.48,1.23,72.97,0.61,8.56,0.09,0.22,1 34 | 1.51753,12.57,3.47,1.38,73.39,0.6,8.55,0,0.06,1 35 | 1.51783,12.69,3.54,1.34,72.95,0.57,8.75,0,0,1 36 | 1.51567,13.29,3.45,1.21,72.74,0.56,8.57,0,0,1 37 | 1.51909,13.89,3.53,1.32,71.81,0.51,8.78,0.11,0,1 38 | 1.51797,12.74,3.48,1.35,72.96,0.64,8.68,0,0,1 39 | 1.52213,14.21,3.82,0.47,71.77,0.11,9.57,0,0,1 40 | 1.52213,14.21,3.82,0.47,71.77,0.11,9.57,0,0,1 41 | 1.51793,12.79,3.5,1.12,73.03,0.64,8.77,0,0,1 42 | 1.51755,12.71,3.42,1.2,73.2,0.59,8.64,0,0,1 43 | 1.51779,13.21,3.39,1.33,72.76,0.59,8.59,0,0,1 44 | 1.5221,13.73,3.84,0.72,71.76,0.17,9.74,0,0,1 45 | 1.51786,12.73,3.43,1.19,72.95,0.62,8.76,0,0.3,1 46 | 1.519,13.49,3.48,1.35,71.95,0.55,9,0,0,1 47 | 1.51869,13.19,3.37,1.18,72.72,0.57,8.83,0,0.16,1 48 | 1.52667,13.99,3.7,0.71,71.57,0.02,9.82,0,0.1,1 49 | 1.52223,13.21,3.77,0.79,71.99,0.13,10.02,0,0,1 50 | 1.51898,13.58,3.35,1.23,72.08,0.59,8.91,0,0,1 51 | 1.5232,13.72,3.72,0.51,71.75,0.09,10.06,0,0.16,1 52 | 1.51926,13.2,3.33,1.28,72.36,0.6,9.14,0,0.11,1 53 | 1.51808,13.43,2.87,1.19,72.84,0.55,9.03,0,0,1 54 | 1.51837,13.14,2.84,1.28,72.85,0.55,9.07,0,0,1 55 | 1.51778,13.21,2.81,1.29,72.98,0.51,9.02,0,0.09,1 56 | 1.51769,12.45,2.71,1.29,73.7,0.56,9.06,0,0.24,1 57 | 1.51215,12.99,3.47,1.12,72.98,0.62,8.35,0,0.31,1 58 | 1.51824,12.87,3.48,1.29,72.95,0.6,8.43,0,0,1 59 | 1.51754,13.48,3.74,1.17,72.99,0.59,8.03,0,0,1 60 | 1.51754,13.39,3.66,1.19,72.79,0.57,8.27,0,0.11,1 61 | 1.51905,13.6,3.62,1.11,72.64,0.14,8.76,0,0,1 62 | 1.51977,13.81,3.58,1.32,71.72,0.12,8.67,0.69,0,1 63 | 1.52172,13.51,3.86,0.88,71.79,0.23,9.54,0,0.11,1 64 | 1.52227,14.17,3.81,0.78,71.35,0,9.69,0,0,1 65 | 1.52172,13.48,3.74,0.9,72.01,0.18,9.61,0,0.07,1 66 | 1.52099,13.69,3.59,1.12,71.96,0.09,9.4,0,0,1 67 | 1.52152,13.05,3.65,0.87,72.22,0.19,9.85,0,0.17,1 68 | 1.52152,13.05,3.65,0.87,72.32,0.19,9.85,0,0.17,1 69 | 1.52152,13.12,3.58,0.9,72.2,0.23,9.82,0,0.16,1 70 | 1.523,13.31,3.58,0.82,71.99,0.12,10.17,0,0.03,1 71 | 1.51574,14.86,3.67,1.74,71.87,0.16,7.36,0,0.12,2 72 | 1.51848,13.64,3.87,1.27,71.96,0.54,8.32,0,0.32,2 73 | 1.51593,13.09,3.59,1.52,73.1,0.67,7.83,0,0,2 74 | 1.51631,13.34,3.57,1.57,72.87,0.61,7.89,0,0,2 75 | 1.51596,13.02,3.56,1.54,73.11,0.72,7.9,0,0,2 76 | 1.5159,13.02,3.58,1.51,73.12,0.69,7.96,0,0,2 77 | 1.51645,13.44,3.61,1.54,72.39,0.66,8.03,0,0,2 78 | 1.51627,13,3.58,1.54,72.83,0.61,8.04,0,0,2 79 | 1.51613,13.92,3.52,1.25,72.88,0.37,7.94,0,0.14,2 80 | 1.5159,12.82,3.52,1.9,72.86,0.69,7.97,0,0,2 81 | 1.51592,12.86,3.52,2.12,72.66,0.69,7.97,0,0,2 82 | 1.51593,13.25,3.45,1.43,73.17,0.61,7.86,0,0,2 83 | 1.51646,13.41,3.55,1.25,72.81,0.68,8.1,0,0,2 84 | 1.51594,13.09,3.52,1.55,72.87,0.68,8.05,0,0.09,2 85 | 1.51409,14.25,3.09,2.08,72.28,1.1,7.08,0,0,2 86 | 1.51625,13.36,3.58,1.49,72.72,0.45,8.21,0,0,2 87 | 1.51569,13.24,3.49,1.47,73.25,0.38,8.03,0,0,2 88 | 1.51645,13.4,3.49,1.52,72.65,0.67,8.08,0,0.1,2 89 | 1.51618,13.01,3.5,1.48,72.89,0.6,8.12,0,0,2 90 | 1.5164,12.55,3.48,1.87,73.23,0.63,8.08,0,0.09,2 91 | 1.51841,12.93,3.74,1.11,72.28,0.64,8.96,0,0.22,2 92 | 1.51605,12.9,3.44,1.45,73.06,0.44,8.27,0,0,2 93 | 1.51588,13.12,3.41,1.58,73.26,0.07,8.39,0,0.19,2 94 | 1.5159,13.24,3.34,1.47,73.1,0.39,8.22,0,0,2 95 | 1.51629,12.71,3.33,1.49,73.28,0.67,8.24,0,0,2 96 | 1.5186,13.36,3.43,1.43,72.26,0.51,8.6,0,0,2 97 | 1.51841,13.02,3.62,1.06,72.34,0.64,9.13,0,0.15,2 98 | 1.51743,12.2,3.25,1.16,73.55,0.62,8.9,0,0.24,2 99 | 1.51689,12.67,2.88,1.71,73.21,0.73,8.54,0,0,2 100 | 1.51811,12.96,2.96,1.43,72.92,0.6,8.79,0.14,0,2 101 | 1.51655,12.75,2.85,1.44,73.27,0.57,8.79,0.11,0.22,2 102 | 1.5173,12.35,2.72,1.63,72.87,0.7,9.23,0,0,2 103 | 1.5182,12.62,2.76,0.83,73.81,0.35,9.42,0,0.2,2 104 | 1.52725,13.8,3.15,0.66,70.57,0.08,11.64,0,0,2 105 | 1.5241,13.83,2.9,1.17,71.15,0.08,10.79,0,0,2 106 | 1.52475,11.45,0,1.88,72.19,0.81,13.24,0,0.34,2 107 | 1.53125,10.73,0,2.1,69.81,0.58,13.3,3.15,0.28,2 108 | 1.53393,12.3,0,1,70.16,0.12,16.19,0,0.24,2 109 | 1.52222,14.43,0,1,72.67,0.1,11.52,0,0.08,2 110 | 1.51818,13.72,0,0.56,74.45,0,10.99,0,0,2 111 | 1.52664,11.23,0,0.77,73.21,0,14.68,0,0,2 112 | 1.52739,11.02,0,0.75,73.08,0,14.96,0,0,2 113 | 1.52777,12.64,0,0.67,72.02,0.06,14.4,0,0,2 114 | 1.51892,13.46,3.83,1.26,72.55,0.57,8.21,0,0.14,2 115 | 1.51847,13.1,3.97,1.19,72.44,0.6,8.43,0,0,2 116 | 1.51846,13.41,3.89,1.33,72.38,0.51,8.28,0,0,2 117 | 1.51829,13.24,3.9,1.41,72.33,0.55,8.31,0,0.1,2 118 | 1.51708,13.72,3.68,1.81,72.06,0.64,7.88,0,0,2 119 | 1.51673,13.3,3.64,1.53,72.53,0.65,8.03,0,0.29,2 120 | 1.51652,13.56,3.57,1.47,72.45,0.64,7.96,0,0,2 121 | 1.51844,13.25,3.76,1.32,72.4,0.58,8.42,0,0,2 122 | 1.51663,12.93,3.54,1.62,72.96,0.64,8.03,0,0.21,2 123 | 1.51687,13.23,3.54,1.48,72.84,0.56,8.1,0,0,2 124 | 1.51707,13.48,3.48,1.71,72.52,0.62,7.99,0,0,2 125 | 1.52177,13.2,3.68,1.15,72.75,0.54,8.52,0,0,2 126 | 1.51872,12.93,3.66,1.56,72.51,0.58,8.55,0,0.12,2 127 | 1.51667,12.94,3.61,1.26,72.75,0.56,8.6,0,0,2 128 | 1.52081,13.78,2.28,1.43,71.99,0.49,9.85,0,0.17,2 129 | 1.52068,13.55,2.09,1.67,72.18,0.53,9.57,0.27,0.17,2 130 | 1.5202,13.98,1.35,1.63,71.76,0.39,10.56,0,0.18,2 131 | 1.52177,13.75,1.01,1.36,72.19,0.33,11.14,0,0,2 132 | 1.52614,13.7,0,1.36,71.24,0.19,13.44,0,0.1,2 133 | 1.51813,13.43,3.98,1.18,72.49,0.58,8.15,0,0,2 134 | 1.518,13.71,3.93,1.54,71.81,0.54,8.21,0,0.15,2 135 | 1.51811,13.33,3.85,1.25,72.78,0.52,8.12,0,0,2 136 | 1.51789,13.19,3.9,1.3,72.33,0.55,8.44,0,0.28,2 137 | 1.51806,13,3.8,1.08,73.07,0.56,8.38,0,0.12,2 138 | 1.51711,12.89,3.62,1.57,72.96,0.61,8.11,0,0,2 139 | 1.51674,12.79,3.52,1.54,73.36,0.66,7.9,0,0,2 140 | 1.51674,12.87,3.56,1.64,73.14,0.65,7.99,0,0,2 141 | 1.5169,13.33,3.54,1.61,72.54,0.68,8.11,0,0,2 142 | 1.51851,13.2,3.63,1.07,72.83,0.57,8.41,0.09,0.17,2 143 | 1.51662,12.85,3.51,1.44,73.01,0.68,8.23,0.06,0.25,2 144 | 1.51709,13,3.47,1.79,72.72,0.66,8.18,0,0,2 145 | 1.5166,12.99,3.18,1.23,72.97,0.58,8.81,0,0.24,2 146 | 1.51839,12.85,3.67,1.24,72.57,0.62,8.68,0,0.35,2 147 | 1.51769,13.65,3.66,1.11,72.77,0.11,8.6,0,0,3 148 | 1.5161,13.33,3.53,1.34,72.67,0.56,8.33,0,0,3 149 | 1.5167,13.24,3.57,1.38,72.7,0.56,8.44,0,0.1,3 150 | 1.51643,12.16,3.52,1.35,72.89,0.57,8.53,0,0,3 151 | 1.51665,13.14,3.45,1.76,72.48,0.6,8.38,0,0.17,3 152 | 1.52127,14.32,3.9,0.83,71.5,0,9.49,0,0,3 153 | 1.51779,13.64,3.65,0.65,73,0.06,8.93,0,0,3 154 | 1.5161,13.42,3.4,1.22,72.69,0.59,8.32,0,0,3 155 | 1.51694,12.86,3.58,1.31,72.61,0.61,8.79,0,0,3 156 | 1.51646,13.04,3.4,1.26,73.01,0.52,8.58,0,0,3 157 | 1.51655,13.41,3.39,1.28,72.64,0.52,8.65,0,0,3 158 | 1.52121,14.03,3.76,0.58,71.79,0.11,9.65,0,0,3 159 | 1.51776,13.53,3.41,1.52,72.04,0.58,8.79,0,0,3 160 | 1.51796,13.5,3.36,1.63,71.94,0.57,8.81,0,0.09,3 161 | 1.51832,13.33,3.34,1.54,72.14,0.56,8.99,0,0,3 162 | 1.51934,13.64,3.54,0.75,72.65,0.16,8.89,0.15,0.24,3 163 | 1.52211,14.19,3.78,0.91,71.36,0.23,9.14,0,0.37,3 164 | 1.51514,14.01,2.68,3.5,69.89,1.68,5.87,2.2,0,4 165 | 1.51915,12.73,1.85,1.86,72.69,0.6,10.09,0,0,4 166 | 1.52171,11.56,1.88,1.56,72.86,0.47,11.41,0,0,4 167 | 1.52151,11.03,1.71,1.56,73.44,0.58,11.62,0,0,4 168 | 1.51969,12.64,0,1.65,73.75,0.38,11.53,0,0,4 169 | 1.51666,12.86,0,1.83,73.88,0.97,10.17,0,0,4 170 | 1.51994,13.27,0,1.76,73.03,0.47,11.32,0,0,4 171 | 1.52369,13.44,0,1.58,72.22,0.32,12.24,0,0,4 172 | 1.51316,13.02,0,3.04,70.48,6.21,6.96,0,0,4 173 | 1.51321,13,0,3.02,70.7,6.21,6.93,0,0,4 174 | 1.52043,13.38,0,1.4,72.25,0.33,12.5,0,0,4 175 | 1.52058,12.85,1.61,2.17,72.18,0.76,9.7,0.24,0.51,4 176 | 1.52119,12.97,0.33,1.51,73.39,0.13,11.27,0,0.28,4 177 | 1.51905,14,2.39,1.56,72.37,0,9.57,0,0,5 178 | 1.51937,13.79,2.41,1.19,72.76,0,9.77,0,0,5 179 | 1.51829,14.46,2.24,1.62,72.38,0,9.26,0,0,5 180 | 1.51852,14.09,2.19,1.66,72.67,0,9.32,0,0,5 181 | 1.51299,14.4,1.74,1.54,74.55,0,7.59,0,0,5 182 | 1.51888,14.99,0.78,1.74,72.5,0,9.95,0,0,5 183 | 1.51916,14.15,0,2.09,72.74,0,10.88,0,0,5 184 | 1.51969,14.56,0,0.56,73.48,0,11.22,0,0,5 185 | 1.51115,17.38,0,0.34,75.41,0,6.65,0,0,5 186 | 1.51131,13.69,3.2,1.81,72.81,1.76,5.43,1.19,0,6 187 | 1.51838,14.32,3.26,2.22,71.25,1.46,5.79,1.63,0,6 188 | 1.52315,13.44,3.34,1.23,72.38,0.6,8.83,0,0,6 189 | 1.52247,14.86,2.2,2.06,70.26,0.76,9.76,0,0,6 190 | 1.52365,15.79,1.83,1.31,70.43,0.31,8.61,1.68,0,6 191 | 1.51613,13.88,1.78,1.79,73.1,0,8.67,0.76,0,6 192 | 1.51602,14.85,0,2.38,73.28,0,8.76,0.64,0.09,6 193 | 1.51623,14.2,0,2.79,73.46,0.04,9.04,0.4,0.09,6 194 | 1.51719,14.75,0,2,73.02,0,8.53,1.59,0.08,6 195 | 1.51683,14.56,0,1.98,73.29,0,8.52,1.57,0.07,6 196 | 1.51545,14.14,0,2.68,73.39,0.08,9.07,0.61,0.05,6 197 | 1.51556,13.87,0,2.54,73.23,0.14,9.41,0.81,0.01,6 198 | 1.51727,14.7,0,2.34,73.28,0,8.95,0.66,0,6 199 | 1.51531,14.38,0,2.66,73.1,0.04,9.08,0.64,0,6 200 | 1.51609,15.01,0,2.51,73.05,0.05,8.83,0.53,0,6 201 | 1.51508,15.15,0,2.25,73.5,0,8.34,0.63,0,6 202 | 1.51653,11.95,0,1.19,75.18,2.7,8.93,0,0,6 203 | 1.51514,14.85,0,2.42,73.72,0,8.39,0.56,0,6 204 | 1.51658,14.8,0,1.99,73.11,0,8.28,1.71,0,6 205 | 1.51617,14.95,0,2.27,73.3,0,8.71,0.67,0,6 206 | 1.51732,14.95,0,1.8,72.99,0,8.61,1.55,0,6 207 | 1.51645,14.94,0,1.87,73.11,0,8.67,1.38,0,6 208 | 1.51831,14.39,0,1.82,72.86,1.41,6.47,2.88,0,6 209 | 1.5164,14.37,0,2.74,72.85,0,9.45,0.54,0,6 210 | 1.51623,14.14,0,2.88,72.61,0.08,9.18,1.06,0,6 211 | 1.51685,14.92,0,1.99,73.06,0,8.4,1.59,0,6 212 | 1.52065,14.36,0,2.02,73.42,0,8.44,1.64,0,6 213 | 1.51651,14.38,0,1.94,73.61,0,8.48,1.57,0,6 214 | 1.51711,14.23,0,2.08,73.36,0,8.62,1.67,0,6 215 | -------------------------------------------------------------------------------- /adv_cg.jl: -------------------------------------------------------------------------------- 1 | using Gurobi 2 | using Mosek 3 | 4 | # train adversarial method using constraint generation 5 | function train_adv_cg(X::Matrix, y::Vector, C::Real=1.0; 6 | perturb::Real=0.0, tol::Real=1e-6, psdtol::Real=1e-6, obj_reltol::Real=0.0, 7 | log::Real=0, n_thread::Int=0, solver::Symbol=:gurobi, verbose::Bool=true) 8 | 9 | n = length(y) 10 | # add one 11 | X1 = [ones(n) X]' # transpose 12 | m = size(X1, 1) 13 | 14 | # number of class 15 | n_c = maximum(y) 16 | n_f = n_c * m # number of features 17 | 18 | # parameters. init with zero 19 | w = zeros(n_f) 20 | 21 | alpha = zeros(n) 22 | 23 | # array of tuple 24 | constraints = Tuple{Integer, Vector}[] 25 | 26 | # prepare saved vars 27 | idmi = map(i -> idi(m, i), collect(1:n_c)) 28 | 29 | # precompute xi dot xj 30 | K = [( i >= j ? dot(view(X1,:,i), view(X1,:,j)) : 0.0 )::Float64 for i=1:n, j=1:n] 31 | K = [( i >= j ? K[i,j] : K[j,i] )::Float64 for i=1:n, j=1:n] 32 | 33 | if solver == :gurobi 34 | # gurobi solver 35 | # gurobi environtment 36 | env = Gurobi.Env() 37 | # Method : 0=primal simplex, 1=dual simplex, 2=barrier ; default for QP: barrier 38 | # Threads : default = 0 (use all threads) 39 | setparams!(env, PSDTol=psdtol, LogToConsole=log, Method=2, Threads=n_thread) 40 | elseif solver == :mosek 41 | # mosek environtment 42 | env = makeenv() 43 | end 44 | 45 | # params for Gurobi 46 | Q = zeros(0,0) 47 | nu = zeros(0) 48 | A = spzeros(n,0) # sparse matrix 49 | b = ones(n) * C 50 | 51 | Q_prev = zeros(0,0) 52 | nu_prev = zeros(0) 53 | A_prev = spzeros(n,0) # sparse matrix 54 | 55 | # additional for w 56 | L = zeros(n_f,0) 57 | L_prev = zeros(n_f,0) 58 | 59 | iter = 0 60 | dual_obj = 0.0 61 | dual_obj_prev = -Inf 62 | 63 | while true 64 | iter += 1 65 | 66 | if verbose 67 | println("Iteration : ", iter) 68 | tic(); 69 | end 70 | 71 | # previous constraints 72 | const_prev = copy(constraints) 73 | 74 | ## add to constraints 75 | const_added = Tuple{Integer, Vector}[] 76 | 77 | # find constraint for each sample 78 | for i=1:n 79 | psis = psi_list(w, X1, y, i, n_c, idmi) 80 | psis_id, val = best_psis(psis) # most violated constraints 81 | 82 | # current xi_i 83 | id_i = find(x -> x[1] == i, constraints) 84 | xi_i_list = map(x -> x[2], constraints[id_i]) 85 | max_xi_i = 0 86 | for j = 1:length(xi_i_list) 87 | a = calc_const(psis::Vector, xi_i_list[j]) 88 | if a > max_xi_i 89 | max_xi_i = a 90 | end 91 | end 92 | 93 | if val > max_xi_i 94 | cs = (i, sort!(psis_id)) 95 | if findfirst(constraints .== cs) == 0 96 | push!(constraints, cs) 97 | push!(const_added, cs) 98 | end 99 | end 100 | end 101 | 102 | # if no constraints added 103 | if length(const_added) == 0 104 | break 105 | end 106 | 107 | n_const = length(constraints) 108 | 109 | #### Start QP ### 110 | 111 | if verbose 112 | println(">> Start QP") 113 | toc(); 114 | tic(); 115 | end 116 | 117 | n_prev = length(const_prev) 118 | n_added = length(const_added) 119 | 120 | Q_aug = [ ( calc_dot((const_prev[i][1], const_prev[i][2], const_added[j][1], const_added[j][2]), K, y) )::Float64 121 | for i=1:n_prev, j=1:n_added] 122 | 123 | Q_aug_diag = [ 124 | ( i >= j ? calc_dot((const_added[i][1], const_added[i][2], const_added[j][1], const_added[j][2]), K, y) : 0.0 )::Float64 125 | for i=1:n_added, j=1:n_added] 126 | Q_aug_diag = [ (i >= j ? Q_aug_diag[i,j] : Q_aug_diag[j,i] )::Float64 for i=1:n_added, j=1:n_added ] 127 | 128 | Q = [ 129 | ( (i <= n_prev && j <= n_prev) ? Q_prev[i, j] : ( i <= n_prev ? Q_aug[i, j-n_prev] : 130 | ( j <= n_prev ? Q_aug[j, i-n_prev] : Q_aug_diag[i-n_prev, j-n_prev] ) ) )::Float64 131 | for i=1:n_const, j=1:n_const] 132 | 133 | nu_aug = [ ( calc_cconst(const_added[i][2]) )::Float64 for i=1:n_added ] 134 | nu = [ ( i <= n_prev ? nu_prev[i] : nu_aug[i-n_prev] )::Float64 for i=1:n_const] 135 | 136 | A_aug = spzeros(n, n_added) 137 | for j = 1:n_added 138 | A_aug[const_added[j][1], j] = 1.0 139 | end 140 | A = [A_prev A_aug] 141 | 142 | ## add perturbation 143 | for i=1:n_const 144 | Q[i,i] = Q[i,i] + perturb 145 | end 146 | 147 | if verbose 148 | toc(); 149 | tic(); 150 | end 151 | 152 | if solver == :gurobi 153 | 154 | if verbose println(">> Optim :: Gurobi") end 155 | 156 | ## init model 157 | model = gurobi_model(env, 158 | sense = :minimize, 159 | H = Q, 160 | f = -nu, 161 | A = A, 162 | b = b, 163 | lb = zeros(n_const) 164 | ) 165 | # Print the model to check correctness 166 | # print(model) 167 | 168 | # Solve with Gurobi 169 | Gurobi.optimize(model) 170 | 171 | 172 | if verbose 173 | toc(); 174 | println("<< End QP") 175 | end 176 | 177 | dual_obj = -get_objval(model) 178 | # Solution 179 | if verbose println("Objective value: ", dual_obj) end 180 | 181 | # get alpha 182 | alpha = get_solution(model) 183 | 184 | if verbose println("n constraints = ", length(constraints)) end 185 | 186 | ### end QP ### 187 | 188 | elseif solver == :mosek 189 | 190 | if verbose println(">> Optim :: Mosek") end 191 | 192 | task = maketask(env) 193 | 194 | # set params 195 | putintparam(task, Mosek.MSK_IPAR_LOG, 1) 196 | putintparam(task, Mosek.MSK_IPAR_LOG_CHECK_CONVEXITY, 1) 197 | putdouparam(task, Mosek.MSK_DPAR_CHECK_CONVEXITY_REL_TOL, psdtol) 198 | 199 | # variables 200 | appendvars(task, n_const) 201 | # bound on var 202 | for i::Int32 = 1:n_const 203 | putbound(task, Mosek.MSK_ACC_VAR, i, Mosek.MSK_BK_RA, 0.0, C) 204 | end 205 | 206 | # objective 207 | putobjsense(task, Mosek.MSK_OBJECTIVE_SENSE_MINIMIZE) 208 | qi = zeros(Int32, (n_const * (n_const+1)) ÷ 2 ) 209 | qj = zeros(Int32, (n_const * (n_const+1)) ÷ 2 ) 210 | qv = zeros(Float64, (n_const * (n_const+1)) ÷ 2 ) 211 | ix = 1 212 | for j::Int32 = 1:n_const 213 | for i::Int32 = j:n_const 214 | qi[ix] = i 215 | qj[ix] = j 216 | qv[ix] = Q[i,j] 217 | ix += 1 218 | end 219 | end 220 | putqobj(task, qi, qj, qv) 221 | 222 | putclist(task, collect(1:n_const), -nu) 223 | 224 | # constraints 225 | ## sparse array 226 | appendcons(task, n) 227 | for i::Int32 = 1:n_const 228 | id_nz = A[:,i].nzind 229 | putacol(task, i, id_nz, ones(length(id_nz))) 230 | end 231 | 232 | for i::Int32 = 1:n 233 | putbound(task, Mosek.MSK_ACC_CON, i, Mosek.MSK_BK_RA, 0.0, C) 234 | end 235 | 236 | if verbose 237 | toc(); tic(); 238 | end 239 | 240 | Mosek.optimize(task) 241 | 242 | if verbose 243 | toc(); 244 | println("<< End QP") 245 | end 246 | 247 | # Solution 248 | dual_obj, _ = getsolutioninfo(task, Mosek.MSK_SOL_ITR) 249 | if verbose println("Objective value: ", -dual_obj) end 250 | 251 | # get alpha 252 | alpha = getxx(task, Mosek.MSK_SOL_ITR) 253 | 254 | if verbose println("n constraints = ", length(constraints)) end 255 | end 256 | 257 | L_aug = zeros(n_f, n_added) 258 | for i=1:n_added 259 | L_aug[:, i] = calc_dconst(const_added[i], X1, y, n_c, idmi) 260 | end 261 | L = [L_prev L_aug] 262 | 263 | # recover w 264 | w = zeros(n_f) 265 | for i = 1:n_const 266 | w -= alpha[i] * L[:, i] 267 | end 268 | 269 | if obj_reltol > 0.0 270 | if (dual_obj - dual_obj_prev) / dual_obj_prev < obj_reltol && iter > 1 271 | if verbose println(">> Iteration STOPPED | Objective relative tolerance : ", obj_reltol) end 272 | break 273 | end 274 | end 275 | 276 | if verbose println() end 277 | 278 | Q_prev = Q 279 | nu_prev = nu 280 | A_prev = A 281 | L_prev = L 282 | dual_obj_prev = dual_obj 283 | end 284 | 285 | # finalizing losses 286 | gv_aug = zeros(n) 287 | gv_01 = zeros(n) 288 | l_adv = zeros(n) 289 | l_01 = zeros(n) 290 | for i=1:n 291 | psis = psi_list(w, X1, y, i, n_c, idmi) 292 | psis_id, val = best_psis(psis) # most violated constraints 293 | n_ps = length(psis_id) 294 | 295 | gv_aug[i] = val 296 | 297 | # compute probs 298 | p_hat = zeros(n_c) 299 | p_check = zeros(n_c) 300 | for j=1:n_c 301 | if j in psis_id 302 | p_hat[j] = ( (n_ps-1.0)*psis[j] - sum(psis[psis_id[psis_id .!= j]]) + 1.0 ) / n_ps 303 | p_check[j] = 1.0 / n_ps 304 | else 305 | p_hat[j] = 0.0 306 | p_check[j] = 0.0 307 | end 308 | end 309 | 310 | C01 = 1 - eye(n_c) # 01 loss matrix 311 | v = p_hat' * C01 * p_check # the result is vector size 1 not a number 312 | gv_01[i] = v[1] # 313 | 314 | ## training loss 315 | l_adv[i] = 1.0 - p_hat[y[i]] 316 | l_01[i] = 1.0 - round(Int, indmax(p_hat) == y[i]) 317 | end 318 | 319 | game_value_01 = mean(gv_01) 320 | game_value_augmented = mean(gv_aug) 321 | 322 | # create model 323 | adv_model = MultiAdversarialModel(w, alpha, constraints, n_c, game_value_01, game_value_augmented, mean(l_adv), mean(l_01)) 324 | 325 | return adv_model::MultiAdversarialModel 326 | end 327 | 328 | 329 | function predict_adv(model::MultiAdversarialModel, X_test::Matrix) 330 | 331 | w = model.w 332 | n_c = model.n_class 333 | n = size(X_test, 1) 334 | 335 | X1 = [ones(n) X_test]' # transpose 336 | m = size(X1, 1) 337 | 338 | # prepare saved vars 339 | idmi = map(i -> idi(m, i), collect(1:n_c)) 340 | 341 | prob = zeros(n, n_c) 342 | pred = zeros(n) 343 | for i=1:n 344 | psis = psi_list(w, X1, i, n_c, idmi) 345 | psis_id, val = best_psis(psis) # most violated constraints 346 | n_ps = length(psis_id) 347 | 348 | for j=1:n_c 349 | if j in psis_id 350 | prob[i,j] = ( (n_ps-1.0)*psis[j] - sum(psis[psis_id[psis_id .!= j]]) + 1.0 ) / n_ps 351 | else 352 | prob[i,j] = 0.0 353 | end 354 | end 355 | 356 | pred[i] = indmax(psis) 357 | 358 | end 359 | 360 | return prob::Matrix{Float64}, pred::Vector{Float64} 361 | end 362 | 363 | function test_adv(model::MultiAdversarialModel, X_test::Matrix, y_test::Vector) 364 | n = size(X_test, 1) 365 | 366 | y_prob, y_pred = predict_adv(model, X_test) 367 | 368 | # calculate testing loss 369 | losses = zeros(n) 370 | losses01 = zeros(n) 371 | for i=1:n 372 | losses[i] = 1.0 - y_prob[i, y_test[i]] 373 | losses01[i] = 1.0 - round(Int, y_pred[i] == y_test[i]) 374 | end 375 | 376 | loss = sum(losses) / n 377 | loss01 = sum(losses01) / n 378 | 379 | return loss::Float64, losses::Vector{Float64}, loss01::Float64, losses01::Vector{Float64}, 380 | y_prob::Matrix{Float64}, y_pred::Vector{Float64} 381 | 382 | end 383 | -------------------------------------------------------------------------------- /adv_kernel_cg.jl: -------------------------------------------------------------------------------- 1 | using Gurobi 2 | using Mosek 3 | 4 | # train adversarial method with kernel using constraint generation 5 | function train_adv_kernel_cg(X::Matrix, y::Vector, C::Real=1.0, 6 | kernel::Symbol=:linear, kernel_params::Vector=[]; 7 | perturb::Real=0.0, tol::Real=1e-6, psdtol::Real=1e-6, obj_reltol::Real=0.0, 8 | log::Real=0, n_thread::Int=0, solver::Symbol=:gurobi, verbose::Bool=true) 9 | 10 | n = length(y) 11 | # add one 12 | X1 = [ones(n) X]' # transpose 13 | m = size(X1, 1) 14 | 15 | # number of class 16 | n_c = maximum(y) 17 | n_f = n_c * m # number of features 18 | 19 | alpha = zeros(0) 20 | 21 | # array of tuple 22 | constraints = Tuple{Integer, Vector}[] 23 | 24 | # prepare saved vars 25 | idmi = map(i -> idi(m, i), collect(1:n_c)) 26 | 27 | # kernel 28 | kernel_func = linear_kernel 29 | if kernel == :gaussian 30 | kernel_func = gaussian_kernel 31 | elseif kernel == :polynomial 32 | kernel_func = polynomial_kernel 33 | end 34 | 35 | # precompute kernel 36 | K = [( i >= j ? kernel_func(X1[:,i], X1[:,j], kernel_params...) : 0.0 )::Float64 for i=1:n, j=1:n] 37 | K = [( i >= j ? K[i,j] : K[j,i] )::Float64 for i=1:n, j=1:n] 38 | 39 | if solver == :gurobi 40 | # gurobi solver 41 | # gurobi environtment 42 | env = Gurobi.Env() 43 | # Method : 0=primal simplex, 1=dual simplex, 2=barrier ; default for QP: barrier 44 | # Threads : default = 0 (use all threads) 45 | setparams!(env, PSDTol=psdtol, LogToConsole=log, Method=2, Threads=n_thread) 46 | elseif solver == :mosek 47 | # mosek environtment 48 | env = makeenv() 49 | end 50 | 51 | # params for Gurobi 52 | Q = zeros(0,0) 53 | nu = zeros(0) 54 | A = spzeros(n, 0) # sparse matrix 55 | b = ones(n) * C 56 | 57 | Q_prev = zeros(0,0) 58 | nu_prev = zeros(0) 59 | A_prev = spzeros(n, 0) # sparse matrix 60 | 61 | # save lambda dot psi_i 62 | n_lps = n * n_c 63 | LPsi = zeros(n_lps, 0) 64 | LPsi_prev = zeros(n_lps, 0) 65 | 66 | iter = 0 67 | dual_obj = 0.0 68 | dual_obj_prev = 0.0 69 | 70 | while true 71 | iter += 1 72 | 73 | if verbose 74 | println("Iteration : ", iter) 75 | tic(); 76 | end 77 | 78 | # previous constraints 79 | const_prev = copy(constraints) 80 | # previous alpha 81 | alpha_prev = copy(alpha) 82 | 83 | ## add to constraints 84 | const_added = Tuple{Integer, Vector}[] 85 | 86 | # find constraint for each sample 87 | for i=1:n 88 | psis = psi_list_dual(alpha, LPsi, i, n_c) 89 | psis_id, val = best_psis(psis) # most violated constraints 90 | 91 | # current xi_i 92 | id_i = find(x -> x[1] == i, constraints) 93 | xi_i_list = map(x -> x[2], constraints[id_i]) 94 | max_xi_i = 0 95 | for j = 1:length(xi_i_list) 96 | a = calc_const(psis::Vector, xi_i_list[j]) 97 | if a > max_xi_i 98 | max_xi_i = a 99 | end 100 | end 101 | 102 | if val > max_xi_i 103 | cs = (i, sort!(psis_id)) 104 | if findfirst(constraints .== cs) == 0 105 | push!(constraints, cs) 106 | push!(const_added, cs) 107 | end 108 | end 109 | end 110 | 111 | # if no constraints added 112 | if length(const_added) == 0 113 | break 114 | end 115 | 116 | n_const = length(constraints) 117 | 118 | #### Start QP ### 119 | 120 | if verbose 121 | println(">> Start QP") 122 | toc(); 123 | tic(); 124 | end 125 | 126 | n_prev = length(const_prev) 127 | n_added = length(const_added) 128 | 129 | # init alpha with previous iteration 130 | alpha = zeros(n_const) 131 | alpha[1:n_prev] = alpha_prev 132 | 133 | Q_aug = [ ( calc_dot((const_prev[i][1], const_prev[i][2], const_added[j][1], const_added[j][2]), K, y) )::Float64 134 | for i=1:n_prev, j=1:n_added] 135 | 136 | Q_aug_diag = [ 137 | ( i >= j ? calc_dot((const_added[i][1], const_added[i][2], const_added[j][1], const_added[j][2]), K, y) : 0.0 )::Float64 138 | for i=1:n_added, j=1:n_added] 139 | Q_aug_diag = [ (i >= j ? Q_aug_diag[i,j] : Q_aug_diag[j,i] )::Float64 for i=1:n_added, j=1:n_added ] 140 | 141 | Q = [ 142 | ( (i <= n_prev && j <= n_prev) ? Q_prev[i, j] : ( i <= n_prev ? Q_aug[i, j-n_prev] : 143 | ( j <= n_prev ? Q_aug[j, i-n_prev] : Q_aug_diag[i-n_prev, j-n_prev] ) ) )::Float64 144 | for i=1:n_const, j=1:n_const] 145 | 146 | nu_aug = [ ( calc_cconst(const_added[i][2]) )::Float64 for i=1:n_added ] 147 | nu = [ ( i <= n_prev ? nu_prev[i] : nu_aug[i-n_prev] )::Float64 for i=1:n_const] 148 | 149 | A_aug = spzeros(n, n_added) 150 | for j = 1:n_added 151 | A_aug[const_added[j][1], j] = 1.0 152 | end 153 | A = [A_prev A_aug] 154 | 155 | # update LPsi 156 | LPsi_aug = [ ( calc_dot((ceil(Int64, i/n_c), [(i%n_c==0)? n_c:i%n_c], const_added[j][1], const_added[j][2]), K, y) )::Float64 157 | for i=1:n_lps, j=1:n_added] 158 | LPsi = [LPsi_prev LPsi_aug] 159 | 160 | ## add perturbation 161 | for i=1:n_const 162 | Q[i,i] = Q[i,i] + perturb 163 | end 164 | 165 | if verbose 166 | toc(); 167 | tic(); 168 | end 169 | 170 | if solver == :gurobi 171 | 172 | if verbose println(">> Optim :: Gurobi") end 173 | 174 | ## init model 175 | model = gurobi_model(env, 176 | sense = :minimize, 177 | H = Q, 178 | f = -nu, 179 | A = A, 180 | b = b, 181 | lb = zeros(n_const) 182 | ) 183 | # Print the model to check correctness 184 | # print(model) 185 | 186 | # Solve with Gurobi 187 | Gurobi.optimize(model) 188 | 189 | if verbose 190 | toc(); 191 | println("<< End QP") 192 | end 193 | 194 | dual_obj = -get_objval(model) 195 | # Solution 196 | if verbose println("Objective value: ", dual_obj) end 197 | 198 | # get alpha 199 | alpha = get_solution(model) 200 | # println("alpha = ", alpha) 201 | # 202 | # println("constraints = ", constraints) 203 | 204 | if verbose println("n constraints = ", length(constraints)) end 205 | 206 | ### end QP ### 207 | 208 | elseif solver == :mosek 209 | 210 | if verbose println(">> Optim :: Mosek") end 211 | 212 | task = maketask(env) 213 | 214 | # set params 215 | putintparam(task, Mosek.MSK_IPAR_LOG, 1) 216 | putintparam(task, Mosek.MSK_IPAR_LOG_CHECK_CONVEXITY, 1) 217 | putdouparam(task, Mosek.MSK_DPAR_CHECK_CONVEXITY_REL_TOL, psdtol) 218 | 219 | # variables 220 | appendvars(task, n_const) 221 | # bound on var 222 | for i::Int32 = 1:n_const 223 | putbound(task, Mosek.MSK_ACC_VAR, i, Mosek.MSK_BK_RA, 0.0, C) 224 | end 225 | 226 | # objective 227 | putobjsense(task, Mosek.MSK_OBJECTIVE_SENSE_MINIMIZE) 228 | qi = zeros(Int32, (n_const * (n_const+1)) ÷ 2 ) 229 | qj = zeros(Int32, (n_const * (n_const+1)) ÷ 2 ) 230 | qv = zeros(Float64, (n_const * (n_const+1)) ÷ 2 ) 231 | ix = 1 232 | for j::Int32 = 1:n_const 233 | for i::Int32 = j:n_const 234 | qi[ix] = i 235 | qj[ix] = j 236 | qv[ix] = Q[i,j] 237 | ix += 1 238 | end 239 | end 240 | putqobj(task, qi, qj, qv) 241 | 242 | putclist(task, collect(1:n_const), -nu) 243 | 244 | # constraints 245 | ## sparse array 246 | appendcons(task, n) 247 | for i::Int32 = 1:n_const 248 | id_nz = A[:,i].nzind 249 | putacol(task, i, id_nz, ones(length(id_nz))) 250 | end 251 | 252 | for i::Int32 = 1:n 253 | putbound(task, Mosek.MSK_ACC_CON, i, Mosek.MSK_BK_RA, 0.0, C) 254 | end 255 | 256 | if verbose 257 | toc(); tic(); 258 | end 259 | 260 | Mosek.optimize(task) 261 | 262 | if verbose 263 | toc(); 264 | println("<< End QP") 265 | end 266 | 267 | # Solution 268 | dual_obj, _ = getsolutioninfo(task, Mosek.MSK_SOL_ITR) 269 | if verbose println("Objective value: ", -dual_obj) end 270 | 271 | # get alpha 272 | alpha = getxx(task, Mosek.MSK_SOL_ITR) 273 | 274 | if verbose println("n constraints = ", length(constraints)) end 275 | end 276 | 277 | if obj_reltol > 0.0 278 | if (dual_obj - dual_obj_prev) / dual_obj_prev < obj_reltol && iter > 1 279 | if verbose 280 | println((dual_obj - dual_obj_prev) / dual_obj_prev) 281 | println(">> Iteration STOPPED | Objective relative tolerance : ", obj_reltol) 282 | end 283 | break 284 | end 285 | end 286 | 287 | if verbose println() end 288 | 289 | Q_prev = Q 290 | nu_prev = nu 291 | A_prev = A 292 | LPsi_prev = LPsi 293 | dual_obj_prev = dual_obj 294 | end 295 | 296 | # finalizing losses 297 | gv_aug = zeros(n) 298 | gv_01 = zeros(n) 299 | l_adv = zeros(n) 300 | l_01 = zeros(n) 301 | for i=1:n 302 | psis = psi_list_dual(alpha, LPsi, i, n_c) 303 | psis_id, val = best_psis(psis) # most violated constraints 304 | n_ps = length(psis_id) 305 | 306 | gv_aug[i] = val 307 | 308 | # compute probs 309 | p_hat = zeros(n_c) 310 | p_check = zeros(n_c) 311 | for j=1:n_c 312 | if j in psis_id 313 | p_hat[j] = ( (n_ps-1.0)*psis[j] - sum(psis[psis_id[psis_id .!= j]]) + 1.0 ) / n_ps 314 | p_check[j] = 1.0 / n_ps 315 | else 316 | p_hat[j] = 0.0 317 | p_check[j] = 0.0 318 | end 319 | end 320 | 321 | C01 = 1 - eye(n_c) # 01 loss matrix 322 | v = p_hat' * C01 * p_check # the result is vector size 1 not a number 323 | gv_01[i] = v[1] # 324 | 325 | ## training loss 326 | l_adv[i] = 1.0 - p_hat[y[i]] 327 | l_01[i] = 1.0 - round(Int, indmax(p_hat) == y[i]) 328 | end 329 | 330 | game_value_01 = mean(gv_01) 331 | game_value_augmented = mean(gv_aug) 332 | 333 | # create model 334 | adv_model = KernelMultiAdversarialModel(kernel, kernel_params, alpha, constraints, n_c, game_value_01, game_value_augmented, mean(l_adv), mean(l_01)) 335 | 336 | return adv_model::KernelMultiAdversarialModel 337 | end 338 | 339 | function predict_adv_kernel(model::KernelMultiAdversarialModel, X_test::Matrix, X_train::Matrix, y_train::Vector) 340 | 341 | alpha = model.alpha 342 | n_c = model.n_class 343 | n = size(X_test, 1) 344 | 345 | constraints = model.constraints 346 | n_const = length(alpha) 347 | 348 | X1 = [ones(n) X_test]' # transpose 349 | m = size(X1, 1) 350 | 351 | # training data 352 | n_tr = size(X_train, 1) 353 | X1_tr = [ones(n_tr) X_train]' # transpose 354 | 355 | # kernel 356 | kernel = model.kernel 357 | kernel_params = model.kernel_params 358 | # kernel function 359 | kernel_func = linear_kernel 360 | if kernel == :gaussian 361 | kernel_func = gaussian_kernel 362 | elseif kernel == :polynomial 363 | kernel_func = polynomial_kernel 364 | end 365 | 366 | # compute Kernel 367 | K = [ kernel_func(X1_tr[:,i], X1[:,j], kernel_params...)::Float64 for i=1:n_tr, j=1:n] 368 | 369 | prob = zeros(n, n_c) 370 | pred = zeros(n) 371 | for i=1:n 372 | 373 | psis = zeros(n_c) 374 | for j = 1:n_const 375 | psis -= alpha[j] * calc_dotlphi( (constraints[j][1], constraints[j][2], i), K, y_train, n_c) 376 | end 377 | 378 | psis_id, val = best_psis(psis) # most violated constraints 379 | n_ps = length(psis_id) 380 | 381 | for j=1:n_c 382 | if j in psis_id 383 | prob[i,j] = ( (n_ps-1.0)*psis[j] - sum(psis[psis_id[psis_id .!= j]]) + 1.0 ) / n_ps 384 | else 385 | prob[i,j] = 0.0 386 | end 387 | end 388 | 389 | pred[i] = indmax(psis) 390 | 391 | end 392 | 393 | return prob::Matrix{Float64}, pred::Vector{Float64} 394 | end 395 | 396 | function test_adv_kernel(model::KernelMultiAdversarialModel, X_test::Matrix, y_test::Vector, X_train::Matrix, y_train::Vector) 397 | n = size(X_test, 1) 398 | 399 | y_prob, y_pred = predict_adv_kernel(model, X_test, X_train, y_train) 400 | 401 | # calculate testing loss 402 | losses = zeros(n) 403 | losses01 = zeros(n) 404 | for i=1:n 405 | losses[i] = 1.0 - y_prob[i, y_test[i]] 406 | losses01[i] = 1.0 - round(Int, y_pred[i] == y_test[i]) 407 | end 408 | 409 | loss = sum(losses) / n 410 | loss01 = sum(losses01) / n 411 | 412 | return loss::Float64, losses::Vector{Float64}, loss01::Float64, losses01::Vector{Float64}, 413 | y_prob::Matrix{Float64}, y_pred::Vector{Float64} 414 | 415 | end 416 | --------------------------------------------------------------------------------