├── Full.pdf ├── cache ├── covtype_Fmin ├── ijcnn1_Fmin ├── ijcnn1_SGD_white_gm ├── mnist_SAGA_white_gm ├── mnist_SGD_white_gm ├── covtype_SAGA_white_gm ├── covtype_SARAH_white_gm ├── covtype_SGD_white_Krum ├── covtype_SGD_white_gm ├── covtype_SGD_white_mean ├── covtype_SVRG_white_gm ├── ijcnn1_SAGA_white_Krum ├── ijcnn1_SAGA_white_gm ├── ijcnn1_SAGA_white_mean ├── ijcnn1_SARAH_white_gm ├── ijcnn1_SGD_ZV_white_gm ├── ijcnn1_SGD_baseline_gm ├── ijcnn1_SGD_maxValue_gm ├── ijcnn1_SGD_white_Krum ├── ijcnn1_SGD_white_mean ├── ijcnn1_SVRG_white_Krum ├── ijcnn1_SVRG_white_gm ├── ijcnn1_SVRG_white_mean ├── mnist_SAGA_baseline_gm ├── mnist_SAGA_maxValue_gm ├── mnist_SAGA_white_mean ├── mnist_SGD_baseline_gm ├── mnist_SGD_maxValue_gm ├── mnist_SGD_white_mean ├── covtype_SAGA_baseline_gm ├── covtype_SAGA_maxValue_gm ├── covtype_SAGA_white_Krum ├── covtype_SAGA_white_mean ├── covtype_SARAH_white_Krum ├── covtype_SARAH_white_mean ├── covtype_SGD_baseline_gm ├── covtype_SGD_maxValue_gm ├── covtype_SGD_white_median ├── covtype_SVRG_baseline_gm ├── covtype_SVRG_maxValue_gm ├── covtype_SVRG_white_Krum ├── covtype_SVRG_white_mean ├── ijcnn1_BatchSGD_white_gm ├── ijcnn1_SAGA_ZV_white_gm ├── ijcnn1_SAGA_baseline_gm ├── ijcnn1_SAGA_maxValue_gm ├── ijcnn1_SAGA_white_median ├── ijcnn1_SARAH_baseline_gm ├── ijcnn1_SARAH_maxValue_gm ├── ijcnn1_SARAH_white_Krum ├── ijcnn1_SARAH_white_mean ├── ijcnn1_SGD_baseline_Krum ├── ijcnn1_SGD_baseline_krum ├── ijcnn1_SGD_baseline_mean ├── ijcnn1_SGD_maxValue_Krum ├── ijcnn1_SGD_maxValue_mean ├── ijcnn1_SGD_white_median ├── ijcnn1_SVRG_baseline_gm ├── ijcnn1_SVRG_maxValue_gm ├── ijcnn1_SVRG_white_median ├── mnist_BatchSGD_white_gm ├── mnist_SAGA_baseline_mean ├── mnist_SAGA_maxValue_mean ├── mnist_SGD_baseline_mean ├── mnist_SGD_maxValue_mean ├── covtype_BatchSGD_white_Krum ├── covtype_BatchSGD_white_gm ├── covtype_BatchSGD_white_mean ├── covtype_SAGA_baseline_Krum ├── covtype_SAGA_baseline_mean ├── covtype_SAGA_maxValue_Krum ├── covtype_SAGA_maxValue_mean ├── covtype_SAGA_white_median ├── covtype_SARAH_baseline_Krum ├── covtype_SARAH_baseline_gm ├── covtype_SARAH_baseline_mean ├── covtype_SARAH_maxValue_Krum ├── covtype_SARAH_maxValue_gm ├── covtype_SARAH_maxValue_mean ├── covtype_SARAH_white_median ├── covtype_SGD_baseline_Krum ├── covtype_SGD_baseline_mean ├── covtype_SGD_baseline_median ├── covtype_SGD_maxValue_Krum ├── covtype_SGD_maxValue_mean ├── covtype_SGD_maxValue_median ├── covtype_SGD_zeroGradient_gm ├── covtype_SVRG_baseline_Krum ├── covtype_SVRG_baseline_mean ├── covtype_SVRG_maxValue_Krum ├── covtype_SVRG_maxValue_mean ├── covtype_SVRG_white_median ├── ijcnn1_BatchSGD_ZV_white_gm ├── ijcnn1_BatchSGD_baseline_gm ├── ijcnn1_BatchSGD_maxValue_gm ├── ijcnn1_BatchSGD_white_mean ├── ijcnn1_SAGA_ZV_baseline_gm ├── ijcnn1_SAGA_ZV_maxValue_gm ├── ijcnn1_SAGA_baseline_Krum ├── ijcnn1_SAGA_baseline_mean ├── ijcnn1_SAGA_baseline_median ├── ijcnn1_SAGA_maxValue_Krum ├── ijcnn1_SAGA_maxValue_mean ├── ijcnn1_SAGA_maxValue_median ├── ijcnn1_SAGA_zeroGradient_gm ├── ijcnn1_SARAH_baseline_Krum ├── ijcnn1_SARAH_baseline_mean ├── ijcnn1_SARAH_maxValue_Krum ├── ijcnn1_SARAH_maxValue_mean ├── ijcnn1_SARAH_white_median ├── ijcnn1_SGD_ZV_baseline_gm ├── ijcnn1_SGD_ZV_maxValue_gm ├── ijcnn1_SGD_baseline_median ├── ijcnn1_SGD_maxValue_median ├── ijcnn1_SGD_zeroGradient_gm ├── ijcnn1_SVRG_baseline_Krum ├── ijcnn1_SVRG_baseline_mean ├── ijcnn1_SVRG_baseline_median ├── ijcnn1_SVRG_maxValue_Krum ├── ijcnn1_SVRG_maxValue_mean ├── ijcnn1_SVRG_maxValue_median ├── ijcnn1_SVRG_zeroGradient_gm ├── mnist_BatchSGD_baseline_gm ├── mnist_BatchSGD_maxValue_gm ├── mnist_BatchSGD_white_mean ├── mnist_SAGA_zeroGradient_gm ├── mnist_SGD_zeroGradient_gm ├── mnist_SGD_zeroGradient_mean ├── covtype_BatchSGD_baseline_gm ├── covtype_BatchSGD_maxValue_gm ├── covtype_BatchSGD_white_median ├── covtype_SAGA_baseline_median ├── covtype_SAGA_maxValue_median ├── covtype_SAGA_zeroGradient_gm ├── covtype_SARAH_baseline_median ├── covtype_SARAH_maxValue_median ├── covtype_SARAH_zeroGradient_gm ├── covtype_SGD_zeroGradient_Krum ├── covtype_SGD_zeroGradient_mean ├── covtype_SVRG_baseline_median ├── covtype_SVRG_maxValue_median ├── covtype_SVRG_zeroGradient_gm ├── ijcnn1_BatchSGD_baseline_mean ├── ijcnn1_BatchSGD_maxValue_mean ├── ijcnn1_SAGA_ZV_baseline_mean ├── ijcnn1_SAGA_zeroGradient_Krum ├── ijcnn1_SAGA_zeroGradient_mean ├── ijcnn1_SARAH_baseline_median ├── ijcnn1_SARAH_maxValue_median ├── ijcnn1_SARAH_zeroGradient_gm ├── ijcnn1_SGD_SAGA_cmpVar_white ├── ijcnn1_SGD_ZV_zeroGradient_gm ├── ijcnn1_SGD_zeroGradient_Krum ├── ijcnn1_SGD_zeroGradient_mean ├── ijcnn1_SVRG_zeroGradient_Krum ├── ijcnn1_SVRG_zeroGradient_mean ├── mnist_BatchSGD_baseline_mean ├── mnist_BatchSGD_maxValue_mean ├── mnist_SAGA_zeroGradient_mean ├── covtype_BatchSGD_baseline_Krum ├── covtype_BatchSGD_baseline_mean ├── covtype_BatchSGD_baseline_median ├── covtype_BatchSGD_maxValue_Krum ├── covtype_BatchSGD_maxValue_mean ├── covtype_BatchSGD_maxValue_median ├── covtype_BatchSGD_zeroGradient_gm ├── covtype_SAGA_zeroGradient_Krum ├── covtype_SAGA_zeroGradient_mean ├── covtype_SAGA_zeroGradient_median ├── covtype_SARAH_zeroGradient_Krum ├── covtype_SARAH_zeroGradient_mean ├── covtype_SGD_zeroGradient_median ├── covtype_SVRG_zeroGradient_Krum ├── covtype_SVRG_zeroGradient_mean ├── covtype_SVRG_zeroGradient_median ├── ijcnn1_BatchSGD_ZV_baseline_gm ├── ijcnn1_BatchSGD_ZV_maxValue_gm ├── ijcnn1_BatchSGD_zeroGradient_gm ├── ijcnn1_SAGA_ZV_zeroGradient_gm ├── ijcnn1_SAGA_zeroGradient_median ├── ijcnn1_SARAH_zeroGradient_Krum ├── ijcnn1_SARAH_zeroGradient_mean ├── ijcnn1_SARAH_zeroGradient_median ├── ijcnn1_SGD_SAGA_cmpVar_baseline ├── ijcnn1_SGD_SAGA_cmpVar_maxValue ├── ijcnn1_SGD_zeroGradient_median ├── ijcnn1_SVRG_zeroGradient_median ├── mnist_BatchSGD_zeroGradient_gm ├── mnist_BatchSGD_zeroGradient_mean ├── covtype_BatchSGD_zeroGradient_Krum ├── covtype_BatchSGD_zeroGradient_mean ├── covtype_SARAH_zeroGradient_median ├── ijcnn1_BatchSGD_ZV_zeroGradient_gm ├── ijcnn1_BatchSGD_zeroGradient_mean ├── ijcnn1_SGD_ZV_baseline_gm_oldGamma ├── ijcnn1_SGD_ZV_maxValue_gm_oldGamma ├── covtype_BatchSGD_zeroGradient_median ├── ijcnn1_BatchSGD_ZV_white_gm_oldGamma ├── ijcnn1_SGD_SAGA_cmpVar_zeroGradient ├── ijcnn1_BatchSGD_ZV_baseline_gm_oldGamma ├── ijcnn1_BatchSGD_ZV_maxValue_gm_oldGamma ├── ijcnn1_SGD_ZV_zeroGradient_gm_oldGamma ├── CIFAR-10_ResNet_CentralSGD(20)_baseline_mean ├── CIFAR-10_ResNet_CentralSGD(30)_baseline_mean ├── CIFAR-10_ResNet_CentralSGD(5)_baseline_mean └── ijcnn1_BatchSGD_ZV_zeroGradient_gm_oldGamma ├── readme.md ├── draw.ipynb └── Byrd_SAGA_torch_ANN.ipynb /Full.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/Full.pdf -------------------------------------------------------------------------------- /cache/covtype_Fmin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_Fmin -------------------------------------------------------------------------------- /cache/ijcnn1_Fmin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_Fmin -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_white_gm -------------------------------------------------------------------------------- /cache/mnist_SAGA_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SAGA_white_gm -------------------------------------------------------------------------------- /cache/mnist_SGD_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SGD_white_gm -------------------------------------------------------------------------------- /cache/covtype_SAGA_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_white_gm -------------------------------------------------------------------------------- /cache/covtype_SARAH_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_white_gm -------------------------------------------------------------------------------- /cache/covtype_SGD_white_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_white_Krum -------------------------------------------------------------------------------- /cache/covtype_SGD_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_white_gm -------------------------------------------------------------------------------- /cache/covtype_SGD_white_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_white_mean -------------------------------------------------------------------------------- /cache/covtype_SVRG_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_white_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_white_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_white_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_white_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_white_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_white_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_white_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_ZV_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_ZV_white_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_baseline_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_maxValue_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_white_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_white_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_white_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_white_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_white_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_white_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_white_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_white_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_white_mean -------------------------------------------------------------------------------- /cache/mnist_SAGA_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SAGA_baseline_gm -------------------------------------------------------------------------------- /cache/mnist_SAGA_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SAGA_maxValue_gm -------------------------------------------------------------------------------- /cache/mnist_SAGA_white_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SAGA_white_mean -------------------------------------------------------------------------------- /cache/mnist_SGD_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SGD_baseline_gm -------------------------------------------------------------------------------- /cache/mnist_SGD_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SGD_maxValue_gm -------------------------------------------------------------------------------- /cache/mnist_SGD_white_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SGD_white_mean -------------------------------------------------------------------------------- /cache/covtype_SAGA_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_baseline_gm -------------------------------------------------------------------------------- /cache/covtype_SAGA_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_maxValue_gm -------------------------------------------------------------------------------- /cache/covtype_SAGA_white_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_white_Krum -------------------------------------------------------------------------------- /cache/covtype_SAGA_white_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_white_mean -------------------------------------------------------------------------------- /cache/covtype_SARAH_white_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_white_Krum -------------------------------------------------------------------------------- /cache/covtype_SARAH_white_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_white_mean -------------------------------------------------------------------------------- /cache/covtype_SGD_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_baseline_gm -------------------------------------------------------------------------------- /cache/covtype_SGD_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_maxValue_gm -------------------------------------------------------------------------------- /cache/covtype_SGD_white_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_white_median -------------------------------------------------------------------------------- /cache/covtype_SVRG_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_baseline_gm -------------------------------------------------------------------------------- /cache/covtype_SVRG_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_maxValue_gm -------------------------------------------------------------------------------- /cache/covtype_SVRG_white_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_white_Krum -------------------------------------------------------------------------------- /cache/covtype_SVRG_white_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_white_mean -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_white_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_ZV_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_ZV_white_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_baseline_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_maxValue_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_white_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_white_median -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_baseline_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_maxValue_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_white_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_white_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_white_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_white_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_baseline_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_baseline_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_baseline_krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_baseline_krum -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_baseline_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_maxValue_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_maxValue_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_maxValue_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_maxValue_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_white_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_white_median -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_baseline_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_maxValue_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_white_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_white_median -------------------------------------------------------------------------------- /cache/mnist_BatchSGD_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_BatchSGD_white_gm -------------------------------------------------------------------------------- /cache/mnist_SAGA_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SAGA_baseline_mean -------------------------------------------------------------------------------- /cache/mnist_SAGA_maxValue_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SAGA_maxValue_mean -------------------------------------------------------------------------------- /cache/mnist_SGD_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SGD_baseline_mean -------------------------------------------------------------------------------- /cache/mnist_SGD_maxValue_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SGD_maxValue_mean -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_white_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_white_Krum -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_white_gm -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_white_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_white_mean -------------------------------------------------------------------------------- /cache/covtype_SAGA_baseline_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_baseline_Krum -------------------------------------------------------------------------------- /cache/covtype_SAGA_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_baseline_mean -------------------------------------------------------------------------------- /cache/covtype_SAGA_maxValue_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_maxValue_Krum -------------------------------------------------------------------------------- /cache/covtype_SAGA_maxValue_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_maxValue_mean -------------------------------------------------------------------------------- /cache/covtype_SAGA_white_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_white_median -------------------------------------------------------------------------------- /cache/covtype_SARAH_baseline_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_baseline_Krum -------------------------------------------------------------------------------- /cache/covtype_SARAH_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_baseline_gm -------------------------------------------------------------------------------- /cache/covtype_SARAH_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_baseline_mean -------------------------------------------------------------------------------- /cache/covtype_SARAH_maxValue_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_maxValue_Krum -------------------------------------------------------------------------------- /cache/covtype_SARAH_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_maxValue_gm -------------------------------------------------------------------------------- /cache/covtype_SARAH_maxValue_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_maxValue_mean -------------------------------------------------------------------------------- /cache/covtype_SARAH_white_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_white_median -------------------------------------------------------------------------------- /cache/covtype_SGD_baseline_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_baseline_Krum -------------------------------------------------------------------------------- /cache/covtype_SGD_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_baseline_mean -------------------------------------------------------------------------------- /cache/covtype_SGD_baseline_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_baseline_median -------------------------------------------------------------------------------- /cache/covtype_SGD_maxValue_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_maxValue_Krum -------------------------------------------------------------------------------- /cache/covtype_SGD_maxValue_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_maxValue_mean -------------------------------------------------------------------------------- /cache/covtype_SGD_maxValue_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_maxValue_median -------------------------------------------------------------------------------- /cache/covtype_SGD_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_zeroGradient_gm -------------------------------------------------------------------------------- /cache/covtype_SVRG_baseline_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_baseline_Krum -------------------------------------------------------------------------------- /cache/covtype_SVRG_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_baseline_mean -------------------------------------------------------------------------------- /cache/covtype_SVRG_maxValue_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_maxValue_Krum -------------------------------------------------------------------------------- /cache/covtype_SVRG_maxValue_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_maxValue_mean -------------------------------------------------------------------------------- /cache/covtype_SVRG_white_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_white_median -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_ZV_white_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_ZV_white_gm -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_baseline_gm -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_maxValue_gm -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_white_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_white_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_ZV_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_ZV_baseline_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_ZV_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_ZV_maxValue_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_baseline_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_baseline_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_baseline_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_baseline_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_baseline_median -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_maxValue_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_maxValue_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_maxValue_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_maxValue_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_maxValue_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_maxValue_median -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_zeroGradient_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_baseline_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_baseline_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_baseline_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_maxValue_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_maxValue_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_maxValue_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_maxValue_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_white_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_white_median -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_ZV_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_ZV_baseline_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_ZV_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_ZV_maxValue_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_baseline_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_baseline_median -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_maxValue_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_maxValue_median -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_zeroGradient_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_baseline_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_baseline_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_baseline_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_baseline_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_baseline_median -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_maxValue_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_maxValue_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_maxValue_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_maxValue_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_maxValue_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_maxValue_median -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_zeroGradient_gm -------------------------------------------------------------------------------- /cache/mnist_BatchSGD_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_BatchSGD_baseline_gm -------------------------------------------------------------------------------- /cache/mnist_BatchSGD_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_BatchSGD_maxValue_gm -------------------------------------------------------------------------------- /cache/mnist_BatchSGD_white_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_BatchSGD_white_mean -------------------------------------------------------------------------------- /cache/mnist_SAGA_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SAGA_zeroGradient_gm -------------------------------------------------------------------------------- /cache/mnist_SGD_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SGD_zeroGradient_gm -------------------------------------------------------------------------------- /cache/mnist_SGD_zeroGradient_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SGD_zeroGradient_mean -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_baseline_gm -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_maxValue_gm -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_white_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_white_median -------------------------------------------------------------------------------- /cache/covtype_SAGA_baseline_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_baseline_median -------------------------------------------------------------------------------- /cache/covtype_SAGA_maxValue_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_maxValue_median -------------------------------------------------------------------------------- /cache/covtype_SAGA_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_zeroGradient_gm -------------------------------------------------------------------------------- /cache/covtype_SARAH_baseline_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_baseline_median -------------------------------------------------------------------------------- /cache/covtype_SARAH_maxValue_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_maxValue_median -------------------------------------------------------------------------------- /cache/covtype_SARAH_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_zeroGradient_gm -------------------------------------------------------------------------------- /cache/covtype_SGD_zeroGradient_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_zeroGradient_Krum -------------------------------------------------------------------------------- /cache/covtype_SGD_zeroGradient_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_zeroGradient_mean -------------------------------------------------------------------------------- /cache/covtype_SVRG_baseline_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_baseline_median -------------------------------------------------------------------------------- /cache/covtype_SVRG_maxValue_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_maxValue_median -------------------------------------------------------------------------------- /cache/covtype_SVRG_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_zeroGradient_gm -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_baseline_mean -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_maxValue_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_maxValue_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_ZV_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_ZV_baseline_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_zeroGradient_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_zeroGradient_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_zeroGradient_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_zeroGradient_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_baseline_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_baseline_median -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_maxValue_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_maxValue_median -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_zeroGradient_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_SAGA_cmpVar_white: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_SAGA_cmpVar_white -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_ZV_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_ZV_zeroGradient_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_zeroGradient_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_zeroGradient_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_zeroGradient_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_zeroGradient_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_zeroGradient_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_zeroGradient_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_zeroGradient_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_zeroGradient_mean -------------------------------------------------------------------------------- /cache/mnist_BatchSGD_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_BatchSGD_baseline_mean -------------------------------------------------------------------------------- /cache/mnist_BatchSGD_maxValue_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_BatchSGD_maxValue_mean -------------------------------------------------------------------------------- /cache/mnist_SAGA_zeroGradient_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_SAGA_zeroGradient_mean -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_baseline_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_baseline_Krum -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_baseline_mean -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_baseline_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_baseline_median -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_maxValue_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_maxValue_Krum -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_maxValue_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_maxValue_mean -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_maxValue_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_maxValue_median -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_zeroGradient_gm -------------------------------------------------------------------------------- /cache/covtype_SAGA_zeroGradient_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_zeroGradient_Krum -------------------------------------------------------------------------------- /cache/covtype_SAGA_zeroGradient_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_zeroGradient_mean -------------------------------------------------------------------------------- /cache/covtype_SAGA_zeroGradient_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SAGA_zeroGradient_median -------------------------------------------------------------------------------- /cache/covtype_SARAH_zeroGradient_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_zeroGradient_Krum -------------------------------------------------------------------------------- /cache/covtype_SARAH_zeroGradient_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_zeroGradient_mean -------------------------------------------------------------------------------- /cache/covtype_SGD_zeroGradient_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SGD_zeroGradient_median -------------------------------------------------------------------------------- /cache/covtype_SVRG_zeroGradient_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_zeroGradient_Krum -------------------------------------------------------------------------------- /cache/covtype_SVRG_zeroGradient_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_zeroGradient_mean -------------------------------------------------------------------------------- /cache/covtype_SVRG_zeroGradient_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SVRG_zeroGradient_median -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_ZV_baseline_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_ZV_baseline_gm -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_ZV_maxValue_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_ZV_maxValue_gm -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_zeroGradient_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_ZV_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_ZV_zeroGradient_gm -------------------------------------------------------------------------------- /cache/ijcnn1_SAGA_zeroGradient_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SAGA_zeroGradient_median -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_zeroGradient_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_zeroGradient_Krum -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_zeroGradient_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_zeroGradient_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SARAH_zeroGradient_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SARAH_zeroGradient_median -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_SAGA_cmpVar_baseline: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_SAGA_cmpVar_baseline -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_SAGA_cmpVar_maxValue: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_SAGA_cmpVar_maxValue -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_zeroGradient_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_zeroGradient_median -------------------------------------------------------------------------------- /cache/ijcnn1_SVRG_zeroGradient_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SVRG_zeroGradient_median -------------------------------------------------------------------------------- /cache/mnist_BatchSGD_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_BatchSGD_zeroGradient_gm -------------------------------------------------------------------------------- /cache/mnist_BatchSGD_zeroGradient_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/mnist_BatchSGD_zeroGradient_mean -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_zeroGradient_Krum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_zeroGradient_Krum -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_zeroGradient_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_zeroGradient_mean -------------------------------------------------------------------------------- /cache/covtype_SARAH_zeroGradient_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_SARAH_zeroGradient_median -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_ZV_zeroGradient_gm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_ZV_zeroGradient_gm -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_zeroGradient_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_zeroGradient_mean -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_ZV_baseline_gm_oldGamma: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_ZV_baseline_gm_oldGamma -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_ZV_maxValue_gm_oldGamma: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_ZV_maxValue_gm_oldGamma -------------------------------------------------------------------------------- /cache/covtype_BatchSGD_zeroGradient_median: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/covtype_BatchSGD_zeroGradient_median -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_ZV_white_gm_oldGamma: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_ZV_white_gm_oldGamma -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_SAGA_cmpVar_zeroGradient: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_SAGA_cmpVar_zeroGradient -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_ZV_baseline_gm_oldGamma: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_ZV_baseline_gm_oldGamma -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_ZV_maxValue_gm_oldGamma: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_ZV_maxValue_gm_oldGamma -------------------------------------------------------------------------------- /cache/ijcnn1_SGD_ZV_zeroGradient_gm_oldGamma: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_SGD_ZV_zeroGradient_gm_oldGamma -------------------------------------------------------------------------------- /cache/CIFAR-10_ResNet_CentralSGD(20)_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/CIFAR-10_ResNet_CentralSGD(20)_baseline_mean -------------------------------------------------------------------------------- /cache/CIFAR-10_ResNet_CentralSGD(30)_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/CIFAR-10_ResNet_CentralSGD(30)_baseline_mean -------------------------------------------------------------------------------- /cache/CIFAR-10_ResNet_CentralSGD(5)_baseline_mean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/CIFAR-10_ResNet_CentralSGD(5)_baseline_mean -------------------------------------------------------------------------------- /cache/ijcnn1_BatchSGD_ZV_zeroGradient_gm_oldGamma: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoxian-Wu/Byrd-SAGA/HEAD/cache/ijcnn1_BatchSGD_ZV_zeroGradient_gm_oldGamma -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Byrd-SAGA 2 | This hub stores the code for paper *Byzantine-Resilient Distributed Finite-Sum Optimization over Networks* (short version) and *Federated Variance-Reduced Stochastic Gradient Descent with Robustness to Byzantine Attacks* (full version, which can be seen in `Full.pdf`). The code should be run in the *jupyter notebook*. 3 | 4 | ## Environment 5 | - python 3.7.4 6 | - pytorch 1.2.0 7 | - matplotlib 3.1.1 8 | 9 | ## Construction 10 | The main programs can be found in the following files: 11 | - Byrd_SAGA_torch_LinearRegression.ipynb: The experiment on linear regression. 12 | - Byrd_SAGA_torch_ANN.ipynb: The experiment on neural network. 13 | - draw.ipynb: The script to draw picture. 14 | 15 | ## Runing 16 | Download the dataset to the file folder `./dataset` and create a file folder named `./cache`. The experiment output will be stored in `./cache`. 17 | 18 | ## Download dataset 19 | - *ijcnn1/covtype*: [https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/) 20 | - *MNIST*: [http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz](http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz) 21 | 22 | ## Acknowledgement 23 | We would like to thanks to Runhua Wang, SYSU, for helping us to review and improve our code. -------------------------------------------------------------------------------- /draw.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import matplotlib.pyplot as plt\n", 10 | "import pickle\n", 11 | "import math\n", 12 | "\n", 13 | "dataSetConfigs = [\n", 14 | " {\n", 15 | " 'name': 'ijcnn1',\n", 16 | "\n", 17 | " 'dataSet' : 'ijcnn1',\n", 18 | " 'dataSetSize': 49990,\n", 19 | " 'maxFeature': 22,\n", 20 | " 'findingType': '1',\n", 21 | "\n", 22 | " 'honestNodeSize': 50,\n", 23 | " 'byzantineNodeSize': 20,\n", 24 | "\n", 25 | " 'rounds': 10,\n", 26 | " 'displayInterval': 4000,\n", 27 | " },\n", 28 | " {\n", 29 | " 'name': 'covtype',\n", 30 | "\n", 31 | " 'dataSet' : 'covtype.libsvm.binary.scale',\n", 32 | " 'dataSetSize': 581012,\n", 33 | " 'maxFeature': 54,\n", 34 | " 'findingType': '1',\n", 35 | "\n", 36 | " 'honestNodeSize': 50,\n", 37 | " 'byzantineNodeSize': 20,\n", 38 | "\n", 39 | " 'rounds': 10,\n", 40 | " 'displayInterval': 4000,\n", 41 | " }\n", 42 | "]\n" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "# 多聚合方式比较图" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "import matplotlib.pyplot as plt\n", 59 | "import pickle\n", 60 | "import math\n", 61 | "\n", 62 | "def logAxis(path, Fmin):\n", 63 | " return [p-Fmin for p in path]\n", 64 | "\n", 65 | "# %matplotlib inline\n", 66 | "# %config InlineBackend.figure_format = 'svg'\n", 67 | "\n", 68 | "SCALE = 2.0\n", 69 | "FONT_SIZE = 12\n", 70 | "\n", 71 | "fig, axs = plt.subplots(4, 2)\n", 72 | "axs = tuple(zip(*axs))\n", 73 | "\n", 74 | "for axColumn, dataSetConfig in zip(axs, dataSetConfigs):\n", 75 | " \n", 76 | " CACHE_DIR = './cache/' + dataSetConfig['name']\n", 77 | " with open(CACHE_DIR + '_Fmin', 'rb') as f:\n", 78 | " obj = pickle.load(f)\n", 79 | " Fmin, w_min = obj['Fmin'], obj['w_min']\n", 80 | "\n", 81 | " x_axis = list(range(dataSetConfig['rounds']+1))\n", 82 | "\n", 83 | " attackNames = [\n", 84 | " ('baseline', 'without attack'), \n", 85 | " ('white', 'Gaussian attack'), \n", 86 | " ('maxValue', 'max-value attack'), \n", 87 | " ('zeroGradient', 'zero-gradient attack'),\n", 88 | " ]\n", 89 | " # 文件名 图片中显示的名称 颜色\n", 90 | " aggregations = [\n", 91 | " ('mean', 'mean', 'royalblue'), \n", 92 | " ('gm', 'geometric median', 'darkorange'), \n", 93 | " ('Krum', 'Krum', 'darkgreen'), \n", 94 | " ('median', 'median', 'darkmagenta'),\n", 95 | " ]\n", 96 | " \n", 97 | " for ax, (attackName, title) in zip(axColumn, attackNames):\n", 98 | " # 画曲线\n", 99 | " for (aggregationName, showName, color) in aggregations:\n", 100 | " with open(CACHE_DIR + '_SAGA_' + attackName + '_' + aggregationName, 'rb') as f:\n", 101 | " record = pickle.load(f)\n", 102 | " path = record['path']\n", 103 | " path = logAxis(path, Fmin)\n", 104 | " ax.plot(x_axis, path, 'v-', color=color, label=showName)\n", 105 | " \n", 106 | " # 填小标题\n", 107 | " ax.set_title('{} ({})'.format(title, dataSetConfig['name'].upper()))\n", 108 | " \n", 109 | " # 坐标轴\n", 110 | " ax.set_yscale('log')\n", 111 | " ax.set_ylabel(r'$f(x^k)-f(x^*)$', fontsize=FONT_SIZE)\n", 112 | " ax.set_xlabel(r'iteration k / ${}$'.format(dataSetConfig['displayInterval']), fontsize=FONT_SIZE)\n", 113 | " labels = ax.get_xticklabels() + ax.get_yticklabels()\n", 114 | " [label.set_fontsize(FONT_SIZE) for label in labels]\n", 115 | "# ax.set_ylim(top=1e2, bottom=1e-6)\n", 116 | "\n", 117 | "# 图例\n", 118 | "axs[0][-1].legend(loc='lower left', bbox_to_anchor=(0.01,-0.4), \n", 119 | " borderaxespad = 0., ncol=4, fontsize=FONT_SIZE)\n", 120 | "# for ax in axs[1]:\n", 121 | "# ax.set_ylim(top=1e10)\n", 122 | "\n", 123 | "fig.set_size_inches((SCALE*4, SCALE*8))\n", 124 | "plt.subplots_adjust(hspace=0.4, wspace=0.35)\n", 125 | "plt.savefig('./attack_ijc_cov.eps', format='eps', bbox_inches='tight')\n", 126 | "plt.show()" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "import matplotlib.pyplot as plt\n", 136 | "import pickle\n", 137 | "import math\n", 138 | "\n", 139 | "def logAxis(path, Fmin):\n", 140 | " return [p-Fmin for p in path]\n", 141 | "\n", 142 | "# %matplotlib inline\n", 143 | "# %config InlineBackend.figure_format = 'svg'\n", 144 | "\n", 145 | "SCALE = 2.0\n", 146 | "FONT_SIZE = 12\n", 147 | "\n", 148 | "fig, axs = plt.subplots(2, 4)\n", 149 | "# axs = tuple(zip(*axs))\n", 150 | "\n", 151 | "for axColumn, dataSetConfig in zip(axs, dataSetConfigs):\n", 152 | " \n", 153 | " CACHE_DIR = './cache/' + dataSetConfig['name']\n", 154 | " with open(CACHE_DIR + '_Fmin', 'rb') as f:\n", 155 | " obj = pickle.load(f)\n", 156 | " Fmin, w_min = obj['Fmin'], obj['w_min']\n", 157 | "\n", 158 | " x_axis = list(range(dataSetConfig['rounds']+1))\n", 159 | "\n", 160 | " attackNames = [\n", 161 | " ('baseline', 'without attack'), \n", 162 | " ('white', 'Gaussian attack'), \n", 163 | " ('maxValue', 'sign-flipping attack'), \n", 164 | " ('zeroGradient', 'zero-gradient attack'),\n", 165 | " ]\n", 166 | " # 文件名 图片中显示的名称 颜色\n", 167 | " aggregations = [\n", 168 | " ('mean', 'mean', 'royalblue'), \n", 169 | " ('gm', 'geometric median', 'darkorange'), \n", 170 | " ('Krum', 'Krum', 'darkgreen'), \n", 171 | " ('median', 'median', 'darkmagenta'),\n", 172 | " ]\n", 173 | " \n", 174 | " for ax, (attackName, title) in zip(axColumn, attackNames):\n", 175 | " # 画曲线\n", 176 | " for (aggregationName, showName, color) in aggregations:\n", 177 | " with open(CACHE_DIR + '_SAGA_' + attackName + '_' + aggregationName, 'rb') as f:\n", 178 | " record = pickle.load(f)\n", 179 | " path = record['path']\n", 180 | " path = logAxis(path, Fmin)\n", 181 | " ax.plot(x_axis, path, 'v-', color=color, label=showName)\n", 182 | " \n", 183 | " # 填小标题\n", 184 | " ax.set_title('{} ({})'.format(title, dataSetConfig['name'].upper()))\n", 185 | " \n", 186 | " # 坐标轴\n", 187 | " ax.set_yscale('log')\n", 188 | " ax.set_ylabel(r'$f(x^k)-f(x^*)$', fontsize=FONT_SIZE)\n", 189 | " ax.set_xlabel(r'iteration k / ${}$'.format(dataSetConfig['displayInterval']), fontsize=FONT_SIZE)\n", 190 | " labels = ax.get_xticklabels() + ax.get_yticklabels()\n", 191 | " [label.set_fontsize(FONT_SIZE) for label in labels]\n", 192 | "# ax.set_ylim(top=1e2, bottom=1e-6)\n", 193 | "\n", 194 | "# 图例\n", 195 | "# 两行\n", 196 | "axs[1][0].legend(loc='lower left', bbox_to_anchor=(1.2,-0.4), \n", 197 | " borderaxespad = 0., ncol=4, fontsize=FONT_SIZE)\n", 198 | "\n", 199 | "fig.set_size_inches((SCALE*9, SCALE*4))\n", 200 | "plt.subplots_adjust(hspace=0.35, wspace=0.35)\n", 201 | "plt.savefig('./attack_ijc_cov.eps', format='eps', bbox_inches='tight')\n", 202 | "plt.show()" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "# 多聚合方式比较图(SGD-SAGA)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "import matplotlib.pyplot as plt\n", 219 | "import pickle\n", 220 | "import math\n", 221 | "\n", 222 | "def logAxis(path, Fmin):\n", 223 | " return [p-Fmin for p in path]\n", 224 | "\n", 225 | "# %matplotlib inline\n", 226 | "# %config InlineBackend.figure_format = 'svg'\n", 227 | "\n", 228 | "SCALE = 2.0\n", 229 | "FONT_SIZE = 12\n", 230 | "\n", 231 | "fig, axs = plt.subplots(4, 2)\n", 232 | "axs = tuple(zip(*axs))\n", 233 | "\n", 234 | "for axColumn, dataSetConfig in zip(axs, dataSetConfigs):\n", 235 | " \n", 236 | " CACHE_DIR = './cache/' + dataSetConfig['name']\n", 237 | " with open(CACHE_DIR + '_Fmin', 'rb') as f:\n", 238 | " obj = pickle.load(f)\n", 239 | " Fmin, w_min = obj['Fmin'], obj['w_min']\n", 240 | "\n", 241 | " x_axis = list(range(dataSetConfig['rounds']+1))\n", 242 | "\n", 243 | " attackNames = [\n", 244 | " ('baseline', 'without attack'), \n", 245 | " ('white', 'Gaussian attack'), \n", 246 | " ('maxValue', 'max-value attack'), \n", 247 | " ('zeroGradient', 'zero-gradient attack'),\n", 248 | " ]\n", 249 | " # 文件名 图片中显示的名称 颜色\n", 250 | " aggregations = [\n", 251 | " ('mean', 'mean', 'royalblue'), \n", 252 | " ('gm', 'GM', 'darkorange'), \n", 253 | " ('Krum', 'Krum', 'darkgreen'), \n", 254 | " ('median', 'median', 'darkmagenta'),\n", 255 | " ]\n", 256 | " \n", 257 | " algorithms = [\n", 258 | " ('SGD', ':'),\n", 259 | " ('SAGA', '-'),\n", 260 | " ]\n", 261 | " \n", 262 | " for ax, (attackName, title) in zip(axColumn, attackNames):\n", 263 | " # 画曲线\n", 264 | " for (aggregationName, showName, color) in aggregations:\n", 265 | " for algorithm, line in algorithms:\n", 266 | " try:\n", 267 | " with open(CACHE_DIR + '_' + algorithm + '_' + attackName + '_' + aggregationName, 'rb') as f:\n", 268 | " record = pickle.load(f)\n", 269 | " path = record['path']\n", 270 | " path = logAxis(path, Fmin)\n", 271 | " ax.plot(x_axis, path, 'v'+line, color=color, label=algorithm + ' ' + showName)\n", 272 | " except Exception as e:\n", 273 | " print(e)\n", 274 | " \n", 275 | " # 填小标题\n", 276 | " ax.set_title('{} ({})'.format(title, dataSetConfig['name'].upper()))\n", 277 | " \n", 278 | " # 坐标轴\n", 279 | " ax.set_yscale('log')\n", 280 | " ax.set_ylabel(r'$f(x^k)-f(x^*)$', fontsize=FONT_SIZE)\n", 281 | " ax.set_xlabel(r'iteration k / ${}$'.format(dataSetConfig['displayInterval']), fontsize=FONT_SIZE)\n", 282 | " labels = ax.get_xticklabels() + ax.get_yticklabels()\n", 283 | " [label.set_fontsize(FONT_SIZE) for label in labels]\n", 284 | "# ax.set_ylim(top=1e2, bottom=1e-6)\n", 285 | "\n", 286 | "# 图例\n", 287 | "axs[0][-1].legend(loc='lower left', bbox_to_anchor=(0.01,-0.6), \n", 288 | " borderaxespad = 0., ncol=4, fontsize=FONT_SIZE)\n", 289 | "# for ax in axs[1]:\n", 290 | "# ax.set_ylim(top=1e10)\n", 291 | "\n", 292 | "fig.set_size_inches((SCALE*4, SCALE*8))\n", 293 | "plt.subplots_adjust(hspace=0.4, wspace=0.35)\n", 294 | "# plt.savefig('./attack_ijc_cov.jpg', format='jpg', bbox_inches='tight')\n", 295 | "# plt.savefig('./attack_ijc_cov.eps', format='eps', bbox_inches='tight')\n", 296 | "plt.show()" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "import matplotlib.pyplot as plt\n", 306 | "import pickle\n", 307 | "import math\n", 308 | "\n", 309 | "def logAxis(path, Fmin):\n", 310 | " return [p-Fmin for p in path]\n", 311 | "\n", 312 | "# %matplotlib inline\n", 313 | "# %config InlineBackend.figure_format = 'svg'\n", 314 | "\n", 315 | "SCALE = 2.0\n", 316 | "FONT_SIZE = 12\n", 317 | "\n", 318 | "fig, axs = plt.subplots(4, 2)\n", 319 | "axs = tuple(zip(*axs))\n", 320 | "\n", 321 | "for axColumn, dataSetConfig in zip(axs, dataSetConfigs):\n", 322 | " \n", 323 | " CACHE_DIR = './cache/' + dataSetConfig['name']\n", 324 | " with open(CACHE_DIR + '_Fmin', 'rb') as f:\n", 325 | " obj = pickle.load(f)\n", 326 | " Fmin, w_min = obj['Fmin'], obj['w_min']\n", 327 | "\n", 328 | " x_axis = list(range(dataSetConfig['rounds']+1))\n", 329 | "\n", 330 | " attackNames = [\n", 331 | " ('baseline', 'without attack'), \n", 332 | " ('white', 'Gaussian attack'), \n", 333 | " ('maxValue', 'max-value attack'), \n", 334 | " ('zeroGradient', 'zero-gradient attack'),\n", 335 | " ]\n", 336 | " # 文件名 图片中显示的名称 颜色\n", 337 | " aggregations = [\n", 338 | " ('mean', 'mean', 'royalblue'), \n", 339 | " ('gm', 'GM', 'darkorange'), \n", 340 | " ('Krum', 'Krum', 'darkgreen'), \n", 341 | " ('median', 'median', 'darkmagenta'),\n", 342 | " ]\n", 343 | " \n", 344 | " algorithms = [\n", 345 | " ('SGD', ':'),\n", 346 | " ('SAGA', '-'),\n", 347 | " ]\n", 348 | " \n", 349 | " for ax, (attackName, title) in zip(axColumn, attackNames):\n", 350 | " # 画曲线\n", 351 | " for (aggregationName, showName, color) in aggregations:\n", 352 | " for algorithm, line in algorithms:\n", 353 | " try:\n", 354 | " with open(CACHE_DIR + '_' + algorithm + '_' + attackName + '_' + aggregationName, 'rb') as f:\n", 355 | " record = pickle.load(f)\n", 356 | " path = record['path']\n", 357 | " path = logAxis(path, Fmin)\n", 358 | " ax.plot(x_axis, path, 'v'+line, color=color, label=algorithm + ' ' + showName)\n", 359 | " except Exception as e:\n", 360 | " print(e)\n", 361 | " \n", 362 | " # 填小标题\n", 363 | " ax.set_title('{} ({})'.format(title, dataSetConfig['name'].upper()))\n", 364 | " \n", 365 | " # 坐标轴\n", 366 | " ax.set_yscale('log')\n", 367 | " ax.set_ylabel(r'$f(x^k)-f(x^*)$', fontsize=FONT_SIZE)\n", 368 | " ax.set_xlabel(r'iteration k / ${}$'.format(dataSetConfig['displayInterval']), fontsize=FONT_SIZE)\n", 369 | " labels = ax.get_xticklabels() + ax.get_yticklabels()\n", 370 | " [label.set_fontsize(FONT_SIZE) for label in labels]\n", 371 | "# ax.set_ylim(top=1e2, bottom=1e-6)\n", 372 | "\n", 373 | "# 图例\n", 374 | "axs[0][-1].legend(loc='lower left', bbox_to_anchor=(0.01,-0.6), \n", 375 | " borderaxespad = 0., ncol=4, fontsize=FONT_SIZE)\n", 376 | "# for ax in axs[1]:\n", 377 | "# ax.set_ylim(top=1e10)\n", 378 | "\n", 379 | "fig.set_size_inches((SCALE*4, SCALE*8))\n", 380 | "plt.subplots_adjust(hspace=0.4, wspace=0.35)\n", 381 | "# plt.savefig('./attack_ijc_cov.jpg', format='jpg', bbox_inches='tight')\n", 382 | "# plt.savefig('./attack_ijc_cov.eps', format='eps', bbox_inches='tight')\n", 383 | "plt.show()" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "metadata": {}, 389 | "source": [ 390 | "# 神经网络对比" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "for optimizer in ['SAGA', 'SGD', 'BatchSGD']:\n", 400 | " for attack in ['baseline', 'white', 'maxValue', 'zeroGradient']:\n", 401 | " for aggegate in ['mean', 'gm']:\n", 402 | " file = optimizer + '_' + attack + '_' + aggegate\n", 403 | " try:\n", 404 | " with open('./cache/mnist_' + file, 'rb') as f:\n", 405 | " record = pickle.load(f)\n", 406 | " print('{:>4} {:>4f} {}'.format(record['gamma'], record['accPath'][-1]*100, file))\n", 407 | " except:\n", 408 | " print(file)" 409 | ] 410 | }, 411 | { 412 | "cell_type": "markdown", 413 | "metadata": {}, 414 | "source": [ 415 | "# 步长比较" 416 | ] 417 | }, 418 | { 419 | "cell_type": "markdown", 420 | "metadata": {}, 421 | "source": [ 422 | "ijcnn1 步长比较" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": null, 428 | "metadata": {}, 429 | "outputs": [], 430 | "source": [ 431 | "for optimizer in ['SAGA', 'SVRG', 'SGD', 'BatchSGD']:\n", 432 | " for attack in ['baseline', 'white', 'maxValue', 'zeroGradient']:\n", 433 | " for aggegate in ['mean', 'gm']:\n", 434 | " file = optimizer + '_' + attack + '_' + aggegate\n", 435 | " try:\n", 436 | " with open('./cache/ijcnn1_' + file, 'rb') as f:\n", 437 | " record = pickle.load(f)\n", 438 | " print(record['gamma'], file)\n", 439 | " except Exception as e:\n", 440 | " print(e)" 441 | ] 442 | }, 443 | { 444 | "cell_type": "markdown", 445 | "metadata": {}, 446 | "source": [ 447 | "covtype" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [ 456 | "for optimizer in ['SAGA', 'SVRG', 'SARAH', 'SGD', 'BatchSGD']:\n", 457 | " for attack in ['baseline', 'white', 'maxValue', 'zeroGradient']:\n", 458 | " for aggegate in ['mean', 'gm', 'median', 'Krum']:\n", 459 | " file = optimizer + '_' + attack + '_' + aggegate\n", 460 | " try:\n", 461 | " with open('./cache/covtype_' + file, 'rb') as f:\n", 462 | " record = pickle.load(f)\n", 463 | " print(record['gamma'], file)\n", 464 | " except Exception as e:\n", 465 | " print(e)" 466 | ] 467 | }, 468 | { 469 | "cell_type": "markdown", 470 | "metadata": {}, 471 | "source": [ 472 | "mnist 步长比较" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": null, 478 | "metadata": {}, 479 | "outputs": [], 480 | "source": [ 481 | "with open('./cache/mnist_SAGA_baseline_mean', 'rb') as f:\n", 482 | " record = pickle.load(f)\n", 483 | "print('[SAGA] gamma={}'.format(record['gamma']))\n", 484 | "with open('./cache/mnist_SGD_baseline_mean', 'rb') as f:\n", 485 | " record = pickle.load(f)\n", 486 | "print('[SGD] gamma={}'.format(record['gamma']))" 487 | ] 488 | }, 489 | { 490 | "cell_type": "markdown", 491 | "metadata": {}, 492 | "source": [ 493 | "# 方差比较图(竖)" 494 | ] 495 | }, 496 | { 497 | "cell_type": "markdown", 498 | "metadata": {}, 499 | "source": [ 500 | "ijcnn" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": null, 506 | "metadata": { 507 | "code_folding": [] 508 | }, 509 | "outputs": [], 510 | "source": [ 511 | "import matplotlib.pyplot as plt\n", 512 | "import pickle\n", 513 | "import math\n", 514 | "\n", 515 | "def logAxis(path, Fmin):\n", 516 | " return [p-Fmin for p in path]\n", 517 | "\n", 518 | "# %matplotlib inline\n", 519 | "# %config InlineBackend.figure_format = 'svg'\n", 520 | "\n", 521 | "dataSetConfig = {\n", 522 | " 'name': 'ijcnn1',\n", 523 | "\n", 524 | " 'dataSet' : 'ijcnn1',\n", 525 | " 'dataSetSize': 49990,\n", 526 | " 'maxFeature': 22,\n", 527 | " 'findingType': '1',\n", 528 | "\n", 529 | " 'honestNodeSize': 50,\n", 530 | " 'byzantineNodeSize': 20,\n", 531 | "\n", 532 | " 'rounds': 10,\n", 533 | " 'displayInterval': 4000,\n", 534 | "}\n", 535 | "\n", 536 | "SCALE = 2.0\n", 537 | "FONT_SIZE = 12\n", 538 | "\n", 539 | "fig, axs = plt.subplots(4, 2)\n", 540 | "\n", 541 | "attackNames = [\n", 542 | " ('baseline', 'without attack'), \n", 543 | " ('white', 'Gaussian attack'), \n", 544 | " ('maxValue', 'max-value attack'), \n", 545 | " ('zeroGradient', 'zero-gradient attack'),\n", 546 | "]\n", 547 | "# 名称 颜色\n", 548 | "optimizers = [\n", 549 | " ('SGD', 'SGD', 'royalblue'), \n", 550 | " ('BatchSGD', 'BSGD', 'darkgreen'), \n", 551 | " ('SAGA', 'SAGA', 'darkorange'),\n", 552 | "]\n", 553 | "# 文件名 图片中显示的名称 纹理\n", 554 | "aggregations = [\n", 555 | " ('mean', 'mean', 'v-'), \n", 556 | " ('gm', 'geomed','v--'),\n", 557 | "]\n", 558 | "\n", 559 | "CACHE_DIR = './cache/' + dataSetConfig['name']\n", 560 | "with open(CACHE_DIR + '_Fmin', 'rb') as f:\n", 561 | " obj = pickle.load(f)\n", 562 | " Fmin, w_min = obj['Fmin'], obj['w_min']\n", 563 | "\n", 564 | "x_axis = list(range(dataSetConfig['rounds']+1))\n", 565 | "x_axis_minus_1 = list(range(1, dataSetConfig['rounds']+1))\n", 566 | "\n", 567 | "for axline, (attackName, title) in zip(axs, attackNames):\n", 568 | " for optimizer, optimizerName, color in optimizers:\n", 569 | " for (aggregationName, showName, fmt) in aggregations:\n", 570 | " # 标签\n", 571 | " label = optimizerName + ' ' + showName\n", 572 | " \n", 573 | " # 画曲线\n", 574 | " with open(CACHE_DIR + '_' + optimizer+'_' + attackName + '_' + aggregationName, 'rb') as f:\n", 575 | " record = pickle.load(f)\n", 576 | " # 损失函数\n", 577 | " path = record['path']\n", 578 | " path = logAxis(path, Fmin)\n", 579 | " axline[0].plot(x_axis, path, fmt, color=color, label=label)\n", 580 | " # variance\n", 581 | " variancePath = record['variancePath']\n", 582 | " axline[1].plot(x_axis_minus_1, variancePath, fmt, color=color, label=label)\n", 583 | " \n", 584 | " # 填小标题\n", 585 | " axline[0].set_title('optimality gap ({})'.format(title))\n", 586 | " axline[1].set_title('variance ({})'.format(title))\n", 587 | " for ax in axline:\n", 588 | " # 坐标轴\n", 589 | " ax.set_yscale('log')\n", 590 | " ax.set_xlabel(r'iteration k / ${}$'.format(dataSetConfig['displayInterval']), fontsize=FONT_SIZE)\n", 591 | " labels = ax.get_xticklabels() + ax.get_yticklabels()\n", 592 | " [label.set_fontsize(FONT_SIZE) for label in labels]\n", 593 | "\n", 594 | " axline[0].set_ylabel(r'$f(x^k)-f(x^*)$', fontsize=FONT_SIZE)\n", 595 | " axline[1].set_ylabel(r'$D_{w\\notin B} [m_w^k]$', fontsize=FONT_SIZE)\n", 596 | " axline[1].set_xlim(left=0)\n", 597 | " \n", 598 | "# 图例\n", 599 | "# 一行\n", 600 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(0.01,-0.4), \n", 601 | "# borderaxespad = 0., ncol=4, fontsize=FONT_SIZE)\n", 602 | "# 两行\n", 603 | "axs[-1][0].legend(loc='lower left', bbox_to_anchor=(0.1,-0.6), \n", 604 | " borderaxespad = 0., ncol=3, fontsize=FONT_SIZE)\n", 605 | "\n", 606 | "fig.set_size_inches((SCALE*4, SCALE*8))\n", 607 | "plt.subplots_adjust(hspace=0.4, wspace=0.35)\n", 608 | "plt.savefig('./attack_loss_variance.eps', format='eps', bbox_inches='tight')\n", 609 | "# plt.savefig('./attack_loss_variance.jpg', format='jpg', bbox_inches='tight')\n", 610 | "plt.show()" 611 | ] 612 | }, 613 | { 614 | "cell_type": "markdown", 615 | "metadata": {}, 616 | "source": [ 617 | "covtype" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": null, 623 | "metadata": { 624 | "code_folding": [] 625 | }, 626 | "outputs": [], 627 | "source": [ 628 | "import matplotlib.pyplot as plt\n", 629 | "import pickle\n", 630 | "import math\n", 631 | "\n", 632 | "def logAxis(path, Fmin):\n", 633 | " return [p-Fmin for p in path]\n", 634 | "\n", 635 | "# %matplotlib inline\n", 636 | "# %config InlineBackend.figure_format = 'svg'\n", 637 | "\n", 638 | "dataSetConfig = dataSetConfigs[1]\n", 639 | "\n", 640 | "SCALE = 2.0\n", 641 | "FONT_SIZE = 12\n", 642 | "\n", 643 | "fig, axs = plt.subplots(4, 2)\n", 644 | "\n", 645 | "attackNames = [\n", 646 | " ('baseline', 'without attack'), \n", 647 | " ('white', 'Gaussian attack'), \n", 648 | " ('maxValue', 'max-value attack'), \n", 649 | " ('zeroGradient', 'zero-gradient attack'),\n", 650 | "]\n", 651 | "# 名称 颜色\n", 652 | "optimizers = [\n", 653 | " ('SGD','royalblue'), \n", 654 | " ('BatchSGD', 'darkgreen'), \n", 655 | " ('SAGA', 'darkorange'),\n", 656 | "]\n", 657 | "# 文件名 图片中显示的名称 纹理\n", 658 | "aggregations = [\n", 659 | " ('mean', 'mean', 'v-'), \n", 660 | " ('gm', 'gm','v--'),\n", 661 | "]\n", 662 | "\n", 663 | "CACHE_DIR = './cache/' + dataSetConfig['name']\n", 664 | "with open(CACHE_DIR + '_Fmin', 'rb') as f:\n", 665 | " obj = pickle.load(f)\n", 666 | " Fmin, w_min = obj['Fmin'], obj['w_min']\n", 667 | "\n", 668 | "x_axis = list(range(dataSetConfig['rounds']+1))\n", 669 | "x_axis_minus_1 = list(range(1, dataSetConfig['rounds']+1))\n", 670 | "\n", 671 | "for axline, (attackName, title) in zip(axs, attackNames):\n", 672 | " for optimizer, color in optimizers:\n", 673 | " for (aggregationName, showName, fmt) in aggregations:\n", 674 | " # 标签\n", 675 | " label = optimizer + ' ' + showName\n", 676 | " \n", 677 | " # 画曲线\n", 678 | " try:\n", 679 | " with open(CACHE_DIR + '_' + optimizer+'_' + attackName + '_' + aggregationName, 'rb') as f:\n", 680 | " record = pickle.load(f)\n", 681 | " # 损失函数\n", 682 | " path = record['path']\n", 683 | " path = logAxis(path, Fmin)\n", 684 | " axline[0].plot(x_axis, path, fmt, color=color, label=label)\n", 685 | " # variance\n", 686 | " variancePath = record['variancePath']\n", 687 | " axline[1].plot(x_axis_minus_1, variancePath, fmt, color=color, label=label)\n", 688 | " except Exception as e:\n", 689 | " print(e)\n", 690 | " \n", 691 | " # 填小标题\n", 692 | " axline[0].set_title('optimal error ({})'.format(title))\n", 693 | " axline[1].set_title('variance ({})'.format(title))\n", 694 | " for ax in axline:\n", 695 | " # 坐标轴\n", 696 | " ax.set_yscale('log')\n", 697 | " ax.set_xlabel(r'iteration k / ${}$'.format(dataSetConfig['displayInterval']), fontsize=FONT_SIZE)\n", 698 | " labels = ax.get_xticklabels() + ax.get_yticklabels()\n", 699 | " [label.set_fontsize(FONT_SIZE) for label in labels]\n", 700 | "\n", 701 | " axline[0].set_ylabel(r'$f(x^k)-f(x^*)$', fontsize=FONT_SIZE)\n", 702 | " axline[1].set_ylabel(r'$D_{w\\notin B} [m_w^k]$', fontsize=FONT_SIZE)\n", 703 | " axline[1].set_xlim(left=0)\n", 704 | " \n", 705 | "# 图例\n", 706 | "# 一行\n", 707 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(0.01,-0.4), \n", 708 | "# borderaxespad = 0., ncol=4, fontsize=FONT_SIZE)\n", 709 | "# 两行\n", 710 | "axs[-1][0].legend(loc='lower left', bbox_to_anchor=(0.1,-0.6), \n", 711 | " borderaxespad = 0., ncol=3, fontsize=FONT_SIZE)\n", 712 | "\n", 713 | "fig.set_size_inches((SCALE*4, SCALE*8))\n", 714 | "plt.subplots_adjust(hspace=0.4, wspace=0.35)\n", 715 | "# plt.savefig('./attack_loss_variance.eps', format='eps', bbox_inches='tight')\n", 716 | "# plt.savefig('./attack_loss_variance.jpg', format='jpg', bbox_inches='tight')\n", 717 | "plt.show()" 718 | ] 719 | }, 720 | { 721 | "cell_type": "markdown", 722 | "metadata": {}, 723 | "source": [ 724 | "# 方差比较图(横)" 725 | ] 726 | }, 727 | { 728 | "cell_type": "markdown", 729 | "metadata": {}, 730 | "source": [ 731 | "ijcnn" 732 | ] 733 | }, 734 | { 735 | "cell_type": "code", 736 | "execution_count": null, 737 | "metadata": { 738 | "code_folding": [], 739 | "scrolled": false 740 | }, 741 | "outputs": [], 742 | "source": [ 743 | "import matplotlib.pyplot as plt\n", 744 | "import pickle\n", 745 | "import math\n", 746 | "\n", 747 | "def logAxis(path, Fmin):\n", 748 | " return [p-Fmin for p in path]\n", 749 | "\n", 750 | "# %matplotlib inline\n", 751 | "# %config InlineBackend.figure_format = 'svg'\n", 752 | "\n", 753 | "dataSetConfig = dataSetConfigs[0]\n", 754 | "\n", 755 | "SCALE = 2.0\n", 756 | "FONT_SIZE = 12\n", 757 | "\n", 758 | "fig, axs = plt.subplots(2, 4)\n", 759 | "axs = tuple(zip(*axs))\n", 760 | "\n", 761 | "attackNames = [\n", 762 | " ('baseline', 'without attack'), \n", 763 | " ('white', 'Gaussian attack'), \n", 764 | " ('maxValue', 'sign-flipping attack'), \n", 765 | " ('zeroGradient', 'zero-gradient attack'),\n", 766 | "]\n", 767 | "# 名称 颜色\n", 768 | "optimizers = [\n", 769 | " ('SGD', 'SGD', 'royalblue'), \n", 770 | " ('BatchSGD', 'BSGD', 'darkgreen'), \n", 771 | " ('SAGA', 'SAGA', 'darkorange'),\n", 772 | "]\n", 773 | "# 文件名 图片中显示的名称 纹理\n", 774 | "aggregations = [\n", 775 | " ('mean', 'mean', 'v-'), \n", 776 | " ('gm', 'geomed','v--'),\n", 777 | "]\n", 778 | "\n", 779 | "CACHE_DIR = './cache/' + dataSetConfig['name']\n", 780 | "with open(CACHE_DIR + '_Fmin', 'rb') as f:\n", 781 | " obj = pickle.load(f)\n", 782 | " Fmin, w_min = obj['Fmin'], obj['w_min']\n", 783 | "\n", 784 | "x_axis = list(range(dataSetConfig['rounds']+1))\n", 785 | "x_axis_minus_1 = list(range(1, dataSetConfig['rounds']+1))\n", 786 | "\n", 787 | "for axline, (attackName, title) in zip(axs, attackNames):\n", 788 | " for optimizer, optimizerName, color in optimizers:\n", 789 | " for (aggregationName, showName, fmt) in aggregations:\n", 790 | " # 标签\n", 791 | " label = optimizerName + ' ' + showName\n", 792 | " \n", 793 | " # 画曲线\n", 794 | " with open(CACHE_DIR + '_' + optimizer+'_' + attackName + '_' + aggregationName, 'rb') as f:\n", 795 | " record = pickle.load(f)\n", 796 | " # 损失函数\n", 797 | " path = record['path']\n", 798 | " path = logAxis(path, Fmin)\n", 799 | " axline[0].plot(x_axis, path, fmt, color=color, label=label)\n", 800 | " # variance\n", 801 | " variancePath = record['variancePath']\n", 802 | " axline[1].plot(x_axis_minus_1, variancePath, fmt, color=color, label=label)\n", 803 | " \n", 804 | " # 填小标题\n", 805 | " axline[0].set_title('optimality gap ({})'.format(title))\n", 806 | " axline[1].set_title('variance ({})'.format(title))\n", 807 | " for ax in axline:\n", 808 | " # 坐标轴\n", 809 | " ax.set_yscale('log')\n", 810 | " ax.set_xlabel(r'iteration k / ${}$'.format(dataSetConfig['displayInterval']), fontsize=FONT_SIZE)\n", 811 | " labels = ax.get_xticklabels() + ax.get_yticklabels()\n", 812 | " [label.set_fontsize(FONT_SIZE) for label in labels]\n", 813 | "\n", 814 | " axline[0].set_ylabel(r'$f(x^k)-f(x^*)$', fontsize=FONT_SIZE)\n", 815 | " axline[1].set_ylabel(r'$D_{w\\notin B} [m_w^k]$', fontsize=FONT_SIZE)\n", 816 | " axline[1].set_xlim(left=0)\n", 817 | "\n", 818 | "# 图例\n", 819 | "# 一行\n", 820 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(0.01,-0.4), \n", 821 | "# borderaxespad = 0., ncol=4, fontsize=FONT_SIZE)\n", 822 | "# 两行\n", 823 | "axs[0][1].legend(loc='lower left', bbox_to_anchor=(1.2,-0.5), \n", 824 | " borderaxespad = 0., ncol=3, fontsize=FONT_SIZE)\n", 825 | "# 右侧\n", 826 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(1.2, 0.3), \n", 827 | "# borderaxespad = 0., ncol=1, fontsize=FONT_SIZE)\n", 828 | "\n", 829 | "fig.set_size_inches((SCALE*9, SCALE*4))\n", 830 | "plt.subplots_adjust(hspace=0.35, wspace=0.35)\n", 831 | "plt.savefig('./attack_loss_variance.eps', format='eps', bbox_inches='tight')\n", 832 | "# plt.savefig('./attack_loss_variance.jpg', format='jpg', bbox_inches='tight')\n", 833 | "plt.show()" 834 | ] 835 | }, 836 | { 837 | "cell_type": "code", 838 | "execution_count": null, 839 | "metadata": { 840 | "code_folding": [] 841 | }, 842 | "outputs": [], 843 | "source": [ 844 | "import matplotlib.pyplot as plt\n", 845 | "import pickle\n", 846 | "import math\n", 847 | "\n", 848 | "def logAxis(path, Fmin):\n", 849 | " return [p-Fmin for p in path]\n", 850 | "\n", 851 | "# %matplotlib inline\n", 852 | "# %config InlineBackend.figure_format = 'svg'\n", 853 | "\n", 854 | "dataSetConfig = dataSetConfigs[1]\n", 855 | "\n", 856 | "SCALE = 2.0\n", 857 | "FONT_SIZE = 12\n", 858 | "\n", 859 | "fig, axs = plt.subplots(2, 4)\n", 860 | "axs = tuple(zip(*axs))\n", 861 | "\n", 862 | "attackNames = [\n", 863 | " ('baseline', 'without attack'), \n", 864 | " ('white', 'Gaussian attack'), \n", 865 | " ('maxValue', 'sign-flipping attack'), \n", 866 | " ('zeroGradient', 'zero-gradient attack'),\n", 867 | "]\n", 868 | "# 名称 颜色\n", 869 | "optimizers = [\n", 870 | " ('SGD', 'SGD', 'royalblue'), \n", 871 | " ('BatchSGD', 'BSGD', 'darkgreen'), \n", 872 | " ('SAGA', 'SAGA', 'darkorange'),\n", 873 | "]\n", 874 | "# 文件名 图片中显示的名称 纹理\n", 875 | "aggregations = [\n", 876 | " ('mean', 'mean', 'v-'), \n", 877 | " ('gm', 'geomed','v--'),\n", 878 | "]\n", 879 | "\n", 880 | "CACHE_DIR = './cache/' + dataSetConfig['name']\n", 881 | "with open(CACHE_DIR + '_Fmin', 'rb') as f:\n", 882 | " obj = pickle.load(f)\n", 883 | " Fmin, w_min = obj['Fmin'], obj['w_min']\n", 884 | "\n", 885 | "x_axis = list(range(dataSetConfig['rounds']+1))\n", 886 | "x_axis_minus_1 = list(range(1, dataSetConfig['rounds']+1))\n", 887 | "\n", 888 | "for axline, (attackName, title) in zip(axs, attackNames):\n", 889 | " for optimizer, optimizerName, color in optimizers:\n", 890 | " for (aggregationName, showName, fmt) in aggregations:\n", 891 | " # 标签\n", 892 | " label = optimizerName + ' ' + showName\n", 893 | " \n", 894 | " # 画曲线\n", 895 | " with open(CACHE_DIR + '_' + optimizer+'_' + attackName + '_' + aggregationName, 'rb') as f:\n", 896 | " record = pickle.load(f)\n", 897 | " # 损失函数\n", 898 | " path = record['path']\n", 899 | " path = logAxis(path, Fmin)\n", 900 | " axline[0].plot(x_axis, path, fmt, color=color, label=label)\n", 901 | " # variance\n", 902 | " variancePath = record['variancePath']\n", 903 | " axline[1].plot(x_axis_minus_1, variancePath, fmt, color=color, label=label)\n", 904 | " \n", 905 | " # 填小标题\n", 906 | " axline[0].set_title('optimality gap ({})'.format(title))\n", 907 | " axline[1].set_title('variance ({})'.format(title))\n", 908 | " for ax in axline:\n", 909 | " # 坐标轴\n", 910 | " ax.set_yscale('log')\n", 911 | " ax.set_xlabel(r'iteration k / ${}$'.format(dataSetConfig['displayInterval']), fontsize=FONT_SIZE)\n", 912 | " labels = ax.get_xticklabels() + ax.get_yticklabels()\n", 913 | " [label.set_fontsize(FONT_SIZE) for label in labels]\n", 914 | "\n", 915 | " axline[0].set_ylabel(r'$f(x^k)-f(x^*)$', fontsize=FONT_SIZE)\n", 916 | " axline[1].set_ylabel(r'$D_{w\\notin B} [m_w^k]$', fontsize=FONT_SIZE)\n", 917 | " axline[1].set_xlim(left=0)\n", 918 | "\n", 919 | "# 图例\n", 920 | "# 一行\n", 921 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(0.01,-0.4), \n", 922 | "# borderaxespad = 0., ncol=4, fontsize=FONT_SIZE)\n", 923 | "# 两行\n", 924 | "axs[0][1].legend(loc='lower left', bbox_to_anchor=(1.2,-0.5), \n", 925 | " borderaxespad = 0., ncol=3, fontsize=FONT_SIZE)\n", 926 | "# 右侧\n", 927 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(1.2, 0.3), \n", 928 | "# borderaxespad = 0., ncol=1, fontsize=FONT_SIZE)\n", 929 | "\n", 930 | "fig.set_size_inches((SCALE*9, SCALE*4))\n", 931 | "plt.subplots_adjust(hspace=0.35, wspace=0.35)\n", 932 | "plt.savefig('./attack_loss_variance_covtype.eps', format='eps', bbox_inches='tight')\n", 933 | "# plt.savefig('./attack_loss_variance.jpg', format='jpg', bbox_inches='tight')\n", 934 | "plt.show()" 935 | ] 936 | }, 937 | { 938 | "cell_type": "code", 939 | "execution_count": null, 940 | "metadata": { 941 | "code_folding": [] 942 | }, 943 | "outputs": [], 944 | "source": [ 945 | "import matplotlib.pyplot as plt\n", 946 | "import pickle\n", 947 | "import math\n", 948 | "\n", 949 | "def logAxis(path, Fmin):\n", 950 | " return [p-Fmin for p in path]\n", 951 | "\n", 952 | "# %matplotlib inline\n", 953 | "# %config InlineBackend.figure_format = 'svg'\n", 954 | "\n", 955 | "dataSetConfig = dataSetConfigs[0]\n", 956 | "\n", 957 | "SCALE = 2.0\n", 958 | "FONT_SIZE = 12\n", 959 | "\n", 960 | "fig, axs = plt.subplots(2, 4)\n", 961 | "axs = tuple(zip(*axs))\n", 962 | "\n", 963 | "attackNames = [\n", 964 | " ('baseline', 'without attack'), \n", 965 | " ('white', 'Gaussian attack'), \n", 966 | " ('maxValue', 'max-value attack'), \n", 967 | " ('zeroGradient', 'zero-gradient attack'),\n", 968 | "]\n", 969 | "# 名称 颜色\n", 970 | "optimizers = [\n", 971 | " ('SGD', 'SGD', 'royalblue'), \n", 972 | " ('BatchSGD', 'BSGD', 'darkgreen'), \n", 973 | " ('SAGA', 'SAGA', 'darkorange'),\n", 974 | " ('SVRG', 'SVRG', 'blueviolet'),\n", 975 | " ('SARAH', 'SARAH', 'k'),\n", 976 | "]\n", 977 | "# 文件名 图片中显示的名称 纹理\n", 978 | "aggregations = [\n", 979 | " ('mean', 'mean', 'v-'), \n", 980 | " ('gm', 'geomed','v--'),\n", 981 | "]\n", 982 | "\n", 983 | "CACHE_DIR = './cache/' + dataSetConfig['name']\n", 984 | "with open(CACHE_DIR + '_Fmin', 'rb') as f:\n", 985 | " obj = pickle.load(f)\n", 986 | " Fmin, w_min = obj['Fmin'], obj['w_min']\n", 987 | "\n", 988 | "x_axis = list(range(dataSetConfig['rounds']+1))\n", 989 | "x_axis_minus_1 = list(range(1, dataSetConfig['rounds']+1))\n", 990 | "\n", 991 | "for axline, (attackName, title) in zip(axs, attackNames):\n", 992 | " for optimizer, optimizerName, color in optimizers:\n", 993 | " for (aggregationName, showName, fmt) in aggregations:\n", 994 | " # 标签\n", 995 | " label = optimizerName + ' ' + showName\n", 996 | " \n", 997 | " # 画曲线\n", 998 | " with open(CACHE_DIR + '_' + optimizer+'_' + attackName + '_' + aggregationName, 'rb') as f:\n", 999 | " record = pickle.load(f)\n", 1000 | " # 损失函数\n", 1001 | " path = record['path']\n", 1002 | " path = logAxis(path, Fmin)\n", 1003 | " axline[0].plot(x_axis, path, fmt, color=color, label=label)\n", 1004 | " # variance\n", 1005 | " variancePath = record['variancePath']\n", 1006 | " axline[1].plot(x_axis_minus_1, variancePath, fmt, color=color, label=label)\n", 1007 | " \n", 1008 | " # 填小标题\n", 1009 | " axline[0].set_title('optimality gap ({})'.format(title))\n", 1010 | " axline[1].set_title('variance ({})'.format(title))\n", 1011 | " for ax in axline:\n", 1012 | " # 坐标轴\n", 1013 | " ax.set_yscale('log')\n", 1014 | " ax.set_xlabel(r'iteration k / ${}$'.format(dataSetConfig['displayInterval']), fontsize=FONT_SIZE)\n", 1015 | " labels = ax.get_xticklabels() + ax.get_yticklabels()\n", 1016 | " [label.set_fontsize(FONT_SIZE) for label in labels]\n", 1017 | "\n", 1018 | " axline[0].set_ylabel(r'$f(x^k)-f(x^*)$', fontsize=FONT_SIZE)\n", 1019 | " axline[1].set_ylabel(r'$D_{w\\notin B} [m_w^k]$', fontsize=FONT_SIZE)\n", 1020 | " axline[1].set_xlim(left=0)\n", 1021 | "\n", 1022 | "# 图例\n", 1023 | "# 一行\n", 1024 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(0.01,-0.4), \n", 1025 | "# borderaxespad = 0., ncol=4, fontsize=FONT_SIZE)\n", 1026 | "# 两行\n", 1027 | "axs[0][1].legend(loc='lower left', bbox_to_anchor=(1.0,-0.5), \n", 1028 | " borderaxespad = 0., ncol=5, fontsize=FONT_SIZE)\n", 1029 | "# 右侧\n", 1030 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(1.2, 0.3), \n", 1031 | "# borderaxespad = 0., ncol=1, fontsize=FONT_SIZE)\n", 1032 | "\n", 1033 | "fig.set_size_inches((SCALE*9, SCALE*4))\n", 1034 | "plt.subplots_adjust(hspace=0.35, wspace=0.35)\n", 1035 | "# plt.savefig('./attack_loss_variance.eps', format='eps', bbox_inches='tight')\n", 1036 | "# plt.savefig('./attack_loss_variance.jpg', format='jpg', bbox_inches='tight')\n", 1037 | "plt.show()" 1038 | ] 1039 | }, 1040 | { 1041 | "cell_type": "code", 1042 | "execution_count": null, 1043 | "metadata": { 1044 | "code_folding": [] 1045 | }, 1046 | "outputs": [], 1047 | "source": [ 1048 | "import matplotlib.pyplot as plt\n", 1049 | "import pickle\n", 1050 | "import math\n", 1051 | "\n", 1052 | "def logAxis(path, Fmin):\n", 1053 | " return [p-Fmin for p in path]\n", 1054 | "\n", 1055 | "# %matplotlib inline\n", 1056 | "# %config InlineBackend.figure_format = 'svg'\n", 1057 | "\n", 1058 | "dataSetConfig = dataSetConfigs[1]\n", 1059 | "\n", 1060 | "SCALE = 2.0\n", 1061 | "FONT_SIZE = 12\n", 1062 | "\n", 1063 | "fig, axs = plt.subplots(2, 4)\n", 1064 | "axs = tuple(zip(*axs))\n", 1065 | "\n", 1066 | "attackNames = [\n", 1067 | " ('baseline', 'without attack'), \n", 1068 | " ('white', 'Gaussian attack'), \n", 1069 | " ('maxValue', 'max-value attack'), \n", 1070 | " ('zeroGradient', 'zero-gradient attack'),\n", 1071 | "]\n", 1072 | "# 名称 颜色\n", 1073 | "optimizers = [\n", 1074 | " ('SGD', 'SGD', 'royalblue'), \n", 1075 | " ('BatchSGD', 'BSGD', 'darkgreen'), \n", 1076 | " ('SAGA', 'SAGA', 'darkorange'),\n", 1077 | " ('SVRG', 'SVRG', 'blueviolet'),\n", 1078 | "]\n", 1079 | "# 文件名 图片中显示的名称 纹理\n", 1080 | "aggregations = [\n", 1081 | " ('mean', 'mean', 'v-'), \n", 1082 | " ('gm', 'geomed','v--'),\n", 1083 | "]\n", 1084 | "\n", 1085 | "CACHE_DIR = './cache/' + dataSetConfig['name']\n", 1086 | "with open(CACHE_DIR + '_Fmin', 'rb') as f:\n", 1087 | " obj = pickle.load(f)\n", 1088 | " Fmin, w_min = obj['Fmin'], obj['w_min']\n", 1089 | "\n", 1090 | "x_axis = list(range(dataSetConfig['rounds']+1))\n", 1091 | "x_axis_minus_1 = list(range(1, dataSetConfig['rounds']+1))\n", 1092 | "\n", 1093 | "for axline, (attackName, title) in zip(axs, attackNames):\n", 1094 | " for optimizer, optimizerName, color in optimizers:\n", 1095 | " for (aggregationName, showName, fmt) in aggregations:\n", 1096 | " # 标签\n", 1097 | " label = optimizerName + ' ' + showName\n", 1098 | " \n", 1099 | " # 画曲线\n", 1100 | " with open(CACHE_DIR + '_' + optimizer+'_' + attackName + '_' + aggregationName, 'rb') as f:\n", 1101 | " record = pickle.load(f)\n", 1102 | " # 损失函数\n", 1103 | " path = record['path']\n", 1104 | " path = logAxis(path, Fmin)\n", 1105 | " axline[0].plot(x_axis, path, fmt, color=color, label=label)\n", 1106 | " # variance\n", 1107 | " variancePath = record['variancePath']\n", 1108 | " axline[1].plot(x_axis_minus_1, variancePath, fmt, color=color, label=label)\n", 1109 | " \n", 1110 | " # 填小标题\n", 1111 | " axline[0].set_title('optimality gap ({})'.format(title))\n", 1112 | " axline[1].set_title('variance ({})'.format(title))\n", 1113 | " for ax in axline:\n", 1114 | " # 坐标轴\n", 1115 | " ax.set_yscale('log')\n", 1116 | " ax.set_xlabel(r'iteration k / ${}$'.format(dataSetConfig['displayInterval']), fontsize=FONT_SIZE)\n", 1117 | " labels = ax.get_xticklabels() + ax.get_yticklabels()\n", 1118 | " [label.set_fontsize(FONT_SIZE) for label in labels]\n", 1119 | "\n", 1120 | " axline[0].set_ylabel(r'$f(x^k)-f(x^*)$', fontsize=FONT_SIZE)\n", 1121 | " axline[1].set_ylabel(r'$D_{w\\notin B} [m_w^k]$', fontsize=FONT_SIZE)\n", 1122 | " axline[1].set_xlim(left=0)\n", 1123 | "\n", 1124 | "# 图例\n", 1125 | "# 一行\n", 1126 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(0.01,-0.4), \n", 1127 | "# borderaxespad = 0., ncol=4, fontsize=FONT_SIZE)\n", 1128 | "# 两行\n", 1129 | "axs[0][1].legend(loc='lower left', bbox_to_anchor=(1.0,-0.5), \n", 1130 | " borderaxespad = 0., ncol=4, fontsize=FONT_SIZE)\n", 1131 | "# 右侧\n", 1132 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(1.2, 0.3), \n", 1133 | "# borderaxespad = 0., ncol=1, fontsize=FONT_SIZE)\n", 1134 | "\n", 1135 | "fig.set_size_inches((SCALE*9, SCALE*4))\n", 1136 | "plt.subplots_adjust(hspace=0.35, wspace=0.35)\n", 1137 | "# plt.savefig('./attack_loss_variance.eps', format='eps', bbox_inches='tight')\n", 1138 | "# plt.savefig('./attack_loss_variance.jpg', format='jpg', bbox_inches='tight')\n", 1139 | "plt.show()" 1140 | ] 1141 | }, 1142 | { 1143 | "cell_type": "markdown", 1144 | "metadata": {}, 1145 | "source": [ 1146 | "# 比较GM error, bias, variance" 1147 | ] 1148 | }, 1149 | { 1150 | "cell_type": "code", 1151 | "execution_count": null, 1152 | "metadata": {}, 1153 | "outputs": [], 1154 | "source": [ 1155 | "\n", 1156 | "dataSetConfig = {\n", 1157 | " 'name': 'ijcnn1',\n", 1158 | "\n", 1159 | " 'dataSet' : 'ijcnn1',\n", 1160 | " 'dataSetSize': 49990,\n", 1161 | " 'maxFeature': 22,\n", 1162 | " 'findingType': '1',\n", 1163 | "\n", 1164 | " 'honestNodeSize': 50,\n", 1165 | " 'byzantineNodeSize': 20,\n", 1166 | "\n", 1167 | " 'rounds': 10,\n", 1168 | " 'displayInterval': 4000,\n", 1169 | "}\n", 1170 | "\n", 1171 | "SCALE = 2.0\n", 1172 | "FONT_SIZE = 12\n", 1173 | "\n", 1174 | "x_axis_minus_1 = list(range(1, dataSetConfig['rounds']+1))\n", 1175 | "\n", 1176 | "CACHE_DIR = './cache/' + dataSetConfig['name']\n", 1177 | "\n", 1178 | "attackNames = [\n", 1179 | " ('baseline', 'without attack'), \n", 1180 | " ('white', 'Gaussian attack'), \n", 1181 | " ('maxValue', 'sign-flipping attack'), \n", 1182 | " ('zeroGradient', 'zero-gradient attack'),\n", 1183 | "]\n", 1184 | "\n", 1185 | "fig, axs = plt.subplots(2, 2)\n", 1186 | "axs = sum(axs.tolist(), [])\n", 1187 | "for ax, (attackName, title) in zip(axs, attackNames):\n", 1188 | " with open(CACHE_DIR + '_SGD_SAGA_cmpVar_' + attackName, 'rb') as f:\n", 1189 | " record = pickle.load(f)\n", 1190 | "\n", 1191 | "# SAGA_biasPath = record['SAGA_biasPath']\n", 1192 | " SAGA_variencePath = record['SAGA_variencePath']\n", 1193 | " SAGA_error_Path = record['SAGA_error_Path']\n", 1194 | "# SGD_biasPath = record['SGD_biasPath']\n", 1195 | " SGD_variencePath = record['SGD_variencePath']\n", 1196 | " SGD_error_Path = record['SGD_error_Path']\n", 1197 | "\n", 1198 | " # SAGA / SGD\n", 1199 | "# biasPro = [b_saga / b_sgd for b_saga, b_sgd in zip(SAGA_biasPath, SGD_biasPath)]\n", 1200 | " varPro = [v_saga / v_sgd for v_saga, v_sgd in zip(SAGA_variencePath, SGD_variencePath)]\n", 1201 | " errPro = [e_saga / e_sgd for e_saga, e_sgd in zip(SAGA_error_Path, SGD_error_Path)]\n", 1202 | "\n", 1203 | "# ax.plot(x_axis_minus_1, biasPro, label='error')\n", 1204 | " ax.plot(x_axis_minus_1, varPro, label='variance')\n", 1205 | " ax.plot(x_axis_minus_1, errPro, label='geomed error')\n", 1206 | "\n", 1207 | " # plt.plot(x_axis_minus_1, SAGA_biasPath, label='SAGA_biasPath')\n", 1208 | " # plt.plot(x_axis_minus_1, SAGA_variencePath, label='SAGA_variencePath')\n", 1209 | " # plt.plot(x_axis_minus_1, SAGA_error_Path, label='SAGA_error_Path')\n", 1210 | " # plt.plot(x_axis_minus_1, SGD_biasPath, label='SGD_biasPath')\n", 1211 | " # plt.plot(x_axis_minus_1, SGD_variencePath, label='SGD_variencePath')\n", 1212 | " # plt.plot(x_axis_minus_1, SGD_error_Path, label='SGD_error_Path')\n", 1213 | "\n", 1214 | " # 填小标题\n", 1215 | " ax.set_title('{}'.format(title))\n", 1216 | "\n", 1217 | " # 坐标轴\n", 1218 | " ax.set_xlabel(r'iteration $k$ / ${}$'.format(dataSetConfig['displayInterval']), fontsize=FONT_SIZE)\n", 1219 | " ax.set_ylabel(r'ratio', fontsize=FONT_SIZE)\n", 1220 | " labels = ax.get_xticklabels() + ax.get_yticklabels()\n", 1221 | " [label.set_fontsize(FONT_SIZE) for label in labels]\n", 1222 | " \n", 1223 | "# ax.set_yscale('log')\n", 1224 | "\n", 1225 | "# 两行\n", 1226 | "axs[-2].legend(loc='lower left', bbox_to_anchor=(0.6,-0.4), \n", 1227 | " borderaxespad = 0., ncol=3, fontsize=FONT_SIZE)\n", 1228 | "\n", 1229 | "fig.set_size_inches((SCALE*4, SCALE*4))\n", 1230 | "plt.subplots_adjust(hspace=0.35, wspace=0.3)\n", 1231 | "plt.savefig('./cmp_variance.eps', format='eps', bbox_inches='tight')\n", 1232 | "plt.show()" 1233 | ] 1234 | }, 1235 | { 1236 | "cell_type": "markdown", 1237 | "metadata": {}, 1238 | "source": [ 1239 | "# 单个精度影响聚合精度示意图" 1240 | ] 1241 | }, 1242 | { 1243 | "cell_type": "code", 1244 | "execution_count": null, 1245 | "metadata": {}, 1246 | "outputs": [], 1247 | "source": [ 1248 | "import numpy as np\n", 1249 | "\n", 1250 | "def aggregate_geometric(wList):\n", 1251 | " max_iter = 1000\n", 1252 | " tol = 1e-7\n", 1253 | " guess = np.mean(wList, axis=0)\n", 1254 | " for _ in range(max_iter):\n", 1255 | " dist_li = [np.linalg.norm(w - guess) for w in wList]\n", 1256 | " dist_li = [d if d != 0 else 1 for d in dist_li]\n", 1257 | " temp1 = np.sum([w / dist for w, dist in zip(wList, dist_li)], axis=0)\n", 1258 | " temp2 = np.sum([1.0 / dist for dist in dist_li])\n", 1259 | " guess_next = temp1 / temp2\n", 1260 | " guess_movement = np.linalg.norm(guess - guess_next)\n", 1261 | " guess = guess_next\n", 1262 | " if guess_movement <= tol:\n", 1263 | " break\n", 1264 | " return guess\n", 1265 | "\n", 1266 | "fig, axs = plt.subplots(1, 2)\n", 1267 | "\n", 1268 | "bias = 1\n", 1269 | "pointCount = 20\n", 1270 | "\n", 1271 | "mean = [4, 2]\n", 1272 | "largeVar = 1.0\n", 1273 | "smallVar = 0.1\n", 1274 | "\n", 1275 | "MARKER_SIZE = 10\n", 1276 | "\n", 1277 | "np.random.seed(500)\n", 1278 | "largeList = np.random.multivariate_normal(mean, largeVar*np.array([[1, 0], [0, 1]]), pointCount)\n", 1279 | "smallList = np.random.multivariate_normal(mean, smallVar*np.array([[1, 0], [0, 1]]), pointCount)\n", 1280 | "\n", 1281 | "byzantineNode = np.array([\n", 1282 | " [1, 4],\n", 1283 | " [1, 3],\n", 1284 | " [1, 2],\n", 1285 | " [1, 1],\n", 1286 | " [1.5, 4],\n", 1287 | " [1.5, 3],\n", 1288 | " [1.5, 2],\n", 1289 | " [1.5, 1],\n", 1290 | " [2, 4],\n", 1291 | " [2, 3],\n", 1292 | " [2, 2],\n", 1293 | " [2, 1]\n", 1294 | "])\n", 1295 | "[*byzantineXY] = zip(*byzantineNode)\n", 1296 | "\n", 1297 | "largeAll = np.vstack([largeList, byzantineNode])\n", 1298 | "smallAll = np.vstack([smallList, byzantineNode])\n", 1299 | "\n", 1300 | "largeGM = aggregate_geometric(largeAll)\n", 1301 | "smallGM = aggregate_geometric(smallAll)\n", 1302 | "\n", 1303 | "[*largeXY] = zip(*largeList)\n", 1304 | "[*smallXY] = zip(*smallList)\n", 1305 | "\n", 1306 | "# 画好节点\n", 1307 | "fmt = '.'\n", 1308 | "axs[0].plot(largeXY[0], largeXY[1], fmt, label='honest gradient')\n", 1309 | "axs[1].plot(smallXY[0], smallXY[1], fmt)\n", 1310 | "\n", 1311 | "# 画拜占庭节点\n", 1312 | "bfmt = 'r.'\n", 1313 | "for ax in axs:\n", 1314 | " ax.plot(byzantineXY[0], byzantineXY[1], bfmt, label='Byzantine gradient')\n", 1315 | " \n", 1316 | "# GM\n", 1317 | "GMfmt = 'k+'\n", 1318 | "axs[0].plot([largeGM[0]], [largeGM[1]], GMfmt, label='geometric median', markersize=MARKER_SIZE)\n", 1319 | "axs[1].plot([smallGM[0]], [smallGM[1]], GMfmt, markersize=MARKER_SIZE)\n", 1320 | "\n", 1321 | "# 真实梯度\n", 1322 | "trueFmt = 'k*'\n", 1323 | "for ax in axs:\n", 1324 | " ax.plot([mean[0]], [mean[1]], trueFmt, label='true gradient', markersize=MARKER_SIZE)\n", 1325 | "\n", 1326 | "for ax in axs:\n", 1327 | " ax.set_xlim(0, 6)\n", 1328 | " ax.set_ylim(0, 5)\n", 1329 | " ax.set_axis_off()\n", 1330 | "\n", 1331 | "axs[0].set_title('gradients with large variance')\n", 1332 | "axs[1].set_title('gradients with small variance')\n", 1333 | "SCALE = 1\n", 1334 | "FONT_SIZE = 12\n", 1335 | "fig.set_size_inches((SCALE*8, SCALE*4))\n", 1336 | "axs[0].legend(loc='lower left', bbox_to_anchor=(0.3,-0.2), \n", 1337 | " borderaxespad = 0., ncol=2, fontsize=FONT_SIZE)\n", 1338 | "# plt.savefig('./sketchMap_howVarianceAffectGM.eps', format='eps', bbox_inches='tight')\n", 1339 | "plt.show()" 1340 | ] 1341 | }, 1342 | { 1343 | "cell_type": "markdown", 1344 | "metadata": {}, 1345 | "source": [ 1346 | "# 零方差SAGA" 1347 | ] 1348 | }, 1349 | { 1350 | "cell_type": "code", 1351 | "execution_count": null, 1352 | "metadata": { 1353 | "code_folding": [] 1354 | }, 1355 | "outputs": [], 1356 | "source": [ 1357 | "import matplotlib.pyplot as plt\n", 1358 | "import pickle\n", 1359 | "import math\n", 1360 | "\n", 1361 | "def logAxis(path, Fmin):\n", 1362 | " return [p-Fmin for p in path]\n", 1363 | "\n", 1364 | "# %matplotlib inline\n", 1365 | "# %config InlineBackend.figure_format = 'svg'\n", 1366 | "\n", 1367 | "dataSetConfig = dataSetConfigs[0]\n", 1368 | "\n", 1369 | "SCALE = 2.0\n", 1370 | "FONT_SIZE = 12\n", 1371 | "\n", 1372 | "fig, axs = plt.subplots(2, 4)\n", 1373 | "axs = tuple(zip(*axs))\n", 1374 | "\n", 1375 | "attackNames = [\n", 1376 | " ('baseline', 'without attack'), \n", 1377 | " ('white', 'Gaussian attack'), \n", 1378 | " ('maxValue', 'sign-flipping attack'), \n", 1379 | " ('zeroGradient', 'zero-gradient attack'),\n", 1380 | "]\n", 1381 | "# 名称 颜色\n", 1382 | "optimizers = [\n", 1383 | " ('SGD_ZV', 'SGD', 'royalblue'), \n", 1384 | " ('BatchSGD_ZV', 'BSGD', 'darkgreen'), \n", 1385 | " ('SAGA_ZV', 'SAGA', 'darkorange'),\n", 1386 | "]\n", 1387 | "\n", 1388 | "CACHE_DIR = './cache/' + dataSetConfig['name']\n", 1389 | "with open(CACHE_DIR + '_Fmin', 'rb') as f:\n", 1390 | " obj = pickle.load(f)\n", 1391 | " Fmin, w_min = obj['Fmin'], obj['w_min']\n", 1392 | "\n", 1393 | "x_axis = list(range(dataSetConfig['rounds']+1))\n", 1394 | "x_axis_minus_1 = list(range(1, dataSetConfig['rounds']+1))\n", 1395 | "\n", 1396 | "for axline, (attackName, title) in zip(axs, attackNames):\n", 1397 | " for optimizer, optimizerName, color in optimizers:\n", 1398 | " aggregationName = 'gm'\n", 1399 | " showName = 'geomed'\n", 1400 | " fmt = 'v--'\n", 1401 | "\n", 1402 | " # 标签\n", 1403 | " label = optimizerName + ' ' + showName\n", 1404 | "\n", 1405 | " # 画曲线\n", 1406 | " with open(CACHE_DIR + '_' + optimizer + '_' + attackName + '_' + aggregationName, 'rb') as f:\n", 1407 | " record = pickle.load(f)\n", 1408 | " # 损失函数\n", 1409 | " path = record['path']\n", 1410 | " path = logAxis(path, Fmin)\n", 1411 | " axline[0].plot(x_axis, path, fmt, color=color, label=label)\n", 1412 | " # variance\n", 1413 | " variancePath = record['variancePath']\n", 1414 | " axline[1].plot(x_axis_minus_1, variancePath, fmt, color=color, label=label)\n", 1415 | " \n", 1416 | " # 填小标题\n", 1417 | " axline[0].set_title('optimality gap ({})'.format(title))\n", 1418 | " axline[1].set_title('variance ({})'.format(title))\n", 1419 | " for ax in axline:\n", 1420 | " # 坐标轴\n", 1421 | " ax.set_yscale('log')\n", 1422 | " ax.set_xlabel(r'iteration k / ${}$'.format(dataSetConfig['displayInterval']*50), fontsize=FONT_SIZE)\n", 1423 | " labels = ax.get_xticklabels() + ax.get_yticklabels()\n", 1424 | " [label.set_fontsize(FONT_SIZE) for label in labels]\n", 1425 | "\n", 1426 | " axline[0].set_ylabel(r'$f(x^k)-f(x^*)$', fontsize=FONT_SIZE)\n", 1427 | " axline[1].set_ylabel(r'$D_{w\\notin B} [m_w^k]$', fontsize=FONT_SIZE)\n", 1428 | " axline[1].set_xlim(left=0)\n", 1429 | "\n", 1430 | "# 图例\n", 1431 | "# 一行\n", 1432 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(0.01,-0.4), \n", 1433 | "# borderaxespad = 0., ncol=4, fontsize=FONT_SIZE)\n", 1434 | "# 两行\n", 1435 | "axs[0][1].legend(loc='lower left', bbox_to_anchor=(1.2,-0.5), \n", 1436 | " borderaxespad = 0., ncol=3, fontsize=FONT_SIZE)\n", 1437 | "# 右侧\n", 1438 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(1.2, 0.3), \n", 1439 | "# borderaxespad = 0., ncol=1, fontsize=FONT_SIZE)\n", 1440 | "\n", 1441 | "fig.set_size_inches((SCALE*9, SCALE*4))\n", 1442 | "plt.subplots_adjust(hspace=0.35, wspace=0.35)\n", 1443 | "plt.savefig('./zeroOuterVariation.eps', format='eps', bbox_inches='tight')\n", 1444 | "# plt.savefig('./zeroOuterVariation.jpg', format='jpg', bbox_inches='tight')\n", 1445 | "plt.show()" 1446 | ] 1447 | }, 1448 | { 1449 | "cell_type": "code", 1450 | "execution_count": null, 1451 | "metadata": { 1452 | "code_folding": [] 1453 | }, 1454 | "outputs": [], 1455 | "source": [ 1456 | "import matplotlib.pyplot as plt\n", 1457 | "import pickle\n", 1458 | "import math\n", 1459 | "\n", 1460 | "def logAxis(path, Fmin):\n", 1461 | " return [p-Fmin for p in path]\n", 1462 | "\n", 1463 | "# %matplotlib inline\n", 1464 | "# %config InlineBackend.figure_format = 'svg'\n", 1465 | "\n", 1466 | "dataSetConfig = dataSetConfigs[0]\n", 1467 | "\n", 1468 | "SCALE = 2.0\n", 1469 | "FONT_SIZE = 12\n", 1470 | "\n", 1471 | "fig, axs = plt.subplots(2, 2)\n", 1472 | "axs = list(axs[0]) + list(axs[1])\n", 1473 | "\n", 1474 | "attackNames = [\n", 1475 | " ('baseline', 'without attack'), \n", 1476 | " ('white', 'Gaussian attack'), \n", 1477 | " ('maxValue', 'sign-flipping'), \n", 1478 | " ('zeroGradient', 'zero-gradient attack'),\n", 1479 | "]\n", 1480 | "# 名称 颜色\n", 1481 | "optimizers = [\n", 1482 | " ('SGD_ZV', 'SGD', 'royalblue'), \n", 1483 | " ('BatchSGD_ZV', 'BSGD', 'darkgreen'), \n", 1484 | " ('SAGA_ZV', 'SAGA', 'darkorange'),\n", 1485 | "]\n", 1486 | "\n", 1487 | "CACHE_DIR = './cache/' + dataSetConfig['name']\n", 1488 | "with open(CACHE_DIR + '_Fmin', 'rb') as f:\n", 1489 | " obj = pickle.load(f)\n", 1490 | " Fmin, w_min = obj['Fmin'], obj['w_min']\n", 1491 | "\n", 1492 | "x_axis = list(range(dataSetConfig['rounds']+1))\n", 1493 | "x_axis_minus_1 = list(range(1, dataSetConfig['rounds']+1))\n", 1494 | "\n", 1495 | "for ax, (attackName, title) in zip(axs, attackNames):\n", 1496 | " for optimizer, optimizerName, color in optimizers:\n", 1497 | " aggregationName = 'gm'\n", 1498 | " showName = 'geomed'\n", 1499 | " fmt = 'v--'\n", 1500 | "\n", 1501 | " # 标签\n", 1502 | " label = optimizerName + ' ' + showName\n", 1503 | "\n", 1504 | " # 画曲线\n", 1505 | " with open(CACHE_DIR + '_' + optimizer + '_' + attackName + '_' + aggregationName, 'rb') as f:\n", 1506 | " record = pickle.load(f)\n", 1507 | " # 损失函数\n", 1508 | " path = record['path']\n", 1509 | " path = logAxis(path, Fmin)\n", 1510 | " ax.plot(x_axis, path, fmt, color=color, label=label)\n", 1511 | " \n", 1512 | " # 填小标题\n", 1513 | " ax.set_title('optimality gap ({})'.format(title))\n", 1514 | "for ax in axs:\n", 1515 | " # 坐标轴\n", 1516 | " ax.set_yscale('log')\n", 1517 | " ax.set_xlabel(r'iteration k / ${}$'.format(dataSetConfig['displayInterval']*50), fontsize=FONT_SIZE)\n", 1518 | " labels = ax.get_xticklabels() + ax.get_yticklabels()\n", 1519 | " [label.set_fontsize(FONT_SIZE) for label in labels]\n", 1520 | "\n", 1521 | " ax.set_ylabel(r'$f(x^k)-f(x^*)$', fontsize=FONT_SIZE)\n", 1522 | " ax.set_xlim(left=0)\n", 1523 | "\n", 1524 | "# 图例\n", 1525 | "# 一行\n", 1526 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(0.01,-0.4), \n", 1527 | "# borderaxespad = 0., ncol=4, fontsize=FONT_SIZE)\n", 1528 | "# 两行\n", 1529 | "axs[2].legend(loc='lower left', bbox_to_anchor=(0.15,-0.5), \n", 1530 | " borderaxespad = 0., ncol=3, fontsize=FONT_SIZE)\n", 1531 | "# 右侧\n", 1532 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(1.2, 0.3), \n", 1533 | "# borderaxespad = 0., ncol=1, fontsize=FONT_SIZE)\n", 1534 | "\n", 1535 | "fig.set_size_inches((SCALE*4, SCALE*4))\n", 1536 | "plt.subplots_adjust(hspace=0.35, wspace=0.35)\n", 1537 | "# plt.savefig('./zeroOuterVariation_small.eps', format='eps', bbox_inches='tight')\n", 1538 | "# plt.savefig('./zeroOuterVariation.jpg', format='jpg', bbox_inches='tight')\n", 1539 | "plt.show()" 1540 | ] 1541 | }, 1542 | { 1543 | "cell_type": "code", 1544 | "execution_count": null, 1545 | "metadata": { 1546 | "code_folding": [] 1547 | }, 1548 | "outputs": [], 1549 | "source": [ 1550 | "import matplotlib.pyplot as plt\n", 1551 | "import pickle\n", 1552 | "import math\n", 1553 | "\n", 1554 | "def logAxis(path, Fmin):\n", 1555 | " return [p-Fmin for p in path]\n", 1556 | "\n", 1557 | "# %matplotlib inline\n", 1558 | "# %config InlineBackend.figure_format = 'svg'\n", 1559 | "\n", 1560 | "dataSetConfig = dataSetConfigs[0]\n", 1561 | "\n", 1562 | "SCALE = 2.0\n", 1563 | "FONT_SIZE = 12\n", 1564 | "\n", 1565 | "fig, axs = plt.subplots(2, 2)\n", 1566 | "axs = list(axs[0]) + list(axs[1])\n", 1567 | "\n", 1568 | "attackNames = [\n", 1569 | " ('baseline', 'without attack'), \n", 1570 | " ('white', 'Gaussian attack'), \n", 1571 | " ('maxValue', 'max-value attack'), \n", 1572 | " ('zeroGradient', 'zero-gradient attack'),\n", 1573 | "]\n", 1574 | "# 名称 颜色\n", 1575 | "optimizers = [\n", 1576 | " ('SGD_ZV', 'SGD', 'royalblue'), \n", 1577 | " ('BatchSGD_ZV', 'BSGD', 'darkgreen'), \n", 1578 | " ('SAGA_ZV', 'SAGA', 'darkorange'),\n", 1579 | "]\n", 1580 | "\n", 1581 | "CACHE_DIR = './cache/' + dataSetConfig['name']\n", 1582 | "with open(CACHE_DIR + '_Fmin', 'rb') as f:\n", 1583 | " obj = pickle.load(f)\n", 1584 | " Fmin, w_min = obj['Fmin'], obj['w_min']\n", 1585 | "\n", 1586 | "x_axis = list(range(dataSetConfig['rounds']+1))\n", 1587 | "x_axis_minus_1 = list(range(1, dataSetConfig['rounds']+1))\n", 1588 | "\n", 1589 | "for ax, (attackName, title) in zip(axs, attackNames):\n", 1590 | " for optimizer, optimizerName, color in optimizers:\n", 1591 | " aggregationName = 'gm'\n", 1592 | " showName = 'geomed'\n", 1593 | " fmt = 'v--'\n", 1594 | "\n", 1595 | " # 标签\n", 1596 | " label = optimizerName + ' ' + showName\n", 1597 | "\n", 1598 | " # 画曲线\n", 1599 | " with open(CACHE_DIR + '_' + optimizer + '_' + attackName + '_' + aggregationName, 'rb') as f:\n", 1600 | " record = pickle.load(f)\n", 1601 | " # 损失函数\n", 1602 | " path = record['path']\n", 1603 | " path = logAxis(path, Fmin)\n", 1604 | " ax.plot(x_axis, path, fmt, color=color, label=label)\n", 1605 | " \n", 1606 | " # 填小标题\n", 1607 | " ax.set_title('optimality gap ({})'.format(title))\n", 1608 | "for ax in axs:\n", 1609 | " # 坐标轴\n", 1610 | " ax.set_yscale('log')\n", 1611 | " ax.set_xlabel(r'iteration k / ${}$'.format(dataSetConfig['displayInterval']*50), fontsize=FONT_SIZE)\n", 1612 | " labels = ax.get_xticklabels() + ax.get_yticklabels()\n", 1613 | " [label.set_fontsize(FONT_SIZE) for label in labels]\n", 1614 | "\n", 1615 | " ax.set_ylabel(r'$f(x^k)-f(x^*)$', fontsize=FONT_SIZE)\n", 1616 | " ax.set_xlim(left=0)\n", 1617 | "\n", 1618 | "# 图例\n", 1619 | "# 一行\n", 1620 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(0.01,-0.4), \n", 1621 | "# borderaxespad = 0., ncol=4, fontsize=FONT_SIZE)\n", 1622 | "# 两行\n", 1623 | "axs[2].legend(loc='lower left', bbox_to_anchor=(0.15,-0.5), \n", 1624 | " borderaxespad = 0., ncol=3, fontsize=FONT_SIZE)\n", 1625 | "# 右侧\n", 1626 | "# axs[-1][0].legend(loc='lower left', bbox_to_anchor=(1.2, 0.3), \n", 1627 | "# borderaxespad = 0., ncol=1, fontsize=FONT_SIZE)\n", 1628 | "\n", 1629 | "fig.set_size_inches((SCALE*4, SCALE*4))\n", 1630 | "plt.subplots_adjust(hspace=0.35, wspace=0.35)\n", 1631 | "# plt.savefig('./zeroOuterVariation_small.eps', format='eps', bbox_inches='tight')\n", 1632 | "# plt.savefig('./zeroOuterVariation.jpg', format='jpg', bbox_inches='tight')\n", 1633 | "plt.show()" 1634 | ] 1635 | } 1636 | ], 1637 | "metadata": { 1638 | "kernelspec": { 1639 | "display_name": "Python 3", 1640 | "language": "python", 1641 | "name": "python3" 1642 | }, 1643 | "language_info": { 1644 | "codemirror_mode": { 1645 | "name": "ipython", 1646 | "version": 3 1647 | }, 1648 | "file_extension": ".py", 1649 | "mimetype": "text/x-python", 1650 | "name": "python", 1651 | "nbconvert_exporter": "python", 1652 | "pygments_lexer": "ipython3", 1653 | "version": "3.7.4" 1654 | }, 1655 | "varInspector": { 1656 | "cols": { 1657 | "lenName": 16, 1658 | "lenType": 16, 1659 | "lenVar": 40 1660 | }, 1661 | "kernels_config": { 1662 | "python": { 1663 | "delete_cmd_postfix": "", 1664 | "delete_cmd_prefix": "del ", 1665 | "library": "var_list.py", 1666 | "varRefreshCmd": "print(var_dic_list())" 1667 | }, 1668 | "r": { 1669 | "delete_cmd_postfix": ") ", 1670 | "delete_cmd_prefix": "rm(", 1671 | "library": "var_list.r", 1672 | "varRefreshCmd": "cat(var_dic_list()) " 1673 | } 1674 | }, 1675 | "types_to_exclude": [ 1676 | "module", 1677 | "function", 1678 | "builtin_function_or_method", 1679 | "instance", 1680 | "_Feature" 1681 | ], 1682 | "window_display": false 1683 | } 1684 | }, 1685 | "nbformat": 4, 1686 | "nbformat_minor": 2 1687 | } 1688 | -------------------------------------------------------------------------------- /Byrd_SAGA_torch_ANN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 基本定义" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import matplotlib.pyplot as plt\n", 17 | "import random\n", 18 | "import time\n", 19 | "import pickle" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import torch\n", 29 | "# from torchvision.datasets import MNIST\n", 30 | "import torchvision\n", 31 | "from torchvision import transforms\n", 32 | "import matplotlib.pyplot as plt" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "## 数据集/模型" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "### MLP + MNIST" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# MLP\n", 56 | "optConfig = {\n", 57 | " 'honestSize': 50,\n", 58 | " 'byzantineSize': 20,\n", 59 | "\n", 60 | " 'rounds': 15,\n", 61 | " 'displayInterval': 1000,\n", 62 | "\n", 63 | " 'weight_decay': 0.00,\n", 64 | " \n", 65 | " 'fixSeed': False,\n", 66 | " 'SEED': 100,\n", 67 | " \n", 68 | " 'batchSize': 5,\n", 69 | " 'shuffle': True,\n", 70 | "}\n", 71 | "\n", 72 | "# 数据集属性\n", 73 | "dataSetConfig = {\n", 74 | " 'name': 'mnist',\n", 75 | "\n", 76 | " 'dataSet' : 'mnist',\n", 77 | " 'dataSetSize': 60000,\n", 78 | " 'maxFeature': 784,\n", 79 | "\n", 80 | " 'honestNodeSize': 50,\n", 81 | " 'byzantineNodeSize': 20,\n", 82 | "\n", 83 | " 'rounds': 15,\n", 84 | " 'displayInterval': 1000,\n", 85 | "}\n", 86 | "\n", 87 | "SGDConfig = optConfig.copy()\n", 88 | "SGDConfig['gamma'] = 1e-1\n", 89 | "\n", 90 | "batchConfig = optConfig.copy()\n", 91 | "batchConfig['batchSize'] = 50\n", 92 | "batchConfig['gamma'] = 5e-1\n", 93 | "\n", 94 | "SVRGConfig = optConfig.copy()\n", 95 | "SVRGConfig['snapshotInterval'] = dataSetConfig['dataSetSize']\n", 96 | "SVRGConfig['gamma'] = 1e-1\n", 97 | "\n", 98 | "SAGAConfig = optConfig.copy()\n", 99 | "SAGAConfig['gamma'] = 1e-1\n", 100 | "\n", 101 | "SARAHConfig = optConfig.copy()\n", 102 | "SARAHConfig['gamma'] = 1e-1\n", 103 | "\n", 104 | "# 加载数据集\n", 105 | "train_transform = transforms.Compose([\n", 106 | " transforms.ToTensor(), # Convert a PIL Image or numpy.ndarray to tensor.\n", 107 | " # Normalize a tensor image with mean 0.1307 and standard deviation 0.3081\n", 108 | " transforms.Normalize((0.1307,), (0.3081,))\n", 109 | "])\n", 110 | "test_transform = transforms.Compose([\n", 111 | " transforms.ToTensor(),\n", 112 | " transforms.Normalize((0.1307,), (0.3081,))\n", 113 | "])\n", 114 | "train_dataset = torchvision.datasets.MNIST(root='./dataset/', \n", 115 | " train=True, \n", 116 | " transform=train_transform,\n", 117 | " download=True)\n", 118 | "validate_dataset = torchvision.datasets.MNIST(root='./dataset/', \n", 119 | " train=False, \n", 120 | " transform=test_transform,\n", 121 | " download=False)\n", 122 | "\n", 123 | "# 模型\n", 124 | "class MLP(torch.nn.Module):\n", 125 | " \"\"\"\n", 126 | " Inputs Linear/Function Output\n", 127 | " [128, 1, 28, 28] -> Linear(28*28, 100) -> [128, 100] # first hidden layer\n", 128 | " -> Tanh -> [128, 100] # Tanh activation function, may sigmoid\n", 129 | " -> Linear(100, 100) -> [128, 100] # third hidden layer\n", 130 | " -> Tanh -> [128, 100] # Tanh activation function, may sigmoid\n", 131 | " -> Linear(100, 10) -> [128, 10] # Classification Layer \n", 132 | " \"\"\"\n", 133 | " def __init__(self, input_size, hidden_size, output_size, SEED=100):\n", 134 | " super(MLP, self).__init__()\n", 135 | " self.hidden = torch.nn.Linear(input_size, hidden_size)\n", 136 | " self.classification_layer = torch.nn.Linear(hidden_size, output_size)\n", 137 | " \n", 138 | " self.tanh1 = torch.nn.Tanh()\n", 139 | " self.tanh2 = torch.nn.Tanh()\n", 140 | " \n", 141 | " self.softmax = torch.nn.Softmax(dim=1)\n", 142 | " \n", 143 | " def forward(self, x):\n", 144 | " \"\"\"Defines the computation performed at every call.\n", 145 | " Should be overridden by all subclasses.\n", 146 | " Args:\n", 147 | " x: [batch_size, channel, height, width], input for network\n", 148 | " Returns:\n", 149 | " out: [batch_size, n_classes], output from network\n", 150 | " \"\"\"\n", 151 | " \n", 152 | " out = x.view(x.size(0), -1) # flatten x in [128, 784]\n", 153 | " out = self.tanh1(out)\n", 154 | " out = self.hidden(out)\n", 155 | " out = self.tanh2(out)\n", 156 | " out = self.classification_layer(out)\n", 157 | " out = self.softmax(out)\n", 158 | " return out\n", 159 | " \n", 160 | "# 模型工厂\n", 161 | "def modelFactory(SEED=100):\n", 162 | " return MLP(784, 50, 10)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "### ResNet + CIFAR10" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "# # ResNet50 + CIFAR10\n", 179 | "# optConfig = {\n", 180 | "# 'honestSize': 10,\n", 181 | "# 'byzantineSize': 4,\n", 182 | "\n", 183 | "# 'rounds': 15,\n", 184 | "# 'displayInterval': 6000,\n", 185 | " \n", 186 | "# 'weight_decay': 0.0001,\n", 187 | " \n", 188 | "# 'fixSeed': False,\n", 189 | "# 'SEED': 100,\n", 190 | " \n", 191 | "# 'batchSize': 5,\n", 192 | "# 'shuffle': True,\n", 193 | "# }\n", 194 | "\n", 195 | "# SGDConfig = optConfig.copy()\n", 196 | "# SGDConfig['gamma'] = 1e-1\n", 197 | "\n", 198 | "# batchConfig = optConfig.copy()\n", 199 | "# batchConfig['batchSize'] = 50\n", 200 | "# batchConfig['gamma'] = 5e-1\n", 201 | "\n", 202 | "# SVRGConfig = optConfig.copy()\n", 203 | "# SVRGConfig['snapshotInterval'] = dataSetConfig['dataSetSize']\n", 204 | "# SVRGConfig['gamma'] = 1e-1\n", 205 | "\n", 206 | "# SAGAConfig = optConfig.copy()\n", 207 | "# SAGAConfig['gamma'] = 1e-1\n", 208 | "\n", 209 | "# SARAHConfig = optConfig.copy()\n", 210 | "# SARAHConfig['gamma'] = 1e-1\n", 211 | "\n", 212 | "# # 数据集属性\n", 213 | "# dataSetConfig = {\n", 214 | "# 'name': 'CIFAR-10',\n", 215 | "\n", 216 | "# 'dataSet' : 'CIFAR-10',\n", 217 | "# 'dataSetSize': 60000,\n", 218 | "# 'maxFeature': 32*32*3,\n", 219 | "# }\n", 220 | "\n", 221 | "# # 加载数据集\n", 222 | "# preprocess = transforms.Compose([\n", 223 | "# transforms.Resize(256),\n", 224 | "# transforms.CenterCrop(224),\n", 225 | "# transforms.ToTensor(),\n", 226 | "# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", 227 | "# ])\n", 228 | "# train_dataset = torchvision.datasets.CIFAR10(root='./dataset/',\n", 229 | "# train=True, \n", 230 | "# transform=preprocess,\n", 231 | "# download=False)\n", 232 | "# validate_dataset = torchvision.datasets.CIFAR10(root='./dataset/',\n", 233 | "# train=False, \n", 234 | "# transform=preprocess)\n", 235 | "\n", 236 | "# 模型工厂\n", 237 | "# def modelFactory(SEED=100):\n", 238 | "# return torchvision.models.resnet50()" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "## 运行参数" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "CACHE_DIR = './cache/' + dataSetConfig['name'] + '_'" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "## 辅助函数" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "# 报告函数\n", 280 | "def log(*k, **kw):\n", 281 | " timeStamp = time.strftime('[%m-%d %H:%M:%S] ', time.localtime())\n", 282 | " print(timeStamp, end='')\n", 283 | " print(*k, **kw)\n", 284 | "def debug(*k, **kw):\n", 285 | " timeStamp = time.strftime('[%m-%d %H:%M:%S] (debug)', time.localtime())\n", 286 | " print(timeStamp, end='')\n", 287 | " print(*k, **kw)" 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "metadata": {}, 293 | "source": [ 294 | "## 损失函数" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "loss_func = torch.nn.CrossEntropyLoss()" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "metadata": {}, 310 | "outputs": [], 311 | "source": [ 312 | "def getVarience(w_local, honestSize):\n", 313 | " avg = w_local[:honestSize].mean(dim=0)\n", 314 | " s = 0\n", 315 | " for w in w_local[:honestSize]:\n", 316 | " s += (w - avg).norm()**2\n", 317 | " s /= honestSize\n", 318 | " return s.item()" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": null, 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [ 327 | "def calculateAccuracy(model, loader, device):\n", 328 | " loss = 0\n", 329 | " accuracy = 0\n", 330 | " total = 0\n", 331 | " \n", 332 | " for material, targets in loader:\n", 333 | " material, targets = material.to(device), targets.to(device)\n", 334 | " outputs = model(material)\n", 335 | " \n", 336 | " l = loss_func(outputs, targets)\n", 337 | "\n", 338 | " loss += l.item() * len(targets)\n", 339 | " _, predicted = torch.max(outputs.data, dim=1)\n", 340 | " accuracy += (predicted == targets).sum().item()\n", 341 | " total += len(targets)\n", 342 | " \n", 343 | " loss /= total\n", 344 | " accuracy /= total\n", 345 | " \n", 346 | " return loss, accuracy" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "metadata": {}, 352 | "source": [ 353 | "## 聚合函数" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "def mean(wList):\n", 363 | " return torch.mean(wList, dim=0)" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "def gm(wList):\n", 373 | " max_iter = 80\n", 374 | " tol = 1e-5\n", 375 | " guess = torch.mean(wList, dim=0)\n", 376 | " for _ in range(max_iter):\n", 377 | " dist_li = torch.norm(wList-guess, dim=1)\n", 378 | " for i in range(len(dist_li)):\n", 379 | " if dist_li[i] == 0:\n", 380 | " dist_li[i] = 1\n", 381 | " temp1 = torch.sum(torch.stack([w/d for w, d in zip(wList, dist_li)]), dim=0)\n", 382 | " temp2 = torch.sum(1/dist_li)\n", 383 | " guess_next = temp1 / temp2\n", 384 | " guess_movement = torch.norm(guess - guess_next)\n", 385 | " guess = guess_next\n", 386 | " if guess_movement <= tol:\n", 387 | " break\n", 388 | " return guess" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "def Krum_(nodeSize, byzantineSize):\n", 398 | " honestSize = nodeSize - byzantineSize\n", 399 | " dist = torch.zeros(nodeSize, nodeSize, dtype=torch.float32)\n", 400 | " def Krum(wList):\n", 401 | " for i in range(nodeSize):\n", 402 | " for j in range(i, nodeSize):\n", 403 | " distance = wList[i].data - wList[j].data\n", 404 | " distance = (distance*distance).sum()\n", 405 | " dist[i][j] = distance.data\n", 406 | " dist[j][i] = distance.data\n", 407 | " k = nodeSize - byzantineSize - 2 + 1 # 算上自己和自己的0.00\n", 408 | " topv, _ = dist.topk(k=k, dim=1)\n", 409 | " sumdist = -topv.sum(dim=1)\n", 410 | " resindex = sumdist.topk(1)[1].squeeze()\n", 411 | " return wList[resindex]\n", 412 | " return Krum" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "metadata": {}, 419 | "outputs": [], 420 | "source": [ 421 | "def median(wList):\n", 422 | " return wList.median(dim=0)[0]" 423 | ] 424 | }, 425 | { 426 | "cell_type": "markdown", 427 | "metadata": {}, 428 | "source": [ 429 | "## torch辅助函数" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": {}, 436 | "outputs": [], 437 | "source": [ 438 | "def flatten_list(message, byzantineSize):\n", 439 | " wList = [torch.cat([p.flatten() for p in parameters]) for parameters in message]\n", 440 | " wList.extend([torch.zeros_like(wList[0]) for _ in range(byzantineSize)])\n", 441 | " wList = torch.stack(wList)\n", 442 | " return wList\n", 443 | "def unflatten_vector(vector, model):\n", 444 | " paraGroup = []\n", 445 | " cum = 0\n", 446 | " for p in model.parameters():\n", 447 | " newP = vector[cum:cum+p.numel()]\n", 448 | " paraGroup.append(newP.view_as(p))\n", 449 | " cum += p.numel()\n", 450 | " return paraGroup" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": null, 456 | "metadata": {}, 457 | "outputs": [], 458 | "source": [ 459 | "def randomSample(dataset, batchSize):\n", 460 | " m, t = zip(*random.sample(dataset, batchSize))\n", 461 | " material, targets = torch.cat(m), torch.tensor(t)\n", 462 | " return material, targets" 463 | ] 464 | }, 465 | { 466 | "cell_type": "code", 467 | "execution_count": null, 468 | "metadata": {}, 469 | "outputs": [], 470 | "source": [ 471 | "def getPara(module, useString=True):\n", 472 | " para = sum([x.nelement() for x in module.parameters()])\n", 473 | " if not useString:\n", 474 | " return para\n", 475 | " elif para >= 2**20:\n", 476 | " return '{:.2f}M'.format(para / 2**20)\n", 477 | " elif para >= 2**10:\n", 478 | " return '{:.2f}K'.format(para / 2**10)\n", 479 | " else:\n", 480 | " return str(para)" 481 | ] 482 | }, 483 | { 484 | "cell_type": "markdown", 485 | "metadata": {}, 486 | "source": [ 487 | "# 优化算法" 488 | ] 489 | }, 490 | { 491 | "cell_type": "markdown", 492 | "metadata": {}, 493 | "source": [ 494 | "报告函数" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": null, 500 | "metadata": {}, 501 | "outputs": [], 502 | "source": [ 503 | "def report(r, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy, var=None):\n", 504 | " varStr = '' if (var == None) else ' var={:.2e}'.format(var)\n", 505 | " log('[{}/{}](interval: {:.0f}) train: loss={:.4f} acc={:.2f} val: loss={:.4f} acc={:.2f}{}'\n", 506 | " .format(r, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy, varStr)\n", 507 | " )" 508 | ] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "metadata": { 513 | "heading_collapsed": true 514 | }, 515 | "source": [ 516 | "## CentralSGD" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "execution_count": null, 522 | "metadata": { 523 | "code_folding": [], 524 | "hidden": true 525 | }, 526 | "outputs": [], 527 | "source": [ 528 | "def CentralSGD(model, gamma, aggregate, weight_decay, attack=None, \n", 529 | " rounds=10, displayInterval=1000, \n", 530 | " device='cpu', SEED=100, fixSeed=False, \n", 531 | " batchSize=1,\n", 532 | " **kw):\n", 533 | " if fixSeed:\n", 534 | " random.seed(SEED)\n", 535 | "\n", 536 | " # 顺序遍历loader\n", 537 | " train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batchSize, shuffle=False)\n", 538 | " validate_loader = torch.utils.data.DataLoader(dataset=validate_dataset, batch_size=batchSize, shuffle=False)\n", 539 | "\n", 540 | " # 随机取样器\n", 541 | " randomSampler = lambda dataset: torch.utils.data.sampler.RandomSampler(\n", 542 | " dataset, \n", 543 | " num_samples=rounds*displayInterval*batchSize, \n", 544 | " replacement=True\n", 545 | " )\n", 546 | " train_random_loaders_splited = [torch.utils.data.DataLoader(\n", 547 | " dataset=subset,\n", 548 | " batch_size=batchSize, \n", 549 | " sampler=randomSampler(subset),\n", 550 | " ) for subset in train_dataset_subset]\n", 551 | " randomIters = [iter(loader) for loader in train_random_loaders_splited]\n", 552 | " \n", 553 | " # 求初始误差\n", 554 | " trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)\n", 555 | " valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)\n", 556 | "\n", 557 | " trainLossPath = [trainLoss]\n", 558 | " trainAccPath = [trainAccuracy]\n", 559 | " valLossPath = [valLoss]\n", 560 | " valAccPath = [valAccuracy]\n", 561 | " \n", 562 | " report(0, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)\n", 563 | "\n", 564 | " for r in range(rounds):\n", 565 | " model.train()\n", 566 | " for k in range(displayInterval):\n", 567 | " # 读取数据\n", 568 | " material, targets = next(randomIter)\n", 569 | " material, targets = material.to(device), targets.to(device)\n", 570 | "\n", 571 | " # 随机梯度\n", 572 | " # --------------------\n", 573 | " # 预测\n", 574 | " outputs = model(material)\n", 575 | " loss = loss_func(outputs, targets)\n", 576 | " # 反向传播\n", 577 | " model.zero_grad()\n", 578 | " loss.backward()\n", 579 | "\n", 580 | " # 更新\n", 581 | " for para in model.parameters():\n", 582 | " para.data.add_(-gamma, para.grad)\n", 583 | " para.data.add_(-weight_decay, para)\n", 584 | " \n", 585 | " \n", 586 | " model.eval()\n", 587 | " trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)\n", 588 | " valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)\n", 589 | "\n", 590 | " trainLossPath.append(trainLoss)\n", 591 | " trainAccPath.append(trainAccuracy)\n", 592 | " valLossPath.append(valLoss)\n", 593 | " valAccPath.append(valAccuracy)\n", 594 | "\n", 595 | " report(r+1, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)\n", 596 | " return model, trainLossPath, trainAccPath, valLossPath, valAccPath, []" 597 | ] 598 | }, 599 | { 600 | "cell_type": "markdown", 601 | "metadata": { 602 | "heading_collapsed": true 603 | }, 604 | "source": [ 605 | "## Central SARAH" 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "execution_count": null, 611 | "metadata": { 612 | "code_folding": [], 613 | "hidden": true 614 | }, 615 | "outputs": [], 616 | "source": [ 617 | "def CentralSARAH(model, gamma, aggregate, weight_decay, \n", 618 | " snapshotInterval=len(train_dataset),\n", 619 | " rounds=10, displayInterval=1000, \n", 620 | " device='cpu', SEED=100, fixSeed=False, \n", 621 | " batchSize=5,\n", 622 | " **kw):\n", 623 | " \n", 624 | " if fixSeed:\n", 625 | " random.seed(SEED)\n", 626 | " \n", 627 | " # 初始化模型\n", 628 | " lastModel = modelFactory(SEED=SEED)\n", 629 | " lastModel = lastModel.to(device)\n", 630 | "\n", 631 | " # 随机的停止期限\n", 632 | " randomStop = 1\n", 633 | " \n", 634 | " # 顺序遍历loader\n", 635 | " train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batchSize, shuffle=False)\n", 636 | " validate_loader = torch.utils.data.DataLoader(dataset=validate_dataset, batch_size=batchSize, shuffle=False)\n", 637 | " \n", 638 | " # 随机取样器\n", 639 | " randomSampler = torch.utils.data.sampler.RandomSampler(\n", 640 | " train_dataset, \n", 641 | " num_samples=rounds*displayInterval*batchSize, \n", 642 | " replacement=True\n", 643 | " )\n", 644 | " randomLoader = torch.utils.data.DataLoader(\n", 645 | " dataset=train_dataset,\n", 646 | " batch_size=batchSize, \n", 647 | " sampler=randomSampler,\n", 648 | " )\n", 649 | " randomIter = iter(randomLoader)\n", 650 | " \n", 651 | " # 求初始误差\n", 652 | " trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)\n", 653 | " valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)\n", 654 | " \n", 655 | " trainLossPath = [trainLoss]\n", 656 | " trainAccPath = [trainAccuracy]\n", 657 | " valLossPath = [valLoss]\n", 658 | " valAccPath = [valAccuracy]\n", 659 | " \n", 660 | " log('[SARAH]初始 train: loss={:.6f} accuracy={:.2f} validation: loss={:.6f} accuracy={:.2f}'\n", 661 | " .format(trainLossPath[0], trainAccPath[0], valLossPath[0], valAccPath[0])\n", 662 | " )\n", 663 | "\n", 664 | " gradients = [torch.zeros_like(para, requires_grad=False) for para in model.parameters()]\n", 665 | " \n", 666 | " for r in range(rounds):\n", 667 | " for k in range(displayInterval):\n", 668 | " # snapshot\n", 669 | " if (r*displayInterval + k) % randomStop == 0:\n", 670 | " # 清空旧梯度\n", 671 | " for grad in gradients:\n", 672 | " grad.zero_()\n", 673 | " for material, targets in train_loader:\n", 674 | " material, targets = material.to(device), targets.to(device)\n", 675 | " # 预测\n", 676 | " outputs = model(material)\n", 677 | " loss = loss_func(outputs, targets)\n", 678 | " # 反向传播\n", 679 | " model.zero_grad()\n", 680 | " loss.backward()\n", 681 | "\n", 682 | " for grad, para in zip(gradients, model.parameters()):\n", 683 | " grad.data.add_(1/len(train_loader), para.grad.data)\n", 684 | " for grad, para in zip(gradients, model.parameters()):\n", 685 | " grad.data.add_(weight_decay, para.data)\n", 686 | " \n", 687 | " # 保存旧结果\n", 688 | " for oldPara, newPara in zip(lastModel.parameters(), model.parameters()):\n", 689 | " oldPara.data.copy_(newPara)\n", 690 | " # 更新\n", 691 | " for para, grad in zip(model.parameters(), gradients):\n", 692 | " para.data.add_(-gamma, grad)\n", 693 | " # 指定下一次停止时间\n", 694 | " randomStop = random.randint(1, snapshotInterval-1)\n", 695 | " \n", 696 | " # 更新\n", 697 | " # 读取数据\n", 698 | " material, targets = next(randomIter)\n", 699 | " material, targets = material.to(device), targets.to(device)\n", 700 | "\n", 701 | " # 随机梯度\n", 702 | " # --------------------\n", 703 | " # 预测\n", 704 | " outputs = model(material)\n", 705 | " loss = loss_func(outputs, targets)\n", 706 | " # 反向传播\n", 707 | " model.zero_grad()\n", 708 | " loss.backward()\n", 709 | "\n", 710 | " # 修正梯度\n", 711 | " # --------------------\n", 712 | " # 预测\n", 713 | " outputs = lastModel(material)\n", 714 | " loss = loss_func(outputs, targets)\n", 715 | " # 反向传播\n", 716 | " lastModel.zero_grad()\n", 717 | " loss.backward()\n", 718 | "\n", 719 | " # 更新梯度表\n", 720 | " for pi, para in enumerate(model.parameters()):\n", 721 | " gradients[pi].data.add_(1, para.grad.data)\n", 722 | " gradients[pi].data.add_(weight_decay, para)\n", 723 | " for pi, para in enumerate(lastModel.parameters()):\n", 724 | " gradients[pi].data.sub_(1, para.grad.data)\n", 725 | " gradients[pi].data.sub_(weight_decay, para)\n", 726 | "\n", 727 | " # 保存旧结果\n", 728 | " for oldPara, newPara in zip(lastModel.parameters(), model.parameters()):\n", 729 | " oldPara.data.copy_(newPara)\n", 730 | " # 更新\n", 731 | " for para, grad in zip(model.parameters(), gradients):\n", 732 | " para.data.add_(-gamma, grad)\n", 733 | " \n", 734 | " trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)\n", 735 | " valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)\n", 736 | "\n", 737 | " trainLossPath.append(trainLoss)\n", 738 | " trainAccPath.append(trainAccuracy)\n", 739 | " valLossPath.append(valLoss)\n", 740 | " valAccPath.append(valAccuracy)\n", 741 | " \n", 742 | " report(r+1, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)\n", 743 | " return model, trainLossPath, trainAccPath, valLossPath, valAccPath, []" 744 | ] 745 | }, 746 | { 747 | "cell_type": "markdown", 748 | "metadata": { 749 | "heading_collapsed": true 750 | }, 751 | "source": [ 752 | "## SGD" 753 | ] 754 | }, 755 | { 756 | "cell_type": "code", 757 | "execution_count": null, 758 | "metadata": { 759 | "code_folding": [], 760 | "hidden": true 761 | }, 762 | "outputs": [], 763 | "source": [ 764 | "def SGD(model, gamma, aggregate, weight_decay, \n", 765 | " honestSize=0, byzantineSize=0, attack=None, \n", 766 | " rounds=10, displayInterval=1000, \n", 767 | " device='cpu', SEED=100, fixSeed=False, \n", 768 | " batchSize=5,\n", 769 | " **kw):\n", 770 | " assert byzantineSize == 0 or attack != None\n", 771 | " assert honestSize != 0\n", 772 | " \n", 773 | " if fixSeed:\n", 774 | " random.seed(SEED)\n", 775 | "\n", 776 | " nodeSize = honestSize + byzantineSize\n", 777 | "\n", 778 | " # 数据分片\n", 779 | " pieces = [(i*len(train_dataset)) // honestSize for i in range(honestSize+1)]\n", 780 | " dataPerNode = [pieces[i+1] - pieces[i] for i in range(honestSize)]\n", 781 | "\n", 782 | " # 回复的消息\n", 783 | " message = [\n", 784 | " [torch.zeros_like(para, requires_grad=False) for para in model.parameters()]\n", 785 | " for _ in range(nodeSize)] \n", 786 | " \n", 787 | " # 顺序遍历loader\n", 788 | " train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batchSize, shuffle=False)\n", 789 | " validate_loader = torch.utils.data.DataLoader(dataset=validate_dataset, batch_size=batchSize, shuffle=False)\n", 790 | " \n", 791 | " train_dataset_subset = [torch.utils.data.Subset(train_dataset, range(pieces[i], pieces[i+1])) for i in range(honestSize)]\n", 792 | " train_loaders_splited = [\n", 793 | " torch.utils.data.DataLoader(dataset=subset, batch_size=batchSize, shuffle=False)\n", 794 | " for subset in train_dataset_subset\n", 795 | " ]\n", 796 | " \n", 797 | " # 随机取样器\n", 798 | " randomSampler = lambda dataset: torch.utils.data.sampler.RandomSampler( \n", 799 | " dataset, \n", 800 | " num_samples=rounds*displayInterval*batchSize, \n", 801 | " replacement=True #有放回取样\n", 802 | " )\n", 803 | " train_random_loaders_splited = [torch.utils.data.DataLoader(\n", 804 | " dataset=subset,\n", 805 | " batch_size=batchSize, \n", 806 | " sampler=randomSampler(subset),\n", 807 | " ) for subset in train_dataset_subset]\n", 808 | " \n", 809 | " randomIters = [iter(loader) for loader in train_random_loaders_splited]\n", 810 | " \n", 811 | " # 求初始误差\n", 812 | " trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)\n", 813 | " valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)\n", 814 | " \n", 815 | " trainLossPath = [trainLoss]\n", 816 | " trainAccPath = [trainAccuracy]\n", 817 | " valLossPath = [valLoss]\n", 818 | " valAccPath = [valAccuracy]\n", 819 | " variencePath = []\n", 820 | " \n", 821 | " report(0, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)\n", 822 | "\n", 823 | " for r in range(rounds):\n", 824 | " for k in range(displayInterval):\n", 825 | " # 诚实节点更新\n", 826 | " for node in range(honestSize):\n", 827 | " # 读取数据\n", 828 | " material, targets = next(randomIters[node]) \n", 829 | " \n", 830 | " # 随机梯度\n", 831 | " # --------------------\n", 832 | " # 预测\n", 833 | " outputs = model(material)\n", 834 | " loss = loss_func(outputs, targets)\n", 835 | " # 反向传播\n", 836 | " model.zero_grad()\n", 837 | " loss.backward()\n", 838 | "\n", 839 | " # 更新梯度表\n", 840 | " for pi, para in enumerate(model.parameters()):\n", 841 | " message[node][pi].data.zero_()\n", 842 | " message[node][pi].data.add_(1, para.grad.data)\n", 843 | " message[node][pi].data.add_(weight_decay, para)\n", 844 | "\n", 845 | " # 同步, Byzantine攻击\n", 846 | " message_f = flatten_list(message, byzantineSize) \n", 847 | " if attack != None:\n", 848 | " attack(message_f, byzantineSize)\n", 849 | " # 聚合\n", 850 | " g_vector = aggregate(message_f)\n", 851 | " # 展开\n", 852 | " g = unflatten_vector(g_vector, model) \n", 853 | " # 更新\n", 854 | " for para, grad in zip(model.parameters(), g):\n", 855 | " para.data.add_(-gamma, grad)\n", 856 | " \n", 857 | " var = getVarience(message_f, honestSize)\n", 858 | " variencePath.append(var)\n", 859 | " \n", 860 | " trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)\n", 861 | " valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)\n", 862 | "\n", 863 | " trainLossPath.append(trainLoss)\n", 864 | " trainAccPath.append(trainAccuracy)\n", 865 | " valLossPath.append(valLoss)\n", 866 | " valAccPath.append(valAccuracy)\n", 867 | " \n", 868 | " report(r+1, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)\n", 869 | " return model, trainLossPath, trainAccPath, valLossPath, valAccPath, variencePath" 870 | ] 871 | }, 872 | { 873 | "cell_type": "markdown", 874 | "metadata": { 875 | "heading_collapsed": true 876 | }, 877 | "source": [ 878 | "## SAGA" 879 | ] 880 | }, 881 | { 882 | "cell_type": "code", 883 | "execution_count": null, 884 | "metadata": { 885 | "hidden": true 886 | }, 887 | "outputs": [], 888 | "source": [ 889 | "# 初始化本地模型\n", 890 | "def initModel(local_models, honestSize):\n", 891 | " stateDict = local_models[0].state_dict()\n", 892 | " for model in local_models[1:honestSize]:\n", 893 | " model.load_state_dict(stateDict)\n", 894 | "\n", 895 | "# 广播\n", 896 | "def broadcastPara(newPara, local_models):\n", 897 | " cum = 0\n", 898 | " for p in local_models[0].parameters():\n", 899 | " newP = newPara[cum:cum+p.numel()]\n", 900 | " p.data.copy_(newP.view_as(p))\n", 901 | " cum += p.numel()\n", 902 | " stateDict = local_models[0].state_dict()\n", 903 | " for model in local_models[1:]:\n", 904 | " model.load_state_dict(stateDict)" 905 | ] 906 | }, 907 | { 908 | "cell_type": "code", 909 | "execution_count": null, 910 | "metadata": { 911 | "code_folding": [ 912 | 37 913 | ], 914 | "hidden": true 915 | }, 916 | "outputs": [], 917 | "source": [ 918 | "def SAGA(model, gamma, aggregate, weight_decay, \n", 919 | " honestSize=0, byzantineSize=0, attack=None, \n", 920 | " rounds=10, displayInterval=1000, \n", 921 | " device='cpu', SEED=100, fixSeed=False, \n", 922 | " batchSize=1,\n", 923 | " **kw):\n", 924 | " assert byzantineSize == 0 or attack != None\n", 925 | " assert honestSize != 0\n", 926 | " \n", 927 | " if fixSeed:\n", 928 | " random.seed(SEED)\n", 929 | "\n", 930 | " nodeSize = honestSize + byzantineSize\n", 931 | " \n", 932 | " # 数据分片\n", 933 | " pieces = [(i*len(train_dataset)) // honestSize for i in range(honestSize+1)]\n", 934 | " dataPerNode = [pieces[i+1] - pieces[i] for i in range(honestSize)]\n", 935 | " \n", 936 | " #创建变量\n", 937 | " store = []\n", 938 | " \n", 939 | " # 顺序遍历loader\n", 940 | " train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batchSize, shuffle=False)\n", 941 | " validate_loader = torch.utils.data.DataLoader(dataset=validate_dataset, batch_size=batchSize, shuffle=False)\n", 942 | " \n", 943 | " train_dataset_subset = [torch.utils.data.Subset(train_dataset, range(pieces[i], pieces[i+1])) for i in range(honestSize)]\n", 944 | " train_loaders_splited = [\n", 945 | " torch.utils.data.DataLoader(dataset=subset, batch_size=batchSize, shuffle=False)\n", 946 | " for subset in train_dataset_subset\n", 947 | " ]\n", 948 | " \n", 949 | " # 随机取样器\n", 950 | " randomSampler = lambda dataset: torch.utils.data.sampler.RandomSampler( \n", 951 | " dataset, \n", 952 | " num_samples=rounds*displayInterval*batchSize, #取样规模:10*1500*batchSize \n", 953 | " replacement=True #有放回取样\n", 954 | " )\n", 955 | " train_random_loaders_splited = [torch.utils.data.DataLoader(\n", 956 | " dataset=subset,\n", 957 | " batch_size=batchSize, \n", 958 | " sampler=randomSampler(subset),\n", 959 | " ) for subset in train_dataset_subset]\n", 960 | " \n", 961 | " randomIters = [iter(loader) for loader in train_random_loaders_splited]\n", 962 | " \n", 963 | " # 求初始误差\n", 964 | " trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)\n", 965 | " valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)\n", 966 | " \n", 967 | " trainLossPath = [trainLoss]\n", 968 | " trainAccPath = [trainAccuracy]\n", 969 | " valLossPath = [valLoss]\n", 970 | " valAccPath = [valAccuracy]\n", 971 | " variencePath = []\n", 972 | " \n", 973 | " report(0, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)\n", 974 | " \n", 975 | " #对所有样本的权重梯度进行初始化\n", 976 | " for index, (material, targets) in enumerate(train_dataset):\n", 977 | " # 计算Loss\n", 978 | " outputs = model(material)\n", 979 | " targets = torch.tensor([targets]) \n", 980 | " loss = loss_func(outputs, targets)\n", 981 | " \n", 982 | " # 反向传播\n", 983 | " model.zero_grad()\n", 984 | " loss.backward()\n", 985 | " \n", 986 | " store.append([p.grad.clone().detach() for p in model.parameters()])\n", 987 | " \n", 988 | " # G_avg每一行是单个节点上存储的均值\n", 989 | " G_avg = []\n", 990 | " for i in range(honestSize):\n", 991 | " # storeInThisNode:该节点上梯度缓存的集合\n", 992 | " storeInThisNode = store[pieces[i]: pieces[i+1]]\n", 993 | " # para每一个元素是在对应节点上的一组参数\n", 994 | " (*paras,) = zip(*storeInThisNode)\n", 995 | " # 对所有单一节点上所有数据求平均\n", 996 | " G_avg.append([sum(para)/(pieces[i+1]-pieces[i]) for para in paras])\n", 997 | " \n", 998 | " # 回复的消息\n", 999 | " message = [\n", 1000 | " [torch.zeros_like(para, requires_grad=False) for para in model.parameters()]\n", 1001 | " for _ in range(nodeSize)\n", 1002 | " ]\n", 1003 | " \n", 1004 | " for r in range(rounds):\n", 1005 | " for k in range(displayInterval):\n", 1006 | " # 诚实节点更新\n", 1007 | " for node in range(honestSize):\n", 1008 | " # 读取数据\n", 1009 | " index = random.randint(pieces[node], pieces[node+1]-1)\n", 1010 | " # 预测\n", 1011 | " material, targets = train_dataset[index]\n", 1012 | " # 计算Loss\n", 1013 | " outputs = model(material)\n", 1014 | " targets = torch.tensor([targets])\n", 1015 | " loss = loss_func(outputs, targets)\n", 1016 | " \n", 1017 | " # 反向传播\n", 1018 | " model.zero_grad() \n", 1019 | " loss.backward()\n", 1020 | "\n", 1021 | " # 更新梯度表\n", 1022 | " for pi, para in enumerate(model.parameters()):\n", 1023 | " old_G = store[index][pi]\n", 1024 | " new_G = para.grad.data.clone()\n", 1025 | " new_G.add_(weight_decay, para.data)\n", 1026 | "\n", 1027 | " gradient = new_G.data - old_G.data + G_avg[node][pi].data\n", 1028 | " \n", 1029 | " message[node][pi] =gradient\n", 1030 | "\n", 1031 | " G_avg[node][pi].add_(1 / dataPerNode[node],new_G.data - old_G.data)\n", 1032 | " \n", 1033 | " store[index][pi] = new_G.data\n", 1034 | " \n", 1035 | " #攻击\n", 1036 | " message_f = flatten_list(message, byzantineSize) #将原本parameters的tensor形式压缩成torch.Size([90, 39760])\n", 1037 | " if attack != None:\n", 1038 | " attack(message_f, byzantineSize)\n", 1039 | " # 聚合\n", 1040 | " g_vector = aggregate(message_f)\n", 1041 | " # 展开\n", 1042 | " g = unflatten_vector(g_vector, model) #展开成原本parameters的tensor形式\n", 1043 | " # 更新\n", 1044 | " for para, grad in zip(model.parameters(), g):\n", 1045 | " para.data.add_(-gamma, grad)\n", 1046 | " \n", 1047 | " var = getVarience(message_f, honestSize)\n", 1048 | " variencePath.append(var)\n", 1049 | " \n", 1050 | " trainLoss, trainAccuracy = calculateAccuracy(model, train_loader, device)\n", 1051 | " valLoss, valAccuracy = calculateAccuracy(model, validate_loader, device)\n", 1052 | "\n", 1053 | " trainLossPath.append(trainLoss)\n", 1054 | " trainAccPath.append(trainAccuracy)\n", 1055 | " valLossPath.append(valLoss)\n", 1056 | " valAccPath.append(valAccuracy)\n", 1057 | " \n", 1058 | " report(r+1, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)\n", 1059 | " return model, trainLossPath, trainAccPath, valLossPath, valAccPath, variencePath " 1060 | ] 1061 | }, 1062 | { 1063 | "cell_type": "markdown", 1064 | "metadata": { 1065 | "heading_collapsed": true 1066 | }, 1067 | "source": [ 1068 | "## SVRG" 1069 | ] 1070 | }, 1071 | { 1072 | "cell_type": "code", 1073 | "execution_count": null, 1074 | "metadata": { 1075 | "code_folding": [ 1076 | 37 1077 | ], 1078 | "hidden": true 1079 | }, 1080 | "outputs": [], 1081 | "source": [ 1082 | "def SVRG(w0, gamma, aggregate, weight_decay, honestSize=0, byzantineSize=0, attack=None, \n", 1083 | " snapshotInterval=6000, rounds=10, displayInterval=1000, SEED=100, fixSeed=False, **kw):\n", 1084 | " assert byzantineSize == 0 or attack != None\n", 1085 | " assert honestSize != 0\n", 1086 | " \n", 1087 | " if fixSeed:\n", 1088 | " random.seed(SEED)\n", 1089 | "\n", 1090 | " nodeSize = honestSize + byzantineSize\n", 1091 | " \n", 1092 | " # 初始化\n", 1093 | " w = w0.clone().detach()\n", 1094 | "\n", 1095 | " # 数据分片\n", 1096 | " pieces = [(i*len(dataset)) // honestSize for i in range(honestSize+1)]\n", 1097 | " dataPerNode = [pieces[i+1] - pieces[i] for i in range(honestSize)]\n", 1098 | "\n", 1099 | " snapshot_g = torch.zeros(honestSize, len(w0), dtype=torch.float64)\n", 1100 | " snapshot_w = torch.zeros(len(w0), dtype=torch.float64)\n", 1101 | "\n", 1102 | " path = [F(w, dataset, weight_decay)]\n", 1103 | " variencePath = []\n", 1104 | " log('[SVRG]初始 loss={:.6f}, accuracy={:.2f} gamma={:}'.format(path[0], accuracy(w, dataset), gamma))\n", 1105 | " \n", 1106 | " # 中间变量分配空间\n", 1107 | " message = torch.zeros(nodeSize, len(w0), dtype=torch.float64)\n", 1108 | "\n", 1109 | " log('开始迭代')\n", 1110 | " for r in range(rounds):\n", 1111 | " for k in range(displayInterval):\n", 1112 | " # snapshot\n", 1113 | " if (r*displayInterval + k) % snapshotInterval == 0:\n", 1114 | " snapshot_g.zero_()\n", 1115 | " for node in range(honestSize):\n", 1116 | " for index in range(pieces[node], pieces[node+1]):\n", 1117 | " x, y = dataset[index]\n", 1118 | " # 更新梯度表\n", 1119 | " predict = LogisticRegression(w, x)\n", 1120 | "\n", 1121 | " err = (predict-y).data\n", 1122 | " snapshot_g[node][:-1].add_(1/dataPerNode[node], err*x)\n", 1123 | " snapshot_g[node][-1].add_(1/dataPerNode[node], err)\n", 1124 | " snapshot_g[node].add_(weight_decay, w)\n", 1125 | " snapshot_w.copy_(w)\n", 1126 | " \n", 1127 | " # 诚实节点更新\n", 1128 | " message.zero_()\n", 1129 | " for node in range(honestSize):\n", 1130 | " index = random.randint(pieces[node], pieces[node+1]-1)\n", 1131 | "\n", 1132 | " x, y = dataset[index]\n", 1133 | " # 随机梯度\n", 1134 | " predict = LogisticRegression(w, x)\n", 1135 | " err = (predict-y).data\n", 1136 | " message[node][:-1].add_(err, x)\n", 1137 | " message[node][-1].add_(err, 1)\n", 1138 | " message[node].add_(weight_decay, w)\n", 1139 | " \n", 1140 | " # 修正梯度\n", 1141 | " predict = LogisticRegression(snapshot_w, x)\n", 1142 | " err = (predict-y).data\n", 1143 | " message[node][:-1].add_(-err, x)\n", 1144 | " message[node][-1].add_(-err, 1)\n", 1145 | " message[node].add_(-weight_decay, snapshot_w)\n", 1146 | " \n", 1147 | " message[node].add_(1, snapshot_g[node])\n", 1148 | " \n", 1149 | " # 同步\n", 1150 | " # Byzantine攻击\n", 1151 | " if attack != None:\n", 1152 | " attack(message, byzantineSize)\n", 1153 | " g = aggregate(message)\n", 1154 | " w.add_(-gamma, g)\n", 1155 | " \n", 1156 | " loss = F(w, dataset, weight_decay)\n", 1157 | " acc = accuracy(w, dataset)\n", 1158 | " path.append(loss)\n", 1159 | " var = getVarience(message, honestSize)\n", 1160 | " variencePath.append(var)\n", 1161 | " log('[SVRG]已迭代 {}/{} rounds (interval: {:.0f}), loss={:.9f}, accuracy={:.2f}, var={:.9f}'.format(\n", 1162 | " r+1, rounds, displayInterval, loss, acc, var\n", 1163 | " ))\n", 1164 | " return w, path, variencePath" 1165 | ] 1166 | }, 1167 | { 1168 | "cell_type": "markdown", 1169 | "metadata": { 1170 | "heading_collapsed": true 1171 | }, 1172 | "source": [ 1173 | "## SARAH" 1174 | ] 1175 | }, 1176 | { 1177 | "cell_type": "code", 1178 | "execution_count": null, 1179 | "metadata": { 1180 | "code_folding": [], 1181 | "hidden": true 1182 | }, 1183 | "outputs": [], 1184 | "source": [ 1185 | "def SARAH(model, gamma, aggregate, weight_decay, \n", 1186 | " snapshotInterval=len(train_dataset),\n", 1187 | " honestSize=0, byzantineSize=0, attack=None, \n", 1188 | " rounds=10, displayInterval=1000, \n", 1189 | " device='cpu', SEED=100, fixSeed=False, \n", 1190 | " batchSize=5,\n", 1191 | " **kw):\n", 1192 | " assert byzantineSize == 0 or attack != None\n", 1193 | " assert honestSize != 0\n", 1194 | " \n", 1195 | " if fixSeed:\n", 1196 | " random.seed(SEED)\n", 1197 | "\n", 1198 | " nodeSize = honestSize + byzantineSize\n", 1199 | " \n", 1200 | " # 初始化模型\n", 1201 | " lastModel = modelFactory(SEED=SEED)\n", 1202 | "\n", 1203 | " if device == 'cpu':\n", 1204 | " torch.manual_seed(SEED)#为CPU设置随机种子\n", 1205 | " else:\n", 1206 | " torch.cuda.manual_seed(seed)#为当前GPU设置随机种子\n", 1207 | " torch.cuda.manual_seed_all(seed)#为所有GPU设置随机种子\n", 1208 | " \n", 1209 | " # 数据分片\n", 1210 | " pieces = [(i*len(train_dataset)) // honestSize for i in range(honestSize+1)]\n", 1211 | " dataPerNode = [pieces[i+1] - pieces[i] for i in range(honestSize)]\n", 1212 | "\n", 1213 | " # 随机的停止期限\n", 1214 | " randomStop = 1\n", 1215 | " # 回复的消息\n", 1216 | " message = [\n", 1217 | " [torch.zeros_like(para, requires_grad=False) for para in model.parameters()]\n", 1218 | " for _ in range(nodeSize)\n", 1219 | " ]\n", 1220 | " \n", 1221 | " # 顺序遍历loader\n", 1222 | " train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batchSize, shuffle=False)\n", 1223 | " validate_loader = torch.utils.data.DataLoader(dataset=validate_dataset, batch_size=batchSize, shuffle=False)\n", 1224 | " \n", 1225 | " train_dataset_subset = [torch.utils.data.Subset(train_dataset, range(pieces[i], pieces[i+1])) for i in range(honestSize)]\n", 1226 | " train_loaders_splited = [\n", 1227 | " torch.utils.data.DataLoader(dataset=subset, batch_size=batchSize, shuffle=False)\n", 1228 | " for subset in train_dataset_subset\n", 1229 | " ]\n", 1230 | " \n", 1231 | " # 随机取样器\n", 1232 | " randomSampler = lambda dataset: torch.utils.data.sampler.RandomSampler(\n", 1233 | " dataset, \n", 1234 | " num_samples=rounds*displayInterval*batchSize, \n", 1235 | " replacement=True\n", 1236 | " )\n", 1237 | " train_random_loaders_splited = [torch.utils.data.DataLoader(\n", 1238 | " dataset=subset,\n", 1239 | " batch_size=batchSize, \n", 1240 | " sampler=randomSampler(subset),\n", 1241 | " ) for subset in train_dataset_subset]\n", 1242 | " randomIters = [iter(loader) for loader in train_random_loaders_splited]\n", 1243 | " \n", 1244 | " # 求初始误差\n", 1245 | " trainLoss, trainAccuracy = calculateAccuracy(model, train_loader)\n", 1246 | " valLoss, valAccuracy = calculateAccuracy(model, validate_loader)\n", 1247 | " \n", 1248 | " trainLossPath = [trainLoss]\n", 1249 | " trainAccPath = [trainAccuracy]\n", 1250 | " valLossPath = [valLoss]\n", 1251 | " valAccPath = [valAccuracy]\n", 1252 | " variencePath = []\n", 1253 | " \n", 1254 | " log('[SARAH]初始 train: loss={:.6f} accuracy={:.2f} validation: loss={:.6f} accuracy={:.2f}'\n", 1255 | " .format(trainLossPath[0], trainAccPath[0], valLossPath[0], valAccPath[0])\n", 1256 | " )\n", 1257 | "\n", 1258 | " for r in range(rounds):\n", 1259 | " for k in range(displayInterval):\n", 1260 | " # snapshot\n", 1261 | " if (r*displayInterval + k) % randomStop == 0:\n", 1262 | " for node in range(honestSize):\n", 1263 | " # 清空旧梯度\n", 1264 | " for grad in message[node]:\n", 1265 | " grad.zero_()\n", 1266 | " loader = train_loaders_splited[node]\n", 1267 | " for material, targets in loader:\n", 1268 | " # 预测\n", 1269 | " outputs = model(material)\n", 1270 | " loss = loss_func(outputs, targets)\n", 1271 | " # 反向传播\n", 1272 | " model.zero_grad()\n", 1273 | " loss.backward()\n", 1274 | " \n", 1275 | " for grad, para in zip(message[node], model.parameters()):\n", 1276 | " grad.data.add_(1/len(loader), para.grad.data)\n", 1277 | " for grad, para in zip(message[node], model.parameters()):\n", 1278 | " grad.data.add_(weight_decay, para.data)\n", 1279 | " \n", 1280 | " # 保存旧结果\n", 1281 | " for oldPara, newPara in zip(lastModel.parameters(), model.parameters()):\n", 1282 | " oldPara.data.copy_(newPara)\n", 1283 | " # 同步, Byzantine攻击\n", 1284 | " message_f = flatten_list(message, byzantineSize)\n", 1285 | " if attack != None:\n", 1286 | " attack(message_f, byzantineSize)\n", 1287 | " # 聚合\n", 1288 | " g_vector = aggregate(message_f)\n", 1289 | " # 展开\n", 1290 | " g = unflatten_vector(g_vector, model)\n", 1291 | " # 更新\n", 1292 | " for para, grad in zip(model.parameters(), g):\n", 1293 | " para.data.add_(-gamma, grad)\n", 1294 | " # 指定下一次停止时间\n", 1295 | " randomStop = random.randint(1, snapshotInterval-1)\n", 1296 | " \n", 1297 | " # 诚实节点更新\n", 1298 | " for node in range(honestSize):\n", 1299 | " # 读取数据\n", 1300 | " material, targets = next(randomIters[node])\n", 1301 | " \n", 1302 | " # 随机梯度\n", 1303 | " # --------------------\n", 1304 | " # 预测\n", 1305 | " outputs = model(material)\n", 1306 | " loss = loss_func(outputs, targets)\n", 1307 | " # 反向传播\n", 1308 | " model.zero_grad()\n", 1309 | " loss.backward()\n", 1310 | " \n", 1311 | " # 修正梯度\n", 1312 | " # --------------------\n", 1313 | " # 预测\n", 1314 | " outputs = lastModel(material)\n", 1315 | " loss = loss_func(outputs, targets)\n", 1316 | " # 反向传播\n", 1317 | " lastModel.zero_grad()\n", 1318 | " loss.backward()\n", 1319 | "\n", 1320 | " # 更新梯度表\n", 1321 | " for pi, para in enumerate(model.parameters()):\n", 1322 | " message[node][pi].data.add_(1, para.grad.data)\n", 1323 | " message[node][pi].data.add_(weight_decay, para)\n", 1324 | " for pi, para in enumerate(lastModel.parameters()):\n", 1325 | " message[node][pi].data.sub_(1, para.grad.data)\n", 1326 | " message[node][pi].data.sub_(weight_decay, para)\n", 1327 | "\n", 1328 | " # 同步, Byzantine攻击\n", 1329 | " message_f = flatten_list(message, byzantineSize)\n", 1330 | " if attack != None:\n", 1331 | " attack(message_f, byzantineSize)\n", 1332 | " # 聚合\n", 1333 | " g_vector = aggregate(message_f)\n", 1334 | " # 展开\n", 1335 | " g = unflatten_vector(g_vector, model)\n", 1336 | " # 保存旧结果\n", 1337 | " for oldPara, newPara in zip(lastModel.parameters(), model.parameters()):\n", 1338 | " oldPara.data.copy_(newPara)\n", 1339 | " # 更新\n", 1340 | " for para, grad in zip(model.parameters(), g):\n", 1341 | " para.data.add_(-gamma, grad)\n", 1342 | " \n", 1343 | " var = getVarience(message_f, honestSize)\n", 1344 | " variencePath.append(var)\n", 1345 | " \n", 1346 | " trainLoss, trainAccuracy = calculateAccuracy(model, train_loader)\n", 1347 | " valLoss, valAccuracy = calculateAccuracy(model, validate_loader)\n", 1348 | "\n", 1349 | " trainLossPath.append(trainLoss)\n", 1350 | " trainAccPath.append(trainAccuracy)\n", 1351 | " valLossPath.append(valLoss)\n", 1352 | " valAccPath.append(valAccuracy)\n", 1353 | " \n", 1354 | " report(r+1, rounds, displayInterval, trainLoss, trainAccuracy, valLoss, valAccuracy)\n", 1355 | " return model, trainLossPath, trainAccPath, valLossPath, valAccPath, variencePath" 1356 | ] 1357 | }, 1358 | { 1359 | "cell_type": "markdown", 1360 | "metadata": { 1361 | "heading_collapsed": true 1362 | }, 1363 | "source": [ 1364 | "# 恶意攻击" 1365 | ] 1366 | }, 1367 | { 1368 | "cell_type": "code", 1369 | "execution_count": null, 1370 | "metadata": { 1371 | "hidden": true 1372 | }, 1373 | "outputs": [], 1374 | "source": [ 1375 | "def white(messages, byzantinesize):\n", 1376 | " # 均值相同,方差为30\n", 1377 | " mu = torch.mean(messages[0:-byzantinesize], dim=0)\n", 1378 | " messages[-byzantinesize:].copy_(mu)\n", 1379 | " noise = torch.randn((byzantinesize, messages.size(1)), dtype=torch.float64)\n", 1380 | " messages[-byzantinesize:].add_(30, noise)\n", 1381 | " \n", 1382 | "def maxValue(messages, byzantinesize):\n", 1383 | " mu = torch.mean(messages[0:-byzantinesize], dim=0)\n", 1384 | " meliciousMessage = -10*mu\n", 1385 | " messages[-byzantinesize:].copy_(meliciousMessage)\n", 1386 | " \n", 1387 | "def zeroGradient(messages, byzantinesize):\n", 1388 | " s = torch.sum(messages[0:-byzantinesize], dim=0)\n", 1389 | " messages[-byzantinesize:].copy_(-s / byzantinesize)" 1390 | ] 1391 | }, 1392 | { 1393 | "cell_type": "markdown", 1394 | "metadata": {}, 1395 | "source": [ 1396 | "# 训练函数" 1397 | ] 1398 | }, 1399 | { 1400 | "cell_type": "code", 1401 | "execution_count": null, 1402 | "metadata": {}, 1403 | "outputs": [], 1404 | "source": [ 1405 | "def train(model, loss_func, optimizer, trainloader, device, weight_decay):\n", 1406 | " \"\"\"\n", 1407 | " train model using loss_fn and optimizer in an epoch.\n", 1408 | " model: CNN networks\n", 1409 | " train_loader: a Dataloader object with training data\n", 1410 | " loss_func: loss function\n", 1411 | " device: train on cpu or gpu device\n", 1412 | " \"\"\"\n", 1413 | " model.train()\n", 1414 | " \n", 1415 | " trainAccuracy = 0\n", 1416 | " trainLoss = 0\n", 1417 | " total = 0\n", 1418 | " \n", 1419 | " for i, (*material, targets) in enumerate(trainloader):\n", 1420 | " if isinstance(material, torch.Tensor):\n", 1421 | " material = material.to(device)\n", 1422 | " else:\n", 1423 | " material = [m.to(device) for m in material]\n", 1424 | " \n", 1425 | " targets = targets.to(device)\n", 1426 | "\n", 1427 | " # forward\n", 1428 | " outputs = model(*material)\n", 1429 | " \n", 1430 | " loss = loss_func(outputs, targets)\n", 1431 | " trainLoss += loss.item()\n", 1432 | "\n", 1433 | " # backward and optimize\n", 1434 | " optimizer.zero_grad()\n", 1435 | " loss.backward()\n", 1436 | " optimizer.step()\n", 1437 | " \n", 1438 | " # AdamW - https://zhuanlan.zhihu.com/p/38945390\n", 1439 | " for group in optimizer.param_groups:\n", 1440 | " for param in group['params']:\n", 1441 | " param.data = param.data.add(-weight_decay * group['lr'], param.data)\n", 1442 | "\n", 1443 | " # return the maximum value of each row of the input tensor in the \n", 1444 | " # given dimension dim, the second return vale is the index location\n", 1445 | " # of each maxium value found(argmax)\n", 1446 | " _, predicted = torch.max(outputs.data, dim=1)\n", 1447 | " trainAccuracy += (predicted == targets).sum().item()\n", 1448 | " \n", 1449 | " total += len(targets)\n", 1450 | " trainAccuracy /= total\n", 1451 | " trainLoss /= total\n", 1452 | " return trainLoss, trainAccuracy" 1453 | ] 1454 | }, 1455 | { 1456 | "cell_type": "code", 1457 | "execution_count": null, 1458 | "metadata": {}, 1459 | "outputs": [], 1460 | "source": [ 1461 | "def validate(model, loss_func, validateloader, device):\n", 1462 | " # evaluate the model\n", 1463 | " model.eval()\n", 1464 | " # context-manager that disabled gradient computation\n", 1465 | " with torch.no_grad():\n", 1466 | " # =============================================================\n", 1467 | " valAccuracy = 0\n", 1468 | " valLoss = 0\n", 1469 | " total = 0\n", 1470 | " \n", 1471 | " for i, (*material, targets) in enumerate(trainloader):\n", 1472 | " if isinstance(material, torch.Tensor):\n", 1473 | " material = material.to(device)\n", 1474 | " else:\n", 1475 | " material = [m.to(device) for m in material]\n", 1476 | "\n", 1477 | " targets = targets.to(device)\n", 1478 | " \n", 1479 | " outputs = model(*material)\n", 1480 | " \n", 1481 | " loss = loss_func(outputs, targets)\n", 1482 | " valLoss += loss.item()\n", 1483 | " \n", 1484 | " # return the maximum value of each row of the input tensor in the \n", 1485 | " # given dimension dim, the second return vale is the index location\n", 1486 | " # of each maxium value found(argmax)\n", 1487 | " _, predicted = torch.max(outputs.data, dim=1)\n", 1488 | " valAccuracy += (predicted == targets).sum().item()\n", 1489 | " \n", 1490 | " total += len(targets)\n", 1491 | " valAccuracy /= total\n", 1492 | " valLoss /= total\n", 1493 | " return valLoss, valAccuracy" 1494 | ] 1495 | }, 1496 | { 1497 | "cell_type": "code", 1498 | "execution_count": null, 1499 | "metadata": {}, 1500 | "outputs": [], 1501 | "source": [ 1502 | "def test(model, testloader, classname=None, name='default'):\n", 1503 | " # evaluate the model\n", 1504 | " model.eval()\n", 1505 | " # context-manager that disabled gradient computation\n", 1506 | " with torch.no_grad():\n", 1507 | " result = []\n", 1508 | " test_cnt = 0\n", 1509 | " for i, (*material, targets) in enumerate(testloader):\n", 1510 | " if isinstance(material, torch.Tensor):\n", 1511 | " material = material.to(device)\n", 1512 | " else:\n", 1513 | " material = [m.to(device) for m in material]\n", 1514 | "\n", 1515 | " targets = targets.to(device)\n", 1516 | "\n", 1517 | " outputs = model(*material)\n", 1518 | "\n", 1519 | " _, predicted = torch.max(outputs.data, dim=1)\n", 1520 | "\n", 1521 | " result.extend(predicted)\n", 1522 | " test_cnt += len(targets)\n", 1523 | "\n", 1524 | " if classname != None:\n", 1525 | " result = [classname[i] for i in result]\n", 1526 | "\n", 1527 | " log('共预测{}个数据'.format(test_cnt))\n", 1528 | " df_predict = pd.DataFrame({'id': list(range(1, len(result)+1)), 'polarity': result})\n", 1529 | " df_predict.to_csv('{}.csv'.format(name), index=False)\n", 1530 | " log('预测完成')\n", 1531 | " " 1532 | ] 1533 | }, 1534 | { 1535 | "cell_type": "code", 1536 | "execution_count": null, 1537 | "metadata": {}, 1538 | "outputs": [], 1539 | "source": [ 1540 | "def showCurve(list_trainLoss, list_trainAccuracy, list_valLoss, list_valAccuracy):\n", 1541 | " xAxis = list(range(len(list_trainLoss)))\n", 1542 | " fig, axs = plt.subplots(1, 2)\n", 1543 | "\n", 1544 | " axs[0].plot(xAxis, list_trainLoss, label='train')\n", 1545 | " axs[0].plot(xAxis, list_valLoss, label='validation')\n", 1546 | " axs[0].set_title('Loss')\n", 1547 | "\n", 1548 | " axs[1].plot(xAxis, list_trainAccuracy, label='train')\n", 1549 | " axs[1].plot(xAxis, list_valAccuracy, label='validation')\n", 1550 | " axs[1].set_title('Accuracy')\n", 1551 | "\n", 1552 | " for ax in axs:\n", 1553 | " ax.axis()\n", 1554 | " ax.set_xlabel('epoch')\n", 1555 | " ax.set_ylabel('{}'.format(ax.get_title()))\n", 1556 | " ax.legend()\n", 1557 | " fig.set_size_inches((8, 4))\n", 1558 | " plt.subplots_adjust(wspace=0.3)\n", 1559 | " plt.show()" 1560 | ] 1561 | }, 1562 | { 1563 | "cell_type": "markdown", 1564 | "metadata": {}, 1565 | "source": [ 1566 | "# 运行函数" 1567 | ] 1568 | }, 1569 | { 1570 | "cell_type": "code", 1571 | "execution_count": null, 1572 | "metadata": {}, 1573 | "outputs": [], 1574 | "source": [ 1575 | "def run(optimizer, aggregate, attack, config, device='cpu'):\n", 1576 | " # 初始化参数\n", 1577 | " _config = config.copy()\n", 1578 | " _config['aggregate'] = aggregate\n", 1579 | " _config['attack'] = attack\n", 1580 | " if attack == None:\n", 1581 | " _config['byzantineSize'] = 0\n", 1582 | " \n", 1583 | " model = modelFactory(SEED=_config['SEED'])\n", 1584 | " model = model.to(device)\n", 1585 | "\n", 1586 | " # 记录参数\n", 1587 | " attackName = 'baseline' if attack == None else attack.__name__\n", 1588 | " # e.g. Resnet50_SARAH(5)_baseline_mean\n", 1589 | " title = '{}_{}({})_{}_{}'.format(\n", 1590 | " model.__class__.__name__, \n", 1591 | " optimizer.__name__, \n", 1592 | " _config['batchSize'],\n", 1593 | " attackName, \n", 1594 | " aggregate.__name__\n", 1595 | " )\n", 1596 | " \n", 1597 | " # 打印运行信息\n", 1598 | " print('[提交任务] ' + title)\n", 1599 | " print('[运行信息]')\n", 1600 | " print('[网络属性] name={} parameters number={}'.format(model.__class__.__name__, getPara(model)))\n", 1601 | " print('[优化方法] name={} aggregation={} attack={}'.format(optimizer.__name__, aggregate.__name__, attackName))\n", 1602 | " print('[数据集属性] name={} trainSize={} validationSize={}'.format(dataSetConfig['name'], len(train_dataset), len(validate_dataset)))\n", 1603 | " print('[优化器设置] gamma={} weight_decay={} batchSize={}'.format(_config['gamma'], _config['weight_decay'], _config['batchSize']))\n", 1604 | " print('[节点个数] honestSize={}, byzantineSize={}'.format(_config['honestSize'], _config['byzantineSize']))\n", 1605 | " print('[运行次数] rounds={}, displayInterval={}'.format(_config['rounds'], _config['displayInterval']))\n", 1606 | " print('[torch设置] device={}, SEED={}, fixSeed={}'.format(device, _config['SEED'], _config['fixSeed']))\n", 1607 | " print('-------------------------------------------')\n", 1608 | " \n", 1609 | " # 开始运行\n", 1610 | " log('优化开始')\n", 1611 | " res = optimizer(model, device=device, **_config)\n", 1612 | " [*model, trainLossPath, trainAccPath, valLossPath, valAccPath, variencePath] = res\n", 1613 | "\n", 1614 | " record = {\n", 1615 | " **dataSetConfig,\n", 1616 | " **{key:(_config[key].__name__ if hasattr(_config[key], '__call__') else _config[key]) for key in _config},\n", 1617 | " 'trainLossPath': trainLossPath, \n", 1618 | " 'trainAccPath': trainAccPath, \n", 1619 | " 'valLossPath': valLossPath, \n", 1620 | " 'valAccPath': valAccPath, \n", 1621 | " 'variencePath': variencePath,\n", 1622 | " }\n", 1623 | "\n", 1624 | " with open(CACHE_DIR + title, 'wb') as f:\n", 1625 | " pickle.dump(record, f)\n", 1626 | " \n", 1627 | " _, axis = plt.subplots(1, 2)\n", 1628 | " axis[0].plot(list(range(len(trainLossPath))), trainLossPath, label='train loss')\n", 1629 | " axis[0].plot(list(range(len(valLossPath))), valLossPath, label='validation loss')\n", 1630 | " axis[1].plot(list(range(len(trainAccPath))), trainAccPath, label='train accuracy')\n", 1631 | " axis[1].plot(list(range(len(valAccPath))), valAccPath, label='validation accuracy')\n", 1632 | " for ax in axis:\n", 1633 | " ax.legend()\n", 1634 | " plt.show()" 1635 | ] 1636 | }, 1637 | { 1638 | "cell_type": "markdown", 1639 | "metadata": {}, 1640 | "source": [ 1641 | "# 测试" 1642 | ] 1643 | }, 1644 | { 1645 | "cell_type": "markdown", 1646 | "metadata": {}, 1647 | "source": [ 1648 | "## 中心式SGD调参" 1649 | ] 1650 | }, 1651 | { 1652 | "cell_type": "code", 1653 | "execution_count": null, 1654 | "metadata": { 1655 | "scrolled": true 1656 | }, 1657 | "outputs": [], 1658 | "source": [ 1659 | "_config = SGDConfig.copy()\n", 1660 | "_config['gamma'] = 5e-1\n", 1661 | "_config['rounds'] = 50\n", 1662 | "_config['batchSize'] = 20\n", 1663 | "run(optimizer = CentralSGD, aggregate = mean, attack = None, config = _config, device=device)" 1664 | ] 1665 | }, 1666 | { 1667 | "cell_type": "markdown", 1668 | "metadata": {}, 1669 | "source": [ 1670 | "## 中心式SARAH调参" 1671 | ] 1672 | }, 1673 | { 1674 | "cell_type": "code", 1675 | "execution_count": null, 1676 | "metadata": { 1677 | "scrolled": false 1678 | }, 1679 | "outputs": [], 1680 | "source": [ 1681 | "_config = SARAHConfig.copy()\n", 1682 | "_config['batchSize'] = 20\n", 1683 | "_config['gamma'] = 1e-4\n", 1684 | "_config['displayInterval'] = 100000\n", 1685 | "_config['rounds'] = 30\n", 1686 | "run(optimizer = CentralSARAH, aggregate = mean, attack = None, config = _config, device=device)" 1687 | ] 1688 | }, 1689 | { 1690 | "cell_type": "markdown", 1691 | "metadata": {}, 1692 | "source": [ 1693 | "## SGD" 1694 | ] 1695 | }, 1696 | { 1697 | "cell_type": "markdown", 1698 | "metadata": {}, 1699 | "source": [ 1700 | "### SGD - mean" 1701 | ] 1702 | }, 1703 | { 1704 | "cell_type": "code", 1705 | "execution_count": null, 1706 | "metadata": { 1707 | "scrolled": false 1708 | }, 1709 | "outputs": [], 1710 | "source": [ 1711 | "run(optimizer = SGD, aggregate = mean, attack = None, config = SGDConfig)" 1712 | ] 1713 | }, 1714 | { 1715 | "cell_type": "markdown", 1716 | "metadata": {}, 1717 | "source": [ 1718 | "white" 1719 | ] 1720 | }, 1721 | { 1722 | "cell_type": "code", 1723 | "execution_count": null, 1724 | "metadata": { 1725 | "scrolled": false 1726 | }, 1727 | "outputs": [], 1728 | "source": [ 1729 | "run(optimizer = SGD, aggregate = mean, attack = white, config = SGDConfig)" 1730 | ] 1731 | }, 1732 | { 1733 | "cell_type": "markdown", 1734 | "metadata": {}, 1735 | "source": [ 1736 | "max" 1737 | ] 1738 | }, 1739 | { 1740 | "cell_type": "code", 1741 | "execution_count": null, 1742 | "metadata": {}, 1743 | "outputs": [], 1744 | "source": [ 1745 | "run(optimizer = SGD, aggregate = mean, attack = maxValue, config = SGDConfig)" 1746 | ] 1747 | }, 1748 | { 1749 | "cell_type": "markdown", 1750 | "metadata": {}, 1751 | "source": [ 1752 | "zero Gradient" 1753 | ] 1754 | }, 1755 | { 1756 | "cell_type": "code", 1757 | "execution_count": null, 1758 | "metadata": {}, 1759 | "outputs": [], 1760 | "source": [ 1761 | "run(optimizer = SGD, aggregate = mean, attack = zeroGradient, config = SGDConfig)" 1762 | ] 1763 | }, 1764 | { 1765 | "cell_type": "markdown", 1766 | "metadata": {}, 1767 | "source": [ 1768 | "### SGD - geomtric median" 1769 | ] 1770 | }, 1771 | { 1772 | "cell_type": "code", 1773 | "execution_count": null, 1774 | "metadata": { 1775 | "scrolled": false 1776 | }, 1777 | "outputs": [], 1778 | "source": [ 1779 | "run(optimizer = SGD, aggregate = gm, attack = zeroGradient, config = SGDConfig)" 1780 | ] 1781 | }, 1782 | { 1783 | "cell_type": "markdown", 1784 | "metadata": {}, 1785 | "source": [ 1786 | "white" 1787 | ] 1788 | }, 1789 | { 1790 | "cell_type": "code", 1791 | "execution_count": null, 1792 | "metadata": { 1793 | "scrolled": false 1794 | }, 1795 | "outputs": [], 1796 | "source": [ 1797 | "run(optimizer = SGD, aggregate = gm, attack = white, config = SGDConfig)" 1798 | ] 1799 | }, 1800 | { 1801 | "cell_type": "markdown", 1802 | "metadata": {}, 1803 | "source": [ 1804 | "max" 1805 | ] 1806 | }, 1807 | { 1808 | "cell_type": "code", 1809 | "execution_count": null, 1810 | "metadata": { 1811 | "scrolled": false 1812 | }, 1813 | "outputs": [], 1814 | "source": [ 1815 | "run(optimizer = SGD, aggregate = gm, attack = maxValue, config = SGDConfig)" 1816 | ] 1817 | }, 1818 | { 1819 | "cell_type": "markdown", 1820 | "metadata": {}, 1821 | "source": [ 1822 | "zero Gradient" 1823 | ] 1824 | }, 1825 | { 1826 | "cell_type": "code", 1827 | "execution_count": null, 1828 | "metadata": { 1829 | "scrolled": false 1830 | }, 1831 | "outputs": [], 1832 | "source": [ 1833 | "run(optimizer = SGD, aggregate = gm, attack = zeroGradient, config = SGDConfig)" 1834 | ] 1835 | }, 1836 | { 1837 | "cell_type": "markdown", 1838 | "metadata": {}, 1839 | "source": [ 1840 | "### SGD - Krum" 1841 | ] 1842 | }, 1843 | { 1844 | "cell_type": "code", 1845 | "execution_count": null, 1846 | "metadata": { 1847 | "scrolled": false 1848 | }, 1849 | "outputs": [], 1850 | "source": [ 1851 | "Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)\n", 1852 | "run(optimizer = SGD, aggregate = Krum, attack = None, config = SGDConfig)" 1853 | ] 1854 | }, 1855 | { 1856 | "cell_type": "markdown", 1857 | "metadata": {}, 1858 | "source": [ 1859 | "white" 1860 | ] 1861 | }, 1862 | { 1863 | "cell_type": "code", 1864 | "execution_count": null, 1865 | "metadata": { 1866 | "scrolled": false 1867 | }, 1868 | "outputs": [], 1869 | "source": [ 1870 | "Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)\n", 1871 | "run(optimizer = SGD, aggregate = Krum, attack = white, config = SGDConfig)" 1872 | ] 1873 | }, 1874 | { 1875 | "cell_type": "markdown", 1876 | "metadata": {}, 1877 | "source": [ 1878 | "max" 1879 | ] 1880 | }, 1881 | { 1882 | "cell_type": "code", 1883 | "execution_count": null, 1884 | "metadata": { 1885 | "scrolled": false 1886 | }, 1887 | "outputs": [], 1888 | "source": [ 1889 | "Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)\n", 1890 | "run(optimizer = SGD, aggregate = Krum, attack = maxValue, config = SGDConfig)" 1891 | ] 1892 | }, 1893 | { 1894 | "cell_type": "markdown", 1895 | "metadata": {}, 1896 | "source": [ 1897 | "zero Gradient" 1898 | ] 1899 | }, 1900 | { 1901 | "cell_type": "code", 1902 | "execution_count": null, 1903 | "metadata": { 1904 | "scrolled": true 1905 | }, 1906 | "outputs": [], 1907 | "source": [ 1908 | "Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)\n", 1909 | "run(optimizer = SGD, aggregate = Krum, attack = zeroGradient, config = SGDConfig)" 1910 | ] 1911 | }, 1912 | { 1913 | "cell_type": "markdown", 1914 | "metadata": {}, 1915 | "source": [ 1916 | "### SGD - Median" 1917 | ] 1918 | }, 1919 | { 1920 | "cell_type": "code", 1921 | "execution_count": null, 1922 | "metadata": { 1923 | "scrolled": false 1924 | }, 1925 | "outputs": [], 1926 | "source": [ 1927 | "run(optimizer = SGD, aggregate = median, attack = None, config = SGDConfig)" 1928 | ] 1929 | }, 1930 | { 1931 | "cell_type": "markdown", 1932 | "metadata": {}, 1933 | "source": [ 1934 | "white" 1935 | ] 1936 | }, 1937 | { 1938 | "cell_type": "code", 1939 | "execution_count": null, 1940 | "metadata": { 1941 | "scrolled": false 1942 | }, 1943 | "outputs": [], 1944 | "source": [ 1945 | "run(optimizer = SGD, aggregate = median, attack = white, config = SGDConfig)" 1946 | ] 1947 | }, 1948 | { 1949 | "cell_type": "markdown", 1950 | "metadata": {}, 1951 | "source": [ 1952 | "max" 1953 | ] 1954 | }, 1955 | { 1956 | "cell_type": "code", 1957 | "execution_count": null, 1958 | "metadata": { 1959 | "scrolled": false 1960 | }, 1961 | "outputs": [], 1962 | "source": [ 1963 | "run(optimizer = SGD, aggregate = median, attack = maxValue, config = SGDConfig)" 1964 | ] 1965 | }, 1966 | { 1967 | "cell_type": "markdown", 1968 | "metadata": {}, 1969 | "source": [ 1970 | "zero Gradient" 1971 | ] 1972 | }, 1973 | { 1974 | "cell_type": "code", 1975 | "execution_count": null, 1976 | "metadata": { 1977 | "scrolled": true 1978 | }, 1979 | "outputs": [], 1980 | "source": [ 1981 | "run(optimizer = SGD, aggregate = median, attack = zeroGradient, config = SGDConfig)" 1982 | ] 1983 | }, 1984 | { 1985 | "cell_type": "markdown", 1986 | "metadata": {}, 1987 | "source": [ 1988 | "## BatchSGD" 1989 | ] 1990 | }, 1991 | { 1992 | "cell_type": "markdown", 1993 | "metadata": {}, 1994 | "source": [ 1995 | "### BatchSGD - mean" 1996 | ] 1997 | }, 1998 | { 1999 | "cell_type": "code", 2000 | "execution_count": null, 2001 | "metadata": { 2002 | "scrolled": false 2003 | }, 2004 | "outputs": [], 2005 | "source": [ 2006 | "run(optimizer = BatchSGD, aggregate = mean, attack = None, config = batchConfig)" 2007 | ] 2008 | }, 2009 | { 2010 | "cell_type": "markdown", 2011 | "metadata": {}, 2012 | "source": [ 2013 | "white" 2014 | ] 2015 | }, 2016 | { 2017 | "cell_type": "code", 2018 | "execution_count": null, 2019 | "metadata": { 2020 | "scrolled": false 2021 | }, 2022 | "outputs": [], 2023 | "source": [ 2024 | "run(optimizer = BatchSGD, aggregate = mean, attack = white, config = batchConfig)" 2025 | ] 2026 | }, 2027 | { 2028 | "cell_type": "markdown", 2029 | "metadata": {}, 2030 | "source": [ 2031 | "max" 2032 | ] 2033 | }, 2034 | { 2035 | "cell_type": "code", 2036 | "execution_count": null, 2037 | "metadata": { 2038 | "code_folding": [], 2039 | "scrolled": false 2040 | }, 2041 | "outputs": [], 2042 | "source": [ 2043 | "run(optimizer = BatchSGD, aggregate = mean, attack = maxValue, config = batchConfig)" 2044 | ] 2045 | }, 2046 | { 2047 | "cell_type": "markdown", 2048 | "metadata": {}, 2049 | "source": [ 2050 | "zero Gradient" 2051 | ] 2052 | }, 2053 | { 2054 | "cell_type": "code", 2055 | "execution_count": null, 2056 | "metadata": { 2057 | "scrolled": true 2058 | }, 2059 | "outputs": [], 2060 | "source": [ 2061 | "run(optimizer = BatchSGD, aggregate = mean, attack = zeroGradient, config = batchConfig)" 2062 | ] 2063 | }, 2064 | { 2065 | "cell_type": "markdown", 2066 | "metadata": {}, 2067 | "source": [ 2068 | "### BatchSGD - geomtric median" 2069 | ] 2070 | }, 2071 | { 2072 | "cell_type": "code", 2073 | "execution_count": null, 2074 | "metadata": { 2075 | "scrolled": false 2076 | }, 2077 | "outputs": [], 2078 | "source": [ 2079 | "run(optimizer = BatchSGD, aggregate = gm, attack = None, config = batchConfig)" 2080 | ] 2081 | }, 2082 | { 2083 | "cell_type": "markdown", 2084 | "metadata": {}, 2085 | "source": [ 2086 | "white" 2087 | ] 2088 | }, 2089 | { 2090 | "cell_type": "code", 2091 | "execution_count": null, 2092 | "metadata": { 2093 | "scrolled": false 2094 | }, 2095 | "outputs": [], 2096 | "source": [ 2097 | "run(optimizer = BatchSGD, aggregate = gm, attack = white, config = batchConfig)" 2098 | ] 2099 | }, 2100 | { 2101 | "cell_type": "markdown", 2102 | "metadata": {}, 2103 | "source": [ 2104 | "max" 2105 | ] 2106 | }, 2107 | { 2108 | "cell_type": "code", 2109 | "execution_count": null, 2110 | "metadata": { 2111 | "scrolled": false 2112 | }, 2113 | "outputs": [], 2114 | "source": [ 2115 | "run(optimizer = BatchSGD, aggregate = gm, attack = maxValue, config = batchConfig)" 2116 | ] 2117 | }, 2118 | { 2119 | "cell_type": "markdown", 2120 | "metadata": {}, 2121 | "source": [ 2122 | "zero Gradient" 2123 | ] 2124 | }, 2125 | { 2126 | "cell_type": "code", 2127 | "execution_count": null, 2128 | "metadata": { 2129 | "scrolled": false 2130 | }, 2131 | "outputs": [], 2132 | "source": [ 2133 | "run(optimizer = BatchSGD, aggregate = gm, attack = zeroGradient, config = batchConfig)" 2134 | ] 2135 | }, 2136 | { 2137 | "cell_type": "markdown", 2138 | "metadata": {}, 2139 | "source": [ 2140 | "## SAGA" 2141 | ] 2142 | }, 2143 | { 2144 | "cell_type": "markdown", 2145 | "metadata": {}, 2146 | "source": [ 2147 | "### SAGA - mean" 2148 | ] 2149 | }, 2150 | { 2151 | "cell_type": "code", 2152 | "execution_count": null, 2153 | "metadata": {}, 2154 | "outputs": [], 2155 | "source": [ 2156 | "run(optimizer = SAGA, aggregate = mean, attack = None, config = SAGAConfig)" 2157 | ] 2158 | }, 2159 | { 2160 | "cell_type": "markdown", 2161 | "metadata": {}, 2162 | "source": [ 2163 | "white" 2164 | ] 2165 | }, 2166 | { 2167 | "cell_type": "code", 2168 | "execution_count": null, 2169 | "metadata": { 2170 | "scrolled": false 2171 | }, 2172 | "outputs": [], 2173 | "source": [ 2174 | "run(optimizer = SAGA, aggregate = mean, attack = white, config = SAGAConfig)" 2175 | ] 2176 | }, 2177 | { 2178 | "cell_type": "markdown", 2179 | "metadata": {}, 2180 | "source": [ 2181 | "max" 2182 | ] 2183 | }, 2184 | { 2185 | "cell_type": "code", 2186 | "execution_count": null, 2187 | "metadata": { 2188 | "scrolled": false 2189 | }, 2190 | "outputs": [], 2191 | "source": [ 2192 | "run(optimizer = SAGA, aggregate = mean, attack = maxValue, config = SAGAConfig)" 2193 | ] 2194 | }, 2195 | { 2196 | "cell_type": "markdown", 2197 | "metadata": {}, 2198 | "source": [ 2199 | "zero Gradient" 2200 | ] 2201 | }, 2202 | { 2203 | "cell_type": "code", 2204 | "execution_count": null, 2205 | "metadata": { 2206 | "scrolled": false 2207 | }, 2208 | "outputs": [], 2209 | "source": [ 2210 | "run(optimizer = SAGA, aggregate = mean, attack = zeroGradient, config = SAGAConfig)" 2211 | ] 2212 | }, 2213 | { 2214 | "cell_type": "markdown", 2215 | "metadata": {}, 2216 | "source": [ 2217 | "### SAGA - geomtric median" 2218 | ] 2219 | }, 2220 | { 2221 | "cell_type": "code", 2222 | "execution_count": null, 2223 | "metadata": { 2224 | "scrolled": false 2225 | }, 2226 | "outputs": [], 2227 | "source": [ 2228 | "run(optimizer = SAGA, aggregate = gm, attack = None, config = SAGAConfig)" 2229 | ] 2230 | }, 2231 | { 2232 | "cell_type": "markdown", 2233 | "metadata": {}, 2234 | "source": [ 2235 | "white" 2236 | ] 2237 | }, 2238 | { 2239 | "cell_type": "code", 2240 | "execution_count": null, 2241 | "metadata": { 2242 | "scrolled": false 2243 | }, 2244 | "outputs": [], 2245 | "source": [ 2246 | "run(optimizer = SAGA, aggregate = gm, attack = white, config = SAGAConfig)" 2247 | ] 2248 | }, 2249 | { 2250 | "cell_type": "markdown", 2251 | "metadata": {}, 2252 | "source": [ 2253 | "max" 2254 | ] 2255 | }, 2256 | { 2257 | "cell_type": "code", 2258 | "execution_count": null, 2259 | "metadata": {}, 2260 | "outputs": [], 2261 | "source": [ 2262 | "run(optimizer = SAGA, aggregate = gm, attack = maxValue, config = SAGAConfig)" 2263 | ] 2264 | }, 2265 | { 2266 | "cell_type": "markdown", 2267 | "metadata": {}, 2268 | "source": [ 2269 | "zero Gradient" 2270 | ] 2271 | }, 2272 | { 2273 | "cell_type": "code", 2274 | "execution_count": null, 2275 | "metadata": { 2276 | "scrolled": false 2277 | }, 2278 | "outputs": [], 2279 | "source": [ 2280 | "run(optimizer = SAGA, aggregate = gm, attack = zeroGradient, config = SAGAConfig)" 2281 | ] 2282 | }, 2283 | { 2284 | "cell_type": "markdown", 2285 | "metadata": {}, 2286 | "source": [ 2287 | "### SAGA - Krum" 2288 | ] 2289 | }, 2290 | { 2291 | "cell_type": "code", 2292 | "execution_count": null, 2293 | "metadata": { 2294 | "scrolled": false 2295 | }, 2296 | "outputs": [], 2297 | "source": [ 2298 | "Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)\n", 2299 | "run(optimizer = SAGA, aggregate = Krum, attack = None, config = SAGAConfig)" 2300 | ] 2301 | }, 2302 | { 2303 | "cell_type": "markdown", 2304 | "metadata": {}, 2305 | "source": [ 2306 | "white" 2307 | ] 2308 | }, 2309 | { 2310 | "cell_type": "code", 2311 | "execution_count": null, 2312 | "metadata": { 2313 | "scrolled": false 2314 | }, 2315 | "outputs": [], 2316 | "source": [ 2317 | "Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)\n", 2318 | "run(optimizer = SAGA, aggregate = Krum, attack = white, config = SAGAConfig)" 2319 | ] 2320 | }, 2321 | { 2322 | "cell_type": "markdown", 2323 | "metadata": {}, 2324 | "source": [ 2325 | "max" 2326 | ] 2327 | }, 2328 | { 2329 | "cell_type": "code", 2330 | "execution_count": null, 2331 | "metadata": { 2332 | "scrolled": false 2333 | }, 2334 | "outputs": [], 2335 | "source": [ 2336 | "Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)\n", 2337 | "run(optimizer = SAGA, aggregate = Krum, attack = maxValue, config = SAGAConfig)" 2338 | ] 2339 | }, 2340 | { 2341 | "cell_type": "markdown", 2342 | "metadata": {}, 2343 | "source": [ 2344 | "zero Gradient" 2345 | ] 2346 | }, 2347 | { 2348 | "cell_type": "code", 2349 | "execution_count": null, 2350 | "metadata": { 2351 | "scrolled": true 2352 | }, 2353 | "outputs": [], 2354 | "source": [ 2355 | "Krum = Krum_(nodeSize=dataSetConfig['honestNodeSize'], byzantineSize=0)\n", 2356 | "run(optimizer = SAGA, aggregate = Krum, attack = zeroGradient, config = SAGAConfig)" 2357 | ] 2358 | }, 2359 | { 2360 | "cell_type": "markdown", 2361 | "metadata": {}, 2362 | "source": [ 2363 | "### SAGA - Median" 2364 | ] 2365 | }, 2366 | { 2367 | "cell_type": "code", 2368 | "execution_count": null, 2369 | "metadata": { 2370 | "scrolled": false 2371 | }, 2372 | "outputs": [], 2373 | "source": [ 2374 | "run(optimizer = SAGA, aggregate = median, attack = None, config = SAGAConfig)" 2375 | ] 2376 | }, 2377 | { 2378 | "cell_type": "markdown", 2379 | "metadata": {}, 2380 | "source": [ 2381 | "white" 2382 | ] 2383 | }, 2384 | { 2385 | "cell_type": "code", 2386 | "execution_count": null, 2387 | "metadata": { 2388 | "scrolled": false 2389 | }, 2390 | "outputs": [], 2391 | "source": [ 2392 | "run(optimizer = SAGA, aggregate = median, attack = white, config = SAGAConfig)" 2393 | ] 2394 | }, 2395 | { 2396 | "cell_type": "markdown", 2397 | "metadata": {}, 2398 | "source": [ 2399 | "max" 2400 | ] 2401 | }, 2402 | { 2403 | "cell_type": "code", 2404 | "execution_count": null, 2405 | "metadata": { 2406 | "scrolled": false 2407 | }, 2408 | "outputs": [], 2409 | "source": [ 2410 | "run(optimizer = SAGA, aggregate = median, attack = maxValue, config = SAGAConfig)" 2411 | ] 2412 | }, 2413 | { 2414 | "cell_type": "markdown", 2415 | "metadata": {}, 2416 | "source": [ 2417 | "zero Gradient" 2418 | ] 2419 | }, 2420 | { 2421 | "cell_type": "code", 2422 | "execution_count": null, 2423 | "metadata": { 2424 | "scrolled": true 2425 | }, 2426 | "outputs": [], 2427 | "source": [ 2428 | "run(optimizer = SAGA, aggregate = median, attack = zeroGradient, config = SAGAConfig)" 2429 | ] 2430 | } 2431 | ], 2432 | "metadata": { 2433 | "kernelspec": { 2434 | "display_name": "Python 3", 2435 | "language": "python", 2436 | "name": "python3" 2437 | }, 2438 | "language_info": { 2439 | "codemirror_mode": { 2440 | "name": "ipython", 2441 | "version": 3 2442 | }, 2443 | "file_extension": ".py", 2444 | "mimetype": "text/x-python", 2445 | "name": "python", 2446 | "nbconvert_exporter": "python", 2447 | "pygments_lexer": "ipython3", 2448 | "version": "3.7.4" 2449 | }, 2450 | "varInspector": { 2451 | "cols": { 2452 | "lenName": 16, 2453 | "lenType": 16, 2454 | "lenVar": 40 2455 | }, 2456 | "kernels_config": { 2457 | "python": { 2458 | "delete_cmd_postfix": "", 2459 | "delete_cmd_prefix": "del ", 2460 | "library": "var_list.py", 2461 | "varRefreshCmd": "print(var_dic_list())" 2462 | }, 2463 | "r": { 2464 | "delete_cmd_postfix": ") ", 2465 | "delete_cmd_prefix": "rm(", 2466 | "library": "var_list.r", 2467 | "varRefreshCmd": "cat(var_dic_list()) " 2468 | } 2469 | }, 2470 | "position": { 2471 | "height": "713px", 2472 | "left": "1485px", 2473 | "right": "20px", 2474 | "top": "120px", 2475 | "width": "350px" 2476 | }, 2477 | "types_to_exclude": [ 2478 | "module", 2479 | "function", 2480 | "builtin_function_or_method", 2481 | "instance", 2482 | "_Feature" 2483 | ], 2484 | "window_display": false 2485 | } 2486 | }, 2487 | "nbformat": 4, 2488 | "nbformat_minor": 2 2489 | } 2490 | --------------------------------------------------------------------------------