├── .github └── FUNDING.yml ├── IFL.md ├── LICENSE ├── README.md ├── _COCOGeneration ├── 1.Visualization.ipynb ├── 1.Visualization.py ├── 2.SplitGeneration.ipynb ├── 2.SplitGeneration.py ├── 3.DataGeneration.ipynb ├── 3.DataGeneration.py ├── 4.Analysis.ipynb ├── 4.Analysis.py ├── README.md ├── cocottributes_py3.jbl ├── mscoco_glt_crop.py └── mscoco_glt_generation.py ├── _ImageNetGeneration ├── 1.MoCoFeature.ipynb ├── 1.SupervisedFeature.ipynb ├── 1.UnsupervisedFeature.ipynb ├── 2.SplitGeneration.ipynb ├── 3.Analysis.ipynb ├── 4.ChangePath.ipynb ├── 5.visualize_tsne.ipynb ├── README.md ├── imagenet_glt_generation.py └── long-tail-distribution.pytorch ├── __init__.py ├── config ├── COCO_BL.yaml ├── COCO_LT.yaml ├── ColorMNIST_BL.yaml ├── ColorMNIST_LT.yaml ├── ImageNet_BL.yaml ├── ImageNet_LT.yaml ├── __init__.py └── algorithms_config.yaml ├── data ├── DT_COCO_LT.py ├── DT_ColorMNIST.py ├── DT_ImageNet_LT.py ├── Sampler_ClassAware.py ├── Sampler_MultiEnv.py ├── __init__.py └── dataloader.py ├── deprecated ├── _ColorMNISTGeneration │ └── 1.DataGeneration.ipynb └── __init__.py ├── figure ├── generalized-long-tail.jpg ├── glt_formulation.jpg ├── ifl.jpg ├── ifl_code.jpg ├── imagenet-glt-statistics.jpg ├── imagenet-glt-visualization.jpg ├── imagenet-glt.jpg ├── mscoco-glt-statistics.jpg ├── mscoco-glt-testgeneration.jpg ├── mscoco-glt-visualization.jpg └── mscoco-glt.jpg ├── main.py ├── models ├── ClassifierCOS.py ├── ClassifierFC.py ├── ClassifierLA.py ├── ClassifierLDAM.py ├── ClassifierLWS.py ├── ClassifierMultiHead.py ├── ClassifierRIDE.py ├── ClassifierTDE.py ├── ResNet.py ├── ResNet_BBN.py ├── ResNet_RIDE.py └── __init__.py ├── test_baseline.py ├── test_la.py ├── test_tde.py ├── train_baseline.py ├── train_bbn.py ├── train_center_dual.py ├── train_center_dual_mixup.py ├── train_center_ldam_dual.py ├── train_center_ride.py ├── train_center_ride_mixup.py ├── train_center_single.py ├── train_center_tade.py ├── train_center_triple.py ├── train_irm_dual.py ├── train_la.py ├── train_ldam.py ├── train_lff.py ├── train_mixup.py ├── train_ride.py ├── train_stage1.py ├── train_stage2.py ├── train_stage2_ride.py ├── train_tade.py ├── train_tde.py ├── utils ├── __init__.py ├── checkpoint_utils.py ├── general_utils.py ├── logger_utils.py ├── test_loader.py ├── train_loader.py └── training_utils.py └── visualization ├── figure1.ipynb ├── figure2.ipynb └── figure3.ipynb /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [KaihuaTang] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: tkhchipaomg 14 | thanks_dev: # Replace with a single thanks.dev username 15 | custom: ['https://kaihuatang.github.io/donate'] 16 | -------------------------------------------------------------------------------- /IFL.md: -------------------------------------------------------------------------------- 1 | # Invariant Feature Learning 2 | 3 | ## The framework of the proposed IFL 4 | 5 |

Invariant Feature Learning.

6 |

Figure 1. Invariant Feature Learning.

7 | 8 | 9 | ## The pseudo code of IFL 10 | 11 |

Pseudo Code.

12 |

Figure 2. Pseudo Code.

-------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | NTUItive Dual License 2 | 3 | Copyright (c) 2022 Kaihua Tang 4 | 5 | NANYANG TECHNOLOGICAL UNIVERSITY - NTUITIVE PTE LTD (NTUITIVE) Dual License Agreement 6 | Non-Commercial Use Only 7 | This NTUITIVE License Agreement, including all exhibits ("NTUITIVE-LA") is a legal agreement between you and NTUITIVE (or “we”) located at 71 Nanyang Drive, NTU Innovation Centre, #01-109, Singapore 637722, a wholly owned subsidiary of Nanyang Technological University (“NTU”) for the software or data identified above, which may include source code, and any associated materials, text or speech files, associated media and "online" or electronic documentation and any updates we provide in our discretion (together, the "Software"). 8 | 9 | By installing, copying, or otherwise using this Software, found at *************** (researcher to provide link to software), you agree to be bound by the terms of this NTUITIVE-LA. If you do not agree, do not install copy or use the Software. The Software is protected by copyright and other intellectual property laws and is licensed, not sold. If you wish to obtain a commercial royalty bearing license to this software please contact us at XXX (provide email). 10 | 11 | SCOPE OF RIGHTS: 12 | You may use, copy, reproduce, and distribute this Software for any non-commercial purpose, subject to the restrictions in this NTUITIVE-LA. Some purposes which can be non-commercial are teaching, academic research, public demonstrations and personal experimentation. You may also distribute this Software with books or other teaching materials, or publish the Software on websites, that are intended to teach the use of the Software for academic or other non-commercial purposes. 13 | You may not use or distribute this Software or any derivative works in any form for commercial purposes. Examples of commercial purposes would be running business operations, licensing, leasing, or selling the Software, distributing the Software for use with commercial products, using the Software in the creation or use of commercial products or any other activity which purpose is to procure a commercial gain to you or others. 14 | If the Software includes source code or data, you may create derivative works of such portions of the Software and distribute the modified Software for non-commercial purposes, as provided herein. 15 | If you distribute the Software or any derivative works of the Software, you will distribute them under the same terms and conditions as in this license, and you will not grant other rights to the Software or derivative works that are different from those provided by this NTUITIVE-LA. 16 | If you have created derivative works of the Software, and distribute such derivative works, you will cause the modified files to carry prominent notices so that recipients know that they are not receiving the original Software. Such notices must state: (i) that you have changed the Software; and (ii) the date of any changes. 17 | 18 | You may not distribute this Software or any derivative works. 19 | In return, we simply require that you agree: 20 | 1. That you will not remove any copyright or other notices from the Software. 21 | 2. That if any of the Software is in binary format, you will not attempt to modify such portions of the Software, or to reverse engineer or decompile them, except and only to the extent authorized by applicable law. 22 | 3. That NTUITIVE is granted back, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, post, distribute, make and have made, sell and transfer your modifications to and/or derivative works of the Software source code or data, for any purpose. 23 | 4. That any feedback about the Software provided by you to us is voluntarily given, and NTUITIVE shall be free to use the feedback as it sees fit without obligation or restriction of any kind, even if the feedback is designated by you as confidential. 24 | 5. THAT THE SOFTWARE COMES "AS IS", WITH NO WARRANTIES. THIS MEANS NO EXPRESS, IMPLIED OR STATUTORY WARRANTY, INCLUDING WITHOUT LIMITATION, WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE, ANY WARRANTY AGAINST INTERFERENCE WITH YOUR ENJOYMENT OF THE SOFTWARE OR ANY WARRANTY OF TITLE OR NON-INFRINGEMENT. THERE IS NO WARRANTY THAT THIS SOFTWARE WILL FULFILL ANY OF YOUR PARTICULAR PURPOSES OR NEEDS. ALSO, YOU MUST PASS THIS DISCLAIMER ON WHENEVER YOU DISTRIBUTE THE SOFTWARE OR DERIVATIVE WORKS. 25 | 6. THAT NEITHER NTUITIVE NOR NTU NOR ANY CONTRIBUTOR TO THE SOFTWARE WILL BE LIABLE FOR ANY DAMAGES RELATED TO THE SOFTWARE OR THIS NTUITIVE-LA, INCLUDING DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL OR INCIDENTAL DAMAGES, TO THE MAXIMUM EXTENT THE LAW PERMITS, NO MATTER WHAT LEGAL THEORY IT IS BASED ON. ALSO, YOU MUST PASS THIS LIMITATION OF LIABILITY ON WHENEVER YOU DISTRIBUTE THE SOFTWARE OR DERIVATIVE WORKS. 26 | 7. That we have no duty of reasonable care or lack of negligence, and we are not obligated to (and will not) provide technical support for the Software. 27 | 8. That if you breach this NTUITIVE-LA or if you sue anyone over patents that you think may apply to or read on the Software or anyone's use of the Software, this NTUITIVE-LA (and your license and rights obtained herein) terminate automatically. Upon any such termination, you shall destroy all of your copies of the Software immediately. Sections 3, 4, 5, 6, 7, 8, 11 and 12 of this NTUITIVE-LA shall survive any termination of this NTUITIVE-LA. 28 | 9. That the patent rights, if any, granted to you in this NTUITIVE-LA only apply to the Software, not to any derivative works you make. 29 | 10. That the Software may be subject to U.S. export jurisdiction at the time it is licensed to you, and it may be subject to additional export or import laws in other places. You agree to comply with all such laws and regulations that may apply to the Software after delivery of the software to you. 30 | 11. That all rights not expressly granted to you in this NTUITIVE-LA are reserved. 31 | 12. That this NTUITIVE-LA shall be construed and controlled by the laws of the Republic of Singapore without regard to conflicts of law. If any provision of this NTUITIVE-LA shall be deemed unenforceable or contrary to law, the rest of this NTUITIVE-LA shall remain in full effect and interpreted in an enforceable manner that most nearly captures the intent of the original language. 32 | 33 | 34 | Do you accept all of the terms of the preceding NTUITIVE-LA license agreement? If you accept the terms, click “I Agree,” then “Next.” Otherwise click “Cancel.” 35 | 36 | Copyright (c) NTUITIVE. All rights reserved. 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /_COCOGeneration/1.Visualization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[2]: 5 | 6 | 7 | # Python 2 8 | #from sklearn.externals import joblib 9 | # Python 3 10 | import joblib 11 | 12 | from PIL import Image, ImageDraw 13 | from io import BytesIO 14 | import json 15 | import joblib 16 | import os 17 | import random 18 | 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | 22 | random.seed(25) 23 | 24 | COCO_IMAGE_PATH = '/data4/coco2014/images/' 25 | COCO_ATTRIBUTE_PATH = './cocottributes_py3.jbl' 26 | COCO_ANNOTATION_PATH = '/data4/coco2014/annotations/' 27 | 28 | 29 | # In[3]: 30 | 31 | 32 | cocottributes = joblib.load(COCO_ATTRIBUTE_PATH) 33 | 34 | 35 | # In[4]: 36 | 37 | 38 | def load_coco(root_path=COCO_ANNOTATION_PATH): 39 | # Load COCO Annotations in val2014 & train2014 40 | coco_data = {'images':[], 'annotations':[]} 41 | with open(os.path.join(root_path, 'instances_train2014.json'), 'r') as f: 42 | train2014 = json.load(f) 43 | with open(os.path.join(root_path, 'instances_val2014.json'), 'r') as f: 44 | val2014 = json.load(f) 45 | coco_data['categories'] = train2014['categories'] 46 | coco_data['images'] += train2014['images'] 47 | coco_data['images'] += val2014['images'] 48 | coco_data['annotations'] += train2014['annotations'] 49 | coco_data['annotations'] += val2014['annotations'] 50 | return coco_data 51 | 52 | 53 | # In[5]: 54 | 55 | 56 | def id_to_path(data_path=COCO_IMAGE_PATH): 57 | id2path = {} 58 | splits = ['train2014', 'val2014'] 59 | for split in splits: 60 | for file in os.listdir(os.path.join(data_path, split)): 61 | if file.endswith(".jpg"): 62 | idx = int(file.split('.')[0].split('_')[-1]) 63 | id2path[idx] = os.path.join(data_path, split, file) 64 | return id2path 65 | 66 | 67 | # In[6]: 68 | 69 | 70 | def print_coco_attributes_instance(cocottributes, coco_data, id2path, ex_ind, sname): 71 | # List of COCO Attributes 72 | attr_details = sorted(cocottributes['attributes'], key=lambda x:x['id']) 73 | attr_names = [item['name'] for item in attr_details] 74 | 75 | # COCO Attributes instance ID for this example 76 | coco_attr_id = list(cocottributes['ann_vecs'].keys())[ex_ind] 77 | 78 | # COCO Attribute annotation vector, attributes in order sorted by dataset ID 79 | instance_attrs = cocottributes['ann_vecs'][coco_attr_id] 80 | 81 | # Print the image and positive attributes for this instance, attribute considered postive if worker vote is > 0.5 82 | pos_attrs = [a for ind, a in enumerate(attr_names) if instance_attrs[ind] > 0.5] 83 | coco_dataset_ann_id = cocottributes['patch_id_to_ann_id'][coco_attr_id] 84 | 85 | coco_annotation = [ann for ann in coco_data['annotations'] if ann['id'] == coco_dataset_ann_id][0] 86 | 87 | img_path = id2path[coco_annotation['image_id']] 88 | img = Image.open(img_path) 89 | polygon = coco_annotation['segmentation'][0] 90 | bbox = coco_annotation['bbox'] 91 | ImageDraw.Draw(img, 'RGBA').polygon(polygon, outline=(255,0,0), fill=(255,0,0,50)) 92 | ImageDraw.Draw(img, 'RGBA').rectangle(((bbox[0], bbox[1]), (bbox[0]+bbox[2], bbox[1]+bbox[3])), outline="red") 93 | img = img.crop((bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3])) 94 | img = np.array(img) 95 | category = [c['name'] for c in coco_data['categories'] if c['id'] == coco_annotation['category_id']][0] 96 | 97 | print_image_with_attributes(img, pos_attrs, category, sname) 98 | 99 | 100 | # In[7]: 101 | 102 | 103 | def print_image_with_attributes(img, attrs, category, sname): 104 | 105 | fig = plt.figure() 106 | plt.imshow(img) 107 | plt.axis('off') # clear x- and y-axes 108 | plt.title(category) 109 | for ind, a in enumerate(attrs): 110 | plt.text(min(img.shape[1]+10, 1000), (ind+1)*img.shape[1]*0.1, a, ha='left') 111 | 112 | #fig.savefig(sname, dpi = 300, bbox_inches='tight') 113 | 114 | 115 | # In[8]: 116 | 117 | 118 | coco_data = load_coco() 119 | id2path = id_to_path() 120 | 121 | 122 | # In[9]: 123 | 124 | 125 | ex_inds = [0,10,50,1000,2000,3000,4000,5000] 126 | sname = '/Users/tangkaihua/Desktop/example_cocottributes_annotation{}.jpg' 127 | for ex_ind in ex_inds: 128 | print_coco_attributes_instance(cocottributes, coco_data, id2path, ex_ind, sname.format(ex_ind)) 129 | 130 | 131 | # In[ ]: 132 | 133 | 134 | 135 | 136 | 137 | # In[ ]: 138 | 139 | 140 | 141 | 142 | 143 | # In[ ]: 144 | 145 | 146 | 147 | 148 | 149 | # In[ ]: 150 | 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /_COCOGeneration/3.DataGeneration.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | # Python 2 8 | #from sklearn.externals import joblib 9 | # Python 3 10 | import joblib 11 | 12 | from PIL import Image, ImageDraw 13 | from io import BytesIO 14 | import json 15 | import joblib 16 | import os 17 | import random 18 | 19 | import torch 20 | import numpy as np 21 | import matplotlib.pyplot as plt 22 | 23 | random.seed(25) 24 | 25 | # annotation path 26 | ANNOTATION_LT = './coco_intra_lt_inter_lt.jbl' 27 | ANNOTATION_BL = './coco_intra_lt_inter_bl.jbl' 28 | COCO_IMAGE_PATH = '/data4/coco2014/images/' 29 | ROOT_PATH = '/data4/' 30 | 31 | 32 | # In[2]: 33 | 34 | 35 | def id_to_path(data_path=COCO_IMAGE_PATH): 36 | id2path = {} 37 | subpath = ['val2014', 'train2014'] 38 | for spath in subpath: 39 | for file in os.listdir(data_path + spath): 40 | if file.endswith(".jpg"): 41 | idx = int(file.split('.')[0].split('_')[-1]) 42 | id2path[idx] = os.path.join(data_path, spath, file) 43 | return id2path 44 | 45 | 46 | # In[ ]: 47 | 48 | 49 | 50 | 51 | 52 | # In[3]: 53 | 54 | 55 | def generate_images_labels(): 56 | annotations = {} 57 | cat2id = cocottributes_all['cat2id'] 58 | id2cat = {i:cat for cat,i in cat2id.items()} 59 | annotations['cat2id'] = cat2id 60 | annotations['id2cat'] = id2cat 61 | annotations['key2path'] = {} 62 | 63 | train_count_array = get_att_count(cocottributes_all, 'train') 64 | 65 | for setname in ['train', 'val', 'test_lt', 'test_bl', 'test_bbl']: 66 | annotations[setname] = {'label':{}, 'frequency':{}, 'attribute':{}, 'path':{}, 'attribute_score':{}} 67 | # check validity 68 | all_keys = list(cocottributes_all[setname]['label'].keys()) 69 | if len(all_keys) == 0: 70 | print('Skip {}'.format(setname)) 71 | continue 72 | 73 | # attribute distribution 74 | annotations[setname]['attribute_dist'] = get_att_count(cocottributes_all, setname) 75 | for cat_id in annotations[setname]['attribute_dist'].keys(): 76 | annotations[setname]['attribute_dist'][cat_id] = annotations[setname]['attribute_dist'][cat_id].tolist() 77 | 78 | # find attribute threshold 79 | for coco_attr_key in all_keys: 80 | cat_id = cocottributes_all[setname]['label'][coco_attr_key] 81 | att_array = cocottributes_all[setname]['attribute'][coco_attr_key] 82 | base_score = normalize_vector(train_count_array[cat_id]) 83 | attr_score = normalize_vector((torch.from_numpy(att_array) > 0.5).float()) 84 | annotations[setname]['attribute_score'][coco_attr_key] = (base_score * attr_score).sum().item() 85 | att_scores = list(annotations[setname]['attribute_score'].values()) 86 | att_scores.sort(reverse=True) 87 | attribute_high_mid_thres = att_scores[len(att_scores) // 3] 88 | attribute_mid_low_thres = att_scores[len(att_scores) // 3 * 2] 89 | 90 | for i, coco_attr_key in enumerate(all_keys): 91 | if (i%1000 == 0): 92 | print('==== Processing : {}'.format(i/len(all_keys))) 93 | # generate image 94 | print_coco_attributes_instance(cocottributes_all, id2path, coco_attr_key, OUTPUT_PATH.format(coco_attr_key), setname) 95 | # generate label 96 | annotations[setname]['label'][coco_attr_key] = cocottributes_all[setname]['label'][coco_attr_key] 97 | annotations[setname]['path'][coco_attr_key] = OUTPUT_PATH.format(coco_attr_key) 98 | annotations[setname]['frequency'][coco_attr_key] = cocottributes_all[setname]['frequency'][coco_attr_key] 99 | if annotations[setname]['attribute_score'][coco_attr_key] > attribute_high_mid_thres: 100 | annotations[setname]['attribute'][coco_attr_key] = 0 101 | elif annotations[setname]['attribute_score'][coco_attr_key] > attribute_mid_low_thres: 102 | annotations[setname]['attribute'][coco_attr_key] = 1 103 | else: 104 | annotations[setname]['attribute'][coco_attr_key] = 2 105 | 106 | with open(OUTPUT_ANNO, 'w') as outfile: 107 | json.dump(annotations, outfile) 108 | 109 | 110 | def normalize_vector(vector): 111 | output = vector / (vector.sum() + 1e-9) 112 | return output 113 | 114 | 115 | def get_att_count(cocottributes, setname): 116 | split_data = cocottributes[setname] 117 | cat2id = cocottributes['cat2id'] 118 | 119 | # update array count 120 | count_array = {} 121 | for item in set(cat2id.values()): 122 | count_array[item] = torch.FloatTensor([0 for i in range(len(cocottributes['attributes']))]) 123 | for key in split_data['label'].keys(): 124 | cat_id = split_data['label'][key] 125 | att_array = split_data['attribute'][key] 126 | count_array[cat_id] = count_array[cat_id] + (torch.from_numpy(att_array) > 0.5).float() 127 | 128 | return count_array 129 | 130 | 131 | def print_coco_attributes_instance(cocottributes, id2path, coco_attr_id, sname, setname): 132 | # List of COCO Attributes 133 | coco_annotation = cocottributes['annotations'][coco_attr_id] 134 | img_path = id2path[coco_annotation['image_id']] 135 | img = Image.open(img_path) 136 | bbox = coco_annotation['bbox'] 137 | 138 | # crop the object bounding box 139 | if bbox[2] < 100: 140 | x1 = max(bbox[0]-50,0) 141 | x2 = min(bbox[0]+50+bbox[2],img.size[0]) 142 | else: 143 | x1 = max(bbox[0]-bbox[2]*0.2,0) 144 | x2 = min(bbox[0]+1.2*bbox[2],img.size[0]) 145 | 146 | if bbox[3] < 100: 147 | y1 = max(bbox[1]-50,0) 148 | y2 = min(bbox[1]+50+bbox[3],img.size[1]) 149 | else: 150 | y1 = max(bbox[1]-bbox[3]*0.2,0) 151 | y2 = min(bbox[1]+1.2*bbox[3],img.size[1]) 152 | 153 | img = img.crop((x1, y1, x2, y2)) 154 | 155 | # padding the rectangular boxes to square boxes 156 | w, h = img.size 157 | pad_size = (max(w,h) - min(w,h))/2 158 | if w > h: 159 | img = img.crop((0, -pad_size, w, h+pad_size)) 160 | else: 161 | img = img.crop((-pad_size, 0, w+pad_size, h)) 162 | 163 | # save image 164 | img.save(sname) 165 | 166 | 167 | # In[ ]: 168 | 169 | 170 | 171 | 172 | 173 | # In[ ]: 174 | 175 | 176 | 177 | 178 | 179 | # In[4]: 180 | 181 | 182 | DATA_TYPE = 'coco_bl' 183 | # output path 184 | OUTPUT_PATH = ROOT_PATH + DATA_TYPE + '/images/{}.jpg' 185 | OUTPUT_ANNO = ROOT_PATH + DATA_TYPE + '/annotations/annotation.json' 186 | 187 | cocottributes_all = joblib.load(ANNOTATION_BL) 188 | 189 | id2path = id_to_path() 190 | generate_images_labels() 191 | 192 | 193 | # In[5]: 194 | 195 | 196 | DATA_TYPE = 'coco_lt' #'coco_lt' / 'coco_half_lt' 197 | # output path 198 | OUTPUT_PATH = ROOT_PATH + DATA_TYPE + '/images/{}.jpg' 199 | OUTPUT_ANNO = ROOT_PATH + DATA_TYPE + '/annotations/annotation.json' 200 | 201 | cocottributes_all = joblib.load(ANNOTATION_LT) 202 | 203 | id2path = id_to_path() 204 | generate_images_labels() 205 | 206 | 207 | # In[ ]: 208 | 209 | 210 | 211 | 212 | 213 | # In[ ]: 214 | 215 | 216 | 217 | 218 | 219 | # In[ ]: 220 | 221 | 222 | 223 | 224 | 225 | # In[ ]: 226 | 227 | 228 | 229 | 230 | 231 | # In[ ]: 232 | 233 | 234 | 235 | 236 | 237 | # In[ ]: 238 | 239 | 240 | 241 | 242 | 243 | # In[ ]: 244 | 245 | 246 | 247 | 248 | -------------------------------------------------------------------------------- /_COCOGeneration/4.Analysis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | # Python 2 8 | #from sklearn.externals import joblib 9 | # Python 3 10 | import joblib 11 | 12 | from PIL import Image, ImageDraw 13 | from io import BytesIO 14 | import json 15 | import joblib 16 | import os 17 | import random 18 | import torch 19 | 20 | import numpy as np 21 | import matplotlib.pyplot as plt 22 | 23 | random.seed(25) 24 | 25 | 26 | ANNOTATION_LT = '/data4/coco_lt/annotations/annotation.json' 27 | ANNOTATION_BL = '/data4/coco_bl/annotations/annotation.json' 28 | 29 | 30 | # In[2]: 31 | 32 | 33 | lt_annotation = json.load(open(ANNOTATION_LT)) 34 | bl_annotation = json.load(open(ANNOTATION_BL)) 35 | 36 | 37 | # In[3]: 38 | 39 | 40 | def show_statistics(vals): 41 | # sort your values in descending order 42 | indSort = np.argsort(vals)[::-1] 43 | # rearrange your data 44 | att_values = np.array(vals)[indSort] 45 | indexes = np.arange(len(vals)) 46 | bar_width = 0.35 47 | plt.bar(indexes, att_values) 48 | plt.show() 49 | 50 | 51 | # In[4]: 52 | 53 | 54 | lt_annotation.keys() 55 | 56 | 57 | # In[5]: 58 | 59 | 60 | def duplication_check(annotation): 61 | count = 0 62 | all_img = [] 63 | all_set = ['train', 'val', 'test_lt', 'test_bl', 'test_bbl'] 64 | 65 | for key in all_set: 66 | count += len(annotation[key]['label']) 67 | all_img = all_img + list(annotation[key]['label'].keys()) 68 | 69 | all_img = set(all_img) 70 | print('Counting Result: {}'.format(count)) 71 | print('Number of Images: {}'.format(len(all_img))) 72 | 73 | 74 | # In[6]: 75 | 76 | 77 | duplication_check(lt_annotation) 78 | 79 | 80 | # In[7]: 81 | 82 | 83 | duplication_check(bl_annotation) 84 | 85 | 86 | # In[8]: 87 | 88 | 89 | def count_label_dist(dataset, num_cls=29): 90 | cls_sizes = [0] * num_cls 91 | for key, val in dataset['label'].items(): 92 | cls_sizes[int(val)] += 1 93 | print(len(dataset['label'])) 94 | return cls_sizes 95 | 96 | 97 | # In[9]: 98 | 99 | 100 | def count_attribuate_dist(dataset, num_cls=204): 101 | if 'attribute_dist' in dataset: 102 | att_all = list(dataset['attribute_dist'].values()) 103 | att_all = [torch.FloatTensor(item) for item in att_all] 104 | att_count = sum(att_all) 105 | print('std: ', (att_count / att_count.sum()).std().item()) 106 | att_count = att_count.tolist() 107 | att_count.sort(reverse=True) 108 | return att_count 109 | else: 110 | return [0] * num_cls 111 | 112 | 113 | # In[10]: 114 | 115 | 116 | show_statistics(count_label_dist(lt_annotation['train'])) 117 | 118 | 119 | # In[11]: 120 | 121 | 122 | show_statistics(count_label_dist(lt_annotation['val'])) 123 | 124 | 125 | # In[12]: 126 | 127 | 128 | show_statistics(count_label_dist(lt_annotation['test_lt'])) 129 | 130 | 131 | # In[13]: 132 | 133 | 134 | show_statistics(count_label_dist(lt_annotation['test_bl'])) 135 | 136 | 137 | # In[14]: 138 | 139 | 140 | show_statistics(count_label_dist(lt_annotation['test_bbl'])) 141 | 142 | 143 | # In[ ]: 144 | 145 | 146 | 147 | 148 | 149 | # In[15]: 150 | 151 | 152 | show_statistics(count_label_dist(bl_annotation['train'])) 153 | 154 | 155 | # In[16]: 156 | 157 | 158 | show_statistics(count_label_dist(bl_annotation['val'])) 159 | 160 | 161 | # In[17]: 162 | 163 | 164 | show_statistics(count_label_dist(bl_annotation['test_lt'])) 165 | 166 | 167 | # In[18]: 168 | 169 | 170 | show_statistics(count_label_dist(bl_annotation['test_bl'])) 171 | 172 | 173 | # In[19]: 174 | 175 | 176 | show_statistics(count_label_dist(bl_annotation['test_bbl'])) 177 | 178 | 179 | # In[ ]: 180 | 181 | 182 | 183 | 184 | 185 | # In[20]: 186 | 187 | 188 | show_statistics(count_attribuate_dist(lt_annotation['train'])) 189 | 190 | 191 | # In[21]: 192 | 193 | 194 | show_statistics(count_attribuate_dist(lt_annotation['val'])) 195 | 196 | 197 | # In[22]: 198 | 199 | 200 | show_statistics(count_attribuate_dist(lt_annotation['test_lt'])) 201 | 202 | 203 | # In[23]: 204 | 205 | 206 | show_statistics(count_attribuate_dist(lt_annotation['test_bl'])) 207 | 208 | 209 | # In[24]: 210 | 211 | 212 | show_statistics(count_attribuate_dist(lt_annotation['test_bbl'])) 213 | 214 | 215 | # In[ ]: 216 | 217 | 218 | 219 | 220 | 221 | # In[25]: 222 | 223 | 224 | show_statistics(count_attribuate_dist(bl_annotation['train'])) 225 | 226 | 227 | # In[26]: 228 | 229 | 230 | show_statistics(count_attribuate_dist(bl_annotation['val'])) 231 | 232 | 233 | # In[27]: 234 | 235 | 236 | show_statistics(count_attribuate_dist(bl_annotation['test_lt'])) 237 | 238 | 239 | # In[28]: 240 | 241 | 242 | show_statistics(count_attribuate_dist(bl_annotation['test_bl'])) 243 | 244 | 245 | # In[29]: 246 | 247 | 248 | show_statistics(count_attribuate_dist(bl_annotation['test_bbl'])) 249 | 250 | 251 | # In[ ]: 252 | 253 | 254 | 255 | 256 | 257 | # In[ ]: 258 | 259 | 260 | 261 | 262 | 263 | # In[30]: 264 | 265 | 266 | JBL_LT = './coco_intra_lt_inter_lt.jbl' 267 | JBL_BL = './coco_intra_lt_inter_bl.jbl' 268 | cocottributes_lt = joblib.load(JBL_LT) 269 | cocottributes_bl = joblib.load(JBL_BL) 270 | 271 | 272 | # In[48]: 273 | 274 | 275 | def print_image_by_index(annotation, cocottributes, setname='train', index=1): 276 | key = list(annotation[setname]['label'].keys())[index] 277 | img_path = annotation[setname]['path'][key] 278 | img = Image.open(img_path) 279 | 280 | att_array = cocottributes[setname]['attribute'][key] 281 | # List of COCO Attributes 282 | attr_details = sorted(cocottributes['attributes'], key=lambda x:x['id']) 283 | attr_names = [item['name'] for item in attr_details] 284 | pos_attrs = [a for ind, a in enumerate(attr_names) if att_array[ind] > 0.5] 285 | # 286 | id2cat = {i:cat for cat,i in cocottributes['cat2id'].items()} 287 | category = id2cat[annotation[setname]['label'][key]] 288 | print_image_with_attributes(img, pos_attrs, category) 289 | 290 | def print_image_with_attributes(img, attrs, category): 291 | 292 | fig = plt.figure() 293 | plt.imshow(img) 294 | plt.axis('off') # clear x- and y-axes 295 | plt.title(category) 296 | for ind, a in enumerate(attrs): 297 | plt.text(min(img.size[1]+10, 1000), (ind+1)*img.size[1]*0.1, a, ha='left') 298 | 299 | 300 | # In[49]: 301 | 302 | 303 | ex_inds = [0,10,50,1000,2000,3000,4000,5000] 304 | 305 | 306 | # In[50]: 307 | 308 | 309 | for ex_ind in ex_inds: 310 | print_image_by_index(lt_annotation, cocottributes_lt, setname='train', index=ex_ind) 311 | 312 | 313 | # In[51]: 314 | 315 | 316 | for ex_ind in ex_inds: 317 | print_image_by_index(bl_annotation, cocottributes_bl, setname='train', index=ex_ind) 318 | 319 | 320 | # In[ ]: 321 | 322 | 323 | lt_annotation[''] 324 | 325 | -------------------------------------------------------------------------------- /_COCOGeneration/README.md: -------------------------------------------------------------------------------- 1 | # MSCOCO-GLT Dataset Generation 2 | 3 | ## Introduction 4 |

Generating MSCOCO-GLT.

5 |

Figure 1. Balancing the Attribute Distribution of MSCOCO-Attribute Through Minimized STD.

6 | 7 | For MSCOCO-GLT, although we can directly obtain the attribute annotation from [MSCOCO-Attribute](https://github.com/genp/cocottributes), each object may have multiple attributes, making strictly balancing the attribute distribution prohibitive, as every time we sample an object with a rare attribute, it usually contain co-occurred frequent attributes as well. Therefore, as illustrated in the above Figure 1, we collect a subset of images with relatively more balanced attribute distributions (by minimizing STD of attributes) to serve as attribute-wise balanced test set. Other long-tailed split and class-wise splits are constructed the same as the ImageNet-GLT. 8 | 9 | Note that the size of Train-CBL split is significantly smaller than Train-GLT, because one single super head class “person” contains over 40% of the entire data, making class-wise data re-balancing extremely expensive. The worse results on Train-CBL further proves the importance of long-tailed classification in real-world applications, as large long-tailed datasets are better than small balanced counterparts. 10 | 11 | ## Steps for MSCOCO-GLT Generation 12 | 1. Download the [MSCOCO dataset](https://cocodataset.org/#download) 13 | 2. Download the [MSCOCO-Attribute Annotations](https://github.com/genp/cocottributes) (you can skip this step, as it has already included in this folder, i.e., cocoattributes_py3.jbl) 14 | 3. Construct training sets and testing sets for GLT, see [SplitGeneration.ipynb](2.SplitGeneration.ipynb). Note that we have three evaluation protocols: 1) Class-wise Long Tail with (Train-GLT, Test-CBL), 2) Attribute-wise Long Tail with (Train-CBL, Test-GBL), and 3) Generalized Long Tail with (Train-GLT, Test-GBL), where Class-wise Long Tail protocol and Generalized Long Tail protocol share the same training set Train-GLT, so there are two annotation files corresponding to two different training sets. 15 | 1. coco_intra_lt_inter_lt.json: 16 | 1. Train-GLT: classes and attribute clusters are both long-tailed 17 | 2. Test-GLT: same as above 18 | 3. Test-CBL: classes are balanced but the attribute clusters are still i.i.d. sampled within each classes, i.e., long-tailed intra-class distribution 19 | 4. Test-GBL: both classes and pretext attributes are balanced 20 | 2. coco_intra_lt_inter_bl.json: 21 | 1. Train-CBL: classes are balanced, but the attribute clusters are still i.i.d. sampled within each classes, i.e., long-tailed intra-class distribution 22 | 2. Test-CBL: same as above 23 | 3. Test-GBL: both classes and pretext attributes are balanced 24 | 4. We further cropped each object from the original MSCOCO images to generate the image data, see [DataGeneration.ipynb](3.DataGeneration.ipynb). 25 | 26 | 27 | ## Download Our Generated Annotation files 28 | You can directly download our generated json files from [the link](https://1drv.ms/u/s!AmRLLNf6bzcitKlpBoTerorv5yaeIw?e=0bw83U). 29 | 30 | Since OneDrive links might be broken in mainland China, we also provide the following alternate link using BaiduNetDisk: 31 | 32 | Link:https://pan.baidu.com/s/1VSpgMVfMFnhQ0RTMzKNIyw Extraction code:1234 33 | 34 | ## Commands 35 | Run the following command to generate the MSCOCO-GLT annotations: 1) coco_intra_lt_inter_lt.json and 2) coco_intra_lt_inter_bl.json 36 | ``` 37 | python mscoco_glt_generation.py --data_path YOUR_COCO_IMAGE_FOLDER --anno_path YOUR_COCO_ANNOTATION_FOLDER --attribute_path ./cocottributes_py3.jbl 38 | ``` 39 | Run the following command to crop objects from MSCOCO images 40 | ``` 41 | python mscoco_glt_crop.py --data_path YOUR_COCO_IMAGE_FOLDER --output_path OUTOUT_FOLDER 42 | ``` 43 | 44 | ## Visualization of Attributes and Dataset Statistics 45 | 46 |

Visualization of MSCOCO-GLT objects.

47 |

Figure 2. Examples of object with attributes in MSCOCO-GLT.

48 | 49 |

Generating Test-GBL for MSCOCO-GLT.

50 |

Figure 3. The algorithm used to generate Test-GBL for MSCOCO-GLT.

51 | 52 | 53 |

Statistics of MSCOCO-GLT Dataset.

54 |

Figure 4. Class distributions and attribute distributions for each split of MSCOCO-GLT benchmark, where the most frequent category is the “person”. Although we cannot strictly balance the attribute distribution in Test-GBL split due to the fact that both head attributes and tail attributes can co-occur in one object, the selected Test-GBL has lower standard deviation of attributes than other splits..

55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /_COCOGeneration/cocottributes_py3.jbl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/_COCOGeneration/cocottributes_py3.jbl -------------------------------------------------------------------------------- /_COCOGeneration/mscoco_glt_crop.py: -------------------------------------------------------------------------------- 1 | # Python 2 2 | #from sklearn.externals import joblib 3 | # Python 3 4 | import joblib 5 | 6 | import argparse 7 | from PIL import Image, ImageDraw 8 | from io import BytesIO 9 | import json 10 | import joblib 11 | import os 12 | import random 13 | 14 | import torch 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | 18 | # annotation path 19 | ANNOTATION_LT = './coco_intra_lt_inter_lt.jbl' 20 | ANNOTATION_BL = './coco_intra_lt_inter_bl.jbl' 21 | ROOT_PATH = '/data4/' 22 | 23 | 24 | # ============================================================================ 25 | # argument parser 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--data_path', default='/data4/coco2014/images/', type=str, help='indicate the path of images.') 28 | parser.add_argument('--output_path', default='/data4/coco_glt/', type=str, help='output path.') 29 | parser.add_argument('--seed', default=25, type=int, help='Fix the random seed for reproduction. Default is 25.') 30 | args = parser.parse_args() 31 | 32 | 33 | def id_to_path(data_path=args.data_path): 34 | id2path = {} 35 | subpath = ['val2014', 'train2014'] 36 | for spath in subpath: 37 | for file in os.listdir(data_path + spath): 38 | if file.endswith(".jpg"): 39 | idx = int(file.split('.')[0].split('_')[-1]) 40 | id2path[idx] = os.path.join(data_path, spath, file) 41 | return id2path 42 | 43 | def generate_images_labels(output_img_path, output_ano_path, cocottributes_all): 44 | annotations = {} 45 | cat2id = cocottributes_all['cat2id'] 46 | id2cat = {i:cat for cat,i in cat2id.items()} 47 | annotations['cat2id'] = cat2id 48 | annotations['id2cat'] = id2cat 49 | annotations['key2path'] = {} 50 | 51 | train_count_array = get_att_count(cocottributes_all, 'train') 52 | 53 | for setname in ['train', 'val', 'test_lt', 'test_bl', 'test_bbl']: 54 | annotations[setname] = {'label':{}, 'frequency':{}, 'attribute':{}, 'path':{}, 'attribute_score':{}} 55 | # check validity 56 | all_keys = list(cocottributes_all[setname]['label'].keys()) 57 | if len(all_keys) == 0: 58 | print('Skip {}'.format(setname)) 59 | continue 60 | 61 | # attribute distribution 62 | annotations[setname]['attribute_dist'] = get_att_count(cocottributes_all, setname) 63 | for cat_id in annotations[setname]['attribute_dist'].keys(): 64 | annotations[setname]['attribute_dist'][cat_id] = annotations[setname]['attribute_dist'][cat_id].tolist() 65 | 66 | # find attribute threshold 67 | for coco_attr_key in all_keys: 68 | cat_id = cocottributes_all[setname]['label'][coco_attr_key] 69 | att_array = cocottributes_all[setname]['attribute'][coco_attr_key] 70 | base_score = normalize_vector(train_count_array[cat_id]) 71 | attr_score = normalize_vector((torch.from_numpy(att_array) > 0.5).float()) 72 | annotations[setname]['attribute_score'][coco_attr_key] = (base_score * attr_score).sum().item() 73 | att_scores = list(annotations[setname]['attribute_score'].values()) 74 | att_scores.sort(reverse=True) 75 | attribute_high_mid_thres = att_scores[len(att_scores) // 3] 76 | attribute_mid_low_thres = att_scores[len(att_scores) // 3 * 2] 77 | 78 | for i, coco_attr_key in enumerate(all_keys): 79 | if (i%1000 == 0): 80 | print('==== Processing : {}'.format(i/len(all_keys))) 81 | # generate image 82 | print_coco_attributes_instance(cocottributes_all, id2path, coco_attr_key, output_img_path.format(coco_attr_key), setname) 83 | # generate label 84 | annotations[setname]['label'][coco_attr_key] = cocottributes_all[setname]['label'][coco_attr_key] 85 | annotations[setname]['path'][coco_attr_key] = output_img_path.format(coco_attr_key) 86 | annotations[setname]['frequency'][coco_attr_key] = cocottributes_all[setname]['frequency'][coco_attr_key] 87 | if annotations[setname]['attribute_score'][coco_attr_key] > attribute_high_mid_thres: 88 | annotations[setname]['attribute'][coco_attr_key] = 0 89 | elif annotations[setname]['attribute_score'][coco_attr_key] > attribute_mid_low_thres: 90 | annotations[setname]['attribute'][coco_attr_key] = 1 91 | else: 92 | annotations[setname]['attribute'][coco_attr_key] = 2 93 | 94 | with open(output_ano_path, 'w') as outfile: 95 | json.dump(annotations, outfile) 96 | 97 | 98 | def normalize_vector(vector): 99 | output = vector / (vector.sum() + 1e-9) 100 | return output 101 | 102 | 103 | def get_att_count(cocottributes, setname): 104 | split_data = cocottributes[setname] 105 | cat2id = cocottributes['cat2id'] 106 | 107 | # update array count 108 | count_array = {} 109 | for item in set(cat2id.values()): 110 | count_array[item] = torch.FloatTensor([0 for i in range(len(cocottributes['attributes']))]) 111 | for key in split_data['label'].keys(): 112 | cat_id = split_data['label'][key] 113 | att_array = split_data['attribute'][key] 114 | count_array[cat_id] = count_array[cat_id] + (torch.from_numpy(att_array) > 0.5).float() 115 | 116 | return count_array 117 | 118 | 119 | def print_coco_attributes_instance(cocottributes, id2path, coco_attr_id, sname, setname): 120 | # List of COCO Attributes 121 | coco_annotation = cocottributes['annotations'][coco_attr_id] 122 | img_path = id2path[coco_annotation['image_id']] 123 | img = Image.open(img_path) 124 | bbox = coco_annotation['bbox'] 125 | 126 | # crop the object bounding box 127 | if bbox[2] < 100: 128 | x1 = max(bbox[0]-50,0) 129 | x2 = min(bbox[0]+50+bbox[2],img.size[0]) 130 | else: 131 | x1 = max(bbox[0]-bbox[2]*0.2,0) 132 | x2 = min(bbox[0]+1.2*bbox[2],img.size[0]) 133 | 134 | if bbox[3] < 100: 135 | y1 = max(bbox[1]-50,0) 136 | y2 = min(bbox[1]+50+bbox[3],img.size[1]) 137 | else: 138 | y1 = max(bbox[1]-bbox[3]*0.2,0) 139 | y2 = min(bbox[1]+1.2*bbox[3],img.size[1]) 140 | 141 | img = img.crop((x1, y1, x2, y2)) 142 | 143 | # padding the rectangular boxes to square boxes 144 | w, h = img.size 145 | pad_size = (max(w,h) - min(w,h))/2 146 | if w > h: 147 | img = img.crop((0, -pad_size, w, h+pad_size)) 148 | else: 149 | img = img.crop((-pad_size, 0, w+pad_size, h)) 150 | 151 | # save image 152 | img.save(sname) 153 | 154 | 155 | DATA_TYPE = 'coco_bl' 156 | # output path 157 | if not os.path.exists(args.output_path): 158 | os.mkdir(args.output_path) 159 | 160 | 161 | # generate Train-GLT 162 | if not os.path.exists(os.path.join(args.output_path, 'coco_glt')): 163 | os.mkdir(os.path.join(args.output_path, 'coco_glt')) 164 | 165 | if not os.path.exists(os.path.join(args.output_path, 'coco_glt', 'images')): 166 | os.mkdir(os.path.join(args.output_path, 'coco_glt', 'images')) 167 | 168 | if not os.path.exists(args.output_path): 169 | os.mkdir(args.output_path) 170 | output_img_path = os.path.join(args.output_path, 'coco_glt', 'images') + '/{}.jpg' 171 | output_ano_path = os.path.join(args.output_path, 'coco_glt', 'annotation.json') 172 | 173 | cocottributes_all = joblib.load(ANNOTATION_LT) 174 | 175 | id2path = id_to_path() 176 | generate_images_labels(output_img_path, output_ano_path, cocottributes_all) 177 | 178 | 179 | # generate Train-CBL 180 | if not os.path.exists(os.path.join(args.output_path, 'coco_cbl')): 181 | os.mkdir(os.path.join(args.output_path, 'coco_cbl')) 182 | 183 | if not os.path.exists(os.path.join(args.output_path, 'coco_cbl', 'images')): 184 | os.mkdir(os.path.join(args.output_path, 'coco_cbl', 'images')) 185 | 186 | if not os.path.exists(args.output_path): 187 | os.mkdir(args.output_path) 188 | output_img_path = os.path.join(args.output_path, 'coco_cbl', 'images') + '/{}.jpg' 189 | output_ano_path = os.path.join(args.output_path, 'coco_cbl', 'annotation.json') 190 | 191 | cocottributes_all = joblib.load(ANNOTATION_BL) 192 | 193 | id2path = id_to_path() 194 | generate_images_labels(output_img_path, output_ano_path, cocottributes_all) 195 | -------------------------------------------------------------------------------- /_ImageNetGeneration/4.ChangePath.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 8, 6 | "id": "60106164", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import json\n", 11 | "\n", 12 | "LT_PATH = './imagenet_sup_intra_lt_inter_lt.json'\n", 13 | "BL_PATH = './imagenet_sup_intra_lt_inter_bl.json'\n", 14 | "\n", 15 | "OLD_PATH = '/data4'\n", 16 | "NEW_PATH = '/home/kaihua.tkh/datasets'" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 9, 22 | "id": "05c1f024", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "anno_file = json.load(open(LT_PATH))\n", 27 | "\n", 28 | "for split in ('train', 'val', 'test_lt', 'test_bl', 'test_bbl'):\n", 29 | " split_new_data = {}\n", 30 | " split_data = anno_file[split]\n", 31 | " for k_type, val in split_data.items():\n", 32 | " split_new_data[k_type] = {}\n", 33 | " for path, label in val.items():\n", 34 | " assert path.startswith(OLD_PATH) \n", 35 | " new_path = NEW_PATH + path[len(OLD_PATH):]\n", 36 | " split_new_data[k_type][new_path] = label\n", 37 | " anno_file[split] = split_new_data\n", 38 | "\n", 39 | "with open(LT_PATH, 'w') as outfile:\n", 40 | " json.dump(anno_file, outfile)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 10, 46 | "id": "b843650f", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "anno_file = json.load(open(BL_PATH))\n", 51 | "\n", 52 | "for split in ('train', 'val', 'test_lt', 'test_bl', 'test_bbl'):\n", 53 | " split_new_data = {}\n", 54 | " split_data = anno_file[split]\n", 55 | " for k_type, val in split_data.items():\n", 56 | " split_new_data[k_type] = {}\n", 57 | " for path, label in val.items():\n", 58 | " assert path.startswith(OLD_PATH) \n", 59 | " new_path = NEW_PATH + path[len(OLD_PATH):]\n", 60 | " split_new_data[k_type][new_path] = label\n", 61 | " anno_file[split] = split_new_data\n", 62 | "\n", 63 | "with open(BL_PATH, 'w') as outfile:\n", 64 | " json.dump(anno_file, outfile)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "id": "5a7bdaf8", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [] 74 | } 75 | ], 76 | "metadata": { 77 | "kernelspec": { 78 | "display_name": "Python 3", 79 | "language": "python", 80 | "name": "python3" 81 | }, 82 | "language_info": { 83 | "codemirror_mode": { 84 | "name": "ipython", 85 | "version": 3 86 | }, 87 | "file_extension": ".py", 88 | "mimetype": "text/x-python", 89 | "name": "python", 90 | "nbconvert_exporter": "python", 91 | "pygments_lexer": "ipython3", 92 | "version": "3.6.10" 93 | } 94 | }, 95 | "nbformat": 4, 96 | "nbformat_minor": 5 97 | } 98 | -------------------------------------------------------------------------------- /_ImageNetGeneration/README.md: -------------------------------------------------------------------------------- 1 | # ImageNet-GLT Dataset Generation 2 | 3 | ## Introduction 4 | 5 |

Generating ImageNet-GLT.

6 |

Figure 1. Collecting an “Attribute-Wise Balanced” Test Set for ImageNet.

7 | 8 | For [ImageNet](https://ieeexplore.ieee.org/abstract/document/5206848) dataset, there is no ground-truth attribute labels for us to monitor the intra-class distribution and construct the GLT benchmark. However, attribute-wise long-tailed distribution, i.e., intra-class imbalance, is commonly existing in all kinds of data. Therefore, we can apply the clustering algorithms (e.g. K-Means in our project) to create a set of clusters within each class as ``pretext attributes''. In other words, each cluster represents a meta attribute layout for this class, e.g., one cluster for cat category can be ginger cat in house, another cluster for cat category is black cat on street, etc. 9 | 10 | Note that our clustered attributes here are purely constructed based on intra-class feature distribution, so it stands for all kinds of factors causing the intra-class variance, including object-level attributes like textures, or image-level attributes like contexts. 11 | 12 | ## Steps for ImageNet-GLT Generation 13 | 1. Download the [ImageNet dataset](https://image-net.org/download.php) 14 | 2. Use a pre-trained model to extract features (our default choice is ii: A torchvision Pre-trained ResNet): 15 | 1. [MoCo Pre-trained Model](https://github.com/facebookresearch/moco), see [MoCoFeature.ipynb](1.MoCoFeature.ipynb) 16 | 2. Torchvision Pre-trained ResNet, see [SupervisedFeature.ipynb](1.SupervisedFeature.ipynb) 17 | 3. Using MMD-VAE with reconstruction loss to train a model and extract features, see [UnsupervisedFeature.ipynb](1.UnsupervisedFeature.ipynb) 18 | 3. Apply the K-Means algorithm to cluster images within each class into K (6 by default) pretext attributes based on extracted features 19 | 4. Construct training sets and testing sets for GLT, see [SplitGeneration.ipynb](2.SplitGeneration.ipynb). Note that we have three evaluation protocols: 1) Class-wise Long Tail with (Train-GLT, Test-CBL), 2) Attribute-wise Long Tail with (Train-CBL, Test-GBL), and 3) Generalized Long Tail with (Train-GLT, Test-GBL), where Class-wise Long Tail protocol and Generalized Long Tail protocol share the same training set Train-GLT, so there are two annotation files corresponding to two different training sets. 20 | 1. imagenet_sup_intra_lt_inter_lt.json: 21 | 1. Train-GLT: classes and attribute clusters are both long-tailed 22 | 2. Test-GLT: same as above 23 | 3. Test-CBL: classes are balanced but the attribute clusters are still i.i.d. sampled within each classes, i.e., long-tailed intra-class distribution 24 | 4. Test-GBL: both classes and pretext attributes are balanced 25 | 2. imagenet_sup_intra_lt_inter_bl.json: 26 | 1. Train-CBL: classes are balanced, but the attribute clusters are still i.i.d. sampled within each classes, i.e., long-tailed intra-class distribution 27 | 2. Test-CBL: same as above 28 | 3. Test-GBL: both classes and pretext attributes are balanced 29 | 30 | 31 | ## Download Our Generated Annotation files 32 | You can directly download our generated json files from [the link](https://1drv.ms/u/s!AmRLLNf6bzcitKlpBoTerorv5yaeIw?e=0bw83U). 33 | 34 | Since OneDrive links might be broken in mainland China, we also provide the following alternate link using BaiduNetDisk: 35 | 36 | Link:https://pan.baidu.com/s/1VSpgMVfMFnhQ0RTMzKNIyw Extraction code:1234 37 | 38 | ## Commands 39 | Run the following command to generate the ImageNet-GLT annotations: 1) imagenet_sup_intra_lt_inter_lt.json and 2) imagenet_sup_intra_lt_inter_bl.json 40 | ``` 41 | python imagenet_glt_generation.py --data_path YOUR_IMAGENET_FOLDER_FOR_TRAINSET 42 | ``` 43 | 44 | ## Visualization of Clusters and Dataset Statistics 45 | 46 |

Visualization of ImageNet-GLT Clusters.

47 |

Figure 2. Examples of feature clusters using KMeans (through t-SNE) and the corresponding visualized images for each cluster in the original ImageNet.

48 | 49 | 50 |

Statistics of ImageNet-GLT Dataset.

51 |

Figure 3. Class distributions and cluster distributions for each split of ImageNet-GLT benchmark. Note that clusters may represent different attribute layouts in each classes, so ImageNet-GLT actually has 6 × 1000 pretext attributes rather than 6, i.e, each column in the cluster distribution stands for 1000 pretext attributes having the same frequency.

52 | -------------------------------------------------------------------------------- /_ImageNetGeneration/long-tail-distribution.pytorch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/_ImageNetGeneration/long-tail-distribution.pytorch -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/__init__.py -------------------------------------------------------------------------------- /config/COCO_BL.yaml: -------------------------------------------------------------------------------- 1 | output_dir: null 2 | 3 | dataset: 4 | name: 'MSCOCO-BL' 5 | data_path: ./data/coco_bl/images 6 | anno_path: ./data/coco_bl/annotations/annotation.json 7 | # lt : test distribution is long-tailed in both cateogies and atttributes 8 | # bl : the category is balanced 9 | # bbl : both category and attribute are balanced 10 | testset: 'test_lt' # test_lt / test_bl / test_bbl 11 | rgb_mean: [0.473, 0.429, 0.370] 12 | rgb_std: [0.277, 0.268, 0.274] 13 | rand_aug: False 14 | 15 | sampler: default 16 | 17 | networks: 18 | type: ResNext 19 | ResNext: 20 | def_file: ./models/ResNet.py 21 | params: {m_type: 'resnext50'} 22 | BBN: 23 | def_file: ./models/ResNet_BBN.py 24 | params: {m_type: 'bbn_resnet50'} 25 | RIDE: 26 | def_file: ./models/ResNet_RIDE.py 27 | params: {m_type: 'resnext50', num_experts: 3, reduce_dimension: True} 28 | 29 | classifiers: 30 | type: FC 31 | FC: 32 | def_file: ./models/ClassifierFC.py 33 | params: {feat_dim: 2048, num_classes: 29} 34 | RIDE: 35 | def_file: ./models/ClassifierRIDE.py 36 | params: {feat_dim: 1536, num_classes: 29, num_experts: 3, use_norm: True} 37 | BBN: 38 | def_file: ./models/ClassifierFC.py 39 | params: {feat_dim: 4096, num_classes: 29} 40 | COS: 41 | def_file: ./models/ClassifierCOS.py 42 | params: {feat_dim: 2048, num_classes: 29, num_head: 1, tau: 30.0} 43 | LDAM: 44 | def_file: ./models/ClassifierLDAM.py 45 | params: {feat_dim: 2048, num_classes: 29} 46 | TDE: 47 | def_file: ./models/ClassifierTDE.py 48 | params: {feat_dim: 2048, num_classes: 29, num_head: 2, tau: 16.0, alpha: 0.0, gamma: 0.03125} 49 | LA: 50 | def_file: ./models/ClassifierLA.py 51 | params: {feat_dim: 2048, num_classes: 29, posthoc: True, loss: False} 52 | LWS: 53 | def_file: ./models/ClassifierLWS.py 54 | params: {feat_dim: 2048, num_classes: 29} 55 | MultiHead: 56 | def_file: ./models/ClassifierMultiHead.py 57 | params: {feat_dim: 2048, num_classes: 29} 58 | 59 | training_opt: 60 | type: baseline # baseline / mixup / two_stage2 61 | num_epochs: 100 62 | batch_size: 256 63 | data_workers: 4 64 | loss: 'CrossEntropy' # CrossEntropy / Focal / BalancedSoftmax / LDAM 65 | loss_params: {alpha: 1.0, gamma: 2.0} 66 | optimizer: 'SGD' # 'Adam' / 'SGD' 67 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005} 68 | scheduler: 'cosine' # 'cosine' / 'step' / 'multistep' 69 | scheduler_params: {endlr: 0.0, gamma: 0.1, step_size: 35, milestones: [120, 160]} 70 | 71 | testing_opt: 72 | type: baseline # baseline / TDE 73 | 74 | logger_opt: 75 | print_grad: false 76 | print_iter: 100 77 | 78 | checkpoint_opt: 79 | checkpoint_step: 10 80 | checkpoint_name: 'test_model.pth' 81 | 82 | saving_opt: 83 | save_all: false -------------------------------------------------------------------------------- /config/COCO_LT.yaml: -------------------------------------------------------------------------------- 1 | output_dir: null 2 | 3 | dataset: 4 | name: 'MSCOCO-LT' 5 | data_path: ./data/coco_lt/images 6 | anno_path: ./data/coco_lt/annotations/annotation.json 7 | # lt : test distribution is long-tailed in both cateogies and atttributes 8 | # bl : the category is balanced 9 | # bbl : both category and attribute are balanced 10 | testset: 'test_lt' # test_lt / test_bl / test_bbl 11 | rgb_mean: [0.473, 0.429, 0.370] 12 | rgb_std: [0.277, 0.268, 0.274] 13 | rand_aug: False 14 | 15 | sampler: default 16 | 17 | networks: 18 | type: ResNext 19 | ResNext: 20 | def_file: ./models/ResNet.py 21 | params: {m_type: 'resnext50'} 22 | BBN: 23 | def_file: ./models/ResNet_BBN.py 24 | params: {m_type: 'bbn_resnet50'} 25 | RIDE: 26 | def_file: ./models/ResNet_RIDE.py 27 | params: {m_type: 'resnext50', num_experts: 3, reduce_dimension: True} 28 | 29 | classifiers: 30 | type: FC 31 | FC: 32 | def_file: ./models/ClassifierFC.py 33 | params: {feat_dim: 2048, num_classes: 29} 34 | RIDE: 35 | def_file: ./models/ClassifierRIDE.py 36 | params: {feat_dim: 1536, num_classes: 29, num_experts: 3, use_norm: True} 37 | BBN: 38 | def_file: ./models/ClassifierFC.py 39 | params: {feat_dim: 4096, num_classes: 29} 40 | COS: 41 | def_file: ./models/ClassifierCOS.py 42 | params: {feat_dim: 2048, num_classes: 29, num_head: 1, tau: 30.0} 43 | LDAM: 44 | def_file: ./models/ClassifierLDAM.py 45 | params: {feat_dim: 2048, num_classes: 29} 46 | TDE: 47 | def_file: ./models/ClassifierTDE.py 48 | params: {feat_dim: 2048, num_classes: 29, num_head: 2, tau: 16.0, alpha: 1.0, gamma: 0.03125} 49 | LA: 50 | def_file: ./models/ClassifierLA.py 51 | params: {feat_dim: 2048, num_classes: 29, posthoc: True, loss: False} 52 | LWS: 53 | def_file: ./models/ClassifierLWS.py 54 | params: {feat_dim: 2048, num_classes: 29} 55 | MultiHead: 56 | def_file: ./models/ClassifierMultiHead.py 57 | params: {feat_dim: 2048, num_classes: 29} 58 | 59 | training_opt: 60 | type: baseline # baseline / mixup / two_stage2 61 | num_epochs: 100 62 | batch_size: 256 63 | data_workers: 4 64 | loss: 'CrossEntropy' # CrossEntropy / Focal / BalancedSoftmax / LDAM 65 | loss_params: {alpha: 1.0, gamma: 2.0} 66 | optimizer: 'SGD' # 'Adam' / 'SGD' 67 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005} 68 | scheduler: 'cosine' # 'cosine' / 'step' / 'multistep' 69 | scheduler_params: {endlr: 0.0, gamma: 0.1, step_size: 35, milestones: [120, 160]} 70 | 71 | testing_opt: 72 | type: baseline # baseline / TDE 73 | 74 | logger_opt: 75 | print_grad: false 76 | print_iter: 100 77 | 78 | checkpoint_opt: 79 | checkpoint_step: 10 80 | checkpoint_name: 'test_model.pth' 81 | 82 | saving_opt: 83 | save_all: false -------------------------------------------------------------------------------- /config/ColorMNIST_BL.yaml: -------------------------------------------------------------------------------- 1 | output_dir: null 2 | 3 | dataset: 4 | name: 'ColorMNIST-BL' 5 | data_path: './_ColorMNISTGeneration/' 6 | # lt : test distribution is long-tailed in both cateogies and atttributes 7 | # bl : the category is balanced 8 | # bbl : both category and attribute are balanced 9 | testset: 'test_lt' # test_lt / test_bl / test_bbl 10 | cat_ratio: 1.0 11 | att_ratio: 0.1 12 | rand_aug: False 13 | 14 | sampler: default 15 | 16 | networks: 17 | type: ResNext 18 | ResNext: 19 | def_file: ./models/ResNet.py 20 | params: {m_type: 'resnext50'} 21 | BBN: 22 | def_file: ./models/ResNet_BBN.py 23 | params: {m_type: 'bbn_resnet50'} 24 | RIDE: 25 | def_file: ./models/ResNet_RIDE.py 26 | params: {m_type: 'resnext50', num_experts: 3, reduce_dimension: True} 27 | 28 | classifiers: 29 | type: FC 30 | FC: 31 | def_file: ./models/ClassifierFC.py 32 | params: {feat_dim: 2048, num_classes: 3} 33 | RIDE: 34 | def_file: ./models/ClassifierRIDE.py 35 | params: {feat_dim: 1536, num_classes: 3, num_experts: 3, use_norm: True} 36 | BBN: 37 | def_file: ./models/ClassifierFC.py 38 | params: {feat_dim: 4096, num_classes: 3} 39 | COS: 40 | def_file: ./models/ClassifierCOS.py 41 | params: {feat_dim: 2048, num_classes: 3, num_head: 1, tau: 30.0} 42 | LDAM: 43 | def_file: ./models/ClassifierLDAM.py 44 | params: {feat_dim: 2048, num_classes: 3} 45 | TDE: 46 | def_file: ./models/ClassifierTDE.py 47 | params: {feat_dim: 2048, num_classes: 3, num_head: 2, tau: 16.0, alpha: 0.0, gamma: 0.03125} 48 | LA: 49 | def_file: ./models/ClassifierLA.py 50 | params: {feat_dim: 2048, num_classes: 3, posthoc: True, loss: False} 51 | LWS: 52 | def_file: ./models/ClassifierLWS.py 53 | params: {feat_dim: 2048, num_classes: 3} 54 | MultiHead: 55 | def_file: ./models/ClassifierMultiHead.py 56 | params: {feat_dim: 2048, num_classes: 3} 57 | 58 | training_opt: 59 | type: baseline # baseline / mixup / two_stage2 60 | num_epochs: 100 61 | batch_size: 256 62 | data_workers: 4 63 | loss: 'CrossEntropy' # CrossEntropy / Focal / BalancedSoftmax / LDAM 64 | loss_params: {alpha: 1.0, gamma: 2.0} 65 | optimizer: 'SGD' # 'Adam' / 'SGD' 66 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005} 67 | scheduler: 'cosine' # 'cosine' / 'step' / 'multistep' 68 | scheduler_params: {endlr: 0.0, gamma: 0.1, step_size: 35, milestones: [120, 160]} 69 | 70 | testing_opt: 71 | type: baseline # baseline / TDE 72 | 73 | logger_opt: 74 | print_grad: false 75 | print_iter: 100 76 | 77 | checkpoint_opt: 78 | checkpoint_step: 10 79 | checkpoint_name: 'test_model.pth' 80 | 81 | saving_opt: 82 | save_all: false -------------------------------------------------------------------------------- /config/ColorMNIST_LT.yaml: -------------------------------------------------------------------------------- 1 | output_dir: null 2 | 3 | dataset: 4 | name: 'ColorMNIST-LT' 5 | data_path: './_ColorMNISTGeneration/' 6 | # lt : test distribution is long-tailed in both cateogies and atttributes 7 | # bl : the category is balanced 8 | # bbl : both category and attribute are balanced 9 | testset: 'test_lt' # test_lt / test_bl / test_bbl 10 | cat_ratio: 0.1 11 | att_ratio: 0.1 12 | rand_aug: False 13 | 14 | sampler: default 15 | 16 | networks: 17 | type: ResNext 18 | ResNext: 19 | def_file: ./models/ResNet.py 20 | params: {m_type: 'resnext50'} 21 | BBN: 22 | def_file: ./models/ResNet_BBN.py 23 | params: {m_type: 'bbn_resnet50'} 24 | RIDE: 25 | def_file: ./models/ResNet_RIDE.py 26 | params: {m_type: 'resnext50', num_experts: 3, reduce_dimension: True} 27 | 28 | classifiers: 29 | type: FC 30 | FC: 31 | def_file: ./models/ClassifierFC.py 32 | params: {feat_dim: 2048, num_classes: 3} 33 | RIDE: 34 | def_file: ./models/ClassifierRIDE.py 35 | params: {feat_dim: 1536, num_classes: 3, num_experts: 3, use_norm: True} 36 | BBN: 37 | def_file: ./models/ClassifierFC.py 38 | params: {feat_dim: 4096, num_classes: 3} 39 | COS: 40 | def_file: ./models/ClassifierCOS.py 41 | params: {feat_dim: 2048, num_classes: 3, num_head: 1, tau: 30.0} 42 | LDAM: 43 | def_file: ./models/ClassifierLDAM.py 44 | params: {feat_dim: 2048, num_classes: 3} 45 | TDE: 46 | def_file: ./models/ClassifierTDE.py 47 | params: {feat_dim: 2048, num_classes: 3, num_head: 2, tau: 16.0, alpha: 0.1, gamma: 0.03125} 48 | LA: 49 | def_file: ./models/ClassifierLA.py 50 | params: {feat_dim: 2048, num_classes: 3, posthoc: True, loss: False} 51 | LWS: 52 | def_file: ./models/ClassifierLWS.py 53 | params: {feat_dim: 2048, num_classes: 3} 54 | MultiHead: 55 | def_file: ./models/ClassifierMultiHead.py 56 | params: {feat_dim: 2048, num_classes: 3} 57 | 58 | training_opt: 59 | type: baseline # baseline / mixup / two_stage2 60 | num_epochs: 100 61 | batch_size: 256 62 | data_workers: 4 63 | loss: 'CrossEntropy' # CrossEntropy / Focal / BalancedSoftmax / LDAM 64 | loss_params: {alpha: 1.0, gamma: 2.0} 65 | optimizer: 'SGD' # 'Adam' / 'SGD' 66 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005} 67 | scheduler: 'cosine' # 'cosine' / 'step' / 'multistep' 68 | scheduler_params: {endlr: 0.0, gamma: 0.1, step_size: 35, milestones: [120, 160]} 69 | 70 | testing_opt: 71 | type: baseline # baseline / TDE 72 | 73 | logger_opt: 74 | print_grad: false 75 | print_iter: 100 76 | 77 | checkpoint_opt: 78 | checkpoint_step: 10 79 | checkpoint_name: 'test_model.pth' 80 | 81 | saving_opt: 82 | save_all: false -------------------------------------------------------------------------------- /config/ImageNet_BL.yaml: -------------------------------------------------------------------------------- 1 | output_dir: null 2 | 3 | dataset: 4 | name: 'ImageNet-BL' 5 | anno_path: ./_ImageNetGeneration/imagenet_sup_intra_lt_inter_bl.json 6 | # lt : test distribution is long-tailed in both cateogies and atttributes 7 | # bl : the category is balanced 8 | # bbl : both category and attribute are balanced 9 | testset: 'test_lt' # test_lt / test_bl / test_bbl 10 | rgb_mean: [0.485, 0.456, 0.406] 11 | rgb_std: [0.229, 0.224, 0.225] 12 | rand_aug: False 13 | 14 | sampler: default 15 | 16 | networks: 17 | type: ResNext 18 | ResNext: 19 | def_file: ./models/ResNet.py 20 | params: {m_type: 'resnext50'} 21 | BBN: 22 | def_file: ./models/ResNet_BBN.py 23 | params: {m_type: 'bbn_resnet50'} 24 | RIDE: 25 | def_file: ./models/ResNet_RIDE.py 26 | params: {m_type: 'resnext50', num_experts: 3, reduce_dimension: True} 27 | 28 | classifiers: 29 | type: FC 30 | FC: 31 | def_file: ./models/ClassifierFC.py 32 | params: {feat_dim: 2048, num_classes: 1000} 33 | RIDE: 34 | def_file: ./models/ClassifierRIDE.py 35 | params: {feat_dim: 1536, num_classes: 1000, num_experts: 3, use_norm: True} 36 | BBN: 37 | def_file: ./models/ClassifierFC.py 38 | params: {feat_dim: 4096, num_classes: 1000} 39 | COS: 40 | def_file: ./models/ClassifierCOS.py 41 | params: {feat_dim: 2048, num_classes: 1000, num_head: 1, tau: 30.0} 42 | LDAM: 43 | def_file: ./models/ClassifierLDAM.py 44 | params: {feat_dim: 2048, num_classes: 1000} 45 | TDE: 46 | def_file: ./models/ClassifierTDE.py 47 | params: {feat_dim: 2048, num_classes: 1000, num_head: 2, tau: 16.0, alpha: 0.0, gamma: 0.03125} 48 | LA: 49 | def_file: ./models/ClassifierLA.py 50 | params: {feat_dim: 2048, num_classes: 1000, posthoc: True, loss: False} 51 | LWS: 52 | def_file: ./models/ClassifierLWS.py 53 | params: {feat_dim: 2048, num_classes: 1000} 54 | MultiHead: 55 | def_file: ./models/ClassifierMultiHead.py 56 | params: {feat_dim: 2048, num_classes: 1000} 57 | 58 | training_opt: 59 | type: baseline # baseline / mixup / two_stage2 60 | num_epochs: 100 61 | batch_size: 256 62 | data_workers: 4 63 | loss: 'CrossEntropy' # CrossEntropy / Focal / BalancedSoftmax / LDAM 64 | loss_params: {alpha: 1.0, gamma: 2.0} 65 | optimizer: 'SGD' # 'Adam' / 'SGD' 66 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005} 67 | scheduler: 'cosine' # 'cosine' / 'step' / 'multistep' 68 | scheduler_params: {endlr: 0.0, gamma: 0.1, step_size: 35, milestones: [120, 160]} 69 | 70 | testing_opt: 71 | type: baseline # baseline / TDE 72 | 73 | logger_opt: 74 | print_grad: false 75 | print_iter: 100 76 | 77 | checkpoint_opt: 78 | checkpoint_step: 10 79 | checkpoint_name: 'test_model.pth' 80 | 81 | saving_opt: 82 | save_all: false -------------------------------------------------------------------------------- /config/ImageNet_LT.yaml: -------------------------------------------------------------------------------- 1 | output_dir: null 2 | 3 | dataset: 4 | name: 'ImageNet-LT' 5 | anno_path: ./_ImageNetGeneration/imagenet_sup_intra_lt_inter_lt.json 6 | # lt : test distribution is long-tailed in both cateogies and atttributes 7 | # bl : the category is balanced 8 | # bbl : both category and attribute are balanced 9 | testset: 'test_lt' # test_lt / test_bl / test_bbl 10 | rgb_mean: [0.485, 0.456, 0.406] 11 | rgb_std: [0.229, 0.224, 0.225] 12 | rand_aug: False 13 | 14 | sampler: default 15 | 16 | networks: 17 | type: ResNext 18 | ResNext: 19 | def_file: ./models/ResNet.py 20 | params: {m_type: 'resnext50'} 21 | BBN: 22 | def_file: ./models/ResNet_BBN.py 23 | params: {m_type: 'bbn_resnet50'} 24 | RIDE: 25 | def_file: ./models/ResNet_RIDE.py 26 | params: {m_type: 'resnext50', num_experts: 3, reduce_dimension: True} 27 | 28 | classifiers: 29 | type: FC 30 | FC: 31 | def_file: ./models/ClassifierFC.py 32 | params: {feat_dim: 2048, num_classes: 1000} 33 | RIDE: 34 | def_file: ./models/ClassifierRIDE.py 35 | params: {feat_dim: 1536, num_classes: 1000, num_experts: 3, use_norm: True} 36 | BBN: 37 | def_file: ./models/ClassifierFC.py 38 | params: {feat_dim: 4096, num_classes: 1000} 39 | COS: 40 | def_file: ./models/ClassifierCOS.py 41 | params: {feat_dim: 2048, num_classes: 1000, num_head: 1, tau: 30.0} 42 | LDAM: 43 | def_file: ./models/ClassifierLDAM.py 44 | params: {feat_dim: 2048, num_classes: 1000} 45 | TDE: 46 | def_file: ./models/ClassifierTDE.py 47 | params: {feat_dim: 2048, num_classes: 1000, num_head: 2, tau: 16.0, alpha: 2.0, gamma: 0.03125} 48 | LA: 49 | def_file: ./models/ClassifierLA.py 50 | params: {feat_dim: 2048, num_classes: 1000, posthoc: True, loss: False} 51 | LWS: 52 | def_file: ./models/ClassifierLWS.py 53 | params: {feat_dim: 2048, num_classes: 1000} 54 | MultiHead: 55 | def_file: ./models/ClassifierMultiHead.py 56 | params: {feat_dim: 2048, num_classes: 1000} 57 | 58 | training_opt: 59 | type: baseline # baseline / mixup / two_stage2 60 | num_epochs: 100 61 | batch_size: 256 62 | data_workers: 4 63 | loss: 'CrossEntropy' # CrossEntropy / Focal / BalancedSoftmax / LDAM 64 | loss_params: {alpha: 1.0, gamma: 2.0} 65 | optimizer: 'SGD' # 'Adam' / 'SGD' 66 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005} 67 | scheduler: 'cosine' # 'cosine' / 'step' / 'multistep' 68 | scheduler_params: {endlr: 0.0, gamma: 0.1, step_size: 35, milestones: [120, 160]} 69 | 70 | testing_opt: 71 | type: baseline # baseline / TDE 72 | 73 | logger_opt: 74 | print_grad: false 75 | print_iter: 100 76 | 77 | checkpoint_opt: 78 | checkpoint_step: 10 79 | checkpoint_name: 'test_model.pth' 80 | 81 | saving_opt: 82 | save_all: false -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/config/__init__.py -------------------------------------------------------------------------------- /data/DT_COCO_LT.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | 6 | import os 7 | import json 8 | 9 | import torch 10 | import torch.utils.data as data 11 | import torchvision.transforms as transforms 12 | from PIL import Image 13 | 14 | from randaugment import RandAugment 15 | 16 | class COCO_LT(data.Dataset): 17 | def __init__(self, phase, data_path, anno_path, testset, rgb_mean, rgb_std, rand_aug, output_path, logger): 18 | super(COCO_LT, self).__init__() 19 | valid_phase = ['train', 'val', 'test'] 20 | assert phase in valid_phase 21 | if phase == 'train': 22 | full_phase = 'train' 23 | elif phase == 'test': 24 | full_phase = testset 25 | else: 26 | full_phase = phase 27 | logger.info('====== The Current Split is : {}'.format(full_phase)) 28 | self.logger = logger 29 | 30 | self.dataset_info = {} 31 | self.phase = phase 32 | self.rand_aug = rand_aug 33 | self.data_path = data_path 34 | 35 | self.annotations = json.load(open(anno_path)) 36 | self.data = self.annotations[full_phase] 37 | self.transform = self.get_data_transform(phase, rgb_mean, rgb_std) 38 | 39 | # load dataset category info 40 | logger.info('=====> Load dataset category info') 41 | self.id2cat, self.cat2id = self.annotations['id2cat'], self.annotations['cat2id'] 42 | 43 | # load all image info 44 | logger.info('=====> Load image info') 45 | self.img_paths, self.labels, self.attributes, self.frequencies = self.load_img_info() 46 | 47 | # save dataset info 48 | logger.info('=====> Save dataset info') 49 | self.dataset_info['cat2id'] = self.cat2id 50 | self.dataset_info['id2cat'] = self.id2cat 51 | self.save_dataset_info(output_path) 52 | 53 | 54 | def __len__(self): 55 | return len(self.labels) 56 | 57 | 58 | def __getitem__(self, index): 59 | path = self.img_paths[index] 60 | label = self.labels[index] 61 | rarity = self.frequencies[index] 62 | 63 | with open(path, 'rb') as f: 64 | sample = Image.open(f).convert('RGB') 65 | 66 | if self.transform is not None: 67 | sample = self.transform(sample) 68 | 69 | # intra-class attribute SHOULD NOT be used during training 70 | if self.phase != 'train': 71 | attribute = self.attributes[index] 72 | return sample, label, rarity, attribute, index 73 | else: 74 | return sample, label, rarity, index 75 | 76 | 77 | ####################################### 78 | # Load image info 79 | ####################################### 80 | def load_img_info(self): 81 | img_paths = [] 82 | labels = [] 83 | attributes = [] 84 | frequencies = [] 85 | 86 | for key, label in self.data['label'].items(): 87 | img_paths.append(self.data['path'][key]) 88 | labels.append(int(label)) 89 | frequencies.append(int(self.data['frequency'][key])) 90 | 91 | # intra-class attribute SHOULD NOT be used in training 92 | if self.phase != 'train': 93 | att_label = int(self.data['attribute'][key]) 94 | attributes.append(att_label) 95 | 96 | # save dataset info 97 | self.dataset_info['img_paths'] = img_paths 98 | self.dataset_info['labels'] = labels 99 | self.dataset_info['attributes'] = attributes 100 | self.dataset_info['frequencies'] = frequencies 101 | 102 | return img_paths, labels, attributes, frequencies 103 | 104 | 105 | ####################################### 106 | # Save dataset info 107 | ####################################### 108 | def save_dataset_info(self, output_path): 109 | 110 | with open(os.path.join(output_path, 'dataset_info_{}.json'.format(self.phase)), 'w') as f: 111 | json.dump(self.dataset_info, f) 112 | 113 | del self.dataset_info 114 | 115 | 116 | ####################################### 117 | # transform 118 | ####################################### 119 | def get_data_transform(self, phase, rgb_mean, rgb_std): 120 | transform_info = { 121 | 'rgb_mean': rgb_mean, 122 | 'rgb_std': rgb_std, 123 | } 124 | 125 | if phase == 'train': 126 | if self.rand_aug: 127 | self.logger.info('============= Using Rand Augmentation in Dataset ===========') 128 | trans = transforms.Compose([ 129 | transforms.RandomResizedCrop(112, scale=(0.5, 1.0)), 130 | transforms.RandomHorizontalFlip(), 131 | RandAugment(), 132 | transforms.ToTensor(), 133 | transforms.Normalize(rgb_mean, rgb_std) 134 | ]) 135 | transform_info['operations'] = ['RandomResizedCrop(112, scale=(0.5, 1.0)),', 'RandomHorizontalFlip()', 136 | 'RandAugment()', 'ToTensor()', 'Normalize(rgb_mean, rgb_std)'] 137 | else: 138 | self.logger.info('============= Using normal transforms in Dataset ===========') 139 | trans = transforms.Compose([ 140 | transforms.RandomResizedCrop(112, scale=(0.5, 1.0)), 141 | transforms.RandomHorizontalFlip(), 142 | transforms.ToTensor(), 143 | transforms.Normalize(rgb_mean, rgb_std) 144 | ]) 145 | transform_info['operations'] = ['RandomResizedCrop(112, scale=(0.5, 1.0)),', 'RandomHorizontalFlip()', 146 | 'ToTensor()', 'Normalize(rgb_mean, rgb_std)'] 147 | else: 148 | trans = transforms.Compose([ 149 | transforms.Resize(128), 150 | transforms.CenterCrop(112), 151 | transforms.ToTensor(), 152 | transforms.Normalize(rgb_mean, rgb_std) 153 | ]) 154 | transform_info['operations'] = ['Resize(128)', 'CenterCrop(112)', 'ToTensor()', 'Normalize(rgb_mean, rgb_std)'] 155 | 156 | # save dataset info 157 | self.dataset_info['transform_info'] = transform_info 158 | 159 | return trans -------------------------------------------------------------------------------- /data/DT_ColorMNIST.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | 6 | from PIL import Image, ImageDraw 7 | from io import BytesIO 8 | import json 9 | import os 10 | import random 11 | 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | 15 | import torch 16 | import torchvision 17 | import torch.nn as nn 18 | import torch.optim as optim 19 | import torch.utils.data as data 20 | import torchvision.transforms as transforms 21 | 22 | from randaugment import RandAugment 23 | 24 | 25 | class ColorMNIST_LT(torchvision.datasets.MNIST): 26 | def __init__(self, phase, testset, data_path, logger, cat_ratio=1.0, att_ratio=0.1, rand_aug=False): 27 | super(ColorMNIST_LT, self).__init__(root=data_path, train=(phase == 'train'), download=True) 28 | # mnist dataset contains self.data, self.targets 29 | self.dig2label = {0: 0, 1: 0, 2: 0, 3: 0, 4: 1, 5: 1, 6: 1, 7: 2, 8: 2, 9: 2} 30 | self.dig2attri = {} 31 | self.colors = {0:[1,0,0], 1:[0,1,0], 2:[0,0,1]} 32 | self.logger = logger 33 | self.phase = phase 34 | 35 | # ColorMNIST should not use rand augmentation, as its patterns are too simple 36 | # and its color confounder could be over-written by augmentation 37 | assert rand_aug == False 38 | 39 | valid_phase = ['train', 'val', 'test'] 40 | assert phase in valid_phase 41 | if phase == 'train': 42 | full_phase = 'train' 43 | # generate long-tailed data 44 | self.cat_ratio = cat_ratio 45 | self.att_ratio = att_ratio 46 | self.generate_lt_label(cat_ratio) 47 | elif phase == 'test': 48 | full_phase = testset 49 | # generate long-tailed data 50 | if full_phase == 'test_iid': 51 | self.cat_ratio = cat_ratio 52 | self.att_ratio = att_ratio 53 | self.generate_lt_label(cat_ratio) 54 | elif full_phase == 'test_half_bl': 55 | self.cat_ratio = 1.0 56 | self.att_ratio = att_ratio 57 | self.generate_lt_label(1.0) 58 | elif full_phase == 'test_bl': 59 | self.cat_ratio = 1.0 60 | self.att_ratio = 1.0 61 | self.generate_lt_label(1.0) 62 | else: 63 | full_phase = phase 64 | self.cat_ratio = cat_ratio 65 | self.att_ratio = att_ratio 66 | self.generate_lt_label(cat_ratio) 67 | logger.info('====== The Current Split is : {}'.format(full_phase)) 68 | 69 | 70 | 71 | def generate_lt_label(self, ratio=1.0): 72 | self.label2list = {i:[] for i in range(3)} 73 | for img, dig in zip(self.data, self.targets): 74 | label = self.dig2label[int(dig)] 75 | self.label2list[label].append(img) 76 | if ratio == 1.0: 77 | balance_size = min([len(val) for key, val in self.label2list.items()]) 78 | for key, val in self.label2list.items(): 79 | self.label2list[key] = val[:balance_size] 80 | elif ratio < 1.0: 81 | current_size = len(self.label2list[0]) 82 | for key, val in self.label2list.items(): 83 | max_size = len(val) 84 | self.label2list[key] = val[:min(max_size, current_size)] 85 | current_size = int(current_size * ratio) 86 | else: 87 | raise ValueError('Wrong Ratio in ColorMNIST') 88 | 89 | self.labels = [] 90 | self.imgs = [] 91 | for key, val in self.label2list.items(): 92 | for item in val: 93 | self.labels.append(key) 94 | self.imgs.append(item) 95 | self.logger.info('Generate ColorMNIST: label {} has {} images.'.format(key, len(val))) 96 | 97 | 98 | def __len__(self): 99 | return len(self.labels) 100 | 101 | def __getitem__(self, index): 102 | img = self.imgs[index].unsqueeze(0).repeat(3,1,1) 103 | label = self.labels[index] 104 | 105 | # generate tail colors 106 | if random.random() < self.att_ratio: 107 | att_label = random.randint(0,2) 108 | else: 109 | att_label = label 110 | color = self.colors[att_label] 111 | 112 | # assign attribute 113 | img = self.to_color(img, color) 114 | 115 | if self.phase != 'train': 116 | # attribute 117 | attribute = 1 - int(att_label == label) 118 | return img, label, label, attribute, index 119 | else: 120 | return img, label, label, index 121 | 122 | def to_color(self, img, rgb=[1,0,0]): 123 | return (img * torch.FloatTensor(rgb).unsqueeze(-1).unsqueeze(-1)).float() -------------------------------------------------------------------------------- /data/DT_ImageNet_LT.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | 6 | import os 7 | import json 8 | 9 | import torch 10 | import torch.utils.data as data 11 | import torchvision.transforms as transforms 12 | from PIL import Image 13 | 14 | from randaugment import RandAugment 15 | 16 | class ImageNet_LT(data.Dataset): 17 | def __init__(self, phase, anno_path, testset, rgb_mean, rgb_std, rand_aug, output_path, logger): 18 | super(ImageNet_LT, self).__init__() 19 | valid_phase = ['train', 'val', 'test'] 20 | assert phase in valid_phase 21 | if phase == 'train': 22 | full_phase = 'train' 23 | elif phase == 'test': 24 | full_phase = testset 25 | else: 26 | full_phase = phase 27 | logger.info('====== The Current Split is : {}'.format(full_phase)) 28 | self.logger = logger 29 | 30 | self.dataset_info = {} 31 | self.phase = phase 32 | self.rand_aug = rand_aug 33 | 34 | # load annotation 35 | self.annotations = json.load(open(anno_path)) 36 | self.data = self.annotations[full_phase] 37 | 38 | # get transform 39 | self.transform = self.get_data_transform(phase, rgb_mean, rgb_std) 40 | 41 | # load dataset category info 42 | logger.info('=====> Load dataset category info') 43 | self.id2cat, self.cat2id = self.annotations['id2cat'], self.annotations['cat2id'] 44 | 45 | # load all image info 46 | logger.info('=====> Load image info') 47 | self.img_paths, self.labels, self.attributes, self.frequencies = self.load_img_info() 48 | 49 | # save dataset info 50 | logger.info('=====> Save dataset info') 51 | self.dataset_info['cat2id'] = self.cat2id 52 | self.dataset_info['id2cat'] = self.id2cat 53 | self.save_dataset_info(output_path) 54 | 55 | 56 | def __len__(self): 57 | return len(self.labels) 58 | 59 | 60 | def __getitem__(self, index): 61 | path = self.img_paths[index] 62 | label = self.labels[index] 63 | rarity = self.frequencies[index] 64 | 65 | with open(path, 'rb') as f: 66 | sample = Image.open(f).convert('RGB') 67 | 68 | if self.transform is not None: 69 | sample = self.transform(sample) 70 | 71 | # intra-class attribute SHOULD NOT be used during training 72 | if self.phase != 'train': 73 | attribute = self.attributes[index] 74 | return sample, label, rarity, attribute, index 75 | else: 76 | return sample, label, rarity, index 77 | 78 | 79 | ####################################### 80 | # Load image info 81 | ####################################### 82 | def load_img_info(self): 83 | img_paths = [] 84 | labels = [] 85 | attributes = [] 86 | frequencies = [] 87 | 88 | 89 | for path, label in self.data['label'].items(): 90 | img_paths.append(path) 91 | labels.append(int(label)) 92 | frequencies.append(int(self.data['frequency'][path])) 93 | 94 | # intra-class attribute SHOULD NOT be used in training 95 | if self.phase != 'train': 96 | att_label = int(self.data['attribute'][path]) 97 | attributes.append(att_label) 98 | 99 | # save dataset info 100 | self.dataset_info['img_paths'] = img_paths 101 | self.dataset_info['labels'] = labels 102 | self.dataset_info['attributes'] = attributes 103 | self.dataset_info['frequencies'] = frequencies 104 | 105 | return img_paths, labels, attributes, frequencies 106 | 107 | 108 | ####################################### 109 | # Save dataset info 110 | ####################################### 111 | def save_dataset_info(self, output_path): 112 | 113 | with open(os.path.join(output_path, 'dataset_info_{}.json'.format(self.phase)), 'w') as f: 114 | json.dump(self.dataset_info, f) 115 | 116 | del self.dataset_info 117 | 118 | 119 | ####################################### 120 | # transform 121 | ####################################### 122 | def get_data_transform(self, phase, rgb_mean, rgb_std): 123 | transform_info = { 124 | 'rgb_mean': rgb_mean, 125 | 'rgb_std': rgb_std, 126 | } 127 | 128 | if phase == 'train': 129 | if self.rand_aug: 130 | self.logger.info('============= Using Rand Augmentation in Dataset ===========') 131 | trans = transforms.Compose([ 132 | transforms.RandomResizedCrop(112), 133 | transforms.RandomHorizontalFlip(), 134 | RandAugment(), 135 | transforms.ToTensor(), 136 | transforms.Normalize(rgb_mean, rgb_std) 137 | ]) 138 | transform_info['operations'] = ['RandomResizedCrop(112)', 'RandomHorizontalFlip()', 139 | 'RandAugment()', 'ToTensor()', 'Normalize(rgb_mean, rgb_std)'] 140 | else: 141 | self.logger.info('============= Using normal transforms in Dataset ===========') 142 | trans = transforms.Compose([ 143 | transforms.RandomResizedCrop(112), 144 | transforms.RandomHorizontalFlip(), 145 | transforms.ToTensor(), 146 | transforms.Normalize(rgb_mean, rgb_std) 147 | ]) 148 | transform_info['operations'] = ['RandomResizedCrop(112)', 'RandomHorizontalFlip()', 149 | 'ToTensor()', 'Normalize(rgb_mean, rgb_std)'] 150 | else: 151 | trans = transforms.Compose([ 152 | transforms.Resize(128), 153 | transforms.CenterCrop(112), 154 | transforms.ToTensor(), 155 | transforms.Normalize(rgb_mean, rgb_std) 156 | ]) 157 | transform_info['operations'] = ['Resize(128)', 'CenterCrop(112)', 'ToTensor()', 'Normalize(rgb_mean, rgb_std)'] 158 | 159 | # save dataset info 160 | self.dataset_info['transform_info'] = transform_info 161 | 162 | return trans -------------------------------------------------------------------------------- /data/Sampler_ClassAware.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import numpy as np 4 | from torch.utils.data.sampler import Sampler 5 | 6 | 7 | 8 | ################################## 9 | ## Class-aware sampling, partly implemented by frombeijingwithlove 10 | ## github: https://github.com/facebookresearch/classifier-balancing/blob/main/data/ClassAwareSampler.py 11 | ################################## 12 | 13 | class RandomCycleIter: 14 | 15 | def __init__ (self, data, test_mode=False): 16 | self.data_list = list(data) 17 | random.shuffle(self.data_list) 18 | self.length = len(self.data_list) 19 | self.i = self.length - 1 20 | self.test_mode = test_mode 21 | 22 | def __iter__ (self): 23 | return self 24 | 25 | def __next__ (self): 26 | self.i += 1 27 | 28 | if self.i == self.length: 29 | self.i = 0 30 | if not self.test_mode: 31 | random.shuffle(self.data_list) 32 | 33 | return self.data_list[self.i] 34 | 35 | def class_aware_sample_generator (cls_iter, data_iter_list, n, num_samples_cls=1): 36 | 37 | i = 0 38 | j = 0 39 | while i < n: 40 | 41 | # yield next(data_iter_list[next(cls_iter)]) 42 | 43 | if j >= num_samples_cls: 44 | j = 0 45 | 46 | if j == 0: 47 | temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]]*num_samples_cls)) 48 | yield temp_tuple[j] 49 | else: 50 | yield temp_tuple[j] 51 | 52 | i += 1 53 | j += 1 54 | 55 | class ClassAwareSampler (Sampler): 56 | 57 | def __init__(self, data_source, num_samples_cls=1, max_aug=4): 58 | num_images = len(data_source.labels) 59 | num_classes = len(np.unique(data_source.labels)) 60 | self.class_iter = RandomCycleIter(range(num_classes)) 61 | cls_data_list = [list() for _ in range(num_classes)] 62 | for i, label in enumerate(data_source.labels): 63 | cls_data_list[label].append(i) 64 | self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list] 65 | self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list) 66 | if self.num_samples > (num_images * max_aug): 67 | self.num_samples = num_images * max_aug 68 | self.num_samples_cls = num_samples_cls 69 | 70 | def __iter__ (self): 71 | return class_aware_sample_generator(self.class_iter, self.data_iter_list, 72 | self.num_samples, self.num_samples_cls) 73 | 74 | def __len__ (self): 75 | return self.num_samples -------------------------------------------------------------------------------- /data/Sampler_MultiEnv.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | from torch.utils.data.sampler import Sampler 5 | 6 | 7 | class WeightedSampler(Sampler): 8 | def __init__(self, dataset): 9 | self.num_samples = len(dataset) 10 | self.indexes = torch.arange(self.num_samples) 11 | self.weight = torch.zeros_like(self.indexes).fill_(1.0).float() # init weight 12 | 13 | 14 | def __iter__(self): 15 | selected_inds = [] 16 | # MAKE SURE self.weight.sum() == self.num_samples 17 | while((self.weight >= 1.0).sum().item() > 0): 18 | inds = self.indexes[self.weight >= 1.0].tolist() 19 | selected_inds = selected_inds + inds 20 | self.weight = self.weight - 1.0 21 | selected_inds = torch.LongTensor(selected_inds) 22 | # shuffle 23 | current_size = selected_inds.shape[0] 24 | selected_inds = selected_inds[torch.randperm(current_size)] 25 | expand = torch.randperm(self.num_samples) % current_size 26 | indices = selected_inds[expand].tolist() 27 | 28 | assert len(indices) == self.num_samples 29 | return iter(indices) 30 | 31 | def __len__(self): 32 | return self.num_samples 33 | 34 | def set_parameter(self, weight): 35 | self.weight = weight.float() 36 | 37 | 38 | class DistributionSampler(Sampler): 39 | def __init__(self, dataset): 40 | self.num_samples = len(dataset) 41 | self.indexes = torch.arange(self.num_samples) 42 | self.weight = torch.zeros_like(self.indexes).fill_(1.0).float() # init weight 43 | 44 | 45 | def __iter__(self): 46 | self.prob = self.weight / self.weight.sum() 47 | 48 | indices = torch.multinomial(self.prob, self.num_samples, replacement=True).tolist() 49 | assert len(indices) == self.num_samples 50 | return iter(indices) 51 | 52 | def __len__(self): 53 | return self.num_samples 54 | 55 | def set_parameter(self, weight): 56 | self.weight = weight.float() 57 | 58 | 59 | class FixSeedSampler(Sampler): 60 | def __init__(self, dataset): 61 | self.dataset = dataset 62 | self.num_samples = len(dataset) 63 | 64 | def __iter__(self): 65 | # deterministically shuffle based on epoch 66 | g = torch.Generator() 67 | g.manual_seed(self.epoch) 68 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 69 | assert len(indices) == self.num_samples 70 | return iter(indices) 71 | 72 | def __len__(self): 73 | return self.num_samples 74 | 75 | def set_parameter(self, epoch): 76 | self.epoch = epoch 77 | 78 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/data/__init__.py -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | import math 5 | import random 6 | import numpy as np 7 | import torch 8 | import torch.utils.data as data 9 | import torch.distributed as dist 10 | from torch.utils.data.sampler import Sampler 11 | 12 | 13 | from .DT_COCO_LT import COCO_LT 14 | from .DT_ColorMNIST import ColorMNIST_LT 15 | from .DT_ImageNet_LT import ImageNet_LT 16 | 17 | from .Sampler_ClassAware import ClassAwareSampler 18 | from .Sampler_MultiEnv import WeightedSampler, DistributionSampler, FixSeedSampler 19 | 20 | ################################## 21 | # return a dataloader 22 | ################################## 23 | def get_loader(config, phase, testset, logger): 24 | if config['dataset']['name'] in ('MSCOCO-LT', 'MSCOCO-BL'): 25 | split = COCO_LT(phase=phase, 26 | data_path=config['dataset']['data_path'], 27 | anno_path=config['dataset']['anno_path'], 28 | testset=testset, 29 | rgb_mean=config['dataset']['rgb_mean'], 30 | rgb_std=config['dataset']['rgb_std'], 31 | rand_aug = config['dataset']['rand_aug'], 32 | output_path=config['output_dir'], 33 | logger=logger) 34 | elif config['dataset']['name'] in ('ColorMNIST-LT', 'ColorMNIST-BL'): 35 | split = ColorMNIST_LT(phase=phase, 36 | testset=testset, 37 | data_path=config['dataset']['data_path'], 38 | cat_ratio=config['dataset']['cat_ratio'], 39 | att_ratio=config['dataset']['att_ratio'], 40 | rand_aug = config['dataset']['rand_aug'], 41 | logger=logger) 42 | elif config['dataset']['name'] in ('ImageNet-LT', 'ImageNet-BL'): 43 | split = ImageNet_LT(phase=phase, 44 | anno_path=config['dataset']['anno_path'], 45 | testset=testset, 46 | rgb_mean=config['dataset']['rgb_mean'], 47 | rgb_std=config['dataset']['rgb_std'], 48 | rand_aug = config['dataset']['rand_aug'], 49 | output_path=config['output_dir'], 50 | logger=logger) 51 | else: 52 | logger.info('********** ERROR: unidentified dataset **********') 53 | 54 | 55 | 56 | 57 | # create data sampler 58 | sampler_type = config['sampler'] 59 | 60 | # class aware sampling (re-balancing) 61 | if sampler_type == 'ClassAwareSampler' and phase == 'train': 62 | logger.info('======> Sampler Type {}'.format(sampler_type)) 63 | sampler = ClassAwareSampler(split, num_samples_cls=4) 64 | loader = data.DataLoader(split, num_workers=config['training_opt']['data_workers'], 65 | batch_size=config['training_opt']['batch_size'], 66 | sampler=sampler, 67 | pin_memory=True,) 68 | # hard weighted sampling (don't sampling samples with weights smaller than 1.0) 69 | elif sampler_type == 'WeightedSampler' and phase == 'train': 70 | logger.info('======> Sampler Type {}, Sampler Number {}'.format(sampler_type, config['num_sampler'])) 71 | loader = [] 72 | num_sampler = config['num_sampler'] 73 | batch_size = config['training_opt']['batch_size'] 74 | if config['batch_split']: 75 | batch_size = batch_size // num_sampler 76 | for _ in range(num_sampler): 77 | loader.append(data.DataLoader(split, num_workers=config['training_opt']['data_workers'], 78 | batch_size=batch_size, 79 | sampler=WeightedSampler(split), 80 | pin_memory=True,)) 81 | # soft weighted sampling (sampling samples by the provided weights) 82 | elif sampler_type == 'DistributionSampler' and phase == 'train': 83 | logger.info('======> Sampler Type {}, Sampler Number {}'.format(sampler_type, config['num_sampler'])) 84 | loader = [] 85 | num_sampler = config['num_sampler'] 86 | batch_size = config['training_opt']['batch_size'] 87 | if config['batch_split']: 88 | batch_size = batch_size // num_sampler 89 | for _ in range(num_sampler): 90 | loader.append(data.DataLoader(split, num_workers=config['training_opt']['data_workers'], 91 | batch_size=batch_size, 92 | sampler=DistributionSampler(split), 93 | pin_memory=True,)) 94 | # Random Sampling with given seed 95 | elif sampler_type == 'FixSeedSampler' and phase == 'train': 96 | logger.info('======> Sampler Type {}, Sampler Number {}'.format(sampler_type, config['num_sampler'])) 97 | loader = [] 98 | num_sampler = config['num_sampler'] 99 | batch_size = config['training_opt']['batch_size'] 100 | if config['batch_split']: 101 | batch_size = batch_size // num_sampler 102 | for _ in range(num_sampler): 103 | loader.append(data.DataLoader(split, num_workers=config['training_opt']['data_workers'], 104 | batch_size=batch_size, 105 | sampler=FixSeedSampler(split), 106 | pin_memory=True,)) 107 | else: 108 | logger.info('======> Sampler Type Naive Sampling with shuffle type: {}'.format(True if phase == 'train' else False)) 109 | loader = data.DataLoader(split, num_workers=config['training_opt']['data_workers'], 110 | batch_size=config['training_opt']['batch_size'], 111 | shuffle=True if phase == 'train' else False, 112 | pin_memory=True,) 113 | 114 | return loader 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /deprecated/_ColorMNISTGeneration/1.DataGeneration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "5a5db594", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from PIL import Image, ImageDraw\n", 11 | "from io import BytesIO\n", 12 | "import json\n", 13 | "import joblib\n", 14 | "import os\n", 15 | "import requests\n", 16 | "import random\n", 17 | "\n", 18 | "import cv2\n", 19 | "import numpy as np\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "from sklearn.cluster import KMeans\n", 22 | "from skimage import feature as skif\n", 23 | "\n", 24 | "import torch\n", 25 | "import torchvision\n", 26 | "import torch.nn as nn\n", 27 | "import torch.optim as optim\n", 28 | "import torch.utils.data as data\n", 29 | "import torchvision.transforms as transforms\n", 30 | "from torch.utils.data import Dataset, DataLoader, ConcatDataset\n", 31 | "\n", 32 | "random.seed(25)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 132, 38 | "id": "74f34318", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "class ColorMNIST_LT(torchvision.datasets.MNIST):\n", 43 | " def __init__(self, phase, test_type, output_path, logger, cat_ratio=1.0, att_ratio=0.1):\n", 44 | " super(ColorMNIST_LT, self).__init__(root='./', train=(phase == 'train'), download=True)\n", 45 | " # mnist dataset contains self.data, self.targets\n", 46 | " self.dig2label = {0: 0, 1: 0, 2: 0, 3: 0, 4: 1, 5: 1, 6: 1, 7: 2, 8: 2, 9: 2}\n", 47 | " self.dig2attri = {}\n", 48 | " self.colors = {0:[1,0,0], 1:[0,1,0], 2:[0,0,1]}\n", 49 | " \n", 50 | " self.cat_ratio = cat_ratio\n", 51 | " self.att_ratio = att_ratio\n", 52 | " # generate long-tailed data\n", 53 | " self.generate_lt_label(cat_ratio)\n", 54 | " \n", 55 | " \n", 56 | " def generate_lt_label(self, ratio=1.0):\n", 57 | " self.label2list = {i:[] for i in range(3)}\n", 58 | " for img, dig in zip(self.data, self.targets):\n", 59 | " label = self.dig2label[int(dig)]\n", 60 | " self.label2list[label].append(img)\n", 61 | " if ratio == 1.0:\n", 62 | " balance_size = min([len(val) for key, val in self.label2list.items()])\n", 63 | " for key, val in self.label2list.items():\n", 64 | " self.label2list[key] = val[:balance_size]\n", 65 | " elif ratio < 1.0:\n", 66 | " current_size = len(self.label2list[0])\n", 67 | " for key, val in self.label2list.items():\n", 68 | " max_size = len(val)\n", 69 | " self.label2list[key] = val[:min(max_size, current_size)]\n", 70 | " current_size = int(current_size * ratio)\n", 71 | " else:\n", 72 | " raise ValueError('Wrong Ratio in ColorMNIST')\n", 73 | " \n", 74 | " self.lt_labels = []\n", 75 | " self.lt_imgs = []\n", 76 | " for key, val in self.label2list.items():\n", 77 | " for item in val:\n", 78 | " self.lt_labels.append(key)\n", 79 | " self.lt_imgs.append(item)\n", 80 | " print('Generate ColorMNIST: label {} has {} images.'.format(key, len(val)))\n", 81 | " \n", 82 | " \n", 83 | " def __len__(self):\n", 84 | " return len(self.lt_labels)\n", 85 | " \n", 86 | " def __getitem__(self, index):\n", 87 | " img = self.lt_imgs[index].unsqueeze(-1).repeat(1,1,3)\n", 88 | " label = self.lt_labels[index]\n", 89 | " \n", 90 | " # generate tail colors\n", 91 | " if random.random() < self.att_ratio:\n", 92 | " att_label = random.randint(0,2)\n", 93 | " color = self.colors[att_label]\n", 94 | " else:\n", 95 | " color = self.colors[label]\n", 96 | " \n", 97 | " # assign attribute\n", 98 | " img = self.to_color(img, color)\n", 99 | " \n", 100 | " return img, label\n", 101 | " \n", 102 | " def to_color(self, img, rgb=[1,0,0]):\n", 103 | " return (img * torch.FloatTensor(rgb).unsqueeze(0).unsqueeze(0)).byte()" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 133, 109 | "id": "0bddd1e3", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "def visualization(input_tensor):\n", 114 | " return Image.fromarray(input_tensor.numpy()).resize((64,64))" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 135, 120 | "id": "77a12b35", 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | "Generate ColorMNIST: label 0 has 24754 images.\n", 128 | "Generate ColorMNIST: label 1 has 2475 images.\n", 129 | "Generate ColorMNIST: label 2 has 247 images.\n" 130 | ] 131 | } 132 | ], 133 | "source": [ 134 | "dataset = ColorMNIST_LT('train', 'test_iid', '.', None, 0.1, 0.1)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 165, 140 | "id": "72d73ef4", 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "data": { 145 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAIAAAAlC+aJAAAGcUlEQVR4nO3a6W9T2RkG8J/tLI4hIdsEsk0yCYQtIQxQllLUUTuqqn7rp/5X/SP6D1QjVZVGoyJ1OtN2gGGAGZYQAgkhG1lIyOIsjt0PVwc5gypsx6FVm0f+YPn6nvM+55znfZ9z7mUf+9jHPvbx/4zYe+umkiQpaqgBadKssU6GbEktV5QzzH+PShrooY/j9IIRhhjiGYtskiu+5T0nEKOaVnoY4GPO0Q9+oJ16UowwyVbx7b8PAi38nGt0c4SOcKmLBO2c4Qv+VBKBvUUlzXzKH1ggwyaZvM8WGRb4Pa0ldbGHM1BFJx9zhUEaQAJhpCvDPxtooZ6Z4qW8VwQqaGSA33KZw+SIkQujvs0GiZCRUrQyzWu2i+yozIhU28FprnKWHoQMk2ObNZaYJ8tRmungGrWMMc0S6f8UgSo+4Cf8hrO0hd+jGYiR4TVPuckan/IJx6njAt9wm+8Zf88EEhyggSa6ucIVuomHP/yoZKZ5zizHWKGOOlrIsMzz90+ghkEGOUYP3RwOks0nkKOaJtpp4jWrTNFOiioa+YBUwf2WjcAhzvIr+ukKet3eWVwjDlVU0cwhKllghAo+JEsNtVS9ZwIVNNHDcdpCoNGKj/Q6xRqdQdDR1TSzbFNJDW1UESNbTCIqA4EKGviQLo5QGfSKdWYY4y6LXKEzpP9ZXjLBNJsc5hw1YeoK95i7JVCdV62OcjD0nWGNUW4xzCRx2hkiyxR3GeIlWbY4wwYHaaCNJlKsF1DXdkUgGvvz/I7zNOeN3Brj3OCPDNNIF+N8xQy3GWGaZXJBxxvEaOM4H9HMTAEWtUQCsVBrexngHB+CLDG2mGOE+yGjR+t7m+dMcof5vNa2WCEDkhyhlz4yzLxLDyUSiFPLUS5yitrwe5YcK0xwnxHWwCuGGCMRZP32iLxBDSeYY4P5vSBQQT1dnOE8vSTJssgrtljgAY+YCL4t/S5rEM/jUEknp3icVwfLSSDJCS5xgQE6SLHCd9xjnlfMMMViWBhFIR7UnCogHZVCoJZ+PuEUbSTBOP/gOi+CNKvZKNJaRog20NVUlJ3AG+32cJJWkizxkG/4J4+YZ5M4qVCMS0Bs56IqG4EK6miji45Q8J/xGV8znKe5XJBvCfv06K5sYfcWRyBFH2fpphqs8Zw73OdVXpe54kOP7fxe4Ay8U+U7UM9P+QVdYJ1JxvLWfcmI5YUbRR8vLLhCZyBOBa0MMEATWZYZZ5yFXR8o/Ih8ho3CMlihBA7QQh+dNFFBlnlGmWSzyHDfxmae3CMPu8hqAbNaKIGowkfRR3lzmxmeMFlSsn+DaPEk83Y8aaYYY6GMZi5BNcmgXawywUNGWS8p9AiH6eR4OJ6IssItvuRpAUNTdCF7I7XIQg7n+YUSkOQoVxkMdWOOIf7OlyyXl0BsZ67YZoW5UgUQDx77Etc4QQ1LDHOPEV4V1k4RBN7W09YuFk9kOS/ySy7RSI4JbnOb2YLbKWUJ5VhnlqXihz9BkgN8xGUuc5yDYf9wl295zGLBDZbihbKM8z0TxR9lHqCL05zlKr1UMcrXfM4Yk7wq+FhOCQRiYRO4WEw3ieACO+nnIoOcJMUoN/mCz0pakIUSeCPfLHEO00s9G4XdHhnYfgbpo4tGNhnjr/yFh6XKqbgZiOxxnFZO0kmK1XfdFeXKK/yMi7QSY4VhbvJnru+iFBZNIFr0MRo4xgVeMM8qGeIkqaWeJmpJcojTnKaPJuKhCD7gLqO7cyKFEsjlfSLUcJJfc48HTJEOR3QdHOMU7TTSSAuHwDJTzDLMDX4oJuHsikDsLY8ebb0vUkdTODxM0EIXPRyllcZw2pVjgWlGecwTHjDKcri65wQSxPMOnGM0k6KNAZZYI04zLdSR4kDwTmnmgkX7lq+YYJm1Up+uFk0gOipcI802ifBA4ABNdLHOOjHqg12NsMUSz3jGTKi134UN5+5RKIFVppngJd3U5q2lSLiV4Wi2Ou+udSZ4xN+4yyorvCxf9EUQSPOSMZ7QQjt1JPKSUiLsALfJkiHNJEPc4nPulC/oUghkSTPKdeboo5duGvP+E83Ja57zgnGeM8ZTRssbdR6KqAMZXvCaYfo5T3YngQhj3OAO9xgPxmbvHsEXZ6c3mA0BRXboKcmQRmJshPz4mKfM7U3Q+SjldZvIEtdwkNTOw49seIMmeonmv+7dh33sYx/72Mf/Gv4FIJ/2aCjskFoAAAAASUVORK5CYII=\n", 146 | "text/plain": [ 147 | "" 148 | ] 149 | }, 150 | "execution_count": 165, 151 | "metadata": {}, 152 | "output_type": "execute_result" 153 | } 154 | ], 155 | "source": [ 156 | "visualization(dataset[0][0])" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "id": "a9f1a89c", 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "id": "36f89360", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "id": "617f415f", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [] 182 | } 183 | ], 184 | "metadata": { 185 | "kernelspec": { 186 | "display_name": "Python 3", 187 | "language": "python", 188 | "name": "python3" 189 | }, 190 | "language_info": { 191 | "codemirror_mode": { 192 | "name": "ipython", 193 | "version": 3 194 | }, 195 | "file_extension": ".py", 196 | "mimetype": "text/x-python", 197 | "name": "python", 198 | "nbconvert_exporter": "python", 199 | "pygments_lexer": "ipython3", 200 | "version": "3.6.10" 201 | } 202 | }, 203 | "nbformat": 4, 204 | "nbformat_minor": 5 205 | } 206 | -------------------------------------------------------------------------------- /deprecated/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/deprecated/__init__.py -------------------------------------------------------------------------------- /figure/generalized-long-tail.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/generalized-long-tail.jpg -------------------------------------------------------------------------------- /figure/glt_formulation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/glt_formulation.jpg -------------------------------------------------------------------------------- /figure/ifl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/ifl.jpg -------------------------------------------------------------------------------- /figure/ifl_code.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/ifl_code.jpg -------------------------------------------------------------------------------- /figure/imagenet-glt-statistics.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/imagenet-glt-statistics.jpg -------------------------------------------------------------------------------- /figure/imagenet-glt-visualization.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/imagenet-glt-visualization.jpg -------------------------------------------------------------------------------- /figure/imagenet-glt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/imagenet-glt.jpg -------------------------------------------------------------------------------- /figure/mscoco-glt-statistics.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/mscoco-glt-statistics.jpg -------------------------------------------------------------------------------- /figure/mscoco-glt-testgeneration.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/mscoco-glt-testgeneration.jpg -------------------------------------------------------------------------------- /figure/mscoco-glt-visualization.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/mscoco-glt-visualization.jpg -------------------------------------------------------------------------------- /figure/mscoco-glt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/mscoco-glt.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import json 6 | import yaml 7 | import os 8 | import argparse 9 | import torch 10 | import torch.nn as nn 11 | import random 12 | import utils.general_utils as utils 13 | from utils.logger_utils import custom_logger 14 | from data.dataloader import get_loader 15 | from utils.checkpoint_utils import Checkpoint 16 | from utils.training_utils import * 17 | 18 | 19 | from utils.train_loader import train_loader 20 | from utils.test_loader import test_loader 21 | 22 | # ============================================================================ 23 | # argument parser 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--cfg', default=None, type=str, help='Indicate the config file used for the training.') 26 | parser.add_argument('--seed', default=25, type=int, help='Fix the random seed for reproduction. Default is 25.') 27 | parser.add_argument('--phase', default='train', type=str, help='Indicate train/val/test phase.') 28 | parser.add_argument('--load_dir', default=None, type=str, help='Load model from this directory for testing') 29 | parser.add_argument('--output_dir', default=None, type=str, help='Output directory that saves everything.') 30 | parser.add_argument('--require_eval', action='store_true', help='Require evaluation on val set during training.') 31 | parser.add_argument('--logger_name', default='logger_eval', type=str, help='Name of TXT output for the logger.') 32 | # update config settings 33 | parser.add_argument('--lr', default=None, type=float, help='Learning Rate') 34 | parser.add_argument('--testset', default=None, type=str, help='Reset the type of test set.') 35 | parser.add_argument('--loss_type', default=None, type=str, help='Reset the type of loss function.') 36 | parser.add_argument('--model_type', default=None, type=str, help='Reset the type of model.') 37 | parser.add_argument('--train_type', default=None, type=str, help='Reset the type of traning strategy.') 38 | parser.add_argument('--sample_type', default=None, type=str, help='Reset the type of sampling strategy.') 39 | parser.add_argument('--rand_aug', action='store_true', help='Apply Random Augmentation During Training.') 40 | parser.add_argument('--save_all', action='store_true', help='Save All Output Information During Testing.') 41 | 42 | args = parser.parse_args() 43 | 44 | # ============================================================================ 45 | # init logger 46 | if args.output_dir is None: 47 | print('Please specify output directory') 48 | if not os.path.exists(args.output_dir): 49 | os.mkdir(args.output_dir) 50 | if args.phase != 'train': 51 | logger = custom_logger(args.output_dir, name='{}.txt'.format(args.logger_name)) 52 | else: 53 | logger = custom_logger(args.output_dir) 54 | logger.info('========================= Start Main =========================') 55 | 56 | 57 | # ============================================================================ 58 | # fix random seed 59 | logger.info('=====> Using fixed random seed: ' + str(args.seed)) 60 | random.seed(args.seed) 61 | torch.manual_seed(args.seed) 62 | torch.cuda.manual_seed(args.seed) 63 | torch.cuda.manual_seed_all(args.seed) 64 | 65 | # ============================================================================ 66 | # load config 67 | logger.info('=====> Load config from yaml: ' + str(args.cfg)) 68 | with open(args.cfg) as f: 69 | config = yaml.load(f) 70 | 71 | # load detailed settings for each algorithms 72 | logger.info('=====> Load algorithm details from yaml: config/algorithms_config.yaml') 73 | with open('config/algorithms_config.yaml') as f: 74 | algo_config = yaml.load(f) 75 | 76 | # update config 77 | logger.info('=====> Merge arguments from command') 78 | config = utils.update(config, algo_config, args) 79 | 80 | # save config 81 | logger.info('=====> Save config as config.json') 82 | with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: 83 | json.dump(config, f) 84 | utils.print_config(config, logger) 85 | 86 | # ============================================================================ 87 | # training 88 | if args.phase == 'train': 89 | logger.info('========= The Current Training Strategy is {} ========='.format(config['training_opt']['type'])) 90 | train_func = train_loader(config) 91 | training = train_func(args, config, logger, eval=args.require_eval) 92 | training.run() 93 | 94 | 95 | else: 96 | # ============================================================================ 97 | # create model 98 | logger.info('=====> Model construction from: ' + str(config['networks']['type'])) 99 | model_type = config['networks']['type'] 100 | model_file = config['networks'][model_type]['def_file'] 101 | model_args = config['networks'][model_type]['params'] 102 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type'])) 103 | classifier_type = config['classifiers']['type'] 104 | classifier_file = config['classifiers'][classifier_type]['def_file'] 105 | classifier_args = config['classifiers'][classifier_type]['params'] 106 | model = utils.source_import(model_file).create_model(**model_args) 107 | classifier = utils.source_import(classifier_file).create_model(**classifier_args) 108 | 109 | model = nn.DataParallel(model).cuda() 110 | classifier = nn.DataParallel(classifier).cuda() 111 | 112 | # ============================================================================ 113 | # load checkpoint 114 | checkpoint = Checkpoint(config) 115 | ckpt = checkpoint.load(model, classifier, args.load_dir, logger) 116 | 117 | # ============================================================================ 118 | # testing 119 | test_func = test_loader(config) 120 | if args.phase == 'val': 121 | # run validation set 122 | testing = test_func(config, logger, model, classifier, val=True, add_ckpt=ckpt) 123 | testing.run_val(epoch=-1) 124 | else: 125 | assert args.phase == 'test' 126 | # Run a specific test split 127 | if args.testset: 128 | testing = test_func(config, logger, model, classifier, val=False, add_ckpt=ckpt) 129 | testing.run_test() 130 | # Run all test splits 131 | else: 132 | if 'LT' in config['dataset']['name']: 133 | config['dataset']['testset'] = 'test_lt' 134 | testing = test_func(config, logger, model, classifier, val=False, add_ckpt=ckpt) 135 | testing.run_test() 136 | 137 | config['dataset']['testset'] = 'test_bl' 138 | testing = test_func(config, logger, model, classifier, val=False, add_ckpt=ckpt) 139 | testing.run_test() 140 | 141 | config['dataset']['testset'] = 'test_bbl' 142 | testing = test_func(config, logger, model, classifier, val=False, add_ckpt=ckpt) 143 | testing.run_test() 144 | 145 | logger.info('========================= Complete =========================') 146 | -------------------------------------------------------------------------------- /models/ClassifierCOS.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.models as models 8 | 9 | import math 10 | 11 | class ClassifierCOS(nn.Module): 12 | def __init__(self, feat_dim, num_classes=1000, num_head=2, tau=16.0): 13 | super(ClassifierCOS, self).__init__() 14 | 15 | # classifier weights 16 | self.weight = nn.Parameter(torch.Tensor(num_classes, feat_dim).cuda(), requires_grad=True) 17 | self.reset_parameters(self.weight) 18 | 19 | # parameters 20 | self.scale = tau / num_head # 16.0 / num_head 21 | self.num_head = num_head 22 | self.head_dim = feat_dim // num_head 23 | 24 | def reset_parameters(self, weight): 25 | stdv = 1. / math.sqrt(weight.size(1)) 26 | weight.data.uniform_(-stdv, stdv) 27 | 28 | def forward(self, x, add_inputs=None): 29 | normed_x = self.multi_head_call(self.l2_norm, x) 30 | normed_w = self.multi_head_call(self.l2_norm, self.weight) 31 | y = torch.mm(normed_x * self.scale, normed_w.t()) 32 | return y 33 | 34 | def multi_head_call(self, func, x): 35 | assert len(x.shape) == 2 36 | x_list = torch.split(x, self.head_dim, dim=1) 37 | y_list = [func(item) for item in x_list] 38 | assert len(x_list) == self.num_head 39 | assert len(y_list) == self.num_head 40 | return torch.cat(y_list, dim=1) 41 | 42 | def l2_norm(self, x): 43 | normed_x = x / torch.norm(x, 2, 1, keepdim=True) 44 | return normed_x 45 | 46 | def create_model(feat_dim=2048, num_classes=1000, num_head=2, tau=16.0): 47 | model = ClassifierCOS(feat_dim=feat_dim, num_classes=num_classes, num_head=num_head, tau=tau) 48 | return model -------------------------------------------------------------------------------- /models/ClassifierFC.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.models as models 8 | 9 | import math 10 | 11 | class ClassifierFC(nn.Module): 12 | def __init__(self, feat_dim, num_classes=1000): 13 | super(ClassifierFC, self).__init__() 14 | 15 | self.fc = nn.Linear(feat_dim, num_classes, bias=False) 16 | 17 | def forward(self, x, add_inputs=None): 18 | y = self.fc(x) 19 | return y 20 | 21 | 22 | def create_model(feat_dim=2048, num_classes=1000): 23 | model = ClassifierFC(feat_dim=feat_dim, num_classes=num_classes) 24 | return model -------------------------------------------------------------------------------- /models/ClassifierLA.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.models as models 8 | 9 | import math 10 | 11 | class ClassifierLA(nn.Module): 12 | def __init__(self, feat_dim, num_classes=1000, posthoc=False, loss=False): 13 | super(ClassifierLA, self).__init__() 14 | 15 | self.posthoc = posthoc 16 | self.loss = loss 17 | assert (self.posthoc and self.loss) == False 18 | assert (self.posthoc or self.loss) == True 19 | 20 | self.fc = nn.Linear(feat_dim, num_classes, bias=False) 21 | 22 | def forward(self, x, add_inputs=None): 23 | y = self.fc(x) 24 | if self.training and self.loss: 25 | logit_adj = add_inputs['logit_adj'] 26 | y = y + logit_adj 27 | if (not self.training) and self.posthoc: 28 | logit_adj = add_inputs['logit_adj'] 29 | y = y - logit_adj 30 | return y 31 | 32 | 33 | def create_model(feat_dim=2048, num_classes=1000, posthoc=True, loss=False): 34 | model = ClassifierLA(feat_dim=feat_dim, num_classes=num_classes, posthoc=posthoc, loss=loss) 35 | return model -------------------------------------------------------------------------------- /models/ClassifierLDAM.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.models as models 8 | import torch.nn.functional as F 9 | 10 | import math 11 | 12 | class ClassifierLDAM(nn.Module): 13 | def __init__(self, feat_dim, num_classes=1000): 14 | super(ClassifierLDAM, self).__init__() 15 | 16 | self.weight = nn.Parameter(torch.Tensor(feat_dim, num_classes).cuda(), requires_grad=True) 17 | self.weight.data.uniform_(-1, 1) 18 | self.weight.data.renorm_(2, 1, 1e-5) 19 | self.weight.data.mul_(1e5) 20 | 21 | 22 | def forward(self, x, add_inputs=None): 23 | y = torch.mm(F.normalize(x, dim=1), F.normalize(self.weight, dim=0)) 24 | return y 25 | 26 | 27 | def create_model(feat_dim=2048, num_classes=1000): 28 | model = ClassifierLDAM(feat_dim=feat_dim, num_classes=num_classes) 29 | return model -------------------------------------------------------------------------------- /models/ClassifierLWS.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.models as models 8 | 9 | import math 10 | 11 | class ClassifierLWS(nn.Module): 12 | def __init__(self, feat_dim, num_classes=1000): 13 | super(ClassifierLWS, self).__init__() 14 | 15 | self.fc = nn.Linear(feat_dim, num_classes, bias=False) 16 | 17 | self.scales = nn.Parameter(torch.ones(num_classes)) 18 | for _, param in self.fc.named_parameters(): 19 | param.requires_grad = False 20 | 21 | def forward(self, x, add_inputs=None): 22 | y = self.fc(x) 23 | y *= self.scales 24 | return y 25 | 26 | 27 | def create_model(feat_dim=2048, num_classes=1000): 28 | model = ClassifierLWS(feat_dim=feat_dim, num_classes=num_classes) 29 | return model -------------------------------------------------------------------------------- /models/ClassifierMultiHead.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.models as models 8 | 9 | import math 10 | 11 | class ClassifierMultiHead(nn.Module): 12 | def __init__(self, feat_dim, num_classes=1000): 13 | super(ClassifierMultiHead, self).__init__() 14 | 15 | self.fc1 = nn.Linear(feat_dim, num_classes, bias=False) 16 | self.fc2 = nn.Linear(feat_dim, feat_dim, bias=False) 17 | self.fc3 = nn.Linear(feat_dim, feat_dim, bias=False) 18 | 19 | def forward(self, x, add_inputs=None): 20 | y1 = self.fc1(x.detach()) # prediction (re-train in stage2) 21 | y2 = self.fc2(x) # contrastive head 22 | y3 = self.fc3(x) # metric head 23 | return y1, y2, y3 24 | 25 | 26 | def create_model(feat_dim=2048, num_classes=1000): 27 | model = ClassifierMultiHead(feat_dim=feat_dim, num_classes=num_classes) 28 | return model -------------------------------------------------------------------------------- /models/ClassifierRIDE.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision.models as models 9 | 10 | 11 | 12 | class NormedLinear(nn.Module): 13 | def __init__(self, in_features, out_features): 14 | super(NormedLinear, self).__init__() 15 | self.weight = nn.Parameter(torch.Tensor(in_features, out_features)) 16 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 17 | 18 | def forward(self, x): 19 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 20 | return out 21 | 22 | class ClassifierRIDE(nn.Module): 23 | def __init__(self, feat_dim, num_classes=1000, num_experts=3, use_norm=True): 24 | super(ClassifierRIDE, self).__init__() 25 | self.num_experts = num_experts 26 | if use_norm: 27 | self.linears = nn.ModuleList([NormedLinear(feat_dim, num_classes) for _ in range(num_experts)]) 28 | s = 30 29 | else: 30 | self.linears = nn.ModuleList([nn.Linear(feat_dim, num_classes) for _ in range(num_experts)]) 31 | s = 1 32 | self.s = s 33 | def forward(self, x, add_inputs=None, index=None): 34 | if index is None: 35 | logits = [] 36 | for ind in range(self.num_experts): 37 | logit = (self.linears[ind])(x[:, ind, :]) 38 | logits.append(logit * self.s) 39 | y = torch.stack(logits, dim=1).mean(dim=1) 40 | return y, logits 41 | else: 42 | logit = (self.linears[index])(x) 43 | return logit 44 | 45 | 46 | def create_model(feat_dim=2048, num_classes=1000, num_experts=3, use_norm=True): 47 | model = ClassifierRIDE(feat_dim=feat_dim, num_classes=num_classes, 48 | num_experts=num_experts, use_norm=use_norm) 49 | return model -------------------------------------------------------------------------------- /models/ClassifierTDE.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.models as models 8 | 9 | import math 10 | 11 | class ClassifierTDE(nn.Module): 12 | def __init__(self, feat_dim, num_classes=1000, num_head=2, tau=16.0, alpha=3.0, gamma=0.03125): 13 | super(ClassifierTDE, self).__init__() 14 | 15 | # classifier weights 16 | self.weight = nn.Parameter(torch.Tensor(num_classes, feat_dim).cuda(), requires_grad=True) 17 | self.reset_parameters(self.weight) 18 | 19 | # parameters 20 | self.scale = tau / num_head # 16.0 / num_head 21 | self.norm_scale = gamma 22 | self.alpha = alpha 23 | self.num_head = num_head 24 | self.head_dim = feat_dim // num_head 25 | 26 | def reset_parameters(self, weight): 27 | stdv = 1. / math.sqrt(weight.size(1)) 28 | weight.data.uniform_(-stdv, stdv) 29 | 30 | def forward(self, x, add_inputs=None): 31 | normed_x = self.multi_head_call(self.l2_norm, x) 32 | normed_w = self.multi_head_call(self.causal_norm, self.weight, weight=self.norm_scale) 33 | y = torch.mm(normed_x * self.scale, normed_w.t()) 34 | 35 | # apply TDE during inference 36 | if (not self.training): 37 | self.embed = add_inputs['embed'] 38 | normed_c = self.multi_head_call(self.l2_norm, self.embed) 39 | x_list = torch.split(normed_x, self.head_dim, dim=1) 40 | c_list = torch.split(normed_c, self.head_dim, dim=1) 41 | w_list = torch.split(normed_w, self.head_dim, dim=1) 42 | output = [] 43 | 44 | for nx, nc, nw in zip(x_list, c_list, w_list): 45 | cos_val, sin_val = self.get_cos_sin(nx, nc) 46 | y0 = torch.mm((nx - cos_val * self.alpha * nc) * self.scale, nw.t()) 47 | output.append(y0) 48 | y = sum(output) 49 | return y 50 | 51 | def get_cos_sin(self, x, y): 52 | cos_val = (x * y).sum(-1, keepdim=True) / torch.norm(x, 2, 1, keepdim=True) / torch.norm(y, 2, 1, keepdim=True) 53 | sin_val = (1 - cos_val * cos_val).sqrt() 54 | return cos_val, sin_val 55 | 56 | def multi_head_call(self, func, x, weight=None): 57 | assert len(x.shape) == 2 58 | x_list = torch.split(x, self.head_dim, dim=1) 59 | if weight: 60 | y_list = [func(item, weight) for item in x_list] 61 | else: 62 | y_list = [func(item) for item in x_list] 63 | assert len(x_list) == self.num_head 64 | assert len(y_list) == self.num_head 65 | return torch.cat(y_list, dim=1) 66 | 67 | def l2_norm(self, x): 68 | normed_x = x / torch.norm(x, 2, 1, keepdim=True) 69 | return normed_x 70 | 71 | def causal_norm(self, x, weight): 72 | norm= torch.norm(x, 2, 1, keepdim=True) 73 | normed_x = x / (norm + weight) 74 | return normed_x 75 | 76 | def create_model(feat_dim=2048, num_classes=1000, num_head=2, tau=16.0, alpha=3.0, gamma=0.03125): 77 | model = ClassifierTDE(feat_dim=feat_dim, num_classes=num_classes, num_head=num_head, tau=tau, alpha=alpha, gamma=gamma) 78 | return model -------------------------------------------------------------------------------- /models/ResNet.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.models as models 8 | 9 | 10 | def create_model(m_type='resnet101'): 11 | # create various resnet models 12 | if m_type == 'resnet18': 13 | model = models.resnet18(pretrained=False) 14 | elif m_type == 'resnet50': 15 | model = models.resnet50(pretrained=False) 16 | elif m_type == 'resnet101': 17 | model = models.resnet101(pretrained=False) 18 | elif m_type == 'resnext50': 19 | model = models.resnext50_32x4d(pretrained=False) 20 | elif m_type == 'resnext101': 21 | model = models.resnext101_32x8d(pretrained=False) 22 | else: 23 | raise ValueError('Wrong Model Type') 24 | model.fc = nn.ReLU() 25 | return model -------------------------------------------------------------------------------- /models/ResNet_BBN.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.models as models 8 | import torch.nn.functional as F 9 | import math 10 | 11 | class BottleNeck(nn.Module): 12 | 13 | expansion = 4 14 | 15 | def __init__(self, inplanes, planes, stride=1): 16 | super(BottleNeck, self).__init__() 17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(True) 20 | self.conv2 = nn.Conv2d( 21 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 22 | ) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | self.relu2 = nn.ReLU(True) 25 | self.conv3 = nn.Conv2d( 26 | planes, planes * self.expansion, kernel_size=1, bias=False 27 | ) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | if stride != 1 or self.expansion * planes != inplanes: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d( 32 | inplanes, 33 | self.expansion * planes, 34 | kernel_size=1, 35 | stride=stride, 36 | bias=False, 37 | ), 38 | nn.BatchNorm2d(self.expansion * planes), 39 | ) 40 | else: 41 | self.downsample = None 42 | self.relu = nn.ReLU(True) 43 | 44 | def forward(self, x): 45 | out = self.relu1(self.bn1(self.conv1(x))) 46 | 47 | out = self.relu2(self.bn2(self.conv2(out))) 48 | 49 | out = self.bn3(self.conv3(out)) 50 | 51 | if self.downsample != None: 52 | residual = self.downsample(x) 53 | else: 54 | residual = x 55 | out = out + residual 56 | out = self.relu(out) 57 | return out 58 | 59 | 60 | 61 | class BBN_ResNet(nn.Module): 62 | def __init__( 63 | self, 64 | block_type, 65 | num_blocks, 66 | last_layer_stride=2, 67 | ): 68 | super(BBN_ResNet, self).__init__() 69 | self.inplanes = 64 70 | self.block = block_type 71 | 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 73 | self.bn1 = nn.BatchNorm2d(64) 74 | self.relu = nn.ReLU(True) 75 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 76 | 77 | self.layer1 = self._make_layer(num_blocks[0], 64) 78 | self.layer2 = self._make_layer(num_blocks[1], 128, stride=2) 79 | self.layer3 = self._make_layer(num_blocks[2], 256, stride=2) 80 | self.layer4 = self._make_layer(num_blocks[3] - 1, 512, stride=last_layer_stride) 81 | 82 | self.cb_block = self.block(self.inplanes, self.inplanes // 4, stride=1) 83 | self.rb_block = self.block(self.inplanes, self.inplanes // 4, stride=1) 84 | 85 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 86 | 87 | def load_model(self, pretrain): 88 | print("Loading Backbone pretrain model from {}......".format(pretrain)) 89 | model_dict = self.state_dict() 90 | pretrain_dict = torch.load(pretrain) 91 | pretrain_dict = pretrain_dict["state_dict"] if "state_dict" in pretrain_dict else pretrain_dict 92 | from collections import OrderedDict 93 | 94 | new_dict = OrderedDict() 95 | for k, v in pretrain_dict.items(): 96 | if k.startswith("module"): 97 | k = k[7:] 98 | if "fc" not in k and "classifier" not in k: 99 | k = k.replace("backbone.", "") 100 | new_dict[k] = v 101 | 102 | model_dict.update(new_dict) 103 | self.load_state_dict(model_dict) 104 | print("Backbone model has been loaded......") 105 | 106 | def _make_layer(self, num_block, planes, stride=1): 107 | strides = [stride] + [1] * (num_block - 1) 108 | layers = [] 109 | for now_stride in strides: 110 | layers.append(self.block(self.inplanes, planes, stride=now_stride)) 111 | self.inplanes = planes * self.block.expansion 112 | return nn.Sequential(*layers) 113 | 114 | def forward(self, x, **kwargs): 115 | out = self.conv1(x) 116 | out = self.bn1(out) 117 | out = self.relu(out) 118 | out = self.pool(out) 119 | 120 | out = self.layer1(out) 121 | out = self.layer2(out) 122 | out = self.layer3(out) 123 | out = self.layer4(out) 124 | 125 | if "feature_cb" in kwargs: 126 | out = self.cb_block(out) 127 | elif "feature_rb" in kwargs: 128 | out = self.rb_block(out) 129 | else: 130 | out1 = self.cb_block(out) 131 | out2 = self.rb_block(out) 132 | out = torch.cat((out1, out2), dim=1) 133 | 134 | out = self.avgpool(out) 135 | out = out.view(x.shape[0], -1) 136 | return out 137 | 138 | def create_model(m_type='bbn_resnet50'): 139 | # create various resnet models 140 | model = BBN_ResNet(BottleNeck, [3, 4, 6, 3], last_layer_stride=2,) 141 | return model -------------------------------------------------------------------------------- /models/ResNet_RIDE.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import math 6 | from typing import ForwardRef 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision.models as models 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(inplanes, planes, stride) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv2 = conv3x3(planes, planes) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | residual = x 33 | 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | 41 | if self.downsample is not None: 42 | residual = self.downsample(x) 43 | 44 | out += residual 45 | out = self.relu(out) 46 | 47 | return out 48 | 49 | class Bottleneck(nn.Module): 50 | expansion = 4 51 | 52 | def __init__(self, inplanes, planes, stride=1, downsample=None, 53 | groups=1, base_width=64, is_last=False): 54 | super(Bottleneck, self).__init__() 55 | width = int(planes * (base_width / 64.)) * groups 56 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(width) 58 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, 59 | groups=groups, padding=1, bias=False) 60 | self.bn2 = nn.BatchNorm2d(width) 61 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(planes * 4) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | self.is_last = is_last 67 | 68 | def forward(self, x): 69 | residual = x 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv3(out) 80 | out = self.bn3(out) 81 | 82 | if self.downsample is not None: 83 | residual = self.downsample(x) 84 | 85 | out += residual 86 | out = self.relu(out) 87 | 88 | return out 89 | 90 | class ResNext(nn.Module): 91 | def __init__(self, block, layers, groups=1, width_per_group=64, num_experts=1, reduce_dimension=False): 92 | self.inplanes = 64 93 | self.num_experts = num_experts 94 | super(ResNext, self).__init__() 95 | 96 | self.groups = groups 97 | self.base_width = width_per_group 98 | 99 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 100 | bias=False) 101 | self.bn1 = nn.BatchNorm2d(64) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 104 | self.layer1 = self._make_layer(block, 64, layers[0]) 105 | self.inplanes = self.next_inplanes 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.inplanes = self.next_inplanes 108 | 109 | if reduce_dimension: 110 | layer3_output_dim = 192 111 | else: 112 | layer3_output_dim = 256 113 | 114 | if reduce_dimension: 115 | layer4_output_dim = 384 116 | else: 117 | layer4_output_dim = 512 118 | 119 | self.layer3s = nn.ModuleList([self._make_layer(block, layer3_output_dim, layers[2], stride=2) for _ in range(num_experts)]) 120 | self.inplanes = self.next_inplanes 121 | self.layer4s = nn.ModuleList([self._make_layer(block, layer4_output_dim, layers[3], stride=2) for _ in range(num_experts)]) 122 | self.inplanes = self.next_inplanes 123 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 124 | 125 | 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 129 | m.weight.data.normal_(0, math.sqrt(2. / n)) 130 | elif isinstance(m, nn.BatchNorm2d): 131 | m.weight.data.fill_(1) 132 | m.bias.data.zero_() 133 | 134 | 135 | def _make_layer(self, block, planes, blocks, stride=1, is_last=False): 136 | downsample = None 137 | if stride != 1 or self.inplanes != planes * block.expansion: 138 | downsample = nn.Sequential( 139 | nn.Conv2d(self.inplanes, planes * block.expansion, 140 | kernel_size=1, stride=stride, bias=False), 141 | nn.BatchNorm2d(planes * block.expansion), 142 | ) 143 | 144 | layers = [] 145 | layers.append(block(self.inplanes, planes, stride, downsample, 146 | groups=self.groups, base_width=self.base_width)) 147 | self.next_inplanes = planes * block.expansion 148 | for i in range(1, blocks): 149 | layers.append(block(self.next_inplanes, planes, 150 | groups=self.groups, base_width=self.base_width, 151 | is_last=(is_last and i == blocks-1))) 152 | 153 | return nn.Sequential(*layers) 154 | 155 | def _separate_part(self, x, ind): 156 | x = (self.layer3s[ind])(x) 157 | x = (self.layer4s[ind])(x) 158 | x = self.avgpool(x) 159 | x = x.view(x.size(0), -1) 160 | return x 161 | 162 | def forward(self, x, index=None): 163 | x = self.conv1(x) 164 | x = self.bn1(x) 165 | x = self.relu(x) 166 | x = self.maxpool(x) 167 | 168 | x = self.layer1(x) 169 | x = self.layer2(x) 170 | 171 | if index is None: 172 | feats = [] 173 | for ind in range(self.num_experts): 174 | feats.append(self._separate_part(x, ind)) 175 | return torch.stack(feats, dim=1) 176 | else: 177 | return self._separate_part(x, index) 178 | 179 | class ResNeXt50Model(nn.Module): 180 | def __init__(self, reduce_dimension=False, num_experts=1): 181 | super(ResNeXt50Model, self).__init__() 182 | self.backbone = ResNext(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, reduce_dimension=reduce_dimension, num_experts=num_experts) 183 | def forward(self, x, index=None): 184 | return self.backbone(x, index) 185 | 186 | 187 | def create_model(m_type='resnext50', num_experts=3, reduce_dimension=True): 188 | # create various resnet models 189 | if m_type == 'resnext50': 190 | model = ResNeXt50Model(reduce_dimension=reduce_dimension, num_experts=num_experts) 191 | return model -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/models/__init__.py -------------------------------------------------------------------------------- /train_baseline.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | import torch.nn.functional as F 10 | 11 | import utils.general_utils as utils 12 | from data.dataloader import get_loader 13 | from utils.checkpoint_utils import Checkpoint 14 | from utils.training_utils import * 15 | from utils.test_loader import test_loader 16 | 17 | class train_baseline(): 18 | def __init__(self, args, config, logger, eval=False): 19 | # ============================================================================ 20 | # create model 21 | logger.info('=====> Model construction from: ' + str(config['networks']['type'])) 22 | model_type = config['networks']['type'] 23 | model_file = config['networks'][model_type]['def_file'] 24 | model_args = config['networks'][model_type]['params'] 25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type'])) 26 | classifier_type = config['classifiers']['type'] 27 | classifier_file = config['classifiers'][classifier_type]['def_file'] 28 | classifier_args = config['classifiers'][classifier_type]['params'] 29 | model = utils.source_import(model_file).create_model(**model_args) 30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args) 31 | 32 | model = nn.DataParallel(model).cuda() 33 | classifier = nn.DataParallel(classifier).cuda() 34 | 35 | # other initialization 36 | self.config = config 37 | self.logger = logger 38 | self.model = model 39 | self.classifier = classifier 40 | self.optimizer = create_optimizer(model, classifier, logger, config) 41 | self.scheduler = create_scheduler(self.optimizer, logger, config) 42 | self.eval = eval 43 | self.training_opt = config['training_opt'] 44 | 45 | self.checkpoint = Checkpoint(config) 46 | 47 | # get dataloader 48 | self.logger.info('=====> Get train dataloader') 49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger) 50 | 51 | # get loss 52 | self.loss_fc = create_loss(logger, config, self.train_loader) 53 | 54 | # set eval 55 | if self.eval: 56 | test_func = test_loader(config) 57 | self.testing = test_func(config, logger, model, classifier, val=True) 58 | 59 | 60 | def run(self): 61 | # Start Training 62 | self.logger.info('=====> Start Baseline Training') 63 | 64 | # run epoch 65 | for epoch in range(self.training_opt['num_epochs']): 66 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch)) 67 | 68 | # preprocess for each epoch 69 | total_batch = len(self.train_loader) 70 | 71 | for step, (inputs, labels, _, _) in enumerate(self.train_loader): 72 | self.optimizer.zero_grad() 73 | 74 | # additional inputs 75 | inputs, labels = inputs.cuda(), labels.cuda() 76 | add_inputs = {} 77 | 78 | features = self.model(inputs) 79 | predictions = self.classifier(features, add_inputs) 80 | 81 | # calculate loss 82 | loss = self.loss_fc(predictions, labels) 83 | iter_info_print = {self.training_opt['loss'] : loss.sum().item(),} 84 | 85 | loss.backward() 86 | self.optimizer.step() 87 | 88 | # calculate accuracy 89 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0] 90 | 91 | # log information 92 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])}) 93 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter']) 94 | 95 | first_batch = (epoch == 0) and (step == 0) 96 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0: 97 | utils.print_grad(self.classifier.named_parameters()) 98 | utils.print_grad(self.model.named_parameters()) 99 | 100 | # evaluation on validation set 101 | if self.eval: 102 | val_acc = self.testing.run_val(epoch) 103 | else: 104 | val_acc = 0.0 105 | 106 | # checkpoint 107 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc) 108 | 109 | # update scheduler 110 | self.scheduler.step() 111 | 112 | # save best model path 113 | self.checkpoint.save_best_model_path(self.logger) -------------------------------------------------------------------------------- /train_bbn.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | import torch.nn.functional as F 10 | 11 | import utils.general_utils as utils 12 | from data.dataloader import get_loader 13 | from utils.checkpoint_utils import Checkpoint 14 | from utils.training_utils import * 15 | from utils.test_loader import test_loader 16 | 17 | class train_bbn(): 18 | def __init__(self, args, config, logger, eval=False): 19 | # ============================================================================ 20 | # create model 21 | logger.info('=====> Model construction from: ' + str(config['networks']['type'])) 22 | model_type = config['networks']['type'] 23 | model_file = config['networks'][model_type]['def_file'] 24 | model_args = config['networks'][model_type]['params'] 25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type'])) 26 | classifier_type = config['classifiers']['type'] 27 | classifier_file = config['classifiers'][classifier_type]['def_file'] 28 | classifier_args = config['classifiers'][classifier_type]['params'] 29 | model = utils.source_import(model_file).create_model(**model_args) 30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args) 31 | 32 | model = nn.DataParallel(model).cuda() 33 | classifier = nn.DataParallel(classifier).cuda() 34 | 35 | # other initialization 36 | self.config = config 37 | self.logger = logger 38 | self.model = model 39 | self.classifier = classifier 40 | self.optimizer = create_optimizer(model, classifier, logger, config) 41 | self.scheduler = create_scheduler(self.optimizer, logger, config) 42 | self.eval = eval 43 | self.training_opt = config['training_opt'] 44 | 45 | self.checkpoint = Checkpoint(config) 46 | 47 | # get dataloader 48 | self.logger.info('=====> Get train dataloader') 49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger) 50 | 51 | # get loss 52 | self.loss_fc = create_loss(logger, config, self.train_loader) 53 | 54 | # set eval 55 | if self.eval: 56 | test_func = test_loader(config) 57 | self.testing = test_func(config, logger, model, classifier, val=True) 58 | 59 | 60 | def calculate_reverse_instance_weight(self, dataloader): 61 | # counting frequency 62 | label_freq = {} 63 | for key in dataloader.dataset.labels: 64 | label_freq[key] = label_freq.get(key, 0) + 1 65 | label_freq = dict(sorted(label_freq.items())) 66 | label_freq_array = torch.FloatTensor(list(label_freq.values())) 67 | reverse_class_weight = label_freq_array.max() / label_freq_array 68 | # generate reverse weight 69 | reverse_instance_weight = torch.zeros(len(dataloader.dataset)).fill_(1.0) 70 | for i, label in enumerate(dataloader.dataset.labels): 71 | reverse_instance_weight[i] = reverse_class_weight[label] / (label_freq_array[label] + 1e-9) 72 | return reverse_instance_weight 73 | 74 | 75 | def run(self): 76 | # Start Training 77 | self.logger.info('=====> Start BBN Training') 78 | 79 | # preprocess for each epoch 80 | env1_loader, env2_loader = self.train_loader 81 | assert len(env1_loader) == len(env2_loader) 82 | total_batch = len(env1_loader) 83 | total_image = len(env1_loader.dataset) 84 | 85 | # set dataloader distribution 86 | instance_normal_weight = torch.zeros(total_image).fill_(1.0) 87 | env1_loader.sampler.set_parameter(instance_normal_weight) # conventional distribution 88 | instance_reverse_weight = self.calculate_reverse_instance_weight(env1_loader) 89 | env2_loader.sampler.set_parameter(instance_reverse_weight) # reverse distribution 90 | 91 | # run epoch 92 | num_epoch = self.training_opt['num_epochs'] 93 | for epoch in range(num_epoch): 94 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch)) 95 | 96 | for step, ((inputs1, labels1, _, indexs1), (inputs2, labels2, _, indexs2)) in enumerate(zip(env1_loader, env2_loader)): 97 | iter_info_print = {} 98 | 99 | self.optimizer.zero_grad() 100 | 101 | # additional inputs 102 | inputs1, inputs2 = inputs1.cuda(), inputs2.cuda() 103 | labels1, labels2 = labels1.cuda(), labels2.cuda() 104 | 105 | feature1 = self.model(inputs1, feature_cb=True) 106 | feature2 = self.model(inputs2, feature_rb=True) 107 | 108 | l = 1 - ((epoch - 1) / num_epoch) ** 2 # parabolic decay 109 | 110 | mixed_feature = 2 * torch.cat((l * feature1, (1-l) * feature2), dim=1) 111 | 112 | predictions = self.classifier(mixed_feature) 113 | 114 | # calculate loss 115 | loss = l * self.loss_fc(predictions, labels1) + (1 - l) * self.loss_fc(predictions, labels2) 116 | iter_info_print = {'BBN mixup loss': loss.sum().item(),} 117 | 118 | loss.backward() 119 | self.optimizer.step() 120 | 121 | # calculate accuracy 122 | accuracy = l * (predictions.max(1)[1] == labels1).float() + (1 - l) * (predictions.max(1)[1] == labels2).float() 123 | accuracy = accuracy.sum() / accuracy.shape[0] 124 | 125 | # log information 126 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])}) 127 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter']) 128 | 129 | first_batch = (epoch == 0) and (step == 0) 130 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0: 131 | utils.print_grad(self.classifier.named_parameters()) 132 | utils.print_grad(self.model.named_parameters()) 133 | 134 | # evaluation on validation set 135 | if self.eval: 136 | val_acc = self.testing.run_val(epoch) 137 | else: 138 | val_acc = 0.0 139 | 140 | # checkpoint 141 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc) 142 | 143 | # update scheduler 144 | self.scheduler.step() 145 | 146 | # save best model path 147 | self.checkpoint.save_best_model_path(self.logger) -------------------------------------------------------------------------------- /train_la.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | import torch.nn.functional as F 10 | 11 | import utils.general_utils as utils 12 | from data.dataloader import get_loader 13 | from utils.checkpoint_utils import Checkpoint 14 | from utils.training_utils import * 15 | from utils.test_loader import test_loader 16 | 17 | class train_la(): 18 | def __init__(self, args, config, logger, eval=False): 19 | # ============================================================================ 20 | # create model 21 | logger.info('=====> Model construction from: ' + str(config['networks']['type'])) 22 | model_type = config['networks']['type'] 23 | model_file = config['networks'][model_type]['def_file'] 24 | model_args = config['networks'][model_type]['params'] 25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type'])) 26 | classifier_type = config['classifiers']['type'] 27 | classifier_file = config['classifiers'][classifier_type]['def_file'] 28 | classifier_args = config['classifiers'][classifier_type]['params'] 29 | model = utils.source_import(model_file).create_model(**model_args) 30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args) 31 | 32 | model = nn.DataParallel(model).cuda() 33 | classifier = nn.DataParallel(classifier).cuda() 34 | 35 | # other initialization 36 | self.algorithm_opt = config['algorithm_opt'] 37 | self.config = config 38 | self.logger = logger 39 | self.model = model 40 | self.classifier = classifier 41 | self.optimizer = create_optimizer(model, classifier, logger, config) 42 | self.scheduler = create_scheduler(self.optimizer, logger, config) 43 | self.eval = eval 44 | self.training_opt = config['training_opt'] 45 | 46 | self.checkpoint = Checkpoint(config) 47 | 48 | # get dataloader 49 | self.logger.info('=====> Get train dataloader') 50 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger) 51 | 52 | # get loss 53 | self.loss_fc = create_loss(logger, config, self.train_loader) 54 | 55 | # set eval 56 | if self.eval: 57 | test_func = test_loader(config) 58 | self.testing = test_func(config, logger, model, classifier, val=True) 59 | 60 | 61 | def run(self): 62 | # Start Training 63 | self.logger.info('=====> Start Baseline Training') 64 | 65 | # logit adjustment 66 | logit_adj = utils.compute_adjustment(self.train_loader, self.algorithm_opt['tro']) 67 | logit_adj.requires_grad = False 68 | 69 | # run epoch 70 | for epoch in range(self.training_opt['num_epochs']): 71 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch)) 72 | 73 | # preprocess for each epoch 74 | total_batch = len(self.train_loader) 75 | 76 | for step, (inputs, labels, _, _) in enumerate(self.train_loader): 77 | self.optimizer.zero_grad() 78 | 79 | # additional inputs 80 | inputs, labels = inputs.cuda(), labels.cuda() 81 | add_inputs = {} 82 | batch_size = inputs.shape[0] 83 | add_inputs['logit_adj'] = logit_adj.to(inputs.device).view(1, -1).repeat(batch_size, 1) 84 | 85 | features = self.model(inputs) 86 | predictions = self.classifier(features, add_inputs) 87 | 88 | # calculate loss 89 | loss = self.loss_fc(predictions, labels) 90 | iter_info_print = {self.training_opt['loss'] : loss.sum().item(),} 91 | 92 | loss.backward() 93 | self.optimizer.step() 94 | 95 | # calculate accuracy 96 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0] 97 | 98 | # log information 99 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])}) 100 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter']) 101 | 102 | first_batch = (epoch == 0) and (step == 0) 103 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0: 104 | utils.print_grad(self.classifier.named_parameters()) 105 | utils.print_grad(self.model.named_parameters()) 106 | 107 | # evaluation on validation set 108 | if self.eval: 109 | val_acc = self.testing.run_val(epoch) 110 | else: 111 | val_acc = 0.0 112 | 113 | # checkpoint 114 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc) 115 | 116 | # update scheduler 117 | self.scheduler.step() 118 | 119 | # save best model path 120 | self.checkpoint.save_best_model_path(self.logger) -------------------------------------------------------------------------------- /train_ldam.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | import torch.nn.functional as F 10 | 11 | import utils.general_utils as utils 12 | from data.dataloader import get_loader 13 | from utils.checkpoint_utils import Checkpoint 14 | from utils.training_utils import * 15 | from utils.test_loader import test_loader 16 | 17 | class train_ldam(): 18 | def __init__(self, args, config, logger, eval=False): 19 | # ============================================================================ 20 | # create model 21 | logger.info('=====> Model construction from: ' + str(config['networks']['type'])) 22 | model_type = config['networks']['type'] 23 | model_file = config['networks'][model_type]['def_file'] 24 | model_args = config['networks'][model_type]['params'] 25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type'])) 26 | classifier_type = config['classifiers']['type'] 27 | classifier_file = config['classifiers'][classifier_type]['def_file'] 28 | classifier_args = config['classifiers'][classifier_type]['params'] 29 | model = utils.source_import(model_file).create_model(**model_args) 30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args) 31 | 32 | model = nn.DataParallel(model).cuda() 33 | classifier = nn.DataParallel(classifier).cuda() 34 | 35 | # other initialization 36 | self.config = config 37 | self.logger = logger 38 | self.model = model 39 | self.classifier = classifier 40 | self.optimizer = create_optimizer(model, classifier, logger, config) 41 | self.scheduler = create_scheduler(self.optimizer, logger, config) 42 | self.eval = eval 43 | self.training_opt = config['training_opt'] 44 | 45 | self.checkpoint = Checkpoint(config) 46 | 47 | # get dataloader 48 | self.logger.info('=====> Get train dataloader') 49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger) 50 | 51 | # get loss 52 | self.loss_ldam = create_loss(logger, config, self.train_loader) 53 | 54 | # set eval 55 | if self.eval: 56 | test_func = test_loader(config) 57 | self.testing = test_func(config, logger, model, classifier, val=True) 58 | 59 | 60 | def run(self): 61 | # Start Training 62 | self.logger.info('=====> Start LDAM Training') 63 | 64 | # run epoch 65 | for epoch in range(self.training_opt['num_epochs']): 66 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch)) 67 | 68 | # preprocess for each epoch 69 | total_batch = len(self.train_loader) 70 | 71 | # set LDAM weight 72 | self.loss_ldam.set_weight(epoch) 73 | 74 | for step, (inputs, labels, _, _) in enumerate(self.train_loader): 75 | self.optimizer.zero_grad() 76 | 77 | # additional inputs 78 | inputs, labels = inputs.cuda(), labels.cuda() 79 | add_inputs = {} 80 | 81 | features = self.model(inputs) 82 | predictions = self.classifier(features, add_inputs) 83 | 84 | # calculate loss 85 | loss = self.loss_ldam(predictions, labels) 86 | iter_info_print = {self.training_opt['loss'] : loss.sum().item(),} 87 | 88 | loss.backward() 89 | self.optimizer.step() 90 | 91 | # calculate accuracy 92 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0] 93 | 94 | # log information 95 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])}) 96 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter']) 97 | 98 | first_batch = (epoch == 0) and (step == 0) 99 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0: 100 | utils.print_grad(self.classifier.named_parameters()) 101 | utils.print_grad(self.model.named_parameters()) 102 | 103 | # evaluation on validation set 104 | if self.eval: 105 | val_acc = self.testing.run_val(epoch) 106 | else: 107 | val_acc = 0.0 108 | 109 | # checkpoint 110 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc) 111 | 112 | # update scheduler 113 | self.scheduler.step() 114 | 115 | # save best model path 116 | self.checkpoint.save_best_model_path(self.logger) -------------------------------------------------------------------------------- /train_lff.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | import torch.nn.functional as F 10 | 11 | import utils.general_utils as utils 12 | from data.dataloader import get_loader 13 | from utils.checkpoint_utils import Checkpoint 14 | from utils.training_utils import * 15 | from utils.test_loader import test_loader 16 | 17 | # hard example mining 18 | 19 | class GeneralizedCELoss(nn.Module): 20 | def __init__(self, q=0.7): 21 | super(GeneralizedCELoss, self).__init__() 22 | self.q = q 23 | 24 | def forward(self, logits, targets, requires_weight = False, weight_base = 0): 25 | p = F.softmax(logits, dim=1) 26 | if np.isnan(p.mean().item()): 27 | raise NameError('GCE_p') 28 | Yg = torch.gather(p, 1, torch.unsqueeze(targets, 1)) 29 | # modify gradient of cross entropy 30 | loss_weight = (Yg.squeeze().detach()**self.q)*self.q 31 | if np.isnan(Yg.mean().item()): 32 | raise NameError('GCE_Yg') 33 | 34 | loss = F.cross_entropy(logits, targets, reduction='none') * loss_weight + weight_base 35 | if requires_weight: 36 | return loss, loss_weight 37 | return loss 38 | 39 | class train_lff(): 40 | def __init__(self, args, config, logger, eval=False): 41 | # ============================================================================ 42 | # create model 43 | logger.info('=====> Model construction from: ' + str(config['networks']['type'])) 44 | model_type = config['networks']['type'] 45 | model_file = config['networks'][model_type]['def_file'] 46 | model_args = config['networks'][model_type]['params'] 47 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type'])) 48 | classifier_type = config['classifiers']['type'] 49 | classifier_file = config['classifiers'][classifier_type]['def_file'] 50 | classifier_args = config['classifiers'][classifier_type]['params'] 51 | model_b = utils.source_import(model_file).create_model(**model_args) 52 | model_d = utils.source_import(model_file).create_model(**model_args) 53 | classifier_b = utils.source_import(classifier_file).create_model(**classifier_args) 54 | classifier_d = utils.source_import(classifier_file).create_model(**classifier_args) 55 | 56 | model_b = nn.DataParallel(model_b).cuda() 57 | model_d = nn.DataParallel(model_d).cuda() 58 | classifier_b = nn.DataParallel(classifier_b).cuda() 59 | classifier_d = nn.DataParallel(classifier_d).cuda() 60 | 61 | # other initialization 62 | self.algorithm_opt = config['algorithm_opt'] 63 | self.config = config 64 | self.logger = logger 65 | self.model_b = model_b 66 | self.model_d = model_d 67 | self.classifier_b = classifier_b 68 | self.classifier_d = classifier_d 69 | self.optimizer_b = create_optimizer(model_b, classifier_b, logger, config) 70 | self.optimizer_d = create_optimizer(model_d, classifier_d, logger, config) 71 | self.scheduler_b = create_scheduler(self.optimizer_b, logger, config) 72 | self.scheduler_d = create_scheduler(self.optimizer_d, logger, config) 73 | self.eval = eval 74 | self.training_opt = config['training_opt'] 75 | 76 | self.checkpoint = Checkpoint(config) 77 | 78 | # get dataloader 79 | self.logger.info('=====> Get train dataloader') 80 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger) 81 | 82 | # get loss 83 | self.loss_fc = nn.CrossEntropyLoss(reduction='none') 84 | # biased loss 85 | self.loss_bias = GeneralizedCELoss() 86 | 87 | 88 | # set eval 89 | if self.eval: 90 | test_func = test_loader(config) 91 | self.testing = test_func(config, logger, model_d, classifier_d, val=True) 92 | 93 | 94 | def run(self): 95 | # Start Training 96 | self.logger.info('=====> Start Baseline Training') 97 | 98 | # logit adjustment 99 | logit_adj = utils.compute_adjustment(self.train_loader, self.algorithm_opt['tro']) 100 | logit_adj.requires_grad = False 101 | 102 | # run epoch 103 | for epoch in range(self.training_opt['num_epochs']): 104 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch)) 105 | 106 | # preprocess for each epoch 107 | total_batch = len(self.train_loader) 108 | 109 | for step, (inputs, labels, _, _) in enumerate(self.train_loader): 110 | self.optimizer_b.zero_grad() 111 | self.optimizer_d.zero_grad() 112 | 113 | # additional inputs 114 | inputs, labels = inputs.cuda(), labels.cuda() 115 | add_inputs = {} 116 | batch_size = inputs.shape[0] 117 | add_inputs['logit_adj'] = logit_adj.to(inputs.device).view(1, -1).repeat(batch_size, 1) 118 | 119 | # biased prediction 120 | predictions_b = self.classifier_b(self.model_b(inputs), add_inputs) 121 | # targeted prediction 122 | predictions_d = self.classifier_d(self.model_d(inputs), add_inputs) 123 | 124 | # calculate hard exampling mining weight 125 | loss_b = self.loss_fc(predictions_b, labels).detach() 126 | loss_d = self.loss_fc(predictions_d, labels).detach() 127 | 128 | loss_weight = loss_b / (loss_b + loss_d + 1e-8) 129 | 130 | # calculate loss 131 | # biased model 132 | loss_b_update = self.loss_bias(predictions_b, labels) 133 | loss_d_update = self.loss_fc(predictions_d, labels) * loss_weight.cuda().detach() 134 | loss = loss_b_update.mean() + loss_d_update.mean() 135 | 136 | iter_info_print = {'biased loss' : loss_b_update.mean().item(), 'target loss': loss_d_update.mean().item()} 137 | 138 | loss.backward() 139 | self.optimizer_b.step() 140 | self.optimizer_d.step() 141 | 142 | # calculate accuracy 143 | accuracy = (predictions_d.max(1)[1] == labels).sum().float() / predictions_d.shape[0] 144 | 145 | # log information 146 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer_d.param_groups[0]['lr'])}) 147 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter']) 148 | 149 | # evaluation on validation set 150 | if self.eval: 151 | val_acc = self.testing.run_val(epoch) 152 | else: 153 | val_acc = 0.0 154 | 155 | # checkpoint 156 | self.checkpoint.save(self.model_d, self.classifier_d, epoch, self.logger, acc=val_acc) 157 | 158 | # update scheduler 159 | self.scheduler_b.step() 160 | self.scheduler_d.step() 161 | 162 | # save best model path 163 | self.checkpoint.save_best_model_path(self.logger) -------------------------------------------------------------------------------- /train_mixup.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | import torch.nn.functional as F 10 | 11 | import utils.general_utils as utils 12 | from data.dataloader import get_loader 13 | from utils.checkpoint_utils import Checkpoint 14 | from utils.training_utils import * 15 | from utils.test_loader import test_loader 16 | 17 | class train_mixup(): 18 | def __init__(self, args, config, logger, eval=False): 19 | # ============================================================================ 20 | # create model 21 | logger.info('=====> Model construction from: ' + str(config['networks']['type'])) 22 | model_type = config['networks']['type'] 23 | model_file = config['networks'][model_type]['def_file'] 24 | model_args = config['networks'][model_type]['params'] 25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type'])) 26 | classifier_type = config['classifiers']['type'] 27 | classifier_file = config['classifiers'][classifier_type]['def_file'] 28 | classifier_args = config['classifiers'][classifier_type]['params'] 29 | model = utils.source_import(model_file).create_model(**model_args) 30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args) 31 | 32 | model = nn.DataParallel(model).cuda() 33 | classifier = nn.DataParallel(classifier).cuda() 34 | 35 | # other initialization 36 | self.config = config 37 | self.logger = logger 38 | self.model = model 39 | self.classifier = classifier 40 | self.optimizer = create_optimizer(model, classifier, logger, config) 41 | self.scheduler = create_scheduler(self.optimizer, logger, config) 42 | self.eval = eval 43 | self.training_opt = config['training_opt'] 44 | 45 | self.checkpoint = Checkpoint(config) 46 | 47 | # get dataloader 48 | self.logger.info('=====> Get train dataloader') 49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger) 50 | 51 | # get loss 52 | self.loss_fc = create_loss(logger, config, self.train_loader) 53 | 54 | # set eval 55 | if self.eval: 56 | test_func = test_loader(config) 57 | self.testing = test_func(config, logger, model, classifier, val=True) 58 | 59 | 60 | def mixup_data(self, x, y, alpha=1.0): 61 | lam = np.random.beta(alpha, alpha) if alpha > 0 else 1 62 | batch_size = x.shape[0] 63 | index = torch.randperm(batch_size).to(x.device) 64 | mixed_x = lam * x + (1 - lam) * x[index] 65 | y_a, y_b = y, y[index] 66 | return mixed_x, y_a, y_b, lam 67 | 68 | def mixup_criterion(self, pred, y_a, y_b, lam): 69 | return lam * self.loss_fc(pred, y_a) + (1 - lam) * self.loss_fc(pred, y_b) 70 | 71 | def mixup_accuracy(self, pred, y_a, y_b, lam): 72 | correct = lam * (pred.max(1)[1] == y_a) + (1 - lam) * (pred.max(1)[1] == y_b) 73 | accuracy = correct.sum().float() / pred.shape[0] 74 | return accuracy 75 | 76 | 77 | def run(self): 78 | # Start Training 79 | self.logger.info('=====> Start Mixup Training') 80 | 81 | # run epoch 82 | for epoch in range(self.training_opt['num_epochs']): 83 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch)) 84 | 85 | # preprocess for each epoch 86 | total_batch = len(self.train_loader) 87 | 88 | for step, (inputs, labels, _, _) in enumerate(self.train_loader): 89 | self.optimizer.zero_grad() 90 | 91 | # additional inputs 92 | inputs, labels = inputs.cuda(), labels.cuda() 93 | add_inputs = {} 94 | 95 | # mixup 96 | inputs, labels_a, labels_b, lam = self.mixup_data(inputs, labels) 97 | 98 | features = self.model(inputs) 99 | predictions = self.classifier(features, add_inputs) 100 | 101 | # calculate loss 102 | loss = self.mixup_criterion(predictions, labels_a, labels_b, lam) 103 | iter_info_print = {self.training_opt['loss'] : loss.sum().item(),} 104 | 105 | loss.backward() 106 | self.optimizer.step() 107 | 108 | # calculate accuracy 109 | accuracy = self.mixup_accuracy(predictions, labels_a, labels_b, lam) 110 | 111 | # log information 112 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])}) 113 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter']) 114 | 115 | first_batch = (epoch == 0) and (step == 0) 116 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0: 117 | utils.print_grad(self.classifier.named_parameters()) 118 | utils.print_grad(self.model.named_parameters()) 119 | 120 | # evaluation on validation set 121 | if self.eval: 122 | val_acc = self.testing.run_val(epoch) 123 | else: 124 | val_acc = 0.0 125 | 126 | # checkpoint 127 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc) 128 | 129 | # update scheduler 130 | self.scheduler.step() 131 | 132 | # save best model path 133 | self.checkpoint.save_best_model_path(self.logger) -------------------------------------------------------------------------------- /train_ride.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | import torch.nn.functional as F 10 | 11 | import utils.general_utils as utils 12 | from data.dataloader import get_loader 13 | from utils.checkpoint_utils import Checkpoint 14 | from utils.training_utils import * 15 | from utils.test_loader import test_loader 16 | 17 | class train_ride(): 18 | def __init__(self, args, config, logger, eval=False): 19 | # ============================================================================ 20 | # create model 21 | logger.info('=====> Model construction from: ' + str(config['networks']['type'])) 22 | model_type = config['networks']['type'] 23 | model_file = config['networks'][model_type]['def_file'] 24 | model_args = config['networks'][model_type]['params'] 25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type'])) 26 | classifier_type = config['classifiers']['type'] 27 | classifier_file = config['classifiers'][classifier_type]['def_file'] 28 | classifier_args = config['classifiers'][classifier_type]['params'] 29 | model = utils.source_import(model_file).create_model(**model_args) 30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args) 31 | 32 | model = nn.DataParallel(model).cuda() 33 | classifier = nn.DataParallel(classifier).cuda() 34 | 35 | # other initialization 36 | self.config = config 37 | self.logger = logger 38 | self.model = model 39 | self.classifier = classifier 40 | self.optimizer = create_optimizer(model, classifier, logger, config) 41 | self.scheduler = create_scheduler(self.optimizer, logger, config) 42 | self.eval = eval 43 | self.training_opt = config['training_opt'] 44 | 45 | self.checkpoint = Checkpoint(config) 46 | 47 | # get dataloader 48 | self.logger.info('=====> Get train dataloader') 49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger) 50 | 51 | # get loss 52 | self.loss_fc = create_loss(logger, config, self.train_loader) 53 | 54 | # set eval 55 | if self.eval: 56 | test_func = test_loader(config) 57 | self.testing = test_func(config, logger, model, classifier, val=True) 58 | 59 | 60 | def run(self): 61 | # Start Training 62 | self.logger.info('=====> Start RIDE Training') 63 | 64 | # run epoch 65 | for epoch in range(self.training_opt['num_epochs']): 66 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch)) 67 | iter_info_print = {} 68 | 69 | # preprocess for each epoch 70 | total_batch = len(self.train_loader) 71 | if self.training_opt['loss'] == 'RIDE': 72 | self.loss_fc.set_epoch(epoch) 73 | 74 | for step, (inputs, labels, _, _) in enumerate(self.train_loader): 75 | self.optimizer.zero_grad() 76 | 77 | # additional inputs 78 | inputs, labels = inputs.cuda(), labels.cuda() 79 | add_inputs = {} 80 | 81 | features = self.model(inputs) 82 | predictions, all_logits = self.classifier(features, add_inputs) 83 | 84 | # calculate loss 85 | if self.training_opt['loss'] == 'RIDE': 86 | extra_info = {'logits': all_logits} 87 | loss = self.loss_fc(predictions, labels, extra_info) 88 | iter_info_print[self.training_opt['loss']] = loss.sum().item() 89 | else: 90 | all_loss = [] 91 | for logit in all_logits: 92 | all_loss.append(self.loss_fc(logit, labels)) 93 | loss = sum(all_loss) 94 | for i, branch_loss in enumerate(all_loss): 95 | iter_info_print[self.training_opt['loss'] + '_{}'.format(i)] = branch_loss.sum().item() 96 | 97 | loss.backward() 98 | self.optimizer.step() 99 | 100 | # calculate accuracy 101 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0] 102 | 103 | # log information 104 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])}) 105 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter']) 106 | 107 | first_batch = (epoch == 0) and (step == 0) 108 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0: 109 | utils.print_grad(self.classifier.named_parameters()) 110 | utils.print_grad(self.model.named_parameters()) 111 | 112 | # evaluation on validation set 113 | if self.eval: 114 | val_acc = self.testing.run_val(epoch) 115 | else: 116 | val_acc = 0.0 117 | 118 | # checkpoint 119 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc) 120 | 121 | # update scheduler 122 | self.scheduler.step() 123 | 124 | # save best model path 125 | self.checkpoint.save_best_model_path(self.logger) -------------------------------------------------------------------------------- /train_stage1.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | import torch.nn.functional as F 10 | 11 | import utils.general_utils as utils 12 | from data.dataloader import get_loader 13 | from utils.checkpoint_utils import Checkpoint 14 | from utils.training_utils import * 15 | from utils.test_loader import test_loader 16 | 17 | class train_stage1(): 18 | def __init__(self, args, config, logger, eval=False): 19 | # ============================================================================ 20 | # create model 21 | logger.info('=====> Model construction from: ' + str(config['networks']['type'])) 22 | model_type = config['networks']['type'] 23 | model_file = config['networks'][model_type]['def_file'] 24 | model_args = config['networks'][model_type]['params'] 25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type'])) 26 | classifier_type = config['classifiers']['type'] 27 | classifier_file = config['classifiers'][classifier_type]['def_file'] 28 | classifier_args = config['classifiers'][classifier_type]['params'] 29 | model = utils.source_import(model_file).create_model(**model_args) 30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args) 31 | 32 | model = nn.DataParallel(model).cuda() 33 | classifier = nn.DataParallel(classifier).cuda() 34 | 35 | # other initialization 36 | self.config = config 37 | self.logger = logger 38 | self.model = model 39 | self.classifier = classifier 40 | self.optimizer = create_optimizer(model, classifier, logger, config) 41 | self.scheduler = create_scheduler(self.optimizer, logger, config) 42 | self.eval = eval 43 | self.training_opt = config['training_opt'] 44 | 45 | self.checkpoint = Checkpoint(config) 46 | 47 | # get dataloader 48 | self.logger.info('=====> Get train dataloader') 49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger) 50 | 51 | # get loss 52 | self.loss_fc = create_loss(logger, config, self.train_loader) 53 | 54 | # set eval 55 | if self.eval: 56 | test_func = test_loader(config) 57 | self.testing = test_func(config, logger, model, classifier, val=True) 58 | 59 | 60 | def run(self): 61 | # Start Training 62 | self.logger.info('=====> Start Stage 1 Training') 63 | 64 | # run epoch 65 | for epoch in range(self.training_opt['num_epochs']): 66 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch)) 67 | 68 | # preprocess for each epoch 69 | total_batch = len(self.train_loader) 70 | 71 | for step, (inputs, labels, _, _) in enumerate(self.train_loader): 72 | self.optimizer.zero_grad() 73 | 74 | # additional inputs 75 | inputs, labels = inputs.cuda(), labels.cuda() 76 | add_inputs = {} 77 | 78 | features = self.model(inputs) 79 | predictions = self.classifier(features, add_inputs) 80 | 81 | # calculate loss 82 | loss = self.loss_fc(predictions, labels) 83 | iter_info_print = {self.training_opt['loss'] : loss.sum().item(),} 84 | 85 | loss.backward() 86 | self.optimizer.step() 87 | 88 | # calculate accuracy 89 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0] 90 | 91 | # log information 92 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])}) 93 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter']) 94 | 95 | first_batch = (epoch == 0) and (step == 0) 96 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0: 97 | utils.print_grad(self.classifier.named_parameters()) 98 | utils.print_grad(self.model.named_parameters()) 99 | 100 | # evaluation on validation set 101 | if self.eval: 102 | val_acc = self.testing.run_val(epoch) 103 | else: 104 | val_acc = 0.0 105 | 106 | # checkpoint 107 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc) 108 | 109 | # update scheduler 110 | self.scheduler.step() 111 | 112 | # save best model path 113 | self.checkpoint.save_best_model_path(self.logger) -------------------------------------------------------------------------------- /train_stage2.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | import torch.nn.functional as F 10 | 11 | import utils.general_utils as utils 12 | from data.dataloader import get_loader 13 | from utils.checkpoint_utils import Checkpoint 14 | from utils.training_utils import * 15 | from utils.test_loader import test_loader 16 | 17 | class train_stage2(): 18 | def __init__(self, args, config, logger, eval=False): 19 | # ============================================================================ 20 | # create model 21 | logger.info('=====> Model construction from: ' + str(config['networks']['type'])) 22 | model_type = config['networks']['type'] 23 | model_file = config['networks'][model_type]['def_file'] 24 | model_args = config['networks'][model_type]['params'] 25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type'])) 26 | classifier_type = config['classifiers']['type'] 27 | classifier_file = config['classifiers'][classifier_type]['def_file'] 28 | classifier_args = config['classifiers'][classifier_type]['params'] 29 | model = utils.source_import(model_file).create_model(**model_args) 30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args) 31 | 32 | model = nn.DataParallel(model).cuda() 33 | classifier = nn.DataParallel(classifier).cuda() 34 | 35 | # other initialization 36 | self.config = config 37 | self.logger = logger 38 | self.model = model 39 | self.classifier = classifier 40 | self.optimizer = create_optimizer_stage2(model, classifier, logger, config) 41 | self.scheduler = create_scheduler(self.optimizer, logger, config) 42 | self.eval = eval 43 | self.training_opt = config['training_opt'] 44 | 45 | self.checkpoint = Checkpoint(config) 46 | 47 | if config['classifiers']['type'] == 'LWS': 48 | self.checkpoint.load(self.model, self.classifier, args.load_dir, logger) 49 | else: 50 | self.checkpoint.load_backbone(self.model, args.load_dir, logger) 51 | 52 | # get dataloader 53 | self.logger.info('=====> Get train dataloader') 54 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger) 55 | 56 | # get loss 57 | self.loss_fc = create_loss(logger, config, self.train_loader) 58 | 59 | # set eval 60 | if self.eval: 61 | test_func = test_loader(config) 62 | self.testing = test_func(config, logger, model, classifier, val=True) 63 | 64 | 65 | def run(self): 66 | # Start Training 67 | self.logger.info('=====> Start Stage 2 Training') 68 | 69 | # run epoch 70 | for epoch in range(self.training_opt['num_epochs']): 71 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch)) 72 | 73 | # preprocess for each epoch 74 | total_batch = len(self.train_loader) 75 | 76 | for step, (inputs, labels, _, _) in enumerate(self.train_loader): 77 | self.optimizer.zero_grad() 78 | 79 | # additional inputs 80 | inputs, labels = inputs.cuda(), labels.cuda() 81 | add_inputs = {} 82 | 83 | features = self.model(inputs) 84 | predictions = self.classifier(features, add_inputs) 85 | 86 | # calculate loss 87 | loss = self.loss_fc(predictions, labels) 88 | iter_info_print = {self.training_opt['loss'] : loss.sum().item(),} 89 | 90 | loss.backward() 91 | self.optimizer.step() 92 | 93 | # calculate accuracy 94 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0] 95 | 96 | # log information 97 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])}) 98 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter']) 99 | 100 | first_batch = (epoch == 0) and (step == 0) 101 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0: 102 | utils.print_grad(self.classifier.named_parameters()) 103 | utils.print_grad(self.model.named_parameters()) 104 | 105 | # evaluation on validation set 106 | if self.eval: 107 | val_acc = self.testing.run_val(epoch) 108 | else: 109 | val_acc = 0.0 110 | 111 | # checkpoint 112 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc) 113 | 114 | # update scheduler 115 | self.scheduler.step() 116 | 117 | # save best model path 118 | self.checkpoint.save_best_model_path(self.logger) -------------------------------------------------------------------------------- /train_stage2_ride.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | import torch.nn.functional as F 10 | 11 | import utils.general_utils as utils 12 | from data.dataloader import get_loader 13 | from utils.checkpoint_utils import Checkpoint 14 | from utils.training_utils import * 15 | from utils.test_loader import test_loader 16 | 17 | class train_stage2_ride(): 18 | def __init__(self, args, config, logger, eval=False): 19 | # ============================================================================ 20 | # create model 21 | logger.info('=====> Model construction from: ' + str(config['networks']['type'])) 22 | model_type = config['networks']['type'] 23 | model_file = config['networks'][model_type]['def_file'] 24 | model_args = config['networks'][model_type]['params'] 25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type'])) 26 | classifier_type = config['classifiers']['type'] 27 | classifier_file = config['classifiers'][classifier_type]['def_file'] 28 | classifier_args = config['classifiers'][classifier_type]['params'] 29 | model = utils.source_import(model_file).create_model(**model_args) 30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args) 31 | 32 | model = nn.DataParallel(model).cuda() 33 | classifier = nn.DataParallel(classifier).cuda() 34 | 35 | # other initialization 36 | self.config = config 37 | self.logger = logger 38 | self.model = model 39 | self.classifier = classifier 40 | self.optimizer = create_optimizer_stage2(model, classifier, logger, config) 41 | self.scheduler = create_scheduler(self.optimizer, logger, config) 42 | self.eval = eval 43 | self.training_opt = config['training_opt'] 44 | 45 | self.checkpoint = Checkpoint(config) 46 | 47 | if config['classifiers']['type'] == 'LWS': 48 | self.checkpoint.load(self.model, self.classifier, args.load_dir, logger) 49 | else: 50 | self.checkpoint.load_backbone(self.model, args.load_dir, logger) 51 | 52 | # get dataloader 53 | self.logger.info('=====> Get train dataloader') 54 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger) 55 | 56 | # get loss 57 | self.loss_fc = create_loss(logger, config, self.train_loader) 58 | 59 | # set eval 60 | if self.eval: 61 | test_func = test_loader(config) 62 | self.testing = test_func(config, logger, model, classifier, val=True) 63 | 64 | 65 | def run(self): 66 | # Start Training 67 | self.logger.info('=====> Start RIDE Stage 2 Training') 68 | 69 | # run epoch 70 | for epoch in range(self.training_opt['num_epochs']): 71 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch)) 72 | 73 | # preprocess for each epoch 74 | total_batch = len(self.train_loader) 75 | 76 | for step, (inputs, labels, _, _) in enumerate(self.train_loader): 77 | self.optimizer.zero_grad() 78 | iter_info_print = {} 79 | # additional inputs 80 | inputs, labels = inputs.cuda(), labels.cuda() 81 | add_inputs = {} 82 | 83 | features = self.model(inputs) 84 | predictions, all_logits = self.classifier(features, add_inputs) 85 | 86 | # calculate loss 87 | all_loss = [] 88 | for logit in all_logits: 89 | all_loss.append(self.loss_fc(logit, labels)) 90 | loss = sum(all_loss) 91 | for i, branch_loss in enumerate(all_loss): 92 | iter_info_print[self.training_opt['loss'] + '_{}'.format(i)] = branch_loss.sum().item() 93 | 94 | loss.backward() 95 | self.optimizer.step() 96 | 97 | # calculate accuracy 98 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0] 99 | 100 | # log information 101 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])}) 102 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter']) 103 | 104 | first_batch = (epoch == 0) and (step == 0) 105 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0: 106 | utils.print_grad(self.classifier.named_parameters()) 107 | utils.print_grad(self.model.named_parameters()) 108 | 109 | # evaluation on validation set 110 | if self.eval: 111 | val_acc = self.testing.run_val(epoch) 112 | else: 113 | val_acc = 0.0 114 | 115 | # checkpoint 116 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc) 117 | 118 | # update scheduler 119 | self.scheduler.step() 120 | 121 | # save best model path 122 | self.checkpoint.save_best_model_path(self.logger) -------------------------------------------------------------------------------- /train_tade.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | import torch.nn.functional as F 10 | 11 | import utils.general_utils as utils 12 | from data.dataloader import get_loader 13 | from utils.checkpoint_utils import Checkpoint 14 | from utils.training_utils import * 15 | from utils.test_loader import test_loader 16 | 17 | class train_tade(): 18 | def __init__(self, args, config, logger, eval=False): 19 | # ============================================================================ 20 | # create model 21 | logger.info('=====> Model construction from: ' + str(config['networks']['type'])) 22 | model_type = config['networks']['type'] 23 | model_file = config['networks'][model_type]['def_file'] 24 | model_args = config['networks'][model_type]['params'] 25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type'])) 26 | classifier_type = config['classifiers']['type'] 27 | classifier_file = config['classifiers'][classifier_type]['def_file'] 28 | classifier_args = config['classifiers'][classifier_type]['params'] 29 | model = utils.source_import(model_file).create_model(**model_args) 30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args) 31 | 32 | model = nn.DataParallel(model).cuda() 33 | classifier = nn.DataParallel(classifier).cuda() 34 | 35 | # other initialization 36 | self.config = config 37 | self.logger = logger 38 | self.model = model 39 | self.classifier = classifier 40 | self.optimizer = create_optimizer(model, classifier, logger, config) 41 | self.scheduler = create_scheduler(self.optimizer, logger, config) 42 | self.eval = eval 43 | self.training_opt = config['training_opt'] 44 | 45 | self.checkpoint = Checkpoint(config) 46 | 47 | # get dataloader 48 | self.logger.info('=====> Get train dataloader') 49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger) 50 | 51 | # get loss 52 | self.loss_fc = create_loss(logger, config, self.train_loader) 53 | 54 | # set eval 55 | if self.eval: 56 | test_func = test_loader(config) 57 | self.testing = test_func(config, logger, model, classifier, val=True) 58 | 59 | 60 | def run(self): 61 | # Start Training 62 | self.logger.info('=====> Start TADE Training') 63 | 64 | # run epoch 65 | for epoch in range(self.training_opt['num_epochs']): 66 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch)) 67 | iter_info_print = {} 68 | 69 | # preprocess for each epoch 70 | total_batch = len(self.train_loader) 71 | 72 | for step, (inputs, labels, _, _) in enumerate(self.train_loader): 73 | self.optimizer.zero_grad() 74 | 75 | # additional inputs 76 | inputs, labels = inputs.cuda(), labels.cuda() 77 | add_inputs = {} 78 | 79 | features = self.model(inputs) 80 | predictions, all_logits = self.classifier(features, add_inputs) 81 | 82 | # calculate loss 83 | if self.training_opt['loss'] == 'TADE': 84 | extra_info = {'logits': all_logits} 85 | loss = self.loss_fc(predictions, labels, extra_info) 86 | iter_info_print[self.training_opt['loss']] = loss.sum().item() 87 | else: 88 | all_loss = [] 89 | for logit in all_logits: 90 | all_loss.append(self.loss_fc(logit, labels)) 91 | loss = sum(all_loss) 92 | for i, branch_loss in enumerate(all_loss): 93 | iter_info_print[self.training_opt['loss'] + '_{}'.format(i)] = branch_loss.sum().item() 94 | 95 | loss.backward() 96 | self.optimizer.step() 97 | 98 | # calculate accuracy 99 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0] 100 | 101 | # log information 102 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])}) 103 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter']) 104 | 105 | first_batch = (epoch == 0) and (step == 0) 106 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0: 107 | utils.print_grad(self.classifier.named_parameters()) 108 | utils.print_grad(self.model.named_parameters()) 109 | 110 | # evaluation on validation set 111 | if self.eval: 112 | val_acc = self.testing.run_val(epoch) 113 | else: 114 | val_acc = 0.0 115 | 116 | # checkpoint 117 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc) 118 | 119 | # update scheduler 120 | self.scheduler.step() 121 | 122 | # save best model path 123 | self.checkpoint.save_best_model_path(self.logger) -------------------------------------------------------------------------------- /train_tde.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | import torch.nn.functional as F 10 | 11 | import utils.general_utils as utils 12 | from data.dataloader import get_loader 13 | from utils.checkpoint_utils import Checkpoint 14 | from utils.training_utils import * 15 | from utils.test_loader import test_loader 16 | 17 | class train_tde(): 18 | def __init__(self, args, config, logger, eval=False): 19 | # ============================================================================ 20 | # create model 21 | logger.info('=====> Model construction from: ' + str(config['networks']['type'])) 22 | model_type = config['networks']['type'] 23 | model_file = config['networks'][model_type]['def_file'] 24 | model_args = config['networks'][model_type]['params'] 25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type'])) 26 | classifier_type = config['classifiers']['type'] 27 | classifier_file = config['classifiers'][classifier_type]['def_file'] 28 | classifier_args = config['classifiers'][classifier_type]['params'] 29 | model = utils.source_import(model_file).create_model(**model_args) 30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args) 31 | 32 | model = nn.DataParallel(model).cuda() 33 | classifier = nn.DataParallel(classifier).cuda() 34 | 35 | # other initialization 36 | self.config = config 37 | self.logger = logger 38 | self.model = model 39 | self.classifier = classifier 40 | self.optimizer = create_optimizer(model, classifier, logger, config) 41 | self.scheduler = create_scheduler(self.optimizer, logger, config) 42 | self.eval = eval 43 | self.training_opt = config['training_opt'] 44 | 45 | self.checkpoint = Checkpoint(config) 46 | 47 | # get dataloader 48 | self.logger.info('=====> Get train dataloader') 49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger) 50 | 51 | # get loss 52 | self.loss_fc = create_loss(logger, config, self.train_loader) 53 | 54 | # set eval 55 | if self.eval: 56 | test_func = test_loader(config) 57 | self.testing = test_func(config, logger, model, classifier, val=True) 58 | 59 | 60 | def update_embed(self, embed, input): 61 | # embed is only updated during training 62 | assert len(input.shape) == 2 63 | with torch.no_grad(): 64 | embed = embed * 0.995 + (1 - 0.995) * input.mean(0).view(-1) 65 | return embed 66 | 67 | 68 | def run(self): 69 | # Start Training 70 | self.logger.info('=====> Start TDE Training') 71 | 72 | # init embed 73 | self.embed = torch.zeros(2048).cuda() 74 | 75 | # run epoch 76 | for epoch in range(self.training_opt['num_epochs']): 77 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch)) 78 | 79 | # preprocess for each epoch 80 | total_batch = len(self.train_loader) 81 | 82 | for step, (inputs, labels, _, _) in enumerate(self.train_loader): 83 | self.optimizer.zero_grad() 84 | 85 | # additional inputs 86 | inputs, labels = inputs.cuda(), labels.cuda() 87 | add_inputs = {} 88 | 89 | features = self.model(inputs) 90 | predictions = self.classifier(features, add_inputs) 91 | 92 | # update embed during training 93 | self.embed = self.update_embed(self.embed, features) 94 | 95 | # calculate loss 96 | loss = self.loss_fc(predictions, labels) 97 | iter_info_print = {self.training_opt['loss'] : loss.sum().item(),} 98 | 99 | loss.backward() 100 | self.optimizer.step() 101 | 102 | # calculate accuracy 103 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0] 104 | 105 | # log information 106 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])}) 107 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter']) 108 | 109 | first_batch = (epoch == 0) and (step == 0) 110 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0: 111 | utils.print_grad(self.classifier.named_parameters()) 112 | utils.print_grad(self.model.named_parameters()) 113 | 114 | # evaluation on validation set 115 | if self.eval: 116 | val_acc = self.testing.run_val(epoch, self.embed) 117 | else: 118 | val_acc = 0.0 119 | 120 | # checkpoint 121 | add_dict = {} 122 | add_dict['embed'] = self.embed.cpu() 123 | self.logger.info('Embed Mean: {}'.format(self.embed.mean().item())) 124 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc, add_dict=add_dict) 125 | 126 | # update scheduler 127 | self.scheduler.step() 128 | 129 | # save best model path 130 | self.checkpoint.save_best_model_path(self.logger) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/utils/__init__.py -------------------------------------------------------------------------------- /utils/checkpoint_utils.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import os 6 | import json 7 | import torch 8 | import numpy as np 9 | 10 | class Checkpoint(): 11 | def __init__(self, config): 12 | self.config = config 13 | self.best_epoch = -1 14 | self.best_performance = -1 15 | self.best_model_path = None 16 | 17 | 18 | def save(self, model, classifier, epoch, logger, add_dict={}, acc=None): 19 | # only save at certain steps or the last epoch 20 | if epoch % self.config['checkpoint_opt']['checkpoint_step'] != 0 and epoch < (self.config['training_opt']['num_epochs'] - 1): 21 | return 22 | 23 | output = { 24 | 'model': model.state_dict(), 25 | 'classifier': classifier.state_dict(), 26 | 'epoch': epoch, 27 | } 28 | output.update(add_dict) 29 | 30 | model_name = 'epoch_{}_'.format(epoch) + self.config['checkpoint_opt']['checkpoint_name'] 31 | model_path = os.path.join(self.config['output_dir'], model_name) 32 | 33 | logger.info('Model at epoch {} is saved to {}'.format(epoch, model_path)) 34 | torch.save(output, model_path) 35 | 36 | # update best model 37 | if acc is not None: 38 | if float(acc) > self.best_performance: 39 | self.best_epoch = epoch 40 | self.best_performance = float(acc) 41 | self.best_model_path = model_path 42 | else: 43 | # if acc is None, the newest is always the best 44 | self.best_epoch = epoch 45 | self.best_model_path = model_path 46 | 47 | logger.info('Best model is at epoch {} with accuracy {:9.3f}'.format(self.best_epoch, self.best_performance)) 48 | 49 | 50 | def save_best_model_path(self, logger): 51 | logger.info('Best model is at epoch {} with accuracy {:9.3f} (Path: {})'.format(self.best_epoch, self.best_performance, self.best_model_path)) 52 | with open(os.path.join(self.config['output_dir'], 'best_checkpoint'), 'w+') as f: 53 | f.write(self.best_model_path + ' ' + str(self.best_epoch) + ' ' + str(self.best_performance) + '\n') 54 | 55 | def load(self, model, classifier, path, logger): 56 | if path.split('.')[-1] != 'pth': 57 | with open(os.path.join(path, 'best_checkpoint')) as f: 58 | path = f[0].split(' ')[0] 59 | 60 | checkpoint = torch.load(path, map_location='cpu') 61 | logger.info('Loading checkpoint pretrained with epoch {}.'.format(checkpoint['epoch'])) 62 | 63 | model_state = checkpoint['model'] 64 | classifier_state = checkpoint['classifier'] 65 | 66 | model = self.load_module(model, model_state, logger) 67 | classifier = self.load_module(classifier, classifier_state, logger) 68 | return checkpoint 69 | 70 | 71 | def load_backbone(self, model, path, logger): 72 | if path.split('.')[-1] != 'pth': 73 | with open(os.path.join(path, 'best_checkpoint')) as f: 74 | path = f[0].split(' ')[0] 75 | 76 | checkpoint = torch.load(path, map_location='cpu') 77 | logger.info('Loading checkpoint pretrained with epoch {}.'.format(checkpoint['epoch'])) 78 | 79 | model_state = checkpoint['model'] 80 | 81 | model = self.load_module(model, model_state, logger) 82 | return checkpoint 83 | 84 | 85 | def load_module(self, module, module_state, logger): 86 | x = module.state_dict() 87 | for key, _ in x.items(): 88 | if key in module_state: 89 | x[key] = module_state[key] 90 | logger.info('Load {:>50} from checkpoint.'.format(key)) 91 | elif 'module.' + key in module_state: 92 | x[key] = module_state['module.' + key] 93 | logger.info('Load {:>50} from checkpoint (rematch with module.).'.format(key)) 94 | else: 95 | logger.info('WARNING: Key {} is missing in the checkpoint.'.format(key)) 96 | 97 | module.load_state_dict(x) 98 | return module -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | import torch 5 | import numpy as np 6 | import importlib 7 | 8 | 9 | def count_dataset(train_loader): 10 | label_freq = {} 11 | if isinstance(train_loader, list) or isinstance(train_loader, tuple): 12 | all_labels = train_loader[0].dataset.labels 13 | else: 14 | all_labels = train_loader.dataset.labels 15 | for label in all_labels: 16 | key = str(label) 17 | label_freq[key] = label_freq.get(key, 0) + 1 18 | label_freq = dict(sorted(label_freq.items())) 19 | label_freq_array = np.array(list(label_freq.values())) 20 | return label_freq_array 21 | 22 | 23 | def compute_adjustment(train_loader, tro=1.0): 24 | """compute the base probabilities""" 25 | label_freq = {} 26 | for key in train_loader.dataset.labels: 27 | label_freq[key] = label_freq.get(key, 0) + 1 28 | label_freq = dict(sorted(label_freq.items())) 29 | label_freq_array = np.array(list(label_freq.values())) 30 | label_freq_array = label_freq_array / label_freq_array.sum() 31 | adjustments = np.log(label_freq_array ** tro + 1e-12) 32 | adjustments = torch.from_numpy(adjustments) 33 | return adjustments 34 | 35 | 36 | def update(config, algo_config, args): 37 | if args.output_dir: 38 | config['output_dir'] = args.output_dir 39 | if args.load_dir: 40 | config['load_dir'] = args.load_dir 41 | 42 | # select the algorithm we use 43 | if args.train_type: 44 | config['training_opt']['type'] = args.train_type 45 | # update algorithm details based on training_type 46 | algo_info = algo_config[args.train_type] 47 | config['sampler'] = algo_info['sampler'] 48 | config['num_sampler'] = algo_info['num_sampler'] if 'num_sampler' in algo_info else 1 49 | config['batch_split'] = algo_info['batch_split'] if 'batch_split' in algo_info else False 50 | config['testing_opt']['type'] = algo_info['test_type'] 51 | config['training_opt']['loss'] = algo_info['loss_type'] 52 | config['networks']['type'] = algo_info['backbone_type'] 53 | config['classifiers']['type'] = algo_info['classifier_type'] 54 | config['algorithm_opt'] = algo_info['algorithm_opt'] 55 | config['dataset']['rand_aug'] = algo_info['rand_aug'] if 'rand_aug' in algo_info else False 56 | 57 | if 'num_epochs' in algo_info: 58 | config['training_opt']['num_epochs'] = algo_info['num_epochs'] 59 | if 'batch_size' in algo_info: 60 | config['training_opt']['batch_size'] = algo_info['batch_size'] 61 | if 'optim_params' in algo_info: 62 | config['training_opt']['optim_params'] = algo_info['optim_params'] 63 | if 'scheduler' in algo_info: 64 | config['training_opt']['scheduler'] = algo_info['scheduler'] 65 | if 'scheduler_params' in algo_info: 66 | config['training_opt']['scheduler_params'] = algo_info['scheduler_params'] 67 | 68 | 69 | # other updates 70 | if args.lr: 71 | config['training_opt']['optim_params']['lr'] = args.lr 72 | if args.testset: 73 | config['dataset']['testset'] = args.testset 74 | if args.model_type: 75 | config['classifiers']['type'] = args.model_type 76 | if args.loss_type: 77 | config['training_opt']['loss'] = args.loss_type 78 | if args.sample_type: 79 | config['sampler'] = args.sample_type 80 | if args.rand_aug: 81 | config['dataset']['rand_aug'] = True 82 | if args.save_all: 83 | config['saving_opt']['save_all'] = True 84 | return config 85 | 86 | 87 | def source_import(file_path): 88 | """This function imports python module directly from source code using importlib""" 89 | spec = importlib.util.spec_from_file_location('', file_path) 90 | module = importlib.util.module_from_spec(spec) 91 | spec.loader.exec_module(module) 92 | return module 93 | 94 | def print_grad(named_parameters): 95 | """ show grads """ 96 | total_norm = 0 97 | param_to_norm = {} 98 | param_to_shape = {} 99 | for n, p in named_parameters: 100 | if p.grad is not None: 101 | param_norm = p.grad.data.norm(2) 102 | total_norm += param_norm ** 2 103 | param_to_norm[n] = param_norm 104 | param_to_shape[n] = p.size() 105 | total_norm = total_norm ** (1. / 2) 106 | print('---Total norm {:.3f} -----------------'.format(total_norm)) 107 | for name, norm in sorted(param_to_norm.items(), key=lambda x: -x[1]): 108 | print("{:<50s}: {:.3f}, ({})".format(name, norm, param_to_shape[name])) 109 | print('-------------------------------', flush=True) 110 | return total_norm 111 | 112 | def print_config(config, logger, head=''): 113 | for key, val in config.items(): 114 | if isinstance(val, dict): 115 | logger.info(head + str(key)) 116 | print_config(val, logger, head=head + ' ') 117 | else: 118 | logger.info(head + '{} : {}'.format(str(key), str(val))) 119 | 120 | class TriggerAction(): 121 | def __init__(self, name): 122 | self.name = name 123 | self.action = {} 124 | 125 | def add_action(self, name, func): 126 | assert str(name) not in self.action 127 | self.action[str(name)] = func 128 | 129 | def remove_action(self, name): 130 | assert str(name) in self.action 131 | del self.action[str(name)] 132 | assert str(name) not in self.action 133 | 134 | def run_all(self, logger=None): 135 | for key, func in self.action.items(): 136 | if logger: 137 | logger.info('trigger {}'.format(key)) 138 | func() 139 | 140 | 141 | def calculate_recall(prediction, label, split_mask=None): 142 | recall = (prediction == label).float() 143 | if split_mask is not None: 144 | recall = recall[split_mask].mean().item() 145 | else: 146 | recall = recall.mean().item() 147 | return recall 148 | 149 | 150 | def calculate_precision(prediction, label, num_class, split_mask=None): 151 | pred_count = torch.zeros(num_class).to(label.device) 152 | for i in range(num_class): 153 | pred_count[i] = (prediction == i).float().sum() 154 | 155 | precision = (prediction == label).float() / pred_count[prediction].float() 156 | 157 | if split_mask is not None: 158 | available_class = len(set(label[split_mask].tolist())) 159 | precision = precision[split_mask].sum().item() / available_class 160 | else: 161 | available_class = len(set(label.tolist())) 162 | precision = precision.sum().item() / available_class 163 | 164 | return precision 165 | 166 | 167 | def calculate_f1(recall, precision): 168 | f1 = 2 * recall * precision / (recall + precision) 169 | return f1 -------------------------------------------------------------------------------- /utils/logger_utils.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | import os 5 | import csv 6 | from datetime import datetime 7 | 8 | class custom_logger(): 9 | def __init__(self, output_path, name='logger.txt'): 10 | now = datetime.now() 11 | logger_name = str(now.strftime("20%y_%h_%d_")) + name 12 | self.logger_path = os.path.join(output_path, logger_name) 13 | self.csv_path = os.path.join(output_path, 'results.csv') 14 | # init logger file 15 | f = open(self.logger_path, "w+") 16 | f.write(self.get_local_time() + 'Start Logging \n') 17 | f.close() 18 | # init csv 19 | with open(self.csv_path, 'w') as f: 20 | writer = csv.writer(f) 21 | writer.writerow([self.get_local_time(), ]) 22 | 23 | def get_local_time(self): 24 | now = datetime.now() 25 | return str(now.strftime("%y_%h_%d %H:%M:%S : ")) 26 | 27 | def info(self, log_str): 28 | print(str(log_str)) 29 | with open(self.logger_path, "a") as f: 30 | f.write(self.get_local_time() + str(log_str) + '\n') 31 | 32 | def raise_error(self, error): 33 | prototype = '************* Error: {} *************'.format(str(error)) 34 | self.info(prototype) 35 | raise ValueError(str(error)) 36 | 37 | def info_iter(self, epoch, batch, total_batch, info_dict, print_iter): 38 | if batch % print_iter != 0: 39 | pass 40 | else: 41 | acc_log = 'Epoch {:5d}, Batch {:6d}/{},'.format(epoch, batch, total_batch) 42 | for key, val in info_dict.items(): 43 | acc_log += ' {}: {:9.3f},'.format(str(key), float(val)) 44 | self.info(acc_log) 45 | 46 | def write_results(self, result_list): 47 | with open(self.csv_path, 'a') as f: 48 | writer = csv.writer(f) 49 | writer.writerow(result_list) -------------------------------------------------------------------------------- /utils/test_loader.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | from test_baseline import test_baseline 6 | from test_tde import test_tde 7 | from test_la import test_la 8 | 9 | def test_loader(config): 10 | if config['testing_opt']['type'] in ('baseline'): 11 | return test_baseline 12 | elif config['testing_opt']['type'] in ('TDE'): 13 | return test_tde 14 | elif config['testing_opt']['type'] in ('LA'): 15 | return test_la 16 | else: 17 | raise ValueError('Wrong Test Pipeline') 18 | 19 | -------------------------------------------------------------------------------- /utils/train_loader.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # Kaihua Tang 3 | ###################################### 4 | 5 | import imp 6 | from train_baseline import train_baseline 7 | from train_la import train_la 8 | from train_bbn import train_bbn 9 | from train_tde import train_tde 10 | from train_ride import train_ride 11 | from train_tade import train_tade 12 | from train_ldam import train_ldam 13 | from train_mixup import train_mixup 14 | 15 | from train_lff import train_lff 16 | 17 | from train_stage1 import train_stage1 18 | from train_stage2 import train_stage2 19 | from train_stage2_ride import train_stage2_ride 20 | 21 | from train_irm_dual import train_irm_dual 22 | from train_center_dual import train_center_dual 23 | from train_center_single import train_center_single 24 | from train_center_triple import train_center_triple 25 | from train_center_tade import train_center_tade 26 | from train_center_ride import train_center_ride 27 | from train_center_ride_mixup import train_center_ride_mixup 28 | from train_center_ldam_dual import train_center_ldam_dual 29 | from train_center_dual_mixup import train_center_dual_mixup 30 | 31 | 32 | def train_loader(config): 33 | if config['training_opt']['type'] in ('baseline', 'Focal'): 34 | return train_baseline 35 | elif config['training_opt']['type'] in ('LFF', 'LFFLA'): 36 | return train_lff 37 | elif config['training_opt']['type'] in ('LA', 'FocalLA'): 38 | return train_la 39 | elif config['training_opt']['type'] in ('BBN'): 40 | return train_bbn 41 | elif config['training_opt']['type'] in ('TDE'): 42 | return train_tde 43 | elif config['training_opt']['type'] in ('mixup'): 44 | return train_mixup 45 | elif config['training_opt']['type'] in ('LDAM'): 46 | return train_ldam 47 | elif config['training_opt']['type'] in ('RIDE'): 48 | return train_ride 49 | elif config['training_opt']['type'] in ('TADE'): 50 | return train_tade 51 | elif config['training_opt']['type'] in ('stage1'): 52 | return train_stage1 53 | elif config['training_opt']['type'] in ('crt_stage2', 'lws_stage2'): 54 | return train_stage2 55 | elif config['training_opt']['type'] in ('ride_stage2'): 56 | return train_stage2_ride 57 | elif config['training_opt']['type'] in ('center_dual', 'env_dual'): 58 | return train_center_dual 59 | elif config['training_opt']['type'] in ('center_single'): 60 | return train_center_single 61 | elif config['training_opt']['type'] in ('center_triple'): 62 | return train_center_triple 63 | elif config['training_opt']['type'] in ('center_LDAM_dual'): 64 | return train_center_ldam_dual 65 | elif config['training_opt']['type'] in ('center_dual_mixup'): 66 | return train_center_dual_mixup 67 | elif config['training_opt']['type'] in ('center_tade'): 68 | return train_center_tade 69 | elif config['training_opt']['type'] in ('center_ride'): 70 | return train_center_ride 71 | elif config['training_opt']['type'] in ('center_ride_mixup'): 72 | return train_center_ride_mixup 73 | elif config['training_opt']['type'] in ('irm_dual'): 74 | return train_irm_dual 75 | else: 76 | raise ValueError('Wrong Train Type') --------------------------------------------------------------------------------