├── utils ├── __init__.py ├── arguments.py ├── defined_models.py ├── mrdatasets.py ├── utils.py └── Data_Prepper.py ├── Approximation comparison vs sampling ├── error_vs_N.csv ├── error_vs_d.csv ├── error_vs_N=5-10.csv ├── error_vs_d=1024-32768.csv ├── plot_error.py └── error.py ├── LICENSE ├── README.md ├── .gitignore ├── environment.yml └── main.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | # from .misc import * 4 | # from .logger import * 5 | # from .visualize import * 6 | # from .eval import * 7 | 8 | # progress bar 9 | # import os, sys 10 | # sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | # from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /Approximation comparison vs sampling/error_vs_N.csv: -------------------------------------------------------------------------------- 1 | n,l1diff_mean,l1diff_std,l1diff_hat_mean,l1diff_hat_std,l2diff_mean,l2diff_std,l2diff_hat_mean,l2diff_hat_std,pearsonr_mean,pearsonr_std,pearsonr_hat_mean,pearsonr_hat_std,cosines,sv hats 2 | 5,0.008091998100280762,0.0022460016211591116,2.41398811340332e-07,5.6151877884318745e-08,0.004669313970953226,0.0010289216679312796,1.3261171147860295e-07,2.8857416732577295e-08,0.983971064677535,0.00979651275598962,0.9999999999903852,7.223624793619214e-12,0.0002438896000000579,0.04280750059999998 3 | 6,0.009732875227928161,0.0029208887077009998,0.02095826417207718,0.025239504868412435,0.005180482845753431,0.0011251909626681718,0.010527192051867473,0.01308483604975957,0.9803676380264893,0.012760495825332596,0.9368630872583854,0.1269788589612159,0.00023987519999995043,0.0625431088 4 | 7,0.009452568491299947,0.0026150339660095676,0.02958042621612549,0.02516592636275709,0.004866639835139116,0.001172765459526596,0.014075724156803441,0.012344759463706985,0.981111714348765,0.01135251219832652,0.81847367965078,0.31377652463498207,0.00025484079999991136,0.07317314120000003 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Michael 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Approximation comparison vs sampling/error_vs_d.csv: -------------------------------------------------------------------------------- 1 | d,l1diff_mean,l1diff_std,l1diff_hat_mean,l1diff_hat_std,l2diff_mean,l2diff_std,l2diff_hat_mean,l2diff_hat_std,pearsonr_mean,pearsonr_std,pearsonr_hat_mean,pearsonr_hat_std,cosines,sv hats 2 | 1024,0.010020508244633675,0.0027466566124571767,0.03410668522119522,0.023495181348128252,0.004833548585884273,0.0010494290825008947,0.01529074134089541,0.010924339918180048,0.9831591586503853,0.010768545831461276,0.8251711910225825,0.270992879700879,0.0005306975999999963,4.115784334 3 | 2048,0.009702356457710266,0.002628454137926433,0.03643774926662445,0.0223370005453812,0.00448323605582118,0.0012219756207280763,0.01575039613819513,0.009992213560190148,0.9840992320278068,0.010161601323653811,0.8168065901883128,0.2452534343022872,0.000525275399999714,3.9986797304 4 | 4096,0.009152701248725255,0.0027501761634120114,0.038969969004392625,0.022162344866794505,0.00414364078703026,0.0013628797496670999,0.016330417120012916,0.009413805292273885,0.9837923195080249,0.01006132717679411,0.7781283985108084,0.24736190668123736,0.0005610317999995118,4.000483513400003 5 | 8192,0.008353987974779946,0.0032342641747702618,0.04113902917930058,0.021358209501296972,0.003743493545334786,0.001607186100951362,0.01710734957972022,0.009028263826809518,0.9844362806821139,0.010083608170416751,0.718447721931862,0.30910441759462937,0.0006285209999987273,4.0540632465999975 6 | 16384,0.0076129425317049025,0.003620956661870113,0.04178902823477983,0.020310665561547585,0.003394814202329144,0.0017702966866505617,0.017306823841149566,0.008586734516723125,0.98550658731108,0.009919978059013327,0.6935460341760626,0.3018107485231935,0.0007434163999988641,4.1105634444 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gradient Driven Rewards to Guarantee Fairness in Collaborative Machine Learning [NeurIPS'2021] 2 | Official code repository for our accepted work "Gradient Driven Rewards to Guarantee Fairness in Collaborative Machine Learning" in the Thirty-fifth Conference on Neural Information Processing Systems (NeurIPS) 2021: 3 | 4 | > Xinyi Xu*, Lingjuan Lyu\*, Xingjun Ma, Chenglin Miao, Chuan Sheng Foo, Bryan Kian Hsiang Low 5 | > 6 | > Gradient Driven Rewards to Guarantee Fairness in Collaborative Machine Learning [Paper](https://proceedings.neurips.cc/paper/2021/hash/8682cc30db9c025ecd3fee433f8ab54c-Abstract.html) 7 | 8 | ### Set up environment using conda 9 | 10 | Tested OS platform: Ubuntu 20.04 with Nvidia driver Version: 470.86 CUDA Version: 11.4 11 | 12 | ` conda env create -f environment.yml` 13 | 14 | ### Running the `main.py` 15 | 16 | Running on _MNIST_ dataset with 5 agents and uniform data split (i.e., I.I.D). Automatically uses GPU if available. 17 | 18 | `python main.py -D mnist -N 5 -split uni ` 19 | 20 | ### Results directory 21 | 22 | The results are saved in csv formats in a `RESULTS` directory (created if not exist) by default. 23 | 24 | 25 | ## Citing 26 | If you have found our work to be useful in your research, please consider citing it with the following bibtex: 27 | ``` 28 | @inproceedings{Xu2021, 29 | author = {Xu, Xinyi and Lyu, Lingjuan and Ma, Xingjun and Miao, Chenglin and Foo, Chuan Sheng and Low, Bryan Kian Hsiang}, 30 | booktitle = {Advances in Neural Information Processing Systems}, 31 | editor = {M. Ranzato and A. Beygelzimer and Y. Dauphin and P.S. Liang and J. Wortman Vaughan}, 32 | pages = {16104--16117}, 33 | publisher = {Curran Associates, Inc.}, 34 | title = {Gradient Driven Rewards to Guarantee Fairness in Collaborative Machine Learning}, 35 | volume = {34}, 36 | year = {2021} 37 | } 38 | ``` 39 | -------------------------------------------------------------------------------- /Approximation comparison vs sampling/error_vs_N=5-10.csv: -------------------------------------------------------------------------------- 1 | n,l1diff_mean,l1diff_std,l1diff_hat_mean,l1diff_hat_std,l2diff_mean,l2diff_std,l2diff_hat_mean,l2diff_hat_std,pearsonr_mean,pearsonr_std,pearsonr_hat_mean,pearsonr_hat_std,cosines,sv hats 2 | 5,0.009171876311302184,0.0026954266861583353,2.3096799850463867e-07,7.798432542933812e-08,0.004719050088897348,0.0012607978851724244,1.2377369884575273e-07,4.4717760714537175e-08,0.9759886712212655,0.020614353785632766,0.9999999999854916,2.016413079965404e-11,0.00020376839999998176,0.04084072640000003 3 | 6,0.00869121178984642,0.002684737451832855,0.014126814901828766,0.015791201201074887,0.004370174539508298,0.00132489325609156,0.007271581143605132,0.00809444302099282,0.9795621776599542,0.01788258091883257,0.8768511222094059,0.18509045396453067,0.0002221987000000425,0.05937654349999997 4 | 7,0.009545473009347915,0.003166961731762969,0.02451517308751742,0.021106502061252433,0.004537772088466833,0.0013761269243564743,0.011536425492104504,0.009573028029051179,0.9816828549393876,0.015553553067805682,0.8588889018691267,0.18806918083002513,0.0002635568999999727,0.07953291360000007 5 | 8,0.010160443000495433,0.0035798661671830514,0.029701840318739414,0.021080489519420643,0.004687007356551476,0.0014215370952947136,0.013678314991061846,0.009359330139022943,0.9821432749241559,0.014736900516640668,0.8541253782944389,0.17491450386522328,0.0004026877000000706,0.11774800939999985 6 | 9,0.010137430131435395,0.003311372853978838,0.03263758897781372,0.020069239066262436,0.004554644722957164,0.0013264949706016593,0.014546289621106752,0.008614904975406358,0.9831939463816035,0.013703971477418541,0.831418569095722,0.18013878645072626,0.000492392400000341,0.4793974747999995 7 | 10,0.010132978484034539,0.0031048988123833733,0.03601948954164982,0.02097701272620769,0.004434980861454581,0.0012607208418478771,0.015603769904467768,0.008577283199672166,0.9838800991213649,0.013292003576924663,0.8327431585720994,0.16641670578998258,0.0005893068000013102,4.225892503699999 8 | -------------------------------------------------------------------------------- /Approximation comparison vs sampling/error_vs_d=1024-32768.csv: -------------------------------------------------------------------------------- 1 | d,l1diff_mean,l1diff_std,l1diff_hat_mean,l1diff_hat_std,l2diff_mean,l2diff_std,l2diff_hat_mean,l2diff_hat_std,pearsonr_mean,pearsonr_std,pearsonr_hat_mean,pearsonr_hat_std,cosines,sv hats 2 | 1024,0.010312746890953609,0.00295944025958162,0.03697453045419284,0.020003939067078345,0.004426389980861651,0.0011841413886605383,0.015707147953100124,0.00810138873239015,0.9833668558580976,0.015609888773573518,0.8332952469388546,0.15815323385973165,0.0005621512999979927,4.852790423800004 3 | 2048,0.010004725307226181,0.0029479960952859003,0.03870062511414289,0.019867559024566914,0.004245720243488904,0.0012209276189026174,0.01615865963406362,0.007857196189302095,0.9832453527347165,0.016082025067087172,0.8133475424704268,0.16663777917232583,0.0005698087000041597,4.419399778399997 4 | 4096,0.009531992591089673,0.003122165666972018,0.03940594055586391,0.019042799038970777,0.004031320021022111,0.0013135626193473203,0.016297273632863037,0.0074913053306994125,0.9833260485494423,0.015622290367436483,0.7906524587737681,0.17993939752013077,0.0005788668000008101,4.581610664799999 5 | 8192,0.0089887186139822,0.0033900868909086646,0.04055105626583099,0.018656620082275317,0.003784100858028978,0.0014533568064227203,0.016691482220869247,0.007325895776894984,0.9835665456193816,0.015045423154830031,0.7603427022755648,0.2107820890997765,0.0006593860000037921,4.642013692199993 6 | 16384,0.008413977853276513,0.0037132661213964972,0.041128960184075615,0.018055742077061204,0.0035298809869511224,0.0016037359464565759,0.016765730256637825,0.007061835551014405,0.9842341484039844,0.014531937341625986,0.7296600984041836,0.23724340012728243,0.000751572299992631,4.742844950000006 7 | 32768,0.007865127176046371,0.003998077704762082,0.041127730657656986,0.017442201245039576,0.003292862595602249,0.0017263988068967433,0.016789496104155137,0.006821920190619116,0.9848276463228708,0.01413783644935122,0.6873493462987206,0.28491902690359533,0.001566345599997021,5.265680147500007 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /utils/arguments.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | 4 | from utils.defined_models import CNN_Net, CNN_Text, CNN_Cifar10 5 | 6 | 7 | mnist_args = { 8 | 9 | # setting parameters 10 | 'dataset': 'mnist', 11 | 'sample_size_cap': 6000, 12 | 'split': 'powerlaw', #or 'classimbalance' 13 | 14 | 'batch_size' : 32, 15 | 'train_val_split_ratio': 0.9, 16 | 'lambda': 0.5, 17 | 'alpha': 0.95, 18 | 'Gamma': 0.5, 19 | 'lambda': 0, # coefficient between sign_cossim and modu_cossim 20 | 21 | # model parameters 22 | 'model_fn': CNN_Net, 23 | 'optimizer_fn': optim.SGD, 24 | 'loss_fn': nn.NLLLoss(), 25 | 'lr': 0.15, 26 | 'gamma':0.977, 27 | 'lr_decay':0.977, #0.977**100 ~= 0.1 28 | 29 | } 30 | 31 | 32 | sst_args = { 33 | 34 | # setting parameters 35 | 'dataset': 'sst', 36 | 'sample_size_cap': 5000, 37 | 'split': 'powerlaw', #or 'powerlaw' classimbalance 38 | 'batch_size' : 256, 39 | 40 | 'train_val_split_ratio': 0.9, 41 | 'alpha': 0.95, 42 | 'Gamma': 1, 43 | 'lambda': 1, # coefficient between sign_cossim and modu_cossim 44 | 45 | 46 | # model parameters 47 | 'model_fn': CNN_Text, 48 | 'embed_num': 20000, 49 | 'embed_dim': 300, 50 | 'class_num': 5, 51 | 'kernel_num': 128, 52 | 'kernel_sizes': [3,3,3], 53 | 'static':False, 54 | 55 | 'optimizer_fn': optim.Adam, 56 | 'loss_fn': nn.NLLLoss(), 57 | 'lr': 1e-4, 58 | # 'grad_clip':1e-3, 59 | 'gamma':0.977, 60 | 'lr_decay':0.977, #0.977**100 ~= 0.1 61 | } 62 | 63 | 64 | mr_args = { 65 | 66 | # setting parameters 67 | 'dataset': 'mr', 68 | 69 | 'batch_size' : 128, 70 | 'train_val_split_ratio': 0.9, 71 | 'alpha': 0.95, 72 | 'lambda': 0.5, # coefficient between sign_cossim and modu_cossim 73 | 'Gamma':1, 74 | 75 | # model parameters 76 | 'model_fn': CNN_Text, 77 | 'embed_num': 20000, 78 | 'embed_dim': 300, 79 | 'class_num': 2, 80 | 'kernel_num': 128, 81 | 'kernel_sizes': [3,3,3], 82 | 'static':False, 83 | 84 | 'optimizer_fn': optim.Adam, 85 | 'loss_fn': nn.NLLLoss(), 86 | 'lr': 5e-5, 87 | # 'grad_clip':1e-3, 88 | 'gamma':0.977, 89 | 'lr_decay':0.977, #0.977**100 ~= 0.1 90 | 91 | } 92 | 93 | 94 | cifar_cnn_args = { 95 | 96 | # setting parameters 97 | 'dataset': 'cifar10', 98 | 99 | 'batch_size' : 128, 100 | 'train_val_split_ratio': 0.8, 101 | 'alpha': 0.95, 102 | 'Gamma': 0.15, 103 | 'lambda': 0.5, # coefficient between sign_cossim and modu_cossim 104 | 105 | # model parameters 106 | 'model_fn': CNN_Cifar10, 107 | 'optimizer_fn': optim.SGD, 108 | 'loss_fn': nn.NLLLoss(), 109 | 'lr': 0.015, 110 | 'gamma':0.977, 111 | 'lr_decay':0.977, #0.977**100 ~= 0.1 112 | 113 | 114 | } 115 | 116 | -------------------------------------------------------------------------------- /Approximation comparison vs sampling/plot_error.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | import numpy as np 4 | 5 | # plt.rcParams["text.usetex"] = True 6 | plt.rcParams["font.family"] = "serif" 7 | plt.rcParams["font.serif"] = ["Times New Roman"] + plt.rcParams["font.serif"] 8 | 9 | 10 | vs_D_df = pd.read_csv('error_vs_d=1024-32768.csv') 11 | vs_N_df = pd.read_csv('error_vs_N=5-10.csv') 12 | 13 | colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] 14 | linestyle = ['-', '.-'] 15 | 16 | rename_dict = {'l1diff':'$||\\mathbb{{\\phi}} -\\mathbb{\\psi}||_{1} $', 17 | 'l1diff_hat':'$||\\mathbb{{\\phi}} -\\mathbb{\\bar{\\phi}}||_{1}$', 18 | 'l2diff':'$||\\mathbb{{\\phi}} -\\mathbb{\\psi}||_{2}$', 19 | 'l2diff_hat':'$||\\mathbb{{\\phi}} -\\mathbb{\\bar{\\phi}}||_{2}$',} 20 | 21 | def plot(df, vs_N=True): 22 | index_col = 'n' if vs_N else 'd' 23 | 24 | fig, ax1 = plt.subplots() 25 | ax2 = ax1.twinx() 26 | 27 | x = df[index_col] 28 | for key in df.columns: 29 | if key == index_col: continue 30 | value = df[key] 31 | 32 | linestyle = '--' if 'hat' in key else '-' 33 | 34 | if 'l1' in key: color = colors[0] 35 | elif 'l2' in key: color = colors[1] 36 | else: color = colors[2] 37 | 38 | if key in ['cosines', 'sv hats']: 39 | key = key.replace('cosines', 'time $\\mathbb{\\psi}$') 40 | key = key.replace('sv hats', 'time $\\mathbb{\\bar{\\phi}}$') 41 | ax2.plot(x, value, linestyle=linestyle, color=color, label=key, linewidth=6) 42 | elif 'diff' in key and '_mean' in key: 43 | key = key.replace('_mean', '') 44 | key = rename_dict[key] 45 | ax1.plot(x, value, linestyle=linestyle, color=color, label=key, linewidth=6.0) 46 | 47 | # ask matplotlib for the plotted objects and their labels 48 | lines, labels = ax1.get_legend_handles_labels() 49 | lines2, labels2 = ax2.get_legend_handles_labels() 50 | if vs_N: 51 | ax1.legend(lines + lines2, labels + labels2, loc=2, fontsize=20) 52 | 53 | # ax1.legend(lines + lines2, labels + labels2, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=3, fancybox=True, shadow=True) 54 | 55 | 56 | x_label = '$N$ agents' if vs_N else '$D$' 57 | ax1.set_xlabel(x_label, fontsize=28) 58 | if vs_N: 59 | ax1.set_ylabel('Error', fontsize=28) 60 | if not vs_N: 61 | ax2.set_ylabel('Time (seconds)', fontsize=28) 62 | 63 | # if not vs_N: 64 | # ax2.set_xscale('log') 65 | # ax2.ticklabel_format(axis="x", style="sci", scilimits=(0,0)) 66 | 67 | ax1.tick_params(axis="x", labelsize=28) 68 | ax1.tick_params(axis="y", labelsize=28) 69 | ax2.tick_params(axis="y", labelsize=28) 70 | 71 | plt.tight_layout() 72 | 73 | save = True 74 | if save: 75 | figname = 'vs_N' if vs_N else 'vs_D' 76 | 77 | plt.savefig(figname+'.png') 78 | plt.clf() 79 | else: 80 | plt.show() 81 | 82 | 83 | plot(vs_N_df) 84 | plot(vs_D_df, False) 85 | -------------------------------------------------------------------------------- /utils/defined_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | 7 | # for MNIST 32*32 8 | class CNN_Net(nn.Module): 9 | 10 | def __init__(self, device=None): 11 | super(CNN_Net, self).__init__() 12 | self.conv1 = nn.Conv2d(1, 64, 3, 1) 13 | self.conv2 = nn.Conv2d(64, 16, 7, 1) 14 | self.fc1 = nn.Linear(4 * 4 * 16, 200) 15 | self.fc2 = nn.Linear(200, 10) 16 | 17 | def forward(self, x): 18 | x = x.view(-1, 1, 32, 32) 19 | x = torch.tanh(self.conv1(x)) 20 | x = F.max_pool2d(x, 2, 2) 21 | x = torch.tanh(self.conv2(x)) 22 | x = F.max_pool2d(x, 2, 2) 23 | x = x.view(-1, 4 * 4 * 16) 24 | x = torch.tanh(self.fc1(x)) 25 | x = self.fc2(x) 26 | return F.log_softmax(x, dim=1) 27 | 28 | 29 | 30 | class CNN_Cifar10(nn.Module): 31 | def __init__(self, in_channels=3, n_kernels=16, out_dim=10): 32 | super(CNN_Cifar10, self).__init__() 33 | 34 | self.conv1 = nn.Conv2d(in_channels, n_kernels, 5) 35 | self.pool = nn.MaxPool2d(2, 2) 36 | self.conv2 = nn.Conv2d(n_kernels, 2 * n_kernels, 5) 37 | self.fc1 = nn.Linear(2 * n_kernels * 5 * 5, 120) 38 | self.fc2 = nn.Linear(120, 84) 39 | self.fc3 = nn.Linear(84, out_dim) 40 | 41 | def forward(self, x): 42 | x = self.pool(F.relu(self.conv1(x))) 43 | x = self.pool(F.relu(self.conv2(x))) 44 | x = x.view(x.shape[0], -1) 45 | x = F.relu(self.fc1(x)) 46 | x = F.relu(self.fc2(x)) 47 | x = self.fc3(x) 48 | # return x 49 | return F.log_softmax(x, dim=1) 50 | 51 | 52 | class CNN_Text(nn.Module): 53 | 54 | def __init__(self, args=None, device=None): 55 | super(CNN_Text,self).__init__() 56 | 57 | 58 | self.args = args 59 | self.device = device 60 | 61 | V = args['embed_num'] 62 | D = args['embed_dim'] 63 | C = args['class_num'] 64 | Ci = 1 65 | Co = args['kernel_num'] 66 | Ks = args['kernel_sizes'] 67 | 68 | self.embed = nn.Embedding(V, D) 69 | self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (K, D)) for K in Ks]) 70 | ''' 71 | self.conv13 = nn.Conv2d(Ci, Co, (3, D)) 72 | self.conv14 = nn.Conv2d(Ci, Co, (4, D)) 73 | self.conv15 = nn.Conv2d(Ci, Co, (5, D)) 74 | ''' 75 | self.dropout = nn.Dropout(0.5) 76 | # self.dropout = nn.Dropout(args.dropout) 77 | self.fc1 = nn.Linear(len(Ks)*Co, C) 78 | 79 | def conv_and_pool(self, x, conv): 80 | x = F.relu(conv(x)).squeeze(3) #(N,Co,W) 81 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 82 | return x 83 | 84 | 85 | def forward(self, x): 86 | 87 | x = self.embed(x) # (W,N,D) 88 | # x = x.permute(1,0,2) # -> (N,W,D) 89 | # permute during loading the batches instead of in the forward function 90 | # in order to allow nn.DataParallel 91 | 92 | if not self.args or self.args['static']: 93 | x = Variable(x).to(self.device) 94 | 95 | x = x.unsqueeze(1) # (W,Ci,N,D) 96 | 97 | x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks) 98 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks) 99 | x = torch.cat(x, 1) 100 | ''' 101 | x1 = self.conv_and_pool(x,self.conv13) #(N,Co) 102 | x2 = self.conv_and_pool(x,self.conv14) #(N,Co) 103 | x3 = self.conv_and_pool(x,self.conv15) #(N,Co) 104 | x = torch.cat((x1, x2, x3), 1) # (N,len(Ks)*Co) 105 | ''' 106 | x = self.dropout(x) # (N,len(Ks)*Co) 107 | logit = self.fc1(x) # (N,C) 108 | return F.log_softmax(logit, dim=1) 109 | # return logit -------------------------------------------------------------------------------- /utils/mrdatasets.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import random 4 | import tarfile 5 | from six.moves import urllib 6 | from torchtext import data 7 | import os 8 | 9 | class TarDataset(data.Dataset): 10 | """Defines a Dataset loaded from a downloadable tar archive. 11 | Attributes: 12 | url: URL where the tar archive can be downloaded. 13 | filename: Filename of the downloaded tar archive. 14 | dirname: Name of the top-level directory within the zip archive that 15 | contains the data files. 16 | """ 17 | 18 | @classmethod 19 | def download_or_unzip(cls, root): 20 | path = os.path.join(root, cls.dirname) 21 | if not os.path.isdir(path): 22 | tpath = os.path.join(root, cls.filename) 23 | os.makedirs(root, exist_ok=True) 24 | if not os.path.isfile(tpath): 25 | print('downloading') 26 | urllib.request.urlretrieve(cls.url, tpath) 27 | with tarfile.open(tpath, 'r') as tfile: 28 | print('extracting') 29 | def is_within_directory(directory, target): 30 | 31 | abs_directory = os.path.abspath(directory) 32 | abs_target = os.path.abspath(target) 33 | 34 | prefix = os.path.commonprefix([abs_directory, abs_target]) 35 | 36 | return prefix == abs_directory 37 | 38 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False): 39 | 40 | for member in tar.getmembers(): 41 | member_path = os.path.join(path, member.name) 42 | if not is_within_directory(path, member_path): 43 | raise Exception("Attempted Path Traversal in Tar File") 44 | 45 | tar.extractall(path, members, numeric_owner=numeric_owner) 46 | 47 | 48 | safe_extract(tfile, root) 49 | return os.path.join(path, '') 50 | 51 | 52 | class MR(TarDataset): 53 | 54 | url = 'https://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz' 55 | filename = 'rt-polaritydata.tar' 56 | dirname = 'rt-polaritydata' 57 | 58 | @staticmethod 59 | def sort_key(ex): 60 | return len(ex.text) 61 | 62 | def __init__(self, text_field, label_field, path=None, examples=None, **kwargs): 63 | """Create an MR dataset instance given a path and fields. 64 | Arguments: 65 | text_field: The field that will be used for text data. 66 | label_field: The field that will be used for label data. 67 | path: Path to the data file. 68 | examples: The examples contain all the data. 69 | Remaining keyword arguments: Passed to the constructor of 70 | data.Dataset. 71 | """ 72 | def clean_str(string): 73 | """ 74 | Tokenization/string cleaning for all datasets except for SST. 75 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 76 | """ 77 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) 78 | string = re.sub(r"\'s", " \'s", string) 79 | string = re.sub(r"\'ve", " \'ve", string) 80 | string = re.sub(r"n\'t", " n\'t", string) 81 | string = re.sub(r"\'re", " \'re", string) 82 | string = re.sub(r"\'d", " \'d", string) 83 | string = re.sub(r"\'ll", " \'ll", string) 84 | string = re.sub(r",", " , ", string) 85 | string = re.sub(r"!", " ! ", string) 86 | string = re.sub(r"\(", " \( ", string) 87 | string = re.sub(r"\)", " \) ", string) 88 | string = re.sub(r"\?", " \? ", string) 89 | string = re.sub(r"\s{2,}", " ", string) 90 | return string.strip() 91 | 92 | text_field.preprocessing = data.Pipeline(clean_str) 93 | fields = [('text', text_field), ('label', label_field)] 94 | 95 | if examples is None: 96 | path = self.dirname if path is None else path 97 | examples = [] 98 | with open(os.path.join(path, 'rt-polarity.neg'), encoding = "ISO-8859-1") as f: 99 | examples += [ 100 | data.Example.fromlist([line, 0], fields) for line in f] 101 | with open(os.path.join(path, 'rt-polarity.pos'), encoding = "ISO-8859-1") as f: 102 | examples += [ 103 | data.Example.fromlist([line, 1], fields) for line in f] 104 | super(MR, self).__init__(examples, fields, **kwargs) 105 | 106 | @classmethod 107 | def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True ,root='.', **kwargs): 108 | """Create dataset objects for splits of the MR dataset. 109 | Arguments: 110 | text_field: The field that will be used for the sentence. 111 | label_field: The field that will be used for label data. 112 | dev_ratio: The ratio that will be used to get split validation dataset. 113 | shuffle: Whether to shuffle the data before split. 114 | root: The root directory that the dataset's zip archive will be 115 | expanded into; therefore the directory in whose trees 116 | subdirectory the data files will be stored. 117 | train: The filename of the train data. Default: 'train.txt'. 118 | Remaining keyword arguments: Passed to the splits method of 119 | Dataset. 120 | """ 121 | path = cls.download_or_unzip(root) 122 | examples = cls(text_field, label_field, path=path, **kwargs).examples 123 | if shuffle: random.shuffle(examples) 124 | dev_index = -1 * int(dev_ratio*len(examples)) 125 | 126 | return (cls(text_field, label_field, examples=examples[:dev_index]), 127 | cls(text_field, label_field, examples=examples[dev_index:])) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: FGCML 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=4.5=1_gnu 10 | - absl-py=0.14.0=pyhd8ed1ab_0 11 | - aiohttp=3.7.4=py38h27cfd23_1 12 | - async-timeout=3.0.1=py38h06a4308_0 13 | - attrs=21.2.0=pyhd3eb1b0_0 14 | - blas=1.0=mkl 15 | - blessings=1.7=py38h578d9bd_1004 16 | - blinker=1.4=py_1 17 | - blosc=1.21.0=h8c45485_0 18 | - boto3=1.18.21=pyhd3eb1b0_0 19 | - botocore=1.21.21=pyhd3eb1b0_1 20 | - brotli=1.0.9=he6710b0_2 21 | - brotlipy=0.7.0=py38h27cfd23_1003 22 | - bzip2=1.0.8=h7b6447c_0 23 | - c-ares=1.17.1=h27cfd23_0 24 | - ca-certificates=2021.10.26=h06a4308_2 25 | - cached-property=1.5.2=py_0 26 | - cachetools=4.2.2=pyhd3eb1b0_0 27 | - catalogue=2.0.4=py38h578d9bd_0 28 | - certifi=2021.10.8=py38h06a4308_2 29 | - cffi=1.14.6=py38h400218f_0 30 | - chardet=3.0.4=py38h06a4308_1003 31 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 32 | - click=7.1.2=pyhd3eb1b0_0 33 | - colorama=0.4.4=pyhd3eb1b0_0 34 | - coverage=5.5=py38h27cfd23_2 35 | - cryptography=3.4.7=py38hd23ed53_0 36 | - cudatoolkit=11.1.1=h6406543_8 37 | - cycler=0.10.0=py38_0 38 | - cymem=2.0.5=py38h2531618_0 39 | - cython=0.29.24=py38h295c915_0 40 | - cython-blis=0.7.4=py38h27cfd23_1 41 | - dataclasses=0.8=pyh6d0b6a4_7 42 | - dbus=1.13.18=hb2f20db_0 43 | - expat=2.4.1=h2531618_2 44 | - ffmpeg=4.2.2=h20bf706_0 45 | - fontconfig=2.13.1=h6c09931_0 46 | - fonttools=4.25.0=pyhd3eb1b0_0 47 | - freetype=2.10.4=h5ab3b9f_0 48 | - glib=2.69.1=h5202010_0 49 | - gmp=6.2.1=h2531618_2 50 | - gnutls=3.6.15=he1e5248_0 51 | - google-api-core=1.25.1=pyhd3eb1b0_0 52 | - google-auth=1.33.0=pyhd3eb1b0_0 53 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 54 | - google-cloud-core=1.7.1=pyhd3eb1b0_0 55 | - google-cloud-storage=1.41.0=pyhd3eb1b0_0 56 | - google-crc32c=1.1.2=py38h27cfd23_0 57 | - google-resumable-media=1.3.1=pyhd3eb1b0_1 58 | - googleapis-common-protos=1.53.0=py38h06a4308_0 59 | - grpcio=1.36.1=py38h2157cd5_1 60 | - gst-plugins-base=1.14.0=h8213a91_2 61 | - gstreamer=1.14.0=h28cd5cc_2 62 | - h5py=2.10.0=py38h7918eee_0 63 | - hdf5=1.10.4=hb1b8bf9_0 64 | - icu=58.2=he6710b0_3 65 | - idna=3.2=pyhd3eb1b0_0 66 | - importlib-metadata=4.8.1=py38h578d9bd_0 67 | - intel-openmp=2021.3.0=h06a4308_3350 68 | - jinja2=3.0.2=pyhd3eb1b0_0 69 | - jmespath=0.10.0=py_0 70 | - joblib=1.0.1=pyhd3eb1b0_0 71 | - jpeg=9b=h024ee3a_2 72 | - kiwisolver=1.3.1=py38h2531618_0 73 | - lame=3.100=h7b6447c_0 74 | - lcms2=2.12=h3be6417_0 75 | - ld_impl_linux-64=2.35.1=h7274673_9 76 | - libcrc32c=1.1.1=he6710b0_2 77 | - libffi=3.3=he6710b0_2 78 | - libgcc-ng=9.3.0=h5101ec6_17 79 | - libgfortran-ng=7.5.0=ha8ba4b0_17 80 | - libgfortran4=7.5.0=ha8ba4b0_17 81 | - libgomp=9.3.0=h5101ec6_17 82 | - libidn2=2.3.2=h7f8727e_0 83 | - libopus=1.3.1=h7b6447c_0 84 | - libpng=1.6.37=hbc83047_0 85 | - libprotobuf=3.17.2=h4ff587b_1 86 | - libstdcxx-ng=9.3.0=hd4cf53a_17 87 | - libtasn1=4.16.0=h27cfd23_0 88 | - libtiff=4.2.0=h85742a9_0 89 | - libunistring=0.9.10=h27cfd23_0 90 | - libuuid=1.0.3=h1bed415_2 91 | - libuv=1.40.0=h7b6447c_0 92 | - libvpx=1.7.0=h439df22_0 93 | - libwebp-base=1.2.0=h27cfd23_0 94 | - libxcb=1.14=h7b6447c_0 95 | - libxml2=2.9.12=h03d6c58_0 96 | - lz4-c=1.9.3=h295c915_1 97 | - lzo=2.10=h7b6447c_2 98 | - markdown=3.3.4=pyhd8ed1ab_0 99 | - markupsafe=2.0.1=py38h27cfd23_0 100 | - matplotlib=3.4.2=py38h06a4308_0 101 | - matplotlib-base=3.4.2=py38hab158f2_0 102 | - mkl=2021.3.0=h06a4308_520 103 | - mkl-service=2.4.0=py38h7f8727e_0 104 | - mkl_fft=1.3.0=py38h42c9631_2 105 | - mkl_random=1.2.2=py38h51133e4_0 106 | - mock=4.0.3=pyhd3eb1b0_0 107 | - multidict=5.1.0=py38h27cfd23_2 108 | - munkres=1.1.4=py_0 109 | - murmurhash=1.0.5=py38h2531618_0 110 | - ncurses=6.2=he6710b0_1 111 | - nettle=3.7.3=hbbd107a_1 112 | - ninja=1.10.2=hff7bd54_1 113 | - nltk=3.5=py_0 114 | - numexpr=2.7.3=py38h22e1b3c_1 115 | - numpy=1.20.3=py38hf144106_0 116 | - numpy-base=1.20.3=py38h74d4b33_0 117 | - nvidia-ml=7.352.0=py_0 118 | - oauthlib=3.1.1=pyhd8ed1ab_0 119 | - olefile=0.46=py_0 120 | - openh264=2.1.0=hd408876_0 121 | - openjpeg=2.4.0=h3ad879b_0 122 | - openssl=1.1.1m=h7f8727e_0 123 | - packaging=21.0=pyhd3eb1b0_0 124 | - pandas=1.1.3=py38he6710b0_0 125 | - pathlib=1.0.1=py_1 126 | - pathy=0.6.0=pyhd3eb1b0_0 127 | - pcre=8.45=h295c915_0 128 | - pillow=8.3.1=py38h2c7a002_0 129 | - pip=21.2.2=py38h06a4308_0 130 | - preshed=3.0.5=py38h2531618_4 131 | - progress=1.5=py_1 132 | - protobuf=3.17.2=py38h295c915_0 133 | - psutil=5.8.0=py38h497a2fe_1 134 | - pyasn1=0.4.8=py_0 135 | - pyasn1-modules=0.2.8=py_0 136 | - pycparser=2.20=py_2 137 | - pydantic=1.8.2=py38h27cfd23_0 138 | - pyjwt=2.1.0=pyhd8ed1ab_0 139 | - pyopenssl=20.0.1=pyhd3eb1b0_1 140 | - pyparsing=2.4.7=pyhd3eb1b0_0 141 | - pyqt=5.9.2=py38h05f1152_4 142 | - pysocks=1.7.1=py38h06a4308_0 143 | - pytables=3.6.1=py38h9fd0a39_0 144 | - python=3.8.8=hdb3f193_5 145 | - python-dateutil=2.8.2=pyhd3eb1b0_0 146 | - python_abi=3.8=1_cp38 147 | - pytorch=1.8.1=py3.8_cuda11.1_cudnn8.0.5_0 148 | - pytz=2021.1=pyhd3eb1b0_0 149 | - qt=5.9.7=h5867ecd_1 150 | - readline=8.1=h27cfd23_0 151 | - regex=2021.8.3=py38h7f8727e_0 152 | - requests=2.26.0=pyhd3eb1b0_0 153 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 154 | - rsa=4.7.2=pyhd3eb1b0_1 155 | - s3transfer=0.5.0=pyhd3eb1b0_0 156 | - scipy=1.6.2=py38had2a1c9_1 157 | - setuptools=52.0.0=py38h06a4308_0 158 | - shellingham=1.3.1=py_0 159 | - sip=4.19.13=py38he6710b0_0 160 | - six=1.16.0=pyhd3eb1b0_0 161 | - smart_open=5.1.0=pyhd3eb1b0_0 162 | - spacy=3.1.0=py38hfc89cab_0 163 | - spacy-legacy=3.0.8=pyhd8ed1ab_0 164 | - sqlite=3.36.0=hc218d9a_0 165 | - srsly=2.4.1=py38h2531618_0 166 | - tensorboard=2.6.0=pyhd8ed1ab_1 167 | - tensorboard-data-server=0.6.0=py38h2b97feb_0 168 | - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 169 | - thinc=8.0.7=py38hfc89cab_0 170 | - tk=8.6.10=hbc83047_0 171 | - torchvision=0.9.1=py38_cu111 172 | - tornado=6.1=py38h27cfd23_0 173 | - typer=0.3.2=pyhd3eb1b0_0 174 | - typing-extensions=3.10.0.0=hd3eb1b0_0 175 | - typing_extensions=3.10.0.0=pyh06a4308_0 176 | - unidecode=1.1.1=py_0 177 | - wasabi=0.8.2=pyhd3eb1b0_0 178 | - werkzeug=2.0.1=pyhd8ed1ab_0 179 | - wheel=0.37.0=pyhd3eb1b0_0 180 | - x264=1!157.20191217=h7b6447c_0 181 | - xz=5.2.5=h7b6447c_0 182 | - yarl=1.5.1=py38h7b6447c_0 183 | - zipp=3.5.0=pyhd8ed1ab_0 184 | - zlib=1.2.11=h7b6447c_3 185 | - zstd=1.4.9=haebb681_0 186 | - pip: 187 | - data==0.4 188 | - decorator==5.0.9 189 | - funcsigs==1.0.2 190 | - future==0.18.2 191 | - gpustat==0.6.0 192 | - sentencepiece==0.1.95 193 | - shutilwhich==1.1.0 194 | - tempdir==0.7.1 195 | - torchtext==0.5.0 196 | - tqdm==4.60.0 197 | - urllib3==1.26.4 198 | -------------------------------------------------------------------------------- /Approximation comparison vs sampling/error.py: -------------------------------------------------------------------------------- 1 | from itertools import chain, combinations 2 | 3 | def powerset(iterable): 4 | "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" 5 | s = list(iterable) 6 | return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) 7 | 8 | from utils.utils import choose 9 | 10 | 11 | from collections import defaultdict 12 | import pandas as pd 13 | 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn.functional as F 18 | cos = F.cosine_similarity 19 | 20 | N = 3 21 | M = 5 22 | a = torch.arange(1, 1+M) 23 | b = 2 * a 24 | c = torch.div(1, torch.arange(1, 1+M)) 25 | d = torch.square(a) 26 | 27 | 28 | from torch.linalg import norm 29 | def v(coalition_v, grand_v, if_norm=False): 30 | coalition_v_ = torch.div(coalition_v, norm(coalition_v)) if (if_norm and norm(coalition_v) != 0) else coalition_v 31 | grand_v_ = torch.div(grand_v, norm(grand_v)) if (if_norm and norm(grand_v) != 0) else grand_v 32 | return cos(coalition_v_, grand_v_, 0) 33 | 34 | from math import factorial as fac 35 | 36 | 37 | def calculate_svs(vectors, N, d): 38 | grand = torch.stack(vectors).sum(dim=0) 39 | svs = torch.zeros(N) 40 | for coalition in powerset(range(N)): 41 | if not coalition: continue 42 | coalition_v = torch.zeros(d) 43 | for i in coalition: 44 | coalition_v += vectors[i] 45 | for i in coalition: 46 | with_i = v(coalition_v, grand) 47 | without_i = v(coalition_v - vectors[i], grand) 48 | 49 | svs[i] += 1.0 / choose(N-1, len(coalition)-1) * (with_i - without_i) 50 | return torch.div(svs, sum(svs)) 51 | 52 | from itertools import permutations 53 | from random import shuffle 54 | def calculate_sv_hats(vectors, N, d, K=30): 55 | grand = torch.stack(vectors).sum(dim=0) 56 | svs = torch.zeros(N) 57 | all_permutations = list(permutations(range(N))) 58 | shuffle(all_permutations) 59 | 60 | for permutation in all_permutations[:K]: 61 | permutation_v = torch.zeros(d) 62 | for i in permutation: 63 | without_i = v(permutation_v, grand) 64 | permutation_v += vectors[i] 65 | with_i = v(permutation_v, grand) 66 | svs[i] += with_i - without_i 67 | return torch.div(svs, sum(svs)) 68 | 69 | 70 | from scipy.stats import pearsonr 71 | 72 | pd.set_option('display.max_columns', None) 73 | pd.set_option('display.max_colwidth', None) 74 | 75 | trials = 50 76 | N = 10 77 | d = 1000 78 | 79 | 80 | from time import process_time 81 | 82 | def clock(name, start, time_dict): 83 | now = process_time() 84 | time_dict[name] += now - start 85 | return now 86 | 87 | results = defaultdict(list) 88 | 89 | def generate_random_vectors(N, d, uniform=True): 90 | 91 | if uniform: 92 | return [torch.randn(d) for i in range(N)] 93 | else: 94 | vectors = [] 95 | for i in range(N): 96 | rand_v = torch.zeros(d) 97 | for j in range(d): 98 | if j == 0: 99 | rand_v[j] = j 100 | else: 101 | rand_v[j] = (torch.randn(1) * rand_v[j-1])**2 + j 102 | rand_v += torch.randn(d) 103 | vectors.append(torch.div(rand_v, norm(rand_v)) ) 104 | return vectors 105 | 106 | def experiment(N, d, trials=50, epsilon=0.1, delta=0.1): 107 | time_dict = defaultdict(float) 108 | now = process_time() 109 | K = int(np.ceil( np.log(2.0/delta) * 1**2 / (2.0 * epsilon **2) )) 110 | 111 | # print("For epsilon {}, delta {}, sampling method needs {} samples.".format(epsilon, delta, K)) 112 | for i in range(trials): 113 | vectors = generate_random_vectors(N, d, True) 114 | grand = torch.stack(vectors).sum(dim=0) 115 | now = clock('init', now, time_dict) 116 | 117 | svs = calculate_svs(vectors, N, d) 118 | now = clock('true svs', now, time_dict) 119 | 120 | sv_hats = calculate_sv_hats(vectors, N, d, K) 121 | now = clock('sv hats', now, time_dict) 122 | 123 | cosines = torch.tensor([cos(v, grand, 0) for v in vectors]) 124 | cosines = torch.div(cosines, sum(cosines)) 125 | now = clock('cosines', now, time_dict) 126 | 127 | results['svs'].append(svs) 128 | results['sv_hats'].append(sv_hats) 129 | diff_cos = cosines - svs 130 | results['diff'].append(diff_cos) 131 | results['l1diff'].append(sum(np.abs(diff_cos)).item() ) 132 | results['l2diff'].append(norm(diff_cos).item()) 133 | results['cossim'].append( cos(cosines, svs, 0).item()) 134 | r, p = pearsonr(cosines, svs) 135 | results['pearsonr'].append(r) 136 | results['pearsonp'].append(p) 137 | 138 | diff_hat = svs - sv_hats 139 | results['diff_hat'].append(diff_hat) 140 | results['l1diff_hat'].append(sum(np.abs(diff_hat)).item() ) 141 | results['l2diff_hat'].append(norm(diff_hat).item()) 142 | results['cossim_hat'].append( cos(sv_hats, svs, 0).item()) 143 | r, p = pearsonr(sv_hats, svs) 144 | results['pearsonr_hat'].append(r) 145 | results['pearsonp_hat'].append(p) 146 | now = clock('results', now, time_dict) 147 | 148 | return results, time_dict 149 | 150 | import matplotlib.pyplot as plt 151 | 152 | trials = 10 153 | 154 | 155 | # Experiment vs N 156 | Nmin, Nmax = 5, 10 157 | d = 1000 158 | 159 | stats_dict = defaultdict(list) 160 | for n in range(5, Nmax+1): 161 | 162 | results, time_dict = experiment(n, d, trials=trials) 163 | df = pd.DataFrame(results, columns=['l1diff', 'l1diff_hat', 'l2diff', 'l2diff_hat', 'pearsonr', 'pearsonr_hat']) 164 | 165 | stats_dict['n'].append(n) 166 | for column in df.columns: 167 | stats_dict[column+'_mean'].append(df[column].mean()) 168 | stats_dict[column+'_std'].append(df[column].std()) 169 | 170 | stats_dict['cosines'].append(time_dict['cosines'] / trials) 171 | stats_dict['sv hats'].append(time_dict['sv hats'] / trials) 172 | 173 | stats_df = pd.DataFrame(stats_dict) 174 | stats_df.to_csv('error_vs_N={}-{}.csv'.format(Nmin, Nmax), index=False) 175 | 176 | 177 | # Experiment vs d 178 | dmin, dmax = 10, 15 179 | n = 10 180 | 181 | stats_dict = defaultdict(list) 182 | for d in range(10, dmax+1): 183 | d = 2**d 184 | results, time_dict = experiment(n, d, trials=trials) 185 | df = pd.DataFrame(results, columns=['l1diff', 'l1diff_hat', 'l2diff', 'l2diff_hat', 'pearsonr', 'pearsonr_hat']) 186 | 187 | stats_dict['d'].append(d) 188 | for column in df.columns: 189 | stats_dict[column+'_mean'].append(df[column].mean()) 190 | stats_dict[column+'_std'].append(df[column].std()) 191 | 192 | stats_dict['cosines'].append(time_dict['cosines'] / trials) 193 | stats_dict['sv hats'].append(time_dict['sv hats'] / trials) 194 | 195 | stats_df = pd.DataFrame(stats_dict) 196 | stats_df.to_csv('error_vs_d={}-{}.csv'.format(2**dmin, 2**dmax), index=False) 197 | 198 | exit() 199 | 200 | 201 | 202 | data = defaultdict(list) 203 | 204 | for coalition in powerset(range(N)): 205 | if not coalition: continue 206 | coalition_v = torch.zeros(M) 207 | for i in coalition: 208 | coalition_v += vectors[i] 209 | data['coalition'].append(coalition) 210 | data['utility'].append(cos(grand, coalition_v, 0).item() ) 211 | data['sum of svs'].append(sum([svs[i] for i in coalition]).item() ) 212 | 213 | df = pd.DataFrame(data) 214 | df['utility_left_over'] = df['utility'] - df['sum of svs'] 215 | df['efficient'] = df['utility_left_over'] == 0 216 | print(df) 217 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | import torch 4 | from torch import nn 5 | from torch.utils.data import DataLoader 6 | from torchtext.data import Batch 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | def compute_grad_update(old_model, new_model, device=None): 17 | # maybe later to implement on selected layers/parameters 18 | if device: 19 | old_model, new_model = old_model.to(device), new_model.to(device) 20 | return [(new_param.data - old_param.data) for old_param, new_param in zip(old_model.parameters(), new_model.parameters())] 21 | 22 | def add_update_to_model(model, update, weight=1.0, device=None): 23 | if not update: return model 24 | if device: 25 | model = model.to(device) 26 | update = [param.to(device) for param in update] 27 | 28 | for param_model, param_update in zip(model.parameters(), update): 29 | param_model.data += weight * param_update.data 30 | return model 31 | 32 | def add_gradient_updates(grad_update_1, grad_update_2, weight = 1.0): 33 | assert len(grad_update_1) == len( 34 | grad_update_2), "Lengths of the two grad_updates not equal" 35 | 36 | for param_1, param_2 in zip(grad_update_1, grad_update_2): 37 | param_1.data += param_2.data * weight 38 | 39 | 40 | def flatten(grad_update): 41 | return torch.cat([update.data.view(-1) for update in grad_update]) 42 | 43 | 44 | def unflatten(flattened, normal_shape): 45 | grad_update = [] 46 | for param in normal_shape: 47 | n_params = len(param.view(-1)) 48 | grad_update.append(torch.as_tensor(flattened[:n_params]).reshape(param.size()) ) 49 | flattened = flattened[n_params:] 50 | 51 | return grad_update 52 | 53 | 54 | def compute_distance_percentage(model, ref_model): 55 | percents, dists = [], [] 56 | for layer, ref_layer in zip(model.parameters(), ref_model.parameters()): 57 | dist = torch.norm(layer - ref_layer) 58 | dists.append(dist.item()) 59 | percents.append( (torch.div(dist, torch.norm(ref_layer))).item() ) 60 | 61 | return percents, dists 62 | 63 | 64 | 65 | def cosine_similarity(grad1, grad2, normalized=False): 66 | """ 67 | Input: two sets of gradients of the same shape 68 | Output range: [-1, 1] 69 | """ 70 | 71 | cos_sim = F.cosine_similarity(flatten(grad1), flatten(grad2), 0, 1e-10) 72 | if normalized: 73 | return (cos_sim + 1) / 2.0 74 | else: 75 | return cos_sim 76 | 77 | def evaluate(model, eval_loader, device, loss_fn=None, verbose=False): 78 | model.eval() 79 | model = model.to(device) 80 | correct = 0 81 | total = 0 82 | loss = 0 83 | 84 | with torch.no_grad(): 85 | for i, batch in enumerate(eval_loader): 86 | 87 | if isinstance(batch, Batch): 88 | batch_data, batch_target = batch.text, batch.label 89 | # batch_data.data.t_(), batch_target.data.sub_(1) # batch first, index align 90 | batch_data = batch_data.permute(1, 0) 91 | else: 92 | batch_data, batch_target = batch[0], batch[1] 93 | 94 | batch_data, batch_target = batch_data.to(device), batch_target.to(device) 95 | outputs = model(batch_data) 96 | 97 | if loss_fn: 98 | loss += loss_fn(outputs, batch_target) 99 | else: 100 | loss = None 101 | correct += (torch.max(outputs, 1)[1].view(batch_target.size()).data == batch_target.data).sum() 102 | total += len(batch_target) 103 | 104 | accuracy = correct.float() / total 105 | if loss_fn: 106 | loss /= total 107 | 108 | if verbose: 109 | print("Loss: {:.6f}. Accuracy: {:.4%}.".format(loss, accuracy)) 110 | return loss, accuracy 111 | 112 | from torchtext.data import Batch 113 | def train_model(model, loader, loss_fn, optimizer, device, E=1, **kwargs): 114 | 115 | model.train() 116 | for e in range(E): 117 | # running local epochs 118 | for _, batch in enumerate(loader): 119 | if isinstance(batch, Batch): 120 | data, label = batch.text, batch.label 121 | data = data.permute(1, 0) 122 | # data.data.t_(), label.data.sub_(1) # batch first, index align 123 | else: 124 | data, label = batch[0], batch[1] 125 | 126 | data, label = data.to(device), label.to(device) 127 | 128 | optimizer.zero_grad() 129 | pred = model(data) 130 | loss_fn(pred, label).backward() 131 | 132 | optimizer.step() 133 | 134 | if 'scheduler' in kwargs: kwargs['scheduler'].step() 135 | 136 | return model 137 | 138 | 139 | def mask_grad_update_by_order(grad_update, mask_order=None, mask_percentile=None, mode='all'): 140 | 141 | if mode == 'all': 142 | # mask all but the largest updates (by magnitude) to zero 143 | all_update_mod = torch.cat([update.data.view(-1).abs() 144 | for update in grad_update]) 145 | if not mask_order and mask_percentile is not None: 146 | mask_order = int(len(all_update_mod) * mask_percentile) 147 | 148 | if mask_order == 0: 149 | return mask_grad_update_by_magnitude(grad_update, float('inf')) 150 | else: 151 | topk, indices = torch.topk(all_update_mod, mask_order) 152 | return mask_grad_update_by_magnitude(grad_update, topk[-1]) 153 | 154 | elif mode == 'layer': # layer wise largest-values criterion 155 | grad_update = copy.deepcopy(grad_update) 156 | 157 | mask_percentile = max(0, mask_percentile) 158 | for i, layer in enumerate(grad_update): 159 | layer_mod = layer.data.view(-1).abs() 160 | if mask_percentile is not None: 161 | mask_order = math.ceil(len(layer_mod) * mask_percentile) 162 | 163 | if mask_order == 0: 164 | grad_update[i].data = torch.zeros(layer.data.shape, device=layer.device) 165 | else: 166 | topk, indices = torch.topk(layer_mod, min(mask_order, len(layer_mod)-1)) 167 | grad_update[i].data[layer.data.abs() < topk[-1]] = 0 168 | return grad_update 169 | 170 | def mask_grad_update_by_magnitude(grad_update, mask_constant): 171 | 172 | # mask all but the updates with larger magnitude than to zero 173 | # print('Masking all gradient updates with magnitude smaller than ', mask_constant) 174 | grad_update = copy.deepcopy(grad_update) 175 | for i, update in enumerate(grad_update): 176 | grad_update[i].data[update.data.abs() < mask_constant] = 0 177 | return grad_update 178 | 179 | 180 | import os 181 | from contextlib import contextmanager 182 | 183 | @contextmanager 184 | def cwd(path): 185 | oldpwd=os.getcwd() 186 | os.chdir(path) 187 | try: 188 | yield 189 | finally: 190 | os.chdir(oldpwd) 191 | 192 | 193 | ''' 194 | 195 | 196 | def sign(grad): 197 | return [torch.sign(update) for update in grad] 198 | def l2norm(grad): 199 | return torch.sqrt(torch.sum(torch.pow(flatten(grad), 2))) 200 | 201 | 202 | def cosine_similarity_modified(coalition_grad, coalition_grad_majority, grad_all, grad_all_majority, normalized=False, Lambda=0): 203 | sign_cossim = F.cosine_similarity(coalition_grad_majority, grad_all_majority, 0, 1e-10) 204 | modu_cossim = F.cosine_similarity(coalition_grad, grad_all, 0, 1e-10) 205 | 206 | return Lambda * sign_cossim + (1 - Lambda) * modu_cossim 207 | 208 | def mask_grad_update_by_indices(grad_update, indices=None): 209 | """ 210 | Mask the grad.data to be 0, if the position is not in the list of indices 211 | If indicies is empty, mask nothing. 212 | 213 | Arguments: 214 | grad_update: as in the shape of the model parameters. A list of tensors. 215 | indices: a tensor of integers, corresponding to the specific individual scalar values in the grad_update, 216 | as if the entire grad_update is flattened. 217 | 218 | e.g. 219 | grad_update = [[1, 2, 3], [3, 2, 1]] 220 | indices = [4, 5] 221 | returning masked grad_update = [[0, 0, 0], [0, 2, 1]] 222 | """ 223 | 224 | grad_update = copy.deepcopy(grad_update) 225 | if indices is None or len(indices)==0: return grad_update 226 | 227 | #flatten and unflatten 228 | flattened = torch.cat([update.data.view(-1) for update in grad_update]) 229 | masked = torch.zeros_like(torch.arange(len(flattened)), device=flattened.device).float() 230 | masked.data[indices] = flattened.data[indices] 231 | 232 | pointer = 0 233 | for m, update in enumerate(grad_update): 234 | size_of_update = torch.prod(torch.tensor(update.shape)).long() 235 | grad_update[m].data = masked[pointer: pointer + size_of_update].reshape(update.shape) 236 | pointer += size_of_update 237 | return grad_update 238 | 239 | 240 | from itertools import chain, combinations 241 | def powerset(iterable): 242 | "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" 243 | s = list(iterable) 244 | return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) 245 | 246 | from math import factorial as f 247 | def choose(n, r): 248 | return f(n) // f(r) // f(n-r) 249 | 250 | def clip_gradient_update(grad_update, grad_clip): 251 | """ 252 | Return a copy of clipped grad update 253 | 254 | """ 255 | return [torch.clamp(param.data, min=-grad_clip, max=grad_clip) for param in grad_update] 256 | 257 | ''' 258 | -------------------------------------------------------------------------------- /utils/Data_Prepper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import torch 8 | from torch.utils.data import DataLoader, Dataset 9 | from torch.utils.data.sampler import SubsetRandomSampler 10 | 11 | from torchvision.datasets import CIFAR10, CIFAR100 12 | from torchtext.data import Field, LabelField, BucketIterator 13 | 14 | 15 | class Data_Prepper: 16 | def __init__(self, name, train_batch_size, n_agents, 17 | sample_size_cap=-1, test_batch_size=100, valid_batch_size=None, 18 | train_val_split_ratio=0.8, device=None, args_dict=None): 19 | self.args = None 20 | self.args_dict = args_dict 21 | self.name = name 22 | self.device = device 23 | self.n_agents = n_agents 24 | self.sample_size_cap = sample_size_cap 25 | self.train_val_split_ratio = train_val_split_ratio 26 | 27 | self.init_batch_size(train_batch_size, test_batch_size, valid_batch_size) 28 | 29 | if name in ['sst', 'mr']: 30 | parser = argparse.ArgumentParser(description='CNN text classificer') 31 | self.args = {} 32 | 33 | self.train_datasets, self.validation_dataset, self.test_dataset = self.prepare_dataset(name) 34 | 35 | self.valid_loader = BucketIterator(self.validation_dataset, batch_size = 500, sort_key=lambda x: len(x.text), device=self.device ) 36 | self.test_loader = BucketIterator(self.test_dataset, batch_size = 500, sort_key=lambda x: len(x.text), device=self.device) 37 | 38 | self.args['embed_dim'] = self.args_dict['embed_dim'] 39 | self.args['kernel_num'] = self.args_dict['kernel_num'] 40 | self.args['kernel_sizes'] = self.args_dict['kernel_sizes'] 41 | self.args['static'] = self.args_dict['static'] 42 | 43 | train_size = sum([len(train_dataset) for train_dataset in self.train_datasets]) 44 | if self.n_agents > 5: 45 | print("Splitting all {} train data to {} parties. Caution against this due to the limited training size.".format(train_size, self.n_agents)) 46 | print("Model embedding arguments:", self.args) 47 | print('------') 48 | print("Train to split size: {}. Validation size: {}. Test size: {}".format(train_size, len(self.validation_dataset), len(self.test_dataset))) 49 | print('------') 50 | 51 | 52 | elif name in ['mnist', 'cifar10']: 53 | self.train_dataset, self.validation_dataset, self.test_dataset = self.prepare_dataset(name) 54 | 55 | print('------') 56 | print("Train to split size: {}. Validation size: {}. Test size: {}".format(len(self.train_dataset), len(self.validation_dataset), len(self.test_dataset))) 57 | print('------') 58 | 59 | self.valid_loader = DataLoader(self.validation_dataset, batch_size=self.test_batch_size) 60 | self.test_loader = DataLoader(self.test_dataset, batch_size=self.test_batch_size) 61 | 62 | else: 63 | raise NotImplementedError 64 | 65 | 66 | def init_batch_size(self, train_batch_size, test_batch_size, valid_batch_size): 67 | self.train_batch_size = train_batch_size 68 | self.test_batch_size = test_batch_size 69 | self.valid_batch_size = valid_batch_size if valid_batch_size else test_batch_size 70 | 71 | def get_valid_loader(self): 72 | return self.valid_loader 73 | 74 | def get_test_loader(self): 75 | return self.test_loader 76 | 77 | def get_train_loaders(self, n_agents, split='powerlaw', batch_size=None): 78 | if not batch_size: 79 | batch_size = self.train_batch_size 80 | 81 | if self.name not in ['sst', 'mr', 'mnist', 'cifar10']: raise NotImplementedError 82 | 83 | if self.name in ['sst', 'mr']: 84 | # sst, mr split is different from other datasets, so return here 85 | 86 | self.train_loaders = [BucketIterator(train_dataset, batch_size=self.train_batch_size, device=self.device, sort_key=lambda x: len(x.text),train=True) for train_dataset in self.train_datasets] 87 | self.shard_sizes = [(len(train_dataset)) for train_dataset in self.train_datasets] 88 | return self.train_loaders 89 | 90 | elif self.name in ['mnist', 'cifar10']: 91 | 92 | if split == 'classimbalance': 93 | if self.name not in ['mnist','cifar10']: 94 | raise NotImplementedError("Calling on dataset {}. Only mnist and cifar10 are implemnted for this split".format(self.name)) 95 | 96 | n_classes = 10 97 | data_indices = [torch.nonzero(self.train_dataset.targets == class_id).view(-1).tolist() for class_id in range(n_classes)] 98 | class_sizes = np.linspace(1, n_classes, n_agents, dtype='int') 99 | print("class_sizes for each party", class_sizes) 100 | party_mean = self.sample_size_cap // self.n_agents 101 | 102 | from collections import defaultdict 103 | party_indices = defaultdict(list) 104 | for party_id, class_sz in enumerate(class_sizes): 105 | classes = range(class_sz) # can customize classes for each party rather than just listing 106 | each_class_id_size = party_mean // class_sz 107 | # print("party each class size:", party_id, each_class_id_size) 108 | for i, class_id in enumerate(classes): 109 | # randomly pick from each class a certain number of samples, with replacement 110 | selected_indices = random.choices(data_indices[class_id], k=each_class_id_size) 111 | 112 | # randomly pick from each class a certain number of samples, without replacement 113 | ''' 114 | NEED TO MAKE SURE THAT EACH CLASS HAS MORE THAN each_class_id_size for no replacement sampling 115 | selected_indices = random.sample(data_indices[class_id],k=each_class_id_size) 116 | ''' 117 | party_indices[party_id].extend(selected_indices) 118 | 119 | # top up to make sure all parties have the same number of samples 120 | if i == len(classes) - 1 and len(party_indices[party_id]) < party_mean: 121 | extra_needed = party_mean - len(party_indices[party_id]) 122 | party_indices[party_id].extend(data_indices[class_id][:extra_needed]) 123 | data_indices[class_id] = data_indices[class_id][extra_needed:] 124 | 125 | indices_list = [party_index_list for party_id, party_index_list in party_indices.items()] 126 | 127 | elif split == 'powerlaw': 128 | indices_list = powerlaw(list(range(len(self.train_dataset))), n_agents) 129 | 130 | elif split in ['uniform']: 131 | indices_list = random_split(sample_indices=list(range(len(self.train_dataset))), m_bins=n_agents, equal=True) 132 | 133 | self.train_datasets = [Custom_Dataset(self.train_dataset.data[indices],self.train_dataset.targets[indices]) for indices in indices_list] 134 | 135 | self.shard_sizes = [len(indices) for indices in indices_list] 136 | agent_train_loaders = [DataLoader(self.train_dataset, batch_size=batch_size, sampler=SubsetRandomSampler(indices)) for indices in indices_list] 137 | self.train_loaders = agent_train_loaders 138 | return agent_train_loaders 139 | 140 | def prepare_dataset(self, name='mnist'): 141 | 142 | if name == 'mnist': 143 | 144 | train = FastMNIST('.data', train=True, download=True) 145 | test = FastMNIST('.data', train=False, download=True) 146 | 147 | train_indices, valid_indices = get_train_valid_indices(len(train), self.train_val_split_ratio, self.sample_size_cap) 148 | 149 | train_set = Custom_Dataset(train.data[train_indices], train.targets[train_indices], device=self.device) 150 | validation_set = Custom_Dataset(train.data[valid_indices],train.targets[valid_indices] , device=self.device) 151 | test_set = Custom_Dataset(test.data, test.targets, device=self.device) 152 | 153 | del train, test 154 | 155 | return train_set, validation_set, test_set 156 | 157 | elif name == 'cifar10': 158 | 159 | train = FastCIFAR10('.data', train=True, download=True)#, transform=transform_train) 160 | test = FastCIFAR10('.data', train=False, download=True)#, transform=transform_test) 161 | 162 | train_indices, valid_indices = get_train_valid_indices(len(train), self.train_val_split_ratio, self.sample_size_cap) 163 | 164 | train_set = Custom_Dataset(train.data[train_indices], train.targets[train_indices], device=self.device) 165 | validation_set = Custom_Dataset(train.data[valid_indices],train.targets[valid_indices] , device=self.device) 166 | test_set = Custom_Dataset(test.data, test.targets, device=self.device) 167 | del train, test 168 | 169 | return train_set, validation_set, test_set 170 | 171 | elif name == "sst": 172 | import torchtext.data as data 173 | text_field = data.Field(lower=True) 174 | from torch import long as torch_long 175 | label_field = LabelField(dtype = torch_long, sequential=False) 176 | 177 | import torchtext.datasets as datasets 178 | train_data, validation_data, test_data = datasets.SST.splits(text_field, label_field, root='.data', fine_grained=True) 179 | 180 | if self.args_dict['split'] == 'uniform': 181 | indices_list = random_split(sample_indices=list(range(len(train_data))), m_bins=self.n_agents, equal=True) 182 | else: 183 | indices_list = powerlaw(list(range(len(train_data))), self.n_agents) 184 | ratios = [len(indices) / len(train_data) for indices in indices_list] 185 | 186 | train_datasets = split_torchtext_dataset_ratios(train_data, ratios) 187 | 188 | text_field.build_vocab(*(train_datasets + [validation_data, test_data])) 189 | label_field.build_vocab(*(train_datasets + [validation_data, test_data])) 190 | 191 | self.args['embed_num'] = len(text_field.vocab) 192 | self.args['class_num'] = len(label_field.vocab) 193 | 194 | return train_datasets, validation_data, test_data 195 | 196 | elif name == 'mr': 197 | 198 | import torchtext.data as data 199 | from utils import mrdatasets 200 | 201 | text_field = data.Field(lower=True) 202 | from torch import long as torch_long 203 | label_field = LabelField(dtype = torch_long, sequential=False) 204 | # label_field = data.Field(sequential=False) 205 | 206 | train_data, dev_data = mrdatasets.MR.splits(text_field, label_field, root='.data', shuffle=False) 207 | 208 | validation_data, test_data = dev_data.split(split_ratio=0.5, random_state = random.seed(1234)) 209 | 210 | if self.args_dict['split'] == 'uniform': 211 | indices_list = random_split(sample_indices=list(range(len(train_data))), m_bins=self.n_agents, equal=True) 212 | else: 213 | indices_list = powerlaw(list(range(len(train_data))), self.n_agents) 214 | 215 | ratios = [len(indices) / len(train_data) for indices in indices_list] 216 | 217 | train_datasets = split_torchtext_dataset_ratios(train_data, ratios) 218 | 219 | text_field.build_vocab( *(train_datasets + [validation_data, test_data] )) 220 | label_field.build_vocab( *(train_datasets + [validation_data, test_data] )) 221 | 222 | 223 | self.args['embed_num'] = len(text_field.vocab) 224 | self.args['class_num'] = len(label_field.vocab) 225 | 226 | return train_datasets, validation_data, test_data 227 | else: 228 | raise NotImplementedError 229 | 230 | from torchvision.datasets import MNIST 231 | class FastMNIST(MNIST): 232 | def __init__(self, *args, **kwargs): 233 | super().__init__(*args, **kwargs) 234 | 235 | self.data = self.data.unsqueeze(1).float().div(255) 236 | from torch.nn import ZeroPad2d 237 | pad = ZeroPad2d(2) 238 | self.data = torch.stack([pad(sample.data) for sample in self.data]) 239 | 240 | self.targets = self.targets.long() 241 | 242 | self.data = self.data.sub_(self.data.mean()).div_(self.data.std()) 243 | # self.data = self.data.sub_(0.1307).div_(0.3081) 244 | # Put both data and targets on GPU in advance 245 | self.data, self.targets = self.data, self.targets 246 | print('MNIST data shape {}, targets shape {}'.format(self.data.shape, self.targets.shape)) 247 | 248 | def __getitem__(self, index): 249 | """ 250 | Args: 251 | index (int): Index 252 | 253 | Returns: 254 | tuple: (image, target) where target is index of the target class. 255 | """ 256 | img, target = self.data[index], self.targets[index] 257 | 258 | return img, target 259 | 260 | from torchvision.datasets import CIFAR10, CIFAR100 261 | class FastCIFAR10(CIFAR10): 262 | def __init__(self, *args, **kwargs): 263 | super().__init__(*args, **kwargs) 264 | 265 | # Scale data to [0,1] 266 | from torch import from_numpy 267 | self.data = from_numpy(self.data) 268 | self.data = self.data.float().div(255) 269 | self.data = self.data.permute(0, 3, 1, 2) 270 | 271 | self.targets = torch.Tensor(self.targets).long() 272 | 273 | 274 | # https://github.com/kuangliu/pytorch-cifar/issues/16 275 | # https://github.com/kuangliu/pytorch-cifar/issues/8 276 | for i, (mean, std) in enumerate(zip((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))): 277 | self.data[:,i].sub_(mean).div_(std) 278 | 279 | # Put both data and targets on GPU in advance 280 | self.data, self.targets = self.data, self.targets 281 | print('CIFAR10 data shape {}, targets shape {}'.format(self.data.shape, self.targets.shape)) 282 | 283 | def __getitem__(self, index): 284 | """ 285 | Args: 286 | index (int): Index 287 | 288 | Returns: 289 | tuple: (image, target) where target is index of the target class. 290 | """ 291 | img, target = self.data[index], self.targets[index] 292 | 293 | return img, target 294 | 295 | 296 | class Custom_Dataset(Dataset): 297 | 298 | def __init__(self, X, y, device=None, transform=None): 299 | self.data = X.to(device) 300 | self.targets = y.to(device) 301 | self.count = len(X) 302 | self.device = device 303 | self.transform = transform 304 | 305 | def __len__(self): 306 | return self.count 307 | 308 | def __getitem__(self, idx): 309 | if self.transform: 310 | return self.transform(self.data[idx]), self.targets[idx] 311 | 312 | return self.data[idx], self.targets[idx] 313 | 314 | 315 | 316 | def random_split(sample_indices, m_bins, equal=True): 317 | np.random.seed(1111) 318 | sample_indices = np.asarray(sample_indices) 319 | if equal: 320 | indices_list = np.array_split(sample_indices, m_bins) 321 | else: 322 | split_points = np.random.choice( 323 | n_samples - 2, m_bins - 1, replace=False) + 1 324 | split_points.sort() 325 | indices_list = np.split(sample_indices, split_points) 326 | 327 | return indices_list 328 | 329 | def powerlaw(sample_indices, n_agents, alpha=1.65911332899, shuffle=False): 330 | # the smaller the alpha, the more extreme the division 331 | if shuffle: 332 | random.seed(1234) 333 | random.shuffle(sample_indices) 334 | 335 | from scipy.stats import powerlaw 336 | import math 337 | party_size = int(len(sample_indices) / n_agents) 338 | b = np.linspace(powerlaw.ppf(0.01, alpha), powerlaw.ppf(0.99, alpha), n_agents) 339 | shard_sizes = list(map(math.ceil, b/sum(b)*party_size*n_agents)) 340 | indices_list = [] 341 | accessed = 0 342 | for agent_id in range(n_agents): 343 | indices_list.append(sample_indices[accessed:accessed + shard_sizes[agent_id]]) 344 | accessed += shard_sizes[agent_id] 345 | return indices_list 346 | 347 | 348 | def get_train_valid_indices(n_samples, train_val_split_ratio, sample_size_cap=None): 349 | indices = list(range(n_samples)) 350 | random.seed(1111) 351 | random.shuffle(indices) 352 | split_point = int(n_samples * train_val_split_ratio) 353 | train_indices, valid_indices = indices[:split_point], indices[split_point:] 354 | if sample_size_cap is not None: 355 | train_indices = indices[:min(split_point, sample_size_cap)] 356 | 357 | return train_indices, valid_indices 358 | 359 | 360 | def split_torchtext_dataset_ratios(data, ratios): 361 | train_datasets = [] 362 | while len(ratios) > 1: 363 | 364 | split_ratio = ratios[0] / sum(ratios) 365 | ratios.pop(0) 366 | train_dataset, data = data.split(split_ratio=split_ratio, random_state=random.seed(1234)) 367 | train_datasets.append(train_dataset) 368 | train_datasets.append(data) 369 | return train_datasets -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The main driving code 3 | 4 | 1. CML/FL Training 5 | 6 | 2. Compute/Approximate Cosine Gradient Shapley 7 | 8 | 3. Calculate and realize the fair gradient reward 9 | 10 | ''' 11 | 12 | import os, sys, json 13 | from os.path import join as oj 14 | import copy 15 | from copy import deepcopy as dcopy 16 | import time, datetime, random, pickle 17 | from collections import defaultdict 18 | from itertools import product 19 | import numpy as np 20 | import pandas as pd 21 | 22 | import torch 23 | from torch import nn, optim 24 | from torch.linalg import norm 25 | from torchtext.data import Batch 26 | import torch.nn.functional as F 27 | 28 | 29 | from utils.Data_Prepper import Data_Prepper 30 | from utils.arguments import mnist_args, cifar_cnn_args, mr_args, sst_args 31 | 32 | from utils.utils import cwd, train_model, evaluate, cosine_similarity, mask_grad_update_by_order, \ 33 | compute_grad_update, add_update_to_model, add_gradient_updates,\ 34 | flatten, unflatten, compute_distance_percentage 35 | 36 | 37 | import argparse 38 | 39 | parser = argparse.ArgumentParser(description='Process which dataset to run') 40 | parser.add_argument('-D', '--dataset', help='Pick the dataset to run.', type=str, required=True) 41 | parser.add_argument('-N', '--n_agents', help='The number of agents.', type=int, default=5) 42 | 43 | parser.add_argument('-nocuda', dest='cuda', help='Not to use cuda even if available.', action='store_false') 44 | parser.add_argument('-cuda', dest='cuda', help='Use cuda if available.', action='store_true') 45 | 46 | 47 | parser.add_argument('-split', '--split', dest='split', help='The type of data splits.', type=str, default='all', choices=['all', 'uni', 'cla', 'pow']) 48 | 49 | cmd_args = parser.parse_args() 50 | 51 | print(cmd_args) 52 | 53 | N = cmd_args.n_agents 54 | 55 | if torch.cuda.is_available() and cmd_args.cuda: 56 | device = torch.device('cuda') 57 | else: 58 | device = torch.device('cpu') 59 | 60 | if cmd_args.dataset == 'mnist': 61 | args = copy.deepcopy(mnist_args) 62 | 63 | if N > 0: 64 | agent_iterations = [[N, N*600]] 65 | else: 66 | agent_iterations = [[5,3000], [10, 6000], [20, 12000]] 67 | 68 | if cmd_args.split == 'uni': 69 | splits = ['uniform'] 70 | 71 | elif cmd_args.split == 'pow': 72 | splits = ['powerlaw'] 73 | 74 | elif cmd_args.split == 'cla': 75 | splits = ['classimbalance'] 76 | 77 | elif cmd_args.split == 'all': 78 | splits = ['uniform', 'powerlaw', 'classimbalance',] 79 | 80 | args['iterations'] = 200 81 | args['E'] = 3 82 | args['lr'] = 1e-3 83 | args['num_classes'] = 10 84 | args['lr_decay'] = 0.955 85 | 86 | elif cmd_args.dataset == 'cifar10': 87 | args = copy.deepcopy(cifar_cnn_args) 88 | 89 | if N > 0: 90 | agent_iterations = [[N, N*2000]] 91 | else: 92 | agent_iterations = [[10, 20000]] 93 | 94 | if cmd_args.split == 'uni': 95 | splits = ['uniform'] 96 | 97 | elif cmd_args.split == 'pow': 98 | splits = ['powerlaw'] 99 | 100 | elif cmd_args.split == 'cla': 101 | splits = ['classimbalance'] 102 | 103 | elif cmd_args.split == 'all': 104 | splits = ['uniform', 'powerlaw', 'classimbalance'] 105 | 106 | args['iterations'] = 200 107 | args['E'] = 3 108 | args['num_classes'] = 10 109 | 110 | elif cmd_args.dataset == 'sst': 111 | args = copy.deepcopy(sst_args) 112 | agent_iterations = [[5, 8000]] 113 | splits = ['powerlaw'] 114 | args['iterations'] = 200 115 | args['E'] = 3 116 | args['num_classes'] = 5 117 | 118 | elif cmd_args.dataset == 'mr': 119 | args = copy.deepcopy(mr_args) 120 | agent_iterations = [[5, 8000]] 121 | splits = ['powerlaw'] 122 | args['iterations'] = 200 123 | args['E'] = 3 124 | args['num_classes'] = 2 125 | 126 | 127 | E = args['E'] 128 | 129 | ts = time.time() 130 | time_str = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d-%H:%M') 131 | 132 | for N, sample_size_cap in agent_iterations: 133 | 134 | args.update(vars(cmd_args)) 135 | 136 | 137 | args['n_agents'] = N 138 | args['sample_size_cap'] = sample_size_cap 139 | # args['momentum'] = 1.5 / N 140 | 141 | for beta in [0.5, 1, 1.2, 1.5, 2, 1e7]: 142 | args['beta'] = beta 143 | 144 | for split in splits: 145 | args['split'] = split 146 | 147 | optimizer_fn = args['optimizer_fn'] 148 | loss_fn = args['loss_fn'] 149 | 150 | print(args) 151 | print("Data Split information for the agents:") 152 | data_prepper = Data_Prepper( 153 | args['dataset'], train_batch_size=args['batch_size'], n_agents=N, sample_size_cap=args['sample_size_cap'], 154 | train_val_split_ratio=args['train_val_split_ratio'], device=device, args_dict=args) 155 | 156 | # valid_loader = data_prepper.get_valid_loader() 157 | test_loader = data_prepper.get_test_loader() 158 | 159 | train_loaders = data_prepper.get_train_loaders(N, args['split']) 160 | shard_sizes = data_prepper.shard_sizes 161 | 162 | 163 | # shard sizes refer to the sizes of the local data of each agent 164 | shard_sizes = torch.tensor(shard_sizes).float() 165 | relative_shard_sizes = torch.div(shard_sizes, torch.sum(shard_sizes)) 166 | print("Shard sizes are: ", shard_sizes.tolist()) 167 | 168 | if args['dataset'] in ['mr', 'sst']: 169 | server_model = args['model_fn'](args=data_prepper.args).to(device) 170 | else: 171 | server_model = args['model_fn']().to(device) 172 | 173 | D = sum([p.numel() for p in server_model.parameters()]) 174 | init_backup = dcopy(server_model) 175 | 176 | # ---- init the agents ---- 177 | agent_models, agent_optimizers, agent_schedulers = [], [], [] 178 | 179 | for i in range(N): 180 | model = copy.deepcopy(server_model) 181 | # try: 182 | # optimizer = optimizer_fn(model.parameters(), lr=args['lr'], momentum=args['momentum']) 183 | # except: 184 | 185 | optimizer = optimizer_fn(model.parameters(), lr=args['lr']) 186 | 187 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 200, 300], gamma=0.1) 188 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = args['lr_decay']) 189 | 190 | agent_models.append(model) 191 | agent_optimizers.append(optimizer) 192 | agent_schedulers.append(scheduler) 193 | 194 | 195 | # ---- book-keeping variables 196 | 197 | rs_dict, qs_dict = [], [] 198 | rs = torch.zeros(N, device=device) 199 | past_phis = [] 200 | 201 | # for performance analysis 202 | valid_perfs, local_perfs, fed_perfs = defaultdict(list), defaultdict(list), defaultdict(list) 203 | 204 | # for gradient/model parameter analysis 205 | dist_all_layer, dist_last_layer = defaultdict(list), defaultdict(list) 206 | reward_all_layer, reward_last_layer= defaultdict(list), defaultdict(list) 207 | 208 | # ---- CML/FL begins ---- 209 | for iteration in range(args['iterations']): 210 | 211 | gradients = [] 212 | for i in range(N): 213 | loader = train_loaders[i] 214 | model = agent_models[i] 215 | optimizer = agent_optimizers[i] 216 | scheduler = agent_schedulers[i] 217 | 218 | model.train() 219 | model = model.to(device) 220 | 221 | backup = copy.deepcopy(model) 222 | 223 | model = train_model(model, loader, loss_fn, optimizer, device=device, E=E, scheduler=scheduler) 224 | 225 | gradient = compute_grad_update(old_model=backup, new_model=model, device=device) 226 | 227 | 228 | # SUPPOSE DO NOT TOP UP WITH OWN GRADIENTS 229 | model.load_state_dict(backup.state_dict()) 230 | # add_update_to_model(model, gradient, device=device) 231 | 232 | # append the normalzied gradient 233 | flattened = flatten(gradient) 234 | norm_value = norm(flattened) + 1e-7 # to prevent division by zero 235 | 236 | gradient = unflatten(torch.multiply(torch.tensor(args['Gamma']), torch.div(flattened, norm_value)), gradient) 237 | gradients.append(gradient) 238 | 239 | 240 | # ---- Server Aggregate ---- 241 | 242 | aggregated_gradient = [torch.zeros(param.shape).to(device) for param in server_model.parameters()] 243 | 244 | # aggregate and update server model 245 | 246 | if iteration == 0: 247 | # first iteration use FedAvg 248 | weights = torch.div(shard_sizes, torch.sum(shard_sizes)) 249 | else: 250 | weights = rs 251 | 252 | for gradient, weight in zip(gradients, weights): 253 | add_gradient_updates(aggregated_gradient, gradient, weight=weight) 254 | 255 | add_update_to_model(server_model, aggregated_gradient) 256 | 257 | # update reputation and calculate reward gradients 258 | flat_aggre_grad = flatten(aggregated_gradient) 259 | 260 | # phis = torch.zeros(N, device=device) 261 | phis = torch.tensor([F.cosine_similarity(flatten(gradient), flat_aggre_grad, 0, 1e-10) for gradient in gradients], device=device) 262 | past_phis.append(phis) 263 | 264 | rs = args['alpha'] * rs + (1 - args['alpha']) * phis 265 | 266 | rs = torch.clamp(rs, min=1e-3) # make sure the rs do not go negative 267 | rs = torch.div(rs, rs.sum()) # normalize the weights to 1 268 | 269 | # --- altruistic degree function 270 | q_ratios = torch.tanh(args['beta'] * rs) 271 | q_ratios = torch.div(q_ratios, torch.max(q_ratios)) 272 | 273 | qs_dict.append(q_ratios) 274 | rs_dict.append(rs) 275 | 276 | 277 | for i in range(N): 278 | 279 | reward_gradient = mask_grad_update_by_order(aggregated_gradient, mask_percentile=q_ratios[i], mode='layer') 280 | 281 | add_update_to_model(agent_models[i], reward_gradient) 282 | 283 | 284 | ''' Analysis of rewarded gradients in terms cosine to the aggregated gradient ''' 285 | reward_all_layer[str(i)+'cos'].append(F.cosine_similarity(flatten(reward_gradient), flat_aggre_grad, 0, 1e-10).item() ) 286 | reward_all_layer[str(i)+'l2'].append(norm(flatten(reward_gradient) - flat_aggre_grad).item()) 287 | 288 | reward_last_layer[str(i)+'cos'].append(F.cosine_similarity(flatten(reward_gradient[-2]), flatten(aggregated_gradient[-2]), 0, 1e-10).item() ) 289 | reward_last_layer[str(i)+'l2'].append(norm(flatten(reward_gradient[-2])- flatten(aggregated_gradient[-2])).item()) 290 | 291 | 292 | weights = torch.div(shard_sizes, torch.sum(shard_sizes)) if iteration == 0 else rs 293 | 294 | for i, model in enumerate(agent_models + [server_model]): 295 | 296 | loss, accuracy = evaluate(model, test_loader, loss_fn=loss_fn, device=device) 297 | 298 | valid_perfs[str(i)+'_loss'].append(loss.item()) 299 | valid_perfs[str(i)+'_accu'].append(accuracy.item()) 300 | 301 | fed_loss, fed_accu = 0, 0 302 | for j, train_loader in enumerate(train_loaders): 303 | loss, accuracy = evaluate(model, train_loader, loss_fn=loss_fn, device=device) 304 | 305 | fed_loss += weights[j] * loss.item() 306 | fed_accu += weights[j] * accuracy.item() 307 | if j == i: 308 | local_perfs[str(i)+'_loss'].append(loss.item()) 309 | local_perfs[str(i)+'_accu'].append(accuracy.item()) 310 | 311 | fed_perfs[str(i)+'_loss'].append(fed_loss.item()) 312 | fed_perfs[str(i)+'_accu'].append(fed_accu.item()) 313 | 314 | # ---- Record model distance to the server model ---- 315 | for i, model in enumerate(agent_models + [init_backup]) : 316 | 317 | percents, dists = compute_distance_percentage(model, server_model) 318 | 319 | dist_all_layer[str(i)+'dist'].append(np.mean(dists)) 320 | dist_last_layer[str(i)+'dist'].append(dists[-1]) 321 | 322 | dist_all_layer[str(i)+'perc'].append(np.mean(percents)) 323 | dist_last_layer[str(i)+'perc'].append(percents[-1]) 324 | 325 | 326 | # Saving results, into csvs 327 | agent_str = '{}-{}'.format(args['split'][:3].upper(), 'A'+str(N), ) 328 | 329 | folder = oj('RESULTS', args['dataset'], time_str, agent_str, 330 | 'beta-{}'.format(str(args['beta'])[:4]) ) 331 | 332 | os.makedirs(folder, exist_ok=True) 333 | 334 | with cwd(folder): 335 | 336 | # distance to the full gradient: all layers and only last layer of the model parameters 337 | pd.DataFrame(reward_all_layer).to_csv(('all_layer.csv'), index=False) 338 | 339 | pd.DataFrame(reward_last_layer).to_csv(('last_layer.csv'), index=False) 340 | 341 | # distance to server model parameters: all layers and only last layer of the model parameters 342 | pd.DataFrame(dist_all_layer).to_csv(('dist_all_layer.csv'), index=False) 343 | 344 | pd.DataFrame(dist_last_layer).to_csv(('dist_last_layer.csv'), index=False) 345 | 346 | 347 | # importance coefficients rs 348 | rs_dict = torch.stack(rs_dict).detach().cpu().numpy() 349 | df = pd.DataFrame(rs_dict) 350 | df.to_csv(('rs.csv'), index=False) 351 | 352 | # q values 353 | qs_dict = torch.stack(qs_dict).detach().cpu().numpy() 354 | df = pd.DataFrame(qs_dict) 355 | df.to_csv(('qs.csv'), index=False) 356 | 357 | # federated performance (local objectives weighted w.r.t the importance coefficient rs) 358 | df = pd.DataFrame(fed_perfs) 359 | df.to_csv(('fed.csv'), index=False) 360 | 361 | # validation performance 362 | df = pd.DataFrame(valid_perfs) 363 | df.to_csv(('valid.csv'), index=False) 364 | 365 | # local performance (only on local training set) 366 | df = pd.DataFrame(local_perfs) 367 | df.to_csv(('local.csv'), index=False) 368 | 369 | # store settings 370 | with open(('settings_dict.txt'), 'w') as file: 371 | [file.write(key + ' : ' + str(value) + '\n') for key, value in args.items()] 372 | 373 | with open(('settings_dict.pickle'), 'wb') as f: 374 | pickle.dump(args, f) --------------------------------------------------------------------------------