├── unbalanced_loss ├── __init__.py ├── label_smoothing.py ├── weight_ce_loss.py ├── GHM_loss.py ├── dice_loss_nlp.py ├── focal_loss.py └── dice_loss.py ├── out ├── acc.png ├── eval_loss.png ├── train_loss.png ├── train_loss_acc.png ├── ATTypePGD_UseATTrue.csv ├── ATTypeFGSM_UseATTrue.csv ├── ATTypeFGM_UseATTrue.csv ├── ATTypeFGM_UseATFalse.csv ├── ATTypeFreeAT_UseATTrue_epsilon0.8.csv └── plot_pic.py ├── bert_model ├── __pycache__ │ ├── model.cpython-38.pyc │ └── dataloader.cpython-38.pyc ├── model.py ├── dataloader.py └── train.py ├── adversarial_training ├── __pycache__ │ ├── FGM.cpython-38.pyc │ ├── FGSM.cpython-38.pyc │ ├── PGD.cpython-38.pyc │ └── FreeAT.cpython-38.pyc ├── README.md ├── FGSM.py ├── FGM.py ├── FreeAT.py └── PGD.py ├── data └── THUCNews │ └── news │ └── class.txt ├── scripts └── run_at.sh ├── LICENSE ├── test_loss.py ├── README.md └── test_bert.py /unbalanced_loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /out/acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyi-code/NLP-Loss-Pytorch/HEAD/out/acc.png -------------------------------------------------------------------------------- /out/eval_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyi-code/NLP-Loss-Pytorch/HEAD/out/eval_loss.png -------------------------------------------------------------------------------- /out/train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyi-code/NLP-Loss-Pytorch/HEAD/out/train_loss.png -------------------------------------------------------------------------------- /out/train_loss_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyi-code/NLP-Loss-Pytorch/HEAD/out/train_loss_acc.png -------------------------------------------------------------------------------- /bert_model/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyi-code/NLP-Loss-Pytorch/HEAD/bert_model/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /bert_model/__pycache__/dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyi-code/NLP-Loss-Pytorch/HEAD/bert_model/__pycache__/dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /adversarial_training/__pycache__/FGM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyi-code/NLP-Loss-Pytorch/HEAD/adversarial_training/__pycache__/FGM.cpython-38.pyc -------------------------------------------------------------------------------- /adversarial_training/__pycache__/FGSM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyi-code/NLP-Loss-Pytorch/HEAD/adversarial_training/__pycache__/FGSM.cpython-38.pyc -------------------------------------------------------------------------------- /adversarial_training/__pycache__/PGD.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyi-code/NLP-Loss-Pytorch/HEAD/adversarial_training/__pycache__/PGD.cpython-38.pyc -------------------------------------------------------------------------------- /adversarial_training/__pycache__/FreeAT.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyi-code/NLP-Loss-Pytorch/HEAD/adversarial_training/__pycache__/FreeAT.cpython-38.pyc -------------------------------------------------------------------------------- /data/THUCNews/news/class.txt: -------------------------------------------------------------------------------- 1 | finance 0 2 | realty 1 3 | stocks 2 4 | education 3 5 | science 6 | society 5 7 | politics 6 8 | sports 7 9 | game 8 10 | entertainment 9 -------------------------------------------------------------------------------- /scripts/run_at.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | cd .. 3 | 4 | python test_bert.py --AT_type FGM --use_attack 1 --epoch 32 5 | 6 | #echo "train FreeAT" 7 | #python test_bert.py --AT_type FreeAT --use_attack 1 --epoch 32 8 | -------------------------------------------------------------------------------- /unbalanced_loss/label_smoothing.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class LabelSmoothingCrossEntropy(nn.Module): 5 | def __init__(self, eps=0.1, reduction='mean',ignore_index=-100): 6 | super(LabelSmoothingCrossEntropy, self).__init__() 7 | self.eps = eps 8 | self.reduction = reduction 9 | self.ignore_index = ignore_index 10 | 11 | def forward(self, output, target): 12 | c = output.size()[-1] 13 | log_preds = F.log_softmax(output, dim=-1) 14 | if self.reduction=='sum': 15 | loss = -log_preds.sum() 16 | else: 17 | loss = -log_preds.sum(dim=-1) 18 | if self.reduction=='mean': 19 | loss = loss.mean() 20 | return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction, 21 | ignore_index=self.ignore_index) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 shuxinyin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /adversarial_training/README.md: -------------------------------------------------------------------------------- 1 | Implementation of some adversarial training function like FGM, FGSM, PGD, FreeAT. 2 | 3 | ### Quick Start 4 | 5 | You can find a simple demo for bert classification in test_bert.py. 6 | you can run it easily by the script below. 7 | > cd scripts 8 | > sh run_at.sh 9 | 10 | 11 | 12 | Here is a simple demo of usage: 13 | You just need to rewrite train function according to input for your model in file PGD.py, then you can use adversarial 14 | training like below. 15 | 16 | ```python 17 | import transformers 18 | from model import bert_classification 19 | from adversarial_training.PGD import PGD 20 | 21 | batch_size, num_class = 64, 10 22 | # model = your_model() 23 | model = bert_classification() 24 | AT_Model = PGD(model) 25 | optimizer = transformers.AdamW(model.parameters(), lr=0.001) 26 | 27 | # rewrite your train function in pgd.py 28 | outputs, loss = AT_Model.train_bert(token, segment, mask, label, optimizer) 29 | ``` 30 | 31 | ### Adversarial Training Results Compare 32 | ![acc](../out/train_loss_acc.png) 33 | 34 | | Adversarial Training | Time Cost(s/epoch ) | best_acc | 35 | |:----------------------:|:-------------------:|:--------:| 36 | | Normal(not add attack) | 23.77 | 0.773 | 37 | | FGSM | 45.95 | 0.7936 | 38 | | FGM | 47.28 | 0.8008 | 39 | | PGD(k=3) | 87.50 | 0.7963 | 40 | | FreeAT(k=3) | 93.26 | 0.7896 | 41 | 42 | 43 | 44 | ### Reference 45 | - https://github.com/locuslab/fast_adversarial 46 | - https://github.com/ashafahi/free_adv_train -------------------------------------------------------------------------------- /adversarial_training/FGSM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | 4 | 5 | class FGSM: 6 | def __init__(self, model, epsilon=0.05, emb_name='embedding.'): 7 | self.model = model 8 | self.emb_backup = {} 9 | self.epsilon = epsilon 10 | self.emb_name = emb_name 11 | 12 | def train(self, input_data, labels, optimizer): 13 | ''' define process of training here according to your model define 14 | ''' 15 | pass 16 | 17 | def train_bert(self, token, segment, mask, labels, optimizer, attack=False): 18 | ''' add disturbance in training 19 | ''' 20 | outputs = self.model(token, segment, mask) 21 | loss = F.cross_entropy(outputs, labels) 22 | loss.backward() 23 | 24 | if attack: 25 | self.attack_embedding() 26 | outputs = self.model(token, segment, mask) 27 | loss = F.cross_entropy(outputs, labels) 28 | # self.model.zero_grad() # compute advertise samples' grad only 29 | loss.backward() 30 | self.restore_embedding() # recover 31 | optimizer.step() 32 | self.model.zero_grad() 33 | 34 | return outputs, loss 35 | 36 | def attack_param(self, param): 37 | # r_at = epsilon * sign(grad) 38 | r_at = self.epsilon * np.sign(param.grad) 39 | param.data.add_(r_at) 40 | 41 | def attack_embedding(self, backup=True): 42 | for name, param in self.model.named_parameters(): 43 | if param.requires_grad and self.emb_name in name: 44 | if backup: 45 | self.emb_backup[name] = param.data.clone() 46 | # attack embedding 47 | self.attack_param(name, param) 48 | 49 | def restore_embedding(self): 50 | for name, param in self.model.named_parameters(): 51 | if param.requires_grad and self.emb_name in name: 52 | assert name in self.emb_backup 53 | param.data = self.emb_backup[name] 54 | self.emb_backup = {} 55 | -------------------------------------------------------------------------------- /out/ATTypePGD_UseATTrue.csv: -------------------------------------------------------------------------------- 1 | train_loss time acc eval_loss 2 | 2.237075458599042 83.45233535766602 0.4976 1.4369977772425695 3 | 1.6613552351541157 86.46612024307251 0.6631 0.9831372532685092 4 | 1.3488045556132804 87.38610363006592 0.7244 0.8744084649974373 5 | 1.1552119787923898 88.64773416519165 0.7576 0.7677854638380609 6 | 1.0182921078386187 87.31404995918274 0.7629 0.7842189069884787 7 | 0.9173169383771309 87.32622718811035 0.7696 0.7894695585795269 8 | 0.8316184749816684 87.33676671981812 0.774 0.773496703737101 9 | 0.7675504596738876 87.3484354019165 0.7761 0.8000537196447135 10 | 0.7121977768995591 87.47815442085266 0.7763 0.8553681345597194 11 | 0.6644578330501726 87.6738555431366 0.7865 0.7777208091726728 12 | 0.6247536281190204 87.81615138053894 0.7807 0.8440746126850699 13 | 0.5896623820176482 87.7217128276825 0.7746 0.8983549268762018 14 | 0.5582099286467562 87.4314775466919 0.7836 0.895342266673495 15 | 0.5300630825499298 87.53213310241699 0.7503 1.016948850861021 16 | 0.5038996183859648 87.72018194198608 0.7893 0.9250938386484316 17 | 0.48075908059298494 87.82800507545471 0.783 0.9435146401642235 18 | 0.460679778546977 87.96218657493591 0.7864 0.9188103512593895 19 | 0.44015228159466313 88.0583381652832 0.7857 1.0289483816380713 20 | 0.42164058785334596 88.03025031089783 0.7854 1.1166192225779696 21 | 0.4053912778520556 87.56653475761414 0.7917 1.0436718155434177 22 | 0.389954097527447 87.44577479362488 0.776 1.0933413016283589 23 | 0.37556892122014485 87.61701512336731 0.7785 1.1371298162325931 24 | 0.3618065541165219 87.6816611289978 0.7963 1.0783061248956212 25 | 0.34923851050737115 87.74892210960388 0.7756 1.1533686951941746 26 | 0.3375953037492766 87.85224175453186 0.7732 1.225065349013942 27 | 0.32733632383230815 87.96056890487671 0.7753 1.2293367929946466 28 | 0.31868766858481506 87.80440878868103 0.7674 1.237078681422077 29 | 0.3101716200717633 87.6871395111084 0.784 1.1537967795019697 30 | 0.302003020973783 87.3574857711792 0.779 1.1975945876851963 31 | 0.29463573093865 87.52284073829651 0.7779 1.2251497156870592 32 | 0.28663633111621895 87.64569282531738 0.7809 1.2975429419404383 33 | 0.27872003423429687 87.65008163452148 0.771 1.463192546823222 34 | -------------------------------------------------------------------------------- /out/ATTypeFGSM_UseATTrue.csv: -------------------------------------------------------------------------------- 1 | train_loss time acc eval_loss 2 | 2.6776645455179335 43.543582916259766 0.2168 2.080889650211213 3 | 2.1576888327357135 43.04824161529541 0.6264 1.1502349625347525 4 | 1.7534638223265797 44.00109505653381 0.6985 0.9455733544127957 5 | 1.4671906117774263 44.810537576675415 0.7146 0.9047348200325753 6 | 1.2631667987455295 45.47444438934326 0.7756 0.7439963895424156 7 | 1.1133458390203206 46.74590706825256 0.7864 0.7467929955310882 8 | 0.9959796576156107 45.82074570655823 0.785 0.790090296251379 9 | 0.9044466390262677 49.288307905197144 0.7804 0.7779491857928076 10 | 0.829159151495509 46.84054517745972 0.7769 0.8330396671014227 11 | 0.7656418678126757 47.686845779418945 0.7873 0.8556257359161499 12 | 0.7117522867274299 46.82540941238403 0.7842 0.8272688874773159 13 | 0.6663479475689589 46.40801930427551 0.7824 0.8833960617423817 14 | 0.6262950927269406 45.21188950538635 0.7856 0.8912214768730151 15 | 0.5908075425046316 44.69367003440857 0.7819 0.947202683254412 16 | 0.5606648896251046 45.13036489486694 0.7562 1.0995947938815804 17 | 0.5340527448437731 44.976510524749756 0.7868 0.9588465116870631 18 | 0.5085824815014118 45.38303351402283 0.7617 1.1317627841879607 19 | 0.48795551810748927 44.66726350784302 0.7828 1.0526206342467836 20 | 0.4677473636906219 44.31538391113281 0.7778 1.1026592401752047 21 | 0.4480925619407687 44.40354347229004 0.7936 1.0769182444567893 22 | 0.4319763340391645 45.88613533973694 0.7589 1.2406504916584795 23 | 0.4182341055932773 46.918251514434814 0.7902 0.9706501920891416 24 | 0.40315868800796106 46.95930480957031 0.7636 1.2297042566499892 25 | 0.3904574846358624 45.71096181869507 0.7876 1.0571456545857107 26 | 0.3791653635234961 46.28270173072815 0.7761 1.0189071936402352 27 | 0.3675311717001618 45.81886649131775 0.771 1.2824938833523707 28 | 0.3561834233945722 44.58826923370361 0.7905 1.17407976650888 29 | 0.34567984684381403 43.96854019165039 0.7776 1.2393342770968274 30 | 0.3362837299618777 48.10841608047485 0.7729 1.2981403264080642 31 | 0.32837205010364107 48.99410319328308 0.7662 1.1172230958369127 32 | 0.3203710058604003 48.88341784477234 0.7791 1.2156062625396025 33 | 0.31274655728900796 49.09944534301758 0.779 1.2289809419470987 34 | -------------------------------------------------------------------------------- /out/ATTypeFGM_UseATTrue.csv: -------------------------------------------------------------------------------- 1 | train_loss time acc eval_loss 2 | 2.3333273960065237 43.381364822387695 0.4531 1.562754148510611 3 | 1.6886588534976863 45.142308950424194 0.7188 0.8847488274999485 4 | 1.340206026527952 45.33142638206482 0.7504 0.7940964020171742 5 | 1.123123235434671 45.80958271026611 0.7684 0.7564992623723996 6 | 0.974935200357739 45.8135244846344 0.7689 0.7668718708928224 7 | 0.8641455148028422 45.877368211746216 0.7825 0.7398840265858705 8 | 0.7777992735396242 46.61386275291443 0.7724 0.805544844858206 9 | 0.7073700550967191 46.75735831260681 0.7786 0.859818220660565 10 | 0.6498766840028193 46.9925901889801 0.7819 0.838290487125421 11 | 0.6043911423958555 47.14177656173706 0.7777 0.8826836675974974 12 | 0.5633314667811876 47.32115578651428 0.7699 0.9750690354852919 13 | 0.5294293419847006 47.71417284011841 0.7879 0.8509429387131314 14 | 0.4980318044414771 47.46504998207092 0.7766 0.9951246199524326 15 | 0.47300977116049847 47.504666328430176 0.7515 1.0647146940990617 16 | 0.4519219266022811 47.380953311920166 0.7899 0.9285912260318258 17 | 0.4299107054528438 47.50248718261719 0.7685 1.1504604248863877 18 | 0.41094699925143197 48.092716217041016 0.8008 0.9502677745215452 19 | 0.39239888204557694 47.53448152542114 0.7837 1.109953431756633 20 | 0.3776875585202825 47.919201374053955 0.7996 0.9209640406451787 21 | 0.36211077737639646 48.20289444923401 0.7973 1.0523225172396133 22 | 0.34824701532527286 47.954283714294434 0.7944 1.0328755428789147 23 | 0.3347199469811395 48.15190672874451 0.7952 1.103152217782417 24 | 0.32164264471615417 48.34456658363342 0.7778 1.4075998105820577 25 | 0.31149850223790415 48.07316446304321 0.7798 1.097483038048076 26 | 0.3019897498661517 48.11297607421875 0.7889 1.0612573566176235 27 | 0.2920884113673248 48.16046500205994 0.7875 1.2591359905281645 28 | 0.28304781502880827 48.16928839683533 0.7955 1.1827917640945713 29 | 0.2752550704100437 47.80799746513367 0.7703 1.1690012692076386 30 | 0.26825296289025763 48.329411029815674 0.7882 1.1108314302886368 31 | 0.2619254907041141 48.22082877159119 0.7791 1.1114299016393674 32 | 0.25665846668476294 47.98287582397461 0.7854 1.1205626665406925 33 | 0.2506976429083208 48.243019580841064 0.7797 1.1829084668094945 34 | -------------------------------------------------------------------------------- /out/ATTypeFGM_UseATFalse.csv: -------------------------------------------------------------------------------- 1 | train_loss time acc eval_loss 2 | 2.5656534086299847 23.005752563476562 0.3932 1.6760724541867615 3 | 1.9351202026952672 23.912715435028076 0.6035 1.0626736604104376 4 | 1.580022838548266 23.773364543914795 0.7077 0.8821596322925227 5 | 1.355501955068564 23.608240604400635 0.7047 0.9010835878408638 6 | 1.1979745381240603 23.520110607147217 0.7261 0.8716525065291459 7 | 1.068298367382605 23.3923122882843 0.7503 0.868983195560753 8 | 0.9726326460486941 23.329983472824097 0.7556 0.8179388191479786 9 | 0.8911157061310508 23.332130432128906 0.7498 0.9097628469110295 10 | 0.8225903935866684 23.337834358215332 0.773 0.8072677204373536 11 | 0.7679196563325351 23.29317569732666 0.7691 0.8647188866973683 12 | 0.7212425871088809 23.196439027786255 0.7336 0.9963459839486772 13 | 0.6825848231602948 23.962465286254883 0.7365 1.0902006414001155 14 | 0.6453348279594214 24.248406410217285 0.7456 0.9555637108007814 15 | 0.6122434359939792 23.919344186782837 0.7628 0.9786352657588424 16 | 0.5807945935456436 23.065471410751343 0.7718 0.998969958751065 17 | 0.5528098571606428 23.035010814666748 0.7612 1.0639447281315069 18 | 0.5277961964551873 24.662909746170044 0.7324 1.1418056264519691 19 | 0.5068799018500957 24.89802885055542 0.7422 1.227999315567457 20 | 0.48611777831553954 24.654135704040527 0.7593 1.2218825145132224 21 | 0.46596512937729684 24.868656158447266 0.749 1.2956563744955003 22 | 0.44973682069442517 24.97143530845642 0.7579 1.1449188509374668 23 | 0.43389151174300733 24.446916341781616 0.7276 1.2804629438242334 24 | 0.41870288518869947 24.8124737739563 0.7543 1.1945647516637852 25 | 0.40431129785877445 24.63015866279602 0.7632 1.2685521355100497 26 | 0.3924142614285214 24.186723947525024 0.7636 1.139045369188497 27 | 0.3797581068778768 23.28633403778076 0.7675 1.235060898931163 28 | 0.3677599838822626 23.291670083999634 0.7626 1.3202310614999693 29 | 0.3571285070088535 23.234180212020874 0.7678 1.2097944138915675 30 | 0.34675005609557824 23.10203742980957 0.7691 1.2908580456000225 31 | 0.337177728058452 23.033085823059082 0.7499 1.4373560090353534 32 | 0.32958440354317137 23.511265754699707 0.7589 1.2843477542802786 33 | 0.32164161259618285 23.1623797416687 0.7618 1.3510286828894524 34 | -------------------------------------------------------------------------------- /out/ATTypeFreeAT_UseATTrue_epsilon0.8.csv: -------------------------------------------------------------------------------- 1 | train_loss time acc eval_loss 2 | 0.7558265738660777 86.75020909309387 0.6616 1.1887983835426865 3 | 0.5460665085930613 90.25732851028442 0.7638 0.7707689756610591 4 | 0.4369606377816276 90.43820118904114 0.7769 0.8158577247789711 5 | 0.36862465056083815 90.64329481124878 0.7588 0.9334954178067529 6 | 0.32489433801155304 96.03342795372009 0.746 0.9735628950178243 7 | 0.29432850834453783 90.92144107818604 0.7617 1.0137013402903916 8 | 0.27164953302720035 91.15593218803406 0.7896 0.8527051932682657 9 | 0.2513386958051779 91.28371262550354 0.7702 1.0220181503493315 10 | 0.23519230696324223 91.30850648880005 0.7816 0.9527276417916748 11 | 0.22146122558867629 91.34853172302246 0.7657 1.0251696777951187 12 | 0.20988262213672262 91.79116678237915 0.781 0.9523702913029178 13 | 0.20045320112808343 95.76618218421936 0.7184 1.4448095213810566 14 | 0.1912665646015408 96.50476932525635 0.7842 0.955566512076718 15 | 0.1830400047684004 97.424471616745 0.7721 1.0058023301279468 16 | 0.17685886710284882 97.88513207435608 0.783 1.0016688990175344 17 | 0.17019966005530396 97.54570651054382 0.7676 1.112565981449595 18 | 0.16316204084040442 96.99987554550171 0.6892 1.6611272329167954 19 | 0.15728639387171953 96.4523434638977 0.7713 1.1483741025588694 20 | 0.152044323816378 97.82581353187561 0.7713 1.1751926061549005 21 | 0.14711767479915855 96.38065028190613 0.7265 1.506433735607536 22 | 0.14258787206417897 98.03607940673828 0.7521 1.3868515919308195 23 | 0.1394311841074924 91.77282738685608 0.788 1.0450309007339607 24 | 0.13583591310284318 91.3764922618866 0.7694 1.1765143142443648 25 | 0.13259396613744984 91.27464985847473 0.7621 1.2282011441554233 26 | 0.12869091514444944 91.43755221366882 0.7731 1.324908080564183 27 | 0.12520547374962135 95.51044940948486 0.7623 1.3370983962230623 28 | 0.12255265266880999 92.1047670841217 0.7095 1.7447330920844322 29 | 0.11973198762275988 91.8145592212677 0.7831 1.2829045978416302 30 | 0.1163079629079958 91.77716183662415 0.7531 1.6614306874715599 31 | 0.11348100599445027 91.54884028434753 0.7686 1.201888377784164 32 | 0.11039214750679646 91.56718468666077 0.7301 1.9073487332766983 33 | 0.10778030354239364 91.60365056991577 0.7868 1.283133428567534 34 | -------------------------------------------------------------------------------- /unbalanced_loss/weight_ce_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class WBCEWithLogitLoss(nn.Module): 6 | """ 7 | Weighted Binary Cross Entropy. 8 | `WBCE(p,t)=-β*t*log(p)-(1-t)*log(1-p)` 9 | To decrease the number of false negatives, set β>1. 10 | To decrease the number of false positives, set β<1. 11 | Args: 12 | @param weight: positive sample weight 13 | Shapes: 14 | output: A tensor of shape [N, 1,(d,), h, w] without sigmoid activation function applied 15 | target: A tensor of shape same with output 16 | """ 17 | 18 | def __init__(self, weight=1.0, ignore_index=None, reduction='mean'): 19 | super(WBCEWithLogitLoss, self).__init__() 20 | assert reduction in ['none', 'mean', 'sum'] 21 | self.ignore_index = ignore_index 22 | weight = float(weight) 23 | self.weight = weight 24 | self.reduction = reduction 25 | self.smooth = 0.01 26 | 27 | def forward(self, output, target): 28 | assert output.shape[0] == target.shape[0], "output & target batch size don't match" 29 | 30 | if self.ignore_index is not None: 31 | valid_mask = (target != self.ignore_index).float() 32 | output = output.mul(valid_mask) # can not use inplace for bp 33 | target = target.float().mul(valid_mask) 34 | 35 | batch_size = output.size(0) 36 | output = output.view(batch_size, -1) 37 | target = target.view(batch_size, -1) 38 | 39 | output = torch.sigmoid(output) 40 | # avoid `nan` loss 41 | eps = 1e-6 42 | output = torch.clamp(output, min=eps, max=1.0 - eps) 43 | # soft label 44 | target = torch.clamp(target, min=self.smooth, max=1.0 - self.smooth) 45 | 46 | # loss = self.bce(output, target) 47 | loss = -self.weight * target.mul(torch.log(output)) - ((1.0 - target).mul(torch.log(1.0 - output))) 48 | if self.reduction == 'mean': 49 | loss = torch.mean(loss) 50 | elif self.reduction == 'sum': 51 | loss = torch.sum(loss) 52 | elif self.reduction == 'none': 53 | loss = loss 54 | else: 55 | raise NotImplementedError 56 | return loss 57 | -------------------------------------------------------------------------------- /adversarial_training/FGM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class FGM: 6 | def __init__(self, model, emb_name='embedding.'): 7 | self.model = model 8 | self.emb_backup = {} # restore embedding parameters 9 | self.epsilon = 1.0 10 | self.emb_name = emb_name 11 | 12 | def train(self, input_data, labels, optimizer): 13 | ''' define process of training here according to your model define 14 | ''' 15 | pass 16 | 17 | def train_bert(self, token, segment, mask, labels, optimizer, attack=False): 18 | ''' a advertisement training demo for bert 19 | ''' 20 | outputs = self.model(token, segment, mask) 21 | loss = F.cross_entropy(outputs, labels) 22 | loss.backward() 23 | 24 | if attack: 25 | self.attack_embedding() 26 | outputs = self.model(token, segment, mask) 27 | loss = F.cross_entropy(outputs, labels) 28 | # self.model.zero_grad() # compute advertise samples' grad only 29 | loss.backward() 30 | self.restore_embedding() # recover 31 | optimizer.step() 32 | self.model.zero_grad() 33 | 34 | return outputs, loss 35 | 36 | def attack_embedding(self, backup=True): 37 | ''' add add disturbance in embedding layer you want 38 | ''' 39 | for name, param in self.model.named_parameters(): 40 | if param.requires_grad and self.emb_name in name: 41 | if backup: # store parameter 42 | self.emb_backup[name] = param.data.clone() 43 | 44 | self._add_disturbance(name, param) # add disturbance 45 | 46 | def restore_embedding(self): 47 | '''recover embedding backup before 48 | ''' 49 | for name, param in self.model.named_parameters(): 50 | if param.requires_grad and self.emb_name in name: 51 | assert name in self.emb_backup 52 | param.data = self.emb_backup[name] 53 | self.emb_backup = {} 54 | 55 | def _add_disturbance(self, param): 56 | ''' add disturbance 57 | ''' 58 | norm = torch.norm(param.grad) 59 | if norm != 0: 60 | r_at = self.epsilon * param.grad / norm 61 | param.data.add_(r_at) 62 | -------------------------------------------------------------------------------- /out/plot_pic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import pandas as pd 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def read_data(file): 9 | df = pd.read_csv(file, header=0, sep='\t', encoding='utf-8') 10 | return df 11 | 12 | 13 | def get_time(df): 14 | time_list = df['time'].to_list() 15 | return np.mean(time_list) 16 | 17 | 18 | def draw_plot(l1, l2, l3, l4, l5, name='train_loss', location="upper right"): 19 | plt.plot(l1) 20 | plt.plot(l2) 21 | plt.plot(l3) 22 | plt.plot(l4) 23 | plt.plot(l5) 24 | legend = plt.legend(['Normal', 'FGSM', 'FGM', 'PGD', 'FreeAT'], title=name, loc=location) 25 | 26 | 27 | if __name__ == '__main__': 28 | file_dic = {"Normal": "ATTypeFGM_UseATFalse.csv", 29 | "FGSM": "ATTypeFGSM_UseATTrue.csv", 30 | "FGM": "ATTypeFGM_UseATTrue.csv", 31 | "PGD": "ATTypePGD_UseATTrue.csv", 32 | "FreeAT": "ATTypeFreeAT_UseATTrue_epsilon0.8.csv"} 33 | 34 | df_normal = read_data(file_dic['Normal']) 35 | df_fgsm = read_data(file_dic['FGSM']) 36 | df_fgm = read_data(file_dic['FGM']) 37 | df_pgd = read_data(file_dic['PGD']) 38 | df_freeat = read_data(file_dic['FreeAT']) 39 | print(df_normal.head(3)) 40 | 41 | t1, t2, t3, t4, t5 = get_time(df_normal), get_time(df_fgsm), \ 42 | get_time(df_fgm), get_time(df_pgd), \ 43 | get_time(df_freeat) 44 | print(t1, t2, t3, t4, t5) 45 | 46 | ax1 = plt.subplot(1, 2, 1) 47 | train_loss1 = df_normal['train_loss'].to_list() 48 | train_loss2 = df_fgsm['train_loss'].to_list() 49 | train_loss3 = df_fgm['train_loss'].to_list() 50 | train_loss4 = df_pgd['train_loss'].to_list() 51 | train_loss5 = df_freeat['train_loss'].to_list() 52 | 53 | draw_plot(train_loss1, train_loss2, train_loss3, train_loss4, train_loss5, name='train_loss') 54 | 55 | # eval_loss1 = df_normal['eval_loss'].to_list() 56 | # eval_loss2 = df_fgsm['eval_loss'].to_list() 57 | # eval_loss3 = df_fgm['eval_loss'].to_list() 58 | # eval_loss4 = df_pgd['eval_loss'].to_list() 59 | # eval_loss5 = df_freeat['eval_loss'].to_list() 60 | # draw_plot(eval_loss1, eval_loss2, eval_loss3, eval_loss4, eval_loss5, name='eval_loss') 61 | ax2 = plt.subplot(1, 2, 2) 62 | acc1 = df_normal['acc'].to_list() 63 | acc2 = df_fgsm['acc'].to_list() 64 | acc3 = df_fgm['acc'].to_list() 65 | acc4 = df_pgd['acc'].to_list() 66 | acc5 = df_freeat['acc'].to_list() 67 | draw_plot(acc1, acc2, acc3, acc4, acc5, name='acc', location="lower right") 68 | print(max(acc1), max(acc2), max(acc3), max(acc4), max(acc5)) 69 | plt.show() 70 | -------------------------------------------------------------------------------- /unbalanced_loss/GHM_loss.py: -------------------------------------------------------------------------------- 1 | # some code reforenced from https://github.com/DHPO/GHM_Loss.pytorch 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class GHM_Loss(nn.Module): 9 | def __init__(self, bins=10, alpha=0.5): 10 | ''' 11 | bins: split to n bins 12 | alpha: hyper-parameter 13 | ''' 14 | super(GHM_Loss, self).__init__() 15 | self._bins = bins 16 | self._alpha = alpha 17 | self._last_bin_count = None 18 | 19 | def _g2bin(self, g): 20 | return torch.floor(g * (self._bins - 0.0001)).long() 21 | 22 | def _custom_loss(self, x, target, weight): 23 | raise NotImplementedError 24 | 25 | def _custom_loss_grad(self, x, target): 26 | raise NotImplementedError 27 | 28 | def forward(self, x, target): 29 | g = torch.abs(self._custom_loss_grad(x, target)).detach() 30 | 31 | bin_idx = self._g2bin(g) 32 | 33 | bin_count = torch.zeros((self._bins)) 34 | for i in range(self._bins): 35 | bin_count[i] = (bin_idx == i).sum().item() 36 | 37 | N = (x.size(0) * x.size(1)) 38 | 39 | if self._last_bin_count is None: 40 | self._last_bin_count = bin_count 41 | else: 42 | bin_count = self._alpha * self._last_bin_count + (1 - self._alpha) * bin_count 43 | self._last_bin_count = bin_count 44 | 45 | nonempty_bins = (bin_count > 0).sum().item() 46 | 47 | gd = bin_count * nonempty_bins 48 | gd = torch.clamp(gd, min=0.0001) 49 | beta = N / gd 50 | 51 | return self._custom_loss(x, target, beta[bin_idx]) 52 | 53 | 54 | class GHMC_Loss(GHM_Loss): 55 | ''' 56 | GHM_Loss for classification 57 | ''' 58 | 59 | def __init__(self, bins, alpha): 60 | super(GHMC_Loss, self).__init__(bins, alpha) 61 | 62 | def _custom_loss(self, x, target, weight): 63 | return F.binary_cross_entropy_with_logits(x, target, weight=weight) 64 | 65 | def _custom_loss_grad(self, x, target): 66 | return torch.sigmoid(x).detach() - target 67 | 68 | 69 | class GHMR_Loss(GHM_Loss): 70 | ''' 71 | GHM_Loss for regression 72 | ''' 73 | 74 | def __init__(self, bins, alpha, mu): 75 | super(GHMR_Loss, self).__init__(bins, alpha) 76 | self._mu = mu 77 | 78 | def _custom_loss(self, x, target, weight): 79 | d = x - target 80 | mu = self._mu 81 | loss = torch.sqrt(d * d + mu * mu) - mu 82 | N = x.size(0) * x.size(1) 83 | return (loss * weight).sum() / N 84 | 85 | def _custom_loss_grad(self, x, target): 86 | d = x - target 87 | mu = self._mu 88 | return d / torch.sqrt(d * d + mu * mu) 89 | -------------------------------------------------------------------------------- /bert_model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | BertLayerNorm = torch.nn.LayerNorm 5 | 6 | 7 | class MultiClass(nn.Module): 8 | """ text processed by bert model encode and get cls vector for multi classification 9 | """ 10 | 11 | def __init__(self, bert_encode_model, model_config, num_classes=10, pooling_type='first-last-avg'): 12 | super(MultiClass, self).__init__() 13 | self.bert = bert_encode_model 14 | self.num_classes = num_classes 15 | self.fc = nn.Linear(model_config.hidden_size, num_classes) 16 | self.pooling = pooling_type 17 | self.dropout = nn.Dropout(model_config.hidden_dropout_prob) 18 | self.layer_norm = BertLayerNorm(model_config.hidden_size) 19 | 20 | def forward(self, batch_token, batch_segment, batch_attention_mask): 21 | out = self.bert(batch_token, 22 | attention_mask=batch_attention_mask, 23 | token_type_ids=batch_segment, 24 | output_hidden_states=True) 25 | 26 | if self.pooling == 'cls': 27 | out = out.last_hidden_state[:, 0, :] # [batch, 768] 28 | elif self.pooling == 'pooler': 29 | out = out.pooler_output # [batch, 768] 30 | elif self.pooling == 'last-avg': 31 | last = out.last_hidden_state.transpose(1, 2) # [batch, 768, seqlen] 32 | out = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768] 33 | elif self.pooling == 'first-last-avg': 34 | first = out.hidden_states[1].transpose(1, 2) # [batch, 768, seqlen] 35 | last = out.hidden_states[-1].transpose(1, 2) # [batch, 768, seqlen] 36 | first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768] 37 | last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768] 38 | avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [batch, 2, 768] 39 | out = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [batch, 768] 40 | else: 41 | raise "should define pooling type first!" 42 | 43 | out = self.layer_norm(out) 44 | out = self.dropout(out) 45 | out_fc = self.fc(out) 46 | return out_fc 47 | 48 | 49 | if __name__ == '__main__': 50 | path = "/data/Learn_Project/Backup_Data/bert_chinese" 51 | MultiClassModel = MultiClass 52 | # MultiClassModel = BertForMultiClassification 53 | multi_classification_model = MultiClassModel.from_pretrained(path, num_classes=10) 54 | if hasattr(multi_classification_model, 'bert'): 55 | print("-------------------------------------------------") 56 | else: 57 | print("**********************************************") 58 | -------------------------------------------------------------------------------- /adversarial_training/FreeAT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class FreeAT: 6 | def __init__(self, model, epsilon=0.8, k=3, emb_name='embedding.'): 7 | self.model = model 8 | self.emb_backup = {} 9 | self.epsilon = epsilon 10 | self.K = k # attack times 11 | self.emb_name = emb_name # embedding layer name want to attack 12 | self.backup_emb() 13 | 14 | def train(self, input_data, labels, optimizer): 15 | ''' define process of training here according to your model define 16 | ''' 17 | pass 18 | 19 | def train_bert(self, token, segment, mask, labels, optimizer, attack=True): 20 | ''' add disturbance in training 21 | ''' 22 | outputs = self.model(token, segment, mask) 23 | loss = F.cross_entropy(outputs, labels) 24 | loss.backward() 25 | 26 | if attack: 27 | for t in range(self.K): 28 | outputs = self.model(token, segment, mask) 29 | self.model.zero_grad() 30 | loss = F.cross_entropy(outputs, labels) 31 | loss.backward() 32 | optimizer.step() 33 | self.attack_emb(backup=False) # accumulate projected disturb in embedding 34 | 35 | return outputs, loss 36 | 37 | def attack_param(self, name, param): 38 | '''add disturbance 39 | FreeAT Format: 40 | r[t+1] = r[t] + epsilon * sign(grad) 41 | r_at = epsilon * np.sign(param.grad) 42 | ''' 43 | norm = torch.norm(param.grad) 44 | if norm != 0: 45 | r_at = self.epsilon * param.grad / norm 46 | param.data.add_(r_at) 47 | param.data = self.project(name, param.data) 48 | 49 | def project(self, param_name, param_data): 50 | ''' projected disturbance like disturb cropping inside the pale (-eps, eps) 51 | ''' 52 | r = param_data - self.emb_backup[param_name] # compute disturbance 53 | if torch.norm(r) > self.epsilon: # disturbance cropping inside the pale (-eps, eps) 54 | r = self.epsilon * r / torch.norm(r) 55 | return self.emb_backup[param_name] + r 56 | 57 | def attack_emb(self, backup=False): 58 | for name, param in self.model.named_parameters(): 59 | if param.requires_grad and self.emb_name in name: 60 | if backup: # backup embedding 61 | self.emb_backup[name] = param.data.clone() 62 | self.attack_param(name, param) 63 | 64 | def backup_emb(self): 65 | for name, param in self.model.named_parameters(): 66 | if param.requires_grad and self.emb_name in name: 67 | self.emb_backup[name] = param.data.clone() 68 | 69 | def restore_emb(self): 70 | '''recover embedding''' 71 | for name, param in self.model.named_parameters(): 72 | if param.requires_grad and self.emb_name in name: 73 | assert name in self.emb_backup 74 | param.data = self.emb_backup[name] 75 | self.emb_backup = {} 76 | -------------------------------------------------------------------------------- /bert_model/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import numpy as np 5 | 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | from transformers import BertModel, AlbertModel, BertConfig, BertTokenizer 10 | from transformers import BertForSequenceClassification, AutoModelForMaskedLM 11 | 12 | 13 | def load_data(path): 14 | train = pd.read_csv(path, header=0, sep='\t', names=["text", "label"]) 15 | print(train.shape) 16 | 17 | texts = train.text.to_list() 18 | labels = train.label.map(int).to_list() 19 | # label_dic = dict(zip(train.label.unique(), range(len(train.label.unique())))) 20 | return texts, labels 21 | 22 | 23 | class TextDataset(Dataset): 24 | def __init__(self, filepath): 25 | super(TextDataset, self).__init__() 26 | self.train, self.label = load_data(filepath) 27 | 28 | def __len__(self): 29 | return len(self.train) 30 | 31 | def __getitem__(self, item): 32 | text = self.train[item] 33 | label = self.label[item] 34 | return text, label 35 | 36 | 37 | class BatchTextCall(object): 38 | """call function for tokenizing and getting batch text 39 | """ 40 | 41 | def __init__(self, tokenizer, max_len=64): 42 | self.tokenizer = tokenizer 43 | self.max_len = max_len 44 | 45 | def text2id(self, batch_text): 46 | return self.tokenizer(batch_text, max_length=self.max_len, 47 | truncation=True, padding='max_length', return_tensors='pt') 48 | 49 | def __call__(self, batch): 50 | batch_text = [item[0] for item in batch] 51 | batch_label = [item[1] for item in batch] 52 | 53 | source = self.text2id(batch_text) 54 | token = source.get('input_ids').squeeze(1) 55 | mask = source.get('attention_mask').squeeze(1) 56 | segment = source.get('token_type_ids').squeeze(1) 57 | label = torch.tensor(batch_label) 58 | 59 | return token, segment, mask, label 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | data_dir = "../data/THUCNews/news" 65 | pretrained_path = "/data/Learn_Project/Backup_Data/bert_chinese" 66 | 67 | label_dict = {'体育': 0, '娱乐': 1, '家居': 2, '房产': 3, '教育': 4, '时尚': 5, '时政': 6, '游戏': 7, '科技': 8, 68 | '财经': 9} 69 | 70 | tokenizer = BertTokenizer.from_pretrained(pretrained_path) 71 | model_config = BertConfig.from_pretrained(pretrained_path) 72 | model = BertModel.from_pretrained(pretrained_path, config=model_config) 73 | 74 | text_dataset = TextDataset(os.path.join(data_dir, "train.txt")) 75 | text_dataset_call = BatchTextCall(tokenizer) 76 | text_dataloader = DataLoader(text_dataset, batch_size=2, shuffle=True, num_workers=2, collate_fn=text_dataset_call) 77 | 78 | for i, (token, segment, mask, label) in enumerate(text_dataloader): 79 | print(i, token, segment, mask, label) 80 | out = model(input_ids=token, attention_mask=mask, token_type_ids=segment) 81 | # loss, logits = model(token, mask, segment)[:2] 82 | print(out) 83 | print(out.last_hidden_state.shape) 84 | break 85 | -------------------------------------------------------------------------------- /test_loss.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | 5 | import torch.nn as nn 6 | from torch import optim 7 | import matplotlib.pyplot as plt 8 | 9 | from unbalanced_loss.focal_loss import MultiFocalLoss, BinaryFocalLoss 10 | from unbalanced_loss.dice_loss_nlp import BinaryDSCLoss, MultiDSCLoss 11 | 12 | torch.manual_seed(123) 13 | 14 | 15 | class CNNModel(nn.Module): 16 | 17 | def __init__(self, num_class, kernel_size=3, padding=1, stride=1): 18 | super(CNNModel, self).__init__() 19 | self.model = nn.Sequential(*[nn.Conv2d(3, 16, kernel_size=3, padding=1, stride=1), 20 | nn.BatchNorm2d(16), 21 | nn.ReLU()]) 22 | self.fc = nn.Linear(32 * 32 * 16, num_class) # flatten length * width * channels 23 | 24 | def forward(self, data): 25 | output = self.model(data) 26 | output = output.view(output.size(0), -1) 27 | output = self.fc(output) 28 | 29 | return output 30 | 31 | 32 | def choose_loss(num_class, loss_type): 33 | ''' 34 | choose loss type 35 | ''' 36 | if loss_type == "binary_focal_loss": 37 | data_shape = (16, 3, 32, 32) 38 | target_shape = (16, ) # [batch, 1] 39 | 40 | datas = (torch.rand(data_shape)).cuda() 41 | target = torch.randint(0, 2, size=target_shape).cuda() 42 | Loss = BinaryFocalLoss() 43 | 44 | if loss_type == "multi_class_focal_loss": 45 | data_shape = (16, 3, 32, 32) # [batch, channels, width, length] 46 | target_shape = (16,) # [batch, ] 47 | 48 | datas = (torch.rand(data_shape)).cuda() 49 | target = torch.randint(0, num_class, size=target_shape).cuda() 50 | Loss = MultiFocalLoss(num_class=num_class, gamma=2.0, reduction='mean') 51 | 52 | if loss_type == "binary_dice_loss": # 重写 53 | data_shape = (16, 3, 32, 32) 54 | target_shape = (16, ) # [batch, 1] 55 | 56 | datas = (torch.rand(data_shape)).cuda() 57 | target = torch.randint(0, 2, size=target_shape).cuda() 58 | Loss = BinaryDSCLoss() 59 | 60 | if loss_type == "multi_class_dice_loss": 61 | data_shape = (16, 3, 32, 32) # [batch, channels, width, length] 62 | target_shape = (16,) # [batch,] 63 | 64 | datas = (torch.rand(data_shape)).cuda() 65 | target = torch.randint(0, num_class, size=target_shape).cuda() 66 | Loss = MultiDSCLoss(alpha=1.0, smooth=1.0, reduction="mean") 67 | 68 | return datas, target, Loss 69 | 70 | 71 | def main(): 72 | num_class = 5 73 | datas, target, Loss = choose_loss(num_class, loss_type="multi_class_focal_loss") 74 | target = target.long().cuda() 75 | # print(target.shape, datas.shape) 76 | 77 | model = CNNModel(num_class) 78 | model = model.cuda() 79 | 80 | optimizer = optim.Adam(params=model.parameters(), lr=0.001) 81 | 82 | losses = [] 83 | for i in range(32): 84 | output = model(datas) 85 | loss = Loss(output, target) 86 | losses.append(loss.item()) 87 | optimizer.zero_grad() 88 | loss.backward() 89 | optimizer.step() 90 | 91 | plt.plot(losses) 92 | plt.show() 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | -------------------------------------------------------------------------------- /adversarial_training/PGD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class PGD: 6 | def __init__(self, model, epsilon=1.0, alpha=0.3, k=3, emb_name='embedding.'): 7 | self.model = model 8 | self.emb_backup = {} 9 | self.grad_backup = {} 10 | self.epsilon = epsilon 11 | self.alpha = alpha 12 | self.K = k # PGD attack times 13 | self.emb_name = emb_name 14 | 15 | def train(self, input_data, labels, optimizer): 16 | ''' define process of training here according to your model define 17 | ''' 18 | pass 19 | 20 | def train_bert(self, token, segment, mask, labels, optimizer, attack=True): 21 | ''' a advertisement training demo for bert 22 | ''' 23 | outputs = self.model(token, segment, mask) 24 | loss = F.cross_entropy(outputs, labels) 25 | loss.backward() 26 | 27 | if attack: 28 | self.backup_grad() 29 | for t in range(self.K): 30 | self.attack_embedding(backup=(t == 0)) 31 | if t != self.K - 1: 32 | self.model.zero_grad() 33 | else: 34 | self.restore_grad() 35 | outputs = self.model(token, segment, mask) 36 | loss = F.cross_entropy(outputs, labels) 37 | loss.backward() 38 | self.restore_embedding() # recover embedding 39 | optimizer.step() 40 | self.model.zero_grad() 41 | 42 | return outputs, loss 43 | 44 | def attack_param(self, name, param): 45 | '''add disturbance 46 | PGD: r = epsilon * grad / norm(grad) 47 | ''' 48 | norm = torch.norm(param.grad) 49 | if norm != 0 and not torch.isnan(norm): 50 | r_at = self.alpha * param.grad / norm 51 | param.data.add_(r_at) 52 | param.data = self.project(name, param.data) 53 | 54 | def project(self, param_name, param_data): 55 | ''' projected disturbance like parameter cropping inside the pale 56 | ''' 57 | r = param_data - self.emb_backup[param_name] 58 | if torch.norm(r) > self.epsilon: 59 | r = self.epsilon * r / torch.norm(r) 60 | return self.emb_backup[param_name] + r 61 | 62 | def attack_embedding(self, backup=False): 63 | for name, param in self.model.named_parameters(): 64 | if param.requires_grad and self.emb_name in name: 65 | if backup: # backup embedding 66 | self.emb_backup[name] = param.data.clone() 67 | self.attack_param(name, param) 68 | 69 | def backup_embedding(self): 70 | for name, param in self.model.named_parameters(): 71 | if param.requires_grad and self.emb_name in name: 72 | self.emb_backup[name] = param.data.clone() 73 | 74 | def restore_embedding(self): 75 | '''recover embedding''' 76 | for name, param in self.model.named_parameters(): 77 | if param.requires_grad and self.emb_name in name: 78 | assert name in self.emb_backup 79 | param.data = self.emb_backup[name] 80 | self.emb_backup = {} 81 | 82 | def backup_grad(self): 83 | for name, param in self.model.named_parameters(): 84 | if param.requires_grad and param.grad is not None: 85 | self.grad_backup[name] = param.grad.clone() 86 | 87 | def restore_grad(self): 88 | '''recover grad back upped 89 | ''' 90 | for name, param in self.model.named_parameters(): 91 | if param.requires_grad and name in self.grad_backup: 92 | param.grad = self.grad_backup[name] 93 | -------------------------------------------------------------------------------- /unbalanced_loss/dice_loss_nlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BinaryDSCLoss(torch.nn.Module): 5 | r""" 6 | Creates a criterion that optimizes a multi-class Self-adjusting Dice Loss 7 | ("Dice Loss for Data-imbalanced NLP Tasks" paper) 8 | 9 | Args: 10 | alpha (float): a factor to push down the weight of easy examples 11 | gamma (float): a factor added to both the nominator and the denominator for smoothing purposes 12 | reduction (string): Specifies the reduction to apply to the output: 13 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 14 | ``'mean'``: the sum of the output will be divided by the number of 15 | elements in the output, ``'sum'``: the output will be summed. 16 | 17 | Shape: 18 | - logits: `(N, C)` where `N` is the batch size and `C` is the number of classes. 19 | - targets: `(N)` where each value is in [0, C - 1] 20 | """ 21 | 22 | def __init__(self, alpha: float = 1.0, smooth: float = 1.0, reduction: str = "mean") -> None: 23 | super().__init__() 24 | self.alpha = alpha 25 | self.smooth = smooth 26 | self.reduction = reduction 27 | 28 | def forward(self, logits, targets): 29 | probs = torch.sigmoid(logits) 30 | probs = torch.gather(probs, dim=1, index=targets.unsqueeze(1)) 31 | 32 | targets = targets.unsqueeze(dim=1) 33 | pos_mask = (targets == 1).float() 34 | neg_mask = (targets == 0).float() 35 | 36 | pos_weight = pos_mask * ((1 - probs) ** self.alpha) * probs 37 | pos_loss = 1 - (2 * pos_weight + self.smooth) / (pos_weight + 1 + self.smooth) 38 | 39 | neg_weight = neg_mask * ((1 - probs) ** self.alpha) * probs 40 | neg_loss = 1 - (2 * neg_weight + self.smooth) / (neg_weight + self.smooth) 41 | 42 | loss = pos_loss + neg_loss 43 | loss = loss.mean() 44 | return loss 45 | 46 | 47 | class MultiDSCLoss(torch.nn.Module): 48 | r""" 49 | Creates a criterion that optimizes a multi-class Self-adjusting Dice Loss 50 | ("Dice Loss for Data-imbalanced NLP Tasks" paper) 51 | 52 | Args: 53 | alpha (float): a factor to push down the weight of easy examples 54 | gamma (float): a factor added to both the nominator and the denominator for smoothing purposes 55 | reduction (string): Specifies the reduction to apply to the output: 56 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 57 | ``'mean'``: the sum of the output will be divided by the number of 58 | elements in the output, ``'sum'``: the output will be summed. 59 | 60 | Shape: 61 | - logits: `(N, C)` where `N` is the batch size and `C` is the number of classes. 62 | - targets: `(N)` where each value is in [0, C - 1] 63 | """ 64 | 65 | def __init__(self, alpha: float = 1.0, smooth: float = 1.0, reduction: str = "mean"): 66 | super().__init__() 67 | self.alpha = alpha 68 | self.smooth = smooth 69 | self.reduction = reduction 70 | 71 | def forward(self, logits, targets): 72 | probs = torch.softmax(logits, dim=1) 73 | probs = torch.gather(probs, dim=1, index=targets.unsqueeze(1)) 74 | 75 | probs_with_factor = ((1 - probs) ** self.alpha) * probs 76 | loss = 1 - (2 * probs_with_factor + self.smooth) / (probs_with_factor + 1 + self.smooth) 77 | 78 | if self.reduction == "mean": 79 | return loss.mean() 80 | elif self.reduction == "sum": 81 | return loss.sum() 82 | elif self.reduction == "none" or self.reduction is None: 83 | return loss 84 | else: 85 | raise NotImplementedError(f"Reduction `{self.reduction}` is not supported.") 86 | -------------------------------------------------------------------------------- /unbalanced_loss/focal_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BinaryFocalLoss(nn.Module): 8 | """ 9 | Focal_Loss= -1*alpha*(1-pt)*log(pt) 10 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion 11 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more 12 | focus on hard misclassified example 13 | :param reduction: `none`|`mean`|`sum` 14 | """ 15 | 16 | def __init__(self, alpha=1, gamma=2, reduction='mean', **kwargs): 17 | super(BinaryFocalLoss, self).__init__() 18 | self.alpha = alpha 19 | self.gamma = gamma 20 | self.smooth = 1e-6 # set '1e-4' when train with FP16 21 | self.reduction = reduction 22 | 23 | assert self.reduction in ['none', 'mean', 'sum'] 24 | 25 | def forward(self, output, target): 26 | prob = torch.sigmoid(output) 27 | prob = torch.clamp(prob, self.smooth, 1.0 - self.smooth) 28 | 29 | target = target.unsqueeze(dim=1) 30 | pos_mask = (target == 1).float() 31 | neg_mask = (target == 0).float() 32 | 33 | pos_weight = (pos_mask * torch.pow(1 - prob, self.gamma)).detach() 34 | pos_loss = -pos_weight * torch.log(prob) # / (torch.sum(pos_weight) + 1e-4) 35 | 36 | neg_weight = (neg_mask * torch.pow(prob, self.gamma)).detach() 37 | neg_loss = -self.alpha * neg_weight * F.logsigmoid(-output) # / (torch.sum(neg_weight) + 1e-4) 38 | 39 | loss = pos_loss + neg_loss 40 | loss = loss.mean() 41 | return loss 42 | 43 | 44 | class MultiFocalLoss(nn.Module): 45 | """ 46 | Focal_Loss= -1*alpha*((1-pt)**gamma)*log(pt) 47 | Args: 48 | num_class: number of classes 49 | alpha: class balance factor shape=[num_class, ] 50 | gamma: hyper-parameter 51 | reduction: reduction type 52 | """ 53 | 54 | def __init__(self, num_class, alpha=None, gamma=2, reduction='mean'): 55 | super(MultiFocalLoss, self).__init__() 56 | self.num_class = num_class 57 | self.gamma = gamma 58 | self.reduction = reduction 59 | self.smooth = 1e-4 60 | self.alpha = alpha 61 | if alpha is None: 62 | self.alpha = torch.ones(num_class, ) - 0.5 63 | elif isinstance(alpha, (int, float)): 64 | self.alpha = torch.as_tensor([alpha] * num_class) 65 | elif isinstance(alpha, (list, np.ndarray)): 66 | self.alpha = torch.as_tensor(alpha) 67 | if self.alpha.shape[0] != num_class: 68 | raise RuntimeError('the length not equal to number of class') 69 | 70 | def forward(self, logit, target): 71 | # assert isinstance(self.alpha,torch.Tensor)\ 72 | alpha = self.alpha.to(logit.device) 73 | prob = F.softmax(logit, dim=1) 74 | 75 | if prob.dim() > 2: 76 | # used for 3d-conv: N,C,d1,d2 -> N,C,m (m=d1*d2*...) 77 | N, C = logit.shape[:2] 78 | prob = prob.view(N, C, -1) 79 | prob = prob.transpose(1, 2).contiguous() # [N,C,d1*d2..] -> [N,d1*d2..,C] 80 | prob = prob.view(-1, prob.size(-1)) # [N,d1*d2..,C]-> [N*d1*d2..,C] 81 | 82 | ori_shp = target.shape 83 | target = target.view(-1, 1) 84 | 85 | prob = prob.gather(1, target).view(-1) + self.smooth # avoid nan 86 | logpt = torch.log(prob) 87 | # alpha_class = alpha.gather(0, target.squeeze(-1)) 88 | alpha_weight = alpha[target.squeeze().long()] 89 | loss = -alpha_weight * torch.pow(torch.sub(1.0, prob), self.gamma) * logpt 90 | 91 | if self.reduction == 'mean': 92 | loss = loss.mean() 93 | elif self.reduction == 'none': 94 | loss = loss.view(ori_shp) 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Implementation of some unbalanced loss for NLP task like focal_loss, dice_loss, DSC Loss, GHM Loss et.al and adversarial 2 | training like FGM, FGSM, PGD, FreeAT. 3 | 4 | ### Loss Summary 5 | 6 | Here is a loss implementation repository included unbalanced loss 7 | 8 | | Loss Name | paper | Notes | 9 | |:----------------:|:------------------------------------------------------------------------------------------------------------------------------------:|:-----:| 10 | | Weighted CE Loss | [UNet Architectures in Multiplanar Volumetric Segmentation -- Validated on Three Knee MRI Cohorts](https://arxiv.org/abs/2203.08194) | | 11 | | Focal Loss | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) | | 12 | | Dice Loss | [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation](https://arxiv.org/abs/1606.04797) | | 13 | | DSC Loss | [Dice Loss for Data-imbalanced NLP Tasks](https://arxiv.org/pdf/1911.02855.pdf) | | 14 | | GHM Loss | [Gradient Harmonized Single-stage Detector](https://www.aaai.org/ojs/index.php/AAAI/article/download/4877/4750) | | 15 | | Label Smoothing | [When Does Label Smoothing Help?](https://arxiv.org/pdf/1906.02629.pdf) | | 16 | 17 | #### How to use? 18 | 19 | You can find all the loss usage information in test_loss.py. 20 | 21 | Here is a simple demo of usage: 22 | 23 | ```python 24 | import torch 25 | from unbalanced_loss.focal_loss import MultiFocalLoss 26 | 27 | batch_size, num_class = 64, 10 28 | Loss_Func = MultiFocalLoss(num_class=num_class, gamma=2.0, reduction='mean') 29 | 30 | logits = torch.rand(batch_size, num_class, requires_grad=True) # (batch_size, num_classes) 31 | targets = torch.randint(0, num_class, size=(batch_size,)) # (batch_size, ) 32 | 33 | loss = Loss_Func(logits, targets) 34 | loss.backward() 35 | ``` 36 | 37 | ### Adversarial Training Summary 38 | 39 | Here is a Summary of Adversarial Training implementation. 40 | you can find more details in adversarial_training/README.md 41 | 42 | | Adversarial Training | paper | Notes | 43 | |:--------------------:|:-----------------------------------------------------------------------------------------------------:|:-----:| 44 | | FGM | [Fast Gradient Method](https://arxiv.org/pdf/1605.07725.pdf) | | 45 | | FGSM | [Fast Gradient Sign Method](https://arxiv.org/abs/1412.6572) | | 46 | | PGD | [Towards Deep Learning Models Resistant to Adversarial Attacks](https://arxiv.org/pdf/1706.06083.pdf) | | 47 | | FreeAT | [Free Adversarial Training](https://arxiv.org/pdf/1904.12843.pdf) | | 48 | | FreeLB | [Free Large Batch Adversarial Training](https://arxiv.org/pdf/1909.11764v5.pdf) | | 49 | 50 | #### How to use? 51 | 52 | **You can find a simple demo for bert classification in test_bert.py.** 53 | 54 | Here is a simple demo of usage: 55 | You just need to rewrite train function according to input for your model in file PGD.py, then you can use adversarial 56 | training like below. 57 | 58 | ```python 59 | import transformers 60 | from model import bert_classification 61 | from adversarial_training.PGD import PGD 62 | 63 | batch_size, num_class = 64, 10 64 | # model = your_model() 65 | model = bert_classification() 66 | AT_Model = PGD(model) 67 | optimizer = transformers.AdamW(model.parameters(), lr=0.001) 68 | 69 | # rewrite your train function in pgd.py 70 | outputs, loss = AT_Model.train_bert(token, segment, mask, label, optimizer) 71 | ``` 72 | 73 | #### Adversarial Training Results 74 | 75 | here are some results tested on THNews classification task based on bert. 76 | you can find run the code as below: 77 | > cd scripts 78 | > sh run_at.sh 79 | 80 | | Adversarial Training | Time Cost(s/epoch ) | best_acc | 81 | |:----------------------:|:-------------------:|:--------:| 82 | | Normal(not add attack) | 23.77 | 0.773 | 83 | | FGSM | 45.95 | 0.7936 | 84 | | FGM | 47.28 | 0.8008 | 85 | | PGD(k=3) | 87.50 | 0.7963 | 86 | | FreeAT(k=3) | 93.26 | 0.7896 | -------------------------------------------------------------------------------- /unbalanced_loss/dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class BinaryDiceLoss(nn.Module): 8 | """ 9 | Args: 10 | ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient 11 | reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum' 12 | Shapes: 13 | output: A tensor of shape [N, *] without sigmoid activation function applied 14 | target: A tensor of shape same with output 15 | Returns: 16 | Loss tensor according to arg reduction 17 | Raise: 18 | Exception if unexpected reduction 19 | """ 20 | 21 | def __init__(self, ignore_index=None, reduction='mean', **kwargs): 22 | super(BinaryDiceLoss, self).__init__() 23 | self.smooth = 1 # suggest set a large number when target area is large,like '10|100' 24 | self.ignore_index = ignore_index 25 | self.reduction = reduction 26 | self.batch_dice = False # treat a large map when True 27 | if 'batch_loss' in kwargs.keys(): 28 | self.batch_dice = kwargs['batch_loss'] 29 | 30 | def forward(self, output, target, use_sigmoid=True): 31 | assert output.shape[0] == target.shape[0], "output & target batch size don't match" 32 | if use_sigmoid: 33 | output = torch.sigmoid(output) 34 | 35 | if self.ignore_index is not None: 36 | validmask = (target != self.ignore_index).float() 37 | output = output.mul(validmask) # can not use inplace for bp 38 | target = target.float().mul(validmask) 39 | 40 | dim0 = output.shape[0] 41 | if self.batch_dice: 42 | dim0 = 1 43 | 44 | output = output.contiguous().view(dim0, -1) 45 | target = target.contiguous().view(dim0, -1).float() 46 | 47 | num = 2 * torch.sum(torch.mul(output, target), dim=1) + self.smooth 48 | den = torch.sum(output.abs() + target.abs(), dim=1) + self.smooth 49 | 50 | loss = 1 - (num / den) 51 | 52 | if self.reduction == 'mean': 53 | return loss.mean() 54 | elif self.reduction == 'sum': 55 | return loss.sum() 56 | elif self.reduction == 'none': 57 | return loss 58 | else: 59 | raise Exception('Unexpected reduction {}'.format(self.reduction)) 60 | 61 | 62 | class DiceLoss(nn.Module): 63 | """ 64 | Args: 65 | weight: An array of shape [num_classes,] 66 | ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient 67 | output: A tensor of shape [N, C, *] 68 | target: A tensor of same shape with output 69 | other args pass to BinaryDiceLoss 70 | Return: 71 | same as BinaryDiceLoss 72 | """ 73 | 74 | def __init__(self, weight=None, ignore_index=None, **kwargs): 75 | super(DiceLoss, self).__init__() 76 | self.kwargs = kwargs 77 | self.weight = weight 78 | if isinstance(ignore_index, (int, float)): 79 | self.ignore_index = [int(ignore_index)] 80 | elif ignore_index is None: 81 | self.ignore_index = [] 82 | elif isinstance(ignore_index, (list, tuple)): 83 | self.ignore_index = ignore_index 84 | else: 85 | raise TypeError("Expect 'int|float|list|tuple', while get '{}'".format(type(ignore_index))) 86 | 87 | def forward(self, output, target): 88 | assert output.shape == target.shape, 'output & target shape do not match' 89 | dice = BinaryDiceLoss(**self.kwargs) 90 | total_loss = 0 91 | output = F.softmax(output, dim=1) 92 | for i in range(target.shape[1]): 93 | if i not in self.ignore_index: 94 | dice_loss = dice(output[:, i], target[:, i], use_sigmoid=False) 95 | if self.weight is not None: 96 | assert self.weight.shape[0] == target.shape[1], \ 97 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) 98 | dice_loss *= self.weights[i] 99 | total_loss += (dice_loss) 100 | loss = total_loss / (target.size(1) - len(self.ignore_index)) 101 | return loss 102 | 103 | 104 | def test(): 105 | input = torch.rand((3, 1, 32, 32, 32)) 106 | model = nn.Conv3d(1, 4, 3, padding=1) 107 | target = torch.randint(0, 4, (3, 1, 32, 32, 32)).float() 108 | target = make_one_hot(target, num_classes=4) 109 | criterion = DiceLoss(ignore_index=[2, 3], reduction='mean') 110 | loss = criterion(model(input), target) 111 | loss.backward() 112 | print(loss.item()) 113 | 114 | 115 | def make_one_hot(input, num_classes=None): 116 | """Convert class index tensor to one hot encoding tensor. 117 | 118 | Args: 119 | input: A tensor of shape [N, 1, *] 120 | num_classes: An int of number of class 121 | Shapes: 122 | predict: A tensor of shape [N, *] without sigmoid activation function applied 123 | target: A tensor of shape same with predict 124 | Returns: 125 | A tensor of shape [N, num_classes, *] 126 | """ 127 | if num_classes is None: 128 | num_classes = input.max() + 1 129 | shape = np.array(input.shape) 130 | shape[1] = num_classes 131 | shape = tuple(shape) 132 | result = torch.zeros(shape) 133 | result = result.scatter_(1, input.cpu().long(), 1) 134 | return result 135 | 136 | 137 | if __name__ == '__main__': 138 | test() 139 | -------------------------------------------------------------------------------- /bert_model/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import numpy as np 5 | from tqdm import tqdm 6 | from sklearn import metrics 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import Dataset, DataLoader 11 | 12 | import transformers 13 | from transformers import BertModel, AlbertModel, BertConfig, BertTokenizer 14 | 15 | from dataloader import TextDataset, BatchTextCall 16 | from model import MultiClass 17 | from adversarial_training.FGM import FGM 18 | 19 | 20 | def choose_bert_type(path, bert_type="tiny_albert"): 21 | """ 22 | choose bert type for chinese, tiny_albert or macbert(bert) 23 | return: tokenizer, model 24 | """ 25 | 26 | if bert_type == "albert": 27 | model_config = BertConfig.from_pretrained(path) 28 | model = AlbertModel.from_pretrained(path, config=model_config) 29 | elif bert_type == "bert" or bert_type == "roberta": 30 | model_config = BertConfig.from_pretrained(path) 31 | model = BertModel.from_pretrained(path, config=model_config) 32 | else: 33 | model_config, model = None, None 34 | print("ERROR, not choose model!") 35 | 36 | return model_config, model 37 | 38 | 39 | def evaluation(model, test_dataloader, loss_func, label2ind_dict, save_path, valid_or_test="test"): 40 | # model.load_state_dict(torch.load(save_path)) 41 | 42 | model.eval() 43 | total_loss = 0 44 | predict_all = np.array([], dtype=int) 45 | labels_all = np.array([], dtype=int) 46 | 47 | for ind, (token, segment, mask, label) in enumerate(test_dataloader): 48 | token = token.cuda() 49 | segment = segment.cuda() 50 | mask = mask.cuda() 51 | label = label.cuda() 52 | 53 | out = model(token, segment, mask) 54 | loss = loss_func(out, label) 55 | total_loss += loss.detach().item() 56 | 57 | label = label.data.cpu().numpy() 58 | predic = torch.max(out.data, 1)[1].cpu().numpy() 59 | labels_all = np.append(labels_all, label) 60 | predict_all = np.append(predict_all, predic) 61 | 62 | acc = metrics.accuracy_score(labels_all, predict_all) 63 | if valid_or_test == "test": 64 | report = metrics.classification_report(labels_all, predict_all, target_names=label2ind_dict.keys(), digits=4) 65 | confusion = metrics.confusion_matrix(labels_all, predict_all) 66 | return acc, total_loss / len(test_dataloader), report, confusion 67 | return acc, total_loss / len(test_dataloader) 68 | 69 | 70 | def train(config): 71 | label2ind_dict = {'finance': 0, 'realty': 1, 'stocks': 2, 'education': 3, 'science': 4, 'society': 5, 'politics': 6, 72 | 'sports': 7, 'game': 8, 'entertainment': 9} 73 | 74 | os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu 75 | torch.backends.cudnn.benchmark = True 76 | 77 | # load_data(os.path.join(data_dir, "cnews.train.txt"), label_dict) 78 | 79 | tokenizer = BertTokenizer.from_pretrained(config.pretrained_path) 80 | train_dataset_call = BatchTextCall(tokenizer, max_len=config.sent_max_len) 81 | 82 | train_dataset = TextDataset(os.path.join(config.data_dir, "train.txt")) 83 | train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=10, 84 | collate_fn=train_dataset_call) 85 | 86 | test_dataset = TextDataset(os.path.join(config.data_dir, "test.txt")) 87 | test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=True, num_workers=10, 88 | collate_fn=train_dataset_call) 89 | 90 | model_config, bert_encode_model = choose_bert_type(config.pretrained_path, bert_type=config.bert_type) 91 | multi_classification_model = MultiClass(bert_encode_model, model_config, 92 | num_classes=10, pooling_type=config.pooling_type) 93 | multi_classification_model.cuda() 94 | # multi_classification_model.load_state_dict(torch.load(config.save_path)) 95 | # AT_FGM = FGM(model) 96 | 97 | num_train_optimization_steps = len(train_dataloader) * config.epoch 98 | param_optimizer = list(multi_classification_model.named_parameters()) 99 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 100 | optimizer_grouped_parameters = [ 101 | {'params': [p for n, p in param_optimizer 102 | if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 103 | {'params': [p for n, p in param_optimizer 104 | if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 105 | ] 106 | optimizer = transformers.AdamW(optimizer_grouped_parameters, lr=config.lr) 107 | scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 108 | int(num_train_optimization_steps * config.warmup_proportion), 109 | num_train_optimization_steps) 110 | loss_func = F.cross_entropy 111 | 112 | loss_total, top_acc = [], 0 113 | for epoch in range(config.epoch): 114 | multi_classification_model.train() 115 | start_time = time.time() 116 | tqdm_bar = tqdm(train_dataloader, desc="Training epoch{epoch}".format(epoch=epoch)) 117 | for i, (token, segment, mask, label) in enumerate(tqdm_bar): 118 | token = token.cuda() 119 | segment = segment.cuda() 120 | mask = mask.cuda() 121 | label = label.cuda() 122 | 123 | multi_classification_model.zero_grad() 124 | out = multi_classification_model(token, segment, mask) 125 | loss = loss_func(out, label) 126 | loss.backward() 127 | optimizer.step() 128 | scheduler.step() 129 | optimizer.zero_grad() 130 | loss_total.append(loss.detach().item()) 131 | print("Epoch: %03d; loss = %.4f cost time %.4f" % (epoch, np.mean(loss_total), time.time() - start_time)) 132 | 133 | acc, loss, report, confusion = evaluation(multi_classification_model, 134 | test_dataloader, loss_func, label2ind_dict, 135 | config.save_path) 136 | print("Accuracy: %.4f Loss in test %.4f" % (acc, loss)) 137 | if top_acc < acc: 138 | top_acc = acc 139 | # torch.save(multi_classification_model.state_dict(), config.save_path) 140 | print(report, '\n', confusion) 141 | time.sleep(1) 142 | 143 | 144 | if __name__ == "__main__": 145 | parser = argparse.ArgumentParser(description='bert classification') 146 | parser.add_argument("--data_dir", type=str, default="../data/THUCNews/news") 147 | parser.add_argument("--save_path", type=str, default="../ckpt/bert_classification") 148 | parser.add_argument("--pretrained_path", type=str, default="/data/Learn_Project/Backup_Data/bert_chinese", 149 | help="pre-train model path") 150 | parser.add_argument("--bert_type", type=str, default="bert", help="bert or albert") 151 | parser.add_argument("--gpu", type=str, default='0') 152 | parser.add_argument("--epoch", type=int, default=20) 153 | parser.add_argument("--lr", type=float, default=0.005) 154 | parser.add_argument("--warmup_proportion", type=float, default=0.1) 155 | parser.add_argument("--pooling_type", type=str, default="first-last-avg") 156 | parser.add_argument("--batch_size", type=int, default=128) 157 | parser.add_argument("--sent_max_len", type=int, default=44) 158 | parser.add_argument("--do_lower_case", type=bool, default=True, 159 | help="Set this flag true if you are using an uncased model.") 160 | args = parser.parse_args() 161 | 162 | train(args) 163 | -------------------------------------------------------------------------------- /test_bert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import numpy as np 5 | import pandas as pd 6 | from tqdm import tqdm 7 | from sklearn import metrics 8 | import matplotlib.pyplot as plt 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch.utils.data import Dataset, DataLoader 13 | 14 | import transformers 15 | from transformers import BertModel, AlbertModel, BertConfig, BertTokenizer 16 | 17 | from bert_model.dataloader import TextDataset, BatchTextCall 18 | from bert_model.model import MultiClass 19 | 20 | from adversarial_training.FGSM import FGSM 21 | from adversarial_training.FGM import FGM 22 | from adversarial_training.PGD import PGD 23 | from adversarial_training.FreeAT import FreeAT 24 | 25 | 26 | def choose_bert_type(path, bert_type="tiny_albert"): 27 | """ 28 | choose bert type for chinese, tiny_albert or macbert(bert) 29 | return: tokenizer, model 30 | """ 31 | 32 | if bert_type == "albert": 33 | model_config = BertConfig.from_pretrained(path) 34 | model = AlbertModel.from_pretrained(path, config=model_config) 35 | elif bert_type == "bert" or bert_type == "roberta": 36 | model_config = BertConfig.from_pretrained(path) 37 | model = BertModel.from_pretrained(path, config=model_config) 38 | else: 39 | model_config, model = None, None 40 | print("ERROR, not choose model!") 41 | 42 | return model_config, model 43 | 44 | 45 | def choose_attack_type(model, attack_type="FGM"): 46 | if attack_type == 'FGSM': 47 | attack_model = FGSM(model) 48 | elif attack_type == 'FGM': 49 | attack_model = FGM(model) 50 | elif attack_type == 'PGD': 51 | attack_model = PGD(model) 52 | elif attack_type == 'FreeAT': 53 | attack_model = FreeAT(model) 54 | return attack_model 55 | 56 | 57 | def evaluation(model, test_dataloader, loss_func, label2ind_dict, save_path, valid_or_test="test"): 58 | # model.load_state_dict(torch.load(save_path)) 59 | 60 | model.eval() 61 | total_loss = 0 62 | predict_all = np.array([], dtype=int) 63 | labels_all = np.array([], dtype=int) 64 | 65 | for ind, (token, segment, mask, label) in enumerate(test_dataloader): 66 | token = token.cuda() 67 | segment = segment.cuda() 68 | mask = mask.cuda() 69 | label = label.cuda() 70 | 71 | with torch.no_grad(): 72 | out = model(token, segment, mask) 73 | loss = loss_func(out, label) 74 | total_loss += loss.detach().item() 75 | 76 | label = label.data.cpu().numpy() 77 | predic = torch.max(out.data, 1)[1].cpu().numpy() 78 | labels_all = np.append(labels_all, label) 79 | predict_all = np.append(predict_all, predic) 80 | 81 | acc = metrics.accuracy_score(labels_all, predict_all) 82 | return acc, total_loss / len(test_dataloader) 83 | 84 | 85 | def train(config): 86 | label2ind_dict = {'finance': 0, 'realty': 1, 'stocks': 2, 'education': 3, 'science': 4, 'society': 5, 'politics': 6, 87 | 'sports': 7, 'game': 8, 'entertainment': 9} 88 | 89 | os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu 90 | torch.backends.cudnn.benchmark = True 91 | loss_func = F.cross_entropy 92 | 93 | # load_data(os.path.join(data_dir, "cnews.train.txt"), label_dict) 94 | 95 | tokenizer = BertTokenizer.from_pretrained(config.pretrained_path) 96 | train_dataset_call = BatchTextCall(tokenizer, max_len=config.sent_max_len) 97 | 98 | train_dataset = TextDataset(os.path.join(config.data_dir, "train.txt")) 99 | train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=10, 100 | collate_fn=train_dataset_call) 101 | 102 | test_dataset = TextDataset(os.path.join(config.data_dir, "test.txt")) 103 | test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size // 2, shuffle=False, num_workers=10, 104 | collate_fn=train_dataset_call) 105 | 106 | model_config, bert_encode_model = choose_bert_type(config.pretrained_path, bert_type=config.bert_type) 107 | multi_classification_model = MultiClass(bert_encode_model, model_config, 108 | num_classes=10, pooling_type=config.pooling_type) 109 | multi_classification_model.cuda() 110 | # multi_classification_model.load_state_dict(torch.load(config.save_path)) 111 | AT_Model = choose_attack_type(multi_classification_model, attack_type=config.AT_type) 112 | 113 | num_train_optimization_steps = len(train_dataloader) * config.epoch 114 | param_optimizer = list(multi_classification_model.named_parameters()) 115 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 116 | optimizer_grouped_parameters = [ 117 | {'params': [p for n, p in param_optimizer 118 | if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 119 | {'params': [p for n, p in param_optimizer 120 | if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 121 | ] 122 | optimizer = transformers.AdamW(optimizer_grouped_parameters, lr=config.lr) 123 | # optimizer = transformers.AdamW(multi_classification_model.parameters(), lr=config.lr) 124 | 125 | # loss_func = F.cross_entropy 126 | loss_total, top_acc = [], 0 127 | losses, acc_list, eval_loss_list, time_list = [], [], [], [] 128 | 129 | for epoch in range(config.epoch): 130 | multi_classification_model.train() 131 | start_time = time.time() 132 | tqdm_bar = tqdm(train_dataloader, desc="Training epoch{epoch}".format(epoch=epoch)) 133 | for i, (token, segment, mask, label) in enumerate(tqdm_bar): 134 | token = token.cuda() 135 | segment = segment.cuda() 136 | mask = mask.cuda() 137 | label = label.cuda() 138 | outputs, loss = AT_Model.train_bert(token, segment, mask, label, optimizer, attack=config.use_attack) 139 | 140 | loss_total.append(loss.detach().item()) 141 | print("Epoch: %03d; loss = %.4f cost time %.4f" % ( 142 | epoch, np.mean(loss_total), time.time() - start_time)) 143 | losses.append(np.mean(loss_total)) 144 | time_list.append(time.time() - start_time) 145 | time.sleep(0.5) 146 | 147 | acc, loss_test = evaluation(multi_classification_model, 148 | test_dataloader, loss_func, label2ind_dict, 149 | config.save_path) 150 | print("Accuracy: %.4f Loss in test %.4f" % (acc, loss_test)) 151 | acc_list.append(acc) 152 | eval_loss_list.append(loss_test) 153 | 154 | # plt.plot(losses) 155 | # plt.plot(acc_list) 156 | # plt.plot(eval_loss_list) 157 | # plt.show() 158 | 159 | result_file = f"out/ATType{config.AT_type}_UseAT{bool(config.use_attack)}.csv" 160 | df = pd.DataFrame( 161 | {'train_loss': losses, 'time': time_list, 'acc': acc_list, 'eval_loss': eval_loss_list}) 162 | df.to_csv(result_file, index=False, sep='\t', encoding='utf-8') 163 | 164 | 165 | if __name__ == "__main__": 166 | parser = argparse.ArgumentParser(description='bert classification') 167 | parser.add_argument("--data_dir", type=str, default="./data/THUCNews/news") 168 | parser.add_argument("--save_path", type=str, default="../ckpt/bert_classification") 169 | parser.add_argument("--pretrained_path", type=str, default="/data/Learn_Project/Backup_Data/bert_chinese", 170 | help="pre-train model path") 171 | parser.add_argument("--bert_type", type=str, default="bert", help="bert or albert") 172 | parser.add_argument("--AT_type", type=str, default="FGM", help="FGM, PGD or FreeAT") 173 | parser.add_argument("--use_attack", type=int, default=1, help="1 represents use") 174 | parser.add_argument("--gpu", type=str, default='0') 175 | parser.add_argument("--epoch", type=int, default=3) 176 | parser.add_argument("--lr", type=float, default=0.005) 177 | parser.add_argument("--warmup_proportion", type=float, default=0.1) 178 | parser.add_argument("--pooling_type", type=str, default="first-last-avg") 179 | parser.add_argument("--batch_size", type=int, default=128) 180 | parser.add_argument("--sent_max_len", type=int, default=44) 181 | parser.add_argument("--do_lower_case", type=bool, default=True, 182 | help="Set this flag true if you are using an uncased model.") 183 | args = parser.parse_args() 184 | 185 | print(args) 186 | train(args) 187 | --------------------------------------------------------------------------------